├── .gitattributes ├── README.md ├── download.py ├── SRPN.py ├── video2pic.py ├── data_otb.py ├── test_otb.py ├── train.py └── axis.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Siamese-RPN (PyTorch Implementation) 2 | 3 | This is my implementation for Siamese Region Proposal Network with PyTorch. 4 | 5 | 6 | Python version: 3.6 (3.5 or 3.4 may be also okay even though I did not try) 7 | 8 | PyTorch version: 0.4.0 or higher 9 | 10 | 11 | NOTE: 12 | This project has been polishing. I will add some useful comments in the near future. 13 | You can click "Watch","Star", or "Fork" if you feel helpful with my work, thanks to your support! 14 | 15 | 16 | 17 | 18 | Paper: 19 | @InProceedings{Li_2018_CVPR, 20 | 21 | author = {Li, Bo and Yan, Junjie and Wu, Wei and Zhu, Zheng and Hu, Xiaolin}, 22 | 23 | title = {High Performance Visual Tracking With Siamese Region Proposal Network}, 24 | 25 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 26 | 27 | month = {June}, 28 | 29 | year = {2018} 30 | 31 | } 32 | 33 | -------------------------------------------------------------------------------- /download.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Jul 11 13:45:30 2018 4 | 5 | @author: ZK 6 | """ 7 | 8 | #!/usr/bin/python 9 | import re 10 | import urllib 11 | def getHtml(url): 12 | page=urllib.request.urlopen(url) 13 | html=page.read() 14 | html=html.decode('utf-8') 15 | return html 16 | def getMp4(html): 17 | r=r"href='(http.*\.mp4)'" 18 | re_mp4=re.compile(r) 19 | mp4List=re.findall(re_mp4,html) 20 | filename=1 21 | for mp4url in mp4List: 22 | urllib.urlretrieve(mp4url,"%s.mp4" %filename) 23 | print ('file "%s.mp4" done' %filename) 24 | filename+=1 25 | url = 'https://v.youku.com/v_show/id_XMzcxMDc1MjYwMA==.html?spm=a2hww.11359951.m_26659.5~5!2~5~5~5~5~5~A' 26 | #url="http://youtu.be/AAxYohQXjmY" 27 | html=getHtml(url) 28 | getMp4(html) 29 | #%% 30 | from selenium import webdriver 31 | from time import sleep 32 | with open('./youtube_BB/detection_train.txt') as f: 33 | a = f.read().split(', ') 34 | browser = webdriver.Chrome() 35 | #%% 36 | for i in range(4, 10000): 37 | # i = 3 38 | tmp = a[i][1:len(a[i])-1] 39 | browser.get('https://www.clipconverter.cc/') 40 | elem = browser.find_element_by_name('mediaurl') 41 | elem.clear() 42 | elem.send_keys('http://youtu.be/' + tmp) 43 | 44 | browser.find_element_by_id('submiturl').click() 45 | # browser.switch_to_window(browser.window_handles[-1]) 46 | # browser.close() 47 | # browser.switch_to_window(browser.window_handles[-1]) 48 | sleep(4) 49 | elem = browser.find_element_by_name('filename') 50 | elem.clear() 51 | elem.send_keys(tmp) 52 | browser.find_element_by_xpath('//*[@id="submitconvert"]/input').click() 53 | sleep(30) 54 | browser.find_element_by_xpath('//*[@id="downloadbutton"]').click() 55 | 56 | #%% 57 | from pytube import YouTube 58 | import os 59 | 60 | if not os.path.exists('./webm/'): 61 | os.makedirs('./webm/') 62 | with open('./youtube_BB/detection_validation.txt') as f: a = f.read().split(', ') 63 | 64 | for i in range(200, len(a)): 65 | i = 0 66 | name = a[i] 67 | name = name[1:-1] 68 | url = 'http://youtu.be/' + name 69 | print(name) 70 | try: 71 | yt = YouTube(url) 72 | except: 73 | continue 74 | print(name) 75 | stream = yt.streams.first() 76 | try: 77 | stream.download('./webm/', name) 78 | except: 79 | continue 80 | 'https://youtu.be/AACebVo-JXY' 81 | #%% 82 | import os 83 | for item in os.listdir('./YTBB_jpg/'): 84 | newname = item.split('_') 85 | os.rename('./YTBB_jpg/'+item, './YTBB_jpg/'+''.join(newname[:-1]) +'@'+ newname[-1]) 86 | os.rename('./127_.jpg', './abc@127.jpg') 87 | -------------------------------------------------------------------------------- /SRPN.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Jun 20 17:34:57 2018 4 | @author: ZK 5 | """ 6 | import torch.nn as nn 7 | import torch.utils.model_zoo as model_zoo 8 | #%% 9 | model_urls = {'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth'} 10 | 11 | class SiameseRPN(nn.Module): 12 | def __init__(self): 13 | super(SiameseRPN, self).__init__() 14 | self.features = nn.Sequential( 15 | nn.Conv2d(3, 64, kernel_size=11, stride=2), 16 | nn.ReLU(inplace=True), 17 | nn.MaxPool2d(kernel_size=3, stride=2), 18 | nn.Conv2d(64, 192, kernel_size=5), 19 | nn.ReLU(inplace=True), 20 | nn.MaxPool2d(kernel_size=3, stride=2), 21 | nn.Conv2d(192, 384, kernel_size=3), 22 | nn.ReLU(inplace=True), 23 | nn.Conv2d(384, 256, kernel_size=3), 24 | nn.ReLU(inplace=True), 25 | nn.Conv2d(256, 256, kernel_size=3), 26 | ) 27 | 28 | self.k = 5 29 | self.conv1 = nn.Conv2d(256, 2*self.k*256, kernel_size=3) 30 | self.relu1 = nn.ReLU(inplace=True) 31 | self.conv2 = nn.Conv2d(256, 4*self.k*256, kernel_size=3) 32 | self.relu2 = nn.ReLU(inplace=True) 33 | self.conv3 = nn.Conv2d(256, 256, kernel_size=3) 34 | self.relu3 = nn.ReLU(inplace=True) 35 | self.conv4 = nn.Conv2d(256, 256, kernel_size=3) 36 | self.relu4 = nn.ReLU(inplace=True) 37 | 38 | self.cconv = nn.Conv2d(256, 2* self.k, kernel_size = 4, bias = False) 39 | self.rconv = nn.Conv2d(256, 4* self.k, kernel_size = 4, bias = False) 40 | 41 | self.reset_params() 42 | 43 | def reset_params(self): 44 | pretrained_dict = model_zoo.load_url(model_urls['alexnet']) 45 | model_dict = self.state_dict() 46 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 47 | model_dict.update(pretrained_dict) 48 | self.load_state_dict(model_dict) 49 | 50 | def forward(self, template, detection): 51 | template = self.features(template) 52 | detection = self.features(detection) 53 | 54 | ckernal = self.conv1(template) 55 | # ckernal = self.relu1(ckernal) 56 | ckernal = ckernal.view(2* self.k, 256, 4, 4) 57 | self.cconv.weight = nn.Parameter(ckernal) 58 | cinput = self.conv3(detection) 59 | # cinput = self.relu3(cinput) 60 | coutput = self.cconv(cinput) 61 | 62 | rkernal = self.conv2(template) 63 | # rkernal = self.relu2(rkernal) 64 | rkernal = rkernal.view(4* self.k, 256, 4, 4) 65 | self.rconv.weight = nn.Parameter(rkernal) 66 | rinput = self.conv4(detection) 67 | # rinput = self.relu4(rinput) 68 | routput = self.rconv(rinput) 69 | 70 | return coutput, routput 71 | 72 | #%% 73 | if __name__ == '__main__': 74 | print('1') 75 | model = SiameseRPN() 76 | #y1, y2 = model(template, detection) 77 | 78 | # model2 = RPN() 79 | #z1, z2 = model(y1, y2) 80 | # model3 = SRPN() -------------------------------------------------------------------------------- /video2pic.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue Jul 3 17:06:28 2018 4 | 5 | @author: ZK 6 | """ 7 | import pandas as pd 8 | import cv2 9 | import os 10 | 11 | f = pd.read_csv('./youtube_BB/youtube_boundingboxes_detection_validation.csv', header=None) 12 | f.columns = ['youtube_id','timestamp_ms','class_id','class_name','object_id','object_presence','xmin','xmax','ymin','ymax'] 13 | for mp4 in os.listdir('./YTBB_mp4/'): 14 | mp4 = mp4.split('.')[0] 15 | mp4='ABQJpBm9hP8' 16 | if not os.path.exists('./YTBB_jpg/' + mp4 + '/'): 17 | os.makedirs('./YTBB_jpg/' + mp4 + '/') 18 | print(mp4) 19 | else: 20 | continue 21 | #mp4 = 'AAQmL_BlrRs' 22 | 23 | id0 = f.loc[f['youtube_id'] == mp4] 24 | id0 = id0.loc[id0['object_presence'] == 'present'] 25 | id0 = id0[['timestamp_ms']] 26 | 27 | vc = cv2.VideoCapture('./YTBB_mp4/' + mp4 + '.mp4') 28 | c = 0 29 | i = 0 30 | if vc.isOpened(): 31 | rval,frame=vc.read() 32 | else: 33 | rval=False 34 | while rval: 35 | rval,frame=vc.read() 36 | if c == int(id0['timestamp_ms'].iloc[i]*30/1000): 37 | cv2.imwrite('./YTBB_jpg/' + mp4 + '@'+str(int(id0['timestamp_ms'].iloc[i]/1000))+'.jpg',frame) 38 | i += 1 39 | c += 1 40 | if i == len(id0): 41 | break 42 | cv2.waitKey(1) 43 | vc.release() 44 | 45 | #f = f[['youtube_id']] 46 | #f = f['youtube_id'].unique() 47 | #f = list(f) 48 | #with open('./youtube_BB/detection_train.txt', 'w') as file: 49 | # file.write(str(f)) 50 | #%% 51 | import os 52 | list1 = os.listdir('./OTB2015/') 53 | for item in list1: 54 | print(item) 55 | with open('./OTB2015/'+item+'/groundtruth_rect.txt') as f: 56 | a = f.read().split('\n') 57 | if '\t' in a[0]: 58 | print('...') 59 | a = [','.join(i.split('\t')) for i in a] 60 | l = os.listdir('./OTB2015/'+item+'/img/') 61 | if not os.path.exists('./OTB2015/'+item+'/label/'): 62 | os.makedirs('./OTB2015/'+item+'/label/') 63 | for j,k in zip(a, l): 64 | with open('./OTB2015/'+item+'/label/'+k.split('.')[0]+'.txt', 'w') as f: 65 | f.write(j) 66 | #%% 67 | import os 68 | l = os.listdir('./OTB2015/') 69 | for item in l: 70 | print(item) 71 | d = './OTB2015/'+item+'/img/' 72 | for j in os.listdir(d): 73 | if j.split('.')[-1] != 'jpg': 74 | print('...') 75 | os.remove(d+j) 76 | #%% 77 | import os 78 | import cv2 79 | from PIL import Image 80 | l = os.listdir('./OTB2015/') 81 | for item in l: 82 | item = 'Car1' 83 | # item = 'Basketball' 84 | print(item) 85 | d = './OTB2015/'+item+'/img/' 86 | f = os.listdir(d)[0] 87 | img = Image.open(d+f) 88 | try: 89 | r,g,b = img.split() 90 | except ValueError: 91 | print ('...') 92 | #%% 93 | import os 94 | l = os.listdir('./OTB2015/') 95 | for item in l: 96 | img = './OTB2015/'+item+'/img/' 97 | label = './OTB2015/'+item+'/label/' 98 | if len(os.listdir(img)) != len(os.listdir(label)): 99 | print(item) 100 | #%% 101 | import os 102 | l = os.listdir('./lq/label/') 103 | for item in l: 104 | with open('./lq/label/'+item, 'r') as f: 105 | a = f.read().split()[2:] 106 | with open('./lq/label/'+item, 'w') as f: 107 | f.write(','.join(a)) -------------------------------------------------------------------------------- /data_otb.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Jun 20 17:36:51 2018 4 | @author: ZK 5 | """ 6 | import numpy as np 7 | import math 8 | import torch 9 | from torch.utils.data import Dataset, DataLoader 10 | from PIL import Image 11 | from torchvision.transforms import functional as F 12 | import os 13 | from axis import x1y1x2y2_to_xywh, xywh_to_x1y1x2y2, x1y1wh_to_xywh, x1y1wh_to_x1y1x2y2, point_center_crop, resize 14 | import random 15 | 16 | #%% 17 | data_dir = './OTB2015/' 18 | interval = 20 19 | 20 | list1 = os.listdir(data_dir) 21 | number = [] 22 | for item in list1: 23 | number.append(len(os.listdir(data_dir+item+'/img/'))) 24 | 25 | number = [i-interval for i in number] 26 | 27 | sum1 = [0] 28 | for a in range(len(number)): 29 | sum1.append(sum1[a]+number[a]) 30 | 31 | #%% 32 | class MyDataset(Dataset): 33 | 34 | def __init__(self, root_dir, anchor_scale = 64, k = 5): 35 | self.root_dir = root_dir 36 | self.anchor_shape = self._get_anchor_shape(anchor_scale) 37 | self.k = k 38 | 39 | """根据anchor_scale获得5个anchor的宽度和高度 40 | """ 41 | def _get_anchor_shape(self, a): 42 | s = a**2 43 | r = [[3*math.sqrt(s/3.),math.sqrt(s/3.)], [2*math.sqrt(s/2.),math.sqrt(s/2.)], 44 | [a,a], [math.sqrt(s/2.),2*math.sqrt(s/2.)], [math.sqrt(s/3.),3*math.sqrt(s/3.)]] 45 | return [list(map(round, i)) for i in r] 46 | 47 | def __len__(self): 48 | return sum1[-1] 49 | 50 | def _which(self, index, sum1): 51 | low = 0 52 | high = len(sum1) - 1 53 | while(high - low > 1): 54 | mid = (high+low) // 2 55 | if sum1[mid] <= index: 56 | low = mid 57 | elif sum1[mid] > index: 58 | high = mid 59 | return low 60 | 61 | """读取数据集时,将会调用下面这个方法来获取数据 62 | """ 63 | def __getitem__(self, index): 64 | # print(index) 65 | low = self._which(index, sum1) 66 | index -= sum1[low] 67 | folder = list1[low] 68 | 69 | img = os.listdir(self.root_dir + folder + '/img/')[index] 70 | img = Image.open(self.root_dir + folder + '/img/' + img) 71 | if img.mode != 'RGB': 72 | img = img.convert('RGB') 73 | 74 | gtbox = os.listdir(self.root_dir + folder + '/label/')[index] 75 | with open(self.root_dir + folder + '/label/' + gtbox) as f: 76 | gtbox = f.read().split(',') 77 | gtbox = [round(float(i)) for i in gtbox] 78 | gtbox = x1y1wh_to_xywh(gtbox) 79 | template, _, _ = self._transform(img, gtbox, 1, 127) 80 | 81 | rand = random.randrange(1,interval) 82 | 83 | img = os.listdir(self.root_dir + folder + '/img/')[index + rand] 84 | img = Image.open(self.root_dir + folder + '/img/' + img) 85 | if img.mode != 'RGB': 86 | img = img.convert('RGB') 87 | 88 | gtbox = os.listdir(self.root_dir + folder + '/label/')[index + rand] 89 | with open(self.root_dir + folder + '/label/' + gtbox) as f: 90 | gtbox = f.read().split(',') 91 | gtbox = [round(float(i)) for i in gtbox] 92 | gtbox = x1y1wh_to_xywh(gtbox) 93 | detection, pcc, ratio = self._transform(img, gtbox, 2, 255) 94 | 95 | 96 | a = (gtbox[2]+gtbox[3]) / 2. 97 | a = math.sqrt((gtbox[2]+a)*(gtbox[3]+a)) * 2 98 | gtbox = [127, 127, round(255*gtbox[2]/a), round(255*gtbox[3]/a)] 99 | 100 | clabel, rlabel = self._gtbox_to_label(gtbox) 101 | return template, detection, clabel, rlabel, torch.from_numpy(np.array(pcc).reshape((1,4))), torch.from_numpy(np.array(ratio).reshape((1,1))) 102 | 103 | '''数据转换,包括裁剪、变形、转换为tensor、归一化 104 | ''' 105 | def _transform(self, img, gtbox, area, size): 106 | img, pcc = point_center_crop(img, gtbox, area) 107 | img, ratio = resize(img, size) 108 | img = F.to_tensor(img) 109 | # img = F.normalize(img, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 110 | return img, pcc, ratio 111 | 112 | """根据ground truth box构造class label和reg label 113 | """ 114 | def _gtbox_to_label(self, gtbox): 115 | clabel = np.zeros([5, 17, 17]) - 100 116 | rlabel = np.zeros([20, 17, 17], dtype = np.float32) 117 | pos, neg = self._get_64_anchors(gtbox) 118 | for i in range(len(pos)): 119 | clabel[pos[i, 2], pos[i, 0], pos[i, 1]] = 1 120 | for i in range(len(neg)): 121 | clabel[neg[i, 2], neg[i, 0], neg[i, 1]] = 0 122 | pos_coord = self._anchor_coord(pos) 123 | channel0 = (gtbox[0] - pos_coord[:, 0]) / pos_coord[:, 2] 124 | channel1 = (gtbox[1] - pos_coord[:, 1]) / pos_coord[:, 3] 125 | channel2 = np.array([math.log(i) for i in (gtbox[2] / pos_coord[:, 2]).tolist()]) 126 | channel3 = np.array([math.log(i) for i in (gtbox[3] / pos_coord[:, 3]).tolist()]) 127 | for i in range(len(pos)): 128 | rlabel[pos[i][2]*4, pos[i][0], pos[i][1]] = channel0[i] 129 | rlabel[pos[i][2]*4 + 1, pos[i][0], pos[i][1]] = channel1[i] 130 | rlabel[pos[i][2]*4 + 2, pos[i][0], pos[i][1]] = channel2[i] 131 | rlabel[pos[i][2]*4 + 3, pos[i][0], pos[i][1]] = channel3[i] 132 | return torch.Tensor(clabel).long(), torch.Tensor(rlabel).float() 133 | 134 | """根据anchor在label中的位置来获取anchor在detection frame中的坐标 135 | """ 136 | def _anchor_coord(self, pos): 137 | result = np.ndarray([0, 4]) 138 | for i in pos: 139 | tmp = [7+15*i[0], 7+15*i[1], self.anchor_shape[i[2]][0], self.anchor_shape[i[2]][1]] 140 | result = np.concatenate([result, np.array(tmp).reshape([1,4])], axis = 0) 141 | return result 142 | 143 | def _get_64_anchors(self, gtbox): 144 | pos = {} 145 | neg = {} 146 | for a in range(17): 147 | for b in range(17): 148 | for c in range(5): 149 | anchor = [7+15*a, 7+15*b, self.anchor_shape[c][0], self.anchor_shape[c][1]] 150 | anchor = xywh_to_x1y1x2y2(anchor) 151 | if anchor[0]>=0 and anchor[1]>=0 and anchor[2]<=255 and anchor[3]<=255: 152 | iou = self._IOU(anchor, gtbox) 153 | if iou >= 0.5: 154 | pos['%d,%d,%d' % (a,b,c)] = iou 155 | elif iou <= 0.2: 156 | neg['%d,%d,%d' % (a,b,c)] = iou 157 | pos = sorted(pos.items(),key = lambda x:x[1],reverse = True) 158 | pos = [list(map(int, i[0].split(','))) for i in pos[:16]] 159 | neg = sorted(neg.items(),key = lambda x:x[1],reverse = True) 160 | neg = [list(map(int, i[0].split(','))) for i in neg[:(64-len(pos))]] 161 | return np.array(pos), np.array(neg) 162 | 163 | # def _f(self, x): 164 | # if x <= 0: return 0 165 | # elif x >= 254: return 254 166 | # else: return x 167 | 168 | def _IOU(self, a, b): 169 | # a = xywh_to_x1y1x2y2(a) 170 | b = xywh_to_x1y1x2y2(b) 171 | sa = (a[2] - a[0]) * (a[3] - a[1]) 172 | sb = (b[2] - b[0]) * (b[3] - b[1]) 173 | w = max(0, min(a[2], b[2]) - max(a[0], b[0])) 174 | h = max(0, min(a[3], b[3]) - max(a[1], b[1])) 175 | area = w * h 176 | return area / (sa + sb - area) 177 | 178 | #%% 179 | transformed_dataset_train = MyDataset(root_dir = data_dir) 180 | train_dataloader = DataLoader(transformed_dataset_train, batch_size=1, shuffle=True, num_workers=0) 181 | dataloader = {'train':train_dataloader, 'valid':train_dataloader} 182 | 183 | #transformed_dataset_test = MyDataset(detection_root_dir = './lq/JPEGImages/', 184 | # gtbox_root_dir = './lq/label/') 185 | #test_dataloader = DataLoader(transformed_dataset_train, batch_size=1, shuffle=False, num_workers=0) 186 | #%% 187 | #with open('./vot2015/bag/groundtruth.txt') as f: 188 | # a = f.read().split() 189 | #b = [float(i) for i in a[0].split(',')] 190 | #x = (b[0]+b[4]) / 2. 191 | #y = (b[1]+b[5]) / 2. 192 | #w = 1.1 * math.sqrt(math.sqrt(((b[0]-b[2])**2+(b[1]-b[3])**2) * ((b[2]-b[4])**2+(b[3]-b[5])**2))) 193 | #h = w 194 | #if (b[0]-b[2]) >= (b[1]-b[3]): 195 | # w = (b[0]-b[2]) 196 | # 197 | #list1 = xywh_to_x1y1x2y2([x, y, w, h]) 198 | #list1 = os.listdir('./vot2015') 199 | #number = [len(os.listdir('./vot2015/'+item)) for item in os.listdir('./vot2015')] 200 | -------------------------------------------------------------------------------- /test_otb.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Dec 4 09:08:38 2017 4 | @author: ZK 5 | """ 6 | #%% 7 | import torch 8 | from torch.autograd import Variable as V 9 | import numpy as np 10 | import cv2 11 | import math 12 | from axis import xywh_to_x1y1x2y2 13 | #from PIL import Image 14 | import os 15 | from torchvision.transforms import functional as F 16 | from axis import x1y1x2y2_to_xywh, point_center_crop, resize, x1y1wh_to_xywh 17 | from PIL import Image 18 | from SRPN import SiameseRPN 19 | 20 | #%% 21 | #interval = 1 22 | # 23 | #imgdir = os.listdir('./OTB2015_small/') 24 | #number = [] 25 | #for item in imgdir: 26 | # number.append(len(os.listdir('./OTB2015_small/'+item+'/img/'))) 27 | # 28 | #number = [i-interval for i in number] 29 | # 30 | #sum1 = [number[0]] 31 | #for a in range(1, len(number)): 32 | # sum1.append(sum1[a-1]+number[a]) 33 | #%% 34 | def _transform(img, gtbox, area, size): 35 | img, pcc = point_center_crop(img, gtbox, area) 36 | img, ratio = resize(img, size) 37 | img = F.to_tensor(img) 38 | # img = F.normalize(img, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 39 | img = img.unsqueeze(0) 40 | return img, pcc, ratio 41 | #%% 42 | def transform2(img, gtbox, area, size): 43 | img, _ = point_center_crop(img, gtbox, area) 44 | img, _ = resize(img, size) 45 | 46 | return img 47 | #%% 48 | def IOU(a, b): 49 | sa = (a[2] - a[0]) * (a[3] - a[1]) 50 | sb = (b[2] - b[0]) * (b[3] - b[1]) 51 | w = max(0, min(a[2], b[2]) - max(a[0], b[0])) 52 | h = max(0, min(a[3], b[3]) - max(a[1], b[1])) 53 | area = w * h 54 | return area / (sa + sb - area) 55 | 56 | #%% 57 | def show_output(root_dir, model, pth_file, video_dir, use_gpu): 58 | 59 | print('load model params...') 60 | model.load_state_dict(torch.load(pth_file)) 61 | model = model.train(False) # Set model to evaluate mode 62 | 63 | print('test...') 64 | 65 | "将last_box初始化为第一帧,即template的gtbox。" 66 | last_box = os.listdir(root_dir+'/label/')[0] 67 | with open(root_dir +'/label/'+ last_box) as f: 68 | last_box = f.read().split(',') 69 | last_box = [float(i) for i in last_box] 70 | last_box = x1y1wh_to_xywh(last_box) 71 | # last_box = x1y1x2y2_to_xywh(last_box) 72 | 73 | "template初始化为第一帧。" 74 | template = os.listdir(root_dir+'/img/')[0] 75 | template = Image.open(root_dir +'/img/'+ template) 76 | video_size = template.size 77 | # video_size = (255, 255) 78 | template, _, _ = _transform(template, last_box, 1, 127) 79 | """""" 80 | # template = template.squeeze() 81 | # template = template.numpy() 82 | # import cv2 83 | # import numpy as np 84 | # import math 85 | # from axis import xywh_to_x1y1x2y2 86 | # template = np.transpose(template,(1,2,0)) 87 | # template = cv2.cvtColor(template, cv2.COLOR_RGB2BGR) 88 | # cv2.imshow('img', template) 89 | # cv2.waitKey(0) 90 | """""" 91 | if use_gpu: 92 | template = V(template.cuda()) 93 | model = model.cuda() 94 | else: 95 | template = V(template) 96 | 97 | fps = 20 #视频帧率 98 | fourcc = cv2.cv2.VideoWriter_fourcc('M','J','P','G') 99 | videoWriter = cv2.VideoWriter(video_dir, fourcc, fps, video_size) #(1360,480)为视频大小 100 | 101 | RESET = True 102 | count = 0 103 | 104 | for index in range(1, len(os.listdir(root_dir+'/img/'))): 105 | # for index in range(0, 30): 106 | # print(index) 107 | 108 | gtbox = os.listdir(root_dir +'/label/')[index] 109 | with open(root_dir +'/label/' + gtbox) as f: 110 | gtbox = f.read().split(',') 111 | gtbox = [float(i) for i in gtbox] 112 | gtbox = x1y1wh_to_xywh(gtbox) 113 | ## gtbox = x1y1x2y2_to_xywh(gtbox) 114 | # """用ground truth作为last_box:""" 115 | # if RESET: 116 | # last_box = gtbox 117 | # """""" 118 | # """更新template:""" 119 | # template = os.listdir(root_dir+'/img/')[index] 120 | # template = Image.open(root_dir +'/img/'+ template) 121 | # template, _, _ = _transform(template, last_box, 1, 127) 122 | # if use_gpu: 123 | # template = V(template.cuda()) 124 | # model = model.cuda() 125 | # else: 126 | # template = V(template) 127 | """""" 128 | print(last_box) 129 | detection = os.listdir(root_dir+'/img/')[index] 130 | detection = Image.open(root_dir+'/img/' + detection) 131 | detection, pcc, ratio = _transform(detection, last_box, 2, 255) 132 | """""" 133 | # detection = detection.squeeze() 134 | # detection = detection.numpy() 135 | # import cv2 136 | # import numpy as np 137 | # import math 138 | # from axis import xywh_to_x1y1x2y2 139 | # detection = np.transpose(detection,(1,2,0)) 140 | # detection = cv2.cvtColor(detection, cv2.COLOR_RGB2BGR) 141 | # cv2.imshow('img', detection) 142 | # cv2.waitKey(0) 143 | """""" 144 | if use_gpu: 145 | detection = V(detection.cuda()) 146 | else: 147 | detection = V(detection) 148 | 149 | coutput, routput = model(template, detection) 150 | coutput, routput = coutput.squeeze(), routput.squeeze() 151 | # coutput_numpy = coutput.data.cpu().numpy() 152 | # pcc, ratio = pcc.squeeze(), ratio.squeeze() 153 | 154 | coutput = coutput.view(5, 2, 17, 17) 155 | 156 | coutput = torch.nn.Softmax2d()(coutput) 157 | coutput1 = coutput[:,1,:,:] 158 | 159 | if use_gpu: 160 | coutput1, routput = coutput1.data.cpu().numpy().astype(np.float64), routput.data.cpu().numpy() 161 | else: 162 | coutput1, routput = coutput1.data.numpy().astype(np.float64), routput.data.numpy() 163 | 164 | 165 | a = 64 166 | s = a**2 167 | r = [[3*math.sqrt(s/3.),math.sqrt(s/3.)], [2*math.sqrt(s/2.),math.sqrt(s/2.)], [a,a], [math.sqrt(s/2.),2*math.sqrt(s/2.)], [math.sqrt(s/3.),3*math.sqrt(s/3.)]] 168 | r = [list(map(round, i)) for i in r] 169 | 170 | center_size = 5 171 | "只保留与coutput1的中心位置的距离小于等于center_size的部分:" 172 | coutput1 = coutput1[:, (8-center_size):(10+center_size), (8-center_size):(10+center_size)] 173 | 174 | "根据coutput1确定对应anchor及reg的位置:" 175 | loc1 = np.where(coutput1 == np.max(coutput1)) 176 | # loc1 = np.where(coutput1 > 0.1) 177 | img = cv2.imread(root_dir+'/img/'+os.listdir(root_dir+'/img/')[index]) 178 | # "用last_box作为中心,对img处理成255*255:" 179 | # img = Image.fromarray(cv2.cvtColor(img,cv2.COLOR_BGR2RGB)) 180 | # img = transform2(img, last_box, 2, 255) 181 | # img = cv2.cvtColor(np.asarray(img),cv2.COLOR_RGB2BGR) 182 | # total = {} 183 | for where in range(len(loc1[0])): 184 | # where = 0 185 | loc = [loc1[0][where], loc1[1][where]+8-center_size, loc1[2][where]+8-center_size] 186 | 187 | anchor = [7+15*loc[1], 7+15*loc[2]] + r[loc[0]] #根据loc确定anchor 188 | "根据loc确定对anchor的修正:" 189 | reg = [routput[loc[0]*4, loc[1], loc[2]], routput[loc[0]*4+1, loc[1], loc[2]], routput[loc[0]*4+2, loc[1], loc[2]], routput[loc[0]*4+3, loc[1], loc[2]]] 190 | "根据anchor及reg确定proposals:" 191 | pro = [anchor[0]+reg[0]*anchor[2], anchor[1]+reg[1]*anchor[3], anchor[2]*math.exp(reg[2]), anchor[3]*math.exp(reg[3])] 192 | # pro = anchor 193 | "把在255X255中的proposals转换成原图的对应位置:" 194 | pro = [pro[0]*ratio+pcc[2]-pcc[0], pro[1]*ratio+pcc[3]-pcc[1], pro[2]*ratio, pro[3]*ratio] 195 | 196 | list1 = xywh_to_x1y1x2y2(pro) 197 | list1 = list(map(lambda x:int(round(x)), list1)) 198 | 199 | # total[','.join([str(i) for i in list1])] = sum(list1) - sum(gtbox) 200 | 201 | "把pro传给last_box,下一帧的detection进行图像预处理时将以last_box作为中心。" 202 | last_box = pro 203 | last_box = list(map(lambda x:int(round(x)), last_box)) 204 | 205 | # list1 = list(total.keys())[list(total.values()).index(min(total.values()))].split(',') 206 | # list1 = [int(i) for i in list1] 207 | 208 | gtbox = xywh_to_x1y1x2y2(gtbox) 209 | gtbox = list(map(lambda x:int(round(x)), gtbox)) 210 | try: 211 | cv2.rectangle(img, (list1[0],list1[1]), (list1[2],list1[3]), (0,255,0), 1) 212 | except OverflowError: 213 | print(list1) 214 | cv2.rectangle(img, (gtbox[0],gtbox[1]), (gtbox[2],gtbox[3]), (255,0,0), 1) 215 | 216 | # if IOU(gtbox, list1) < 0.1: 217 | # RESET = True 218 | # count += 1 219 | # cv2.imshow('img', img) 220 | # cv2.waitKey(0) 221 | 222 | videoWriter.write(img) 223 | videoWriter.release() 224 | print(count) 225 | return 226 | 227 | #%% 228 | if __name__ == '__main__': 229 | model = SiameseRPN() 230 | show_output( 231 | root_dir = './OTB2015/Dog/' 232 | # root_dir = './lq/' 233 | , 234 | model = model 235 | , 236 | pth_file = './pth_OTB2015/epoch_31.pth' 237 | , 238 | video_dir = './OTB2015_Dog.avi' 239 | , 240 | use_gpu = True 241 | ) 242 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Jun 27 11:49:42 2018 4 | @author: ZK 5 | """ 6 | import torch 7 | from torch import nn 8 | from torch.autograd import Variable as V 9 | import os 10 | from axis import SmoothL1Loss 11 | from axis import Myloss 12 | #%% 13 | def train_model(dataloader, model, optimizer, lmbda, scheduler, num_epochs, pth_dir, use_gpu): 14 | if not os.path.exists(pth_dir): 15 | os.makedirs(pth_dir) 16 | dirlist = os.listdir(pth_dir) 17 | if (dirlist): 18 | # del dirlist[dirlist.index('record.txt')] 19 | l = [int(i.split('.')[0].split('_')[-1]) for i in dirlist] 20 | former_epoch = max(l) 21 | model.load_state_dict(torch.load(pth_dir+'/epoch_'+str(former_epoch)+'.pth')) 22 | print('former_epoch %d loaded.' % former_epoch) 23 | else: 24 | former_epoch = 0 25 | print('first train begin.') 26 | for epoch in range(former_epoch+1, num_epochs+1): 27 | print('-' * 20) 28 | print('Epoch {}/{}'.format(epoch, num_epochs)) 29 | 30 | # for phase in ['train', 'valid']: 31 | phase = 'train' 32 | epoch_loss = 0 33 | epoch_closs = 0 34 | epoch_rloss = 0 35 | 36 | if phase == 'train': 37 | print('-----train-----') 38 | if scheduler: 39 | scheduler.step() 40 | model.train(True) # Set model to training mode 41 | else: 42 | print('-----valid-----') 43 | model.train(False) # Set model to evaluate mode 44 | 45 | phase = 'train' 46 | for i, tvdata in enumerate(dataloader[phase]): 47 | template, detection, clabel, rlabel, pcc, ratio = tvdata 48 | """""" 49 | # template, detection, clabel, rlabel, pcc, ratio = template.squeeze(), detection.squeeze(), clabel.squeeze(), rlabel.squeeze(), pcc.squeeze(), ratio.squeeze() 50 | # template, detection, clabel, rlabel, pcc, ratio = template.numpy(), detection.numpy(), clabel.numpy(), rlabel.numpy(), pcc.numpy(), ratio.numpy() 51 | # import cv2 52 | # import numpy as np 53 | # import math 54 | # from axis import xywh_to_x1y1x2y2 55 | # template = np.transpose(template,(1,2,0)) 56 | # template = cv2.cvtColor(template, cv2.COLOR_RGB2BGR) 57 | # cv2.imshow('img', template) 58 | # cv2.waitKey(0) 59 | # 60 | # detection = np.transpose(detection,(1,2,0)) 61 | # detection = cv2.cvtColor(detection, cv2.COLOR_RGB2BGR) 62 | ## cv2.imshow('img', detection) 63 | ## cv2.waitKey(0) 64 | ## 65 | # a = 64 66 | # s = a**2 67 | # r = [[3*math.sqrt(s/3.),math.sqrt(s/3.)], [2*math.sqrt(s/2.),math.sqrt(s/2.)], [a,a], [math.sqrt(s/2.),2*math.sqrt(s/2.)], [math.sqrt(s/3.),3*math.sqrt(s/3.)]] 68 | # r = [list(map(round, i)) for i in r] 69 | # 70 | # loc1 = np.where(clabel > 0.5) 71 | ## img = cv2.imread('./lq/JPEGImages/'+os.listdir('./lq/JPEGImages/')[i]) 72 | # for where in range(len(loc1[0])): 73 | # loc = [loc1[0][where], loc1[1][where], loc1[2][where]] 74 | # 75 | # anchor = [7+15*loc[1], 7+15*loc[2]] + r[loc[0]] #根据loc确定anchor 76 | # "根据loc确定对anchor的修正:" 77 | # reg = [rlabel[loc[0]*4, loc[1], loc[2]], rlabel[loc[0]*4+1, loc[1], loc[2]], rlabel[loc[0]*4+2, loc[1], loc[2]], rlabel[loc[0]*4+3, loc[1], loc[2]]] 78 | # "根据anchor及reg确定proposals" 79 | # pro = [anchor[0]+reg[0]*anchor[2], anchor[1]+reg[1]*anchor[3], anchor[2]*math.exp(reg[2]), anchor[3]*math.exp(reg[3])] 80 | ## pro = anchor 81 | ## "把在255X255中的proposals转换成原图的对应位置" 82 | ## pro2 = [pro[0]*ratio+pcc[2]-pcc[0], pro[1]*ratio+pcc[3]-pcc[1], pro[2]*ratio, pro[3]*ratio] 83 | # list1 = xywh_to_x1y1x2y2(pro) 84 | # list1 = list(map(lambda x:int(round(x)), list1)) 85 | # cv2.rectangle(detection, (list1[0],list1[1]), (list1[2],list1[3]), (0,255,0), 1) 86 | # cv2.imshow('img', detection) 87 | # cv2.waitKey(0) 88 | # detection = Image.fromarray(cv2.cvtColor(detection,cv2.COLOR_BGR2RGB)) 89 | # detection.save('./tmp/'+str(i)+'.jpg') 90 | # cv2.imwrite('./tmp/'+str(i)+'.jpg', detection, [int(cv2.IMWRITE_JPEG_QUALITY), 95]) 91 | """""" 92 | if use_gpu: 93 | target = torch.zeros(clabel.shape).cuda()+1 94 | template = V(template.cuda()) 95 | detection = V(detection.cuda()) 96 | clabel = V(clabel.cuda()) 97 | rlabel = V(rlabel.cuda()) 98 | model = model.cuda() 99 | else: 100 | target = torch.zeros(clabel.shape)+1 101 | template = V(template) 102 | detection = V(detection) 103 | clabel = V(clabel) 104 | rlabel = V(rlabel) 105 | 106 | optimizer.zero_grad() 107 | 108 | # forward 109 | coutput, routput = model(template, detection) 110 | # coutput, routput, clabel, rlabel = coutput.squeeze(), routput.squeeze(), clabel.squeeze(), rlabel.squeeze() 111 | coutput, clabel = coutput.squeeze(), clabel.squeeze() 112 | coutput = coutput.view(5, 2, 17, 17) # Batch*k*2*17*17 113 | 114 | # routput0 = routput[0].data.numpy() 115 | # rlabel0 = rlabel[0].data.numpy() 116 | closs = nn.CrossEntropyLoss()(coutput, clabel) 117 | 118 | rloss = SmoothL1Loss(use_gpu = use_gpu)(clabel, target, routput, rlabel) 119 | # rloss = nn.SmoothL1Loss()(routput, rlabel) 120 | loss = Myloss()(coutput, clabel, target, routput, rlabel, lmbda) 121 | 122 | # loss = closs + lmbda * rloss 123 | loss2 = torch.add(closs, lmbda, rloss) 124 | epoch_loss += loss2.data.item() 125 | epoch_closs += closs.data.item() 126 | epoch_rloss += rloss.data.item() 127 | # epoch_rloss += 0 128 | # backward + optimize only if in training 129 | if phase == 'train': 130 | loss.backward() 131 | optimizer.step() 132 | 133 | # statistics 134 | 135 | # top1num, top1acc = accuracy(outputs, labels, 1) 136 | # top3num, top3acc = accuracy(outputs, labels, 3) 137 | 138 | # epoch_top1num += top1num 139 | # epoch_top3num += top3num 140 | 141 | if (phase == 'train'): 142 | if(i+1 == 2 or (i+1) % 100 == 0): 143 | print('batch %d, train loss:%.6f' % (i+1, loss.data.item())) 144 | # duration = time.time() - since 145 | # print('step %d in %.0f seconds. loss: %.6f' % (i+1, duration, loss.data[0])) 146 | # print(' * top1acc:{top1acc:.6f}; top3acc:{top3acc:.6f}' 147 | # .format(top1acc=top1acc, top3acc=top3acc)) 148 | if (i+1 == len(dataloader[phase])): 149 | print('train loss:%.6f' % (epoch_loss/len(dataloader[phase]))) 150 | print('closs:%.6f' % (epoch_closs/len(dataloader[phase]))) 151 | print('rloss:%.6f' % (epoch_rloss/len(dataloader[phase]))) 152 | # with open(RECORD_FILE, 'a') as f: 153 | # f.write('-'*20 + '\nEpoch %d/%d\n' % (epoch,num_epochs)) 154 | # f.write('Epoch %d: loss:%.6f; top1acc:%.6f; top3acc:%.6f\n' 155 | # % (epoch, epoch_loss/len(dataloader), epoch_top1num/len(dataset_train), epoch_top3num/len(dataset_train))) 156 | # elif (phase == 'valid'): 157 | # if (i+1 == len(valid_dataloader)): 158 | # print('\nvalid loss:%.6f;\ntop1acc:%.6f; top3acc:%.6f' 159 | # % (epoch_loss/len(valid_dataloader), epoch_top1num/len(dataset_valid), epoch_top3num/len(dataset_valid))) 160 | # with open(RECORD_FILE, 'a') as f: 161 | # f.write('Epoch %d: loss:%.6f; top1acc:%.6f; top3acc:%.6f\n' 162 | # % (epoch, epoch_loss/len(valid_dataloader), epoch_top1num/len(dataset_valid), epoch_top3num/len(dataset_valid))) 163 | 164 | # deep copy the model 165 | # if epoch_acc > best_acc: 166 | # best_acc = epoch_acc 167 | # best_model_wts = model_conv.state_dict() 168 | 169 | torch.save(model.state_dict(), (pth_dir + 'epoch_%d.pth')% epoch) 170 | # print('current model saved to epoch_%d.pth' % epoch) 171 | 172 | #%% 173 | from SRPN import SiameseRPN 174 | #from data import dataloader 175 | from data_otb import dataloader 176 | import torch.optim as optim 177 | from torch.optim import lr_scheduler 178 | #%% 179 | if __name__ == '__main__': 180 | 181 | model = SiameseRPN() 182 | 183 | params = [] 184 | # params += list(model.features[0].parameters()) 185 | # params += list(model.features[3].parameters()) 186 | # params += list(model.features[6].parameters()) 187 | params += list(model.features[8].parameters()) 188 | params += list(model.features[10].parameters()) 189 | params += list(model.conv1.parameters()) 190 | params += list(model.conv2.parameters()) 191 | params += list(model.conv3.parameters()) 192 | params += list(model.conv4.parameters()) 193 | 194 | optimizer = optim.Adam(params, lr=1e-3, eps=1e-8, weight_decay=0) 195 | # optimizer = optim.SGD(params, lr=1e-3) 196 | scheduler = lr_scheduler.StepLR(optimizer, step_size=40, gamma=0.1) 197 | 198 | train_model( 199 | dataloader = dataloader 200 | , 201 | model = model 202 | , 203 | optimizer = optimizer 204 | , 205 | # scheduler = scheduler 206 | scheduler = None 207 | , 208 | lmbda = 1 209 | , 210 | num_epochs = 100 211 | , 212 | pth_dir = './pth_OTB2015/' 213 | , 214 | use_gpu = True 215 | ) 216 | #%% 217 | 218 | 219 | -------------------------------------------------------------------------------- /axis.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Jun 27 11:42:36 2018 4 | @author: ZK 5 | """ 6 | 7 | def x1y1x2y2_to_xywh(gtbox): 8 | return list(map(round, [(gtbox[0]+gtbox[2])/2., (gtbox[1]+gtbox[3])/2., gtbox[2]-gtbox[0], gtbox[3]-gtbox[1]])) 9 | 10 | def xywh_to_x1y1x2y2(gtbox): 11 | return list(map(round, [gtbox[0]-gtbox[2]/2., gtbox[1]-gtbox[3]/2., gtbox[0]+gtbox[2]/2., gtbox[1]+gtbox[3]/2.])) 12 | 13 | def x1y1wh_to_xywh(gtbox): 14 | x1, y1, w, h = gtbox 15 | return [round(x1 + w/2.), round(y1 + h/2.), w, h] 16 | 17 | def x1y1wh_to_x1y1x2y2(gtbox): 18 | x1, y1, w, h = gtbox 19 | return [x1, y1, x1+w, y1+h] 20 | #%% 21 | import torch 22 | from torch.nn import Module 23 | from torch.nn import functional as F 24 | #%% 25 | class SmoothL1Loss(Module): 26 | def __init__(self, use_gpu): 27 | super (SmoothL1Loss, self).__init__() 28 | self.use_gpu = use_gpu 29 | return 30 | 31 | def forward(self, clabel, target, routput, rlabel): 32 | 33 | # rloss = F.smooth_l1_loss(routput, rlabel) 34 | rloss = F.smooth_l1_loss(routput, rlabel, size_average=False, reduce=False) 35 | 36 | 37 | e = torch.eq(clabel.float(), target) 38 | e = e.squeeze() 39 | e0,e1,e2,e3,e4 = e[0].unsqueeze(0),e[1].unsqueeze(0),e[2].unsqueeze(0),e[3].unsqueeze(0),e[4].unsqueeze(0) 40 | eq = torch.cat([e0,e0,e0,e0,e1,e1,e1,e1,e2,e2,e2,e2,e3,e3,e3,e3,e4,e4,e4,e4], dim=0).float() 41 | 42 | rloss = rloss.squeeze() 43 | rloss = torch.mul(eq, rloss) 44 | rloss = torch.sum(rloss) 45 | rloss = torch.div(rloss, eq.nonzero().shape[0]+1e-4) 46 | return rloss 47 | #%% 48 | class Myloss(Module): 49 | def __init__(self): 50 | super (Myloss, self).__init__() 51 | return 52 | 53 | def forward(self, coutput, clabel, target, routput, rlabel, lmbda): 54 | closs = F.cross_entropy(coutput, clabel) 55 | 56 | # rloss = F.smooth_l1_loss(routput, rlabel) 57 | rloss = F.smooth_l1_loss(routput, rlabel, size_average=False, reduce=False) 58 | 59 | 60 | e = torch.eq(clabel.float(), target) 61 | e = e.squeeze() 62 | e0,e1,e2,e3,e4 = e[0].unsqueeze(0),e[1].unsqueeze(0),e[2].unsqueeze(0),e[3].unsqueeze(0),e[4].unsqueeze(0) 63 | eq = torch.cat([e0,e0,e0,e0,e1,e1,e1,e1,e2,e2,e2,e2,e3,e3,e3,e3,e4,e4,e4,e4], dim=0).float() 64 | 65 | rloss = rloss.squeeze() 66 | rloss = torch.mul(eq, rloss) 67 | rloss = torch.sum(rloss) 68 | rloss = torch.div(rloss, eq.nonzero().shape[0]+1e-4) 69 | 70 | loss = torch.add(closs, lmbda, rloss) 71 | return loss 72 | #%% 73 | import math 74 | from PIL import ImageStat, Image 75 | from torchvision.transforms import functional as F2 76 | #%% 77 | def resize(img, size, interpolation=Image.BILINEAR): 78 | assert img.size[0] == img.size[1] 79 | return img.resize((size, size), interpolation), img.size[0] / size 80 | #%% 81 | def point_center_crop(img, gtbox, area): 82 | x, y, dw, dh = gtbox 83 | p = (dw + dh) / 2. 84 | a = math.sqrt((dw + p) * (dh + p)) 85 | a *= area 86 | i = round(x - a/2.) 87 | j = round(y - a/2.) 88 | mean = tuple(map(round, ImageStat.Stat(img).mean)) 89 | if i < 0: 90 | left = -i 91 | i = 0 92 | else: 93 | left = 0 94 | if j < 0: 95 | top = -j 96 | j = 0 97 | else: 98 | top = 0 99 | if x+a/2. > img.size[0]: 100 | right = round(x+a/2.-img.size[0]) 101 | else: 102 | right = 0 103 | if y+a/2. > img.size[1]: 104 | bottom = round(y+a/2.-img.size[1]) 105 | else: 106 | bottom = 0 107 | 108 | img = F2.pad(img, padding=(left, top, right, bottom), fill=mean, padding_mode='constant') 109 | img = img.crop((i, j, i+round(a), j+round(a))) 110 | 111 | return img, [left, top, i, j] 112 | #%% 113 | def cosine_window(coutput1): 114 | math.cos() 115 | 116 | 117 | return 118 | 119 | 120 | 121 | #%% 122 | #class PointCenterCrop(object): 123 | # def __init__( gtbox, area): 124 | # gtbox = gtbox 125 | # area = area 126 | # 127 | # def __call__( img): 128 | # return point_center_crop(img, gtbox, area) 129 | # 130 | # def __repr__(): 131 | # return __class__.__name__ + '(gtbox={0})'.format(gtbox) 132 | 133 | #%% 134 | ''' 135 | import torch.nn as nn 136 | 137 | features = nn.Sequential( 138 | nn.Conv2d(3, 64, kernel_size=11, stride=2), 139 | nn.ReLU(inplace=True), 140 | nn.MaxPool2d(kernel_size=3, stride=2), 141 | nn.Conv2d(64, 192, kernel_size=5), 142 | nn.ReLU(inplace=True), 143 | nn.MaxPool2d(kernel_size=3, stride=2), 144 | nn.Conv2d(192, 384, kernel_size=3), 145 | nn.ReLU(inplace=True), 146 | nn.Conv2d(384, 256, kernel_size=3), 147 | nn.ReLU(inplace=True), 148 | nn.Conv2d(256, 256, kernel_size=3), 149 | ) 150 | 151 | k = 5 152 | conv1 = nn.Conv2d(256, 2*k*256, kernel_size=3) 153 | conv2 = nn.Conv2d(256, 4*k*256, kernel_size=3) 154 | conv3 = nn.Conv2d(256, 256, kernel_size=3) 155 | conv4 = nn.Conv2d(256, 256, kernel_size=3) 156 | 157 | cconv = nn.Conv2d(256, 2* k, kernel_size = 4, bias = False) 158 | rconv = nn.Conv2d(256, 4* k, kernel_size = 4, bias = False) 159 | # cconv.train(False) 160 | # rconv.train(False) 161 | 162 | # reset_params() 163 | # freeze_layers(8) 164 | 165 | # def reset_params(): 166 | # pretrained_dict = model_zoo.load_url(model_urls['alexnet']) 167 | # model_dict = state_dict() 168 | # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 169 | # model_dict.update(pretrained_dict) 170 | # load_state_dict(model_dict) 171 | 172 | # def freeze_layers( number): 173 | # for i in range(number): 174 | # features[i].train(False) 175 | 176 | # def forward( template, detection): 177 | template = features(template) 178 | detection = features(detection) 179 | 180 | 181 | ckernal = conv1(template) 182 | ckernal = ckernal.view(2* k, 256, 4, 4) 183 | cconv.weight = nn.Parameter(ckernal.data) 184 | cinput = conv3(detection) 185 | coutput = cconv(cinput) 186 | 187 | rkernal = conv2(template) 188 | rkernal = rkernal.view(4* k, 256, 4, 4) 189 | rconv.weight = nn.Parameter(rkernal.data) 190 | rinput = conv4(detection) 191 | routput = rconv(rinput) 192 | 193 | # return template, detection 194 | return coutput, routput 195 | ''' 196 | ''' 197 | #%% 198 | import numpy as np 199 | import math 200 | import torch 201 | from PIL import Image 202 | from torchvision import transforms 203 | import os 204 | #from torch.utils.data import Dataset 205 | ''' 206 | #%% 207 | 208 | 209 | ''' 210 | #%% 211 | import numpy as np 212 | import math 213 | import torch 214 | from torch.utils.data import Dataset, DataLoader 215 | from PIL import Image 216 | from torchvision.transforms import functional as F 217 | import os 218 | from axis import x1y1x2y2_to_xywh, xywh_to_x1y1x2y2, point_center_crop, resize 219 | detection_root_dir = './lq/JPEGImages/' 220 | gtbox_root_dir = './lq/label/' 221 | def _get_anchor_shape( a): 222 | s = a**2 223 | r = [[3*math.sqrt(s/3.),math.sqrt(s/3.)], [2*math.sqrt(s/2.),math.sqrt(s/2.)], 224 | [a,a], [math.sqrt(s/2.),2*math.sqrt(s/2.)], [math.sqrt(s/3.),3*math.sqrt(s/3.)]] 225 | return [list(map(round, i)) for i in r] 226 | 227 | def __len__(): 228 | return len(os.listdir(detection_root_dir)) 229 | 230 | """读取数据集时,将会调用下面这个方法来获取数据 231 | """ 232 | def __getitem__( index): 233 | 234 | img = os.listdir(detection_root_dir)[0] 235 | img = Image.open(detection_root_dir + img) 236 | gtbox = os.listdir(gtbox_root_dir)[0] 237 | with open(gtbox_root_dir + gtbox) as f: 238 | gtbox = f.read().split(' ')[1:] 239 | gtbox = [int(i) for i in gtbox] 240 | gtbox = x1y1x2y2_to_xywh(gtbox) 241 | template, _, _ = _transform(img, gtbox, 1, 127) 242 | for index in range(100): 243 | # index=80 244 | 245 | img = os.listdir(detection_root_dir)[index] 246 | img = Image.open(detection_root_dir + img) 247 | gtbox = os.listdir(gtbox_root_dir)[index] 248 | with open(gtbox_root_dir + gtbox) as f: 249 | gtbox = f.read().split(' ')[1:] 250 | gtbox = [int(i) for i in gtbox] 251 | gtbox = x1y1x2y2_to_xywh(gtbox) 252 | # template = _transform(img, gtbox, 1, 127) 253 | detection, pcc, ratio = _transform(img, gtbox, 2, 255) 254 | 255 | 256 | a = (gtbox[2]+gtbox[3]) / 2. 257 | a = math.sqrt((gtbox[2]+a)*(gtbox[3]+a)) * 2 258 | gtbox = [127, 127, round(255*gtbox[2]/a), round(255*gtbox[3]/a)] 259 | list1 = xywh_to_x1y1x2y2(gtbox) 260 | import cv2 261 | detection = cv2.cvtColor(np.asarray(detection),cv2.COLOR_RGB2BGR) 262 | 263 | cv2.rectangle(detection, (list1[0],list1[1]), (list1[2],list1[3]), (0,255,0), 1) 264 | detection = Image.fromarray(cv2.cvtColor(detection,cv2.COLOR_BGR2RGB)) 265 | detection.save('./tmp/'+str(index)+'.jpg') 266 | #detection = Image.fromarray(np.array(detection)) 267 | #detection.show() 268 | 269 | 270 | clabel, rlabel = _gtbox_to_label(gtbox) 271 | return template, detection, clabel, rlabel, pcc, ratio 272 | 273 | #数据转换,包括裁剪、变形、转换为tensor、归一化 274 | # 275 | def _transform( img, gtbox, area, size): 276 | img, pcc = point_center_crop(img, gtbox, area) 277 | img, ratio = resize(img, size) 278 | # img = F.to_tensor(img) 279 | # img = F.normalize(img, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 280 | return img, pcc, ratio 281 | 282 | 283 | # def _transform( img, gtbox, area, scale): 284 | # trans = transforms.Compose([ 285 | # PointCenterCrop(gtbox, area = area), 286 | # transforms.Resize(scale), 287 | # transforms.ToTensor(), 288 | # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 289 | # ]) 290 | # return trans(img) 291 | 292 | """根据ground truth box构造class label和reg label 293 | """ 294 | def _gtbox_to_label( gtbox): 295 | clabel = np.zeros([5, 17, 17]) - 100 296 | rlabel = np.zeros([20, 17, 17], dtype = np.float32) 297 | pos, neg = _get_64_anchors(gtbox) 298 | for i in range(len(pos)): 299 | clabel[pos[i, 2], pos[i, 0], pos[i, 1]] = 1 300 | for i in range(len(neg)): 301 | clabel[neg[i, 2], neg[i, 0], neg[i, 1]] = 0 302 | pos_coord = _anchor_coord(pos) 303 | channel0 = (gtbox[0] - pos_coord[:, 0]) / pos_coord[:, 2] 304 | channel1 = (gtbox[1] - pos_coord[:, 1]) / pos_coord[:, 3] 305 | channel2 = np.array([math.log(i) for i in (gtbox[2] / pos_coord[:, 2]).tolist()]) 306 | channel3 = np.array([math.log(i) for i in (gtbox[3] / pos_coord[:, 3]).tolist()]) 307 | for i in range(len(pos)): 308 | rlabel[pos[i][2]*4, pos[i][0], pos[i][1]] = channel0[i] 309 | rlabel[pos[i][2]*4 + 1, pos[i][0], pos[i][1]] = channel1[i] 310 | rlabel[pos[i][2]*4 + 2, pos[i][0], pos[i][1]] = channel2[i] 311 | rlabel[pos[i][2]*4 + 3, pos[i][0], pos[i][1]] = channel3[i] 312 | return torch.Tensor(clabel).long(), torch.Tensor(rlabel).float() 313 | 314 | """根据anchor在label中的位置来获取anchor在detection frame中的坐标 315 | """ 316 | def _anchor_coord( pos): 317 | result = np.ndarray([0, 4]) 318 | for i in pos: 319 | tmp = [7+15*i[0], 7+15*i[1], anchor_shape[i[2]][0], anchor_shape[i[2]][1]] 320 | result = np.concatenate([result, np.array(tmp).reshape([1,4])], axis = 0) 321 | return result 322 | 323 | def _get_64_anchors( gtbox): 324 | pos = {} 325 | neg = {} 326 | for a in range(17): 327 | for b in range(17): 328 | for c in range(5): 329 | anchor = [7+15*a, 7+15*b, anchor_shape[c][0], anchor_shape[c][1]] 330 | anchor = xywh_to_x1y1x2y2(anchor) 331 | if anchor[0]>0 and anchor[1]>0 and anchor[2]<255 and anchor[3]<255: 332 | iou = _IOU(anchor, gtbox) 333 | if iou >= 0.6: 334 | pos['%d,%d,%d' % (a,b,c)] = iou 335 | elif iou <= 0.3: 336 | neg['%d,%d,%d' % (a,b,c)] = iou 337 | pos = sorted(pos.items(),key = lambda x:x[1],reverse = True) 338 | pos = [list(map(int, i[0].split(','))) for i in pos[:16]] 339 | neg = sorted(neg.items(),key = lambda x:x[1],reverse = True) 340 | neg = [list(map(int, i[0].split(','))) for i in neg[:(64-len(pos))]] 341 | return np.array(pos), np.array(neg) 342 | 343 | # def _f( x): 344 | # if x <= 0: return 0 345 | # elif x >= 254: return 254 346 | # else: return x 347 | 348 | def _IOU( a, b): 349 | # a = xywh_to_x1y1x2y2(a) 350 | b = xywh_to_x1y1x2y2(b) 351 | sa = (a[2] - a[0]) * (a[3] - a[1]) 352 | sb = (b[2] - b[0]) * (b[3] - b[1]) 353 | w = max(0, min(a[2], b[2]) - max(a[0], b[0])) 354 | h = max(0, min(a[3], b[3]) - max(a[1], b[1])) 355 | area = w * h 356 | return area / (sa + sb - area) 357 | ''' --------------------------------------------------------------------------------