├── utility.py ├── LICENSE ├── rank_based_test.py ├── binary_heap_test.py ├── README.md ├── rank_based.py └── binary_heap.py /utility.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding=utf-8 -*- 3 | # author: Ian 4 | # e-mail: stmayue@gmail.com 5 | # description: 6 | 7 | 8 | def list_to_dict(in_list): 9 | return dict((i, in_list[i]) for i in range(0, len(in_list))) 10 | 11 | 12 | def exchange_key_value(in_dict): 13 | return dict((in_dict[i], i) for i in in_dict) 14 | 15 | 16 | def main(): 17 | pass 18 | 19 | 20 | if __name__ == '__main__': 21 | main() 22 | 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Damcy 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /rank_based_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding=utf-8 -*- 3 | # author: Ian 4 | # e-mail: stmayue@gmail.com 5 | # description: 6 | 7 | import rank_based 8 | 9 | 10 | def test(): 11 | conf = {'size': 50, 12 | 'learn_start': 10, 13 | 'partition_num': 5, 14 | 'total_step': 100, 15 | 'batch_size': 4} 16 | experience = rank_based.Experience(conf) 17 | 18 | # insert to experience 19 | print('test insert experience') 20 | for i in range(1, 51): 21 | # tuple, like(state_t, a, r, state_t_1, t) 22 | to_insert = (i, 1, 1, i, 1) 23 | experience.store(to_insert) 24 | print(experience.priority_queue) 25 | print(experience._experience[1]) 26 | print(experience._experience[2]) 27 | print('test replace') 28 | to_insert = (51, 1, 1, 51, 1) 29 | experience.store(to_insert) 30 | print(experience.priority_queue) 31 | print(experience._experience[1]) 32 | print(experience._experience[2]) 33 | 34 | # sample 35 | print('test sample') 36 | sample, w, e_id = experience.sample(51) 37 | print(sample) 38 | print(w) 39 | print(e_id) 40 | 41 | # update delta to priority 42 | print('test update delta') 43 | delta = [v for v in range(1, 5)] 44 | experience.update_priority(e_id, delta) 45 | print(experience.priority_queue) 46 | sample, w, e_id = experience.sample(51) 47 | print(sample) 48 | print(w) 49 | print(e_id) 50 | 51 | # rebalance 52 | print('test rebalance') 53 | experience.rebalance() 54 | print(experience.priority_queue) 55 | 56 | 57 | def main(): 58 | test() 59 | 60 | 61 | if __name__ == '__main__': 62 | main() 63 | 64 | -------------------------------------------------------------------------------- /binary_heap_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding=utf-8 -*- 3 | # author: Ian 4 | # e-mail: stmayue@gmail.com 5 | # description: 6 | 7 | 8 | import binary_heap 9 | 10 | 11 | def test(): 12 | test_data = [(1, 0), (0.9, 1), (1.1, 2), (1.1, 3), (3.3, 4), (0, 5), (0.93, 6)] 13 | BH = binary_heap.BinaryHeap(7) 14 | 15 | # empty queue 16 | print('empty queue') 17 | print(BH) 18 | print('max priority') 19 | print(BH.get_max_priority()) 20 | 21 | # insert 22 | print('\ntest insert') 23 | for p, i in test_data: 24 | BH.update(p, i) 25 | print(BH) 26 | print(BH.p2e) 27 | print(BH.e2p) 28 | 29 | # update 30 | print('\ntest update') 31 | BH.update(9.9, 0) 32 | print(BH) 33 | print(BH.p2e) 34 | print(BH.e2p) 35 | 36 | # get max priority 37 | print('\nmax priority') 38 | print(BH.get_max_priority()) 39 | 40 | # re balance 41 | print('\ntest re_balance') 42 | BH.balance_tree() 43 | print(BH) 44 | print(BH.p2e) 45 | print(BH.e2p) 46 | 47 | # full insert 48 | print('\ntest full insert') 49 | BH.update(9.2, 7) 50 | print(BH) 51 | 52 | # pop 53 | print('\ntest pop') 54 | p, i = BH.pop() 55 | print('pop out: ', p, i) 56 | print(BH) 57 | print(BH.p2e) 58 | print(BH.e2p) 59 | 60 | # get priority 61 | print('\ntest get priority') 62 | print(BH.get_priority()) 63 | 64 | # get e id 65 | print('\ntest e id') 66 | print(BH.get_e_id()) 67 | 68 | # p id to e id 69 | print('\ntest p id to e id') 70 | print(BH.priority_to_experience([2, 3, 6])) 71 | 72 | 73 | def main(): 74 | test() 75 | 76 | 77 | if __name__ == '__main__': 78 | main() 79 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Prioritized Experience Replay 2 | 3 | ### Usage 4 | 1. in rank_base.py Experience.stroe give a simple description of store replay memory, or you can also refer rank_base_test.py 5 | 2. It's more convenient to store replay as format (state_1, action_1, reward, state_2, terminal). If we use this method, all replay memory in Experience are legal and can be sampled as we like. 6 | 3. run it with python3/python2.7 7 | 8 | ### Rank-based 9 | use binary heap tree as priority queue, and build an Experience class to store and retrieve the sample 10 | 11 | Interface: 12 | * All interfaces are in rank_based.py 13 | * init conf, please read Experience.__init__ for more detail, all parameters can be set by input conf 14 | * replay sample store: Experience.store 15 | params: [in] experience, sample to store 16 | returns: bools, True for success, False for failed 17 | * replay sample sample: Experience.sample 18 | params: [in] global_step, used for cal beta 19 | returns: 20 | experience, list of samples 21 | w, list of weight 22 | rank_e_id, list of experience's id, used for update priority value 23 | * update priority value: Experience.update 24 | params: 25 | [in] indices, rank_e_ids 26 | [in] delta, new TD-error 27 | 28 | ### Proportional 29 | you can find the implementation here: [proportional](https://github.com/takoika/PrioritizedExperienceReplay) 30 | 31 | ### Reference 32 | 1. "Prioritized Experience Replay" http://arxiv.org/abs/1511.05952 33 | 2. [Atari](https://github.com/Kaixhin/Atari) by @Kaixhin, Atari uses torch to implement rank-based algorithm. 34 | 35 | ### Application 36 | 1. TEST1 PASSED: These code has been applied to my own NLP DQN experiment, it significantly improves performance. See [here](https://github.com/Damcy/cascadeLSTMDRL) for more detail. 37 | -------------------------------------------------------------------------------- /rank_based.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding=utf-8 -*- 3 | # author: Ian 4 | # e-mail: stmayue@gmail.com 5 | # description: 6 | 7 | import sys 8 | import math 9 | import random 10 | import numpy as np 11 | 12 | import binary_heap 13 | 14 | 15 | class Experience(object): 16 | 17 | def __init__(self, conf): 18 | self.size = conf['size'] 19 | self.replace_flag = conf['replace_old'] if 'replace_old' in conf else True 20 | self.priority_size = conf['priority_size'] if 'priority_size' in conf else self.size 21 | 22 | self.alpha = conf['alpha'] if 'alpha' in conf else 0.7 23 | self.beta_zero = conf['beta_zero'] if 'beta_zero' in conf else 0.5 24 | self.batch_size = conf['batch_size'] if 'batch_size' in conf else 32 25 | self.learn_start = conf['learn_start'] if 'learn_start' in conf else 1000 26 | self.total_steps = conf['steps'] if 'steps' in conf else 100000 27 | # partition number N, split total size to N part 28 | self.partition_num = conf['partition_num'] if 'partition_num' in conf else 100 29 | 30 | self.index = 0 31 | self.record_size = 0 32 | self.isFull = False 33 | 34 | self._experience = {} 35 | self.priority_queue = binary_heap.BinaryHeap(self.priority_size) 36 | self.distributions = self.build_distributions() 37 | 38 | self.beta_grad = (1 - self.beta_zero) / float(self.total_steps - self.learn_start) 39 | 40 | def build_distributions(self): 41 | """ 42 | preprocess pow of rank 43 | (rank i) ^ (-alpha) / sum ((rank i) ^ (-alpha)) 44 | :return: distributions, dict 45 | """ 46 | res = {} 47 | n_partitions = self.partition_num 48 | partition_num = 1 49 | # each part size 50 | partition_size = int(math.floor(self.size / n_partitions)) 51 | 52 | for n in range(partition_size, self.size + 1, partition_size): 53 | if self.learn_start <= n <= self.priority_size: 54 | distribution = {} 55 | # P(i) = (rank i) ^ (-alpha) / sum ((rank i) ^ (-alpha)) 56 | pdf = list( 57 | map(lambda x: math.pow(x, -self.alpha), range(1, n + 1)) 58 | ) 59 | pdf_sum = math.fsum(pdf) 60 | distribution['pdf'] = list(map(lambda x: x / pdf_sum, pdf)) 61 | # split to k segment, and than uniform sample in each k 62 | # set k = batch_size, each segment has total probability is 1 / batch_size 63 | # strata_ends keep each segment start pos and end pos 64 | cdf = np.cumsum(distribution['pdf']) 65 | strata_ends = {1: 0, self.batch_size + 1: n} 66 | step = 1 / float(self.batch_size) 67 | index = 1 68 | for s in range(2, self.batch_size + 1): 69 | while cdf[index] < step: 70 | index += 1 71 | strata_ends[s] = index 72 | step += 1 / float(self.batch_size) 73 | 74 | distribution['strata_ends'] = strata_ends 75 | 76 | res[partition_num] = distribution 77 | 78 | partition_num += 1 79 | 80 | return res 81 | 82 | def fix_index(self): 83 | """ 84 | get next insert index 85 | :return: index, int 86 | """ 87 | if self.record_size <= self.size: 88 | self.record_size += 1 89 | if self.index % self.size == 0: 90 | self.isFull = True if len(self._experience) == self.size else False 91 | if self.replace_flag: 92 | self.index = 1 93 | return self.index 94 | else: 95 | sys.stderr.write('Experience replay buff is full and replace is set to FALSE!\n') 96 | return -1 97 | else: 98 | self.index += 1 99 | return self.index 100 | 101 | def store(self, experience): 102 | """ 103 | store experience, suggest that experience is a tuple of (s1, a, r, s2, t) 104 | so each experience is valid 105 | :param experience: maybe a tuple, or list 106 | :return: bool, indicate insert status 107 | """ 108 | insert_index = self.fix_index() 109 | if insert_index > 0: 110 | if insert_index in self._experience: 111 | del self._experience[insert_index] 112 | self._experience[insert_index] = experience 113 | # add to priority queue 114 | priority = self.priority_queue.get_max_priority() 115 | self.priority_queue.update(priority, insert_index) 116 | return True 117 | else: 118 | sys.stderr.write('Insert failed\n') 119 | return False 120 | 121 | def retrieve(self, indices): 122 | """ 123 | get experience from indices 124 | :param indices: list of experience id 125 | :return: experience replay sample 126 | """ 127 | return [self._experience[v] for v in indices] 128 | 129 | def rebalance(self): 130 | """ 131 | rebalance priority queue 132 | :return: None 133 | """ 134 | self.priority_queue.balance_tree() 135 | 136 | def update_priority(self, indices, delta): 137 | """ 138 | update priority according indices and deltas 139 | :param indices: list of experience id 140 | :param delta: list of delta, order correspond to indices 141 | :return: None 142 | """ 143 | for i in range(0, len(indices)): 144 | self.priority_queue.update(math.fabs(delta[i]), indices[i]) 145 | 146 | def sample(self, global_step): 147 | """ 148 | sample a mini batch from experience replay 149 | :param global_step: now training step 150 | :return: experience, list, samples 151 | :return: w, list, weights 152 | :return: rank_e_id, list, samples id, used for update priority 153 | """ 154 | if self.record_size < self.learn_start: 155 | sys.stderr.write('Record size less than learn start! Sample failed\n') 156 | return False, False, False 157 | 158 | dist_index = math.floor(self.record_size / self.size * self.partition_num) 159 | # issue 1 by @camigord 160 | partition_size = math.floor(self.size / self.partition_num) 161 | partition_max = dist_index * partition_size 162 | distribution = self.distributions[dist_index] 163 | rank_list = [] 164 | # sample from k segments 165 | for n in range(1, self.batch_size + 1): 166 | index = random.randint(distribution['strata_ends'][n] + 1, 167 | distribution['strata_ends'][n + 1]) 168 | rank_list.append(index) 169 | 170 | # beta, increase by global_step, max 1 171 | beta = min(self.beta_zero + (global_step - self.learn_start - 1) * self.beta_grad, 1) 172 | # find all alpha pow, notice that pdf is a list, start from 0 173 | alpha_pow = [distribution['pdf'][v - 1] for v in rank_list] 174 | # w = (N * P(i)) ^ (-beta) / max w 175 | w = np.power(np.array(alpha_pow) * partition_max, -beta) 176 | w_max = max(w) 177 | w = np.divide(w, w_max) 178 | # rank list is priority id 179 | # convert to experience id 180 | rank_e_id = self.priority_queue.priority_to_experience(rank_list) 181 | # get experience id according rank_e_id 182 | experience = self.retrieve(rank_e_id) 183 | return experience, w, rank_e_id 184 | -------------------------------------------------------------------------------- /binary_heap.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding=utf-8 -*- 3 | # author: Ian 4 | # e-mail: stmayue@gmail.com 5 | # description: 6 | 7 | import sys 8 | import math 9 | 10 | import utility 11 | 12 | 13 | class BinaryHeap(object): 14 | 15 | def __init__(self, priority_size=100, priority_init=None, replace=True): 16 | self.e2p = {} 17 | self.p2e = {} 18 | self.replace = replace 19 | 20 | if priority_init is None: 21 | self.priority_queue = {} 22 | self.size = 0 23 | self.max_size = priority_size 24 | else: 25 | # not yet test 26 | self.priority_queue = priority_init 27 | self.size = len(self.priority_queue) 28 | self.max_size = None or self.size 29 | 30 | experience_list = list(map(lambda x: self.priority_queue[x], self.priority_queue)) 31 | self.p2e = utility.list_to_dict(experience_list) 32 | self.e2p = utility.exchange_key_value(self.p2e) 33 | for i in range(int(self.size / 2), -1, -1): 34 | self.down_heap(i) 35 | 36 | def __repr__(self): 37 | """ 38 | :return: string of the priority queue, with level info 39 | """ 40 | if self.size == 0: 41 | return 'No element in heap!' 42 | to_string = '' 43 | level = -1 44 | max_level = int(math.floor(math.log(self.size, 2))) 45 | 46 | for i in range(1, self.size + 1): 47 | now_level = int(math.floor(math.log(i, 2))) 48 | if level != now_level: 49 | to_string = to_string + ('\n' if level != -1 else '') \ 50 | + ' ' * (max_level - now_level) 51 | level = now_level 52 | 53 | to_string = to_string + '%.2f ' % self.priority_queue[i][1] + ' ' * (max_level - now_level) 54 | 55 | return to_string 56 | 57 | def check_full(self): 58 | return self.size > self.max_size 59 | 60 | def _insert(self, priority, e_id): 61 | """ 62 | insert new experience id with priority 63 | (maybe don't need get_max_priority and implement it in this function) 64 | :param priority: priority value 65 | :param e_id: experience id 66 | :return: bool 67 | """ 68 | self.size += 1 69 | 70 | if self.check_full() and not self.replace: 71 | sys.stderr.write('Error: no space left to add experience id %d with priority value %f\n' % (e_id, priority)) 72 | return False 73 | else: 74 | self.size = min(self.size, self.max_size) 75 | 76 | self.priority_queue[self.size] = (priority, e_id) 77 | self.p2e[self.size] = e_id 78 | self.e2p[e_id] = self.size 79 | 80 | self.up_heap(self.size) 81 | return True 82 | 83 | def update(self, priority, e_id): 84 | """ 85 | update priority value according its experience id 86 | :param priority: new priority value 87 | :param e_id: experience id 88 | :return: bool 89 | """ 90 | if e_id in self.e2p: 91 | p_id = self.e2p[e_id] 92 | self.priority_queue[p_id] = (priority, e_id) 93 | self.p2e[p_id] = e_id 94 | 95 | self.down_heap(p_id) 96 | self.up_heap(p_id) 97 | return True 98 | else: 99 | # this e id is new, do insert 100 | return self._insert(priority, e_id) 101 | 102 | def get_max_priority(self): 103 | """ 104 | get max priority, if no experience, return 1 105 | :return: max priority if size > 0 else 1 106 | """ 107 | if self.size > 0: 108 | return self.priority_queue[1][0] 109 | else: 110 | return 1 111 | 112 | def pop(self): 113 | """ 114 | pop out the max priority value with its experience id 115 | :return: priority value & experience id 116 | """ 117 | if self.size == 0: 118 | sys.stderr.write('Error: no value in heap, pop failed\n') 119 | return False, False 120 | 121 | pop_priority, pop_e_id = self.priority_queue[1] 122 | self.e2p[pop_e_id] = -1 123 | # replace first 124 | last_priority, last_e_id = self.priority_queue[self.size] 125 | self.priority_queue[1] = (last_priority, last_e_id) 126 | self.size -= 1 127 | self.e2p[last_e_id] = 1 128 | self.p2e[1] = last_e_id 129 | 130 | self.down_heap(1) 131 | 132 | return pop_priority, pop_e_id 133 | 134 | def up_heap(self, i): 135 | """ 136 | upward balance 137 | :param i: tree node i 138 | :return: None 139 | """ 140 | if i > 1: 141 | parent = math.floor(i / 2) 142 | if self.priority_queue[parent][0] < self.priority_queue[i][0]: 143 | tmp = self.priority_queue[i] 144 | self.priority_queue[i] = self.priority_queue[parent] 145 | self.priority_queue[parent] = tmp 146 | # change e2p & p2e 147 | self.e2p[self.priority_queue[i][1]] = i 148 | self.e2p[self.priority_queue[parent][1]] = parent 149 | self.p2e[i] = self.priority_queue[i][1] 150 | self.p2e[parent] = self.priority_queue[parent][1] 151 | # up heap parent 152 | self.up_heap(parent) 153 | 154 | def down_heap(self, i): 155 | """ 156 | downward balance 157 | :param i: tree node i 158 | :return: None 159 | """ 160 | if i < self.size: 161 | greatest = i 162 | left, right = i * 2, i * 2 + 1 163 | if left < self.size and self.priority_queue[left][0] > self.priority_queue[greatest][0]: 164 | greatest = left 165 | if right < self.size and self.priority_queue[right][0] > self.priority_queue[greatest][0]: 166 | greatest = right 167 | 168 | if greatest != i: 169 | tmp = self.priority_queue[i] 170 | self.priority_queue[i] = self.priority_queue[greatest] 171 | self.priority_queue[greatest] = tmp 172 | # change e2p & p2e 173 | self.e2p[self.priority_queue[i][1]] = i 174 | self.e2p[self.priority_queue[greatest][1]] = greatest 175 | self.p2e[i] = self.priority_queue[i][1] 176 | self.p2e[greatest] = self.priority_queue[greatest][1] 177 | # down heap greatest 178 | self.down_heap(greatest) 179 | 180 | def get_priority(self): 181 | """ 182 | get all priority value 183 | :return: list of priority 184 | """ 185 | return list(map(lambda x: x[0], self.priority_queue.values()))[0:self.size] 186 | 187 | def get_e_id(self): 188 | """ 189 | get all experience id in priority queue 190 | :return: list of experience ids order by their priority 191 | """ 192 | return list(map(lambda x: x[1], self.priority_queue.values()))[0:self.size] 193 | 194 | def balance_tree(self): 195 | """ 196 | rebalance priority queue 197 | :return: None 198 | """ 199 | sort_array = sorted(self.priority_queue.values(), key=lambda x: x[0], reverse=True) 200 | # reconstruct priority_queue 201 | self.priority_queue.clear() 202 | self.p2e.clear() 203 | self.e2p.clear() 204 | cnt = 1 205 | while cnt <= self.size: 206 | priority, e_id = sort_array[cnt - 1] 207 | self.priority_queue[cnt] = (priority, e_id) 208 | self.p2e[cnt] = e_id 209 | self.e2p[e_id] = cnt 210 | cnt += 1 211 | # sort the heap 212 | for i in range(int(math.floor(self.size / 2)), 1, -1): 213 | self.down_heap(i) 214 | 215 | def priority_to_experience(self, priority_ids): 216 | """ 217 | retrieve experience ids by priority ids 218 | :param priority_ids: list of priority id 219 | :return: list of experience id 220 | """ 221 | return [self.p2e[i] for i in priority_ids] 222 | --------------------------------------------------------------------------------