├── custom_gym
└── envs
│ ├── custom_env_dir
│ ├── __init__.py
│ ├── episode9.pth
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ └── custom_env.cpython-36.pyc
│ └── custom_env.py
│ ├── __pycache__
│ └── __init__.cpython-36.pyc
│ └── __init__.py
├── plots
├── updates.png
├── 0updates.png
├── 10updates.png
├── 11updates.png
├── 12updates.png
├── 13updates.png
├── 14updates.png
├── 15updates.png
├── 16updates.png
├── 17updates.png
├── 18updates.png
├── 19updates.png
├── 1updates.png
├── 20updates.png
├── 21updates.png
├── 22updates.png
├── 23updates.png
├── 24updates.png
├── 25updates.png
├── 26updates.png
├── 27updates.png
├── 28updates.png
├── 29updates.png
├── 2updates.png
├── 30updates.png
├── 31updates.png
├── 32updates.png
├── 33updates.png
├── 34updates.png
├── 35updates.png
├── 36updates.png
├── 37updates.png
├── 38updates.png
├── 39updates.png
├── 3updates.png
├── 40updates.png
├── 41updates.png
├── 42updates.png
├── 43updates.png
├── 4updates.png
├── 5updates.png
├── 6updates.png
├── 7updates.png
├── 8updates.png
└── 9updates.png
├── model_store
├── episode0.pth
├── episode1.pth
├── episode10.pth
├── episode11.pth
├── episode12.pth
├── episode13.pth
├── episode14.pth
├── episode15.pth
├── episode16.pth
├── episode17.pth
├── episode18.pth
├── episode19.pth
├── episode2.pth
├── episode20.pth
├── episode21.pth
├── episode22.pth
├── episode23.pth
├── episode24.pth
├── episode25.pth
├── episode26.pth
├── episode27.pth
├── episode28.pth
├── episode3.pth
├── episode4.pth
├── episode5.pth
├── episode6.pth
├── episode7.pth
├── episode8.pth
└── episode9.pth
├── __pycache__
├── DDPG.cpython-36.pyc
├── model.cpython-36.pyc
└── utils.cpython-36.pyc
├── README.md
├── model.py
├── main.py
├── utils.py
└── DDPG.py
/custom_gym/envs/custom_env_dir/__init__.py:
--------------------------------------------------------------------------------
1 | from custom_env import CustomEnv
2 |
--------------------------------------------------------------------------------
/plots/updates.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/updates.png
--------------------------------------------------------------------------------
/plots/0updates.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/0updates.png
--------------------------------------------------------------------------------
/plots/10updates.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/10updates.png
--------------------------------------------------------------------------------
/plots/11updates.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/11updates.png
--------------------------------------------------------------------------------
/plots/12updates.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/12updates.png
--------------------------------------------------------------------------------
/plots/13updates.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/13updates.png
--------------------------------------------------------------------------------
/plots/14updates.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/14updates.png
--------------------------------------------------------------------------------
/plots/15updates.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/15updates.png
--------------------------------------------------------------------------------
/plots/16updates.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/16updates.png
--------------------------------------------------------------------------------
/plots/17updates.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/17updates.png
--------------------------------------------------------------------------------
/plots/18updates.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/18updates.png
--------------------------------------------------------------------------------
/plots/19updates.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/19updates.png
--------------------------------------------------------------------------------
/plots/1updates.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/1updates.png
--------------------------------------------------------------------------------
/plots/20updates.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/20updates.png
--------------------------------------------------------------------------------
/plots/21updates.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/21updates.png
--------------------------------------------------------------------------------
/plots/22updates.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/22updates.png
--------------------------------------------------------------------------------
/plots/23updates.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/23updates.png
--------------------------------------------------------------------------------
/plots/24updates.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/24updates.png
--------------------------------------------------------------------------------
/plots/25updates.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/25updates.png
--------------------------------------------------------------------------------
/plots/26updates.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/26updates.png
--------------------------------------------------------------------------------
/plots/27updates.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/27updates.png
--------------------------------------------------------------------------------
/plots/28updates.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/28updates.png
--------------------------------------------------------------------------------
/plots/29updates.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/29updates.png
--------------------------------------------------------------------------------
/plots/2updates.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/2updates.png
--------------------------------------------------------------------------------
/plots/30updates.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/30updates.png
--------------------------------------------------------------------------------
/plots/31updates.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/31updates.png
--------------------------------------------------------------------------------
/plots/32updates.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/32updates.png
--------------------------------------------------------------------------------
/plots/33updates.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/33updates.png
--------------------------------------------------------------------------------
/plots/34updates.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/34updates.png
--------------------------------------------------------------------------------
/plots/35updates.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/35updates.png
--------------------------------------------------------------------------------
/plots/36updates.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/36updates.png
--------------------------------------------------------------------------------
/plots/37updates.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/37updates.png
--------------------------------------------------------------------------------
/plots/38updates.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/38updates.png
--------------------------------------------------------------------------------
/plots/39updates.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/39updates.png
--------------------------------------------------------------------------------
/plots/3updates.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/3updates.png
--------------------------------------------------------------------------------
/plots/40updates.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/40updates.png
--------------------------------------------------------------------------------
/plots/41updates.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/41updates.png
--------------------------------------------------------------------------------
/plots/42updates.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/42updates.png
--------------------------------------------------------------------------------
/plots/43updates.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/43updates.png
--------------------------------------------------------------------------------
/plots/4updates.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/4updates.png
--------------------------------------------------------------------------------
/plots/5updates.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/5updates.png
--------------------------------------------------------------------------------
/plots/6updates.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/6updates.png
--------------------------------------------------------------------------------
/plots/7updates.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/7updates.png
--------------------------------------------------------------------------------
/plots/8updates.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/8updates.png
--------------------------------------------------------------------------------
/plots/9updates.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/9updates.png
--------------------------------------------------------------------------------
/model_store/episode0.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode0.pth
--------------------------------------------------------------------------------
/model_store/episode1.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode1.pth
--------------------------------------------------------------------------------
/model_store/episode10.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode10.pth
--------------------------------------------------------------------------------
/model_store/episode11.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode11.pth
--------------------------------------------------------------------------------
/model_store/episode12.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode12.pth
--------------------------------------------------------------------------------
/model_store/episode13.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode13.pth
--------------------------------------------------------------------------------
/model_store/episode14.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode14.pth
--------------------------------------------------------------------------------
/model_store/episode15.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode15.pth
--------------------------------------------------------------------------------
/model_store/episode16.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode16.pth
--------------------------------------------------------------------------------
/model_store/episode17.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode17.pth
--------------------------------------------------------------------------------
/model_store/episode18.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode18.pth
--------------------------------------------------------------------------------
/model_store/episode19.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode19.pth
--------------------------------------------------------------------------------
/model_store/episode2.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode2.pth
--------------------------------------------------------------------------------
/model_store/episode20.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode20.pth
--------------------------------------------------------------------------------
/model_store/episode21.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode21.pth
--------------------------------------------------------------------------------
/model_store/episode22.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode22.pth
--------------------------------------------------------------------------------
/model_store/episode23.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode23.pth
--------------------------------------------------------------------------------
/model_store/episode24.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode24.pth
--------------------------------------------------------------------------------
/model_store/episode25.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode25.pth
--------------------------------------------------------------------------------
/model_store/episode26.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode26.pth
--------------------------------------------------------------------------------
/model_store/episode27.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode27.pth
--------------------------------------------------------------------------------
/model_store/episode28.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode28.pth
--------------------------------------------------------------------------------
/model_store/episode3.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode3.pth
--------------------------------------------------------------------------------
/model_store/episode4.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode4.pth
--------------------------------------------------------------------------------
/model_store/episode5.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode5.pth
--------------------------------------------------------------------------------
/model_store/episode6.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode6.pth
--------------------------------------------------------------------------------
/model_store/episode7.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode7.pth
--------------------------------------------------------------------------------
/model_store/episode8.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode8.pth
--------------------------------------------------------------------------------
/model_store/episode9.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode9.pth
--------------------------------------------------------------------------------
/__pycache__/DDPG.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/__pycache__/DDPG.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/__pycache__/model.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/custom_gym/envs/custom_env_dir/episode9.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/custom_gym/envs/custom_env_dir/episode9.pth
--------------------------------------------------------------------------------
/custom_gym/envs/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/custom_gym/envs/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/custom_gym/envs/custom_env_dir/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/custom_gym/envs/custom_env_dir/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/custom_gym/envs/custom_env_dir/__pycache__/custom_env.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arunajit/drl/HEAD/custom_gym/envs/custom_env_dir/__pycache__/custom_env.cpython-36.pyc
--------------------------------------------------------------------------------
/custom_gym/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from gym.envs.registration import register
2 | #env_name = 'CustomEnv-v0'
3 | #if env_name in gym.envs.registry.env_specs:
4 | # del gym.envs.registry.env_specs[env_name]
5 |
6 | register(id='CustomEnv-v0',
7 | entry_point='custom_env:CustomEnv'
8 | )
9 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # drl
2 | Deep Deterministic Policy Gradients (DDPG) for learning the thermal control policy.
3 |
4 |
5 | Deep Deterministic Policy Gradient:
6 | https://spinningup.openai.com/en/latest/algorithms/ddpg.html
7 |
8 | Research Paper:
9 | https://arxiv.org/abs/1901.04693
10 |
11 | PyTorch based Library:
12 | https://github.com/vitchyr/rlkit
13 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch.autograd
6 | from torch.autograd import Variable
7 | import pudb
8 |
9 | class Critic(nn.Module):
10 | def __init__(self, input_size, hidden_size, output_size):
11 | super(Critic, self).__init__()
12 | self.linear1 = nn.Linear(input_size, hidden_size)
13 | self.linear2 = nn.Linear(hidden_size, hidden_size)
14 | self.linear3 = nn.Linear(hidden_size, output_size)
15 |
16 | def forward(self, state, action):
17 | """
18 | Params state and actions are torch tensors
19 | """
20 | x = torch.cat([state, action], 1)
21 | x = F.relu(self.linear1(x))
22 | x = F.relu(self.linear2(x))
23 | x = self.linear3(x)
24 |
25 | return x
26 |
27 | class Actor(nn.Module):
28 | def __init__(self, input_size, hidden_size, output_size, learning_rate = 3e-4):
29 | super(Actor, self).__init__()
30 | self.linear1 = nn.Linear(input_size, hidden_size)
31 | self.linear2 = nn.Linear(hidden_size, hidden_size)
32 | self.linear3 = nn.Linear(hidden_size, output_size)
33 |
34 | def forward(self, state):
35 | """
36 | Param state is a torch tensor
37 | """
38 | # pu.db
39 | # print(state)
40 | x = F.relu(self.linear1(state))
41 | x = F.relu(self.linear2(x))
42 | x = torch.tanh(self.linear3(x))
43 |
44 | return x
45 |
46 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import gym
3 | import envs
4 | import numpy as np
5 | import pandas as pd
6 | import matplotlib.pyplot as plt
7 | from DDPG import DDPGagent
8 | from utils import *
9 | import pudb
10 | from tqdm import tqdm
11 | import torch
12 |
13 | from custom_gym import *
14 |
15 | data=pd.read_csv('hvac_data.csv')
16 |
17 | x=data.dropna(axis=0,how='any')
18 | m=x.columns
19 | p=x[x[m[3]]==max(x[m[3]])].index.values
20 | len(p)
21 | max(x[m[3]])
22 | x=x.drop(p)
23 |
24 | env = gym.make('CustomEnv-v0')
25 | env.pass_df(x)
26 | agent = DDPGagent(env)
27 | noise = OUNoise(env.action_space)
28 | batch_size = 128
29 | rewards = []
30 | avg_rewards = []
31 |
32 | for episode in tqdm(range(500)):
33 |
34 | state = env.reset()
35 | noise.reset()
36 | episode_reward = 0
37 |
38 | for step in range(48):
39 | action = agent.get_action(state)
40 | action = noise.get_action(action, step)
41 | new_state, reward, done, _ = env.step(action)
42 | agent.memory.push(state, action, reward, new_state, done)
43 |
44 | if len(agent.memory) > batch_size:
45 | agent.update(batch_size)
46 |
47 | state = new_state
48 | episode_reward += reward
49 |
50 | if step == 499:
51 | sys.stdout.write("episode: {}, reward: {}, average _reward: {} \n".format(episode, np.round(episode_reward, decimals=2), np.mean(rewards[-10:])))
52 | break
53 |
54 | rewards.append(episode_reward)
55 | avg_rewards.append(np.mean(rewards[-10:]))
56 | # except:
57 | # pass
58 |
59 | plt.plot(rewards)
60 | plt.plot(avg_rewards)
61 | plt.plot()
62 | plt.xlabel('Episode')
63 | plt.ylabel('Reward')
64 | plt.savefig("rewards.png")
65 | torch.save(agent.get_model().state_dict(), "models/"+str(episode)+".pth")
66 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import gym
3 | from collections import deque
4 | import random
5 |
6 | # Ornstein-Ulhenbeck Process
7 | # Credits: https://github.com/vitchyr/rlkit/blob/master/rlkit/exploration_strategies/ou_strategy.py
8 | class OUNoise(object):
9 | def __init__(self, action_space, mu=0.0, theta=0.15, max_sigma=0.3, min_sigma=0.3, decay_period=100000):
10 | self.mu = mu
11 | self.theta = theta
12 | self.sigma = max_sigma
13 | self.max_sigma = max_sigma
14 | self.min_sigma = min_sigma
15 | self.decay_period = decay_period
16 | self.action_dim = action_space.shape[0]
17 | self.reset()
18 |
19 | def reset(self):
20 | self.state = np.ones(self.action_dim) * self.mu
21 |
22 | def evolve_state(self):
23 | x = self.state
24 | dx = self.theta * (self.mu - x) + self.sigma * np.random.randn(self.action_dim)
25 | self.state = x + dx
26 | return self.state
27 |
28 | def get_action(self, action, t=0):
29 | ou_state = self.evolve_state()
30 | self.sigma = self.max_sigma - (self.max_sigma - self.min_sigma) * min(1.0, t / self.decay_period)
31 | return np.clip(action + ou_state)
32 |
33 |
34 | # Credits: https://github.com/openai/gym/blob/master/gym/core.py
35 | class NormalizedEnv(gym.ActionWrapper):
36 | """ Wrap action """
37 |
38 | def _action(self, action):
39 | act_k = 1./2.
40 | act_b = 3/2.
41 | return act_k * action + act_b
42 |
43 | def _reverse_action(self, action):
44 | act_k_inv = 2./1.
45 | act_b = 3/ 2.
46 | return act_k_inv * (action - act_b)
47 |
48 |
49 | class Memory:
50 | def __init__(self, max_size):
51 | self.max_size = max_size
52 | self.buffer = deque(maxlen=max_size)
53 |
54 | def push(self, state, action, reward, next_state, done):
55 | experience = (state, action, np.array([reward]), next_state, done)
56 | self.buffer.append(experience)
57 |
58 | def sample(self, batch_size):
59 | state_batch = []
60 | action_batch = []
61 | reward_batch = []
62 | next_state_batch = []
63 | done_batch = []
64 |
65 | batch = random.sample(self.buffer, batch_size)
66 |
67 | for experience in batch:
68 | state, action, reward, next_state, done = experience
69 | state_batch.append(state)
70 | action_batch.append(action)
71 | reward_batch.append(reward)
72 | next_state_batch.append(next_state)
73 | done_batch.append(done)
74 |
75 | return state_batch, action_batch, reward_batch, next_state_batch, done_batch
76 |
77 | def __len__(self):
78 | return len(self.buffer)
79 |
--------------------------------------------------------------------------------
/custom_gym/envs/custom_env_dir/custom_env.py:
--------------------------------------------------------------------------------
1 | import random
2 | import gym
3 | from gym import spaces
4 | import numpy as np
5 | import pandas as pd
6 | from gym.utils import seeding
7 | import math
8 | import io
9 | import torch
10 | from copy import deepcopy
11 | import scipy.io
12 | import pudb
13 | import json
14 |
15 |
16 |
17 | max_temp_in=60
18 | max_comfort=6
19 | min_comfort=0
20 | max_slots=3
21 | max_steps=500
22 | max_humid=100
23 | min_humid=0
24 | initial_temp=22
25 |
26 | class CustomEnv(gym.Env):
27 | def __init__(self):
28 | self.seed()
29 | super(CustomEnv,self).__init__()
30 | self.reward_range=(-10,10)
31 | self.action_space= np.array([ 0, 1, 2])
32 | self.observation_space=spaces.Box(shape=(4,4),low=0,high=100,dtype=np.float32)
33 |
34 | def pass_df(self,df):
35 | self.df=df
36 | print(self.df)
37 |
38 | def _next_observation(self):
39 | frame = np.array([np.array(self.df.loc[self.current_step: self.current_step + 3,'Air_temp'].values,dtype=np.float32),
40 | np.array(self.df.loc[self.current_step: self.current_step + 3,'Relative_humidity'].values,dtype=np.float32),
41 | np.array(self.df.loc[self.current_step: self.current_step + 3,'Outdoor_temp'].values,dtype=np.float32),
42 | np.array(self.df.loc[self.current_step: self.current_step + 3,'Thermal_comfort'].values,dtype=np.float32)],dtype=np.float32)
43 | #obs=np.append(frame,[np.ndarray([int(self.temp)]),np.ndarray([self.humid]),np.ndarray([self.comfort]),np.ndarray([self.out_temp])],axis=0)
44 | obs=frame
45 | return obs
46 |
47 | def _take_action(self,action):
48 | current_temp=random.uniform(self.df.loc[self.current_step,'Air_temp'],self.df.loc[self.current_step,'Air_temp']+5)
49 | action_type=action[0]
50 | temperature=action[1]
51 |
52 | if action_type <1:
53 | self.temp +=5
54 | self.humid *= (self.temp/100)
55 |
56 | elif action_type <2:
57 | self.temp -=5
58 | self.humid *= (self.temp/100)
59 |
60 | self.comfort -=(self.temp*self.humid)/1000*6
61 | #self.out_temp=self.out_temp
62 |
63 |
64 | def step(self,action):
65 | self._take_action(action)
66 | self.current_step+=1
67 | if self.current_step > len(self.df.loc[:, 'Air_temp'].values) - 4:
68 | self.current_step = 0
69 | if self.comfort >4 and self.comfort <6:
70 | reward= self.comfort
71 | else:
72 | reward=-self.comfort
73 | done=self.comfort>4.5
74 | obs=self._next_observation()
75 | return obs, reward, done, {}
76 |
77 |
78 | def reset(self):
79 | self.temp=random.uniform(25,30)
80 | self.humid=random.uniform(30,40)
81 | self.comfort=random.uniform(2,5)
82 | #self.out_temp=random.uniform(28,35)
83 | self.current_step = random.randint(0, len(self.df.loc[:, 'Air_temp'].values) - 6)
84 | return self._next_observation()
85 |
86 | def seed(self, seed=None):
87 | self.np_random, seed = seeding.np_random(seed)
88 | return [seed]
89 |
90 | def _render(self, mode='human', close=False):
91 | pass
92 |
--------------------------------------------------------------------------------
/DDPG.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.autograd
3 | import torch.optim as optim
4 | import torch.nn as nn
5 | from model import *
6 | from utils import *
7 |
8 | class DDPGagent:
9 | def __init__(self, env, hidden_size=256, actor_learning_rate=1e-4, critic_learning_rate=1e-3, gamma=0.99, tau=1e-2, max_memory_size=50000):
10 | # Params
11 | self.num_states = env.observation_space.shape[0]
12 | self.num_actions = env.action_space.shape[0]
13 | self.gamma = gamma
14 | self.tau = tau
15 |
16 | # Networks
17 | self.actor = Actor(self.num_states, hidden_size, self.num_actions)
18 | self.actor_target = Actor(self.num_states, hidden_size, self.num_actions)
19 | self.critic = Critic(self.num_states + self.num_actions, hidden_size, self.num_actions)
20 | self.critic_target = Critic(self.num_states + self.num_actions, hidden_size, self.num_actions)
21 |
22 | for target_param, param in zip(self.actor_target.parameters(), self.actor.parameters()):
23 | target_param.data.copy_(param.data)
24 |
25 | for target_param, param in zip(self.critic_target.parameters(), self.critic.parameters()):
26 | target_param.data.copy_(param.data)
27 |
28 | # Training
29 | self.memory = Memory(max_memory_size)
30 | self.critic_criterion = nn.MSELoss()
31 | self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=actor_learning_rate)
32 | self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=critic_learning_rate)
33 |
34 | def get_action(self, state):
35 | state = Variable(torch.from_numpy(state).float().unsqueeze(0))
36 | action = self.actor.forward(state)
37 | action = action.detach().numpy()[0,0]
38 | return action
39 |
40 | def update(self, batch_size):
41 | states, actions, rewards, next_states, _ = self.memory.sample(batch_size)
42 | states = torch.FloatTensor(states)
43 | actions = torch.FloatTensor(actions)
44 | rewards = torch.FloatTensor(rewards)
45 | next_states = torch.FloatTensor(next_states)
46 |
47 | # Critic loss
48 | Qvals = self.critic.forward(states, actions)
49 | next_actions = self.actor_target.forward(next_states)
50 | next_Q = self.critic_target.forward(next_states, next_actions.detach())
51 | Qprime = rewards + self.gamma * next_Q
52 | critic_loss = self.critic_criterion(Qvals, Qprime)
53 |
54 | # Actor loss
55 | policy_loss = -self.critic.forward(states, self.actor.forward(states)).mean()
56 |
57 | # update networks
58 | self.actor_optimizer.zero_grad()
59 | policy_loss.backward()
60 | self.actor_optimizer.step()
61 |
62 | self.critic_optimizer.zero_grad()
63 | critic_loss.backward()
64 | self.critic_optimizer.step()
65 |
66 | # update target networks
67 | for target_param, param in zip(self.actor_target.parameters(), self.actor.parameters()):
68 | target_param.data.copy_(param.data * self.tau + target_param.data * (1.0 - self.tau))
69 |
70 | for target_param, param in zip(self.critic_target.parameters(), self.critic.parameters()):
71 | target_param.data.copy_(param.data * self.tau + target_param.data * (1.0 - self.tau))
72 |
73 | def get_model(self):
74 | return self.actor
75 |
--------------------------------------------------------------------------------