├── environment.yml ├── utils ├── __init__.py ├── general.py ├── pid_grid_search.py ├── data_collection.py ├── data_processing.py ├── evaluation.py └── parameters.py ├── README.md ├── .gitignore ├── Offline_RL_Comparison.ipynb ├── TD3_BC.py ├── BCQ.py ├── CQL.py └── SAC_RNN.py /environment.yml: -------------------------------------------------------------------------------- 1 | name: offline-glucose 2 | channels: 3 | - defaults 4 | - conda-forge 5 | dependencies: 6 | - python=3.8 7 | - conda>=4.9.2 8 | - pip>=21.2.4 9 | - cudatoolkit=10.2 10 | - jupyterlab 11 | - pip: 12 | - -f https://download.pytorch.org/whl/torch_stable.html 13 | - gym==0.9.4 14 | - torch==1.7.1+cu101 15 | - torchvision==0.8.2+cu101 16 | - matplotlib==3.5.1 17 | - numpy==1.22.3 18 | - git+https://github.com/hemerson1/simglucose.git 19 | 20 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sat Apr 2 14:02:57 2022 5 | 6 | """ 7 | 8 | from .general import calculate_bolus, calculate_risk, is_in_range, PID_action 9 | from .parameters import create_env, get_params 10 | from .data_collection import fill_replay, fill_replay_split 11 | from .data_processing import unpackage_replay, get_batch 12 | from .evaluation import test_algorithm, create_graph 13 | from .pid_grid_search import optimal_pid_search 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Offline RL for Safer Glucose Control 2 | 3 | The code release for *"Offline Reinforcement Learning for Safer Blood 4 | Glucose Control in People with Type 1 Diabetes"*. 5 | 6 | ## Installation 7 | 8 | All python dependencies are in ```environment.yml```. Install with: 9 | 10 | ``` 11 | conda env create -f environment.yml 12 | conda activate offline_glucose 13 | pip install -e . 14 | ``` 15 | 16 | ## Usage 17 | 18 | An example of the data workflow for TD3-BC is provided in: ```Offline_RL_Comparison.ipynb```. 19 | 20 | ## License 21 | [MIT](https://choosealicense.com/licenses/mit/) 22 | -------------------------------------------------------------------------------- /utils/general.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sat Apr 2 13:29:03 2022 5 | """ 6 | 7 | """ 8 | Functions for general use across data collection, training and evaluation. 9 | """ 10 | 11 | import math 12 | import numpy as np 13 | 14 | 15 | """ 16 | When provided with a blood glucose output from UVA/Padova simulator it 17 | calculates the corresponding magni risk and returns as a floats. 18 | """ 19 | def calculate_risk(blood_glucose): 20 | return 10 * math.pow((3.5506 * (math.pow(math.log(max(1, blood_glucose[0])), 0.8353) - 3.7932)), 2) 21 | 22 | """ 23 | Uses the current blood glucose value, meal history and current meals carbs 24 | to calculate the optimal bolus dose for a meal for a patient 25 | """ 26 | def calculate_bolus(blood_glucose, meal_history, current_meal, 27 | carbohyrdate_ratio, correction_factor, target_blood_glucose): 28 | 29 | # calculate the meal bolus using meal carbs 30 | bolus = current_meal / carbohyrdate_ratio 31 | 32 | # if a meal hasn't occurred in meal history 33 | if np.sum(meal_history) == 0: 34 | 35 | # correct the bolus for high or low blood glucose 36 | bolus += (blood_glucose[0] - target_blood_glucose) / correction_factor 37 | 38 | return bolus / 3 39 | 40 | """ 41 | When given the current blood glucose value determine if it falls in range and 42 | return a value indicating its position. 43 | """ 44 | def is_in_range(blood_glucose, hypo_threshold, hyper_threshold, sig_hypo_threshold, sig_hyper_threshold): 45 | 46 | # output: 0 = in range, 1 = hyper, -1 = hypo 47 | if blood_glucose > sig_hyper_threshold: return 2 48 | elif blood_glucose > hyper_threshold: return 1 49 | elif blood_glucose < sig_hypo_threshold: return -2 50 | elif blood_glucose < hypo_threshold: return -1 51 | else: return 0 52 | 53 | """ 54 | Calculate the recommended basal dose for a patient based on their current 55 | blood glucose and their parameters. 56 | """ 57 | def PID_action(blood_glucose, previous_error, integrated_state, 58 | target_blood_glucose, kp, ki, kd, basal_default): 59 | 60 | # proportional control 61 | error = target_blood_glucose - blood_glucose[0] 62 | p_act = kp * error 63 | 64 | # integral control 65 | integrated_state += error 66 | i_act = ki * integrated_state 67 | 68 | # derivative control 69 | d_act = kd * (error - previous_error) 70 | previous_error = error 71 | 72 | # get the final dose output 73 | calculated_dose = np.array([(p_act + i_act + d_act + basal_default) / 3], dtype=np.float32) 74 | 75 | return calculated_dose, previous_error, integrated_state 76 | 77 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ -------------------------------------------------------------------------------- /utils/pid_grid_search.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sat Apr 2 13:44:34 2022 5 | 6 | """ 7 | 8 | import numpy as np 9 | 10 | from .general import calculate_risk 11 | 12 | """ 13 | Perform a grid search using the provided values to determine the optimal 14 | PID parameters for a given patient. 15 | """ 16 | def optimal_pid_search(env, bas, cr, cf, k_ps, k_is, k_ds, num_days=10, target_bg=144): 17 | 18 | # initialise the parameters 19 | listed_rewards = list() 20 | max_timesteps = 480 * num_days 21 | counter, current_index = 0, 0 22 | current_max = -1000000000 23 | 24 | 25 | for k_p in k_ps: 26 | for k_i in k_is: 27 | for k_d in k_ds: 28 | 29 | max_val = 0 30 | 31 | # reset the seed 32 | env.seed(0) 33 | 34 | done, bg_val = False, env.reset() 35 | rewards, timesteps, meal = 0, 0, 0 36 | 37 | # create the state 38 | meal_history = np.zeros(int((3 * 60) / 3)) 39 | integrated_state = 0 40 | previous_error = 0 41 | 42 | while not done and timesteps < max_timesteps: 43 | 44 | # proportional control 45 | error = target_bg - bg_val[0] 46 | p_act = k_p * error 47 | 48 | # integral control 49 | integrated_state += error 50 | i_act = k_i * integrated_state 51 | 52 | # derivative control 53 | d_act = k_d * (error - previous_error) 54 | 55 | # get the combined pid action 56 | previous_error = error 57 | action = (p_act + i_act + d_act + bas) / 3 58 | chosen_action = max(action, 0) 59 | 60 | # keep track of the max insulin dose 61 | if action > max_val: max_val = action 62 | 63 | # get the bolus dose 64 | bolus = 0 65 | if meal > 0: 66 | bolus = meal / cr 67 | if np.sum(meal_history) == 0: 68 | bolus += (bg_val[0] - target_bg) / cf 69 | chosen_action += max(bolus/3, 0) 70 | 71 | # step the environment 72 | next_bg_val, _, done, info = env.step(chosen_action) 73 | reward = - calculate_risk(next_bg_val[0]) 74 | if done: reward -= 1e5 75 | 76 | # update the state 77 | meal_history = np.append(meal_history, meal) 78 | meal_history = np.delete(meal_history, 0) 79 | 80 | # update the state and memory 81 | bg_val = next_bg_val 82 | rewards += reward 83 | timesteps += 1 84 | meal = info['meal'] 85 | 86 | counter += 1 87 | 88 | # keep track of the max reward 89 | if timesteps == max_timesteps: 90 | 91 | if reward > current_max: 92 | current_max = rewards 93 | current_index = counter 94 | 95 | data = { 96 | "params" : "kp: {}, ki: {}, kd: {}".format(k_p, k_i, k_d), 97 | "reward" : rewards, "max_val": max_val/bas 98 | } 99 | listed_rewards.append(data) 100 | 101 | # display the run results 102 | print('#{} kp:{} ki:{} kd:{} -- Reward: {} Timesteps: {}'.format(counter, str(k_p), str(k_i), str(k_d), rewards, timesteps)) 103 | print('Max Action {}'.format(max_val/bas)) 104 | print('--------------------------------') 105 | 106 | # display the best completed runs 107 | print('Best run {}'.format(current_index)) 108 | sorted_rewards = sorted(listed_rewards, key=lambda d: d['reward'], reverse=True) 109 | for idx, val in enumerate(sorted_rewards): 110 | print("Rank: {} | Reward {} | {} | {}".format(idx + 1, val["reward"], val["params"], val["max_val"])) -------------------------------------------------------------------------------- /Offline_RL_Comparison.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "6045578b-41ab-4a1d-9488-0138d89bfb5c", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "# INITIALISE THE ENVIRONMENT ----------------------\n", 11 | "\n", 12 | "import numpy as np\n", 13 | "\n", 14 | "from utils import create_env\n", 15 | "\n", 16 | "# Set the parameters for the meal scenario\n", 17 | "prob = [0.95, 0.1, 0.95, 0.1, 0.95, 0.1]\n", 18 | "time_lb = np.array([5, 9, 10, 14, 16, 20])\n", 19 | "time_ub = np.array([9, 10, 14, 16, 20, 23])\n", 20 | "time_mu = np.array([7, 9.5, 12, 15, 18, 21.5])\n", 21 | "time_sigma = np.array([30, 15, 30, 15, 30, 15])\n", 22 | "amount_mu = [50, 15, 70, 15, 90, 30]\n", 23 | "amount_sigma = [10, 5, 10, 5, 10, 5] \n", 24 | "schedule=[prob, time_lb, time_ub, time_mu, time_sigma, amount_mu, amount_sigma]\n", 25 | "\n", 26 | "# Incorporate the schedule into the environment\n", 27 | "create_env(schedule=schedule)" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "id": "04b98287-e72a-481e-9350-1602f8ccb10d", 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "# SPECIFY THE PARAMETERS -----------------------\n", 38 | "\n", 39 | "from utils import get_params\n", 40 | "\n", 41 | "# Get the parameters for a specified patient\n", 42 | "patient_params = get_params()[\"adult#1\"]\n", 43 | "bas = patient_params[\"u2ss\"] * (patient_params[\"BW\"] / 6000) * 3\n", 44 | "\n", 45 | "# Set the parameters\n", 46 | "params = {\n", 47 | " \n", 48 | " # Environmental\n", 49 | " \"state_size\": 3,\n", 50 | " \"basal_default\": bas, \n", 51 | " \"target_blood_glucose\": 144.0 ,\n", 52 | " \"days\": 10, \n", 53 | " \n", 54 | " # PID and Bolus\n", 55 | " \"carbohydrate_ratio\": patient_params[\"carbohydrate_ratio\"],\n", 56 | " \"correction_factor\": patient_params[\"correction_factor\"],\n", 57 | " \"kp\": patient_params[\"kp\"],\n", 58 | " \"ki\": patient_params[\"ki\"],\n", 59 | " \"kd\": patient_params[\"kd\"],\n", 60 | " \n", 61 | " # RL \n", 62 | " \"training_timesteps\": int(1e5),\n", 63 | " \"device\": \"cpu\",\n", 64 | " \"rnn\": None\n", 65 | "}" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "id": "500164b6-ab56-47cf-a09b-50381b55f8c5", 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "# COLLECT THE DATA --------------------------------\n", 76 | "\n", 77 | "from utils import fill_replay_split\n", 78 | "\n", 79 | "import gym\n", 80 | "\n", 81 | "# initialise the environment\n", 82 | "env = gym.make(patient_params[\"env_name\"])\n", 83 | "\n", 84 | "# Fill the replay\n", 85 | "full_replay = fill_replay_split(\n", 86 | " env=env, \n", 87 | " replay_name=patient_params[\"replay_name\"],\n", 88 | " data_split=0.0,\n", 89 | " noise=True,\n", 90 | " bolus_noise=0.1,\n", 91 | " seed=0,\n", 92 | " params=params\n", 93 | ")" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "id": "abd3bd31-d092-4e34-b604-a4e9b14a14f9", 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "# TRAIN THE MODEL ---------------------------\n", 104 | "\n", 105 | "from utils import get_params\n", 106 | "from TD3_BC import td3_bc\n", 107 | "\n", 108 | "# Initialise the agent\n", 109 | "agent = td3_bc(\n", 110 | " init_seed=0,\n", 111 | " patient_params=patient_params,\n", 112 | " params=params\n", 113 | ")\n", 114 | "\n", 115 | "# Train the agent\n", 116 | "agent.train_model()" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": null, 122 | "id": "ecb7f832-bc43-4a5f-861d-db43d18890cb", 123 | "metadata": {}, 124 | "outputs": [], 125 | "source": [ 126 | "# TEST THE MODEL ---------------------------\n", 127 | "\n", 128 | "from utils import get_params\n", 129 | "from TD3_BC import td3_bc\n", 130 | "\n", 131 | "# Initialise the agent\n", 132 | "agent = td3_bc(\n", 133 | " init_seed=0,\n", 134 | " patient_params=patient_params,\n", 135 | " params=params\n", 136 | ")\n", 137 | "\n", 138 | "# Train the agent\n", 139 | "agent.test_model()" 140 | ] 141 | } 142 | ], 143 | "metadata": { 144 | "kernelspec": { 145 | "display_name": "Python 3 (ipykernel)", 146 | "language": "python", 147 | "name": "python3" 148 | }, 149 | "language_info": { 150 | "codemirror_mode": { 151 | "name": "ipython", 152 | "version": 3 153 | }, 154 | "file_extension": ".py", 155 | "mimetype": "text/x-python", 156 | "name": "python", 157 | "nbconvert_exporter": "python", 158 | "pygments_lexer": "ipython3", 159 | "version": "3.8.13" 160 | } 161 | }, 162 | "nbformat": 4, 163 | "nbformat_minor": 5 164 | } 165 | -------------------------------------------------------------------------------- /utils/data_collection.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sat Apr 2 13:28:57 2022 5 | """ 6 | 7 | """ 8 | Functions for generating datasets from the glucose dynamics simulator. 9 | """ 10 | 11 | import numpy as np 12 | import pickle 13 | from collections import defaultdict 14 | 15 | from .general import PID_action, calculate_bolus, calculate_risk 16 | 17 | 18 | """ 19 | Create a replay with a mixture of expert data and random data. 20 | """ 21 | def fill_replay_split( 22 | env, replay_name, data_split=0.5, replay_length=100_000, meal_announce=0, 23 | noise=False, bolus_noise=None, bolus_overestimate=0.0, seed=0, missing_data_prob=0.0, 24 | compression_prob=0.0, params=None): 25 | 26 | # determine the split of the two datasets 27 | random_timesteps = int(data_split * replay_length) 28 | expert_timesteps = int(replay_length - random_timesteps) 29 | 30 | # Random data generation ------------------------------------------------- 31 | 32 | new_replay = None 33 | if random_timesteps > 0: 34 | 35 | new_replay = fill_replay(replay_length=random_timesteps,env=env, 36 | replay_name=replay_name, 37 | player="random", bolus_noise=bolus_noise, 38 | noise=False, meal_announce=meal_announce, 39 | bolus_overestimate=bolus_overestimate, 40 | missing_data_prob=missing_data_prob, 41 | compression_prob=compression_prob, 42 | seed=seed, params=params 43 | ) 44 | print('Buffer Full with Random policy of size {}'.format(random_timesteps)) 45 | 46 | # Expert data generation ------------------------------------------------- 47 | 48 | if expert_timesteps > 0: 49 | 50 | full_replay = fill_replay(replay_length=expert_timesteps, 51 | replay_name=replay_name, 52 | replay=new_replay, 53 | env=env, player="expert", 54 | bolus_noise=bolus_noise, seed=seed, 55 | bolus_overestimate=bolus_overestimate, 56 | meal_announce=meal_announce, 57 | missing_data_prob=missing_data_prob, 58 | compression_prob=compression_prob, 59 | noise=noise, 60 | params=params 61 | ) 62 | print('Buffer Full with Expert policy of size {}'.format(expert_timesteps)) 63 | else: 64 | return new_replay 65 | 66 | # return the finished replay 67 | return full_replay 68 | 69 | 70 | """ 71 | Create a named replay of specified size with either a random or expert 72 | demonstrator. The replay produced is a list containing individual trajectories 73 | stopping when the agent terminates or the max number of days is reached. 74 | """ 75 | 76 | def fill_replay( 77 | env, replay_name, replay=None, replay_length=100_000, player='random', 78 | meal_announce=0.0, bolus_noise=None, seed=0, params=None, noise=False, 79 | bolus_overestimate=0.0, missing_data_prob=0.0, compression_prob=0.0): 80 | 81 | # Unpack the additional parameters 82 | 83 | # Environment 84 | days = params.get("days", 10) 85 | 86 | # Diabetes 87 | basal_default = params.get("basal_default") 88 | target_blood_glucose = params.get("target_blood_glucose") 89 | 90 | # PID 91 | kp, ki, kd = params.get("kp"), params.get("ki"), params.get("kd") 92 | 93 | # Bolus 94 | cr, cf = params.get("carbohydrate_ratio"), params.get("correction_factor") 95 | 96 | # OU Noise 97 | sigma = params.get("ou_sigma", 0.2) 98 | theta = params.get("ou_theta", 0.0015) 99 | dt = params.get("ou_dt", 0.9) 100 | 101 | # seed numpy and the environment 102 | seed = seed 103 | np.random.seed(seed) 104 | env.seed(seed) 105 | 106 | # create the replay 107 | if replay is None: replay = [] 108 | buffer_not_full = True 109 | replay_progress_freq = replay_length // 10 110 | 111 | # Specify the counter for total timesteps 112 | counter = 0 113 | episode_max = 480 * days 114 | 115 | while buffer_not_full: 116 | 117 | # get the starting state 118 | insulin_dose = np.array([1/3 * basal_default], dtype=np.float32) 119 | meal, done, bg_val = 0, False, env.reset() 120 | time = ((env.env.time.hour * 60) / 3 + env.env.time.minute / 3) / 479 121 | state = np.array([bg_val[0], meal, insulin_dose[0], time], dtype=np.float32) 122 | 123 | # get the meal history for the last 3 hrs 124 | meal_history = np.zeros(60) 125 | 126 | # intiialise the PID and OU noise parameters 127 | integrated_state, previous_error = 0, 0 128 | prev_ou_noise = 0 129 | 130 | # record the trajectory and the current timestep 131 | trajectory = defaultdict(list) 132 | episode_timestep = 0 133 | 134 | # count missing data period 135 | missing_period = 0 136 | 137 | # add compression error 138 | compression_period = 0 139 | compression_size = 0 140 | 141 | while not done: 142 | 143 | # select the basal dose ------------------------------------------ 144 | 145 | # calculate the OU noise from the initial parameters 146 | if noise: 147 | ou_noise = (prev_ou_noise + theta * (0 - prev_ou_noise) * dt + sigma * np.sqrt(dt) * np.random.normal(size=(1,))[0]) 148 | ou_noise = ou_noise * basal_default 149 | prev_ou_noise = ou_noise 150 | 151 | if player == "random": 152 | 153 | # add the noise to a baseline 154 | action = np.array([1/3 * basal_default]) 155 | agent_action = np.copy(action + ou_noise) 156 | 157 | elif player == "expert": 158 | 159 | # add the noise to a PID agent dose 160 | action, previous_error, integrated_state = PID_action( 161 | blood_glucose=bg_val, 162 | previous_error=previous_error, 163 | integrated_state=integrated_state, 164 | target_blood_glucose=target_blood_glucose, 165 | kp=kp, ki=ki, kd=kd, basal_default=basal_default 166 | ) 167 | agent_action = np.copy(action) 168 | 169 | # add on the noise 170 | if noise: agent_action += ou_noise 171 | chosen_action = agent_action 172 | 173 | # select the bolus dose ------------------------------------------ 174 | 175 | if meal > 0: 176 | 177 | # save the adjusted meal 178 | adjusted_meal = meal 179 | 180 | # add some calculation error in bolus 181 | if bolus_noise: 182 | adjusted_meal += bolus_noise * adjusted_meal * np.random.uniform(-1, 1, 1)[0] 183 | 184 | # add a bias to the bolus estimation 185 | adjusted_meal += bolus_overestimate * meal 186 | adjusted_meal = max(0, adjusted_meal) 187 | 188 | # calculate the bolus dose 189 | chosen_action = calculate_bolus( 190 | blood_glucose=bg_val, meal_history=meal_history, 191 | current_meal=adjusted_meal, carbohyrdate_ratio=cr, 192 | correction_factor=cf, 193 | target_blood_glucose=target_blood_glucose 194 | ) 195 | 196 | # amend the agent action to the dose 197 | chosen_action += agent_action 198 | 199 | # take a step in the environment ---------------------------------- 200 | 201 | # update the state and get the true reward 202 | next_bg_val, _, done, info = env.step(chosen_action) 203 | reward = -calculate_risk(next_bg_val) 204 | 205 | # announce a meal ------------------------------------------- 206 | 207 | # meal announcement 208 | meal_input = meal 209 | if meal_announce != 0.0: 210 | 211 | # get times + meal schedule 212 | current_time = env.env.time.hour * 60 + env.env.time.minute 213 | future_time = current_time + meal_announce - 1 214 | meal_scenario = env.env.scenario.scenario["meal"] 215 | future_meal = 0 216 | 217 | # check for future meal 218 | if future_time in meal_scenario["time"]: 219 | index = meal_scenario["time"].index(future_time) 220 | future_meal = meal_scenario["amount"][index] 221 | 222 | meal_input = future_meal / 3 223 | 224 | # configure the next state ---------------------------------------- 225 | 226 | # add missing data to the dataset 227 | rand = np.random.rand() 228 | if (rand < missing_data_prob) or (missing_period > 0): 229 | if missing_period < 1: 230 | prev_bg = 144 # bg_val[0] 231 | missing_period = np.random.randint(10) 232 | next_bg_val = [next_bg_val[0]] 233 | next_bg_val[0] = 144 # prev_bg 234 | 235 | # add compression error 236 | rand = np.random.rand() 237 | if (rand < compression_prob) or (compression_period > 0): 238 | if compression_period < 1: 239 | compression_period = np.random.randint(10) 240 | compression_size = np.random.randint(30) 241 | next_bg_val = [next_bg_val[0]] 242 | next_bg_val[0] -= compression_size 243 | 244 | # step forward in time 245 | missing_period -= 1 246 | compression_period -= 1 247 | 248 | time = ((env.env.time.hour * 60)/3 + env.env.time.minute/3)/479 249 | next_state = np.array([next_bg_val[0], meal_input, chosen_action[0], time], dtype=np.float32) 250 | 251 | # add a termination penalty 252 | if done: 253 | reward = -1e5 254 | 255 | # update the replay --------------------------------------------- 256 | 257 | # update the replay with trajectory 258 | sample = [('reward', reward), ('state', state), ('next_state', next_state), 259 | ('action', agent_action), ('done', done)] 260 | 261 | for key, value in sample: 262 | trajectory[key].append(value) 263 | 264 | # add the new trajectory to the replay 265 | if done or episode_timestep == episode_max: 266 | replay.append(trajectory) 267 | break 268 | 269 | # update the variables ----------------------------------------- 270 | 271 | # update the meal history 272 | meal_history = np.append(meal_history, meal) 273 | meal_history = np.delete(meal_history, 0) 274 | 275 | # update the state 276 | bg_val, state, meal = next_bg_val, next_state, info['meal'] 277 | counter += 1 278 | episode_timestep += 1 279 | 280 | # save the replay ---------------------------------------------- 281 | 282 | # visualise replay size 283 | if counter % replay_progress_freq == 0: 284 | print('Replay size: {}'.format(counter)) 285 | with open("./Replays/" + replay_name + ".txt", "wb") as file: 286 | pickle.dump(replay, file) 287 | 288 | # full termination condition ----------------------------------- 289 | 290 | # add termination when full 291 | if counter == replay_length: 292 | buffer_not_full = False 293 | replay.append(trajectory) 294 | return replay 295 | -------------------------------------------------------------------------------- /utils/data_processing.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sat Apr 2 13:29:00 2022 5 | """ 6 | 7 | """ 8 | Functions for converting the replay output into the correct form for the 9 | chosen agent. 10 | """ 11 | 12 | import numpy as np 13 | import random, torch 14 | 15 | """ 16 | Converts a list of trajectories gathered from the data_colleciton algorithm into 17 | a replay with samples appropriate for model training along with state and action 18 | means and stds. 19 | """ 20 | def unpackage_replay(trajectories, empty_replay, data_processing="condensed", sequence_length=80, params=None): 21 | 22 | # TODO: add functionality to change gamma if necessary 23 | gamma = 1.0 24 | 25 | # initialise the data lists 26 | states, rewards, actions, dones = [], [], [], [] 27 | 28 | for path in trajectories: 29 | 30 | # states include blood glucose, meal carbs, insulin dose, time 31 | states += path['state'] 32 | rewards += path['reward'] 33 | actions += path['action'] 34 | dones += path['done'] 35 | 36 | # ensure that the last state is always a terminal state 37 | dones[-1] = True 38 | 39 | # initialise the lists 40 | processed_states, processed_next_states, processed_rewards, processed_actions = [], [], [], [] 41 | processed_dones, processed_timesteps, processed_reward_to_go, processed_last_actions = [], [], [], [] 42 | decay_state = np.arange(1 / (sequence_length + 2), 1, 1 / (sequence_length + 2)) 43 | counter = 0 44 | 45 | # Condense the state ------------------------------------------------- 46 | 47 | # 4hr | 3.5hr | 3hr | 2.5hr | 2hr | 1.5hr | 1hr | 0.5hr | 0hr | meal_on_board | insulin_on_board 48 | if data_processing == "condensed": 49 | 50 | for idx, state in enumerate(states): 51 | 52 | # find the next done index 53 | if idx == 0 or dones[max(idx - 1, 0)]: 54 | done_index = idx + dones[idx:].index(True) 55 | 56 | # if there are 80 states previously 57 | if counter >= (sequence_length) and idx + 1 != len(states): 58 | 59 | # add rewards, actions, dones and timestep label 60 | processed_rewards.append(rewards[idx]) 61 | processed_actions.append(actions[idx]) 62 | processed_last_actions.append(actions[idx - 1]) 63 | processed_dones.append(dones[idx]) 64 | processed_timesteps.append(counter) 65 | processed_reward_to_go.append(sum(rewards[idx: done_index])) 66 | 67 | # current state ----------------------------------------- 68 | 69 | # unpackage the values 70 | related_states = states[idx - sequence_length: idx + 1] 71 | related_bgs, related_meals, related_insulins, _ = zip(*related_states) 72 | 73 | # extract the correct metrics 74 | extracted_bg = related_bgs[::10] 75 | meals_on_board = [np.sum(np.array(related_meals) * decay_state)] 76 | insulin_on_board = [np.sum(np.array(related_insulins) * decay_state)] 77 | 78 | # append the state 79 | processed_state = list(extracted_bg) + meals_on_board + insulin_on_board 80 | processed_states.append(processed_state) 81 | 82 | # next state ----------------------------------------- 83 | 84 | # unpackage the values 85 | related_next_states = states[(idx - sequence_length) + 1: idx + 1 + 1] 86 | related_next_bgs, related_next_meals, related_next_insulins, _ = zip(*related_next_states) 87 | 88 | # extract the correct metrics 89 | extracted_next_bg = related_next_bgs[::10] 90 | next_meals_on_board = [np.sum(np.array(related_next_meals) * decay_state)] 91 | next_insulin_on_board = [np.sum(np.array(related_next_insulins) * decay_state)] 92 | 93 | # append the state 94 | processed_next_state = list(extracted_next_bg) + next_meals_on_board + next_insulin_on_board 95 | processed_next_states.append(processed_next_state) 96 | 97 | # update the counter 98 | counter += 1 99 | if dones[idx]: 100 | counter = 0 101 | 102 | 103 | # Create a sequence ------------------------------------------------- 104 | 105 | elif data_processing == "sequence": 106 | 107 | for idx, state in enumerate(states): 108 | 109 | # find the next done index 110 | if idx == 0 or dones[max(idx - 1, 0)]: 111 | done_index = idx + dones[idx:].index(True) 112 | 113 | # if there are 80 states previously 114 | if counter >= (sequence_length) and idx + 1 != len(states): 115 | 116 | # add rewards, actions and dones 117 | processed_rewards.append(rewards[idx - sequence_length:idx]) 118 | processed_actions.append(actions[idx - sequence_length:idx]) 119 | processed_last_actions.append(actions[(idx - sequence_length) - 1:(idx - 1)]) 120 | processed_dones.append(dones[idx - sequence_length:idx]) 121 | processed_timesteps.append(list(range(counter - sequence_length, counter))) 122 | 123 | # get the reward_to_go 124 | rewards_to_go = [sum(rewards[(idx + 1) : done_index])] 125 | for i in range(sequence_length - 1): 126 | rewards_to_go.append(rewards_to_go[-1] + rewards[idx - i]) 127 | processed_reward_to_go.append(rewards_to_go[::-1]) 128 | 129 | # add the state and next_state 130 | extracted_states = [state[:3] for state in states[idx - sequence_length:idx]] 131 | processed_states.append(extracted_states) 132 | processed_next_states.append(extracted_states[1:] + [[0, 0, 0]]) 133 | 134 | # update the counter 135 | counter += 1 136 | if dones[idx]: 137 | counter = 0 138 | 139 | # Normalisation ------------------------------------------------------ 140 | array_states = np.array(processed_states) 141 | array_actions = np.array(processed_actions) 142 | array_rewards = np.array(processed_rewards) 143 | 144 | if data_processing == "condensed": 145 | 146 | # ensure the state mean and std are consistent across blood glucose 147 | state_mean, state_std = np.mean(array_states, axis=0), np.std(array_states, axis=0) 148 | action_mean, action_std = np.mean(array_actions, axis=0), np.std(array_actions, axis=0) 149 | reward_mean, reward_std = np.mean(array_rewards, axis=0), np.std(array_rewards, axis=0) 150 | state_mean[:-2], state_std[:-2] = state_mean[0], state_std[0] 151 | 152 | elif data_processing == "sequence": 153 | 154 | # reshape array and calculate mean and std 155 | state_size, action_size = array_states.shape[2], array_actions.shape[2] 156 | state_mean = np.mean(array_states.reshape(-1, state_size), axis=0) 157 | state_std = np.std(array_states.reshape(-1, state_size), axis=0) 158 | action_mean = np.mean(array_actions.reshape(-1, action_size), axis=0) 159 | action_std = np.std(array_actions.reshape(-1, action_size), axis=0) 160 | reward_mean = np.mean(array_rewards.reshape(-1, 1), axis=0) 161 | reward_std = np.std(array_rewards.reshape(-1, 1), axis=0) 162 | 163 | # load in new replay ---------------------------------------------------- 164 | 165 | # TODO: do hidden_in and hidden_out need to be explicitly added 166 | 167 | for idx, state in enumerate(processed_states): 168 | empty_replay.append((state, processed_actions[idx], processed_rewards[idx], processed_next_states[idx], 169 | processed_dones[idx], processed_timesteps[idx], processed_reward_to_go[idx], 170 | processed_actions[idx], None, None 171 | )) 172 | 173 | full_replay = empty_replay 174 | 175 | return full_replay, state_mean, state_std, action_mean, action_std, reward_mean, reward_std 176 | 177 | """ 178 | Extracts a batch of data from the full replay and puts it in an appropriate form 179 | """ 180 | def get_batch(replay, batch_size, data_processing="condensed", sequence_length=80, device='cpu', online=True, params=None): 181 | 182 | # Environment 183 | state_size = params.get("state_size") 184 | state_mean = params.get("state_mean") 185 | state_std = params.get("state_std") 186 | action_mean = params.get("action_mean") 187 | action_std = params.get("action_std") 188 | reward_mean = params.get("reward_mean") 189 | reward_std = params.get("reward_std") 190 | reward_scale = params.get("reward_scale", 1.0) 191 | 192 | # sample a minibatch 193 | minibatch = random.sample(replay, batch_size) 194 | 195 | if data_processing == "condensed": 196 | state = np.zeros((batch_size, state_size), dtype=np.float32) 197 | action = np.zeros(batch_size, dtype=np.float32) 198 | reward = np.zeros(batch_size, dtype=np.float32) 199 | next_state = np.zeros((batch_size, state_size), dtype=np.float32) 200 | done = np.zeros(batch_size, dtype=np.uint8) 201 | timestep = np.zeros(batch_size, dtype=np.float32) 202 | reward_to_go = np.zeros(batch_size, dtype=np.float32) 203 | 204 | last_action = np.zeros(batch_size, dtype=np.float32) 205 | hidden_in = [0] * batch_size 206 | hidden_out = [0] * batch_size 207 | 208 | elif data_processing == "sequence": 209 | 210 | state = np.zeros((batch_size, sequence_length, state_size), dtype=np.float32) 211 | action = np.zeros((batch_size, sequence_length), dtype=np.float32) 212 | reward = np.zeros((batch_size, sequence_length), dtype=np.float32) 213 | next_state = np.zeros((batch_size, sequence_length, state_size), dtype=np.float32) 214 | done = np.zeros((batch_size, sequence_length), dtype=np.uint8) 215 | timestep = np.zeros((batch_size, sequence_length), dtype=np.float32) 216 | reward_to_go = np.zeros((batch_size, sequence_length), dtype=np.float32) 217 | last_action = np.zeros((batch_size, sequence_length), dtype=np.float32) 218 | hidden_in = [0] * batch_size 219 | hidden_out = [0] * batch_size 220 | 221 | # unpack the batch 222 | for i in range(len(minibatch)): 223 | state[i], action[i], reward[i], next_state[i], done[i], timestep[i], reward_to_go[i], last_action[i], hidden_in[i], hidden_out[i] = minibatch[i] 224 | # convert to torch 225 | state = torch.FloatTensor((state - state_mean) / state_std).to(device) 226 | action = torch.FloatTensor((action - action_mean) / action_std).to(device) 227 | last_action = torch.FloatTensor((last_action - action_mean) / action_std).to(device) 228 | next_state = torch.FloatTensor((next_state - state_mean) / state_std).to(device) 229 | done = torch.FloatTensor(1 - done).to(device) 230 | reward_to_go = torch.FloatTensor(reward_to_go / reward_scale).to(device) 231 | timestep = torch.tensor(timestep, dtype=torch.int32).to(device) 232 | 233 | # get norm of reward 234 | if reward_mean: reward = torch.FloatTensor(reward_scale * (reward - reward_mean) / reward_std).to(device) 235 | else: reward = torch.FloatTensor(reward).to(device) 236 | 237 | if hidden_in[0] is not None and online: 238 | 239 | # process lstm layers 240 | if len(hidden_in[0]) > 1: 241 | layer_in, cell_in = list(zip(*hidden_in)) 242 | layer_out, cell_out = list(zip(*hidden_out)) 243 | layer_in, cell_in = torch.cat(layer_in, 1).to(device).detach(), torch.cat(cell_in, 1).to(device).detach() 244 | layer_out, cell_out = torch.cat(layer_out, 1).to(device).detach(), torch.cat(cell_out, 1).to(device).detach() 245 | hidden_in, hidden_out = (layer_in, cell_in), (layer_out, cell_out) 246 | 247 | # process gru layers 248 | else: 249 | layer_in = torch.cat(hidden_in, 1).to(device).detach() 250 | layer_out = torch.cat(hidden_out, 1).to(device).detach() 251 | hidden_in, hidden_out = layer_in, layer_out 252 | 253 | # Modify Dimensions 254 | action = action.unsqueeze(-1) 255 | last_action = last_action.unsqueeze(-1) 256 | reward = reward.unsqueeze(-1) 257 | reward_to_go = reward_to_go.unsqueeze(-1) 258 | done = done.unsqueeze(-1) 259 | 260 | return state, action, reward, next_state, done, timestep, reward_to_go, last_action, hidden_in, hidden_out 261 | -------------------------------------------------------------------------------- /TD3_BC.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sat Apr 2 13:35:30 2022 5 | 6 | """ 7 | 8 | import numpy as np 9 | import copy, random, torch, gym, pickle 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from collections import deque 13 | 14 | from utils import unpackage_replay, get_batch, test_algorithm, create_graph 15 | 16 | 17 | """ 18 | Simple feedforward neural network for the Actor. 19 | """ 20 | class Actor(nn.Module): 21 | def __init__(self, state_dim, action_dim, max_action): 22 | super(Actor, self).__init__() 23 | 24 | self.l1 = nn.Linear(state_dim, 256) 25 | self.l2 = nn.Linear(256, 256) 26 | self.l3 = nn.Linear(256, action_dim) 27 | 28 | self.max_action = max_action 29 | 30 | def forward(self, state): 31 | a = F.relu(self.l1(state)) 32 | a = F.relu(self.l2(a)) 33 | return self.max_action * torch.tanh(self.l3(a)) 34 | 35 | 36 | """ 37 | Simple feedforward neural network for the Critic. 38 | """ 39 | class Critic(nn.Module): 40 | def __init__(self, state_dim, action_dim): 41 | super(Critic, self).__init__() 42 | 43 | # Q1 architecture 44 | self.l1 = nn.Linear(state_dim + action_dim, 256) 45 | self.l2 = nn.Linear(256, 256) 46 | self.l3 = nn.Linear(256, 1) 47 | 48 | # Q2 architecture 49 | self.l4 = nn.Linear(state_dim + action_dim, 256) 50 | self.l5 = nn.Linear(256, 256) 51 | self.l6 = nn.Linear(256, 1) 52 | 53 | 54 | def forward(self, state, action): 55 | sa = torch.cat([state, action], 1) 56 | 57 | q1 = F.relu(self.l1(sa)) 58 | q1 = F.relu(self.l2(q1)) 59 | q1 = self.l3(q1) 60 | 61 | q2 = F.relu(self.l4(sa)) 62 | q2 = F.relu(self.l5(q2)) 63 | q2 = self.l6(q2) 64 | return q1, q2 65 | 66 | 67 | def Q1(self, state, action): 68 | sa = torch.cat([state, action], 1) 69 | 70 | q1 = F.relu(self.l1(sa)) 71 | q1 = F.relu(self.l2(q1)) 72 | q1 = self.l3(q1) 73 | return q1 74 | 75 | 76 | class td3_bc: 77 | 78 | def __init__(self, init_seed, patient_params, params): 79 | 80 | # ENVIRONMENT 81 | self.params = params 82 | self.env_name = patient_params["env_name"] 83 | self.folder_name = patient_params["folder_name"] 84 | self.replay_name = patient_params["replay_name"] 85 | self.bas = patient_params["u2ss"] * (patient_params["BW"] / 6000) * 3 86 | self.env = gym.make(self.env_name) 87 | self.action_size, self.state_size = 1, 11 88 | self.params["state_size"] = self.state_size 89 | self.sequence_length = 80 90 | self.data_processing = "condensed" 91 | 92 | # HYPERPARAMETERS 93 | self.device = params["device"] 94 | self.batch_size = 256 95 | self.actor_lr = 3e-4 96 | self.critic_lr = 3e-4 97 | self.gamma = 0.99 98 | self.tau = 0.005 99 | self.policy_noise = 0.2 100 | self.noise_clip = 0.5 101 | self.policy_freq = 2 102 | self.alpha = 2.5 103 | 104 | # DISPLAY 105 | self.pid_bg, self.pid_insulin, self.pid_action, self.pid_reward = [], [], [], 0 106 | self.training_timesteps = params["training_timesteps"] 107 | self.training_progress_freq = int(self.training_timesteps // 10) 108 | 109 | # SEEDING 110 | self.train_seed = init_seed 111 | self.env.seed(self.train_seed) 112 | np.random.seed(self.train_seed) 113 | torch.manual_seed(self.train_seed) 114 | random.seed(self.train_seed) 115 | 116 | # MEMORY 117 | self.memory_size = self.training_timesteps 118 | self.memory = deque(maxlen=self.memory_size) 119 | 120 | """ 121 | Initialise the Actor and the Critic. 122 | """ 123 | def init_model(self): 124 | 125 | # actor 126 | self.actor = Actor(self.state_size, self.action_size, self.max_action).to(self.device) 127 | self.actor_target = copy.deepcopy(self.actor) 128 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=self.actor_lr) 129 | 130 | # critic 131 | self.critic = Critic(self.state_size, self.action_size).to(self.device) 132 | self.critic_target = copy.deepcopy(self.critic) 133 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=self.critic_lr) 134 | 135 | 136 | """ 137 | Save the learned models. 138 | """ 139 | def save_model(self): 140 | 141 | torch.save(self.actor.state_dict(), './Models/'+ str(self.env_name) + str(self.train_seed) +'TD3_offline_BC_weights_actor' + self.replay_name.split("-")[-1]) 142 | torch.save(self.critic.state_dict(), './Models/'+ str(self.env_name) + str(self.train_seed) +'TD3_offline_BC_weights_critic' + self.replay_name.split("-")[-1]) 143 | 144 | """ 145 | Load pre-trained weights for testing. 146 | """ 147 | def load_model(self, name): 148 | 149 | # load actor 150 | self.actor.load_state_dict(torch.load(name + '_actor')) 151 | self.actor_target = copy.deepcopy(self.actor) 152 | self.actor.eval() 153 | 154 | # load critic 155 | self.critic.load_state_dict(torch.load(name + '_critic')) 156 | self.critic_target = copy.deepcopy(self.critic) 157 | self.critic.eval() 158 | 159 | """ 160 | Determine the action based on the state. 161 | """ 162 | def select_action(self, state, action, timestep, prev_reward): 163 | 164 | # Feed state into model 165 | with torch.no_grad(): 166 | tensor_state = torch.FloatTensor(state.reshape(1, -1)).to(self.device) 167 | tensor_action = self.actor(tensor_state) 168 | 169 | return tensor_action.cpu().data.numpy().flatten() 170 | 171 | """ 172 | Train the model on a pre-collected sample of training data. 173 | """ 174 | def train_model(self): 175 | 176 | # load the replay buffer 177 | with open("./Replays/" + self.replay_name + ".txt", "rb") as file: 178 | trajectories = pickle.load(file) 179 | 180 | # Process the replay -------------------------------------------------- 181 | 182 | # unpackage the replay 183 | self.memory, self.state_mean, self.state_std, self.action_mean, self.action_std, _, _ = unpackage_replay( 184 | trajectories=trajectories, empty_replay=self.memory, data_processing=self.data_processing, sequence_length=self.sequence_length 185 | ) 186 | 187 | # update the parameters 188 | self.action_std = 1.75 * self.bas * 0.25 / (self.action_std / self.bas) 189 | self.params["state_mean"], self.params["state_std"] = self.state_mean, self.state_std 190 | self.params["action_mean"], self.params["action_std"] = self.action_mean, self.action_std 191 | self.max_action = float(((self.bas * 3.0) - self.action_mean) / self.action_std) 192 | 193 | # initialise the networks 194 | self.init_model() 195 | 196 | print('Processing Complete.') 197 | 198 | for t in range(1, self.training_timesteps + 1): 199 | 200 | # Get the batch ------------------------------------------------ 201 | 202 | # unpackage the samples and split 203 | state, action, reward, next_state, done, _, _, _, _, _ = get_batch( 204 | replay=self.memory, batch_size=self.batch_size, 205 | data_processing=self.data_processing, 206 | sequence_length=self.sequence_length, device=self.device, 207 | params=self.params 208 | ) 209 | 210 | # Training ----------------------------------------------- 211 | 212 | with torch.no_grad(): 213 | 214 | # Select action according to policy and add clipped noise 215 | noise = ( 216 | torch.randn_like(action) * self.policy_noise 217 | ).clamp(-self.noise_clip, self.noise_clip) 218 | 219 | next_action = ( 220 | self.actor_target(next_state) + noise 221 | ).clamp(-self.max_action, self.max_action) 222 | 223 | # Compute the target Q value 224 | target_Q1, target_Q2 = self.critic_target(next_state, next_action) 225 | target_Q = torch.min(target_Q1, target_Q2) 226 | target_Q = reward + done * self.gamma * target_Q 227 | 228 | # Update the critic ------------------------------------------- 229 | 230 | # Get current Q estimates 231 | current_Q1, current_Q2 = self.critic(state, action) 232 | 233 | # Compute critic loss 234 | critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q) 235 | 236 | # Optimize the critic 237 | self.critic_optimizer.zero_grad() 238 | critic_loss.backward() 239 | self.critic_optimizer.step() 240 | 241 | # Perform the actor update --------------------------------------------------- 242 | 243 | # Delayed policy updates 244 | if t % self.policy_freq == 0: 245 | 246 | # Compute the modfied actor loss 247 | pi = self.actor(state) 248 | Q = self.critic.Q1(state, pi) 249 | lmbda = self.alpha / Q.abs().mean().detach() 250 | actor_loss = -lmbda * Q.mean() + F.mse_loss(pi, action) 251 | 252 | # Optimize the actor 253 | self.actor_optimizer.zero_grad() 254 | actor_loss.backward() 255 | self.actor_optimizer.step() 256 | 257 | # Update the frozen target models 258 | for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): 259 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 260 | 261 | for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): 262 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 263 | 264 | # Show progress 265 | if t % self.training_progress_freq == 0: 266 | 267 | # show the updated loss 268 | print('Timesteps {} - Actor Loss {} - Critic Loss {}'.format(t, actor_loss, critic_loss)) 269 | self.save_model() 270 | 271 | """ 272 | Test the learned weights against the PID controller. 273 | """ 274 | def test_model(self, input_seed=0, input_max_timesteps=4800): 275 | 276 | # initialise the environment 277 | env = gym.make(self.env_name) 278 | 279 | # load the replay buffer 280 | with open("./Replays/" + self.replay_name + ".txt", "rb") as file: 281 | trajectories = pickle.load(file) 282 | 283 | # Process the replay -------------------------------------------------- 284 | 285 | # unpackage the replay 286 | self.memory, self.state_mean, self.state_std, self.action_mean, self.action_std, _, _ = unpackage_replay( 287 | trajectories=trajectories, empty_replay=self.memory, data_processing=self.data_processing, sequence_length=self.sequence_length 288 | ) 289 | 290 | # update the parameters 291 | self.action_std = 1.75 * self.bas * 0.25 / (self.action_std / self.bas) 292 | self.params["state_mean"], self.params["state_std"] = self.state_mean, self.state_std 293 | self.params["action_mean"], self.params["action_std"] = self.action_mean, self.action_std 294 | self.max_action = float(((self.bas * 3) - self.action_mean) / self.action_std) 295 | self.init_model() 296 | 297 | # load the learned model 298 | self.load_model('./Models/' + self.folder_name + "/" + "Seed" + str(self.train_seed) + "/" + 'TD3_offline_BC_weights') 299 | test_seed, max_timesteps = input_seed, input_max_timesteps 300 | 301 | # TESTING ------------------------------------------------------------------------------------------- 302 | 303 | # test the algorithm's performance vs pid algorithm 304 | rl_reward, rl_bg, rl_action, rl_insulin, rl_meals, pid_reward, pid_bg, pid_action = test_algorithm( 305 | env=env, agent_action=self.select_action, seed=test_seed, max_timesteps=max_timesteps, 306 | sequence_length=self.sequence_length, data_processing=self.data_processing, 307 | pid_run=False, params=self.params 308 | ) 309 | 310 | # display the results 311 | create_graph( 312 | rl_reward=rl_reward, rl_blood_glucose=rl_bg, rl_action=rl_action, rl_insulin=rl_insulin, 313 | rl_meals=rl_meals, pid_reward=pid_reward, pid_blood_glucose=pid_bg, 314 | pid_action=pid_action, params=self.params 315 | ) -------------------------------------------------------------------------------- /BCQ.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sat Apr 2 13:36:19 2022 5 | 6 | """ 7 | 8 | import numpy as np 9 | import copy, random, torch, gym, pickle 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from collections import deque 13 | 14 | from utils import unpackage_replay, get_batch, test_algorithm, create_graph 15 | 16 | """ 17 | Simple feedforward neural network for the Actor. 18 | """ 19 | class Actor(nn.Module): 20 | def __init__(self, state_dim, action_dim, max_action, phi=0.05): 21 | super(Actor, self).__init__() 22 | self.l1 = nn.Linear(state_dim + action_dim, 400) 23 | self.l2 = nn.Linear(400, 300) 24 | self.l3 = nn.Linear(300, action_dim) 25 | 26 | self.max_action = max_action 27 | self.phi = phi 28 | 29 | def forward(self, state, action): 30 | a = F.relu(self.l1(torch.cat([state, action], 1))) 31 | a = F.relu(self.l2(a)) 32 | a = self.phi * self.max_action * torch.tanh(self.l3(a)) 33 | return (a + action).clamp(-self.max_action, self.max_action) 34 | 35 | """ 36 | Simple feedforward neural network for the Critic. 37 | """ 38 | class Critic(nn.Module): 39 | def __init__(self, state_dim, action_dim): 40 | super(Critic, self).__init__() 41 | self.l1 = nn.Linear(state_dim + action_dim, 400) 42 | self.l2 = nn.Linear(400, 300) 43 | self.l3 = nn.Linear(300, 1) 44 | 45 | self.l4 = nn.Linear(state_dim + action_dim, 400) 46 | self.l5 = nn.Linear(400, 300) 47 | self.l6 = nn.Linear(300, 1) 48 | 49 | def forward(self, state, action): 50 | q1 = F.relu(self.l1(torch.cat([state, action], 1))) 51 | q1 = F.relu(self.l2(q1)) 52 | q1 = self.l3(q1) 53 | 54 | q2 = F.relu(self.l4(torch.cat([state, action], 1))) 55 | q2 = F.relu(self.l5(q2)) 56 | q2 = self.l6(q2) 57 | return q1, q2 58 | 59 | def q1(self, state, action): 60 | q1 = F.relu(self.l1(torch.cat([state, action], 1))) 61 | q1 = F.relu(self.l2(q1)) 62 | q1 = self.l3(q1) 63 | return q1 64 | 65 | """ 66 | Vanilla Variational Auto-Encoder 67 | """ 68 | class VAE(nn.Module): 69 | def __init__(self, state_dim, action_dim, latent_dim, max_action, device): 70 | super(VAE, self).__init__() 71 | self.e1 = nn.Linear(state_dim + action_dim, 750) 72 | self.e2 = nn.Linear(750, 750) 73 | 74 | self.mean = nn.Linear(750, latent_dim) 75 | self.log_std = nn.Linear(750, latent_dim) 76 | 77 | self.d1 = nn.Linear(state_dim + latent_dim, 750) 78 | self.d2 = nn.Linear(750, 750) 79 | self.d3 = nn.Linear(750, action_dim) 80 | 81 | self.max_action = max_action 82 | self.latent_dim = latent_dim 83 | self.device = device 84 | 85 | def forward(self, state, action): 86 | z = F.relu(self.e1(torch.cat([state, action], 1))) 87 | z = F.relu(self.e2(z)) 88 | 89 | mean = self.mean(z) 90 | # Clamped for numerical stability 91 | log_std = self.log_std(z).clamp(-4, 15) 92 | std = torch.exp(log_std) 93 | z = mean + std * torch.randn_like(std) 94 | 95 | u = self.decode(state, z) 96 | 97 | return u, mean, std 98 | 99 | def decode(self, state, z=None): 100 | # When sampling from the VAE, the latent vector is clipped to [-0.5, 0.5] 101 | if z is None: 102 | z = torch.randn((state.shape[0], self.latent_dim)).to(self.device).clamp(-0.5,0.5) 103 | 104 | a = F.relu(self.d1(torch.cat([state, z], 1))) 105 | a = F.relu(self.d2(a)) 106 | return self.max_action * torch.tanh(self.d3(a)) 107 | 108 | 109 | class bcq: 110 | def __init__(self, init_seed, patient_params, params): 111 | 112 | # ENVIRONMENT 113 | self.params = params 114 | self.env_name = patient_params["env_name"] 115 | self.folder_name = patient_params["folder_name"] 116 | self.replay_name = patient_params["replay_name"] 117 | self.bas = patient_params["u2ss"] * (patient_params["BW"] / 6000) * 3 118 | self.env = gym.make(self.env_name) 119 | self.action_size, self.state_size = 1, 11 120 | self.params["state_size"] = self.state_size 121 | self.latent_size = self.action_size * 2 122 | self.sequence_length = 80 123 | self.data_processing = "condensed" 124 | self.device = params["device"] 125 | 126 | # HYPERPARAMETERS 127 | self.batch_size = 100 128 | self.actor_lr = 1e-3 129 | self.critic_lr = 1e-3 130 | self.gamma = 0.99 131 | self.tau = 0.005 132 | self.phi = 0.05 133 | self.lmbda = 0.75 134 | 135 | # DISPLAY 136 | self.training_timesteps = params["training_timesteps"] 137 | self.training_progress_freq = int(self.training_timesteps // 10) 138 | 139 | # SEEDING 140 | self.train_seed = init_seed 141 | self.env.seed(self.train_seed) 142 | np.random.seed(self.train_seed) 143 | torch.manual_seed(self.train_seed) 144 | random.seed(self.train_seed) 145 | 146 | # MEMORY 147 | self.memory_size = self.training_timesteps 148 | self.memory = deque(maxlen=self.memory_size) 149 | 150 | """ 151 | Initalise the neural networks. 152 | """ 153 | def init_model(self): 154 | 155 | # Actor 156 | self.actor = Actor(self.state_size, self.action_size, self.max_action, self.phi).to(self.device) 157 | self.actor_target = copy.deepcopy(self.actor) 158 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=self.actor_lr) 159 | 160 | # Critic 161 | self.critic = Critic(self.state_size, self.action_size).to(self.device) 162 | self.critic_target = copy.deepcopy(self.critic) 163 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=self.critic_lr) 164 | 165 | # VAE 166 | self.vae = VAE(self.state_size, self.action_size, self.latent_size, self.max_action, self.device).to(self.device) 167 | self.vae_optimizer = torch.optim.Adam(self.vae.parameters()) 168 | 169 | """ 170 | Save the learned models. 171 | """ 172 | def save_model(self): 173 | torch.save(self.actor.state_dict(), './Models/' + str(self.env_name) + str(self.train_seed) + 'BCQ_weights_actor') 174 | torch.save(self.critic.state_dict(), './Models/' + str(self.env_name) + str(self.train_seed) + 'BCQ_weights_critic') 175 | torch.save(self.vae.state_dict(), './Models/' + str(self.env_name) + str(self.train_seed) + 'BCQ_weights_vae') 176 | 177 | """ 178 | Load pre-trained weights for testing. 179 | """ 180 | def load_model(self, name): 181 | 182 | # load actor 183 | self.actor.load_state_dict(torch.load(name + '_actor')) 184 | self.actor_target = copy.deepcopy(self.actor) 185 | self.actor.eval() 186 | 187 | # load critic 188 | self.critic.load_state_dict(torch.load(name + '_critic')) 189 | self.critic_target = copy.deepcopy(self.critic) 190 | self.critic.eval() 191 | 192 | # load vae 193 | self.vae.load_state_dict(torch.load(name + '_vae')) 194 | self.vae_target = copy.deepcopy(self.vae) 195 | self.vae.eval() 196 | 197 | """ 198 | Determine the action based on the state. 199 | """ 200 | def select_action(self, state, action, timestep, prev_reward): 201 | 202 | # Feed state into model 203 | with torch.no_grad(): 204 | tensor_state = torch.FloatTensor(state.reshape(1, -1)).repeat(self.batch_size, 1).to(self.device) 205 | tensor_action = self.actor(tensor_state, self.vae.decode(tensor_state)) 206 | q1 = self.critic.q1(tensor_state, tensor_action) 207 | ind = q1.argmax(0) 208 | 209 | return tensor_action[ind].cpu().data.numpy().flatten() 210 | 211 | """ 212 | Train the model on a pre-collected sample of training data. 213 | """ 214 | def train_model(self): 215 | 216 | # load the replay buffer 217 | with open("./Replays/" + self.replay_name + ".txt", "rb") as file: 218 | trajectories = pickle.load(file) 219 | 220 | # Process the replay -------------------------------------------------- 221 | 222 | # unpackage the replay 223 | self.memory, self.state_mean, self.state_std, self.action_mean, self.action_std, _, _ = unpackage_replay( 224 | trajectories=trajectories, empty_replay=self.memory, data_processing=self.data_processing, sequence_length=self.sequence_length 225 | ) 226 | 227 | # update the parameters 228 | self.action_std = 1.75 * self.bas * 0.25 / (self.action_std / self.bas) 229 | self.params["state_mean"], self.params["state_std"] = self.state_mean, self.state_std 230 | self.params["action_mean"], self.params["action_std"] = self.action_mean, self.action_std 231 | self.max_action = float(((self.bas * 3) - self.action_mean) / self.action_std) 232 | 233 | # initialise the networks 234 | self.init_model() 235 | 236 | print('Processing Complete.') 237 | # ------------------------------------------------------------------------ 238 | 239 | for t in range(1, self.training_timesteps + 1): 240 | 241 | # Get the batch ------------------------------------------------ 242 | 243 | # unpackage the samples and split 244 | state, action, reward, next_state, done, _, _, _, _, _ = get_batch( 245 | replay=self.memory, batch_size=self.batch_size, 246 | data_processing=self.data_processing, 247 | sequence_length=self.sequence_length, device=self.device, 248 | params=self.params 249 | ) 250 | 251 | # Variational Auto-Encoder Training -------------------------------------------------------- 252 | 253 | recon, mean, std = self.vae(state, action) 254 | 255 | recon_loss = F.mse_loss(recon, action) 256 | KL_loss = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean() 257 | vae_loss = recon_loss + 0.5 * KL_loss 258 | 259 | self.vae_optimizer.zero_grad() 260 | vae_loss.backward() 261 | self.vae_optimizer.step() 262 | 263 | # Critic Training -------------------------------------------------------- 264 | 265 | with torch.no_grad(): 266 | # Duplicate next state 10 times 267 | next_state = torch.repeat_interleave(next_state, 10, 0) 268 | 269 | # Compute value of perturbed actions sampled from the VAE 270 | target_Q1, target_Q2 = self.critic_target(next_state, self.actor_target(next_state, self.vae.decode(next_state))) 271 | 272 | # Soft Clipped Double Q-learning 273 | target_Q = self.lmbda * torch.min(target_Q1, target_Q2) + (1. - self.lmbda) * torch.max(target_Q1, target_Q2) 274 | 275 | # Take max over each action sampled from the VAE 276 | target_Q = target_Q.reshape(self.batch_size, -1).max(1)[0].reshape(-1, 1) 277 | target_Q = reward + done * self.gamma * target_Q 278 | 279 | current_Q1, current_Q2 = self.critic(state, action) 280 | critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q) 281 | 282 | self.critic_optimizer.zero_grad() 283 | critic_loss.backward() 284 | self.critic_optimizer.step() 285 | 286 | # Pertubation Model / Action Training -------------------------------------------------------- 287 | 288 | sampled_actions = self.vae.decode(state) 289 | perturbed_actions = self.actor(state, sampled_actions) 290 | 291 | # Update through DPG 292 | actor_loss = -self.critic.q1(state, perturbed_actions).mean() 293 | 294 | self.actor_optimizer.zero_grad() 295 | actor_loss.backward() 296 | self.actor_optimizer.step() 297 | 298 | # Update Target Networks 299 | for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): 300 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 301 | 302 | for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): 303 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 304 | 305 | # Show progress 306 | if t % self.training_progress_freq == 0: 307 | 308 | # show the updated loss 309 | print('Timesteps {} - Actor Loss {} - Critic Loss {} - Encoder Loss {}'.format(t, actor_loss, critic_loss, vae_loss)) 310 | self.save_model() 311 | 312 | """ 313 | Test the learned weights against the PID controller. 314 | """ 315 | def test_model(self, training=False, input_seed=0, input_max_timesteps=4800): 316 | 317 | # initialise the environment 318 | env = gym.make(self.env_name) 319 | 320 | # load the replay buffer 321 | with open("./Replays/" + self.replay_name + ".txt", "rb") as file: 322 | trajectories = pickle.load(file) 323 | 324 | # Process the replay -------------------------------------------------- 325 | 326 | # unpackage the replay 327 | self.memory, self.state_mean, self.state_std, self.action_mean, self.action_std, _, _ = unpackage_replay( 328 | trajectories=trajectories, empty_replay=self.memory, data_processing=self.data_processing, sequence_length=self.sequence_length 329 | ) 330 | 331 | # adding this allows better results? 332 | self.action_std = 1.75 * self.bas * 0.25 / (self.action_std / self.bas) 333 | self.params["state_mean"], self.params["state_std"] = self.state_mean, self.state_std 334 | self.params["action_mean"], self.params["action_std"] = self.action_mean, self.action_std 335 | self.max_action = float(((self.bas * 3) - self.action_mean) / self.action_std) 336 | self.init_model() 337 | 338 | # load the learned model 339 | self.load_model('./Models/' + self.folder_name + "/" + "Seed" + str(self.train_seed) + "/" + 'BCQ_weights') 340 | test_seed, max_timesteps = input_seed, input_max_timesteps 341 | 342 | # test the algorithm's performance vs pid algorithm 343 | rl_reward, rl_bg, rl_action, rl_insulin, rl_meals, pid_reward, pid_bg, pid_action = test_algorithm( 344 | env=env, agent_action=self.select_action, seed=test_seed, max_timesteps=max_timesteps, 345 | sequence_length=self.sequence_length, data_processing=self.data_processing, 346 | pid_run=False, params=self.params 347 | ) 348 | 349 | # display the results 350 | create_graph( 351 | rl_reward=rl_reward, rl_blood_glucose=rl_bg, rl_action=rl_action, rl_insulin=rl_insulin, 352 | rl_meals=rl_meals, pid_reward=pid_reward, pid_blood_glucose=pid_bg, 353 | pid_action=pid_action, params=self.params 354 | ) 355 | 356 | 357 | -------------------------------------------------------------------------------- /CQL.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sat Apr 2 13:36:04 2022 5 | 6 | """ 7 | 8 | import numpy as np 9 | import copy, random, torch, gym, pickle 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from collections import deque 13 | 14 | from utils import unpackage_replay, get_batch, test_algorithm, create_graph 15 | 16 | 17 | """ 18 | Create a scalar constant 19 | """ 20 | class Scalar(nn.Module): 21 | def __init__(self, init_value): 22 | super().__init__() 23 | self.constant = nn.Parameter( 24 | torch.tensor(init_value, dtype=torch.float32) 25 | ) 26 | 27 | def forward(self): 28 | return self.constant 29 | 30 | """ 31 | Extend and repast the tensor along axis and repeat it 32 | """ 33 | def extend_and_repeat(tensor, dim, repeat): 34 | ones_shape = [1 for _ in range(tensor.ndim + 1)] 35 | ones_shape[dim] = repeat 36 | return torch.unsqueeze(tensor, dim) * tensor.new_ones(ones_shape) 37 | 38 | """ 39 | Forward the q function with multiple actions on each state, to be used as a decorator 40 | """ 41 | def multiple_action_q_function(forward): 42 | def wrapped(self, observations, actions, **kwargs): 43 | multiple_actions = False 44 | batch_size = observations.shape[0] 45 | if actions.ndim == 3 and observations.ndim == 2: 46 | multiple_actions = True 47 | observations = extend_and_repeat(observations, 1, actions.shape[1]).reshape(-1, observations.shape[-1]) 48 | actions = actions.reshape(-1, actions.shape[-1]) 49 | q_values = forward(self, observations, actions, **kwargs) 50 | if multiple_actions: 51 | q_values = q_values.reshape(batch_size, -1) 52 | return q_values 53 | return wrapped 54 | 55 | 56 | """ 57 | Fully connected feedforward neural network. 58 | """ 59 | class FullyConnectedNetwork(nn.Module): 60 | 61 | def __init__(self, input_dim, output_dim, arch='256-256'): 62 | super().__init__() 63 | 64 | # get the parameters 65 | self.input_dim = input_dim 66 | self.output_dim = output_dim 67 | self.arch = arch 68 | 69 | d = input_dim 70 | modules = [] 71 | hidden_sizes = [int(h) for h in arch.split('-')] 72 | 73 | # add linear layers to the network 74 | for hidden_size in hidden_sizes: 75 | fc = nn.Linear(d, hidden_size) 76 | modules.append(fc) 77 | modules.append(nn.ReLU()) 78 | d = hidden_size 79 | 80 | # add the output layer 81 | last_fc = nn.Linear(d, output_dim) 82 | modules.append(last_fc) 83 | 84 | # construct the network 85 | self.network = nn.Sequential(*modules) 86 | 87 | def forward(self, input_tensor): 88 | return self.network(input_tensor) 89 | 90 | """ 91 | Fully connected Q function approximator. 92 | """ 93 | class FullyConnectedQFunction(nn.Module): 94 | 95 | def __init__(self, observation_dim, action_dim, arch='256-256'): 96 | super().__init__() 97 | 98 | # get the parameters 99 | self.observation_dim = observation_dim 100 | self.action_dim = action_dim 101 | self.arch = arch 102 | 103 | # initialise the network 104 | self.network = FullyConnectedNetwork( 105 | observation_dim + action_dim, 1 106 | ) 107 | 108 | @multiple_action_q_function 109 | def forward(self, observations, actions): 110 | 111 | # concatentate the tensors and feed unto network 112 | input_tensor = torch.cat([observations, actions], dim=-1) 113 | return torch.squeeze(self.network(input_tensor), dim=-1) 114 | 115 | """ 116 | Reparamterised Policy 117 | """ 118 | class ReparameterizedTanhGaussian(nn.Module): 119 | 120 | def __init__(self, log_std_min=-20.0, log_std_max=2.0): 121 | super().__init__() 122 | 123 | # get the parameters 124 | self.log_std_min = log_std_min 125 | self.log_std_max = log_std_max 126 | 127 | def log_prob(self, mean, log_std, sample): 128 | 129 | # restrict log probability and then calculate exponential 130 | log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max) 131 | std = torch.exp(log_std) 132 | 133 | # construct the action distribution 134 | action_distribution = torch.distributions.transformed_distribution.TransformedDistribution( 135 | torch.distributions.Normal(mean, std), torch.distributions.transforms.TanhTransform(cache_size=1) 136 | ) 137 | 138 | return torch.sum(action_distribution.log_prob(sample), dim=-1) 139 | 140 | def forward(self, mean, log_std, deterministic=False): 141 | 142 | # restrict log probability and then calculate exponential 143 | log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max) 144 | std = torch.exp(log_std) 145 | 146 | # construct the action distribution 147 | action_distribution = torch.distributions.transformed_distribution.TransformedDistribution( 148 | torch.distributions.Normal(mean, std), torch.distributions.transforms.TanhTransform(cache_size=1) 149 | ) 150 | 151 | # sample from the action distribution 152 | if deterministic: action_sample = torch.tanh(mean) 153 | else: action_sample = action_distribution.rsample() 154 | 155 | log_prob = torch.sum( 156 | action_distribution.log_prob(action_sample), dim=-1 157 | ) 158 | 159 | return action_sample, log_prob 160 | 161 | """ 162 | Tanh Gaussian Policy 163 | """ 164 | class TanhGaussianPolicy(nn.Module): 165 | 166 | def __init__(self, observation_dim, action_dim, arch='256-256', 167 | log_std_multiplier=1.0, log_std_offset=-1.0): 168 | 169 | super().__init__() 170 | 171 | # get the parameters 172 | self.observation_dim = observation_dim 173 | self.action_dim = action_dim 174 | self.arch = arch 175 | 176 | # initialise the base network 177 | self.base_network = FullyConnectedNetwork( 178 | observation_dim, 2 * action_dim, arch 179 | ) 180 | 181 | # initiailse the reparameterized tanh gaussian 182 | self.log_std_multiplier = Scalar(log_std_multiplier) 183 | self.log_std_offset = Scalar(log_std_offset) 184 | self.tanh_gaussian = ReparameterizedTanhGaussian() 185 | 186 | def log_prob(self, observations, actions): 187 | 188 | # change the dimensions of the observation to match the action 189 | if actions.ndim == 3: 190 | observations = extend_and_repeat(observations, 1, actions.shape[1]) 191 | 192 | # prepare the parameters 193 | base_network_output = self.base_network(observations) 194 | mean, log_std = torch.split(base_network_output, self.action_dim, dim=-1) 195 | log_std = self.log_std_multiplier() * log_std + self.log_std_offset() 196 | 197 | # get the log probability 198 | return self.tanh_gaussian.log_prob(mean, log_std, actions) 199 | 200 | def forward(self, observations, deterministic=False, repeat=None): 201 | 202 | # change the dimensions of the observation to match the action 203 | if repeat is not None: 204 | observations = extend_and_repeat(observations, 1, repeat) 205 | 206 | # prepare the parameters 207 | base_network_output = self.base_network(observations) 208 | mean, log_std = torch.split(base_network_output, self.action_dim, dim=-1) 209 | log_std = self.log_std_multiplier() * log_std + self.log_std_offset() 210 | 211 | # get the action sample and log prob 212 | return self.tanh_gaussian(mean, log_std, deterministic) 213 | 214 | 215 | class cql: 216 | 217 | def __init__(self, init_seed, patient_params, params): 218 | 219 | # ENVIRONMENT 220 | self.params = params 221 | self.env_name = patient_params["env_name"] 222 | self.folder_name = patient_params["folder_name"] 223 | self.replay_name = patient_params["replay_name"] 224 | self.bas = patient_params["u2ss"] * (patient_params["BW"] / 6000) * 3 225 | self.env = gym.make(self.env_name) 226 | self.action_size, self.state_size = 1, 11 227 | self.params["state_size"] = self.state_size 228 | self.sequence_length = 80 229 | self.data_processing = "condensed" 230 | self.device = params["device"] 231 | 232 | # HYPERPARAMETERS 233 | self.batch_size = 256 234 | self.policy_arch = '256-256' 235 | self.qf_arch = '256-256' 236 | self.policy_log_std_multiplier = 1.0 237 | self.policy_log_std_offset = -1.0 238 | self.discount = 0.99 239 | self.alpha_multiplier = 1.0 240 | self.target_entropy = 0.0 241 | self.policy_lr = 3e-4 242 | self.qf_lr = 3e-4 243 | self.soft_target_update_rate = 5e-3 244 | self.target_update_period = 1 245 | self.cql_n_actions = 10 246 | self.cql_temp = 1.0 247 | self.cql_min_q_weight = 5.0 248 | self.cql_clip_diff_min = -np.inf 249 | self.cql_clip_diff_max = np.inf 250 | 251 | # DISPLAY 252 | self.training_timesteps = params["training_timesteps"] 253 | self.training_progress_freq = int(self.training_timesteps // 10) 254 | 255 | # SEEDING 256 | self.train_seed = init_seed # use seeds 1, 2, 3 257 | self.env.seed(self.train_seed) 258 | np.random.seed(self.train_seed) 259 | torch.manual_seed(self.train_seed) 260 | random.seed(self.train_seed) 261 | 262 | # MEMORY 263 | self.memory_size = self.training_timesteps 264 | self.memory = deque(maxlen=self.memory_size) 265 | 266 | """ 267 | Initalise the neural networks. 268 | """ 269 | def init_model(self): 270 | 271 | # policy network 272 | self.policy = TanhGaussianPolicy(self.state_size, self.action_size, arch=self.policy_arch, 273 | log_std_multiplier=self.policy_log_std_multiplier, 274 | log_std_offset=self.policy_log_std_offset).to(self.device) 275 | self.log_alpha = Scalar(0.0).to(self.device) 276 | 277 | # Q networks and Target networks 278 | self.qf1 = FullyConnectedQFunction(self.state_size, self.action_size, arch=self.qf_arch).to(self.device) 279 | self.qf2 = FullyConnectedQFunction(self.state_size, self.action_size, arch=self.qf_arch).to(self.device) 280 | self.target_qf1 = copy.deepcopy(self.qf1) 281 | self.target_qf2 = copy.deepcopy(self.qf2) 282 | 283 | # Initialise the optimisers 284 | self.policy_optimizer = torch.optim.Adam(self.policy.parameters(), lr=self.policy_lr) 285 | self.qf_optimizer = torch.optim.Adam(list(self.qf1.parameters()) + list(self.qf2.parameters()), lr=self.qf_lr) 286 | self.alpha_optimizer = torch.optim.Adam(self.log_alpha.parameters(), lr=self.policy_lr) 287 | 288 | """ 289 | Save the learned models. 290 | """ 291 | def save_model(self): 292 | torch.save(self.policy.state_dict(), './Models/' + str(self.env_name) + str(self.train_seed) + 'CQL_weights_policy') 293 | torch.save(self.qf1.state_dict(), './Models/' + str(self.env_name) + str(self.train_seed) + 'CQL_weights_qf1') 294 | torch.save(self.qf2.state_dict(), './Models/' + str(self.env_name) + str(self.train_seed) + 'CQL_weights_qf2') 295 | 296 | """ 297 | Load pre-trained weights for testing. 298 | """ 299 | def load_model(self, name): 300 | 301 | # load he policy 302 | self.policy.load_state_dict(torch.load(name + '_policy')) 303 | self.policy.eval() 304 | 305 | # load qf1 and target 306 | self.qf1.load_state_dict(torch.load(name + '_qf1')) 307 | self.target_qf1 = copy.deepcopy(self.qf1) 308 | self.qf1.eval() 309 | 310 | # load qf2 and target 311 | self.qf2.load_state_dict(torch.load(name + '_qf2')) 312 | self.target_qf2 = copy.deepcopy(self.qf2) 313 | self.qf2.eval() 314 | 315 | """ 316 | Determine the action based on the state. 317 | """ 318 | def select_action(self, state, action, timestep, prev_reward): 319 | state = torch.tensor(state, dtype=torch.float32, device=self.device) 320 | action, _ = self.policy(state, deterministic=True) 321 | 322 | return action.cpu().data.numpy().flatten() 323 | 324 | """ 325 | Train the model on a pre-collected sample of training data. 326 | """ 327 | def train_model(self): 328 | 329 | # load the replay buffer 330 | with open("./Replays/" + self.replay_name + ".txt", "rb") as file: 331 | trajectories = pickle.load(file) 332 | 333 | # Process the replay -------------------------------------------------- 334 | 335 | # unpackage the replay 336 | self.memory, self.state_mean, self.state_std, self.action_mean, self.action_std, _, _ = unpackage_replay( 337 | trajectories=trajectories, empty_replay=self.memory, data_processing=self.data_processing, sequence_length=self.sequence_length 338 | ) 339 | 340 | # update the parameters 341 | self.action_std = 1.75 * self.bas * 0.25 / (self.action_std / self.bas) 342 | self.params["state_mean"], self.params["state_std"] = self.state_mean, self.state_std 343 | self.params["action_mean"], self.params["action_std"] = self.action_mean, self.action_std 344 | self.init_model() 345 | 346 | print('Processing Complete.') 347 | # ------------------------------------------------------------------------ 348 | 349 | for t in range(1, self.training_timesteps + 1): 350 | 351 | # Get the batch ------------------------------------------------ 352 | 353 | # unpackage the samples and split 354 | state, action, reward, next_state, done, _, _, _, _, _ = get_batch( 355 | replay=self.memory, batch_size=self.batch_size, 356 | data_processing=self.data_processing, 357 | sequence_length=self.sequence_length, device=self.device, 358 | params=self.params 359 | ) 360 | 361 | # Perform the training update ----------------------------------------- 362 | 363 | # get the action predictions 364 | new_actions, log_pi = self.policy(state) 365 | 366 | # update the alpha loss 367 | alpha_loss = -(self.log_alpha() * (log_pi + self.target_entropy).detach()).mean() 368 | alpha = self.log_alpha().exp() * self.alpha_multiplier 369 | 370 | # Compute the policy loss -------------------------------- 371 | q_new_actions = torch.min( 372 | self.qf1(state, new_actions), 373 | self.qf2(state, new_actions), 374 | ) 375 | policy_loss = (alpha * log_pi - q_new_actions).mean() 376 | 377 | # Compute the Q function loss -------------------------------- 378 | q1_pred = self.qf1(state, action) 379 | q2_pred = self.qf2(state, action) 380 | 381 | new_next_actions, next_log_pi = self.policy(next_state) 382 | target_q_values = torch.min( 383 | self.target_qf1(next_state, new_next_actions), 384 | self.target_qf2(next_state, new_next_actions), 385 | ) 386 | 387 | td_target = reward.reshape(-1) + (1. - done).reshape(-1) * self.discount * target_q_values 388 | qf1_loss = F.mse_loss(q1_pred, td_target.detach()) 389 | qf2_loss = F.mse_loss(q2_pred, td_target.detach()) 390 | 391 | # CQL -> incorporate conservativism into Q function loss -------------------------------- 392 | 393 | # create an array of unitiialised values of size below between -1 and 1 394 | cql_random_actions = action.new_empty((self.batch_size, self.cql_n_actions, self.action_size), requires_grad=False).uniform_(-1, 1) 395 | 396 | # get the current policy predictions 397 | cql_current_actions, cql_current_log_pis = self.policy(state, repeat=self.cql_n_actions) 398 | cql_next_actions, cql_next_log_pis = self.policy(next_state, repeat=self.cql_n_actions) 399 | 400 | # detach the values from the graph 401 | cql_current_actions, cql_current_log_pis = cql_current_actions.detach(), cql_current_log_pis.detach() 402 | cql_next_actions, cql_next_log_pis = cql_next_actions.detach(), cql_next_log_pis.detach() 403 | 404 | # get the network predictions for random, current and next actions 405 | cql_q1_rand = self.qf1(state, cql_random_actions) 406 | cql_q2_rand = self.qf2(state, cql_random_actions) 407 | cql_q1_current_actions = self.qf1(state, cql_current_actions) 408 | cql_q2_current_actions = self.qf2(state, cql_current_actions) 409 | cql_q1_next_actions = self.qf1(state, cql_next_actions) 410 | cql_q2_next_actions = self.qf2(state, cql_next_actions) 411 | 412 | # concatenate the results and calculate the standard deviation 413 | cql_cat_q1 = torch.cat([cql_q1_rand, torch.unsqueeze(q1_pred, 1), cql_q1_next_actions, cql_q1_current_actions], dim=1) 414 | cql_cat_q2 = torch.cat([cql_q2_rand, torch.unsqueeze(q2_pred, 1), cql_q2_next_actions, cql_q2_current_actions], dim=1) 415 | cql_std_q1 = torch.std(cql_cat_q1, dim=1) 416 | cql_std_q2 = torch.std(cql_cat_q2, dim=1) 417 | 418 | # Subtract density function from the Q function predictions 419 | random_density = np.log(0.5 ** self.action_size) 420 | cql_cat_q1 = torch.cat( 421 | [cql_q1_rand - random_density, 422 | cql_q1_next_actions - cql_next_log_pis.detach(), 423 | cql_q1_current_actions - cql_current_log_pis.detach()], 424 | dim=1 425 | ) 426 | cql_cat_q2 = torch.cat( 427 | [cql_q2_rand - random_density, 428 | cql_q2_next_actions - cql_next_log_pis.detach(), 429 | cql_q2_current_actions - cql_current_log_pis.detach()], 430 | dim=1 431 | ) 432 | 433 | # Check if the predictions are out of the distribution (OOD) 434 | cql_qf1_ood = torch.logsumexp(cql_cat_q1 / self.cql_temp, dim=1) * self.cql_temp 435 | cql_qf2_ood = torch.logsumexp(cql_cat_q2 / self.cql_temp, dim=1) * self.cql_temp 436 | 437 | # Subtract the log likelihood of data 438 | cql_qf1_diff = torch.clamp(cql_qf1_ood - q1_pred, self.cql_clip_diff_min, self.cql_clip_diff_max).mean() 439 | cql_qf2_diff = torch.clamp(cql_qf2_ood - q2_pred, self.cql_clip_diff_min, self.cql_clip_diff_max).mean() 440 | 441 | # calculate the conservative loss 442 | cql_min_qf1_loss = cql_qf1_diff * self.cql_min_q_weight 443 | cql_min_qf2_loss = cql_qf2_diff * self.cql_min_q_weight 444 | 445 | # Returns a new tensor of the dimensions of the state 446 | alpha_prime_loss = state.new_tensor(0.0).to(self.device) 447 | alpha_prime = state.new_tensor(0.0).to(self.device) 448 | 449 | # get the combined conservative loss 450 | qf_loss = qf1_loss + qf2_loss + cql_min_qf1_loss + cql_min_qf2_loss 451 | 452 | # Backpropagation -------------------------------------------------- 453 | 454 | # backpropagate and gradient step 455 | self.alpha_optimizer.zero_grad() 456 | alpha_loss.backward() 457 | self.alpha_optimizer.step() 458 | self.policy_optimizer.zero_grad() 459 | policy_loss.backward() 460 | self.policy_optimizer.step() 461 | self.qf_optimizer.zero_grad() 462 | qf_loss.backward() 463 | self.qf_optimizer.step() 464 | 465 | # Target Update ----------------------------------------------------- 466 | 467 | # update the target networks 468 | if t % self.target_update_period == 0: 469 | 470 | for param, target_param in zip(self.qf1.parameters(), self.target_qf1.parameters()): 471 | target_param.data.copy_(self.soft_target_update_rate * param.data + (1 - self.soft_target_update_rate) * target_param.data) 472 | 473 | for param, target_param in zip(self.qf2.parameters(), self.qf2.parameters()): 474 | target_param.data.copy_(self.soft_target_update_rate * param.data + (1 - self.soft_target_update_rate) * target_param.data) 475 | 476 | # Show progress 477 | if t % self.training_progress_freq == 0: 478 | 479 | # show the updated loss 480 | print('Timesteps {} - Policy Loss {} - Q function Loss {}'.format(t, policy_loss, qf_loss)) 481 | self.save_model() 482 | 483 | 484 | """ 485 | Test the learned weights against the PID controller. 486 | """ 487 | def test_model(self, input_seed=0, input_max_timesteps=4800): 488 | 489 | # TESTING -------------------------------------------------------------------------------------------- 490 | 491 | # initialise the environment 492 | env = gym.make(self.env_name) 493 | 494 | # load the replay buffer 495 | with open("./Replays/" + self.replay_name + ".txt", "rb") as file: 496 | trajectories = pickle.load(file) 497 | 498 | # Process the replay -------------------------------------------------- 499 | 500 | # unpackage the replay 501 | self.memory, self.state_mean, self.state_std, self.action_mean, self.action_std, _, _ = unpackage_replay( 502 | trajectories=trajectories, empty_replay=self.memory, data_processing=self.data_processing, sequence_length=self.sequence_length 503 | ) 504 | 505 | # update the parameters 506 | self.action_std = 1.75 * self.bas * 0.25 / (self.action_std / self.bas) 507 | self.params["state_mean"], self.params["state_std"] = self.state_mean, self.state_std 508 | self.params["action_mean"], self.params["action_std"] = self.action_mean, self.action_std 509 | self.init_model() 510 | 511 | # load the learned model 512 | self.load_model('./Models/' + self.folder_name + "/" + "Seed" + str(self.train_seed) + "/" + 'CQL_weights') 513 | test_seed, max_timesteps = input_seed, input_max_timesteps 514 | 515 | # test the algorithm's performance vs pid algorithm 516 | rl_reward, rl_bg, rl_action, rl_insulin, rl_meals, pid_reward, pid_bg, pid_action = test_algorithm( 517 | env=env, agent_action=self.select_action, seed=test_seed, max_timesteps=max_timesteps, 518 | sequence_length=self.sequence_length, data_processing=self.data_processing, 519 | pid_run=False, params=self.params 520 | ) 521 | 522 | # display the results 523 | create_graph( 524 | rl_reward=rl_reward, rl_blood_glucose=rl_bg, rl_action=rl_action, rl_insulin=rl_insulin, 525 | rl_meals=rl_meals, pid_reward=pid_reward, pid_blood_glucose=pid_bg, 526 | pid_action=pid_action, params=self.params 527 | ) -------------------------------------------------------------------------------- /utils/evaluation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sat Apr 2 13:28:19 2022 5 | """ 6 | 7 | """ 8 | Functions for evaluating algorithmic performance and displaying it to the user. 9 | """ 10 | 11 | import matplotlib.pyplot as plt 12 | import numpy as np 13 | import torch, random 14 | 15 | from .general import PID_action, calculate_bolus, calculate_risk, is_in_range 16 | 17 | 18 | """ 19 | Test the learned policy of an agent against the PID algorithm over a 20 | specified length of time. 21 | """ 22 | def test_algorithm(env, agent_action, seed=0, max_timesteps=480, sequence_length=80, 23 | data_processing="condensed", pid_run=False, lstm=False, params=None): 24 | 25 | # Diabetes 26 | basal_default = params.get("basal_default") 27 | target_blood_glucose = params.get("target_blood_glucose", 144) 28 | 29 | # Bolus 30 | cr = params.get("carbohydrate_ratio") 31 | cf = params.get("correction_factor") 32 | bolus_overestimate = params.get("bolus_overestimate", 0.0) 33 | meal_announce = params.get("meal_announce", 0.0) 34 | 35 | # PID 36 | kp = params.get("kp") 37 | ki = params.get("ki") 38 | kd = params.get("kd") 39 | 40 | # Means and Stds 41 | state_mean = params.get("state_mean") 42 | state_std = params.get("state_std") 43 | action_mean = params.get("action_mean") 44 | action_std = params.get("action_std") 45 | 46 | # Device 47 | device = params.get("device") 48 | missing_data_prob = params.get("missing_data_prob", 0.0) 49 | compression_prob = params.get("compression_prob", 0.0) 50 | 51 | # Network 52 | model_dim = params.get("model_dim", 256) 53 | 54 | 55 | # initialise the arrays for data collection 56 | rl_reward, rl_blood_glucose, rl_action = 0, [], [] 57 | pid_reward, pid_blood_glucose, pid_action = 0, [], [] 58 | rl_insulin, rl_meals = [], [] 59 | 60 | # select the number of iterations 61 | if not pid_run: runs = 2 62 | else: runs = 1 63 | 64 | for ep in range(runs): 65 | 66 | # set the seed for the environment 67 | env.seed(seed) 68 | np.random.seed(seed) 69 | torch.manual_seed(seed) 70 | random.seed(seed) 71 | 72 | # Initialise the environment -------------------------------------- 73 | 74 | # get the state 75 | insulin_dose = 1/3 * basal_default 76 | meal, done, bg_val = 0, False, env.reset() 77 | time = ((env.env.time.hour * 60) / 3 + env.env.time.minute / 3) / 479 78 | state = np.array([bg_val[0], meal, insulin_dose, time], dtype=np.float32) 79 | last_action = insulin_dose 80 | 81 | # get a suitable input 82 | state_stack = np.tile(state, (sequence_length + 1, 1)) 83 | state_stack[:, 3] = (state_stack[:, 3] - np.arange(((sequence_length + 1) / 479), 0, -(1 / 479))[:sequence_length + 1]) * 479 84 | state_stack[:, 3] = (np.around(state_stack[:, 3], 0) % 480) / 479 85 | 86 | # get the action and reward stack 87 | action_stack = np.tile(np.array([insulin_dose], dtype=np.float32), (sequence_length + 1, 1)) 88 | reward_stack = np.tile(-calculate_risk(bg_val), (sequence_length + 1, 1)) 89 | 90 | # get the meal history 91 | meal_history = np.zeros(int((3 * 60) / 3), dtype=np.float32) 92 | 93 | # intiialise pid parameters 94 | integrated_state = 0 95 | previous_error = 0 96 | timesteps = 0 97 | reward = 0 98 | 99 | # count missing data period 100 | missing_period = 0 101 | 102 | # add compression error 103 | compression_period = 0 104 | compression_size = 0 105 | 106 | # init the hidden_layer 107 | if params["rnn"] == "gru": 108 | hidden_in = torch.zeros([1, 1, model_dim], dtype=torch.float).to(device) 109 | 110 | else: 111 | hidden_in = (torch.zeros([1, 1, model_dim], dtype=torch.float).to(device), 112 | torch.zeros([1, 1, model_dim], dtype=torch.float).to(device)) 113 | 114 | while not done and timesteps < max_timesteps: 115 | 116 | # Run the RL algorithm ------------------------------------------------------ 117 | if ep == 0: 118 | 119 | # condense the state 120 | if data_processing == "condensed": 121 | 122 | # Unpack the state 123 | bg_vals, meal_vals, insulin_vals = state_stack[:, 0][::10], state_stack[:, 1], state_stack[:, 2] 124 | 125 | # calculate insulin and meals on board 126 | decay_factor = np.arange(1 / (sequence_length + 2), 1, 1 / (sequence_length + 2)) 127 | meals_on_board, insulin_on_board = np.sum(meal_vals * decay_factor), np.sum(insulin_vals * decay_factor) 128 | 129 | # create the state 130 | state = np.concatenate([bg_vals, meals_on_board.reshape(1), insulin_on_board.reshape(1)]) 131 | prev_action = last_action 132 | 133 | # TOOD: replace with explicity state and action size 134 | 135 | # get the state a sequence of specified length 136 | elif data_processing == "sequence": 137 | state = state_stack[1:, :3].reshape(1, sequence_length, 3) 138 | prev_action = action_stack[1:, :].reshape(1, sequence_length) 139 | 140 | # Normalise the current state 141 | state = (state - state_mean) / state_std 142 | prev_action = (prev_action - action_mean) / action_std 143 | 144 | # get the action prediction from the model 145 | if lstm: 146 | action, hidden_in = agent_action(state, prev_action, timestep=timesteps, hidden_in=hidden_in, prev_reward=reward) 147 | else: 148 | action = agent_action(state, prev_action, timestep=timesteps, prev_reward=reward) 149 | 150 | # Unnormalise action output 151 | action_pred = (action * action_std + action_mean)[0] 152 | 153 | # to stop subtracting from bolus when -ve 154 | action_pred = max(0, action_pred) 155 | player_action = action_pred 156 | 157 | 158 | # Run the pid algorithm ------------------------------------------------------ 159 | else: 160 | player_action, previous_error, integrated_state = PID_action( 161 | blood_glucose=bg_val, previous_error=previous_error, 162 | integrated_state=integrated_state, 163 | target_blood_glucose=target_blood_glucose, 164 | kp=kp, ki=ki, kd=kd, basal_default=basal_default 165 | ) 166 | 167 | 168 | # update the chosen action 169 | chosen_action = np.copy(player_action) 170 | 171 | # Get the meal bolus -------------------------------------------- 172 | 173 | # take meal bolus 174 | if meal > 0: 175 | 176 | # add a bias to the bolus estimation 177 | adjusted_meal = meal 178 | adjusted_meal += bolus_overestimate * meal 179 | adjusted_meal = max(0, adjusted_meal) 180 | 181 | bolus_action = calculate_bolus( 182 | blood_glucose=bg_val, meal_history=meal_history, 183 | current_meal=adjusted_meal, carbohyrdate_ratio=cr, 184 | correction_factor=cf, 185 | target_blood_glucose=target_blood_glucose 186 | ) 187 | 188 | chosen_action = float(chosen_action) + bolus_action 189 | 190 | # Step the environment ------------------------------------------ 191 | 192 | # append the basal and bolus action 193 | action_stack = np.delete(action_stack, 0, 0) 194 | action_stack = np.vstack([action_stack, player_action]) 195 | 196 | # step the simulator 197 | next_bg_val, _, done, info = env.step(chosen_action) 198 | reward = -calculate_risk(next_bg_val) 199 | 200 | # announce a meal ------------------------------------------- 201 | 202 | # meal announcement 203 | meal_input = meal 204 | if meal_announce != 0.0: 205 | 206 | # get times + meal schedule 207 | current_time = env.env.time.hour * 60 + env.env.time.minute 208 | future_time = current_time + meal_announce - 1 209 | meal_scenario = env.env.scenario.scenario["meal"] 210 | future_meal = 0 211 | 212 | # check for future meal 213 | if future_time in meal_scenario["time"]: 214 | index = meal_scenario["time"].index(future_time) 215 | future_meal = meal_scenario["amount"][index] 216 | 217 | meal_input = future_meal/3 218 | 219 | # add missing data to the dataset 220 | rand = np.random.rand() 221 | if (rand < missing_data_prob) or (missing_period > 0): 222 | if missing_period < 1: 223 | prev_bg = 144 # bg_val[0] 224 | missing_period = np.random.randint(10) 225 | next_bg_val = [next_bg_val[0]] 226 | next_bg_val[0] = 144 # prev_bg 227 | 228 | # add compression error 229 | rand = np.random.rand() 230 | if (rand < compression_prob) or (compression_period > 0): 231 | if compression_period < 1: 232 | compression_period = np.random.randint(10) 233 | compression_size = np.random.randint(30) 234 | next_bg_val = [next_bg_val[0]] 235 | next_bg_val[0] -= compression_size 236 | 237 | # step forward in time 238 | missing_period -= 1 239 | compression_period -= 1 240 | 241 | # get the rnn array format for state 242 | time = ((env.env.time.hour * 60) / 3 + env.env.time.minute / 3)/479 243 | next_state = np.array([float(next_bg_val[0]), float(meal_input), float(chosen_action), time], dtype=np.float32) 244 | 245 | # update the state stacks 246 | next_state_stack = np.delete(state_stack, 0, 0) 247 | next_state_stack = np.vstack([next_state_stack, next_state]) 248 | reward_stack = np.delete(reward_stack, 0, 0) 249 | reward_stack = np.vstack([reward_stack, np.array([reward], dtype=np.float32)]) 250 | 251 | # add a termination penalty 252 | if done: 253 | reward = -1e5 254 | break 255 | 256 | # Save the testing results -------------------------------------- 257 | 258 | # for RL agent 259 | if ep == 0: 260 | rl_blood_glucose.append(next_bg_val[0]) 261 | rl_action.append(player_action) 262 | rl_insulin.append(chosen_action) 263 | rl_reward += reward 264 | rl_meals.append(info['meal']) 265 | 266 | # for pid agent 267 | else: 268 | pid_blood_glucose.append(next_bg_val[0]) 269 | pid_action.append(player_action) 270 | pid_reward += reward 271 | 272 | # Update the state --------------------------------------------- 273 | 274 | # update the meal history 275 | meal_history = np.append(meal_history, meal) 276 | meal_history = np.delete(meal_history, 0) 277 | 278 | # update the state stacks 279 | state_stack = next_state_stack 280 | 281 | # update the state 282 | bg_val, state, meal = next_bg_val, next_state, info['meal'] 283 | last_action = player_action 284 | timesteps += 1 285 | 286 | return rl_reward, rl_blood_glucose, rl_action, rl_insulin, rl_meals, pid_reward, pid_blood_glucose, pid_action 287 | 288 | 289 | """ 290 | Plot a four-tiered graph comparing the blood glucose control of a PID 291 | and RL algorithm, showing the blood glucose, insulin doses and meal 292 | carbohyrdates. 293 | """ 294 | def create_graph(rl_reward, rl_blood_glucose, rl_action, rl_insulin, rl_meals, 295 | pid_reward, pid_blood_glucose, pid_action, params): 296 | 297 | # Unpack the params 298 | 299 | # Diabetes 300 | basal_default = params.get("basal_default") 301 | hyper_threshold = params.get("hyper_threshold", 180) 302 | sig_hyper_threshold = params.get("sig_hyper_threshold", 250) 303 | hypo_threshold = params.get("hypo_threshold ", 70) 304 | sig_hypo_threshold = params.get("sig_hypo_threshold ", 54) 305 | 306 | # Display the evaluation metrics 307 | 308 | # TIR Metrics ---------------------------------------------- 309 | 310 | # PID algorithm 311 | pid_in_range, pid_above_range, pid_below_range, pid_total = 0, 0, 0, len(pid_blood_glucose) 312 | pid_sig_above_range, pid_sig_below_range = 0, 0 313 | for pid_bg in pid_blood_glucose: 314 | 315 | # classify the blood glucose_value 316 | classification = is_in_range(pid_bg, hypo_threshold, hyper_threshold, sig_hypo_threshold, sig_hyper_threshold) 317 | 318 | # in range 319 | if classification == 0: 320 | pid_in_range += 1 321 | 322 | # hyperglycaemia 323 | elif classification > 0: 324 | pid_above_range += 1 325 | if classification > 1: 326 | pid_sig_above_range += 1 327 | 328 | # hypoglycaemia 329 | else: 330 | pid_below_range += 1 331 | if classification > -1: 332 | pid_sig_below_range += 1 333 | 334 | # RL algorithm 335 | rl_in_range, rl_above_range, rl_below_range, rl_total = 0, 0, 0, len(rl_blood_glucose) 336 | rl_sig_above_range, rl_sig_below_range = 0, 0 337 | for rl_bg in rl_blood_glucose: 338 | 339 | # classify the blood glucose_value 340 | classification = is_in_range(rl_bg, hypo_threshold, hyper_threshold, sig_hypo_threshold, sig_hyper_threshold) 341 | 342 | # in range 343 | if classification == 0: 344 | rl_in_range += 1 345 | 346 | # hyperglycaemia 347 | elif classification > 0: 348 | rl_above_range += 1 349 | if classification > 1: 350 | rl_sig_above_range += 1 351 | 352 | # hypoglycaemia 353 | else: 354 | rl_below_range += 1 355 | if classification < -1: 356 | rl_sig_below_range += 1 357 | 358 | # Statistical Metrics ----------------------------------------- 359 | 360 | pid_mean, pid_std = np.mean(pid_blood_glucose), np.std(pid_blood_glucose) 361 | rl_mean, rl_std = np.mean(rl_blood_glucose), np.std(rl_blood_glucose) 362 | pid_cv, rl_cv = (pid_std / pid_mean), (rl_std / rl_mean) 363 | 364 | # Diabetes Metrics --------------------------------------------- 365 | 366 | # get the average hypo/hyper length for PID 367 | pid_hypo_length, pid_hyper_length = [], [] 368 | prev_classification, hypo_count, hyper_count = 0, 0, 0 369 | 370 | for pid_bg in pid_blood_glucose: 371 | 372 | # classify the blood glucose_value 373 | classification = is_in_range(pid_bg, hypo_threshold, hyper_threshold, sig_hypo_threshold, sig_hyper_threshold) 374 | 375 | # if continued hyper 376 | if classification > 0: 377 | 378 | # add to the hyper count 379 | if prev_classification > 0: 380 | hyper_count += 1 381 | 382 | # reset the count 383 | else: 384 | pid_hyper_length.append(hyper_count * 3) 385 | hyper_count = 0 386 | 387 | # if continued hypo 388 | if classification < 0: 389 | 390 | # add to the hypo count 391 | if prev_classification < 0: 392 | hypo_count += 1 393 | 394 | # reset the count 395 | else: 396 | pid_hypo_length.append(hypo_count * 3) 397 | hypo_count = 0 398 | 399 | prev_classification = classification 400 | 401 | # get the average hypo/hyper length for RL 402 | rl_hypo_length, rl_hyper_length = [], [] 403 | prev_classification, hypo_count, hyper_count = 0, 0, 0 404 | 405 | for rl_bg in rl_blood_glucose: 406 | 407 | # classify the blood glucose_value 408 | classification = is_in_range(rl_bg, hypo_threshold, hyper_threshold, sig_hypo_threshold, sig_hyper_threshold) 409 | 410 | # if continued hyper 411 | if classification > 0: 412 | 413 | # add to the hyper count 414 | if prev_classification > 0: 415 | hyper_count += 1 416 | 417 | # reset the count 418 | else: 419 | rl_hyper_length.append(hyper_count * 3) 420 | hyper_count = 0 421 | 422 | # if continued hypo 423 | if classification < 0: 424 | 425 | # add to the hypo count 426 | if prev_classification < 0: 427 | hypo_count += 1 428 | 429 | # reset the count 430 | else: 431 | rl_hypo_length.append(hypo_count * 3) 432 | hypo_count = 0 433 | 434 | prev_classification = classification 435 | 436 | mean_pid_hypo_length = sum(pid_hypo_length) / max(len(pid_hypo_length), 1) 437 | mean_pid_hyper_length = sum(pid_hyper_length) / max(len(pid_hyper_length), 1) 438 | mean_rl_hypo_length = sum(rl_hypo_length) / max(len(rl_hypo_length), 1) 439 | mean_rl_hyper_length = sum(rl_hyper_length) / max(len(rl_hyper_length), 1) 440 | 441 | print('\n-----------------------------------------------------------') 442 | print(' | {: ^016} | {: ^016} |'.format("PID", "RL")) 443 | print('-----------------------------------------------------------') 444 | print('Reward | {: ^#016.2f} | {: ^#016.2f} |'.format(pid_reward, rl_reward)) 445 | print('TIR (%) | {: ^#016.2f} | {: ^#016.2f} |'.format(pid_in_range / pid_total * 100, rl_in_range / rl_total * 100)) 446 | print('TAR (%) | {: ^#016.2f} | {: ^#016.2f} |'.format(pid_above_range / pid_total * 100, rl_above_range / rl_total * 100)) 447 | print('TBR (%) | {: ^#016.2f} | {: ^#016.2f} |'.format(pid_below_range / pid_total * 100, rl_below_range / rl_total * 100)) 448 | print('Mean (mg/dl) | {: ^#016.2f} | {: ^#016.2f} |'.format(pid_mean, rl_mean)) 449 | print('STD (mg/dl) | {: ^#016.2f} | {: ^#016.2f} |'.format(pid_std, rl_std)) 450 | print('CoV | {: ^#016.2f} | {: ^#016.2f} |'.format(pid_cv, rl_cv)) 451 | print('Hyper Length (mins) | {: ^#016.2f} | {: ^#016.2f} |'.format(mean_pid_hyper_length, mean_rl_hyper_length)) 452 | print('Hypo Length (mins) | {: ^#016.2f} | {: ^#016.2f} |'.format(mean_pid_hypo_length, mean_rl_hypo_length)) 453 | print('TMBR (%) | {: ^#016.2f} | {: ^#016.2f} |'.format((pid_below_range - pid_sig_below_range) / pid_total * 100, (rl_below_range - rl_sig_below_range) / rl_total * 100)) 454 | print('TSBR (%) | {: ^#016.2f} | {: ^#016.2f} |'.format(pid_sig_below_range / pid_total * 100, rl_sig_below_range / rl_total * 100)) 455 | print('TMAR (%) | {: ^#016.2f} | {: ^#016.2f} |'.format((pid_above_range - pid_sig_above_range) / pid_total * 100, (rl_above_range - rl_sig_above_range) / rl_total * 100)) 456 | print('TSAR (%) | {: ^#016.2f} | {: ^#016.2f} |'.format(pid_sig_above_range / pid_total * 100, rl_sig_above_range / rl_total * 100)) 457 | print('-----------------------------------------------------------') 458 | 459 | # Produce the glucose display graph ----------------------------------------------- 460 | 461 | # Check that the rl algorithm completed the full episode 462 | if len(pid_blood_glucose) == len(rl_blood_glucose): 463 | 464 | # Plot insulin actions alongside blood glucose ------------------------------ 465 | 466 | # get the x-axis 467 | x = list(range(len(pid_blood_glucose))) 468 | 469 | # Initialise the plot and specify the title 470 | fig = plt.figure(dpi=160) 471 | gs = fig.add_gridspec(4, hspace=0.0) 472 | axs = gs.subplots(sharex=True, sharey=False) 473 | fig.suptitle('Blood Glucose Control Algorithm Comparison') 474 | 475 | # define the hypo, eu and hyper regions 476 | axs[0].axhspan(180, 500, color='lightcoral', alpha=0.6, lw=0) 477 | axs[0].axhspan(70, 180, color='#c1efc1', alpha=1.0, lw=0) 478 | axs[0].axhspan(0, 70, color='lightcoral', alpha=0.6, lw=0) 479 | 480 | # plot the blood glucose values 481 | axs[0].plot(x, pid_blood_glucose, label='pid', color='orange') 482 | axs[0].plot(x, rl_blood_glucose, label='rl', color='dodgerblue') 483 | axs[0].legend(bbox_to_anchor=(1.0, 1.0)) 484 | 485 | # specify the limits and the axis lables 486 | axs[0].axis(ymin=50, ymax=500) 487 | axs[0].axis(xmin=0.0, xmax=len(pid_blood_glucose)) 488 | axs[0].set_ylabel("BG \n(mg/dL)") 489 | axs[0].set_xlabel("Time \n(mins)") 490 | 491 | # show the basal doses 492 | axs[1].plot(x, pid_action, label='pid', color='orange') 493 | axs[1].plot(x, rl_action, label='rl', color='dodgerblue') 494 | axs[1].axis(ymin=0.0, ymax=(basal_default * 1.4)) 495 | axs[1].set_ylabel("Basal \n(U/min)") 496 | 497 | # show the bolus doses 498 | axs[2].plot(x, rl_insulin) 499 | axs[2].axis(ymin=0.01, ymax=0.99) 500 | axs[2].set_ylabel("Bolus \n(U/min)") 501 | 502 | # show the scheduled meals 503 | axs[3].plot(x, rl_meals) 504 | axs[3].axis(ymin=0, ymax=29.9) 505 | axs[3].set_ylabel("CHO \n(g/min)") 506 | 507 | # Hide x labels and tick labels for all but bottom plot. 508 | for ax in axs: 509 | ax.label_outer() 510 | 511 | plt.show() 512 | 513 | # Plot the distribution of states ------------------------------ 514 | 515 | fig2 = plt.figure(dpi=160) 516 | 517 | bins = np.linspace(10, 1000, 100) 518 | 519 | # plot the bins and the legend 520 | plt.hist(pid_blood_glucose, bins, alpha=0.5, label='pid', color='orange') 521 | plt.hist(rl_blood_glucose, bins, alpha=0.5, label='rl', color='dodgerblue') 522 | plt.legend(loc='upper right') 523 | 524 | # mark the target range 525 | plt.axvline(hyper_threshold, color='k', linestyle='dashed', linewidth=1) 526 | plt.axvline(hypo_threshold, color='k', linestyle='dashed', linewidth=1) 527 | 528 | # set the axis labels 529 | plt.xlabel("Blood glucose (mg/dl)") 530 | plt.ylabel("Frequency") 531 | plt.title("Blood glucose distribution") 532 | 533 | plt.show() 534 | 535 | # specify the timesteps before termination 536 | else: print('Terminated after: {} timesteps.'.format(len(rl_blood_glucose))) 537 | -------------------------------------------------------------------------------- /SAC_RNN.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sat Apr 2 13:35:53 2022 5 | 6 | """ 7 | 8 | import numpy as np 9 | import copy, random, torch, gym 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torch.distributions import Normal 13 | from collections import deque 14 | 15 | from utils import get_batch, test_algorithm, create_graph, calculate_bolus, calculate_risk 16 | 17 | """ 18 | Recurrent neural network for the Actor. 19 | """ 20 | class Actor(nn.Module): 21 | def __init__(self, state_dim, action_dim, hidden_dim, min_log_std=-20, max_log_std=2): 22 | super(Actor, self).__init__() 23 | hidden_dim = hidden_dim 24 | 25 | # linear branch 26 | self.fc1 = nn.Linear(state_dim, hidden_dim) 27 | self.fc2 = nn.Linear(hidden_dim * 2, hidden_dim) 28 | self.fc3 = nn.Linear(hidden_dim, hidden_dim) 29 | 30 | # lstm branch 31 | self.lstm1 = nn.Linear(state_dim + action_dim, hidden_dim) 32 | self.lstm2 = nn.LSTM(hidden_dim, hidden_dim) 33 | self.mu_head = nn.Linear(hidden_dim, action_dim) 34 | self.log_std_head = nn.Linear(hidden_dim, action_dim) 35 | self.min_log_std = min_log_std 36 | self.max_log_std = max_log_std 37 | 38 | def forward(self, state, last_action, hidden_in): 39 | 40 | state = state.permute(1, 0, 2) 41 | last_action = last_action.permute(1, 0, 2) 42 | 43 | # branch 1 44 | x1 = F.relu(self.fc1(state)) 45 | 46 | # branch 2 47 | x2 = torch.cat([state, last_action], -1) 48 | x2 = F.relu(self.lstm1(x2)) 49 | x2, hidden_out = self.lstm2(x2, hidden_in) 50 | 51 | # merging 52 | x = torch.cat([x1, x2], -1) 53 | x = F.relu(self.fc2(x)) 54 | x = F.relu(self.fc3(x)) 55 | x = x.permute(1, 0, 2) 56 | 57 | # mean and std 58 | mu = self.mu_head(x) 59 | log_std_head = F.relu(self.log_std_head(x)) 60 | log_std_head = torch.clamp(log_std_head, self.min_log_std, self.max_log_std) 61 | 62 | return mu, log_std_head, hidden_out 63 | 64 | """ 65 | Recurrent neural network for the Critic. 66 | """ 67 | class Q(nn.Module): 68 | def __init__(self, state_dim, action_dim, hidden_dim): 69 | super(Q, self).__init__() 70 | hidden_dim = hidden_dim 71 | 72 | self.state_dim, self.action_dim = state_dim, action_dim 73 | self.fc1 = nn.Linear(state_dim + action_dim, hidden_dim) 74 | self.fc2 = nn.Linear(state_dim + action_dim, hidden_dim) 75 | self.lstm1 = nn.LSTM(hidden_dim, hidden_dim) 76 | self.fc3 = nn.Linear(hidden_dim * 2, hidden_dim) 77 | self.fc4 = nn.Linear(hidden_dim, 1) 78 | 79 | def forward(self, state, action, last_action, hidden_in): 80 | 81 | state = state.permute(1, 0, 2) 82 | action = action.permute(1, 0, 2) 83 | last_action = last_action.permute(1, 0, 2) 84 | 85 | # branch 1 86 | x1 = torch.cat((state, action), -1) 87 | x1 = F.relu(self.fc1(x1)) 88 | 89 | # branch 2 90 | x2 = torch.cat((state, last_action), -1) 91 | x2 = F.relu(self.fc2(x2)) 92 | x2, hidden_out = self.lstm1(x2, hidden_in) 93 | 94 | # merged 95 | x = torch.cat([x1, x2], -1) 96 | x = F.relu(self.fc3(x)) 97 | x = self.fc4(x) 98 | x = x.permute(1,0,2) 99 | 100 | return x, hidden_out 101 | 102 | 103 | class sac_rnn(object): 104 | 105 | def __init__(self, init_seed, patient_params, params): 106 | 107 | # ENVIRONMENT 108 | self.params = params 109 | self.env_name = patient_params["env_name"] 110 | self.folder_name = patient_params["folder_name"] 111 | self.bas = patient_params["u2ss"] * (patient_params["BW"] / 6000) * 3 112 | self.env = gym.make(self.env_name) 113 | self.action_size, self.state_size = 1, 3 114 | self.params["state_size"] = self.state_size 115 | self.sequence_length = 80 116 | self.data_processing = "sequence" 117 | self.device = params["device"] 118 | 119 | # HYPERPARAMETERS 120 | self.tau = 0.01 121 | self.gamma = 0.99 122 | self.ac_learning_rate = 3e-4 123 | self.ct_learning_rate = 3e-4 124 | self.ap_learning_rate = 3e-4 125 | self.batch_size = 3 126 | self.target_entropy = -1.0 127 | self.starting_timesteps = 80 * (self.batch_size + 1) # 4801 128 | self.entropy = True 129 | self.hidden_dim = 128 130 | 131 | # DISPLAY 132 | self.training_timesteps = params["training_timesteps"] 133 | self.training_progress_freq = int(self.training_timesteps // 10) 134 | self.max_timesteps = 480 * 10 135 | 136 | # SEEDING 137 | self.train_seed = init_seed 138 | self.env.seed(self.train_seed) 139 | np.random.seed(self.train_seed) 140 | torch.manual_seed(self.train_seed) 141 | random.seed(self.train_seed) 142 | 143 | # MEMORY 144 | self.memory_size = self.training_timesteps 145 | self.memory = deque(maxlen=self.memory_size) 146 | 147 | # NORMALISATION 148 | self.state_mean = np.array([10.0, 0.0, 0.0], dtype=np.float32) 149 | self.state_std = np.array([990.0, 35, 0.5], dtype=np.float32) 150 | self.action_mean, self.action_std = np.ones(1) * patient_params["max_dose"] * self.bas, np.ones(1) * patient_params["max_dose"] * self.bas 151 | self.params["state_mean"], self.params["state_std"] = self.state_mean, self.state_std 152 | self.params["action_mean"], self.params["action_std"] = self.action_mean, self.action_std 153 | self.unnormed_max_action = self.action_mean * 2 154 | self.action_range = 1 155 | 156 | """ 157 | Initalise the neural networks. 158 | """ 159 | def init_model(self): 160 | 161 | # policy network 162 | self.policy_net = Actor(self.state_size, self.action_size, self.hidden_dim).to(self.device) 163 | self.policy_optimizer = torch.optim.Adam(self.policy_net.parameters(), lr=self.ac_learning_rate) 164 | 165 | # Q network 166 | self.target_soft_q_net1 = Q(self.state_size, self.action_size, self.hidden_dim).to(self.device) 167 | self.soft_q_net1 = Q(self.state_size, self.action_size, self.hidden_dim).to(self.device) 168 | self.soft_q_net2 = Q(self.state_size, self.action_size, self.hidden_dim).to(self.device) 169 | self.target_soft_q_net2 = Q(self.state_size, self.action_size, self.hidden_dim).to(self.device) 170 | 171 | self.soft_q_criterion1 = nn.MSELoss() 172 | self.soft_q_criterion2 = nn.MSELoss() 173 | 174 | for target_param, param in zip(self.target_soft_q_net1.parameters(), self.soft_q_net1.parameters()): 175 | target_param.data.copy_(param.data) 176 | for target_param, param in zip(self.target_soft_q_net2.parameters(), self.soft_q_net2.parameters()): 177 | target_param.data.copy_(param.data) 178 | 179 | self.soft_q_optimizer1 = torch.optim.Adam(self.soft_q_net1.parameters(), lr=self.ct_learning_rate) 180 | self.soft_q_optimizer2 = torch.optim.Adam(self.soft_q_net2.parameters(), lr=self.ct_learning_rate) 181 | 182 | self.log_alpha = torch.zeros(1, dtype=torch.float32, requires_grad=True, device=self.device) 183 | self.alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=self.ap_learning_rate) 184 | 185 | """ 186 | Save the learned models. 187 | """ 188 | def save_model(self): 189 | 190 | torch.save(self.policy_net.state_dict(), './Models/'+ str(self.env_name) + "_" + str(self.train_seed) + "_" +'SAC_RNN_online_weights_actor') 191 | torch.save(self.soft_q_net1.state_dict(), './Models/' + str(self.env_name) + "_" + str(self.train_seed) + "_" +'SAC_RNN_online_weights_q1') 192 | torch.save(self.soft_q_net2.state_dict(), './Models/'+ str(self.env_name) + "_" + str(self.train_seed) + "_" + 'SAC_RNN_online_weights_q2') 193 | 194 | """ 195 | Load pre-trained weights for testing. 196 | """ 197 | def load_model(self, name): 198 | 199 | # load actor 200 | self.policy_net.load_state_dict(torch.load(name + '_actor')) 201 | self.policy_net.eval() 202 | 203 | # load q1 204 | self.soft_q_net1.load_state_dict(torch.load(name + '_q1')) 205 | self.soft_q_net1_target = copy.deepcopy(self.soft_q_net1) 206 | self.soft_q_net1.eval() 207 | 208 | # load q2 209 | self.soft_q_net2.load_state_dict(torch.load(name + '_q2')) 210 | self.soft_q_net2_target = copy.deepcopy(self.soft_q_net2) 211 | self.soft_q_net2.eval() 212 | 213 | """ 214 | Determine the action based on the state. 215 | """ 216 | def select_action(self, state, last_action, hidden_in, timestep, prev_reward, deterministic=True): 217 | 218 | with torch.no_grad(): 219 | state = torch.FloatTensor(state[:, -1].reshape(1, 1, -1)).to(self.device) 220 | last_action = torch.FloatTensor(last_action[:, -1].reshape(1, 1, -1)).to(self.device) 221 | mean, log_std, hidden_out = self.policy_net(state, last_action, hidden_in) 222 | std = log_std.exp() 223 | 224 | normal = Normal(0, 1) 225 | z = normal.sample(mean.shape).to(self.device) 226 | action = self.action_range * torch.tanh(mean + std * z) 227 | action = self.action_range * torch.tanh(mean).detach() if deterministic else action 228 | 229 | return action[0][0].detach().cpu().numpy(), hidden_out 230 | 231 | """ 232 | Get the action, log proabilities, ect. from a from state. 233 | """ 234 | def evaluate(self, state, last_action, hidden_in, epsilon=1e-6): 235 | 236 | mean, log_std, hidden_out = self.policy_net(state, last_action, hidden_in) 237 | std = log_std.exp() 238 | 239 | normal = Normal(0, 1) 240 | z = normal.sample(mean.shape) 241 | action_0 = torch.tanh(mean + std * z.to(self.device)) 242 | action = self.action_range * action_0 243 | log_prob = Normal(mean, std).log_prob(mean + std * z.to(self.device)) - torch.log(1. - action_0.pow(2) + epsilon) - np.log(self.action_range) 244 | log_prob = log_prob.sum(dim=-1, keepdim=True) 245 | return action, log_prob, z, mean, log_std, hidden_out 246 | 247 | """ 248 | Train the model on a pre-collected sample of training data. 249 | """ 250 | def train_model(self): 251 | 252 | # initialise the environment and set max timesteps 253 | env = gym.make(self.env_name) 254 | 255 | env.seed(self.train_seed) 256 | total_timesteps = 0 257 | 258 | # initialise the model 259 | self.init_model() 260 | 261 | while total_timesteps < self.training_timesteps: 262 | 263 | # Reset all the parameters ---------------------------------------------------------- 264 | total_rewards = 0 265 | 266 | # get the state 267 | insulin_dose = 1/3 * self.bas 268 | meal, done, bg_val = 0, False, env.reset() 269 | time = ((env.env.time.hour * 60) / 3 + env.env.time.minute / 3) / 479 270 | state = np.array([bg_val[0], meal, insulin_dose, time] , dtype=np.float32) 271 | 272 | # get a suitable input 273 | state_stack = np.tile(state, (self.sequence_length + 1, 1)) 274 | 275 | # ensure that the time is correct 276 | state_stack[:, 3] = (state_stack[:, 3] - np.arange(((self.sequence_length + 1) / 479), 0, -(1 / 479))[:self.sequence_length + 1]) * 479 277 | state_stack[:, 3] = (np.around(state_stack[:, 3], 0) % 480) / 479 278 | 279 | # get the action and reward stack 280 | action_stack = np.tile(np.array([insulin_dose], dtype=np.float32), (self.sequence_length + 1, 1)) 281 | reward_stack = np.tile(-calculate_risk(bg_val), (self.sequence_length + 1, 1)) 282 | done_stack = np.tile(np.array([False]), (self.sequence_length + 1, 1)) 283 | 284 | # get the meal history 285 | meal_history = np.zeros(int((3 * 60) / 3), dtype=np.float32) 286 | 287 | # initialise the hidden layer 288 | hidden_in = (torch.zeros([1, 1, 128], dtype=torch.float).to(self.device), 289 | torch.zeros([1, 1, 128], dtype=torch.float).to(self.device)) 290 | hidden_layers = [hidden_in] 291 | 292 | # initialise data and time tracking 293 | timesteps, counter = 0, 0 294 | 295 | while not done and timesteps < self.max_timesteps: 296 | 297 | # Get the player action ---------------------------------------------------- 298 | 299 | state = state_stack[1:, :3].reshape(1, self.sequence_length, 3) 300 | prev_action = action_stack[1:, :].reshape(1, self.sequence_length, 1) 301 | 302 | # Feed state into model 303 | state = (state - self.state_mean) / self.state_std 304 | prev_action = (prev_action - self.action_mean) / self.action_std 305 | 306 | action, hidden_out = self.select_action(state, prev_action, timestep=None, prev_reward=None, hidden_in=hidden_in, deterministic=False) 307 | output_action = np.maximum(np.minimum(action, np.ones(1)), -np.ones(1)) * self.action_std + self.action_mean 308 | 309 | # add the hidden layer 310 | hidden_layers.append(hidden_out) 311 | 312 | # Unnormalise action output and add gaussian noise 313 | action_pred = (output_action).clip(0, self.unnormed_max_action)[0] 314 | player_action = action_pred 315 | 316 | # Step the environment ---------------------------------------------------- 317 | 318 | # update the chosen action 319 | chosen_action = np.copy(player_action) 320 | 321 | # take meal bolus 322 | if meal > 0: 323 | chosen_action += calculate_bolus( 324 | bg_val, meal_history, meal, self.params['carbohydrate_ratio'], 325 | self.params['correction_factor'], self.params['target_blood_glucose'] 326 | ) 327 | 328 | # append the basal and bolus action 329 | action_stack = np.delete(action_stack, 0, 0) 330 | action_stack = np.vstack([action_stack, player_action]) 331 | 332 | # step the simulator 333 | next_bg_val, _, done, info = env.step(chosen_action) 334 | reward = -calculate_risk(next_bg_val) 335 | 336 | # get the rnn array format for state 337 | time = ((env.env.time.hour * 60) / 3 + env.env.time.minute / 3) / 479 338 | next_state = np.array([float(next_bg_val[0]), float(info['meal']), float(chosen_action), time], dtype=np.float32) 339 | 340 | # update the state stacks 341 | next_state_stack = np.delete(state_stack, 0, 0) 342 | next_state_stack = np.vstack([next_state_stack, next_state]) 343 | reward_stack = np.delete(reward_stack, 0, 0) 344 | reward_stack = np.vstack([reward_stack, np.array([reward], dtype=np.float32)]) 345 | done_stack = np.delete(done_stack, 0, 0) 346 | done_stack = np.vstack([done_stack, np.array([done], dtype=np.float32)]) 347 | 348 | # add a termination penalty 349 | if done: reward = -1e5 350 | 351 | # update the memory --------------------------------------------------- 352 | 353 | counter += 1 354 | if counter % self.sequence_length == 0 or done or timesteps == self.max_timesteps - 1: 355 | 356 | # get the states in the correct form 357 | state_inp = next_state_stack[:-1, :3].reshape(1, self.sequence_length, 3) 358 | next_state_inp = next_state_stack[1:, :3].reshape(1, self.sequence_length, 3) 359 | reward_inp = reward_stack[1:, :].reshape(1, self.sequence_length) 360 | last_action_inp = action_stack[:-1, :].reshape(1, self.sequence_length) 361 | action_inp = action_stack[1:, :].reshape(1, self.sequence_length) 362 | done_inp = done_stack[:-1, :].reshape(1, self.sequence_length) 363 | 364 | # reset the counter and upload the data 365 | counter = 0 366 | self.memory.append(( 367 | state_inp, action_inp, reward_inp, 368 | next_state_inp, done_inp, None, None, 369 | last_action_inp, 370 | hidden_layers[-self.sequence_length], 371 | hidden_layers[-self.sequence_length + 1]) 372 | ) 373 | 374 | # update the states --------------------------------------------------- 375 | 376 | # update the meal history 377 | meal_history = np.append(meal_history, meal) 378 | meal_history = np.delete(meal_history, 0) 379 | 380 | # update the state stacks 381 | state_stack = next_state_stack 382 | 383 | # update the state 384 | bg_val = next_bg_val 385 | state = next_state 386 | meal = info['meal'] 387 | timesteps += 1 388 | total_timesteps += 1 389 | hidden_in = hidden_out 390 | total_rewards += reward 391 | 392 | # break the loop if terminated 393 | if done: break 394 | 395 | # Sample a batch of data ------------------------------------------------ 396 | 397 | if total_timesteps >= self.starting_timesteps: 398 | 399 | # unpackage the samples and split 400 | state_array, action_array, reward_array, next_state_array, done_array, _, _, last_action_array, hidden_in_array, hidden_out_array = get_batch( 401 | replay=self.memory, batch_size=self.batch_size, 402 | data_processing=self.data_processing, 403 | sequence_length=self.sequence_length, device=self.device, 404 | params=self.params 405 | ) 406 | 407 | # Training --------------------------------------------------------- 408 | 409 | reward_array = (reward_array - torch.mean(reward_array).to(self.device)) / (torch.std(reward_array) + 1e-6).to(self.device) 410 | 411 | # get q values 412 | predicted_q_value1, _ = self.soft_q_net1(state_array, action_array, last_action_array, hidden_in_array) 413 | predicted_q_value2, _ = self.soft_q_net2(state_array, action_array, last_action_array, hidden_in_array) 414 | 415 | # predict actions 416 | new_action, log_prob, z, mean, log_std, _ = self.evaluate(state_array, last_action_array, hidden_in_array) 417 | new_next_action, next_log_prob, _, _, _, _ = self.evaluate(next_state_array, action_array, hidden_out_array) 418 | 419 | 420 | if self.entropy: 421 | alpha_loss = -(self.log_alpha * (log_prob + self.target_entropy).detach()).mean() 422 | self.alpha_optimizer.zero_grad() 423 | alpha_loss.backward() 424 | self.alpha_optimizer.step() 425 | self.alpha = self.log_alpha.exp() 426 | else: 427 | self.alpha = 1. 428 | alpha_loss = 0 429 | 430 | # calculate the q function loss 431 | predict_target_q1, _ = self.target_soft_q_net1(next_state_array, new_next_action, action_array, hidden_out_array) 432 | predict_target_q2, _ = self.target_soft_q_net2(next_state_array, new_next_action, action_array, hidden_out_array) 433 | target_q_min = torch.min(predict_target_q1, predict_target_q2) - self.alpha * next_log_prob 434 | target_q_value = reward_array + done_array * self.gamma * target_q_min 435 | 436 | q_value_loss1 = self.soft_q_criterion1(predicted_q_value1, target_q_value.detach()) 437 | q_value_loss2 = self.soft_q_criterion2(predicted_q_value2, target_q_value.detach()) 438 | 439 | # step the optimisers 440 | self.soft_q_optimizer1.zero_grad() 441 | q_value_loss1.backward() 442 | self.soft_q_optimizer1.step() 443 | self.soft_q_optimizer2.zero_grad() 444 | q_value_loss2.backward() 445 | self.soft_q_optimizer2.step() 446 | 447 | # calculate the policy loss 448 | predict_q1, _= self.soft_q_net1(state_array, new_action, last_action_array, hidden_in_array) 449 | predict_q2, _ = self.soft_q_net2(state_array, new_action, last_action_array, hidden_in_array) 450 | predicted_new_q_value = torch.min(predict_q1, predict_q2) 451 | policy_loss = (self.alpha * log_prob - predicted_new_q_value).mean() 452 | 453 | # step the policy optimiser 454 | self.policy_optimizer.zero_grad() 455 | policy_loss.backward() 456 | self.policy_optimizer.step() 457 | 458 | # update the target networks 459 | for target_param, param in zip(self.target_soft_q_net1.parameters(), self.soft_q_net1.parameters()): 460 | target_param.data.copy_(target_param.data * (1.0 - self.tau) + param.data * self.tau) 461 | for target_param, param in zip(self.target_soft_q_net2.parameters(), self.soft_q_net2.parameters()): 462 | target_param.data.copy_(target_param.data * (1.0 - self.tau) + param.data * self.tau) 463 | 464 | # Testing ---------------------------------------------------------- 465 | 466 | # show the progress 467 | if total_timesteps % self.training_progress_freq == 0: 468 | 469 | # show the updated loss 470 | print('Timesteps {} - Actor Loss {} - Q1 Loss {} - Q2 Loss {}'.format(total_timesteps, policy_loss, q_value_loss1, q_value_loss2)) 471 | self.save_model() 472 | 473 | print('Episode score {} - Episode Timesteps {}'.format(total_rewards, timesteps)) 474 | 475 | """ 476 | Test the learned weights against the PID controller. 477 | """ 478 | def test_model(self, input_seed=0, input_max_timesteps=4800): 479 | 480 | # initialise the environment 481 | env = gym.make(self.env_name) 482 | 483 | # initialise the model 484 | self.init_model() 485 | self.load_model('./Models/' + self.folder_name + "/" + "Seed" + str(self.train_seed) + "/" + 'SAC_RNN_online_weights') 486 | test_seed, max_timesteps = input_seed, input_max_timesteps 487 | 488 | # test the algorithm's performance vs pid algorithm 489 | rl_reward, rl_bg, rl_action, rl_insulin, rl_meals, pid_reward, pid_bg, pid_action = test_algorithm( 490 | env=env, agent_action=self.select_action, seed=test_seed, max_timesteps=max_timesteps, 491 | sequence_length=self.sequence_length, data_processing=self.data_processing, 492 | pid_run=False, lstm=True, params=self.params 493 | ) 494 | 495 | # display the results 496 | create_graph( 497 | rl_reward=rl_reward, rl_blood_glucose=rl_bg, rl_action=rl_action, rl_insulin=rl_insulin, 498 | rl_meals=rl_meals, pid_reward=pid_reward, pid_blood_glucose=pid_bg, 499 | pid_action=pid_action, params=self.params 500 | ) 501 | -------------------------------------------------------------------------------- /utils/parameters.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sat Apr 2 13:32:17 2022 5 | """ 6 | 7 | """ 8 | Functions for handling the environmental parameters 9 | """ 10 | 11 | from gym.envs.registration import register 12 | 13 | """ 14 | Register all the potential training environments. 15 | """ 16 | def create_env(schedule): 17 | 18 | register( 19 | id='simglucose-child1-v0', 20 | entry_point='simglucose.envs:T1DSimEnv', 21 | kwargs={'patient_name': 'child#001', 22 | "schedule":schedule 23 | } 24 | ) 25 | 26 | register( 27 | id='simglucose-child2-v0', 28 | entry_point='simglucose.envs:T1DSimEnv', 29 | kwargs={'patient_name': 'child#002', 30 | "schedule":schedule 31 | } 32 | ) 33 | 34 | register( 35 | id='simglucose-child3-v0', 36 | entry_point='simglucose.envs:T1DSimEnv', 37 | kwargs={'patient_name': 'child#003', 38 | "schedule":schedule 39 | } 40 | ) 41 | 42 | register( 43 | id='simglucose-child4-v0', 44 | entry_point='simglucose.envs:T1DSimEnv', 45 | kwargs={'patient_name': 'child#004', 46 | "schedule":schedule 47 | } 48 | ) 49 | 50 | register( 51 | id='simglucose-child5-v0', 52 | entry_point='simglucose.envs:T1DSimEnv', 53 | kwargs={'patient_name': 'child#005', 54 | "schedule":schedule 55 | } 56 | ) 57 | 58 | register( 59 | id='simglucose-child6-v0', 60 | entry_point='simglucose.envs:T1DSimEnv', 61 | kwargs={'patient_name': 'child#006', 62 | "schedule":schedule 63 | } 64 | ) 65 | 66 | register( 67 | id='simglucose-child7-v0', 68 | entry_point='simglucose.envs:T1DSimEnv', 69 | kwargs={'patient_name': 'child#007', 70 | "schedule":schedule 71 | } 72 | ) 73 | 74 | register( 75 | id='simglucose-child8-v0', 76 | entry_point='simglucose.envs:T1DSimEnv', 77 | kwargs={'patient_name': 'child#008', 78 | "schedule":schedule 79 | } 80 | ) 81 | 82 | register( 83 | id='simglucose-child9-v0', 84 | entry_point='simglucose.envs:T1DSimEnv', 85 | kwargs={'patient_name': 'child#009', 86 | "schedule":schedule 87 | } 88 | ) 89 | 90 | register( 91 | id='simglucose-child10-v0', 92 | entry_point='simglucose.envs:T1DSimEnv', 93 | kwargs={'patient_name': 'child#010', 94 | "schedule":schedule 95 | } 96 | ) 97 | 98 | # ADOLESCENTS ####################################### 99 | 100 | register( 101 | id='simglucose-adolescent1-v0', 102 | entry_point='simglucose.envs:T1DSimEnv', 103 | kwargs={'patient_name': 'adolescent#001', 104 | "schedule":schedule 105 | } 106 | ) 107 | 108 | register( 109 | id='simglucose-adolescent2-v0', 110 | entry_point='simglucose.envs:T1DSimEnv', 111 | kwargs={'patient_name': 'adolescent#002', 112 | "schedule":schedule 113 | } 114 | ) 115 | 116 | register( 117 | id='simglucose-adolescent3-v0', 118 | entry_point='simglucose.envs:T1DSimEnv', 119 | kwargs={'patient_name': 'adolescent#003', 120 | "schedule":schedule 121 | } 122 | ) 123 | 124 | register( 125 | id='simglucose-adolescent4-v0', 126 | entry_point='simglucose.envs:T1DSimEnv', 127 | kwargs={'patient_name': 'adolescent#004', 128 | "schedule":schedule 129 | } 130 | ) 131 | 132 | register( 133 | id='simglucose-adolescent5-v0', 134 | entry_point='simglucose.envs:T1DSimEnv', 135 | kwargs={'patient_name': 'adolescent#005', 136 | "schedule":schedule 137 | } 138 | ) 139 | 140 | register( 141 | id='simglucose-adolescent6-v0', 142 | entry_point='simglucose.envs:T1DSimEnv', 143 | kwargs={'patient_name': 'adolescent#006', 144 | "schedule":schedule 145 | } 146 | ) 147 | 148 | register( 149 | id='simglucose-adolescent7-v0', 150 | entry_point='simglucose.envs:T1DSimEnv', 151 | kwargs={'patient_name': 'adolescent#007', 152 | "schedule":schedule 153 | } 154 | ) 155 | 156 | register( 157 | id='simglucose-adolescent8-v0', 158 | entry_point='simglucose.envs:T1DSimEnv', 159 | kwargs={'patient_name': 'adolescent#008', 160 | "schedule":schedule 161 | } 162 | ) 163 | 164 | register( 165 | id='simglucose-adolescent9-v0', 166 | entry_point='simglucose.envs:T1DSimEnv', 167 | kwargs={'patient_name': 'adolescent#009', 168 | "schedule":schedule 169 | } 170 | ) 171 | 172 | register( 173 | id='simglucose-adolescent10-v0', 174 | entry_point='simglucose.envs:T1DSimEnv', 175 | kwargs={'patient_name': 'adolescent#010', 176 | "schedule":schedule 177 | } 178 | ) 179 | 180 | # ADULTS ################################################## 181 | 182 | register( 183 | id='simglucose-adult1-v0', 184 | entry_point='simglucose.envs:T1DSimEnv', 185 | kwargs={'patient_name': 'adult#001', 186 | "schedule":schedule 187 | } 188 | ) 189 | 190 | register( 191 | id='simglucose-adult2-v0', 192 | entry_point='simglucose.envs:T1DSimEnv', 193 | kwargs={'patient_name': 'adult#002', 194 | "schedule":schedule 195 | } 196 | ) 197 | 198 | register( 199 | id='simglucose-adult3-v0', 200 | entry_point='simglucose.envs:T1DSimEnv', 201 | kwargs={'patient_name': 'adult#003', 202 | "schedule":schedule 203 | } 204 | ) 205 | 206 | register( 207 | id='simglucose-adult4-v0', 208 | entry_point='simglucose.envs:T1DSimEnv', 209 | kwargs={'patient_name': 'adult#004', 210 | "schedule":schedule 211 | } 212 | ) 213 | 214 | register( 215 | id='simglucose-adult5-v0', 216 | entry_point='simglucose.envs:T1DSimEnv', 217 | kwargs={'patient_name': 'adult#005', 218 | "schedule":schedule 219 | } 220 | ) 221 | 222 | register( 223 | id='simglucose-adult6-v0', 224 | entry_point='simglucose.envs:T1DSimEnv', 225 | kwargs={'patient_name': 'adult#006', 226 | "schedule":schedule 227 | } 228 | ) 229 | 230 | register( 231 | id='simglucose-adult7-v0', 232 | entry_point='simglucose.envs:T1DSimEnv', 233 | kwargs={'patient_name': 'adult#007', 234 | "schedule":schedule 235 | } 236 | ) 237 | 238 | register( 239 | id='simglucose-adult8-v0', 240 | entry_point='simglucose.envs:T1DSimEnv', 241 | kwargs={'patient_name': 'adult#008', 242 | "schedule":schedule 243 | } 244 | ) 245 | 246 | register( 247 | id='simglucose-adult9-v0', 248 | entry_point='simglucose.envs:T1DSimEnv', 249 | kwargs={'patient_name': 'adult#009', 250 | "schedule":schedule 251 | } 252 | ) 253 | 254 | register( 255 | id='simglucose-adult10-v0', 256 | entry_point='simglucose.envs:T1DSimEnv', 257 | kwargs={'patient_name': 'adult#010', 258 | "schedule":schedule 259 | } 260 | ) 261 | 262 | 263 | 264 | """ 265 | Get the patient parameters. 266 | """ 267 | def get_params(): 268 | 269 | params = { 270 | 271 | # CHILDREN ############################################## 272 | 273 | "child#1": { 274 | 275 | "folder_name": "child#1", 276 | "env_name" : 'simglucose-child1-v0', 277 | "u2ss" : 1.14220356012, 278 | "BW" : 34.55648182, 279 | "carbohydrate_ratio": 28.6156949676669, 280 | "correction_factor": 103.016501883601, 281 | "kp": -1.00E-05, 282 | "ki": -1.00E-09, 283 | "kd": -1.00E-03, 284 | "max_dose": 0.6, 285 | "replay_name" : "Child#1-1e5" 286 | 287 | }, 288 | 289 | "child#1-10": { 290 | 291 | "folder_name": "child#1", 292 | "env_name" : 'simglucose-child1-v0', 293 | "u2ss" : 1.14220356012, 294 | "BW" : 34.55648182, 295 | "carbohydrate_ratio": 28.6156949676669, 296 | "correction_factor": 103.016501883601, 297 | "kp": -1.00E-05, 298 | "ki": -1.00E-08, 299 | "kd": -1.00E-03, 300 | "max_dose": 0.6, 301 | "replay_name" : "Child#1-1e5-10" 302 | 303 | }, 304 | 305 | "child#1-20": { 306 | 307 | "folder_name": "child#1", 308 | "env_name" : 'simglucose-child1-v0', 309 | "u2ss" : 1.14220356012, 310 | "BW" : 34.55648182, 311 | "carbohydrate_ratio": 28.6156949676669, 312 | "correction_factor": 103.016501883601, 313 | "kp": -1.00E-04, 314 | "ki": -1.00E-07, 315 | "kd": -1.00E-03, 316 | "max_dose": 0.6, 317 | "replay_name" : "Child#1-1e5-20" 318 | 319 | }, 320 | 321 | "child#2": { 322 | 323 | "folder_name": "child#2", 324 | "env_name" : 'simglucose-child2-v0', 325 | "u2ss" : 1.38470169593, 326 | "BW" : 28.53257352, 327 | "carbohydrate_ratio": 27.5060230377229, 328 | "correction_factor": 99.0216829358025, 329 | "kp": -1.00E-05, 330 | "ki": -1.00E-08, 331 | "kd": -1.00E-02, 332 | "max_dose": 1.5, 333 | "replay_name" : "Child#2-1e5" 334 | 335 | }, 336 | 337 | "child#3": { 338 | 339 | "folder_name": "child#3", 340 | "env_name" : 'simglucose-child3-v0', 341 | "u2ss" : 0.70038560703, 342 | "BW" : 41.23304017, 343 | "carbohydrate_ratio": 31.2073322051186, 344 | "correction_factor": 112.346395938427, 345 | "kp": -1.00E-05, 346 | "ki": -1.00E-08, 347 | "kd": -1.00E-03, 348 | "max_dose": 0.7, 349 | "replay_name" : "Child#3-1e5" 350 | 351 | }, 352 | 353 | "child#4": { 354 | 355 | "folder_name": "child#4", 356 | "env_name" : 'simglucose-child4-v0', 357 | "u2ss" : 1.38610897835, 358 | "BW" : 35.5165043, 359 | "carbohydrate_ratio": 25.23323213020198, 360 | "correction_factor": 90.83963566872713, 361 | "kp": -1e-05, 362 | "ki": -1e-11, 363 | "kd": -1e-03, 364 | "max_dose": 0.5, 365 | "replay_name" : "Child#4-1e5" 366 | 367 | }, 368 | 369 | "child#5": { 370 | 371 | "folder_name": "child#5", 372 | "env_name" : 'simglucose-child5-v0', 373 | "u2ss" : 1.36318862639, 374 | "BW" : 37.78855797, 375 | "carbohydrate_ratio": 12.21462592173681, 376 | "correction_factor": 43.97265331825251, 377 | "kp": -1e-04, 378 | "ki": -1e-07, 379 | "kd": -1e-02, 380 | "max_dose": 1.5, 381 | "replay_name" : "Child#5-1e5" 382 | 383 | }, 384 | 385 | "child#6": { 386 | 387 | "folder_name": "child#6", 388 | "env_name" : 'simglucose-child6-v0', 389 | "u2ss" : 0.985487128297, 390 | "BW" : 41.00214896, 391 | "carbohydrate_ratio": 24.723079998277314, 392 | "correction_factor": 89.00308799379833, 393 | "kp": -1e-05, 394 | "ki": -1e-08, 395 | "kd": -1e-03, 396 | "max_dose": 0.5, 397 | "replay_name" : "Child#6-1e5" 398 | 399 | }, 400 | 401 | "child#7": { 402 | 403 | "folder_name": "child#7", 404 | "env_name" : 'simglucose-child7-v0', 405 | "u2ss" : 1.02592147609, 406 | "BW" : 45.5397665, 407 | "carbohydrate_ratio": 13.807252026084589, 408 | "correction_factor": 49.706107293904516, 409 | "kp": -1e-07, 410 | "ki": -1e-08, 411 | "kd": -1e-03, 412 | "max_dose": 0.5, 413 | "replay_name" : "Child#7-1e5" 414 | 415 | }, 416 | 417 | "child#8": { 418 | 419 | "folder_name": "child#8", 420 | "env_name" : 'simglucose-child8-v0', 421 | "u2ss" : 1.43273282863, 422 | "BW" : 23.73405728, 423 | "carbohydrate_ratio": 23.261842061321445, 424 | "correction_factor": 83.7426314207572, 425 | "kp": -1e-07, 426 | "ki": -1e-11, 427 | "kd": -1e-02, 428 | "max_dose": 5.0, 429 | "replay_name" : "Child#8-1e5" 430 | 431 | }, 432 | 433 | "child#9": { 434 | 435 | "folder_name": "child#9", 436 | "env_name" : 'simglucose-child9-v0', 437 | "u2ss" : 1.10155422738, 438 | "BW" : 35.53392558, 439 | "carbohydrate_ratio": 28.74519570209282, 440 | "correction_factor": 103.48270452753414, 441 | "kp": -1e-07, 442 | "ki": -1e-07, 443 | "kd": -1e-07, 444 | "max_dose": 2.5, 445 | "replay_name" : "Child#9-1e5" 446 | 447 | }, 448 | 449 | "child#10": { 450 | 451 | "folder_name": "child#10", 452 | "env_name" : 'simglucose-child10-v0', 453 | "u2ss" : 1.12891185261, 454 | "BW" : 35.21305847, 455 | "carbohydrate_ratio": 24.21108601288932, 456 | "correction_factor": 87.15990964640156, 457 | "kp": -1e-05, 458 | "ki": -1e-08, 459 | "kd": -1e-03, 460 | "max_dose": 0.5, 461 | "replay_name" : "Child#10-1e5" 462 | 463 | }, 464 | 465 | # ADOLESCENTS ############################################# 466 | 467 | "adolescent#1": { 468 | 469 | "folder_name": "adolescent#1", 470 | "env_name" : 'simglucose-adolescent1-v0', 471 | "u2ss" : 1.21697571391, 472 | "BW" : 68.706, 473 | "carbohydrate_ratio": 13.6113998281669, 474 | "correction_factor": 49.0010393814008, 475 | "kp": -1.00E-04, 476 | "ki": -1.00E-07, 477 | "kd": -1.00E-02, 478 | "max_dose": 1.5, 479 | "replay_name" : "Adolescent#1-1e5" 480 | 481 | }, 482 | 483 | "adolescent#1-10": { 484 | 485 | "folder_name": "adolescent#1", 486 | "env_name" : 'simglucose-adolescent1-v0', 487 | "u2ss" : 1.21697571391, 488 | "BW" : 68.706, 489 | "carbohydrate_ratio": 13.6113998281669, 490 | "correction_factor": 49.0010393814008, 491 | "kp": -1.00E-06, 492 | "ki": -1.00E-08, 493 | "kd": -1.00E-02, 494 | "max_dose": 1.5, 495 | "replay_name" : "Adolescent#1-1e5-10" 496 | 497 | }, 498 | 499 | "adolescent#1-20": { 500 | 501 | "folder_name": "adolescent#1", 502 | "env_name" : 'simglucose-adolescent1-v0', 503 | "u2ss" : 1.21697571391, 504 | "BW" : 68.706, 505 | "carbohydrate_ratio": 13.6113998281669, 506 | "correction_factor": 49.0010393814008, 507 | "kp": -1.00E-07, 508 | "ki": -1.00E-11, 509 | "kd": -1.00E-02, 510 | "max_dose": 1.5, 511 | "replay_name" : "Adolescent#1-1e5-20" 512 | 513 | }, 514 | 515 | "adolescent#2": { 516 | 517 | "folder_name": "adolescent#2", 518 | "env_name" : 'simglucose-adolescent2-v0', 519 | "u2ss" : 1.79829979626, 520 | "BW" : 51.046, 521 | "carbohydrate_ratio": 8.06048033285474, 522 | "correction_factor": 29.0177291982771, 523 | "kp": -1.00E-04, 524 | "ki": -1.00E-07, 525 | "kd": -1.00E-02, 526 | "max_dose": 1.5, 527 | "replay_name" : "Adolescent#2-1e5" 528 | 529 | }, 530 | 531 | "adolescent#3": { 532 | 533 | "folder_name": "adolescent#3", 534 | "env_name" : 'simglucose-adolescent3-v0', 535 | "u2ss" : 1.4462660088, 536 | "BW" : 44.791, 537 | "carbohydrate_ratio": 20.6246970212749, 538 | "correction_factor": 74.2489092765897, 539 | "kp": -1.00E-04, 540 | "ki": -1.00E-07, 541 | "kd": -1.00E-02, 542 | "max_dose": 1.2, 543 | "replay_name" : "Adolescent#3-1e5" 544 | 545 | }, 546 | 547 | "adolescent#4": { 548 | 549 | "folder_name": "adolescent#4", 550 | "env_name" : 'simglucose-adolescent4-v0', 551 | "u2ss" : 1.76263284642, 552 | "BW" : 49.564, 553 | "carbohydrate_ratio": 14.18324377702899, 554 | "correction_factor": 51.05967759730436, 555 | "kp": -1e-04, 556 | "ki": -1e-07, 557 | "kd": -1e-02, 558 | "max_dose": 1.5, 559 | "replay_name" : "Adolescent#4-1e5" 560 | 561 | }, 562 | 563 | "adolescent#5": { 564 | 565 | "folder_name": "adolescent#5", 566 | "env_name" : 'simglucose-adolescent5-v0', 567 | "u2ss" : 1.5346452819, 568 | "BW" : 47.074, 569 | "carbohydrate_ratio": 14.703840790944376, 570 | "correction_factor": 52.93382684739976, 571 | "kp": -1e-04, 572 | "ki": -1e-07, 573 | "kd": -1e-02, 574 | "max_dose": 1.5, 575 | "replay_name" : "Adolescent#5-1e5" 576 | 577 | }, 578 | "adolescent#6": { 579 | 580 | "folder_name": "adolescent#6", 581 | "env_name" : 'simglucose-adolescent6-v0', 582 | "u2ss" : 1.92787834743, 583 | "BW" : 45.408, 584 | "carbohydrate_ratio": 10.084448671441356, 585 | "correction_factor": 36.30401521718888, 586 | "kp": -1e-04, 587 | "ki": -1e-07, 588 | "kd": -1e-02, 589 | "max_dose": 1.5, 590 | "replay_name" : "Adolescent#6-1e5" 591 | 592 | }, 593 | 594 | "adolescent#7": { 595 | 596 | "folder_name": "adolescent#7", 597 | "env_name" : 'simglucose-adolescent7-v0', 598 | "u2ss" : 2.04914771228, 599 | "BW" : 37.898, 600 | "carbohydrate_ratio": 11.457886857675446, 601 | "correction_factor": 41.24839268763161, 602 | "kp": -1e-04, 603 | "ki": -1e-07, 604 | "kd": -1e-02, 605 | "max_dose": 1.75, 606 | "replay_name" : "Adolescent#7-1e5" 607 | 608 | }, 609 | "adolescent#8": { 610 | 611 | "folder_name": "adolescent#8", 612 | "env_name" : 'simglucose-adolescent8-v0', 613 | "u2ss" : 1.35324144985, 614 | "BW" : 41.218, 615 | "carbohydrate_ratio": 7.888090404486432, 616 | "correction_factor": 28.397125456151155, 617 | "kp": -1e-04, 618 | "ki": -1e-07, 619 | "kd": -1e-02, 620 | "max_dose": 2.5, 621 | "replay_name" : "Adolescent#8-1e5" 622 | 623 | }, 624 | 625 | "adolescent#9": { 626 | 627 | "folder_name": "adolescent#9", 628 | "env_name" : 'simglucose-adolescent9-v0', 629 | "u2ss" : 1.38186522046, 630 | "BW" : 43.885, 631 | "carbohydrate_ratio": 20.76570050945875, 632 | "correction_factor": 74.7565218340515, 633 | "kp": -1e-07, 634 | "ki": -1e-07, 635 | "kd": -1e-02, 636 | "max_dose": 1.5, 637 | "replay_name" : "Adolescent#9-1e5" 638 | 639 | }, 640 | 641 | "adolescent#10": { 642 | 643 | "folder_name": "adolescent#10", 644 | "env_name" : 'simglucose-adolescent10-v0', 645 | "u2ss" : 1.66109036262, 646 | "BW" : 47.378, 647 | "carbohydrate_ratio": 15.07226804643741, 648 | "correction_factor": 54.260164967174674, 649 | "kp": -1e-04, 650 | "ki": -1e-07, 651 | "kd": -1e-02, 652 | "max_dose": 1.5, 653 | "replay_name" : "Adolescent#10-1e5" 654 | }, 655 | 656 | # ADULTS ############################################### 657 | 658 | "adult#1": { 659 | 660 | "folder_name": "adult#1", 661 | "env_name" : 'simglucose-adult1-v0', 662 | "u2ss" : 1.2386244136, 663 | "BW" : 102.32, 664 | "carbohydrate_ratio": 9.9173582569505, 665 | "correction_factor": 35.7024897250218, 666 | "kp": -1.00E-04, 667 | "ki": -1.00E-07, 668 | "kd": -1.00E-02, 669 | "max_dose": 0.75, 670 | "replay_name" : "Adult#1-1e5" 671 | 672 | }, 673 | 674 | "adult#1-10": { 675 | 676 | "folder_name": "adult#1", 677 | "env_name" : 'simglucose-adult1-v0', 678 | "u2ss" : 1.2386244136, 679 | "BW" : 102.32, 680 | "carbohydrate_ratio": 9.9173582569505, 681 | "correction_factor": 35.7024897250218, 682 | "kp": -1.00E-06, 683 | "ki": -1.00E-08, 684 | "kd": -1.00E-02, 685 | "max_dose": 0.75, 686 | "replay_name" : "Adult#1-1e5-10" 687 | 688 | }, 689 | 690 | "adult#1-20": { 691 | 692 | "folder_name": "adult#1", 693 | "env_name" : 'simglucose-adult1-v0', 694 | "u2ss" : 1.2386244136, 695 | "BW" : 102.32, 696 | "carbohydrate_ratio": 9.9173582569505, 697 | "correction_factor": 35.7024897250218, 698 | "kp": -1.00E-07, 699 | "ki": -1.00E-11, 700 | "kd": -1.00E-02, 701 | "max_dose": 0.75, 702 | "replay_name" : "Adult#1-1e5-20" 703 | 704 | }, 705 | 706 | "adult#2": { 707 | 708 | "folder_name": "adult#2", 709 | "env_name" : 'simglucose-adult2-v0', 710 | "u2ss" : 1.23270240324, 711 | "BW" : 111.1, 712 | "carbohydrate_ratio": 8.64023791338857, 713 | "correction_factor": 31.1048564881989, 714 | "kp": -1.00E-04, 715 | "ki": -1.00E-07, 716 | "kd": -1.00E-02, 717 | "max_dose": 0.7, 718 | "replay_name" : "Adult#2-1e5" 719 | 720 | }, 721 | 722 | "adult#3": { 723 | 724 | "folder_name": "adult#3", 725 | "env_name" : 'simglucose-adult3-v0', 726 | "u2ss" : 1.74604298612, 727 | "BW" : 81.631, 728 | "carbohydrate_ratio": 8.86057935797141, 729 | "correction_factor": 31.8980856886971, 730 | "kp": -1.00E-04, 731 | "ki": -1.00E-07, 732 | "kd": -1.00E-02, 733 | "max_dose": 0.9, 734 | "replay_name" : "Adult#3-1e5" 735 | 736 | }, 737 | 738 | "adult#4": { 739 | 740 | "folder_name": "adult#4", 741 | "env_name" : 'simglucose-adult4-v0', 742 | "u2ss" : 1.40925544793, 743 | "BW" : 63.0, 744 | "carbohydrate_ratio": 14.789424168083986, 745 | "correction_factor": 53.24192700510235, 746 | "kp": -1e-05, 747 | "ki": -1e-07, 748 | "kd": -1e-02, 749 | "max_dose": 1.5, 750 | "replay_name" : "Adult#4-1e5" 751 | 752 | }, 753 | 754 | "adult#5": { 755 | 756 | "folder_name": "adult#5", 757 | "env_name" : 'simglucose-adult5-v0', 758 | "u2ss" : 1.25415109169, 759 | "BW" : 94.074, 760 | "carbohydrate_ratio": 7.318937998432252, 761 | "correction_factor": 26.348176794356107, 762 | "kp": -1e-03, 763 | "ki": -1e-07, 764 | "kd": -1e-02, 765 | "max_dose": 1.0, 766 | "replay_name" : "Adult#5-1e5" 767 | 768 | }, 769 | 770 | "adult#6": { 771 | 772 | "folder_name": "adult#6", 773 | "env_name" : 'simglucose-adult6-v0', 774 | "u2ss" : 2.60909529933, 775 | "BW" : 66.097, 776 | "carbohydrate_ratio": 8.144806942246657, 777 | "correction_factor": 29.321304992087967, 778 | "kp": -1e-04, 779 | "ki": -1e-07, 780 | "kd": -1e-02, 781 | "max_dose": 1.5, 782 | "replay_name" : "Adult#6-1e5" 783 | 784 | }, 785 | 786 | "adult#7": { 787 | 788 | "folder_name": "adult#7", 789 | "env_name" : 'simglucose-adult7-v0', 790 | "u2ss" : 1.50334589878, 791 | "BW" : 91.229, 792 | "carbohydrate_ratio": 11.902889350456292, 793 | "correction_factor": 42.85040166164265, 794 | "kp": -1e-07, 795 | "ki": -1e-07, 796 | "kd": -1e-02, 797 | "max_dose": 1.0, 798 | "replay_name" : "Adult#7-1e5" 799 | 800 | }, 801 | 802 | "adult#8": { 803 | 804 | "folder_name": "adult#8", 805 | "env_name" : 'simglucose-adult8-v0', 806 | "u2ss" : 1.11044245549, 807 | "BW" : 102.79, 808 | "carbohydrate_ratio": 11.68803605523481, 809 | "correction_factor": 42.07692979884532, 810 | "kp": -1e-04, 811 | "ki": -1e-07, 812 | "kd": -1e-02, 813 | "max_dose": 1.0, 814 | "replay_name" : "Adult#8-1e5" 815 | 816 | }, 817 | 818 | "adult#9": { 819 | 820 | "folder_name": "adult#9", 821 | "env_name" : 'simglucose-adult9-v0', 822 | "u2ss" : 1.51977345451, 823 | "BW" : 74.604, 824 | "carbohydrate_ratio": 7.439205003922471, 825 | "correction_factor": 26.781138014120895, 826 | "kp": -1e-04, 827 | "ki": -1e-07, 828 | "kd": -1e-02, 829 | "max_dose": 1.5, 830 | "replay_name" : "Adult#9-1e5" 831 | 832 | }, 833 | 834 | "adult#10": { 835 | 836 | "folder_name": "adult#10", 837 | "env_name" : 'simglucose-adult10-v0', 838 | "u2ss" : 1.37923535927, 839 | "BW" : 73.859, 840 | "carbohydrate_ratio": 7.758126846037283, 841 | "correction_factor": 27.92925664573422, 842 | "kp": -1e-04, 843 | "ki": -1e-07, 844 | "kd": -1e-02, 845 | "max_dose": 1.5, 846 | "replay_name" : "Adult#10-1e5" 847 | 848 | }, 849 | 850 | 851 | 852 | } 853 | 854 | return params 855 | --------------------------------------------------------------------------------