├── README.md ├── features.py ├── image_helper.py ├── main.py ├── metrics.py ├── parse_xml_annotations.py ├── reinforcement.py ├── test.ipynb └── train.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch_Deep_RL_1 2 | Pytorch implements of [Hierarchical Object Detection with Deep Reinforcement Learning](https://arxiv.org/abs/1611.03718) 3 | -------------------------------------------------------------------------------- /features.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torchvision.transforms as transforms 4 | from torch.autograd import Variable 5 | 6 | 7 | transform = transforms.Compose([ 8 | transforms.ToPILImage(), 9 | transforms.Scale((224,224)), 10 | transforms.ToTensor(), 11 | transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5)) # numbers here need to be adjusted in future 12 | ]) 13 | 14 | 15 | def getVGG_16bn(path_vgg): 16 | # if the pre_trained vgg16 model not in path_vgg, download it from the url below 17 | state_dict = torch.utils.model_zoo.load_url('https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',path_vgg) 18 | model = torchvision.models.vgg16_bn() 19 | model.load_state_dict(state_dict) 20 | # ingore the classifier 21 | model_2 = list(model.children())[0] 22 | return model_2 23 | 24 | # dtype determine to use cpu or gpu 25 | def get_conv_feature_for_image(image, model, dtype=torch.cuda.FloatTensor): 26 | im = transform(image) 27 | im = im.view(1,*im.shape) 28 | feature = model(Variable(im).type(dtype)) 29 | return feature.data 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /image_helper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import ImageDraw 3 | from PIL import Image 4 | 5 | 6 | def load_images_labels_in_data_set(data_set_name, path_voc): 7 | file_path = path_voc + '/ImageSets/Main/' + data_set_name + '.txt' 8 | f = open(file_path) 9 | images_names = f.readlines() 10 | images_names = [x.split(None, 1)[1] for x in images_names] 11 | images_names = [x.strip('\n').strip(None) for x in images_names] 12 | return images_names 13 | 14 | 15 | def load_image_data(path_voc, class_object): 16 | print("load images" + path_voc) 17 | image_names = np.array(load_images_names_in_data_set('aeroplane_trainval', path_voc)) 18 | labels = load_images_labels_in_data_set('aeroplane_trainval', path_voc) 19 | image_names_class = [] 20 | for i in range(len(image_names)): 21 | if labels[i] == class_object: 22 | image_names_class.append(image_names[i]) 23 | image_names = image_names_class 24 | images = get_all_images(image_names, path_voc) 25 | print("total image:%d" % len(image_names)) 26 | return image_names, images 27 | 28 | 29 | def load_images_names_in_data_set(data_set_name, path_voc): 30 | file_path = path_voc + '/ImageSets/Main/' + data_set_name + '.txt' 31 | f = open(file_path) 32 | image_names = f.readlines() 33 | image_names = [x.strip('\n') for x in image_names] 34 | return [x.split(None, 1)[0] for x in image_names] 35 | 36 | 37 | def get_all_images(image_names, path_voc): 38 | images = [] 39 | for j in range(np.size(image_names)): 40 | image_name = image_names[j] 41 | string = path_voc + '/JPEGImages/' + image_name + '.jpg' 42 | img = Image.open(string) 43 | images.append(np.array(img)) 44 | return images -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torch.optim as optim 4 | from torch.autograd import Variable 5 | from image_helper import * 6 | from parse_xml_annotations import * 7 | from features import * 8 | from reinforcement import * 9 | from metrics import * 10 | from collections import namedtuple 11 | import time 12 | import os 13 | import numpy as np 14 | import random 15 | 16 | 17 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 18 | path_voc = "../datas/VOCdevkit/VOC2007" 19 | 20 | # get models 21 | print("load models") 22 | 23 | model_vgg = getVGG_16bn("../models") 24 | model_vgg = model_vgg.cuda() 25 | model = get_q_network() 26 | model = model.cuda() 27 | 28 | # define optimizers for each model 29 | optimizer = optim.Adam(model.parameters(),lr=1e-6) 30 | criterion = nn.MSELoss().cuda() 31 | 32 | # get image datas 33 | path_voc_1 = "../datas/VOCdevkit/VOC2007" 34 | path_voc_2 = "../datas/VOCdevkit/VOC2012" 35 | class_object = '1' 36 | image_names_1, images_1 = load_image_data(path_voc_1, class_object) 37 | image_names_2, images_2 = load_image_data(path_voc_2, class_object) 38 | image_names = image_names_1 + image_names_2 39 | images = images_1 + images_2 40 | 41 | print("aeroplane_trainval image:%d" % len(image_names)) 42 | 43 | # define the Pytorch Tensor 44 | use_cuda = torch.cuda.is_available() 45 | FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor 46 | LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor 47 | ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor 48 | Tensor = FloatTensor 49 | 50 | # define the super parameter 51 | epsilon = 1.0 52 | BATCH_SIZE = 100 53 | GAMMA = 0.90 54 | CLASS_OBJECT = 1 55 | steps = 10 56 | epochs = 50 57 | memory = ReplayMemory(1000) 58 | 59 | def select_action(state): 60 | if random.random() < epsilon: 61 | action = np.random.randint(1,7) 62 | else: 63 | qval = model(Variable(state)) 64 | _, predicted = torch.max(qval.data,1) 65 | action = predicted[0] + 1 66 | return action 67 | 68 | 69 | Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward')) 70 | def optimizer_model(): 71 | if len(memory) < BATCH_SIZE: 72 | return 73 | transitions = memory.sample(BATCH_SIZE) 74 | batch = Transition(*zip(*transitions)) 75 | 76 | non_final_mask = ByteTensor(tuple(map(lambda s: s is not None, batch.next_state))) 77 | next_states = [s for s in batch.next_state if s is not None] 78 | non_final_next_states = Variable(torch.cat(next_states), 79 | volatile=True).type(Tensor) 80 | state_batch = Variable(torch.cat(batch.state)).type(Tensor) 81 | action_batch = Variable(torch.LongTensor(batch.action).view(-1,1)).type(LongTensor) 82 | reward_batch = Variable(torch.FloatTensor(batch.reward).view(-1,1)).type(Tensor) 83 | 84 | # Compute Q(s_t, a) - the model computes Q(s_t), then we select the 85 | # columns of actions taken 86 | state_action_values = model(state_batch).gather(1, action_batch) 87 | 88 | # Compute V(s_{t+1}) for all next states. 89 | next_state_values = Variable(torch.zeros(BATCH_SIZE, 1).type(Tensor)) 90 | next_state_values[non_final_mask] = model(non_final_next_states).max(1)[0] 91 | 92 | # Now, we don't want to mess up the loss with a volatile flag, so let's 93 | # clear it. After this, we'll just end up with a Variable that has 94 | # requires_grad=False 95 | next_state_values.volatile = False 96 | 97 | # Compute the expected Q values 98 | expected_state_action_values = (next_state_values * GAMMA) + reward_batch 99 | 100 | # Compute loss 101 | loss = criterion(state_action_values, expected_state_action_values) 102 | 103 | # Optimize the model 104 | optimizer.zero_grad() 105 | loss.backward() 106 | optimizer.step() 107 | 108 | # train procedure 109 | print('train the Q-network') 110 | for epoch in range(epochs): 111 | print('epoch: %d' %epoch) 112 | now = time.time() 113 | for i in range(len(image_names)): 114 | # the image part 115 | image_name = image_names[i] 116 | image = images[i] 117 | if i < len(image_names_1): 118 | annotation = get_bb_of_gt_from_pascal_xml_annotation(image_name, path_voc_1) 119 | else: 120 | annotation = get_bb_of_gt_from_pascal_xml_annotation(image_name, path_voc_2) 121 | classes_gt_objects = get_ids_objects_from_annotation(annotation) 122 | gt_masks = generate_bounding_box_from_annotation(annotation, image.shape) 123 | 124 | # the iou part 125 | original_shape = (image.shape[0], image.shape[1]) 126 | region_mask = np.ones((image.shape[0], image.shape[1])) 127 | #choose the max bouding box 128 | iou = find_max_bounding_box(gt_masks, region_mask, classes_gt_objects, CLASS_OBJECT) 129 | 130 | # the initial part 131 | region_image = image 132 | size_mask = original_shape 133 | offset = (0, 0) 134 | history_vector = torch.zeros((4,6)) 135 | state = get_state(region_image, history_vector, model_vgg) 136 | done = False 137 | for step in range(steps): 138 | # Select action, the author force terminal action if case actual IoU is higher than 0.5 139 | if iou > 0.5: 140 | action = 6 141 | else: 142 | action = select_action(state) 143 | 144 | # Perform the action and observe new state 145 | if action == 6: 146 | next_state = None 147 | reward = get_reward_trigger(iou) 148 | done = True 149 | else: 150 | offset, region_image, size_mask, region_mask = get_crop_image_and_mask(original_shape, offset, 151 | region_image, size_mask, action) 152 | # update history vector and get next state 153 | history_vector = update_history_vector(history_vector, action) 154 | next_state = get_state(region_image, history_vector, model_vgg) 155 | 156 | # find the max bounding box in the region image 157 | new_iou = find_max_bounding_box(gt_masks, region_mask, classes_gt_objects, CLASS_OBJECT) 158 | reward = get_reward_movement(iou, new_iou) 159 | iou = new_iou 160 | print('epoch: %d, image: %d, step: %d, reward: %d' %(epoch ,i, step, reward)) 161 | # Store the transition in memory 162 | memory.push(state, action-1, next_state, reward) 163 | 164 | # Move to the next state 165 | state = next_state 166 | 167 | # Perform one step of the optimization (on the target network) 168 | optimizer_model() 169 | if done: 170 | break 171 | if epsilon > 0.1: 172 | epsilon -= 0.1 173 | time_cost = time.time() - now 174 | print('epoch = %d, time_cost = %.4f' %(epoch, time_cost)) 175 | 176 | # save the whole model 177 | Q_NETWORK_PATH = '../models/' + 'voc2012_2007_model' 178 | torch.save(model, Q_NETWORK_PATH) 179 | print('Complete') -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | scale_subregion = float(3) / 4 5 | scale_mask = float(1)/(scale_subregion*4) 6 | def calculate_iou(img_mask, gt_mask): 7 | gt_mask *= 1.0 8 | img_and = cv2.bitwise_and(img_mask, gt_mask) 9 | img_or = cv2.bitwise_or(img_mask, gt_mask) 10 | j = np.count_nonzero(img_and) 11 | i = np.count_nonzero(img_or) 12 | iou = float(float(j)/float(i)) 13 | return iou 14 | 15 | 16 | def calculate_overlapping(img_mask, gt_mask): 17 | gt_mask *= 1.0 18 | img_and = cv2.bitwise_and(img_mask, gt_mask) 19 | j = np.count_nonzero(img_and) 20 | i = np.count_nonzero(gt_mask) 21 | overlap = float(float(j)/float(i)) 22 | return overlap 23 | 24 | 25 | def follow_iou(gt_masks, region_mask, classes_gt_objects, class_object, last_matrix): 26 | results = np.zeros([np.size(array_classes_gt_objects), 1]) 27 | for k in range(np.size(classes_gt_objects)): 28 | if classes_gt_objects[k] == class_object: 29 | gt_mask = gt_masks[:, :, k] 30 | iou = calculate_iou(region_mask, gt_mask) 31 | results[k] = iou 32 | index = np.argmax(results) 33 | new_iou = results[index] 34 | iou = last_matrix[index] 35 | return iou, new_iou, results, index 36 | 37 | # Auto find the max bounding box in the region image 38 | def find_max_bounding_box(gt_masks, region_mask, classes_gt_objects, class_object): 39 | _, _, n = gt_masks.shape 40 | max_iou = 0.0 41 | for k in range(n): 42 | if classes_gt_objects[k] != class_object: 43 | continue 44 | gt_mask = gt_masks[:,:,k] 45 | iou = calculate_iou(region_mask, gt_mask) 46 | if max_iou < iou: 47 | max_iou = iou 48 | return max_iou 49 | 50 | def get_crop_image_and_mask(original_shape, offset, region_image, size_mask, action): 51 | r"""crop the the image according to action 52 | 53 | Args: 54 | original_shape: shape of original image (H x W) 55 | offset: the current image's left-top coordinate base on the original image 56 | region_image: the image to be cropped 57 | size_mask: the size of region_image 58 | action: the action choose by agent. can be 1,2,3,4,5. 59 | 60 | Returns: 61 | offset: the cropped image's left-top coordinate base on original image 62 | region_image: the cropped image 63 | size_mask: the size of the cropped image 64 | region_mask: the masked image which mask cropped region and has same size with original image 65 | 66 | """ 67 | 68 | 69 | region_mask = np.zeros(original_shape) # mask at original image 70 | size_mask = (int(size_mask[0] * scale_subregion), int(size_mask[1] * scale_subregion)) # the size of croped image 71 | if action == 1: 72 | offset_aux = (0, 0) 73 | elif action == 2: 74 | offset_aux = (0, int(size_mask[1] * scale_mask)) 75 | offset = (offset[0], offset[1] + int(size_mask[1] * scale_mask)) 76 | elif action == 3: 77 | offset_aux = (int(size_mask[0] * scale_mask), 0) 78 | offset = (offset[0] + int(size_mask[0] * scale_mask), offset[1]) 79 | elif action == 4: 80 | offset_aux = (int(size_mask[0] * scale_mask), 81 | int(size_mask[1] * scale_mask)) 82 | offset = (offset[0] + int(size_mask[0] * scale_mask), 83 | offset[1] + int(size_mask[1] * scale_mask)) 84 | elif action == 5: 85 | offset_aux = (int(size_mask[0] * scale_mask / 2), 86 | int(size_mask[0] * scale_mask / 2)) 87 | offset = (offset[0] + int(size_mask[0] * scale_mask / 2), 88 | offset[1] + int(size_mask[0] * scale_mask / 2)) 89 | region_image = region_image[offset_aux[0]:offset_aux[0] + size_mask[0], 90 | offset_aux[1]:offset_aux[1] + size_mask[1]] 91 | region_mask[offset[0]:offset[0] + size_mask[0], offset[1]:offset[1] + size_mask[1]] = 1 92 | return offset, region_image, size_mask, region_mask 93 | -------------------------------------------------------------------------------- /parse_xml_annotations.py: -------------------------------------------------------------------------------- 1 | import xml.etree.ElementTree as ET 2 | import numpy as np 3 | 4 | 5 | def get_bb_of_gt_from_pascal_xml_annotation(xml_name, voc_path): 6 | string = voc_path + '/Annotations/' + xml_name + '.xml' 7 | tree = ET.parse(string) 8 | root = tree.getroot() 9 | names = [] 10 | x_min = [] 11 | x_max = [] 12 | y_min = [] 13 | y_max = [] 14 | for child in root: 15 | if child.tag == 'object': 16 | for child2 in child: 17 | if child2.tag == 'name': 18 | names.append(child2.text) 19 | elif child2.tag == 'bndbox': 20 | for child3 in child2: 21 | if child3.tag == 'xmin': 22 | x_min.append(child3.text) 23 | elif child3.tag == 'xmax': 24 | x_max.append(child3.text) 25 | elif child3.tag == 'ymin': 26 | y_min.append(child3.text) 27 | elif child3.tag == 'ymax': 28 | y_max.append(child3.text) 29 | category_and_bb = np.zeros([np.size(names), 5]) 30 | for i in range(np.size(names)): 31 | category_and_bb[i][0] = get_id_of_class_name(names[i]) 32 | category_and_bb[i][1] = x_min[i] 33 | category_and_bb[i][2] = x_max[i] 34 | category_and_bb[i][3] = y_min[i] 35 | category_and_bb[i][4] = y_max[i] 36 | return category_and_bb 37 | 38 | 39 | def get_all_annotations(image_names, voc_path): 40 | annotations = [] 41 | for i in range(np.size(image_names)): 42 | image_name = image_names[i] 43 | annotations.append(get_bb_of_gt_from_pascal_xml_annotation(image_name, voc_path)) 44 | return annotations 45 | 46 | 47 | def generate_bounding_box_from_annotation(annotation, image_shape): 48 | length_annotation = annotation.shape[0] 49 | masks = np.zeros([image_shape[0], image_shape[1], length_annotation]) 50 | for i in range(0, length_annotation): 51 | masks[int(annotation[i, 3]):int(annotation[i, 4]), int(annotation[i, 1]):int(annotation[i, 2]), i] = 1 52 | return masks 53 | 54 | 55 | def get_ids_objects_from_annotation(annotation): 56 | return annotation[:, 0] 57 | 58 | 59 | def get_id_of_class_name (class_name): 60 | if class_name == 'aeroplane': 61 | return 1 62 | elif class_name == 'bicycle': 63 | return 2 64 | elif class_name == 'bird': 65 | return 3 66 | elif class_name == 'boat': 67 | return 4 68 | elif class_name == 'bottle': 69 | return 5 70 | elif class_name == 'bus': 71 | return 6 72 | elif class_name == 'car': 73 | return 7 74 | elif class_name == 'cat': 75 | return 8 76 | elif class_name == 'chair': 77 | return 9 78 | elif class_name == 'cow': 79 | return 10 80 | elif class_name == 'diningtable': 81 | return 11 82 | elif class_name == 'dog': 83 | return 12 84 | elif class_name == 'horse': 85 | return 13 86 | elif class_name == 'motorbike': 87 | return 14 88 | elif class_name == 'person': 89 | return 15 90 | elif class_name == 'pottedplant': 91 | return 16 92 | elif class_name == 'sheep': 93 | return 17 94 | elif class_name == 'sofa': 95 | return 18 96 | elif class_name == 'train': 97 | return 19 98 | elif class_name == 'tvmonitor': 99 | return 20 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /reinforcement.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torch.nn.init as init 4 | import torch.nn as nn 5 | from collections import namedtuple 6 | from features import * 7 | 8 | 9 | # Different actions that the agent can do 10 | number_of_actions = 6 11 | # Actions captures in the history vector 12 | actions_of_history = 4 13 | # Visual descriptor size 14 | visual_descriptor_size = 25088 15 | # Reward movement action 16 | reward_movement_action = 1 17 | # Reward terminal action 18 | reward_terminal_action = 3 19 | # IoU required to consider a positive detection 20 | iou_threshold = 0.5 21 | 22 | 23 | Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward')) 24 | class ReplayMemory(object): 25 | 26 | def __init__(self, capacity): 27 | self.capacity = capacity 28 | self.memory = [] 29 | self.position = 0 30 | 31 | def push(self, *args): 32 | """Saves a transition.""" 33 | if len(self.memory) < self.capacity: 34 | self.memory.append(None) 35 | self.memory[self.position] = Transition(*args) 36 | self.position = (self.position + 1) % self.capacity 37 | 38 | def sample(self, batch_size): 39 | return random.sample(self.memory, batch_size) 40 | 41 | def __len__(self): 42 | return len(self.memory) 43 | 44 | 45 | def get_state(image, history_vector, model_vgg, dtype=torch.cuda.FloatTensor): 46 | image_feature = get_conv_feature_for_image(image, model_vgg,dtype) 47 | image_feature = image_feature.view(1,-1) 48 | history_vector_flatten = history_vector.view(1,-1).type(dtype) 49 | state = torch.cat((image_feature, history_vector_flatten), 1) 50 | return state 51 | 52 | 53 | # FIFO 54 | def update_history_vector(history_vector, action): 55 | action_vector = torch.zeros(number_of_actions) 56 | action_vector[action-1] = 1 57 | size_history_vector = len(torch.nonzero(history_vector)) 58 | if size_history_vector < actions_of_history: 59 | history_vector[size_history_vector][action-1] = 1 60 | else: 61 | for i in range(actions_of_history-1,0,-1): 62 | history_vector[i][:] = history_vector[i-1][:] 63 | history_vector[0][:] = action_vector[:] 64 | return history_vector 65 | 66 | 67 | def get_q_network(weights_path="0"): 68 | model = nn.Sequential( 69 | nn.Linear(25112, 1024), 70 | nn.ReLU(), 71 | nn.Dropout(0.2), 72 | nn.Linear(1024, 1024), 73 | nn.ReLU(), 74 | nn.Dropout(0.2), 75 | nn.Linear(1024, 6), 76 | ) 77 | if weights_path != "0": 78 | model.load_state_dict(torch.load(weights_path)) 79 | # init weights by xavier_normal, it may be different with the author's implements 80 | def weights_init(m): 81 | if isinstance(m, nn.Linear): 82 | init.xavier_normal(m.weight.data) 83 | model.apply(weights_init) 84 | return model 85 | 86 | 87 | def get_reward_movement(iou, new_iou): 88 | if new_iou > iou: 89 | reward = reward_movement_action 90 | else: 91 | reward = - reward_movement_action 92 | return reward 93 | 94 | 95 | def get_reward_trigger(new_iou): 96 | if new_iou > iou_threshold: 97 | reward = reward_terminal_action 98 | else: 99 | reward = - reward_terminal_action 100 | return reward -------------------------------------------------------------------------------- /train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import os\n", 12 | "import numpy as np\n", 13 | "import torch\n", 14 | "import torchvision\n", 15 | "import torchvision.transforms as transforms\n", 16 | "import torchvision.datasets as datasets\n", 17 | "import torch.utils.data as data\n", 18 | "import torch.optim as optim\n", 19 | "from torch.autograd import Variable\n", 20 | "import random\n", 21 | "import matplotlib.pyplot as plt\n", 22 | "%matplotlib inline\n", 23 | "from image_helper import *\n", 24 | "from parse_xml_annotations import *\n", 25 | "from features import *\n", 26 | "from reinforcement import *\n", 27 | "from metrics import *\n", 28 | "import logging\n", 29 | "import time\n", 30 | "import os\n", 31 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"" 32 | ] 33 | }, 34 | { 35 | "cell_type": "markdown", 36 | "metadata": {}, 37 | "source": [ 38 | "# Load Image and Model" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": { 45 | "collapsed": true 46 | }, 47 | "outputs": [], 48 | "source": [ 49 | "path_voc = \"../datas/VOCdevkit/VOC2007\"\n", 50 | "\n", 51 | "# get models \n", 52 | "print(\"load models\")\n", 53 | "\n", 54 | "model_vgg = getVGG_16bn(\"../models\")\n", 55 | "model_vgg = model_vgg.cuda()\n", 56 | "model = get_q_network()\n", 57 | "model = model.cuda()\n", 58 | "\n", 59 | "# define optimizers for each model\n", 60 | "optimizer = optim.Adam(model.parameters(),lr=1e-6)\n", 61 | "criterion = nn.MSELoss().cuda() \n", 62 | "\n", 63 | "# get image datas\n", 64 | "print(\"load images\")\n", 65 | "\n", 66 | "path_voc = \"../datas/VOCdevkit/VOC2007\"\n", 67 | "image_names = np.array(load_images_names_in_data_set('aeroplane_trainval', path_voc))\n", 68 | "labels = load_images_labels_in_data_set('aeroplane_trainval', path_voc)\n", 69 | "image_names_aero = []\n", 70 | "for i in range(len(image_names)):\n", 71 | " if labels[i] == '1':\n", 72 | " image_names_aero.append(image_names[i])\n", 73 | "image_names = image_names_aero\n", 74 | "images = get_all_images(image_names, path_voc)\n", 75 | "\n", 76 | "print(\"aeroplane_trainval image:%d\" % len(image_names))" 77 | ] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "metadata": {}, 82 | "source": [ 83 | "### the replay part should be added in replay.py" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "metadata": { 90 | "collapsed": true 91 | }, 92 | "outputs": [], 93 | "source": [ 94 | "from collections import namedtuple\n", 95 | "Transition = namedtuple('Transition',\n", 96 | " ('state', 'action', 'next_state', 'reward'))\n", 97 | "\n", 98 | "use_cuda = torch.cuda.is_available()\n", 99 | "FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor\n", 100 | "LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor\n", 101 | "ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor\n", 102 | "Tensor = FloatTensor\n", 103 | "class ReplayMemory(object):\n", 104 | "\n", 105 | " def __init__(self, capacity):\n", 106 | " self.capacity = capacity\n", 107 | " self.memory = []\n", 108 | " self.position = 0\n", 109 | "\n", 110 | " def push(self, *args):\n", 111 | " \"\"\"Saves a transition.\"\"\"\n", 112 | " if len(self.memory) < self.capacity:\n", 113 | " self.memory.append(None)\n", 114 | " self.memory[self.position] = Transition(*args)\n", 115 | " self.position = (self.position + 1) % self.capacity\n", 116 | "\n", 117 | " def sample(self, batch_size):\n", 118 | " return random.sample(self.memory, batch_size)\n", 119 | "\n", 120 | " def __len__(self):\n", 121 | " return len(self.memory)" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": null, 127 | "metadata": { 128 | "collapsed": true 129 | }, 130 | "outputs": [], 131 | "source": [ 132 | "epsilon = 0.9\n", 133 | "BATCH_SIZE = 100\n", 134 | "GAMMA = 0.90\n", 135 | "CLASS_OBJECT = 1\n", 136 | "steps = 10\n", 137 | "epochs = 25\n", 138 | "memory = ReplayMemory(1000)\n", 139 | "\n", 140 | "def select_action(state):\n", 141 | " if random.random() < epsilon:\n", 142 | " action = np.random.randint(1,7)\n", 143 | " else:\n", 144 | " qval = model(Variable(state))\n", 145 | " _, predicted = torch.max(qval.data,1)\n", 146 | " action = predicted[0] + 1\n", 147 | " return action\n", 148 | "\n", 149 | "def optimizer_model():\n", 150 | " if len(memory) < BATCH_SIZE:\n", 151 | " return\n", 152 | " transitions = memory.sample(BATCH_SIZE)\n", 153 | " batch = Transition(*zip(*transitions))\n", 154 | " \n", 155 | " non_final_mask = ByteTensor(tuple(map(lambda s: s is not None, batch.next_state)))\n", 156 | " next_states = [s for s in batch.next_state if s is not None]\n", 157 | " non_final_next_states = Variable(torch.cat(next_states), \n", 158 | " volatile=True).type(Tensor)\n", 159 | " state_batch = Variable(torch.cat(batch.state)).type(Tensor)\n", 160 | " action_batch = Variable(torch.LongTensor(batch.action).view(-1,1)).type(LongTensor)\n", 161 | " reward_batch = Variable(torch.FloatTensor(batch.reward).view(-1,1)).type(Tensor)\n", 162 | "\n", 163 | " # Compute Q(s_t, a) - the model computes Q(s_t), then we select the\n", 164 | " # columns of actions taken\n", 165 | " state_action_values = model(state_batch).gather(1, action_batch)\n", 166 | " \n", 167 | " # Compute V(s_{t+1}) for all next states.\n", 168 | " next_state_values = Variable(torch.zeros(BATCH_SIZE, 1).type(Tensor)) \n", 169 | " next_state_values[non_final_mask] = model(non_final_next_states).max(1)[0]\n", 170 | " \n", 171 | " # Now, we don't want to mess up the loss with a volatile flag, so let's\n", 172 | " # clear it. After this, we'll just end up with a Variable that has\n", 173 | " # requires_grad=False\n", 174 | " next_state_values.volatile = False\n", 175 | " \n", 176 | " # Compute the expected Q values\n", 177 | " expected_state_action_values = (next_state_values * GAMMA) + reward_batch\n", 178 | " \n", 179 | " # Compute loss\n", 180 | " loss = criterion(state_action_values, expected_state_action_values)\n", 181 | "\n", 182 | " # Optimize the model\n", 183 | " optimizer.zero_grad()\n", 184 | " loss.backward()\n", 185 | " optimizer.step()" 186 | ] 187 | }, 188 | { 189 | "cell_type": "markdown", 190 | "metadata": {}, 191 | "source": [ 192 | "# Train the model" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": null, 198 | "metadata": { 199 | "collapsed": true 200 | }, 201 | "outputs": [], 202 | "source": [ 203 | "# train procedure\n", 204 | "print('train the Q-network')\n", 205 | "for epoch in range(epochs):\n", 206 | " print('epoch: %d' %epoch)\n", 207 | " now = time.time()\n", 208 | " for i in range(len(image_names)):\n", 209 | " # the image part\n", 210 | " image_name = image_names[i]\n", 211 | " image = images[i]\n", 212 | " annotation = get_bb_of_gt_from_pascal_xml_annotation(image_name, path_voc)\n", 213 | " classes_gt_objects = get_ids_objects_from_annotation(annotation)\n", 214 | " gt_masks = generate_bounding_box_from_annotation(annotation, image.shape) \n", 215 | " \n", 216 | " # the iou part\n", 217 | " original_shape = (image.shape[0], image.shape[1])\n", 218 | " region_mask = np.ones((image.shape[0], image.shape[1]))\n", 219 | " #choose the max bouding box\n", 220 | " iou = find_max_bounding_box(gt_masks, region_mask, classes_gt_objects, CLASS_OBJECT)\n", 221 | " \n", 222 | " # the initial part\n", 223 | " region_image = image\n", 224 | " size_mask = original_shape\n", 225 | " offset = (0, 0)\n", 226 | " history_vector = torch.zeros((4,6))\n", 227 | " state = get_state(region_image, history_vector, model_vgg)\n", 228 | " done = False\n", 229 | "\n", 230 | " for step in range(steps):\n", 231 | "\n", 232 | " # Select action, the author force terminal action if case actual IoU is higher than 0.5\n", 233 | " if iou > 0.5:\n", 234 | " action = 6\n", 235 | " else:\n", 236 | " action = select_action(state)\n", 237 | " \n", 238 | " # Perform the action and observe new state\n", 239 | " if action == 6:\n", 240 | " next_state = None\n", 241 | " reward = get_reward_trigger(iou)\n", 242 | " done = True\n", 243 | " else:\n", 244 | " offset, region_image, size_mask, region_mask = get_crop_image_and_mask(original_shape, offset,\n", 245 | " region_image, size_mask, action)\n", 246 | " # update history vector and get next state\n", 247 | " history_vector = update_history_vector(history_vector, action)\n", 248 | " next_state = get_state(region_image, history_vector, model_vgg)\n", 249 | " \n", 250 | " # find the max bounding box in the region image\n", 251 | " new_iou = find_max_bounding_box(gt_masks, region_mask, classes_gt_objects, CLASS_OBJECT)\n", 252 | " reward = get_reward_movement(iou, new_iou)\n", 253 | " iou = new_iou\n", 254 | " print('epoch: %d, image: %d, step: %d, reward: %d' %(epoch ,i, step, reward)) \n", 255 | " # Store the transition in memory\n", 256 | " memory.push(state, action-1, next_state, reward)\n", 257 | " \n", 258 | " # Move to the next state\n", 259 | " state = next_state\n", 260 | " \n", 261 | " # Perform one step of the optimization (on the target network)\n", 262 | " optimizer_model()\n", 263 | " if done:\n", 264 | " break\n", 265 | " if epsilon > 0.1:\n", 266 | " epsilon -= 0.1\n", 267 | " time_cost = time.time() - now\n", 268 | " print('epoch = %d, time_cost = %.4f' %(epoch, time_cost))\n" 269 | ] 270 | }, 271 | { 272 | "cell_type": "markdown", 273 | "metadata": {}, 274 | "source": [ 275 | "# Save the Q-Model" 276 | ] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "execution_count": null, 281 | "metadata": {}, 282 | "outputs": [], 283 | "source": [ 284 | "# save the whole model\n", 285 | "Q_NETWORK_PATH = '../models/' + 'one_object_model_2'\n", 286 | "torch.save(model, Q_NETWORK_PATH)\n", 287 | "print('Complete')" 288 | ] 289 | } 290 | ], 291 | "metadata": { 292 | "kernelspec": { 293 | "display_name": "Python 3", 294 | "language": "python", 295 | "name": "python3" 296 | }, 297 | "language_info": { 298 | "codemirror_mode": { 299 | "name": "ipython", 300 | "version": 3 301 | }, 302 | "file_extension": ".py", 303 | "mimetype": "text/x-python", 304 | "name": "python", 305 | "nbconvert_exporter": "python", 306 | "pygments_lexer": "ipython3", 307 | "version": "3.6.2" 308 | } 309 | }, 310 | "nbformat": 4, 311 | "nbformat_minor": 2 312 | } 313 | --------------------------------------------------------------------------------