├── custom_gym └── envs │ ├── custom_env_dir │ ├── __init__.py │ ├── episode9.pth │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ └── custom_env.cpython-36.pyc │ └── custom_env.py │ ├── __pycache__ │ └── __init__.cpython-36.pyc │ └── __init__.py ├── plots ├── updates.png ├── 0updates.png ├── 10updates.png ├── 11updates.png ├── 12updates.png ├── 13updates.png ├── 14updates.png ├── 15updates.png ├── 16updates.png ├── 17updates.png ├── 18updates.png ├── 19updates.png ├── 1updates.png ├── 20updates.png ├── 21updates.png ├── 22updates.png ├── 23updates.png ├── 24updates.png ├── 25updates.png ├── 26updates.png ├── 27updates.png ├── 28updates.png ├── 29updates.png ├── 2updates.png ├── 30updates.png ├── 31updates.png ├── 32updates.png ├── 33updates.png ├── 34updates.png ├── 35updates.png ├── 36updates.png ├── 37updates.png ├── 38updates.png ├── 39updates.png ├── 3updates.png ├── 40updates.png ├── 41updates.png ├── 42updates.png ├── 43updates.png ├── 4updates.png ├── 5updates.png ├── 6updates.png ├── 7updates.png ├── 8updates.png └── 9updates.png ├── model_store ├── episode0.pth ├── episode1.pth ├── episode10.pth ├── episode11.pth ├── episode12.pth ├── episode13.pth ├── episode14.pth ├── episode15.pth ├── episode16.pth ├── episode17.pth ├── episode18.pth ├── episode19.pth ├── episode2.pth ├── episode20.pth ├── episode21.pth ├── episode22.pth ├── episode23.pth ├── episode24.pth ├── episode25.pth ├── episode26.pth ├── episode27.pth ├── episode28.pth ├── episode3.pth ├── episode4.pth ├── episode5.pth ├── episode6.pth ├── episode7.pth ├── episode8.pth └── episode9.pth ├── __pycache__ ├── DDPG.cpython-36.pyc ├── model.cpython-36.pyc └── utils.cpython-36.pyc ├── README.md ├── model.py ├── main.py ├── utils.py └── DDPG.py /custom_gym/envs/custom_env_dir/__init__.py: -------------------------------------------------------------------------------- 1 | from custom_env import CustomEnv 2 | -------------------------------------------------------------------------------- /plots/updates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/updates.png -------------------------------------------------------------------------------- /plots/0updates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/0updates.png -------------------------------------------------------------------------------- /plots/10updates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/10updates.png -------------------------------------------------------------------------------- /plots/11updates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/11updates.png -------------------------------------------------------------------------------- /plots/12updates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/12updates.png -------------------------------------------------------------------------------- /plots/13updates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/13updates.png -------------------------------------------------------------------------------- /plots/14updates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/14updates.png -------------------------------------------------------------------------------- /plots/15updates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/15updates.png -------------------------------------------------------------------------------- /plots/16updates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/16updates.png -------------------------------------------------------------------------------- /plots/17updates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/17updates.png -------------------------------------------------------------------------------- /plots/18updates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/18updates.png -------------------------------------------------------------------------------- /plots/19updates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/19updates.png -------------------------------------------------------------------------------- /plots/1updates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/1updates.png -------------------------------------------------------------------------------- /plots/20updates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/20updates.png -------------------------------------------------------------------------------- /plots/21updates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/21updates.png -------------------------------------------------------------------------------- /plots/22updates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/22updates.png -------------------------------------------------------------------------------- /plots/23updates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/23updates.png -------------------------------------------------------------------------------- /plots/24updates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/24updates.png -------------------------------------------------------------------------------- /plots/25updates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/25updates.png -------------------------------------------------------------------------------- /plots/26updates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/26updates.png -------------------------------------------------------------------------------- /plots/27updates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/27updates.png -------------------------------------------------------------------------------- /plots/28updates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/28updates.png -------------------------------------------------------------------------------- /plots/29updates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/29updates.png -------------------------------------------------------------------------------- /plots/2updates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/2updates.png -------------------------------------------------------------------------------- /plots/30updates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/30updates.png -------------------------------------------------------------------------------- /plots/31updates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/31updates.png -------------------------------------------------------------------------------- /plots/32updates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/32updates.png -------------------------------------------------------------------------------- /plots/33updates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/33updates.png -------------------------------------------------------------------------------- /plots/34updates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/34updates.png -------------------------------------------------------------------------------- /plots/35updates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/35updates.png -------------------------------------------------------------------------------- /plots/36updates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/36updates.png -------------------------------------------------------------------------------- /plots/37updates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/37updates.png -------------------------------------------------------------------------------- /plots/38updates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/38updates.png -------------------------------------------------------------------------------- /plots/39updates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/39updates.png -------------------------------------------------------------------------------- /plots/3updates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/3updates.png -------------------------------------------------------------------------------- /plots/40updates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/40updates.png -------------------------------------------------------------------------------- /plots/41updates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/41updates.png -------------------------------------------------------------------------------- /plots/42updates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/42updates.png -------------------------------------------------------------------------------- /plots/43updates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/43updates.png -------------------------------------------------------------------------------- /plots/4updates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/4updates.png -------------------------------------------------------------------------------- /plots/5updates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/5updates.png -------------------------------------------------------------------------------- /plots/6updates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/6updates.png -------------------------------------------------------------------------------- /plots/7updates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/7updates.png -------------------------------------------------------------------------------- /plots/8updates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/8updates.png -------------------------------------------------------------------------------- /plots/9updates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/plots/9updates.png -------------------------------------------------------------------------------- /model_store/episode0.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode0.pth -------------------------------------------------------------------------------- /model_store/episode1.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode1.pth -------------------------------------------------------------------------------- /model_store/episode10.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode10.pth -------------------------------------------------------------------------------- /model_store/episode11.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode11.pth -------------------------------------------------------------------------------- /model_store/episode12.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode12.pth -------------------------------------------------------------------------------- /model_store/episode13.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode13.pth -------------------------------------------------------------------------------- /model_store/episode14.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode14.pth -------------------------------------------------------------------------------- /model_store/episode15.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode15.pth -------------------------------------------------------------------------------- /model_store/episode16.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode16.pth -------------------------------------------------------------------------------- /model_store/episode17.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode17.pth -------------------------------------------------------------------------------- /model_store/episode18.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode18.pth -------------------------------------------------------------------------------- /model_store/episode19.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode19.pth -------------------------------------------------------------------------------- /model_store/episode2.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode2.pth -------------------------------------------------------------------------------- /model_store/episode20.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode20.pth -------------------------------------------------------------------------------- /model_store/episode21.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode21.pth -------------------------------------------------------------------------------- /model_store/episode22.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode22.pth -------------------------------------------------------------------------------- /model_store/episode23.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode23.pth -------------------------------------------------------------------------------- /model_store/episode24.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode24.pth -------------------------------------------------------------------------------- /model_store/episode25.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode25.pth -------------------------------------------------------------------------------- /model_store/episode26.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode26.pth -------------------------------------------------------------------------------- /model_store/episode27.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode27.pth -------------------------------------------------------------------------------- /model_store/episode28.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode28.pth -------------------------------------------------------------------------------- /model_store/episode3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode3.pth -------------------------------------------------------------------------------- /model_store/episode4.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode4.pth -------------------------------------------------------------------------------- /model_store/episode5.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode5.pth -------------------------------------------------------------------------------- /model_store/episode6.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode6.pth -------------------------------------------------------------------------------- /model_store/episode7.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode7.pth -------------------------------------------------------------------------------- /model_store/episode8.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode8.pth -------------------------------------------------------------------------------- /model_store/episode9.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/model_store/episode9.pth -------------------------------------------------------------------------------- /__pycache__/DDPG.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/__pycache__/DDPG.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /custom_gym/envs/custom_env_dir/episode9.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/custom_gym/envs/custom_env_dir/episode9.pth -------------------------------------------------------------------------------- /custom_gym/envs/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/custom_gym/envs/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /custom_gym/envs/custom_env_dir/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/custom_gym/envs/custom_env_dir/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /custom_gym/envs/custom_env_dir/__pycache__/custom_env.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arunajit/drl/HEAD/custom_gym/envs/custom_env_dir/__pycache__/custom_env.cpython-36.pyc -------------------------------------------------------------------------------- /custom_gym/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.registration import register 2 | #env_name = 'CustomEnv-v0' 3 | #if env_name in gym.envs.registry.env_specs: 4 | # del gym.envs.registry.env_specs[env_name] 5 | 6 | register(id='CustomEnv-v0', 7 | entry_point='custom_env:CustomEnv' 8 | ) 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # drl 2 | Deep Deterministic Policy Gradients (DDPG) for learning the thermal control policy.
3 | 4 | 5 | Deep Deterministic Policy Gradient:
6 | https://spinningup.openai.com/en/latest/algorithms/ddpg.html
7 | 8 | Research Paper:
9 | https://arxiv.org/abs/1901.04693
10 | 11 | PyTorch based Library:
12 | https://github.com/vitchyr/rlkit
13 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.autograd 6 | from torch.autograd import Variable 7 | import pudb 8 | 9 | class Critic(nn.Module): 10 | def __init__(self, input_size, hidden_size, output_size): 11 | super(Critic, self).__init__() 12 | self.linear1 = nn.Linear(input_size, hidden_size) 13 | self.linear2 = nn.Linear(hidden_size, hidden_size) 14 | self.linear3 = nn.Linear(hidden_size, output_size) 15 | 16 | def forward(self, state, action): 17 | """ 18 | Params state and actions are torch tensors 19 | """ 20 | x = torch.cat([state, action], 1) 21 | x = F.relu(self.linear1(x)) 22 | x = F.relu(self.linear2(x)) 23 | x = self.linear3(x) 24 | 25 | return x 26 | 27 | class Actor(nn.Module): 28 | def __init__(self, input_size, hidden_size, output_size, learning_rate = 3e-4): 29 | super(Actor, self).__init__() 30 | self.linear1 = nn.Linear(input_size, hidden_size) 31 | self.linear2 = nn.Linear(hidden_size, hidden_size) 32 | self.linear3 = nn.Linear(hidden_size, output_size) 33 | 34 | def forward(self, state): 35 | """ 36 | Param state is a torch tensor 37 | """ 38 | # pu.db 39 | # print(state) 40 | x = F.relu(self.linear1(state)) 41 | x = F.relu(self.linear2(x)) 42 | x = torch.tanh(self.linear3(x)) 43 | 44 | return x 45 | 46 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import gym 3 | import envs 4 | import numpy as np 5 | import pandas as pd 6 | import matplotlib.pyplot as plt 7 | from DDPG import DDPGagent 8 | from utils import * 9 | import pudb 10 | from tqdm import tqdm 11 | import torch 12 | 13 | from custom_gym import * 14 | 15 | data=pd.read_csv('hvac_data.csv') 16 | 17 | x=data.dropna(axis=0,how='any') 18 | m=x.columns 19 | p=x[x[m[3]]==max(x[m[3]])].index.values 20 | len(p) 21 | max(x[m[3]]) 22 | x=x.drop(p) 23 | 24 | env = gym.make('CustomEnv-v0') 25 | env.pass_df(x) 26 | agent = DDPGagent(env) 27 | noise = OUNoise(env.action_space) 28 | batch_size = 128 29 | rewards = [] 30 | avg_rewards = [] 31 | 32 | for episode in tqdm(range(500)): 33 | 34 | state = env.reset() 35 | noise.reset() 36 | episode_reward = 0 37 | 38 | for step in range(48): 39 | action = agent.get_action(state) 40 | action = noise.get_action(action, step) 41 | new_state, reward, done, _ = env.step(action) 42 | agent.memory.push(state, action, reward, new_state, done) 43 | 44 | if len(agent.memory) > batch_size: 45 | agent.update(batch_size) 46 | 47 | state = new_state 48 | episode_reward += reward 49 | 50 | if step == 499: 51 | sys.stdout.write("episode: {}, reward: {}, average _reward: {} \n".format(episode, np.round(episode_reward, decimals=2), np.mean(rewards[-10:]))) 52 | break 53 | 54 | rewards.append(episode_reward) 55 | avg_rewards.append(np.mean(rewards[-10:])) 56 | # except: 57 | # pass 58 | 59 | plt.plot(rewards) 60 | plt.plot(avg_rewards) 61 | plt.plot() 62 | plt.xlabel('Episode') 63 | plt.ylabel('Reward') 64 | plt.savefig("rewards.png") 65 | torch.save(agent.get_model().state_dict(), "models/"+str(episode)+".pth") 66 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import gym 3 | from collections import deque 4 | import random 5 | 6 | # Ornstein-Ulhenbeck Process 7 | # Credits: https://github.com/vitchyr/rlkit/blob/master/rlkit/exploration_strategies/ou_strategy.py 8 | class OUNoise(object): 9 | def __init__(self, action_space, mu=0.0, theta=0.15, max_sigma=0.3, min_sigma=0.3, decay_period=100000): 10 | self.mu = mu 11 | self.theta = theta 12 | self.sigma = max_sigma 13 | self.max_sigma = max_sigma 14 | self.min_sigma = min_sigma 15 | self.decay_period = decay_period 16 | self.action_dim = action_space.shape[0] 17 | self.reset() 18 | 19 | def reset(self): 20 | self.state = np.ones(self.action_dim) * self.mu 21 | 22 | def evolve_state(self): 23 | x = self.state 24 | dx = self.theta * (self.mu - x) + self.sigma * np.random.randn(self.action_dim) 25 | self.state = x + dx 26 | return self.state 27 | 28 | def get_action(self, action, t=0): 29 | ou_state = self.evolve_state() 30 | self.sigma = self.max_sigma - (self.max_sigma - self.min_sigma) * min(1.0, t / self.decay_period) 31 | return np.clip(action + ou_state) 32 | 33 | 34 | # Credits: https://github.com/openai/gym/blob/master/gym/core.py 35 | class NormalizedEnv(gym.ActionWrapper): 36 | """ Wrap action """ 37 | 38 | def _action(self, action): 39 | act_k = 1./2. 40 | act_b = 3/2. 41 | return act_k * action + act_b 42 | 43 | def _reverse_action(self, action): 44 | act_k_inv = 2./1. 45 | act_b = 3/ 2. 46 | return act_k_inv * (action - act_b) 47 | 48 | 49 | class Memory: 50 | def __init__(self, max_size): 51 | self.max_size = max_size 52 | self.buffer = deque(maxlen=max_size) 53 | 54 | def push(self, state, action, reward, next_state, done): 55 | experience = (state, action, np.array([reward]), next_state, done) 56 | self.buffer.append(experience) 57 | 58 | def sample(self, batch_size): 59 | state_batch = [] 60 | action_batch = [] 61 | reward_batch = [] 62 | next_state_batch = [] 63 | done_batch = [] 64 | 65 | batch = random.sample(self.buffer, batch_size) 66 | 67 | for experience in batch: 68 | state, action, reward, next_state, done = experience 69 | state_batch.append(state) 70 | action_batch.append(action) 71 | reward_batch.append(reward) 72 | next_state_batch.append(next_state) 73 | done_batch.append(done) 74 | 75 | return state_batch, action_batch, reward_batch, next_state_batch, done_batch 76 | 77 | def __len__(self): 78 | return len(self.buffer) 79 | -------------------------------------------------------------------------------- /custom_gym/envs/custom_env_dir/custom_env.py: -------------------------------------------------------------------------------- 1 | import random 2 | import gym 3 | from gym import spaces 4 | import numpy as np 5 | import pandas as pd 6 | from gym.utils import seeding 7 | import math 8 | import io 9 | import torch 10 | from copy import deepcopy 11 | import scipy.io 12 | import pudb 13 | import json 14 | 15 | 16 | 17 | max_temp_in=60 18 | max_comfort=6 19 | min_comfort=0 20 | max_slots=3 21 | max_steps=500 22 | max_humid=100 23 | min_humid=0 24 | initial_temp=22 25 | 26 | class CustomEnv(gym.Env): 27 | def __init__(self): 28 | self.seed() 29 | super(CustomEnv,self).__init__() 30 | self.reward_range=(-10,10) 31 | self.action_space= np.array([ 0, 1, 2]) 32 | self.observation_space=spaces.Box(shape=(4,4),low=0,high=100,dtype=np.float32) 33 | 34 | def pass_df(self,df): 35 | self.df=df 36 | print(self.df) 37 | 38 | def _next_observation(self): 39 | frame = np.array([np.array(self.df.loc[self.current_step: self.current_step + 3,'Air_temp'].values,dtype=np.float32), 40 | np.array(self.df.loc[self.current_step: self.current_step + 3,'Relative_humidity'].values,dtype=np.float32), 41 | np.array(self.df.loc[self.current_step: self.current_step + 3,'Outdoor_temp'].values,dtype=np.float32), 42 | np.array(self.df.loc[self.current_step: self.current_step + 3,'Thermal_comfort'].values,dtype=np.float32)],dtype=np.float32) 43 | #obs=np.append(frame,[np.ndarray([int(self.temp)]),np.ndarray([self.humid]),np.ndarray([self.comfort]),np.ndarray([self.out_temp])],axis=0) 44 | obs=frame 45 | return obs 46 | 47 | def _take_action(self,action): 48 | current_temp=random.uniform(self.df.loc[self.current_step,'Air_temp'],self.df.loc[self.current_step,'Air_temp']+5) 49 | action_type=action[0] 50 | temperature=action[1] 51 | 52 | if action_type <1: 53 | self.temp +=5 54 | self.humid *= (self.temp/100) 55 | 56 | elif action_type <2: 57 | self.temp -=5 58 | self.humid *= (self.temp/100) 59 | 60 | self.comfort -=(self.temp*self.humid)/1000*6 61 | #self.out_temp=self.out_temp 62 | 63 | 64 | def step(self,action): 65 | self._take_action(action) 66 | self.current_step+=1 67 | if self.current_step > len(self.df.loc[:, 'Air_temp'].values) - 4: 68 | self.current_step = 0 69 | if self.comfort >4 and self.comfort <6: 70 | reward= self.comfort 71 | else: 72 | reward=-self.comfort 73 | done=self.comfort>4.5 74 | obs=self._next_observation() 75 | return obs, reward, done, {} 76 | 77 | 78 | def reset(self): 79 | self.temp=random.uniform(25,30) 80 | self.humid=random.uniform(30,40) 81 | self.comfort=random.uniform(2,5) 82 | #self.out_temp=random.uniform(28,35) 83 | self.current_step = random.randint(0, len(self.df.loc[:, 'Air_temp'].values) - 6) 84 | return self._next_observation() 85 | 86 | def seed(self, seed=None): 87 | self.np_random, seed = seeding.np_random(seed) 88 | return [seed] 89 | 90 | def _render(self, mode='human', close=False): 91 | pass 92 | -------------------------------------------------------------------------------- /DDPG.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.autograd 3 | import torch.optim as optim 4 | import torch.nn as nn 5 | from model import * 6 | from utils import * 7 | 8 | class DDPGagent: 9 | def __init__(self, env, hidden_size=256, actor_learning_rate=1e-4, critic_learning_rate=1e-3, gamma=0.99, tau=1e-2, max_memory_size=50000): 10 | # Params 11 | self.num_states = env.observation_space.shape[0] 12 | self.num_actions = env.action_space.shape[0] 13 | self.gamma = gamma 14 | self.tau = tau 15 | 16 | # Networks 17 | self.actor = Actor(self.num_states, hidden_size, self.num_actions) 18 | self.actor_target = Actor(self.num_states, hidden_size, self.num_actions) 19 | self.critic = Critic(self.num_states + self.num_actions, hidden_size, self.num_actions) 20 | self.critic_target = Critic(self.num_states + self.num_actions, hidden_size, self.num_actions) 21 | 22 | for target_param, param in zip(self.actor_target.parameters(), self.actor.parameters()): 23 | target_param.data.copy_(param.data) 24 | 25 | for target_param, param in zip(self.critic_target.parameters(), self.critic.parameters()): 26 | target_param.data.copy_(param.data) 27 | 28 | # Training 29 | self.memory = Memory(max_memory_size) 30 | self.critic_criterion = nn.MSELoss() 31 | self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=actor_learning_rate) 32 | self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=critic_learning_rate) 33 | 34 | def get_action(self, state): 35 | state = Variable(torch.from_numpy(state).float().unsqueeze(0)) 36 | action = self.actor.forward(state) 37 | action = action.detach().numpy()[0,0] 38 | return action 39 | 40 | def update(self, batch_size): 41 | states, actions, rewards, next_states, _ = self.memory.sample(batch_size) 42 | states = torch.FloatTensor(states) 43 | actions = torch.FloatTensor(actions) 44 | rewards = torch.FloatTensor(rewards) 45 | next_states = torch.FloatTensor(next_states) 46 | 47 | # Critic loss 48 | Qvals = self.critic.forward(states, actions) 49 | next_actions = self.actor_target.forward(next_states) 50 | next_Q = self.critic_target.forward(next_states, next_actions.detach()) 51 | Qprime = rewards + self.gamma * next_Q 52 | critic_loss = self.critic_criterion(Qvals, Qprime) 53 | 54 | # Actor loss 55 | policy_loss = -self.critic.forward(states, self.actor.forward(states)).mean() 56 | 57 | # update networks 58 | self.actor_optimizer.zero_grad() 59 | policy_loss.backward() 60 | self.actor_optimizer.step() 61 | 62 | self.critic_optimizer.zero_grad() 63 | critic_loss.backward() 64 | self.critic_optimizer.step() 65 | 66 | # update target networks 67 | for target_param, param in zip(self.actor_target.parameters(), self.actor.parameters()): 68 | target_param.data.copy_(param.data * self.tau + target_param.data * (1.0 - self.tau)) 69 | 70 | for target_param, param in zip(self.critic_target.parameters(), self.critic.parameters()): 71 | target_param.data.copy_(param.data * self.tau + target_param.data * (1.0 - self.tau)) 72 | 73 | def get_model(self): 74 | return self.actor 75 | --------------------------------------------------------------------------------