├── README.md ├── code ├── dqn_algo.py ├── model.py └── run_scale.py ├── fig1.jpg ├── fig2.jpg ├── fig3.jpg └── fig4.jpg /README.md: -------------------------------------------------------------------------------- 1 | # Deep Reinforcement Learning for Furniture Layout Simulation in Indoor Graphics Scenes 2 | 3 | ## Installation 4 | Please install the dependencies via conda: 5 | * PyTorch >= 1.0.0 6 | * networkx 7 | * numpy 8 | * Python >= 3.6 9 | 10 | ## Introduction 11 | 12 | In the industrial interior design process, professional designers plan the size and position of furniture in a room to achieve a satisfactory design for selling. In this repository, we explore the interior graphics scenes design task as a Markov decision process (MDP), which is solved by deep reinforcement learning. The goal is to produce an accurate layout for the furniture in the indoor graphics scenes simulation. In particular, we first formulate the furniture layout task as a MDP problem by defining the state, action, and reward function. We then design the simulated environment and deploy a reinforcement learning (RL) agent that interacts with the environment to learn the optimal layout for the MDP. 13 | 14 | ## Numerical results 15 | We conduct our experiments on a large-scale real-world interior layout dataset that contains industrial designs from professional designers. Our numerical results demonstrate that the proposed model yields higher-quality layouts as compared with the state-of-art model. 16 | 17 | The figure below illustrate several examples of layouts produced by the state-of-the-art models. These layouts are for bedroom and tatami room. The ground truth layout in the simulator and the real-time renders can be found in the first row. The layouts produced by the state-of-art models are shown in the second and third rows. It can be observed that the state-of-art model produces inaccurate position and size of the furniture. 18 | 19 | ![Size & Position](fig1.jpg) 20 | 21 | We formulate the planning of furniture layout in the simulation of graphics indoor scenes as a Markov decision process (MDP) augmented with a goal state $G$ that we would like an agent to learn. We develop the simulator for the interior graphic indoor scenes. As shown in the figure below, we formulate the action, reward, policy and environment for the learning of the 2D furniture layout simulation. 22 | 23 | ![MDP Formulation](fig2.jpg) 24 | 25 | Given the sizes and positions of the walls, windows, doors and furniture in a real room, the developed simulator transfers the real indoor scenes to simulated graphics indoor scenes. Different components are in different colors in the simulation. 26 | 27 | ![Simulation Environment](fig3.jpg) 28 | 29 | Given a bathroom with random furniture positions, the trained RL agent is able to produce a good layout for the bathroom graphics scenes. The first row represents the the ground truth layoutfor a bathroom in the simulation and its corresponding render. The second row represents the bath-room with random furniture positions. The third row represents the final layouts produced by our proposed method. The fourth row represents the corresponding layout renders. 30 | 31 | ![Results](fig4.jpg) 32 | 33 | 34 | Codes and more results will be released soon. Please contact deepearthgo@gmail.com if you have any questions. 35 | 36 | -------------------------------------------------------------------------------- /code/dqn_algo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import gym.wrappers 4 | from net.iccv1.stock.dqn_agent import DQNAgent 5 | from net.iccv1.env import SingleAgentEnv 6 | from collections import deque 7 | 8 | class DQNAlgo(object): 9 | 10 | def __init__(self, filepath, netname, rootpath, algoname='duel', action_space=4, eps_start=1.0, eps_end=0.1, eps_steps=1e6, discount=0.99, buffer_size=100000, batch_size=1, copy_times=1e10): 11 | """ 12 | DQNAlgo --- Implementation of simple deep Q-learning architecture 13 | setting parameters: 14 | filepath -- the data file used to create stocks trading environment 15 | netname -- 'fc' for fully connected network, 'conv1d' for 1D convolutional network 16 | algoname -- '' for naive dqn, 'duel' for dueling dqn 17 | """ 18 | if netname == 'conv2d': 19 | env = SingleAgentEnv(filepath, rootpath) 20 | #env = SingleAgentEnv.from_dir([filepath],state_1d=True) 21 | else: 22 | env = SingleAgentEnv(filepath, rootpath) 23 | self.env = gym.wrappers.TimeLimit(env, max_episode_steps=1e10) 24 | if algoname == 'duel': 25 | netname = 'duel_' + netname 26 | self.agent = DQNAgent(netname, self.env.observation_space.shape, action_space, eps_start, discount) 27 | self.eps_start = eps_start 28 | self.eps_steps = eps_steps 29 | self.eps_end = eps_end 30 | self.batch_size = batch_size 31 | self.copy_times = copy_times 32 | self.replay_buffer = deque(maxlen=buffer_size) 33 | self.iteration_n = 1 34 | self.rewards_record = [] 35 | self.total_rewards_record = [0] 36 | 37 | 38 | def _get_batch(self): 39 | state_batch, action_batch, reward_batch, next_state_batch = [], [], [], [] 40 | idxs = np.random.choice(len(self.replay_buffer), self.batch_size) 41 | for idx in idxs: 42 | s, a, r, next_s = self.replay_buffer[idx] 43 | state_batch.append(s) 44 | action_batch.append(a) 45 | reward_batch.append(r) 46 | next_state_batch.append(next_s) 47 | return np.array(state_batch[0]), np.array(action_batch[0]), np.array(reward_batch[0]), np.array(next_state_batch[0]) 48 | 49 | def train_episode(self, max_iteration): 50 | state = self.env.reset() 51 | done = False 52 | while not done: 53 | action = self.agent.epsilon_greedy_action(state) 54 | next_s, reward, done, _ = self.env.step(action) 55 | self.replay_buffer.append((state, action, reward, next_s)) 56 | if len(self.replay_buffer) >= self.batch_size: 57 | state_batch, action_batch, reward_batch, next_state_batch = self._get_batch() 58 | self.agent.update_eval_net(state_batch, reward_batch, action_batch, next_state_batch) 59 | if self.iteration_n % self.copy_times == 0: 60 | self.agent.copy_to_target_net() 61 | 62 | self.rewards_record.append(reward) 63 | self.total_rewards_record.append(reward + self.total_rewards_record[-1]) 64 | state = next_s 65 | self.agent.set_epsilon(self.eps_end - (self.eps_end - self.eps_start) / self.eps_steps * self.iteration_n) 66 | self.iteration_n += 1 67 | if self.iteration_n > max_iteration: 68 | break 69 | 70 | def train(self, max_iteration=1e5, save_pth='/ai/51/dixinhan/demo2/net/iccv1/checkpoints/single1/model.pth'): 71 | self.iteration_n = 1 72 | self.rewards_record = [] 73 | self.total_rewards_record = [0] 74 | while self.iteration_n < max_iteration: 75 | self.train_episode(max_iteration) 76 | self.agent.save_model(save_pth) 77 | return self.rewards_record, self.total_rewards_record 78 | 79 | def valid(self, max_iteration=1e4, load_pth='/ai/51/dixinhan/demo2/net/iccv1/checkpoints/single1/model.pth'): 80 | self.iteration_n = 1 81 | self.rewards_record = [] 82 | self.total_rewards_record = [0] 83 | self.agent.load_model(load_pth) 84 | while self.iteration_n <= max_iteration: 85 | state = self.env.reset() 86 | done = False 87 | while not done: 88 | action = self.agent.greedy_action(state) 89 | next_s, reward, done, _ = self.env.step(action) 90 | self.rewards_record.append(reward) 91 | self.total_rewards_record.append(reward + self.total_rewards_record[-1]) 92 | state = next_s 93 | self.iteration_n += 1 94 | if self.iteration_n > max_iteration: 95 | break 96 | 97 | return self.rewards_record, self.total_rewards_record -------------------------------------------------------------------------------- /code/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class FC(nn.Module): 7 | """ 8 | 3-layer fully connected network 9 | """ 10 | def __init__(self, input_dim, output_dim): 11 | super(FC, self).__init__() 12 | self.fc = nn.Sequential( 13 | nn.Linear(input_dim, 512), 14 | nn.ReLU(), 15 | nn.Linear(512, 512), 16 | nn.ReLU(), 17 | nn.Linear(512, output_dim) 18 | ) 19 | 20 | def forward(self, x): 21 | out = self.fc(x) 22 | return out 23 | 24 | class Duel_FC(nn.Module): 25 | """ 26 | 3-layer fully connected network for dueling architecture 27 | """ 28 | def __init__(self, input_dim, output_dim): 29 | super(Duel_FC, self).__init__() 30 | self.value_net = nn.Sequential( 31 | nn.Linear(input_dim, 512), 32 | nn.ReLU(), 33 | nn.Linear(512, 512), 34 | nn.ReLU(), 35 | nn.Linear(512, 1) 36 | ) 37 | self.adv_net = nn.Sequential( 38 | nn.Linear(input_dim, 512), 39 | nn.ReLU(), 40 | nn.Linear(512, 512), 41 | nn.ReLU(), 42 | nn.Linear(512, output_dim) 43 | ) 44 | 45 | def forward(self, x): 46 | value = self.value_net(x) 47 | adv = self.adv_net(x) 48 | out = value + (adv - adv.mean(dim=1, keepdims=True)) 49 | return out 50 | 51 | class Conv_2D(nn.Module): 52 | """ 53 | 4-layer 2D convolutional neural network 54 | """ 55 | def __init__(self, input_shape, output_dim): 56 | super(Conv_2D, self).__init__() 57 | self.conv = nn.Sequential( 58 | nn.Conv2d(3, 8, 3, stride=2), 59 | nn.ReLU(), 60 | nn.Conv1d(8, 16, 3, stride=2), 61 | nn.ReLU(), 62 | ) 63 | conv_out_dim = self._get_conv_out_dim(input_shape) 64 | self.fc = nn.Sequential( 65 | nn.Linear(conv_out_dim, 512), 66 | nn.ReLU(), 67 | nn.Linear(512, output_dim) 68 | ) 69 | 70 | def _get_conv_out_dim(self, input_shape): 71 | out = self.conv(torch.zeros(1, *input_shape)) 72 | return int(np.prod(out.shape)) 73 | 74 | def forward(self, x): 75 | batch_size = x.shape[0] 76 | conv_out = self.conv(x).view(batch_size,-1) 77 | out = self.fc(conv_out) 78 | return out 79 | 80 | class Duel_Conv_2D(nn.Module): 81 | """ 82 | 4-layer two-branch 1D convolutional neural network for dueling architecture, 83 | the first two 1D convolutional layer are shared. 84 | """ 85 | def __init__(self, input_shape, output_dim): 86 | super(Duel_Conv_2D, self).__init__() 87 | self.conv = nn.Sequential( 88 | nn.Conv2d(3, 8, 3, stride=2), 89 | nn.ReLU(), 90 | nn.Conv2d(8,16, 3, stride=2), 91 | nn.ReLU(), 92 | nn.Conv2d(16,32, 3, stride=2), 93 | nn.ReLU(), 94 | ) 95 | conv_out_dim = self._get_conv_out_dim(input_shape) 96 | self.value_net = nn.Sequential( 97 | nn.Linear(conv_out_dim, 512), 98 | nn.ReLU(), 99 | nn.Linear(512, 1) 100 | ) 101 | self.adv_net = nn.Sequential( 102 | nn.Linear(conv_out_dim, 512), 103 | nn.ReLU(), 104 | nn.Linear(512, output_dim) 105 | ) 106 | 107 | def _get_conv_out_dim(self, input_shape): 108 | out = self.conv(torch.zeros(1, *input_shape)) 109 | return int(np.prod(out.shape)) 110 | 111 | def forward(self, x): 112 | batch_size = x.shape[0] 113 | conv_out = self.conv(x).view(batch_size,-1) 114 | value = self.value_net(conv_out) 115 | adv = self.adv_net(conv_out) 116 | out = value + (adv - adv.mean(dim=1, keepdims=True)) 117 | return out -------------------------------------------------------------------------------- /code/run_scale.py: -------------------------------------------------------------------------------- 1 | from net.iccv1.stock.dqn_algo_scale import * 2 | root_path = "/ai/51/dixinhan/layout3/bedroom1/srl1_scale_r3/" 3 | json_path = "/ai/51/dixinhan/layout3/bedroom1/27705_4.json" 4 | 5 | #1.1 6 | dqn = DQNAlgo(filepath=json_path, netname='conv2d', rootpath = root_path, algoname='duel') 7 | #1.2 8 | rewards11, total_rewards11 = dqn.train(max_iteration=1e5, save_pth='/ai/51/dixinhan/demo2/net/iccv1/checkpoints/single4/model.pth') 9 | #1.3 10 | import matplotlib.pyplot as plt 11 | fig = plt.figure(figsize=(15,5)) 12 | fig.add_subplot(1,2,1) 13 | plt.plot(list(range(len(rewards11))), rewards11) 14 | plt.xlabel('iteration') 15 | plt.ylabel('reward') 16 | fig.add_subplot(1,2,2) 17 | plt.plot(list(range(len(rewards11))), total_rewards11[1:]) 18 | plt.xlabel('iteration') 19 | plt.ylabel('total reward') 20 | plt.show() 21 | 22 | #2.1 23 | valid_dqn = DQNAlgo(filepath='./data/16.csv', netname='fc', algoname='duel') 24 | #2.2 25 | rewards12, total_rewards12 = valid_dqn.valid(max_iteration=1e4, load_pth='./modelduelfc.pth') 26 | #2.3 27 | import matplotlib.pyplot as plt 28 | fig = plt.figure(figsize=(15,5)) 29 | fig.add_subplot(1,2,1) 30 | plt.plot(list(range(len(rewards12))), rewards12) 31 | plt.xlabel('iteration') 32 | plt.ylabel('reward') 33 | fig.add_subplot(1,2,2) 34 | plt.plot(list(range(len(rewards12))), total_rewards12[1:]) 35 | plt.xlabel('iteration') 36 | plt.ylabel('total reward') 37 | plt.show() 38 | 39 | #%% 40 | 41 | dqn2 = DQNAlgo(filepath='./data/15.csv', netname='fc', algoname='') 42 | 43 | #%% 44 | 45 | rewards21, total_rewards21 = dqn2.train(max_iteration=1e6, save_pth='./modelfc.pth') 46 | 47 | #%% 48 | 49 | import matplotlib.pyplot as plt 50 | fig = plt.figure(figsize=(15,5)) 51 | fig.add_subplot(1,2,1) 52 | plt.plot(list(range(len(rewards21))), rewards21) 53 | plt.xlabel('iteration') 54 | plt.ylabel('reward') 55 | fig.add_subplot(1,2,2) 56 | plt.plot(list(range(len(rewards21))), total_rewards21[1:]) 57 | plt.xlabel('iteration') 58 | plt.ylabel('total reward') 59 | plt.show() 60 | 61 | #%% 62 | 63 | valid_dqn2 = DQNAlgo(filepath='./data/16.csv', netname='fc', algoname='') 64 | 65 | #%% 66 | 67 | rewards22, total_rewards22 = valid_dqn2.valid(max_iteration=1e6, load_pth='./modelfc.pth') 68 | 69 | #%% 70 | 71 | import matplotlib.pyplot as plt 72 | fig = plt.figure(figsize=(15,5)) 73 | fig.add_subplot(1,2,1) 74 | plt.plot(list(range(len(rewards22))), rewards22) 75 | plt.xlabel('iteration') 76 | plt.ylabel('reward') 77 | fig.add_subplot(1,2,2) 78 | plt.plot(list(range(len(rewards22))), total_rewards22[1:]) 79 | plt.xlabel('iteration') 80 | plt.ylabel('total reward') 81 | plt.show() 82 | 83 | #%% 84 | 85 | dqn3 = DQNAlgo(filepath='./data/15.csv', netname='conv1d', algoname='duel') 86 | 87 | #%% 88 | 89 | rewards31, total_rewards31 = dqn3.train(max_iteration=1e6, save_pth='./modelduelconv.pth') 90 | 91 | #%% 92 | 93 | import matplotlib.pyplot as plt 94 | fig = plt.figure(figsize=(15,5)) 95 | fig.add_subplot(1,2,1) 96 | plt.plot(list(range(len(rewards31))), rewards31) 97 | plt.xlabel('iteration') 98 | plt.ylabel('reward') 99 | fig.add_subplot(1,2,2) 100 | plt.plot(list(range(len(rewards31))), total_rewards31[1:]) 101 | plt.xlabel('iteration') 102 | plt.ylabel('total reward') 103 | plt.show() 104 | 105 | #%% 106 | 107 | valid_dqn3 = DQNAlgo(filepath='./data/16.csv', netname='conv1d', algoname='duel') 108 | 109 | #%% 110 | 111 | rewards32, total_rewards32 = valid_dqn3.valid(max_iteration=1e6, load_pth='./modelduelconv.pth') 112 | 113 | #%% 114 | 115 | import matplotlib.pyplot as plt 116 | fig = plt.figure(figsize=(15,5)) 117 | fig.add_subplot(1,2,1) 118 | plt.plot(list(range(len(rewards32))), rewards32) 119 | plt.xlabel('iteration') 120 | plt.ylabel('reward') 121 | fig.add_subplot(1,2,2) 122 | plt.plot(list(range(len(rewards32))), total_rewards32[1:]) 123 | plt.xlabel('iteration') 124 | plt.ylabel('total reward') 125 | plt.show() 126 | 127 | #%% 128 | 129 | dqn4 = DQNAlgo(filepath='./data/15.csv', netname='conv1d', algoname='') 130 | 131 | #%% 132 | 133 | rewards41, total_rewards41 = dqn4.train(max_iteration=1e6, save_pth='./modelconv.pth') 134 | 135 | #%% 136 | 137 | import matplotlib.pyplot as plt 138 | fig = plt.figure(figsize=(15,5)) 139 | fig.add_subplot(1,2,1) 140 | plt.plot(list(range(len(rewards41))), rewards41) 141 | plt.xlabel('iteration') 142 | plt.ylabel('reward') 143 | fig.add_subplot(1,2,2) 144 | plt.plot(list(range(len(rewards41))), total_rewards41[1:]) 145 | plt.xlabel('iteration') 146 | plt.ylabel('total reward') 147 | plt.show() 148 | 149 | #%% 150 | 151 | valid_dqn4 = DQNAlgo(filepath='./data/16.csv', netname='conv1d', algoname='') 152 | 153 | #%% 154 | 155 | rewards42, total_rewards42 = valid_dqn4.valid(max_iteration=1e6, load_pth='./modelconv.pth') 156 | 157 | #%% 158 | 159 | import matplotlib.pyplot as plt 160 | fig = plt.figure(figsize=(15,5)) 161 | fig.add_subplot(1,2,1) 162 | plt.plot(list(range(len(rewards42))), rewards42) 163 | plt.xlabel('iteration') 164 | plt.ylabel('reward') 165 | fig.add_subplot(1,2,2) 166 | plt.plot(list(range(len(rewards42))), total_rewards42[1:]) 167 | plt.xlabel('iteration') 168 | plt.ylabel('total reward') 169 | plt.show() 170 | 171 | #%% 172 | 173 | import matplotlib.pyplot as plt 174 | fig = plt.figure(figsize=(15,5)) 175 | fig.add_subplot(1,2,1) 176 | plt.plot(list(range(len(rewards12))), rewards12, label='dueling dqn + fc') 177 | plt.plot(list(range(len(rewards12))), rewards22, label='dqn + fc') 178 | plt.plot(list(range(len(rewards12))), rewards32, label='dueling dqn + conv') 179 | plt.plot(list(range(len(rewards12))), rewards42, label='dqn + conv') 180 | plt.xlabel('validation iteration') 181 | plt.ylabel('reward') 182 | fig.add_subplot(1,2,2) 183 | plt.plot(list(range(len(rewards12))), total_rewards12[1:], label='dueling dqn + fc') 184 | plt.plot(list(range(len(rewards12))), total_rewards22[1:], label='dqn + fc') 185 | plt.plot(list(range(len(rewards12))), total_rewards32[1:], label='dueling dqn + conv') 186 | plt.plot(list(range(len(rewards12))), total_rewards42[1:], label='dqn + conv') 187 | plt.xlabel('validation iteration') 188 | plt.ylabel('total reward') 189 | plt.legend() 190 | plt.show() -------------------------------------------------------------------------------- /fig1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CODE-SUBMIT/simulator1/8c98522ee275747258f55ff001d38ca791fed56c/fig1.jpg -------------------------------------------------------------------------------- /fig2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CODE-SUBMIT/simulator1/8c98522ee275747258f55ff001d38ca791fed56c/fig2.jpg -------------------------------------------------------------------------------- /fig3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CODE-SUBMIT/simulator1/8c98522ee275747258f55ff001d38ca791fed56c/fig3.jpg -------------------------------------------------------------------------------- /fig4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CODE-SUBMIT/simulator1/8c98522ee275747258f55ff001d38ca791fed56c/fig4.jpg --------------------------------------------------------------------------------