├── weights ├── bs │ ├── 1month_daily │ │ ├── ddpg_actor_3.h5 │ │ ├── ddpg_critic_Q_ex2_3.h5 │ │ └── ddpg_critic_Q_ex_3.h5 │ ├── 3month_daily │ │ ├── ddpg_actor_6.h5 │ │ ├── ddpg_critic_Q_ex2_6.h5 │ │ └── ddpg_critic_Q_ex_6.h5 │ ├── 1month_every2day │ │ ├── ddpg_actor_5.h5 │ │ ├── ddpg_critic_Q_ex2_5.h5 │ │ └── ddpg_critic_Q_ex_5.h5 │ ├── 3month_every3day │ │ ├── ddpg_actor_4.h5 │ │ ├── ddpg_critic_Q_ex2_4.h5 │ │ └── ddpg_critic_Q_ex_4.h5 │ ├── 3month_every5day │ │ ├── ddpg_actor_1.h5 │ │ ├── ddpg_critic_Q_ex2_1.h5 │ │ └── ddpg_critic_Q_ex_1.h5 │ ├── 1month_every3day │ │ ├── ddpg_actor_14.h5 │ │ ├── ddpg_critic_Q_ex2_14.h5 │ │ └── ddpg_critic_Q_ex_14.h5 │ ├── 1month_every5day │ │ ├── ddpg_actor_40.h5 │ │ ├── ddpg_critic_Q_ex2_40.h5 │ │ └── ddpg_critic_Q_ex_40.h5 │ └── 3month_every2day │ │ ├── ddpg_actor_25.h5 │ │ ├── ddpg_critic_Q_ex2_25.h5 │ │ └── ddpg_critic_Q_ex_25.h5 └── sabr │ ├── 1month_daily │ ├── ddpg_actor_5.h5 │ ├── ddpg_critic_Q_ex2_5.h5 │ └── ddpg_critic_Q_ex_5.h5 │ ├── 3month_daily │ ├── ddpg_actor_109.h5 │ ├── ddpg_critic_Q_ex2_109.h5 │ └── ddpg_critic_Q_ex_109.h5 │ ├── 1month_every2day │ ├── ddpg_actor_9.h5 │ ├── ddpg_critic_Q_ex_9.h5 │ └── ddpg_critic_Q_ex2_9.h5 │ ├── 1month_every3day │ ├── ddpg_actor_43.h5 │ ├── ddpg_critic_Q_ex2_43.h5 │ └── ddpg_critic_Q_ex_43.h5 │ ├── 1month_every5day │ ├── ddpg_actor_15.h5 │ ├── ddpg_critic_Q_ex2_15.h5 │ └── ddpg_critic_Q_ex_15.h5 │ ├── 3month_every2day │ ├── ddpg_actor_62.h5 │ ├── ddpg_critic_Q_ex2_62.h5 │ └── ddpg_critic_Q_ex_62.h5 │ ├── 3month_every3day │ ├── ddpg_actor_7.h5 │ ├── ddpg_critic_Q_ex_7.h5 │ └── ddpg_critic_Q_ex2_7.h5 │ └── 3month_every5day │ ├── ddpg_actor_6.h5 │ ├── ddpg_critic_Q_ex_6.h5 │ └── ddpg_critic_Q_ex2_6.h5 ├── ddpg_test.py ├── README.md ├── drl.py ├── schedules.py ├── segment_tree.py ├── envs.py ├── replay_buffer.py ├── utils.py └── ddpg_per.py /weights/bs/1month_daily/ddpg_actor_3.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tdmdal/rl-hedge-2019/HEAD/weights/bs/1month_daily/ddpg_actor_3.h5 -------------------------------------------------------------------------------- /weights/bs/3month_daily/ddpg_actor_6.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tdmdal/rl-hedge-2019/HEAD/weights/bs/3month_daily/ddpg_actor_6.h5 -------------------------------------------------------------------------------- /weights/sabr/1month_daily/ddpg_actor_5.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tdmdal/rl-hedge-2019/HEAD/weights/sabr/1month_daily/ddpg_actor_5.h5 -------------------------------------------------------------------------------- /weights/bs/1month_every2day/ddpg_actor_5.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tdmdal/rl-hedge-2019/HEAD/weights/bs/1month_every2day/ddpg_actor_5.h5 -------------------------------------------------------------------------------- /weights/bs/3month_every3day/ddpg_actor_4.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tdmdal/rl-hedge-2019/HEAD/weights/bs/3month_every3day/ddpg_actor_4.h5 -------------------------------------------------------------------------------- /weights/bs/3month_every5day/ddpg_actor_1.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tdmdal/rl-hedge-2019/HEAD/weights/bs/3month_every5day/ddpg_actor_1.h5 -------------------------------------------------------------------------------- /weights/sabr/3month_daily/ddpg_actor_109.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tdmdal/rl-hedge-2019/HEAD/weights/sabr/3month_daily/ddpg_actor_109.h5 -------------------------------------------------------------------------------- /weights/bs/1month_daily/ddpg_critic_Q_ex2_3.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tdmdal/rl-hedge-2019/HEAD/weights/bs/1month_daily/ddpg_critic_Q_ex2_3.h5 -------------------------------------------------------------------------------- /weights/bs/1month_daily/ddpg_critic_Q_ex_3.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tdmdal/rl-hedge-2019/HEAD/weights/bs/1month_daily/ddpg_critic_Q_ex_3.h5 -------------------------------------------------------------------------------- /weights/bs/1month_every3day/ddpg_actor_14.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tdmdal/rl-hedge-2019/HEAD/weights/bs/1month_every3day/ddpg_actor_14.h5 -------------------------------------------------------------------------------- /weights/bs/1month_every5day/ddpg_actor_40.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tdmdal/rl-hedge-2019/HEAD/weights/bs/1month_every5day/ddpg_actor_40.h5 -------------------------------------------------------------------------------- /weights/bs/3month_daily/ddpg_critic_Q_ex2_6.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tdmdal/rl-hedge-2019/HEAD/weights/bs/3month_daily/ddpg_critic_Q_ex2_6.h5 -------------------------------------------------------------------------------- /weights/bs/3month_daily/ddpg_critic_Q_ex_6.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tdmdal/rl-hedge-2019/HEAD/weights/bs/3month_daily/ddpg_critic_Q_ex_6.h5 -------------------------------------------------------------------------------- /weights/bs/3month_every2day/ddpg_actor_25.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tdmdal/rl-hedge-2019/HEAD/weights/bs/3month_every2day/ddpg_actor_25.h5 -------------------------------------------------------------------------------- /weights/sabr/1month_every2day/ddpg_actor_9.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tdmdal/rl-hedge-2019/HEAD/weights/sabr/1month_every2day/ddpg_actor_9.h5 -------------------------------------------------------------------------------- /weights/sabr/1month_every3day/ddpg_actor_43.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tdmdal/rl-hedge-2019/HEAD/weights/sabr/1month_every3day/ddpg_actor_43.h5 -------------------------------------------------------------------------------- /weights/sabr/1month_every5day/ddpg_actor_15.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tdmdal/rl-hedge-2019/HEAD/weights/sabr/1month_every5day/ddpg_actor_15.h5 -------------------------------------------------------------------------------- /weights/sabr/3month_every2day/ddpg_actor_62.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tdmdal/rl-hedge-2019/HEAD/weights/sabr/3month_every2day/ddpg_actor_62.h5 -------------------------------------------------------------------------------- /weights/sabr/3month_every3day/ddpg_actor_7.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tdmdal/rl-hedge-2019/HEAD/weights/sabr/3month_every3day/ddpg_actor_7.h5 -------------------------------------------------------------------------------- /weights/sabr/3month_every5day/ddpg_actor_6.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tdmdal/rl-hedge-2019/HEAD/weights/sabr/3month_every5day/ddpg_actor_6.h5 -------------------------------------------------------------------------------- /weights/sabr/1month_daily/ddpg_critic_Q_ex2_5.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tdmdal/rl-hedge-2019/HEAD/weights/sabr/1month_daily/ddpg_critic_Q_ex2_5.h5 -------------------------------------------------------------------------------- /weights/sabr/1month_daily/ddpg_critic_Q_ex_5.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tdmdal/rl-hedge-2019/HEAD/weights/sabr/1month_daily/ddpg_critic_Q_ex_5.h5 -------------------------------------------------------------------------------- /weights/bs/1month_every2day/ddpg_critic_Q_ex2_5.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tdmdal/rl-hedge-2019/HEAD/weights/bs/1month_every2day/ddpg_critic_Q_ex2_5.h5 -------------------------------------------------------------------------------- /weights/bs/1month_every2day/ddpg_critic_Q_ex_5.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tdmdal/rl-hedge-2019/HEAD/weights/bs/1month_every2day/ddpg_critic_Q_ex_5.h5 -------------------------------------------------------------------------------- /weights/bs/1month_every3day/ddpg_critic_Q_ex2_14.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tdmdal/rl-hedge-2019/HEAD/weights/bs/1month_every3day/ddpg_critic_Q_ex2_14.h5 -------------------------------------------------------------------------------- /weights/bs/1month_every3day/ddpg_critic_Q_ex_14.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tdmdal/rl-hedge-2019/HEAD/weights/bs/1month_every3day/ddpg_critic_Q_ex_14.h5 -------------------------------------------------------------------------------- /weights/bs/1month_every5day/ddpg_critic_Q_ex2_40.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tdmdal/rl-hedge-2019/HEAD/weights/bs/1month_every5day/ddpg_critic_Q_ex2_40.h5 -------------------------------------------------------------------------------- /weights/bs/1month_every5day/ddpg_critic_Q_ex_40.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tdmdal/rl-hedge-2019/HEAD/weights/bs/1month_every5day/ddpg_critic_Q_ex_40.h5 -------------------------------------------------------------------------------- /weights/bs/3month_every2day/ddpg_critic_Q_ex2_25.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tdmdal/rl-hedge-2019/HEAD/weights/bs/3month_every2day/ddpg_critic_Q_ex2_25.h5 -------------------------------------------------------------------------------- /weights/bs/3month_every2day/ddpg_critic_Q_ex_25.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tdmdal/rl-hedge-2019/HEAD/weights/bs/3month_every2day/ddpg_critic_Q_ex_25.h5 -------------------------------------------------------------------------------- /weights/bs/3month_every3day/ddpg_critic_Q_ex2_4.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tdmdal/rl-hedge-2019/HEAD/weights/bs/3month_every3day/ddpg_critic_Q_ex2_4.h5 -------------------------------------------------------------------------------- /weights/bs/3month_every3day/ddpg_critic_Q_ex_4.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tdmdal/rl-hedge-2019/HEAD/weights/bs/3month_every3day/ddpg_critic_Q_ex_4.h5 -------------------------------------------------------------------------------- /weights/bs/3month_every5day/ddpg_critic_Q_ex2_1.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tdmdal/rl-hedge-2019/HEAD/weights/bs/3month_every5day/ddpg_critic_Q_ex2_1.h5 -------------------------------------------------------------------------------- /weights/bs/3month_every5day/ddpg_critic_Q_ex_1.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tdmdal/rl-hedge-2019/HEAD/weights/bs/3month_every5day/ddpg_critic_Q_ex_1.h5 -------------------------------------------------------------------------------- /weights/sabr/1month_every2day/ddpg_critic_Q_ex_9.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tdmdal/rl-hedge-2019/HEAD/weights/sabr/1month_every2day/ddpg_critic_Q_ex_9.h5 -------------------------------------------------------------------------------- /weights/sabr/3month_daily/ddpg_critic_Q_ex2_109.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tdmdal/rl-hedge-2019/HEAD/weights/sabr/3month_daily/ddpg_critic_Q_ex2_109.h5 -------------------------------------------------------------------------------- /weights/sabr/3month_daily/ddpg_critic_Q_ex_109.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tdmdal/rl-hedge-2019/HEAD/weights/sabr/3month_daily/ddpg_critic_Q_ex_109.h5 -------------------------------------------------------------------------------- /weights/sabr/3month_every3day/ddpg_critic_Q_ex_7.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tdmdal/rl-hedge-2019/HEAD/weights/sabr/3month_every3day/ddpg_critic_Q_ex_7.h5 -------------------------------------------------------------------------------- /weights/sabr/3month_every5day/ddpg_critic_Q_ex_6.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tdmdal/rl-hedge-2019/HEAD/weights/sabr/3month_every5day/ddpg_critic_Q_ex_6.h5 -------------------------------------------------------------------------------- /weights/sabr/1month_every2day/ddpg_critic_Q_ex2_9.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tdmdal/rl-hedge-2019/HEAD/weights/sabr/1month_every2day/ddpg_critic_Q_ex2_9.h5 -------------------------------------------------------------------------------- /weights/sabr/1month_every3day/ddpg_critic_Q_ex2_43.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tdmdal/rl-hedge-2019/HEAD/weights/sabr/1month_every3day/ddpg_critic_Q_ex2_43.h5 -------------------------------------------------------------------------------- /weights/sabr/1month_every3day/ddpg_critic_Q_ex_43.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tdmdal/rl-hedge-2019/HEAD/weights/sabr/1month_every3day/ddpg_critic_Q_ex_43.h5 -------------------------------------------------------------------------------- /weights/sabr/1month_every5day/ddpg_critic_Q_ex2_15.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tdmdal/rl-hedge-2019/HEAD/weights/sabr/1month_every5day/ddpg_critic_Q_ex2_15.h5 -------------------------------------------------------------------------------- /weights/sabr/1month_every5day/ddpg_critic_Q_ex_15.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tdmdal/rl-hedge-2019/HEAD/weights/sabr/1month_every5day/ddpg_critic_Q_ex_15.h5 -------------------------------------------------------------------------------- /weights/sabr/3month_every2day/ddpg_critic_Q_ex2_62.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tdmdal/rl-hedge-2019/HEAD/weights/sabr/3month_every2day/ddpg_critic_Q_ex2_62.h5 -------------------------------------------------------------------------------- /weights/sabr/3month_every2day/ddpg_critic_Q_ex_62.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tdmdal/rl-hedge-2019/HEAD/weights/sabr/3month_every2day/ddpg_critic_Q_ex_62.h5 -------------------------------------------------------------------------------- /weights/sabr/3month_every3day/ddpg_critic_Q_ex2_7.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tdmdal/rl-hedge-2019/HEAD/weights/sabr/3month_every3day/ddpg_critic_Q_ex2_7.h5 -------------------------------------------------------------------------------- /weights/sabr/3month_every5day/ddpg_critic_Q_ex2_6.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tdmdal/rl-hedge-2019/HEAD/weights/sabr/3month_every5day/ddpg_critic_Q_ex2_6.h5 -------------------------------------------------------------------------------- /ddpg_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from ddpg_per import DDPG 4 | from envs import TradingEnv 5 | 6 | if __name__ == "__main__": 7 | 8 | # disable GPU 9 | os.environ['CUDA_VISIBLE_DEVICES'] = '-1' 10 | 11 | # specify what to test 12 | delta_action_test = False 13 | bartlett_action_test = False 14 | 15 | # specify weights file to load 16 | tag = "49" 17 | 18 | # set init_ttm, spread, and other parameters according to the env that the model is trained 19 | env_test = TradingEnv(continuous_action_flag=True, sabr_flag=True, dg_random_seed=2, spread=0.01, num_contract=1, init_ttm=20, trade_freq=1, num_sim=100001) 20 | ddpg_test = DDPG(env_test) 21 | 22 | print("\n\n***") 23 | if delta_action_test: 24 | print("Testing delta actions.") 25 | else: 26 | print("Testing agent actions.") 27 | if tag == "": 28 | print("tesing the model saved at the end of the training.") 29 | else: 30 | print("Testing model saved at " + tag + "K episode.") 31 | ddpg_test.load(tag=tag) 32 | 33 | ddpg_test.test(100001, delta_flag=delta_action_test, bartlett_flag=bartlett_action_test) 34 | 35 | # for i in range(1, 51): 36 | # tag = str(i) 37 | # print("****** ", tag) 38 | # ddpg_test.load(tag=tag) 39 | # ddpg_test.test(3001, delta_flag=delta_action_test) 40 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Hedging with Reinforcement Learning 2 | 3 | ## About 4 | 5 | This is the companion code for the paper *Deep Hedging of Derivatives Using Reinforcement Learning* by Jay Cao, Jacky Chen, John Hull, and Zissis Poulos. The paper is available [here](https://ssrn.com/abstract=3514586) at SSRN. 6 | 7 | ## Requirement 8 | 9 | The code requires gym (0.12.1), tensorflow (1.13.1), and keras (2.3.1). 10 | 11 | ## Usage 12 | 13 | Run `python ddpg_per.py` to start training. Run `python ddpg_test.py` to test a trained model. 14 | 15 | To setup a trading scenario for training and testing, modify the trading environment instantiation parameter values in the code accordingly (`env = TradingEnv(...)` and `env_test = TradingEnv(...)` ). 16 | 17 | ## Weights files 18 | 19 | Trained weights for all trading scenarios in the paper are provided in the `weights` folder. 20 | 21 | Each set of weights are obtained after 2 or 3 rounds of trainings. Later round of trainings start with the best weights obtained from the previous round together with manually fine-tuned hyper-parameter values (learning rate, target network soft update rate, etc. See comments in the code for details.) 22 | 23 | ## Credits 24 | 25 | * The code structure is adapted from [@xiaochus](https://github.com/xiaochus)'s github project [Deep-Reinforcement-Learning-Practice](https://github.com/xiaochus/Deep-Reinforcement-Learning-Practice). 26 | 27 | * The implementation of prioritized experience replay buffer is taken from OpenAI [Baselines](https://github.com/openai/baselines). 28 | -------------------------------------------------------------------------------- /drl.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gym 3 | import numpy as np 4 | import pandas as pd 5 | import matplotlib.pyplot as plt 6 | 7 | # put things common to different algorithms here 8 | class DRL: 9 | def __init__(self): 10 | if not os.path.exists('model'): 11 | os.mkdir('model') 12 | 13 | if not os.path.exists('history'): 14 | os.mkdir('history') 15 | 16 | def test(self, total_episode, delta_flag=False, bartlett_flag=False): 17 | """hedge with model. 18 | """ 19 | print('testing...') 20 | 21 | self.epsilon = -1 22 | 23 | w_T_store = [] 24 | 25 | for i in range(total_episode): 26 | observation = self.env.reset() 27 | done = False 28 | action_store = [] 29 | reward_store = [] 30 | 31 | while not done: 32 | 33 | # prepare state 34 | x = np.array(observation).reshape(1, -1) 35 | 36 | if delta_flag: 37 | action = self.env.delta_path[i % self.env.num_path, self.env.t] * self.env.num_contract * 100 38 | elif bartlett_flag: 39 | action = self.env.bartlett_delta_path[i % self.env.num_path, self.env.t] * self.env.num_contract * 100 40 | else: 41 | # choose action from epsilon-greedy; epsilon has been set to -1 42 | action, _, _ = self.egreedy_action(x) 43 | 44 | # store action to take a look 45 | action_store.append(action) 46 | 47 | # a step 48 | observation, reward, done, info = self.env.step(action) 49 | reward_store.append(reward) 50 | 51 | # get final wealth at the end of episode, and store it. 52 | w_T = sum(reward_store) 53 | w_T_store.append(w_T) 54 | 55 | if i % 1000 == 0: 56 | w_T_mean = np.mean(w_T_store) 57 | w_T_var = np.var(w_T_store) 58 | path_row = info["path_row"] 59 | print(info) 60 | with np.printoptions(precision=2, suppress=True): 61 | print("episode: {} | final wealth: {:.2f}; so far mean and variance of final wealth was {} and {}".format(i, w_T, w_T_mean, w_T_var)) 62 | print("episode: {} | so far Y(0): {:.2f}".format(i, -w_T_mean + self.ra_c * np.sqrt(w_T_var))) 63 | print("episode: {} | rewards: {}".format(i, np.array(reward_store))) 64 | print("episode: {} | action taken: {}".format(i, np.array(action_store))) 65 | print("episode: {} | deltas {}".format(i, self.env.delta_path[path_row] * 100)) 66 | print("episode: {} | stock price {}".format(i, self.env.path[path_row])) 67 | print("episode: {} | option price {}\n".format(i, self.env.option_price_path[path_row] * 100)) 68 | 69 | def plot(self, history): 70 | pass 71 | 72 | def save_history(self, history, name): 73 | name = os.path.join('history', name) 74 | 75 | df = pd.DataFrame.from_dict(history) 76 | df.to_csv(name, index=False, encoding='utf-8') -------------------------------------------------------------------------------- /schedules.py: -------------------------------------------------------------------------------- 1 | # utility classes for various schedules 2 | # use the implementation from OpenAI's baselines 3 | # https://github.com/openai/baselines/blob/master/baselines/common/schedules.py 4 | 5 | """This file is used for specifying various schedules that evolve over 6 | time throughout the execution of the algorithm, such as: 7 | - learning rate for the optimizer 8 | - exploration epsilon for the epsilon greedy exploration strategy 9 | - beta parameter for beta parameter in prioritized replay 10 | 11 | Each schedule has a function `value(t)` which returns the current value 12 | of the parameter given the timestep t of the optimization procedure. 13 | """ 14 | 15 | 16 | class Schedule(object): 17 | def value(self, t): 18 | """Value of the schedule at time t""" 19 | raise NotImplementedError() 20 | 21 | 22 | class ConstantSchedule(object): 23 | def __init__(self, value): 24 | """Value remains constant over time. 25 | 26 | Parameters 27 | ---------- 28 | value: float 29 | Constant value of the schedule 30 | """ 31 | self._v = value 32 | 33 | def value(self, t): 34 | """See Schedule.value""" 35 | return self._v 36 | 37 | 38 | def linear_interpolation(l, r, alpha): 39 | return l + alpha * (r - l) 40 | 41 | 42 | class PiecewiseSchedule(object): 43 | def __init__(self, endpoints, interpolation=linear_interpolation, outside_value=None): 44 | """Piecewise schedule. 45 | 46 | endpoints: [(int, int)] 47 | list of pairs `(time, value)` meanining that schedule should output 48 | `value` when `t==time`. All the values for time must be sorted in 49 | an increasing order. When t is between two times, e.g. `(time_a, value_a)` 50 | and `(time_b, value_b)`, such that `time_a <= t < time_b` then value outputs 51 | `interpolation(value_a, value_b, alpha)` where alpha is a fraction of 52 | time passed between `time_a` and `time_b` for time `t`. 53 | interpolation: lambda float, float, float: float 54 | a function that takes value to the left and to the right of t according 55 | to the `endpoints`. Alpha is the fraction of distance from left endpoint to 56 | right endpoint that t has covered. See linear_interpolation for example. 57 | outside_value: float 58 | if the value is requested outside of all the intervals sepecified in 59 | `endpoints` this value is returned. If None then AssertionError is 60 | raised when outside value is requested. 61 | """ 62 | idxes = [e[0] for e in endpoints] 63 | assert idxes == sorted(idxes) 64 | self._interpolation = interpolation 65 | self._outside_value = outside_value 66 | self._endpoints = endpoints 67 | 68 | def value(self, t): 69 | """See Schedule.value""" 70 | for (l_t, l), (r_t, r) in zip(self._endpoints[:-1], self._endpoints[1:]): 71 | if l_t <= t and t < r_t: 72 | alpha = float(t - l_t) / (r_t - l_t) 73 | return self._interpolation(l, r, alpha) 74 | 75 | # t does not belong to any of the pieces, so doom. 76 | assert self._outside_value is not None 77 | return self._outside_value 78 | 79 | 80 | class LinearSchedule(object): 81 | def __init__(self, schedule_timesteps, final_p, initial_p=1.0): 82 | """Linear interpolation between initial_p and final_p over 83 | schedule_timesteps. After this many timesteps pass final_p is 84 | returned. 85 | 86 | Parameters 87 | ---------- 88 | schedule_timesteps: int 89 | Number of timesteps for which to linearly anneal initial_p 90 | to final_p 91 | initial_p: float 92 | initial output value 93 | final_p: float 94 | final output value 95 | """ 96 | self.schedule_timesteps = schedule_timesteps 97 | self.final_p = final_p 98 | self.initial_p = initial_p 99 | 100 | def value(self, t): 101 | """See Schedule.value""" 102 | fraction = min(float(t) / self.schedule_timesteps, 1.0) 103 | return self.initial_p + fraction * (self.final_p - self.initial_p) 104 | -------------------------------------------------------------------------------- /segment_tree.py: -------------------------------------------------------------------------------- 1 | # utility classes for prioritized experience replay (PER) 2 | # use the implementation from OpenAI's baselines 3 | # https://github.com/openai/baselines/blob/master/baselines/common/segment_tree.py 4 | 5 | import operator 6 | 7 | 8 | class SegmentTree(object): 9 | def __init__(self, capacity, operation, neutral_element): 10 | """Build a Segment Tree data structure. 11 | 12 | https://en.wikipedia.org/wiki/Segment_tree 13 | 14 | Can be used as regular array, but with two 15 | important differences: 16 | 17 | a) setting item's value is slightly slower. 18 | It is O(lg capacity) instead of O(1). 19 | b) user has access to an efficient ( O(log segment size) ) 20 | `reduce` operation which reduces `operation` over 21 | a contiguous subsequence of items in the array. 22 | 23 | Paramters 24 | --------- 25 | capacity: int 26 | Total size of the array - must be a power of two. 27 | operation: lambda obj, obj -> obj 28 | and operation for combining elements (eg. sum, max) 29 | must form a mathematical group together with the set of 30 | possible values for array elements (i.e. be associative) 31 | neutral_element: obj 32 | neutral element for the operation above. eg. float('-inf') 33 | for max and 0 for sum. 34 | """ 35 | assert capacity > 0 and capacity & (capacity - 1) == 0, "capacity must be positive and a power of 2." 36 | self._capacity = capacity 37 | self._value = [neutral_element for _ in range(2 * capacity)] 38 | self._operation = operation 39 | 40 | def _reduce_helper(self, start, end, node, node_start, node_end): 41 | if start == node_start and end == node_end: 42 | return self._value[node] 43 | mid = (node_start + node_end) // 2 44 | if end <= mid: 45 | return self._reduce_helper(start, end, 2 * node, node_start, mid) 46 | else: 47 | if mid + 1 <= start: 48 | return self._reduce_helper(start, end, 2 * node + 1, mid + 1, node_end) 49 | else: 50 | return self._operation( 51 | self._reduce_helper(start, mid, 2 * node, node_start, mid), 52 | self._reduce_helper(mid + 1, end, 2 * node + 1, mid + 1, node_end) 53 | ) 54 | 55 | def reduce(self, start=0, end=None): 56 | """Returns result of applying `self.operation` 57 | to a contiguous subsequence of the array. 58 | 59 | self.operation(arr[start], operation(arr[start+1], operation(... arr[end]))) 60 | 61 | Parameters 62 | ---------- 63 | start: int 64 | beginning of the subsequence 65 | end: int 66 | end of the subsequences 67 | 68 | Returns 69 | ------- 70 | reduced: obj 71 | result of reducing self.operation over the specified range of array elements. 72 | """ 73 | if end is None: 74 | end = self._capacity 75 | if end < 0: 76 | end += self._capacity 77 | end -= 1 78 | return self._reduce_helper(start, end, 1, 0, self._capacity - 1) 79 | 80 | def __setitem__(self, idx, val): 81 | # index of the leaf 82 | idx += self._capacity 83 | self._value[idx] = val 84 | idx //= 2 85 | while idx >= 1: 86 | self._value[idx] = self._operation( 87 | self._value[2 * idx], 88 | self._value[2 * idx + 1] 89 | ) 90 | idx //= 2 91 | 92 | def __getitem__(self, idx): 93 | assert 0 <= idx < self._capacity 94 | return self._value[self._capacity + idx] 95 | 96 | 97 | class SumSegmentTree(SegmentTree): 98 | def __init__(self, capacity): 99 | super(SumSegmentTree, self).__init__( 100 | capacity=capacity, 101 | operation=operator.add, 102 | neutral_element=0.0 103 | ) 104 | 105 | def sum(self, start=0, end=None): 106 | """Returns arr[start] + ... + arr[end]""" 107 | return super(SumSegmentTree, self).reduce(start, end) 108 | 109 | def find_prefixsum_idx(self, prefixsum): 110 | """Find the highest index `i` in the array such that 111 | sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum 112 | 113 | if array values are probabilities, this function 114 | allows to sample indexes according to the discrete 115 | probability efficiently. 116 | 117 | Parameters 118 | ---------- 119 | perfixsum: float 120 | upperbound on the sum of array prefix 121 | 122 | Returns 123 | ------- 124 | idx: int 125 | highest index satisfying the prefixsum constraint 126 | """ 127 | assert 0 <= prefixsum <= self.sum() + 1e-5 128 | idx = 1 129 | while idx < self._capacity: # while non-leaf 130 | if self._value[2 * idx] > prefixsum: 131 | idx = 2 * idx 132 | else: 133 | prefixsum -= self._value[2 * idx] 134 | idx = 2 * idx + 1 135 | return idx - self._capacity 136 | 137 | 138 | class MinSegmentTree(SegmentTree): 139 | def __init__(self, capacity): 140 | super(MinSegmentTree, self).__init__( 141 | capacity=capacity, 142 | operation=min, 143 | neutral_element=float('inf') 144 | ) 145 | 146 | def min(self, start=0, end=None): 147 | """Returns min(arr[start], ..., arr[end])""" 148 | 149 | return super(MinSegmentTree, self).reduce(start, end) 150 | -------------------------------------------------------------------------------- /envs.py: -------------------------------------------------------------------------------- 1 | """A trading environment""" 2 | import gym 3 | from gym import spaces 4 | from gym.utils import seeding 5 | 6 | import numpy as np 7 | 8 | from utils import get_sim_path, get_sim_path_sabr 9 | 10 | 11 | class TradingEnv(gym.Env): 12 | """ 13 | trading environment; 14 | """ 15 | 16 | # trade_freq in unit of day, e.g 2: every 2 day; 0.5 twice a day; 17 | def __init__(self, cash_flow_flag=0, dg_random_seed=1, num_sim=500002, sabr_flag = False, 18 | continuous_action_flag=False, spread=0, init_ttm=5, trade_freq=1, num_contract=1): 19 | 20 | # simulated data: array of asset price, option price and delta paths (num_path x num_period) 21 | # generate data now 22 | if sabr_flag: 23 | self.path, self.option_price_path, self.delta_path, self.bartlett_delta_path = get_sim_path_sabr(M=init_ttm, freq=trade_freq, 24 | np_seed=dg_random_seed, num_sim=num_sim) 25 | else: 26 | self.path, self.option_price_path, self.delta_path = get_sim_path(M=init_ttm, freq=trade_freq, 27 | np_seed=dg_random_seed, num_sim=num_sim) 28 | 29 | # other attributes 30 | self.num_path = self.path.shape[0] 31 | 32 | # set num_period: initial time to maturity * daily trading freq + 1 (see get_sim_path() in utils.py) 33 | self.num_period = self.path.shape[1] 34 | # print("***", self.num_period) 35 | 36 | # time to maturity array 37 | self.ttm_array = np.arange(init_ttm, -trade_freq, -trade_freq) 38 | # print(self.ttm_array) 39 | 40 | # spread 41 | self.spread = spread 42 | 43 | # step function initialization depending on cash_flow_flag 44 | if cash_flow_flag == 1: 45 | self.step = self.step_cash_flow 46 | else: 47 | self.step = self.step_profit_loss 48 | 49 | self.num_contract = num_contract 50 | self.strike_price = 100 51 | 52 | # track the index of simulated path in use 53 | self.sim_episode = -1 54 | 55 | # track time step within an episode (it's step) 56 | self.t = None 57 | 58 | # action space 59 | if continuous_action_flag: 60 | self.action_space = spaces.Box(low=np.array([0]), high=np.array([num_contract * 100]), dtype=np.float32) 61 | else: 62 | self.num_action = num_contract * 100 + 1 63 | self.action_space = spaces.Discrete(self.num_action) 64 | 65 | self.num_state = 3 66 | 67 | self.state = [] 68 | 69 | # seed and start 70 | self.seed() 71 | # self.reset() 72 | 73 | def seed(self, seed=None): 74 | self.np_random, seed = seeding.np_random(seed) 75 | return [seed] 76 | 77 | def reset(self): 78 | # repeatedly go through available simulated paths (if needed) 79 | self.sim_episode = (self.sim_episode + 1) % self.num_path 80 | 81 | self.t = 0 82 | 83 | price = self.path[self.sim_episode, self.t] 84 | position = 0 85 | 86 | ttm = self.ttm_array[self.t] 87 | 88 | self.state = [price, position, ttm] 89 | 90 | return self.state 91 | 92 | def step_cash_flow(self, action): 93 | """ 94 | cash flow period reward 95 | """ 96 | 97 | # do it consistently as in the profit & loss case 98 | # current prices (at t) 99 | current_price = self.state[0] 100 | 101 | # current position 102 | current_position = self.state[1] 103 | 104 | # update time/period 105 | self.t = self.t + 1 106 | 107 | # get state for tomorrow 108 | price = self.path[self.sim_episode, self.t] 109 | position = action 110 | ttm = self.ttm_array[self.t] 111 | 112 | self.state = [price, position, ttm] 113 | 114 | # calculate period reward (part 1) 115 | cash_flow = -(position - current_position) * current_price - np.abs(position - current_position) * current_price * self.spread 116 | 117 | # if tomorrow is end of episode 118 | if self.t == self.num_period - 1: 119 | done = True 120 | # add (stock payoff + option payoff) to cash flow 121 | reward = cash_flow + price * position - max(price - self.strike_price, 0) * self.num_contract * 100 - position * price * self.spread 122 | else: 123 | done = False 124 | reward = cash_flow 125 | 126 | # for other info 127 | info = {"path_row": self.sim_episode} 128 | 129 | return self.state, reward, done, info 130 | 131 | def step_profit_loss(self, action): 132 | """ 133 | profit loss period reward 134 | """ 135 | 136 | # current prices (at t) 137 | current_price = self.state[0] 138 | current_option_price = self.option_price_path[self.sim_episode, self.t] 139 | 140 | # current position 141 | current_position = self.state[1] 142 | 143 | # update time 144 | self.t = self.t + 1 145 | 146 | # get state for tomorrow (at t + 1) 147 | price = self.path[self.sim_episode, self.t] 148 | option_price = self.option_price_path[self.sim_episode, self.t] 149 | position = action 150 | ttm = self.ttm_array[self.t] 151 | 152 | self.state = [price, position, ttm] 153 | 154 | # calculate period reward (part 1) 155 | reward = (price - current_price) * position - np.abs(current_position - position) * current_price * self.spread 156 | 157 | # if tomorrow is end of episode 158 | if self.t == self.num_period - 1: 159 | done = True 160 | reward = reward - (max(price - self.strike_price, 0) - current_option_price) * self.num_contract * 100 - position * price * self.spread 161 | else: 162 | done = False 163 | reward = reward - (option_price - current_option_price) * self.num_contract * 100 164 | 165 | # for other info later 166 | info = {"path_row": self.sim_episode} 167 | 168 | return self.state, reward, done, info 169 | -------------------------------------------------------------------------------- /replay_buffer.py: -------------------------------------------------------------------------------- 1 | # prioritized experience replay (PER) buffer 2 | # use the implementation from OpenAI's baselines 3 | # https://github.com/openai/baselines/blob/master/baselines/deepq/replay_buffer.py 4 | 5 | import numpy as np 6 | import random 7 | 8 | from segment_tree import SumSegmentTree, MinSegmentTree 9 | 10 | 11 | class ReplayBuffer(object): 12 | def __init__(self, size): 13 | """Create Replay buffer. 14 | 15 | Parameters 16 | ---------- 17 | size: int 18 | Max number of transitions to store in the buffer. When the buffer 19 | overflows the old memories are dropped. 20 | """ 21 | self._storage = [] 22 | self._maxsize = size 23 | self._next_idx = 0 24 | 25 | def __len__(self): 26 | return len(self._storage) 27 | 28 | def add(self, obs_t, action, reward, obs_tp1, done): 29 | data = (obs_t, action, reward, obs_tp1, done) 30 | 31 | if self._next_idx >= len(self._storage): 32 | self._storage.append(data) 33 | else: 34 | self._storage[self._next_idx] = data 35 | self._next_idx = (self._next_idx + 1) % self._maxsize 36 | 37 | def _encode_sample(self, idxes): 38 | obses_t, actions, rewards, obses_tp1, dones = [], [], [], [], [] 39 | for i in idxes: 40 | data = self._storage[i] 41 | obs_t, action, reward, obs_tp1, done = data 42 | obses_t.append(np.array(obs_t, copy=False)) 43 | actions.append(np.array(action, copy=False)) 44 | rewards.append(reward) 45 | obses_tp1.append(np.array(obs_tp1, copy=False)) 46 | dones.append(done) 47 | return np.array(obses_t), np.array(actions), np.array(rewards), np.array(obses_tp1), np.array(dones) 48 | 49 | def sample(self, batch_size): 50 | """Sample a batch of experiences. 51 | 52 | Parameters 53 | ---------- 54 | batch_size: int 55 | How many transitions to sample. 56 | 57 | Returns 58 | ------- 59 | obs_batch: np.array 60 | batch of observations 61 | act_batch: np.array 62 | batch of actions executed given obs_batch 63 | rew_batch: np.array 64 | rewards received as results of executing act_batch 65 | next_obs_batch: np.array 66 | next set of observations seen after executing act_batch 67 | done_mask: np.array 68 | done_mask[i] = 1 if executing act_batch[i] resulted in 69 | the end of an episode and 0 otherwise. 70 | """ 71 | idxes = [random.randint(0, len(self._storage) - 1) for _ in range(batch_size)] 72 | return self._encode_sample(idxes) 73 | 74 | 75 | class PrioritizedReplayBuffer(ReplayBuffer): 76 | def __init__(self, size, alpha): 77 | """Create Prioritized Replay buffer. 78 | 79 | Parameters 80 | ---------- 81 | size: int 82 | Max number of transitions to store in the buffer. When the buffer 83 | overflows the old memories are dropped. 84 | alpha: float 85 | how much prioritization is used 86 | (0 - no prioritization, 1 - full prioritization) 87 | 88 | See Also 89 | -------- 90 | ReplayBuffer.__init__ 91 | """ 92 | super(PrioritizedReplayBuffer, self).__init__(size) 93 | assert alpha >= 0 94 | self._alpha = alpha 95 | 96 | it_capacity = 1 97 | while it_capacity < size: 98 | it_capacity *= 2 99 | 100 | self._it_sum = SumSegmentTree(it_capacity) 101 | self._it_min = MinSegmentTree(it_capacity) 102 | self._max_priority = 1.0 103 | 104 | def add(self, *args, **kwargs): 105 | """See ReplayBuffer.store_effect""" 106 | idx = self._next_idx 107 | super().add(*args, **kwargs) 108 | self._it_sum[idx] = self._max_priority ** self._alpha 109 | self._it_min[idx] = self._max_priority ** self._alpha 110 | 111 | def _sample_proportional(self, batch_size): 112 | res = [] 113 | p_total = self._it_sum.sum(0, len(self._storage) - 1) 114 | every_range_len = p_total / batch_size 115 | for i in range(batch_size): 116 | mass = random.random() * every_range_len + i * every_range_len 117 | idx = self._it_sum.find_prefixsum_idx(mass) 118 | res.append(idx) 119 | return res 120 | 121 | def sample(self, batch_size, beta): 122 | """Sample a batch of experiences. 123 | 124 | compared to ReplayBuffer.sample 125 | it also returns importance weights and idxes 126 | of sampled experiences. 127 | 128 | 129 | Parameters 130 | ---------- 131 | batch_size: int 132 | How many transitions to sample. 133 | beta: float 134 | To what degree to use importance weights 135 | (0 - no corrections, 1 - full correction) 136 | 137 | Returns 138 | ------- 139 | obs_batch: np.array 140 | batch of observations 141 | act_batch: np.array 142 | batch of actions executed given obs_batch 143 | rew_batch: np.array 144 | rewards received as results of executing act_batch 145 | next_obs_batch: np.array 146 | next set of observations seen after executing act_batch 147 | done_mask: np.array 148 | done_mask[i] = 1 if executing act_batch[i] resulted in 149 | the end of an episode and 0 otherwise. 150 | weights: np.array 151 | Array of shape (batch_size,) and dtype np.float32 152 | denoting importance weight of each sampled transition 153 | idxes: np.array 154 | Array of shape (batch_size,) and dtype np.int32 155 | idexes in buffer of sampled experiences 156 | """ 157 | assert beta > 0 158 | 159 | idxes = self._sample_proportional(batch_size) 160 | 161 | weights = [] 162 | p_min = self._it_min.min() / self._it_sum.sum() 163 | max_weight = (p_min * len(self._storage)) ** (-beta) 164 | 165 | for idx in idxes: 166 | p_sample = self._it_sum[idx] / self._it_sum.sum() 167 | weight = (p_sample * len(self._storage)) ** (-beta) 168 | weights.append(weight / max_weight) 169 | weights = np.array(weights) 170 | encoded_sample = self._encode_sample(idxes) 171 | return tuple(list(encoded_sample) + [weights, idxes]) 172 | 173 | def update_priorities(self, idxes, priorities): 174 | """Update priorities of sampled transitions. 175 | 176 | sets priority of transition at index idxes[i] in buffer 177 | to priorities[i]. 178 | 179 | Parameters 180 | ---------- 181 | idxes: [int] 182 | List of idxes of sampled transitions 183 | priorities: [float] 184 | List of updated priorities corresponding to 185 | transitions at the sampled idxes denoted by 186 | variable `idxes`. 187 | """ 188 | assert len(idxes) == len(priorities) 189 | for idx, priority in zip(idxes, priorities): 190 | assert priority > 0 191 | assert 0 <= idx < len(self._storage) 192 | self._it_sum[idx] = priority ** self._alpha 193 | self._it_min[idx] = priority ** self._alpha 194 | 195 | self._max_priority = max(self._max_priority, priority) 196 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ Utility Functions """ 2 | import random 3 | 4 | import numpy as np 5 | from scipy.stats import norm 6 | 7 | random.seed(1) 8 | 9 | def brownian_sim(num_path, num_period, mu, std, init_p, dt): 10 | z = np.random.normal(size=(num_path, num_period)) 11 | 12 | a_price = np.zeros((num_path, num_period)) 13 | a_price[:, 0] = init_p 14 | 15 | for t in range(num_period - 1): 16 | a_price[:, t + 1] = a_price[:, t] * np.exp( 17 | (mu - (std ** 2) / 2) * dt + std * np.sqrt(dt) * z[:, t] 18 | ) 19 | return a_price 20 | 21 | 22 | # BSM Call Option Pricing Formula & BS Delta formula 23 | # T here is time to maturity 24 | def bs_call(iv, T, S, K, r, q): 25 | d1 = (np.log(S / K) + (r - q + iv * iv / 2) * T) / (iv * np.sqrt(T)) 26 | d2 = d1 - iv * np.sqrt(T) 27 | bs_price = S * np.exp(-q * T) * norm.cdf(d1) - K * np.exp(-r * T) * norm.cdf(d2) 28 | bs_delta = np.exp(-q * T) * norm.cdf(d1) 29 | return bs_price, bs_delta 30 | 31 | 32 | def get_sim_path(M, freq, np_seed, num_sim): 33 | """ Return simulated data: a tuple of three arrays 34 | M: initial time to maturity 35 | freq: trading freq in unit of day, e.g. freq=2: every 2 day; freq=0.5 twice a day; 36 | np_seed: numpy random seed 37 | num_sim: number of simulation path 38 | 39 | 1) asset price paths (num_path x num_period) 40 | 2) option price paths (num_path x num_period) 41 | 3) delta (num_path x num_period) 42 | """ 43 | # set the np random seed 44 | np.random.seed(np_seed) 45 | 46 | # Trading Freq per day; passed from function parameter 47 | # freq = 2 48 | 49 | # Annual Trading Day 50 | T = 250 51 | 52 | # Simulation Time Step 53 | dt = 0.004 * freq 54 | 55 | # Option Day to Maturity; passed from function parameter 56 | # M = 60 57 | 58 | # Number of period 59 | num_period = int(M / freq) 60 | 61 | # Number of simulations; passed from function parameter 62 | # num_sim = 1000000 63 | 64 | # Annual Return 65 | mu = 0.05 66 | 67 | # Annual Volatility 68 | vol = 0.2 69 | 70 | # Initial Asset Value 71 | S = 100 72 | 73 | # Option Strike Price 74 | K = 100 75 | 76 | # Annual Risk Free Rate 77 | r = 0 78 | 79 | # Annual Dividend 80 | q = 0 81 | 82 | # asset price 2-d array 83 | print("1. generate asset price paths") 84 | a_price = brownian_sim(num_sim, num_period + 1, mu, vol, S, dt) 85 | 86 | # time to maturity "rank 1" array: e.g. [M, M-1, ..., 0] 87 | ttm = np.arange(M, -freq, -freq) 88 | 89 | # BS price 2-d array and bs delta 2-d array 90 | print("2. generate BS price and delta") 91 | bs_price, bs_delta = bs_call(vol, ttm / T, a_price, K, r, q) 92 | 93 | print("simulation done!") 94 | 95 | return a_price, bs_price, bs_delta 96 | 97 | 98 | def sabr_sim(num_path, num_period, mu, std, init_p, dt, rho, beta, volvol): 99 | qs = np.random.normal(size=(num_path, num_period)) 100 | qi = np.random.normal(size=(num_path, num_period)) 101 | qv = rho * qs + np.sqrt(1 - rho * rho) * qi 102 | 103 | vol = np.zeros((num_path, num_period)) 104 | vol[:, 0] = std 105 | 106 | a_price = np.zeros((num_path, num_period)) 107 | a_price[:, 0] = init_p 108 | 109 | for t in range(num_period - 1): 110 | gvol = vol[:, t] * (a_price[:, t] ** (beta - 1)) 111 | a_price[:, t + 1] = a_price[:, t] * np.exp( 112 | (mu - (gvol ** 2) / 2) * dt + gvol * np.sqrt(dt) * qs[:, t] 113 | ) 114 | vol[:, t + 1] = vol[:, t] * np.exp( 115 | -volvol * volvol * 0.5 * dt + volvol * qv[:, t] * np.sqrt(dt) 116 | ) 117 | 118 | return a_price, vol 119 | 120 | 121 | def sabr_implied_vol(vol, T, S, K, r, q, beta, volvol, rho): 122 | 123 | F = S * np.exp((r - q) * T) 124 | x = (F * K) ** ((1 - beta) / 2) 125 | y = (1 - beta) * np.log(F / K) 126 | A = vol / (x * (1 + y * y / 24 + y * y * y * y / 1920)) 127 | B = 1 + T * ( 128 | ((1 - beta) ** 2) * (vol * vol) / (24 * x * x) 129 | + rho * beta * volvol * vol / (4 * x) 130 | + volvol * volvol * (2 - 3 * rho * rho) / 24 131 | ) 132 | Phi = (volvol * x / vol) * np.log(F / K) 133 | Chi = np.log((np.sqrt(1 - 2 * rho * Phi + Phi * Phi) + Phi - rho) / (1 - rho)) 134 | 135 | SABRIV = np.where(F == K, vol * B / (F ** (1 - beta)), A * B * Phi / Chi) 136 | 137 | return SABRIV 138 | 139 | 140 | def bartlett(sigma, T, S, K, r, q, ds, beta, volvol, rho): 141 | 142 | dsigma = ds * volvol * rho / (S ** beta) 143 | 144 | vol1 = sabr_implied_vol(sigma, T, S, K, r, q, beta, volvol, rho) 145 | vol2 = sabr_implied_vol(sigma + dsigma, T, S + ds, K, r, q, beta, volvol, rho) 146 | 147 | bs_price1, _ = bs_call(vol1, T, S, K, r, q) 148 | bs_price2, _ = bs_call(vol2, T, S+ds, K, r, q) 149 | 150 | b_delta = (bs_price2 - bs_price1) / ds 151 | 152 | return b_delta 153 | 154 | def bs_call(iv, T, S, K, r, q): 155 | d1 = (np.log(S / K) + (r - q + iv * iv / 2) * T) / (iv * np.sqrt(T)) 156 | d2 = d1 - iv * np.sqrt(T) 157 | bs_price = S * np.exp(-q * T) * norm.cdf(d1) - K * np.exp(-r * T) * norm.cdf(d2) 158 | bs_delta = np.exp(-q * T) * norm.cdf(d1) 159 | return bs_price, bs_delta 160 | 161 | 162 | def get_sim_path_sabr(M, freq, np_seed, num_sim): 163 | """ Return simulated data: a tuple of four arrays 164 | M: initial time to maturity 165 | freq: trading freq in unit of day, e.g. freq=2: every 2 day; freq=0.5 twice a day; 166 | np_seed: numpy random seed 167 | num_sim: number of simulation path 168 | 169 | 1) asset price paths (num_path x num_period) 170 | 2) option price paths (num_path x num_period) 171 | 3) bs delta (num_path x num_period) 172 | 4) bartlett delta (num_path x num_period) 173 | """ 174 | # set the np random seed 175 | np.random.seed(np_seed) 176 | 177 | # Trading Freq per day; passed from function parameter 178 | # freq = 2 179 | 180 | # Annual Trading Day 181 | T = 250 182 | 183 | # Simulation Time Step 184 | dt = 0.004 * freq 185 | 186 | # Option Day to Maturity; passed from function parameter 187 | # M = 60 188 | 189 | # Number of period 190 | num_period = int(M / freq) 191 | 192 | # Number of simulations; passed from function parameter 193 | # num_sim = 1000000 194 | 195 | # Annual Return 196 | mu = 0.05 197 | 198 | # Annual Volatility 199 | vol = 0.2 200 | 201 | # Initial Asset Value 202 | S = 100 203 | 204 | # Option Strike Price 205 | K = 100 206 | 207 | # Annual Risk Free Rate 208 | r = 0 209 | 210 | # Annual Dividend 211 | q = 0 212 | 213 | # SABR parameters 214 | beta = 1 215 | rho = -0.4 216 | volvol = 0.6 217 | ds = 0.001 218 | 219 | # asset price 2-d array; sabr_vol 220 | print("1. generate asset price paths (sabr)") 221 | a_price, sabr_vol = sabr_sim( 222 | num_sim, num_period + 1, mu, vol, S, dt, rho, beta, volvol 223 | ) 224 | 225 | # time to maturity "rank 1" array: e.g. [M, M-1, ..., 0] 226 | ttm = np.arange(M, -freq, -freq) 227 | 228 | # BS price 2-d array and bs delta 2-d array 229 | print("2. generate BS price, BS delta, and Bartlett delta") 230 | 231 | # sabr implied vol 232 | implied_vol = sabr_implied_vol( 233 | sabr_vol, ttm / T, a_price, K, r, q, beta, volvol, rho 234 | ) 235 | 236 | bs_price, bs_delta = bs_call(implied_vol, ttm / T, a_price, K, r, q) 237 | 238 | bartlett_delta = bartlett(sabr_vol, ttm / T, a_price, K, r, q, ds, beta, volvol, rho) 239 | 240 | print("simulation done!") 241 | 242 | return a_price, bs_price, bs_delta, bartlett_delta 243 | -------------------------------------------------------------------------------- /ddpg_per.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | import tensorflow as tf 5 | 6 | from keras.layers import Input, Dense, Lambda, concatenate, BatchNormalization, Activation 7 | from keras.models import Model 8 | from keras.optimizers import Adam 9 | import keras.backend as K 10 | 11 | from drl import DRL 12 | from envs import TradingEnv 13 | from replay_buffer import PrioritizedReplayBuffer 14 | from schedules import LinearSchedule 15 | 16 | 17 | class DDPG(DRL): 18 | """ 19 | Deep Deterministic Policy Gradient 20 | """ 21 | 22 | def __init__(self, env): 23 | super(DDPG, self).__init__() 24 | 25 | self.sess = K.get_session() 26 | 27 | self.env = env 28 | self.upper_bound = self.env.action_space.high[0] 29 | self.lower_bound = self.env.action_space.low[0] 30 | 31 | # update rate for target model. 32 | # for 2nd round training, use 0.000001 33 | self.TAU = 0.00001 34 | 35 | # learning rate for actor and critic 36 | # for 2nd round training, use 1e-5 37 | self.actor_lr = 1e-4 38 | self.critic_lr = 1e-4 39 | 40 | # risk averse constant 41 | self.ra_c = 1.5 42 | 43 | # actor: policy function 44 | # critic: Q functions; Q_ex, Q_ex2, and Q 45 | self.actor = self._build_actor(learning_rate=self.actor_lr) 46 | self.critic_Q_ex, self.critic_Q_ex2, self.critic_Q = self._build_critic(learning_rate=self.critic_lr) 47 | 48 | self.critic_Q.summary() 49 | 50 | # target networks for actor and three critics 51 | self.actor_hat = self._build_actor(learning_rate=self.actor_lr) 52 | self.actor_hat.set_weights(self.actor.get_weights()) 53 | 54 | self.critic_Q_ex_hat, self.critic_Q_ex2_hat, self.critic_Q_hat = self._build_critic(learning_rate=self.critic_lr) 55 | self.critic_Q_ex_hat.set_weights(self.critic_Q_ex.get_weights()) 56 | self.critic_Q_ex2_hat.set_weights(self.critic_Q_ex2.get_weights()) 57 | 58 | # epsilon of epsilon-greedy 59 | self.epsilon = 1.0 60 | 61 | # discount rate for epsilon 62 | self.epsilon_decay = 0.99994 63 | # self.epsilon_decay = 0.9994 64 | 65 | # min epsilon of epsilon-greedy. 66 | self.epsilon_min = 0.1 67 | 68 | # memory buffer for experience replay 69 | buffer_size = 600000 70 | prioritized_replay_alpha = 0.6 71 | self.replay_buffer = PrioritizedReplayBuffer(buffer_size, alpha=prioritized_replay_alpha) 72 | 73 | prioritized_replay_beta0 = 0.4 74 | 75 | # need not be the same as training episode (see schedules.py) 76 | prioritized_replay_beta_iters = 50001 77 | 78 | self.beta_schedule = LinearSchedule(prioritized_replay_beta_iters, 79 | initial_p=prioritized_replay_beta0, 80 | final_p=1.0) 81 | 82 | # for numerical stabiligy 83 | self.prioritized_replay_eps = 1e-6 84 | 85 | self.t = None 86 | 87 | # memory sample batch size 88 | self.batch_size = 128 89 | 90 | # may use for 2nd round training 91 | # self.policy_noise = 5 92 | # self.noise_clip = 5 93 | 94 | # gradient function 95 | self.get_critic_grad = self.critic_gradient() 96 | self.actor_optimizer() 97 | 98 | def load(self, tag=""): 99 | """load two Qs for test""" 100 | if tag == "": 101 | actor_file = "model/ddpg_actor.h5" 102 | critic_Q_ex_file = "model/ddpg_critic_Q_ex.h5" 103 | critic_Q_ex2_file = "model/ddpg_critic_Q_ex2.h5" 104 | else: 105 | actor_file = "model/ddpg_actor_" + tag + ".h5" 106 | critic_Q_ex_file = "model/ddpg_critic_Q_ex_" + tag + ".h5" 107 | critic_Q_ex2_file = "model/ddpg_critic_Q_ex2_" + tag + ".h5" 108 | 109 | if os.path.exists(actor_file): 110 | self.actor.load_weights(actor_file) 111 | self.actor_hat.load_weights(actor_file) 112 | if os.path.exists(critic_Q_ex_file): 113 | self.critic_Q_ex.load_weights(critic_Q_ex_file) 114 | self.critic_Q_ex_hat.load_weights(critic_Q_ex_file) 115 | if os.path.exists(critic_Q_ex2_file): 116 | self.critic_Q_ex2.load_weights(critic_Q_ex2_file) 117 | self.critic_Q_ex2_hat.load_weights(critic_Q_ex2_file) 118 | 119 | def _build_actor(self, learning_rate=1e-3): 120 | """basic NN model. 121 | """ 122 | inputs = Input(shape=(self.env.num_state,)) 123 | 124 | # bn after input 125 | x = BatchNormalization()(inputs) 126 | 127 | # bn after activation 128 | x = Dense(32, activation="relu")(x) 129 | x = BatchNormalization()(x) 130 | 131 | x = Dense(64, activation="relu")(x) 132 | x = BatchNormalization()(x) 133 | 134 | # no bn for output layer 135 | x = Dense(1, activation="sigmoid")(x) 136 | 137 | output = Lambda(lambda x: x * self.env.num_contract * 100)(x) 138 | 139 | model = Model(inputs=inputs, outputs=output) 140 | 141 | # compile the model using mse loss, but won't use mse to train 142 | model.compile(loss="mse", optimizer=Adam(learning_rate)) 143 | 144 | return model 145 | 146 | def _build_critic(self, learning_rate=1e-3): 147 | """basic NN model. 148 | """ 149 | # inputs 150 | s_inputs = Input(shape=(self.env.num_state,)) 151 | a_inputs = Input(shape=(1,)) 152 | 153 | # combine inputs 154 | x = concatenate([s_inputs, a_inputs]) 155 | 156 | # bn after input 157 | x = BatchNormalization()(x) 158 | 159 | # Q_ex network 160 | 161 | # bn after activation 162 | x1 = Dense(32, activation="relu")(x) 163 | x1 = BatchNormalization()(x1) 164 | 165 | x1 = Dense(64, activation="relu")(x1) 166 | x1 = BatchNormalization()(x1) 167 | 168 | # no bn for output layer 169 | output1 = Dense(1, activation="linear")(x1) 170 | 171 | model_Q_ex = Model(inputs=[s_inputs, a_inputs], outputs=output1) 172 | model_Q_ex.compile(loss="mse", optimizer=Adam(learning_rate)) 173 | 174 | # Q_ex2 network 175 | 176 | # bn after activation 177 | x2 = Dense(32, activation="relu")(x) 178 | x2 = BatchNormalization()(x2) 179 | 180 | # bn after activation 181 | x2 = Dense(64, activation="relu")(x2) 182 | x2 = BatchNormalization()(x2) 183 | 184 | # no bn for output layer 185 | output2 = Dense(1, activation="linear")(x2) 186 | 187 | model_Q_ex2 = Model(inputs=[s_inputs, a_inputs], outputs=output2) 188 | model_Q_ex2.compile(loss="mse", optimizer=Adam(learning_rate)) 189 | 190 | # Q 191 | output3 = Lambda(lambda o: o[0] - self.ra_c * K.sqrt(K.max(o[1] - o[0] * o[0], 0)))([output1, output2]) 192 | model_Q = Model(inputs=[s_inputs, a_inputs], outputs=output3) 193 | model_Q.compile(loss="mse", optimizer=Adam(learning_rate)) 194 | 195 | return model_Q_ex, model_Q_ex2, model_Q 196 | 197 | def actor_optimizer(self): 198 | """actor_optimizer. 199 | Returns: 200 | function, opt function for actor. 201 | """ 202 | self.ainput = self.actor.input 203 | aoutput = self.actor.output 204 | trainable_weights = self.actor.trainable_weights 205 | self.action_gradient = tf.placeholder(tf.float32, shape=(None, 1)) 206 | 207 | # tf.gradients calculates dy/dx with a initial gradients for y 208 | # action_gradient is dq/da, so this is dq/da * da/dparams 209 | params_grad = tf.gradients(aoutput, trainable_weights, -self.action_gradient) 210 | grads = zip(params_grad, trainable_weights) 211 | self.opt = tf.train.AdamOptimizer(self.actor_lr).apply_gradients(grads) 212 | self.sess.run(tf.global_variables_initializer()) 213 | 214 | def critic_gradient(self): 215 | """get critic gradient function. 216 | Returns: 217 | function, gradient function for critic. 218 | """ 219 | cinput = self.critic_Q.input 220 | coutput = self.critic_Q.output 221 | 222 | # compute the gradient of the action with q value, dq/da. 223 | action_grads = K.gradients(coutput, cinput[1]) 224 | 225 | return K.function([cinput[0], cinput[1]], action_grads) 226 | 227 | def egreedy_action(self, X): 228 | """get actor action with ou noise. 229 | Arguments: 230 | X: state value. 231 | """ 232 | # do the epsilon greedy way; not using OU 233 | if np.random.rand() <= self.epsilon: 234 | action = env.action_space.sample() 235 | 236 | # may use for 2nd round training 237 | # action = self.actor.predict(X)[0][0] 238 | # noise = np.clip(np.random.normal(0, self.policy_noise), -self.noise_clip, self.noise_clip) 239 | # action = np.clip(action + noise, 0, self.env.num_contract * 100) 240 | else: 241 | action = self.actor.predict(X)[0][0] 242 | 243 | return action, None, None 244 | 245 | def update_epsilon(self): 246 | """update epsilon 247 | """ 248 | if self.epsilon >= self.epsilon_min: 249 | self.epsilon *= self.epsilon_decay 250 | 251 | def remember(self, state, action, reward, next_state, done): 252 | """add data to experience replay. 253 | Arguments: 254 | state: observation 255 | action: action 256 | reward: reward 257 | next_state: next_observation 258 | done: if game is done. 259 | """ 260 | self.replay_buffer.add(state, action, reward, next_state, done) 261 | 262 | def process_batch(self, batch_size): 263 | """process batch data 264 | Arguments: 265 | batch: batch size 266 | Returns: 267 | states: batch of states 268 | actions: batch of actions 269 | target_q_ex, target_q_ex2: batch of targets; 270 | weights: priority weights 271 | """ 272 | # prioritized sample from experience replay buffer 273 | experience = self.replay_buffer.sample(batch_size, beta=self.beta_schedule.value(self.t)) 274 | (states, actions, rewards, next_states, dones, weights, batch_idxes) = experience 275 | 276 | actions = actions.reshape(-1, 1) 277 | rewards = rewards.reshape(-1, 1) 278 | dones = dones.reshape(-1, 1) 279 | 280 | # get next_actions 281 | next_actions = self.actor_hat.predict(next_states) 282 | 283 | # prepare targets for Q_ex and Q_ex2 training 284 | q_ex_next = self.critic_Q_ex_hat.predict([next_states, next_actions]) 285 | q_ex2_next = self.critic_Q_ex2_hat.predict([next_states, next_actions]) 286 | 287 | target_q_ex = rewards + (1 - dones) * q_ex_next 288 | target_q_ex2 = rewards ** 2 + (1 - dones) * (2 * rewards * q_ex_next + q_ex2_next) 289 | 290 | # use Q2 TD error as priority weight 291 | td_errors = self.critic_Q_ex2.predict([states, actions]) - target_q_ex2 292 | new_priorities = (np.abs(td_errors) + self.prioritized_replay_eps).flatten() 293 | self.replay_buffer.update_priorities(batch_idxes, new_priorities) 294 | 295 | return states, actions, target_q_ex, target_q_ex2, weights 296 | 297 | def update_model(self, X1, X2, y1, y2, weights): 298 | """update ddpg model. 299 | Arguments: 300 | X1: states 301 | X2: actions 302 | y1: target for Q_ex 303 | y2: target for Q_ex2 304 | weights: priority weights 305 | Returns: 306 | loss_ex: critic Q_ex loss 307 | loss_ex2: critic Q_ex2 loss 308 | """ 309 | # flatten to prepare for training with weights 310 | weights = weights.flatten() 311 | 312 | # default batch size is 32 313 | loss_ex = self.critic_Q_ex.fit([X1, X2], y1, sample_weight=weights, verbose=0) 314 | loss_ex = np.mean(loss_ex.history['loss']) 315 | 316 | # default batch size is 32 317 | loss_ex2 = self.critic_Q_ex2.fit([X1, X2], y2, sample_weight=weights, verbose=0) 318 | loss_ex2 = np.mean(loss_ex2.history['loss']) 319 | 320 | X3 = self.actor.predict(X1) 321 | 322 | a_grads = np.array(self.get_critic_grad([X1, X3]))[0] 323 | self.sess.run(self.opt, feed_dict={ 324 | self.ainput: X1, 325 | self.action_gradient: a_grads 326 | }) 327 | 328 | return loss_ex, loss_ex2 329 | 330 | def update_target_model(self): 331 | """soft update target model. 332 | """ 333 | critic_Q_ex_weights = self.critic_Q_ex.get_weights() 334 | critic_Q_ex2_weights = self.critic_Q_ex2.get_weights() 335 | actor_weights = self.actor.get_weights() 336 | 337 | critic_Q_ex_hat_weights = self.critic_Q_ex_hat.get_weights() 338 | critic_Q_ex2_hat_weights = self.critic_Q_ex2_hat.get_weights() 339 | actor_hat_weights = self.actor_hat.get_weights() 340 | 341 | for i in range(len(critic_Q_ex_weights)): 342 | critic_Q_ex_hat_weights[i] = self.TAU * critic_Q_ex_weights[i] + (1 - self.TAU) * critic_Q_ex_hat_weights[i] 343 | 344 | for i in range(len(critic_Q_ex2_weights)): 345 | critic_Q_ex2_hat_weights[i] = self.TAU * critic_Q_ex2_weights[i] + (1 - self.TAU) * critic_Q_ex2_hat_weights[i] 346 | 347 | for i in range(len(actor_weights)): 348 | actor_hat_weights[i] = self.TAU * actor_weights[i] + (1 - self.TAU) * actor_hat_weights[i] 349 | 350 | self.critic_Q_ex_hat.set_weights(critic_Q_ex_hat_weights) 351 | self.critic_Q_ex2_hat.set_weights(critic_Q_ex2_hat_weights) 352 | self.actor_hat.set_weights(actor_hat_weights) 353 | 354 | def train(self, episode): 355 | """training 356 | Arguments: 357 | episode: total episodes to run 358 | 359 | Returns: 360 | history: training history 361 | """ 362 | 363 | # some statistics 364 | history = {"episode": [], "episode_w_T": [], "loss_ex": [], "loss_ex2": []} 365 | 366 | for i in range(episode): 367 | observation = self.env.reset() 368 | done = False 369 | 370 | # for recording purpose 371 | y_action = np.empty(0, dtype=int) 372 | reward_store = np.empty(0) 373 | 374 | self.t = i 375 | 376 | # steps in an episode 377 | while not done: 378 | 379 | # prepare state 380 | x = np.array(observation).reshape(1, -1) 381 | 382 | # chocie action from epsilon-greedy. 383 | action, _, _ = self.egreedy_action(x) 384 | 385 | # one step 386 | observation, reward, done, info = self.env.step(action) 387 | 388 | # record action and reward 389 | y_action = np.append(y_action, action) 390 | reward_store = np.append(reward_store, reward) 391 | 392 | # store to memory 393 | self.remember(x[0], action, reward, observation, done) 394 | 395 | if len(self.replay_buffer) > self.batch_size: 396 | 397 | # draw from memory 398 | X1, X2, y_ex, y_ex2, weights = self.process_batch(self.batch_size) 399 | 400 | # update model 401 | loss_ex, loss_ex2 = self.update_model(X1, X2, y_ex, y_ex2, weights) 402 | 403 | # soft update target 404 | self.update_target_model() 405 | 406 | # reduce epsilon per episode 407 | self.update_epsilon() 408 | 409 | # print/store some statistics every 1000 episodes 410 | if i % 1000 == 0 and i != 0: 411 | 412 | # may want to print/store some statistics every 100 episodes 413 | # if i % 100 == 0 and i >= 1000: 414 | 415 | # get w_T for statistics 416 | w_T = np.sum(reward_store) 417 | 418 | history["episode"].append(i) 419 | history["episode_w_T"].append(w_T) 420 | history["loss_ex"].append(loss_ex) 421 | history["loss_ex2"].append(loss_ex2) 422 | 423 | path_row = info["path_row"] 424 | print(info) 425 | print( 426 | "episode: {} | episode final wealth: {:.3f} | loss_ex: {:.3f} | loss_ex2: {:.3f} | epsilon:{:.2f}".format( 427 | i, w_T, loss_ex, loss_ex2, self.epsilon 428 | ) 429 | ) 430 | 431 | with np.printoptions(precision=2, suppress=True): 432 | print("episode: {} | rewards {}".format(i, reward_store)) 433 | print("episode: {} | actions taken {}".format(i, y_action)) 434 | print("episode: {} | deltas {}".format(i, self.env.delta_path[path_row] * 100)) 435 | print("episode: {} | stock price {}".format(i, self.env.path[path_row])) 436 | print("episode: {} | option price {}\n".format(i, self.env.option_price_path[path_row] * 100)) 437 | 438 | # may want to save model every 100 episode 439 | # if i % 100 == 0: 440 | # self.actor.save_weights("model/ddpg_actor_" + str(int(i/100)) + ".h5") 441 | # self.critic_Q_ex.save_weights("model/ddpg_critic_Q_ex_" + str(int(i/100)) + ".h5") 442 | # self.critic_Q_ex2.save_weights("model/ddpg_critic_Q_ex2_" + str(int(i/100)) + ".h5") 443 | self.actor.save_weights("model/ddpg_actor_" + str(int(i/1000)) + ".h5") 444 | self.critic_Q_ex.save_weights("model/ddpg_critic_Q_ex_" + str(int(i/1000)) + ".h5") 445 | self.critic_Q_ex2.save_weights("model/ddpg_critic_Q_ex2_" + str(int(i/1000)) + ".h5") 446 | 447 | # save weights once training is done 448 | self.actor.save_weights("model/ddpg_actor.h5") 449 | self.critic_Q_ex.save_weights("model/ddpg_critic_Q_ex.h5") 450 | self.critic_Q_ex2.save_weights("model/ddpg_critic_Q_ex2.h5") 451 | 452 | return history 453 | 454 | if __name__ == "__main__": 455 | 456 | # disable GPU 457 | os.environ['CUDA_VISIBLE_DEVICES'] = '-1' 458 | 459 | # setup for training 460 | # use init_ttm, spread and other arguments to train for different scenarios 461 | env = TradingEnv(continuous_action_flag=True, sabr_flag=True, dg_random_seed=1, init_ttm=20, trade_freq=1, spread=0.01, num_contract=1, num_sim=50002) 462 | ddpg = DDPG(env) 463 | 464 | # for second round training, specify the tag of weights to load 465 | # ddpg.load(tag="50") 466 | 467 | # for second round training, may want to start with a specific value of epsilon 468 | # ddpg.epsilon = 0.1 469 | 470 | # episode for training: 0 to 50000 inclusive 471 | # cycle through available data paths if number of episode for training > number of sim paths 472 | history = ddpg.train(50001) 473 | ddpg.save_history(history, "ddpg.csv") 474 | 475 | # setup for testing; use another instance for testing 476 | env_test = TradingEnv(continuous_action_flag=True, sabr_flag=True, dg_random_seed=2, init_ttm=20, trade_freq=1, spread=0.01, num_contract=1, num_sim=100001) 477 | ddpg_test = DDPG(env_test) 478 | ddpg_test.load() 479 | 480 | # episode for testing: 0 to 100000 inclusive 481 | ddpg_test.test(100001) --------------------------------------------------------------------------------