├── .gitignore ├── LICENSE ├── README.md ├── creat_DukeV_database.py ├── create_MARS_database.py ├── evaluate.py ├── fig ├── NVAN.jpg └── STE-NVAN.jpg ├── net ├── models.py └── resnet.py ├── parser.py ├── run_NL.sh ├── run_baseline.sh ├── run_evaluate.sh ├── train_NL.py ├── train_baseline.py └── util ├── cmc.py ├── loss.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | MARS_database/ 2 | DukeV_database/ 3 | __pycache__/ 4 | ckpt*/ 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Chih-Ting, Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Spatially and Temporally Efficient Non-local Attention Network for Video-based Person Re-Identification 2 | - **NVAN** 3 |

4 | 5 | - **STE-NVAN** 6 |

7 | 8 | [[Paper]](http://media.ee.ntu.edu.tw/research/STE_NVAN/BMVC19_STE_NVAN_cam.pdf) [[arXiv]](https://arxiv.org/abs/1908.01683) 9 | 10 | [Chih-Ting Liu](https://jackie840129.github.io/), Chih-Wei Wu, [Yu-Chiang Frank Wang](http://vllab.ee.ntu.edu.tw/members.html) and [Shao-Yi Chien](http://www.ee.ntu.edu.tw/profile?id=101),
British Machine Vision Conference (**BMVC**), 2019 11 | 12 | This is the pytorch implementatin of Spatially and Temporally Efficient Non-local Video Attention Network **(STE-NVAN)** for video-based person Re-ID. 13 |
It achieves **90.0%** for the baseline version and **88.9%** for the ST-efficient model in rank-1 accuracy on MARS dataset. 14 | 15 | ## News ## 16 | 17 | **`2021-06-13`**: We will update this repro to a new version, which is similar to our new work [CF-AAN](https://github.com/jackie840129/CF-AAN) ! 18 | 19 | ## Prerequisites 20 | - Python3.5+ 21 | - [Pytorch](https://pytorch.org/) (We run the code under version 1.0.) 22 | - torchvisoin (We run the code under version 0.2.2) 23 | 24 | ## Getting Started 25 | 26 | ### Installation 27 | - Install dependancy. You can install all the dependancies by: 28 | ``` 29 | $ pip3 install numpy, Pillow, progressbar2, tqdm, pandas 30 | ``` 31 | 32 | ### Datasets 33 | We conduct experiments on [MARS](http://www.liangzheng.com.cn/Project/project_mars.html) and [DukeMTMC-VideoReID](https://github.com/Yu-Wu/DukeMTMC-VideoReID) (DukeV) datasets. 34 | 35 | **For MARS dataset:** 36 | - Download and unzip the dataset from the official website. ([Google Drive](https://drive.google.com/drive/u/1/folders/0B6tjyrV1YrHeMVV2UFFXQld6X1E)) 37 | - Clone the repo of [MARS-evaluation](https://github.com/liangzheng06/MARS-evaluation). We will need the files under **info/** directory. 38 |
You will have the structure as follows: 39 | ``` 40 | path/to/your/MARS dataset/ 41 | |-- bbox_train/ 42 | |-- bbox_test/ 43 | |-- MARS-evaluation/ 44 | | |-- info/ 45 | ``` 46 | - run `create_MARS_database.py` to create the database files (.txt and .npy files) into "MARS_database" directory. 47 | ``` 48 | $ python3 create_MARS_database.py --data_dir /path/to/MARS dataset/ \ 49 | --info_dir /path/to/MARS dataset/MARS-evaluation/info/ \ 50 | --output_dir ./MARS_database/ 51 | ``` 52 | 53 | **For DukeV dataset:** 54 | - Download and unzip the dataset from the official github page. ([data link](http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-VideoReID.zip)) 55 |
You will have the structure as follows: 56 | ``` 57 | path/to/your/DukeV dataset/ 58 | |-- gallery/ 59 | |-- query/ 60 | |-- train/ 61 | ``` 62 | - run `create_DukeV_database.py` to create the database files (.txt and .npy files) into "DukeV_database" directory. 63 | ``` 64 | $ python3 create_DukeV_database.py --data_dir /path/to/DukeV dataset/ \ 65 | --output_dir ./DukeV_database/ 66 | ``` 67 | ## Usage-Testing 68 | We rewrite the evaluation code in [here](https://github.com/liangzheng06/MARS-evaluation) with python. 69 | 70 | Furthermore, we follow the video-based evaluation metric in this [paper](https://zpascal.net/cvpr2018/Li_Diversity_Regularized_Spatiotemporal_CVPR_2018_paper.pdf). 71 | 72 | In detail, we will sample the first frame in each chunk of a tracklet. 73 | 74 | ### Prerequisite 75 | For testing, we provide three trained models on **MARS** dataset in this [**link**](https://drive.google.com/drive/folders/1yi4RJHhu8iMtewdnWYpLCLkIi0okjl35?usp=sharing). 76 | 77 | You should first create a directory with this command: `$ mkdir ckpt`, to put these three models under the directory. 78 | 79 | All three execution commands are in the script `run_evaluate.sh`. 80 | You can check and alter the arguments inside and run 81 | ``` 82 | $ sh run_evaluate.sh 83 | ``` 84 | to obtain the rank-1 accuracy and the mAP score. 85 | 86 | Some scores are different to those in my paper because some models are lost in my previous computer. (I've retrained them again.) 87 | 88 | The evaluation commands of three models are as follows. 89 | 90 | ### Baseine model : Resnet50 + FPL (mean) 91 | Uncomment this part. You will get R1=87.42% and mAP=79.44%. 92 | ``` 93 | # Evaluate ResNet50 + FPL (mean or max) 94 | LOAD_CKPT=./ckpt/R50_baseline_mean.pth 95 | python3 evaluate.py --test_txt $TEST_TXT --test_info $TEST_INFO --query_info $QUERY_INFO \ 96 | --batch_size 64 --model_type 'resnet50_s1' --num_workers 8 --S 8 \ 97 | --latent_dim 2048 --temporal mean --stride 1 --load_ckpt $LOAD_CKPT 98 | ``` 99 | ### NVAN : R50 + 5 Non-local layers + FPL 100 | Uncomment this part. You will get R1=90.00% and mAP=82.79%. 101 | ``` 102 | #Evaluate NVAN (R50 + 5 NL + FPL) 103 | LOAD_CKPT=./ckpt/NVAN.pth 104 | python3 evaluate.py --test_txt $TEST_TXT --test_info $TEST_INFO --query_info $QUERY_INFO \ 105 | --batch_size 64 --model_type 'resnet50_NL' --num_workers 8 --S 8 --latent_dim 2048 \ 106 | --temporal Done --non_layers 0 2 3 0 --load_ckpt $LOAD_CKPT \ 107 | ``` 108 | ### STE-NVAN : NVAN + Spatial Reduction + Temporal Reduction 109 | Uncomment this part. You will get R1=88.69% and mAP=81.27%. 110 | ``` 111 | # Evaluate NVAN (R50 + 5 NL + Stripe + Hierarchical + FPL) 112 | LOAD_CKPT=./ckpt/STE_NVAN.pth 113 | python3 evaluate.py --test_txt $TEST_TXT --test_info $TEST_INFO --query_info $QUERY_INFO \ 114 | --batch_size 128 --model_type 'resnet50_NL_stripe_hr' --num_workers 8 --S 8 --latent_dim 2048 \ 115 | --temporal Done --non_layers 0 2 3 0 --stripe 16 16 16 16 --load_ckpt $LOAD_CKPT \ 116 | ``` 117 | 118 | ## Usage-Training 119 | As mentioned in our paper, we have three kinds of models. (Baseline, NVAN, STE-NVAN) 120 | 121 | ### Baseine model : Resnet50 + FPL (mean) 122 | You can alter the arguments in `run_baseline.sh` or just use this command: 123 | ``` 124 | $ sh run_baseline.sh 125 | ``` 126 | ### NVAN : R50 + 5 Non-local layers + FPL 127 | You can alter the arguments or uncomment this part in `run_NL.sh`: 128 | ``` 129 | # For NVAN 130 | CKPT=ckpt_NL_0230 131 | python3 train_NL.py --train_txt $TRAIN_TXT --train_info $TRAIN_INFO --batch_size 64 \ 132 | --test_txt $TEST_TXT --test_info $TEST_INFO --query_info $QUERY_INFO \ 133 | --n_epochs 200 --lr 0.0001 --lr_step_size 50 --optimizer adam --ckpt $CKPT --log_path loss.txt --class_per_batch 8 \ 134 | --model_type 'resnet50_NL' --num_workers 8 --track_per_class 4 --S 8 --latent_dim 2048 --temporal Done --track_id_loss \ 135 | --non_layers 0 2 3 0 136 | ``` 137 | Then run this script. 138 | ``` 139 | $ sh run_NL.sh 140 | ``` 141 | ### STE-NVAN : NVAN + Spatial Reduction + Temporal Reduction 142 | You can alter the arguments or uncomment this part in `run_NL.sh`: 143 | ``` 144 | # For STE-NVAN 145 | CKPT=ckpt_NL_stripe16_hr_0230 146 | python3 train_NL.py --train_txt $TRAIN_TXT --train_info $TRAIN_INFO --batch_size 64 \ 147 | --test_txt $TEST_TXT --test_info $TEST_INFO --query_info $QUERY_INFO \ 148 | --n_epochs 200 --lr 0.0001 --lr_step_size 50 --optimizer adam --ckpt $CKPT --log_path loss.txt --class_per_batch 8 \ 149 | --model_type 'resnet50_NL_stripe_hr' --num_workers 8 --track_per_class 4 --S 8 --latent_dim 2048 --temporal Done --track_id_loss \ 150 | --non_layers 0 2 3 0 --stripes 16 16 16 16 151 | ``` 152 | Then run this script. 153 | ``` 154 | $ sh run_NL.sh 155 | ``` 156 | 157 | ## Citation 158 | ``` 159 | @inproceedings{liu2019spatially, 160 | title={Spatially and Temporally Efficient Non-local Attention Network for Video-based Person Re-Identification}, 161 | author={Liu, Chih-Ting and Wu, Chih-Wei and Wang, Yu-Chiang Frank and Chien, Shao-Yi}, 162 | booktitle={British Machine Vision Conference}, 163 | year={2019} 164 | } 165 | ``` 166 | ## Reference 167 | 168 | Chih-Ting Liu, [Media IC & System Lab](https://github.com/mediaic), National Taiwan University 169 | 170 | E-mail : jackieliu@media.ee.ntu.edu.tw 171 | -------------------------------------------------------------------------------- /creat_DukeV_database.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import scipy.io as sio 5 | 6 | IMG_EXTENSIONS = [ 7 | '.jpg', '.JPG', '.jpeg', '.JPEG', 8 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 9 | ] 10 | 11 | def is_image_file(filename): 12 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 13 | 14 | if __name__ == '__main__': 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--data_dir',help='path/to/DukeV/') 17 | parser.add_argument('--output_dir',help='path/to/save/database/',default='./DukeV_database') 18 | args = parser.parse_args() 19 | 20 | os.system('mkdir -p %s'%(args.output_dir)) 21 | # Read images 22 | # Train 23 | train_imgs_path = [] 24 | infos = [] 25 | count = 0 26 | data_dir = os.path.join(args.data_dir,'train') 27 | ids = sorted(os.listdir(data_dir)) 28 | for id in ids: 29 | tracks = sorted(os.listdir(os.path.join(data_dir,id))) 30 | for track in tracks: 31 | info = [] 32 | images = sorted(os.listdir(os.path.join(data_dir,id,track))) 33 | info.append(count) 34 | info.append(count+len(images)-1) 35 | info.append(int(id)) 36 | count = count+len(images) 37 | for image in images: 38 | if is_image_file(image): 39 | _,cam,_,_ = image.split('_') 40 | train_imgs_path.append(os.path.abspath(os.path.join(data_dir,id,track,image))) 41 | info.append(int(cam[1:])) 42 | infos.append(info) 43 | train_imgs_path = np.array(train_imgs_path) 44 | np.savetxt(os.path.join(args.output_dir,'train_path.txt'),train_imgs_path,fmt='%s',delimiter='\n') 45 | np.save(os.path.join(args.output_dir,'train_info.npy'),np.array(infos)) 46 | 47 | query_info = [] 48 | data_dir = os.path.join(args.data_dir,'query') 49 | ids = sorted(os.listdir(data_dir)) 50 | for id in ids: 51 | tracks = sorted(os.listdir(os.path.join(data_dir,id))) 52 | for track in tracks: 53 | query_info.append([id,track]) 54 | # Test 55 | gallery_imgs_path = [] 56 | track_idx = [] 57 | idx = 0 58 | infos = [] 59 | count = 0 60 | data_dir = os.path.join(args.data_dir,'gallery') 61 | ids = sorted(os.listdir(data_dir)) 62 | for id in ids: 63 | tracks = sorted(os.listdir(os.path.join(data_dir,id))) 64 | for track in tracks: 65 | if [id,track] == query_info[0]: 66 | track_idx.append(idx) 67 | del query_info[0] 68 | info = [] 69 | images = sorted(os.listdir(os.path.join(data_dir,id,track))) 70 | info.append(count) 71 | info.append(count+len(images)-1) 72 | info.append(int(id)) 73 | count = count+len(images) 74 | for image in images: 75 | if is_image_file(image): 76 | _,cam,_,_ = image.split('_') 77 | gallery_imgs_path.append(os.path.abspath(os.path.join(data_dir,id,track,image))) 78 | info.append(int(cam[1:])) 79 | infos.append(info) 80 | idx +=1 81 | gallery_imgs_path = np.array(gallery_imgs_path) 82 | np.savetxt(os.path.join(args.output_dir,'gallery_path.txt'),gallery_imgs_path,fmt='%s',delimiter='\n') 83 | np.save(os.path.join(args.output_dir,'gallery_info.npy'),np.array(infos)) 84 | np.save(os.path.join(args.output_dir,'query_IDX.npy'),np.array(track_idx)) 85 | 86 | -------------------------------------------------------------------------------- /create_MARS_database.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import scipy.io as sio 5 | 6 | IMG_EXTENSIONS = [ 7 | '.jpg', '.JPG', '.jpeg', '.JPEG', 8 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 9 | ] 10 | 11 | def is_image_file(filename): 12 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 13 | 14 | if __name__ == '__main__': 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--data_dir',help='path/to/MARS/') 17 | parser.add_argument('--info_dir',help='path/to/MARS-evaluation/info/') 18 | parser.add_argument('--output_dir',help='path/to/save/database',default='./MARS_database') 19 | args = parser.parse_args() 20 | 21 | os.system('mkdir -p %s'%(args.output_dir)) 22 | # Train 23 | train_imgs = [] 24 | data_dir = os.path.join(args.data_dir,'bbox_train') 25 | ids = sorted(os.listdir(data_dir)) 26 | for id in ids: 27 | images = sorted(os.listdir(os.path.join(data_dir,id))) 28 | for image in images: 29 | if is_image_file(image): 30 | train_imgs.append(os.path.abspath(os.path.join(data_dir,id,image))) 31 | train_imgs = np.array(train_imgs) 32 | np.savetxt(os.path.join(args.output_dir,'train_path.txt'),train_imgs,fmt='%s',delimiter='\n') 33 | # Test 34 | test_imgs = [] 35 | data_dir = os.path.join(args.data_dir,'bbox_test') 36 | ids = sorted(os.listdir(data_dir)) 37 | for id in ids: 38 | images = sorted(os.listdir(os.path.join(data_dir,id))) 39 | for image in images: 40 | if is_image_file(image): 41 | test_imgs.append(os.path.abspath(os.path.join(data_dir,id,image))) 42 | test_imgs = np.array(test_imgs) 43 | np.savetxt(os.path.join(args.output_dir,'test_path.txt'),test_imgs,fmt='%s',delimiter='\n') 44 | 45 | ## process matfile 46 | train_info = sio.loadmat(os.path.join(args.info_dir,'tracks_train_info.mat'))['track_train_info'] 47 | test_info = sio.loadmat(os.path.join(args.info_dir,'tracks_test_info.mat'))['track_test_info'] 48 | query_IDX = sio.loadmat(os.path.join(args.info_dir,'query_IDX.mat'))['query_IDX'] 49 | 50 | # start from 0 (matlab starts from 1) 51 | train_info[:,0:2] = train_info[:,0:2]-1 52 | test_info[:,0:2] = test_info[:,0:2]-1 53 | query_IDX = query_IDX -1 54 | np.save(os.path.join(args.output_dir,'train_info.npy'),train_info) 55 | np.save(os.path.join(args.output_dir,'test_info.npy'),test_info) 56 | np.save(os.path.join(args.output_dir,'query_IDX.npy'),query_IDX) 57 | 58 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | from util import utils 2 | from util.cmc import Video_Cmc 3 | from net import models 4 | import parser 5 | import sys 6 | import random 7 | from tqdm import tqdm 8 | import numpy as np 9 | import math 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torchvision.transforms import Compose,ToTensor,Normalize,Resize 14 | import torch.backends.cudnn as cudnn 15 | cudnn.benchmark=True 16 | import os 17 | os.environ['CUDA_VISIBLE_DEVICES']='0' 18 | torch.multiprocessing.set_sharing_strategy('file_system') 19 | 20 | def validation(network,dataloader,args): 21 | network.eval() 22 | pbar = tqdm(total=len(dataloader),ncols=100,leave=True) 23 | pbar.set_description('Inference') 24 | gallery_features = [] 25 | gallery_labels = [] 26 | gallery_cams = [] 27 | with torch.no_grad(): 28 | for c,data in enumerate(dataloader): 29 | seqs = data[0].cuda() 30 | label = data[1] 31 | cams = data[2] 32 | 33 | if args.model_type != 'resnet50_s1': 34 | B,C,H,W = seqs.shape 35 | seqs = seqs.reshape(B//args.S,args.S,C,H,W) 36 | feat = network(seqs)#.cpu().numpy() #[xx,128] 37 | if args.temporal == 'max': 38 | feat = torch.max(feat.reshape(feat.shape[0]//args.S,args.S,-1),dim=1)[0] 39 | elif args.temporal == 'mean': 40 | feat = torch.mean(feat.reshape(feat.shape[0]//args.S,args.S,-1),dim=1) 41 | elif args.temporal in ['Done'] : 42 | feat = feat 43 | 44 | gallery_features.append(feat.cpu()) 45 | gallery_labels.append(label) 46 | gallery_cams.append(cams) 47 | pbar.update(1) 48 | pbar.close() 49 | 50 | gallery_features = torch.cat(gallery_features,dim=0).numpy() 51 | gallery_labels = torch.cat(gallery_labels,dim=0).numpy() 52 | gallery_cams = torch.cat(gallery_cams,dim=0).numpy() 53 | 54 | Cmc,mAP = Video_Cmc(gallery_features,gallery_labels,gallery_cams,dataloader.dataset.query_idx,10000) 55 | network.train() 56 | 57 | return Cmc[0],mAP 58 | 59 | if __name__ == '__main__': 60 | #Parse args 61 | args = parser.parse_args() 62 | 63 | test_transform = Compose([Resize((256,128)),ToTensor(),Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])]) 64 | print('Start dataloader...') 65 | num_class = 625 66 | test_dataloader = utils.Get_Video_test_DataLoader(args.test_txt,args.test_info,args.query_info,test_transform,batch_size=args.batch_size,\ 67 | shuffle=False,num_workers=args.num_workers,S=args.S,distractor=True) 68 | print('End dataloader...') 69 | 70 | network = nn.DataParallel(models.CNN(args.latent_dim,model_type=args.model_type,num_class=num_class,non_layers=args.non_layers,stripes=args.stripes,temporal=args.temporal).cuda()) 71 | 72 | if args.load_ckpt is None: 73 | print('No ckpt!') 74 | exit() 75 | else: 76 | state = torch.load(args.load_ckpt) 77 | network.load_state_dict(state,strict=True) 78 | 79 | 80 | cmc,map = validation(network,test_dataloader,args) 81 | 82 | print('CMC : %.4f , mAP : %.4f'%(cmc,map)) 83 | -------------------------------------------------------------------------------- /fig/NVAN.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/STE-NVAN/3042bc8e4b4e5a7608123fcd121bde975e20fc50/fig/NVAN.jpg -------------------------------------------------------------------------------- /fig/STE-NVAN.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/STE-NVAN/3042bc8e4b4e5a7608123fcd121bde975e20fc50/fig/STE-NVAN.jpg -------------------------------------------------------------------------------- /net/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision import models 5 | import net.resnet as res 6 | 7 | def weights_init_kaiming(m): 8 | classname = m.__class__.__name__ 9 | if classname.find('Linear') != -1: 10 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out') 11 | nn.init.constant_(m.bias, 0.0) 12 | elif classname.find('Conv') != -1: 13 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 14 | if m.bias is not None: 15 | nn.init.constant_(m.bias, 0.0) 16 | elif classname.find('BatchNorm') != -1: 17 | if m.affine: 18 | nn.init.constant_(m.weight, 1.0) 19 | nn.init.constant_(m.bias, 0.0) 20 | 21 | def weights_init_classifier(m): 22 | classname = m.__class__.__name__ 23 | if classname.find('Linear') != -1: 24 | nn.init.normal_(m.weight, std=0.001) 25 | if m.bias: 26 | nn.init.constant_(m.bias, 0.0) 27 | 28 | 29 | class Resnet50_NL(nn.Module): 30 | def __init__(self,non_layers=[0,1,1,1],stripes=[16,16,16,16],non_type='normal',temporal=None): 31 | super(Resnet50_NL,self).__init__() 32 | original = models.resnet50(pretrained=True).state_dict() 33 | if non_type == 'normal': 34 | self.backbone = res.ResNet_Video_nonlocal(last_stride=1,non_layers=non_layers) 35 | elif non_type == 'stripe': 36 | self.backbone = res.ResNet_Video_nonlocal_stripe(last_stride = 1, non_layers=non_layers, stripes=stripes) 37 | elif non_type == 'hr': 38 | self.backbone = res.ResNet_Video_nonlocal_hr(last_stride = 1, non_layers=non_layers, stripes=stripes) 39 | elif non_type == 'stripe_hr': 40 | self.backbone = res.ResNet_Video_nonlocal_stripe_hr(last_stride = 1, non_layers=non_layers, stripes=stripes) 41 | for key in original: 42 | if key.find('fc') != -1: 43 | continue 44 | self.backbone.state_dict()[key].copy_(original[key]) 45 | del original 46 | 47 | self.temporal = temporal 48 | if self.temporal == 'Done': 49 | self.avgpool = nn.AdaptiveAvgPool3d(1) 50 | 51 | def forward(self,x): 52 | if self.temporal == 'Done': 53 | x = self.backbone(x) 54 | x = self.avgpool(x) 55 | x = x.reshape(x.shape[0],-1) 56 | return x 57 | 58 | 59 | class Resnet50_s1(nn.Module): 60 | def __init__(self,pooling=True,stride=1): 61 | super(Resnet50_s1,self).__init__() 62 | original = models.resnet50(pretrained=True).state_dict() 63 | self.backbone = res.ResNet(last_stride=stride) 64 | for key in original: 65 | if key.find('fc') != -1: 66 | continue 67 | self.backbone.state_dict()[key].copy_(original[key]) 68 | del original 69 | if pooling == True: 70 | self.add_module('avgpool',nn.AdaptiveAvgPool2d(1)) 71 | else: 72 | self.avgpool = None 73 | 74 | self.out_dim = 2048 75 | 76 | def forward(self,x): 77 | x = self.backbone(x) 78 | if self.avgpool is not None: 79 | x = self.avgpool(x) 80 | x = x.view(x.shape[0],-1) 81 | return x 82 | 83 | class CNN(nn.Module): 84 | def __init__(self,out_dim,model_type='resnet50_s1',num_class=710,non_layers=[1,2,2],stripes=[16,16,16,16], temporal = 'Done',stride=1): 85 | super(CNN,self).__init__() 86 | self.model_type = model_type 87 | if model_type == 'resnet50_s1': 88 | self.features = Resnet50_s1(stride=stride) 89 | elif model_type == 'resnet50_NL': 90 | self.features = Resnet50_NL(non_layers=non_layers,temporal=temporal,non_type='normal') 91 | elif model_type == 'resnet50_NL_stripe': 92 | self.features = Resnet50_NL(non_layers=non_layers,stripes=stripes,temporal=temporal,non_type='stripe') 93 | elif model_type == 'resnet50_NL_hr': 94 | self.features = Resnet50_NL(non_layers=non_layers,stripes=stripes,temporal=temporal,non_type='hr') 95 | elif model_type == 'resnet50_NL_stripe_hr': 96 | self.features = Resnet50_NL(non_layers=non_layers,stripes=stripes,temporal=temporal,non_type='stripe_hr') 97 | 98 | self.bottleneck = nn.BatchNorm1d(out_dim) 99 | self.bottleneck.bias.requires_grad_(False) # no shift 100 | self.bottleneck.apply(weights_init_kaiming) 101 | 102 | self.classifier = nn.Linear(out_dim,num_class, bias=False) 103 | self.classifier.apply(weights_init_classifier) 104 | 105 | def forward(self,x,seg=None): 106 | if self.model_type == 'resnet50_s1': 107 | x = self.features(x) 108 | bn = self.bottleneck(x) 109 | if self.training == True: 110 | output = self.classifier(bn) 111 | return x,output 112 | else: 113 | return bn 114 | elif self.model_type == 'resnet50_NL' or self.model_type == 'resnet50_NL_stripe' or \ 115 | self.model_type=='resnet50_NL_hr' or self.model_type == 'resnet50_NL_stripe_hr': 116 | x = self.features(x) 117 | bn = self.bottleneck(x) 118 | if self.training == True: 119 | output = self.classifier(bn) 120 | return x,output 121 | else: 122 | return bn 123 | 124 | if __name__ == '__main__': 125 | model = Resnet50_s1() 126 | input = torch.ones(1,3,256,128) 127 | output = model(input) 128 | print(output.shape) 129 | -------------------------------------------------------------------------------- /net/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch.nn import functional as F 3 | import numpy as np 4 | import os 5 | import torch 6 | from torch import nn 7 | ##################### Small Block ################################### 8 | class NonLocalBlock(nn.Module): 9 | def __init__(self, in_channels, inter_channels=None,sub_sample=False, bn_layer=True,instance='soft'): 10 | super(NonLocalBlock, self).__init__() 11 | self.sub_sample = sub_sample 12 | self.instance = instance 13 | self.in_channels = in_channels 14 | self.inter_channels = inter_channels 15 | 16 | if self.inter_channels is None: 17 | self.inter_channels = in_channels // 2 18 | if self.inter_channels == 0: 19 | self.inter_channels = 1 20 | 21 | conv_nd = nn.Conv3d 22 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) 23 | bn = nn.BatchNorm3d 24 | 25 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 26 | kernel_size=1, stride=1, padding=0) 27 | if bn_layer: 28 | self.W = nn.Sequential( 29 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 30 | kernel_size=1, stride=1, padding=0), 31 | bn(self.in_channels) 32 | ) 33 | nn.init.constant_(self.W[1].weight, 0) 34 | nn.init.constant_(self.W[1].bias, 0) 35 | else: 36 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 37 | kernel_size=1, stride=1, padding=0) 38 | nn.init.constant_(self.W.weight, 0) 39 | nn.init.constant_(self.W.bias, 0) 40 | 41 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 42 | kernel_size=1, stride=1, padding=0) 43 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 44 | kernel_size=1, stride=1, padding=0) 45 | if sub_sample: 46 | self.g = nn.Sequential(self.g, max_pool_layer) 47 | self.phi = nn.Sequential(self.phi, max_pool_layer) 48 | 49 | def forward(self, x): 50 | ''' 51 | :param x: (b, c, t, h, w) 52 | :return: 53 | ''' 54 | batch_size = x.size(0) 55 | 56 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 57 | g_x = g_x.permute(0, 2, 1) 58 | 59 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 60 | theta_x = theta_x.permute(0, 2, 1) 61 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 62 | f = torch.matmul(theta_x, phi_x) 63 | if self.instance == 'soft': 64 | f_div_C = F.softmax(f, dim=-1) 65 | elif self.instance == 'dot': 66 | f_div_C = f / f.shape[1] 67 | 68 | y = torch.matmul(f_div_C, g_x) 69 | y = y.permute(0, 2, 1).contiguous() 70 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 71 | W_y = self.W(y) 72 | z = W_y + x 73 | 74 | return z 75 | 76 | class Stripe_NonLocalBlock(nn.Module): 77 | def __init__(self,stripe,in_channels,inter_channels=None,pool_type='mean',instance='soft'): 78 | super(Stripe_NonLocalBlock,self).__init__() 79 | self.instance = instance 80 | self.stripe=stripe 81 | self.in_channels = in_channels 82 | self.pool_type = pool_type 83 | if pool_type == 'max': 84 | self.pool = nn.AdaptiveMaxPool2d(1) 85 | elif pool_type == 'mean': 86 | self.pool = nn.AdaptiveAvgPool2d(1) 87 | elif pool_type == 'meanmax': 88 | self.avgpool = nn.AdaptiveAvgPool2d(1) 89 | self.maxpool = nn.AdaptiveMaxPool2d(1) 90 | self.in_channels*=2 91 | if inter_channels == None: 92 | self.inter_channels = in_channels//2 93 | else: 94 | self.inter_channels = inter_channels 95 | 96 | self.g = nn.Conv3d(in_channels=self.in_channels, out_channels=self.inter_channels, 97 | kernel_size=1, stride=1, padding=0) 98 | self.theta = nn.Conv3d(in_channels=self.in_channels, out_channels=self.inter_channels, 99 | kernel_size=1, stride=1, padding=0) 100 | self.phi = nn.Conv3d(in_channels=self.in_channels, out_channels=self.inter_channels, 101 | kernel_size=1, stride=1, padding=0) 102 | if pool_type == 'meanmax': 103 | self.in_channels //=2 104 | 105 | self.W = nn.Sequential( 106 | nn.Conv3d(in_channels=self.inter_channels, out_channels=self.in_channels, 107 | kernel_size=1, stride=1, padding=0), 108 | nn.BatchNorm3d(self.in_channels) 109 | ) 110 | nn.init.constant_(self.W[1].weight, 0) 111 | nn.init.constant_(self.W[1].bias, 0) 112 | 113 | def forward(self,x): 114 | # x.shape = (b,c,t,h,w) 115 | b,c,t,h,w = x.shape 116 | assert self.stripe * (h//self.stripe) == h 117 | 118 | if self.pool_type == 'meanmax': 119 | discri_a = self.avgpool(x.reshape(b*c*t,self.stripe,(h//self.stripe),w)).reshape(b,c,t,self.stripe,1) 120 | discri_m = self.maxpool(x.reshape(b*c*t,self.stripe,(h//self.stripe),w)).reshape(b,c,t,self.stripe,1) 121 | discri = torch.cat([discri_a,discri_m],dim=1) 122 | else: 123 | discri = self.pool(x.reshape(b*c*t,self.stripe,(h//self.stripe),w)).reshape(b,c,t,self.stripe,1) 124 | g = self.g(discri).reshape(b,self.inter_channels,-1) 125 | g = g.permute(0,2,1) 126 | theta = self.theta(discri).reshape(b, self.inter_channels, -1) 127 | theta = theta.permute(0,2,1) 128 | phi = self.phi(discri).reshape(b, self.inter_channels, -1) 129 | 130 | f = torch.matmul(theta, phi) 131 | if self.instance == 'soft': 132 | f_div_C = F.softmax(f, dim=-1) 133 | elif self.instance == 'dot': 134 | f_div_C = f / f.shape[1] 135 | 136 | y = torch.matmul(f_div_C, g) 137 | y = y.permute(0, 2, 1).contiguous() 138 | y = y.reshape(b, self.inter_channels, *discri.size()[2:]) 139 | W_y = self.W(y) 140 | 141 | W_y = W_y.repeat(1,1,1,1,h//self.stripe*w).reshape(b,c,t,h,w) 142 | 143 | z = W_y + x 144 | return z 145 | 146 | class Bottleneck(nn.Module): 147 | expansion = 4 148 | def __init__(self, inplanes, planes, stride=1, downsample=None): 149 | super(Bottleneck, self).__init__() 150 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 151 | self.bn1 = nn.BatchNorm2d(planes) 152 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 153 | padding=1, bias=False) 154 | self.bn2 = nn.BatchNorm2d(planes) 155 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 156 | self.bn3 = nn.BatchNorm2d(planes * 4) 157 | self.relu = nn.ReLU(inplace=True) 158 | self.downsample = downsample 159 | self.stride = stride 160 | 161 | def forward(self, x): 162 | residual = x 163 | out = self.conv1(x) 164 | out = self.bn1(out) 165 | out = self.relu(out) 166 | 167 | out = self.conv2(out) 168 | out = self.bn2(out) 169 | out = self.relu(out) 170 | 171 | out = self.conv3(out) 172 | out = self.bn3(out) 173 | 174 | if self.downsample is not None: 175 | residual = self.downsample(x) 176 | out += residual 177 | out = self.relu(out) 178 | 179 | return out 180 | ############################################################################## 181 | 182 | ############################ backbone model ################################## 183 | class ResNet(nn.Module): 184 | def __init__(self, last_stride=1, block=Bottleneck, layers=[3, 4, 6, 3]): 185 | self.inplanes = 64 186 | super().__init__() 187 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 188 | bias=False) 189 | self.bn1 = nn.BatchNorm2d(64) 190 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 191 | self.layer1 = self._make_layer(block, 64, layers[0]) 192 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 193 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 194 | self.layer4 = self._make_layer( 195 | block, 512, layers[3], stride=last_stride) 196 | 197 | def _make_layer(self, block, planes, blocks, stride=1): 198 | downsample = None 199 | if stride != 1 or self.inplanes != planes * block.expansion: 200 | downsample = nn.Sequential( 201 | nn.Conv2d(self.inplanes, planes * block.expansion, 202 | kernel_size=1, stride=stride, bias=False), 203 | nn.BatchNorm2d(planes * block.expansion), 204 | ) 205 | layers = [] 206 | layers.append(block(self.inplanes, planes, stride, downsample)) 207 | self.inplanes = planes * block.expansion 208 | for i in range(1, blocks): 209 | layers.append(block(self.inplanes, planes)) 210 | return nn.Sequential(*layers) 211 | 212 | def forward(self, x): 213 | x = self.conv1(x) 214 | x = self.bn1(x) 215 | x = self.maxpool(x) 216 | 217 | x = self.layer1(x) 218 | x = self.layer2(x) 219 | x = self.layer3(x) 220 | x = self.layer4(x) 221 | 222 | return x 223 | 224 | class ResNet_Video_nonlocal(nn.Module): 225 | def __init__(self,last_stride=1,block=Bottleneck,layers=[3,4,6,3],non_layers=[0,1,1,1]): 226 | self.inplanes = 64 227 | super().__init__() 228 | self.conv1 = nn.Conv2d(3,64,kernel_size=7,stride=2,padding=3,bias=False) 229 | self.bn1 = nn.BatchNorm2d(64) 230 | self.maxpool = nn.MaxPool2d(kernel_size=3,stride=2,padding=1) 231 | self.layer1 = self._make_layer(block, 64, layers[0]) 232 | non_idx = 0 233 | self.NL_1 = nn.ModuleList([NonLocalBlock(self.inplanes,self.inplanes//2,sub_sample=True) for i in range(non_layers[non_idx])]) 234 | self.NL_1_idx = sorted([layers[0]-(i+1) for i in range(non_layers[non_idx])]) 235 | non_idx += 1 236 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 237 | self.NL_2 = nn.ModuleList([NonLocalBlock(self.inplanes,self.inplanes//2) for i in range(non_layers[non_idx])]) 238 | self.NL_2_idx = sorted([layers[1]-(i+1) for i in range(non_layers[non_idx])]) 239 | non_idx += 1 240 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 241 | self.NL_3 = nn.ModuleList([NonLocalBlock(self.inplanes,self.inplanes//2) for i in range(non_layers[non_idx])]) 242 | self.NL_3_idx =sorted( [layers[2]-(i+1) for i in range(non_layers[non_idx])]) 243 | non_idx += 1 244 | self.layer4 = self._make_layer(block, 512, layers[3], stride=last_stride) 245 | self.NL_4 = nn.ModuleList([NonLocalBlock(self.inplanes,self.inplanes//2) for i in range(non_layers[non_idx])]) 246 | self.NL_4_idx = sorted([layers[3]-(i+1) for i in range(non_layers[non_idx])]) 247 | 248 | def _make_layer(self, block, planes, blocks, stride=1): 249 | downsample = None 250 | if stride != 1 or self.inplanes != planes * block.expansion: 251 | downsample = nn.Sequential( 252 | nn.Conv2d(self.inplanes, planes * block.expansion, 253 | kernel_size=1, stride=stride, bias=False), 254 | nn.BatchNorm2d(planes * block.expansion), 255 | ) 256 | layers = [] 257 | layers.append(block(self.inplanes, planes, stride, downsample)) 258 | self.inplanes = planes * block.expansion 259 | for i in range(1, blocks): 260 | layers.append(block(self.inplanes, planes)) 261 | return nn.ModuleList(layers) 262 | 263 | def forward(self, x): 264 | # x 's shape (B,T,C,H,W) 265 | B,T,C,H,W = x.shape 266 | x = x.reshape(B*T,C,H,W) 267 | x = self.conv1(x) 268 | x = self.bn1(x) 269 | x = self.maxpool(x) 270 | 271 | # Layer 1 272 | NL1_counter = 0 273 | if len(self.NL_1_idx)==0: self.NL_1_idx=[-1] 274 | for i in range(len(self.layer1)): 275 | x = self.layer1[i](x) 276 | if i == self.NL_1_idx[NL1_counter]: 277 | _,C,H,W = x.shape 278 | x = x.reshape(B,T,C,H,W).permute(0,2,1,3,4) 279 | x = self.NL_1[NL1_counter](x) 280 | x = x.permute(0,2,1,3,4).reshape(B*T,C,H,W) 281 | NL1_counter+=1 282 | # Layer 2 283 | NL2_counter = 0 284 | if len(self.NL_2_idx)==0: self.NL_2_idx=[-1] 285 | for i in range(len(self.layer2)): 286 | x = self.layer2[i](x) 287 | if i == self.NL_2_idx[NL2_counter]: 288 | _,C,H,W = x.shape 289 | x = x.reshape(B,T,C,H,W).permute(0,2,1,3,4) 290 | x = self.NL_2[NL2_counter](x) 291 | x = x.permute(0,2,1,3,4).reshape(B*T,C,H,W) 292 | NL2_counter+=1 293 | # Layer 3 294 | NL3_counter = 0 295 | if len(self.NL_3_idx)==0: self.NL_3_idx=[-1] 296 | for i in range(len(self.layer3)): 297 | x = self.layer3[i](x) 298 | if i == self.NL_3_idx[NL3_counter]: 299 | _,C,H,W = x.shape 300 | x = x.reshape(B,T,C,H,W).permute(0,2,1,3,4) 301 | x = self.NL_3[NL3_counter](x) 302 | x = x.permute(0,2,1,3,4).reshape(B*T,C,H,W) 303 | NL3_counter+=1 304 | # Layer 4 305 | NL4_counter = 0 306 | if len(self.NL_4_idx)==0: self.NL_4_idx=[-1] 307 | for i in range(len(self.layer4)): 308 | x = self.layer4[i](x) 309 | if i == self.NL_4_idx[NL4_counter]: 310 | _,C,H,W = x.shape 311 | x = x.reshape(B,T,C,H,W).permute(0,2,1,3,4) 312 | x = self.NL_4[NL4_counter](x) 313 | x = x.permute(0,2,1,3,4).reshape(B*T,C,H,W) 314 | NL4_counter+=1 315 | _,C,H,W = x.shape 316 | x = x.reshape(B,T,C,H,W).permute(0,2,1,3,4) 317 | # Return is (B,C,T,H,W) 318 | return x 319 | 320 | class ResNet_Video_nonlocal_stripe(nn.Module): 321 | def __init__(self,last_stride=1,block=Bottleneck,layers=[3,4,6,3],non_layers=[0,1,1,1],stripes=[16,16,16,16]): 322 | self.inplanes = 64 323 | super().__init__() 324 | self.conv1 = nn.Conv2d(3,64,kernel_size=7,stride=2,padding=3,bias=False) 325 | self.bn1 = nn.BatchNorm2d(64) 326 | self.maxpool = nn.MaxPool2d(kernel_size=3,stride=2,padding=1) 327 | self.layer1 = self._make_layer(block, 64, layers[0]) 328 | non_idx = 0 329 | self.NL_1 = nn.ModuleList([Stripe_NonLocalBlock(stripes[non_idx],self.inplanes,self.inplanes//2) for i in range(non_layers[non_idx])]) 330 | self.NL_1_idx = sorted([layers[0]-(i+1) for i in range(non_layers[non_idx])]) 331 | non_idx += 1 332 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 333 | self.NL_2 = nn.ModuleList([Stripe_NonLocalBlock(stripes[non_idx],self.inplanes,self.inplanes//2) for i in range(non_layers[non_idx])]) 334 | self.NL_2_idx = sorted([layers[1]-(i+1) for i in range(non_layers[non_idx])]) 335 | non_idx += 1 336 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 337 | self.NL_3 = nn.ModuleList([Stripe_NonLocalBlock(stripes[non_idx],self.inplanes,self.inplanes//2) for i in range(non_layers[non_idx])]) 338 | self.NL_3_idx =sorted( [layers[2]-(i+1) for i in range(non_layers[non_idx])]) 339 | non_idx += 1 340 | self.layer4 = self._make_layer(block, 512, layers[3], stride=last_stride) 341 | self.NL_4 = nn.ModuleList([Stripe_NonLocalBlock(stripes[non_idx],self.inplanes,self.inplanes//2) for i in range(non_layers[non_idx])]) 342 | self.NL_4_idx = sorted([layers[3]-(i+1) for i in range(non_layers[non_idx])]) 343 | 344 | def _make_layer(self, block, planes, blocks, stride=1): 345 | downsample = None 346 | if stride != 1 or self.inplanes != planes * block.expansion: 347 | downsample = nn.Sequential( 348 | nn.Conv2d(self.inplanes, planes * block.expansion, 349 | kernel_size=1, stride=stride, bias=False), 350 | nn.BatchNorm2d(planes * block.expansion), 351 | ) 352 | layers = [] 353 | layers.append(block(self.inplanes, planes, stride, downsample)) 354 | self.inplanes = planes * block.expansion 355 | for i in range(1, blocks): 356 | layers.append(block(self.inplanes, planes)) 357 | return nn.ModuleList(layers) 358 | 359 | def forward(self, x): 360 | # x 's shape (B,T,C,H,W) 361 | B,T,C,H,W = x.shape 362 | x = x.reshape(B*T,C,H,W) 363 | x = self.conv1(x) 364 | x = self.bn1(x) 365 | x = self.maxpool(x) 366 | 367 | # Layer 1 368 | NL1_counter = 0 369 | if len(self.NL_1_idx)==0: self.NL_1_idx=[-1] 370 | for i in range(len(self.layer1)): 371 | x = self.layer1[i](x) 372 | if i == self.NL_1_idx[NL1_counter]: 373 | _,C,H,W = x.shape 374 | x = x.reshape(B,T,C,H,W).permute(0,2,1,3,4) 375 | x = self.NL_1[NL1_counter](x) 376 | x = x.permute(0,2,1,3,4).reshape(B*T,C,H,W) 377 | NL1_counter+=1 378 | # Layer 2 379 | NL2_counter = 0 380 | if len(self.NL_2_idx)==0: self.NL_2_idx=[-1] 381 | for i in range(len(self.layer2)): 382 | x = self.layer2[i](x) 383 | if i == self.NL_2_idx[NL2_counter]: 384 | _,C,H,W = x.shape 385 | x = x.reshape(B,T,C,H,W).permute(0,2,1,3,4) 386 | x = self.NL_2[NL2_counter](x) 387 | x = x.permute(0,2,1,3,4).reshape(B*T,C,H,W) 388 | NL2_counter+=1 389 | # Layer 3 390 | NL3_counter = 0 391 | if len(self.NL_3_idx)==0: self.NL_3_idx=[-1] 392 | for i in range(len(self.layer3)): 393 | x = self.layer3[i](x) 394 | if i == self.NL_3_idx[NL3_counter]: 395 | _,C,H,W = x.shape 396 | x = x.reshape(B,T,C,H,W).permute(0,2,1,3,4) 397 | x = self.NL_3[NL3_counter](x) 398 | x = x.permute(0,2,1,3,4).reshape(B*T,C,H,W) 399 | NL3_counter+=1 400 | # Layer 4 401 | NL4_counter = 0 402 | if len(self.NL_4_idx)==0: self.NL_4_idx=[-1] 403 | for i in range(len(self.layer4)): 404 | x = self.layer4[i](x) 405 | if i == self.NL_4_idx[NL4_counter]: 406 | _,C,H,W = x.shape 407 | x = x.reshape(B,T,C,H,W).permute(0,2,1,3,4) 408 | x = self.NL_4[NL4_counter](x) 409 | x = x.permute(0,2,1,3,4).reshape(B*T,C,H,W) 410 | NL4_counter+=1 411 | _,C,H,W = x.shape 412 | x = x.reshape(B,T,C,H,W).permute(0,2,1,3,4) 413 | # Return is (B,C,T,H,W) 414 | return x 415 | 416 | class ResNet_Video_nonlocal_hr(nn.Module): 417 | def __init__(self,last_stride=1,block=Bottleneck,layers=[3,4,6,3],non_layers=[0,1,1,1],stripes=[16,16,16,16]): 418 | self.inplanes = 64 419 | super().__init__() 420 | self.conv1 = nn.Conv2d(3,64,kernel_size=7,stride=2,padding=3,bias=False) 421 | self.bn1 = nn.BatchNorm2d(64) 422 | self.maxpool = nn.MaxPool2d(kernel_size=3,stride=2,padding=1) 423 | self.layer1 = self._make_layer(block, 64, layers[0]) 424 | non_idx = 0 425 | self.NL_1 = nn.ModuleList([NonLocalBlock(self.inplanes,self.inplanes//2) for i in range(non_layers[non_idx])]) 426 | self.NL_1_idx = sorted([layers[0]-(i+1) for i in range(non_layers[non_idx])]) 427 | non_idx += 1 428 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 429 | self.NL_2 = nn.ModuleList([NonLocalBlock(self.inplanes,self.inplanes//2) for i in range(non_layers[non_idx])]) 430 | self.NL_2_idx = sorted([layers[1]-(i+1) for i in range(non_layers[non_idx])]) 431 | non_idx += 1 432 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 433 | self.NL_3 = nn.ModuleList([NonLocalBlock(self.inplanes,self.inplanes//2) for i in range(non_layers[non_idx])]) 434 | self.NL_3_idx =sorted( [layers[2]-(i+1) for i in range(non_layers[non_idx])]) 435 | non_idx += 1 436 | self.layer4 = self._make_layer(block, 512, layers[3], stride=last_stride) 437 | self.NL_4 = nn.ModuleList([NonLocalBlock(self.inplanes,self.inplanes//2) for i in range(non_layers[non_idx])]) 438 | self.NL_4_idx = sorted([layers[3]-(i+1) for i in range(non_layers[non_idx])]) 439 | 440 | def _make_layer(self, block, planes, blocks, stride=1): 441 | downsample = None 442 | if stride != 1 or self.inplanes != planes * block.expansion: 443 | downsample = nn.Sequential( 444 | nn.Conv2d(self.inplanes, planes * block.expansion, 445 | kernel_size=1, stride=stride, bias=False), 446 | nn.BatchNorm2d(planes * block.expansion), 447 | ) 448 | layers = [] 449 | layers.append(block(self.inplanes, planes, stride, downsample)) 450 | self.inplanes = planes * block.expansion 451 | for i in range(1, blocks): 452 | layers.append(block(self.inplanes, planes)) 453 | return nn.ModuleList(layers) 454 | 455 | def forward(self, x): 456 | # x 's shape (B,T,C,H,W) 457 | B,T,C,H,W = x.shape 458 | x = x.reshape(B*T,C,H,W) 459 | x = self.conv1(x) 460 | x = self.bn1(x) 461 | x = self.maxpool(x) 462 | # x 's shape (B*T,C,H,W) 463 | 464 | # Layer 1 465 | NL1_counter = 0 466 | if len(self.NL_1_idx)==0: self.NL_1_idx=[-1] 467 | for i in range(len(self.layer1)): 468 | x = self.layer1[i](x) 469 | if i == self.NL_1_idx[NL1_counter]: 470 | _,C,H,W = x.shape 471 | x = x.reshape(-1,2,C,H,W).permute(0,2,1,3,4) 472 | x = self.NL_1[NL1_counter](x) 473 | x = x.permute(0,2,1,3,4).reshape(-1,C,H,W) 474 | # x's shape (B*T//2,2,C,H,W) 475 | NL1_counter+=1 476 | # Max pool 477 | # _,C,H,W = x.shape 478 | # x = torch.max(x.reshape(-1,2,C,H,W),dim=1)[0] 479 | # T = T//2 480 | # Layer 2 481 | NL2_counter = 0 482 | if len(self.NL_2_idx)==0: self.NL_2_idx=[-1] 483 | for i in range(len(self.layer2)): 484 | x = self.layer2[i](x) 485 | if i == self.NL_2_idx[NL2_counter]: 486 | _,C,H,W = x.shape 487 | x = x.reshape(-1,T,C,H,W).permute(0,2,1,3,4) 488 | x = self.NL_2[NL2_counter](x) 489 | x = x.permute(0,2,1,3,4).reshape(-1,C,H,W) 490 | # x's shape (B*T//2,2,C,H,W) 491 | NL2_counter+=1 492 | # Max pool 493 | _,C,H,W = x.shape 494 | x = torch.max(x.reshape(-1,2,C,H,W),dim=1)[0] 495 | T = T//2 496 | # Layer 3 497 | NL3_counter = 0 498 | if len(self.NL_3_idx)==0: self.NL_3_idx=[-1] 499 | for i in range(len(self.layer3)): 500 | x = self.layer3[i](x) 501 | if i == self.NL_3_idx[NL3_counter]: 502 | _,C,H,W = x.shape 503 | x = x.reshape(-1,T,C,H,W).permute(0,2,1,3,4) 504 | x = self.NL_3[NL3_counter](x) 505 | x = x.permute(0,2,1,3,4).reshape(-1,C,H,W) 506 | # x's shape (B*T//2,2,C,H,W) 507 | NL3_counter+=1 508 | # Max pool 509 | _,C,H,W = x.shape 510 | x = torch.max(x.reshape(-1,2,C,H,W),dim=1)[0] 511 | T = T//2 512 | # Layer 4 513 | NL4_counter = 0 514 | if len(self.NL_4_idx)==0: self.NL_4_idx=[-1] 515 | for i in range(len(self.layer4)): 516 | x = self.layer4[i](x) 517 | if i == self.NL_4_idx[NL4_counter]: 518 | _,C,H,W = x.shape 519 | x = x.reshape(-1,T,C,H,W).permute(0,2,1,3,4) 520 | x = self.NL_4[NL4_counter](x) 521 | x = x.permute(0,2,1,3,4).reshape(-1,C,H,W) 522 | NL4_counter+=1 523 | _,C,H,W = x.shape 524 | x = x.reshape(B,T,C,H,W).permute(0,2,1,3,4) 525 | # Return is (B,C,T,H,W) 526 | return x 527 | 528 | class ResNet_Video_nonlocal_stripe_hr(nn.Module): 529 | def __init__(self,last_stride=1,block=Bottleneck,layers=[3,4,6,3],non_layers=[0,1,1,1],stripes=[16,16,16,16]): 530 | self.inplanes = 64 531 | super().__init__() 532 | self.conv1 = nn.Conv2d(3,64,kernel_size=7,stride=2,padding=3,bias=False) 533 | self.bn1 = nn.BatchNorm2d(64) 534 | self.maxpool = nn.MaxPool2d(kernel_size=3,stride=2,padding=1) 535 | self.layer1 = self._make_layer(block, 64, layers[0]) 536 | non_idx = 0 537 | self.NL_1 = nn.ModuleList([Stripe_NonLocalBlock(stripes[non_idx],self.inplanes,self.inplanes//2) for i in range(non_layers[non_idx])]) 538 | self.NL_1_idx = sorted([layers[0]-(i+1) for i in range(non_layers[non_idx])]) 539 | non_idx += 1 540 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 541 | self.NL_2 = nn.ModuleList([Stripe_NonLocalBlock(stripes[non_idx],self.inplanes,self.inplanes//2) for i in range(non_layers[non_idx])]) 542 | self.NL_2_idx = sorted([layers[1]-(i+1) for i in range(non_layers[non_idx])]) 543 | non_idx += 1 544 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 545 | self.NL_3 = nn.ModuleList([Stripe_NonLocalBlock(stripes[non_idx],self.inplanes,self.inplanes//2) for i in range(non_layers[non_idx])]) 546 | self.NL_3_idx =sorted( [layers[2]-(i+1) for i in range(non_layers[non_idx])]) 547 | non_idx += 1 548 | self.layer4 = self._make_layer(block, 512, layers[3], stride=last_stride) 549 | self.NL_4 = nn.ModuleList([Stripe_NonLocalBlock(stripes[non_idx],self.inplanes,self.inplanes//2) for i in range(non_layers[non_idx])]) 550 | self.NL_4_idx = sorted([layers[3]-(i+1) for i in range(non_layers[non_idx])]) 551 | 552 | def _make_layer(self, block, planes, blocks, stride=1): 553 | downsample = None 554 | if stride != 1 or self.inplanes != planes * block.expansion: 555 | downsample = nn.Sequential( 556 | nn.Conv2d(self.inplanes, planes * block.expansion, 557 | kernel_size=1, stride=stride, bias=False), 558 | nn.BatchNorm2d(planes * block.expansion), 559 | ) 560 | layers = [] 561 | layers.append(block(self.inplanes, planes, stride, downsample)) 562 | self.inplanes = planes * block.expansion 563 | for i in range(1, blocks): 564 | layers.append(block(self.inplanes, planes)) 565 | return nn.ModuleList(layers) 566 | 567 | def forward(self, x): 568 | # x 's shape (B,T,C,H,W) 569 | B,T,C,H,W = x.shape 570 | x = x.reshape(B*T,C,H,W) 571 | x = self.conv1(x) 572 | x = self.bn1(x) 573 | x = self.maxpool(x) 574 | # x 's shape (B*T,C,H,W) 575 | 576 | # Layer 1 577 | NL1_counter = 0 578 | if len(self.NL_1_idx)==0: self.NL_1_idx=[-1] 579 | for i in range(len(self.layer1)): 580 | x = self.layer1[i](x) 581 | if i == self.NL_1_idx[NL1_counter]: 582 | _,C,H,W = x.shape 583 | x = x.reshape(-1,2,C,H,W).permute(0,2,1,3,4) 584 | x = self.NL_1[NL1_counter](x) 585 | x = x.permute(0,2,1,3,4).reshape(-1,C,H,W) 586 | # x's shape (B*T//2,2,C,H,W) 587 | NL1_counter+=1 588 | # Max pool 589 | # _,C,H,W = x.shape 590 | # x = torch.max(x.reshape(-1,2,C,H,W),dim=1)[0] 591 | # T = T//2 592 | # Layer 2 593 | NL2_counter = 0 594 | if len(self.NL_2_idx)==0: self.NL_2_idx=[-1] 595 | for i in range(len(self.layer2)): 596 | x = self.layer2[i](x) 597 | if i == self.NL_2_idx[NL2_counter]: 598 | _,C,H,W = x.shape 599 | x = x.reshape(-1,2,C,H,W).permute(0,2,1,3,4) 600 | x = self.NL_2[NL2_counter](x) 601 | x = x.permute(0,2,1,3,4).reshape(-1,C,H,W) 602 | # x's shape (B*T//2,2,C,H,W) 603 | NL2_counter+=1 604 | # Max pool 605 | _,C,H,W = x.shape 606 | x = torch.max(x.reshape(-1,2,C,H,W),dim=1)[0] 607 | T = T//2 608 | # Layer 3 609 | NL3_counter = 0 610 | if len(self.NL_3_idx)==0: self.NL_3_idx=[-1] 611 | for i in range(len(self.layer3)): 612 | x = self.layer3[i](x) 613 | if i == self.NL_3_idx[NL3_counter]: 614 | _,C,H,W = x.shape 615 | x = x.reshape(-1,2,C,H,W).permute(0,2,1,3,4) 616 | x = self.NL_3[NL3_counter](x) 617 | x = x.permute(0,2,1,3,4).reshape(-1,C,H,W) 618 | # x's shape (B*T//2,2,C,H,W) 619 | NL3_counter+=1 620 | # Max pool 621 | _,C,H,W = x.shape 622 | x = torch.max(x.reshape(-1,2,C,H,W),dim=1)[0] 623 | T = T//2 624 | # Layer 4 625 | NL4_counter = 0 626 | if len(self.NL_4_idx)==0: self.NL_4_idx=[-1] 627 | for i in range(len(self.layer4)): 628 | x = self.layer4[i](x) 629 | if i == self.NL_4_idx[NL4_counter]: 630 | _,C,H,W = x.shape 631 | x = x.reshape(-1,2,C,H,W).permute(0,2,1,3,4) 632 | x = self.NL_4[NL4_counter](x) 633 | x = x.permute(0,2,1,3,4).reshape(-1,C,H,W) 634 | NL4_counter+=1 635 | _,C,H,W = x.shape 636 | x = x.reshape(B,T,C,H,W).permute(0,2,1,3,4) 637 | # Return is (B,C,T,H,W) 638 | return x 639 | if __name__ == "__main__": 640 | net = ResNet(last_stride=1) 641 | print(net) 642 | import torch 643 | 644 | x = net(torch.zeros(1, 3, 256, 128)) 645 | print(x.shape) 646 | -------------------------------------------------------------------------------- /parser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def parse_args(): 4 | parser = argparse.ArgumentParser(description='Train Video-based Re-ID',formatter_class=argparse.ArgumentDefaultsHelpFormatter) 5 | parser.add_argument('--train_txt',help='txt for train dataset') 6 | parser.add_argument('--train_info',help='npy for train dataset') 7 | parser.add_argument('--test_txt',help='txt for test dataset') 8 | parser.add_argument('--test_info',help='npy for test dataset') 9 | parser.add_argument('--query_info',help='npy for test dataset') 10 | parser.add_argument('--lr',type=float,default=0.001,help='learning rate') 11 | parser.add_argument('--lr_step_size',type=int,default=100,help='step size of lr') 12 | parser.add_argument('--class_per_batch',type=int,default=16) 13 | parser.add_argument('--track_per_class',type=int,default=3) 14 | parser.add_argument('--batch_size',type=int,default=32) 15 | parser.add_argument('--n_epochs',type=int,default=500) 16 | parser.add_argument('--num_workers',type=int,default=16) 17 | parser.add_argument('--S',type=int,default=6) 18 | parser.add_argument('--latent_dim',type=int,default=2048,help='resnet50:2048,densenet121:1024,densenet169:1664') 19 | parser.add_argument('--load_ckpt',type=str,default=None) 20 | parser.add_argument('--log_path',type=str,default='loss.txt') 21 | parser.add_argument('--ckpt',type=str,default=None) 22 | parser.add_argument('--optimizer',type=str,default='adam') 23 | parser.add_argument('--resume_validation',type=bool,default=False) 24 | parser.add_argument('--model_type',type=str,default='resnet50') 25 | parser.add_argument('--stride',type=int,default=1) 26 | parser.add_argument('--temporal',default='mean') 27 | parser.add_argument('--frame_id_loss',action='store_true',default=False) 28 | parser.add_argument('--track_id_loss',action='store_true',default=False) 29 | parser.add_argument('--non_layers',type=int, nargs='+') 30 | parser.add_argument('--stripes',type=int, nargs='+') 31 | 32 | 33 | # parser.add_argument( 34 | args = parser.parse_args() 35 | 36 | return args 37 | -------------------------------------------------------------------------------- /run_NL.sh: -------------------------------------------------------------------------------- 1 | TRAIN_TXT=./MARS_database/train_path.txt 2 | TRAIN_INFO=./MARS_database/train_info.npy 3 | TEST_TXT=./MARS_database/test_path.txt 4 | TEST_INFO=./MARS_database/test_info.npy 5 | QUERY_INFO=./MARS_database/query_IDX.npy 6 | 7 | # For NVAN 8 | CKPT=ckpt_NL_0230 9 | python3 train_NL.py --train_txt $TRAIN_TXT --train_info $TRAIN_INFO --batch_size 64 \ 10 | --test_txt $TEST_TXT --test_info $TEST_INFO --query_info $QUERY_INFO \ 11 | --n_epochs 200 --lr 0.0001 --lr_step_size 50 --optimizer adam --ckpt $CKPT --log_path loss.txt --class_per_batch 8 \ 12 | --model_type 'resnet50_NL' --num_workers 8 --track_per_class 4 --S 8 --latent_dim 2048 --temporal Done --track_id_loss \ 13 | --non_layers 0 2 3 0 14 | 15 | # For STE-NVAN 16 | #CKPT=ckpt_NL_stripe16_hr_0230 17 | #python3 train_NL.py --train_txt $TRAIN_TXT --train_info $TRAIN_INFO --batch_size 64 \ 18 | #--test_txt $TEST_TXT --test_info $TEST_INFO --query_info $QUERY_INFO \ 19 | #--n_epochs 200 --lr 0.0001 --lr_step_size 50 --optimizer adam --ckpt $CKPT --log_path loss.txt --class_per_batch 8 \ 20 | #--model_type 'resnet50_NL_stripe_hr' --num_workers 8 --track_per_class 4 --S 8 --latent_dim 2048 --temporal Done --track_id_loss \ 21 | #--non_layers 0 2 3 0 --stripes 16 16 16 16 22 | -------------------------------------------------------------------------------- /run_baseline.sh: -------------------------------------------------------------------------------- 1 | TRAIN_TXT=./MARS_database/train_path.txt 2 | TRAIN_INFO=./MARS_database/train_info.npy 3 | TEST_TXT=./MARS_database/test_path.txt 4 | TEST_INFO=./MARS_database/test_info.npy 5 | QUERY_INFO=./MARS_database/query_IDX.npy 6 | 7 | CKPT=ckpt_baseline_mean 8 | python3 train_baseline.py --train_txt $TRAIN_TXT --train_info $TRAIN_INFO --batch_size 64 \ 9 | --test_txt $TEST_TXT --test_info $TEST_INFO --query_info $QUERY_INFO \ 10 | --n_epochs 300 --lr 0.0001 --lr_step_size 50 --optimizer adam --ckpt $CKPT --log_path loss.txt \ 11 | --model_type 'resnet50_s1' --num_workers 8 --class_per_batch 8 --track_per_class 4 --S 8 \ 12 | --latent_dim 2048 --temporal mean --track_id_loss --stride 1 \ 13 | -------------------------------------------------------------------------------- /run_evaluate.sh: -------------------------------------------------------------------------------- 1 | TEST_TXT=./MARS_database/test_path.txt 2 | TEST_INFO=./MARS_database/test_info.npy 3 | QUERY_INFO=./MARS_database/query_IDX.npy 4 | 5 | # Evaluate ResNet50 + FPL (mean or max) 6 | #LOAD_CKPT=./ckpt/R50_baseline_mean.pth 7 | #python3 evaluate.py --test_txt $TEST_TXT --test_info $TEST_INFO --query_info $QUERY_INFO \ 8 | #--batch_size 64 --model_type 'resnet50_s1' --num_workers 8 --S 8 \ 9 | #--latent_dim 2048 --temporal mean --stride 1 --load_ckpt $LOAD_CKPT 10 | #Evaluate NVAN (R50 + 5 NL + FPL) 11 | LOAD_CKPT=./ckpt/NVAN.pth 12 | python3 evaluate.py --test_txt $TEST_TXT --test_info $TEST_INFO --query_info $QUERY_INFO \ 13 | --batch_size 64 --model_type 'resnet50_NL' --num_workers 8 --S 8 --latent_dim 2048 \ 14 | --temporal Done --non_layers 0 2 3 0 --load_ckpt $LOAD_CKPT \ 15 | 16 | # Evaluate NVAN (R50 + 5 NL + Stripe + Hierarchical + FPL) 17 | #LOAD_CKPT=./ckpt/STE_NVAN.pth 18 | #python3 evaluate.py --test_txt $TEST_TXT --test_info $TEST_INFO --query_info $QUERY_INFO \ 19 | #--batch_size 128 --model_type 'resnet50_NL_stripe_hr' --num_workers 8 --S 8 --latent_dim 2048 \ 20 | #--temporal Done --non_layers 0 2 3 0 --stripe 16 16 16 16 --load_ckpt $LOAD_CKPT \ 21 | -------------------------------------------------------------------------------- /train_NL.py: -------------------------------------------------------------------------------- 1 | from util import utils 2 | import parser 3 | from net import models 4 | import sys 5 | import random 6 | from tqdm import tqdm 7 | import numpy as np 8 | import math 9 | from util.loss import TripletLoss 10 | from util.cmc import Video_Cmc 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.optim as optim 15 | from torchvision.transforms import Compose,ToTensor,Normalize,Resize 16 | import torch.backends.cudnn as cudnn 17 | cudnn.benchmark=True 18 | import os 19 | os.environ['CUDA_VISIBLE_DEVICES']='0' 20 | torch.multiprocessing.set_sharing_strategy('file_system') 21 | 22 | 23 | def validation(network,dataloader,args): 24 | network.eval() 25 | pbar = tqdm(total=len(dataloader),ncols=100,leave=True) 26 | pbar.set_description('Inference') 27 | gallery_features = [] 28 | gallery_labels = [] 29 | gallery_cams = [] 30 | with torch.no_grad(): 31 | for c,data in enumerate(dataloader): 32 | seqs = data[0].cuda() 33 | label = data[1] 34 | cams = data[2] 35 | 36 | B,C,H,W = seqs.shape 37 | seqs = seqs.reshape(B//args.S,args.S,C,H,W) 38 | feat = network(seqs)#.cpu().numpy() #[xx,128] 39 | if args.temporal == 'max': 40 | feat = torch.max(feat.reshape(feat.shape[0]//args.S,args.S,-1),dim=1)[0] 41 | elif args.temporal == 'mean': 42 | feat = torch.mean(feat.reshape(feat.shape[0]//args.S,args.S,-1),dim=1) 43 | elif args.temporal in ['Done'] : 44 | feat = feat 45 | 46 | gallery_features.append(feat.cpu()) 47 | gallery_labels.append(label) 48 | gallery_cams.append(cams) 49 | pbar.update(1) 50 | pbar.close() 51 | 52 | gallery_features = torch.cat(gallery_features,dim=0).numpy() 53 | gallery_labels = torch.cat(gallery_labels,dim=0).numpy() 54 | gallery_cams = torch.cat(gallery_cams,dim=0).numpy() 55 | 56 | Cmc,mAP = Video_Cmc(gallery_features,gallery_labels,gallery_cams,dataloader.dataset.query_idx,10000) 57 | network.train() 58 | 59 | return Cmc[0],mAP 60 | 61 | 62 | if __name__ == '__main__': 63 | #Parse args 64 | args = parser.parse_args() 65 | 66 | # set transformation (H flip is inside dataset) 67 | train_transform = Compose([Resize((256,128)),ToTensor(),Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])]) 68 | test_transform = Compose([Resize((256,128)),ToTensor(),Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])]) 69 | print('Start dataloader...') 70 | train_dataloader = utils.Get_Video_train_DataLoader(args.train_txt,args.train_info, train_transform, shuffle=True,num_workers=args.num_workers,\ 71 | S=args.S,track_per_class=args.track_per_class,class_per_batch=args.class_per_batch) 72 | num_class = train_dataloader.dataset.n_id 73 | test_dataloader = utils.Get_Video_test_DataLoader(args.test_txt,args.test_info,args.query_info,test_transform,batch_size=args.batch_size,\ 74 | shuffle=False,num_workers=args.num_workers,S=args.S,distractor=True) 75 | print('End dataloader...') 76 | 77 | network = nn.DataParallel(models.CNN(args.latent_dim,model_type=args.model_type,num_class=num_class,non_layers=args.non_layers,stripes=args.stripes,temporal=args.temporal).cuda()) 78 | if args.load_ckpt is not None: 79 | state = torch.load(args.load_ckpt) 80 | network.load_state_dict(state,strict=False) 81 | # log 82 | os.system('mkdir -p %s'%(args.ckpt)) 83 | f = open(os.path.join(args.ckpt,args.log_path),'a') 84 | f.close() 85 | 86 | # Train loop 87 | # 1. Criterion 88 | criterion_triplet = TripletLoss('soft',True) 89 | 90 | critetion_id = nn.CrossEntropyLoss().cuda() 91 | # 2. Optimizer 92 | if args.optimizer == 'sgd': 93 | optimizer = optim.SGD(network.parameters(),lr = args.lr,momentum=0.9,weight_decay = 1e-4) 94 | else: 95 | optimizer = optim.Adam(network.parameters(),lr = args.lr,weight_decay = 5e-5) 96 | if args.lr_step_size != 0: 97 | scheduler = optim.lr_scheduler.StepLR(optimizer, args.lr_step_size, 0.1) 98 | 99 | id_loss_list = [] 100 | trip_loss_list = [] 101 | track_id_loss_list = [] 102 | best_cmc = 0 103 | for e in range(args.n_epochs): 104 | print('epoch',e) 105 | if (e+1)%10 == 0: 106 | cmc,map = validation(network,test_dataloader,args) 107 | print('CMC: %.4f, mAP : %.4f'%(cmc,map)) 108 | f = open(os.path.join(args.ckpt,args.log_path),'a') 109 | f.write('epoch %d, rank-1 %f , mAP %f\n'%(e,cmc,map)) 110 | if args.frame_id_loss: 111 | f.write('Frame ID loss : %r\n'%(id_loss_list)) 112 | if args.track_id_loss: 113 | f.write('Track ID loss : %r\n'%(track_id_loss_list)) 114 | f.write('Trip Loss : %r\n'%(trip_loss_list)) 115 | 116 | id_loss_list = [] 117 | trip_loss_list = [] 118 | track_id_loss_list = [] 119 | if cmc >= best_cmc: 120 | torch.save(network.state_dict(),os.path.join(args.ckpt,'ckpt_best.pth')) 121 | best_cmc = cmc 122 | f.write('best\n') 123 | f.close() 124 | 125 | total_id_loss = 0 126 | total_trip_loss = 0 127 | total_track_id_loss = 0 128 | pbar = tqdm(total=len(train_dataloader),ncols=100,leave=True) 129 | for i,data in enumerate(train_dataloader): 130 | seqs = data[0]#.cuda() 131 | labels = data[1].cuda() 132 | B,T,C,H,W = seqs.shape 133 | feat, output = network(seqs) 134 | 135 | if args.temporal == 'max': 136 | pool_feat = torch.max(feat.reshape(feat.shape[0]//args.S,args.S,-1),dim=1)[0] 137 | pool_output = torch.max(output.reshape(output.shape[0]//args.S,args.S,-1),dim=1)[0] 138 | elif args.temporal == 'mean': 139 | pool_feat = torch.mean(feat.reshape(feat.shape[0]//args.S,args.S,-1),dim=1) 140 | pool_output = torch.mean(output.reshape(output.shape[0]//args.S,args.S,-1),dim=1) 141 | elif args.temporal in ['Done'] : 142 | pool_feat = feat 143 | pool_output = output 144 | 145 | trip_loss = criterion_triplet(pool_feat,labels,dis_func='eu') 146 | total_trip_loss += trip_loss.mean().item() 147 | total_loss = trip_loss.mean() 148 | 149 | # Frame level ID loss 150 | if args.frame_id_loss == True: 151 | expand_labels = (labels.unsqueeze(1)).repeat(1,args.S).reshape(-1) 152 | id_loss = critetion_id(output,expand_labels) 153 | total_id_loss += id_loss.item() 154 | coeff = 1 155 | total_loss += coeff*id_loss 156 | if args.track_id_loss == True: 157 | track_id_loss = critetion_id(pool_output,labels) 158 | total_track_id_loss += track_id_loss.item() 159 | coeff = 1 160 | total_loss += coeff*track_id_loss 161 | 162 | 163 | ##################### 164 | optimizer.zero_grad() 165 | total_loss.backward() 166 | optimizer.step() 167 | pbar.update(1) 168 | pbar.close() 169 | 170 | if args.lr_step_size !=0: 171 | scheduler.step() 172 | 173 | avg_id_loss = '%.4f'%(total_id_loss/len(train_dataloader)) 174 | avg_trip_loss = '%.4f'%(total_trip_loss/len(train_dataloader)) 175 | avg_track_id_loss = '%.4f'%(total_track_id_loss/len(train_dataloader)) 176 | print('Trip : %s , ID : %s , Track_ID : %s'%(avg_trip_loss,avg_id_loss,avg_track_id_loss)) 177 | id_loss_list.append(avg_id_loss) 178 | trip_loss_list.append(avg_trip_loss) 179 | track_id_loss_list.append(avg_track_id_loss) 180 | -------------------------------------------------------------------------------- /train_baseline.py: -------------------------------------------------------------------------------- 1 | from util import utils 2 | import parser 3 | from net import models 4 | import sys 5 | import random 6 | from tqdm import tqdm 7 | import numpy as np 8 | import math 9 | from util.loss import TripletLoss 10 | from util.cmc import Video_Cmc 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.optim as optim 15 | from torchvision.transforms import Compose,ToTensor,Normalize,Resize 16 | import torch.backends.cudnn as cudnn 17 | cudnn.benchmark=True 18 | import os 19 | os.environ['CUDA_VISIBLE_DEVICES']='0' 20 | torch.multiprocessing.set_sharing_strategy('file_system') 21 | 22 | 23 | def validation(network,dataloader,args): 24 | network.eval() 25 | pbar = tqdm(total=len(dataloader),ncols=100,leave=True) 26 | pbar.set_description('Inference') 27 | gallery_features = [] 28 | gallery_labels = [] 29 | gallery_cams = [] 30 | with torch.no_grad(): 31 | for c,data in enumerate(dataloader): 32 | seqs = data[0].cuda() 33 | label = data[1] 34 | cams = data[2] 35 | 36 | feat = network(seqs)#.cpu().numpy() #[xx,128] 37 | if args.temporal == 'max': 38 | feat = torch.max(feat.reshape(feat.shape[0]//args.S,args.S,-1),dim=1)[0] 39 | elif args.temporal == 'mean': 40 | feat = torch.mean(feat.reshape(feat.shape[0]//args.S,args.S,-1),dim=1) 41 | elif args.temporal =='Done': 42 | feat = feat 43 | 44 | gallery_features.append(feat.cpu()) 45 | gallery_labels.append(label) 46 | gallery_cams.append(cams) 47 | pbar.update(1) 48 | pbar.close() 49 | 50 | gallery_features = torch.cat(gallery_features,dim=0).numpy() 51 | gallery_labels = torch.cat(gallery_labels,dim=0).numpy() 52 | gallery_cams = torch.cat(gallery_cams,dim=0).numpy() 53 | 54 | Cmc,mAP = Video_Cmc(gallery_features,gallery_labels,gallery_cams,dataloader.dataset.query_idx,10000) 55 | network.train() 56 | 57 | return Cmc[0],mAP 58 | 59 | 60 | 61 | if __name__ == '__main__': 62 | #Parse args 63 | args = parser.parse_args() 64 | 65 | # set transformation (H flip is inside dataset) 66 | train_transform = Compose([Resize((256,128)),ToTensor(),Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])]) 67 | test_transform = Compose([Resize((256,128)),ToTensor(),Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])]) 68 | 69 | print('Start dataloader...') 70 | train_dataloader = utils.Get_Video_train_DataLoader(args.train_txt,args.train_info, train_transform, shuffle=True,num_workers=args.num_workers,\ 71 | S=args.S,track_per_class=args.track_per_class,class_per_batch=args.class_per_batch) 72 | num_class = train_dataloader.dataset.n_id 73 | test_dataloader = utils.Get_Video_test_DataLoader(args.test_txt,args.test_info,args.query_info,test_transform,batch_size=args.batch_size,\ 74 | shuffle=False,num_workers=args.num_workers,S=args.S,distractor=True) 75 | print('End dataloader...\n') 76 | 77 | network = nn.DataParallel(models.CNN(args.latent_dim,model_type=args.model_type,num_class=num_class,stride=args.stride).cuda()) 78 | 79 | if args.load_ckpt is not None: 80 | state = torch.load(args.load_ckpt) 81 | network.load_state_dict(state) 82 | 83 | # log 84 | os.system('mkdir -p %s'%(args.ckpt)) 85 | f = open(os.path.join(args.ckpt,args.log_path),'a') 86 | f.close() 87 | # Train loop 88 | # 1. Criterion 89 | criterion_triplet = TripletLoss('soft',True) 90 | 91 | criterion_ID = nn.CrossEntropyLoss().cuda() 92 | # 2. Optimizer 93 | if args.optimizer == 'sgd': 94 | optimizer = optim.SGD(network.parameters(),lr = args.lr,momentum=0.9,weight_decay = 1e-4) 95 | else: 96 | optimizer = optim.Adam(network.parameters(),lr = args.lr,weight_decay = 1e-5) 97 | if args.lr_step_size != 0: 98 | scheduler = optim.lr_scheduler.StepLR(optimizer, args.lr_step_size, 0.1) 99 | 100 | id_loss_list = [] 101 | trip_loss_list = [] 102 | track_id_loss_list = [] 103 | 104 | best_cmc = 0 105 | for e in range(args.n_epochs): 106 | print('Epoch',e) 107 | # Validation 108 | if (e+1)%10 == 0: 109 | cmc,map = validation(network,test_dataloader,args) 110 | print('CMC: %.4f, mAP : %.4f'%(cmc,map)) 111 | f = open(os.path.join(args.ckpt,args.log_path),'a') 112 | f.write('epoch %d, rank-1 %f , mAP %f\n'%(e,cmc,map)) 113 | if args.frame_id_loss: 114 | f.write('Frame ID loss : %r\n'%(id_loss_list)) 115 | if args.track_id_loss: 116 | f.write('Track ID loss : %r\n'%(track_id_loss_list)) 117 | f.write('Trip Loss : %r\n'%(trip_loss_list)) 118 | 119 | id_loss_list = [] 120 | trip_loss_list = [] 121 | track_id_loss_list = [] 122 | if cmc >= best_cmc: 123 | torch.save(network.state_dict(),os.path.join(args.ckpt,'ckpt_best.pth')) 124 | best_cmc = cmc 125 | f.write('best\n') 126 | f.close() 127 | # Training 128 | total_id_loss = 0 129 | total_trip_loss = 0 130 | total_track_id_loss = 0 131 | pbar = tqdm(total=len(train_dataloader),ncols=100,leave=True) 132 | for i,data in enumerate(train_dataloader): 133 | seqs = data[0]#.cuda() 134 | labels = data[1].cuda() 135 | seqs = seqs.reshape((seqs.shape[0]*seqs.shape[1],)+seqs.shape[2:]).cuda() 136 | feat, output = network(seqs) 137 | 138 | if args.temporal == 'max': 139 | pool_feat = torch.max(feat.reshape(feat.shape[0]//args.S,args.S,-1),dim=1)[0] 140 | pool_output = torch.max(output.reshape(output.shape[0]//args.S,args.S,-1),dim=1)[0] 141 | elif args.temporal == 'mean': 142 | pool_feat = torch.mean(feat.reshape(feat.shape[0]//args.S,args.S,-1),dim=1) 143 | pool_output = torch.mean(output.reshape(output.shape[0]//args.S,args.S,-1),dim=1) 144 | elif args.temporal == 'Done': 145 | pool_feat = feat 146 | pool_output = output 147 | 148 | trip_loss = criterion_triplet(pool_feat,labels,dis_func='eu') 149 | total_trip_loss += trip_loss.mean().item() 150 | total_loss = trip_loss.mean() 151 | 152 | # Frame level ID loss 153 | if args.frame_id_loss == True: 154 | expand_labels = (labels.unsqueeze(1)).repeat(1,args.S).reshape(-1) 155 | id_loss = criterion_ID(output,expand_labels) 156 | total_id_loss += id_loss.item() 157 | coeff = 1 158 | total_loss += coeff*id_loss 159 | if args.track_id_loss == True: 160 | track_id_loss = criterion_ID(pool_output,labels) 161 | total_track_id_loss += track_id_loss.item() 162 | coeff = 1 163 | total_loss += coeff*track_id_loss 164 | 165 | ##################### 166 | optimizer.zero_grad() 167 | total_loss.backward() 168 | optimizer.step() 169 | pbar.update(1) 170 | pbar.close() 171 | 172 | if args.lr_step_size !=0: 173 | scheduler.step() 174 | 175 | avg_id_loss = '%.4f'%(total_id_loss/len(train_dataloader)) 176 | avg_trip_loss = '%.4f'%(total_trip_loss/len(train_dataloader)) 177 | avg_track_id_loss = '%.4f'%(total_track_id_loss/len(train_dataloader)) 178 | print('Trip : %s , ID : %s , Track_ID : %s'%(avg_trip_loss,avg_id_loss,avg_track_id_loss)) 179 | id_loss_list.append(avg_id_loss) 180 | trip_loss_list.append(avg_trip_loss) 181 | track_id_loss_list.append(avg_track_id_loss) 182 | -------------------------------------------------------------------------------- /util/cmc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | import sys 5 | import pandas as pd 6 | from progressbar import ProgressBar, AnimatedMarker, Percentage 7 | import math 8 | from tqdm import trange 9 | 10 | 11 | def Video_Cmc(features, ids, cams, query_idx,rank_size): 12 | """ 13 | features: numpy array of shape (n, d) 14 | label`s: numpy array of shape (n) 15 | """ 16 | # Sample query 17 | data = {'feature':features, 'id':ids, 'cam':cams} 18 | q_idx = query_idx 19 | g_idx = np.arange(len(ids)) 20 | q_data = {k:v[q_idx] for k, v in data.items()} 21 | g_data = {k:v[g_idx] for k, v in data.items()} 22 | if len(g_idx) < rank_size: rank_size = len(g_idx) 23 | 24 | CMC, mAP = Cmc(q_data, g_data, rank_size) 25 | 26 | return CMC, mAP 27 | 28 | 29 | def Cmc(q_data, g_data, rank_size): 30 | n_query = q_data['feature'].shape[0] 31 | n_gallery = g_data['feature'].shape[0] 32 | 33 | dist = np_cdist(q_data['feature'], g_data['feature']) # Reture a n_query*n_gallery array 34 | 35 | cmc = np.zeros((n_query, rank_size)) 36 | ap = np.zeros(n_query) 37 | 38 | widgets = ["I'm calculating cmc! ", AnimatedMarker(markers='←↖↑↗→↘↓↙'), ' (', Percentage(), ')'] 39 | pbar = ProgressBar(widgets=widgets, max_value=n_query) 40 | for k in range(n_query): 41 | good_idx = np.where((q_data['id'][k]==g_data['id']) & (q_data['cam'][k]!=g_data['cam']))[0] 42 | junk_mask1 = (g_data['id'] == -1) 43 | junk_mask2 = (q_data['id'][k]==g_data['id']) & (q_data['cam'][k]==g_data['cam']) 44 | junk_idx = np.where(junk_mask1 | junk_mask2)[0] 45 | score = dist[k, :] 46 | sort_idx = np.argsort(score) 47 | sort_idx = sort_idx[:rank_size] 48 | 49 | ap[k], cmc[k, :] = Compute_AP(good_idx, junk_idx, sort_idx) 50 | pbar.update(k) 51 | pbar.finish() 52 | CMC = np.mean(cmc, axis=0) 53 | mAP = np.mean(ap) 54 | return CMC, mAP 55 | 56 | def Compute_AP(good_image, junk_image, index): 57 | cmc = np.zeros((len(index),)) 58 | ngood = len(good_image) 59 | 60 | old_recall = 0 61 | old_precision = 1. 62 | ap = 0 63 | intersect_size = 0 64 | j = 0 65 | good_now = 0 66 | njunk = 0 67 | for n in range(len(index)): 68 | flag = 0 69 | if np.any(good_image == index[n]): 70 | cmc[n-njunk:] = 1 71 | flag = 1 # good image 72 | good_now += 1 73 | if np.any(junk_image == index[n]): 74 | njunk += 1 75 | continue # junk image 76 | 77 | if flag == 1: 78 | intersect_size += 1 79 | recall = intersect_size/ngood 80 | precision = intersect_size/(j+1) 81 | ap += (recall-old_recall) * (old_precision+precision) / 2 82 | old_recall = recall 83 | old_precision = precision 84 | j += 1 85 | 86 | if good_now == ngood: 87 | return ap, cmc 88 | return ap, cmc 89 | 90 | 91 | def cdist(feat1, feat2): 92 | """Cosine distance""" 93 | feat1 = torch.FloatTensor(feat1)#.cuda() 94 | feat2 = torch.FloatTensor(feat2)#.cuda() 95 | feat1 = torch.nn.functional.normalize(feat1, dim=1) 96 | feat2 = torch.nn.functional.normalize(feat2, dim=1).transpose(0, 1) 97 | dist = -1 * torch.mm(feat1, feat2) 98 | return dist.cpu().numpy() 99 | 100 | def np_cdist(feat1, feat2): 101 | """Cosine distance""" 102 | feat1_u = feat1 / np.linalg.norm(feat1, axis=1, keepdims=True) # n * d -> n 103 | feat2_u = feat2 / np.linalg.norm(feat2, axis=1, keepdims=True) # n * d -> n 104 | return -1 * np.dot(feat1_u, feat2_u.T) 105 | 106 | def np_norm_eudist(feat1,feat2): 107 | feat1_u = feat1 / np.linalg.norm(feat1, axis=1, keepdims=True) # n * d -> n 108 | feat2_u = feat2 / np.linalg.norm(feat2, axis=1, keepdims=True) # n * d -> n 109 | feat1_sq = np.sum(feat1_M * feat1, axis=1) 110 | feat2_sq = np.sum(feat2_M * feat2, axis=1) 111 | return np.sqrt(feat1_sq.reshape(-1,1) + feat2_sq.reshape(1,-1) - 2*np.dot(feat1_M, feat2.T)+ 1e-12) 112 | 113 | 114 | def sqdist(feat1, feat2, M=None): 115 | """Mahanalobis/Euclidean distance""" 116 | if M is None: M = np.eye(feat1.shape[1]) 117 | feat1_M = np.dot(feat1, M) 118 | feat2_M = np.dot(feat2, M) 119 | feat1_sq = np.sum(feat1_M * feat1, axis=1) 120 | feat2_sq = np.sum(feat2_M * feat2, axis=1) 121 | return feat1_sq.reshape(-1,1) + feat2_sq.reshape(1,-1) - 2*np.dot(feat1_M, feat2.T) 122 | 123 | if __name__ == '__main__': 124 | from scipy.io import loadmat 125 | q_feature = loadmat(sys.argv[1])['ff'] 126 | q_db_txt = sys.argv[2] 127 | g_feature = loadmat(sys.argv[3])['ff'] 128 | g_db_txt = sys.argv[4] 129 | #print(feature.shape) 130 | CMC, mAP = Self_Cmc(g_feature, g_db_txt, 100) 131 | #CMC, mAP = Vanilla_Cmc(q_feature, q_db_txt, g_feature, g_db_txt) 132 | print('r1 precision = %f, mAP = %f' % (CMC[0], mAP)) 133 | -------------------------------------------------------------------------------- /util/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from torch.autograd import Variable 5 | 6 | class TripletLoss(nn.Module): 7 | 8 | def __init__(self, margin=0, batch_hard=False,dim=2048): 9 | super(TripletLoss, self).__init__() 10 | self.batch_hard = batch_hard 11 | if isinstance(margin, float) or margin == 'soft': 12 | self.margin = margin 13 | else: 14 | raise NotImplementedError( 15 | 'The margin {} is not recognized in TripletLoss()'.format(margin)) 16 | 17 | def forward(self, feat, id=None, pos_mask=None, neg_mask=None, mode='id',dis_func='eu',n_dis=0): 18 | 19 | if dis_func == 'cdist': 20 | feat = feat / feat.norm(p=2,dim=1,keepdim=True) 21 | dist = self.cdist(feat, feat) 22 | elif dis_func == 'eu': 23 | dist = self.cdist(feat, feat) 24 | 25 | if mode == 'id': 26 | if id is None: 27 | raise RuntimeError('foward is in id mode, please input id!') 28 | else: 29 | identity_mask = torch.eye(feat.size(0)).byte() 30 | identity_mask = identity_mask.cuda() if id.is_cuda else identity_mask 31 | same_id_mask = torch.eq(id.unsqueeze(1), id.unsqueeze(0)) 32 | negative_mask = same_id_mask ^ 1 33 | positive_mask = same_id_mask ^ identity_mask 34 | elif mode == 'mask': 35 | if pos_mask is None or neg_mask is None: 36 | raise RuntimeError('foward is in mask mode, please input pos_mask & neg_mask!') 37 | else: 38 | positive_mask = pos_mask 39 | same_id_mask = neg_mask ^ 1 40 | negative_mask = neg_mask 41 | else: 42 | raise ValueError('unrecognized mode') 43 | 44 | if self.batch_hard: 45 | if n_dis != 0: 46 | img_dist = dist[:-n_dis,:-n_dis] 47 | max_positive = (img_dist * positive_mask[:-n_dis,:-n_dis].float()).max(1)[0] 48 | min_negative = (img_dist + 1e5*same_id_mask[:-n_dis,:-n_dis].float()).min(1)[0] 49 | dis_min_negative = dist[:-n_dis,-n_dis:].min(1)[0] 50 | z_origin = max_positive - min_negative 51 | # z_dis = max_positive - dis_min_negative 52 | else: 53 | max_positive = (dist * positive_mask.float()).max(1)[0] 54 | min_negative = (dist + 1e5*same_id_mask.float()).min(1)[0] 55 | z = max_positive - min_negative 56 | else: 57 | pos = positive_mask.topk(k=1, dim=1)[1].view(-1,1) 58 | positive = torch.gather(dist, dim=1, index=pos) 59 | pos = negative_mask.topk(k=1, dim=1)[1].view(-1,1) 60 | negative = torch.gather(dist, dim=1, index=pos) 61 | z = positive - negative 62 | 63 | if isinstance(self.margin, float): 64 | b_loss = torch.clamp(z + self.margin, min=0) 65 | elif self.margin == 'soft': 66 | if n_dis != 0: 67 | b_loss = torch.log(1+torch.exp(z_origin))+ -0.5* dis_min_negative# + torch.log(1+torch.exp(z_dis)) 68 | else: 69 | b_loss = torch.log(1 + torch.exp(z)) 70 | else: 71 | raise NotImplementedError("How do you even get here!") 72 | return b_loss 73 | 74 | def cdist(self, a, b): 75 | ''' 76 | Returns euclidean distance between a and b 77 | 78 | Args: 79 | a (2D Tensor): A batch of vectors shaped (B1, D) 80 | b (2D Tensor): A batch of vectors shaped (B2, D) 81 | Returns: 82 | A matrix of all pairwise distance between all vectors in a and b, 83 | will be shape of (B1, B2) 84 | ''' 85 | diff = a.unsqueeze(1) - b.unsqueeze(0) 86 | return ((diff**2).sum(2)+1e-12).sqrt() 87 | 88 | 89 | class ClusterLoss(nn.Module): 90 | def __init__(self, margin=0, batch_hard=False): 91 | super(ClusterLoss, self).__init__() 92 | self.batch_hard = batch_hard 93 | if isinstance(margin, float) or margin == 'soft': 94 | self.margin = margin 95 | else: 96 | raise NotImplementedError( 97 | 'The margin {} is not recognized in TripletLoss()'.format(margin)) 98 | 99 | def forward(self, feat, id=None, mode='id',dis_func='eu',n_dis=0): 100 | 101 | # feat = feat.reshape(-1,1024) 102 | # diff = feat.unsqueeze(1)-feat.unsqueeze(0) 103 | # diff = ((diff**2).sum(2)+1e-14).sqrt() 104 | mean = torch.mean(feat,dim=1,keepdim=True) # 8,1,1024 105 | f2m_dist = (torch.sum((feat - mean.repeat(1,feat.shape[1],1))**2,dim=2)+1e-14).sqrt() 106 | m2m_dist = (((mean-mean.permute(1,0,2))**2).sum(2)+1e-14).sqrt() 107 | 108 | max_positive = torch.max(f2m_dist,dim=1)[0] 109 | identity_mask = torch.eye(mean.shape[0]).cuda() 110 | min_negative = torch.min(m2m_dist+1e5*identity_mask,dim=1)[0] 111 | z = max_positive - min_negative 112 | 113 | 114 | if isinstance(self.margin, float): 115 | b_loss = torch.clamp(z + self.margin, min=0) 116 | elif self.margin == 'soft': 117 | if n_dis != 0: 118 | b_loss = torch.log(1+torch.exp(z_origin))+ -0.5* dis_min_negative# + torch.log(1+torch.exp(z_dis)) 119 | else: 120 | b_loss = torch.log(1 + torch.exp(z)) 121 | else: 122 | raise NotImplementedError("How do you even get here!") 123 | return b_loss 124 | 125 | if __name__ == '__main__': 126 | criterion0 = TripletLoss(margin=0.5, batch_hard=False) 127 | criterion1 = TripletLoss(margin=0.5, batch_hard=True) 128 | 129 | t = np.random.randint(3, size=(10,)) 130 | print(t) 131 | 132 | feat = Variable(torch.rand(10, 2048), requires_grad=True).cuda() 133 | id = Variable(torch.from_numpy(t), requires_grad=True).cuda() 134 | loss0 = criterion0(feat, id) 135 | loss1 = criterion1(feat, id) 136 | print('no batch hard:', loss0) 137 | print('batch hard:', loss1) 138 | loss0.backward() 139 | loss1.backward() 140 | -------------------------------------------------------------------------------- /util/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import numpy as np 5 | import pandas as pd 6 | import collections 7 | import random 8 | import math 9 | ## For torch lib 10 | import torch 11 | from torch.utils.data import Dataset, DataLoader 12 | from torch.utils.data.sampler import SubsetRandomSampler 13 | import torchvision.transforms as T 14 | import torch.nn.functional as F 15 | ## For Image lib 16 | from PIL import Image 17 | 18 | ''' 19 | For MARS,Video-based Re-ID 20 | ''' 21 | def process_labels(labels): 22 | unique_id = np.unique(labels) 23 | id_count = len(unique_id) 24 | id_dict = {ID:i for i, ID in enumerate(unique_id.tolist())} 25 | for i in range(len(labels)): 26 | labels[i] = id_dict[labels[i]] 27 | assert len(unique_id)-1 == np.max(labels) 28 | return labels,id_count 29 | 30 | class Video_train_Dataset(Dataset): 31 | def __init__(self,db_txt,info,transform,S=6,track_per_class=4,flip_p=0.5,delete_one_cam=False,cam_type='normal'): 32 | with open(db_txt,'r') as f: 33 | self.imgs = np.array(f.read().strip().split('\n')) 34 | # For info (id,track) 35 | if delete_one_cam == True: 36 | info = np.load(info) 37 | info[:,2],id_count = process_labels(info[:,2]) 38 | for i in range(id_count): 39 | idx = np.where(info[:,2]==i)[0] 40 | if len(np.unique(info[idx,3])) ==1: 41 | info = np.delete(info,idx,axis=0) 42 | id_count -=1 43 | info[:,2],id_count = process_labels(info[:,2]) 44 | #change from 625 to 619 45 | else: 46 | info = np.load(info) 47 | info[:,2],id_count = process_labels(info[:,2]) 48 | 49 | self.info = [] 50 | for i in range(len(info)): 51 | sample_clip = [] 52 | F = info[i][1]-info[i][0]+1 53 | if F < S: 54 | strip = list(range(info[i][0],info[i][1]+1))+[info[i][1]]*(S-F) 55 | for s in range(S): 56 | pool = strip[s*1:(s+1)*1] 57 | sample_clip.append(list(pool)) 58 | else: 59 | interval = math.ceil(F/S) 60 | strip = list(range(info[i][0],info[i][1]+1))+[info[i][1]]*(interval*S-F) 61 | for s in range(S): 62 | pool = strip[s*interval:(s+1)*interval] 63 | sample_clip.append(list(pool)) 64 | self.info.append(np.array([np.array(sample_clip),info[i][2],info[i][3]])) 65 | 66 | self.info = np.array(self.info) 67 | self.transform = transform 68 | self.n_id = id_count 69 | self.n_tracklets = self.info.shape[0] 70 | self.flip_p = flip_p 71 | self.track_per_class = track_per_class 72 | self.cam_type = cam_type 73 | self.two_cam = False 74 | self.cross_cam = False 75 | 76 | def __getitem__(self,ID): 77 | sub_info = self.info[self.info[:,1] == ID] 78 | 79 | if self.cam_type == 'normal': 80 | tracks_pool = list(np.random.choice(sub_info[:,0],self.track_per_class)) 81 | elif self.cam_type == 'two_cam': 82 | unique_cam = np.random.permutation(np.unique(sub_info[:,2]))[:2] 83 | tracks_pool = list(np.random.choice(sub_info[sub_info[:,2]==unique_cam[0],0],1))+\ 84 | list(np.random.choice(sub_info[sub_info[:,2]==unique_cam[1],0],1)) 85 | elif self.cam_type == 'cross_cam': 86 | unique_cam = np.random.permutation(np.unique(sub_info[:,2])) 87 | while len(unique_cam) < self.track_per_class: 88 | unique_cam = np.append(unique_cam,unique_cam) 89 | unique_cam = unique_cam[:self.track_per_class] 90 | tracks_pool = [] 91 | for i in range(self.track_per_class): 92 | tracks_pool += list(np.random.choice(sub_info[sub_info[:,2]==unique_cam[i],0],1)) 93 | 94 | one_id_tracks = [] 95 | for track_pool in tracks_pool: 96 | idx = np.random.choice(track_pool.shape[1],track_pool.shape[0]) 97 | number = track_pool[np.arange(len(track_pool)),idx] 98 | imgs = [self.transform(Image.open(path)) for path in self.imgs[number]] 99 | imgs = torch.stack(imgs,dim=0) 100 | 101 | random_p = random.random() 102 | if random_p < self.flip_p: 103 | imgs = torch.flip(imgs,dims=[3]) 104 | one_id_tracks.append(imgs) 105 | return torch.stack(one_id_tracks,dim=0), ID*torch.ones(self.track_per_class,dtype=torch.int64) 106 | 107 | def __len__(self): 108 | return self.n_id 109 | 110 | def Video_train_collate_fn(data): 111 | if isinstance(data[0],collections.Mapping): 112 | t_data = [tuple(d.values()) for d in data] 113 | values = MARS_collate_fn(t_data) 114 | return {key:value for key,value in zip(data[0].keys(),values)} 115 | else: 116 | imgs,labels = zip(*data) 117 | imgs = torch.cat(imgs,dim=0) 118 | labels = torch.cat(labels,dim=0) 119 | return imgs,labels 120 | 121 | def Get_Video_train_DataLoader(db_txt,info,transform,shuffle=True,num_workers=8,S=10,track_per_class=4,class_per_batch=8): 122 | dataset = Video_train_Dataset(db_txt,info,transform,S,track_per_class) 123 | dataloader = DataLoader(dataset,batch_size=class_per_batch,collate_fn=Video_train_collate_fn,shuffle=shuffle,worker_init_fn=lambda _:np.random.seed(),drop_last=True,num_workers=num_workers) 124 | return dataloader 125 | 126 | class Video_test_Dataset(Dataset): 127 | def __init__(self,db_txt,info,query,transform,S=6,distractor=True): 128 | with open(db_txt,'r') as f: 129 | self.imgs = np.array(f.read().strip().split('\n')) 130 | # info 131 | info = np.load(info) 132 | self.info = [] 133 | for i in range(len(info)): 134 | if distractor == False and info[i][2]==0: 135 | continue 136 | sample_clip = [] 137 | F = info[i][1]-info[i][0]+1 138 | if F < S: 139 | strip = list(range(info[i][0],info[i][1]+1))+[info[i][1]]*(S-F) 140 | for s in range(S): 141 | pool = strip[s*1:(s+1)*1] 142 | sample_clip.append(list(pool)) 143 | else: 144 | interval = math.ceil(F/S) 145 | strip = list(range(info[i][0],info[i][1]+1))+[info[i][1]]*(interval*S-F) 146 | for s in range(S): 147 | pool = strip[s*interval:(s+1)*interval] 148 | sample_clip.append(list(pool)) 149 | self.info.append(np.array([np.array(sample_clip),info[i][2],info[i][3]])) 150 | 151 | self.info = np.array(self.info) 152 | self.transform = transform 153 | self.n_id = len(np.unique(self.info[:,1])) 154 | self.n_tracklets = self.info.shape[0] 155 | self.query_idx = np.load(query).reshape(-1) 156 | 157 | if distractor == False: 158 | zero = np.where(info[:,2]==0)[0] 159 | self.new_query = [] 160 | for i in self.query_idx: 161 | if i < zero[0]: 162 | self.new_query.append(i) 163 | elif i <= zero[-1]: 164 | continue 165 | elif i > zero[-1]: 166 | self.new_query.append(i-len(zero)) 167 | else: 168 | continue 169 | self.query_idx = np.array(self.new_query) 170 | 171 | def __getitem__(self,idx): 172 | clips = self.info[idx,0] 173 | imgs = [self.transform(Image.open(path)) for path in self.imgs[clips[:,0]]] 174 | imgs = torch.stack(imgs,dim=0) 175 | label = self.info[idx,1]*torch.ones(1,dtype=torch.int32) 176 | cam = self.info[idx,2]*torch.ones(1,dtype=torch.int32) 177 | return imgs,label,cam 178 | def __len__(self): 179 | return len(self.info) 180 | 181 | def Video_test_collate_fn(data): 182 | if isinstance(data[0],collections.Mapping): 183 | t_data = [tuple(d.values()) for d in data] 184 | values = MARS_collate_fn(t_data) 185 | return {key:value for key,value in zip(data[0].keys(),values)} 186 | else: 187 | imgs,label,cam= zip(*data) 188 | imgs = torch.cat(imgs,dim=0) 189 | labels = torch.cat(label,dim=0) 190 | cams = torch.cat(cam,dim=0) 191 | return imgs,labels,cams 192 | 193 | def Get_Video_test_DataLoader(db_txt,info,query,transform,batch_size=10,shuffle=False,num_workers=8,S=6,distractor=True): 194 | dataset = Video_test_Dataset(db_txt,info,query,transform,S,distractor=distractor) 195 | dataloader = DataLoader(dataset,batch_size=batch_size,collate_fn=Video_test_collate_fn,shuffle=shuffle,worker_init_fn=lambda _:np.random.seed(),num_workers=num_workers) 196 | return dataloader 197 | 198 | 199 | 200 | --------------------------------------------------------------------------------