├── RED ├── __init__.py ├── configs │ ├── model │ │ ├── KerasFittedQAgent.yaml │ │ ├── DRPG_agent.yaml │ │ └── RT3D_agent.yaml │ ├── example │ │ ├── Figure_2_FQ_chemostat.yaml │ │ ├── RT3D_gene_transcription.yaml │ │ ├── Figure_S2_FQ_gene_transcription.yaml │ │ ├── Figure_3_RT3D_chemostat.yaml │ │ └── Figure_4_RT3D_chemostat.yaml │ ├── train.yaml │ └── environment │ │ ├── gene_transcription.yaml │ │ └── chemostat.yaml ├── utils │ ├── data.py │ ├── visualization.py │ └── network.py ├── environments │ ├── gene_transcription │ │ └── xdot_gene_transcription.py │ ├── chemostat │ │ └── xdot_chemostat.py │ └── OED_env.py ├── README.md ├── run_RED.py └── agents │ ├── continuous_agents │ ├── drpg.py │ └── rt3d.py │ └── fitted_Q_agents.py ├── LICENSE ├── README.md └── examples ├── Figure_2_FQ_chemostat ├── OSAO_param_inf.py ├── MPC_param_inf.py └── train_FQ.py ├── RT3D_gene_transcription ├── MPC_param_inf.py └── train_RT3D.py ├── Figure_S2_FQ_gene_transcription └── train_FQ.py ├── Figure_3_RT3D_chemostat └── train_RT3D.py └── Figure_4_RT3D_chemostat └── train_RT3D.py /RED/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /RED/configs/model/KerasFittedQAgent.yaml: -------------------------------------------------------------------------------- 1 | _target_: RED.agents.fitted_Q_agents.KerasFittedQAgent 2 | layer_sizes: [2, 20, 20, 4] -------------------------------------------------------------------------------- /RED/configs/model/DRPG_agent.yaml: -------------------------------------------------------------------------------- 1 | _target_: RED.agents.continuous_agents.DRPG_agent 2 | layer_sizes: ??? 3 | learning_rate: 0.001 4 | critic: True -------------------------------------------------------------------------------- /RED/configs/example/Figure_2_FQ_chemostat.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /environment: chemostat 3 | - /model: KerasFittedQAgent 4 | - _self_ 5 | 6 | hidden_layer_sizes: [150, 150, 150] 7 | use_old_state: false # true as default in original rl-oed 8 | old_state_normaliser: [1e3, 1e1, 1e-3, 1e-4, 1e3, 1e3, 1e3, 1e3, 1e3, 1e3, 1e2] 9 | init_explore_rate: 1 10 | init_alpha: 1 11 | save_path: results/ -------------------------------------------------------------------------------- /RED/configs/train.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /environment: chemostat 3 | - /model: RT3D_agent 4 | - _self_ 5 | 6 | model: 7 | val_learning_rate: 0.0001 8 | pol_learning_rate: 0.00005 9 | policy_act: sigmoid 10 | noise_bounds: [-0.25, 0.25] 11 | action_bounds: [0, 1] 12 | 13 | hidden_layer_size: [[64, 64], [128, 128]] 14 | policy_delay: 2 15 | max_std: 1 16 | explore_rate: "${max_std}" 17 | save_path: results/ -------------------------------------------------------------------------------- /RED/configs/environment/gene_transcription.yaml: -------------------------------------------------------------------------------- 1 | xdot_path: RED/environments/gene_transcription/xdot_gene_transcription.py 2 | n_episodes: 20000 3 | skip: 10 4 | y0: [0.000001, 0.000001] 5 | actual_params: [20, 5e5, 1.09e9, 2.57e-4, 4.0] 6 | input_bounds: [[-3, 3]] 7 | n_controlled_inputs: 1 8 | num_inputs: 12 9 | dt: 0.01 10 | lb: [1.0, 2e3, 4.2e5, 7.7e-5, 1.0] 11 | ub: [30.0, 1e6, 5.93e10, 7.7e-4, 10.0] 12 | N_control_intervals: 6 13 | control_interval_time: 100 14 | n_observed_variables: 2 15 | prior: false 16 | normaliser: [1e3, 1e5, 1e1] -------------------------------------------------------------------------------- /RED/configs/environment/chemostat.yaml: -------------------------------------------------------------------------------- 1 | xdot_path: RED/environments/chemostat/xdot_chemostat.py 2 | n_episodes: 17500 3 | n_parallel_experiments: 10 4 | skip_first_n_experiments: 1000 5 | y0: [200000, 0, 1] 6 | actual_params: [1, 0.00048776, 0.00006845928] 7 | input_bounds: [[0.01, 1],[0.01, 1]] 8 | n_controlled_inputs: 2 9 | num_inputs: 10 10 | dt: 0.00025 11 | lb: [0.5, 0.0001, 0.00001] 12 | ub: [2, 0.001, 0.0001] 13 | N_control_intervals: 10 14 | control_interval_time: 2 15 | n_observed_variables: 1 16 | prior: false 17 | normaliser: [1e3, 1e1] -------------------------------------------------------------------------------- /RED/utils/data.py: -------------------------------------------------------------------------------- 1 | class RT3DDataset(Dataset): 2 | def __init__(self, data): 3 | self.data = data 4 | 5 | def __len__(self): 6 | return self.data.size() 7 | 8 | def __getitem__(self, idx): 9 | return self.data[idx] 10 | 11 | class DRPGDataset(Dataset): 12 | def __init__(self, data, labels): 13 | self.data = data 14 | self.labels = labels 15 | 16 | def __len__(self): 17 | return self.data.size() 18 | 19 | def __getitem__(self, idx): 20 | return self.data[idx], self.labels[idx] 21 | -------------------------------------------------------------------------------- /RED/configs/example/RT3D_gene_transcription.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /environment: gene_transcription 3 | - /model: RT3D_agent 4 | - _self_ 5 | 6 | model: 7 | val_learning_rate: 0.0001 8 | pol_learning_rate: 0.00005 9 | policy_act: sigmoid 10 | noise_bounds: [-0.25, 0.25] 11 | action_bounds: [0, 1] 12 | mem_size: 500_000_000 13 | std: 0.1 14 | 15 | environment: 16 | prior: false 17 | 18 | hidden_layer_size: [[64, 64], [128, 128]] 19 | policy_delay: 2 20 | max_std: 0 # for exploring 21 | explore_rate: "${.max_std}" 22 | recurrent: True 23 | test_episode: False 24 | load_agent_network: False 25 | agent_network_path: "/Users/neythen/Desktop/Projects/RL_OED/results/rt3d_gene_transcription_230822/repeat6" 26 | save_path: results/ -------------------------------------------------------------------------------- /RED/configs/example/Figure_S2_FQ_gene_transcription.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /environment: gene_transcription 3 | - /model: KerasFittedQAgent 4 | - _self_ 5 | 6 | model: 7 | layer_sizes: [23, 150, 150, 150, 12] 8 | 9 | environment: 10 | n_episodes: 20_000 11 | actual_params: [20, 5e5, 1.09e9, 2.57e-4, 4.] 12 | input_bounds: [[-3., 3.]] 13 | n_system_variables: 2 14 | num_inputs: 12 15 | y0: [0.000001, 0.000001] 16 | dt: 0.01 17 | normaliser: [1e3, 1e4, 1e2, 1e6, 1e10, 1e-3, 1e1, 1e9, 1e9, 1e9, 1e9, 1, 1e9, 1e9, 1e9, 1, 1e9, 1e9, 1, 1e9, 1, 1e7,10] 18 | N_control_intervals: 6 19 | control_interval_time: 100 20 | n_observed_variables: 2 21 | n_controlled_inputs: 1 22 | 23 | init_explore_rate: 1 24 | save_path: results/ -------------------------------------------------------------------------------- /RED/utils/visualization.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | 6 | 7 | def plot_returns(returns, explore_rates=None, show=True, save_to_dir=None, conv_window=25): 8 | fig, ax1 = plt.subplots(figsize=(12, 8)) 9 | ax1.set_title("Return over time") 10 | 11 | if explore_rates is not None: 12 | ax2 = ax1.twinx() 13 | ax2.plot(np.repeat(explore_rates, len(returns) // len(explore_rates)), color="black", alpha=0.5, label="Explore Rate") 14 | ax2.set_ylabel("Explore Rate") 15 | ax2.legend(loc=1) 16 | 17 | ax1.plot(np.convolve(returns, np.ones(conv_window)/conv_window, mode="valid"), label="Return") 18 | ax1.set_xlabel("Episode") 19 | ax1.set_ylabel("Return") 20 | ax1.legend(loc=2) 21 | 22 | if save_to_dir is not None: 23 | os.makedirs(save_to_dir, exist_ok=True) 24 | plt.savefig(os.path.join(save_to_dir, "returns.png")) 25 | 26 | if show: 27 | plt.show() 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Neythen Treloar 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RED 2 | [Deep Reinforcement Learning for Optimal Experimental Design in Biology](https://www.biorxiv.org/content/10.1101/2022.05.09.491138.abstract) 3 | 4 | ### Installation 5 | 6 | RED does not need to be installed to run the examples 7 | 8 | To use the package within python scripts, `RED` must be in PYTHONPATH. 9 | 10 | To add to PYTHONPATH on a bash system add the following to the ~/.bashrc file 11 | 12 | ```console 13 | export PYTHONPATH="${PYTHONPATH}:" 14 | ``` 15 | 16 | ### Dependencies 17 | Standard python dependencies are required: `numpy`, `scipy`, `matplotlib`. `TensorFlow` and `hydra-core` are required). Instructions for installing 'TensorFlow' can be found here: 18 | https://www.tensorflow.org/install/ 19 | 20 | ### User Instructions 21 | Code files can be imported into scripts, ensure the RED directory is in PYTHONPATH and simply import the required RED classes. See examples. 22 | 23 | To run examples found in RED_master/examples from the command line, e.g.: 24 | 25 | ```console 26 | $ python train_RT3D_prior.py 27 | ``` 28 | 29 | The examples will automatically save some results in the directory: 30 | 31 | 32 | The main classes are the continuous_agents and OED_env, see examples for how to use these: 33 | 34 | ### continuous_agents 35 | The continuous_agents.py file can be imported and used on any RL task. 36 | ```console 37 | from RED.agents.continuous_agents import RT3D_agent 38 | ``` 39 | 40 | ### OED_env 41 | Contains the environments used for RL for OED. Can be imported and initialised with any system goverened by a set of DEs 42 | 43 | ```console 44 | from RED.environments.OED_env import OED_env 45 | ``` 46 | -------------------------------------------------------------------------------- /RED/environments/gene_transcription/xdot_gene_transcription.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from casadi import * 3 | 4 | def xdot(sym_y, sym_theta, sym_u): 5 | a, Kt, Krt, d, b = [sym_theta[i] for i in range(sym_theta.size()[0])] #intrinsic parameters 6 | #a = 20min^-1 7 | Kr = 40 # practically unidentifiable 8 | Km = 750 9 | #Kt = 5e5 10 | #Krt = 1.09e9 11 | #d = 2.57e-4 um^-3min^-1 12 | #b = 4 min-1 13 | #Km = 750 um^-3 14 | 15 | u = sym_u[0] # for now just choose u 16 | lam = 0.03465735902799726 #min^-1 GROWTH RATE 17 | #lam = 0.006931471805599453 18 | #lam = sym_u[1] 19 | 20 | C = 40 21 | D = 20 22 | V0 = 0.28 23 | V = V0*np.exp((C+D)*lam) #eq 2 24 | G =1/(lam*C) *(np.exp((C+D)*lam) - np.exp(D*lam)) #eq 3 25 | 26 | l_ori = 0.26 # chose this so that g matched values for table 2 for both growth rates as couldnt find it defined in paper 27 | 28 | g = np.exp( (C+D-l_ori*C)*lam)#eq 4 29 | 30 | rho = 0.55 31 | k_pr = -6.47 32 | TH_pr0 = 0.65 33 | 34 | k_p = 0.3 35 | TH_p0 = 0.0074 36 | m_rnap = 6.3e-7 37 | 38 | k_a = -9.3 39 | TH_a0 = 0.59 40 | 41 | Pa = rho*V0/m_rnap *(k_a*lam + TH_a0) * (k_p*lam + TH_p0) * (k_pr*lam + TH_pr0) *np.exp((C+D)*lam) #eq 10 42 | 43 | k_r = 5.48 44 | TH_r0 = 0.03 45 | m_rib = 1.57e-6 46 | Rtot = (k_r*lam + TH_r0) * (k_pr*lam + TH_pr0)*(rho*V0*np.exp((C+D)*lam))/m_rib 47 | 48 | TH_f = 0.1 49 | Rf = TH_f*Rtot #eq 17 50 | n = 5e6 51 | eta = 900 # um^-3min^-1 52 | 53 | 54 | rna, prot = sym_y[0], sym_y[1] 55 | 56 | rna_dot = a*(g/V)*( (Pa/(n*G)*Kr + (Pa*Krt*u)/(n*G)**2) / (1 + (Pa/n*G)*Kr + (Kt/(n*G) + Pa*Krt/(n*G)**2) *u )) - d*eta*rna/V 57 | 58 | prot_dot = ((b*Rf/V) / (Km + Rf/V)) * rna/V - lam*prot/V 59 | 60 | xdot = SX.sym('xdot', 2) 61 | 62 | xdot[0] = rna_dot 63 | xdot[1] = prot_dot 64 | 65 | return xdot -------------------------------------------------------------------------------- /RED/configs/model/RT3D_agent.yaml: -------------------------------------------------------------------------------- 1 | # note: using ${eval:'...'} requires OmegaConf.register_new_resolver("eval", eval) in code 2 | _target_: RED.agents.continuous_agents.rt3d.RT3D_agent 3 | val_learning_rate: 0.001 4 | pol_learning_rate: 0.001 5 | batch_size: 256 6 | action_bounds: [0, 1] 7 | noise_bounds: [-0.25, 0.25] 8 | noise_std: 0.1 9 | gamma: 1 10 | polyak: 0.995 11 | max_length: 11 12 | mem_size: 500000000 13 | pol_module_specs: 14 | # - input_size: ${eval:'${example.environment.n_observed_variables} + 1 + ${example.environment.n_controlled_inputs}'} 15 | - input_size: 4 16 | layers: 17 | - layer_type: GRU 18 | hidden_size: 64 19 | num_layers: 1 20 | - layer_type: GRU 21 | hidden_size: 64 22 | num_layers: 1 23 | # - input_size: ${eval:'${example.environment.n_observed_variables} + 1 + ${..[0].layers[1].hidden_size}'} 24 | - input_size: ${..[0].layers[1].hidden_size} 25 | layers: 26 | - layer_type: Linear 27 | output_size: 128 28 | activation: 29 | _target_: torch.nn.ReLU 30 | - layer_type: Linear 31 | output_size: 128 32 | activation: 33 | _target_: torch.nn.ReLU 34 | - layer_type: Linear 35 | # output_size: ${example.environment.n_controlled_inputs} 36 | output_size: 2 # action space 37 | activation: 38 | _target_: torch.nn.Identity 39 | # - layer_type: Lambda 40 | # lambda_expression: "lambda x: x * 1" # scaling policy outputs 41 | val_module_specs: 42 | # - input_size: ${eval:'${example.environment.n_observed_variables} + 1 + ${example.environment.n_controlled_inputs}'} 43 | - input_size: 4 44 | layers: 45 | - layer_type: GRU 46 | hidden_size: 64 47 | num_layers: 1 48 | - layer_type: GRU 49 | hidden_size: 64 50 | num_layers: 1 51 | # - input_size: ${eval:'${example.environment.n_observed_variables} + 1 + ${example.environment.n_controlled_inputs} + ${..[0].layers[1].hidden_size}'} 52 | - input_size: 66 # 64 + 2 (hidden size + action space) 53 | layers: 54 | - layer_type: Linear 55 | output_size: 128 56 | activation: 57 | _target_: torch.nn.ReLU 58 | - layer_type: Linear 59 | output_size: 128 60 | activation: 61 | _target_: torch.nn.ReLU 62 | - layer_type: Linear 63 | output_size: 1 # predicted state-action value 64 | activation: 65 | _target_: torch.nn.Identity -------------------------------------------------------------------------------- /RED/configs/example/Figure_3_RT3D_chemostat.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /environment: chemostat 3 | - /model: RT3D_agent 4 | - _self_ 5 | 6 | policy_delay: 2 7 | initial_explore_rate: 1 8 | explore_rate_mul: 1 9 | test_episode: False 10 | save_path: ${hydra:run.dir} 11 | ckpt_freq: 50 12 | load_ckpt_dir_path: null # directory containing agent's checkpoint to load ("agent.pt") + optionally "history.json" from which to resume training 13 | 14 | model: 15 | batch_size: ${eval:'${example.environment.N_control_intervals} * ${example.environment.n_parallel_experiments}'} 16 | val_learning_rate: 0.0001 17 | pol_learning_rate: 0.00005 18 | action_bounds: [0, 1] 19 | noise_bounds: [-0.25, 0.25] 20 | noise_std: 0.1 21 | gamma: 1 22 | polyak: 0.995 23 | max_length: 11 24 | mem_size: 500_000_000 25 | pol_module_specs: 26 | - input_size: ${eval:'${example.environment.n_observed_variables} + 1 + ${example.environment.n_controlled_inputs}'} 27 | layers: 28 | - layer_type: GRU 29 | hidden_size: 64 30 | num_layers: 1 31 | - layer_type: GRU 32 | hidden_size: 64 33 | num_layers: 1 34 | - input_size: ${eval:'${example.environment.n_observed_variables} + 1 + ${..[0].layers[1].hidden_size}'} 35 | layers: 36 | - layer_type: Linear 37 | output_size: 128 38 | activation: 39 | _target_: torch.nn.ReLU 40 | - layer_type: Linear 41 | output_size: 128 42 | activation: 43 | _target_: torch.nn.ReLU 44 | - layer_type: Linear 45 | output_size: ${example.environment.n_controlled_inputs} 46 | activation: 47 | _target_: torch.nn.Sigmoid 48 | val_module_specs: 49 | - input_size: ${eval:'${example.environment.n_observed_variables} + 1 + ${example.environment.n_controlled_inputs}'} 50 | layers: 51 | - layer_type: GRU 52 | hidden_size: 64 53 | num_layers: 1 54 | - layer_type: GRU 55 | hidden_size: 64 56 | num_layers: 1 57 | - input_size: ${eval:'${example.environment.n_observed_variables} + 1 + ${example.environment.n_controlled_inputs} + ${..[0].layers[1].hidden_size}'} 58 | layers: 59 | - layer_type: Linear 60 | output_size: 128 61 | activation: 62 | _target_: torch.nn.ReLU 63 | - layer_type: Linear 64 | output_size: 128 65 | activation: 66 | _target_: torch.nn.ReLU 67 | - layer_type: Linear 68 | output_size: 1 69 | activation: 70 | _target_: torch.nn.Identity -------------------------------------------------------------------------------- /RED/configs/example/Figure_4_RT3D_chemostat.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /environment: chemostat 3 | - /model: RT3D_agent 4 | - _self_ 5 | 6 | policy_delay: 2 7 | initial_explore_rate: 1 8 | explore_rate_mul: 1 9 | test_episode: False 10 | save_path: ${hydra:run.dir} 11 | ckpt_freq: 50 12 | load_ckpt_dir_path: null # directory containing agent's checkpoint to load ("agent.pt") + optionally "history.json" from which to resume training 13 | 14 | model: 15 | batch_size: ${eval:'${example.environment.N_control_intervals} * ${example.environment.n_parallel_experiments}'} 16 | val_learning_rate: 0.0001 17 | pol_learning_rate: 0.00005 18 | action_bounds: [0, 1] 19 | noise_bounds: [-0.25, 0.25] 20 | noise_std: 0.1 21 | gamma: 1 22 | polyak: 0.995 23 | max_length: 11 24 | mem_size: 500_000_000 25 | pol_module_specs: 26 | - input_size: ${eval:'${example.environment.n_observed_variables} + 1 + ${example.environment.n_controlled_inputs}'} 27 | layers: 28 | - layer_type: GRU 29 | hidden_size: 64 30 | num_layers: 1 31 | - layer_type: GRU 32 | hidden_size: 64 33 | num_layers: 1 34 | - input_size: ${eval:'${example.environment.n_observed_variables} + 1 + ${..[0].layers[1].hidden_size}'} 35 | layers: 36 | - layer_type: Linear 37 | output_size: 128 38 | activation: 39 | _target_: torch.nn.ReLU 40 | - layer_type: Linear 41 | output_size: 128 42 | activation: 43 | _target_: torch.nn.ReLU 44 | - layer_type: Linear 45 | output_size: ${example.environment.n_controlled_inputs} 46 | activation: 47 | _target_: torch.nn.Sigmoid 48 | val_module_specs: 49 | - input_size: ${eval:'${example.environment.n_observed_variables} + 1 + ${example.environment.n_controlled_inputs}'} 50 | layers: 51 | - layer_type: GRU 52 | hidden_size: 64 53 | num_layers: 1 54 | - layer_type: GRU 55 | hidden_size: 64 56 | num_layers: 1 57 | - input_size: ${eval:'${example.environment.n_observed_variables} + 1 + ${example.environment.n_controlled_inputs} + ${..[0].layers[1].hidden_size}'} 58 | layers: 59 | - layer_type: Linear 60 | output_size: 128 61 | activation: 62 | _target_: torch.nn.ReLU 63 | - layer_type: Linear 64 | output_size: 128 65 | activation: 66 | _target_: torch.nn.ReLU 67 | - layer_type: Linear 68 | output_size: 1 69 | activation: 70 | _target_: torch.nn.Identity -------------------------------------------------------------------------------- /examples/Figure_2_FQ_chemostat/OSAO_param_inf.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | IMPORT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 5 | sys.path.append(IMPORT_PATH) 6 | 7 | from casadi import * 8 | import numpy as np 9 | import matplotlib as mpl 10 | mpl.use('tkagg') 11 | import matplotlib.pyplot as plt 12 | import hydra 13 | from omegaconf import DictConfig 14 | 15 | from RED.environments.OED_env import OED_env 16 | from RED.environments.chemostat.xdot_chemostat import xdot 17 | import json 18 | 19 | def disablePrint(): 20 | sys.stdout = open(os.devnull, 'w') 21 | 22 | def enablePrint(): 23 | sys.stdout = sys.__stdout__ 24 | 25 | 26 | @hydra.main(version_base=None, config_path="../../RED/configs", config_name="example/Figure_2_FQ_chemostat") 27 | def OSAO_param_inf(cfg : DictConfig): 28 | cfg = cfg.example 29 | os.makedirs(cfg.save_path, exist_ok=True) 30 | 31 | #setup 32 | actual_params = DM(cfg.environment.actual_params) 33 | normaliser = np.array(cfg.environment.normaliser) 34 | n_params = actual_params.size()[0] 35 | n_system_variables = len(cfg.environment.y0) 36 | n_FIM_elements = sum(range(n_params + 1)) 37 | n_tot = n_system_variables + n_params * n_system_variables + n_FIM_elements 38 | param_guesses = DM((np.array(cfg.environment.ub) + np.array(cfg.environment.lb))/2) 39 | env = OED_env(cfg.environment.y0, xdot, param_guesses, actual_params, cfg.environment.n_observed_variables, \ 40 | cfg.environment.n_controlled_inputs, cfg.environment.num_inputs, cfg.environment.input_bounds, \ 41 | cfg.environment.dt, cfg.environment.control_interval_time, normaliser) 42 | input_bounds = np.array(cfg.environment.input_bounds) 43 | u0 = (10.0**input_bounds[:,1] + 10.0**input_bounds[:,0])/2 44 | env.u0 = DM(u0) 45 | e_rewards = [] 46 | 47 | 48 | #run optimisation 49 | for e in range(0, cfg.environment.N_control_intervals): 50 | next_state, reward, done, _ = env.step() 51 | 52 | 53 | if e == cfg.environment.N_control_intervals - 1: 54 | next_state = [None]*24 55 | done = True 56 | 57 | e_rewards.append(reward) 58 | state = next_state 59 | 60 | # save results and plot 61 | np.save(os.path.join(cfg.save_path, 'trajectories.npy'), np.array(env.true_trajectory)) 62 | 63 | np.save(os.path.join(cfg.save_path, 'true_trajectory.npy'), env.true_trajectory) 64 | np.save(os.path.join(cfg.save_path, 'us.npy'), np.array(env.us)) 65 | 66 | t = np.arange(cfg.environment.N_control_intervals) * int(cfg.environment.control_interval_time) 67 | 68 | plt.plot(env.true_trajectory[0, :].elements(), label='true') 69 | plt.legend() 70 | plt.ylabel('bacteria') 71 | plt.xlabel('time (mins)') 72 | plt.savefig(os.path.join(cfg.save_path, 'bacteria_trajectories.pdf')) 73 | 74 | plt.figure() 75 | plt.plot(env.true_trajectory[1, :].elements(), label='true') 76 | plt.legend() 77 | plt.ylabel('C') 78 | plt.xlabel('time (mins)') 79 | plt.savefig(os.path.join(cfg.save_path, 'c_trajectories.pdf')) 80 | 81 | plt.figure() 82 | plt.plot(env.true_trajectory[2, :].elements(), label='true') 83 | plt.legend() 84 | plt.ylabel('C0') 85 | plt.xlabel('time (mins)') 86 | plt.savefig(os.path.join(cfg.save_path, 'c0_trajectories.pdf')) 87 | 88 | plt.figure() 89 | plt.ylim(bottom=0) 90 | plt.ylabel('u') 91 | plt.xlabel('Timestep') 92 | plt.savefig(os.path.join(cfg.save_path, 'log_us.pdf')) 93 | plt.show() 94 | 95 | 96 | if __name__ == '__main__': 97 | OSAO_param_inf() 98 | -------------------------------------------------------------------------------- /examples/Figure_2_FQ_chemostat/MPC_param_inf.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | IMPORT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 5 | sys.path.append(IMPORT_PATH) 6 | 7 | 8 | from casadi import * 9 | import numpy as np 10 | import matplotlib as mpl 11 | mpl.use('tkagg') 12 | import matplotlib.pyplot as plt 13 | import hydra 14 | from omegaconf import DictConfig 15 | 16 | from RED.environments.OED_env import OED_env 17 | from RED.environments.chemostat.xdot_chemostat import xdot 18 | import json 19 | 20 | def disablePrint(): 21 | sys.stdout = open(os.devnull, 'w') 22 | 23 | def enablePrint(): 24 | sys.stdout = sys.__stdout__ 25 | 26 | SMALL_SIZE = 11 27 | MEDIUM_SIZE = 14 28 | BIGGER_SIZE = 17 29 | 30 | plt.rc('font', size=SMALL_SIZE) # controls default text sizes 31 | plt.rc('axes', titlesize=SMALL_SIZE) # fontsize of the axes title 32 | plt.rc('axes', labelsize=MEDIUM_SIZE) # fontsize of the x and y labels 33 | plt.rc('xtick', labelsize=SMALL_SIZE) # fontsize of the tick labels 34 | plt.rc('ytick', labelsize=SMALL_SIZE) # fontsize of the tick labels 35 | plt.rc('legend', fontsize=SMALL_SIZE) # legend fontsize 36 | plt.rc('figure', titlesize=BIGGER_SIZE) # fontsize of the figure title 37 | 38 | 39 | 40 | @hydra.main(version_base=None, config_path="../../RED/configs", config_name="example/Figure_2_FQ_chemostat") 41 | def MPC_param_inf(cfg : DictConfig): 42 | # setup 43 | cfg = cfg.example 44 | 45 | actual_params = DM(cfg.environment.actual_params) 46 | normaliser = np.array(cfg.environment.normaliser) 47 | os.makedirs(cfg.save_path, exist_ok=True) 48 | 49 | param_guesses = DM((np.array(cfg.environment.ub) + np.array(cfg.environment.lb))/2) 50 | args = cfg.environment.y0, xdot, param_guesses, actual_params, cfg.environment.n_observed_variables, \ 51 | cfg.environment.n_controlled_inputs, cfg.environment.num_inputs, cfg.environment.input_bounds, \ 52 | cfg.environment.dt, cfg.environment.control_interval_time, normaliser 53 | 54 | env = OED_env(*args) 55 | input_bounds = np.array(cfg.environment.input_bounds) 56 | u0 = (input_bounds[:,1] + input_bounds[:,0])/2 57 | env.u0 = DM(u0) 58 | 59 | 60 | def get_full_u_solver(): 61 | ''' 62 | creates and return the solver which will optimise a full exepiments inputs wrt the FI 63 | :return: solver 64 | ''' 65 | us = SX.sym('us', cfg.environment.N_control_intervals * cfg.environment.n_controlled_inputs) 66 | trajectory_solver = env.get_sampled_trajectory_solver(cfg.environment.N_control_intervals, cfg.environment.control_interval_time, cfg.environment.dt) 67 | est_trajectory = trajectory_solver(env.initial_Y, param_guesses, reshape(us , (cfg.environment.n_controlled_inputs, cfg.environment.N_control_intervals))) 68 | 69 | FIM = env.get_FIM(est_trajectory) 70 | 71 | q, r = qr(FIM) 72 | 73 | obj = -trace(log(r)) 74 | nlp = {'x': us, 'f': obj} 75 | solver = env.gauss_newton(obj, nlp, us, limited_mem = True) # for some reason limited mem works better for the MPC 76 | return solver 77 | 78 | 79 | # run optimisation 80 | u0 = (input_bounds[:,1] + input_bounds[:,0])/2 81 | u_solver = get_full_u_solver() 82 | sol = u_solver( 83 | x0=u0, 84 | lbx = [input_bounds[0][0]] * cfg.environment.n_controlled_inputs * cfg.environment.N_control_intervals, 85 | ubx = [input_bounds[0][1]] * cfg.environment.n_controlled_inputs * cfg.environment.N_control_intervals 86 | ) 87 | us = sol['x'] 88 | 89 | # save results and plot 90 | np.save(os.path.join(cfg.save_path, 'us.npy'), np.array(env.us)) 91 | 92 | 93 | t = np.arange(cfg.environment.N_control_intervals) * int(cfg.environment.control_interval_time) 94 | 95 | 96 | if __name__ == '__main__': 97 | MPC_param_inf() 98 | -------------------------------------------------------------------------------- /RED/README.md: -------------------------------------------------------------------------------- 1 | ## Configuration structure 2 | 3 | We use [Hydra](https://hydra.cc) for configuration, which allows for **hierarchical configuration structure composable from multiple sources** and dynamic command line overrides. The folder structure is as follows: 4 | ``` 5 | |── configs 6 | | |── environment # environment-specific parameters 7 | | | |── chemostat.yaml 8 | | | |── gene_transcription.yaml 9 | | | └── ... 10 | | |── example # configurations for examples in the root folder `examples` 11 | | | |── FQ_chemostat.yaml 12 | | | |── FQ_gene_transcription.yaml 13 | | | └── ... 14 | | └── model # configuration files for different models and agents 15 | | |── RT3D_agent.yaml 16 | | └── ... 17 | └── train.yaml # main training configuration 18 | ``` 19 | 20 | ## How to use 21 | The file `configs/train.yaml` contains the main training configuration that can be altered to fit any experiments. It is an entry point that composes smaller configuration pieces such as the environment (chemostat, gene transcription, ...), the model/agent (RT3D, DRPG, ...), and higher-level settings such as the path where to store the results. As can be seen below, one can also easily override selected parameters in the above-mentioned config sources (see the section `model:`). 22 | 23 |
24 | Training config (configs/train.yaml) 25 | 26 | ```yaml 27 | defaults: 28 | - /environment: chemostat 29 | - /model: RT3D_agent 30 | - _self_ 31 | 32 | model: 33 | val_learning_rate: 0.0001 34 | pol_learning_rate: 0.00005 35 | policy_act: sigmoid 36 | noise_bounds: [-0.25, 0.25] 37 | action_bounds: [0, 1] 38 | 39 | hidden_layer_size: [[64, 64], [128, 128]] 40 | policy_delay: 2 41 | max_std: 1 42 | explore_rate: "${max_std}" 43 | save_path: results/ 44 | ``` 45 |
46 | 47 | ### Using in Python scripts 48 | The following python script example will load defined training config from `.RED/configs/train.yaml`, and then print it. In the function `main`, one can then work with the `config` as with standard Python dictionary, although it is an instance of OmegaConf's DictConfig. You can read more about it in the [Hydra documentation](https://hydra.cc/docs/tutorials/basic/your_first_app/using_config/) or in the [OmegaConf documentation](https://omegaconf.readthedocs.io/en/latest/usage.html#access-and-manipulation). Another example can be found in `examples/Figure_4_RT3D_chemostat/train_RT3D.py`. 49 | 50 | ```python 51 | import hydra 52 | from omegaconf import DictConfig 53 | 54 | @hydra.main(version_base=None, config_path="./RED/configs", config_name="train") 55 | def main(config: DictConfig) -> None: 56 | print(config) 57 | 58 | if __name__ == "__main__": 59 | main() 60 | ``` 61 | 62 | ### Using in Jupyter Notebooks 63 | Including the following at the beginning of a jupyter notebook will initialize Hydra, load defined training config, and then print it. 64 | 65 | ```python 66 | from hydra import compose, initialize 67 | from omegaconf import OmegaConf 68 | 69 | initialize(version_base=None, config_path="./RED/configs") 70 | config = compose(config_name="train") 71 | print(OmegaConf.to_yaml(config)) 72 | ``` 73 | 74 | When initializing hydra it is possible to override any of the default assignments. 75 | Here is an example of overriding batch_size and seed while initializing hydra: 76 | 77 | ```python 78 | from hydra import compose, initialize 79 | from omegaconf import OmegaConf 80 | 81 | initialize(version_base=None, config_path="./RED/configs") 82 | config = compose(overrides=["save_path=experiment_2_results/", "model.pol_learning_rate=0.0001"]) 83 | print(OmegaConf.to_yaml(config)) 84 | ``` 85 | 86 | The following link to hydra documentation provides more information on override syntax:
87 | https://hydra.cc/docs/advanced/override_grammar/basic/
88 | 89 | For more information regarding hydra initialization in jupyter see the following link: 90 | https://github.com/facebookresearch/hydra/blob/main/examples/jupyter_notebooks/compose_configs_in_notebook.ipynb 91 | -------------------------------------------------------------------------------- /examples/RT3D_gene_transcription/MPC_param_inf.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | IMPORT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 5 | sys.path.append(IMPORT_PATH) 6 | 7 | from casadi import * 8 | import numpy as np 9 | import matplotlib as mpl 10 | mpl.use('tkagg') 11 | import matplotlib.pyplot as plt 12 | import hydra 13 | from hydra.utils import instantiate 14 | from omegaconf import DictConfig 15 | 16 | import time 17 | import tensorflow as tf 18 | from RED.environments.OED_env import OED_env 19 | from RED.environments.gene_transcription.xdot_gene_transcription import xdot 20 | from RED.agents.continuous_agents import RT3D_agent 21 | import multiprocessing 22 | import json 23 | import math 24 | 25 | def disablePrint(): 26 | sys.stdout = open(os.devnull, 'w') 27 | 28 | def enablePrint(): 29 | sys.stdout = sys.__stdout__ 30 | 31 | SMALL_SIZE = 11 32 | MEDIUM_SIZE = 14 33 | BIGGER_SIZE = 17 34 | 35 | plt.rc('font', size=SMALL_SIZE) # controls default text sizes 36 | plt.rc('axes', titlesize=SMALL_SIZE) # fontsize of the axes title 37 | plt.rc('axes', labelsize=MEDIUM_SIZE) # fontsize of the x and y labels 38 | plt.rc('xtick', labelsize=SMALL_SIZE) # fontsize of the tick labels 39 | plt.rc('ytick', labelsize=SMALL_SIZE) # fontsize of the tick labels 40 | plt.rc('legend', fontsize=SMALL_SIZE) # legend fontsize 41 | plt.rc('figure', titlesize=BIGGER_SIZE) # fontsize of the figure title 42 | 43 | 44 | 45 | @hydra.main(version_base=None, config_path="../../RED/configs", config_name="example/RT3D_gene_transcription") 46 | def MPC_param_inf(cfg: DictConfig): 47 | ''' 48 | {'f': DM(-73.9751), 'g': DM([]), 'lam_g': DM([]), 'lam_p': DM([]), 'lam_x': DM([0.0544032, -0.000621907, -2.55992e-06, -0.000481557, -0.000545576, -0.000732909]), 'x': DM([2.99997, -2.93105, 1.48927, -2.90577, -2.91904, -2.94369])} 49 | 50 | ''' 51 | # setup 52 | cfg = cfg.example 53 | 54 | actual_params = DM(cfg.environment.actual_params) 55 | normaliser = np.array(cfg.environment.normaliser) 56 | # save_path = os.path.join('.', 'results') 57 | os.makedirs(cfg.save_path, exist_ok=True) 58 | 59 | param_guesses = DM(actual_params) # for non prior 60 | #param_guesses = DM((np.array(ub) + np.array(lb))/2) # for prior 61 | 62 | args = cfg.environment.y0, xdot, param_guesses, actual_params, cfg.environment.n_observed_variables, \ 63 | cfg.environment.n_controlled_inputs, cfg.environment.num_inputs, cfg.environment.input_bounds, \ 64 | cfg.environment.dt, cfg.environment.control_interval_time, normaliser 65 | env = OED_env(*args) 66 | input_bounds = np.array(cfg.environment.input_bounds) 67 | u0 = (input_bounds[:,1] + input_bounds[:,0])/2 68 | env.u0 = DM(u0) 69 | 70 | 71 | def get_full_u_solver(): 72 | ''' 73 | creates and return the solver which will optimise a full exepiments inputs wrt the FI 74 | :return: solver 75 | ''' 76 | us = SX.sym('us', cfg.environment.N_control_intervals * cfg.environment.n_controlled_inputs) 77 | trajectory_solver = env.get_sampled_trajectory_solver(cfg.environment.N_control_intervals, cfg.environment.control_interval_time, cfg.environment.dt) 78 | est_trajectory = trajectory_solver(env.initial_Y, param_guesses, reshape(10.**us , (cfg.environment.n_controlled_inputs, cfg.environment.N_control_intervals))) 79 | 80 | FIM = env.get_FIM(est_trajectory) 81 | 82 | q, r = qr(FIM) 83 | 84 | obj = -trace(log(r)) 85 | nlp = {'x': us, 'f': obj} 86 | solver = env.gauss_newton(obj, nlp, us, limited_mem =False) 87 | return solver 88 | 89 | 90 | # run optimisation 91 | u0 = (input_bounds[:,1] + input_bounds[:,0])/2 92 | u_solver = get_full_u_solver() 93 | sol = u_solver( 94 | x0=u0, 95 | lbx = [input_bounds[0][0]]*cfg.environment.n_controlled_inputs*cfg.environment.N_control_intervals, 96 | ubx = [input_bounds[0][1]]*cfg.environment.n_controlled_inputs*cfg.environment.N_control_intervals 97 | ) 98 | us = sol['x'] 99 | print(sol) 100 | print(us) 101 | 102 | # save results 103 | np.save(os.path.join(cfg.save_path, 'us.npy'), np.array(env.us)) 104 | 105 | 106 | if __name__ == '__main__': 107 | MPC_param_inf() 108 | -------------------------------------------------------------------------------- /RED/environments/chemostat/xdot_chemostat.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from casadi import * 3 | 4 | 5 | def monod(C, C0, umax, Km, Km0): 6 | ''' 7 | Calculates the growth rate based on the monod equation 8 | 9 | Parameters: 10 | C: the concetrations of the auxotrophic nutrients for each bacterial 11 | population 12 | C0: concentration of the common carbon source 13 | Rmax: array of the maximum growth rates for each bacteria 14 | Km: array of the saturation constants for each auxotrophic nutrient 15 | Km0: array of the saturation constant for the common carbon source for 16 | each bacterial species 17 | ''' 18 | 19 | # convert to numpy 20 | 21 | growth_rate = ((umax * C) / (Km + C)) * (C0 / (Km0 + C0)) 22 | 23 | return growth_rate 24 | 25 | 26 | def xdot(sym_y, sym_theta, sym_u): 27 | ''' 28 | Calculates and returns derivatives for the numerical solver odeint 29 | 30 | Parameters: 31 | S: current state 32 | t: current time 33 | Cin: array of the concentrations of the auxotrophic nutrients and the 34 | common carbon source 35 | params: list parameters for all the exquations 36 | num_species: the number of bacterial populations 37 | Returns: 38 | dsol: array of the derivatives for all state variables 39 | ''' 40 | 41 | print(sym_y) 42 | 43 | #q = sym_u[0] 44 | Cin = sym_u[0] 45 | C0in = sym_u[1] 46 | 47 | q = 0.5 48 | 49 | #y, y0, umax, Km, Km0 = [sym_theta[2*i:2*(i+1)] for i in range(len(sym_theta.elements())//2)] 50 | #y, y0, umax, Km, Km0 = [sym_theta[i] for i in range(len(sym_theta.elements()))] 51 | #y0, umax, Km, Km0 = [sym_theta[i] for i in range(len(sym_theta.elements()))] 52 | 53 | umax, Km, Km0 = [sym_theta[i] for i in range(3)] 54 | 55 | 56 | y = np.array([480000.]) 57 | y0 = np.array([520000.]) 58 | 59 | 60 | num_species = Km.size()[0] 61 | 62 | # extract variables 63 | N = sym_y[0] 64 | C = sym_y[1] 65 | C0 = sym_y[2] 66 | 67 | R = monod(C, C0, umax, Km, Km0) 68 | 69 | # calculate derivatives 70 | 71 | dN = N * (R - q) # q term takes account of the dilution 72 | dC = q * (Cin - C) - (1 / y) * R * N # sometimes dC.shape is (2,2) 73 | dC0 = q * (C0in - C0) - sum(1 / y0[i] * R[i] * N[i] for i in range(num_species)) 74 | 75 | # consstruct derivative vector for odeint 76 | 77 | xdot = SX.sym('xdot', 2*num_species + 1) 78 | 79 | 80 | xdot[0] = dN 81 | xdot[1] = dC 82 | xdot[2] = dC0 83 | 84 | 85 | return xdot 86 | 87 | def xdot_scaled(sym_y, sym_theta, sym_u): 88 | ''' 89 | Calculates and returns derivatives for the numerical solver odeint 90 | 91 | Parameters: 92 | S: current state 93 | t: current time 94 | Cin: array of the concentrations of the auxotrophic nutrients and the 95 | common carbon source 96 | params: list parameters for all the exquations 97 | num_species: the number of bacterial populations 98 | Returns: 99 | dsol: array of the derivatives for all state variables 100 | ''' 101 | 102 | #q = sym_u[0] 103 | Cin = sym_u[0] 104 | C0in = sym_u[1] 105 | 106 | q = 0.5 107 | 108 | #y, y0, umax, Km, Km0 = [sym_theta[2*i:2*(i+1)] for i in range(len(sym_theta.elements())//2)] 109 | #y, y0, umax, Km, Km0 = [sym_theta[i] for i in range(len(sym_theta.elements()))] 110 | #y0, umax, Km, Km0 = [sym_theta[i] for i in range(len(sym_theta.elements()))] 111 | 112 | umax, Km, Km0 = [sym_theta[i] for i in range(3)] 113 | 114 | 115 | y = np.array([4.8]) 116 | y0 = np.array([5.2]) 117 | 118 | print('params:', y, y0, umax, Km, Km0 ) 119 | num_species = Km.size()[0] 120 | print('num species:', num_species) 121 | 122 | # extract variables 123 | N = sym_y[0] 124 | C = sym_y[1] 125 | C0 = sym_y[2] 126 | print(N.shape, C.shape, C0.shape) 127 | R = monod(C, C0, umax, Km, Km0) 128 | print(R.shape) 129 | # calculate derivatives 130 | 131 | dN = N * (R - q) # q term takes account of the dilution 132 | dC = q * (Cin - C) - (1 / y) * R * N # sometimes dC.shape is (2,2) 133 | dC0 = q * (C0in - C0) - sum(1 / y0[i] * R[i] * N[i] for i in range(num_species)) 134 | 135 | print(dN.shape, dC.shape, dC0.shape) 136 | if dC.shape == (2, 2): 137 | print(q, Cin.shape, C0, C, y, R, N) # C0in 138 | 139 | # consstruct derivative vector for odeint 140 | 141 | xdot = SX.sym('xdot', 2*num_species + 1) 142 | 143 | 144 | xdot[0] = dN 145 | xdot[1] = dC 146 | xdot[2] = dC0 147 | 148 | 149 | return xdot 150 | 151 | -------------------------------------------------------------------------------- /RED/utils/network.py: -------------------------------------------------------------------------------- 1 | import ast 2 | 3 | import torch 4 | from torch import nn, optim 5 | 6 | class LambdaLayer(nn.Module): 7 | def __init__(self, lambda_expression): 8 | super(LambdaLayer, self).__init__() 9 | self.lambda_expression = lambda_expression 10 | def forward(self, x): 11 | return self.lambda_expression(x) 12 | 13 | class NeuralNetwork(nn.Module): 14 | def __init__( 15 | self, 16 | input_size, 17 | layer_specs: list, 18 | init_optimizer=False, 19 | learning_rate=0.01, 20 | optimizer=optim.Adam, 21 | device=torch.device('cuda' if torch.cuda.is_available() else 'cpu') 22 | ): 23 | ''' 24 | :param input_size: the size of the input data 25 | :param layer_specs: a list of dictionaries, each containing "layer_type" + other key-value pairs, depending on the layer type: 26 | - "GRU" - "hidden_size", "num_layers" 27 | - "Linear" - "output_size" 28 | - "Lambda" - "lambda_expression" 29 | - "lambda_expression" must contain a python (lambda) function which takes a tensor as input and returns a tensor as output 30 | - examples (both are valid - string is converted to lambda expression): 31 | "lambda_expression": lambda x: x * 1 32 | "lambda_expression": "lambda x: x * 1" 33 | "lambda_expression": """ 34 | def f(x): 35 | return x * 1 36 | """ 37 | - Additional key-value pairs: 38 | - "activation" for the activation function which should be applied after the layer 39 | :param init_optimizer: if True, the optimizer is initialized with the given learning rate 40 | :param learning_rate: the learning rate for the optimizer 41 | :param optimizer: the optimizer to be used 42 | ''' 43 | super().__init__() 44 | 45 | self.layers = nn.ModuleList() 46 | self.input_size = input_size 47 | 48 | for layer_spec in layer_specs: 49 | assert "layer_type" in layer_spec, "Each layer spec should contain a key 'layer_type'." 50 | 51 | ### layer initialization 52 | if layer_spec["layer_type"] == "GRU": 53 | assert "hidden_size" in layer_spec, "GRU layer spec should contain a key 'hidden_size'." 54 | assert "num_layers" in layer_spec, "GRU layer spec should contain a key 'num_layers'." 55 | self.layers.append(nn.GRU( 56 | input_size=input_size, 57 | hidden_size=layer_spec["hidden_size"], 58 | num_layers=layer_spec["num_layers"], 59 | batch_first=True, 60 | )) 61 | input_size = layer_spec["hidden_size"] 62 | elif layer_spec["layer_type"] == "Linear": 63 | assert "output_size" in layer_spec, "Linear layer spec should contain a key 'output_size'." 64 | self.layers.append(nn.Linear( 65 | in_features=input_size, 66 | out_features=layer_spec["output_size"], 67 | )) 68 | input_size = layer_spec["output_size"] 69 | elif layer_spec["layer_type"] == "Lambda": 70 | assert "lambda_expression" in layer_spec, "Lambda layer spec should contain a key 'lambda_expression'." 71 | lambda_expr = layer_spec["lambda_expression"] 72 | if type(lambda_expr) == str: 73 | try: 74 | # checks if the string contains a valid python code 75 | ast.parse(lambda_expr) 76 | lambda_expr = eval(lambda_expr) 77 | except SyntaxError: 78 | raise SyntaxError("Lambda expression is not valid.") 79 | self.layers.append(LambdaLayer(lambda_expr)) 80 | else: 81 | raise ValueError("Unknown layer type: " + layer_spec["layer_type"]) 82 | 83 | ### activation function 84 | if "activation" in layer_spec: 85 | self.layers.append(layer_spec["activation"]) 86 | 87 | self.output_size = input_size 88 | self.device = device 89 | self.to(self.device) 90 | 91 | self.learning_rate = learning_rate if init_optimizer else None 92 | self.optimizer = optimizer(params=self.parameters(), lr=learning_rate) if init_optimizer else None 93 | 94 | def forward(self, input_data): 95 | for layer in self.layers: 96 | if type(layer) == nn.modules.rnn.GRU: 97 | input_data = layer(input_data)[0] 98 | else: 99 | input_data = layer(input_data) 100 | 101 | return input_data 102 | -------------------------------------------------------------------------------- /examples/Figure_S2_FQ_gene_transcription/train_FQ.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | IMPORT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 5 | sys.path.append(IMPORT_PATH) 6 | 7 | from casadi import * 8 | import numpy as np 9 | import matplotlib as mpl 10 | mpl.use('tkagg') 11 | import matplotlib.pyplot as plt 12 | import hydra 13 | from hydra.utils import instantiate 14 | from omegaconf import DictConfig 15 | 16 | import time 17 | import tensorflow as tf 18 | from RED.agents.fitted_Q_agents import KerasFittedQAgent 19 | from RED.environments.OED_env import OED_env 20 | from RED.environments.gene_transcription.xdot_gene_transcription import xdot 21 | 22 | def disablePrint(): 23 | sys.stdout = open(os.devnull, 'w') 24 | 25 | def enablePrint(): 26 | sys.stdout = sys.__stdout__ 27 | 28 | def action_scaling(u): 29 | return 10**u 30 | 31 | @hydra.main(version_base=None, config_path="../../RED/configs", config_name="example/Figure_S2_FQ_gene_transcription") 32 | def train_FQ(cfg : DictConfig): 33 | cfg = cfg.example 34 | 35 | physical_devices = tf.config.list_physical_devices('GPU') 36 | try: 37 | tf.config.experimental.set_memory_growth(physical_devices[0], True) 38 | except: 39 | pass 40 | 41 | #setup 42 | agent = instantiate(cfg.model) 43 | actual_params = DM(cfg.environment.actual_params) 44 | n_params = actual_params.size()[0] 45 | n_FIM_elements = sum(range(n_params + 1)) 46 | # ??? vvv 47 | param_guesses = DM([22, 6e5, 1.2e9, 3e-4, 3.5]) 48 | param_guesses = actual_params 49 | # ??? ^^^ 50 | normaliser = np.array(cfg.environment.normaliser) 51 | env = OED_env(cfg.environment.y0, xdot, param_guesses, actual_params, \ 52 | cfg.environment.n_observed_variables, cfg.environment.n_controlled_inputs, cfg.environment.num_inputs, \ 53 | cfg.environment.input_bounds, cfg.environment.dt, cfg.environment.control_interval_time, normaliser) 54 | 55 | explore_rate = cfg.init_explore_rate 56 | all_returns = [] 57 | 58 | for episode in range(cfg.environment.n_episodes): #training loop 59 | 60 | env.reset() 61 | state = env.get_initial_RL_state(use_old_state=True) 62 | e_return = 0 63 | e_actions =[] 64 | e_rewards = [] 65 | trajectory = [] 66 | 67 | for e in range(0, cfg.environment.N_control_intervals): # run episode 68 | t = time.time() 69 | action = agent.get_action(state.reshape(-1, 23), explore_rate) 70 | next_state, reward, done, _ = env.step(action, use_old_state = True, scaling = action_scaling) 71 | if e == cfg.environment.N_control_intervals - 1: 72 | next_state = [None]*23 73 | done = True 74 | transition = (state, action, reward, next_state, done) 75 | trajectory.append(transition) 76 | e_actions.append(action) 77 | e_rewards.append(reward) 78 | state = next_state 79 | e_return += reward 80 | 81 | agent.memory.append(trajectory) 82 | #train the agent 83 | skip = 200 84 | if episode % skip == 0 or episode == cfg.environment.n_episodes - 2: #train agent 85 | explore_rate = agent.get_rate(episode, 0, 1, cfg.environment.n_episodes / 10) 86 | if explore_rate == 1: 87 | n_iters = 0 88 | elif len(agent.memory[0]) * len(agent.memory) < 40000: 89 | n_iters = 1 90 | else: 91 | n_iters = 2 92 | 93 | for iter in range(n_iters): 94 | agent.fitted_Q_update() 95 | 96 | all_returns.append(e_return) 97 | 98 | if episode %skip == 0 or episode == cfg.environment.n_episodes -1: 99 | print() 100 | print('EPISODE: ', episode) 101 | print('explore rate: ', explore_rate) 102 | print('return: ', e_return) 103 | print('av return: ', np.mean(all_returns[-skip:])) 104 | 105 | # save and plot 106 | agent.save_network(cfg.save_path) 107 | np.save(os.path.join(cfg.save_path, 'trajectories.npy'), np.array(env.true_trajectory)) 108 | np.save(os.path.join(cfg.save_path, 'true_trajectory.npy'), env.true_trajectory) 109 | np.save(os.path.join(cfg.save_path, 'us.npy'), np.array(env.us)) 110 | np.save(os.path.join(cfg.save_path, 'all_returns.npy'), np.array(all_returns)) 111 | np.save(os.path.join(cfg.save_path,'actions.npy'), np.array(agent.actions)) 112 | np.save(os.path.join(cfg.save_path,'values.npy'), np.array(agent.values)) 113 | t = np.arange(cfg.environment.N_control_intervals) * int(cfg.environment.control_interval_time) 114 | plt.plot(env.true_trajectory[0, :].elements(), label = 'true') 115 | plt.legend() 116 | plt.ylabel('rna') 117 | plt.xlabel('time (mins)') 118 | plt.savefig(os.path.join(cfg.save_path,'rna_trajectories.pdf')) 119 | plt.figure() 120 | plt.plot( env.true_trajectory[1, :].elements(), label = 'true') 121 | plt.legend() 122 | plt.ylabel( 'protein') 123 | plt.xlabel('time (mins)') 124 | plt.savefig(os.path.join(cfg.save_path, 'prot_trajectories.pdf')) 125 | plt.ylim(bottom=0) 126 | plt.ylabel('u') 127 | plt.xlabel('Timestep') 128 | plt.savefig(os.path.join(cfg.save_path, 'log_us.pdf')) 129 | plt.figure() 130 | plt.plot(all_returns) 131 | plt.ylabel('Return') 132 | plt.xlabel('Episode') 133 | plt.savefig(os.path.join(cfg.save_path, 'return.pdf')) 134 | plt.show() 135 | 136 | 137 | if __name__ == '__main__': 138 | train_FQ() 139 | -------------------------------------------------------------------------------- /examples/Figure_2_FQ_chemostat/train_FQ.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | import os 4 | IMPORT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 5 | sys.path.append(IMPORT_PATH) 6 | 7 | 8 | 9 | import math 10 | from casadi import * 11 | import numpy as np 12 | import matplotlib as mpl 13 | mpl.use('tkagg') 14 | import matplotlib.pyplot as plt 15 | import hydra 16 | from hydra.utils import instantiate 17 | from omegaconf import DictConfig 18 | 19 | import time 20 | 21 | from RED.agents.fitted_Q_agents import KerasFittedQAgent 22 | from RED.environments.OED_env import OED_env 23 | from RED.environments.chemostat.xdot_chemostat import xdot 24 | import json 25 | 26 | import multiprocessing 27 | 28 | 29 | @hydra.main(version_base=None, config_path="../../RED/configs", config_name="example/Figure_2_FQ_chemostat") 30 | def train_FQ(cfg : DictConfig): 31 | #run setup 32 | cfg = cfg.example 33 | 34 | n_cores = multiprocessing.cpu_count() 35 | 36 | actual_params = DM(cfg.environment.actual_params) 37 | if cfg.use_old_state: 38 | normaliser = np.array(cfg.old_state_normaliser) 39 | else: 40 | normaliser = np.array(cfg.environment.normaliser) 41 | n_params = actual_params.size()[0] 42 | n_system_variables = len(cfg.environment.y0) 43 | n_FIM_elements = sum(range(n_params + 1)) 44 | n_tot = cfg.environment.n_observed_variables + n_params * n_system_variables + n_FIM_elements 45 | param_guesses = actual_params 46 | 47 | agent = instantiate( 48 | cfg.model, 49 | layer_sizes=[ 50 | cfg.environment.n_observed_variables + n_params + n_FIM_elements + 1, 51 | *cfg.hidden_layer_sizes, 52 | cfg.environment.num_inputs ** cfg.environment.n_controlled_inputs 53 | ] 54 | ) 55 | 56 | args = cfg.environment.y0, xdot, param_guesses, actual_params, cfg.environment.n_observed_variables, \ 57 | cfg.environment.n_controlled_inputs, cfg.environment.num_inputs, cfg.environment.input_bounds, \ 58 | cfg.environment.dt, cfg.environment.control_interval_time, normaliser 59 | env = OED_env(*args) 60 | env.param_guesses = DM(actual_params) 61 | actual_params = np.random.uniform(low=[1, 0.00048776, 0.00006845928], high=[1, 0.00048776, 0.00006845928], 62 | size=(cfg.environment.skip, 3)) 63 | env.mapped_trajectory_solver = env.CI_solver.map(cfg.environment.skip, "thread", n_cores) 64 | explore_rate = cfg.init_explore_rate 65 | alpha = cfg.init_alpha 66 | t = time.time() 67 | all_returns = [] 68 | 69 | 70 | for episode in range(int(cfg.environment.n_episodes//cfg.environment.skip)): # training loop 71 | 72 | states = [env.get_initial_RL_state_parallel(use_old_state = cfg.use_old_state, i=i) for i in range(cfg.environment.skip)] 73 | 74 | e_returns = [0 for _ in range(cfg.environment.skip)] 75 | e_actions = [] 76 | e_rewards = [[] for _ in range(cfg.environment.skip)] 77 | trajectories = [[] for _ in range(cfg.environment.skip)] 78 | 79 | env.reset() 80 | env.param_guesses = DM(actual_params) 81 | env.logdetFIMs = [[] for _ in range(cfg.environment.skip)] 82 | env.detFIMs = [[] for _ in range(cfg.environment.skip)] 83 | 84 | for e in range(0, cfg.environment.N_control_intervals): # run an episode 85 | 86 | actions = agent.get_actions(states, explore_rate) 87 | e_actions.append(actions) 88 | outputs = env.map_parallel_step(np.array(actions).T, actual_params, use_old_state = cfg.use_old_state) 89 | next_states = [] 90 | 91 | for i,o in enumerate(outputs): # extract outputs from episodes that have been run in parallel 92 | next_state, reward, done, _, u = o 93 | next_states.append(next_state) 94 | state = states[i] 95 | action = actions[i] 96 | 97 | if e == cfg.environment.N_control_intervals - 1 or np.all(np.abs(next_state) >= 1) or math.isnan(np.sum(next_state)): 98 | next_state = [None]*agent.layer_sizes[0] 99 | done = True 100 | 101 | transition = (state, action, reward, next_state, done) 102 | trajectories[i].append(transition) 103 | if reward != -1: # dont include the unstable trajectories as they override the true return 104 | e_rewards[i].append(reward) 105 | e_returns[i] += reward 106 | 107 | state = next_state 108 | 109 | states = next_states 110 | 111 | for j, trajectory in enumerate(trajectories): # add trajectory to memory 112 | if np.all( [np.all(np.abs(trajectory[i][0]) <= 1) for i in range(len(trajectory))] ) and not math.isnan(np.sum(trajectory[-1][0])): # check for instability 113 | 114 | agent.memory.append(trajectory) 115 | all_returns.append(e_returns[j]) 116 | 117 | 118 | if episode != 0: # train agent 119 | explore_rate = agent.get_rate(episode, 0, 1, cfg.environment.n_episodes / (11*cfg.environment.skip)) 120 | alpha = agent.get_rate(episode, 0, 1, cfg.environment.n_episodes / (10*cfg.environment.skip)) 121 | 122 | if explore_rate == 1: 123 | n_iters = 0 124 | else: 125 | n_iters = 1 126 | 127 | 128 | for iter in range(n_iters): 129 | history = agent.fitted_Q_update(alpha = alpha) 130 | 131 | print() 132 | print('EPISODE: ', episode * cfg.environment.skip) 133 | print('explore rate: ', explore_rate) 134 | 135 | print('av return: ', np.mean(all_returns[-cfg.environment.skip:])) 136 | 137 | #save results and plot 138 | agent.save_network(cfg.save_path) 139 | np.save(os.path.join(cfg.save_path, 'all_returns.npy'), np.array(all_returns)) 140 | 141 | np.save(os.path.join(cfg.save_path, 'actions.npy'), np.array(agent.actions)) 142 | np.save(os.path.join(cfg.save_path, 'values.npy'), np.array(agent.values)) 143 | 144 | t = np.arange(cfg.environment.N_control_intervals) * int(cfg.environment.control_interval_time) 145 | 146 | 147 | 148 | plt.figure() 149 | plt.plot(all_returns) 150 | plt.ylabel('Return') 151 | plt.xlabel('Episode') 152 | plt.savefig(os.path.join(cfg.save_path,'return.pdf')) 153 | 154 | 155 | 156 | plt.show() 157 | 158 | 159 | if __name__ == '__main__': 160 | train_FQ() 161 | -------------------------------------------------------------------------------- /RED/run_RED.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | import os 4 | import importlib.util 5 | import hydra 6 | from hydra.utils import instantiate 7 | from omegaconf import DictConfig 8 | 9 | IMPORT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 10 | sys.path.append(IMPORT_PATH) 11 | 12 | 13 | 14 | import math 15 | from casadi import * 16 | import numpy as np 17 | import matplotlib.pyplot as plt 18 | 19 | 20 | from RED.environments.OED_env import OED_env 21 | 22 | import time 23 | 24 | 25 | import tensorflow as tf 26 | 27 | import multiprocessing 28 | 29 | 30 | 31 | 32 | @hydra.main(version_base=None, config_path="configs", config_name="train") 33 | def run_RT3D(cfg : DictConfig): 34 | # setup 35 | n_cores = multiprocessing.cpu_count() 36 | _, n_episodes, skip, y0, actual_params, input_bounds, n_controlled_inputs, num_inputs, dt, lb, ub, N_control_intervals, control_interval_time, n_observed_variables, prior, normaliser = \ 37 | [cfg.environment[k] for k in cfg.environment.keys()] 38 | actual_params = DM(actual_params) 39 | normaliser = np.array(normaliser) 40 | n_params = actual_params.size()[0] 41 | n_system_variables = len(y0) 42 | n_FIM_elements = sum(range(n_params + 1)) 43 | n_tot = n_system_variables + n_params * n_system_variables + n_FIM_elements 44 | param_guesses = actual_params 45 | physical_devices = tf.config.list_physical_devices('GPU') 46 | try: 47 | tf.config.experimental.set_memory_growth(physical_devices[0], True) 48 | except: 49 | pass 50 | save_path = cfg.save_path 51 | os.makedirs(save_path, exist_ok=True) 52 | 53 | # agent setup 54 | pol_layer_sizes = [n_observed_variables + 1, n_observed_variables + 1 + n_controlled_inputs, cfg.hidden_layer_size[0], 55 | cfg.hidden_layer_size[1], n_controlled_inputs] 56 | val_layer_sizes = [n_observed_variables + 1 + n_controlled_inputs, n_observed_variables + 1 + n_controlled_inputs, 57 | cfg.hidden_layer_size[0], cfg.hidden_layer_size[1], 1] 58 | agent = instantiate(cfg.model, pol_layer_sizes=pol_layer_sizes, val_layer_sizes=val_layer_sizes, batch_size=int(N_control_intervals * skip)) 59 | 60 | update_count = 0 61 | explore_rate = cfg.explore_rate 62 | all_returns = [] 63 | all_test_returns = [] 64 | 65 | # env setup 66 | spec = importlib.util.spec_from_file_location('xdot', hydra.utils.to_absolute_path(cfg.environment.xdot_path)) 67 | xdot_mod = importlib.util.module_from_spec(spec) 68 | sys.modules['xdot'] = xdot_mod 69 | spec.loader.exec_module(xdot_mod) 70 | 71 | args = y0, xdot_mod.xdot, param_guesses, actual_params, n_observed_variables, n_controlled_inputs, num_inputs, input_bounds, dt, control_interval_time, normaliser 72 | env = OED_env(*args) 73 | env.mapped_trajectory_solver = env.CI_solver.map(skip, "thread", n_cores) 74 | 75 | for episode in range(int(n_episodes // skip)): # training loop 76 | actual_params = np.random.uniform(low=lb, high=ub, size=(skip, len(cfg.environment.actual_params))) # sample from uniform distribution 77 | env.param_guesses = DM(actual_params) 78 | states = [env.get_initial_RL_state_parallel() for i in range(skip)] 79 | e_returns = [0 for _ in range(skip)] 80 | e_actions = [] 81 | e_exploit_flags = [] 82 | e_rewards = [[] for _ in range(skip)] 83 | e_us = [[] for _ in range(skip)] 84 | trajectories = [[] for _ in range(skip)] 85 | sequences = [[[0] * pol_layer_sizes[1]] for _ in range(skip)] 86 | env.reset() 87 | env.param_guesses = DM(actual_params) 88 | env.logdetFIMs = [[] for _ in range(skip)] 89 | env.detFIMs = [[] for _ in range(skip)] 90 | 91 | for e in range(0, N_control_intervals): # run an episode 92 | inputs = [states, sequences] 93 | if episode < 1000 // skip: 94 | actions = agent.get_actions(inputs, explore_rate=1, test_episode=True, recurrent=True) 95 | else: 96 | actions = agent.get_actions(inputs, explore_rate=explore_rate, test_episode=True, recurrent=True) 97 | 98 | e_actions.append(actions) 99 | outputs = env.map_parallel_step(np.array(actions).T, actual_params, continuous=True) 100 | next_states = [] 101 | 102 | for i, o in enumerate(outputs): # extract outputs from parallel experiments 103 | next_state, reward, done, _, u = o 104 | e_us[i].append(u) 105 | next_states.append(next_state) 106 | state = states[i] 107 | action = actions[i] 108 | 109 | if e == N_control_intervals - 1 or np.all(np.abs(next_state) >= 1) or math.isnan(np.sum(next_state)): 110 | done = True 111 | 112 | transition = (state, action, reward, next_state, done) 113 | trajectories[i].append(transition) 114 | sequences[i].append(np.concatenate((state, action))) 115 | if reward != -1: # dont include the unstable trajectories as they override the true return 116 | e_rewards[i].append(reward) 117 | e_returns[i] += reward 118 | states = next_states 119 | 120 | for trajectory in trajectories: 121 | if np.all([np.all(np.abs(trajectory[i][0]) <= 1) for i in range(len(trajectory))]) and not math.isnan( 122 | np.sum(trajectory[-1][0])): # check for instability 123 | agent.memory.append(trajectory) 124 | 125 | if episode > 1000 // skip: # train agent 126 | print('training', update_count) 127 | t = time.time() 128 | for _ in range(skip): 129 | update_count += 1 130 | policy = update_count % cfg.policy_delay == 0 131 | 132 | agent.Q_update(policy=policy, fitted=False, recurrent=True) 133 | print('fitting time', time.time() - t) 134 | 135 | explore_rate = agent.get_rate(episode, 0, 1, n_episodes / (11 * skip)) * cfg.max_std 136 | 137 | all_returns.extend(e_returns) 138 | print() 139 | print('EPISODE: ', episode, episode * skip) 140 | 141 | print('av return: ', np.mean(all_returns[-skip:])) 142 | print() 143 | 144 | # plot and save results 145 | np.save(save_path + 'all_returns.npy', np.array(all_returns)) 146 | np.save(save_path + 'actions.npy', np.array(agent.actions)) 147 | agent.save_network(save_path) 148 | 149 | t = np.arange(N_control_intervals) * int(control_interval_time) 150 | 151 | plt.plot(all_test_returns) 152 | plt.figure() 153 | plt.plot(all_returns) 154 | plt.show() 155 | 156 | 157 | if __name__ == '__main__': 158 | run_RT3D() 159 | -------------------------------------------------------------------------------- /RED/agents/continuous_agents/drpg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.optim import Adam 4 | from torch.utils.data import DataLoader 5 | from torch.nn import L1Loss 6 | from torch.nn.utils.rnn import pad_sequences 7 | import numpy as np 8 | import copy, os 9 | 10 | from utils.network import NeuralNetwork 11 | from utils.data import DRPGDataset 12 | 13 | class DRPG_agent(): 14 | def __init__(self, layer_sizes, learning_rate = 0.001, critic=True): 15 | self.memory = [] 16 | self.layer_sizes = layer_sizes 17 | self.gamma = 1. 18 | 19 | self.critic = critic 20 | if critic: 21 | self.critic_network = self.initialise_network(layer_sizes) 22 | self.critic_network_opt = Adam(self.opt_Q1_network.parameters(), lr=val_learning_rate) 23 | self.critic_network_loss = L1Loss() 24 | 25 | self.actor_network = self.initialise_network(layer_sizes) 26 | self.actor_network_opt = Adam(self.opt_Q1_network.parameters(), lr=val_learning_rate) 27 | self.actor_network_loss = L1Loss() 28 | 29 | self.values = [] 30 | self.actions = [] 31 | 32 | self.states = [] 33 | self.next_states = [] 34 | self.actions = [] 35 | self.rewards = [] 36 | self.dones = [] 37 | self.sequences = [] 38 | self.next_sequences = [] 39 | self.all_values = [] 40 | 41 | 42 | def initialise_network(self, layer_sizes, critic_nw = False): 43 | 44 | ''' 45 | Creates Q network for value function approximation 46 | ''' 47 | network = NeuralNetwork(layer_sizes, critic_nw) 48 | 49 | return network 50 | 51 | 52 | def get_actions(self, inputs): 53 | 54 | states, sequences = inputs 55 | 56 | 57 | sequences = torch.pad_sequences(sequences) 58 | 59 | 60 | mu, log_std = self.actor_network.predict([np.array(states), sequences]) 61 | 62 | print('mu log_std', mu[0], log_std[0]) 63 | 64 | actions = mu + torch.mult(torch.normal(mu.size), torch.exp(log_std)) 65 | #print('actions',actions[0]) 66 | 67 | return actions 68 | 69 | def loss(self, inputs, actions, returns): 70 | # Obtain mu and sigma from actor network 71 | mu, log_std = self.actor_network(inputs) 72 | 73 | # Compute log probability 74 | log_probability = self.log_probability(actions, mu, log_std) 75 | print('log probability', log_probability.shape) 76 | print('returns:', returns.shape) 77 | # Compute weighted loss 78 | returns_x_logprob = torch.mult(returns, log_probability) 79 | loss_actor = self.Q1_network_loss(returns_x_logprob, torch.zeros(returns_x_logprob)) 80 | print('loss actor', loss_actor.shape) 81 | return loss_actor 82 | 83 | 84 | def log_probability(self, actions, mu, log_std): 85 | 86 | EPS = 1e-8 87 | pre_sum = -0.5 * (((actions - mu) / (torch.exp(log_std) + EPS)) ** 2 + 2 * log_std + np.log(2 * np.pi)) 88 | 89 | print('pre sum', pre_sum.shape) 90 | return torch.sum(pre_sum, axis=1) 91 | 92 | def train(model, train_dataloder, optimizer, criterion, epochs): 93 | for _ in range(epochs): 94 | # go through all the batches generated by dataloader 95 | for i, (inputs, targets) in enumerate(train_dataloder): 96 | # clear the gradients 97 | optimizer.zero_grad() 98 | # compute the model output 99 | yhat = model(inputs) 100 | # calculate loss 101 | loss = criterion(yhat, targets.type(torch.LongTensor)) 102 | # credit assignment 103 | loss.backward() 104 | # update model weights 105 | optimizer.step() 106 | 107 | return model 108 | 109 | def policy_update(self): 110 | inputs, actions, returns = self.get_inputs_targets() 111 | 112 | print(returns.shape) 113 | if self.critic: 114 | 115 | expected_returns = self.critic_network(inputs) 116 | 117 | returns -= expected_returns.reshape(-1) 118 | print(expected_returns.reshape(-1).shape) 119 | train_dataset = DRPGDataset(inputs, returns) 120 | train_dataloader = DataLoader(train_dataset) 121 | self.critic_network = self.train(self.critic_network, train_dataloader, self.critic_network_opt, self.critic_network_loss, epochs=1) 122 | 123 | loss = self.loss(inputs, actions, returns) 124 | loss.backward() 125 | self.critic_network_opt.step() 126 | loss = self.cri 127 | 128 | def get_inputs_targets(self): 129 | ''' 130 | gets fitted Q inputs and calculates targets for training the Q-network for episodic training 131 | ''' 132 | 133 | ''' 134 | gets fitted Q inputs and calculates targets for training the Q-network for episodic training 135 | ''' 136 | 137 | # iterate over all exprienc in memory and create fitted Q targets 138 | for i, trajectory in enumerate(self.memory): 139 | 140 | e_rewards = [] 141 | sequence = [[0]*self.layer_sizes[1]] 142 | for j, transition in enumerate(trajectory): 143 | self.sequences.append(copy.deepcopy(sequence)) 144 | state, action, reward, next_state, done, u = transition 145 | sequence.append(np.concatenate((state, u/1))) 146 | #one_hot_a = np.array([int(i == action) for i in range(self.layer_sizes[-1])])/10 147 | self.next_sequences.append(copy.deepcopy(sequence)) 148 | self.states.append(state) 149 | self.next_states.append(next_state) 150 | self.actions.append(action) 151 | self.rewards.append(reward) 152 | e_rewards.append(reward) 153 | self.dones.append(done) 154 | 155 | 156 | e_values = [e_rewards[-1]] 157 | 158 | for i in range(2, len(e_rewards) + 1): 159 | e_values.insert(0, e_rewards[-i] + e_values[0] * self.gamma) 160 | self.all_values.extend(e_values) 161 | 162 | padded = pad_sequences(self.sequences, maxlen = 11, dtype='float64') 163 | states = np.array(self.states) 164 | actions = np.array(self.actions) 165 | all_values = np.array(self.all_values) 166 | 167 | self.sequences = [] 168 | self.states = [] 169 | self.actions = [] 170 | self.all_values = [] 171 | self.memory = [] # reset memory after this information has been extracted 172 | 173 | randomize = np.arange(len(states)) 174 | np.random.shuffle(randomize) 175 | 176 | states = states[randomize] 177 | actions = actions[randomize] 178 | 179 | padded = padded[randomize] 180 | all_values = all_values[randomize] 181 | 182 | inputs = [states, padded] 183 | print('inputs, actions, all_values', inputs[0].shape, inputs[1].shape, actions.shape, all_values.shape) 184 | return inputs, actions, all_values 185 | 186 | 187 | def save_network(self, save_path): # tested 188 | #print(self.network.layers[1].get_weights()) 189 | torch.save(self.actor_network.state_dict(), 'saved_network.pth') 190 | 191 | 192 | def load_network(self, load_path): #tested 193 | self.policy_network = torch.load(os.path.join(load_path, 'policy_network.pth')) # sometimes this crashes, apparently a bug in keras 194 | 195 | -------------------------------------------------------------------------------- /examples/Figure_3_RT3D_chemostat/train_RT3D.py: -------------------------------------------------------------------------------- 1 | 2 | import json 3 | import math 4 | import os 5 | import sys 6 | 7 | IMPORT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 8 | sys.path.append(IMPORT_PATH) 9 | 10 | import multiprocessing 11 | 12 | import hydra 13 | import matplotlib.pyplot as plt 14 | import numpy as np 15 | from casadi import * 16 | from hydra.utils import instantiate 17 | from omegaconf import DictConfig, OmegaConf 18 | 19 | from RED.agents.continuous_agents.rt3d import RT3D_agent 20 | from RED.environments.chemostat.xdot_chemostat import xdot 21 | from RED.environments.OED_env import OED_env 22 | from RED.utils.visualization import plot_returns 23 | 24 | # https://omegaconf.readthedocs.io/en/2.3_branch/how_to_guides.html#how-to-perform-arithmetic-using-eval-as-a-resolver 25 | OmegaConf.register_new_resolver("eval", eval) 26 | 27 | 28 | @hydra.main(version_base=None, config_path="../../RED/configs", config_name="example/Figure_3_RT3D_chemostat") 29 | def train_RT3D(cfg : DictConfig): 30 | ### config setup 31 | cfg = cfg.example 32 | print( 33 | "--- Configuration ---", 34 | OmegaConf.to_yaml(cfg, resolve=True), 35 | "--- End of configuration ---", 36 | sep="\n\n" 37 | ) 38 | 39 | ### prepare save path 40 | os.makedirs(cfg.save_path, exist_ok=True) 41 | print("Results will be saved in: ", cfg.save_path) 42 | 43 | ### agent setup 44 | agent = instantiate(cfg.model) 45 | explore_rate = cfg.initial_explore_rate 46 | seq_dim = cfg.environment.n_observed_variables + 1 + cfg.environment.n_controlled_inputs 47 | 48 | ### env setup 49 | env, n_params = setup_env(cfg) 50 | total_episodes = cfg.environment.n_episodes // cfg.environment.n_parallel_experiments 51 | skip_first_n_episodes = cfg.environment.skip_first_n_experiments // cfg.environment.n_parallel_experiments 52 | starting_episode = 0 53 | 54 | history = {k: [] for k in ["returns", "actions", "rewards", "us", "explore_rate", "update_count"]} 55 | 56 | ### load ckpt 57 | if cfg.load_ckpt_dir_path is not None: 58 | print(f"Loading checkpoint from: {cfg.load_ckpt_dir_path}") 59 | # load the agent 60 | agent_path = os.path.join(cfg.load_ckpt_dir_path, "agent.pt") 61 | print(f"Loading agent from: {agent_path}") 62 | additional_info = agent.load_ckpt( 63 | load_path=agent_path, 64 | load_target_networks=True, 65 | )["additional_info"] 66 | # load history 67 | history_path = os.path.join(cfg.load_ckpt_dir_path, "history.json") 68 | if os.path.exists(history_path): 69 | print(f"Loading history from: {history_path}") 70 | with open(history_path, "r") as f: 71 | history = json.load(f) 72 | # load explore rate 73 | if "explore_rate" in history and len(history["explore_rate"]) > 0: 74 | explore_rate = history["explore_rate"][-1] 75 | # load starting episode 76 | if "episode" in additional_info: 77 | starting_episode = additional_info["episode"] + 1 78 | 79 | ### training loop 80 | for episode in range(starting_episode, total_episodes): 81 | actual_params = np.random.uniform( 82 | low=cfg.environment.actual_params, 83 | high=cfg.environment.actual_params, 84 | size=(cfg.environment.n_parallel_experiments, n_params) 85 | ) 86 | env.param_guesses = DM(actual_params) 87 | 88 | ### episode buffers for agent 89 | states = [env.get_initial_RL_state_parallel() for i in range(cfg.environment.n_parallel_experiments)] 90 | trajectories = [[] for _ in range(cfg.environment.n_parallel_experiments)] 91 | sequences = [[[0] * seq_dim] for _ in range(cfg.environment.n_parallel_experiments)] 92 | 93 | ### episode logging buffers 94 | e_returns = [0 for _ in range(cfg.environment.n_parallel_experiments)] 95 | e_actions = [] 96 | e_rewards = [[] for _ in range(cfg.environment.n_parallel_experiments)] 97 | e_us = [[] for _ in range(cfg.environment.n_parallel_experiments)] 98 | 99 | ### reset env between episodes 100 | env.reset() 101 | env.param_guesses = DM(actual_params) 102 | env.logdetFIMs = [[] for _ in range(cfg.environment.n_parallel_experiments)] 103 | env.detFIMs = [[] for _ in range(cfg.environment.n_parallel_experiments)] 104 | 105 | ### run an episode 106 | for control_interval in range(0, cfg.environment.N_control_intervals): 107 | inputs = [states, sequences] 108 | 109 | ### get agent's actions 110 | if episode < skip_first_n_episodes: 111 | actions = agent.get_actions(inputs, explore_rate=1, test_episode=cfg.test_episode, recurrent=True) 112 | else: 113 | actions = agent.get_actions(inputs, explore_rate=explore_rate, test_episode=cfg.test_episode, recurrent=True) 114 | e_actions.append(actions) 115 | 116 | ### step env 117 | outputs = env.map_parallel_step(actions.T, actual_params, continuous=True) 118 | next_states = [] 119 | for i, obs in enumerate(outputs): 120 | state, action = states[i], actions[i] 121 | next_state, reward, done, _, u = obs 122 | 123 | ### set done flag 124 | if control_interval == cfg.environment.N_control_intervals - 1 \ 125 | or np.all(np.abs(next_state) >= 1) \ 126 | or math.isnan(np.sum(next_state)): 127 | done = True 128 | 129 | ### memorize transition 130 | transition = (state, action, reward, next_state, done) 131 | trajectories[i].append(transition) 132 | sequences[i].append(np.concatenate((state, action))) 133 | 134 | ### log episode data 135 | e_us[i].append(u.tolist()) 136 | next_states.append(next_state) 137 | e_rewards[i].append(reward) 138 | e_returns[i] += reward 139 | states = next_states 140 | 141 | ### do not memorize the test trajectory (the last one) 142 | if cfg.test_episode: 143 | trajectories = trajectories[:-1] 144 | 145 | ### append trajectories to memory 146 | for trajectory in trajectories: 147 | # check for instability 148 | if np.all([np.all(np.abs(trajectory[i][0]) <= 1) for i in range(len(trajectory))]) \ 149 | and not math.isnan(np.sum(trajectory[-1][0])): 150 | agent.memory.append(trajectory) 151 | 152 | ### train agent 153 | if episode > skip_first_n_episodes: 154 | for _ in range(cfg.environment.n_parallel_experiments): 155 | history["update_count"].append(history["update_count"][-1] + 1 if len(history["update_count"]) > 0 else 1) 156 | update_policy = history["update_count"][-1] % cfg.policy_delay == 0 157 | agent.Q_update(policy=update_policy, recurrent=True) 158 | else: 159 | history["update_count"].append(history["update_count"][-1] if len(history["update_count"]) > 0 else 0) 160 | 161 | ### update explore rate 162 | explore_rate = cfg.explore_rate_mul * agent.get_rate( 163 | episode=episode, 164 | min_rate=0, 165 | max_rate=1, 166 | denominator=cfg.environment.n_episodes / (11 * cfg.environment.n_parallel_experiments) 167 | ) 168 | 169 | ### log results 170 | history["returns"].extend(e_returns) 171 | history["actions"].extend(np.array(e_actions).transpose(1, 0, 2).tolist()) 172 | history["rewards"].extend(e_rewards) 173 | history["us"].extend(e_us) 174 | history["explore_rate"].append(explore_rate) 175 | 176 | print( 177 | f"\nEPISODE: [{episode}/{total_episodes}] ({episode * cfg.environment.n_parallel_experiments} experiments)", 178 | f"explore rate:\t{explore_rate:.2f}", 179 | f"average return:\t{np.mean(e_returns):.5f}", 180 | sep="\n", 181 | ) 182 | 183 | if cfg.test_episode: 184 | print( 185 | f"test actions:\n{np.array(e_actions)[:, -1]}", 186 | f"test rewards:\n{np.array(e_rewards)[-1, :]}", 187 | f"test return:\n{np.sum(np.array(e_rewards)[-1, :])}", 188 | sep="\n", 189 | ) 190 | 191 | ### checkpoint 192 | if (cfg.ckpt_freq is not None and episode % cfg.ckpt_freq == 0) \ 193 | or episode == total_episodes - 1: 194 | ckpt_dir = os.path.join(cfg.save_path, f"ckpt_{episode}") 195 | os.makedirs(ckpt_dir, exist_ok=True) 196 | agent.save_ckpt( 197 | save_path=os.path.join(ckpt_dir, "agent.pt"), 198 | additional_info={ 199 | "episode": episode, 200 | } 201 | ) 202 | with open(os.path.join(ckpt_dir, "history.json"), "w") as f: 203 | json.dump(history, f) 204 | 205 | ### plot 206 | plot_returns( 207 | returns=history["returns"], 208 | explore_rates=history["explore_rate"], 209 | show=False, 210 | save_to_dir=cfg.save_path, 211 | conv_window=25, 212 | ) 213 | 214 | 215 | def setup_env(cfg): 216 | n_cores = multiprocessing.cpu_count() 217 | actual_params = DM(cfg.environment.actual_params) 218 | normaliser = np.array(cfg.environment.normaliser) 219 | n_params = actual_params.size()[0] 220 | param_guesses = actual_params 221 | args = cfg.environment.y0, xdot, param_guesses, actual_params, cfg.environment.n_observed_variables, \ 222 | cfg.environment.n_controlled_inputs, cfg.environment.num_inputs, cfg.environment.input_bounds, \ 223 | cfg.environment.dt, cfg.environment.control_interval_time, normaliser 224 | env = OED_env(*args) 225 | env.mapped_trajectory_solver = env.CI_solver.map(cfg.environment.n_parallel_experiments, "thread", n_cores) 226 | return env, n_params 227 | 228 | 229 | if __name__ == '__main__': 230 | train_RT3D() 231 | -------------------------------------------------------------------------------- /examples/Figure_4_RT3D_chemostat/train_RT3D.py: -------------------------------------------------------------------------------- 1 | 2 | import json 3 | import math 4 | import os 5 | import sys 6 | 7 | IMPORT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 8 | sys.path.append(IMPORT_PATH) 9 | 10 | import multiprocessing 11 | 12 | import hydra 13 | import numpy as np 14 | from casadi import * 15 | from hydra.utils import instantiate 16 | from omegaconf import DictConfig, OmegaConf 17 | 18 | from RED.agents.continuous_agents.rt3d import RT3D_agent 19 | from RED.environments.chemostat.xdot_chemostat import xdot 20 | from RED.environments.OED_env import OED_env 21 | from RED.utils.visualization import plot_returns 22 | 23 | # https://omegaconf.readthedocs.io/en/2.3_branch/how_to_guides.html#how-to-perform-arithmetic-using-eval-as-a-resolver 24 | OmegaConf.register_new_resolver("eval", eval) 25 | 26 | 27 | @hydra.main(version_base=None, config_path="../../RED/configs", config_name="example/Figure_4_RT3D_chemostat") 28 | def train_RT3D(cfg : DictConfig): 29 | ### config setup 30 | cfg = cfg.example 31 | print( 32 | "--- Configuration ---", 33 | OmegaConf.to_yaml(cfg, resolve=True), 34 | "--- End of configuration ---", 35 | sep="\n\n" 36 | ) 37 | 38 | ### prepare save path 39 | os.makedirs(cfg.save_path, exist_ok=True) 40 | print("Results will be saved in: ", cfg.save_path) 41 | 42 | ### agent setup 43 | agent = instantiate(cfg.model) 44 | explore_rate = cfg.initial_explore_rate 45 | seq_dim = cfg.environment.n_observed_variables + 1 + cfg.environment.n_controlled_inputs 46 | 47 | ### env setup 48 | env, n_params = setup_env(cfg) 49 | total_episodes = cfg.environment.n_episodes // cfg.environment.n_parallel_experiments 50 | skip_first_n_episodes = cfg.environment.skip_first_n_experiments // cfg.environment.n_parallel_experiments 51 | starting_episode = 0 52 | 53 | history = {k: [] for k in ["returns", "actions", "rewards", "us", "explore_rate", "update_count"]} 54 | 55 | ### load ckpt 56 | if cfg.load_ckpt_dir_path is not None: 57 | print(f"Loading checkpoint from: {cfg.load_ckpt_dir_path}") 58 | # load the agent 59 | agent_path = os.path.join(cfg.load_ckpt_dir_path, "agent.pt") 60 | print(f"Loading agent from: {agent_path}") 61 | additional_info = agent.load_ckpt( 62 | load_path=agent_path, 63 | load_target_networks=True, 64 | )["additional_info"] 65 | # load history 66 | history_path = os.path.join(cfg.load_ckpt_dir_path, "history.json") 67 | if os.path.exists(history_path): 68 | print(f"Loading history from: {history_path}") 69 | with open(history_path, "r") as f: 70 | history = json.load(f) 71 | # load explore rate 72 | if "explore_rate" in history and len(history["explore_rate"]) > 0: 73 | explore_rate = history["explore_rate"][-1] 74 | # load starting episode 75 | if "episode" in additional_info: 76 | starting_episode = additional_info["episode"] + 1 77 | 78 | ### training loop 79 | for episode in range(starting_episode, total_episodes): 80 | # sample params from uniform distribution 81 | actual_params = np.random.uniform( 82 | low=cfg.environment.lb, 83 | high=cfg.environment.ub, 84 | size=(cfg.environment.n_parallel_experiments, 3) 85 | ) 86 | env.param_guesses = DM(actual_params) 87 | 88 | ### episode buffers for agent 89 | states = [env.get_initial_RL_state_parallel() for i in range(cfg.environment.n_parallel_experiments)] 90 | trajectories = [[] for _ in range(cfg.environment.n_parallel_experiments)] 91 | sequences = [[[0] * seq_dim] for _ in range(cfg.environment.n_parallel_experiments)] 92 | 93 | ### episode logging buffers 94 | e_returns = [0 for _ in range(cfg.environment.n_parallel_experiments)] 95 | e_actions = [] 96 | e_rewards = [[] for _ in range(cfg.environment.n_parallel_experiments)] 97 | e_us = [[] for _ in range(cfg.environment.n_parallel_experiments)] 98 | 99 | ### reset env between episodes 100 | env.reset() 101 | env.param_guesses = DM(actual_params) 102 | env.logdetFIMs = [[] for _ in range(cfg.environment.n_parallel_experiments)] 103 | env.detFIMs = [[] for _ in range(cfg.environment.n_parallel_experiments)] 104 | 105 | ### run an episode 106 | for control_interval in range(0, cfg.environment.N_control_intervals): 107 | inputs = [states, sequences] 108 | 109 | ### get agent's actions 110 | if episode < skip_first_n_episodes: 111 | actions = agent.get_actions(inputs, explore_rate=1, test_episode=cfg.test_episode, recurrent=True) 112 | else: 113 | actions = agent.get_actions(inputs, explore_rate=explore_rate, test_episode=cfg.test_episode, recurrent=True) 114 | e_actions.append(actions) 115 | 116 | ### step env 117 | outputs = env.map_parallel_step(actions.T, actual_params, continuous=True) 118 | next_states = [] 119 | for i, obs in enumerate(outputs): 120 | state, action = states[i], actions[i] 121 | next_state, reward, done, _, u = obs 122 | 123 | ### set done flag 124 | if control_interval == cfg.environment.N_control_intervals - 1 \ 125 | or np.all(np.abs(next_state) >= 1) \ 126 | or math.isnan(np.sum(next_state)): 127 | done = True 128 | 129 | ### memorize transition 130 | transition = (state, action, reward, next_state, done) 131 | trajectories[i].append(transition) 132 | sequences[i].append(np.concatenate((state, action))) 133 | 134 | ### log episode data 135 | e_us[i].append(u.tolist()) 136 | next_states.append(next_state) 137 | e_rewards[i].append(reward) 138 | e_returns[i] += reward 139 | states = next_states 140 | 141 | ### do not memorize the test trajectory (the last one) 142 | if cfg.test_episode: 143 | trajectories = trajectories[:-1] 144 | 145 | ### append trajectories to memory 146 | for trajectory in trajectories: 147 | # check for instability 148 | if np.all([np.all(np.abs(trajectory[i][0]) <= 1) for i in range(len(trajectory))]) \ 149 | and not math.isnan(np.sum(trajectory[-1][0])): 150 | agent.memory.append(trajectory) 151 | 152 | ### train agent 153 | if episode > skip_first_n_episodes: 154 | for _ in range(cfg.environment.n_parallel_experiments): 155 | history["update_count"].append(history["update_count"][-1] + 1 if len(history["update_count"]) > 0 else 1) 156 | update_policy = history["update_count"][-1] % cfg.policy_delay == 0 157 | agent.Q_update(policy=update_policy, recurrent=True) 158 | else: 159 | history["update_count"].append(history["update_count"][-1] if len(history["update_count"]) > 0 else 0) 160 | 161 | ### update explore rate 162 | explore_rate = cfg.explore_rate_mul * agent.get_rate( 163 | episode=episode, 164 | min_rate=0, 165 | max_rate=1, 166 | denominator=cfg.environment.n_episodes / (11 * cfg.environment.n_parallel_experiments) 167 | ) 168 | 169 | ### log results 170 | history["returns"].extend(e_returns) 171 | history["actions"].extend(np.array(e_actions).transpose(1, 0, 2).tolist()) 172 | history["rewards"].extend(e_rewards) 173 | history["us"].extend(e_us) 174 | history["explore_rate"].append(explore_rate) 175 | 176 | print( 177 | f"\nEPISODE: [{episode}/{total_episodes}] ({episode * cfg.environment.n_parallel_experiments} experiments)", 178 | f"explore rate:\t{explore_rate:.2f}", 179 | f"average return:\t{np.mean(e_returns):.5f}", 180 | sep="\n", 181 | ) 182 | 183 | if cfg.test_episode: 184 | print( 185 | f"test actions:\n{np.array(e_actions)[:, -1]}", 186 | f"test rewards:\n{np.array(e_rewards)[-1, :]}", 187 | f"test return:\n{np.sum(np.array(e_rewards)[-1, :])}", 188 | sep="\n", 189 | ) 190 | 191 | ### checkpoint 192 | if (cfg.ckpt_freq is not None and episode % cfg.ckpt_freq == 0) \ 193 | or episode == total_episodes - 1: 194 | ckpt_dir = os.path.join(cfg.save_path, f"ckpt_{episode}") 195 | os.makedirs(ckpt_dir, exist_ok=True) 196 | agent.save_ckpt( 197 | save_path=os.path.join(ckpt_dir, "agent.pt"), 198 | additional_info={ 199 | "episode": episode, 200 | } 201 | ) 202 | with open(os.path.join(ckpt_dir, "history.json"), "w") as f: 203 | json.dump(history, f) 204 | 205 | ### plot 206 | plot_returns( 207 | returns=history["returns"], 208 | explore_rates=history["explore_rate"], 209 | show=False, 210 | save_to_dir=cfg.save_path, 211 | conv_window=25, 212 | ) 213 | 214 | 215 | def setup_env(cfg): 216 | n_cores = multiprocessing.cpu_count() 217 | actual_params = DM(cfg.environment.actual_params) 218 | normaliser = np.array(cfg.environment.normaliser) 219 | n_params = actual_params.size()[0] 220 | param_guesses = actual_params 221 | args = cfg.environment.y0, xdot, param_guesses, actual_params, cfg.environment.n_observed_variables, \ 222 | cfg.environment.n_controlled_inputs, cfg.environment.num_inputs, cfg.environment.input_bounds, \ 223 | cfg.environment.dt, cfg.environment.control_interval_time, normaliser 224 | env = OED_env(*args) 225 | env.mapped_trajectory_solver = env.CI_solver.map(cfg.environment.n_parallel_experiments, "thread", n_cores) 226 | return env, n_params 227 | 228 | 229 | if __name__ == '__main__': 230 | train_RT3D() 231 | -------------------------------------------------------------------------------- /examples/RT3D_gene_transcription/train_RT3D.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | IMPORT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 5 | sys.path.append(IMPORT_PATH) 6 | print(IMPORT_PATH) 7 | from casadi import * 8 | import numpy as np 9 | import matplotlib as mpl 10 | mpl.use('tkagg') 11 | import matplotlib.pyplot as plt 12 | import hydra 13 | from hydra.utils import instantiate 14 | from omegaconf import DictConfig 15 | 16 | import time 17 | import tensorflow as tf 18 | from RED.environments.OED_env import OED_env 19 | from RED.environments.gene_transcription.xdot_gene_transcription import xdot 20 | from RED.agents.continuous_agents import RT3D_agent 21 | import multiprocessing 22 | import json 23 | import math 24 | 25 | 26 | def action_scaling(u): 27 | return 10**u 28 | 29 | @hydra.main(version_base=None, config_path="../../RED/configs", config_name="example/RT3D_gene_transcription") 30 | def train_RT3D(cfg : DictConfig): 31 | cfg = cfg.example 32 | 33 | # print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU'))) 34 | n_cores = multiprocessing.cpu_count() 35 | print('Num CPU cores:', n_cores) 36 | 37 | actual_params = DM(cfg.environment.actual_params) 38 | n_params = actual_params.size()[0] 39 | n_system_variables = len(cfg.environment.y0) 40 | n_FIM_elements = sum(range(n_params + 1)) 41 | n_tot = n_system_variables + n_params * n_system_variables + n_FIM_elements 42 | param_guesses = actual_params 43 | 44 | normaliser = np.array(cfg.environment.normaliser) 45 | 46 | os.makedirs(cfg.save_path, exist_ok=True) 47 | 48 | if cfg.recurrent: 49 | # pol_layer_sizes = [cfg.environment.n_observed_variables + 1, cfg.environment.n_observed_variables + 1 + cfg.environment.n_controlled_inputs, [32, 32], [64,64,64], cfg.environment.n_controlled_inputs] 50 | pol_layer_sizes = [cfg.environment.n_observed_variables + 1, cfg.environment.n_observed_variables + 1 + cfg.environment.n_controlled_inputs, 51 | cfg.hidden_layer_size[0], cfg.hidden_layer_size[1], cfg.environment.n_controlled_inputs] 52 | val_layer_sizes = [cfg.environment.n_observed_variables + 1 + cfg.environment.n_controlled_inputs, 53 | cfg.environment.n_observed_variables + 1 + cfg.environment.n_controlled_inputs, cfg.hidden_layer_size[0], cfg.hidden_layer_size[1], 54 | 1] 55 | # agent = DQN_agent(layer_sizes=[cfg.environment.n_observed_variables + n_params + n_FIM_elements + 2, 100, 100, cfg.environment.num_inputs ** cfg.environment.n_controlled_inputs]) 56 | else: 57 | pol_layer_sizes = [cfg.environment.n_observed_variables + 1, 0, [], [128, 128], cfg.environment.n_controlled_inputs] 58 | val_layer_sizes = [cfg.environment.n_observed_variables + 1 + cfg.environment.n_controlled_inputs, 0, [], [128, 128], 1] 59 | 60 | # agent = DRPG_agent(layer_sizes=layer_sizes, learning_rate = 0.0004, critic = True) 61 | agent = instantiate( 62 | cfg.model, 63 | pol_layer_sizes=pol_layer_sizes, 64 | val_layer_sizes=val_layer_sizes, 65 | batch_size=int(cfg.environment.N_control_intervals * cfg.environment.skip), 66 | max_length=cfg.environment.N_control_intervals + 1, 67 | ) 68 | 69 | args = cfg.environment.y0, xdot, param_guesses, actual_params, cfg.environment.n_observed_variables, cfg.environment.n_controlled_inputs, cfg.environment.num_inputs, cfg.environment.input_bounds, cfg.environment.dt, cfg.environment.control_interval_time, normaliser 70 | env = OED_env(*args) 71 | 72 | explore_rate = cfg.explore_rate 73 | alpha = 1 74 | env.mapped_trajectory_solver = env.CI_solver.map(cfg.environment.skip, "thread", n_cores) 75 | total_t = time.time() 76 | 77 | unstable = 0 78 | n_unstables = [] 79 | all_returns = [] 80 | all_test_returns = [] 81 | update_count = 0 82 | fitted = False 83 | 84 | if cfg.load_agent_network: 85 | agent.load_network(cfg.agent_network_path) 86 | print('time:', cfg.environment.control_interval_time) 87 | 88 | for episode in range(int(cfg.environment.n_episodes // cfg.environment.skip)): 89 | print(episode) 90 | 91 | if cfg.environment.prior: 92 | actual_params = np.random.uniform(low=cfg.environment.lb, high=cfg.environment.ub, size=(cfg.environment.skip, 5)) 93 | else: 94 | actual_params = np.random.uniform(low=[20, 500000, 1.09e+09, 0.000257, 4], high=[20, 500000, 1.09e+09, 0.000257, 4], size=(cfg.environment.skip, 5)) 95 | env.param_guesses = DM(actual_params) 96 | 97 | states = [env.get_initial_RL_state_parallel() for i in range(cfg.environment.skip)] 98 | 99 | e_returns = [0 for _ in range(cfg.environment.skip)] 100 | 101 | e_actions = [] 102 | 103 | e_exploit_flags = [] 104 | e_rewards = [[] for _ in range(cfg.environment.skip)] 105 | e_us = [[] for _ in range(cfg.environment.skip)] 106 | trajectories = [[] for _ in range(cfg.environment.skip)] 107 | 108 | sequences = [[[0] * pol_layer_sizes[1]] for _ in range(cfg.environment.skip)] 109 | 110 | env.reset() 111 | env.param_guesses = DM(actual_params) 112 | env.logdetFIMs = [[] for _ in range(cfg.environment.skip)] 113 | env.detFIMs = [[] for _ in range(cfg.environment.skip)] 114 | 115 | for e in range(0, cfg.environment.N_control_intervals): 116 | 117 | if cfg.recurrent: 118 | inputs = [states, sequences] 119 | else: 120 | inputs = [states] 121 | 122 | if episode < 1000 // cfg.environment.skip: 123 | actions = agent.get_actions(inputs, explore_rate=0, test_episode=cfg.test_episode) 124 | else: 125 | actions = agent.get_actions(inputs, explore_rate=cfg.explore_rate, test_episode=cfg.test_episode) 126 | 127 | e_actions.append(actions) 128 | 129 | outputs = env.map_parallel_step(np.array(actions).T, actual_params, continuous=True, scaling=action_scaling) 130 | next_states = [] 131 | print(actions) 132 | 133 | for i, o in enumerate(outputs): 134 | next_state, reward, done, _, u = o 135 | e_us[i].append(u) 136 | next_states.append(next_state) 137 | state = states[i] 138 | 139 | action = actions[i] 140 | 141 | if e == cfg.environment.N_control_intervals - 1 or np.all(np.abs(next_state) >= 1) or math.isnan(np.sum(next_state)): 142 | # next_state = [0]*pol_layer_sizes[0] # maybe dont need this 143 | done = True 144 | 145 | transition = (state, action, reward, next_state, done) 146 | trajectories[i].append(transition) 147 | sequences[i].append(np.concatenate((state, action))) 148 | if reward != -1: # dont include the unstable trajectories as they override the true return 149 | e_rewards[i].append(reward) 150 | e_returns[i] += reward 151 | 152 | 153 | # print('sequences', np.array(sequences).shape) 154 | # print('sequences', sequences[0]) 155 | states = next_states 156 | 157 | if cfg.test_episode: 158 | trajectories = trajectories[:-1] 159 | 160 | for trajectory in trajectories: 161 | if np.all([np.all(np.abs(trajectory[i][0]) <= 1) for i in range(len(trajectory))]) and not math.isnan( 162 | np.sum(trajectory[-1][0])): # check for instability 163 | agent.memory.append(trajectory) # monte carlo, fitted 164 | 165 | else: 166 | unstable += 1 167 | print('UNSTABLE!!!') 168 | print((trajectory[-1][0])) 169 | 170 | 171 | if episode > 1000 // cfg.environment.skip: 172 | print('training', update_count) 173 | t = time.time() 174 | for hello in range(cfg.environment.skip): 175 | # print(e, episode, hello, update_count) 176 | update_count += 1 177 | policy = update_count % cfg.policy_delay == 0 178 | 179 | agent.Q_update(policy=policy, fitted=fitted, recurrent=cfg.recurrent, low_mem = True) 180 | print('fitting time', time.time() - t) 181 | 182 | explore_rate = agent.get_rate(episode, 0, 1, cfg.environment.n_episodes / (11 * cfg.environment.skip)) * cfg.max_std 183 | ''' 184 | if episode > 1000//cfg.environment.skip: 185 | update_count += 1 186 | agent.Q_update( policy=update_count%cfg.policy_delay == 0, fitted=True) 187 | ''' 188 | 189 | print('n unstable ', unstable) 190 | n_unstables.append(unstable) 191 | 192 | if cfg.test_episode: 193 | all_returns.extend(e_returns[:-1]) 194 | all_test_returns.append(np.sum(np.array(e_rewards)[-1, :])) 195 | else: 196 | all_returns.extend(e_returns) 197 | 198 | print() 199 | print('EPISODE: ', episode, episode * cfg.environment.skip) 200 | 201 | print('moving av return:', np.mean(all_returns[-10 * cfg.environment.skip:])) 202 | print('explore rate: ', explore_rate) 203 | print('alpha:', alpha) 204 | print('av return: ', np.mean(all_returns[-cfg.environment.skip:])) 205 | print() 206 | 207 | # print('us:', np.array(e_us)[0, :]) 208 | 209 | 210 | 211 | if cfg.test_episode: 212 | print('test actions:', np.array(e_actions)[:, -1]) 213 | print('test rewards:', np.array(e_rewards)[-1, :]) 214 | print('test return:', np.sum(np.array(e_rewards)[-1, :])) 215 | print() 216 | 217 | print('time:', time.time() - total_t) 218 | print(env.detFIMs[-1]) 219 | print(env.logdetFIMs[-1]) 220 | np.save(cfg.save_path + '/all_returns.npy', np.array(all_returns)) 221 | if cfg.test_episode: 222 | np.save(cfg.save_path + '/all_test_returns.npy', np.array(all_test_returns)) 223 | 224 | np.save(cfg.save_path + '/n_unstables.npy', np.array(n_unstables)) 225 | np.save(cfg.save_path + '/actions.npy', np.array(agent.actions)) 226 | agent.save_network(cfg.save_path) 227 | 228 | # np.save(cfg.save_path + 'values.npy', np.array(agent.values)) 229 | t = np.arange(cfg.environment.N_control_intervals) * int(cfg.environment.control_interval_time) 230 | 231 | plt.plot(all_test_returns) 232 | plt.figure() 233 | plt.plot(all_returns) 234 | plt.show() 235 | 236 | 237 | if __name__ == '__main__': 238 | train_RT3D() 239 | -------------------------------------------------------------------------------- /RED/agents/fitted_Q_agents.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import numpy as np 4 | import tensorflow as tf 5 | import math 6 | import random 7 | import time 8 | from tensorflow import keras 9 | 10 | 11 | import matplotlib.pyplot as plt 12 | import gc 13 | class FittedQAgent(): 14 | 15 | ''' 16 | abstract class for the Torch and Keras implimentations, dont use directly 17 | 18 | ''' 19 | 20 | def get_action(self, state, explore_rate): 21 | ''' 22 | Choses action based on enivormental state, explore rate and current value estimates 23 | 24 | Parameters: 25 | state: environmental state 26 | explore_rate 27 | Returns: 28 | action 29 | ''' 30 | 31 | if np.random.random() < explore_rate: 32 | action = np.random.choice(range(self.layer_sizes[-1])) 33 | 34 | else: 35 | values = self.predict(state) 36 | if np.isnan(values).any(): 37 | print('NAN IN VALUES!') 38 | print('state that gave nan:', state) 39 | self.values.append(values) 40 | action = np.argmax(values) 41 | self.actions.append(action) 42 | 43 | assert action < self.n_actions, 'Invalid action' 44 | return action 45 | 46 | def get_actions(self, states, explore_rate): 47 | ''' 48 | PARALLEL version of get action 49 | Choses action based on enivormental state, explore rate and current value estimates 50 | 51 | Parameters: 52 | state: environmental state 53 | explore_rate 54 | Returns: 55 | action 56 | ''' 57 | rng = np.random.random(len(states)) 58 | 59 | explore_inds = np.where(rng < explore_rate)[0] 60 | 61 | exploit_inds = np.where(rng >= explore_rate)[0] 62 | 63 | explore_actions = np.random.choice(range(self.layer_sizes[-1]), len(explore_inds)) 64 | actions = np.zeros((len(states)), dtype=np.int32) 65 | 66 | if len(exploit_inds) > 0: 67 | values = self.predict(np.array(states)[exploit_inds]) 68 | 69 | 70 | if np.isnan(values).any(): 71 | print('NAN IN VALUES!') 72 | print('states that gave nan:', states) 73 | self.values.extend(values) 74 | 75 | 76 | exploit_actions = np.argmax(values, axis = 1) 77 | actions[exploit_inds] = exploit_actions 78 | 79 | 80 | actions[explore_inds] = explore_actions 81 | self.actions.extend(actions) 82 | return actions 83 | 84 | 85 | 86 | def get_inputs_targets(self, alpha = 1): 87 | ''' 88 | gets fitted Q inputs and calculates targets for training the Q-network for episodic training 89 | ''' 90 | targets = [] 91 | states = [] 92 | next_states = [] 93 | actions = [] 94 | rewards = [] 95 | dones = [] 96 | 97 | # iterate over all exprienc in memory and create fitted Q targets 98 | for trajectory in self.memory: 99 | 100 | for transition in trajectory: 101 | state, action, reward, next_state, done = transition 102 | 103 | states.append(state) 104 | next_states.append(next_state) 105 | 106 | actions.append(action) 107 | rewards.append(reward) 108 | dones.append(done) 109 | 110 | 111 | 112 | states = np.array(states) 113 | next_states = np.array(next_states, dtype=np.float64) 114 | actions = np.array(actions) 115 | rewards = np.array(rewards) 116 | 117 | # construct target 118 | values = self.predict(states) 119 | next_values = self.predict(next_states) 120 | 121 | #update the value for the taken action using cost function and current Q 122 | for i in range(len(next_states)): 123 | # print(actions[i], rewards[i]) 124 | if dones[i]: 125 | 126 | values[i, actions[i]] = rewards[i] 127 | else: 128 | values[i, actions[i]] = (1-alpha)*values[i, actions[i]] + alpha*(rewards[i] + self.gamma * np.max(next_values[i])) # q learning 129 | #values[i, actions[i]] = rewards[i] + self.gamma * next_values[i, actions[i]] # sarsa 130 | 131 | # shuffle inputs and target for IID 132 | inputs, targets = np.array(states), np.array(values) 133 | 134 | 135 | randomize = np.arange(len(inputs)) 136 | np.random.shuffle(randomize) 137 | inputs = inputs[randomize] 138 | targets = targets[randomize] 139 | 140 | if np.isnan(targets).any(): 141 | print('NAN IN TARGETS!') 142 | 143 | return inputs, targets 144 | 145 | def get_inputs_targets_MC(self): 146 | ''' 147 | gets fitted Q inputs and calculates targets for training the Q-network for episodic training 148 | ''' 149 | targets = [] 150 | states = [] 151 | next_states = [] 152 | actions = [] 153 | rewards = [] 154 | dones = [] 155 | all_values = [] 156 | 157 | 158 | 159 | # iterate over all exprienc in memory and create fitted Q targets 160 | for trajectory in self.memory: 161 | 162 | e_rewards = [] 163 | for transition in trajectory: 164 | state, action, reward, next_state, done = transition 165 | 166 | states.append(state) 167 | next_states.append(next_state) 168 | actions.append(action) 169 | rewards.append(reward) 170 | e_rewards.append(reward) 171 | dones.append(done) 172 | 173 | 174 | e_values = [e_rewards[-1]] 175 | 176 | for i in range(2, len(e_rewards) + 1): 177 | e_values.insert(0, e_rewards[-i] + e_values[0] * self.gamma) 178 | all_values.extend(e_values) 179 | 180 | 181 | states = np.array(states) 182 | next_states = np.array(next_states, dtype=np.float64) 183 | actions = np.array(actions) 184 | rewards = np.array(rewards) 185 | 186 | # construct target 187 | values = self.predict(states) 188 | next_values = self.predict(next_states) 189 | 190 | #update the value for the taken action using cost function and current Q 191 | for i in range(len(next_states)): 192 | # print(actions[i], rewards[i]) 193 | 194 | values[i, actions[i]] = all_values[i] 195 | 196 | # shuffle inputs and target for IID 197 | inputs, targets = np.array(states), np.array(values) 198 | 199 | 200 | randomize = np.arange(len(inputs)) 201 | np.random.shuffle(randomize) 202 | inputs = inputs[randomize] 203 | targets = targets[randomize] 204 | 205 | if np.isnan(targets).any(): 206 | print('NAN IN TARGETS!') 207 | 208 | 209 | return inputs, targets 210 | 211 | 212 | def fitted_Q_update(self, inputs = None, targets = None, alpha = 1): 213 | ''' 214 | Uses a set of inputs and targets to update the Q network 215 | ''' 216 | 217 | if inputs is None and targets is None: 218 | t = time.time() 219 | inputs, targets = self.get_inputs_targets(alpha) 220 | 221 | t = time.time() 222 | self.reset_weights() 223 | 224 | t = time.time() 225 | history = self.fit(inputs, targets) 226 | 227 | 228 | return history 229 | 230 | def run_episode(self, env, explore_rate, tmax, train = True, remember = True): 231 | ''' 232 | Runs one fitted Q episode 233 | 234 | Parameters: 235 | env: the environment to train on and control 236 | explore_rate: explore rate for this episodes 237 | tmax: number of timesteps in the episode 238 | train: does the agent learn? 239 | remember: does the agent store eperience in its memory? 240 | 241 | Returns: 242 | env.sSol: time evolution of environmental states 243 | episode reward: total reward for this episode 244 | ''' 245 | # run trajectory with current policy and add to memory 246 | trajectory = [] 247 | actions = [] 248 | 249 | state = env.get_state() 250 | episode_reward = 0 251 | self.single_ep_reward = [] 252 | for i in range(tmax): 253 | 254 | action = self.get_action(state, explore_rate) 255 | 256 | actions.append(action) 257 | 258 | next_state, reward, done, info = env.step(action) 259 | done = False 260 | 261 | assert len(next_state) == self.state_size, 'env return state of wrong size' 262 | 263 | self.single_ep_reward.append(reward) 264 | if done: 265 | print(reward) 266 | 267 | # scale populations 268 | 269 | transition = (state, action, reward, next_state, done) 270 | state = next_state 271 | trajectory.append(transition) 272 | episode_reward += reward 273 | 274 | if done: break 275 | 276 | 277 | if remember: # store this episode 278 | self.memory.append(trajectory) 279 | 280 | if train: # train the agent 281 | 282 | self.actions = actions 283 | self.episode_lengths.append(i) 284 | self.episode_rewards.append(episode_reward) 285 | 286 | 287 | if len(self.memory[0]) * len(self.memory) < 100: 288 | n_iters = 4 289 | elif len(self.memory[0]) * len(self.memory) < 200: 290 | n_iters = 5 291 | else: 292 | n_iters = 10 293 | 294 | for _ in range(n_iters): 295 | 296 | self.fitted_Q_update() 297 | 298 | return env.sSol, episode_reward 299 | 300 | def neural_fitted_Q(self, env, n_episodes, tmax): 301 | ''' 302 | runs a whole neural fitted Q experiment 303 | 304 | Parameters: 305 | env: environment to train on 306 | n_episodes: number of episodes 307 | tmax: timesteps in each episode 308 | ''' 309 | 310 | times = [] 311 | for i in range(n_episodes): 312 | print() 313 | print('EPISODE', i) 314 | 315 | explore_rate = self.get_rate(i, 0, 1, 2.5) 316 | 317 | print('explore_rate:', explore_rate) 318 | env.reset() 319 | trajectory, reward = self.run_episode(env, explore_rate, tmax) 320 | 321 | time = len(trajectory) 322 | print('Time: ', time) 323 | times.append(time) 324 | 325 | print(times) 326 | 327 | def plot_rewards(self): 328 | ''' 329 | Plots the total reward gained in each episode on a matplotlib figure 330 | ''' 331 | plt.figure(figsize = (16.0,12.0)) 332 | 333 | plt.plot(self.episode_rewards) 334 | 335 | def save_results(self, save_path): 336 | ''' 337 | saves numpy arrays of results of training 338 | ''' 339 | np.save(save_path + '/survival_times', self.episode_lengths) 340 | np.save(save_path + '/episode_rewards', self.episode_rewards) 341 | 342 | def get_rate(self, episode, MIN_LEARNING_RATE, MAX_LEARNING_RATE, denominator): 343 | ''' 344 | Calculates the logarithmically decreasing explore or learning rate 345 | 346 | Parameters: 347 | episode: the current episode 348 | MIN_LEARNING_RATE: the minimum possible step size 349 | MAX_LEARNING_RATE: maximum step size 350 | denominator: controls the rate of decay of the step size 351 | Returns: 352 | step_size: the Q-learning step size 353 | ''' 354 | 355 | # input validation 356 | if not 0 <= MIN_LEARNING_RATE <= 1: 357 | raise ValueError("MIN_LEARNING_RATE needs to be between 0 and 1") 358 | 359 | if not 0 <= MAX_LEARNING_RATE <= 1: 360 | raise ValueError("MAX_LEARNING_RATE needs to be between 0 and 1") 361 | 362 | if not 0 < denominator: 363 | raise ValueError("denominator needs to be above 0") 364 | 365 | rate = max(MIN_LEARNING_RATE, min(MAX_LEARNING_RATE, 1.0 - math.log10((episode+1)/denominator))) 366 | 367 | return rate 368 | 369 | 370 | class KerasFittedQAgent(FittedQAgent): 371 | ''' 372 | Implementation of the neural network using keras 373 | ''' 374 | def __init__(self, layer_sizes = [2,20,20,4]): 375 | 376 | self.memory = [] 377 | self.layer_sizes = layer_sizes 378 | self.network = self.initialise_network(layer_sizes) 379 | self.gamma = 1. 380 | self.state_size = layer_sizes[0] 381 | self.n_actions = layer_sizes[-1] 382 | self.episode_lengths = [] 383 | self.episode_rewards = [] 384 | self.single_ep_reward = [] 385 | self.total_loss = 0 386 | self.values = [] 387 | self.actions = [] 388 | 389 | def initialise_network(self, layer_sizes): 390 | 391 | ''' 392 | Creates Q network for value function approximation 393 | ''' 394 | 395 | tf.keras.backend.clear_session() 396 | initialiser = keras.initializers.RandomUniform(minval = -0.5, maxval = 0.5, seed = None) 397 | positive_initialiser = keras.initializers.RandomUniform(minval = 0., maxval = 0.35, seed = None) 398 | regulariser = keras.regularizers.l1_l2(l1=0, l2=1e-6) 399 | network = keras.Sequential() 400 | network.add(keras.layers.InputLayer([layer_sizes[0]])) 401 | 402 | for l in layer_sizes[1:-1]: 403 | network.add(keras.layers.Dense(l, activation = tf.nn.relu)) 404 | network.add(keras.layers.Dense(layer_sizes[-1])) # linear output layer 405 | 406 | opt = keras.optimizers.Adam() 407 | network.compile(optimizer = opt, loss = 'mean_squared_error') # TRY DIFFERENT OPTIMISERS 408 | #try clipnorm=1 409 | return network 410 | 411 | def predict(self, state): 412 | ''' 413 | Predicts value estimates for each action base on currrent states 414 | ''' 415 | 416 | return self.network.predict(state) 417 | 418 | def fit(self, inputs, targets): 419 | ''' 420 | trains the Q network on a set of inputs and targets 421 | ''' 422 | 423 | #history = self.network.fit(inputs, targets, epochs = 500, batch_size = 256, verbose = False) used for nates system 424 | #history = self.network.fit(inputs, targets, epochs=200, batch_size=256, verbose=False) # used for single chemostat before time units error corrected 425 | history = self.network.fit(inputs, targets, validation_split = 0.01, epochs=20, batch_size=256, verbose = False) 426 | return history 427 | 428 | def reset_weights(self): 429 | ''' 430 | Reinitialises weights to random values 431 | ''' 432 | #sess = tf.keras.backend.get_session() 433 | #sess.run(tf.global_variables_initializer()) 434 | del self.network 435 | gc.collect() 436 | tf.keras.backend.clear_session() 437 | tf.compat.v1.reset_default_graph() 438 | self.network = self.initialise_network(self.layer_sizes) 439 | 440 | def save_network(self, save_path): 441 | ''' 442 | Saves current network weights 443 | ''' 444 | self.network.save(os.path.join(save_path, 'saved_network.h5')) 445 | 446 | def save_network_tensorflow(self, save_path): 447 | ''' 448 | Saves current network weights using pure tensorflow, kerassaver seems to crash sometimes 449 | ''' 450 | saver = tf.train.Saver() 451 | sess = tf.keras.backend.get_session() 452 | path = saver.save(sess, save_path + "/saved/model.cpkt") 453 | 454 | def load_network_tensorflow(self, save_path): 455 | ''' 456 | Loads network weights from file using pure tensorflow, kerassaver seems to crash sometimes 457 | ''' 458 | 459 | saver = tf.train.Saver() 460 | 461 | sess = tf.keras.backend.get_session() 462 | saver.restore(sess, save_path +"/saved/model.cpkt") 463 | 464 | def load_network(self, load_path): 465 | ''' 466 | Loads network weights from file 467 | ''' 468 | 469 | try: 470 | self.network = keras.models.load_model(os.path.join(load_path, 'saved_network.h5')) # sometimes this crashes, apparently a bug in keras 471 | except: 472 | print('EXCEPTION IN LOAD NETWORK') 473 | 474 | self.network.load_weights(os.path.join(load_path, 'saved_network.h5')) # this requires model to be initialised exactly the same 475 | -------------------------------------------------------------------------------- /RED/environments/OED_env.py: -------------------------------------------------------------------------------- 1 | from casadi import * 2 | import numpy as np 3 | 4 | import math 5 | import time 6 | 7 | 8 | def disablePrint(): 9 | sys.stdout = open(os.devnull, 'w') 10 | 11 | def enablePrint(): 12 | sys.stdout = sys.__stdout__ 13 | 14 | class OED_env(): 15 | ''' 16 | Class for OED for time course experiments on systems governed by differential equations 17 | ''' 18 | def __init__(self, x0, xdot, param_guesses, actual_params, n_observed_variables, n_controlled_inputs, num_inputs, input_bounds, dt, control_interval_time, normaliser): 19 | ''' 20 | initialises the environment 21 | :param x0: initial system state 22 | :param xdot: the governig differential equations 23 | :param param_guesses: inital parameter guesses 24 | :param actual_params: the actual system params 25 | :param n_observed_variables: number of variables that can be measured 26 | :param n_controlled_inputs: number of inputs that can be controlled 27 | :param num_inputs: number of discrete inputs for FQ 28 | :param input_bounds: the minimum and maximmum inputs 29 | :param dt: timestep for RK4 simulation 30 | :param control_interval_time: time between each control input 31 | :param normaliser: the normaliser for the RL observation 32 | ''' 33 | 34 | # build the reinforcement learning state 35 | 36 | self.n_system_variables = len(x0) 37 | self.FIMs = [] 38 | self.detFIMs = [] 39 | self.logdetFIMs = [] # so we dont have to multiply large eignvalues 40 | self.n_sensitivities = [] 41 | 42 | self.dt = dt 43 | self.control_interval_time = control_interval_time 44 | self.n_observed_variables = n_observed_variables 45 | self.initial_params = param_guesses 46 | self.param_guesses = param_guesses 47 | self.n_params = len(self.param_guesses.elements()) 48 | self.n_sensitivities = self.n_observed_variables * self.n_params 49 | self.n_FIM_elements = sum(range(self.n_params+1)) 50 | self.n_tot = self.n_system_variables + self.n_sensitivities + self.n_FIM_elements 51 | print(self.n_params, self.n_sensitivities, self.n_FIM_elements) 52 | print('n fim: ', self.n_FIM_elements) 53 | print('n_tot: ', self.n_tot) 54 | print('n_sense: ', self.n_sensitivities) 55 | self.x0 = x0 56 | self.n_controlled_inputs = n_controlled_inputs 57 | self.normaliser = normaliser 58 | self.initial_Y = DM([0] * (self.n_tot)) 59 | self.initial_Y[0:len(x0)] = x0 60 | self.Y = self.initial_Y 61 | 62 | #TODO: remove t his as too much memory 63 | self.Ys = [self.initial_Y.elements()] 64 | self.xdot = xdot # f(x, u, params) 65 | self.all_param_guesses = [] 66 | self.all_RL_states = [] 67 | self.us = [] 68 | self.actual_params = actual_params 69 | self.num_inputs = num_inputs 70 | self.input_bounds = np.array(input_bounds) 71 | self.current_tstep = 0 # to keep track of time in parallel 72 | self.CI_solver = self.get_control_interval_solver(control_interval_time, dt) 73 | 74 | def reset(self, partial = False): 75 | 76 | ''' 77 | resets the environment between episodes 78 | :param partial: only reset the FIM elements 79 | :return: 80 | ''' 81 | self.param_guesses = self.initial_params 82 | if partial: 83 | for i in range(self.Y[self.n_system_variables:, :].size()[1]): 84 | self.Y[self.n_system_variables:, i] = self.initial_Y[self.n_system_variables:] 85 | else: 86 | self.Y = self.initial_Y 87 | self.FIMs = [] 88 | self.detFIMs = [] 89 | self.logdetFIMs =[] 90 | self.us = [] 91 | self.true_trajectory = [] 92 | self.est_trajectory = [] 93 | self.current_tstep = 0 94 | 95 | def G(self, Y, theta, u): 96 | ''' 97 | Uses the system equations to setup the full derivatives of system variables plus the FIM 98 | :param Y: system state 99 | :param theta: parameters 100 | :param u: inputs 101 | :return RHS: the full system of derivatives 102 | ''' 103 | 104 | 105 | RHS = SX.sym('RHS', len(Y.elements())) 106 | 107 | # xdot = (sym_theta[0] * sym_u/(sym_theta[1] + sym_u))*sym_Y[0] 108 | 109 | dx = self.xdot(Y, theta,u) 110 | 111 | 112 | sensitivities_dot = jacobian(dx[0:self.n_observed_variables], theta) + mtimes(jacobian(dx[0:self.n_observed_variables], Y[0:self.n_observed_variables]), jacobian(Y[0:self.n_observed_variables], theta)) 113 | 114 | #TODO: dont need this as parameters not dimensioned and helps FIM not become nan 115 | 116 | for i in range(sensitivities_dot.size()[0]): # logarithmic sensitivities 117 | sensitivities_dot[i, :] *= (fabs(theta.T)+1e-5) # absolute value becuase we have negative params 118 | 119 | 120 | 121 | std = 0.05 * Y[0:self.n_observed_variables] # to stop divde by zero when conc = 0 122 | 123 | 124 | inv_sigma = SX.sym('sig', self.n_observed_variables, self.n_observed_variables) # sigma matrix in Nates paper 125 | 126 | for i in range(self.n_observed_variables): 127 | for j in range(self.n_observed_variables): 128 | 129 | if i == j: 130 | inv_sigma[i, j] = 1/(std[i] * Y[i]) 131 | else: 132 | inv_sigma[i, j] = 0 133 | 134 | sensitivities = reshape(Y[self.n_system_variables:self.n_system_variables + self.n_params *self.n_observed_variables], 135 | (self.n_observed_variables, self.n_params)) 136 | FIM_dot = mtimes(transpose(sensitivities), mtimes(inv_sigma, sensitivities)) 137 | FIM_dot = self.get_unique_elements(FIM_dot) 138 | 139 | RHS[0:self.n_system_variables] = dx 140 | sensitivities_dot = reshape(sensitivities_dot, (sensitivities_dot.size(1) * sensitivities_dot.size(2), 1)) 141 | RHS[self.n_system_variables:self.n_system_variables + self.n_sensitivities] = sensitivities_dot 142 | 143 | RHS[self.n_system_variables + self.n_sensitivities:] = FIM_dot 144 | 145 | return RHS 146 | 147 | def get_one_step_RK(self, theta, u, dt, mode = 'OED'): 148 | ''' 149 | create the function that performs one step of RK4 150 | :param theta: parameters 151 | :param u: inputs 152 | :param dt: timestep 153 | :param mode: switch between OED and just simulating the system with no FIM 154 | :return G_1: the casadi function that performs one step of RK4 155 | ''' 156 | 157 | if mode == 'OED': 158 | Y = SX.sym('Y', self.n_tot) 159 | RHS = self.G(Y, theta, u) 160 | else: 161 | Y = SX.sym('Y', self.n_system_variables) 162 | RHS = self.xdot(Y, theta,u) 163 | 164 | 165 | g = Function('g', [Y, theta, u], [RHS]) 166 | 167 | Y_input = SX.sym('Y_input', RHS.shape[0]) 168 | 169 | k1 = g(Y_input, theta, u) 170 | 171 | 172 | k2 = g(Y_input + dt / 2.0 * k1, theta, u) 173 | k3 = g(Y_input + dt / 2.0 * k2, theta, u) 174 | k4 = g(Y_input + dt * k3, theta, u) 175 | 176 | Y_output = Y_input + dt / 6.0 * (k1 + 2 * k2 + 2 * k3 + k4) 177 | 178 | G_1 = Function('G_1', [Y_input, theta, u], [Y_output]) 179 | return G_1 180 | 181 | def get_control_interval_solver(self, control_interval_time, dt, mode = 'OED'): 182 | ''' 183 | creates the function that performs simulation of one control interval 184 | :param control_interval_time: 185 | :param dt: finite different timestep 186 | :param mode: switch between OED and just simulation without the FIM 187 | :return G: function that simulates a single control interval 188 | ''' 189 | 190 | # TODO: try mapaccum in here to reduce memory usage 191 | 192 | #theta = SX.sym('theta', len(self.actual_params.elements())) used for the chemostat OED 193 | 194 | theta = SX.sym('theta', self.actual_params.size()) 195 | 196 | u = SX.sym('u', self.n_controlled_inputs) 197 | 198 | G_1 = self.get_one_step_RK(theta, u, dt, mode = mode) # pass theta and u in just in case# 199 | 200 | 201 | if mode == 'OED': 202 | Y_0 = SX.sym('Y_0', self.n_tot) 203 | else: 204 | Y_0 = SX.sym('Y_0', self.n_system_variables) 205 | Y_iter = Y_0 206 | 207 | 208 | for i in range(int(control_interval_time / dt)): 209 | Y_iter = G_1(Y_iter, theta, u) 210 | 211 | G = Function('G', [Y_0, theta, u], [Y_iter]) 212 | 213 | #G = G_1.mapaccum('control_interval', int(control_interval_time / dt)) # should use less memory than the for loop.. This messes up the shap of action inputs 214 | return G 215 | 216 | def get_sampled_trajectory_solver(self, N_control_intervals, control_interval_time, dt, mode = 'OED'): 217 | ''' 218 | simulates a whole experiment returns the observed measurements at the control intervals 219 | :param N_control_intervals: number of control inputs in hte experiment 220 | :param control_interval_time: time between control inputs 221 | :param dt: finite difference timestep 222 | :param mode: switch between OED and just simulation without the FIM 223 | :return trajectory_solver: the casadi function that performs the simulation 224 | ''' 225 | #CI_solver = self.get_control_interval_solver(control_interval_time, dt, mode = mode) 226 | 227 | #opt = {'base':1} 228 | trajectory_solver = self.CI_solver.mapaccum('trajectory', N_control_intervals) 229 | 230 | return trajectory_solver 231 | 232 | def get_full_trajectory_solver(self, N_control_intervals, control_interval_time, dt): 233 | ''' 234 | simulates a whole experiment returns the full trajectory 235 | :param N_control_intervals: number of control inputs in hte experiment 236 | :param control_interval_time: time between control inputs 237 | :param dt: finite difference timestep 238 | :param mode: switch between OED and just simulation without the FIM 239 | :return trajectory_solver: the casadi function that performs the simulation 240 | ''' 241 | # need to expand the us before putting into this solver 242 | theta = SX.sym('theta', len(self.actual_params.elements())) 243 | u = SX.sym('u', self.n_controlled_inputs) 244 | G = self.get_one_step_RK(theta, u, dt) 245 | 246 | trajectory_solver = G.mapaccum('trajectory', int(N_control_intervals * control_interval_time / dt)) 247 | return trajectory_solver 248 | 249 | def gauss_newton(self, e,nlp,V, max_iter = 3000, limited_mem = False): 250 | ''' 251 | creates a gauss newton solver 252 | :param e: objective to minimise 253 | :param nlp: the non-linear program 254 | :param V: the inputs to optimise wrt 255 | :param max_iter: maximmum bumber of iterations 256 | :param limited_mem: an approximation to reduce memory usage 257 | :return: solver 258 | ''' 259 | J = jacobian(e,V) 260 | print('jacobian init') 261 | H = triu(mtimes(J.T, J)) 262 | print('hessian init') 263 | sigma = SX.sym("sigma") 264 | hessLag = Function('nlp_hess_l',{'x':V,'lam_f':sigma, 'hess_gamma_x_x':sigma*H}, 265 | ['x','p','lam_f','lam_g'], ['hess_gamma_x_x'], 266 | dict(jit=False, compiler='clang', verbose = False)) 267 | print('hesslag init') 268 | 269 | #IPOPT options https://coin-or.github.io/Ipopt/OPTIONS.html 270 | #return nlpsol("solver","ipopt", nlp, dict(ipopt={'max_iter':20}, hess_lag=hessLag, jit=False, compiler='clang', verbose_init = False, verbose = False)) 271 | 272 | # using the limited memory hessian approximation for ipopt seems to make it unstable 273 | ipopt_opt = {'max_iter': max_iter} 274 | if limited_mem: 275 | ipopt_opt['hessian_approximation'] = 'limited-memory' 276 | return nlpsol("solver","ipopt", nlp, dict(ipopt = ipopt_opt, hess_lag=hessLag, jit=False, compiler='clang', verbose_init=False, verbose=False)) 277 | #'acceptable_tol':10, 'acceptable_iter':30,'s_max':1e10, 'obj_scaling_factor': 1e5 278 | #return nlpsol("solver","ipopt", nlp, dict(ipopt={'hessian_approximation':'limited_memory'})) 279 | 280 | 281 | 282 | def get_u_solver(self): 283 | ''' 284 | optimises the next input to maximise the det(FIM) 285 | :return: solver 286 | ''' 287 | 288 | 289 | 290 | u = SX.sym('u', self.n_controlled_inputs) 291 | trajectory_solver = self.get_sampled_trajectory_solver(len(self.us) + 1, self.control_interval_time, self.dt) 292 | # self.past_trajectory_solver = self.get_trajectory_solver(self.xdot, len(self.us)) 293 | 294 | all_us = SX.sym('all_us', (len(self.us) + 1, self.n_controlled_inputs)) 295 | 296 | print('all', all_us.shape) 297 | print('us',self.us) 298 | all_us[0: len(self.us), :] = np.array(self.us).reshape(-1, self.n_controlled_inputs) 299 | all_us[-1, :] = u 300 | 301 | est_trajectory = trajectory_solver(self.initial_Y, self.param_guesses, transpose(all_us)) 302 | 303 | FIM = self.get_FIM(est_trajectory) 304 | 305 | # past_trajectory = self.past_trajectory_solver(self.initial_Y, self.us, self.param_guesses) 306 | # current_FIM = self.get_FIM(past_trajectory) 307 | 308 | q,r = qr(FIM) 309 | 310 | obj = -trace(log(r)) 311 | #obj = -log(det(FIM)) 312 | nlp = {'x': u, 'f': obj} 313 | solver = self.gauss_newton(obj, nlp, u) 314 | #solver.print_options() 315 | #sys.exit() 316 | 317 | return solver # , current_FIM 318 | 319 | def get_param_solver(self, trajectory_solver, test_trajectory=None, initial_Y = None): 320 | ''' 321 | creates the solver to fit the params 322 | :param trajectory_solver: the solver for the trajectory given the params 323 | :param test_trajectory: the true trajectory 324 | :return: parameter solver 325 | ''' 326 | sym_theta = SX.sym('theta', len(self.param_guesses.elements())) 327 | 328 | if initial_Y is None: 329 | initial_Y = self.initial_Y 330 | 331 | if test_trajectory is None: 332 | trajectory = trajectory_solver(DM(initial_Y), self.actual_params, np.array(self.us).T, mode = 'param') 333 | 334 | print('p did:', trajectory.shape) 335 | else: 336 | trajectory = test_trajectory 337 | print('p did:', trajectory.shape) 338 | 339 | est_trajectory_sym = trajectory_solver(DM(initial_Y), sym_theta, np.array(self.us).T) 340 | print('sym trajectory initialised') 341 | print('sym traj:', est_trajectory_sym.shape) 342 | print('traj:', trajectory.shape) 343 | 344 | e = (trajectory[0:self.n_observed_variables, :].T - est_trajectory_sym[0:self.n_observed_variables, :].T)/(0.05 * trajectory[0:self.n_observed_variables, :].T + 0.00000001) 345 | print('e shape:', e.shape) 346 | print(dot(e, e).shape) 347 | 348 | nlp = {'x': sym_theta, 'f': 0.5 * dot(e , 349 | e)} # weighted least squares 350 | print('nlp initialised') 351 | #solver = self.gauss_newton(e, nlp, sym_theta, max_iter = 100000) 352 | solver = self.gauss_newton(e, nlp, sym_theta) 353 | print('solver initialised') 354 | 355 | 356 | return solver 357 | 358 | def step(self, action = None, continuous = True, use_old_state = False, scaling = None): 359 | ''' 360 | performs one RL ste 361 | :param action: 362 | :param continuous: 363 | :param use_old_state: 364 | :return: state, action, reward, done 365 | :param scaling: scaling function for the actions from the RL agent e.g. convert from log scale 366 | ''' 367 | 368 | self.current_tstep += 1 369 | if action is None: # Traditional OED step 370 | u_solver = self.get_u_solver() 371 | #u = u_solver(x0=10**self.u0, lbx = 10**self.input_bounds[0], ubx = 10**self.input_bounds[1])['x'] 372 | u = u_solver(x0=self.u0, lbx = self.input_bounds[:,0], ubx = self.input_bounds[:,1])['x'] 373 | self.us.append(u.elements()) 374 | else: #RL step 375 | if not continuous: 376 | u = self.actions_to_inputs(action) 377 | else: 378 | u = self.input_bounds[:, 0].reshape(-1, 1) + ( 379 | self.input_bounds[:, 1] - self.input_bounds[:, 0]).reshape(-1, 1) * action 380 | 381 | if scaling is not None: 382 | u = scaling(u) 383 | self.us.append(u[0]) 384 | 385 | N_control_intervals = len(self.us) 386 | #N_control_intervals = 12 387 | sampled_trajectory_solver = self.get_sampled_trajectory_solver(N_control_intervals, self.control_interval_time, self.dt) # the sampled trajectory seen by the agent 388 | 389 | 390 | #trajectory_solver = self.get_full_trajectory_solver(N_control_intervals, control_interval_time, self.dt) # the true trajectory of the system 391 | #trajectory_solver = trajectory_solver(N_control_intervals, control_interval_time, dt ) #this si the symbolic trajectory 392 | t = time.time() 393 | self.true_trajectory = sampled_trajectory_solver(self.initial_Y, self.actual_params, np.array(self.us).T) 394 | #self.est_trajectory = sampled_trajectory_solver(self.initial_Y, self.param_guesses, self.us ) 395 | 396 | #param_solver = self.get_param_solver(sampled_trajectory_solver) 397 | # estimate params based on whole trajectory so far 398 | #disablePrint() 399 | #self.param_guesses = param_solver(x0=self.param_guesses, lbx = 0)['x'] 400 | #enablePrint() 401 | #self.all_param_guesses.append(self.param_guesses.elements()) 402 | 403 | #reward = self.get_reward(self.est_trajectory) 404 | 405 | reward = self.get_reward(self.true_trajectory) 406 | 407 | done = False 408 | 409 | #state = self.get_RL_state(self.true_trajectory, self.est_trajectory) 410 | 411 | state = self.get_RL_state(self.true_trajectory, self.true_trajectory, use_old_state = use_old_state) 412 | 413 | 414 | 415 | self.all_RL_states.append(state) 416 | return state, reward, done, None 417 | 418 | 419 | def get_reward(self, est_trajectory): 420 | ''' 421 | calculates the reward for an RL agent 422 | :param est_trajectory: 423 | :return: reward 424 | ''' 425 | FIM = self.get_FIM(est_trajectory) 426 | 427 | #use this method to remove the small negatvie eigenvalues 428 | 429 | # casadi QR seems better,gives same results as np but some -ves in different places and never gives -ve determinant 430 | q, r = qr(FIM) 431 | 432 | det_FIM = np.prod(diag(r).elements()) 433 | 434 | logdet_FIM = trace(log(r)).elements()[0] # do it like this to protect from numerical errors from multiplying large EVs 435 | 436 | if det_FIM <= 0: 437 | print('----------------------------------------smaller than 0') 438 | eigs = np.real(np.linalg.eig(FIM)[0]) 439 | eigs[eigs<0] = 0.00000000000000000000000001 440 | det_FIM = np.prod(eigs) 441 | logdet_FIM = np.log(det_FIM) 442 | 443 | self.FIMs.append(FIM) 444 | self.detFIMs.append(det_FIM) 445 | self.logdetFIMs.append(logdet_FIM) 446 | 447 | try: 448 | #reward = np.log(det_FIM-self.detFIMs[-2]) 449 | reward = logdet_FIM - self.logdetFIMs[-2] 450 | #print('det adfa: ', det_FIM) 451 | #print(det_FIM - self.detFIMs[-2]) 452 | except: 453 | print('return') 454 | reward = logdet_FIM 455 | 456 | if math.isnan(reward): 457 | pass 458 | print() 459 | print('nan reward, FIM might have negative determinant !!!!') 460 | 461 | reward = -100 462 | return reward/100 463 | 464 | 465 | def action_to_input(self,action): 466 | ''' 467 | Takes a discrete action index and returns the corresponding continuous state 468 | vector 469 | 470 | :param action: the descrete action 471 | :returns:action 472 | ''' 473 | 474 | # calculate which bucket each eaction belongs in 475 | 476 | buckets = np.unravel_index(action, [self.num_inputs] *self.n_controlled_inputs) 477 | 478 | # convert each bucket to a continuous state variable 479 | Cin = [] 480 | for r in buckets: 481 | Cin.append(self.input_bounds[0] + r*(self.input_bounds[1]-self.input_bounds[0])/(self.num_inputs-1)) 482 | 483 | Cin = np.array(Cin).reshape(-1,1) 484 | 485 | return np.clip(Cin, self.input_bounds[0], self.input_bounds[1]) 486 | 487 | 488 | def get_FIM(self, trajectory): 489 | ''' 490 | assembles the FIM from an experimental trajectory 491 | :param trajectory: 492 | :return: FIM 493 | ''' 494 | 495 | # Tested on 2x2 and 5x5 matrices 496 | FIM_start = self.n_system_variables + self.n_sensitivities 497 | 498 | FIM_end = FIM_start + self.n_FIM_elements 499 | 500 | FIM_elements = trajectory[FIM_start:FIM_end, -1] 501 | 502 | start = 0 503 | end = self.n_params 504 | # FIM_elements = np.array([11,12,13,14,15,22,23,24,25,33,34,35,44,45,55]) for testing 505 | FIM = reshape(FIM_elements[start:end], (1, self.n_params)) # the first row 506 | 507 | for i in range(1, self.n_params): # for each row 508 | start = end 509 | end = start + self.n_params - i 510 | 511 | # get the first n_params - i elements 512 | row = FIM_elements[start:end] 513 | 514 | # get the other i elements 515 | 516 | for j in range(i - 1, -1, -1): 517 | row = horzcat(FIM[j, i], reshape(row, (1, -1))) 518 | 519 | reshape(row, (1, self.n_params)) # turn to row ector 520 | 521 | FIM = vertcat(FIM, row) 522 | 523 | #sys.exit() 524 | return FIM 525 | 526 | def get_unique_elements(self, FIM): 527 | ''' 528 | gets the unique elements of the FIM 529 | :param FIM: 530 | :return: unique elements 531 | ''' 532 | 533 | n_unique_els = sum(range(self.n_params + 1)) 534 | 535 | UE = SX.sym('UE', n_unique_els) 536 | start = 0 537 | end = self.n_params 538 | for i in range(self.n_params): 539 | UE[start:end] = transpose(FIM[i, i:]) 540 | start = end 541 | end += self.n_params - i - 1 542 | 543 | return UE 544 | 545 | def normalise_RL_state(self, state): 546 | #print(state) 547 | 548 | return state / self.normaliser 549 | 550 | def get_RL_state(self, true_trajectory, est_trajectory, use_old_state = False, use_time = True): 551 | ''' 552 | from a trajectory assemble the RL state 553 | :param true_trajectory: experimetnal trajectory 554 | :param est_trajectory: estimated trajectory, if not doing iterative inference this should be true_trajectory 555 | :param use_old_state: 556 | :param use_time: 557 | :return: RL state 558 | ''' 559 | 560 | # get the current measured system state 561 | sys_state = true_trajectory[:self.n_observed_variables, -1] # TODO: measurement noise 562 | 563 | if use_old_state: 564 | state = sys_state 565 | else: 566 | state = np.sqrt(sys_state) 567 | # get current fim elements 568 | FIM_start = self.n_system_variables + self.n_sensitivities 569 | 570 | FIM_end = FIM_start + self.n_FIM_elements 571 | 572 | # FIM_elements = true_trajectory[FIM_start:FIM_end] 573 | FIM_elements = est_trajectory[FIM_start:FIM_end, -1] 574 | 575 | FIM_signs = np.sign(FIM_elements) 576 | FIM_elements = FIM_signs * sqrt(fabs(FIM_elements)) 577 | 578 | if use_old_state: 579 | state = np.append(sys_state, np.append(self.param_guesses, FIM_elements)) 580 | 581 | if use_time: 582 | state = np.append(state, self.current_tstep) 583 | else: 584 | state = np.append(state, 0) 585 | 586 | #state = np.append(state, self.logdetFIMs[-1]) 587 | 588 | return self.normalise_RL_state(state) 589 | 590 | def get_initial_RL_state(self, use_old_state = False): 591 | ''' 592 | create the initial RL state for the beginning of an episode 593 | :param use_old_state: 594 | :return: initial state 595 | ''' 596 | if use_old_state: 597 | state = np.array(list(np.sqrt(self.x0[0:self.n_observed_variables])) + self.param_guesses.elements() + [0] * self.n_FIM_elements) 598 | else: 599 | state = np.array(list(np.sqrt(self.x0[0:self.n_observed_variables]))) 600 | state = np.append(state, 0) #time 601 | #state = np.append(state, 0) #logdetFIM 602 | 603 | return self.normalise_RL_state(state) 604 | 605 | def get_initial_RL_state_parallel(self, o0 = None, use_old_state = False, i = 0): 606 | ''' 607 | create the initial RL state for the beginning of an episode 608 | :param o0: inital state of the system observables 609 | :param use_old_state: 610 | :param i: the index of the parallel experiment 611 | :return: initial RL state 612 | ''' 613 | 614 | #state = np.array(list(np.sqrt(self.x0[0:self.n_observed_variables])) + self.param_guesses[i,:].elements() + [0] * self.n_FIM_elements) 615 | if o0 is None: 616 | o0 = self.x0 617 | 618 | 619 | if use_old_state: 620 | state = np.array(list(np.sqrt(o0[0:self.n_observed_variables])) + self.param_guesses[i,:].elements() + [ 621 | 0] * self.n_FIM_elements) 622 | else: 623 | state = np.array(list(np.sqrt(o0[0:self.n_observed_variables]))) 624 | 625 | 626 | state = np.append(state, 0) #time 627 | #state = np.append(state, 0) #logdetFIM 628 | 629 | return self.normalise_RL_state(state) 630 | 631 | def map_parallel_step(self, actions, actual_params, continuous = False, Ds = False, use_old_state=False, use_time=True, scaling = None): 632 | ''' 633 | runs step in parrallel using casadi map function 634 | :param actions: actions for all the parallel experiments 635 | :param actual_params: the parameters for each expreiment 636 | :param continuous: 637 | :param Ds: use Ds design 638 | :param use_old_state: 639 | :param use_time: add time to the state 640 | :param scaling: scaling function for the actions from the RL agent e.g. convert from log scale 641 | :return: the transitions of all parallel experiments 642 | ''' 643 | self.current_tstep += 1 644 | # actions, actual_params = args 645 | 646 | # all_us = [] 647 | # for As in actions: 648 | # us = [self.action_to_input(action) for action in actions] 649 | if not continuous: 650 | us = self.actions_to_inputs(actions) 651 | else: 652 | us = self.input_bounds[:, 0].reshape(-1, 1) + (self.input_bounds[:,1] - self.input_bounds[:, 0]).reshape(-1, 1)*actions 653 | 654 | if scaling is not None: 655 | us = scaling(us) 656 | 657 | actual_params = DM(actual_params) 658 | 659 | N_control_intervals = len(us) 660 | 661 | # set sampled trajectory solver in script to ensure thread safety 662 | true_trajectories = self.mapped_trajectory_solver(self.Y, actual_params.T, np.array(us)) 663 | transitions = [] 664 | 665 | for i in range(true_trajectories.shape[1]): 666 | true_trajectory = true_trajectories[:, i] 667 | reward = self.get_reward_parallel(true_trajectory, i, Ds = Ds) 668 | 669 | done = False 670 | 671 | # state = self.get_RL_state(self.true_trajectory, self.est_trajectory) 672 | state = self.get_RL_state_parallel(true_trajectory, true_trajectory,i, use_old_state = use_old_state, use_time=use_time) 673 | 674 | 675 | transitions.append((state, reward, done, None, us[:,i])) 676 | 677 | self.Y = true_trajectories 678 | self.Ys.append(self.Y.elements()) 679 | return transitions 680 | 681 | def actions_to_inputs(self, actions): 682 | ''' 683 | PARALLEL action to input 684 | 685 | Takes a discrete action index and returns the corresponding continuous state 686 | vector 687 | 688 | :param action: the descrete action 689 | :returns:action 690 | ''' 691 | 692 | # calculate which bucket each eaction belongs in 693 | 694 | buckets = np.unravel_index(actions, [self.num_inputs] * self.n_controlled_inputs) 695 | buckets = np.array(buckets) 696 | # convert each bucket to a continuous state variable 697 | # TODO: make this work with multiple different input doounds 698 | Cin = self.input_bounds[0][0] + buckets * (self.input_bounds[0][1] - self.input_bounds[0][0]) / (self.num_inputs - 1) 699 | 700 | return np.clip(Cin, self.input_bounds[0][0], self.input_bounds[0][1]) 701 | 702 | def get_reward_parallel(self, est_trajectory, i, Ds = False): 703 | ''' 704 | parrallel get reward 705 | :param est_trajectory: 706 | :param i: parallel index 707 | :param Ds: use Ds optimality 708 | :return: reward 709 | ''' 710 | FIM = self.get_FIM(est_trajectory) 711 | if Ds: # partition FIM and get determinant of the params we are interested in (the elements of LV matrix) 712 | M11 = FIM[0:-4, 0:-4] 713 | q11, r11 = qr(M11) 714 | 715 | logdet_M11 = trace(log(r11)).elements()[0] # do it like this to protect from numerical errors from multiplying large EVs 716 | q, r = qr(FIM) 717 | det_FIM = np.prod(diag(r).elements()) 718 | logdet_FIM = trace(log(r)).elements()[0] # do it like this to protect from numerical errors from multiplying large EVs 719 | logdet_FIM -= logdet_M11 720 | else: 721 | # use this method to remove the small negatvie eigenvalues 722 | 723 | # casadi QR seems better,gives same results as np but some -ves in different places and never gives -ve determinant 724 | q, r = qr(FIM) 725 | det_FIM = np.prod(diag(r).elements()) 726 | logdet_FIM = trace(log(r)).elements()[0] # do it like this to protect from numerical errors from multiplying large EVs 727 | if det_FIM <= 0: 728 | print('----------------------------------------smaller than 0') 729 | eigs = np.real(np.linalg.eig(FIM)[0]) 730 | eigs[eigs < 0] = 0.00000000000000000000000001 731 | det_FIM = np.prod(eigs) 732 | logdet_FIM = np.log(det_FIM) 733 | 734 | print(det_FIM) 735 | 736 | self.detFIMs[i].append(det_FIM) 737 | self.logdetFIMs[i].append(logdet_FIM) 738 | 739 | try: 740 | reward = logdet_FIM - self.logdetFIMs[i][-2] 741 | 742 | except: 743 | 744 | reward = logdet_FIM 745 | 746 | if math.isnan(reward): 747 | pass 748 | print() 749 | print('nan reward, FIM might have negative determinant !!!!') 750 | print(logdet_FIM, self.logdetFIMs) 751 | 752 | reward = -100 753 | 754 | return reward/100 755 | 756 | def get_RL_state_parallel(self, true_trajectory, est_trajectory,i, use_old_state = False, use_time = True): 757 | ''' 758 | 759 | parallel get Rl state 760 | :param true_trajectory: 761 | :param est_trajectory: 762 | :param i: parallel index 763 | :param use_old_state: 764 | :param use_time: 765 | :return: state 766 | ''' 767 | 768 | # get the current measured system state 769 | 770 | 771 | sys_state = true_trajectory[:self.n_observed_variables, -1] # TODO: measurement noise 772 | 773 | 774 | if use_old_state: 775 | state = np.sqrt(sys_state) 776 | else: 777 | state = np.sqrt(sys_state) 778 | 779 | 780 | # get current fim elements 781 | FIM_start = self.n_system_variables + self.n_sensitivities 782 | 783 | FIM_end = FIM_start + self.n_FIM_elements 784 | 785 | # FIM_elements = true_trajectory[FIM_start:FIM_end] 786 | FIM_elements = est_trajectory[FIM_start:FIM_end, -1] 787 | 788 | FIM_signs = np.sign(FIM_elements) 789 | FIM_elements = FIM_signs * sqrt(fabs(FIM_elements)) 790 | 791 | 792 | 793 | if use_old_state: 794 | state = np.append(state, np.append(self.param_guesses[i,:], FIM_elements)) 795 | 796 | 797 | if use_time: 798 | state = np.append(state, self.current_tstep) 799 | else: 800 | state = np.append(state, 0) 801 | 802 | 803 | #state = np.append(state, self.logdetFIMs[i][-1]) 804 | 805 | 806 | return self.normalise_RL_state(state) 807 | 808 | -------------------------------------------------------------------------------- /RED/agents/continuous_agents/rt3d.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import gc 3 | import math 4 | import os 5 | 6 | import numpy as np 7 | import torch 8 | from torch import nn 9 | from torch.nn.utils.rnn import pad_sequence 10 | from torch.optim import Adam 11 | 12 | from RED.utils.network import NeuralNetwork 13 | 14 | 15 | class RT3D_agent(): 16 | ''' 17 | Class that implements the Recurrent Twin Delayed Deep Deterministic Policy Gradient agent (RT3D). 18 | ''' 19 | def __init__( 20 | self, 21 | val_module_specs, 22 | pol_module_specs, 23 | val_learning_rate=0.001, 24 | pol_learning_rate=0.001, 25 | batch_size=256, 26 | action_bounds=[0, 1], 27 | noise_bounds=[-0.25, 0.25], 28 | noise_std=0.1, 29 | gamma=1, 30 | polyak=0.995, 31 | mem_size=500_000_000, 32 | max_length=11, 33 | device="cuda" if torch.cuda.is_available() else "cpu", 34 | ): 35 | ''' 36 | Initialises the RT3D agent. 37 | :param val_module_specs: module specifications for the Q networks 38 | - last module gets the concatenated output of all the previous modules 39 | - list of dictionaries, each dictionary containing keys "input_size" and "layers" 40 | - "input_size" is the input size of the module 41 | - "layers" is a list of dictionaries, each containing "layer_type" and other key-value pairs, depending on the layer type: 42 | - "GRU" - "hidden_size", "num_layers" 43 | - "Linear" - "output_size" 44 | - "Lambda" - "lambda_expression" 45 | - Additional key-value pairs: 46 | - "activation" for the activation function which should be applied after the layer 47 | :param pol_module_specs: module specifications for the policy networks 48 | - last module gets the concatenated output of all the previous modules 49 | - list of dictionaries, each dictionary containing keys "input_size" and "layers" 50 | - "input_size" is the input size of the module 51 | - "layers" is a list of dictionaries, each containing "layer_type" and other key-value pairs, depending on the layer type: 52 | - "GRU" - "hidden_size", "num_layers" 53 | - "Linear" - "output_size" 54 | - "Lambda" - "lambda_expression" 55 | - Additional key-value pairs: 56 | - "activation" for the activation function which should be applied after the layer 57 | :param val_learning_rate: learning rate for the value networks 58 | :param pol_learning_rate: learning rate for the policy network 59 | :param batch_size: batch size for training the networks 60 | :param action_bounds: bounds for the actions 61 | :param noise_bounds: bounds for the noise added to the actions 62 | :param noise_std: standard deviation of the noise added to the actions 63 | :param gamma: discount rate 64 | :param polyak: polyak averaging rate (see method update_target_networks) 65 | :param mem_size: size of the replay buffer 66 | :param max_length: max sequence length 67 | :param device: device to use for pytorch operations 68 | ''' 69 | self.val_module_specs = val_module_specs 70 | self.pol_module_specs = pol_module_specs 71 | self.val_learning_rate = val_learning_rate 72 | self.pol_learning_rate = pol_learning_rate 73 | self.batch_size = batch_size 74 | self.action_bounds= action_bounds 75 | self.noise_bounds = noise_bounds 76 | self.noise_std = noise_std 77 | self.gamma = gamma 78 | self.polyak = polyak 79 | self.mem_size = mem_size 80 | self.max_length = max_length 81 | self.device = device 82 | 83 | ### initialise policy networks 84 | # policy network (base) 85 | self.policy_network = nn.ModuleList([ 86 | self.initialise_network(input_size=mod_specs["input_size"], layer_specs=mod_specs["layers"], init_optimizer=False) 87 | for mod_specs in pol_module_specs 88 | ]) 89 | self.seq_size = self.policy_network[0].input_size 90 | self.policy_out_size = self.policy_network[-1].output_size 91 | self.policy_network_opt = Adam(self.policy_network.parameters(), lr=self.pol_learning_rate) 92 | self.policy_network_loss = lambda predicted_action_values: -torch.mean(predicted_action_values) 93 | # policy network (target) 94 | self.policy_target = nn.ModuleList([ 95 | self.initialise_network(input_size=mod_specs["input_size"], layer_specs=mod_specs["layers"], init_optimizer=False) 96 | for mod_specs in pol_module_specs 97 | ]) 98 | 99 | ### initialise value networks 100 | # Q-value network 1 (base) 101 | self.Q1_network = nn.ModuleList([ 102 | self.initialise_network(input_size=mod_specs["input_size"], layer_specs=mod_specs["layers"], init_optimizer=False) 103 | for mod_specs in val_module_specs 104 | ]) 105 | self.Q1_network_opt = Adam(self.Q1_network.parameters(), lr=val_learning_rate) 106 | self.Q1_network_loss = nn.MSELoss() 107 | # Q-value network 1 (target) 108 | self.Q1_target = nn.ModuleList([ 109 | self.initialise_network(input_size=mod_specs["input_size"], layer_specs=mod_specs["layers"], init_optimizer=False) 110 | for mod_specs in val_module_specs 111 | ]) 112 | 113 | # Q-value network 2 (base) 114 | self.Q2_network = nn.ModuleList([ 115 | self.initialise_network(input_size=mod_specs["input_size"], layer_specs=mod_specs["layers"], init_optimizer=False) 116 | for mod_specs in val_module_specs 117 | ]) 118 | self.Q2_network_opt = Adam(self.Q2_network.parameters(), lr=val_learning_rate) 119 | self.Q2_network_loss = nn.MSELoss() 120 | # Q-value network 2 (target) 121 | self.Q2_target = nn.ModuleList([ 122 | self.initialise_network(input_size=mod_specs["input_size"], layer_specs=mod_specs["layers"], init_optimizer=False) 123 | for mod_specs in val_module_specs 124 | ]) 125 | 126 | ### initialise buffers 127 | self.memory = [] 128 | self.values = [] 129 | self.states = [] 130 | self.next_states = [] 131 | self.actions = [] 132 | self.rewards = [] 133 | self.dones = [] 134 | self.sequences = [] 135 | self.next_sequences = [] 136 | self.all_returns = [] 137 | 138 | def initialise_network(self, input_size, layer_specs, init_optimizer=False): 139 | ''' 140 | Initialises a neural network. 141 | :param input_size: input size of the network 142 | :param layer_specs: a list of dictionaries, each containing "layer_type" + other key-value pairs, depending on the layer type: 143 | - "GRU" - "hidden_size", "num_layers" 144 | - "Linear" - "output_size" 145 | - "Lambda" - "lambda_expression" 146 | - "lambda_expression" must contain a python (lambda) function which takes a tensor as input and returns a tensor as output 147 | - examples (both are valid - string is converted to lambda expression): 148 | "lambda_expression": lambda x: x * 1 149 | "lambda_expression": "lambda x: x * 1" 150 | "lambda_expression": """ 151 | def f(x): 152 | return x * 1 153 | """ 154 | - Additional key-value pairs: 155 | - "activation" for the activation function which should be applied after the layer 156 | :return: initialised neural network 157 | ''' 158 | 159 | network = NeuralNetwork( 160 | input_size=input_size, 161 | layer_specs=layer_specs, 162 | init_optimizer=init_optimizer, 163 | device=self.device, 164 | ) 165 | return network 166 | 167 | def get_actions_dist(self, inputs, explore_rate, test_episode=False, recurrent=True): 168 | ''' 169 | Gets actions by adding random noise to the actions. 170 | :param inputs: inputs for policy network 171 | :param explore_rate: probability of taking a random action 172 | :param test_episode: whether the episode is a test episode 173 | :param recurrent: whether to use recurrent network 174 | :return: actions 175 | ''' 176 | 177 | actions = self.forward_policy_net( 178 | policy_net=self.policy_network, 179 | inputs=inputs, 180 | recurrent=recurrent 181 | ) 182 | 183 | actions = actions.detach().cpu().numpy() 184 | 185 | if test_episode: 186 | actions[:-1] += np.random.normal(0, explore_rate, size=actions[:-1].shape) 187 | else: 188 | actions += np.random.normal(0, explore_rate, size=actions.shape) 189 | 190 | actions = np.clip(actions, self.action_bounds[0], self.action_bounds[1]) 191 | 192 | return actions 193 | 194 | def get_actions(self, inputs, explore_rate, test_episode=False, recurrent=True, return_exploit_flags=False): 195 | ''' 196 | Gets a mix of explore/exploit actions between the min and max bounds (exploration with probablilty explore_rate uniformly distributed). 197 | :param inputs: inputs for policy network 198 | :param explore_rate: probability of taking a random action 199 | :param test_episode: whether the episode is a test episode 200 | :param recurrent: whether to use recurrent network 201 | :return: actions 202 | ''' 203 | 204 | if recurrent: 205 | states, sequences = inputs # states [batch, features], sequences [batch, sequence, features] 206 | else: 207 | states = inputs[0] 208 | 209 | ### choose between explore/exploit 210 | if test_episode: 211 | rng = np.random.random(len(states) - 1) 212 | else: 213 | rng = np.random.random(len(states)) 214 | 215 | explore_inds = np.where(rng < explore_rate)[0] 216 | exploit_inds = np.where(rng >= explore_rate)[0] 217 | 218 | if test_episode: 219 | exploit_inds = np.append(exploit_inds, len(states) - 1) 220 | 221 | if return_exploit_flags: 222 | exploit_flags = np.zeros((len(states)), dtype=np.int32) 223 | exploit_flags[exploit_inds] = 1 224 | 225 | ### get actions 226 | actions = torch.zeros((len(states), self.policy_out_size), dtype=torch.float32, device=self.device) 227 | 228 | # explore actions (uniformly distributed between the action bounds) 229 | explore_actions = (self.action_bounds[1] - self.action_bounds[0]) \ 230 | * torch.rand((len(explore_inds), self.policy_out_size), dtype=torch.float32, device=self.device) \ 231 | + self.action_bounds[0] 232 | actions[explore_inds] = explore_actions 233 | 234 | # exploit actions (policy network) 235 | if len(exploit_inds) > 0: 236 | # prepare inputs 237 | if recurrent: 238 | policy_net_inputs = [np.array(states)[exploit_inds], np.array(sequences)[exploit_inds]] 239 | else: 240 | policy_net_inputs = [np.array(states)[exploit_inds]] 241 | # run through the policy network 242 | exploit_actions = self.forward_policy_net(policy_net=self.policy_network, inputs=policy_net_inputs, recurrent=recurrent) 243 | # add noise 244 | exploit_actions += torch.normal(0, explore_rate * self.noise_std * 2, size=exploit_actions.shape, device=self.device) 245 | actions[exploit_inds] = exploit_actions 246 | 247 | # clip 248 | actions = torch.clamp(actions, self.action_bounds[0], self.action_bounds[1]) 249 | 250 | actions = actions.cpu().detach().numpy() 251 | 252 | return actions if not return_exploit_flags else (actions, exploit_flags) 253 | 254 | def forward_policy_net(self, policy_net: nn.Module, inputs, recurrent=True): 255 | ''' 256 | Forward pass of the policy net given inputs. 257 | :param inputs: inputs for policy network 258 | :param recurrent: whether to use the recurrent network 259 | :return: outputs of the given policy net 260 | ''' 261 | 262 | if recurrent: 263 | ### prepare inputs for pytorch's models 264 | states, sequences = inputs 265 | # states [batch, features] 266 | if type(states) != torch.Tensor: 267 | states = torch.tensor(states, dtype=torch.float32, device=self.device) 268 | # sequences [batch, sequence, features] 269 | if type(sequences) in (list, tuple): 270 | sequences = pad_sequence( 271 | [torch.tensor(seq[-self.max_length:], dtype=torch.float32) for seq in sequences], 272 | batch_first=True 273 | ).to(self.device) 274 | elif type(sequences) == np.ndarray: 275 | sequences = torch.tensor(sequences[:,-self.max_length:], dtype=torch.float32, device=self.device) 276 | elif type(sequences) != torch.Tensor: 277 | raise ValueError("Sequences must be a list, tuple, or numpy array") 278 | 279 | ### run through the policy network 280 | recurrent_out = policy_net[0](sequences)[:,-1,:] # last output of the sequence 281 | head_inps = torch.cat((states, recurrent_out), dim=1) 282 | policy_net_out = policy_net[1](head_inps) 283 | else: 284 | head_inps = torch.tensor(inputs[0], dtype=torch.float32, device=self.device) 285 | 286 | policy_net_out = policy_net[1](head_inps) 287 | 288 | return policy_net_out 289 | 290 | def forward_q_net(self, q_net: nn.Module, inputs, recurrent=True): 291 | ''' 292 | Forward pass of the given Q net with the inputs. 293 | :param inputs: inputs for Q network 294 | :param recurrent: whether to use the recurrent network 295 | :return: outputs of the given Q net 296 | ''' 297 | 298 | if recurrent: 299 | ### prepare inputs for pytorch's models 300 | state_actions, sequences = inputs 301 | # state-action pairs [batch, features] 302 | if type(state_actions) != torch.Tensor: 303 | state_actions = torch.tensor(state_actions, dtype=torch.float32, device=self.device) 304 | # sequences [batch, sequence, features] 305 | if type(sequences) in (list, tuple): 306 | sequences = pad_sequence( 307 | [torch.tensor(seq[-self.max_length:], dtype=torch.float32) for seq in sequences], 308 | batch_first=True 309 | ).to(self.device) 310 | elif type(sequences) == np.ndarray: 311 | sequences = torch.tensor(sequences[:,-self.max_length:], dtype=torch.float32, device=self.device) 312 | elif type(sequences) != torch.Tensor: 313 | raise ValueError("Sequences must be a list, tuple, or numpy array") 314 | 315 | ### run through the Q network 316 | recurrent_out = q_net[0](sequences)[:,-1,:] # last output of the sequence 317 | head_inps = torch.cat((state_actions, recurrent_out), dim=1) 318 | else: 319 | state_actions = inputs[0] 320 | head_inps = torch.tensor(state_actions, dtype=torch.float32, device=self.device) 321 | 322 | q_net_out = q_net[1](head_inps) 323 | 324 | return q_net_out 325 | 326 | def get_inputs_targets(self, recurrent=True, monte_carlo=False): 327 | ''' 328 | --- Use get_inputs_targets_low_mem() instead, this is *not* more efficient --- 329 | --- TODO: fix or remove --- 330 | assembles the Q learning inputs and trgets from agents memory 331 | :param recurrent: 332 | :param monte_carlo: 333 | :return: 334 | ''' 335 | 336 | ### collect the data 337 | sample_size = int(self.batch_size * 10) 338 | sample_idxs = np.random.randint(0, min(self.mem_size, len(self.memory)), size=(sample_size)) 339 | for i, trajectory in enumerate(self.memory): 340 | e_rewards = [] 341 | sequence = [[0]*self.seq_size] 342 | for j, transition in enumerate(trajectory): 343 | state, action, reward, next_state, done = transition 344 | 345 | self.sequences.append(copy.deepcopy(sequence)) 346 | sequence.append(np.concatenate((state, action))) 347 | self.next_sequences.append(copy.deepcopy(sequence)) 348 | self.states.append(state) 349 | self.next_states.append(next_state) 350 | self.actions.append(action) 351 | self.rewards.append(reward) 352 | self.dones.append(done) 353 | e_rewards.append(reward) 354 | 355 | if monte_carlo: 356 | e_values = [e_rewards[-1]] 357 | 358 | for i in range(2, len(e_rewards) + 1): 359 | e_values.insert(0, e_rewards[-i] + e_values[0] * self.gamma) 360 | self.all_returns.extend(e_values) 361 | 362 | # remove items if agents memory is full 363 | if len(self.states) > self.mem_size: 364 | del self.sequences[:len(self.states)-self.mem_size] 365 | del self.next_sequences[:len(self.states)-self.mem_size] 366 | del self.next_states[:len(self.states)-self.mem_size] 367 | del self.actions[:len(self.states)-self.mem_size] 368 | del self.rewards[:len(self.states)-self.mem_size] 369 | del self.dones[:len(self.states)-self.mem_size] 370 | del self.states[:len(self.states) - self.mem_size] 371 | 372 | # TODO: this is really memory inefficient, take random sample before initialising arrays 373 | next_states = np.array(self.next_states, dtype=np.float64)[:self.mem_size] 374 | rewards = np.array(self.rewards).reshape(-1, 1)[:self.mem_size] 375 | dones = np.array(self.dones).reshape(-1, 1)[:self.mem_size] 376 | states = np.array(self.states)[:self.mem_size] 377 | actions = np.array(self.actions)[:self.mem_size] 378 | all_returns = np.array(self.all_returns)[:self.mem_size] 379 | sequences = self.sequences[:self.mem_size] 380 | next_sequences = self.next_sequences[:self.mem_size] 381 | 382 | self.memory = self.memory[-self.mem_size:] # reset memory after this information has been extracted 383 | 384 | if monte_carlo : # only take last experiences 385 | ''' 386 | batch_size = self.batch_size 387 | if states.shape[0] > batch_size: 388 | states = states[-batch_size:] 389 | padded = padded[-batch_size:] 390 | next_padded = next_padded[-batch_size:] 391 | next_states = next_states[-batch_size:] 392 | actions = actions[-batch_size:] 393 | rewards = rewards[-batch_size:] 394 | dones = dones[-batch_size:] 395 | all_returns = all_returns[-batch_size:] 396 | ''' 397 | 398 | targets = all_returns 399 | pass 400 | else: 401 | ### take random sample 402 | sample_size = int(self.batch_size * 10) 403 | sample_idxs = np.random.randint(max(0, states.shape[0] - self.mem_size), states.shape[0], size=(sample_size)) 404 | 405 | states = states[sample_idxs] 406 | next_states = next_states[sample_idxs] 407 | actions = actions[sample_idxs] 408 | rewards = rewards[sample_idxs] 409 | dones = dones[sample_idxs] 410 | sequences = [sequences[i] for i in sample_idxs] 411 | next_sequences = [next_sequences[i] for i in sample_idxs] 412 | 413 | ### get next actions from target policy 414 | next_actions = self.forward_policy_net( 415 | policy_net=self.policy_target, 416 | inputs=[next_states, next_sequences], 417 | recurrent=recurrent, 418 | ).cpu().detach().numpy() 419 | 420 | # target policy smoothing 421 | noise = np.clip(np.random.normal( 0, self.noise_std, next_actions.shape), self.noise_bounds[0], self.noise_bounds[1]) 422 | next_actions = np.clip(next_actions + noise, self.action_bounds[0], self.action_bounds[1]) 423 | 424 | ### get next values from target Q networks 425 | self.Q1_target.eval() 426 | with torch.no_grad(): 427 | Q1 = self.forward_q_net( 428 | q_net=self.Q1_target, 429 | inputs=[np.concatenate((next_states, next_actions), axis=1), next_sequences], 430 | recurrent=recurrent 431 | ).cpu().detach().numpy() 432 | 433 | self.Q2_target.eval() 434 | with torch.no_grad(): 435 | Q2 = self.forward_q_net( 436 | q_net=self.Q2_target, 437 | inputs=[np.concatenate((next_states, next_actions), axis=1), next_sequences], 438 | recurrent=recurrent 439 | ).cpu().detach().numpy() 440 | 441 | next_values = np.minimum(Q1, Q2) 442 | targets = rewards + self.gamma * (1 - dones) * next_values 443 | 444 | ### shuffle the data and construct the inputs and targets 445 | randomize = np.arange(len(states)) 446 | np.random.shuffle(randomize) 447 | states = states[randomize] 448 | actions = actions[randomize] 449 | sequences = [sequences[i] for i in randomize] 450 | targets = targets[randomize] 451 | inputs = [states, sequences] 452 | targets = targets[randomize] 453 | 454 | gc.collect() # clear old stuff from memory 455 | return inputs, actions, targets 456 | 457 | def get_inputs_targets_low_mem(self, recurrent=True, monte_carlo=False): 458 | ''' 459 | Assembles the Q learning inputs and targets from agent's memory, uses less memory but is slower. 460 | :param recurrent: whether to use the recurrent networks 461 | :param monte_carlo: 462 | :return: inputs, actions, targets 463 | ''' 464 | 465 | # TODO: enable all the options here 466 | self.memory = self.memory[-self.mem_size:] 467 | sequences = [] 468 | next_sequences = [] 469 | states = [] 470 | next_states = [] 471 | actions = [] 472 | rewards = [] 473 | dones = [] 474 | 475 | ### collect the data 476 | sample_size = int(self.batch_size * 10) 477 | sample_idxs = np.random.randint(0, min(self.mem_size, len(self.memory)), size=(sample_size)) 478 | for trajectory_idx in sample_idxs: 479 | trajectory = self.memory[trajectory_idx] 480 | sequence = [np.array([0] * self.seq_size)] 481 | for transition in trajectory: 482 | state, action, reward, next_state, done = transition 483 | 484 | sequences.append(copy.deepcopy(np.array(sequence))) 485 | sequence.append(np.concatenate((state, action))) 486 | next_sequences.append(copy.deepcopy(np.array(sequence))) 487 | states.append(state) 488 | next_states.append(next_state) 489 | actions.append(action) 490 | rewards.append(reward) 491 | dones.append(done) 492 | 493 | next_states = np.array(next_states, dtype=np.float64) 494 | rewards = np.array(rewards).reshape(-1, 1) 495 | dones = np.array(dones).reshape(-1, 1) 496 | states = np.array(states) 497 | actions = np.array(actions) 498 | 499 | ### get next actions from target policy 500 | next_actions = self.forward_policy_net( 501 | policy_net=self.policy_target, 502 | inputs=[next_states, next_sequences], 503 | recurrent=recurrent, 504 | ).cpu().detach().numpy() 505 | 506 | # target policy smoothing 507 | noise = np.clip(np.random.normal(0, self.noise_std, next_actions.shape), self.noise_bounds[0], self.noise_bounds[1]) 508 | next_actions = np.clip(next_actions + noise, self.action_bounds[0], self.action_bounds[1]) 509 | 510 | ### get next values from target Q networks 511 | self.Q1_target.eval() 512 | with torch.no_grad(): 513 | Q1 = self.forward_q_net( 514 | q_net=self.Q1_target, 515 | inputs=[np.concatenate((next_states, next_actions), axis=1), next_sequences], 516 | recurrent=recurrent 517 | ).cpu().detach().numpy() 518 | 519 | self.Q2_target.eval() 520 | with torch.no_grad(): 521 | Q2 = self.forward_q_net( 522 | q_net=self.Q2_target, 523 | inputs=[np.concatenate((next_states, next_actions), axis=1), next_sequences], 524 | recurrent=recurrent 525 | ).cpu().detach().numpy() 526 | 527 | next_values = np.minimum(Q1, Q2) 528 | targets = rewards + self.gamma * (1 - dones) * next_values 529 | 530 | ### shuffle the data and construct the inputs and targets 531 | randomize = np.arange(len(states)) 532 | np.random.shuffle(randomize) 533 | states = states[randomize] 534 | actions = actions[randomize] 535 | sequences = [sequences[i] for i in randomize] 536 | targets = targets[randomize] 537 | inputs = [states, sequences] 538 | 539 | return inputs, actions, targets 540 | 541 | def get_rate(self, episode, min_rate, max_rate, denominator): 542 | ''' 543 | Calculates the logarithmically decreasing explore or learning rate. 544 | :param episode: the current episode 545 | :param min_rate: the minimum possible rate size 546 | :param max_rate: maximum rate size 547 | :param denominator: controls the rate of decay of the rate size 548 | :returns: the new rate size 549 | ''' 550 | 551 | # input validation 552 | if not 0 <= min_rate <= 1: 553 | raise ValueError("MIN_LEARNING_RATE needs to be bewteen 0 and 1") 554 | 555 | if not 0 <= max_rate <= 1: 556 | raise ValueError("MAX_LEARNING_RATE needs to be bewteen 0 and 1") 557 | 558 | if not 0 < denominator: 559 | raise ValueError("denominator needs to be above 0") 560 | 561 | rate = max(min_rate, min(max_rate, 1.0 - math.log10((episode + 1) / denominator))) 562 | return rate 563 | 564 | def train_q_net(self, q_net, inputs, targets, optimizer, criterion, epochs, batch_size=256, recurrent=True): 565 | ''' 566 | Trains the Q network on the given inputs and targets. 567 | :param q_net: the Q network to train 568 | :param inputs: the inputs to the Q network 569 | :param targets: the targets for the Q network 570 | :param optimizer: the optimizer to use 571 | :param criterion: the loss function to use 572 | :param epochs: the number of epochs to train for 573 | :param batch_size: the batch size to use 574 | :param recurrent: whether to use the recurrent networks 575 | :return: the trained Q network 576 | ''' 577 | 578 | if type(targets) in (list, tuple, np.ndarray): 579 | targets = torch.tensor(targets, dtype=torch.float32, device=self.device) 580 | 581 | batch_idxs = math.ceil(len(inputs[0]) / batch_size) 582 | 583 | for _ in range(epochs): 584 | # go through all the batches 585 | for batch_i in range(batch_idxs): 586 | start_idx = batch_i * batch_size 587 | end_idx = start_idx + batch_size 588 | 589 | # clear the gradients 590 | optimizer.zero_grad() 591 | 592 | # compute the model output 593 | pred_values = self.forward_q_net(q_net=q_net, inputs=[inps[start_idx:end_idx] for inps in inputs], recurrent=recurrent) 594 | 595 | # calculate loss + credit assignment 596 | loss = criterion(pred_values, targets[start_idx:end_idx]) 597 | loss.backward() 598 | 599 | # update model weights 600 | optimizer.step() 601 | 602 | return q_net 603 | 604 | def train_policy(self, inputs, epochs, recurrent=True): 605 | ''' 606 | Train the policy network on the given inputs. 607 | :param inputs: the inputs to the policy network 608 | :param epochs: the number of epochs to train for 609 | :param recurrent: whether to use the recurrent networks 610 | :return: the trained policy network 611 | ''' 612 | 613 | if recurrent: 614 | states, sequences = inputs 615 | else: 616 | states = inputs[0] 617 | 618 | batch_idxs = math.ceil(states.shape[0] / self.batch_size) 619 | epoch_losses = [] 620 | for epoch in range(epochs): 621 | batch_losses = [] 622 | for batch_i in range(batch_idxs): 623 | start_idx = batch_i * self.batch_size 624 | end_idx = start_idx + self.batch_size 625 | 626 | self.policy_network_opt.zero_grad() 627 | 628 | ### run: policy network -> state-action pairs -> Q values 629 | pred_actions = self.forward_policy_net( 630 | policy_net=self.policy_network, 631 | inputs=[states[start_idx:end_idx], sequences[start_idx:end_idx]], 632 | recurrent=recurrent, 633 | ) 634 | # gradients flows from final Q-values through pred_actions to policy network's parameters 635 | q_net_inputs = [ 636 | torch.cat(( 637 | torch.tensor(states[start_idx:end_idx], dtype=torch.float32, device=self.device), 638 | pred_actions 639 | ), dim=1), 640 | sequences[start_idx:end_idx] 641 | ] 642 | pred_values = self.forward_q_net( 643 | q_net=self.Q1_network, 644 | inputs=q_net_inputs, 645 | recurrent=recurrent 646 | ) 647 | 648 | ### calculate loss + credit assignment 649 | loss = self.policy_network_loss(pred_values) 650 | loss.backward() 651 | self.policy_network_opt.step() 652 | 653 | batch_losses.append(loss.item()) 654 | 655 | epoch_losses.append(np.mean(batch_losses)) 656 | 657 | # clear the gradients from both networks 658 | self.policy_network_opt.zero_grad() 659 | self.Q1_network_opt.zero_grad() 660 | 661 | return epoch_losses 662 | 663 | def validate_on_train(model, dataloader, loss): 664 | model.eval() 665 | loss_total = 0 666 | 667 | with torch.no_grad(): 668 | for data in dataloader: 669 | input = data[0] 670 | label = data[1] 671 | 672 | output = model(input.view(input.shape[0], -1)) 673 | loss = loss(output, label) 674 | loss_total += loss.item() 675 | 676 | return loss_total / len(dataloader) 677 | 678 | def Q_update(self, recurrent=True, monte_carlo=False, policy=True, verbose=False, low_mem=True, epochs=1): 679 | ''' 680 | Updates the Q network parameters. 681 | :param recurrent: whether to use the recurrent networks 682 | :param monte_carlo: 683 | :param policy: whether to update the policy network afterwards 684 | :param fitted: 685 | :param verbose: 686 | :param low_mem: whether to use low memory mode 687 | ''' 688 | 689 | if low_mem: 690 | inputs, actions, targets = self.get_inputs_targets_low_mem( 691 | recurrent=recurrent, monte_carlo=monte_carlo) 692 | else: 693 | inputs, actions, targets = self.get_inputs_targets( 694 | recurrent=recurrent, monte_carlo=monte_carlo) 695 | 696 | if recurrent: 697 | states, sequences = inputs 698 | else: 699 | states = inputs[0] 700 | 701 | q1_net_inputs = [np.concatenate((states, actions), axis=1), sequences] if recurrent else [np.concatenate((states, actions), axis=1)] 702 | self.Q1_network = self.train_q_net( 703 | q_net=self.Q1_network, 704 | inputs=q1_net_inputs, 705 | targets=targets, 706 | optimizer=self.Q1_network_opt, 707 | criterion=self.Q1_network_loss, 708 | epochs=epochs, 709 | batch_size=self.batch_size, 710 | recurrent=recurrent, 711 | ) 712 | q2_net_inputs = [np.concatenate((states, actions), axis=1), sequences] if recurrent else [np.concatenate((states, actions), axis=1)] 713 | self.Q2_network = self.train_q_net( 714 | q_net=self.Q2_network, 715 | inputs=q2_net_inputs, 716 | targets=targets, 717 | optimizer=self.Q2_network_opt, 718 | criterion=self.Q2_network_loss, 719 | epochs=epochs, 720 | batch_size=self.batch_size, 721 | recurrent=recurrent, 722 | ) 723 | 724 | if policy: 725 | epoch_losses = self.train_policy( 726 | inputs=[states, sequences] if recurrent else [states], 727 | epochs=epochs, 728 | recurrent=recurrent, 729 | ) 730 | 731 | ### update target networks when we update the policy 732 | if policy and not monte_carlo: 733 | self.update_target_network(source=self.Q1_network, target=self.Q1_target, tau=self.polyak) 734 | self.update_target_network(source=self.Q2_network, target=self.Q2_target, tau=self.polyak) 735 | self.update_target_network(source=self.policy_network, target=self.policy_target, tau=self.polyak) 736 | 737 | def save_ckpt(self, save_path, additional_info=None): 738 | ''' 739 | Creates a full checkpoint (networks, optimizers, memory buffers) and saves it to the specified path. 740 | :param save_path: path to save the checkpoint to 741 | :param additional_info: additional information to save (Python dictionary) 742 | ''' 743 | ckpt = { 744 | "policy_network": self.policy_network.state_dict(), 745 | "Q1_network": self.Q1_network.state_dict(), 746 | "Q2_network": self.Q2_network.state_dict(), 747 | "policy_target": self.policy_target.state_dict(), 748 | "Q1_target": self.Q1_target.state_dict(), 749 | "Q2_target": self.Q2_target.state_dict(), 750 | "policy_network_opt": self.policy_network_opt.state_dict(), 751 | "Q1_network_opt": self.Q1_network_opt.state_dict(), 752 | "Q2_network_opt": self.Q2_network_opt.state_dict(), 753 | "additional_info": additional_info if additional_info is not None else {}, 754 | } 755 | 756 | ### save buffers 757 | for buffer in ("memory", "values", "states", "next_states", "actions", "rewards", "dones", 758 | "sequences", "next_sequences", "all_returns"): 759 | ckpt[buffer] = getattr(self, buffer) 760 | 761 | ### save the checkpoint 762 | torch.save(ckpt, save_path) 763 | 764 | def load_ckpt(self, load_path, load_target_networks=True): 765 | ''' 766 | Loads a full checkpoint (networks, optimizers, memory buffers) from the specified path. 767 | :param load_path: path to load the checkpoint from 768 | :param load_target_networks: whether to load the target networks as well 769 | ''' 770 | ckpt = torch.load(load_path) 771 | 772 | ### load networks 773 | self.policy_network.load_state_dict(ckpt["policy_network"]) 774 | self.Q1_network.load_state_dict(ckpt["Q1_network"]) 775 | self.Q2_network.load_state_dict(ckpt["Q2_network"]) 776 | 777 | ### load target networks 778 | if load_target_networks: 779 | self.policy_target.load_state_dict(ckpt["policy_target"]) 780 | self.Q1_target.load_state_dict(ckpt["Q1_target"]) 781 | self.Q2_target.load_state_dict(ckpt["Q2_target"]) 782 | 783 | ### load optimizers 784 | self.policy_network_opt.load_state_dict(ckpt["policy_network_opt"]) 785 | self.Q1_network_opt.load_state_dict(ckpt["Q1_network_opt"]) 786 | self.Q2_network_opt.load_state_dict(ckpt["Q2_network_opt"]) 787 | 788 | ### load buffers 789 | for buffer in ("memory", "values", "states", "next_states", "actions", "rewards", "dones", 790 | "sequences", "next_sequences", "all_returns"): 791 | setattr(self, buffer, ckpt[buffer]) 792 | 793 | return ckpt 794 | 795 | def reset_weights(self, policy=True): 796 | ''' 797 | Reinitialises weights to random values. 798 | :param policy: whether to reinitialise policy network 799 | ''' 800 | del self.Q1_network 801 | del self.Q2_network 802 | if policy: 803 | del self.policy_network 804 | gc.collect() 805 | 806 | self.Q1_network = nn.ModuleList([ 807 | self.initialise_network(input_size=mod_specs["input_size"], layer_specs=mod_specs["layers"], init_optimizer=False) 808 | for mod_specs in self.val_module_specs 809 | ]) 810 | self.Q1_network_opt = Adam(self.Q1_network.parameters(), lr=self.val_learning_rate) 811 | self.Q1_network_loss = nn.MSELoss() 812 | 813 | self.Q2_network = nn.ModuleList([ 814 | self.initialise_network(input_size=mod_specs["input_size"], layer_specs=mod_specs["layers"], init_optimizer=False) 815 | for mod_specs in self.val_module_specs 816 | ]) 817 | self.Q2_network_opt = Adam(self.Q2_network.parameters(), lr=self.val_learning_rate) 818 | self.Q2_network_loss = nn.MSELoss() 819 | 820 | if policy: 821 | ### initialise policy networks 822 | self.policy_network = nn.ModuleList([ 823 | self.initialise_network(input_size=mod_specs["input_size"], layer_specs=mod_specs["layers"], init_optimizer=False) 824 | for mod_specs in self.pol_module_specs 825 | ]) 826 | self.policy_network_opt = Adam(self.policy_network.parameters(), lr=self.pol_learning_rate) 827 | self.policy_network_loss = lambda predicted_action_values: -torch.mean(predicted_action_values) 828 | 829 | def update_target_network(self, source, target, tau): 830 | ''' 831 | Updates the target network from the source network using Polyak averaging. 832 | :param source: source network 833 | :param target: target network 834 | :param tau: Polyak averaging parameter 835 | :return: updated target network 836 | ''' 837 | for target_param, source_param in zip(target.parameters(), source.parameters()): 838 | target_param.data.copy_(source_param.data * tau + target_param.data * (1.0 - tau)) 839 | return target 840 | --------------------------------------------------------------------------------