├── __init__.py ├── Common ├── __init__.py ├── __pycache__ │ ├── SumTree.cpython-36.pyc │ ├── __init__.cpython-36.pyc │ ├── ValueCaculator.cpython-36.pyc │ └── prioritized_memory.cpython-36.pyc ├── SumTree.py ├── prioritized_memory.py └── ValueCaculator.py ├── DQNfromDemo ├── __init__.py ├── __init__.pyc ├── __pycache__ │ ├── DQfD.cpython-37.pyc │ └── __init__.cpython-37.pyc ├── Test │ └── CartPole.py └── DQfD.py ├── DQNwithNoisyNet ├── __init__.py ├── __init__.pyc ├── Test │ ├── CartPoleExpert.txt │ ├── GenerateDemoData.py │ ├── MountainCar.py │ └── CartPole.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── NoisyLayer.cpython-36.pyc │ ├── NoisyLayer.cpython-37.pyc │ ├── DQN_NoisyNet.cpython-36.pyc │ └── DQN_NoisyNet.cpython-37.pyc ├── NoisyLayer.py └── DQN_NoisyNet.py ├── .gitignore ├── __init__.pyc ├── __pycache__ └── __init__.cpython-36.pyc ├── .idea ├── vcs.xml ├── misc.xml ├── modules.xml ├── DRL-using-PyTorch.iml └── workspace.xml ├── test.py └── README.md /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Common/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /DQNfromDemo/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /DQNwithNoisyNet/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.npy 2 | *.txt 3 | *.swp 4 | test.p 5 | *.py 6 | /__pycache__ 7 | -------------------------------------------------------------------------------- /__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LilTwo/DRL-using-PyTorch/HEAD/__init__.pyc -------------------------------------------------------------------------------- /DQNfromDemo/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LilTwo/DRL-using-PyTorch/HEAD/DQNfromDemo/__init__.pyc -------------------------------------------------------------------------------- /DQNwithNoisyNet/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LilTwo/DRL-using-PyTorch/HEAD/DQNwithNoisyNet/__init__.pyc -------------------------------------------------------------------------------- /__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LilTwo/DRL-using-PyTorch/HEAD/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /DQNwithNoisyNet/Test/CartPoleExpert.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LilTwo/DRL-using-PyTorch/HEAD/DQNwithNoisyNet/Test/CartPoleExpert.txt -------------------------------------------------------------------------------- /Common/__pycache__/SumTree.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LilTwo/DRL-using-PyTorch/HEAD/Common/__pycache__/SumTree.cpython-36.pyc -------------------------------------------------------------------------------- /Common/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LilTwo/DRL-using-PyTorch/HEAD/Common/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /DQNfromDemo/__pycache__/DQfD.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LilTwo/DRL-using-PyTorch/HEAD/DQNfromDemo/__pycache__/DQfD.cpython-37.pyc -------------------------------------------------------------------------------- /Common/__pycache__/ValueCaculator.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LilTwo/DRL-using-PyTorch/HEAD/Common/__pycache__/ValueCaculator.cpython-36.pyc -------------------------------------------------------------------------------- /DQNfromDemo/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LilTwo/DRL-using-PyTorch/HEAD/DQNfromDemo/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /DQNwithNoisyNet/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LilTwo/DRL-using-PyTorch/HEAD/DQNwithNoisyNet/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /DQNwithNoisyNet/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LilTwo/DRL-using-PyTorch/HEAD/DQNwithNoisyNet/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /Common/__pycache__/prioritized_memory.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LilTwo/DRL-using-PyTorch/HEAD/Common/__pycache__/prioritized_memory.cpython-36.pyc -------------------------------------------------------------------------------- /DQNwithNoisyNet/__pycache__/NoisyLayer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LilTwo/DRL-using-PyTorch/HEAD/DQNwithNoisyNet/__pycache__/NoisyLayer.cpython-36.pyc -------------------------------------------------------------------------------- /DQNwithNoisyNet/__pycache__/NoisyLayer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LilTwo/DRL-using-PyTorch/HEAD/DQNwithNoisyNet/__pycache__/NoisyLayer.cpython-37.pyc -------------------------------------------------------------------------------- /DQNwithNoisyNet/__pycache__/DQN_NoisyNet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LilTwo/DRL-using-PyTorch/HEAD/DQNwithNoisyNet/__pycache__/DQN_NoisyNet.cpython-36.pyc -------------------------------------------------------------------------------- /DQNwithNoisyNet/__pycache__/DQN_NoisyNet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LilTwo/DRL-using-PyTorch/HEAD/DQNwithNoisyNet/__pycache__/DQN_NoisyNet.cpython-37.pyc -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/DRL-using-PyTorch.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from struct import unpack 3 | import Qnet 4 | import torch 5 | from trainWithMcts import testResult 6 | 7 | ''' 8 | net = Qnet.Net3() 9 | net.load_state_dict(torch.load("ddz_optimal_34.txt")) 10 | 11 | data = net.toNumpy() 12 | for name,layer in data.items(): 13 | np.save(name,layer) 14 | 15 | sa=np.arange(75) 16 | s = torch.Tensor(sa[:60]) 17 | a = torch.Tensor(sa[60:]) 18 | print(sa) 19 | print(net.fc1_s(s)+net.fc1_a(a)) 20 | ''' 21 | 22 | 23 | test = [[[-122.45939656328747, 37.796690447896445], [-122.45859061899071, 37.785810199890264], [-122.44198816647757, 37.786535549757346], [-122.43578239539256, 37.789920515803715], [-122.42828711343275, 37.77444638530603]]] 24 | result = [str(coor).strip('[]') for coor in test[0]] 25 | result = " | ".join(result) 26 | print(result) -------------------------------------------------------------------------------- /DQNwithNoisyNet/Test/GenerateDemoData.py: -------------------------------------------------------------------------------- 1 | from os import path 2 | import sys 3 | parent_dir = path.dirname(path.dirname(path.dirname(path.abspath(__file__)))) 4 | if parent_dir not in sys.path: 5 | sys.path.append(parent_dir) 6 | 7 | import gym 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from DQNwithNoisyNet.NoisyLayer import NoisyLinear 12 | from DQNwithNoisyNet import DQN_NoisyNet 13 | import json 14 | 15 | 16 | class NoisyNet2(nn.Module): 17 | def __init__(self): 18 | super().__init__() 19 | self.fc1 = NoisyLinear(4, 40) 20 | self.fc2 = NoisyLinear(40, 2) 21 | 22 | def forward(self, s): 23 | x = self.fc1(s) 24 | x = F.relu(x) 25 | x = self.fc2(x) 26 | return x 27 | 28 | def sample(self): 29 | for layer in self.children(): 30 | if hasattr(layer, "sample"): 31 | layer.sample() 32 | 33 | 34 | if __name__ == "__main__": 35 | env = gym.make('CartPole-v1') 36 | s = env.reset() 37 | A = [[0], [1]] 38 | dqn = DQN_NoisyNet.DeepQLv2(NoisyNet2, noisy=True, lr=0.002, gamma=1, actionFinder=lambda x: A) 39 | dqn.net.load_state_dict(torch.load("./CartPoleExpert.txt")) 40 | total = 0 41 | s = env.reset() 42 | epoch = 3 43 | step = 0 44 | demo = {} 45 | for e in range(epoch): 46 | data=[] 47 | while True: 48 | a = dqn.act(s)[0] 49 | s_, r, done, _ = env.step(a) 50 | r = -1 if done and total < 500 else 0.002 51 | data.append([list(s),[int(a)],r,list(s_),done]) 52 | total += 1 53 | s=s_ 54 | if done: 55 | demo[e]=data 56 | s = env.reset() 57 | print(total) 58 | total = 0 59 | break 60 | with open("CartPoleDemo.txt","w") as file: 61 | file.write(json.dumps(demo)) 62 | 63 | env.close() 64 | -------------------------------------------------------------------------------- /Common/SumTree.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | 3 | 4 | # the following code is from 5 | # https://github.com/rlcode/per 6 | 7 | class SumTree: 8 | write = 0 9 | 10 | def __init__(self, capacity): 11 | self.capacity = capacity 12 | self.tree = numpy.zeros(2 * capacity - 1) 13 | self.data = numpy.zeros(capacity, dtype=object) 14 | self.n_entries = 0 15 | self.start = 0 16 | 17 | # update to the root node 18 | def _propagate(self, idx, change): 19 | parent = (idx - 1) // 2 20 | 21 | self.tree[parent] += change 22 | 23 | if parent != 0: 24 | self._propagate(parent, change) 25 | 26 | # find sample on leaf node 27 | def _retrieve(self, idx, s): 28 | left = 2 * idx + 1 29 | right = left + 1 30 | 31 | if left >= len(self.tree): 32 | return idx 33 | 34 | if s <= self.tree[left]: 35 | return self._retrieve(left, s) 36 | else: 37 | return self._retrieve(right, s - self.tree[left]) 38 | 39 | def total(self): 40 | return self.tree[0] 41 | 42 | # store priority and sample 43 | def add(self, p, data): 44 | idx = self.write + self.capacity - 1 45 | 46 | self.data[self.write] = data 47 | self.update(idx, p) 48 | 49 | self.write += 1 50 | if self.write >= self.capacity: 51 | self.write = self.start 52 | 53 | if self.n_entries < self.capacity: 54 | self.n_entries += 1 55 | 56 | # update priority 57 | def update(self, idx, p): 58 | change = p - self.tree[idx] 59 | 60 | self.tree[idx] = p 61 | self._propagate(idx, change) 62 | 63 | # get priority and sample 64 | def get(self, s): 65 | idx = self._retrieve(0, s) 66 | dataIdx = idx - self.capacity + 1 67 | 68 | return (idx, self.tree[idx], self.data[dataIdx]) 69 | 70 | 71 | if __name__ == "__main__": 72 | print("SumTree test") 73 | tree=SumTree(500) 74 | tree.add(0.1,[1,1]) 75 | tree.add(0.3,[2,2]) 76 | print(tree.get(0.4*0.99)) -------------------------------------------------------------------------------- /Common/prioritized_memory.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | if __package__: 4 | from .SumTree import SumTree 5 | else: 6 | from Common.SumTree import SumTree 7 | import torch.nn as nn 8 | import torch 9 | 10 | 11 | #MSE with importance sampling 12 | class WeightedMSE(nn.Module): 13 | def __init__(self): 14 | super().__init__() 15 | 16 | def forward(self, inputs, targets, weights): 17 | l = torch.tensor(0.0) 18 | errors = [] 19 | for input, target, weight in zip(inputs, targets, weights): 20 | error = input - target 21 | l += error ** 2 * weight 22 | 23 | return l / weights.shape[0] 24 | 25 | 26 | #the following code is from 27 | #https://github.com/rlcode/per 28 | class Memory: # stored as ( s, a, r, s_ ) in SumTree 29 | e = 0.01 30 | a = 0.6 31 | beta = 0.0 32 | beta_increment_per_sampling = 0.001 33 | 34 | def __init__(self, capacity, epoch=150): 35 | self.tree = SumTree(capacity) 36 | self.capacity = capacity 37 | 38 | def _get_priority(self, error): 39 | return (error + self.e) ** self.a 40 | 41 | def add(self, sample,error=None): 42 | if error is None: 43 | p = self.tree.tree[0] #max priority for new data 44 | if p == 0: 45 | p = 0.1 46 | else: 47 | p = self.tree.get(p*0.9)[1] 48 | else: 49 | p = self._get_priority(error) 50 | self.tree.add(p, sample) 51 | 52 | def sample(self, n): 53 | batch = [] 54 | idxs = [] 55 | segment = self.tree.total() / n 56 | priorities = [] 57 | 58 | self.beta = np.min([1., self.beta + self.beta_increment_per_sampling]) 59 | 60 | for i in range(n): 61 | a = segment * i 62 | b = segment * (i + 1) 63 | 64 | s = random.uniform(a, b) 65 | (idx, p, data) = self.tree.get(s) 66 | priorities.append(p) 67 | batch.append(data) 68 | idxs.append(idx) 69 | 70 | sampling_probabilities = priorities / self.tree.total() 71 | is_weight = np.power(self.tree.n_entries * sampling_probabilities, -self.beta) 72 | is_weight /= is_weight.max() 73 | 74 | return batch, idxs, is_weight 75 | 76 | def update(self, idx, error): 77 | p = self._get_priority(error) 78 | self.tree.update(idx, p) 79 | 80 | 81 | if __name__ == "__main__": 82 | print("memory test") 83 | m = Memory(10, 100) 84 | m.add(1, (3, 4)) 85 | m.add(10, (5, 6)) 86 | m.add(100, (7, 8)) 87 | data, idx, _ = m.sample(2) 88 | print(m.sample(5)) 89 | m.update(11, 0) 90 | print(m.sample(5)) 91 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 1 DRL-using-PyTorch 2 | PyTorch implementations of Deep Reinforcement Algorithms 3 | 4 | All the following DQN variations are derived from DQNwithNoisyNet folder and already contain DDQN, prioritized replay, fixed target network. 5 | 6 | ## DQN with NoisyNet: 7 | https://github.com/LilTwo/DRL-using-PyTorch/tree/master/DQN_NoisyNet. 8 | reference:https://arxiv.org/pdf/1706.10295.pdf 9 | 10 | NoisyNets add randomness to the parameters of the network. 11 | With the presence of noisy layers, network is able to learn a domain-specific exploration strategy, 12 | rather than using epsilon-greedy and increase epsilon manualy during training. 13 | From my expeience, a NoisyNet usually needs a smaller leaning rate than nomal nets to work well, 14 | and is very sensitive to parameters's initial value. 15 | In MountainCar environment, there are some chances that the car never hit the top in first epsiode. 16 | I'm not sure whether this is because I wrote somthing wrong. 17 | 18 | In the original paper, auothrs suggest that the summation of "sigma" can be viewed as the stochasticity of the layer. 19 | This have been implemented in "randomness" method of "NoisyLinear" class with one modification: each "sigma" is normalized by "mu" before the summation. 20 | 21 | ## DQN from Demonstrations (DQfD) 22 | https://github.com/LilTwo/DRL-using-PyTorch/tree/master/DQNfromDemo 23 | reference:https://arxiv.org/pdf/1704.03732.pdf 24 | 25 | If there are some expert's demonstrations produced by human or another well-trained agent, one may expect these data could speed up the training process by saving time from random exploration in a large state/action space. 26 | DQfD proivds a method to leverage demonstration data by pre-training the model on the demonstartion data solely before it starts to interact with the environment. 27 | 28 | ## Hindsight Experience Replay (HER) 29 | code will be uploaded soon. 30 | reference:https://papers.nips.cc/paper/7090-hindsight-experience-replay.pdf 31 | 32 | Since model-free RL algorithms like DQN know nothing about the environment, they usually need lots of exploration to find out what is good or bad at the begining, especially when dealing with sparse reward. 33 | At the first few epochs of training, an agent is likely to get no positive reward during the whole episode, HER can make good use of these trajectorys by storing each trajectory in the replay buffer again but with different goals which are achieved by some states in the trajectory. 34 | So you can know for sure that there have some transition with positive reward are stored in the replay buffer after every episode is finished. 35 | The key for HER to work is that these goals should be correlated in a resonable way so that learning to behave well on one of them can also help to behave well on another one, so I express reservations about the authors opinion: using HER requires less domain knowledge than redefining a shaped reward. 36 | -------------------------------------------------------------------------------- /DQNwithNoisyNet/NoisyLayer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.nn.parameter import Parameter 4 | from torch import nn 5 | import torch.nn.functional as F 6 | 7 | 8 | def f(input): 9 | sign = torch.sign(input) 10 | return sign * (torch.sqrt(torch.abs(input))) 11 | 12 | 13 | class NoisyLinear(nn.Module): 14 | def __init__(self, in_features, out_features, bias=True,sig0=0.5): 15 | super().__init__() 16 | self.in_features = in_features 17 | self.out_features = out_features 18 | self.weight_mu = Parameter(torch.Tensor(out_features, in_features)) 19 | self.weight_sig = Parameter(torch.Tensor(out_features, in_features)) 20 | if bias: 21 | self.bias_mu = Parameter(torch.Tensor(out_features)) 22 | self.bias_sig = Parameter(torch.Tensor(out_features)) 23 | else: 24 | self.register_parameter('bias', None) 25 | self.bias_mu = None 26 | self.reset_parameters(sig0) 27 | self.dist = torch.distributions.Normal(0, 1) 28 | self.weight = None 29 | self.bias = None 30 | self.sample() 31 | 32 | def reset_parameters(self,sig0): 33 | stdv = 1. / math.sqrt(self.weight_mu.size(1)) 34 | self.weight_mu.data.uniform_(-stdv, stdv) 35 | self.weight_sig.data = self.weight_sig.data.zero_() + sig0 / self.weight_mu.shape[1] 36 | 37 | if self.bias_mu is not None: 38 | self.bias_mu.data.uniform_(-stdv, stdv) 39 | self.bias_sig.data.zero_() 40 | self.bias_sig.data = self.bias_sig.data.zero_() + sig0 / self.weight_mu.shape[1] 41 | 42 | def sample(self): 43 | size_in = self.in_features 44 | size_out = self.out_features 45 | noise_in = f(self.dist.sample((1, size_in))) 46 | noise_out = f(self.dist.sample((1, size_out))) 47 | self.weight = self.weight_mu + self.weight_sig * torch.mm(noise_out.t(), noise_in) 48 | if self.bias_mu is not None: 49 | self.bias = (self.bias_mu + self.bias_sig * noise_out).squeeze() 50 | 51 | def forward(self, input): 52 | if self.bias_mu is not None: 53 | return F.linear(input, self.weight, self.bias) 54 | else: 55 | return F.linear(input, self.weight) 56 | 57 | def randomness(self): 58 | size_in = self.in_features 59 | size_out = self.out_features 60 | return torch.abs(self.bias_sig.data/self.bias_mu.data).numpy().sum()/size_out#+torch.abs(self.weight_sig.data/self.weight_mu.data).numpy().sum()/(size_in*size_out) 61 | 62 | def extra_repr(self): 63 | return 'in_features={}, out_features={}, bias={}'.format( 64 | self.in_features, self.out_features, self.bias is not None 65 | ) 66 | 67 | 68 | if __name__ == "__main__": 69 | a = torch.Tensor([[1, -2, 3]]) 70 | b = torch.Tensor([1, 2, 3]) 71 | n = NoisyLinear(3, 100) 72 | 73 | print(n.bias_sig.data.zero_()) 74 | print(n.weight_sig.data.zero_()) 75 | print(n.randomness()) 76 | -------------------------------------------------------------------------------- /DQNwithNoisyNet/DQN_NoisyNet.py: -------------------------------------------------------------------------------- 1 | import random 2 | from torch import optim 3 | import torch 4 | import math 5 | import sys 6 | from os import path 7 | parent_dir = path.dirname(path.dirname(path.abspath(__file__))) 8 | if parent_dir not in sys.path: 9 | sys.path.append(parent_dir) 10 | from Common.prioritized_memory import Memory, WeightedMSE 11 | from Common.ValueCaculator import ValueCalculator1 as VC1 12 | from Common.ValueCaculator import ValueCalculator2 as VC2 13 | 14 | 15 | class DeepQL: 16 | def __init__(self, Net, actionFinder=None,eps=0.9, lr=5e-3, gamma=0.9, mbsize=20, C=100, N=500, L2=0): 17 | self.eps = eps 18 | self.gamma = gamma 19 | self.mbsize = mbsize 20 | self.C = C # for target replacement 21 | self.c = 0 22 | self.replay = Memory(capacity=N) 23 | self.loss = WeightedMSE() 24 | self.actionFinder = actionFinder 25 | self.vc =VC1(Net,actionFinder) if actionFinder else VC2(Net) 26 | self.opt = optim.Adam(self.vc.predictNet.parameters(), lr=lr, weight_decay=L2) 27 | self.noisy = hasattr(self.vc.predictNet, "sample") 28 | # (state:tensor => Action :List[List]) 29 | 30 | def act(self, state): 31 | state = torch.Tensor(state) 32 | A = self.vc.sortedA(state) 33 | if self.noisy: 34 | self.vc.predictNet.sample() 35 | return A[0] 36 | r = random.random() 37 | a = A[0] if self.eps > r else random.sample(A, 1)[0] 38 | return a 39 | 40 | def sample(self): 41 | return self.replay.sample(self.mbsize) 42 | 43 | def store(self, data): 44 | self.replay.add(data) 45 | 46 | def storeTransition(self, s, a, r, s_, done): 47 | s = torch.Tensor(s) 48 | s_ = torch.Tensor(s_) 49 | self.store((s, a, r, s_, done)) 50 | 51 | def calcTD(self, samples): 52 | if self.noisy: 53 | self.vc.predictNet.sample() # for choosing action 54 | alls, alla, allr, alls_, alldone, *_ = zip(*samples) 55 | maxA = [self.vc.sortedA(s_)[0] for s_ in alls_] 56 | if self.noisy: 57 | self.vc.predictNet.sample() # for prediction 58 | self.vc.targetNet.sample() # for target 59 | 60 | Qtarget = torch.Tensor(allr) 61 | Qtarget[torch.tensor(alldone) != 1] += self.gamma * self.vc.calcQ(self.vc.targetNet, alls_, maxA)[ 62 | torch.tensor(alldone) != 1] 63 | Qpredict = self.vc.calcQ(self.vc.predictNet, alls, alla) 64 | return Qpredict, Qtarget 65 | 66 | def update(self): 67 | self.opt.zero_grad() 68 | samples, idxs, IS = self.sample() 69 | Qpredict, Qtarget = self.calcTD(samples) 70 | 71 | for i in range(self.mbsize): 72 | error = math.fabs(float(Qpredict[i] - Qtarget[i])) 73 | self.replay.update(idxs[i], error) 74 | 75 | J = self.loss(Qpredict, Qtarget, IS) 76 | J.backward() 77 | self.opt.step() 78 | 79 | if self.c >= self.C: 80 | self.c = 0 81 | self.vc.updateTargetNet() 82 | else: 83 | self.c += 1 -------------------------------------------------------------------------------- /Common/ValueCaculator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class Net(nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | self.fc1_s = nn.Linear(4, 40) 9 | self.fc1_a = nn.Linear(1, 40) 10 | self.fc2 = nn.Linear(40, 1) 11 | 12 | def forward(self, s, a): 13 | x = self.fc1_s(s) + self.fc1_a(a) 14 | x = F.relu(x) 15 | x = self.fc2(x) 16 | return x 17 | 18 | 19 | class Net2(nn.Module): 20 | def __init__(self): 21 | super().__init__() 22 | self.fc1 = nn.Linear(4, 40) 23 | self.fc2 = nn.Linear(40, 3) 24 | 25 | def forward(self, s): 26 | x = self.fc1(s) 27 | x = F.relu(x) 28 | x = self.fc2(x) 29 | return x 30 | 31 | # (s,a) => Q(s,a) 32 | class ValueCalculator1: 33 | def __init__(self, Net, actionFinder): 34 | self.predictNet = Net() 35 | self.targetNet = Net() 36 | self.actionFinder = actionFinder 37 | self.updateTargetNet() 38 | # (state => Action :List[List]) 39 | 40 | def calcQ(self, net, s, A): 41 | # 1.single state, one or multiple actions 42 | # 2.muliplte states, one action per state, s must be a list of tensors 43 | if isinstance(s, torch.Tensor) and s.dim() == 1: # situation 1 44 | A = torch.Tensor(A) 45 | if A.dim() == 1: 46 | return net(s, A)[0] 47 | return torch.Tensor([net(s, a) for a in A]) 48 | 49 | if not isinstance(s, torch.Tensor): # situation 2 50 | s = torch.stack(s) 51 | a = torch.Tensor(A) 52 | return net(s, a).squeeze() 53 | # [[10.2],[5.3]] => [10.2,5.3] 54 | 55 | def sortedA(self, state): 56 | # return sorted action 57 | net = self.predictNet 58 | net.eval() 59 | A = self.actionFinder(state) 60 | Q = self.calcQ(net, state, A) 61 | A = [a for q,a in sorted(zip(Q, A),reverse=True,key=lambda x:x[0])] 62 | net.train() 63 | return A 64 | 65 | def updateTargetNet(self): 66 | self.targetNet.load_state_dict(self.predictNet.state_dict()) 67 | self.targetNet.eval() 68 | 69 | 70 | # s => Q(s,a1), Q(s,a2)... 71 | class ValueCalculator2: 72 | def __init__(self, Net): 73 | self.predictNet = Net() 74 | self.targetNet = Net() 75 | *_, last = self.predictNet.children() 76 | self.A = list(range(last.out_features)) 77 | self.updateTargetNet() 78 | 79 | def calcQ(self, net, s, A): 80 | # 1.single state, one or multiple actions 81 | # 2.muliplte states, one action per state, s must be a list of tensors 82 | if isinstance(s, torch.Tensor) and s.dim() == 1: # situation 1 83 | return torch.Tensor([net(s)[a] for a in A]).squeeze() 84 | 85 | if not isinstance(s, torch.Tensor): # situation 2 86 | s = torch.stack(s) 87 | Q = net(s) 88 | A = [a[0] for a in A] 89 | return Q[[i for i in range(len(A))], A] 90 | 91 | def sortedA(self, state): 92 | net = self.predictNet 93 | net.eval() 94 | Q = self.calcQ(net, state, self.A) 95 | A = [[a] for q,a in sorted(zip(Q,self.A),reverse=True)] 96 | net.train() 97 | return A 98 | 99 | def updateTargetNet(self): 100 | self.targetNet.load_state_dict(self.predictNet.state_dict()) 101 | self.targetNet.eval() 102 | 103 | if __name__ == "__main__": 104 | actionFinder = lambda x:[[0],[1],[2]] 105 | v1 = ValueCalculator1(Net,actionFinder) 106 | v2 = ValueCalculator2(Net2) 107 | s = torch.Tensor([1,2,3,4]) 108 | print(v1.calcQ(v1.predictNet,s,actionFinder(s))) 109 | print(v1.sortedA(s)) 110 | print(v2.calcQ(v2.predictNet,s,v2.A)) 111 | print(v2.sortedA(s)) 112 | -------------------------------------------------------------------------------- /DQNwithNoisyNet/Test/MountainCar.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from os import path 3 | parent_dir = path.dirname(path.dirname(path.dirname(path.abspath(__file__)))) 4 | if parent_dir not in sys.path: 5 | sys.path.append(parent_dir) 6 | 7 | import gym 8 | import torch 9 | import matplotlib.pyplot as plt 10 | import math 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from DQNwithNoisyNet.NoisyLayer import NoisyLinear 14 | from DQNwithNoisyNet import DQN_NoisyNet 15 | 16 | class NoisyNet(nn.Module): 17 | def __init__(self): 18 | super().__init__() 19 | self.fc1_s = NoisyLinear(2, 40) 20 | self.fc1_a = NoisyLinear(1, 40) 21 | self.fc2 = NoisyLinear(40, 1) 22 | 23 | def forward(self, s, a): 24 | x = self.fc1_s(s) + self.fc1_a(a) 25 | x = F.relu(x) 26 | x = self.fc2(x) 27 | return x 28 | 29 | def sample(self): 30 | for layer in self.children(): 31 | if hasattr(layer, "sample"): 32 | layer.sample() 33 | 34 | 35 | class NoisyNet2(nn.Module): 36 | def __init__(self): 37 | super().__init__() 38 | self.fc1 = NoisyLinear(2, 40) 39 | self.fc2 = NoisyLinear(40, 3) 40 | 41 | def forward(self, s): 42 | x = self.fc1(s) 43 | x = F.relu(x) 44 | x = self.fc2(x) 45 | return x 46 | 47 | def sample(self): 48 | for layer in self.children(): 49 | if hasattr(layer, "sample"): 50 | layer.sample() 51 | 52 | class Net(nn.Module): 53 | def __init__(self): 54 | super().__init__() 55 | self.fc1_s = nn.Linear(2, 40) 56 | self.fc1_a = nn.Linear(1, 40) 57 | self.fc2 = nn.Linear(40, 1) 58 | 59 | def forward(self, s, a): 60 | x = self.fc1_s(s) + self.fc1_a(a) 61 | x = F.relu(x) 62 | x = self.fc2(x) 63 | return x 64 | 65 | 66 | class Net2(nn.Module): 67 | def __init__(self): 68 | super().__init__() 69 | self.fc1 = nn.Linear(2, 40) 70 | self.fc2 = nn.Linear(40, 3) 71 | 72 | def forward(self, s): 73 | x = self.fc1(s) 74 | x = F.relu(x) 75 | x = self.fc2(x) 76 | return x 77 | 78 | 79 | if __name__ == "__main__": 80 | env = gym.make('MountainCar-v0') 81 | env = env.unwrapped 82 | s = env.reset() 83 | s = torch.Tensor(s) 84 | A=[[0],[1],[2]] 85 | dqn = DQN_NoisyNet.DeepQL(NoisyNet, lr=0.001, gamma=0.9, N=10000, C=500, actionFinder=lambda x:A) 86 | #dqn = DQN_NoisyNet.DeepQLv2(NoisyNet2, lr=0.001, gamma=0.9, N=10000, C=500, actionFinder=lambda x:A) 87 | process = [] 88 | epoch = 80 89 | eps_start = 0.05 90 | eps_end = 0.95 91 | N = 1 - eps_start 92 | lam = -math.log((1 - eps_end) / N) / epoch 93 | total = 0 94 | dqn.replay.beta_increment_per_sampling = 0 95 | 96 | for i in range(epoch): 97 | dqn.eps = 1 - N * math.exp(-lam * i) 98 | dqn.replay.beta = dqn.replay.beta + 1.1/epoch if dqn.replay.beta < 1 else 1 99 | 100 | total = 0 101 | print(i, dqn.eps, dqn.replay.beta) 102 | while True: 103 | if total % 5000 == 0: 104 | print("trianing process total:", total) 105 | a = dqn.act(s) 106 | s_, r, done, _ = env.step(a[0]) 107 | total += 1 108 | r = 10 if done else -1 109 | dqn.storeTransition(s, a, r, s_, done) 110 | dqn.update() 111 | s=s_ 112 | 113 | if done: 114 | s = env.reset() 115 | print('finish total:', total) 116 | process.append(total) 117 | break 118 | 119 | plt.plot(process) 120 | plt.show() 121 | env.close() 122 | 123 | # torch.save(dqn.net.state_dict(),"./model.txt") 124 | # dqn.eps=1 125 | total = 0 126 | # dqn.net.load_state_dict(torch.load("./model.txt")) 127 | s = env.reset() 128 | s = torch.Tensor(s) 129 | while True: 130 | a = dqn.act(s)[0] 131 | s, r, done, _ = env.step(a) 132 | total += 1 133 | s = torch.Tensor(s) 134 | env.render() 135 | if done: 136 | s = env.reset() 137 | s = torch.Tensor(s) 138 | print(total) 139 | total = 0 140 | 141 | env.close() 142 | -------------------------------------------------------------------------------- /DQNwithNoisyNet/Test/CartPole.py: -------------------------------------------------------------------------------- 1 | from os import path 2 | import sys 3 | parent_dir = path.dirname(path.dirname(path.dirname(path.abspath(__file__)))) 4 | if parent_dir not in sys.path: 5 | sys.path.append(parent_dir) 6 | 7 | import gym 8 | import torch 9 | import matplotlib.pyplot as plt 10 | import math 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from DQNwithNoisyNet.NoisyLayer import NoisyLinear 14 | from DQNwithNoisyNet import DQN_NoisyNet 15 | from operator import methodcaller 16 | 17 | 18 | class Net(nn.Module): 19 | def __init__(self): 20 | super().__init__() 21 | self.fc1_s = nn.Linear(4, 40) 22 | self.fc1_a = nn.Linear(1, 40) 23 | self.fc2 = nn.Linear(40, 1) 24 | 25 | def forward(self, s, a): 26 | x = self.fc1_s(s) + self.fc1_a(a) 27 | x = F.relu(x) 28 | x = self.fc2(x) 29 | return x 30 | 31 | 32 | class Net2(nn.Module): 33 | def __init__(self): 34 | super().__init__() 35 | self.fc1 = nn.Linear(4, 40) 36 | self.fc2 = nn.Linear(40, 2) 37 | 38 | def forward(self, s): 39 | x = self.fc1(s) 40 | x = F.relu(x) 41 | x = self.fc2(x) 42 | return x 43 | 44 | 45 | class NoisyNet(nn.Module): 46 | def __init__(self): 47 | super().__init__() 48 | self.fc1_s = NoisyLinear(4, 40,) 49 | self.fc1_a = NoisyLinear(1, 40,) 50 | self.fc2 = NoisyLinear(40, 1) 51 | 52 | def forward(self, s, a): 53 | x = self.fc1_s(s) + self.fc1_a(a) 54 | x = F.relu(x) 55 | x = self.fc2(x) 56 | return x 57 | 58 | def sample(self): 59 | for layer in self.children(): 60 | if hasattr(layer, "sample"): 61 | layer.sample() 62 | 63 | 64 | class NoisyNet2(nn.Module): 65 | def __init__(self): 66 | super().__init__() 67 | self.fc1 = NoisyLinear(4, 40) 68 | self.fc2 = NoisyLinear(40, 2) 69 | 70 | def forward(self, s): 71 | x = self.fc1(s) 72 | x = F.relu(x) 73 | x = self.fc2(x) 74 | return x 75 | 76 | def sample(self): 77 | for layer in self.children(): 78 | if hasattr(layer, "sample"): 79 | layer.sample() 80 | 81 | 82 | if __name__ == "__main__": 83 | env = gym.make('CartPole-v1') 84 | s = env.reset() 85 | A = [[0], [1]] 86 | actionFinder = lambda x:A 87 | dqn = DQN_NoisyNet.DeepQL(NoisyNet,lr=0.002, gamma=1,actionFinder=actionFinder) 88 | #dqn = DQN_NoisyNet.DeepQLv2(NoisyNet2,lr=0.003, gamma=1) 89 | process = [] 90 | randomness = [] 91 | epoch = 200 92 | eps_start = 0.05 93 | eps_end = 0.95 94 | N = 1 - eps_start 95 | lam = -math.log((1 - eps_end) / N) / epoch 96 | total = 0 97 | count = 0 # successful count 98 | 99 | for i in range(epoch): 100 | print(i) 101 | if dqn.noisy: 102 | m = methodcaller("randomness") 103 | *_, rlast = map(m, dqn.vc.predictNet.children()) # only record the randomness of last layer 104 | randomness.append(rlast) 105 | dqn.eps = 1 - N * math.exp(-lam * i) 106 | count = count + 1 if total >= 500 else 0 107 | if count >= 2: 108 | dqn.eps = 1 109 | break 110 | total = 0 111 | while True: 112 | a = dqn.act(s) 113 | s_, r, done, _ = env.step(a[0]) 114 | total += r 115 | r = -1 if done and total < 500 else 0.002 116 | dqn.storeTransition(s, a, r, s_, done) 117 | dqn.update() 118 | s = s_ 119 | if done: 120 | s = env.reset() 121 | print('total:', total) 122 | process.append(total) 123 | break 124 | 125 | if dqn.noisy: 126 | plt.plot(randomness) 127 | plt.show() 128 | env.close() 129 | 130 | # torch.save(dqn.net.state_dict(),"./CartPoleExpert.txt") 131 | # dqn.eps=1 132 | total = 0 133 | # dqn.net.load_state_dict(torch.load("./CartPoleExpert.txt")) 134 | s = env.reset() 135 | s = torch.Tensor(s) 136 | while True: 137 | a = dqn.act(s)[0] 138 | s, r, done, _ = env.step(a) 139 | total += 1 140 | s = torch.Tensor(s) 141 | env.render() 142 | if done: 143 | s = env.reset() 144 | s = torch.Tensor(s) 145 | print(total) 146 | total = 0 147 | 148 | env.close() 149 | -------------------------------------------------------------------------------- /DQNfromDemo/Test/CartPole.py: -------------------------------------------------------------------------------- 1 | from os import path 2 | import sys 3 | 4 | local = path.abspath(__file__) 5 | root = path.dirname(path.dirname(path.dirname(local))) 6 | if root not in sys.path: 7 | sys.path.append(root) 8 | 9 | import gym 10 | import torch 11 | import matplotlib.pyplot as plt 12 | import math 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | from DQNwithNoisyNet.NoisyLayer import NoisyLinear 16 | from DQNfromDemo import DQfD 17 | import json 18 | 19 | 20 | def plotJE(dqn,color): 21 | tree=dqn.replay.tree 22 | data = [[d] for d in tree.data[0:500]] 23 | JE = list(map(dqn.JE, data)) 24 | plt.plot(JE, color=color) 25 | 26 | class Net(nn.Module): 27 | def __init__(self): 28 | super().__init__() 29 | self.fc1_s = nn.Linear(4, 40) 30 | self.fc1_a = nn.Linear(1, 40) 31 | self.fc2 = nn.Linear(40, 1) 32 | 33 | def forward(self, s, a): 34 | x = self.fc1_s(s) + self.fc1_a(a) 35 | x = F.relu(x) 36 | x = self.fc2(x) 37 | return x 38 | 39 | 40 | class Net2(nn.Module): 41 | def __init__(self): 42 | super().__init__() 43 | self.fc1 = nn.Linear(4, 40) 44 | self.fc2 = nn.Linear(40, 2) 45 | 46 | def forward(self, s): 47 | x = self.fc1(s) 48 | x = F.relu(x) 49 | x = self.fc2(x) 50 | return x 51 | 52 | 53 | class NoisyNet(nn.Module): 54 | def __init__(self): 55 | super().__init__() 56 | self.fc1_s = NoisyLinear(4, 40) 57 | self.fc1_a = NoisyLinear(1, 40) 58 | self.fc2 = NoisyLinear(40, 1) 59 | 60 | def forward(self, s, a): 61 | x = self.fc1_s(s) + self.fc1_a(a) 62 | x = F.relu(x) 63 | x = self.fc2(x) 64 | return x 65 | 66 | def sample(self): 67 | for layer in self.children(): 68 | if hasattr(layer, "sample"): 69 | layer.sample() 70 | 71 | 72 | class NoisyNet2(nn.Module): 73 | def __init__(self): 74 | super().__init__() 75 | self.fc1 = NoisyLinear(4, 40) 76 | self.fc2 = NoisyLinear(40, 2) 77 | 78 | def forward(self, s): 79 | x = self.fc1(s) 80 | x = F.relu(x) 81 | x = self.fc2(x) 82 | return x 83 | 84 | def sample(self): 85 | for layer in self.children(): 86 | if hasattr(layer, "sample"): 87 | layer.sample() 88 | 89 | 90 | if __name__ == "__main__": 91 | env = gym.make('CartPole-v1') 92 | s = env.reset() 93 | A = [[0], [1]] 94 | af = lambda x:A 95 | dqn = DQfD.DeepQL(Net2, lr=0.002, gamma=1.0, actionFinder=None, N=5000,n_step=1) 96 | process = [] 97 | randomness = [] 98 | epoch = 100 99 | eps_start = 0.05 100 | eps_end = 0.95 101 | N = 1 - eps_start 102 | lam = -math.log((1 - eps_end) / N) / epoch 103 | total = 0 104 | count = 0 # successful count 105 | start = 0 106 | with open("CartPoleDemo.txt", "r") as file: 107 | data = json.load(file) 108 | for k, v in data.items(): 109 | for s, a, r, s_, done in v: 110 | start += 1 111 | dqn.storeDemoTransition(s, a, r, s_, done, int(k)) 112 | 113 | dqn.replay.tree.start = start 114 | for i in range(500): 115 | if i % 100 == 0: 116 | print("pretraining:", i) 117 | dqn.update() 118 | 119 | for i in range(epoch): 120 | print(i) 121 | dqn.eps = 1 - N * math.exp(-lam * i) 122 | dqn.eps = 0.9 123 | count = count + 1 if total >= 500 else 0 124 | if count >= 2: 125 | dqn.eps = 1 126 | break 127 | total = 0 128 | while True: 129 | a = dqn.act(s) 130 | s_, r, done, _ = env.step(a[0]) 131 | total += r 132 | r = -1 if done and total < 500 else 0.002 133 | dqn.storeTransition(s, a, r, s_, done) 134 | dqn.update() 135 | s = s_ 136 | if done: 137 | s = env.reset() 138 | print('total:', total) 139 | process.append(total) 140 | break 141 | 142 | plt.show() 143 | total = 0 144 | s = env.reset() 145 | dqn.eps = 1 146 | while True: 147 | a = dqn.act(s)[0] 148 | s, r, done, _ = env.step(a) 149 | total += 1 150 | env.render() 151 | if done: 152 | s = env.reset() 153 | print(total) 154 | total = 0 155 | 156 | env.close() 157 | -------------------------------------------------------------------------------- /DQNfromDemo/DQfD.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from os import path 3 | 4 | local = path.abspath(__file__) 5 | root = path.dirname(path.dirname(local)) 6 | if root not in sys.path: 7 | sys.path.append(root) 8 | 9 | from Common.prioritized_memory import Memory, WeightedMSE 10 | from Common.ValueCaculator import ValueCalculator1 as VC1 11 | from Common.ValueCaculator import ValueCalculator2 as VC2 12 | import torch 13 | import math 14 | import random 15 | from torch import optim 16 | from collections import defaultdict as ddict 17 | from functools import reduce 18 | import numpy as np 19 | 20 | 21 | class DeepQL: 22 | def __init__(self, Net, actionFinder=None, eps=0.9, lr=5e-3, gamma=0.9, mbsize=20, C=100, N=500, lambda1=1.0, 23 | lambda2=1.0, lambda3=1e-5, n_step=3): 24 | self.eps = eps # eps-greedy 25 | self.gamma = gamma # discount factor 26 | self.mbsize = mbsize # minibatch size 27 | self.C = C # frequenct of target replacement 28 | self.c = 0 # target replacement counter 29 | self.replay = Memory(capacity=N) 30 | self.loss = WeightedMSE() 31 | self.actionFinder = actionFinder # (state:tensor => Action :List[List]) 32 | self.vc = VC1(Net, actionFinder) if actionFinder else VC2(Net) 33 | self.opt = optim.Adam(self.vc.predictNet.parameters(), lr=lr, weight_decay=lambda3) 34 | self.noisy = hasattr(self.vc.predictNet, "sample") 35 | self.ed = 0.005 # bonus for demonstration 36 | self.ea = 0.001 37 | self.margin = 0.8 38 | self.n_step = n_step 39 | self.lambda1 = lambda1 # n-step return 40 | self.lambda2 = lambda2 # supervised loss 41 | self.lambda3 = lambda3 # L2 42 | self.replay.e = 0 43 | self.demoReplay = ddict(list) 44 | 45 | def act(self, state): 46 | state = torch.Tensor(state) 47 | A = self.vc.sortedA(state) 48 | if self.noisy: 49 | self.vc.predictNet.sample() 50 | return A[0] 51 | r = random.random() 52 | a = A[0] if self.eps > r else random.sample(A, 1)[0] 53 | return a 54 | 55 | def sample(self): 56 | return self.replay.sample(self.mbsize) 57 | 58 | def store(self, data): 59 | self.replay.add(data) 60 | 61 | def storeDemoTransition(self, s, a, r, s_, done, demoEpisode): 62 | s = torch.Tensor(s) 63 | s_ = torch.Tensor(s_) 64 | episodeReplay = self.demoReplay[demoEpisode] # replay of certain demo episode 65 | index = len(episodeReplay) 66 | data = (s, a, r, s_, done, (demoEpisode, index)) 67 | episodeReplay.append(data) 68 | self.store(data) 69 | 70 | def storeTransition(self, s, a, r, s_, done): 71 | s = torch.Tensor(s) 72 | s_ = torch.Tensor(s_) 73 | self.store((s, a, r, s_, done, None)) 74 | 75 | def calcTD(self, samples): 76 | if self.noisy: 77 | self.vc.predictNet.sample() # for choosing action 78 | alls, alla, allr, alls_, alldone, *_ = zip(*samples) 79 | maxA = [self.vc.sortedA(s_)[0] for s_ in alls_] 80 | if self.noisy: 81 | self.vc.predictNet.sample() # for prediction 82 | self.vc.targetNet.sample() # for target 83 | 84 | Qtarget = torch.Tensor(allr) 85 | Qtarget[torch.tensor(alldone) != 1] += self.gamma * self.vc.calcQ(self.vc.targetNet, alls_, maxA)[ 86 | torch.tensor(alldone) != 1] 87 | Qpredict = self.vc.calcQ(self.vc.predictNet, alls, alla) 88 | return Qpredict, Qtarget 89 | 90 | def JE(self, samples): 91 | loss = torch.tensor(0.0) 92 | count = 0 # number of demo 93 | for s, aE, *_, isdemo in samples: 94 | if isdemo is None: 95 | continue 96 | A = self.vc.sortedA(s) 97 | if len(A) == 1: 98 | continue 99 | QE = self.vc.calcQ(self.vc.predictNet, s, aE) 100 | A1, A2 = np.array(A)[:2] # action with largest and second largest Q 101 | maxA = A2 if (A1 == aE).all() else A1 102 | Q = self.vc.calcQ(self.vc.predictNet, s, maxA) 103 | if (Q + self.margin) < QE: 104 | continue 105 | else: 106 | loss += (Q - QE) 107 | count += 1 108 | return loss / count if count != 0 else loss 109 | 110 | def Jn(self, samples, Qpredict): 111 | # wait for refactoring, can't use with noisy layer 112 | loss = torch.tensor(0.0) 113 | count = 0 114 | for i,(s, a, r, s_, done, isdemo) in enumerate(samples): 115 | if isdemo is None: 116 | continue 117 | episode, idx = isdemo 118 | nidx = idx + self.n_step 119 | lepoch = len(self.demoReplay[episode]) 120 | if nidx > lepoch: 121 | continue 122 | count += 1 123 | ns, na, nr, ns_, ndone, _ = zip(*self.demoReplay[episode][idx:nidx]) 124 | ns, na, ns_, ndone = ns[-1], na[-1], ns_[-1], ndone[-1] 125 | discountedR = reduce(lambda x, y: (x[0] + self.gamma ** x[1] * y, x[1] + 1), nr, (0, 0))[0] 126 | maxA = self.vc.sortedA(ns_)[0] 127 | target = discountedR if ndone else discountedR + self.gamma ** self.n_step * self.vc.calcQ( 128 | self.vc.targetNet, ns_, 129 | maxA) 130 | predict = Qpredict[i] 131 | loss += (target - predict) ** 2 132 | return loss / count 133 | 134 | def update(self): 135 | self.opt.zero_grad() 136 | samples, idxs, IS = self.sample() 137 | Qpredict, Qtarget = self.calcTD(samples) 138 | 139 | for i in range(self.mbsize): 140 | error = math.fabs(float(Qpredict[i] - Qtarget[i])) 141 | self.replay.update(idxs[i], error) 142 | 143 | Jtd = self.loss(Qpredict, Qtarget, IS*0+1) 144 | JE = self.JE(samples) 145 | Jn = self.Jn(samples,Qpredict) 146 | J = Jtd + self.lambda2 * JE + self.lambda1 * Jn 147 | J.backward() 148 | self.opt.step() 149 | 150 | if self.c >= self.C: 151 | self.c = 0 152 | self.vc.updateTargetNet() 153 | else: 154 | self.c += 1 155 | -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 189 | 190 | 191 | 192 | Linear 193 | 194 | 195 | 196 | 198 | 199 | 211 | 212 | 213 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 |