├── .gitignore ├── image.npy ├── video.npy ├── material ├── 9707366.jpg ├── shancun.JPG ├── 9707366-1.jpg └── FM190311-10.mp4 ├── images ├── test_img_1.jpg ├── test_img_2.jpg ├── theta_dist.png ├── video_match.JPG └── cosine_similarity_vs_time.png ├── config.py ├── list_names.py ├── extract_image.py ├── README.md ├── utils.py ├── LICENSE ├── detect.py ├── demo.py ├── gen_features.py ├── extract_feature.py └── models.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | __pycache__/ -------------------------------------------------------------------------------- /image.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Video-Matching/HEAD/image.npy -------------------------------------------------------------------------------- /video.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Video-Matching/HEAD/video.npy -------------------------------------------------------------------------------- /material/9707366.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Video-Matching/HEAD/material/9707366.jpg -------------------------------------------------------------------------------- /material/shancun.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Video-Matching/HEAD/material/shancun.JPG -------------------------------------------------------------------------------- /images/test_img_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Video-Matching/HEAD/images/test_img_1.jpg -------------------------------------------------------------------------------- /images/test_img_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Video-Matching/HEAD/images/test_img_2.jpg -------------------------------------------------------------------------------- /images/theta_dist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Video-Matching/HEAD/images/theta_dist.png -------------------------------------------------------------------------------- /images/video_match.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Video-Matching/HEAD/images/video_match.JPG -------------------------------------------------------------------------------- /material/9707366-1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Video-Matching/HEAD/material/9707366-1.jpg -------------------------------------------------------------------------------- /material/FM190311-10.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Video-Matching/HEAD/material/FM190311-10.mp4 -------------------------------------------------------------------------------- /images/cosine_similarity_vs_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Video-Matching/HEAD/images/cosine_similarity_vs_time.png -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # sets device for model and PyTorch tensors 4 | # print('device: ' + str(device)) 5 | 6 | im_size = 224 7 | # Data parameters 8 | num_classes = 9935 9 | 10 | pickle_file = 'video_index.pkl' 11 | -------------------------------------------------------------------------------- /list_names.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | 4 | if __name__ == "__main__": 5 | files = [f for f in os.listdir('video') if f.endswith('.mp4')] 6 | print('num_files: ' + str(len(files))) 7 | 8 | folder = 'cache' 9 | if not os.path.isdir(folder): 10 | os.makedirs(folder) 11 | 12 | print('building index...') 13 | i = 0 14 | frames = [] 15 | for file in files: 16 | filename = os.path.join('video', file) 17 | file = file[3:] 18 | tokens = file.split('-') 19 | name = tokens[0] + '-' + tokens[1] 20 | print(name) -------------------------------------------------------------------------------- /extract_image.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 as cv 4 | from tqdm import tqdm 5 | 6 | if __name__ == "__main__": 7 | files = [f for f in os.listdir('video') if f.endswith('.mp4')] 8 | print('num_files: ' + str(len(files))) 9 | 10 | folder = 'v_images' 11 | if not os.path.isdir(folder): 12 | os.makedirs(folder) 13 | 14 | idx = 0 15 | for file in tqdm(files): 16 | filename = os.path.join('video', file) 17 | print(filename) 18 | 19 | cap = cv.VideoCapture(filename) 20 | while cap.isOpened(): 21 | success, frame = cap.read() 22 | if not success: 23 | break 24 | 25 | cv.imwrite('v_images/{}.jpg'.format(idx), frame) 26 | idx = idx + 1 27 | 28 | cap.release() 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 视频匹配 2 | 3 | ## 原理: 4 | 5 | ![image](https://github.com/foamliu/Video-Matching/raw/master/images/video_match.JPG) 6 | 7 | ## 算法流程 8 | 9 | 1. 把视频素材逐帧计算特征。以妙可蓝多为例:帧速=25,10秒广告共251帧,特征512维。 10 | 2. 形成矩阵mat(帧数x维度) 11 | 3. 把巡视照片计算特征得向量 feature 12 | 4. 计算:cosine = np.dot(mat, feature) 13 | 5. 计算: max_index = np.argmax(cosine) 14 | 6. 计算:max_value = cosine[max_index] 15 | 7. 计算:theta = math.acos(max_value) 16 | 8. 阈值计算:若 theta < threshold then return OK, max_index. 17 | 18 | ## 作图 19 | 20 | 余弦相似度与帧的时刻作图: 21 | 22 | ![image](https://github.com/foamliu/Video-Matching/raw/master/images/cosine_similarity_vs_time.png) 23 | 24 | ## 阈值 25 | 26 | 阈值:25.50393648495902 27 | 28 | ![image](https://github.com/foamliu/Video-Matching/raw/master/images/theta_dist.png) 29 | 30 | 31 | ## 结果 32 | 33 | 属性名|属性值| 34 | |---|---| 35 | |视频帧数|251| 36 | |max(余弦相似度)|0.948707| 37 | |theta(角度)|18.43065670378278| 38 | |theta 阈值|25.50393648495902| 39 | |是否匹配|是(*)| 40 | |置信度|0.9967838283534795| 41 | |匹配位置(帧)|82| 42 | |匹配位置(秒)|3.28| 43 | 44 | 注释:因为theta角度小于阈值。 -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from scipy.stats import norm 3 | from torchvision import transforms 4 | 5 | from config import device 6 | 7 | # model params 8 | threshold = 25.50393648495902 9 | mu_0 = 46.1028 10 | sigma_0 = 6.4981 11 | mu_1 = 9.6851 12 | sigma_1 = 3.060 13 | 14 | data_transforms = { 15 | 'train': transforms.Compose([ 16 | transforms.RandomHorizontalFlip(), 17 | transforms.ToTensor(), 18 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 19 | ]), 20 | 'val': transforms.Compose([ 21 | transforms.ToTensor(), 22 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 23 | ]), 24 | } 25 | transformer = data_transforms['val'] 26 | 27 | 28 | def get_image(img): 29 | img = img[..., ::-1] # RGB 30 | img = Image.fromarray(img, 'RGB') # RGB 31 | img = transformer(img) 32 | return img.to(device) 33 | 34 | 35 | def get_prob(theta): 36 | prob_0 = norm.pdf(theta, mu_0, sigma_0) 37 | prob_1 = norm.pdf(theta, mu_1, sigma_1) 38 | total = prob_0 + prob_1 39 | return prob_1 / total 40 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 刘杨 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /detect.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import cv2 as cv 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | 7 | im_size = 224 8 | 9 | 10 | def get_frame_list(): 11 | video = 'material/FM190311-10.mp4' 12 | cap = cv.VideoCapture(video) 13 | frame_list = [] 14 | print('collecting frames...') 15 | while cap.isOpened(): 16 | success, frame = cap.read() 17 | if not success: 18 | break 19 | frame = cv.resize(frame, (im_size, im_size)) 20 | frame_list.append(frame) 21 | frame_count = len(frame_list) 22 | print('frame_count: ' + str(frame_count)) 23 | return frame_list 24 | 25 | 26 | if __name__ == "__main__": 27 | mat = np.load('video.npy') 28 | feature = np.load('image.npy') 29 | print(mat.shape) 30 | # print(feature) 31 | print(feature.shape) 32 | frame_count = mat.shape[0] 33 | cosine = np.dot(mat, feature) 34 | cosine = np.clip(cosine, -1, 1) 35 | print(cosine.shape) 36 | max_index = np.argmax(cosine) 37 | max_value = cosine[max_index] 38 | print(max_index) 39 | print(max_value) 40 | 41 | threshold = 50 42 | theta = math.acos(max_value) 43 | theta = theta * 180 / math.pi 44 | print(theta) 45 | print(theta < threshold) 46 | 47 | fps = 25. 48 | time = 1 / fps * max_index 49 | print(time) 50 | 51 | x = np.linspace(0, frame_count / fps, frame_count) 52 | plt.plot(x, cosine) 53 | plt.show() 54 | 55 | frame_list = get_frame_list() 56 | matched_frame = frame_list[max_index] 57 | cv.imwrite('match.jpg', matched_frame) 58 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pickle 3 | 4 | import numpy as np 5 | import torch 6 | import cv2 as cv 7 | from config import device, pickle_file, im_size 8 | from utils import get_image, get_prob 9 | 10 | if __name__ == "__main__": 11 | with open(pickle_file, 'rb') as file: 12 | frames = pickle.load(file) 13 | 14 | num_frames = len(frames) 15 | features = np.empty((num_frames, 512), dtype=np.float32) 16 | name_list = [] 17 | idx_list = [] 18 | fps_list = [] 19 | 20 | for i, frame in enumerate(frames): 21 | name = frame['name'] 22 | feature = frame['feature'] 23 | fps = frame['fps'] 24 | idx = frame['idx'] 25 | features[i] = feature 26 | name_list.append(name) 27 | idx_list.append(idx) 28 | fps_list.append(fps) 29 | 30 | print(features.shape) 31 | assert (len(name_list) == num_frames) 32 | 33 | checkpoint = 'BEST_checkpoint.tar' 34 | print('loading model: {}...'.format(checkpoint)) 35 | checkpoint = torch.load(checkpoint) 36 | model = checkpoint['model'] 37 | model = model.to(device) 38 | model.eval() 39 | 40 | test_fn = 'images/test_img_1.jpg' 41 | img = cv.imread(test_fn) 42 | img = cv.resize(img, (im_size, im_size)) 43 | img = get_image(img) 44 | imgs = torch.zeros([1, 3, im_size, im_size], dtype=torch.float) 45 | imgs[0] = img 46 | with torch.no_grad(): 47 | output = model(imgs) 48 | feature = output[0].cpu().numpy() 49 | x = feature / np.linalg.norm(feature) 50 | 51 | cosine = np.dot(features, x) 52 | cosine = np.clip(cosine, -1, 1) 53 | print('cosine.shape: ' + str(cosine.shape)) 54 | max_index = int(np.argmax(cosine)) 55 | max_value = cosine[max_index] 56 | print('max_index: ' + str(max_index)) 57 | print('max_value: ' + str(max_value)) 58 | print('name: ' + name_list[max_index]) 59 | print('fps: ' + str(fps_list[max_index])) 60 | print('idx: ' + str(idx_list[max_index])) 61 | theta = math.acos(max_value) 62 | theta = theta * 180 / math.pi 63 | 64 | print('theta: ' + str(theta)) 65 | prob = get_prob(theta) 66 | print('prob: ' + str(prob)) 67 | -------------------------------------------------------------------------------- /gen_features.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pickle 4 | 5 | import cv2 as cv 6 | import numpy as np 7 | import torch 8 | from tqdm import tqdm 9 | 10 | from config import device 11 | from config import im_size 12 | from utils import get_image 13 | 14 | if __name__ == "__main__": 15 | files = [f for f in os.listdir('video') if f.endswith('.mp4')] 16 | print('num_files: ' + str(len(files))) 17 | 18 | folder = 'cache' 19 | if not os.path.isdir(folder): 20 | os.makedirs(folder) 21 | 22 | print('building index...') 23 | i = 0 24 | frames = [] 25 | for file in tqdm(files): 26 | filename = os.path.join('video', file) 27 | file = file[3:] 28 | tokens = file.split('-') 29 | name = tokens[0] + '-' + tokens[1] 30 | 31 | cap = cv.VideoCapture(filename) 32 | fps = cap.get(cv.CAP_PROP_FPS) 33 | frame_idx = 0 34 | while cap.isOpened(): 35 | success, frame = cap.read() 36 | if not success: 37 | break 38 | 39 | frame_info = dict() 40 | frame_info['name'] = name 41 | frame_info['idx'] = frame_idx 42 | frame_info['fps'] = fps 43 | image_fn = os.path.join(folder, str(i) + '.jpg') 44 | cv.imwrite(image_fn, frame) 45 | frame_info['image_fn'] = image_fn 46 | frames.append(frame_info) 47 | frame_idx += 1 48 | i += 1 49 | 50 | with open('video_index.json', 'w') as file: 51 | json.dump(frames, file, ensure_ascii=False, indent=4) 52 | 53 | num_frames = len(frames) 54 | print('num_frames: ' + str(num_frames)) 55 | assert (i == num_frames) 56 | 57 | checkpoint = 'BEST_checkpoint.tar' 58 | print('loading model: {}...'.format(checkpoint)) 59 | checkpoint = torch.load(checkpoint) 60 | model = checkpoint['model'].module.to(device) 61 | model.eval() 62 | 63 | print('generating features...') 64 | with torch.no_grad(): 65 | for frame in tqdm(frames): 66 | image_fn = frame['image_fn'] 67 | img = cv.imread(image_fn) 68 | img = cv.resize(img, (im_size, im_size)) 69 | img = get_image(img) 70 | imgs = torch.zeros([1, 3, im_size, im_size], dtype=torch.float) 71 | imgs[0] = img 72 | imgs = imgs.to(device) 73 | output = model(imgs) 74 | feature = output[0].cpu().numpy() 75 | feature = feature / np.linalg.norm(feature) 76 | frame['feature'] = feature 77 | 78 | with open('video_index.pkl', 'wb') as file: 79 | pickle.dump(frames, file) 80 | -------------------------------------------------------------------------------- /extract_feature.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import cv2 as cv 4 | import numpy as np 5 | import torch 6 | from PIL import Image 7 | from torchvision import transforms 8 | from tqdm import tqdm 9 | 10 | from config import im_size 11 | 12 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # sets device for model and PyTorch tensors 13 | 14 | # Data augmentation and normalization for training 15 | # Just normalization for validation 16 | data_transforms = { 17 | 'train': transforms.Compose([ 18 | transforms.ToTensor(), 19 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 20 | ]), 21 | 'val': transforms.Compose([ 22 | transforms.ToTensor(), 23 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 24 | ]), 25 | } 26 | 27 | checkpoint = 'BEST_checkpoint.tar' 28 | print('loading model: {}...'.format(checkpoint)) 29 | checkpoint = torch.load(checkpoint) 30 | model = checkpoint['model'].to(device) 31 | model.eval() 32 | transformer = data_transforms['val'] 33 | 34 | 35 | def get_image(img, transformer): 36 | img = img[..., ::-1] # RGB 37 | img = Image.fromarray(img, 'RGB') # RGB 38 | img = transformer(img) 39 | return img.to(device) 40 | 41 | 42 | def gen_feature(frame_list): 43 | file_count = len(frame_list) 44 | batch_size = 128 45 | ret_mat = np.empty((file_count, 512), np.float32) 46 | 47 | with torch.no_grad(): 48 | for start_idx in tqdm(range(0, file_count, batch_size)): 49 | end_idx = min(file_count, start_idx + batch_size) 50 | length = end_idx - start_idx 51 | 52 | imgs = torch.zeros([length, 3, im_size, im_size], dtype=torch.float) 53 | for idx in range(0, length): 54 | i = start_idx + idx 55 | imgs[idx] = get_image(frame_list[i], transformer) 56 | 57 | features = model(imgs.to(device)).cpu().numpy() 58 | for idx in range(0, length): 59 | feature = features[idx] 60 | feature = feature / np.linalg.norm(feature) 61 | i = start_idx + idx 62 | ret_mat[i] = feature 63 | 64 | return ret_mat 65 | 66 | 67 | if __name__ == "__main__": 68 | video = 'material/FM190311-10.mp4' 69 | image = 'material/shancun.JPG' 70 | 71 | cap = cv.VideoCapture(video) 72 | 73 | frame_list = [] 74 | frame_idx = 0 75 | 76 | print('collecting frames...') 77 | while cap.isOpened(): 78 | success, frame = cap.read() 79 | if not success: 80 | break 81 | frame = cv.resize(frame, (im_size, im_size)) 82 | frame_list.append(frame) 83 | frame_count = len(frame_list) 84 | print('frame_count: ' + str(frame_count)) 85 | 86 | print('generating features...') 87 | start = time.time() 88 | mat = gen_feature(frame_list) 89 | np.save('video', mat) 90 | end = time.time() 91 | elapsed = end - start 92 | elapsed_per_frame = elapsed / frame_count 93 | print('elapsed: ' + str(elapsed)) 94 | print('elapsed_per_frame: ' + str(elapsed_per_frame)) 95 | 96 | img = cv.imread(image) 97 | img = cv.resize(img, (im_size, im_size)) 98 | img_list = [img] 99 | mat = gen_feature(img_list) 100 | np.save('image', mat[0]) 101 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | import torch.utils.model_zoo as model_zoo 6 | from torch import nn 7 | from torch.nn import Parameter 8 | from torchsummary import summary 9 | 10 | from config import device, num_classes 11 | 12 | 13 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 14 | 'resnet152'] 15 | 16 | model_urls = { 17 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 18 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 19 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 20 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 21 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 22 | } 23 | 24 | 25 | def conv3x3(in_planes, out_planes, stride=1): 26 | """3x3 convolution with padding""" 27 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 28 | padding=1, bias=False) 29 | 30 | 31 | class BasicBlock(nn.Module): 32 | expansion = 1 33 | 34 | def __init__(self, inplanes, planes, stride=1, downsample=None): 35 | super(BasicBlock, self).__init__() 36 | self.conv1 = conv3x3(inplanes, planes, stride) 37 | self.bn1 = nn.BatchNorm2d(planes) 38 | self.relu = nn.ReLU(inplace=True) 39 | self.conv2 = conv3x3(planes, planes) 40 | self.bn2 = nn.BatchNorm2d(planes) 41 | self.downsample = downsample 42 | self.stride = stride 43 | 44 | def forward(self, x): 45 | residual = x 46 | 47 | out = self.conv1(x) 48 | out = self.bn1(out) 49 | out = self.relu(out) 50 | 51 | out = self.conv2(out) 52 | out = self.bn2(out) 53 | 54 | if self.downsample is not None: 55 | residual = self.downsample(x) 56 | 57 | out += residual 58 | out = self.relu(out) 59 | 60 | return out 61 | 62 | 63 | class Bottleneck(nn.Module): 64 | expansion = 4 65 | 66 | def __init__(self, inplanes, planes, stride=1, downsample=None): 67 | super(Bottleneck, self).__init__() 68 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 69 | self.bn1 = nn.BatchNorm2d(planes) 70 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 71 | padding=1, bias=False) 72 | self.bn2 = nn.BatchNorm2d(planes) 73 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 74 | self.bn3 = nn.BatchNorm2d(planes * 4) 75 | self.relu = nn.ReLU(inplace=True) 76 | self.downsample = downsample 77 | self.stride = stride 78 | 79 | def forward(self, x): 80 | residual = x 81 | 82 | out = self.conv1(x) 83 | out = self.bn1(out) 84 | out = self.relu(out) 85 | 86 | out = self.conv2(out) 87 | out = self.bn2(out) 88 | out = self.relu(out) 89 | 90 | out = self.conv3(out) 91 | out = self.bn3(out) 92 | 93 | if self.downsample is not None: 94 | residual = self.downsample(x) 95 | 96 | out += residual 97 | out = self.relu(out) 98 | 99 | return out 100 | 101 | 102 | class SEBlock(nn.Module): 103 | def __init__(self, channel, reduction=16): 104 | super(SEBlock, self).__init__() 105 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 106 | self.fc = nn.Sequential( 107 | nn.Linear(channel, channel // reduction), 108 | nn.PReLU(), 109 | nn.Linear(channel // reduction, channel), 110 | nn.Sigmoid() 111 | ) 112 | 113 | def forward(self, x): 114 | b, c, _, _ = x.size() 115 | y = self.avg_pool(x).view(b, c) 116 | y = self.fc(y).view(b, c, 1, 1) 117 | return x * y 118 | 119 | 120 | class IRBlock(nn.Module): 121 | expansion = 1 122 | 123 | def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True): 124 | super(IRBlock, self).__init__() 125 | self.bn0 = nn.BatchNorm2d(inplanes) 126 | self.conv1 = conv3x3(inplanes, inplanes) 127 | self.bn1 = nn.BatchNorm2d(inplanes) 128 | self.prelu = nn.PReLU() 129 | self.conv2 = conv3x3(inplanes, planes, stride) 130 | self.bn2 = nn.BatchNorm2d(planes) 131 | self.downsample = downsample 132 | self.stride = stride 133 | self.use_se = use_se 134 | if self.use_se: 135 | self.se = SEBlock(planes) 136 | 137 | def forward(self, x): 138 | residual = x 139 | out = self.bn0(x) 140 | out = self.conv1(out) 141 | out = self.bn1(out) 142 | out = self.prelu(out) 143 | 144 | out = self.conv2(out) 145 | out = self.bn2(out) 146 | if self.use_se: 147 | out = self.se(out) 148 | 149 | if self.downsample is not None: 150 | residual = self.downsample(x) 151 | 152 | out += residual 153 | out = self.prelu(out) 154 | 155 | return out 156 | 157 | 158 | class ResNet(nn.Module): 159 | 160 | def __init__(self, block, layers, use_se=True): 161 | self.inplanes = 64 162 | self.use_se = use_se 163 | super(ResNet, self).__init__() 164 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, bias=False) 165 | self.bn1 = nn.BatchNorm2d(64) 166 | self.prelu = nn.PReLU() 167 | self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2) 168 | self.layer1 = self._make_layer(block, 64, layers[0]) 169 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 170 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 171 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 172 | self.bn2 = nn.BatchNorm2d(512) 173 | self.dropout = nn.Dropout() 174 | self.fc = nn.Linear(512 * 14 * 14, 512) 175 | self.bn3 = nn.BatchNorm1d(512) 176 | 177 | for m in self.modules(): 178 | if isinstance(m, nn.Conv2d): 179 | nn.init.xavier_normal_(m.weight) 180 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): 181 | nn.init.constant_(m.weight, 1) 182 | nn.init.constant_(m.bias, 0) 183 | elif isinstance(m, nn.Linear): 184 | nn.init.xavier_normal_(m.weight) 185 | nn.init.constant_(m.bias, 0) 186 | 187 | def _make_layer(self, block, planes, blocks, stride=1): 188 | downsample = None 189 | if stride != 1 or self.inplanes != planes * block.expansion: 190 | downsample = nn.Sequential( 191 | nn.Conv2d(self.inplanes, planes * block.expansion, 192 | kernel_size=1, stride=stride, bias=False), 193 | nn.BatchNorm2d(planes * block.expansion), 194 | ) 195 | 196 | layers = [] 197 | layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se)) 198 | self.inplanes = planes 199 | for i in range(1, blocks): 200 | layers.append(block(self.inplanes, planes, use_se=self.use_se)) 201 | 202 | return nn.Sequential(*layers) 203 | 204 | def forward(self, x): 205 | x = self.conv1(x) 206 | x = self.bn1(x) 207 | x = self.prelu(x) 208 | x = self.maxpool(x) 209 | 210 | x = self.layer1(x) 211 | x = self.layer2(x) 212 | x = self.layer3(x) 213 | x = self.layer4(x) 214 | 215 | x = self.bn2(x) 216 | x = self.dropout(x) 217 | x = x.view(x.size(0), -1) 218 | x = self.fc(x) 219 | x = self.bn3(x) 220 | 221 | return x 222 | 223 | 224 | def resnet18(args, **kwargs): 225 | model = ResNet(IRBlock, [2, 2, 2, 2], use_se=args.use_se, **kwargs) 226 | if args.pretrained: 227 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 228 | return model 229 | 230 | 231 | def resnet34(args, **kwargs): 232 | model = ResNet(IRBlock, [3, 4, 6, 3], use_se=args.use_se, **kwargs) 233 | if args.pretrained: 234 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 235 | return model 236 | 237 | 238 | def resnet50(args, **kwargs): 239 | model = ResNet(IRBlock, [3, 4, 6, 3], use_se=args.use_se, **kwargs) 240 | if args.pretrained: 241 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 242 | return model 243 | 244 | 245 | def resnet101(args, **kwargs): 246 | model = ResNet(IRBlock, [3, 4, 23, 3], use_se=args.use_se, **kwargs) 247 | if args.pretrained: 248 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 249 | return model 250 | 251 | 252 | def resnet152(args, **kwargs): 253 | model = ResNet(IRBlock, [3, 8, 36, 3], use_se=args.use_se, **kwargs) 254 | if args.pretrained: 255 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 256 | return model 257 | 258 | 259 | def resnet_face18(use_se=True, **kwargs): 260 | model = ResNet(IRBlock, [2, 2, 2, 2], use_se=use_se, **kwargs) 261 | return model 262 | 263 | 264 | class ArcMarginModel(nn.Module): 265 | def __init__(self, args): 266 | super(ArcMarginModel, self).__init__() 267 | 268 | self.weight = Parameter(torch.FloatTensor(num_classes, args.emb_size)) 269 | nn.init.xavier_uniform_(self.weight) 270 | 271 | self.easy_margin = args.easy_margin 272 | self.m = args.margin_m 273 | self.s = args.margin_s 274 | 275 | self.cos_m = math.cos(self.m) 276 | self.sin_m = math.sin(self.m) 277 | self.th = math.cos(math.pi - self.m) 278 | self.mm = math.sin(math.pi - self.m) * self.m 279 | 280 | def forward(self, input, label): 281 | x = F.normalize(input) 282 | W = F.normalize(self.weight) 283 | cosine = F.linear(x, W) 284 | sine = torch.sqrt(1.0 - torch.pow(cosine, 2)) 285 | phi = cosine * self.cos_m - sine * self.sin_m # cos(theta + m) 286 | if self.easy_margin: 287 | phi = torch.where(cosine > 0, phi, cosine) 288 | else: 289 | phi = torch.where(cosine > self.th, phi, cosine - self.mm) 290 | one_hot = torch.zeros(cosine.size(), device=device) 291 | one_hot.scatter_(1, label.view(-1, 1).long(), 1) 292 | output = (one_hot * phi) + ((1.0 - one_hot) * cosine) 293 | output *= self.s 294 | return output 295 | 296 | 297 | if __name__ == "__main__": 298 | args = parse_args() 299 | model = resnet50(args).to(device) 300 | summary(model, (3, 224, 224)) 301 | --------------------------------------------------------------------------------