├── .gitignore ├── LICENSE ├── README.md ├── config.py ├── config_djss.py ├── config_djss_attention.py ├── config_djss_attention_paper.py ├── ddqn_agent_attention.py ├── ddqn_agent_attention_paper.py ├── ddqn_agent_attention_paper1.py ├── dqn_agent.py ├── dqn_agent_est.py ├── main.py ├── main_djss_attention.py ├── main_djss_attention_actionPercently.py ├── main_djss_attention_paper.py ├── main_djss_attention_paper1.py ├── main_djss_stable_baseline.py ├── main_est.py ├── model ├── ESTModel.py ├── FullyNetwork.py ├── NetworkModel.py ├── NetworkModel_attention.py ├── NetworkModel_attention_paper.py ├── NetworkModel_attention_paper1.py └── dqn-ft06-aciton-12.pth ├── per_agent.py ├── requirements.txt ├── simulation_env ├── action_map.py ├── env_for_job_shop_v7_attention.py ├── env_for_job_shop_v7_attention1.py ├── env_for_job_shop_v7_attention1_ft06.py ├── env_for_job_shop_v7_attention_test_dynamic_arvl_rate.py ├── env_for_job_shop_v7_attention_test_rule.py ├── env_jobshop_v0.py ├── env_jobshop_v1.py ├── env_jobshop_v1_est.py └── input_data │ └── job_info.xlsx ├── sweep.yaml ├── test_djss_attention.py ├── test_djss_attention_paper.py ├── testing_data_analysis.ipynb └── utils ├── GanttPlot.py ├── MemeryBuffer.py ├── PERMemory.py ├── SumTree.py ├── action_map.py └── dispatch_logic.py /.gitignore: -------------------------------------------------------------------------------- 1 | # file: gitignore 2 | 3 | #.gitignore 4 | .DS_Store 5 | 6 | __pycache__/ 7 | .ipynb_checkpoints/ 8 | test_log/ 9 | 1log_djss_attention/ 10 | log_djss_attention/ 11 | log/ 12 | wandb/ 13 | model/djss_attentioin/ 14 | 15 | *.csv 16 | *.xlsx 17 | *.pth -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Colin-ZL-Lee 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Simulation-based Deep Reinforcement Learning for Job Shop Scheduling Problem 2 | 3 | ### DRL-SimPy-JobShop 4 | 5 | An integration of deep reinforcement learning and discrete-event simulation for job shop scheduling problem. 6 | Using Self Attention model for the DQN agent. 7 | 8 | Please refer for more details: 9 | [Simulation-based Deep Reinforcement Learning for Job Shop Scheduling Problem](https://hdl.handle.net/11296/tt5jk9) 10 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | ## main ## 4 | FILE_NAME = 'job_info.xlsx' 5 | OPT_MAKESPAN = 55 6 | NUM_MACHINE = 6 7 | NUM_JOB = 6 8 | DIM_ACTION = 12 #10 9 | 10 | ## arguments ## 11 | DEVICE = 'cuda' 12 | MODEL = 'model/dqn-ft06-aciton-12-spt.pth' 13 | LOG_DIR = 'log/dqn' 14 | # train 15 | EPISODE = 40000 16 | CAPACITY = int(5000) 17 | WARMUP = CAPACITY / 2 18 | BATCH_SIZE = 64 #128 19 | LEARNING_R = .0005 20 | GAMMA = .99 21 | EPS_DECAY = .99982 22 | EPS_MAX = 1 23 | EPS_MIN = .1 24 | EPS_PERIOD = EPISODE / 4 25 | FREQ = 4 26 | TARGET_FREQ = 500 27 | RENDER_EPISODE = 900 28 | # test 29 | SEED = 2021111 30 | TEST_EPSILON = .001 -------------------------------------------------------------------------------- /config_djss.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | # timestramp 5 | time_start = time.strftime("%Y%m%d-%H-%M-%S", time.localtime()) 6 | timestramp = time.strftime("%Y%m%d-%H%M%S", time.localtime()) 7 | 8 | ## main ## 9 | FILE_NAME = 'ft06.txt' 10 | OPT_MAKESPAN = 55 11 | NUM_MACHINE = 6 12 | NUM_JOB = 6 13 | DIM_ACTION = 12 #10 14 | DIM_ACTION = 10 15 | 16 | RLS_RULE = 'SPT' 17 | 18 | ## arguments ## 19 | DEVICE = 'cuda' 20 | MODEL = f'model/djss_attention/ddqn-4x500x6-normalize-{timestramp}.pth' 21 | # LOG_DIR = 'log/dqn' 22 | LOG_DIR = 'log_djss_attention' 23 | # WRITTER = f'log/DQN-3x6x6-ep99995-{time.time()}' 24 | # WRITTER = f'log_djss/DDQN-20x100x6-{time_start}' 25 | WRITTER = f'log_djss_attention/DDQN-4x500x6-{time_start}' 26 | # train 27 | EPISODE = 50000 28 | PRIORI_PERIOD = 100 # 0# DIM_ACTION * 10 29 | CAPACITY = int(10000) 30 | # CAPACITY = int(5000) 31 | WARMUP = CAPACITY #/ 2 32 | BATCH_SIZE = 8#128 33 | LEARNING_R = .00005 34 | GAMMA = .99 35 | EPS_DECAY = 0.999999525 #.99982 36 | # EPS_DECAY = .999826#82 37 | EPS_MIN = .15 #.2 #.1 38 | FREQ = 16 #1 #4 39 | # TARGET_FREQ = 1000 #300 #600 #500 40 | TARGET_FREQ = 500 #500 41 | RENDER_EPISODE = 900 42 | # test 43 | SEED = 999 #2021111 44 | SEED = 20211211 45 | TEST_EPSILON = .001 46 | -------------------------------------------------------------------------------- /config_djss_attention.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | # timestramp 5 | time_start = time.strftime("%Y%m%d-%H-%M-%S", time.localtime()) 6 | timestramp = time.strftime("%Y%m%d-%H%M%S", time.localtime()) 7 | 8 | ## main ## 9 | FILE_NAME = 'ft06.txt' 10 | OPT_MAKESPAN = 55 11 | NUM_MACHINE = 6 12 | NUM_JOB = 6 13 | DIM_ACTION = 12 #10 14 | DIM_ACTION = 10 15 | 16 | RLS_RULE = 'SPT' 17 | 18 | ## arguments ## 19 | DEVICE = 'cuda' 20 | MODEL = f'model/djss_attention/ddqn-attention-4m300-h4-stateNormalized-{timestramp}.pth' 21 | # LOG_DIR = 'log/dqn' 22 | LOG_DIR = 'log_djss_attention' 23 | # WRITTER = f'log/DQN-3x6x6-ep99995-{time.time()}' 24 | # WRITTER = f'log_djss/DDQN-20x100x6-{time_start}' 25 | WRITTER = f'log_djss_attention/DDQN-Attention-4m-200n-300-stateNormalized-load95-{time_start}' 26 | # train 27 | EPISODE = 50000 28 | PRIORI_PERIOD = 0#50 #100 # 0# DIM_ACTION * 10 29 | CAPACITY = int(10000) 30 | # CAPACITY = int(5000) 31 | WARMUP = CAPACITY #/ 2 32 | BATCH_SIZE = 16#128 33 | LEARNING_R = .00005 34 | GAMMA = .99 35 | EPS_DECAY = 0.99999525 #.99982 36 | # EPS_DECAY = .999826#82 37 | EPS_MIN = .15 #.2 #.1 38 | FREQ = 16 #1 #4 39 | # TARGET_FREQ = 1000 #300 #600 #500 40 | TARGET_FREQ = 500 #500 41 | RENDER_EPISODE = 900 42 | # test 43 | SEED = 999 #2021111 44 | # SEED = 20211228 45 | SEED = 19951027 46 | TEST_EPSILON = .001 47 | -------------------------------------------------------------------------------- /config_djss_attention_paper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | # timestramp 5 | time_start = time.strftime("%Y%m%d-%H-%M-%S", time.localtime()) 6 | timestramp = time.strftime("%Y%m%d-%H%M%S", time.localtime()) 7 | 8 | ## main ## 9 | FILE_NAME = 'ft06.txt' 10 | OPT_MAKESPAN = 55 11 | NUM_MACHINE = 6 12 | NUM_JOB = 6 13 | DIM_ACTION = 12 #10 14 | DIM_ACTION = 10 15 | 16 | RLS_RULE = 'SPT' 17 | 18 | ## arguments ## 19 | DEVICE = 'cuda' 20 | MODEL = f'model/djss_attention/paper/ddqn-attention-4m300-h4-stateNormalized-{timestramp}.pth' 21 | # LOG_DIR = 'log/dqn' 22 | LOG_DIR = 'log_djss_attention' 23 | # WRITTER = f'log/DQN-3x6x6-ep99995-{time.time()}' 24 | # WRITTER = f'log_djss/DDQN-20x100x6-{time_start}' 25 | WRITTER = f'log_djss_attention/paper-DDQN-Attention-4m-200n-300-stateNormalized-load95-{time_start}' 26 | # train 27 | EPISODE = 50000 28 | PRIORI_PERIOD = 0#50 #100 # 0# DIM_ACTION * 10 29 | CAPACITY = int(10000) 30 | # CAPACITY = int(5000) 31 | WARMUP = CAPACITY #/ 2 32 | BATCH_SIZE = 16#128 33 | LEARNING_R = .00005 34 | GAMMA = .99 35 | EPS_DECAY = 0.99999525 #.99982 36 | # EPS_DECAY = .999826#82 37 | EPS_MIN = .15 #.2 #.1 38 | FREQ = 16 #1 #4 39 | # TARGET_FREQ = 1000 #300 #600 #500 40 | TARGET_FREQ = 500 #500 41 | RENDER_EPISODE = 900 42 | # test 43 | SEED = 999 #2021111 44 | # SEED = 20211228 45 | SEED = 19951006#19951027 46 | TEST_EPSILON = .001 47 | -------------------------------------------------------------------------------- /ddqn_agent_attention.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | ##----- [Algorithm] DQN Implementation ------ 5 | # * Author: Colin, Lee 6 | # * Date: Aug 16th, 2021 7 | ##------------------------------------------- 8 | # 9 | import time 10 | import torch 11 | import random 12 | import logging 13 | import argparse 14 | import itertools 15 | 16 | import numpy as np 17 | import torch.nn as nn 18 | 19 | from tensorboardX import SummaryWriter 20 | from datetime import datetime as dt 21 | 22 | from model.NetworkModel_attention import MultiHeadRelationalModule as Net 23 | # from model.FullyNetwork import Net 24 | from utils.MemeryBuffer import ReplayMemory 25 | 26 | import pdb 27 | import config_djss_attention as config 28 | seed = config.SEED 29 | random.seed(seed) 30 | np.random.seed(seed) 31 | torch.manual_seed(seed) 32 | torch.backends.cudnn.deterministic = True 33 | 34 | logging.basicConfig(level=logging.DEBUG) 35 | # ----------------------------------------------- 36 | 37 | class DDQN: 38 | def __init__(self, dim_state, dim_action, args): 39 | ## config ## 40 | self.device = torch.device(args.device) 41 | self.batch_size = args.batch_size 42 | self.gamma = args.gamma 43 | self.freq = args.freq 44 | self.target_freq = args.target_freq 45 | 46 | self._behavior_net = Net(dim_state, dim_action, args.device).to(self.device) 47 | self._target_net = Net(dim_state, dim_action, args.device).to(self.device) 48 | # ------------------------------------------- 49 | # initialize target network 50 | # ------------------------------------------- 51 | self._target_net.load_state_dict(self._behavior_net.state_dict())#, map_location=self.device) 52 | 53 | # self._optimizer = torch.optim.RMSprop( 54 | self._optimizer = torch.optim.Adam( 55 | self._behavior_net.parameters(), 56 | lr=args.lr 57 | ) 58 | self._criteria = nn.MSELoss() 59 | # self._criteria = nn.SmoothL1Loss() 60 | # memory 61 | self._memory = ReplayMemory(capacity=args.capacity) 62 | 63 | def select_best_action(self, state): 64 | ''' 65 | - state: (state_dim, ) 66 | ''' 67 | during_train = self._behavior_net.training 68 | if during_train: 69 | self.eval() 70 | state = torch.Tensor(state).to(self.device) 71 | state = DDQN.reshape_input_state(state) 72 | with torch.no_grad(): 73 | qvars = self._behavior_net(state) # (1, act_dim) 74 | action = torch.argmax(qvars, dim=-1) # (1, ) 75 | 76 | if during_train: 77 | self.train() 78 | 79 | return action.item() 80 | 81 | def select_action(self, state, epsilon, action_space): 82 | ''' 83 | epsilon-greedy based on behavior network 84 | 85 | -state = (state_dim, ) 86 | ''' 87 | if random.random() < epsilon: 88 | return action_space.sample() 89 | else: 90 | return self.select_best_action(state) 91 | 92 | def append(self, state, action, reward, next_state, done): 93 | self._memory.append( 94 | state, 95 | [action], 96 | [reward],# / 10], 97 | next_state, 98 | [1 - int(done)] 99 | ) 100 | 101 | def update(self, total_steps): 102 | if total_steps % self.freq == 0: 103 | return self._update_behavior_network(self.gamma) 104 | if total_steps % self.target_freq == 0: 105 | return self._update_target_network() 106 | 107 | def _update_behavior_network(self, gamma): 108 | # sample a minibatch of transitions 109 | ret = self._memory.sample(self.batch_size, self.device) 110 | state, action, reward, next_state, tocont = ret 111 | 112 | q_values = self._behavior_net(state) # (N, act_dim) 113 | q_value = torch.gather(input=q_values, dim=-1, index=action.long()) # (N, 1) 114 | with torch.no_grad(): 115 | ## Where DDQN is different from DQN 116 | qs_next_behavior_net = self._behavior_net(next_state) # (N, act_dim) 117 | indx_behavior_actions = torch.argmax(qs_next_behavior_net, dim=1).unsqueeze(dim=1) 118 | qs_next = self._target_net(next_state) # (N, act_dim) 119 | 120 | # compute V*(next_states) using predicted next q-values 121 | q_next = qs_next.gather(dim=1, index=indx_behavior_actions) # (N, 1) 122 | q_target = gamma*tocont*q_next.detach() + reward.detach() 123 | 124 | loss = self._criteria(q_value, q_target.detach()) 125 | 126 | # optimize 127 | self._optimizer.zero_grad() 128 | loss.backward() 129 | nn.utils.clip_grad_norm_(self._behavior_net.parameters(), 100)#5)#10) 130 | self._optimizer.step() 131 | 132 | return loss.item() 133 | 134 | def _update_target_network(self): 135 | ''' 136 | update target network by copying from behavior network 137 | ''' 138 | self._target_net.load_state_dict(self._behavior_net.state_dict()) 139 | return None 140 | 141 | def save(self, model_path, checkpoint=False): 142 | if checkpoint: 143 | torch.save( 144 | { 145 | 'behavior_net': self._behavior_net.state_dict(), 146 | 'target_net': self._target_net.state_dict(), 147 | 'optimizer': self._optimizer.state_dict(), 148 | }, model_path) 149 | else: 150 | torch.save({ 151 | 'behavior_net': self._behavior_net.state_dict(), 152 | }, model_path) 153 | 154 | def load(self, model_path, checkpoint=False): 155 | model = torch.load(model_path, map_location=self.device) 156 | self._behavior_net.load_state_dict(model['behavior_net'])#, map_location=self.device) 157 | if checkpoint: 158 | self._target_net.load_state_dict(model['target_net'])#, map_location=self.device) 159 | self._optimizer.load_state_dict(model['optimizer'])# , map_location=self.device) 160 | 161 | def train(self): 162 | self._behavior_net.train() 163 | self._target_net.eval() 164 | 165 | def eval(self): 166 | self._behavior_net.eval() 167 | self._target_net.eval() 168 | 169 | @staticmethod 170 | def reshape_input_state(state): 171 | state_shape = len(state.shape) 172 | state = state.unsqueeze(0) 173 | 174 | return state 175 | -------------------------------------------------------------------------------- /ddqn_agent_attention_paper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | ##----- [Algorithm] DQN Implementation ------ 5 | # * Author: Colin, Lee 6 | # * Date: Aug 16th, 2021 7 | ##------------------------------------------- 8 | # 9 | import time 10 | import torch 11 | import random 12 | import logging 13 | import argparse 14 | import itertools 15 | 16 | import numpy as np 17 | import torch.nn as nn 18 | 19 | from tensorboardX import SummaryWriter 20 | from datetime import datetime as dt 21 | 22 | from model.NetworkModel_attention_paper import MultiHeadRelationalModule as Net 23 | # from model.FullyNetwork import Net 24 | from utils.MemeryBuffer import ReplayMemory 25 | 26 | import pdb 27 | import config_djss_attention as config 28 | seed = config.SEED 29 | random.seed(seed) 30 | np.random.seed(seed) 31 | torch.manual_seed(seed) 32 | torch.backends.cudnn.deterministic = True 33 | 34 | logging.basicConfig(level=logging.DEBUG) 35 | # ----------------------------------------------- 36 | 37 | class DDQN: 38 | def __init__(self, dim_state, dim_action, args): 39 | ## config ## 40 | self.device = torch.device(args.device) 41 | self.batch_size = args.batch_size 42 | self.gamma = args.gamma 43 | self.freq = args.freq 44 | self.target_freq = args.target_freq 45 | 46 | self._behavior_net = Net(dim_state, dim_action, args.device).to(self.device) 47 | self._target_net = Net(dim_state, dim_action, args.device).to(self.device) 48 | # ------------------------------------------- 49 | # initialize target network 50 | # ------------------------------------------- 51 | self._target_net.load_state_dict(self._behavior_net.state_dict())#, map_location=self.device) 52 | 53 | # self._optimizer = torch.optim.RMSprop( 54 | self._optimizer = torch.optim.Adam( 55 | self._behavior_net.parameters(), 56 | lr=args.lr 57 | ) 58 | self._criteria = nn.MSELoss() 59 | # self._criteria = nn.SmoothL1Loss() 60 | # memory 61 | self._memory = ReplayMemory(capacity=args.capacity) 62 | 63 | def select_best_action(self, state): 64 | ''' 65 | - state: (state_dim, ) 66 | ''' 67 | during_train = self._behavior_net.training 68 | if during_train: 69 | self.eval() 70 | state = torch.Tensor(state).to(self.device) 71 | state = DDQN.reshape_input_state(state) 72 | with torch.no_grad(): 73 | qvars = self._behavior_net(state) # (1, act_dim) 74 | action = torch.argmax(qvars, dim=-1) # (1, ) 75 | 76 | if during_train: 77 | self.train() 78 | 79 | return action.item() 80 | 81 | def select_action(self, state, epsilon, action_space): 82 | ''' 83 | epsilon-greedy based on behavior network 84 | 85 | -state = (state_dim, ) 86 | ''' 87 | if random.random() < epsilon: 88 | return action_space.sample() 89 | else: 90 | return self.select_best_action(state) 91 | 92 | def append(self, state, action, reward, next_state, done): 93 | self._memory.append( 94 | state, 95 | [action], 96 | [reward],# / 10], 97 | next_state, 98 | [1 - int(done)] 99 | ) 100 | 101 | def update(self, total_steps): 102 | if total_steps % self.freq == 0: 103 | return self._update_behavior_network(self.gamma) 104 | if total_steps % self.target_freq == 0: 105 | return self._update_target_network() 106 | 107 | def _update_behavior_network(self, gamma): 108 | # sample a minibatch of transitions 109 | ret = self._memory.sample(self.batch_size, self.device) 110 | state, action, reward, next_state, tocont = ret 111 | 112 | q_values = self._behavior_net(state) # (N, act_dim) 113 | q_value = torch.gather(input=q_values, dim=-1, index=action.long()) # (N, 1) 114 | with torch.no_grad(): 115 | ## Where DDQN is different from DQN 116 | qs_next_behavior_net = self._behavior_net(next_state) # (N, act_dim) 117 | indx_behavior_actions = torch.argmax(qs_next_behavior_net, dim=1).unsqueeze(dim=1) 118 | qs_next = self._target_net(next_state) # (N, act_dim) 119 | 120 | # compute V*(next_states) using predicted next q-values 121 | q_next = qs_next.gather(dim=1, index=indx_behavior_actions) # (N, 1) 122 | q_target = gamma*tocont*q_next.detach() + reward.detach() 123 | 124 | loss = self._criteria(q_value, q_target.detach()) 125 | 126 | # optimize 127 | self._optimizer.zero_grad() 128 | loss.backward() 129 | nn.utils.clip_grad_norm_(self._behavior_net.parameters(), 100)#5)#10) 130 | self._optimizer.step() 131 | 132 | return loss.item() 133 | 134 | def _update_target_network(self): 135 | ''' 136 | update target network by copying from behavior network 137 | ''' 138 | self._target_net.load_state_dict(self._behavior_net.state_dict()) 139 | return None 140 | 141 | def save(self, model_path, checkpoint=False): 142 | if checkpoint: 143 | torch.save( 144 | { 145 | 'behavior_net': self._behavior_net.state_dict(), 146 | 'target_net': self._target_net.state_dict(), 147 | 'optimizer': self._optimizer.state_dict(), 148 | }, model_path) 149 | else: 150 | torch.save({ 151 | 'behavior_net': self._behavior_net.state_dict(), 152 | }, model_path) 153 | 154 | def load(self, model_path, checkpoint=False): 155 | model = torch.load(model_path, map_location=self.device) 156 | self._behavior_net.load_state_dict(model['behavior_net'])#, map_location=self.device) 157 | if checkpoint: 158 | self._target_net.load_state_dict(model['target_net'])#, map_location=self.device) 159 | self._optimizer.load_state_dict(model['optimizer'])# , map_location=self.device) 160 | 161 | def train(self): 162 | self._behavior_net.train() 163 | self._target_net.eval() 164 | 165 | def eval(self): 166 | self._behavior_net.eval() 167 | self._target_net.eval() 168 | 169 | @staticmethod 170 | def reshape_input_state(state): 171 | state_shape = len(state.shape) 172 | state = state.unsqueeze(0) 173 | 174 | return state 175 | -------------------------------------------------------------------------------- /ddqn_agent_attention_paper1.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | ##----- [Algorithm] DQN Implementation ------ 5 | # * Author: Colin, Lee 6 | # * Date: Aug 16th, 2021 7 | ##------------------------------------------- 8 | # 9 | import time 10 | import torch 11 | import random 12 | import logging 13 | import argparse 14 | import itertools 15 | 16 | import numpy as np 17 | import torch.nn as nn 18 | 19 | from tensorboardX import SummaryWriter 20 | from datetime import datetime as dt 21 | 22 | from model.NetworkModel_attention_paper1 import MultiHeadRelationalModule as Net 23 | # from model.FullyNetwork import Net 24 | from utils.MemeryBuffer import ReplayMemory 25 | 26 | import pdb 27 | import config_djss_attention as config 28 | seed = config.SEED 29 | random.seed(seed) 30 | np.random.seed(seed) 31 | torch.manual_seed(seed) 32 | torch.backends.cudnn.deterministic = True 33 | 34 | logging.basicConfig(level=logging.DEBUG) 35 | # ----------------------------------------------- 36 | 37 | class DDQN: 38 | def __init__(self, dim_state, dim_action, args): 39 | ## config ## 40 | self.device = torch.device(args.device) 41 | self.batch_size = args.batch_size 42 | self.gamma = args.gamma 43 | self.freq = args.freq 44 | self.target_freq = args.target_freq 45 | 46 | self._behavior_net = Net(dim_state, dim_action, args.device).to(self.device) 47 | self._target_net = Net(dim_state, dim_action, args.device).to(self.device) 48 | # ------------------------------------------- 49 | # initialize target network 50 | # ------------------------------------------- 51 | self._target_net.load_state_dict(self._behavior_net.state_dict())#, map_location=self.device) 52 | 53 | # self._optimizer = torch.optim.RMSprop( 54 | self._optimizer = torch.optim.Adam( 55 | self._behavior_net.parameters(), 56 | lr=args.lr 57 | ) 58 | self._criteria = nn.MSELoss() 59 | # self._criteria = nn.SmoothL1Loss() 60 | # memory 61 | self._memory = ReplayMemory(capacity=args.capacity) 62 | 63 | def select_best_action(self, state): 64 | ''' 65 | - state: (state_dim, ) 66 | ''' 67 | during_train = self._behavior_net.training 68 | if during_train: 69 | self.eval() 70 | state = torch.Tensor(state).to(self.device) 71 | state = DDQN.reshape_input_state(state) 72 | with torch.no_grad(): 73 | qvars = self._behavior_net(state) # (1, act_dim) 74 | action = torch.argmax(qvars, dim=-1) # (1, ) 75 | 76 | if during_train: 77 | self.train() 78 | 79 | return action.item() 80 | 81 | def select_action(self, state, epsilon, action_space): 82 | ''' 83 | epsilon-greedy based on behavior network 84 | 85 | -state = (state_dim, ) 86 | ''' 87 | if random.random() < epsilon: 88 | return action_space.sample() 89 | else: 90 | return self.select_best_action(state) 91 | 92 | def append(self, state, action, reward, next_state, done): 93 | self._memory.append( 94 | state, 95 | [action], 96 | [reward],# / 10], 97 | next_state, 98 | [1 - int(done)] 99 | ) 100 | 101 | def update(self, total_steps): 102 | if total_steps % self.freq == 0: 103 | return self._update_behavior_network(self.gamma) 104 | if total_steps % self.target_freq == 0: 105 | return self._update_target_network() 106 | 107 | def _update_behavior_network(self, gamma): 108 | # sample a minibatch of transitions 109 | ret = self._memory.sample(self.batch_size, self.device) 110 | state, action, reward, next_state, tocont = ret 111 | 112 | q_values = self._behavior_net(state) # (N, act_dim) 113 | q_value = torch.gather(input=q_values, dim=-1, index=action.long()) # (N, 1) 114 | with torch.no_grad(): 115 | ## Where DDQN is different from DQN 116 | qs_next_behavior_net = self._behavior_net(next_state) # (N, act_dim) 117 | indx_behavior_actions = torch.argmax(qs_next_behavior_net, dim=1).unsqueeze(dim=1) 118 | qs_next = self._target_net(next_state) # (N, act_dim) 119 | 120 | # compute V*(next_states) using predicted next q-values 121 | q_next = qs_next.gather(dim=1, index=indx_behavior_actions) # (N, 1) 122 | q_target = gamma*tocont*q_next.detach() + reward.detach() 123 | 124 | loss = self._criteria(q_value, q_target.detach()) 125 | 126 | # optimize 127 | self._optimizer.zero_grad() 128 | loss.backward() 129 | nn.utils.clip_grad_norm_(self._behavior_net.parameters(), 100)#5)#10) 130 | self._optimizer.step() 131 | 132 | return loss.item() 133 | 134 | def _update_target_network(self): 135 | ''' 136 | update target network by copying from behavior network 137 | ''' 138 | self._target_net.load_state_dict(self._behavior_net.state_dict()) 139 | return None 140 | 141 | def save(self, model_path, checkpoint=False): 142 | if checkpoint: 143 | torch.save( 144 | { 145 | 'behavior_net': self._behavior_net.state_dict(), 146 | 'target_net': self._target_net.state_dict(), 147 | 'optimizer': self._optimizer.state_dict(), 148 | }, model_path) 149 | else: 150 | torch.save({ 151 | 'behavior_net': self._behavior_net.state_dict(), 152 | }, model_path) 153 | 154 | def load(self, model_path, checkpoint=False): 155 | model = torch.load(model_path, map_location=self.device) 156 | self._behavior_net.load_state_dict(model['behavior_net'])#, map_location=self.device) 157 | if checkpoint: 158 | self._target_net.load_state_dict(model['target_net'])#, map_location=self.device) 159 | self._optimizer.load_state_dict(model['optimizer'])# , map_location=self.device) 160 | 161 | def train(self): 162 | self._behavior_net.train() 163 | self._target_net.eval() 164 | 165 | def eval(self): 166 | self._behavior_net.eval() 167 | self._target_net.eval() 168 | 169 | @staticmethod 170 | def reshape_input_state(state): 171 | state_shape = len(state.shape) 172 | state = state.unsqueeze(0) 173 | 174 | return state 175 | -------------------------------------------------------------------------------- /dqn_agent.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | ##----- [Algorithm] DQN Implementation ------ 5 | # * Author: Colin, Lee 6 | # * Date: Aug 16th, 2021 7 | ##------------------------------------------- 8 | # 9 | import time 10 | import torch 11 | import random 12 | import logging 13 | import argparse 14 | import itertools 15 | 16 | import numpy as np 17 | import torch.nn as nn 18 | 19 | from tensorboardX import SummaryWriter 20 | from datetime import datetime as dt 21 | 22 | from model.NetworkModel import Net 23 | # from model.FullyNetwork import Net 24 | from utils.MemeryBuffer import ReplayMemory 25 | 26 | import pdb 27 | 28 | logging.basicConfig(level=logging.DEBUG) 29 | # ----------------------------------------------- 30 | 31 | class DQN: 32 | def __init__(self, dim_state, dim_action, args): 33 | ## config ## 34 | self.device = torch.device(args.device) 35 | self.batch_size = args.batch_size 36 | self.gamma = args.gamma 37 | self.freq = args.freq 38 | self.target_freq = args.target_freq 39 | 40 | self._behavior_net = Net(dim_state, dim_action).to(self.device) 41 | self._target_net = Net(dim_state, dim_action).to(self.device) 42 | # ------------------------------------------- 43 | # initialize target network 44 | # ------------------------------------------- 45 | self._target_net.load_state_dict(self._behavior_net.state_dict())#, map_location=self.device) 46 | 47 | self._optimizer = torch.optim.RMSprop( 48 | self._behavior_net.parameters(), 49 | lr=args.lr 50 | ) 51 | self._criteria = nn.MSELoss() 52 | # memory 53 | self._memory = ReplayMemory(capacity=args.capacity) 54 | 55 | def select_best_action(self, state): 56 | ''' 57 | - state: (state_dim, ) 58 | ''' 59 | during_train = self._behavior_net.training 60 | if during_train: 61 | self.eval() 62 | state = torch.Tensor(state).to(self.device) 63 | state = DQN.reshape_input_state(state) 64 | with torch.no_grad(): 65 | qvars = self._behavior_net(state) # (1, act_dim) 66 | action = torch.argmax(qvars, dim=-1) # (1, ) 67 | 68 | if during_train: 69 | self.train() 70 | 71 | return action.item() 72 | 73 | def select_action(self, state, epsilon, action_space): 74 | ''' 75 | epsilon-greedy based on behavior network 76 | 77 | -state = (state_dim, ) 78 | ''' 79 | if random.random() < epsilon: 80 | return action_space.sample() 81 | else: 82 | return self.select_best_action(state) 83 | 84 | def append(self, state, action, reward, next_state, done): 85 | self._memory.append( 86 | state, 87 | [action], 88 | [reward],# / 10], 89 | next_state, 90 | [1 - int(done)] 91 | ) 92 | 93 | def update(self, total_steps): 94 | if total_steps % self.freq == 0: 95 | return self._update_behavior_network(self.gamma) 96 | if total_steps % self.target_freq == 0: 97 | return self._update_target_network() 98 | 99 | def _update_behavior_network(self, gamma): 100 | # sample a minibatch of transitions 101 | ret = self._memory.sample(self.batch_size, self.device) 102 | state, action, reward, next_state, tocont = ret 103 | 104 | q_values = self._behavior_net(state) # (N, act_dim) 105 | q_value = torch.gather(input=q_values, dim=-1, index=action.long()) # (N, 1) 106 | with torch.no_grad(): 107 | qs_next = self._target_net(next_state) # (N, act_dim) 108 | q_next, act = torch.max(qs_next, dim=-1, keepdim=True) # (N, 1) 109 | q_target = gamma*q_next*tocont + reward 110 | 111 | loss = self._criteria(q_value, q_target) 112 | 113 | # optimize 114 | self._optimizer.zero_grad() 115 | loss.backward() 116 | nn.utils.clip_grad_norm_(self._behavior_net.parameters(), 5) 117 | self._optimizer.step() 118 | 119 | return loss.item() 120 | 121 | def _update_target_network(self): 122 | ''' 123 | update target network by copying from behavior network 124 | ''' 125 | self._target_net.load_state_dict(self._behavior_net.state_dict()) 126 | return None 127 | 128 | def save(self, model_path, checkpoint=False): 129 | if checkpoint: 130 | torch.save( 131 | { 132 | 'behavior_net': self._behavior_net.state_dict(), 133 | 'target_net': self._target_net.state_dict(), 134 | 'optimizer': self._optimizer.state_dict(), 135 | }, model_path) 136 | else: 137 | torch.save({ 138 | 'behavior_net': self._behavior_net.state_dict(), 139 | }, model_path) 140 | 141 | def load(self, model_path, checkpoint=False): 142 | model = torch.load(model_path, map_location=self.device) 143 | self._behavior_net.load_state_dict(model['behavior_net'])#, map_location=self.device) 144 | if checkpoint: 145 | self._target_net.load_state_dict(model['target_net'])#, map_location=self.device) 146 | self._optimizer.load_state_dict(model['optimizer'])# , map_location=self.device) 147 | 148 | def train(self): 149 | self._behavior_net.train() 150 | self._target_net.eval() 151 | 152 | def eval(self): 153 | self._behavior_net.eval() 154 | self._target_net.eval() 155 | 156 | @staticmethod 157 | def reshape_input_state(state): 158 | state_shape = len(state.shape) 159 | state = state.unsqueeze(0) 160 | 161 | return state 162 | -------------------------------------------------------------------------------- /dqn_agent_est.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | ##----- [Algorithm] DQN Implementation ------ 5 | # * Author: Colin, Lee 6 | # * Date: Aug 16th, 2021 7 | ##------------------------------------------- 8 | # 9 | import time 10 | import torch 11 | import random 12 | import logging 13 | import argparse 14 | import itertools 15 | 16 | import numpy as np 17 | import torch.nn as nn 18 | 19 | from tensorboardX import SummaryWriter 20 | from datetime import datetime as dt 21 | 22 | # from model.NetworkModel import Net 23 | from model.ESTModel import Net 24 | # from model.FullyNetwork import Net 25 | from utils.MemeryBuffer import ReplayMemory 26 | 27 | import pdb 28 | 29 | logging.basicConfig(level=logging.DEBUG) 30 | # ----------------------------------------------- 31 | 32 | class DQN: 33 | def __init__(self, dim_state, dim_action, args): 34 | ## config ## 35 | self.device = torch.device(args.device) 36 | self.batch_size = args.batch_size 37 | self.gamma = args.gamma 38 | self.freq = args.freq 39 | self.target_freq = args.target_freq 40 | 41 | self._behavior_net = Net(dim_state, dim_action).to(self.device) 42 | self._target_net = Net(dim_state, dim_action).to(self.device) 43 | # ------------------------------------------- 44 | # initialize target network 45 | # ------------------------------------------- 46 | self._target_net.load_state_dict(self._behavior_net.state_dict())#, map_location=self.device) 47 | 48 | self._optimizer = torch.optim.RMSprop( 49 | self._behavior_net.parameters(), 50 | lr=args.lr 51 | ) 52 | self._criteria = nn.MSELoss() 53 | # memory 54 | self._memory = ReplayMemory(capacity=args.capacity) 55 | 56 | def select_best_action(self, state): 57 | ''' 58 | - state: (state_dim, ) 59 | ''' 60 | during_train = self._behavior_net.training 61 | if during_train: 62 | self.eval() 63 | state = torch.Tensor(state).to(self.device) 64 | state = DQN.reshape_input_state(state) 65 | with torch.no_grad(): 66 | qvars = self._behavior_net(state) # (1, act_dim) 67 | action = torch.argmax(qvars, dim=-1) # (1, ) 68 | 69 | if during_train: 70 | self.train() 71 | 72 | return action.item() 73 | 74 | def select_action(self, state, epsilon, action_space): 75 | ''' 76 | epsilon-greedy based on behavior network 77 | 78 | -state = (state_dim, ) 79 | ''' 80 | if random.random() < epsilon: 81 | return action_space.sample() 82 | else: 83 | return self.select_best_action(state) 84 | 85 | def append(self, state, action, reward, next_state, done): 86 | self._memory.append( 87 | state, 88 | [action], 89 | [reward],# / 10], 90 | next_state, 91 | [1 - int(done)] 92 | ) 93 | 94 | def update(self, total_steps): 95 | if total_steps % self.freq == 0: 96 | return self._update_behavior_network(self.gamma) 97 | if total_steps % self.target_freq == 0: 98 | return self._update_target_network() 99 | 100 | def _update_behavior_network(self, gamma): 101 | # sample a minibatch of transitions 102 | ret = self._memory.sample(self.batch_size, self.device) 103 | state, action, reward, next_state, tocont = ret 104 | 105 | q_values = self._behavior_net(state) # (N, act_dim) 106 | q_value = torch.gather(input=q_values, dim=-1, index=action.long()) # (N, 1) 107 | with torch.no_grad(): 108 | qs_next = self._target_net(next_state) # (N, act_dim) 109 | q_next, act = torch.max(qs_next, dim=-1, keepdim=True) # (N, 1) 110 | q_target = gamma*q_next*tocont + reward 111 | 112 | loss = self._criteria(q_value, q_target) 113 | 114 | # optimize 115 | self._optimizer.zero_grad() 116 | loss.backward() 117 | nn.utils.clip_grad_norm_(self._behavior_net.parameters(), 5) 118 | self._optimizer.step() 119 | 120 | return loss.item() 121 | 122 | def _update_target_network(self): 123 | ''' 124 | update target network by copying from behavior network 125 | ''' 126 | self._target_net.load_state_dict(self._behavior_net.state_dict()) 127 | return None 128 | 129 | def save(self, model_path, checkpoint=False): 130 | if checkpoint: 131 | torch.save( 132 | { 133 | 'behavior_net': self._behavior_net.state_dict(), 134 | 'target_net': self._target_net.state_dict(), 135 | 'optimizer': self._optimizer.state_dict(), 136 | }, model_path) 137 | else: 138 | torch.save({ 139 | 'behavior_net': self._behavior_net.state_dict(), 140 | }, model_path) 141 | 142 | def load(self, model_path, checkpoint=False): 143 | model = torch.load(model_path, map_location=self.device) 144 | self._behavior_net.load_state_dict(model['behavior_net'])#, map_location=self.device) 145 | if checkpoint: 146 | self._target_net.load_state_dict(model['target_net'])#, map_location=self.device) 147 | self._optimizer.load_state_dict(model['optimizer'])# , map_location=self.device) 148 | 149 | def train(self): 150 | self._behavior_net.train() 151 | self._target_net.eval() 152 | 153 | def eval(self): 154 | self._behavior_net.eval() 155 | self._target_net.eval() 156 | 157 | @staticmethod 158 | def reshape_input_state(state): 159 | state_shape = len(state.shape) 160 | state = state.unsqueeze(0) 161 | 162 | return state 163 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | ##----- [Trainer] Main program of Learning ------ 5 | # * Author: Colin, Lee 6 | # * Date: Aug 16th, 2021 7 | ##----------------------------------------------- 8 | # 9 | from config import EPISODE 10 | import os 11 | import sys 12 | import time 13 | import torch 14 | import random 15 | import logging 16 | import argparse 17 | import itertools 18 | 19 | import numpy as np 20 | import torch.nn as nn 21 | import matplotlib.pyplot as plt 22 | 23 | from tensorboardX import SummaryWriter 24 | from datetime import datetime as dt 25 | from tqdm import tqdm 26 | 27 | from simulation_env.env_jobshop_v1 import Factory 28 | from dqn_agent import DQN 29 | 30 | import pdb 31 | 32 | logging.basicConfig(level=logging.DEBUG) 33 | plt.set_loglevel('WARNING') 34 | # ----------------------------------------------- 35 | 36 | def train(args, _env, agent, writer): 37 | logging.info('* Start Training') 38 | 39 | env = _env 40 | action_space = env.action_space 41 | 42 | total_step, epsilon, ewma_reward = 0, 1., 0. 43 | 44 | # Switch to train mode 45 | agent.train() 46 | 47 | # Training until episode-condition 48 | for episode in range(args.episode): 49 | total_reward = 0 50 | done = False 51 | state = env.reset() 52 | 53 | # While not terminate 54 | for t in itertools.count(start=1): 55 | # if args.render and episode > 700: 56 | # env.render(done) 57 | # time.sleep(0.0082) 58 | 59 | # select action 60 | if total_step < args.warmup: 61 | action = action_space.sample() 62 | else: 63 | action = agent.select_action(state, epsilon, action_space) 64 | # epsilon = max(epsilon * args.eps_decay, args.eps_min) 65 | 66 | # execute action 67 | next_state, reward, done, _ = env.step(action) 68 | 69 | # store transition 70 | agent.append(state, action, reward, next_state, done) 71 | 72 | # optimize the model 73 | loss = None 74 | if total_step >= args.warmup: 75 | loss = agent.update(total_step) 76 | 77 | # transit next_state --> current_state 78 | state = next_state 79 | total_reward += reward 80 | total_step += 1 81 | 82 | if args.render and episode > args.render_episode: 83 | env.render(done) 84 | # time.sleep(0.0082) 85 | 86 | # Break & Record the performance at the end each episode 87 | if done: 88 | ewma_reward = 0.05 * total_reward + (1 - 0.05) * ewma_reward 89 | writer.add_scalar('Train-Episode/Reward', total_reward, 90 | episode) 91 | writer.add_scalar('Train-Episode/Makespan', env.makespan, 92 | episode) 93 | writer.add_scalar('Train-Episode/Epsilon', epsilon, 94 | episode) 95 | writer.add_scalar('Train-Step/Ewma_Reward', ewma_reward, 96 | total_step) 97 | if loss is not None: 98 | writer.add_scalar('Train-Step/Loss', loss, 99 | total_step) 100 | logging.info( 101 | ' - Step: {}\tEpisode: {}\tLength: {:3d}\tTotal reward: {:.2f}\tEwma reward: {:.2f}\tMakespan: {:.2f}\tEpsilon: {:.3f}' 102 | .format(total_step, episode, t, total_reward, ewma_reward, env.makespan, 103 | epsilon)) 104 | 105 | # Check the scheduling result 106 | fig = env.gantt_plot.draw_gantt(env.makespan) 107 | writer.add_figure('Train-Episode/Gantt_Chart', fig, episode) 108 | break 109 | 110 | delta_eps = (args.eps_max - args.eps_min) / args.eps_period 111 | epsilon = max(epsilon + delta_eps, args.eps_min) 112 | env.close() 113 | 114 | 115 | def test(args, _env, agent, writer): 116 | logging.info('\n* Start Testing') 117 | env = _env 118 | 119 | action_space = env.action_space 120 | epsilon = args.test_epsilon 121 | seeds = [args.seed + i for i in range(10)] 122 | rewards = [] 123 | makespans = [] 124 | 125 | n_episode = 0 126 | seed_loader = tqdm(seeds) 127 | for seed in seed_loader: 128 | n_episode += 1 129 | total_reward = 0 130 | # env.seed(seed) 131 | state = env.reset() 132 | for t in itertools.count(start=1): 133 | 134 | #action = agent.select_action(state, epsilon, action_space) 135 | action = agent.select_best_action(state) 136 | 137 | # execute action 138 | next_state, reward, done, _ = env.step(action) 139 | 140 | state = next_state 141 | total_reward += reward 142 | 143 | # env.render(done) 144 | env.render(terminal=done) 145 | 146 | if done: 147 | writer.add_scalar('Test/Episode_Reward', total_reward, n_episode) 148 | rewards.append(total_reward) 149 | makespans.append(env.makespan) 150 | 151 | # Check the scheduling result 152 | fig = env.gantt_plot.draw_gantt(env.makespan) 153 | writer.add_figure('Test/Gantt_Chart', fig, n_episode) 154 | break 155 | 156 | env.close() 157 | 158 | logging.info(f' - Average Reward = {np.mean(rewards)}') 159 | logging.info(f' - Average Makespan = {np.mean(makespans)}') 160 | 161 | 162 | def main(): 163 | import config 164 | ## arguments ## 165 | parser = argparse.ArgumentParser(description=__doc__) 166 | parser.add_argument('-d', '--device', default=config.DEVICE) #'cuda') 167 | parser.add_argument('-m', '--model' , default=config.MODEL) #'model/dqn.pth') 168 | parser.add_argument('--logdir' , default=config.LOG_DIR) #'log/dqn') 169 | # train 170 | parser.add_argument('--warmup' , default=config.WARMUP , type=int) 171 | parser.add_argument('--episode' , default=config.EPISODE , type=int) 172 | parser.add_argument('--capacity' , default=config.CAPACITY , type=int) 173 | parser.add_argument('--batch_size' , default=config.BATCH_SIZE , type=int) 174 | parser.add_argument('--lr' , default=config.LEARNING_R , type=float) 175 | parser.add_argument('--eps_decay' , default=config.EPS_DECAY , type=float) 176 | parser.add_argument('--eps_max' , default=config.EPS_MAX , type=float) 177 | parser.add_argument('--eps_min' , default=config.EPS_MIN , type=float) 178 | parser.add_argument('--eps_period' , default=config.EPS_PERIOD , type=float) 179 | parser.add_argument('--gamma' , default=config.GAMMA , type=float) 180 | parser.add_argument('--freq' , default=config.FREQ , type=int) 181 | parser.add_argument('--target_freq' , default=config.TARGET_FREQ , type=int) 182 | parser.add_argument('--render_episode', default=config.RENDER_EPISODE, type=int) 183 | # test 184 | parser.add_argument('--test_only' , action='store_true') 185 | parser.add_argument('--render' , action='store_true') 186 | parser.add_argument('--seed' , default=config.SEED , type=int) 187 | parser.add_argument('--test_epsilon', default=config.TEST_EPSILON, type=float) 188 | args = parser.parse_args() 189 | 190 | ## main ## 191 | file_name = config.FILE_NAME 192 | file_dir = os.getcwd() + '/simulation_env/input_data' 193 | file_path = os.path.join(file_dir, file_name) 194 | 195 | opt_makespan = config.OPT_MAKESPAN 196 | num_machine = config.NUM_MACHINE 197 | num_job = config.NUM_JOB 198 | 199 | # Agent & Environment 200 | env = Factory(num_job, num_machine, file_path, opt_makespan, log=False) 201 | agent = DQN(env.dim_observations, env.dim_actions, args) 202 | 203 | # Tensorboard to trace the learning process 204 | writer = SummaryWriter(f'log/DQN-{time.time()}') 205 | 206 | ## Train ## 207 | if not args.test_only: 208 | train(args, env, agent, writer) 209 | agent.save(args.model) 210 | 211 | ## Test ## To test the pre-trained model 212 | agent.load(args.model) 213 | test(args, env, agent, writer) 214 | 215 | writer.close() 216 | 217 | 218 | if __name__ == '__main__': 219 | main() 220 | -------------------------------------------------------------------------------- /main_djss_attention.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | ##----- [Trainer] Main program of Learning ------ 5 | # * Author: Colin, Lee 6 | # * Date: Aug 16th, 2021 7 | ##----------------------------------------------- 8 | # 9 | # from config import EPISODE 10 | import os 11 | import sys 12 | import time 13 | import torch 14 | import random 15 | import logging 16 | import argparse 17 | import itertools 18 | 19 | import numpy as np 20 | import torch.nn as nn 21 | import matplotlib.pyplot as plt 22 | 23 | from tensorboardX import SummaryWriter 24 | from datetime import datetime as dt 25 | from tqdm import tqdm 26 | 27 | # from simulation_env.env_jobshop_v1 import Factory 28 | from simulation_env.env_for_job_shop_v7_attention import Factory 29 | # from dqn_agent_djss import DQN as DDQN 30 | from ddqn_agent_attention import DDQN 31 | 32 | import pdb 33 | import wandb 34 | 35 | # seed 36 | import config_djss_attention as config 37 | seed = config.SEED #999 38 | random.seed(seed) 39 | np.random.seed(seed) 40 | torch.manual_seed(seed) 41 | torch.backends.cudnn.deterministic = True 42 | 43 | logging.basicConfig(level=logging.DEBUG) 44 | plt.set_loglevel('WARNING') 45 | # ----------------------------------------------- 46 | 47 | def train(args, _env, agent, writer): 48 | logging.info('* Start Training') 49 | 50 | env = _env 51 | action_space = env.action_space 52 | 53 | total_step, epsilon, ewma_reward = 0, 1., 0. 54 | 55 | # Switch to train mode 56 | agent.train() 57 | 58 | # Training until episode-condition 59 | for episode in range(args.episode): 60 | total_reward = 0 61 | done = False 62 | state = env.reset() 63 | 64 | # While not terminate 65 | for t in itertools.count(start=1): 66 | # if args.render and episode > 700: 67 | # env.render(done) 68 | # time.sleep(0.0082) 69 | 70 | # select action 71 | if episode < args.priori_period: 72 | action = episode % env.dim_actions 73 | # action = action_space.sample() 74 | elif total_step < args.warmup: 75 | action = action_space.sample() 76 | else: 77 | action = agent.select_action(state, epsilon, action_space) 78 | epsilon = max(epsilon * args.eps_decay, args.eps_min) 79 | 80 | # execute action 81 | next_state, reward, done, _ = env.step(action) 82 | 83 | # store transition 84 | agent.append(state, action, reward, next_state, done) 85 | if done or reward > 10: 86 | for _ in range(20): 87 | agent.append(state, action, reward, next_state, done) 88 | 89 | # optimize the model 90 | loss = None 91 | if total_step >= args.warmup: 92 | loss = agent.update(total_step) 93 | if loss is not None: 94 | writer.add_scalar('Train-Step/Loss', loss, 95 | total_step) 96 | wandb.log({ 97 | 'loss': loss 98 | }) 99 | 100 | # transit next_state --> current_state 101 | state = next_state 102 | total_reward += reward 103 | total_step += 1 104 | 105 | if args.render and episode > args.render_episode: 106 | env.render(done) 107 | # time.sleep(0.0082) 108 | 109 | # Break & Record the performance at the end each episode 110 | if done: 111 | ewma_reward = 0.05 * total_reward + (1 - 0.05) * ewma_reward 112 | writer.add_scalar('Train-Episode/Reward', total_reward, 113 | episode) 114 | writer.add_scalar('Train-Episode/Makespan', env.makespan, 115 | episode) 116 | writer.add_scalar('Train-Episode/MeanFT', env.mean_flow_time, 117 | episode) 118 | writer.add_scalar('Train-Episode/Epsilon', epsilon, 119 | episode) 120 | writer.add_scalar('Train-Step/Ewma_Reward', ewma_reward, 121 | total_step) 122 | wandb.log({ 123 | 'Reward': total_reward, 124 | 'Makespan': env.makespan, 125 | 'MeanFT': env.mean_flow_time, 126 | 'Epsilon': epsilon 127 | }) 128 | wandb.log({ 129 | 'Ewma_Reward': ewma_reward, 130 | }) 131 | if loss is not None: 132 | writer.add_scalar('Train-Step/Loss', loss, 133 | total_step) 134 | logging.info( 135 | ' - Step: {}\tEpisode: {}\tLength: {:3d}\tTotal reward: {:.2f}\tEwma reward: {:.2f}\tMakespan: {:.2f}\tMeanFT: {:.2f}\tEpsilon: {:.3f}' 136 | .format(total_step, episode, t, total_reward, ewma_reward, env.makespan, env.mean_flow_time, 137 | epsilon)) 138 | break 139 | 140 | # if episode > 1000: 141 | # epsilon = max(epsilon * args.eps_decay, args.eps_min) 142 | env.close() 143 | ## Train ## 144 | if episode % 1000 == 0 and episode > 0: 145 | agent.save(f'{args.model}-ck-{episode}.pth') 146 | 147 | 148 | def test(args, _env, agent, writer): 149 | logging.info('\n* Start Testing') 150 | env = _env 151 | 152 | action_space = env.action_space 153 | epsilon = args.test_epsilon 154 | seeds = [args.seed + i for i in range(100)] 155 | rewards = [] 156 | makespans = [] 157 | lst_mean_ft = [] 158 | 159 | n_episode = 0 160 | seed_loader = tqdm(seeds) 161 | for seed in seed_loader: 162 | n_episode += 1 163 | total_reward = 0 164 | # env.seed(seed) 165 | state = env.reset() 166 | for t in itertools.count(start=1): 167 | 168 | #action = agent.select_action(state, epsilon, action_space) 169 | action = agent.select_best_action(state) 170 | 171 | # execute action 172 | next_state, reward, done, _ = env.step(action) 173 | 174 | state = next_state 175 | total_reward += reward 176 | 177 | # env.render(done) 178 | #env.render(terminal=done) 179 | 180 | if done: 181 | writer.add_scalar('Test/Episode_Reward' , total_reward, n_episode) 182 | writer.add_scalar('Test/Episode_Makespan', env.makespan, n_episode) 183 | writer.add_scalar('Test/Episode_MeanFT', env.mean_flow_time, n_episode) 184 | wandb.log({ 185 | 'Test_Reward': total_reward, 186 | 'Test_Makespan': env.makespan, 187 | 'Test_MeanFT': env.mean_flow_time 188 | }) 189 | rewards.append(total_reward) 190 | makespans.append(env.makespan) 191 | lst_mean_ft.append(env.mean_flow_time) 192 | 193 | # # Check the scheduling result 194 | # fig = env.gantt_plot.draw_gantt(env.makespan) 195 | # writer.add_figure('Test/Gantt_Chart', fig, n_episode) 196 | break 197 | 198 | env.close() 199 | 200 | logging.info(f' - Average Reward = {np.mean(rewards)}') 201 | logging.info(f' - Average Makespan = {np.mean(makespans)}') 202 | logging.info(f' - Average MeanFT = {np.mean(lst_mean_ft)}') 203 | 204 | 205 | def main(): 206 | import wandb 207 | import config_djss_attention as config 208 | ## arguments ## 209 | parser = argparse.ArgumentParser(description=__doc__) 210 | parser.add_argument('-d', '--device', default=config.DEVICE) #'cuda') 211 | parser.add_argument('-m', '--model' , default=config.MODEL) #'model/dqn.pth') 212 | parser.add_argument('--logdir' , default=config.LOG_DIR) #'log/dqn') 213 | # train 214 | parser.add_argument('--warmup' , default=config.WARMUP , type=int) 215 | parser.add_argument('--episode' , default=config.EPISODE , type=int) 216 | parser.add_argument('--capacity' , default=config.CAPACITY , type=int) 217 | parser.add_argument('--batch_size' , default=config.BATCH_SIZE , type=int) 218 | parser.add_argument('--lr' , default=config.LEARNING_R , type=float) 219 | parser.add_argument('--eps_decay' , default=config.EPS_DECAY , type=float) 220 | parser.add_argument('--eps_min' , default=config.EPS_MIN , type=float) 221 | parser.add_argument('--gamma' , default=config.GAMMA , type=float) 222 | parser.add_argument('--freq' , default=config.FREQ , type=int) 223 | parser.add_argument('--target_freq' , default=config.TARGET_FREQ , type=int) 224 | parser.add_argument('--render_episode', default=config.RENDER_EPISODE, type=int) 225 | parser.add_argument('--priori_period' , default=config.PRIORI_PERIOD , type=int) 226 | # test 227 | parser.add_argument('--test_only' , action='store_true') 228 | parser.add_argument('--render' , action='store_true') 229 | parser.add_argument('--seed' , default=config.SEED , type=int) 230 | parser.add_argument('--test_epsilon', default=config.TEST_EPSILON, type=float) 231 | args = parser.parse_args() 232 | 233 | wandb.init(project='DRL-SimPy-JSS', config=args) 234 | 235 | ## main ## 236 | file_name = config.FILE_NAME 237 | file_dir = os.getcwd() + '/simulation_env/instance' 238 | file_path = os.path.join(file_dir, file_name) 239 | 240 | opt_makespan = config.OPT_MAKESPAN 241 | num_machine = config.NUM_MACHINE 242 | num_job = config.NUM_JOB 243 | 244 | # rls_rule = config.RLS_RULE 245 | 246 | # Agent & Environment 247 | # env = Factory(num_job, num_machine, file_path, opt_makespan, log=False) 248 | env = Factory(file_path, default_rule='FIFO', util=0.9, log=False)#True) 249 | # agent = DQN(env.dim_observations, env.dim_actions, args) 250 | agent = DDQN(env.dim_observations, env.dim_actions, args) 251 | 252 | # Tensorboard to trace the learning process 253 | # time_start = time.strftime("%Y%m%d-%H-%M-%S", time.localtime()) 254 | # writer = SummaryWriter(f'log_djss_attention/DDQN-4x500x6-{time_start}') 255 | writter_name = config.WRITTER 256 | writer = SummaryWriter(f'{writter_name}') 257 | # writer = SummaryWriter(f'log/DQN-{time.time()}') 258 | # writer = SummaryWriter(f'log/DDQN-{time.time()}') 259 | 260 | ## Train ## 261 | # agent.load('model/djss_attention/ddqn-attention-h4-20211224-125019.pth-ck-1000.pth') 262 | if not args.test_only: 263 | train(args, env, agent, writer) 264 | agent.save(args.model) 265 | 266 | ## Test ## To test the pre-trained model 267 | agent.load(args.model) 268 | test(args, env, agent, writer) 269 | 270 | writer.close() 271 | 272 | 273 | if __name__ == '__main__': 274 | main() 275 | -------------------------------------------------------------------------------- /main_djss_attention_actionPercently.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | ##----- [Trainer] Main program of Learning ------ 5 | # * Author: Colin, Lee 6 | # * Date: Aug 16th, 2021 7 | ##----------------------------------------------- 8 | # 9 | # from config import EPISODE 10 | import os 11 | import sys 12 | import time 13 | import torch 14 | import random 15 | import logging 16 | import argparse 17 | import itertools 18 | 19 | import numpy as np 20 | import torch.nn as nn 21 | import matplotlib.pyplot as plt 22 | 23 | from tensorboardX import SummaryWriter 24 | from datetime import datetime as dt 25 | from tqdm import tqdm 26 | 27 | # from simulation_env.env_jobshop_v1 import Factory 28 | from simulation_env.env_for_job_shop_v7_attention import Factory 29 | # from dqn_agent_djss import DQN as DDQN 30 | from ddqn_agent_attention import DDQN 31 | 32 | import pdb 33 | import wandb 34 | 35 | # seed 36 | import config_djss_attention as config 37 | seed = config.SEED #999 38 | random.seed(seed) 39 | np.random.seed(seed) 40 | torch.manual_seed(seed) 41 | torch.backends.cudnn.deterministic = True 42 | 43 | logging.basicConfig(level=logging.DEBUG) 44 | plt.set_loglevel('WARNING') 45 | # ----------------------------------------------- 46 | 47 | def train(args, _env, agent, writer): 48 | logging.info('* Start Training') 49 | 50 | env = _env 51 | action_space = env.action_space 52 | 53 | total_step, epsilon, ewma_reward = 0, 1., 0. 54 | 55 | # Switch to train mode 56 | agent.train() 57 | 58 | #################### Record Action ################### 59 | # counter = 0 60 | # df_action = pd.DataFrame(columns = [ 61 | # "epidose", \ 62 | # "rule", \ 63 | # "reward", \ 64 | # "avg_utilization" 65 | # ]) 66 | episode_percentage, episode_selection = [], [] 67 | ###################################################### 68 | 69 | # Training until episode-condition 70 | for episode in range(args.episode): 71 | total_reward = 0 72 | done = False 73 | state = env.reset() 74 | 75 | #################### Record Action ################### 76 | action_selection = [0] * env.dim_actions 77 | ###################################################### 78 | 79 | # While not terminate 80 | for t in itertools.count(start=1): 81 | # if args.render and episode > 700: 82 | # env.render(done) 83 | # time.sleep(0.0082) 84 | 85 | # select action 86 | if episode < args.priori_period: 87 | action = episode % env.dim_actions 88 | # action = action_space.sample() 89 | elif total_step < args.warmup: 90 | action = action_space.sample() 91 | else: 92 | action = agent.select_action(state, epsilon, action_space) 93 | epsilon = max(epsilon * args.eps_decay, args.eps_min) 94 | 95 | # execute action 96 | next_state, reward, done, _ = env.step(action) 97 | 98 | #################### Record Action ################### 99 | action_selection[action] += 1 100 | # avg_utils = env.get_utilization() 101 | # df_action.loc[counter] = \ 102 | # [episode, rule, total_reward, avg_utils] 103 | # counter += 1 104 | ###################################################### 105 | 106 | # store transition 107 | agent.append(state, action, reward, next_state, done) 108 | if done or reward > 10: 109 | for _ in range(20): 110 | agent.append(state, action, reward, next_state, done) 111 | 112 | # optimize the model 113 | loss = None 114 | if total_step >= args.warmup: 115 | loss = agent.update(total_step) 116 | if loss is not None: 117 | writer.add_scalar('Train-Step/Loss', loss, 118 | total_step) 119 | # transit next_state --> current_state 120 | state = next_state 121 | total_reward += reward 122 | total_step += 1 123 | 124 | if args.render and episode > args.render_episode: 125 | env.render(done) 126 | # time.sleep(0.0082) 127 | 128 | # Break & Record the performance at the end each episode 129 | if done: 130 | ewma_reward = 0.05 * total_reward + (1 - 0.05) * ewma_reward 131 | writer.add_scalar('Train-Episode/Reward', total_reward, 132 | episode) 133 | writer.add_scalar('Train-Episode/Makespan', env.makespan, 134 | episode) 135 | writer.add_scalar('Train-Episode/MeanFT', env.mean_flow_time, 136 | episode) 137 | writer.add_scalar('Train-Episode/Epsilon', epsilon, 138 | episode) 139 | writer.add_scalar('Train-Step/Ewma_Reward', ewma_reward, 140 | total_step) 141 | wandb.log({ 142 | 'Reward': total_reward, 143 | 'Makespan': env.makespan, 144 | 'MeanFT': env.mean_flow_time, 145 | 'Epsilon': epsilon, 146 | 'Ewma_Reward': ewma_reward, 147 | }) 148 | if loss is not None: 149 | writer.add_scalar('Train-Step/Loss', loss, 150 | total_step) 151 | wandb.log({ 152 | 'loss': loss 153 | }) 154 | 155 | logging.info( 156 | ' - Step: {}\tEpisode: {}\tLength: {:3d}\tTotal reward: {:.2f}\tEwma reward: {:.2f}\tMakespan: {:.2f}\tMeanFT: {:.2f}\tEpsilon: {:.3f}' 157 | .format(total_step, episode, t, total_reward, ewma_reward, env.makespan, env.mean_flow_time, 158 | epsilon)) 159 | break 160 | 161 | # if episode > 1000: 162 | # epsilon = max(epsilon * args.eps_decay, args.eps_min) 163 | env.close() 164 | ## Train ## 165 | if episode % 1000 == 0 and episode > 0: 166 | agent.save(f'{args.model}-ck-{episode}.pth') 167 | #################### Record Action ################### 168 | # statistic of the selection of action 169 | action_percentage = [0] * len(action_selection) 170 | for act in range(len(action_selection)): 171 | action_percentage[act] = action_selection[act] / t 172 | episode_percentage.append(action_percentage) 173 | episode_selection.append(action_selection) 174 | action_list = [ 175 | 'FIFO' , 'LIFO' , 'SPT' , 'LPT', 176 | 'LWKR' , 'MWKR' , 'SSO' , 'LSO', 177 | 'SPT+SSO', 'LPT+LSO' 178 | ] 179 | for act in range(len(action_list)): 180 | wandb.log({ 181 | action_list[act]: action_percentage[act], 182 | f'num_{action_list[act]}': action_selection[act] 183 | }) 184 | ###################################################### 185 | df_act_res = pd.DataFrame(episode_selection) 186 | df_act_per = pd.DataFrame(episode_percentage) 187 | df_act_res.to_csv('training_action_result.csv') 188 | df_act_per.to_csv('training_action_percentage.csv') 189 | 190 | 191 | def test(args, _env, agent, writer): 192 | logging.info('\n* Start Testing') 193 | env = _env 194 | 195 | action_space = env.action_space 196 | epsilon = args.test_epsilon 197 | seeds = [args.seed + i for i in range(100)] 198 | rewards = [] 199 | makespans = [] 200 | lst_mean_ft = [] 201 | 202 | n_episode = 0 203 | seed_loader = tqdm(seeds) 204 | ##################### Record Action ################### 205 | episode_percentage, episode_selection = [], [] 206 | ###################################################### 207 | 208 | for seed in seed_loader: 209 | #################### Record Action ################### 210 | action_selection = [0] * env.dim_actions 211 | ###################################################### 212 | n_episode += 1 213 | total_reward = 0 214 | # env.seed(seed) 215 | state = env.reset() 216 | for t in itertools.count(start=1): 217 | 218 | #action = agent.select_action(state, epsilon, action_space) 219 | action = agent.select_best_action(state) 220 | 221 | # execute action 222 | next_state, reward, done, _ = env.step(action) 223 | #################### Record Action ################### 224 | action_selection[action] += 1 225 | ###################################################### 226 | 227 | state = next_state 228 | total_reward += reward 229 | 230 | # env.render(done) 231 | #env.render(terminal=done) 232 | 233 | if done: 234 | writer.add_scalar('Test/Episode_Reward' , total_reward, n_episode) 235 | writer.add_scalar('Test/Episode_Makespan', env.makespan, n_episode) 236 | writer.add_scalar('Test/Episode_MeanFT', env.mean_flow_time, n_episode) 237 | wandb.log({ 238 | 'Test_Reward': total_reward, 239 | 'Test_Makespan': env.makespan, 240 | 'Test_MeanFT': env.mean_flow_time 241 | }) 242 | rewards.append(total_reward) 243 | makespans.append(env.makespan) 244 | lst_mean_ft.append(env.mean_flow_time) 245 | 246 | # # Check the scheduling result 247 | # fig = env.gantt_plot.draw_gantt(env.makespan) 248 | # writer.add_figure('Test/Gantt_Chart', fig, n_episode) 249 | break 250 | 251 | env.close() 252 | #################### Record Action ################### 253 | # statistic of the selection of action 254 | action_percentage = [0] * len(action_selection) 255 | for act in range(len(action_selection)): 256 | action_percentage[act] = action_selection[act] / t 257 | episode_selection.append(action_selection) 258 | episode_percentage.append(action_percentage) 259 | ###################################################### 260 | df_act_res = pd.DataFrame(episode_selection) 261 | df_act_per = pd.DataFrame(episode_percentage) 262 | df_act_res.to_csv('testing_action_result.csv') 263 | df_act_per.to_csv('testing_action_percentage.csv') 264 | 265 | logging.info(f' - Average Reward = {np.mean(rewards)}') 266 | logging.info(f' - Average Makespan = {np.mean(makespans)}') 267 | logging.info(f' - Average MeanFT = {np.mean(lst_mean_ft)}') 268 | 269 | 270 | def main(): 271 | import wandb 272 | import config_djss_attention as config 273 | ## arguments ## 274 | parser = argparse.ArgumentParser(description=__doc__) 275 | parser.add_argument('-d', '--device', default=config.DEVICE) #'cuda') 276 | parser.add_argument('-m', '--model' , default=config.MODEL) #'model/dqn.pth') 277 | parser.add_argument('--logdir' , default=config.LOG_DIR) #'log/dqn') 278 | # train 279 | parser.add_argument('--warmup' , default=config.WARMUP , type=int) 280 | parser.add_argument('--episode' , default=config.EPISODE , type=int) 281 | parser.add_argument('--capacity' , default=config.CAPACITY , type=int) 282 | parser.add_argument('--batch_size' , default=config.BATCH_SIZE , type=int) 283 | parser.add_argument('--lr' , default=config.LEARNING_R , type=float) 284 | parser.add_argument('--eps_decay' , default=config.EPS_DECAY , type=float) 285 | parser.add_argument('--eps_min' , default=config.EPS_MIN , type=float) 286 | parser.add_argument('--gamma' , default=config.GAMMA , type=float) 287 | parser.add_argument('--freq' , default=config.FREQ , type=int) 288 | parser.add_argument('--target_freq' , default=config.TARGET_FREQ , type=int) 289 | parser.add_argument('--render_episode', default=config.RENDER_EPISODE, type=int) 290 | parser.add_argument('--priori_period' , default=config.PRIORI_PERIOD , type=int) 291 | # test 292 | parser.add_argument('--test_only' , action='store_true') 293 | parser.add_argument('--render' , action='store_true') 294 | parser.add_argument('--seed' , default=config.SEED , type=int) 295 | parser.add_argument('--test_epsilon', default=config.TEST_EPSILON, type=float) 296 | args = parser.parse_args() 297 | 298 | wandb.init(project='DRL-SimPy-JSS', config=args) 299 | 300 | ## main ## 301 | file_name = config.FILE_NAME 302 | file_dir = os.getcwd() + '/simulation_env/instance' 303 | file_path = os.path.join(file_dir, file_name) 304 | 305 | opt_makespan = config.OPT_MAKESPAN 306 | num_machine = config.NUM_MACHINE 307 | num_job = config.NUM_JOB 308 | 309 | # rls_rule = config.RLS_RULE 310 | 311 | # Agent & Environment 312 | # env = Factory(num_job, num_machine, file_path, opt_makespan, log=False) 313 | env = Factory(file_path, default_rule='FIFO', util=0.9, log=False)#True) 314 | # agent = DQN(env.dim_observations, env.dim_actions, args) 315 | agent = DDQN(env.dim_observations, env.dim_actions, args) 316 | 317 | # Tensorboard to trace the learning process 318 | # time_start = time.strftime("%Y%m%d-%H-%M-%S", time.localtime()) 319 | # writer = SummaryWriter(f'log_djss_attention/DDQN-4x500x6-{time_start}') 320 | writter_name = config.WRITTER 321 | writer = SummaryWriter(f'{writter_name}') 322 | # writer = SummaryWriter(f'log/DQN-{time.time()}') 323 | # writer = SummaryWriter(f'log/DDQN-{time.time()}') 324 | 325 | ## Train ## 326 | # agent.load('model/djss_attention/ddqn-attention-h4-20211224-125019.pth-ck-1000.pth') 327 | if not args.test_only: 328 | train(args, env, agent, writer) 329 | agent.save(args.model) 330 | 331 | ## Test ## To test the pre-trained model 332 | agent.load(args.model) 333 | test(args, env, agent, writer) 334 | 335 | writer.close() 336 | 337 | 338 | if __name__ == '__main__': 339 | main() 340 | -------------------------------------------------------------------------------- /main_djss_attention_paper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | ##----- [Trainer] Main program of Learning ------ 5 | # * Author: Colin, Lee 6 | # * Date: Aug 16th, 2021 7 | ##----------------------------------------------- 8 | # 9 | # from config import EPISODE 10 | import os 11 | import sys 12 | import time 13 | import torch 14 | import random 15 | import logging 16 | import argparse 17 | import itertools 18 | 19 | import numpy as np 20 | import pandas as pd 21 | import torch.nn as nn 22 | import matplotlib.pyplot as plt 23 | 24 | from tensorboardX import SummaryWriter 25 | from datetime import datetime as dt 26 | from tqdm import tqdm 27 | 28 | # from simulation_env.env_jobshop_v1 import Factory 29 | from simulation_env.env_for_job_shop_v7_attention import Factory 30 | # from dqn_agent_djss import DQN as DDQN 31 | from ddqn_agent_attention_paper import DDQN 32 | 33 | import pdb 34 | import wandb 35 | 36 | # seed 37 | import config_djss_attention_paper as config 38 | seed = config.SEED #999 39 | random.seed(seed) 40 | np.random.seed(seed) 41 | torch.manual_seed(seed) 42 | torch.backends.cudnn.deterministic = True 43 | 44 | logging.basicConfig(level=logging.DEBUG) 45 | plt.set_loglevel('WARNING') 46 | # ----------------------------------------------- 47 | 48 | def train(args, _env, agent, writer): 49 | logging.info('* Start Training') 50 | 51 | env = _env 52 | action_space = env.action_space 53 | 54 | total_step, epsilon, ewma_reward = 0, 1., 0. 55 | 56 | # Switch to train mode 57 | agent.train() 58 | 59 | #################### Record Action ################### 60 | episode_percentage, episode_selection = [], [] 61 | ###################################################### 62 | 63 | # Training until episode-condition 64 | for episode in range(args.episode): 65 | total_reward = 0 66 | done = False 67 | state = env.reset() 68 | #################### Record Action ################### 69 | action_selection = [0] * env.dim_actions 70 | ###################################################### 71 | 72 | # While not terminate 73 | for t in itertools.count(start=1): 74 | # if args.render and episode > 700: 75 | # env.render(done) 76 | # time.sleep(0.0082) 77 | 78 | # select action 79 | if episode < args.priori_period: 80 | action = episode % env.dim_actions 81 | # action = action_space.sample() 82 | elif total_step < args.warmup: 83 | action = action_space.sample() 84 | else: 85 | action = agent.select_action(state, epsilon, action_space) 86 | epsilon = max(epsilon * args.eps_decay, args.eps_min) 87 | 88 | # execute action 89 | next_state, reward, done, _ = env.step(action) 90 | 91 | #################### Record Action ################### 92 | action_selection[action] += 1 93 | ###################################################### 94 | 95 | # store transition 96 | agent.append(state, action, reward, next_state, done) 97 | if done or reward > 10: 98 | for _ in range(20): 99 | agent.append(state, action, reward, next_state, done) 100 | 101 | # optimize the model 102 | loss = None 103 | if total_step >= args.warmup: 104 | loss = agent.update(total_step) 105 | if loss is not None: 106 | writer.add_scalar('Train-Step/Loss', loss, 107 | total_step) 108 | 109 | # transit next_state --> current_state 110 | state = next_state 111 | total_reward += reward 112 | total_step += 1 113 | 114 | if args.render and episode > args.render_episode: 115 | env.render(done) 116 | # time.sleep(0.0082) 117 | 118 | # Break & Record the performance at the end each episode 119 | if done: 120 | ewma_reward = 0.05 * total_reward + (1 - 0.05) * ewma_reward 121 | writer.add_scalar('Train-Episode/Reward', total_reward, 122 | episode) 123 | writer.add_scalar('Train-Episode/Makespan', env.makespan, 124 | episode) 125 | writer.add_scalar('Train-Episode/MeanFT', env.mean_flow_time, 126 | episode) 127 | writer.add_scalar('Train-Episode/Epsilon', epsilon, 128 | episode) 129 | writer.add_scalar('Train-Step/Ewma_Reward', ewma_reward, 130 | total_step) 131 | wandb.log({ 132 | 'Reward': total_reward, 133 | 'Makespan': env.makespan, 134 | 'MeanFT': env.mean_flow_time, 135 | 'Epsilon': epsilon, 136 | 'Ewma_Reward': ewma_reward 137 | }) 138 | 139 | if loss is not None: 140 | writer.add_scalar('Train-Step/Loss', loss, 141 | total_step) 142 | wandb.log({ 143 | 'loss': loss 144 | }) 145 | 146 | logging.info( 147 | ' - Step: {}\tEpisode: {}\tLength: {:3d}\tTotal reward: {:.2f}\tEwma reward: {:.2f}\tMakespan: {:.2f}\tMeanFT: {:.2f}\tEpsilon: {:.3f}' 148 | .format(total_step, episode, t, total_reward, ewma_reward, env.makespan, env.mean_flow_time, 149 | epsilon)) 150 | 151 | # # Check the scheduling result 152 | # if episode % 50 == 0: 153 | # fig = env.gantt_plot.draw_gantt(env.makespan) 154 | # writer.add_figure('Train-Episode/Gantt_Chart', fig, episode) 155 | break 156 | 157 | # if episode > 1000: 158 | # epsilon = max(epsilon * args.eps_decay, args.eps_min) 159 | env.close() 160 | ## Train ## 161 | if episode % 1000 == 0 and episode > 0: 162 | agent.save(f'{args.model}-ck-{episode}.pth') 163 | #################### Record Action ################### 164 | # statistic of the selection of action 165 | action_percentage = [0] * len(action_selection) 166 | for act in range(len(action_selection)): 167 | action_percentage[act] = action_selection[act] / t 168 | episode_percentage.append(action_percentage) 169 | episode_selection.append(action_selection) 170 | action_list = [ 171 | 'FIFO' , 'LIFO' , 'SPT' , 'LPT', 172 | 'LWKR' , 'MWKR' , 'SSO' , 'LSO', 173 | 'SPT+SSO', 'LPT+LSO' 174 | ] 175 | for act in range(len(action_list)): 176 | wandb.log({ 177 | action_list[act]: action_percentage[act], 178 | f'num_{action_list[act]}': action_selection[act] 179 | }) 180 | ###################################################### 181 | df_act_res = pd.DataFrame(episode_selection) 182 | df_act_per = pd.DataFrame(episode_percentage) 183 | df_act_res.to_csv('paper_training_action_result.csv') 184 | df_act_per.to_csv('paper_training_action_percentage.csv') 185 | 186 | 187 | def test(args, _env, agent, writer): 188 | logging.info('\n* Start Testing') 189 | env = _env 190 | 191 | action_space = env.action_space 192 | epsilon = args.test_epsilon 193 | seeds = [args.seed + i for i in range(100)] 194 | rewards = [] 195 | makespans = [] 196 | lst_mean_ft = [] 197 | 198 | n_episode = 0 199 | seed_loader = tqdm(seeds) 200 | ##################### Record Action ################### 201 | episode_percentage, episode_selection = [], [] 202 | ###################################################### 203 | 204 | for seed in seed_loader: 205 | #################### Record Action ################### 206 | action_selection = [0] * env.dim_actions 207 | ###################################################### 208 | n_episode += 1 209 | total_reward = 0 210 | # env.seed(seed) 211 | state = env.reset() 212 | for t in itertools.count(start=1): 213 | 214 | #action = agent.select_action(state, epsilon, action_space) 215 | action = agent.select_best_action(state) 216 | 217 | # execute action 218 | next_state, reward, done, _ = env.step(action) 219 | #################### Record Action ################### 220 | action_selection[action] += 1 221 | ###################################################### 222 | 223 | state = next_state 224 | total_reward += reward 225 | 226 | # env.render(done) 227 | #env.render(terminal=done) 228 | 229 | if done: 230 | writer.add_scalar('Test/Episode_Reward' , total_reward, n_episode) 231 | writer.add_scalar('Test/Episode_Makespan', env.makespan, n_episode) 232 | writer.add_scalar('Test/Episode_MeanFT', env.mean_flow_time, n_episode) 233 | wandb.log({ 234 | 'Test_Reward': total_reward, 235 | 'Test_Makespan': env.makespan, 236 | 'Test_MeanFT': env.mean_flow_time 237 | }) 238 | 239 | rewards.append(total_reward) 240 | makespans.append(env.makespan) 241 | lst_mean_ft.append(env.mean_flow_time) 242 | 243 | # # Check the scheduling result 244 | # fig = env.gantt_plot.draw_gantt(env.makespan) 245 | # writer.add_figure('Test/Gantt_Chart', fig, n_episode) 246 | break 247 | 248 | env.close() 249 | #################### Record Action ################### 250 | # statistic of the selection of action 251 | action_percentage = [0] * len(action_selection) 252 | for act in range(len(action_selection)): 253 | action_percentage[act] = action_selection[act] / total_step 254 | episode_percentage.append(action_percentage) 255 | episode_selection.append(action_selection) 256 | ###################################################### 257 | df_act_res = pd.DataFrame(episode_selection) 258 | df_act_per = pd.DataFrame(episode_percentage) 259 | df_act_res.to_csv('paper_testing_action_result.csv') 260 | df_act_per.to_csv('paper_testing_action_percentage.csv') 261 | 262 | logging.info(f' - Average Reward = {np.mean(rewards)}') 263 | logging.info(f' - Average Makespan = {np.mean(makespans)}') 264 | logging.info(f' - Average MeanFT = {np.mean(lst_mean_ft)}') 265 | 266 | 267 | def main(): 268 | import wandb 269 | import config_djss_attention_paper as config 270 | ## arguments ## 271 | parser = argparse.ArgumentParser(description=__doc__) 272 | parser.add_argument('-d', '--device', default=config.DEVICE) #'cuda') 273 | parser.add_argument('-m', '--model' , default=config.MODEL) #'model/dqn.pth') 274 | parser.add_argument('--logdir' , default=config.LOG_DIR) #'log/dqn') 275 | # train 276 | parser.add_argument('--warmup' , default=config.WARMUP , type=int) 277 | parser.add_argument('--episode' , default=config.EPISODE , type=int) 278 | parser.add_argument('--capacity' , default=config.CAPACITY , type=int) 279 | parser.add_argument('--batch_size' , default=config.BATCH_SIZE , type=int) 280 | parser.add_argument('--lr' , default=config.LEARNING_R , type=float) 281 | parser.add_argument('--eps_decay' , default=config.EPS_DECAY , type=float) 282 | parser.add_argument('--eps_min' , default=config.EPS_MIN , type=float) 283 | parser.add_argument('--gamma' , default=config.GAMMA , type=float) 284 | parser.add_argument('--freq' , default=config.FREQ , type=int) 285 | parser.add_argument('--target_freq' , default=config.TARGET_FREQ , type=int) 286 | parser.add_argument('--render_episode', default=config.RENDER_EPISODE, type=int) 287 | parser.add_argument('--priori_period' , default=config.PRIORI_PERIOD , type=int) 288 | # test 289 | parser.add_argument('--test_only' , action='store_true') 290 | parser.add_argument('--render' , action='store_true') 291 | parser.add_argument('--seed' , default=config.SEED , type=int) 292 | parser.add_argument('--test_epsilon', default=config.TEST_EPSILON, type=float) 293 | args = parser.parse_args() 294 | 295 | wandb.init(project='DRL-SimPy-JSS', config=args) 296 | 297 | ## main ## 298 | file_name = config.FILE_NAME 299 | file_dir = os.getcwd() + '/simulation_env/instance' 300 | file_path = os.path.join(file_dir, file_name) 301 | 302 | opt_makespan = config.OPT_MAKESPAN 303 | num_machine = config.NUM_MACHINE 304 | num_job = config.NUM_JOB 305 | 306 | # rls_rule = config.RLS_RULE 307 | 308 | # Agent & Environment 309 | # env = Factory(num_job, num_machine, file_path, opt_makespan, log=False) 310 | env = Factory(file_path, default_rule='FIFO', util=0.9, log=False)#True) 311 | # agent = DQN(env.dim_observations, env.dim_actions, args) 312 | agent = DDQN(env.dim_observations, env.dim_actions, args) 313 | 314 | # Tensorboard to trace the learning process 315 | # time_start = time.strftime("%Y%m%d-%H-%M-%S", time.localtime()) 316 | # writer = SummaryWriter(f'log_djss_attention/DDQN-4x500x6-{time_start}') 317 | writter_name = config.WRITTER 318 | writer = SummaryWriter(f'{writter_name}') 319 | # writer = SummaryWriter(f'log/DQN-{time.time()}') 320 | # writer = SummaryWriter(f'log/DDQN-{time.time()}') 321 | 322 | ## Train ## 323 | # agent.load('model/djss_attention/ddqn-attention-h4-20211224-125019.pth-ck-1000.pth') 324 | if not args.test_only: 325 | train(args, env, agent, writer) 326 | agent.save(args.model) 327 | 328 | ## Test ## To test the pre-trained model 329 | agent.load(args.model) 330 | test(args, env, agent, writer) 331 | 332 | writer.close() 333 | 334 | 335 | if __name__ == '__main__': 336 | main() 337 | -------------------------------------------------------------------------------- /main_djss_attention_paper1.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | ##----- [Trainer] Main program of Learning ------ 5 | # * Author: Colin, Lee 6 | # * Date: Aug 16th, 2021 7 | ##----------------------------------------------- 8 | # 9 | # from config import EPISODE 10 | import os 11 | import sys 12 | import time 13 | import torch 14 | import random 15 | import logging 16 | import argparse 17 | import itertools 18 | 19 | import numpy as np 20 | import pandas as pd 21 | import torch.nn as nn 22 | import matplotlib.pyplot as plt 23 | 24 | from tensorboardX import SummaryWriter 25 | from datetime import datetime as dt 26 | from tqdm import tqdm 27 | 28 | # from simulation_env.env_jobshop_v1 import Factory 29 | from simulation_env.env_for_job_shop_v7_attention1 import Factory 30 | # from dqn_agent_djss import DQN as DDQN 31 | from ddqn_agent_attention_paper1 import DDQN 32 | 33 | import pdb 34 | import wandb 35 | 36 | # seed 37 | import config_djss_attention_paper as config 38 | seed = config.SEED #999 39 | random.seed(seed) 40 | np.random.seed(seed) 41 | torch.manual_seed(seed) 42 | torch.backends.cudnn.deterministic = True 43 | 44 | logging.basicConfig(level=logging.DEBUG) 45 | plt.set_loglevel('WARNING') 46 | # ----------------------------------------------- 47 | 48 | def train(args, _env, agent, writer): 49 | logging.info('* Start Training') 50 | 51 | env = _env 52 | action_space = env.action_space 53 | 54 | total_step, epsilon, ewma_reward = 0, 1., 0. 55 | 56 | # Switch to train mode 57 | agent.train() 58 | 59 | #################### Record Action ################### 60 | episode_percentage, episode_selection = [], [] 61 | ###################################################### 62 | 63 | # Training until episode-condition 64 | for episode in range(args.episode): 65 | total_reward = 0 66 | done = False 67 | state = env.reset() 68 | #################### Record Action ################### 69 | action_selection = [0] * env.dim_actions 70 | ###################################################### 71 | 72 | # While not terminate 73 | for t in itertools.count(start=1): 74 | # if args.render and episode > 700: 75 | # env.render(done) 76 | # time.sleep(0.0082) 77 | 78 | # select action 79 | if episode < args.priori_period: 80 | action = episode % env.dim_actions 81 | # action = action_space.sample() 82 | elif total_step < args.warmup: 83 | action = action_space.sample() 84 | else: 85 | action = agent.select_action(state, epsilon, action_space) 86 | epsilon = max(epsilon * args.eps_decay, args.eps_min) 87 | 88 | # execute action 89 | next_state, reward, done, _ = env.step(action) 90 | 91 | 92 | #################### Record Action ################### 93 | action_selection[action] += 1 94 | ###################################################### 95 | 96 | # store transition 97 | agent.append(state, action, reward, next_state, done) 98 | if done or reward > 10: 99 | for _ in range(20): 100 | agent.append(state, action, reward, next_state, done) 101 | 102 | # optimize the model 103 | loss = None 104 | if total_step >= args.warmup: 105 | loss = agent.update(total_step) 106 | if loss is not None: 107 | writer.add_scalar('Train-Step/Loss', loss, 108 | total_step) 109 | 110 | # transit next_state --> current_state 111 | state = next_state 112 | total_reward += reward 113 | total_step += 1 114 | 115 | if args.render and episode > args.render_episode: 116 | env.render(done) 117 | # time.sleep(0.0082) 118 | 119 | # Break & Record the performance at the end each episode 120 | if done: 121 | ewma_reward = 0.05 * total_reward + (1 - 0.05) * ewma_reward 122 | writer.add_scalar('Train-Episode/Reward', total_reward, 123 | episode) 124 | writer.add_scalar('Train-Episode/Makespan', env.makespan, 125 | episode) 126 | writer.add_scalar('Train-Episode/MeanFT', env.mean_flow_time, 127 | episode) 128 | writer.add_scalar('Train-Episode/Epsilon', epsilon, 129 | episode) 130 | writer.add_scalar('Train-Step/Ewma_Reward', ewma_reward, 131 | total_step) 132 | wandb.log({ 133 | 'Reward': total_reward, 134 | 'Makespan': env.makespan, 135 | 'MeanFT': env.mean_flow_time, 136 | 'Epsilon': epsilon, 137 | 'Ewma_Reward': ewma_reward, 138 | 'Episode': episode 139 | }) 140 | 141 | if loss is not None: 142 | writer.add_scalar('Train-Step/Loss', loss, 143 | total_step) 144 | wandb.log({ 145 | 'loss': loss 146 | }) 147 | logging.info( 148 | ' - Step: {}\tEpisode: {}\tLength: {:3d}\tTotal reward: {:.2f}\tEwma reward: {:.2f}\tMakespan: {:.2f}\tMeanFT: {:.2f}\tEpsilon: {:.3f}' 149 | .format(total_step, episode, t, total_reward, ewma_reward, env.makespan, env.mean_flow_time, 150 | epsilon)) 151 | 152 | # # Check the scheduling result 153 | # if episode % 50 == 0: 154 | # fig = env.gantt_plot.draw_gantt(env.makespan) 155 | # writer.add_figure('Train-Episode/Gantt_Chart', fig, episode) 156 | break 157 | 158 | # if episode > 1000: 159 | # epsilon = max(epsilon * args.eps_decay, args.eps_min) 160 | env.close() 161 | ## Train ## 162 | if episode % 1000 == 0 and episode > 0: 163 | agent.save(f'{args.model}-1ck-{episode}.pth') 164 | #################### Record Action ################### 165 | # statistic of the selection of action 166 | action_percentage = [0] * len(action_selection) 167 | for act in range(len(action_selection)): 168 | action_percentage[act] = action_selection[act] / t 169 | episode_percentage.append(action_percentage) 170 | episode_selection.append(action_selection) 171 | action_list = [ 172 | 'FIFO' , 'LIFO' , 'SPT' , 'LPT', 173 | 'LWKR' , 'MWKR' , 'SSO' , 'LSO', 174 | 'SPT+SSO', 'LPT+LSO' 175 | ] 176 | for act in range(len(action_list)): 177 | wandb.log({ 178 | action_list[act]: action_percentage[act], 179 | f'num_{action_list[act]}': action_selection[act] 180 | }) 181 | ###################################################### 182 | df_act_res = pd.DataFrame(episode_selection) 183 | df_act_per = pd.DataFrame(episode_percentage) 184 | df_act_res.to_csv('paper_training_action_result1.csv') 185 | df_act_per.to_csv('paper_training_action_percentage1.csv') 186 | 187 | 188 | def test(args, _env, agent, writer): 189 | logging.info('\n* Start Testing') 190 | env = _env 191 | 192 | action_space = env.action_space 193 | epsilon = args.test_epsilon 194 | seeds = [args.seed + i for i in range(100)] 195 | rewards = [] 196 | makespans = [] 197 | lst_mean_ft = [] 198 | 199 | n_episode = 0 200 | seed_loader = tqdm(seeds) 201 | ##################### Record Action ################### 202 | episode_percentage, episode_selection = [], [] 203 | ###################################################### 204 | 205 | for seed in seed_loader: 206 | #################### Record Action ################### 207 | action_selection = [0] * env.dim_actions 208 | ###################################################### 209 | n_episode += 1 210 | total_reward = 0 211 | # env.seed(seed) 212 | state = env.reset() 213 | for t in itertools.count(start=1): 214 | 215 | #action = agent.select_action(state, epsilon, action_space) 216 | action = agent.select_best_action(state) 217 | 218 | # execute action 219 | next_state, reward, done, _ = env.step(action) 220 | #################### Record Action ################### 221 | action_selection[action] += 1 222 | ###################################################### 223 | 224 | state = next_state 225 | total_reward += reward 226 | 227 | # env.render(done) 228 | #env.render(terminal=done) 229 | 230 | if done: 231 | writer.add_scalar('Test/Episode_Reward' , total_reward, n_episode) 232 | writer.add_scalar('Test/Episode_Makespan', env.makespan, n_episode) 233 | writer.add_scalar('Test/Episode_MeanFT', env.mean_flow_time, n_episode) 234 | wandb.log({ 235 | 'Test_Reward': total_reward, 236 | 'Test_Makespan': env.makespan, 237 | 'Test_MeanFT': env.mean_flow_time 238 | }) 239 | 240 | rewards.append(total_reward) 241 | makespans.append(env.makespan) 242 | lst_mean_ft.append(env.mean_flow_time) 243 | 244 | # # Check the scheduling result 245 | # fig = env.gantt_plot.draw_gantt(env.makespan) 246 | # writer.add_figure('Test/Gantt_Chart', fig, n_episode) 247 | break 248 | 249 | env.close() 250 | #################### Record Action ################### 251 | # statistic of the selection of action 252 | action_percentage = [0] * len(action_selection) 253 | for act in range(len(action_selection)): 254 | action_percentage[act] = action_selection[act] / t 255 | episode_percentage.append(action_percentage) 256 | episode_selection.append(action_selection) 257 | ###################################################### 258 | df_act_res = pd.DataFrame(episode_selection) 259 | df_act_per = pd.DataFrame(episode_percentage) 260 | df_act_res.to_csv('paper_testing_action_result1.csv') 261 | df_act_per.to_csv('paper_testing_action_percentage1.csv') 262 | 263 | logging.info(f' - Average Reward = {np.mean(rewards)}') 264 | logging.info(f' - Average Makespan = {np.mean(makespans)}') 265 | logging.info(f' - Average MeanFT = {np.mean(lst_mean_ft)}') 266 | 267 | 268 | def main(): 269 | import wandb 270 | import config_djss_attention_paper as config 271 | ## arguments ## 272 | parser = argparse.ArgumentParser(description=__doc__) 273 | parser.add_argument('-d', '--device', default=config.DEVICE) #'cuda') 274 | parser.add_argument('-m', '--model' , default=config.MODEL) #'model/dqn.pth') 275 | parser.add_argument('--logdir' , default=config.LOG_DIR) #'log/dqn') 276 | # train 277 | parser.add_argument('--warmup' , default=config.WARMUP , type=int) 278 | parser.add_argument('--episode' , default=config.EPISODE , type=int) 279 | parser.add_argument('--capacity' , default=config.CAPACITY , type=int) 280 | parser.add_argument('--batch_size' , default=config.BATCH_SIZE , type=int) 281 | parser.add_argument('--lr' , default=config.LEARNING_R , type=float) 282 | parser.add_argument('--eps_decay' , default=config.EPS_DECAY , type=float) 283 | parser.add_argument('--eps_min' , default=config.EPS_MIN , type=float) 284 | parser.add_argument('--gamma' , default=config.GAMMA , type=float) 285 | parser.add_argument('--freq' , default=config.FREQ , type=int) 286 | parser.add_argument('--target_freq' , default=config.TARGET_FREQ , type=int) 287 | parser.add_argument('--render_episode', default=config.RENDER_EPISODE, type=int) 288 | parser.add_argument('--priori_period' , default=config.PRIORI_PERIOD , type=int) 289 | # test 290 | parser.add_argument('--test_only' , action='store_true') 291 | parser.add_argument('--render' , action='store_true') 292 | parser.add_argument('--seed' , default=config.SEED , type=int) 293 | parser.add_argument('--test_epsilon', default=config.TEST_EPSILON, type=float) 294 | args = parser.parse_args() 295 | args.batch_size = 32 296 | 297 | wandb.init(project='DRL-SimPy-JSS', config=args) 298 | 299 | ## main ## 300 | file_name = config.FILE_NAME 301 | file_dir = os.getcwd() + '/simulation_env/instance' 302 | file_path = os.path.join(file_dir, file_name) 303 | 304 | opt_makespan = config.OPT_MAKESPAN 305 | num_machine = config.NUM_MACHINE 306 | num_job = config.NUM_JOB 307 | 308 | # rls_rule = config.RLS_RULE 309 | 310 | # Agent & Environment 311 | # env = Factory(num_job, num_machine, file_path, opt_makespan, log=False) 312 | env = Factory(file_path, default_rule='FIFO', util=0.9, log=False)#True) 313 | # agent = DQN(env.dim_observations, env.dim_actions, args) 314 | agent = DDQN(env.dim_observations, env.dim_actions, args) 315 | 316 | # Tensorboard to trace the learning process 317 | # time_start = time.strftime("%Y%m%d-%H-%M-%S", time.localtime()) 318 | # writer = SummaryWriter(f'log_djss_attention/DDQN-4x500x6-{time_start}') 319 | writter_name = config.WRITTER 320 | writer = SummaryWriter('1' + f'{writter_name}') 321 | # writer = SummaryWriter(f'log/DQN-{time.time()}') 322 | # writer = SummaryWriter(f'log/DDQN-{time.time()}') 323 | 324 | ## Train ## 325 | # agent.load('model/djss_attention/ddqn-attention-h4-20211224-125019.pth-ck-1000.pth') 326 | if not args.test_only: 327 | train(args, env, agent, writer) 328 | agent.save('1'+args.model) 329 | 330 | ## Test ## To test the pre-trained model 331 | agent.load(args.model) 332 | test(args, env, agent, writer) 333 | 334 | writer.close() 335 | 336 | 337 | if __name__ == '__main__': 338 | main() 339 | -------------------------------------------------------------------------------- /main_djss_stable_baseline.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | ##----- [Trainer] Main program of Learning ------ 5 | # * Author: Colin, Lee 6 | # * Date: Aug 16th, 2021 7 | ##----------------------------------------------- 8 | # 9 | # from config import EPISODE 10 | import os 11 | import sys 12 | import time 13 | import torch 14 | import random 15 | import logging 16 | import argparse 17 | import itertools 18 | 19 | import numpy as np 20 | import torch.nn as nn 21 | import matplotlib.pyplot as plt 22 | 23 | from tensorboardX import SummaryWriter 24 | from datetime import datetime as dt 25 | from tqdm import tqdm 26 | 27 | # from simulation_env.env_jobshop_v1 import Factory 28 | from simulation_env.env_for_job_shop_v7_attention import Factory 29 | # from dqn_agent_djss import DQN as DDQN 30 | from ddqn_agent_attention import DDQN 31 | 32 | from stable_baselines.common.policies import MlpPolicy 33 | from stable_baselines import PPO2 34 | import pdb 35 | 36 | # seed 37 | import config_djss_attention as config 38 | seed = config.SEED #999 39 | random.seed(seed) 40 | np.random.seed(seed) 41 | torch.manual_seed(seed) 42 | torch.backends.cudnn.deterministic = True 43 | 44 | logging.basicConfig(level=logging.DEBUG) 45 | plt.set_loglevel('WARNING') 46 | # ----------------------------------------------- 47 | 48 | def train(args, _env, agent, writer): 49 | logging.info('* Start Training') 50 | 51 | env = _env 52 | action_space = env.action_space 53 | 54 | total_step, epsilon, ewma_reward = 0, 1., 0. 55 | 56 | # Switch to train mode 57 | agent.train() 58 | 59 | # Training until episode-condition 60 | for episode in range(args.episode): 61 | total_reward = 0 62 | done = False 63 | state = env.reset() 64 | state = torch.flatten(state) 65 | 66 | # While not terminate 67 | for t in itertools.count(start=1): 68 | # if args.render and episode > 700: 69 | # env.render(done) 70 | # time.sleep(0.0082) 71 | 72 | # select action 73 | if episode < args.priori_period: 74 | action = episode % env.dim_actions 75 | # action = action_space.sample() 76 | elif total_step < args.warmup: 77 | action = action_space.sample() 78 | else: 79 | action = agent.select_action(state, epsilon, action_space) 80 | epsilon = max(epsilon * args.eps_decay, args.eps_min) 81 | 82 | # execute action 83 | next_state, reward, done, _ = env.step(action) 84 | next_state = torch.flatten(next_state) 85 | 86 | # store transition 87 | agent.append(state, action, reward, next_state, done) 88 | if done or reward > 10: 89 | for _ in range(20): 90 | agent.append(state, action, reward, next_state, done) 91 | 92 | # optimize the model 93 | loss = None 94 | if total_step >= args.warmup: 95 | loss = agent.update(total_step) 96 | if loss is not None: 97 | writer.add_scalar('Train-Step/Loss', loss, 98 | total_step) 99 | 100 | # transit next_state --> current_state 101 | state = next_state 102 | total_reward += reward 103 | total_step += 1 104 | 105 | if args.render and episode > args.render_episode: 106 | env.render(done) 107 | # time.sleep(0.0082) 108 | 109 | # Break & Record the performance at the end each episode 110 | if done: 111 | ewma_reward = 0.05 * total_reward + (1 - 0.05) * ewma_reward 112 | writer.add_scalar('Train-Episode/Reward', total_reward, 113 | episode) 114 | writer.add_scalar('Train-Episode/Makespan', env.makespan, 115 | episode) 116 | writer.add_scalar('Train-Episode/MeanFT', env.mean_flow_time, 117 | episode) 118 | writer.add_scalar('Train-Episode/Epsilon', epsilon, 119 | episode) 120 | writer.add_scalar('Train-Step/Ewma_Reward', ewma_reward, 121 | total_step) 122 | if loss is not None: 123 | writer.add_scalar('Train-Step/Loss', loss, 124 | total_step) 125 | logging.info( 126 | ' - Step: {}\tEpisode: {}\tLength: {:3d}\tTotal reward: {:.2f}\tEwma reward: {:.2f}\tMakespan: {:.2f}\tMeanFT: {:.2f}\tEpsilon: {:.3f}' 127 | .format(total_step, episode, t, total_reward, ewma_reward, env.makespan, env.mean_flow_time, 128 | epsilon)) 129 | 130 | # # Check the scheduling result 131 | # if episode % 50 == 0: 132 | # fig = env.gantt_plot.draw_gantt(env.makespan) 133 | # writer.add_figure('Train-Episode/Gantt_Chart', fig, episode) 134 | break 135 | 136 | # if episode > 1000: 137 | # epsilon = max(epsilon * args.eps_decay, args.eps_min) 138 | env.close() 139 | ## Train ## 140 | if episode % 1000 == 0 and episode > 0: 141 | agent.save(f'{args.model}-ck-{episode}.pth') 142 | 143 | 144 | def test(args, _env, agent, writer): 145 | logging.info('\n* Start Testing') 146 | env = _env 147 | 148 | action_space = env.action_space 149 | epsilon = args.test_epsilon 150 | seeds = [args.seed + i for i in range(10)] 151 | rewards = [] 152 | makespans = [] 153 | lst_mean_ft = [] 154 | 155 | n_episode = 0 156 | seed_loader = tqdm(seeds) 157 | for seed in seed_loader: 158 | n_episode += 1 159 | total_reward = 0 160 | # env.seed(seed) 161 | state = env.reset() 162 | for t in itertools.count(start=1): 163 | 164 | #action = agent.select_action(state, epsilon, action_space) 165 | action = agent.select_best_action(state) 166 | 167 | # execute action 168 | next_state, reward, done, _ = env.step(action) 169 | 170 | state = next_state 171 | total_reward += reward 172 | 173 | # env.render(done) 174 | #env.render(terminal=done) 175 | 176 | if done: 177 | writer.add_scalar('Test/Episode_Reward' , total_reward, n_episode) 178 | writer.add_scalar('Test/Episode_Makespan', env.makespan, n_episode) 179 | writer.add_scalar('Test/Episode_MeanFT', env.mean_flow_time, n_episode) 180 | rewards.append(total_reward) 181 | makespans.append(env.makespan) 182 | lst_mean_ft.append(env.mean_flow_time) 183 | 184 | # # Check the scheduling result 185 | # fig = env.gantt_plot.draw_gantt(env.makespan) 186 | # writer.add_figure('Test/Gantt_Chart', fig, n_episode) 187 | break 188 | 189 | env.close() 190 | 191 | logging.info(f' - Average Reward = {np.mean(rewards)}') 192 | logging.info(f' - Average Makespan = {np.mean(makespans)}') 193 | logging.info(f' - Average MeanFT = {np.mean(lst_mean_ft)}') 194 | 195 | 196 | def main(): 197 | import config_djss_attention as config 198 | ## arguments ## 199 | parser = argparse.ArgumentParser(description=__doc__) 200 | parser.add_argument('-d', '--device', default=config.DEVICE) #'cuda') 201 | parser.add_argument('-m', '--model' , default=config.MODEL) #'model/dqn.pth') 202 | parser.add_argument('--logdir' , default=config.LOG_DIR) #'log/dqn') 203 | # train 204 | parser.add_argument('--warmup' , default=config.WARMUP , type=int) 205 | parser.add_argument('--episode' , default=config.EPISODE , type=int) 206 | parser.add_argument('--capacity' , default=config.CAPACITY , type=int) 207 | parser.add_argument('--batch_size' , default=config.BATCH_SIZE , type=int) 208 | parser.add_argument('--lr' , default=config.LEARNING_R , type=float) 209 | parser.add_argument('--eps_decay' , default=config.EPS_DECAY , type=float) 210 | parser.add_argument('--eps_min' , default=config.EPS_MIN , type=float) 211 | parser.add_argument('--gamma' , default=config.GAMMA , type=float) 212 | parser.add_argument('--freq' , default=config.FREQ , type=int) 213 | parser.add_argument('--target_freq' , default=config.TARGET_FREQ , type=int) 214 | parser.add_argument('--render_episode', default=config.RENDER_EPISODE, type=int) 215 | parser.add_argument('--priori_period' , default=config.PRIORI_PERIOD , type=int) 216 | # test 217 | parser.add_argument('--test_only' , action='store_true') 218 | parser.add_argument('--render' , action='store_true') 219 | parser.add_argument('--seed' , default=config.SEED , type=int) 220 | parser.add_argument('--test_epsilon', default=config.TEST_EPSILON, type=float) 221 | args = parser.parse_args() 222 | 223 | ## main ## 224 | file_name = config.FILE_NAME 225 | file_dir = os.getcwd() + '/simulation_env/instance' 226 | file_path = os.path.join(file_dir, file_name) 227 | 228 | opt_makespan = config.OPT_MAKESPAN 229 | num_machine = config.NUM_MACHINE 230 | num_job = config.NUM_JOB 231 | 232 | # rls_rule = config.RLS_RULE 233 | 234 | # Agent & Environment 235 | # env = Factory(num_job, num_machine, file_path, opt_makespan, log=False) 236 | env = Factory(file_path, default_rule='FIFO', util=0.9, log=False)#True) 237 | # agent = DQN(env.dim_observations, env.dim_actions, args) 238 | agent = DDQN(env.dim_observations, env.dim_actions, args) 239 | 240 | # Tensorboard to trace the learning process 241 | # time_start = time.strftime("%Y%m%d-%H-%M-%S", time.localtime()) 242 | writer = SummaryWriter(f'log_djss_stablebaseline/PPO-3x300x4-{time_start}') 243 | # writter_name = config.WRITTER 244 | writer = SummaryWriter(f'{writter_name}') 245 | # writer = SummaryWriter(f'log/DQN-{time.time()}') 246 | # writer = SummaryWriter(f'log/DDQN-{time.time()}') 247 | 248 | args.model = f'model/djss_stablebaseline/PPO-3x300x4-{time_start}' 249 | ## Train ## 250 | # agent.load('model/djss_attention/ddqn-attention-h4-20211224-125019.pth-ck-1000.pth') 251 | if not args.test_only: 252 | train(args, env, agent, writer) 253 | agent.save(args.model) 254 | 255 | ## Test ## To test the pre-trained model 256 | agent.load(args.model) 257 | test(args, env, agent, writer) 258 | 259 | writer.close() 260 | 261 | 262 | if __name__ == '__main__': 263 | main() 264 | -------------------------------------------------------------------------------- /main_est.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | ##----- [Trainer] Main program of Learning ------ 5 | # * Author: Colin, Lee 6 | # * Date: Aug 16th, 2021 7 | ##----------------------------------------------- 8 | # 9 | from config import EPISODE 10 | import os 11 | import sys 12 | import time 13 | import torch 14 | import random 15 | import logging 16 | import argparse 17 | import itertools 18 | 19 | import numpy as np 20 | import torch.nn as nn 21 | import matplotlib.pyplot as plt 22 | 23 | from tensorboardX import SummaryWriter 24 | from datetime import datetime as dt 25 | from tqdm import tqdm 26 | 27 | from simulation_env.env_jobshop_v1_est import Factory 28 | from dqn_agent import DQN 29 | 30 | import pdb 31 | 32 | logging.basicConfig(level=logging.DEBUG) 33 | plt.set_loglevel('WARNING') 34 | # ----------------------------------------------- 35 | 36 | def train(args, _env, agent, writer): 37 | logging.info('* Start Training') 38 | 39 | env = _env 40 | action_space = env.action_space 41 | 42 | total_step, epsilon, ewma_reward = 0, 1., 0. 43 | 44 | # Switch to train mode 45 | agent.train() 46 | 47 | # Training until episode-condition 48 | for episode in range(args.episode): 49 | total_reward = 0 50 | done = False 51 | state = env.reset() 52 | 53 | # While not terminate 54 | for t in itertools.count(start=1): 55 | # if args.render and episode > 700: 56 | # env.render(done) 57 | # time.sleep(0.0082) 58 | 59 | # select action 60 | if total_step < args.warmup: 61 | action = action_space.sample() 62 | else: 63 | action = agent.select_action(state, epsilon, action_space) 64 | # epsilon = max(epsilon * args.eps_decay, args.eps_min) 65 | 66 | # execute action 67 | next_state, reward, done, _ = env.step(action) 68 | 69 | # store transition 70 | agent.append(state, action, reward, next_state, done) 71 | 72 | # optimize the model 73 | loss = None 74 | if total_step >= args.warmup: 75 | loss = agent.update(total_step) 76 | 77 | # transit next_state --> current_state 78 | state = next_state 79 | total_reward += reward 80 | total_step += 1 81 | 82 | if args.render and episode > args.render_episode: 83 | env.render(done) 84 | # time.sleep(0.0082) 85 | 86 | # Break & Record the performance at the end each episode 87 | if done: 88 | ewma_reward = 0.05 * total_reward + (1 - 0.05) * ewma_reward 89 | writer.add_scalar('Train-Episode/Reward', total_reward, 90 | episode) 91 | writer.add_scalar('Train-Episode/Makespan', env.makespan, 92 | episode) 93 | writer.add_scalar('Train-Episode/Epsilon', epsilon, 94 | episode) 95 | writer.add_scalar('Train-Step/Ewma_Reward', ewma_reward, 96 | total_step) 97 | if loss is not None: 98 | writer.add_scalar('Train-Step/Loss', loss, 99 | total_step) 100 | logging.info( 101 | ' - Step: {}\tEpisode: {}\tLength: {:3d}\tTotal reward: {:.2f}\tEwma reward: {:.2f}\tMakespan: {:.2f}\tEpsilon: {:.3f}' 102 | .format(total_step, episode, t, total_reward, ewma_reward, env.makespan, 103 | epsilon)) 104 | 105 | # Check the scheduling result 106 | fig = env.gantt_plot.draw_gantt(env.makespan) 107 | writer.add_figure('Train-Episode/Gantt_Chart', fig, episode) 108 | break 109 | 110 | delta_eps = (args.eps_max - args.eps_min) / args.eps_period 111 | epsilon = max(epsilon + delta_eps, args.eps_min) 112 | env.close() 113 | 114 | 115 | def test(args, _env, agent, writer): 116 | logging.info('\n* Start Testing') 117 | env = _env 118 | 119 | action_space = env.action_space 120 | epsilon = args.test_epsilon 121 | seeds = [args.seed + i for i in range(10)] 122 | rewards = [] 123 | makespans = [] 124 | 125 | n_episode = 0 126 | seed_loader = tqdm(seeds) 127 | for seed in seed_loader: 128 | n_episode += 1 129 | total_reward = 0 130 | # env.seed(seed) 131 | state = env.reset() 132 | for t in itertools.count(start=1): 133 | 134 | #action = agent.select_action(state, epsilon, action_space) 135 | action = agent.select_best_action(state) 136 | 137 | # execute action 138 | next_state, reward, done, _ = env.step(action) 139 | 140 | state = next_state 141 | total_reward += reward 142 | 143 | # env.render(done) 144 | env.render(terminal=done) 145 | 146 | if done: 147 | writer.add_scalar('Test/Episode_Reward', total_reward, n_episode) 148 | rewards.append(total_reward) 149 | makespans.append(env.makespan) 150 | 151 | # Check the scheduling result 152 | fig = env.gantt_plot.draw_gantt(env.makespan) 153 | writer.add_figure('Test/Gantt_Chart', fig, n_episode) 154 | break 155 | 156 | env.close() 157 | 158 | logging.info(f' - Average Reward = {np.mean(rewards)}') 159 | logging.info(f' - Average Makespan = {np.mean(makespans)}') 160 | 161 | 162 | def main(): 163 | import config 164 | ## arguments ## 165 | parser = argparse.ArgumentParser(description=__doc__) 166 | parser.add_argument('-d', '--device', default=config.DEVICE) #'cuda') 167 | parser.add_argument('-m', '--model' , default=config.MODEL) #'model/dqn.pth') 168 | parser.add_argument('--logdir' , default=config.LOG_DIR) #'log/dqn') 169 | # train 170 | parser.add_argument('--warmup' , default=config.WARMUP , type=int) 171 | parser.add_argument('--episode' , default=config.EPISODE , type=int) 172 | parser.add_argument('--capacity' , default=config.CAPACITY , type=int) 173 | parser.add_argument('--batch_size' , default=config.BATCH_SIZE , type=int) 174 | parser.add_argument('--lr' , default=config.LEARNING_R , type=float) 175 | parser.add_argument('--eps_decay' , default=config.EPS_DECAY , type=float) 176 | parser.add_argument('--eps_max' , default=config.EPS_MAX , type=float) 177 | parser.add_argument('--eps_min' , default=config.EPS_MIN , type=float) 178 | parser.add_argument('--eps_period' , default=config.EPS_PERIOD , type=float) 179 | parser.add_argument('--gamma' , default=config.GAMMA , type=float) 180 | parser.add_argument('--freq' , default=config.FREQ , type=int) 181 | parser.add_argument('--target_freq' , default=config.TARGET_FREQ , type=int) 182 | parser.add_argument('--render_episode', default=config.RENDER_EPISODE, type=int) 183 | # test 184 | parser.add_argument('--test_only' , action='store_true') 185 | parser.add_argument('--render' , action='store_true') 186 | parser.add_argument('--seed' , default=config.SEED , type=int) 187 | parser.add_argument('--test_epsilon', default=config.TEST_EPSILON, type=float) 188 | args = parser.parse_args() 189 | 190 | ## main ## 191 | file_name = config.FILE_NAME 192 | file_dir = os.getcwd() + '/simulation_env/input_data' 193 | file_path = os.path.join(file_dir, file_name) 194 | 195 | opt_makespan = config.OPT_MAKESPAN 196 | num_machine = config.NUM_MACHINE 197 | num_job = config.NUM_JOB 198 | 199 | # Agent & Environment 200 | env = Factory(num_job, num_machine, file_path, opt_makespan, log=False) 201 | agent = DQN([env.dim_observations[0].shape, env.dim_observations[1].shape], env.dim_actions, args) 202 | 203 | # Tensorboard to trace the learning process 204 | writer = SummaryWriter(f'log/DQN-{time.time()}') 205 | 206 | ## Train ## 207 | if not args.test_only: 208 | train(args, env, agent, writer) 209 | agent.save(args.model) 210 | 211 | ## Test ## To test the pre-trained model 212 | agent.load(args.model) 213 | test(args, env, agent, writer) 214 | 215 | writer.close() 216 | 217 | 218 | if __name__ == '__main__': 219 | main() 220 | -------------------------------------------------------------------------------- /model/ESTModel.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | ##----- Model for Action-value Prediction ------ 5 | # * Author: Colin, Lee 6 | # * Date: Aug 16th, 2021 7 | ##---------------------------------------------- 8 | import math 9 | import torch 10 | import random 11 | import logging 12 | 13 | import numpy as np 14 | import torch.nn as nn 15 | 16 | from collections import deque 17 | from datetime import datetime as dt 18 | 19 | import pdb 20 | 21 | logging.basicConfig(level=logging.DEBUG) 22 | # ----------------------------------------------- 23 | def __reset_param_impl__(cnn_net): 24 | """ 25 | """ 26 | # --- do init --- 27 | conv = cnn_net.conv 28 | n1 = conv.kernel_size[0] * conv.kernel_size[1] * conv.out_channels 29 | conv.weight.data.normal_(0, math.sqrt(2. / n1)) 30 | 31 | 32 | class ConvBlock(nn.Module): 33 | def __init__(self, cin, cout): 34 | super().__init__() # necessary 35 | self.conv = nn.Conv2d(cin, cout, (3, 3), padding=1) 36 | self.bn = nn.BatchNorm2d(cout) 37 | self.relu = nn.ReLU() 38 | 39 | def reset_param(self): 40 | #normalize the para of cnn network 41 | __reset_param_impl__(self) 42 | 43 | def forward(self, x): 44 | x = self.conv(x) 45 | x = self.bn(x) 46 | x = self.relu(x) 47 | return x 48 | 49 | 50 | class Net(nn.Module): 51 | def __init__(self, state_dim=[(5, 5), (7, 7)], action_dim=4, hidden_dim=128): 52 | super(Net, self).__init__() 53 | 54 | # convolution layers 55 | self.conv_1 = nn.Conv2d(3 , 32, kernel_size=3, stride=1) 56 | self.conv_2 = nn.Conv2d(32, 64, kernel_size=3, stride=1) 57 | self.bn_32 = nn.BatchNorm2d(32) 58 | self.bn_64 = nn.BatchNorm2d(64) 59 | self.relu = nn.ReLU() 60 | self.flatten = nn.Flatten() 61 | 62 | #normalize the para of cnn network 63 | n1 = self.conv_1.kernel_size[0] * self.conv_1.kernel_size[1] * self.conv_1.out_channels 64 | n2 = self.conv_2.kernel_size[0] * self.conv_2.kernel_size[1] * self.conv_2.out_channels 65 | self.conv_1.weight.data.normal_(0, math.sqrt(2. / n1)) 66 | self.conv_2.weight.data.normal_(0, math.sqrt(2. / n2)) 67 | 68 | self.cnn = nn.Sequential( 69 | self.conv_1, self.bn_32, self.relu, #, 70 | self.conv_2, self.bn_64, self.relu 71 | ) 72 | 73 | self.conv_1_est = nn.Conv2d(1 , 32, kernel_size=3, stride=1) 74 | self.cnn_est = nn.Sequential( 75 | self.conv_1_est, self.bn_32, self.relu, #, 76 | self.conv_2 , self.bn_64, self.relu 77 | ) 78 | 79 | # check the output of cnn, which is [fc1_dims] 80 | self.cnn_outputs_len_1 = self.cnn_out_dim_1(state_dim[0]) 81 | self.cnn_outputs_len_2 = self.cnn_out_dim_2(state_dim[1]) 82 | self.fcn_inputs_length = self.cnn_outputs_len_1 + self.cnn_outputs_len_2 83 | 84 | # fully connected layers 85 | self.fc1 = nn.Linear(self.fcn_inputs_length, hidden_dim) 86 | self.fc2 = nn.Linear(hidden_dim , action_dim) 87 | self.fc1.weight.data.normal_(0, 0.1) 88 | self.fc2.weight.data.normal_(0, 0.1) 89 | 90 | self.fcn = nn.Sequential( 91 | self.fc1, self.relu, 92 | self.fc2 93 | ) 94 | 95 | def forward(self, x): 96 | ''' 97 | - x : tensor in shape of (N, state_dim) 98 | ''' 99 | 100 | cnn_out_1 = self.cnn(x[0]) 101 | cnn_out_2 = self.cnn_est(x[1]) 102 | cnn_out_1 = cnn_out_1.reshape(-1, self.cnn_outputs_len_1) 103 | cnn_out_2 = cnn_out_2.reshape(-1, self.cnn_outputs_len_2) 104 | fcn_input_1 = self.flatten(cnn_out_1) 105 | fcn_input_2 = self.flatten(cnn_out_2) 106 | fcn_input = torch.cat((fcn_input_1, fcn_input_2), 0) 107 | actions = self.fcn(fcn_input) 108 | return actions 109 | 110 | def cnn_out_dim_1(self, input_dims): 111 | return self.cnn(torch.zeros(1, *input_dims) 112 | ).flatten().shape[0] 113 | 114 | def cnn_out_dim_2(self, input_dims): 115 | return self.cnn_est(torch.zeros(1, *input_dims) 116 | ).flatten().shape[0] 117 | 118 | -------------------------------------------------------------------------------- /model/FullyNetwork.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | ##----- Model for Action-value Prediction ------ 5 | # * Author: Colin, Lee 6 | # * Date: Aug 16th, 2021 7 | ##---------------------------------------------- 8 | import math 9 | import torch 10 | import random 11 | import logging 12 | 13 | import numpy as np 14 | import torch.nn as nn 15 | 16 | from collections import deque 17 | from datetime import datetime as dt 18 | 19 | import pdb 20 | 21 | logging.basicConfig(level=logging.DEBUG) 22 | # ----------------------------------------------- 23 | 24 | 25 | class Net(nn.Module): 26 | def __init__(self, state_dim=8, action_dim=4, hidden_dim=128): 27 | super(Net, self).__init__() 28 | 29 | self.relu = nn.ReLU() 30 | self.flatten = nn.Flatten() 31 | 32 | # check the output of cnn, which is [fc1_dims] 33 | self.fcn_inputs_length = torch.zeros(1, *state_dim).flatten().shape[0] 34 | 35 | # fully connected layers 36 | self.fc1 = nn.Linear(self.fcn_inputs_length, hidden_dim) 37 | self.fc2 = nn.Linear(hidden_dim , action_dim) 38 | self.fc1.weight.data.normal_(0, 0.1) 39 | self.fc2.weight.data.normal_(0, 0.1) 40 | 41 | self.fcn = nn.Sequential( 42 | self.fc1, self.relu, 43 | self.fc2 44 | ) 45 | 46 | def forward(self, x): 47 | ''' 48 | - x : tensor in shape of (N, state_dim) 49 | ''' 50 | 51 | cnn_out = x.reshape(-1, self.fcn_inputs_length) 52 | fcn_input = self.flatten(cnn_out) 53 | actions = self.fcn(fcn_input) 54 | return actions 55 | 56 | def cnn_out_dim(self, input_dims): 57 | return self.cnn(torch.zeros(1, *input_dims) 58 | ).flatten().shape[0] 59 | 60 | -------------------------------------------------------------------------------- /model/NetworkModel.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | ##----- Model for Action-value Prediction ------ 5 | # * Author: Colin, Lee 6 | # * Date: Aug 16th, 2021 7 | ##---------------------------------------------- 8 | import math 9 | import torch 10 | import random 11 | import logging 12 | 13 | import numpy as np 14 | import torch.nn as nn 15 | 16 | from collections import deque 17 | from datetime import datetime as dt 18 | 19 | import pdb 20 | 21 | logging.basicConfig(level=logging.DEBUG) 22 | # ----------------------------------------------- 23 | def __reset_param_impl__(cnn_net): 24 | """ 25 | """ 26 | # --- do init --- 27 | conv = cnn_net.conv 28 | n1 = conv.kernel_size[0] * conv.kernel_size[1] * conv.out_channels 29 | conv.weight.data.normal_(0, math.sqrt(2. / n1)) 30 | 31 | 32 | class ConvBlock(nn.Module): 33 | def __init__(self, cin, cout): 34 | super().__init__() # necessary 35 | self.conv = nn.Conv2d(cin, cout, (3, 3), padding=1) 36 | self.bn = nn.BatchNorm2d(cout) 37 | self.relu = nn.ReLU() 38 | 39 | def reset_param(self): 40 | #normalize the para of cnn network 41 | __reset_param_impl__(self) 42 | 43 | def forward(self, x): 44 | x = self.conv(x) 45 | x = self.bn(x) 46 | x = self.relu(x) 47 | return x 48 | 49 | 50 | class Net(nn.Module): 51 | def __init__(self, state_dim=8, action_dim=4, hidden_dim=128): 52 | super(Net, self).__init__() 53 | 54 | # convolution layers 55 | self.conv_1 = nn.Conv2d(3 , 32, kernel_size=3, stride=1) 56 | self.conv_2 = nn.Conv2d(32, 64, kernel_size=3, stride=1) 57 | self.bn_32 = nn.BatchNorm2d(32) 58 | self.bn_64 = nn.BatchNorm2d(64) 59 | self.relu = nn.ReLU() 60 | self.flatten = nn.Flatten() 61 | 62 | #normalize the para of cnn network 63 | n1 = self.conv_1.kernel_size[0] * self.conv_1.kernel_size[1] * self.conv_1.out_channels 64 | n2 = self.conv_2.kernel_size[0] * self.conv_2.kernel_size[1] * self.conv_2.out_channels 65 | self.conv_1.weight.data.normal_(0, math.sqrt(2. / n1)) 66 | self.conv_2.weight.data.normal_(0, math.sqrt(2. / n2)) 67 | 68 | self.cnn = nn.Sequential( 69 | self.conv_1, self.bn_32, self.relu, #, 70 | self.conv_2, self.bn_64, self.relu 71 | ) 72 | 73 | # check the output of cnn, which is [fc1_dims] 74 | self.fcn_inputs_length = self.cnn_out_dim(state_dim) 75 | 76 | # fully connected layers 77 | self.fc1 = nn.Linear(self.fcn_inputs_length, hidden_dim) 78 | self.fc2 = nn.Linear(hidden_dim , action_dim) 79 | self.fc1.weight.data.normal_(0, 0.1) 80 | self.fc2.weight.data.normal_(0, 0.1) 81 | 82 | self.fcn = nn.Sequential( 83 | self.fc1, self.relu, 84 | self.fc2 85 | ) 86 | 87 | def forward(self, x): 88 | ''' 89 | - x : tensor in shape of (N, state_dim) 90 | ''' 91 | 92 | cnn_out = self.cnn(x) 93 | cnn_out = cnn_out.reshape(-1, self.fcn_inputs_length) 94 | fcn_input = self.flatten(cnn_out) 95 | actions = self.fcn(fcn_input) 96 | return actions 97 | 98 | def cnn_out_dim(self, input_dims): 99 | return self.cnn(torch.zeros(1, *input_dims) 100 | ).flatten().shape[0] 101 | 102 | -------------------------------------------------------------------------------- /model/NetworkModel_attention.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | ##----- Model for Action-value Prediction ------ 5 | # * Author: Colin, Lee 6 | # * Date: Aug 16th, 2021 7 | ##---------------------------------------------- 8 | import sys 9 | if '..' not in sys.path: 10 | sys.path.append('..') 11 | 12 | import math 13 | import torch 14 | import random 15 | import logging 16 | 17 | import numpy as np 18 | import torch.nn as nn 19 | 20 | 21 | from collections import deque 22 | from einops import rearrange 23 | from datetime import datetime as dt 24 | 25 | import pdb 26 | import config_djss_attention as config 27 | 28 | seed = config.SEED #999 29 | random.seed(seed) 30 | np.random.seed(seed) 31 | torch.manual_seed(seed) 32 | torch.backends.cudnn.deterministic = True 33 | 34 | logging.basicConfig(level=logging.DEBUG) 35 | # ----------------------------------------------- 36 | def __reset_param_impl__(cnn_net): 37 | """ 38 | """ 39 | # --- do init --- 40 | conv = cnn_net.conv 41 | n1 = conv.kernel_size[0] * conv.kernel_size[1] * conv.out_channels 42 | conv.weight.data.normal_(0, math.sqrt(2. / n1)) 43 | 44 | 45 | class ConvBlock(nn.Module): 46 | def __init__(self, cin, cout): 47 | super().__init__() # necessary 48 | self.conv = nn.Conv2d(cin, cout, (3, 3), padding=1) 49 | self.bn = nn.BatchNorm2d(cout) 50 | self.relu = nn.ReLU() 51 | 52 | def reset_param(self): 53 | #normalize the para of cnn network 54 | __reset_param_impl__(self) 55 | 56 | def forward(self, x): 57 | x = self.conv(x) 58 | x = self.bn(x) 59 | x = self.relu(x) 60 | return x 61 | 62 | 63 | class MultiHeadRelationalModule(torch.nn.Module): 64 | def __init__(self, input_dim=(4, 100, 6), output_dim=10 ,device='cuda'): 65 | super(MultiHeadRelationalModule, self).__init__() 66 | self.device = torch.device(device) 67 | # self.conv1_ch = 64 68 | # self.conv2_ch = 32 69 | # self.conv3_ch = 24 70 | # self.conv4_ch = 30 71 | self.ch_in = input_dim[0] 72 | self.n_heads = 4 73 | self.node_size = 64 # dimension of nodes after passing the relational module 74 | self.sp_coord_dim = 2 75 | self.input_height = input_dim[1] 76 | self.input_width = input_dim[2] 77 | self.n_cout_pixel = int(self.input_height * self.input_width) #number of nodes (pixels num after passing the cnn) 78 | self.out_dim = output_dim 79 | # self.lin_hid = 100 80 | self.conv2_ch = self.ch_in 81 | 82 | self.proj_shape = ( 83 | (self.conv2_ch + self.sp_coord_dim), 84 | (self.n_heads * self.node_size) 85 | ) 86 | self.node_shape = ( 87 | self.n_heads, 88 | self.n_cout_pixel, 89 | self.node_size 90 | ) 91 | 92 | # self.conv_1 = nn.Conv2d(self.ch_in , self.conv1_ch, kernel_size=(1,1), padding=0) #A 93 | # self.conv_2 = nn.Conv2d(self.conv1_ch, self.conv2_ch, kernel_size=(1,1), padding=0) 94 | # #normalize the para of cnn network 95 | # n1 = self.conv_1.kernel_size[0] * self.conv_1.kernel_size[1] * self.conv_1.out_channels 96 | # n2 = self.conv_2.kernel_size[0] * self.conv_2.kernel_size[1] * self.conv_2.out_channels 97 | # self.conv_1.weight.data.normal_(0, math.sqrt(2. / n1)) 98 | # self.conv_2.weight.data.normal_(0, math.sqrt(2. / n2)) 99 | 100 | self.k_proj = nn.Linear(*self.proj_shape) 101 | self.q_proj = nn.Linear(*self.proj_shape) 102 | self.v_proj = nn.Linear(*self.proj_shape) 103 | 104 | self.k_lin = nn.Linear(self.node_size , self.n_cout_pixel) #B 105 | self.q_lin = nn.Linear(self.node_size , self.n_cout_pixel) 106 | self.a_lin = nn.Linear(self.n_cout_pixel, self.n_cout_pixel) 107 | 108 | self.k_norm = nn.LayerNorm(self.node_shape, elementwise_affine=True) 109 | self.q_norm = nn.LayerNorm(self.node_shape, elementwise_affine=True) 110 | self.v_norm = nn.LayerNorm(self.node_shape, elementwise_affine=True) 111 | 112 | self.linear1 = nn.Linear((self.n_heads * self.node_size), self.node_size) 113 | self.norm1 = nn.LayerNorm([self.n_cout_pixel, self.node_size], elementwise_affine=False) 114 | self.linear2 = nn.Linear(self.node_size, self.out_dim) 115 | 116 | self.relu = nn.ReLU() 117 | 118 | # self.cnn = nn.Sequential( 119 | # self.conv_1, self.relu, 120 | # self.conv_2, self.relu 121 | # ) 122 | 123 | def forward(self, x): 124 | N, Cin, H, W = x.shape 125 | # x = self.cnn(x) 126 | # # for visualization 127 | # with torch.no_grad(): 128 | # self.conv_map = x.clone() #C 129 | 130 | # Appends the (x, y) coordinates of each node to its feature vector and normalized to [0,1] 131 | _, _, cH, cW = x.shape 132 | xcoords = torch.arange(cW).repeat(cH, 1).float() / cW 133 | ycoords = torch.arange(cH).repeat(cW, 1).transpose(1, 0).float() / cH 134 | spatial_coords = torch.stack([xcoords, ycoords], dim=0) 135 | spatial_coords = spatial_coords.unsqueeze(dim=0) 136 | spatial_coords = spatial_coords.repeat(N, 1, 1, 1) 137 | 138 | x = torch.cat([x, spatial_coords.to(self.device)], dim=1) 139 | x = x.permute(0, 2, 3, 1) 140 | x = x.flatten(1, 2) 141 | 142 | # K, Q, V 143 | K = rearrange(self.k_proj(x), "b n (head d) -> b head n d", head=self.n_heads) 144 | K = self.k_norm(K) 145 | Q = rearrange(self.q_proj(x), "b n (head d) -> b head n d", head=self.n_heads) 146 | Q = self.q_norm(Q) 147 | V = rearrange(self.v_proj(x), "b n (head d) -> b head n d", head=self.n_heads) 148 | V = self.v_norm(V) 149 | 150 | del x 151 | 152 | # Compatibility function 153 | A = torch.nn.functional.elu(self.q_lin(Q) + self.k_lin(K)) #D 154 | A = self.a_lin(A) 155 | A = torch.nn.functional.softmax(A, dim=3) 156 | # # for visualization 157 | # with torch.no_grad(): 158 | # self.att_map = A.clone() #E 159 | 160 | del Q; del K 161 | 162 | # Multi-head attention 163 | E = torch.einsum('bhfc, bhcd->bhfd', A, V) #F 164 | E = rearrange(E, 'b head n d -> b n (head d)') 165 | 166 | del A; del V 167 | 168 | # Linear forward 169 | E = self.linear1(E) 170 | E = self.relu(E) 171 | E = self.norm1(E) 172 | E = E.max(dim=1)[0] 173 | y = self.linear2(E) 174 | return torch.nn.functional.elu(y) 175 | # y = torch.nn.functional.elu(y) 176 | # return y 177 | 178 | 179 | class Net(nn.Module): 180 | def __init__(self, state_dim=(3, 8, 8), action_dim=4, hidden_dim=128): 181 | super(Net, self).__init__() 182 | cin = state_dim[0] 183 | # convolution layers 184 | self.conv_1 = nn.Conv2d(cin , 64, kernel_size=(1,3), padding=1, stride=1) 185 | self.conv_2 = nn.Conv2d(64, 32, kernel_size=3, stride=1) 186 | self.conv_3 = nn.Conv2d(32, 16, kernel_size=2, stride=1) 187 | self.bn_16 = nn.BatchNorm2d(16) 188 | self.bn_32 = nn.BatchNorm2d(32) 189 | self.bn_64 = nn.BatchNorm2d(64) 190 | self.relu = nn.LeakyReLU() 191 | self.flatten = nn.Flatten() 192 | 193 | #normalize the para of cnn network 194 | n1 = self.conv_1.kernel_size[0] * self.conv_1.kernel_size[1] * self.conv_1.out_channels 195 | n2 = self.conv_2.kernel_size[0] * self.conv_2.kernel_size[1] * self.conv_2.out_channels 196 | n3 = self.conv_3.kernel_size[0] * self.conv_3.kernel_size[1] * self.conv_3.out_channels 197 | self.conv_1.weight.data.normal_(0, math.sqrt(2. / n1)) 198 | self.conv_2.weight.data.normal_(0, math.sqrt(2. / n2)) 199 | self.conv_3.weight.data.normal_(0, math.sqrt(2. / n3)) 200 | 201 | self.cnn = nn.Sequential( 202 | self.conv_1, self.bn_64, self.relu, 203 | self.conv_2, self.bn_32, self.relu, 204 | self.conv_3, self.bn_16, self.relu 205 | ) 206 | 207 | # check the output of cnn, which is [fc1_dims] 208 | self.fcn_inputs_length = self.cnn_out_dim(state_dim) 209 | 210 | # fully connected layers 211 | self.fc1 = nn.Linear(self.fcn_inputs_length, hidden_dim) 212 | self.fc2 = nn.Linear(hidden_dim , action_dim) 213 | self.fc1.weight.data.normal_(0, 0.1) 214 | self.fc2.weight.data.normal_(0, 0.1) 215 | 216 | self.bn_fc1 = nn.BatchNorm1d(hidden_dim) 217 | # self.bn_fc2 = nn.BatchNorm1d(64) 218 | 219 | self.fcn = nn.Sequential( 220 | self.fc1, self.bn_fc1, self.relu, 221 | self.fc2 222 | ) 223 | 224 | def forward(self, x): 225 | ''' 226 | - x : tensor in shape of (N, state_dim) 227 | ''' 228 | 229 | cnn_out = self.cnn(x) 230 | cnn_out = cnn_out.reshape(-1, self.fcn_inputs_length) 231 | fcn_input = self.flatten(cnn_out) 232 | actions = self.fcn(fcn_input) 233 | return actions 234 | 235 | def cnn_out_dim(self, input_dims): 236 | return self.cnn(torch.zeros(1, *input_dims) 237 | ).flatten().shape[0] 238 | 239 | -------------------------------------------------------------------------------- /model/NetworkModel_attention_paper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | ##----- Model for Action-value Prediction ------ 5 | # * Author: Colin, Lee 6 | # * Date: Aug 16th, 2021 7 | ##---------------------------------------------- 8 | import sys 9 | if '..' not in sys.path: 10 | sys.path.append('..') 11 | 12 | import math 13 | import torch 14 | import random 15 | import logging 16 | 17 | import numpy as np 18 | import torch.nn as nn 19 | 20 | 21 | from collections import deque 22 | from einops import rearrange 23 | from datetime import datetime as dt 24 | 25 | import pdb 26 | import config_djss_attention as config 27 | 28 | seed = config.SEED #999 29 | random.seed(seed) 30 | np.random.seed(seed) 31 | torch.manual_seed(seed) 32 | torch.backends.cudnn.deterministic = True 33 | 34 | logging.basicConfig(level=logging.DEBUG) 35 | # ----------------------------------------------- 36 | def __reset_param_impl__(cnn_net): 37 | """ 38 | """ 39 | # --- do init --- 40 | conv = cnn_net.conv 41 | n1 = conv.kernel_size[0] * conv.kernel_size[1] * conv.out_channels 42 | conv.weight.data.normal_(0, math.sqrt(2. / n1)) 43 | 44 | 45 | class ConvBlock(nn.Module): 46 | def __init__(self, cin, cout): 47 | super().__init__() # necessary 48 | self.conv = nn.Conv2d(cin, cout, (3, 3), padding=1) 49 | self.bn = nn.BatchNorm2d(cout) 50 | self.relu = nn.ReLU() 51 | 52 | def reset_param(self): 53 | #normalize the para of cnn network 54 | __reset_param_impl__(self) 55 | 56 | def forward(self, x): 57 | x = self.conv(x) 58 | x = self.bn(x) 59 | x = self.relu(x) 60 | return x 61 | 62 | 63 | class MultiHeadRelationalModule(torch.nn.Module): 64 | def __init__(self, input_dim=(4, 100, 6), output_dim=10 ,device='cuda'): 65 | super(MultiHeadRelationalModule, self).__init__() 66 | self.device = torch.device(device) 67 | self.conv1_ch = 16#64 68 | self.conv2_ch = 32 69 | # self.conv3_ch = 24 70 | # self.conv4_ch = 30 71 | self.ch_in = input_dim[0] 72 | self.n_heads = 4 73 | self.node_size = 64 # dimension of nodes after passing the relational module 74 | self.sp_coord_dim = 2 75 | self.input_height = input_dim[1] 76 | self.input_width = input_dim[2] 77 | self.n_cout_pixel = 596#int(self.input_height * self.input_width) #number of nodes (pixels num after passing the cnn) 78 | self.out_dim = output_dim 79 | # self.lin_hid = 100 80 | # self.conv2_ch = self.ch_in 81 | self.conv2_ch = 32 82 | 83 | self.proj_shape = ( 84 | (self.conv2_ch + self.sp_coord_dim), 85 | (self.n_heads * self.node_size) 86 | ) 87 | self.node_shape = ( 88 | self.n_heads, 89 | self.n_cout_pixel, 90 | self.node_size 91 | ) 92 | 93 | self.conv_1 = nn.Conv2d(self.ch_in , self.conv1_ch, kernel_size=(2,2), padding=0) #A 94 | self.conv_2 = nn.Conv2d(self.conv1_ch, self.conv2_ch, kernel_size=(2,2), padding=0) 95 | #normalize the para of cnn network 96 | n1 = self.conv_1.kernel_size[0] * self.conv_1.kernel_size[1] * self.conv_1.out_channels 97 | n2 = self.conv_2.kernel_size[0] * self.conv_2.kernel_size[1] * self.conv_2.out_channels 98 | self.conv_1.weight.data.normal_(0, math.sqrt(2. / n1)) 99 | self.conv_2.weight.data.normal_(0, math.sqrt(2. / n2)) 100 | 101 | self.k_proj = nn.Linear(*self.proj_shape) 102 | self.q_proj = nn.Linear(*self.proj_shape) 103 | self.v_proj = nn.Linear(*self.proj_shape) 104 | 105 | self.k_lin = nn.Linear(self.node_size , self.n_cout_pixel) #B 106 | self.q_lin = nn.Linear(self.node_size , self.n_cout_pixel) 107 | self.a_lin = nn.Linear(self.n_cout_pixel, self.n_cout_pixel) 108 | 109 | self.k_norm = nn.LayerNorm(self.node_shape, elementwise_affine=True) 110 | self.q_norm = nn.LayerNorm(self.node_shape, elementwise_affine=True) 111 | self.v_norm = nn.LayerNorm(self.node_shape, elementwise_affine=True) 112 | 113 | self.linear1 = nn.Linear((self.n_heads * self.node_size), self.node_size) 114 | self.norm1 = nn.LayerNorm([self.n_cout_pixel, self.node_size], elementwise_affine=False) 115 | self.linear2 = nn.Linear(self.node_size, self.out_dim) 116 | 117 | self.relu = nn.ReLU() 118 | 119 | self.cnn = nn.Sequential( 120 | self.conv_1, self.relu, 121 | self.conv_2, self.relu 122 | ) 123 | 124 | def forward(self, x): 125 | N, Cin, H, W = x.shape 126 | x = self.cnn(x) 127 | # # for visualization 128 | # with torch.no_grad(): 129 | # self.conv_map = x.clone() #C 130 | 131 | # Appends the (x, y) coordinates of each node to its feature vector and normalized to [0,1] 132 | _, _, cH, cW = x.shape 133 | xcoords = torch.arange(cW).repeat(cH, 1).float() / cW 134 | ycoords = torch.arange(cH).repeat(cW, 1).transpose(1, 0).float() / cH 135 | spatial_coords = torch.stack([xcoords, ycoords], dim=0) 136 | spatial_coords = spatial_coords.unsqueeze(dim=0) 137 | spatial_coords = spatial_coords.repeat(N, 1, 1, 1) 138 | 139 | x = torch.cat([x, spatial_coords.to(self.device)], dim=1) 140 | x = x.permute(0, 2, 3, 1) 141 | x = x.flatten(1, 2) 142 | 143 | # K, Q, V 144 | K = rearrange(self.k_proj(x), "b n (head d) -> b head n d", head=self.n_heads) 145 | K = self.k_norm(K) 146 | Q = rearrange(self.q_proj(x), "b n (head d) -> b head n d", head=self.n_heads) 147 | Q = self.q_norm(Q) 148 | V = rearrange(self.v_proj(x), "b n (head d) -> b head n d", head=self.n_heads) 149 | V = self.v_norm(V) 150 | 151 | del x 152 | 153 | # Compatibility function 154 | A = torch.nn.functional.elu(self.q_lin(Q) + self.k_lin(K)) #D 155 | A = self.a_lin(A) 156 | A = torch.nn.functional.softmax(A, dim=3) 157 | # # for visualization 158 | # with torch.no_grad(): 159 | # self.att_map = A.clone() #E 160 | 161 | del Q; del K 162 | 163 | # Multi-head attention 164 | E = torch.einsum('bhfc, bhcd->bhfd', A, V) #F 165 | E = rearrange(E, 'b head n d -> b n (head d)') 166 | 167 | del A; del V 168 | 169 | # Linear forward 170 | E = self.linear1(E) 171 | E = self.relu(E) 172 | E = self.norm1(E) 173 | E = E.max(dim=1)[0] 174 | y = self.linear2(E) 175 | return torch.nn.functional.elu(y) 176 | # y = torch.nn.functional.elu(y) 177 | # return y 178 | 179 | 180 | class Net(nn.Module): 181 | def __init__(self, state_dim=(3, 8, 8), action_dim=4, hidden_dim=128): 182 | super(Net, self).__init__() 183 | cin = state_dim[0] 184 | # convolution layers 185 | self.conv_1 = nn.Conv2d(cin , 64, kernel_size=(1,3), padding=1, stride=1) 186 | self.conv_2 = nn.Conv2d(64, 32, kernel_size=3, stride=1) 187 | self.conv_3 = nn.Conv2d(32, 16, kernel_size=2, stride=1) 188 | self.bn_16 = nn.BatchNorm2d(16) 189 | self.bn_32 = nn.BatchNorm2d(32) 190 | self.bn_64 = nn.BatchNorm2d(64) 191 | self.relu = nn.LeakyReLU() 192 | self.flatten = nn.Flatten() 193 | 194 | #normalize the para of cnn network 195 | n1 = self.conv_1.kernel_size[0] * self.conv_1.kernel_size[1] * self.conv_1.out_channels 196 | n2 = self.conv_2.kernel_size[0] * self.conv_2.kernel_size[1] * self.conv_2.out_channels 197 | n3 = self.conv_3.kernel_size[0] * self.conv_3.kernel_size[1] * self.conv_3.out_channels 198 | self.conv_1.weight.data.normal_(0, math.sqrt(2. / n1)) 199 | self.conv_2.weight.data.normal_(0, math.sqrt(2. / n2)) 200 | self.conv_3.weight.data.normal_(0, math.sqrt(2. / n3)) 201 | 202 | self.cnn = nn.Sequential( 203 | self.conv_1, self.bn_64, self.relu, 204 | self.conv_2, self.bn_32, self.relu, 205 | self.conv_3, self.bn_16, self.relu 206 | ) 207 | 208 | # check the output of cnn, which is [fc1_dims] 209 | self.fcn_inputs_length = self.cnn_out_dim(state_dim) 210 | 211 | # fully connected layers 212 | self.fc1 = nn.Linear(self.fcn_inputs_length, hidden_dim) 213 | self.fc2 = nn.Linear(hidden_dim , action_dim) 214 | self.fc1.weight.data.normal_(0, 0.1) 215 | self.fc2.weight.data.normal_(0, 0.1) 216 | 217 | self.bn_fc1 = nn.BatchNorm1d(hidden_dim) 218 | # self.bn_fc2 = nn.BatchNorm1d(64) 219 | 220 | self.fcn = nn.Sequential( 221 | self.fc1, self.bn_fc1, self.relu, 222 | self.fc2 223 | ) 224 | 225 | def forward(self, x): 226 | ''' 227 | - x : tensor in shape of (N, state_dim) 228 | ''' 229 | 230 | cnn_out = self.cnn(x) 231 | cnn_out = cnn_out.reshape(-1, self.fcn_inputs_length) 232 | fcn_input = self.flatten(cnn_out) 233 | actions = self.fcn(fcn_input) 234 | return actions 235 | 236 | def cnn_out_dim(self, input_dims): 237 | return self.cnn(torch.zeros(1, *input_dims) 238 | ).flatten().shape[0] 239 | 240 | -------------------------------------------------------------------------------- /model/NetworkModel_attention_paper1.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | ##----- Model for Action-value Prediction ------ 5 | # * Author: Colin, Lee 6 | # * Date: Aug 16th, 2021 7 | ##---------------------------------------------- 8 | import sys 9 | if '..' not in sys.path: 10 | sys.path.append('..') 11 | 12 | import math 13 | import torch 14 | import random 15 | import logging 16 | 17 | import numpy as np 18 | import torch.nn as nn 19 | 20 | 21 | from collections import deque 22 | from einops import rearrange 23 | from datetime import datetime as dt 24 | 25 | import pdb 26 | import config_djss_attention as config 27 | 28 | seed = config.SEED #999 29 | random.seed(seed) 30 | np.random.seed(seed) 31 | torch.manual_seed(seed) 32 | torch.backends.cudnn.deterministic = True 33 | 34 | logging.basicConfig(level=logging.DEBUG) 35 | # ----------------------------------------------- 36 | def __reset_param_impl__(cnn_net): 37 | """ 38 | """ 39 | # --- do init --- 40 | conv = cnn_net.conv 41 | n1 = conv.kernel_size[0] * conv.kernel_size[1] * conv.out_channels 42 | conv.weight.data.normal_(0, math.sqrt(2. / n1)) 43 | 44 | 45 | class ConvBlock(nn.Module): 46 | def __init__(self, cin, cout): 47 | super().__init__() # necessary 48 | self.conv = nn.Conv2d(cin, cout, (3, 3), padding=1) 49 | self.bn = nn.BatchNorm2d(cout) 50 | self.relu = nn.ReLU() 51 | 52 | def reset_param(self): 53 | #normalize the para of cnn network 54 | __reset_param_impl__(self) 55 | 56 | def forward(self, x): 57 | x = self.conv(x) 58 | x = self.bn(x) 59 | x = self.relu(x) 60 | return x 61 | 62 | 63 | class MultiHeadRelationalModule(torch.nn.Module): 64 | def __init__(self, input_dim=(4, 100, 6), output_dim=10 ,device='cuda'): 65 | super(MultiHeadRelationalModule, self).__init__() 66 | self.device = torch.device(device) 67 | self.conv1_ch = 16#64 68 | self.conv2_ch = 32 69 | # self.conv3_ch = 24 70 | # self.conv4_ch = 30 71 | self.ch_in = input_dim[0] 72 | self.n_heads = 4 73 | self.node_size = 64 # dimension of nodes after passing the relational module 74 | self.sp_coord_dim = 2 75 | self.input_height = input_dim[1] 76 | self.input_width = input_dim[2] 77 | # self.n_cout_pixel = 400#int(self.input_height * self.input_width) #number of nodes (pixels num after passing the cnn) 78 | self.n_cout_pixel = int(self.input_height * self.input_width) #number of nodes (pixels num after passing the cnn) 79 | self.out_dim = output_dim 80 | # self.lin_hid = 100 81 | self.conv2_ch = self.ch_in 82 | # self.conv2_ch = 32 83 | 84 | self.proj_shape = ( 85 | (self.conv2_ch + self.sp_coord_dim), 86 | (self.n_heads * self.node_size) 87 | ) 88 | self.node_shape = ( 89 | self.n_heads, 90 | self.n_cout_pixel, 91 | self.node_size 92 | ) 93 | 94 | # self.conv_1 = nn.Conv2d(self.ch_in , self.conv1_ch, kernel_size=(1,2), padding=0) #A 95 | # self.conv_2 = nn.Conv2d(self.conv1_ch, self.conv2_ch, kernel_size=(1,2), padding=0) 96 | #normalize the para of cnn network 97 | # n1 = self.conv_1.kernel_size[0] * self.conv_1.kernel_size[1] * self.conv_1.out_channels 98 | # n2 = self.conv_2.kernel_size[0] * self.conv_2.kernel_size[1] * self.conv_2.out_channels 99 | # self.conv_1.weight.data.normal_(0, math.sqrt(2. / n1)) 100 | # self.conv_2.weight.data.normal_(0, math.sqrt(2. / n2)) 101 | 102 | self.k_proj = nn.Linear(*self.proj_shape) 103 | self.q_proj = nn.Linear(*self.proj_shape) 104 | self.v_proj = nn.Linear(*self.proj_shape) 105 | 106 | self.k_lin = nn.Linear(self.node_size , self.n_cout_pixel) #B 107 | self.q_lin = nn.Linear(self.node_size , self.n_cout_pixel) 108 | self.a_lin = nn.Linear(self.n_cout_pixel, self.n_cout_pixel) 109 | 110 | self.k_norm = nn.LayerNorm(self.node_shape, elementwise_affine=True) 111 | self.q_norm = nn.LayerNorm(self.node_shape, elementwise_affine=True) 112 | self.v_norm = nn.LayerNorm(self.node_shape, elementwise_affine=True) 113 | 114 | self.linear1 = nn.Linear((self.n_heads * self.node_size), self.node_size) 115 | self.norm1 = nn.LayerNorm([self.n_cout_pixel, self.node_size], elementwise_affine=False) 116 | self.linear2 = nn.Linear(self.node_size, self.out_dim) 117 | 118 | self.relu = nn.ReLU() 119 | 120 | # self.cnn = nn.Sequential( 121 | # self.conv_1, self.relu, 122 | # self.conv_2, self.relu 123 | # ) 124 | 125 | def forward(self, x): 126 | N, Cin, H, W = x.shape 127 | # x = self.cnn(x) 128 | # # for visualization 129 | # with torch.no_grad(): 130 | # self.conv_map = x.clone() #C 131 | 132 | # Appends the (x, y) coordinates of each node to its feature vector and normalized to [0,1] 133 | _, _, cH, cW = x.shape 134 | xcoords = torch.arange(cW).repeat(cH, 1).float() / cW 135 | ycoords = torch.arange(cH).repeat(cW, 1).transpose(1, 0).float() / cH 136 | spatial_coords = torch.stack([xcoords, ycoords], dim=0) 137 | spatial_coords = spatial_coords.unsqueeze(dim=0) 138 | spatial_coords = spatial_coords.repeat(N, 1, 1, 1) 139 | 140 | x = torch.cat([x, spatial_coords.to(self.device)], dim=1) 141 | x = x.permute(0, 2, 3, 1) 142 | x = x.flatten(1, 2) 143 | 144 | # K, Q, V 145 | K = rearrange(self.k_proj(x), "b n (head d) -> b head n d", head=self.n_heads) 146 | K = self.k_norm(K) 147 | Q = rearrange(self.q_proj(x), "b n (head d) -> b head n d", head=self.n_heads) 148 | Q = self.q_norm(Q) 149 | V = rearrange(self.v_proj(x), "b n (head d) -> b head n d", head=self.n_heads) 150 | V = self.v_norm(V) 151 | 152 | del x 153 | 154 | # Compatibility function 155 | A = torch.nn.functional.elu(self.q_lin(Q) + self.k_lin(K)) #D 156 | A = self.a_lin(A) 157 | A = torch.nn.functional.softmax(A, dim=3) 158 | # # for visualization 159 | # with torch.no_grad(): 160 | # self.att_map = A.clone() #E 161 | 162 | del Q; del K 163 | 164 | # Multi-head attention 165 | E = torch.einsum('bhfc, bhcd->bhfd', A, V) #F 166 | E = rearrange(E, 'b head n d -> b n (head d)') 167 | 168 | del A; del V 169 | 170 | # Linear forward 171 | E = self.linear1(E) 172 | E = self.relu(E) 173 | E = self.norm1(E) 174 | E = E.max(dim=1)[0] 175 | y = self.linear2(E) 176 | return torch.nn.functional.elu(y) 177 | # y = torch.nn.functional.elu(y) 178 | # return y 179 | 180 | 181 | class Net(nn.Module): 182 | def __init__(self, state_dim=(3, 8, 8), action_dim=4, hidden_dim=128): 183 | super(Net, self).__init__() 184 | cin = state_dim[0] 185 | # convolution layers 186 | self.conv_1 = nn.Conv2d(cin , 64, kernel_size=(1,3), padding=1, stride=1) 187 | self.conv_2 = nn.Conv2d(64, 32, kernel_size=3, stride=1) 188 | self.conv_3 = nn.Conv2d(32, 16, kernel_size=2, stride=1) 189 | self.bn_16 = nn.BatchNorm2d(16) 190 | self.bn_32 = nn.BatchNorm2d(32) 191 | self.bn_64 = nn.BatchNorm2d(64) 192 | self.relu = nn.LeakyReLU() 193 | self.flatten = nn.Flatten() 194 | 195 | #normalize the para of cnn network 196 | n1 = self.conv_1.kernel_size[0] * self.conv_1.kernel_size[1] * self.conv_1.out_channels 197 | n2 = self.conv_2.kernel_size[0] * self.conv_2.kernel_size[1] * self.conv_2.out_channels 198 | n3 = self.conv_3.kernel_size[0] * self.conv_3.kernel_size[1] * self.conv_3.out_channels 199 | self.conv_1.weight.data.normal_(0, math.sqrt(2. / n1)) 200 | self.conv_2.weight.data.normal_(0, math.sqrt(2. / n2)) 201 | self.conv_3.weight.data.normal_(0, math.sqrt(2. / n3)) 202 | 203 | self.cnn = nn.Sequential( 204 | self.conv_1, self.bn_64, self.relu, 205 | self.conv_2, self.bn_32, self.relu, 206 | self.conv_3, self.bn_16, self.relu 207 | ) 208 | 209 | # check the output of cnn, which is [fc1_dims] 210 | self.fcn_inputs_length = self.cnn_out_dim(state_dim) 211 | 212 | # fully connected layers 213 | self.fc1 = nn.Linear(self.fcn_inputs_length, hidden_dim) 214 | self.fc2 = nn.Linear(hidden_dim , action_dim) 215 | self.fc1.weight.data.normal_(0, 0.1) 216 | self.fc2.weight.data.normal_(0, 0.1) 217 | 218 | self.bn_fc1 = nn.BatchNorm1d(hidden_dim) 219 | # self.bn_fc2 = nn.BatchNorm1d(64) 220 | 221 | self.fcn = nn.Sequential( 222 | self.fc1, self.bn_fc1, self.relu, 223 | self.fc2 224 | ) 225 | 226 | def forward(self, x): 227 | ''' 228 | - x : tensor in shape of (N, state_dim) 229 | ''' 230 | 231 | cnn_out = self.cnn(x) 232 | cnn_out = cnn_out.reshape(-1, self.fcn_inputs_length) 233 | fcn_input = self.flatten(cnn_out) 234 | actions = self.fcn(fcn_input) 235 | return actions 236 | 237 | def cnn_out_dim(self, input_dims): 238 | return self.cnn(torch.zeros(1, *input_dims) 239 | ).flatten().shape[0] 240 | 241 | -------------------------------------------------------------------------------- /model/dqn-ft06-aciton-12.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/colinlee0924/DRL-SimPy-JobShop/9e00ef98f99c4518a99a30b80589cce32786bcf4/model/dqn-ft06-aciton-12.pth -------------------------------------------------------------------------------- /per_agent.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | ##----- [Algorithm] DQN Implementation ------ 5 | # * Author: Colin, Lee 6 | # * Date: Aug 16th, 2021 7 | ##------------------------------------------- 8 | # 9 | import time 10 | import torch 11 | import random 12 | import logging 13 | import argparse 14 | import itertools 15 | 16 | import numpy as np 17 | import torch.nn as nn 18 | 19 | from tensorboardX import SummaryWriter 20 | from datetime import datetime as dt 21 | from torch import Variable 22 | 23 | # from model.NetworkModel import Net 24 | from model.FullyNetwork import Net 25 | from utils.MemeryBuffer import ReplayMemory 26 | 27 | import pdb 28 | 29 | logging.basicConfig(level=logging.DEBUG) 30 | # ----------------------------------------------- 31 | 32 | class DQN: 33 | def __init__(self, dim_state, dim_action, args): 34 | ## config ## 35 | self.device = torch.device(args.device) 36 | self.batch_size = args.batch_size 37 | self.gamma = args.gamma 38 | self.freq = args.freq 39 | self.target_freq = args.target_freq 40 | 41 | self._behavior_net = Net(dim_state, dim_action).to(self.device) 42 | self._target_net = Net(dim_state, dim_action).to(self.device) 43 | # ------------------------------------------- 44 | # initialize target network 45 | # ------------------------------------------- 46 | self._target_net.load_state_dict(self._behavior_net.state_dict())#, map_location=self.device) 47 | 48 | self._optimizer = torch.optim.RMSprop( 49 | self._behavior_net.parameters(), 50 | lr=args.lr 51 | ) 52 | self._criteria = nn.MSELoss() 53 | # memory 54 | self._memory = ReplayMemory(capacity=args.capacity) 55 | 56 | def select_best_action(self, state): 57 | ''' 58 | - state: (state_dim, ) 59 | ''' 60 | during_train = self._behavior_net.training 61 | if during_train: 62 | self.eval() 63 | state = torch.Tensor(state).to(self.device) 64 | state = DQN.reshape_input_state(state) 65 | with torch.no_grad(): 66 | qvars = self._behavior_net(state) # (1, act_dim) 67 | action = torch.argmax(qvars, dim=-1) # (1, ) 68 | 69 | if during_train: 70 | self.train() 71 | 72 | return action.item() 73 | 74 | def select_action(self, state, epsilon, action_space): 75 | ''' 76 | epsilon-greedy based on behavior network 77 | 78 | -state = (state_dim, ) 79 | ''' 80 | if random.random() < epsilon: 81 | return action_space.sample() 82 | else: 83 | return self.select_best_action(state) 84 | 85 | def append(self, state, action, reward, next_state, done): 86 | 87 | q_values = self._behavior_net(state) # (N, act_dim) 88 | q_value = torch.gather(input=q_values, dim=-1, index=action.long()) # (N, 1) 89 | with torch.no_grad(): 90 | qs_next = self._target_net(next_state) # (N, act_dim) 91 | q_next, act = torch.max(qs_next, dim=-1, keepdim=True) # (N, 1) 92 | q_target = self.gamma*q_next*tocont + reward 93 | 94 | loss = self._criteria(q_value, q_target) 95 | 96 | self._memory.append( 97 | state, 98 | [action], 99 | [reward],# / 10], 100 | next_state, 101 | [1 - int(done)] 102 | ) 103 | 104 | def update(self, total_steps): 105 | if total_steps % self.freq == 0: 106 | return self._update_behavior_network(self.gamma) 107 | if total_steps % self.target_freq == 0: 108 | return self._update_target_network() 109 | 110 | def _update_behavior_network(self, gamma): 111 | # sample a minibatch of transitions 112 | ret = self._memory.sample(self.batch_size, self.device) 113 | state, action, reward, next_state, tocont = ret 114 | 115 | q_values = self._behavior_net(state) # (N, act_dim) 116 | q_value = torch.gather(input=q_values, dim=-1, index=action.long()) # (N, 1) 117 | with torch.no_grad(): 118 | qs_next = self._target_net(next_state) # (N, act_dim) 119 | q_next, act = torch.max(qs_next, dim=-1, keepdim=True) # (N, 1) 120 | q_target = gamma*q_next*tocont + reward 121 | 122 | loss = self._criteria(q_value, q_target) 123 | 124 | # optimize 125 | self._optimizer.zero_grad() 126 | loss.backward() 127 | nn.utils.clip_grad_norm_(self._behavior_net.parameters(), 5) 128 | self._optimizer.step() 129 | 130 | return loss.item() 131 | 132 | def _update_target_network(self): 133 | ''' 134 | update target network by copying from behavior network 135 | ''' 136 | self._target_net.load_state_dict(self._behavior_net.state_dict()) 137 | return None 138 | 139 | def save(self, model_path, checkpoint=False): 140 | if checkpoint: 141 | torch.save( 142 | { 143 | 'behavior_net': self._behavior_net.state_dict(), 144 | 'target_net': self._target_net.state_dict(), 145 | 'optimizer': self._optimizer.state_dict(), 146 | }, model_path) 147 | else: 148 | torch.save({ 149 | 'behavior_net': self._behavior_net.state_dict(), 150 | }, model_path) 151 | 152 | def load(self, model_path, checkpoint=False): 153 | model = torch.load(model_path, map_location=self.device) 154 | self._behavior_net.load_state_dict(model['behavior_net'])#, map_location=self.device) 155 | if checkpoint: 156 | self._target_net.load_state_dict(model['target_net'])#, map_location=self.device) 157 | self._optimizer.load_state_dict(model['optimizer'])# , map_location=self.device) 158 | 159 | def train(self): 160 | self._behavior_net.train() 161 | self._target_net.eval() 162 | 163 | def eval(self): 164 | self._behavior_net.eval() 165 | self._target_net.eval() 166 | 167 | @staticmethod 168 | def reshape_input_state(state): 169 | state_shape = len(state.shape) 170 | state = state.unsqueeze(0) 171 | 172 | return state 173 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.10.0 2 | aiohttp==3.6.2 3 | async-timeout==3.0.1 4 | attrs==20.2.0 5 | box2d-py==2.3.8 6 | cachetools==4.1.1 7 | certifi==2020.6.20 8 | chardet==3.0.4 9 | cloudpickle==1.3.0 10 | cycler==0.10.0 11 | et-xmlfile==1.1.0 12 | future==0.18.2 13 | google-auth==1.22.0 14 | google-auth-oauthlib==0.4.1 15 | grpcio==1.32.0 16 | gym==0.17.2 17 | idna==2.10 18 | importlib-metadata==2.0.0 19 | kiwisolver==1.2.0 20 | Markdown==3.2.2 21 | matplotlib==3.3.2 22 | multidict==4.7.6 23 | numpy==1.19.2 24 | oauthlib==3.1.0 25 | openpyxl==3.0.7 26 | Pillow==7.2.0 27 | protobuf==3.13.0 28 | pyasn1==0.4.8 29 | pyasn1-modules==0.2.8 30 | pyglet==1.5.0 31 | pyparsing==2.4.7 32 | python-dateutil==2.8.1 33 | pytz==2021.1 34 | requests==2.24.0 35 | requests-oauthlib==1.3.0 36 | rsa==4.6 37 | scipy==1.5.2 38 | simpy==4.0.1 39 | six==1.15.0 40 | tensorboard==2.3.0 41 | tensorboard-plugin-wit==1.7.0 42 | tensorboardX==2.1 43 | tqdm==4.49.0 44 | typing-extensions==3.7.4.3 45 | urllib3==1.25.10 46 | Werkzeug==1.0.1 47 | yarl==1.6.0 48 | zipp==3.2.0 49 | -------------------------------------------------------------------------------- /simulation_env/action_map.py: -------------------------------------------------------------------------------- 1 | # file name : action_map.py 2 | import sys 3 | if '..' not in sys.path: 4 | sys.path.append('..') 5 | 6 | import config 7 | 8 | ACTION_MAP = {} 9 | dim_actions = config.DIM_ACTION 10 | 11 | for action in range(dim_actions): 12 | if action == 0: 13 | dspch_rule = 'FIFO' 14 | elif action == 1: 15 | dspch_rule = 'LIFO' 16 | elif action == 2: 17 | dspch_rule = 'SPT' 18 | elif action == 3: 19 | dspch_rule = 'LPT' 20 | elif action == 4: 21 | dspch_rule = 'LWKR' 22 | elif action == 5: 23 | dspch_rule = 'MWKR' 24 | elif action == 6: 25 | dspch_rule = 'SSO' 26 | elif action == 7: 27 | dspch_rule = 'LSO' 28 | elif action == 8: 29 | dspch_rule = 'SPT+SSO' 30 | elif action == 9: 31 | dspch_rule = 'LPT+LSO' 32 | elif action == 10: 33 | dspch_rule = 'STPT' 34 | elif action == 11: 35 | dspch_rule = 'LTPT' 36 | 37 | ACTION_MAP[action] = dspch_rule -------------------------------------------------------------------------------- /simulation_env/env_jobshop_v0.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | ##----- Job shop SL for RL Environment ------ 5 | # * Author: Colin, Lee 6 | # * Date: Aug 16th, 2021 7 | ##------------------------------------------- 8 | # 9 | import os 10 | import sys 11 | if '..' not in sys.path: 12 | sys.path.append('..') 13 | 14 | import time 15 | import simpy 16 | import numpy as np 17 | import pandas as pd 18 | import matplotlib.pyplot as plt 19 | import utils.dispatch_logic as dp_rule 20 | 21 | from matplotlib.animation import FuncAnimation 22 | from utils.GanttPlot import Gantt 23 | 24 | import config 25 | 26 | INFINITY = float('inf') 27 | OPTIMAL_L = config.OPT_MAKESPAN 28 | DIM_ACTION = config.DIM_ACTION 29 | 30 | np.warnings.filterwarnings('ignore', category=np.VisibleDeprecationWarning) 31 | 32 | #entity 33 | class Order: 34 | def __init__(self, id, routing, prc_time, rls_time, intvl_arr): 35 | # *prc_time: process time, *rls_time: release time, *intvl_arr: arrival_interval 36 | self.id = id 37 | self.routing = routing 38 | self.prc_time = prc_time 39 | self.rls_time = rls_time 40 | self.intvl_arr = intvl_arr 41 | 42 | self.progress = 0 43 | 44 | def __str__(self): 45 | return f'Order #{self.id}' 46 | 47 | 48 | #resource in factory 49 | class Source: 50 | def __init__(self, fac, order_info = None): 51 | #reference 52 | self.fac = fac 53 | #attribute 54 | self.order_info = order_info 55 | self.orders_list = [] 56 | #statistics 57 | self.num_generated = 0 58 | 59 | def set_port(self): 60 | #reference 61 | self.env = self.fac.env 62 | self.queues = self.fac.queues 63 | 64 | #initial process 65 | self.process = self.env.process(self._generate_order()) 66 | 67 | def _generate_order(self): 68 | #get the data of orders 69 | self.order_info = self.order_info.sort_values(by='release_time') 70 | num_order = self.order_info.shape[0] 71 | 72 | _orders_list = [] 73 | 74 | #generate instances of Order from order information 75 | for num in range(num_order): 76 | id = self.order_info.loc[num, "id"] 77 | routing = self.order_info.loc[num, "routing"].split(',') 78 | prc_time = [int(i) for i in self.order_info.loc[num, "process_time"].split(',')] 79 | rls_time = self.order_info.loc[num, "release_time"] 80 | intvl_arr = self.order_info.loc[num, "arrival_interval"] 81 | 82 | _order = Order(id, routing, prc_time, rls_time, intvl_arr) 83 | 84 | _orders_list.append(_order) 85 | 86 | #wait for inter-arrival time 87 | for num in range(num_order): 88 | #To decide which order arrive first 89 | indx = np.argmax([sum(o.prc_time) for o in _orders_list]) 90 | order = _orders_list[indx] 91 | order = _orders_list[0] 92 | intvl_time = order.intvl_arr 93 | _orders_list.remove(order) 94 | 95 | yield self.env.timeout(intvl_time) 96 | #log for tracing 97 | if self.fac.log: 98 | print(" ({}) order {} release.".format(self.env.now, order.id)) 99 | #send order to queue 100 | target = int(order.routing[order.progress]) 101 | self.queues[target].order_arrive(order) 102 | #update statistics 103 | self.num_generated += 1 104 | 105 | 106 | ################# 107 | # TO DO: 108 | ################# 109 | #An actor to execute the action 110 | class Dispatcher: 111 | def __init__(self, fac): 112 | self.fac = fac 113 | self.env = fac.env 114 | 115 | # reference for mapping acton number to dispatching rule 116 | from simulation_env.action_map import ACTION_MAP 117 | self.action_map = ACTION_MAP 118 | 119 | def dispatch_by(self, action): 120 | # set the dispatch rule of each queue 121 | for _, queue in self.fac.queues.items(): 122 | queue.dspch_rule = self.action_map[action] 123 | 124 | 125 | class Queue: 126 | def __init__(self, fac, id): 127 | #reference 128 | self.fac = fac 129 | self.env = self.fac.env 130 | #attribute 131 | self.id = id 132 | self.space = [] 133 | self.dspch_rule = None 134 | 135 | def __str__(self): 136 | return f'queue #{self.id}' 137 | 138 | def set_port(self): 139 | #reference 140 | self.machine = self.fac.machines[self.id] 141 | 142 | def order_arrive(self, order): 143 | machine_idle = self.machine.status == 'idle' 144 | none_in_space = len(self.space) == 0 145 | if machine_idle and none_in_space: 146 | self.machine.process_order(order) 147 | else: 148 | self.space.append(order) 149 | 150 | ################# 151 | # TO DO: 152 | ################# 153 | #A way to pause simulation and resume after the actor take an action 154 | def check_obs_point(self): 155 | machine_idle = self.machine.status == 'idle' 156 | if machine_idle and len(self.space) > 0: 157 | 158 | self.fac.obs_point.succeed() 159 | self.fac.mc_need_dspch.add(self.id) 160 | 161 | self.fac.obs_point = self.env.event() 162 | 163 | self.fac.dict_dspch_evt[self.id] = self.env.event() 164 | self.env.process(self.get_order()) 165 | 166 | def get_order(self): 167 | if len(self.space) > 0: 168 | yield self.fac.dict_dspch_evt[self.id] 169 | ################# 170 | # TO DO: 171 | ################# 172 | #set an event for resuming after receive an action 173 | self.fac.dict_dspch_evt[self.id] = self.env.event() 174 | #get oder in queue 175 | order = dp_rule.get_order_from(self.space, self.dspch_rule) 176 | #send order to machine 177 | self.machine.process_order(order) 178 | #remove order form queue 179 | self.space.remove(order) 180 | else: 181 | order = self.space[0] 182 | #send order to machine 183 | self.machine.process_order(order) 184 | #remove order form queue 185 | self.space.remove(order) 186 | 187 | 188 | class Machine: 189 | def __init__(self, fac, id, num = 1): 190 | self.fac = fac 191 | self.id = id 192 | self.num = num 193 | self.processing = None 194 | #reference 195 | self.env = self.fac.env 196 | #attributes 197 | self.status = "idle" 198 | #statistics 199 | self.using_time = 0 200 | 201 | def __str__(self): 202 | return f'machine #{self.id}' 203 | 204 | def set_port(self): 205 | #reference 206 | self.queues = self.fac.queues 207 | self.sink = self.fac.sink 208 | 209 | ################# 210 | # TO DO: 211 | ################# 212 | #update the information which is state-related 213 | def process_order(self, order): 214 | #change status 215 | self.status = order 216 | #process order 217 | if self.fac.log: 218 | print(" ({}) machine {} start processing order {} - {} progress".format(self.env.now, self.id, order.id, order.progress)) 219 | 220 | ################# 221 | # TO DO: 222 | ################# 223 | #update the table of process status(progess matrix) 224 | self.fac.tb_proc_status[order.id - 1][order.progress] = 1 225 | 226 | #[Gantt plot preparing] udate the start/finish processing time of machine 227 | prc_time = order.prc_time[order.progress] 228 | self.fac.gantt_plot.update_gantt(self.id, self.env.now, prc_time, order.id) 229 | 230 | #do the action about processing 231 | self.process = self.env.process(self._process_order_callback()) 232 | 233 | ################# 234 | # TO DO: 235 | ################# 236 | #check observation point if there is any idle machine 237 | def _process_order_callback(self): 238 | #processing order for prc_time mins 239 | order = self.status 240 | prc_time = order.prc_time[order.progress] 241 | #[Gantt plot preparing] udate the start/finish processing time of machine 242 | self.fac.gantt_plot.update_gantt(self.id, self.env.now, prc_time, order.id) 243 | 244 | yield self.env.timeout(prc_time) 245 | if self.fac.log: 246 | print(" ({}) machine {} finish processing order {} - {} progress".format(self.env.now, self.id, order.id, order.progress)) 247 | #change order status 248 | order.progress += 1 249 | #send order to next station 250 | if order.progress < len(order.routing): 251 | target = int(order.routing[order.progress]) 252 | self.queues[target].order_arrive(order) 253 | else: 254 | self.sink.complete_order(order) 255 | 256 | #change status 257 | self.using_time += prc_time 258 | self.status = "idle" 259 | 260 | ################# 261 | # TO DO: 262 | ################# 263 | #get next order in queue 264 | terminal_is_true = self.fac.terminal.triggered 265 | if terminal_is_true: 266 | self.fac.obs_point.succeed() 267 | else: 268 | for queue in self.queues.values(): 269 | queue.check_obs_point() 270 | 271 | 272 | class Sink: 273 | def __init__(self, fac): 274 | self.fac = fac 275 | 276 | def set_port(self): 277 | #reference 278 | self.env = self.fac.env 279 | #attribute 280 | #statistics 281 | self.order_statistic = pd.DataFrame(columns = 282 | ["id", "release_time", "complete_time", "flow_time"] 283 | ) 284 | 285 | ################# 286 | # TO DO: 287 | ################# 288 | #define the terminal condition 289 | def complete_order(self, order): 290 | #update factory statistic 291 | self.fac.throughput += 1 292 | #update order statistic 293 | self.update_order_statistic(order) 294 | 295 | ################# 296 | # TO DO: 297 | ################# 298 | #ternimal condition 299 | num_order = self.fac.order_info.shape[0] 300 | if self.fac.throughput >= num_order: 301 | self.fac.terminal.succeed() 302 | self.fac.makespan = self.env.now 303 | 304 | def update_order_statistic(self, order): 305 | id = order.id 306 | rls_time = order.rls_time 307 | complete_time = self.env.now 308 | flow_time = complete_time - rls_time 309 | 310 | self.order_statistic.loc[id] = \ 311 | [id, rls_time, complete_time, flow_time] 312 | 313 | 314 | #factory 315 | class Factory: 316 | ''' 317 | A SimPy job-shop simulation which present a familiar OpenAI Gym like interface 318 | for Reinforcement Learning. 319 | 320 | Any environment needs: 321 | * A way to pause simulation and resume after the actor take an action 322 | * A state space 323 | * A reward function 324 | * An initialize (reset) method that returns the initial observations 325 | * A choice of actions 326 | * A islegal method to make sure the action is possible and legal 327 | * A step method that passes an action into the environment and returns: 328 | 1. new observations as state 329 | 2. reward 330 | 3. whether state is terminal 331 | 4. additional information 332 | * A render method to refresh and display the environment. 333 | * A way to recognize and return a terminal state (end of episode) 334 | 335 | ----------------- 336 | Internal methods: 337 | ----------------- 338 | __init__: 339 | Constructor method. 340 | _get_observation: 341 | Gets current state observations 342 | _islegal: 343 | Checks whether requested action is legal 344 | _pass_action: 345 | Execute the action to change the status of system 346 | _get_reward: 347 | Calculates reward based on empty beds or beds without patient 348 | 349 | 350 | Interfacing methods: 351 | -------------------- 352 | render: 353 | Display state 354 | reset: 355 | Initialise environment 356 | Return first state observations 357 | step: 358 | Take an action. Update state. Return obs, reward, terminal, info 359 | 360 | ''' 361 | def __init__(self, num_job, num_machine, file_name, opt_makespan, log=False): 362 | self.log = log 363 | 364 | #statistics 365 | self.throughput = 0 366 | self.last_util = 0 367 | self.makespan = INFINITY 368 | 369 | #system config 370 | self.num_machine = num_machine 371 | self.num_job = num_job 372 | self.order_info = pd.read_excel(file_name) 373 | self.df_machine_no = pd.read_excel(file_name, sheet_name='machine_no') 374 | self.df_proc_times = pd.read_excel(file_name, sheet_name='proc_time') 375 | self.tb_proc_status = np.zeros((num_job, num_machine)) 376 | self.opt_makespan = opt_makespan 377 | 378 | #[RL] attributes for the Environment of RL 379 | self.dim_actions = DIM_ACTION 380 | self.dim_observations = (3, self.num_job, self.num_job) 381 | self.observations = np.zeros(self.dim_observations) 382 | self.actions = np.arange(self.dim_actions) 383 | 384 | from gym import spaces 385 | self.action_space = spaces.Discrete(self.dim_actions) 386 | 387 | self.observations[0] = self.df_machine_no.values 388 | self.observations[1] = self.df_proc_times.values 389 | self.observations[2] = self.tb_proc_status 390 | 391 | #display 392 | self._render_his = [] 393 | 394 | def build(self): 395 | #build 396 | self.env = simpy.Environment() 397 | self.source = Source(self, self.order_info) 398 | self.dispatcher = Dispatcher(self) 399 | self.queues = {} 400 | self.machines = {} 401 | self.sink = Sink(self) 402 | for num in range(self.num_machine): 403 | self.queues[num] = Queue(self, num) 404 | self.machines[num] = Machine(self, num) 405 | #make connection 406 | self.source.set_port() 407 | for num, queue in self.queues.items(): 408 | queue.set_port() 409 | for num, machine in self.machines.items(): 410 | machine.set_port() 411 | self.sink.set_port() 412 | 413 | #dispatch event which would be successed when the mc finished dispatching 414 | self.dict_dspch_evt = {} 415 | self.mc_need_dspch = set() 416 | for num in range(self.num_machine): 417 | self.dict_dspch_evt[num] = self.env.event() 418 | 419 | #terminal event 420 | self.terminal = self.env.event() 421 | 422 | #[Gantt] 423 | self.gantt_plot = Gantt() 424 | 425 | def get_utilization(self): 426 | #compute average utiliztion of machines 427 | total_using_time = 0 428 | for _, machine in self.machines.items(): 429 | total_using_time += machine.using_time 430 | 431 | avg_using_time = total_using_time / self.num_machine 432 | return avg_using_time / self.env.now 433 | 434 | def _islegal(self, action): 435 | """ 436 | Check action is in list of allowed actions. If not, raise an exception. 437 | """ 438 | 439 | if action not in self.actions: 440 | _msg = f'Requested action -->{action},not in the action space' 441 | raise ValueError(_msg) 442 | 443 | def _pass_action(self, action): 444 | #to execute the action 445 | self.dispatcher.dispatch_by(action) 446 | for mc in self.mc_need_dspch: 447 | if mc != None: 448 | self.dict_dspch_evt[mc].succeed() 449 | self.mc_need_dspch = set() 450 | 451 | def _get_observations(self): 452 | self.observations[0] = self.df_machine_no.values 453 | self.observations[1] = self.df_proc_times.values 454 | self.observations[2] = self.tb_proc_status 455 | return self.observations.copy() 456 | 457 | def _get_reward(self): 458 | current_util = self.get_utilization() 459 | last_util = self.last_util 460 | reward = (current_util - last_util) 461 | 462 | #final state 463 | if self.terminal: 464 | makespan = self.makespan 465 | optimal = self.opt_makespan 466 | if makespan == optimal: 467 | reward += 100 468 | else: 469 | reward += 100 / (makespan - optimal) 470 | 471 | #record current utilization as last utilization 472 | self.last_util = current_util 473 | return reward 474 | 475 | def render(self, terminal=False, use_mode=False, motion_speed=0.1): 476 | plt.set_loglevel('WARNING') 477 | if use_mode: 478 | if len(self._render_his) % 3 == 0: 479 | plt.ioff() 480 | plt.close('all') 481 | # queues_status = {} 482 | # for id, queue in self.queues.items(): 483 | # queues_status[id] = [str(order) for order in queue.space] 484 | # print(f'\n (time: {self.env.now}) - Status of queues:\n {queues_status}') 485 | plt.ion() 486 | plt.pause(motion_speed) 487 | fig = self.gantt_plot.draw_gantt(self.env.now) 488 | self._render_his.append(fig) 489 | 490 | if terminal: 491 | plt.ioff() 492 | trm_frame = [plt.close(fig) for fig in self._render_his[:-1]] 493 | plt.show() 494 | else: 495 | if terminal: 496 | plt.ion() 497 | fig = self.gantt_plot.draw_gantt(self.env.now) 498 | plt.pause(motion_speed) 499 | # plt.ioff() 500 | # plt.show() 501 | 502 | def close(self): 503 | plt.ioff() 504 | plt.close('all') 505 | 506 | def reset(self): 507 | #re-build factory include re-setting the events 508 | self.build() 509 | 510 | #display 511 | self._render_his = [] 512 | 513 | #reset statistics 514 | self.throughput = 0 515 | self.last_util = 0 516 | self.makespan = INFINITY 517 | 518 | #get initial observations 519 | observations = self._get_observations() 520 | 521 | return observations 522 | 523 | def step(self, action): 524 | 525 | #execute the action 526 | self._islegal(action) 527 | self._pass_action(action) 528 | 529 | #Resume to simulate until next observation_point(pause) 530 | self.obs_point = self.env.event() 531 | self.env.run(until = self.obs_point) 532 | 533 | #get new observations 534 | observations = self._get_observations() 535 | 536 | #get reward 537 | reward = self._get_reward() 538 | 539 | #ternimal condition 540 | terminal = self.terminal.triggered 541 | 542 | info = {} 543 | 544 | return observations, reward, terminal, info 545 | 546 | 547 | if __name__ == '__main__': 548 | import time 549 | pd.set_option('display.max_columns', None) 550 | pd.set_option('display.max_rows' , None) 551 | pd.set_option('display.width' , 300) 552 | pd.set_option('max_colwidth' , 100) 553 | 554 | human_control = str(input('\n* Human control [y/n]? ')) 555 | usr_interaction = False 556 | if human_control == 'y': 557 | usr_interaction = True 558 | 559 | replication = 10 560 | 561 | # read problem config 562 | opt_makespan = 55 563 | file_name = 'job_info.xlsx' 564 | file_dir = os.getcwd() + '/input_data' 565 | file_path = os.path.join(file_dir, file_name) 566 | 567 | for rep in range(replication): 568 | # make environment 569 | fac = Factory(6, 6, file_path, opt_makespan, log=False) 570 | print('') 571 | print(f'Rep #{rep}') 572 | print('* Order Information: ') 573 | print(fac.order_info) 574 | print('') 575 | 576 | state = fac.reset() #include the bulid function 577 | done = False 578 | # print(f'({fac.env.now})') 579 | # print(f'[{state[-1]}]') 580 | # print() 581 | fac.render(done) 582 | while not done: 583 | 584 | if not usr_interaction: 585 | action = rep % fac.dim_actions #+ 1 586 | else: 587 | action = int(input(' * Choose an action from {0,1}: ')) 588 | 589 | next_state, reward, done, _ = fac.step(action) 590 | # print(f'({fac.env.now})') 591 | # print(f'[{state[-1]},\n {action}, {reward},\n {next_state[-1]}]') 592 | # print() 593 | state = next_state 594 | 595 | fac.render(terminal=done, use_mode=usr_interaction) 596 | 597 | 598 | # # fac.render(done) 599 | # time.sleep(0.5) 600 | # fac.close() 601 | 602 | # print(fac.event_record) 603 | print("============================================================") 604 | print(fac.sink.order_statistic.sort_values(by = "id")) 605 | print("============================================================") 606 | print("Average flow time: {}".format(np.mean(fac.sink.order_statistic['flow_time']))) 607 | print("Makespan = {}".format(fac.makespan)) 608 | -------------------------------------------------------------------------------- /simulation_env/input_data/job_info.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/colinlee0924/DRL-SimPy-JobShop/9e00ef98f99c4518a99a30b80589cce32786bcf4/simulation_env/input_data/job_info.xlsx -------------------------------------------------------------------------------- /sweep.yaml: -------------------------------------------------------------------------------- 1 | project: "DRL-SimPy-JSS" 2 | program: main_djss_attention_paper.py 3 | method: bayes 4 | metric: 5 | name: MeanFT 6 | goal: minimize 7 | parameters: 8 | batch_size: 9 | min: 8 10 | max: 32 11 | lr: 12 | min: 0.00001 13 | max: 0.001 14 | capacity: 15 | min: 1000 16 | max: 10000 17 | episode: 18 | min: 1000 19 | max: 50000 20 | -------------------------------------------------------------------------------- /test_djss_attention.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | ##----- [Trainer] Main program of Learning ------ 5 | # * Author: Colin, Lee 6 | # * Date: Aug 16th, 2021 7 | ##----------------------------------------------- 8 | # 9 | # from config import EPISODE 10 | import os 11 | import sys 12 | import time 13 | import torch 14 | import random 15 | import logging 16 | import argparse 17 | import itertools 18 | 19 | import numpy as np 20 | import pandas as pd 21 | import torch.nn as nn 22 | import matplotlib.pyplot as plt 23 | 24 | from tensorboardX import SummaryWriter 25 | from datetime import datetime as dt 26 | from tqdm import tqdm 27 | 28 | # from simulation_env.env_jobshop_v1 import Factory 29 | from simulation_env.env_for_job_shop_v7_attention import Factory 30 | # from dqn_agent_djss import DQN as DDQN 31 | from ddqn_agent_attention import DDQN 32 | 33 | import pdb 34 | 35 | # seed 36 | import config_djss_attention as config 37 | seed = config.SEED #999 38 | random.seed(seed) 39 | np.random.seed(seed) 40 | torch.manual_seed(seed) 41 | torch.backends.cudnn.deterministic = True 42 | 43 | logging.basicConfig(level=logging.DEBUG) 44 | plt.set_loglevel('WARNING') 45 | 46 | 47 | def test(args, _env, agent, writer, CHECK_PTS): 48 | logging.info('\n* Start Testing') 49 | env = _env 50 | 51 | action_space = env.action_space 52 | epsilon = args.test_epsilon 53 | seeds = [args.seed + i for i in range(100)] 54 | # seeds = [args.seed + i for i in range(30)] 55 | rewards = [] 56 | makespans = [] 57 | lst_mean_ft = [] 58 | 59 | n_episode = 0 60 | seed_loader = tqdm(seeds) 61 | ##################### Record Action ################### 62 | episode_percentage, episode_selection = [], [] 63 | ###################################################### 64 | 65 | for seed in seed_loader: 66 | #################### Record Action ################### 67 | action_selection = [0] * env.dim_actions 68 | ###################################################### 69 | n_episode += 1 70 | total_reward = 0 71 | # env.seed(seed) 72 | state = env.reset() 73 | for t in itertools.count(start=1): 74 | 75 | #action = agent.select_action(state, epsilon, action_space) 76 | action = agent.select_best_action(state) 77 | 78 | # execute action 79 | next_state, reward, done, _ = env.step(action) 80 | #################### Record Action ################### 81 | action_selection[action] += 1 82 | ###################################################### 83 | 84 | state = next_state 85 | total_reward += reward 86 | 87 | # env.render(done) 88 | #env.render(terminal=done) 89 | 90 | if done: 91 | writer.add_scalar(f'Test_{CHECK_PTS}/Episode_Reward' , total_reward, n_episode) 92 | writer.add_scalar(f'Test_{CHECK_PTS}/Episode_Makespan', env.makespan, n_episode) 93 | writer.add_scalar(f'Test_{CHECK_PTS}/Episode_MeanFT', env.mean_flow_time, n_episode) 94 | rewards.append(total_reward) 95 | makespans.append(env.makespan) 96 | lst_mean_ft.append(env.mean_flow_time) 97 | 98 | # # Check the scheduling result 99 | # fig = env.gantt_plot.draw_gantt(env.makespan) 100 | # writer.add_figure('Test/Gantt_Chart', fig, n_episode) 101 | break 102 | 103 | env.close() 104 | #################### Record Action ################### 105 | # statistic of the selection of action 106 | action_percentage = [0] * len(action_selection) 107 | for act in range(len(action_selection)): 108 | action_percentage[act] = action_selection[act] / t 109 | episode_selection.append(action_selection) 110 | episode_percentage.append(action_percentage) 111 | ###################################################### 112 | df_act_res = pd.DataFrame(episode_selection) 113 | df_act_per = pd.DataFrame(episode_percentage) 114 | df_act_res.to_csv('testing_action_result.csv') 115 | df_act_per.to_csv('testing_action_percentage.csv') 116 | 117 | logging.info(f' - Average Reward = {np.mean(rewards)}') 118 | logging.info(f' - Average Makespan = {np.mean(makespans)}') 119 | logging.info(f' - Average MeanFT = {np.mean(lst_mean_ft)}') 120 | 121 | df = pd.DataFrame(lst_mean_ft) 122 | pd.set_option('display.max_rows', df.shape[0] + 1) 123 | print(df) 124 | df.to_csv('testing_result_09_4m200n.csv') 125 | 126 | 127 | def main(): 128 | import config_djss_attention as config 129 | ## arguments ## 130 | parser = argparse.ArgumentParser(description=__doc__) 131 | parser.add_argument('-d', '--device', default=config.DEVICE) #'cuda') 132 | parser.add_argument('-m', '--model' , default=config.MODEL) #'model/dqn.pth') 133 | parser.add_argument('--logdir' , default=config.LOG_DIR) #'log/dqn') 134 | # train 135 | parser.add_argument('--warmup' , default=config.WARMUP , type=int) 136 | parser.add_argument('--episode' , default=config.EPISODE , type=int) 137 | parser.add_argument('--capacity' , default=config.CAPACITY , type=int) 138 | parser.add_argument('--batch_size' , default=config.BATCH_SIZE , type=int) 139 | parser.add_argument('--lr' , default=config.LEARNING_R , type=float) 140 | parser.add_argument('--eps_decay' , default=config.EPS_DECAY , type=float) 141 | parser.add_argument('--eps_min' , default=config.EPS_MIN , type=float) 142 | parser.add_argument('--gamma' , default=config.GAMMA , type=float) 143 | parser.add_argument('--freq' , default=config.FREQ , type=int) 144 | parser.add_argument('--target_freq' , default=config.TARGET_FREQ , type=int) 145 | parser.add_argument('--render_episode', default=config.RENDER_EPISODE, type=int) 146 | parser.add_argument('--priori_period' , default=config.PRIORI_PERIOD , type=int) 147 | # test 148 | parser.add_argument('--test_only' , action='store_true') 149 | parser.add_argument('--render' , action='store_true') 150 | parser.add_argument('--seed' , default=config.SEED , type=int) 151 | parser.add_argument('--test_epsilon', default=config.TEST_EPSILON, type=float) 152 | args = parser.parse_args() 153 | 154 | ## main ## 155 | file_name = config.FILE_NAME 156 | file_dir = os.getcwd() + '/simulation_env/instance' 157 | file_path = os.path.join(file_dir, file_name) 158 | 159 | opt_makespan = config.OPT_MAKESPAN 160 | num_machine = config.NUM_MACHINE 161 | num_job = config.NUM_JOB 162 | 163 | # rls_rule = config.RLS_RULE 164 | 165 | # Agent & Environment 166 | # env = Factory(num_job, num_machine, file_path, opt_makespan, log=False) 167 | # env = Factory(file_path, default_rule='FIFO', util=0.85, log=False)#True) 168 | env = Factory(file_path, default_rule='FIFO', util=0.9, log=False)#True) 169 | # agent = DQN(env.dim_observations, env.dim_actions, args) 170 | agent = DDQN(env.dim_observations, env.dim_actions, args) 171 | 172 | # Tensorboard to trace the learning process 173 | ## ----------------------------------------------- 174 | MODEL_VERSION = "20211224-145648" 175 | CHECK_PTS = 19000 176 | # CHECK_PTS = 50000 177 | ## ----------------------------------------------- 178 | # time_start = time.strftime("%Y%m%d-%H-%M-%S", time.localtime()) 179 | # writer = SummaryWriter(f'log_djss_attention/DDQN-4x500x6-{time_start}') 180 | writter_name = config.WRITTER 181 | # writer = SummaryWriter(f'{writter_name}') 182 | # writer = SummaryWriter(f'test_log/DDQN-attention-h4-load085-{MODEL_VERSION}') 183 | writer = SummaryWriter(f'test_log/DDQN-attention-h4-load09-{MODEL_VERSION}') 184 | # writer = SummaryWriter(f'log/DDQN-{time.time()}') 185 | ## Test ## To test the pre-trained model 186 | # agent.load(args.model) 187 | # agent.load('model/djss_attention/ddqn-attention-h4-20211224-145648.pth-ck-19000.pth') 188 | agent.load(f'model/djss_attention/ddqn-attention-h4-{MODEL_VERSION}.pth-ck-{CHECK_PTS}.pth') 189 | # agent.load(f'model/djss_attention/ddqn-attention-h4-20211224-145648.pth') 190 | test(args, env, agent, writer, CHECK_PTS) 191 | 192 | writer.close() 193 | 194 | 195 | if __name__ == '__main__': 196 | main() 197 | 198 | -------------------------------------------------------------------------------- /test_djss_attention_paper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | ##----- [Trainer] Main program of Learning ------ 5 | # * Author: Colin, Lee 6 | # * Date: Aug 16th, 2021 7 | ##----------------------------------------------- 8 | # 9 | # from config import EPISODE 10 | import os 11 | import sys 12 | import time 13 | import torch 14 | import random 15 | import logging 16 | import argparse 17 | import itertools 18 | 19 | import numpy as np 20 | import pandas as pd 21 | import torch.nn as nn 22 | import matplotlib.pyplot as plt 23 | 24 | from tensorboardX import SummaryWriter 25 | from datetime import datetime as dt 26 | from tqdm import tqdm 27 | 28 | # from simulation_env.env_jobshop_v1 import Factory 29 | from simulation_env.env_for_job_shop_v7_attention import Factory 30 | # from dqn_agent_djss import DQN as DDQN 31 | from ddqn_agent_attention_paper import DDQN 32 | 33 | import pdb 34 | 35 | # seed 36 | import config_djss_attention_paper as config 37 | seed = config.SEED #999 38 | random.seed(seed) 39 | np.random.seed(seed) 40 | torch.manual_seed(seed) 41 | torch.backends.cudnn.deterministic = True 42 | 43 | logging.basicConfig(level=logging.DEBUG) 44 | plt.set_loglevel('WARNING') 45 | 46 | 47 | def test(args, _env, agent, writer, CHECK_PTS): 48 | logging.info('\n* Start Testing') 49 | env = _env 50 | 51 | action_space = env.action_space 52 | epsilon = args.test_epsilon 53 | seeds = [args.seed + i for i in range(100)] 54 | # seeds = [args.seed + i for i in range(30)] 55 | rewards = [] 56 | makespans = [] 57 | lst_mean_ft = [] 58 | 59 | n_episode = 0 60 | seed_loader = tqdm(seeds) 61 | ##################### Record Action ################### 62 | episode_percentage, episode_selection = [], [] 63 | ###################################################### 64 | 65 | for seed in seed_loader: 66 | #################### Record Action ################### 67 | action_selection = [0] * env.dim_actions 68 | ###################################################### 69 | n_episode += 1 70 | total_reward = 0 71 | # env.seed(seed) 72 | state = env.reset() 73 | for t in itertools.count(start=1): 74 | 75 | #action = agent.select_action(state, epsilon, action_space) 76 | action = agent.select_best_action(state) 77 | 78 | # execute action 79 | next_state, reward, done, _ = env.step(action) 80 | #################### Record Action ################### 81 | action_selection[action] += 1 82 | ###################################################### 83 | 84 | state = next_state 85 | total_reward += reward 86 | 87 | # env.render(done) 88 | #env.render(terminal=done) 89 | 90 | if done: 91 | writer.add_scalar(f'Test_{CHECK_PTS}/Episode_Reward' , total_reward, n_episode) 92 | writer.add_scalar(f'Test_{CHECK_PTS}/Episode_Makespan', env.makespan, n_episode) 93 | writer.add_scalar(f'Test_{CHECK_PTS}/Episode_MeanFT', env.mean_flow_time, n_episode) 94 | rewards.append(total_reward) 95 | makespans.append(env.makespan) 96 | lst_mean_ft.append(env.mean_flow_time) 97 | 98 | # # Check the scheduling result 99 | # fig = env.gantt_plot.draw_gantt(env.makespan) 100 | # writer.add_figure('Test/Gantt_Chart', fig, n_episode) 101 | break 102 | 103 | env.close() 104 | #################### Record Action ################### 105 | # statistic of the selection of action 106 | action_percentage = [0] * len(action_selection) 107 | for act in range(len(action_selection)): 108 | action_percentage[act] = action_selection[act] / t 109 | episode_selection.append(action_selection) 110 | episode_percentage.append(action_percentage) 111 | ###################################################### 112 | df_act_res = pd.DataFrame(episode_selection) 113 | df_act_per = pd.DataFrame(episode_percentage) 114 | df_act_res.to_csv('paper_testing_action_result.csv') 115 | df_act_per.to_csv('paper_testing_action_percentage.csv') 116 | 117 | logging.info(f' - Average Reward = {np.mean(rewards)}') 118 | logging.info(f' - Average Makespan = {np.mean(makespans)}') 119 | logging.info(f' - Average MeanFT = {np.mean(lst_mean_ft)}') 120 | 121 | df = pd.DataFrame(lst_mean_ft) 122 | pd.set_option('display.max_rows', df.shape[0] + 1) 123 | print(df) 124 | df.to_csv('paper_testing_result_09_4m200n.csv') 125 | 126 | 127 | def main(): 128 | import config_djss_attention_paper as config 129 | ## arguments ## 130 | parser = argparse.ArgumentParser(description=__doc__) 131 | parser.add_argument('-d', '--device', default=config.DEVICE) #'cuda') 132 | parser.add_argument('-m', '--model' , default=config.MODEL) #'model/dqn.pth') 133 | parser.add_argument('--logdir' , default=config.LOG_DIR) #'log/dqn') 134 | # train 135 | parser.add_argument('--warmup' , default=config.WARMUP , type=int) 136 | parser.add_argument('--episode' , default=config.EPISODE , type=int) 137 | parser.add_argument('--capacity' , default=config.CAPACITY , type=int) 138 | parser.add_argument('--batch_size' , default=config.BATCH_SIZE , type=int) 139 | parser.add_argument('--lr' , default=config.LEARNING_R , type=float) 140 | parser.add_argument('--eps_decay' , default=config.EPS_DECAY , type=float) 141 | parser.add_argument('--eps_min' , default=config.EPS_MIN , type=float) 142 | parser.add_argument('--gamma' , default=config.GAMMA , type=float) 143 | parser.add_argument('--freq' , default=config.FREQ , type=int) 144 | parser.add_argument('--target_freq' , default=config.TARGET_FREQ , type=int) 145 | parser.add_argument('--render_episode', default=config.RENDER_EPISODE, type=int) 146 | parser.add_argument('--priori_period' , default=config.PRIORI_PERIOD , type=int) 147 | # test 148 | parser.add_argument('--test_only' , action='store_true') 149 | parser.add_argument('--render' , action='store_true') 150 | parser.add_argument('--seed' , default=config.SEED , type=int) 151 | parser.add_argument('--test_epsilon', default=config.TEST_EPSILON, type=float) 152 | args = parser.parse_args() 153 | 154 | ## main ## 155 | file_name = config.FILE_NAME 156 | file_dir = os.getcwd() + '/simulation_env/instance' 157 | file_path = os.path.join(file_dir, file_name) 158 | 159 | opt_makespan = config.OPT_MAKESPAN 160 | num_machine = config.NUM_MACHINE 161 | num_job = config.NUM_JOB 162 | 163 | # rls_rule = config.RLS_RULE 164 | 165 | # Agent & Environment 166 | # env = Factory(num_job, num_machine, file_path, opt_makespan, log=False) 167 | # env = Factory(file_path, default_rule='FIFO', util=0.85, log=False)#True) 168 | env = Factory(file_path, default_rule='FIFO', util=0.9, log=False)#True) 169 | # agent = DQN(env.dim_observations, env.dim_actions, args) 170 | agent = DDQN(env.dim_observations, env.dim_actions, args) 171 | 172 | # Tensorboard to trace the learning process 173 | ## ----------------------------------------------- 174 | MODEL_VERSION = "20220106-144253" 175 | CHECK_PTS = 40000 176 | # CHECK_PTS = 50000 177 | ## ----------------------------------------------- 178 | # time_start = time.strftime("%Y%m%d-%H-%M-%S", time.localtime()) 179 | # writer = SummaryWriter(f'log_djss_attention/DDQN-4x500x6-{time_start}') 180 | writter_name = config.WRITTER 181 | # writer = SummaryWriter(f'{writter_name}') 182 | # writer = SummaryWriter(f'test_log/DDQN-attention-h4-load085-{MODEL_VERSION}') 183 | writer = SummaryWriter(f'test_log/paper-DDQN-attention-h4-load09-{MODEL_VERSION}') 184 | # writer = SummaryWriter(f'log/DDQN-{time.time()}') 185 | ## Test ## To test the pre-trained model 186 | # agent.load(args.model) 187 | # agent.load('model/djss_attention/ddqn-attention-h4-20211224-145648.pth-ck-19000.pth') 188 | agent.load(f'model/djss_attention/paper/ddqn-attention-4m300-h4-stateNormalized-{MODEL_VERSION}.pth-ck-{CHECK_PTS}.pth') 189 | # agent.load(f'model/djss_attention/paper/ddqn-attention-4m300-h4-stateNormalized-{MODEL_VERSION}.pth') 190 | test(args, env, agent, writer, CHECK_PTS) 191 | 192 | writer.close() 193 | 194 | 195 | if __name__ == '__main__': 196 | main() 197 | 198 | -------------------------------------------------------------------------------- /utils/GanttPlot.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | ##------- [Tool] Gantt chart plotting tool -------- 5 | # * Author: CIMLab 6 | # * Date: Apr 30th, 2020 7 | # * Description: 8 | # This program is a tool to help you guys 9 | # to draw the gantt chart of your scheduling 10 | # result. Feel free to modify it as u like. 11 | ##------------------------------------------------- 12 | # 13 | import time 14 | import numpy as np 15 | import matplotlib.pyplot as plt 16 | import matplotlib.colors as mcolors 17 | 18 | class Gantt: 19 | def __init__(self): 20 | self.gantt_data = {"MC": [], 21 | "Order" : [], 22 | "Start time" : [], 23 | "Process time" : []} 24 | 25 | def update_gantt(self, MC, ST, PT, job): 26 | self.gantt_data['MC'].append("M{}".format(MC)) 27 | self.gantt_data['Order'].append(job) 28 | self.gantt_data['Start time'].append(ST) 29 | self.gantt_data['Process time'].append(PT) 30 | 31 | def draw_gantt(self, T_NOW, save_name = None): 32 | #set color list 33 | # fig = plt.figure(figsize=(12, 6)) 34 | fig, axes = plt.subplots(1, 1, figsize=(16, 6)) 35 | ax = axes 36 | colors = list(mcolors.TABLEAU_COLORS.keys()) 37 | #draw gantt bar 38 | y = self.gantt_data['MC'] 39 | width = self.gantt_data['Process time'] 40 | left = self.gantt_data['Start time'] 41 | color = [] 42 | for j in self.gantt_data['Order']: 43 | color.append(colors[int(j)]) 44 | 45 | # plt.barh(y = y, width = width, height = 0.5, color=color,left = left, align = 'center',alpha = 0.6) 46 | ax.barh(y = y , width = width , height = 0.5 ,\ 47 | color = color, left = left , align = 'center',\ 48 | alpha = 0.6 , edgecolor ='black', linewidth = 1) 49 | #add text 50 | for i in range(len(self.gantt_data['MC'])): 51 | text_x = self.gantt_data['Start time'][i] + self.gantt_data['Process time'][i]/2 52 | text_y = self.gantt_data['MC'][i] 53 | text = self.gantt_data['Order'][i] 54 | # plt.text(text_x, text_y, 'J'+str(text), verticalalignment='center', horizontalalignment='center', fontsize=8) 55 | ax.text(text_x, text_y, 'J'+str(text), verticalalignment='center', horizontalalignment='center', fontsize=8) 56 | #figure setting 57 | ax.set_xlabel("time") 58 | ax.set_ylabel("Machine") 59 | ax.set_title("Gantt Chart") 60 | if T_NOW >= 20: 61 | ax.set_xticks(np.arange(0, T_NOW+1, 5)) 62 | else: 63 | ax.set_xticks(np.arange(0, T_NOW+1, 1)) 64 | # plt.grid(True) 65 | 66 | if save_name != None: 67 | plt.savefig(save_name) 68 | 69 | return fig 70 | 71 | -------------------------------------------------------------------------------- /utils/MemeryBuffer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | ##----- [Utils] Memory Buffer ------ 5 | # * Author: Colin, Lee 6 | # * Date: Aug 16th, 2021 7 | ##---------------------------------- 8 | import math 9 | import torch 10 | import random 11 | import logging 12 | 13 | import numpy as np 14 | import torch.nn as nn 15 | 16 | from collections import deque 17 | from datetime import datetime as dt 18 | 19 | import pdb 20 | 21 | logging.basicConfig(level=logging.DEBUG) 22 | # ----------------------------------------------- 23 | 24 | class ReplayMemory: 25 | __slots__ = ['buffer'] 26 | 27 | def __init__(self, capacity): 28 | self.buffer = deque(maxlen=capacity) 29 | 30 | def __len__(self): 31 | return len(self.buffer) 32 | 33 | def append(self, *transition): 34 | # (state, action, reward, next_state, not_done) 35 | self.buffer.append(tuple(map(tuple, transition))) 36 | 37 | def sample(self, batch_size, device): 38 | '''sample a batch of transition tensors''' 39 | transitions = random.sample(self.buffer, batch_size) 40 | return (torch.tensor(np.array(x), dtype=torch.float, device=device) 41 | for x in zip(*transitions)) -------------------------------------------------------------------------------- /utils/PERMemory.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | from SumTree import SumTree 4 | 5 | class Memory: # stored as ( s, a, r, s_ ) in SumTree 6 | e = 0.01 7 | a = 0.6 8 | beta = 0.4 9 | beta_increment_per_sampling = 0.001 10 | 11 | def __init__(self, capacity): 12 | self.tree = SumTree(capacity) 13 | self.capacity = capacity 14 | 15 | def __len__(self): 16 | return self.capacity 17 | 18 | def _get_priority(self, error): 19 | return (np.abs(error) + self.e) ** self.a 20 | 21 | def append(self, error, sample): 22 | p = self._get_priority(error) 23 | self.tree.add(p, sample) 24 | 25 | def sample(self, n): 26 | batch = [] 27 | idxs = [] 28 | segment = self.tree.total() / n 29 | priorities = [] 30 | 31 | self.beta = np.min([1., self.beta + self.beta_increment_per_sampling]) 32 | 33 | for i in range(n): 34 | a = segment * i 35 | b = segment * (i + 1) 36 | 37 | s = random.uniform(a, b) 38 | (idx, p, data) = self.tree.get(s) 39 | priorities.append(p) 40 | batch.append(data) 41 | idxs.append(idx) 42 | 43 | sampling_probabilities = priorities / self.tree.total() 44 | is_weight = np.power(self.tree.n_entries * sampling_probabilities, -self.beta) 45 | is_weight /= is_weight.max() 46 | 47 | return batch, idxs, is_weight 48 | 49 | def update(self, idx, error): 50 | p = self._get_priority(error) 51 | self.tree.update(idx, p) 52 | -------------------------------------------------------------------------------- /utils/SumTree.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | 3 | 4 | # SumTree 5 | # a binary tree data structure where the parent’s value is the sum of its children 6 | class SumTree: 7 | write = 0 8 | 9 | def __init__(self, capacity): 10 | self.capacity = capacity 11 | self.tree = numpy.zeros(2 * capacity - 1) 12 | self.data = numpy.zeros(capacity, dtype=object) 13 | self.n_entries = 0 14 | 15 | # update to the root node 16 | def _propagate(self, idx, change): 17 | parent = (idx - 1) // 2 18 | 19 | self.tree[parent] += change 20 | 21 | if parent != 0: 22 | self._propagate(parent, change) 23 | 24 | # find sample on leaf node 25 | def _retrieve(self, idx, s): 26 | left = 2 * idx + 1 27 | right = left + 1 28 | 29 | if left >= len(self.tree): 30 | return idx 31 | 32 | if s <= self.tree[left]: 33 | return self._retrieve(left, s) 34 | else: 35 | return self._retrieve(right, s - self.tree[left]) 36 | 37 | def total(self): 38 | return self.tree[0] 39 | 40 | # store priority and sample 41 | def add(self, p, data): 42 | idx = self.write + self.capacity - 1 43 | 44 | self.data[self.write] = data 45 | self.update(idx, p) 46 | 47 | self.write += 1 48 | if self.write >= self.capacity: 49 | self.write = 0 50 | 51 | if self.n_entries < self.capacity: 52 | self.n_entries += 1 53 | 54 | # update priority 55 | def update(self, idx, p): 56 | change = p - self.tree[idx] 57 | 58 | self.tree[idx] = p 59 | self._propagate(idx, change) 60 | 61 | # get priority and sample 62 | def get(self, s): 63 | idx = self._retrieve(0, s) 64 | dataIdx = idx - self.capacity + 1 65 | 66 | return (idx, self.tree[idx], self.data[dataIdx]) -------------------------------------------------------------------------------- /utils/action_map.py: -------------------------------------------------------------------------------- 1 | # file name : action_map.py 2 | import sys 3 | if '..' not in sys.path: 4 | sys.path.append('..') 5 | 6 | import config 7 | 8 | ACTION_MAP = {} 9 | dim_actions = config.DIM_ACTION 10 | 11 | for action in range(dim_actions): 12 | if action == 0: 13 | dspch_rule = 'FIFO' 14 | elif action == 1: 15 | dspch_rule = 'LIFO' 16 | elif action == 2: 17 | dspch_rule = 'SPT' 18 | elif action == 3: 19 | dspch_rule = 'LPT' 20 | elif action == 4: 21 | dspch_rule = 'LWKR' 22 | elif action == 5: 23 | dspch_rule = 'MWKR' 24 | elif action == 6: 25 | dspch_rule = 'SSO' 26 | elif action == 7: 27 | dspch_rule = 'LSO' 28 | elif action == 8: 29 | dspch_rule = 'SPT+SSO' 30 | elif action == 9: 31 | dspch_rule = 'LPT+LSO' 32 | elif action == 10: 33 | dspch_rule = 'STPT' 34 | elif action == 11: 35 | dspch_rule = 'LTPT' 36 | 37 | ACTION_MAP[action] = dspch_rule -------------------------------------------------------------------------------- /utils/dispatch_logic.py: -------------------------------------------------------------------------------- 1 | # file name: dispatch_rule.py 2 | 3 | import numpy as np 4 | 5 | def sort_order_by(queue_space, dspch_rule): 6 | if dspch_rule == 'FIFO': 7 | space = sorted(queue_space, key= lambda order: order.id) 8 | elif dspch_rule == 'LIFO': 9 | space = sorted(queue_space, key= lambda order: order.id) 10 | else: 11 | if dspch_rule == 'SPT': 12 | space = sorted(queue_space, key= lambda order: order.prc_time[order.progress]) 13 | elif dspch_rule == 'LPT': 14 | space = sorted(queue_space, key= lambda order: order.prc_time[order.progress], reverse=True) 15 | elif dspch_rule == 'LWKR': 16 | indx = np.argmin([sum(ord.prc_time[ord.progress:]) for ord in queue_space]) 17 | elif dspch_rule == 'MWKR': 18 | indx = np.argmax([sum(ord.prc_time[ord.progress:]) for ord in queue_space]) 19 | elif dspch_rule == 'SSO': 20 | indx = np.argmin(_get_subsequence_prc_times(queue_space)) 21 | elif dspch_rule == 'LSO': 22 | indx = np.argmax(_get_subsequence_prc_times(queue_space)) 23 | elif dspch_rule == 'SPT+SSO': 24 | indx = np.argmin(_get_cur_subsequence_prc_times(queue_space)) 25 | elif dspch_rule == 'LPT+LSO': 26 | indx = np.argmax(_get_cur_subsequence_prc_times(queue_space)) 27 | elif dspch_rule == 'STPT': 28 | indx = np.argmin([sum(order.prc_time) for order in queue_space]) 29 | elif dspch_rule == 'LTPT': 30 | indx = np.argmax([sum(order.prc_time) for order in queue_space]) 31 | else: 32 | print(f'[ERROR #1] Here is not a {dspch_rule} rule in set') 33 | raise NotImplementedError 34 | 35 | order = queue_space[indx] 36 | return order 37 | 38 | def get_order_from(queue_space, dspch_rule): 39 | #get oder in queue 40 | if dspch_rule == 'FIFO': 41 | order = queue_space[0] 42 | elif dspch_rule == 'LIFO': 43 | order = queue_space[-1] 44 | # indx = np.argmax([order.arr_time for order in queue_space]) 45 | # order = queue_space[indx] 46 | else: 47 | if dspch_rule == 'SPT': 48 | indx = np.argmin([order.prc_time[order.progress] for order in queue_space]) 49 | elif dspch_rule == 'LPT': 50 | indx = np.argmax([order.prc_time[order.progress] for order in queue_space]) 51 | elif dspch_rule == 'LWKR': 52 | indx = np.argmin([sum(ord.prc_time[ord.progress:]) for ord in queue_space]) 53 | elif dspch_rule == 'MWKR': 54 | indx = np.argmax([sum(ord.prc_time[ord.progress:]) for ord in queue_space]) 55 | elif dspch_rule == 'SSO': 56 | indx = np.argmin(_get_subsequence_prc_times(queue_space)) 57 | elif dspch_rule == 'LSO': 58 | indx = np.argmax(_get_subsequence_prc_times(queue_space)) 59 | elif dspch_rule == 'SPT+SSO': 60 | indx = np.argmin(_get_cur_subsequence_prc_times(queue_space)) 61 | elif dspch_rule == 'LPT+LSO': 62 | indx = np.argmax(_get_cur_subsequence_prc_times(queue_space)) 63 | elif dspch_rule == 'STPT': 64 | indx = np.argmin([sum(order.prc_time) for order in queue_space]) 65 | elif dspch_rule == 'LTPT': 66 | indx = np.argmax([sum(order.prc_time) for order in queue_space]) 67 | else: 68 | print(f'[ERROR #1] Here is not a {dspch_rule} rule in set') 69 | raise NotImplementedError 70 | 71 | order = queue_space[indx] 72 | 73 | return order 74 | 75 | def _get_subsequence_prc_times(queue_space): 76 | lst_sub_prc_times = [] 77 | for order in queue_space: 78 | if order.progress + 1 < len(order.prc_time): 79 | lst_sub_prc_times.append(order.prc_time[order.progress + 1]) 80 | else: 81 | lst_sub_prc_times.append(0) 82 | return lst_sub_prc_times 83 | 84 | def _get_cur_subsequence_prc_times(queue_space): 85 | lst_res_prc_times = [] 86 | for order in queue_space: 87 | if order.progress + 2 < len(order.prc_time): 88 | cur_subsq_times = sum(order.prc_time[order.progress:order.progress + 2]) 89 | lst_res_prc_times.append(cur_subsq_times) 90 | else: 91 | cur_subsq_times = sum(order.prc_time[order.progress:]) 92 | lst_res_prc_times.append(cur_subsq_times) 93 | return lst_res_prc_times --------------------------------------------------------------------------------