├── .gitignore
├── LICENSE
├── README.md
├── creat_DukeV_database.py
├── create_MARS_database.py
├── evaluate.py
├── fig
├── NVAN.jpg
└── STE-NVAN.jpg
├── net
├── models.py
└── resnet.py
├── parser.py
├── run_NL.sh
├── run_baseline.sh
├── run_evaluate.sh
├── train_NL.py
├── train_baseline.py
└── util
├── cmc.py
├── loss.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | MARS_database/
2 | DukeV_database/
3 | __pycache__/
4 | ckpt*/
5 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Chih-Ting, Liu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Spatially and Temporally Efficient Non-local Attention Network for Video-based Person Re-Identification
2 | - **NVAN**
3 |

4 |
5 | - **STE-NVAN**
6 | 
7 |
8 | [[Paper]](http://media.ee.ntu.edu.tw/research/STE_NVAN/BMVC19_STE_NVAN_cam.pdf) [[arXiv]](https://arxiv.org/abs/1908.01683)
9 |
10 | [Chih-Ting Liu](https://jackie840129.github.io/), Chih-Wei Wu, [Yu-Chiang Frank Wang](http://vllab.ee.ntu.edu.tw/members.html) and [Shao-Yi Chien](http://www.ee.ntu.edu.tw/profile?id=101),
British Machine Vision Conference (**BMVC**), 2019
11 |
12 | This is the pytorch implementatin of Spatially and Temporally Efficient Non-local Video Attention Network **(STE-NVAN)** for video-based person Re-ID.
13 |
It achieves **90.0%** for the baseline version and **88.9%** for the ST-efficient model in rank-1 accuracy on MARS dataset.
14 |
15 | ## News ##
16 |
17 | **`2021-06-13`**: We will update this repro to a new version, which is similar to our new work [CF-AAN](https://github.com/jackie840129/CF-AAN) !
18 |
19 | ## Prerequisites
20 | - Python3.5+
21 | - [Pytorch](https://pytorch.org/) (We run the code under version 1.0.)
22 | - torchvisoin (We run the code under version 0.2.2)
23 |
24 | ## Getting Started
25 |
26 | ### Installation
27 | - Install dependancy. You can install all the dependancies by:
28 | ```
29 | $ pip3 install numpy, Pillow, progressbar2, tqdm, pandas
30 | ```
31 |
32 | ### Datasets
33 | We conduct experiments on [MARS](http://www.liangzheng.com.cn/Project/project_mars.html) and [DukeMTMC-VideoReID](https://github.com/Yu-Wu/DukeMTMC-VideoReID) (DukeV) datasets.
34 |
35 | **For MARS dataset:**
36 | - Download and unzip the dataset from the official website. ([Google Drive](https://drive.google.com/drive/u/1/folders/0B6tjyrV1YrHeMVV2UFFXQld6X1E))
37 | - Clone the repo of [MARS-evaluation](https://github.com/liangzheng06/MARS-evaluation). We will need the files under **info/** directory.
38 |
You will have the structure as follows:
39 | ```
40 | path/to/your/MARS dataset/
41 | |-- bbox_train/
42 | |-- bbox_test/
43 | |-- MARS-evaluation/
44 | | |-- info/
45 | ```
46 | - run `create_MARS_database.py` to create the database files (.txt and .npy files) into "MARS_database" directory.
47 | ```
48 | $ python3 create_MARS_database.py --data_dir /path/to/MARS dataset/ \
49 | --info_dir /path/to/MARS dataset/MARS-evaluation/info/ \
50 | --output_dir ./MARS_database/
51 | ```
52 |
53 | **For DukeV dataset:**
54 | - Download and unzip the dataset from the official github page. ([data link](http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-VideoReID.zip))
55 |
You will have the structure as follows:
56 | ```
57 | path/to/your/DukeV dataset/
58 | |-- gallery/
59 | |-- query/
60 | |-- train/
61 | ```
62 | - run `create_DukeV_database.py` to create the database files (.txt and .npy files) into "DukeV_database" directory.
63 | ```
64 | $ python3 create_DukeV_database.py --data_dir /path/to/DukeV dataset/ \
65 | --output_dir ./DukeV_database/
66 | ```
67 | ## Usage-Testing
68 | We rewrite the evaluation code in [here](https://github.com/liangzheng06/MARS-evaluation) with python.
69 |
70 | Furthermore, we follow the video-based evaluation metric in this [paper](https://zpascal.net/cvpr2018/Li_Diversity_Regularized_Spatiotemporal_CVPR_2018_paper.pdf).
71 |
72 | In detail, we will sample the first frame in each chunk of a tracklet.
73 |
74 | ### Prerequisite
75 | For testing, we provide three trained models on **MARS** dataset in this [**link**](https://drive.google.com/drive/folders/1yi4RJHhu8iMtewdnWYpLCLkIi0okjl35?usp=sharing).
76 |
77 | You should first create a directory with this command: `$ mkdir ckpt`, to put these three models under the directory.
78 |
79 | All three execution commands are in the script `run_evaluate.sh`.
80 | You can check and alter the arguments inside and run
81 | ```
82 | $ sh run_evaluate.sh
83 | ```
84 | to obtain the rank-1 accuracy and the mAP score.
85 |
86 | Some scores are different to those in my paper because some models are lost in my previous computer. (I've retrained them again.)
87 |
88 | The evaluation commands of three models are as follows.
89 |
90 | ### Baseine model : Resnet50 + FPL (mean)
91 | Uncomment this part. You will get R1=87.42% and mAP=79.44%.
92 | ```
93 | # Evaluate ResNet50 + FPL (mean or max)
94 | LOAD_CKPT=./ckpt/R50_baseline_mean.pth
95 | python3 evaluate.py --test_txt $TEST_TXT --test_info $TEST_INFO --query_info $QUERY_INFO \
96 | --batch_size 64 --model_type 'resnet50_s1' --num_workers 8 --S 8 \
97 | --latent_dim 2048 --temporal mean --stride 1 --load_ckpt $LOAD_CKPT
98 | ```
99 | ### NVAN : R50 + 5 Non-local layers + FPL
100 | Uncomment this part. You will get R1=90.00% and mAP=82.79%.
101 | ```
102 | #Evaluate NVAN (R50 + 5 NL + FPL)
103 | LOAD_CKPT=./ckpt/NVAN.pth
104 | python3 evaluate.py --test_txt $TEST_TXT --test_info $TEST_INFO --query_info $QUERY_INFO \
105 | --batch_size 64 --model_type 'resnet50_NL' --num_workers 8 --S 8 --latent_dim 2048 \
106 | --temporal Done --non_layers 0 2 3 0 --load_ckpt $LOAD_CKPT \
107 | ```
108 | ### STE-NVAN : NVAN + Spatial Reduction + Temporal Reduction
109 | Uncomment this part. You will get R1=88.69% and mAP=81.27%.
110 | ```
111 | # Evaluate NVAN (R50 + 5 NL + Stripe + Hierarchical + FPL)
112 | LOAD_CKPT=./ckpt/STE_NVAN.pth
113 | python3 evaluate.py --test_txt $TEST_TXT --test_info $TEST_INFO --query_info $QUERY_INFO \
114 | --batch_size 128 --model_type 'resnet50_NL_stripe_hr' --num_workers 8 --S 8 --latent_dim 2048 \
115 | --temporal Done --non_layers 0 2 3 0 --stripe 16 16 16 16 --load_ckpt $LOAD_CKPT \
116 | ```
117 |
118 | ## Usage-Training
119 | As mentioned in our paper, we have three kinds of models. (Baseline, NVAN, STE-NVAN)
120 |
121 | ### Baseine model : Resnet50 + FPL (mean)
122 | You can alter the arguments in `run_baseline.sh` or just use this command:
123 | ```
124 | $ sh run_baseline.sh
125 | ```
126 | ### NVAN : R50 + 5 Non-local layers + FPL
127 | You can alter the arguments or uncomment this part in `run_NL.sh`:
128 | ```
129 | # For NVAN
130 | CKPT=ckpt_NL_0230
131 | python3 train_NL.py --train_txt $TRAIN_TXT --train_info $TRAIN_INFO --batch_size 64 \
132 | --test_txt $TEST_TXT --test_info $TEST_INFO --query_info $QUERY_INFO \
133 | --n_epochs 200 --lr 0.0001 --lr_step_size 50 --optimizer adam --ckpt $CKPT --log_path loss.txt --class_per_batch 8 \
134 | --model_type 'resnet50_NL' --num_workers 8 --track_per_class 4 --S 8 --latent_dim 2048 --temporal Done --track_id_loss \
135 | --non_layers 0 2 3 0
136 | ```
137 | Then run this script.
138 | ```
139 | $ sh run_NL.sh
140 | ```
141 | ### STE-NVAN : NVAN + Spatial Reduction + Temporal Reduction
142 | You can alter the arguments or uncomment this part in `run_NL.sh`:
143 | ```
144 | # For STE-NVAN
145 | CKPT=ckpt_NL_stripe16_hr_0230
146 | python3 train_NL.py --train_txt $TRAIN_TXT --train_info $TRAIN_INFO --batch_size 64 \
147 | --test_txt $TEST_TXT --test_info $TEST_INFO --query_info $QUERY_INFO \
148 | --n_epochs 200 --lr 0.0001 --lr_step_size 50 --optimizer adam --ckpt $CKPT --log_path loss.txt --class_per_batch 8 \
149 | --model_type 'resnet50_NL_stripe_hr' --num_workers 8 --track_per_class 4 --S 8 --latent_dim 2048 --temporal Done --track_id_loss \
150 | --non_layers 0 2 3 0 --stripes 16 16 16 16
151 | ```
152 | Then run this script.
153 | ```
154 | $ sh run_NL.sh
155 | ```
156 |
157 | ## Citation
158 | ```
159 | @inproceedings{liu2019spatially,
160 | title={Spatially and Temporally Efficient Non-local Attention Network for Video-based Person Re-Identification},
161 | author={Liu, Chih-Ting and Wu, Chih-Wei and Wang, Yu-Chiang Frank and Chien, Shao-Yi},
162 | booktitle={British Machine Vision Conference},
163 | year={2019}
164 | }
165 | ```
166 | ## Reference
167 |
168 | Chih-Ting Liu, [Media IC & System Lab](https://github.com/mediaic), National Taiwan University
169 |
170 | E-mail : jackieliu@media.ee.ntu.edu.tw
171 |
--------------------------------------------------------------------------------
/creat_DukeV_database.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import numpy as np
4 | import scipy.io as sio
5 |
6 | IMG_EXTENSIONS = [
7 | '.jpg', '.JPG', '.jpeg', '.JPEG',
8 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
9 | ]
10 |
11 | def is_image_file(filename):
12 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
13 |
14 | if __name__ == '__main__':
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument('--data_dir',help='path/to/DukeV/')
17 | parser.add_argument('--output_dir',help='path/to/save/database/',default='./DukeV_database')
18 | args = parser.parse_args()
19 |
20 | os.system('mkdir -p %s'%(args.output_dir))
21 | # Read images
22 | # Train
23 | train_imgs_path = []
24 | infos = []
25 | count = 0
26 | data_dir = os.path.join(args.data_dir,'train')
27 | ids = sorted(os.listdir(data_dir))
28 | for id in ids:
29 | tracks = sorted(os.listdir(os.path.join(data_dir,id)))
30 | for track in tracks:
31 | info = []
32 | images = sorted(os.listdir(os.path.join(data_dir,id,track)))
33 | info.append(count)
34 | info.append(count+len(images)-1)
35 | info.append(int(id))
36 | count = count+len(images)
37 | for image in images:
38 | if is_image_file(image):
39 | _,cam,_,_ = image.split('_')
40 | train_imgs_path.append(os.path.abspath(os.path.join(data_dir,id,track,image)))
41 | info.append(int(cam[1:]))
42 | infos.append(info)
43 | train_imgs_path = np.array(train_imgs_path)
44 | np.savetxt(os.path.join(args.output_dir,'train_path.txt'),train_imgs_path,fmt='%s',delimiter='\n')
45 | np.save(os.path.join(args.output_dir,'train_info.npy'),np.array(infos))
46 |
47 | query_info = []
48 | data_dir = os.path.join(args.data_dir,'query')
49 | ids = sorted(os.listdir(data_dir))
50 | for id in ids:
51 | tracks = sorted(os.listdir(os.path.join(data_dir,id)))
52 | for track in tracks:
53 | query_info.append([id,track])
54 | # Test
55 | gallery_imgs_path = []
56 | track_idx = []
57 | idx = 0
58 | infos = []
59 | count = 0
60 | data_dir = os.path.join(args.data_dir,'gallery')
61 | ids = sorted(os.listdir(data_dir))
62 | for id in ids:
63 | tracks = sorted(os.listdir(os.path.join(data_dir,id)))
64 | for track in tracks:
65 | if [id,track] == query_info[0]:
66 | track_idx.append(idx)
67 | del query_info[0]
68 | info = []
69 | images = sorted(os.listdir(os.path.join(data_dir,id,track)))
70 | info.append(count)
71 | info.append(count+len(images)-1)
72 | info.append(int(id))
73 | count = count+len(images)
74 | for image in images:
75 | if is_image_file(image):
76 | _,cam,_,_ = image.split('_')
77 | gallery_imgs_path.append(os.path.abspath(os.path.join(data_dir,id,track,image)))
78 | info.append(int(cam[1:]))
79 | infos.append(info)
80 | idx +=1
81 | gallery_imgs_path = np.array(gallery_imgs_path)
82 | np.savetxt(os.path.join(args.output_dir,'gallery_path.txt'),gallery_imgs_path,fmt='%s',delimiter='\n')
83 | np.save(os.path.join(args.output_dir,'gallery_info.npy'),np.array(infos))
84 | np.save(os.path.join(args.output_dir,'query_IDX.npy'),np.array(track_idx))
85 |
86 |
--------------------------------------------------------------------------------
/create_MARS_database.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import numpy as np
4 | import scipy.io as sio
5 |
6 | IMG_EXTENSIONS = [
7 | '.jpg', '.JPG', '.jpeg', '.JPEG',
8 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
9 | ]
10 |
11 | def is_image_file(filename):
12 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
13 |
14 | if __name__ == '__main__':
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument('--data_dir',help='path/to/MARS/')
17 | parser.add_argument('--info_dir',help='path/to/MARS-evaluation/info/')
18 | parser.add_argument('--output_dir',help='path/to/save/database',default='./MARS_database')
19 | args = parser.parse_args()
20 |
21 | os.system('mkdir -p %s'%(args.output_dir))
22 | # Train
23 | train_imgs = []
24 | data_dir = os.path.join(args.data_dir,'bbox_train')
25 | ids = sorted(os.listdir(data_dir))
26 | for id in ids:
27 | images = sorted(os.listdir(os.path.join(data_dir,id)))
28 | for image in images:
29 | if is_image_file(image):
30 | train_imgs.append(os.path.abspath(os.path.join(data_dir,id,image)))
31 | train_imgs = np.array(train_imgs)
32 | np.savetxt(os.path.join(args.output_dir,'train_path.txt'),train_imgs,fmt='%s',delimiter='\n')
33 | # Test
34 | test_imgs = []
35 | data_dir = os.path.join(args.data_dir,'bbox_test')
36 | ids = sorted(os.listdir(data_dir))
37 | for id in ids:
38 | images = sorted(os.listdir(os.path.join(data_dir,id)))
39 | for image in images:
40 | if is_image_file(image):
41 | test_imgs.append(os.path.abspath(os.path.join(data_dir,id,image)))
42 | test_imgs = np.array(test_imgs)
43 | np.savetxt(os.path.join(args.output_dir,'test_path.txt'),test_imgs,fmt='%s',delimiter='\n')
44 |
45 | ## process matfile
46 | train_info = sio.loadmat(os.path.join(args.info_dir,'tracks_train_info.mat'))['track_train_info']
47 | test_info = sio.loadmat(os.path.join(args.info_dir,'tracks_test_info.mat'))['track_test_info']
48 | query_IDX = sio.loadmat(os.path.join(args.info_dir,'query_IDX.mat'))['query_IDX']
49 |
50 | # start from 0 (matlab starts from 1)
51 | train_info[:,0:2] = train_info[:,0:2]-1
52 | test_info[:,0:2] = test_info[:,0:2]-1
53 | query_IDX = query_IDX -1
54 | np.save(os.path.join(args.output_dir,'train_info.npy'),train_info)
55 | np.save(os.path.join(args.output_dir,'test_info.npy'),test_info)
56 | np.save(os.path.join(args.output_dir,'query_IDX.npy'),query_IDX)
57 |
58 |
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | from util import utils
2 | from util.cmc import Video_Cmc
3 | from net import models
4 | import parser
5 | import sys
6 | import random
7 | from tqdm import tqdm
8 | import numpy as np
9 | import math
10 |
11 | import torch
12 | import torch.nn as nn
13 | from torchvision.transforms import Compose,ToTensor,Normalize,Resize
14 | import torch.backends.cudnn as cudnn
15 | cudnn.benchmark=True
16 | import os
17 | os.environ['CUDA_VISIBLE_DEVICES']='0'
18 | torch.multiprocessing.set_sharing_strategy('file_system')
19 |
20 | def validation(network,dataloader,args):
21 | network.eval()
22 | pbar = tqdm(total=len(dataloader),ncols=100,leave=True)
23 | pbar.set_description('Inference')
24 | gallery_features = []
25 | gallery_labels = []
26 | gallery_cams = []
27 | with torch.no_grad():
28 | for c,data in enumerate(dataloader):
29 | seqs = data[0].cuda()
30 | label = data[1]
31 | cams = data[2]
32 |
33 | if args.model_type != 'resnet50_s1':
34 | B,C,H,W = seqs.shape
35 | seqs = seqs.reshape(B//args.S,args.S,C,H,W)
36 | feat = network(seqs)#.cpu().numpy() #[xx,128]
37 | if args.temporal == 'max':
38 | feat = torch.max(feat.reshape(feat.shape[0]//args.S,args.S,-1),dim=1)[0]
39 | elif args.temporal == 'mean':
40 | feat = torch.mean(feat.reshape(feat.shape[0]//args.S,args.S,-1),dim=1)
41 | elif args.temporal in ['Done'] :
42 | feat = feat
43 |
44 | gallery_features.append(feat.cpu())
45 | gallery_labels.append(label)
46 | gallery_cams.append(cams)
47 | pbar.update(1)
48 | pbar.close()
49 |
50 | gallery_features = torch.cat(gallery_features,dim=0).numpy()
51 | gallery_labels = torch.cat(gallery_labels,dim=0).numpy()
52 | gallery_cams = torch.cat(gallery_cams,dim=0).numpy()
53 |
54 | Cmc,mAP = Video_Cmc(gallery_features,gallery_labels,gallery_cams,dataloader.dataset.query_idx,10000)
55 | network.train()
56 |
57 | return Cmc[0],mAP
58 |
59 | if __name__ == '__main__':
60 | #Parse args
61 | args = parser.parse_args()
62 |
63 | test_transform = Compose([Resize((256,128)),ToTensor(),Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])])
64 | print('Start dataloader...')
65 | num_class = 625
66 | test_dataloader = utils.Get_Video_test_DataLoader(args.test_txt,args.test_info,args.query_info,test_transform,batch_size=args.batch_size,\
67 | shuffle=False,num_workers=args.num_workers,S=args.S,distractor=True)
68 | print('End dataloader...')
69 |
70 | network = nn.DataParallel(models.CNN(args.latent_dim,model_type=args.model_type,num_class=num_class,non_layers=args.non_layers,stripes=args.stripes,temporal=args.temporal).cuda())
71 |
72 | if args.load_ckpt is None:
73 | print('No ckpt!')
74 | exit()
75 | else:
76 | state = torch.load(args.load_ckpt)
77 | network.load_state_dict(state,strict=True)
78 |
79 |
80 | cmc,map = validation(network,test_dataloader,args)
81 |
82 | print('CMC : %.4f , mAP : %.4f'%(cmc,map))
83 |
--------------------------------------------------------------------------------
/fig/NVAN.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/STE-NVAN/3042bc8e4b4e5a7608123fcd121bde975e20fc50/fig/NVAN.jpg
--------------------------------------------------------------------------------
/fig/STE-NVAN.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/STE-NVAN/3042bc8e4b4e5a7608123fcd121bde975e20fc50/fig/STE-NVAN.jpg
--------------------------------------------------------------------------------
/net/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torchvision import models
5 | import net.resnet as res
6 |
7 | def weights_init_kaiming(m):
8 | classname = m.__class__.__name__
9 | if classname.find('Linear') != -1:
10 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out')
11 | nn.init.constant_(m.bias, 0.0)
12 | elif classname.find('Conv') != -1:
13 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
14 | if m.bias is not None:
15 | nn.init.constant_(m.bias, 0.0)
16 | elif classname.find('BatchNorm') != -1:
17 | if m.affine:
18 | nn.init.constant_(m.weight, 1.0)
19 | nn.init.constant_(m.bias, 0.0)
20 |
21 | def weights_init_classifier(m):
22 | classname = m.__class__.__name__
23 | if classname.find('Linear') != -1:
24 | nn.init.normal_(m.weight, std=0.001)
25 | if m.bias:
26 | nn.init.constant_(m.bias, 0.0)
27 |
28 |
29 | class Resnet50_NL(nn.Module):
30 | def __init__(self,non_layers=[0,1,1,1],stripes=[16,16,16,16],non_type='normal',temporal=None):
31 | super(Resnet50_NL,self).__init__()
32 | original = models.resnet50(pretrained=True).state_dict()
33 | if non_type == 'normal':
34 | self.backbone = res.ResNet_Video_nonlocal(last_stride=1,non_layers=non_layers)
35 | elif non_type == 'stripe':
36 | self.backbone = res.ResNet_Video_nonlocal_stripe(last_stride = 1, non_layers=non_layers, stripes=stripes)
37 | elif non_type == 'hr':
38 | self.backbone = res.ResNet_Video_nonlocal_hr(last_stride = 1, non_layers=non_layers, stripes=stripes)
39 | elif non_type == 'stripe_hr':
40 | self.backbone = res.ResNet_Video_nonlocal_stripe_hr(last_stride = 1, non_layers=non_layers, stripes=stripes)
41 | for key in original:
42 | if key.find('fc') != -1:
43 | continue
44 | self.backbone.state_dict()[key].copy_(original[key])
45 | del original
46 |
47 | self.temporal = temporal
48 | if self.temporal == 'Done':
49 | self.avgpool = nn.AdaptiveAvgPool3d(1)
50 |
51 | def forward(self,x):
52 | if self.temporal == 'Done':
53 | x = self.backbone(x)
54 | x = self.avgpool(x)
55 | x = x.reshape(x.shape[0],-1)
56 | return x
57 |
58 |
59 | class Resnet50_s1(nn.Module):
60 | def __init__(self,pooling=True,stride=1):
61 | super(Resnet50_s1,self).__init__()
62 | original = models.resnet50(pretrained=True).state_dict()
63 | self.backbone = res.ResNet(last_stride=stride)
64 | for key in original:
65 | if key.find('fc') != -1:
66 | continue
67 | self.backbone.state_dict()[key].copy_(original[key])
68 | del original
69 | if pooling == True:
70 | self.add_module('avgpool',nn.AdaptiveAvgPool2d(1))
71 | else:
72 | self.avgpool = None
73 |
74 | self.out_dim = 2048
75 |
76 | def forward(self,x):
77 | x = self.backbone(x)
78 | if self.avgpool is not None:
79 | x = self.avgpool(x)
80 | x = x.view(x.shape[0],-1)
81 | return x
82 |
83 | class CNN(nn.Module):
84 | def __init__(self,out_dim,model_type='resnet50_s1',num_class=710,non_layers=[1,2,2],stripes=[16,16,16,16], temporal = 'Done',stride=1):
85 | super(CNN,self).__init__()
86 | self.model_type = model_type
87 | if model_type == 'resnet50_s1':
88 | self.features = Resnet50_s1(stride=stride)
89 | elif model_type == 'resnet50_NL':
90 | self.features = Resnet50_NL(non_layers=non_layers,temporal=temporal,non_type='normal')
91 | elif model_type == 'resnet50_NL_stripe':
92 | self.features = Resnet50_NL(non_layers=non_layers,stripes=stripes,temporal=temporal,non_type='stripe')
93 | elif model_type == 'resnet50_NL_hr':
94 | self.features = Resnet50_NL(non_layers=non_layers,stripes=stripes,temporal=temporal,non_type='hr')
95 | elif model_type == 'resnet50_NL_stripe_hr':
96 | self.features = Resnet50_NL(non_layers=non_layers,stripes=stripes,temporal=temporal,non_type='stripe_hr')
97 |
98 | self.bottleneck = nn.BatchNorm1d(out_dim)
99 | self.bottleneck.bias.requires_grad_(False) # no shift
100 | self.bottleneck.apply(weights_init_kaiming)
101 |
102 | self.classifier = nn.Linear(out_dim,num_class, bias=False)
103 | self.classifier.apply(weights_init_classifier)
104 |
105 | def forward(self,x,seg=None):
106 | if self.model_type == 'resnet50_s1':
107 | x = self.features(x)
108 | bn = self.bottleneck(x)
109 | if self.training == True:
110 | output = self.classifier(bn)
111 | return x,output
112 | else:
113 | return bn
114 | elif self.model_type == 'resnet50_NL' or self.model_type == 'resnet50_NL_stripe' or \
115 | self.model_type=='resnet50_NL_hr' or self.model_type == 'resnet50_NL_stripe_hr':
116 | x = self.features(x)
117 | bn = self.bottleneck(x)
118 | if self.training == True:
119 | output = self.classifier(bn)
120 | return x,output
121 | else:
122 | return bn
123 |
124 | if __name__ == '__main__':
125 | model = Resnet50_s1()
126 | input = torch.ones(1,3,256,128)
127 | output = model(input)
128 | print(output.shape)
129 |
--------------------------------------------------------------------------------
/net/resnet.py:
--------------------------------------------------------------------------------
1 | import math
2 | from torch.nn import functional as F
3 | import numpy as np
4 | import os
5 | import torch
6 | from torch import nn
7 | ##################### Small Block ###################################
8 | class NonLocalBlock(nn.Module):
9 | def __init__(self, in_channels, inter_channels=None,sub_sample=False, bn_layer=True,instance='soft'):
10 | super(NonLocalBlock, self).__init__()
11 | self.sub_sample = sub_sample
12 | self.instance = instance
13 | self.in_channels = in_channels
14 | self.inter_channels = inter_channels
15 |
16 | if self.inter_channels is None:
17 | self.inter_channels = in_channels // 2
18 | if self.inter_channels == 0:
19 | self.inter_channels = 1
20 |
21 | conv_nd = nn.Conv3d
22 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
23 | bn = nn.BatchNorm3d
24 |
25 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
26 | kernel_size=1, stride=1, padding=0)
27 | if bn_layer:
28 | self.W = nn.Sequential(
29 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
30 | kernel_size=1, stride=1, padding=0),
31 | bn(self.in_channels)
32 | )
33 | nn.init.constant_(self.W[1].weight, 0)
34 | nn.init.constant_(self.W[1].bias, 0)
35 | else:
36 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
37 | kernel_size=1, stride=1, padding=0)
38 | nn.init.constant_(self.W.weight, 0)
39 | nn.init.constant_(self.W.bias, 0)
40 |
41 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
42 | kernel_size=1, stride=1, padding=0)
43 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
44 | kernel_size=1, stride=1, padding=0)
45 | if sub_sample:
46 | self.g = nn.Sequential(self.g, max_pool_layer)
47 | self.phi = nn.Sequential(self.phi, max_pool_layer)
48 |
49 | def forward(self, x):
50 | '''
51 | :param x: (b, c, t, h, w)
52 | :return:
53 | '''
54 | batch_size = x.size(0)
55 |
56 | g_x = self.g(x).view(batch_size, self.inter_channels, -1)
57 | g_x = g_x.permute(0, 2, 1)
58 |
59 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
60 | theta_x = theta_x.permute(0, 2, 1)
61 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
62 | f = torch.matmul(theta_x, phi_x)
63 | if self.instance == 'soft':
64 | f_div_C = F.softmax(f, dim=-1)
65 | elif self.instance == 'dot':
66 | f_div_C = f / f.shape[1]
67 |
68 | y = torch.matmul(f_div_C, g_x)
69 | y = y.permute(0, 2, 1).contiguous()
70 | y = y.view(batch_size, self.inter_channels, *x.size()[2:])
71 | W_y = self.W(y)
72 | z = W_y + x
73 |
74 | return z
75 |
76 | class Stripe_NonLocalBlock(nn.Module):
77 | def __init__(self,stripe,in_channels,inter_channels=None,pool_type='mean',instance='soft'):
78 | super(Stripe_NonLocalBlock,self).__init__()
79 | self.instance = instance
80 | self.stripe=stripe
81 | self.in_channels = in_channels
82 | self.pool_type = pool_type
83 | if pool_type == 'max':
84 | self.pool = nn.AdaptiveMaxPool2d(1)
85 | elif pool_type == 'mean':
86 | self.pool = nn.AdaptiveAvgPool2d(1)
87 | elif pool_type == 'meanmax':
88 | self.avgpool = nn.AdaptiveAvgPool2d(1)
89 | self.maxpool = nn.AdaptiveMaxPool2d(1)
90 | self.in_channels*=2
91 | if inter_channels == None:
92 | self.inter_channels = in_channels//2
93 | else:
94 | self.inter_channels = inter_channels
95 |
96 | self.g = nn.Conv3d(in_channels=self.in_channels, out_channels=self.inter_channels,
97 | kernel_size=1, stride=1, padding=0)
98 | self.theta = nn.Conv3d(in_channels=self.in_channels, out_channels=self.inter_channels,
99 | kernel_size=1, stride=1, padding=0)
100 | self.phi = nn.Conv3d(in_channels=self.in_channels, out_channels=self.inter_channels,
101 | kernel_size=1, stride=1, padding=0)
102 | if pool_type == 'meanmax':
103 | self.in_channels //=2
104 |
105 | self.W = nn.Sequential(
106 | nn.Conv3d(in_channels=self.inter_channels, out_channels=self.in_channels,
107 | kernel_size=1, stride=1, padding=0),
108 | nn.BatchNorm3d(self.in_channels)
109 | )
110 | nn.init.constant_(self.W[1].weight, 0)
111 | nn.init.constant_(self.W[1].bias, 0)
112 |
113 | def forward(self,x):
114 | # x.shape = (b,c,t,h,w)
115 | b,c,t,h,w = x.shape
116 | assert self.stripe * (h//self.stripe) == h
117 |
118 | if self.pool_type == 'meanmax':
119 | discri_a = self.avgpool(x.reshape(b*c*t,self.stripe,(h//self.stripe),w)).reshape(b,c,t,self.stripe,1)
120 | discri_m = self.maxpool(x.reshape(b*c*t,self.stripe,(h//self.stripe),w)).reshape(b,c,t,self.stripe,1)
121 | discri = torch.cat([discri_a,discri_m],dim=1)
122 | else:
123 | discri = self.pool(x.reshape(b*c*t,self.stripe,(h//self.stripe),w)).reshape(b,c,t,self.stripe,1)
124 | g = self.g(discri).reshape(b,self.inter_channels,-1)
125 | g = g.permute(0,2,1)
126 | theta = self.theta(discri).reshape(b, self.inter_channels, -1)
127 | theta = theta.permute(0,2,1)
128 | phi = self.phi(discri).reshape(b, self.inter_channels, -1)
129 |
130 | f = torch.matmul(theta, phi)
131 | if self.instance == 'soft':
132 | f_div_C = F.softmax(f, dim=-1)
133 | elif self.instance == 'dot':
134 | f_div_C = f / f.shape[1]
135 |
136 | y = torch.matmul(f_div_C, g)
137 | y = y.permute(0, 2, 1).contiguous()
138 | y = y.reshape(b, self.inter_channels, *discri.size()[2:])
139 | W_y = self.W(y)
140 |
141 | W_y = W_y.repeat(1,1,1,1,h//self.stripe*w).reshape(b,c,t,h,w)
142 |
143 | z = W_y + x
144 | return z
145 |
146 | class Bottleneck(nn.Module):
147 | expansion = 4
148 | def __init__(self, inplanes, planes, stride=1, downsample=None):
149 | super(Bottleneck, self).__init__()
150 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
151 | self.bn1 = nn.BatchNorm2d(planes)
152 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
153 | padding=1, bias=False)
154 | self.bn2 = nn.BatchNorm2d(planes)
155 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
156 | self.bn3 = nn.BatchNorm2d(planes * 4)
157 | self.relu = nn.ReLU(inplace=True)
158 | self.downsample = downsample
159 | self.stride = stride
160 |
161 | def forward(self, x):
162 | residual = x
163 | out = self.conv1(x)
164 | out = self.bn1(out)
165 | out = self.relu(out)
166 |
167 | out = self.conv2(out)
168 | out = self.bn2(out)
169 | out = self.relu(out)
170 |
171 | out = self.conv3(out)
172 | out = self.bn3(out)
173 |
174 | if self.downsample is not None:
175 | residual = self.downsample(x)
176 | out += residual
177 | out = self.relu(out)
178 |
179 | return out
180 | ##############################################################################
181 |
182 | ############################ backbone model ##################################
183 | class ResNet(nn.Module):
184 | def __init__(self, last_stride=1, block=Bottleneck, layers=[3, 4, 6, 3]):
185 | self.inplanes = 64
186 | super().__init__()
187 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
188 | bias=False)
189 | self.bn1 = nn.BatchNorm2d(64)
190 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
191 | self.layer1 = self._make_layer(block, 64, layers[0])
192 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
193 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
194 | self.layer4 = self._make_layer(
195 | block, 512, layers[3], stride=last_stride)
196 |
197 | def _make_layer(self, block, planes, blocks, stride=1):
198 | downsample = None
199 | if stride != 1 or self.inplanes != planes * block.expansion:
200 | downsample = nn.Sequential(
201 | nn.Conv2d(self.inplanes, planes * block.expansion,
202 | kernel_size=1, stride=stride, bias=False),
203 | nn.BatchNorm2d(planes * block.expansion),
204 | )
205 | layers = []
206 | layers.append(block(self.inplanes, planes, stride, downsample))
207 | self.inplanes = planes * block.expansion
208 | for i in range(1, blocks):
209 | layers.append(block(self.inplanes, planes))
210 | return nn.Sequential(*layers)
211 |
212 | def forward(self, x):
213 | x = self.conv1(x)
214 | x = self.bn1(x)
215 | x = self.maxpool(x)
216 |
217 | x = self.layer1(x)
218 | x = self.layer2(x)
219 | x = self.layer3(x)
220 | x = self.layer4(x)
221 |
222 | return x
223 |
224 | class ResNet_Video_nonlocal(nn.Module):
225 | def __init__(self,last_stride=1,block=Bottleneck,layers=[3,4,6,3],non_layers=[0,1,1,1]):
226 | self.inplanes = 64
227 | super().__init__()
228 | self.conv1 = nn.Conv2d(3,64,kernel_size=7,stride=2,padding=3,bias=False)
229 | self.bn1 = nn.BatchNorm2d(64)
230 | self.maxpool = nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
231 | self.layer1 = self._make_layer(block, 64, layers[0])
232 | non_idx = 0
233 | self.NL_1 = nn.ModuleList([NonLocalBlock(self.inplanes,self.inplanes//2,sub_sample=True) for i in range(non_layers[non_idx])])
234 | self.NL_1_idx = sorted([layers[0]-(i+1) for i in range(non_layers[non_idx])])
235 | non_idx += 1
236 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
237 | self.NL_2 = nn.ModuleList([NonLocalBlock(self.inplanes,self.inplanes//2) for i in range(non_layers[non_idx])])
238 | self.NL_2_idx = sorted([layers[1]-(i+1) for i in range(non_layers[non_idx])])
239 | non_idx += 1
240 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
241 | self.NL_3 = nn.ModuleList([NonLocalBlock(self.inplanes,self.inplanes//2) for i in range(non_layers[non_idx])])
242 | self.NL_3_idx =sorted( [layers[2]-(i+1) for i in range(non_layers[non_idx])])
243 | non_idx += 1
244 | self.layer4 = self._make_layer(block, 512, layers[3], stride=last_stride)
245 | self.NL_4 = nn.ModuleList([NonLocalBlock(self.inplanes,self.inplanes//2) for i in range(non_layers[non_idx])])
246 | self.NL_4_idx = sorted([layers[3]-(i+1) for i in range(non_layers[non_idx])])
247 |
248 | def _make_layer(self, block, planes, blocks, stride=1):
249 | downsample = None
250 | if stride != 1 or self.inplanes != planes * block.expansion:
251 | downsample = nn.Sequential(
252 | nn.Conv2d(self.inplanes, planes * block.expansion,
253 | kernel_size=1, stride=stride, bias=False),
254 | nn.BatchNorm2d(planes * block.expansion),
255 | )
256 | layers = []
257 | layers.append(block(self.inplanes, planes, stride, downsample))
258 | self.inplanes = planes * block.expansion
259 | for i in range(1, blocks):
260 | layers.append(block(self.inplanes, planes))
261 | return nn.ModuleList(layers)
262 |
263 | def forward(self, x):
264 | # x 's shape (B,T,C,H,W)
265 | B,T,C,H,W = x.shape
266 | x = x.reshape(B*T,C,H,W)
267 | x = self.conv1(x)
268 | x = self.bn1(x)
269 | x = self.maxpool(x)
270 |
271 | # Layer 1
272 | NL1_counter = 0
273 | if len(self.NL_1_idx)==0: self.NL_1_idx=[-1]
274 | for i in range(len(self.layer1)):
275 | x = self.layer1[i](x)
276 | if i == self.NL_1_idx[NL1_counter]:
277 | _,C,H,W = x.shape
278 | x = x.reshape(B,T,C,H,W).permute(0,2,1,3,4)
279 | x = self.NL_1[NL1_counter](x)
280 | x = x.permute(0,2,1,3,4).reshape(B*T,C,H,W)
281 | NL1_counter+=1
282 | # Layer 2
283 | NL2_counter = 0
284 | if len(self.NL_2_idx)==0: self.NL_2_idx=[-1]
285 | for i in range(len(self.layer2)):
286 | x = self.layer2[i](x)
287 | if i == self.NL_2_idx[NL2_counter]:
288 | _,C,H,W = x.shape
289 | x = x.reshape(B,T,C,H,W).permute(0,2,1,3,4)
290 | x = self.NL_2[NL2_counter](x)
291 | x = x.permute(0,2,1,3,4).reshape(B*T,C,H,W)
292 | NL2_counter+=1
293 | # Layer 3
294 | NL3_counter = 0
295 | if len(self.NL_3_idx)==0: self.NL_3_idx=[-1]
296 | for i in range(len(self.layer3)):
297 | x = self.layer3[i](x)
298 | if i == self.NL_3_idx[NL3_counter]:
299 | _,C,H,W = x.shape
300 | x = x.reshape(B,T,C,H,W).permute(0,2,1,3,4)
301 | x = self.NL_3[NL3_counter](x)
302 | x = x.permute(0,2,1,3,4).reshape(B*T,C,H,W)
303 | NL3_counter+=1
304 | # Layer 4
305 | NL4_counter = 0
306 | if len(self.NL_4_idx)==0: self.NL_4_idx=[-1]
307 | for i in range(len(self.layer4)):
308 | x = self.layer4[i](x)
309 | if i == self.NL_4_idx[NL4_counter]:
310 | _,C,H,W = x.shape
311 | x = x.reshape(B,T,C,H,W).permute(0,2,1,3,4)
312 | x = self.NL_4[NL4_counter](x)
313 | x = x.permute(0,2,1,3,4).reshape(B*T,C,H,W)
314 | NL4_counter+=1
315 | _,C,H,W = x.shape
316 | x = x.reshape(B,T,C,H,W).permute(0,2,1,3,4)
317 | # Return is (B,C,T,H,W)
318 | return x
319 |
320 | class ResNet_Video_nonlocal_stripe(nn.Module):
321 | def __init__(self,last_stride=1,block=Bottleneck,layers=[3,4,6,3],non_layers=[0,1,1,1],stripes=[16,16,16,16]):
322 | self.inplanes = 64
323 | super().__init__()
324 | self.conv1 = nn.Conv2d(3,64,kernel_size=7,stride=2,padding=3,bias=False)
325 | self.bn1 = nn.BatchNorm2d(64)
326 | self.maxpool = nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
327 | self.layer1 = self._make_layer(block, 64, layers[0])
328 | non_idx = 0
329 | self.NL_1 = nn.ModuleList([Stripe_NonLocalBlock(stripes[non_idx],self.inplanes,self.inplanes//2) for i in range(non_layers[non_idx])])
330 | self.NL_1_idx = sorted([layers[0]-(i+1) for i in range(non_layers[non_idx])])
331 | non_idx += 1
332 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
333 | self.NL_2 = nn.ModuleList([Stripe_NonLocalBlock(stripes[non_idx],self.inplanes,self.inplanes//2) for i in range(non_layers[non_idx])])
334 | self.NL_2_idx = sorted([layers[1]-(i+1) for i in range(non_layers[non_idx])])
335 | non_idx += 1
336 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
337 | self.NL_3 = nn.ModuleList([Stripe_NonLocalBlock(stripes[non_idx],self.inplanes,self.inplanes//2) for i in range(non_layers[non_idx])])
338 | self.NL_3_idx =sorted( [layers[2]-(i+1) for i in range(non_layers[non_idx])])
339 | non_idx += 1
340 | self.layer4 = self._make_layer(block, 512, layers[3], stride=last_stride)
341 | self.NL_4 = nn.ModuleList([Stripe_NonLocalBlock(stripes[non_idx],self.inplanes,self.inplanes//2) for i in range(non_layers[non_idx])])
342 | self.NL_4_idx = sorted([layers[3]-(i+1) for i in range(non_layers[non_idx])])
343 |
344 | def _make_layer(self, block, planes, blocks, stride=1):
345 | downsample = None
346 | if stride != 1 or self.inplanes != planes * block.expansion:
347 | downsample = nn.Sequential(
348 | nn.Conv2d(self.inplanes, planes * block.expansion,
349 | kernel_size=1, stride=stride, bias=False),
350 | nn.BatchNorm2d(planes * block.expansion),
351 | )
352 | layers = []
353 | layers.append(block(self.inplanes, planes, stride, downsample))
354 | self.inplanes = planes * block.expansion
355 | for i in range(1, blocks):
356 | layers.append(block(self.inplanes, planes))
357 | return nn.ModuleList(layers)
358 |
359 | def forward(self, x):
360 | # x 's shape (B,T,C,H,W)
361 | B,T,C,H,W = x.shape
362 | x = x.reshape(B*T,C,H,W)
363 | x = self.conv1(x)
364 | x = self.bn1(x)
365 | x = self.maxpool(x)
366 |
367 | # Layer 1
368 | NL1_counter = 0
369 | if len(self.NL_1_idx)==0: self.NL_1_idx=[-1]
370 | for i in range(len(self.layer1)):
371 | x = self.layer1[i](x)
372 | if i == self.NL_1_idx[NL1_counter]:
373 | _,C,H,W = x.shape
374 | x = x.reshape(B,T,C,H,W).permute(0,2,1,3,4)
375 | x = self.NL_1[NL1_counter](x)
376 | x = x.permute(0,2,1,3,4).reshape(B*T,C,H,W)
377 | NL1_counter+=1
378 | # Layer 2
379 | NL2_counter = 0
380 | if len(self.NL_2_idx)==0: self.NL_2_idx=[-1]
381 | for i in range(len(self.layer2)):
382 | x = self.layer2[i](x)
383 | if i == self.NL_2_idx[NL2_counter]:
384 | _,C,H,W = x.shape
385 | x = x.reshape(B,T,C,H,W).permute(0,2,1,3,4)
386 | x = self.NL_2[NL2_counter](x)
387 | x = x.permute(0,2,1,3,4).reshape(B*T,C,H,W)
388 | NL2_counter+=1
389 | # Layer 3
390 | NL3_counter = 0
391 | if len(self.NL_3_idx)==0: self.NL_3_idx=[-1]
392 | for i in range(len(self.layer3)):
393 | x = self.layer3[i](x)
394 | if i == self.NL_3_idx[NL3_counter]:
395 | _,C,H,W = x.shape
396 | x = x.reshape(B,T,C,H,W).permute(0,2,1,3,4)
397 | x = self.NL_3[NL3_counter](x)
398 | x = x.permute(0,2,1,3,4).reshape(B*T,C,H,W)
399 | NL3_counter+=1
400 | # Layer 4
401 | NL4_counter = 0
402 | if len(self.NL_4_idx)==0: self.NL_4_idx=[-1]
403 | for i in range(len(self.layer4)):
404 | x = self.layer4[i](x)
405 | if i == self.NL_4_idx[NL4_counter]:
406 | _,C,H,W = x.shape
407 | x = x.reshape(B,T,C,H,W).permute(0,2,1,3,4)
408 | x = self.NL_4[NL4_counter](x)
409 | x = x.permute(0,2,1,3,4).reshape(B*T,C,H,W)
410 | NL4_counter+=1
411 | _,C,H,W = x.shape
412 | x = x.reshape(B,T,C,H,W).permute(0,2,1,3,4)
413 | # Return is (B,C,T,H,W)
414 | return x
415 |
416 | class ResNet_Video_nonlocal_hr(nn.Module):
417 | def __init__(self,last_stride=1,block=Bottleneck,layers=[3,4,6,3],non_layers=[0,1,1,1],stripes=[16,16,16,16]):
418 | self.inplanes = 64
419 | super().__init__()
420 | self.conv1 = nn.Conv2d(3,64,kernel_size=7,stride=2,padding=3,bias=False)
421 | self.bn1 = nn.BatchNorm2d(64)
422 | self.maxpool = nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
423 | self.layer1 = self._make_layer(block, 64, layers[0])
424 | non_idx = 0
425 | self.NL_1 = nn.ModuleList([NonLocalBlock(self.inplanes,self.inplanes//2) for i in range(non_layers[non_idx])])
426 | self.NL_1_idx = sorted([layers[0]-(i+1) for i in range(non_layers[non_idx])])
427 | non_idx += 1
428 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
429 | self.NL_2 = nn.ModuleList([NonLocalBlock(self.inplanes,self.inplanes//2) for i in range(non_layers[non_idx])])
430 | self.NL_2_idx = sorted([layers[1]-(i+1) for i in range(non_layers[non_idx])])
431 | non_idx += 1
432 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
433 | self.NL_3 = nn.ModuleList([NonLocalBlock(self.inplanes,self.inplanes//2) for i in range(non_layers[non_idx])])
434 | self.NL_3_idx =sorted( [layers[2]-(i+1) for i in range(non_layers[non_idx])])
435 | non_idx += 1
436 | self.layer4 = self._make_layer(block, 512, layers[3], stride=last_stride)
437 | self.NL_4 = nn.ModuleList([NonLocalBlock(self.inplanes,self.inplanes//2) for i in range(non_layers[non_idx])])
438 | self.NL_4_idx = sorted([layers[3]-(i+1) for i in range(non_layers[non_idx])])
439 |
440 | def _make_layer(self, block, planes, blocks, stride=1):
441 | downsample = None
442 | if stride != 1 or self.inplanes != planes * block.expansion:
443 | downsample = nn.Sequential(
444 | nn.Conv2d(self.inplanes, planes * block.expansion,
445 | kernel_size=1, stride=stride, bias=False),
446 | nn.BatchNorm2d(planes * block.expansion),
447 | )
448 | layers = []
449 | layers.append(block(self.inplanes, planes, stride, downsample))
450 | self.inplanes = planes * block.expansion
451 | for i in range(1, blocks):
452 | layers.append(block(self.inplanes, planes))
453 | return nn.ModuleList(layers)
454 |
455 | def forward(self, x):
456 | # x 's shape (B,T,C,H,W)
457 | B,T,C,H,W = x.shape
458 | x = x.reshape(B*T,C,H,W)
459 | x = self.conv1(x)
460 | x = self.bn1(x)
461 | x = self.maxpool(x)
462 | # x 's shape (B*T,C,H,W)
463 |
464 | # Layer 1
465 | NL1_counter = 0
466 | if len(self.NL_1_idx)==0: self.NL_1_idx=[-1]
467 | for i in range(len(self.layer1)):
468 | x = self.layer1[i](x)
469 | if i == self.NL_1_idx[NL1_counter]:
470 | _,C,H,W = x.shape
471 | x = x.reshape(-1,2,C,H,W).permute(0,2,1,3,4)
472 | x = self.NL_1[NL1_counter](x)
473 | x = x.permute(0,2,1,3,4).reshape(-1,C,H,W)
474 | # x's shape (B*T//2,2,C,H,W)
475 | NL1_counter+=1
476 | # Max pool
477 | # _,C,H,W = x.shape
478 | # x = torch.max(x.reshape(-1,2,C,H,W),dim=1)[0]
479 | # T = T//2
480 | # Layer 2
481 | NL2_counter = 0
482 | if len(self.NL_2_idx)==0: self.NL_2_idx=[-1]
483 | for i in range(len(self.layer2)):
484 | x = self.layer2[i](x)
485 | if i == self.NL_2_idx[NL2_counter]:
486 | _,C,H,W = x.shape
487 | x = x.reshape(-1,T,C,H,W).permute(0,2,1,3,4)
488 | x = self.NL_2[NL2_counter](x)
489 | x = x.permute(0,2,1,3,4).reshape(-1,C,H,W)
490 | # x's shape (B*T//2,2,C,H,W)
491 | NL2_counter+=1
492 | # Max pool
493 | _,C,H,W = x.shape
494 | x = torch.max(x.reshape(-1,2,C,H,W),dim=1)[0]
495 | T = T//2
496 | # Layer 3
497 | NL3_counter = 0
498 | if len(self.NL_3_idx)==0: self.NL_3_idx=[-1]
499 | for i in range(len(self.layer3)):
500 | x = self.layer3[i](x)
501 | if i == self.NL_3_idx[NL3_counter]:
502 | _,C,H,W = x.shape
503 | x = x.reshape(-1,T,C,H,W).permute(0,2,1,3,4)
504 | x = self.NL_3[NL3_counter](x)
505 | x = x.permute(0,2,1,3,4).reshape(-1,C,H,W)
506 | # x's shape (B*T//2,2,C,H,W)
507 | NL3_counter+=1
508 | # Max pool
509 | _,C,H,W = x.shape
510 | x = torch.max(x.reshape(-1,2,C,H,W),dim=1)[0]
511 | T = T//2
512 | # Layer 4
513 | NL4_counter = 0
514 | if len(self.NL_4_idx)==0: self.NL_4_idx=[-1]
515 | for i in range(len(self.layer4)):
516 | x = self.layer4[i](x)
517 | if i == self.NL_4_idx[NL4_counter]:
518 | _,C,H,W = x.shape
519 | x = x.reshape(-1,T,C,H,W).permute(0,2,1,3,4)
520 | x = self.NL_4[NL4_counter](x)
521 | x = x.permute(0,2,1,3,4).reshape(-1,C,H,W)
522 | NL4_counter+=1
523 | _,C,H,W = x.shape
524 | x = x.reshape(B,T,C,H,W).permute(0,2,1,3,4)
525 | # Return is (B,C,T,H,W)
526 | return x
527 |
528 | class ResNet_Video_nonlocal_stripe_hr(nn.Module):
529 | def __init__(self,last_stride=1,block=Bottleneck,layers=[3,4,6,3],non_layers=[0,1,1,1],stripes=[16,16,16,16]):
530 | self.inplanes = 64
531 | super().__init__()
532 | self.conv1 = nn.Conv2d(3,64,kernel_size=7,stride=2,padding=3,bias=False)
533 | self.bn1 = nn.BatchNorm2d(64)
534 | self.maxpool = nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
535 | self.layer1 = self._make_layer(block, 64, layers[0])
536 | non_idx = 0
537 | self.NL_1 = nn.ModuleList([Stripe_NonLocalBlock(stripes[non_idx],self.inplanes,self.inplanes//2) for i in range(non_layers[non_idx])])
538 | self.NL_1_idx = sorted([layers[0]-(i+1) for i in range(non_layers[non_idx])])
539 | non_idx += 1
540 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
541 | self.NL_2 = nn.ModuleList([Stripe_NonLocalBlock(stripes[non_idx],self.inplanes,self.inplanes//2) for i in range(non_layers[non_idx])])
542 | self.NL_2_idx = sorted([layers[1]-(i+1) for i in range(non_layers[non_idx])])
543 | non_idx += 1
544 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
545 | self.NL_3 = nn.ModuleList([Stripe_NonLocalBlock(stripes[non_idx],self.inplanes,self.inplanes//2) for i in range(non_layers[non_idx])])
546 | self.NL_3_idx =sorted( [layers[2]-(i+1) for i in range(non_layers[non_idx])])
547 | non_idx += 1
548 | self.layer4 = self._make_layer(block, 512, layers[3], stride=last_stride)
549 | self.NL_4 = nn.ModuleList([Stripe_NonLocalBlock(stripes[non_idx],self.inplanes,self.inplanes//2) for i in range(non_layers[non_idx])])
550 | self.NL_4_idx = sorted([layers[3]-(i+1) for i in range(non_layers[non_idx])])
551 |
552 | def _make_layer(self, block, planes, blocks, stride=1):
553 | downsample = None
554 | if stride != 1 or self.inplanes != planes * block.expansion:
555 | downsample = nn.Sequential(
556 | nn.Conv2d(self.inplanes, planes * block.expansion,
557 | kernel_size=1, stride=stride, bias=False),
558 | nn.BatchNorm2d(planes * block.expansion),
559 | )
560 | layers = []
561 | layers.append(block(self.inplanes, planes, stride, downsample))
562 | self.inplanes = planes * block.expansion
563 | for i in range(1, blocks):
564 | layers.append(block(self.inplanes, planes))
565 | return nn.ModuleList(layers)
566 |
567 | def forward(self, x):
568 | # x 's shape (B,T,C,H,W)
569 | B,T,C,H,W = x.shape
570 | x = x.reshape(B*T,C,H,W)
571 | x = self.conv1(x)
572 | x = self.bn1(x)
573 | x = self.maxpool(x)
574 | # x 's shape (B*T,C,H,W)
575 |
576 | # Layer 1
577 | NL1_counter = 0
578 | if len(self.NL_1_idx)==0: self.NL_1_idx=[-1]
579 | for i in range(len(self.layer1)):
580 | x = self.layer1[i](x)
581 | if i == self.NL_1_idx[NL1_counter]:
582 | _,C,H,W = x.shape
583 | x = x.reshape(-1,2,C,H,W).permute(0,2,1,3,4)
584 | x = self.NL_1[NL1_counter](x)
585 | x = x.permute(0,2,1,3,4).reshape(-1,C,H,W)
586 | # x's shape (B*T//2,2,C,H,W)
587 | NL1_counter+=1
588 | # Max pool
589 | # _,C,H,W = x.shape
590 | # x = torch.max(x.reshape(-1,2,C,H,W),dim=1)[0]
591 | # T = T//2
592 | # Layer 2
593 | NL2_counter = 0
594 | if len(self.NL_2_idx)==0: self.NL_2_idx=[-1]
595 | for i in range(len(self.layer2)):
596 | x = self.layer2[i](x)
597 | if i == self.NL_2_idx[NL2_counter]:
598 | _,C,H,W = x.shape
599 | x = x.reshape(-1,2,C,H,W).permute(0,2,1,3,4)
600 | x = self.NL_2[NL2_counter](x)
601 | x = x.permute(0,2,1,3,4).reshape(-1,C,H,W)
602 | # x's shape (B*T//2,2,C,H,W)
603 | NL2_counter+=1
604 | # Max pool
605 | _,C,H,W = x.shape
606 | x = torch.max(x.reshape(-1,2,C,H,W),dim=1)[0]
607 | T = T//2
608 | # Layer 3
609 | NL3_counter = 0
610 | if len(self.NL_3_idx)==0: self.NL_3_idx=[-1]
611 | for i in range(len(self.layer3)):
612 | x = self.layer3[i](x)
613 | if i == self.NL_3_idx[NL3_counter]:
614 | _,C,H,W = x.shape
615 | x = x.reshape(-1,2,C,H,W).permute(0,2,1,3,4)
616 | x = self.NL_3[NL3_counter](x)
617 | x = x.permute(0,2,1,3,4).reshape(-1,C,H,W)
618 | # x's shape (B*T//2,2,C,H,W)
619 | NL3_counter+=1
620 | # Max pool
621 | _,C,H,W = x.shape
622 | x = torch.max(x.reshape(-1,2,C,H,W),dim=1)[0]
623 | T = T//2
624 | # Layer 4
625 | NL4_counter = 0
626 | if len(self.NL_4_idx)==0: self.NL_4_idx=[-1]
627 | for i in range(len(self.layer4)):
628 | x = self.layer4[i](x)
629 | if i == self.NL_4_idx[NL4_counter]:
630 | _,C,H,W = x.shape
631 | x = x.reshape(-1,2,C,H,W).permute(0,2,1,3,4)
632 | x = self.NL_4[NL4_counter](x)
633 | x = x.permute(0,2,1,3,4).reshape(-1,C,H,W)
634 | NL4_counter+=1
635 | _,C,H,W = x.shape
636 | x = x.reshape(B,T,C,H,W).permute(0,2,1,3,4)
637 | # Return is (B,C,T,H,W)
638 | return x
639 | if __name__ == "__main__":
640 | net = ResNet(last_stride=1)
641 | print(net)
642 | import torch
643 |
644 | x = net(torch.zeros(1, 3, 256, 128))
645 | print(x.shape)
646 |
--------------------------------------------------------------------------------
/parser.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def parse_args():
4 | parser = argparse.ArgumentParser(description='Train Video-based Re-ID',formatter_class=argparse.ArgumentDefaultsHelpFormatter)
5 | parser.add_argument('--train_txt',help='txt for train dataset')
6 | parser.add_argument('--train_info',help='npy for train dataset')
7 | parser.add_argument('--test_txt',help='txt for test dataset')
8 | parser.add_argument('--test_info',help='npy for test dataset')
9 | parser.add_argument('--query_info',help='npy for test dataset')
10 | parser.add_argument('--lr',type=float,default=0.001,help='learning rate')
11 | parser.add_argument('--lr_step_size',type=int,default=100,help='step size of lr')
12 | parser.add_argument('--class_per_batch',type=int,default=16)
13 | parser.add_argument('--track_per_class',type=int,default=3)
14 | parser.add_argument('--batch_size',type=int,default=32)
15 | parser.add_argument('--n_epochs',type=int,default=500)
16 | parser.add_argument('--num_workers',type=int,default=16)
17 | parser.add_argument('--S',type=int,default=6)
18 | parser.add_argument('--latent_dim',type=int,default=2048,help='resnet50:2048,densenet121:1024,densenet169:1664')
19 | parser.add_argument('--load_ckpt',type=str,default=None)
20 | parser.add_argument('--log_path',type=str,default='loss.txt')
21 | parser.add_argument('--ckpt',type=str,default=None)
22 | parser.add_argument('--optimizer',type=str,default='adam')
23 | parser.add_argument('--resume_validation',type=bool,default=False)
24 | parser.add_argument('--model_type',type=str,default='resnet50')
25 | parser.add_argument('--stride',type=int,default=1)
26 | parser.add_argument('--temporal',default='mean')
27 | parser.add_argument('--frame_id_loss',action='store_true',default=False)
28 | parser.add_argument('--track_id_loss',action='store_true',default=False)
29 | parser.add_argument('--non_layers',type=int, nargs='+')
30 | parser.add_argument('--stripes',type=int, nargs='+')
31 |
32 |
33 | # parser.add_argument(
34 | args = parser.parse_args()
35 |
36 | return args
37 |
--------------------------------------------------------------------------------
/run_NL.sh:
--------------------------------------------------------------------------------
1 | TRAIN_TXT=./MARS_database/train_path.txt
2 | TRAIN_INFO=./MARS_database/train_info.npy
3 | TEST_TXT=./MARS_database/test_path.txt
4 | TEST_INFO=./MARS_database/test_info.npy
5 | QUERY_INFO=./MARS_database/query_IDX.npy
6 |
7 | # For NVAN
8 | CKPT=ckpt_NL_0230
9 | python3 train_NL.py --train_txt $TRAIN_TXT --train_info $TRAIN_INFO --batch_size 64 \
10 | --test_txt $TEST_TXT --test_info $TEST_INFO --query_info $QUERY_INFO \
11 | --n_epochs 200 --lr 0.0001 --lr_step_size 50 --optimizer adam --ckpt $CKPT --log_path loss.txt --class_per_batch 8 \
12 | --model_type 'resnet50_NL' --num_workers 8 --track_per_class 4 --S 8 --latent_dim 2048 --temporal Done --track_id_loss \
13 | --non_layers 0 2 3 0
14 |
15 | # For STE-NVAN
16 | #CKPT=ckpt_NL_stripe16_hr_0230
17 | #python3 train_NL.py --train_txt $TRAIN_TXT --train_info $TRAIN_INFO --batch_size 64 \
18 | #--test_txt $TEST_TXT --test_info $TEST_INFO --query_info $QUERY_INFO \
19 | #--n_epochs 200 --lr 0.0001 --lr_step_size 50 --optimizer adam --ckpt $CKPT --log_path loss.txt --class_per_batch 8 \
20 | #--model_type 'resnet50_NL_stripe_hr' --num_workers 8 --track_per_class 4 --S 8 --latent_dim 2048 --temporal Done --track_id_loss \
21 | #--non_layers 0 2 3 0 --stripes 16 16 16 16
22 |
--------------------------------------------------------------------------------
/run_baseline.sh:
--------------------------------------------------------------------------------
1 | TRAIN_TXT=./MARS_database/train_path.txt
2 | TRAIN_INFO=./MARS_database/train_info.npy
3 | TEST_TXT=./MARS_database/test_path.txt
4 | TEST_INFO=./MARS_database/test_info.npy
5 | QUERY_INFO=./MARS_database/query_IDX.npy
6 |
7 | CKPT=ckpt_baseline_mean
8 | python3 train_baseline.py --train_txt $TRAIN_TXT --train_info $TRAIN_INFO --batch_size 64 \
9 | --test_txt $TEST_TXT --test_info $TEST_INFO --query_info $QUERY_INFO \
10 | --n_epochs 300 --lr 0.0001 --lr_step_size 50 --optimizer adam --ckpt $CKPT --log_path loss.txt \
11 | --model_type 'resnet50_s1' --num_workers 8 --class_per_batch 8 --track_per_class 4 --S 8 \
12 | --latent_dim 2048 --temporal mean --track_id_loss --stride 1 \
13 |
--------------------------------------------------------------------------------
/run_evaluate.sh:
--------------------------------------------------------------------------------
1 | TEST_TXT=./MARS_database/test_path.txt
2 | TEST_INFO=./MARS_database/test_info.npy
3 | QUERY_INFO=./MARS_database/query_IDX.npy
4 |
5 | # Evaluate ResNet50 + FPL (mean or max)
6 | #LOAD_CKPT=./ckpt/R50_baseline_mean.pth
7 | #python3 evaluate.py --test_txt $TEST_TXT --test_info $TEST_INFO --query_info $QUERY_INFO \
8 | #--batch_size 64 --model_type 'resnet50_s1' --num_workers 8 --S 8 \
9 | #--latent_dim 2048 --temporal mean --stride 1 --load_ckpt $LOAD_CKPT
10 | #Evaluate NVAN (R50 + 5 NL + FPL)
11 | LOAD_CKPT=./ckpt/NVAN.pth
12 | python3 evaluate.py --test_txt $TEST_TXT --test_info $TEST_INFO --query_info $QUERY_INFO \
13 | --batch_size 64 --model_type 'resnet50_NL' --num_workers 8 --S 8 --latent_dim 2048 \
14 | --temporal Done --non_layers 0 2 3 0 --load_ckpt $LOAD_CKPT \
15 |
16 | # Evaluate NVAN (R50 + 5 NL + Stripe + Hierarchical + FPL)
17 | #LOAD_CKPT=./ckpt/STE_NVAN.pth
18 | #python3 evaluate.py --test_txt $TEST_TXT --test_info $TEST_INFO --query_info $QUERY_INFO \
19 | #--batch_size 128 --model_type 'resnet50_NL_stripe_hr' --num_workers 8 --S 8 --latent_dim 2048 \
20 | #--temporal Done --non_layers 0 2 3 0 --stripe 16 16 16 16 --load_ckpt $LOAD_CKPT \
21 |
--------------------------------------------------------------------------------
/train_NL.py:
--------------------------------------------------------------------------------
1 | from util import utils
2 | import parser
3 | from net import models
4 | import sys
5 | import random
6 | from tqdm import tqdm
7 | import numpy as np
8 | import math
9 | from util.loss import TripletLoss
10 | from util.cmc import Video_Cmc
11 |
12 | import torch
13 | import torch.nn as nn
14 | import torch.optim as optim
15 | from torchvision.transforms import Compose,ToTensor,Normalize,Resize
16 | import torch.backends.cudnn as cudnn
17 | cudnn.benchmark=True
18 | import os
19 | os.environ['CUDA_VISIBLE_DEVICES']='0'
20 | torch.multiprocessing.set_sharing_strategy('file_system')
21 |
22 |
23 | def validation(network,dataloader,args):
24 | network.eval()
25 | pbar = tqdm(total=len(dataloader),ncols=100,leave=True)
26 | pbar.set_description('Inference')
27 | gallery_features = []
28 | gallery_labels = []
29 | gallery_cams = []
30 | with torch.no_grad():
31 | for c,data in enumerate(dataloader):
32 | seqs = data[0].cuda()
33 | label = data[1]
34 | cams = data[2]
35 |
36 | B,C,H,W = seqs.shape
37 | seqs = seqs.reshape(B//args.S,args.S,C,H,W)
38 | feat = network(seqs)#.cpu().numpy() #[xx,128]
39 | if args.temporal == 'max':
40 | feat = torch.max(feat.reshape(feat.shape[0]//args.S,args.S,-1),dim=1)[0]
41 | elif args.temporal == 'mean':
42 | feat = torch.mean(feat.reshape(feat.shape[0]//args.S,args.S,-1),dim=1)
43 | elif args.temporal in ['Done'] :
44 | feat = feat
45 |
46 | gallery_features.append(feat.cpu())
47 | gallery_labels.append(label)
48 | gallery_cams.append(cams)
49 | pbar.update(1)
50 | pbar.close()
51 |
52 | gallery_features = torch.cat(gallery_features,dim=0).numpy()
53 | gallery_labels = torch.cat(gallery_labels,dim=0).numpy()
54 | gallery_cams = torch.cat(gallery_cams,dim=0).numpy()
55 |
56 | Cmc,mAP = Video_Cmc(gallery_features,gallery_labels,gallery_cams,dataloader.dataset.query_idx,10000)
57 | network.train()
58 |
59 | return Cmc[0],mAP
60 |
61 |
62 | if __name__ == '__main__':
63 | #Parse args
64 | args = parser.parse_args()
65 |
66 | # set transformation (H flip is inside dataset)
67 | train_transform = Compose([Resize((256,128)),ToTensor(),Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])])
68 | test_transform = Compose([Resize((256,128)),ToTensor(),Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])])
69 | print('Start dataloader...')
70 | train_dataloader = utils.Get_Video_train_DataLoader(args.train_txt,args.train_info, train_transform, shuffle=True,num_workers=args.num_workers,\
71 | S=args.S,track_per_class=args.track_per_class,class_per_batch=args.class_per_batch)
72 | num_class = train_dataloader.dataset.n_id
73 | test_dataloader = utils.Get_Video_test_DataLoader(args.test_txt,args.test_info,args.query_info,test_transform,batch_size=args.batch_size,\
74 | shuffle=False,num_workers=args.num_workers,S=args.S,distractor=True)
75 | print('End dataloader...')
76 |
77 | network = nn.DataParallel(models.CNN(args.latent_dim,model_type=args.model_type,num_class=num_class,non_layers=args.non_layers,stripes=args.stripes,temporal=args.temporal).cuda())
78 | if args.load_ckpt is not None:
79 | state = torch.load(args.load_ckpt)
80 | network.load_state_dict(state,strict=False)
81 | # log
82 | os.system('mkdir -p %s'%(args.ckpt))
83 | f = open(os.path.join(args.ckpt,args.log_path),'a')
84 | f.close()
85 |
86 | # Train loop
87 | # 1. Criterion
88 | criterion_triplet = TripletLoss('soft',True)
89 |
90 | critetion_id = nn.CrossEntropyLoss().cuda()
91 | # 2. Optimizer
92 | if args.optimizer == 'sgd':
93 | optimizer = optim.SGD(network.parameters(),lr = args.lr,momentum=0.9,weight_decay = 1e-4)
94 | else:
95 | optimizer = optim.Adam(network.parameters(),lr = args.lr,weight_decay = 5e-5)
96 | if args.lr_step_size != 0:
97 | scheduler = optim.lr_scheduler.StepLR(optimizer, args.lr_step_size, 0.1)
98 |
99 | id_loss_list = []
100 | trip_loss_list = []
101 | track_id_loss_list = []
102 | best_cmc = 0
103 | for e in range(args.n_epochs):
104 | print('epoch',e)
105 | if (e+1)%10 == 0:
106 | cmc,map = validation(network,test_dataloader,args)
107 | print('CMC: %.4f, mAP : %.4f'%(cmc,map))
108 | f = open(os.path.join(args.ckpt,args.log_path),'a')
109 | f.write('epoch %d, rank-1 %f , mAP %f\n'%(e,cmc,map))
110 | if args.frame_id_loss:
111 | f.write('Frame ID loss : %r\n'%(id_loss_list))
112 | if args.track_id_loss:
113 | f.write('Track ID loss : %r\n'%(track_id_loss_list))
114 | f.write('Trip Loss : %r\n'%(trip_loss_list))
115 |
116 | id_loss_list = []
117 | trip_loss_list = []
118 | track_id_loss_list = []
119 | if cmc >= best_cmc:
120 | torch.save(network.state_dict(),os.path.join(args.ckpt,'ckpt_best.pth'))
121 | best_cmc = cmc
122 | f.write('best\n')
123 | f.close()
124 |
125 | total_id_loss = 0
126 | total_trip_loss = 0
127 | total_track_id_loss = 0
128 | pbar = tqdm(total=len(train_dataloader),ncols=100,leave=True)
129 | for i,data in enumerate(train_dataloader):
130 | seqs = data[0]#.cuda()
131 | labels = data[1].cuda()
132 | B,T,C,H,W = seqs.shape
133 | feat, output = network(seqs)
134 |
135 | if args.temporal == 'max':
136 | pool_feat = torch.max(feat.reshape(feat.shape[0]//args.S,args.S,-1),dim=1)[0]
137 | pool_output = torch.max(output.reshape(output.shape[0]//args.S,args.S,-1),dim=1)[0]
138 | elif args.temporal == 'mean':
139 | pool_feat = torch.mean(feat.reshape(feat.shape[0]//args.S,args.S,-1),dim=1)
140 | pool_output = torch.mean(output.reshape(output.shape[0]//args.S,args.S,-1),dim=1)
141 | elif args.temporal in ['Done'] :
142 | pool_feat = feat
143 | pool_output = output
144 |
145 | trip_loss = criterion_triplet(pool_feat,labels,dis_func='eu')
146 | total_trip_loss += trip_loss.mean().item()
147 | total_loss = trip_loss.mean()
148 |
149 | # Frame level ID loss
150 | if args.frame_id_loss == True:
151 | expand_labels = (labels.unsqueeze(1)).repeat(1,args.S).reshape(-1)
152 | id_loss = critetion_id(output,expand_labels)
153 | total_id_loss += id_loss.item()
154 | coeff = 1
155 | total_loss += coeff*id_loss
156 | if args.track_id_loss == True:
157 | track_id_loss = critetion_id(pool_output,labels)
158 | total_track_id_loss += track_id_loss.item()
159 | coeff = 1
160 | total_loss += coeff*track_id_loss
161 |
162 |
163 | #####################
164 | optimizer.zero_grad()
165 | total_loss.backward()
166 | optimizer.step()
167 | pbar.update(1)
168 | pbar.close()
169 |
170 | if args.lr_step_size !=0:
171 | scheduler.step()
172 |
173 | avg_id_loss = '%.4f'%(total_id_loss/len(train_dataloader))
174 | avg_trip_loss = '%.4f'%(total_trip_loss/len(train_dataloader))
175 | avg_track_id_loss = '%.4f'%(total_track_id_loss/len(train_dataloader))
176 | print('Trip : %s , ID : %s , Track_ID : %s'%(avg_trip_loss,avg_id_loss,avg_track_id_loss))
177 | id_loss_list.append(avg_id_loss)
178 | trip_loss_list.append(avg_trip_loss)
179 | track_id_loss_list.append(avg_track_id_loss)
180 |
--------------------------------------------------------------------------------
/train_baseline.py:
--------------------------------------------------------------------------------
1 | from util import utils
2 | import parser
3 | from net import models
4 | import sys
5 | import random
6 | from tqdm import tqdm
7 | import numpy as np
8 | import math
9 | from util.loss import TripletLoss
10 | from util.cmc import Video_Cmc
11 |
12 | import torch
13 | import torch.nn as nn
14 | import torch.optim as optim
15 | from torchvision.transforms import Compose,ToTensor,Normalize,Resize
16 | import torch.backends.cudnn as cudnn
17 | cudnn.benchmark=True
18 | import os
19 | os.environ['CUDA_VISIBLE_DEVICES']='0'
20 | torch.multiprocessing.set_sharing_strategy('file_system')
21 |
22 |
23 | def validation(network,dataloader,args):
24 | network.eval()
25 | pbar = tqdm(total=len(dataloader),ncols=100,leave=True)
26 | pbar.set_description('Inference')
27 | gallery_features = []
28 | gallery_labels = []
29 | gallery_cams = []
30 | with torch.no_grad():
31 | for c,data in enumerate(dataloader):
32 | seqs = data[0].cuda()
33 | label = data[1]
34 | cams = data[2]
35 |
36 | feat = network(seqs)#.cpu().numpy() #[xx,128]
37 | if args.temporal == 'max':
38 | feat = torch.max(feat.reshape(feat.shape[0]//args.S,args.S,-1),dim=1)[0]
39 | elif args.temporal == 'mean':
40 | feat = torch.mean(feat.reshape(feat.shape[0]//args.S,args.S,-1),dim=1)
41 | elif args.temporal =='Done':
42 | feat = feat
43 |
44 | gallery_features.append(feat.cpu())
45 | gallery_labels.append(label)
46 | gallery_cams.append(cams)
47 | pbar.update(1)
48 | pbar.close()
49 |
50 | gallery_features = torch.cat(gallery_features,dim=0).numpy()
51 | gallery_labels = torch.cat(gallery_labels,dim=0).numpy()
52 | gallery_cams = torch.cat(gallery_cams,dim=0).numpy()
53 |
54 | Cmc,mAP = Video_Cmc(gallery_features,gallery_labels,gallery_cams,dataloader.dataset.query_idx,10000)
55 | network.train()
56 |
57 | return Cmc[0],mAP
58 |
59 |
60 |
61 | if __name__ == '__main__':
62 | #Parse args
63 | args = parser.parse_args()
64 |
65 | # set transformation (H flip is inside dataset)
66 | train_transform = Compose([Resize((256,128)),ToTensor(),Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])])
67 | test_transform = Compose([Resize((256,128)),ToTensor(),Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])])
68 |
69 | print('Start dataloader...')
70 | train_dataloader = utils.Get_Video_train_DataLoader(args.train_txt,args.train_info, train_transform, shuffle=True,num_workers=args.num_workers,\
71 | S=args.S,track_per_class=args.track_per_class,class_per_batch=args.class_per_batch)
72 | num_class = train_dataloader.dataset.n_id
73 | test_dataloader = utils.Get_Video_test_DataLoader(args.test_txt,args.test_info,args.query_info,test_transform,batch_size=args.batch_size,\
74 | shuffle=False,num_workers=args.num_workers,S=args.S,distractor=True)
75 | print('End dataloader...\n')
76 |
77 | network = nn.DataParallel(models.CNN(args.latent_dim,model_type=args.model_type,num_class=num_class,stride=args.stride).cuda())
78 |
79 | if args.load_ckpt is not None:
80 | state = torch.load(args.load_ckpt)
81 | network.load_state_dict(state)
82 |
83 | # log
84 | os.system('mkdir -p %s'%(args.ckpt))
85 | f = open(os.path.join(args.ckpt,args.log_path),'a')
86 | f.close()
87 | # Train loop
88 | # 1. Criterion
89 | criterion_triplet = TripletLoss('soft',True)
90 |
91 | criterion_ID = nn.CrossEntropyLoss().cuda()
92 | # 2. Optimizer
93 | if args.optimizer == 'sgd':
94 | optimizer = optim.SGD(network.parameters(),lr = args.lr,momentum=0.9,weight_decay = 1e-4)
95 | else:
96 | optimizer = optim.Adam(network.parameters(),lr = args.lr,weight_decay = 1e-5)
97 | if args.lr_step_size != 0:
98 | scheduler = optim.lr_scheduler.StepLR(optimizer, args.lr_step_size, 0.1)
99 |
100 | id_loss_list = []
101 | trip_loss_list = []
102 | track_id_loss_list = []
103 |
104 | best_cmc = 0
105 | for e in range(args.n_epochs):
106 | print('Epoch',e)
107 | # Validation
108 | if (e+1)%10 == 0:
109 | cmc,map = validation(network,test_dataloader,args)
110 | print('CMC: %.4f, mAP : %.4f'%(cmc,map))
111 | f = open(os.path.join(args.ckpt,args.log_path),'a')
112 | f.write('epoch %d, rank-1 %f , mAP %f\n'%(e,cmc,map))
113 | if args.frame_id_loss:
114 | f.write('Frame ID loss : %r\n'%(id_loss_list))
115 | if args.track_id_loss:
116 | f.write('Track ID loss : %r\n'%(track_id_loss_list))
117 | f.write('Trip Loss : %r\n'%(trip_loss_list))
118 |
119 | id_loss_list = []
120 | trip_loss_list = []
121 | track_id_loss_list = []
122 | if cmc >= best_cmc:
123 | torch.save(network.state_dict(),os.path.join(args.ckpt,'ckpt_best.pth'))
124 | best_cmc = cmc
125 | f.write('best\n')
126 | f.close()
127 | # Training
128 | total_id_loss = 0
129 | total_trip_loss = 0
130 | total_track_id_loss = 0
131 | pbar = tqdm(total=len(train_dataloader),ncols=100,leave=True)
132 | for i,data in enumerate(train_dataloader):
133 | seqs = data[0]#.cuda()
134 | labels = data[1].cuda()
135 | seqs = seqs.reshape((seqs.shape[0]*seqs.shape[1],)+seqs.shape[2:]).cuda()
136 | feat, output = network(seqs)
137 |
138 | if args.temporal == 'max':
139 | pool_feat = torch.max(feat.reshape(feat.shape[0]//args.S,args.S,-1),dim=1)[0]
140 | pool_output = torch.max(output.reshape(output.shape[0]//args.S,args.S,-1),dim=1)[0]
141 | elif args.temporal == 'mean':
142 | pool_feat = torch.mean(feat.reshape(feat.shape[0]//args.S,args.S,-1),dim=1)
143 | pool_output = torch.mean(output.reshape(output.shape[0]//args.S,args.S,-1),dim=1)
144 | elif args.temporal == 'Done':
145 | pool_feat = feat
146 | pool_output = output
147 |
148 | trip_loss = criterion_triplet(pool_feat,labels,dis_func='eu')
149 | total_trip_loss += trip_loss.mean().item()
150 | total_loss = trip_loss.mean()
151 |
152 | # Frame level ID loss
153 | if args.frame_id_loss == True:
154 | expand_labels = (labels.unsqueeze(1)).repeat(1,args.S).reshape(-1)
155 | id_loss = criterion_ID(output,expand_labels)
156 | total_id_loss += id_loss.item()
157 | coeff = 1
158 | total_loss += coeff*id_loss
159 | if args.track_id_loss == True:
160 | track_id_loss = criterion_ID(pool_output,labels)
161 | total_track_id_loss += track_id_loss.item()
162 | coeff = 1
163 | total_loss += coeff*track_id_loss
164 |
165 | #####################
166 | optimizer.zero_grad()
167 | total_loss.backward()
168 | optimizer.step()
169 | pbar.update(1)
170 | pbar.close()
171 |
172 | if args.lr_step_size !=0:
173 | scheduler.step()
174 |
175 | avg_id_loss = '%.4f'%(total_id_loss/len(train_dataloader))
176 | avg_trip_loss = '%.4f'%(total_trip_loss/len(train_dataloader))
177 | avg_track_id_loss = '%.4f'%(total_track_id_loss/len(train_dataloader))
178 | print('Trip : %s , ID : %s , Track_ID : %s'%(avg_trip_loss,avg_id_loss,avg_track_id_loss))
179 | id_loss_list.append(avg_id_loss)
180 | trip_loss_list.append(avg_trip_loss)
181 | track_id_loss_list.append(avg_track_id_loss)
182 |
--------------------------------------------------------------------------------
/util/cmc.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | import sys
5 | import pandas as pd
6 | from progressbar import ProgressBar, AnimatedMarker, Percentage
7 | import math
8 | from tqdm import trange
9 |
10 |
11 | def Video_Cmc(features, ids, cams, query_idx,rank_size):
12 | """
13 | features: numpy array of shape (n, d)
14 | label`s: numpy array of shape (n)
15 | """
16 | # Sample query
17 | data = {'feature':features, 'id':ids, 'cam':cams}
18 | q_idx = query_idx
19 | g_idx = np.arange(len(ids))
20 | q_data = {k:v[q_idx] for k, v in data.items()}
21 | g_data = {k:v[g_idx] for k, v in data.items()}
22 | if len(g_idx) < rank_size: rank_size = len(g_idx)
23 |
24 | CMC, mAP = Cmc(q_data, g_data, rank_size)
25 |
26 | return CMC, mAP
27 |
28 |
29 | def Cmc(q_data, g_data, rank_size):
30 | n_query = q_data['feature'].shape[0]
31 | n_gallery = g_data['feature'].shape[0]
32 |
33 | dist = np_cdist(q_data['feature'], g_data['feature']) # Reture a n_query*n_gallery array
34 |
35 | cmc = np.zeros((n_query, rank_size))
36 | ap = np.zeros(n_query)
37 |
38 | widgets = ["I'm calculating cmc! ", AnimatedMarker(markers='←↖↑↗→↘↓↙'), ' (', Percentage(), ')']
39 | pbar = ProgressBar(widgets=widgets, max_value=n_query)
40 | for k in range(n_query):
41 | good_idx = np.where((q_data['id'][k]==g_data['id']) & (q_data['cam'][k]!=g_data['cam']))[0]
42 | junk_mask1 = (g_data['id'] == -1)
43 | junk_mask2 = (q_data['id'][k]==g_data['id']) & (q_data['cam'][k]==g_data['cam'])
44 | junk_idx = np.where(junk_mask1 | junk_mask2)[0]
45 | score = dist[k, :]
46 | sort_idx = np.argsort(score)
47 | sort_idx = sort_idx[:rank_size]
48 |
49 | ap[k], cmc[k, :] = Compute_AP(good_idx, junk_idx, sort_idx)
50 | pbar.update(k)
51 | pbar.finish()
52 | CMC = np.mean(cmc, axis=0)
53 | mAP = np.mean(ap)
54 | return CMC, mAP
55 |
56 | def Compute_AP(good_image, junk_image, index):
57 | cmc = np.zeros((len(index),))
58 | ngood = len(good_image)
59 |
60 | old_recall = 0
61 | old_precision = 1.
62 | ap = 0
63 | intersect_size = 0
64 | j = 0
65 | good_now = 0
66 | njunk = 0
67 | for n in range(len(index)):
68 | flag = 0
69 | if np.any(good_image == index[n]):
70 | cmc[n-njunk:] = 1
71 | flag = 1 # good image
72 | good_now += 1
73 | if np.any(junk_image == index[n]):
74 | njunk += 1
75 | continue # junk image
76 |
77 | if flag == 1:
78 | intersect_size += 1
79 | recall = intersect_size/ngood
80 | precision = intersect_size/(j+1)
81 | ap += (recall-old_recall) * (old_precision+precision) / 2
82 | old_recall = recall
83 | old_precision = precision
84 | j += 1
85 |
86 | if good_now == ngood:
87 | return ap, cmc
88 | return ap, cmc
89 |
90 |
91 | def cdist(feat1, feat2):
92 | """Cosine distance"""
93 | feat1 = torch.FloatTensor(feat1)#.cuda()
94 | feat2 = torch.FloatTensor(feat2)#.cuda()
95 | feat1 = torch.nn.functional.normalize(feat1, dim=1)
96 | feat2 = torch.nn.functional.normalize(feat2, dim=1).transpose(0, 1)
97 | dist = -1 * torch.mm(feat1, feat2)
98 | return dist.cpu().numpy()
99 |
100 | def np_cdist(feat1, feat2):
101 | """Cosine distance"""
102 | feat1_u = feat1 / np.linalg.norm(feat1, axis=1, keepdims=True) # n * d -> n
103 | feat2_u = feat2 / np.linalg.norm(feat2, axis=1, keepdims=True) # n * d -> n
104 | return -1 * np.dot(feat1_u, feat2_u.T)
105 |
106 | def np_norm_eudist(feat1,feat2):
107 | feat1_u = feat1 / np.linalg.norm(feat1, axis=1, keepdims=True) # n * d -> n
108 | feat2_u = feat2 / np.linalg.norm(feat2, axis=1, keepdims=True) # n * d -> n
109 | feat1_sq = np.sum(feat1_M * feat1, axis=1)
110 | feat2_sq = np.sum(feat2_M * feat2, axis=1)
111 | return np.sqrt(feat1_sq.reshape(-1,1) + feat2_sq.reshape(1,-1) - 2*np.dot(feat1_M, feat2.T)+ 1e-12)
112 |
113 |
114 | def sqdist(feat1, feat2, M=None):
115 | """Mahanalobis/Euclidean distance"""
116 | if M is None: M = np.eye(feat1.shape[1])
117 | feat1_M = np.dot(feat1, M)
118 | feat2_M = np.dot(feat2, M)
119 | feat1_sq = np.sum(feat1_M * feat1, axis=1)
120 | feat2_sq = np.sum(feat2_M * feat2, axis=1)
121 | return feat1_sq.reshape(-1,1) + feat2_sq.reshape(1,-1) - 2*np.dot(feat1_M, feat2.T)
122 |
123 | if __name__ == '__main__':
124 | from scipy.io import loadmat
125 | q_feature = loadmat(sys.argv[1])['ff']
126 | q_db_txt = sys.argv[2]
127 | g_feature = loadmat(sys.argv[3])['ff']
128 | g_db_txt = sys.argv[4]
129 | #print(feature.shape)
130 | CMC, mAP = Self_Cmc(g_feature, g_db_txt, 100)
131 | #CMC, mAP = Vanilla_Cmc(q_feature, q_db_txt, g_feature, g_db_txt)
132 | print('r1 precision = %f, mAP = %f' % (CMC[0], mAP))
133 |
--------------------------------------------------------------------------------
/util/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from torch.autograd import Variable
5 |
6 | class TripletLoss(nn.Module):
7 |
8 | def __init__(self, margin=0, batch_hard=False,dim=2048):
9 | super(TripletLoss, self).__init__()
10 | self.batch_hard = batch_hard
11 | if isinstance(margin, float) or margin == 'soft':
12 | self.margin = margin
13 | else:
14 | raise NotImplementedError(
15 | 'The margin {} is not recognized in TripletLoss()'.format(margin))
16 |
17 | def forward(self, feat, id=None, pos_mask=None, neg_mask=None, mode='id',dis_func='eu',n_dis=0):
18 |
19 | if dis_func == 'cdist':
20 | feat = feat / feat.norm(p=2,dim=1,keepdim=True)
21 | dist = self.cdist(feat, feat)
22 | elif dis_func == 'eu':
23 | dist = self.cdist(feat, feat)
24 |
25 | if mode == 'id':
26 | if id is None:
27 | raise RuntimeError('foward is in id mode, please input id!')
28 | else:
29 | identity_mask = torch.eye(feat.size(0)).byte()
30 | identity_mask = identity_mask.cuda() if id.is_cuda else identity_mask
31 | same_id_mask = torch.eq(id.unsqueeze(1), id.unsqueeze(0))
32 | negative_mask = same_id_mask ^ 1
33 | positive_mask = same_id_mask ^ identity_mask
34 | elif mode == 'mask':
35 | if pos_mask is None or neg_mask is None:
36 | raise RuntimeError('foward is in mask mode, please input pos_mask & neg_mask!')
37 | else:
38 | positive_mask = pos_mask
39 | same_id_mask = neg_mask ^ 1
40 | negative_mask = neg_mask
41 | else:
42 | raise ValueError('unrecognized mode')
43 |
44 | if self.batch_hard:
45 | if n_dis != 0:
46 | img_dist = dist[:-n_dis,:-n_dis]
47 | max_positive = (img_dist * positive_mask[:-n_dis,:-n_dis].float()).max(1)[0]
48 | min_negative = (img_dist + 1e5*same_id_mask[:-n_dis,:-n_dis].float()).min(1)[0]
49 | dis_min_negative = dist[:-n_dis,-n_dis:].min(1)[0]
50 | z_origin = max_positive - min_negative
51 | # z_dis = max_positive - dis_min_negative
52 | else:
53 | max_positive = (dist * positive_mask.float()).max(1)[0]
54 | min_negative = (dist + 1e5*same_id_mask.float()).min(1)[0]
55 | z = max_positive - min_negative
56 | else:
57 | pos = positive_mask.topk(k=1, dim=1)[1].view(-1,1)
58 | positive = torch.gather(dist, dim=1, index=pos)
59 | pos = negative_mask.topk(k=1, dim=1)[1].view(-1,1)
60 | negative = torch.gather(dist, dim=1, index=pos)
61 | z = positive - negative
62 |
63 | if isinstance(self.margin, float):
64 | b_loss = torch.clamp(z + self.margin, min=0)
65 | elif self.margin == 'soft':
66 | if n_dis != 0:
67 | b_loss = torch.log(1+torch.exp(z_origin))+ -0.5* dis_min_negative# + torch.log(1+torch.exp(z_dis))
68 | else:
69 | b_loss = torch.log(1 + torch.exp(z))
70 | else:
71 | raise NotImplementedError("How do you even get here!")
72 | return b_loss
73 |
74 | def cdist(self, a, b):
75 | '''
76 | Returns euclidean distance between a and b
77 |
78 | Args:
79 | a (2D Tensor): A batch of vectors shaped (B1, D)
80 | b (2D Tensor): A batch of vectors shaped (B2, D)
81 | Returns:
82 | A matrix of all pairwise distance between all vectors in a and b,
83 | will be shape of (B1, B2)
84 | '''
85 | diff = a.unsqueeze(1) - b.unsqueeze(0)
86 | return ((diff**2).sum(2)+1e-12).sqrt()
87 |
88 |
89 | class ClusterLoss(nn.Module):
90 | def __init__(self, margin=0, batch_hard=False):
91 | super(ClusterLoss, self).__init__()
92 | self.batch_hard = batch_hard
93 | if isinstance(margin, float) or margin == 'soft':
94 | self.margin = margin
95 | else:
96 | raise NotImplementedError(
97 | 'The margin {} is not recognized in TripletLoss()'.format(margin))
98 |
99 | def forward(self, feat, id=None, mode='id',dis_func='eu',n_dis=0):
100 |
101 | # feat = feat.reshape(-1,1024)
102 | # diff = feat.unsqueeze(1)-feat.unsqueeze(0)
103 | # diff = ((diff**2).sum(2)+1e-14).sqrt()
104 | mean = torch.mean(feat,dim=1,keepdim=True) # 8,1,1024
105 | f2m_dist = (torch.sum((feat - mean.repeat(1,feat.shape[1],1))**2,dim=2)+1e-14).sqrt()
106 | m2m_dist = (((mean-mean.permute(1,0,2))**2).sum(2)+1e-14).sqrt()
107 |
108 | max_positive = torch.max(f2m_dist,dim=1)[0]
109 | identity_mask = torch.eye(mean.shape[0]).cuda()
110 | min_negative = torch.min(m2m_dist+1e5*identity_mask,dim=1)[0]
111 | z = max_positive - min_negative
112 |
113 |
114 | if isinstance(self.margin, float):
115 | b_loss = torch.clamp(z + self.margin, min=0)
116 | elif self.margin == 'soft':
117 | if n_dis != 0:
118 | b_loss = torch.log(1+torch.exp(z_origin))+ -0.5* dis_min_negative# + torch.log(1+torch.exp(z_dis))
119 | else:
120 | b_loss = torch.log(1 + torch.exp(z))
121 | else:
122 | raise NotImplementedError("How do you even get here!")
123 | return b_loss
124 |
125 | if __name__ == '__main__':
126 | criterion0 = TripletLoss(margin=0.5, batch_hard=False)
127 | criterion1 = TripletLoss(margin=0.5, batch_hard=True)
128 |
129 | t = np.random.randint(3, size=(10,))
130 | print(t)
131 |
132 | feat = Variable(torch.rand(10, 2048), requires_grad=True).cuda()
133 | id = Variable(torch.from_numpy(t), requires_grad=True).cuda()
134 | loss0 = criterion0(feat, id)
135 | loss1 = criterion1(feat, id)
136 | print('no batch hard:', loss0)
137 | print('batch hard:', loss1)
138 | loss0.backward()
139 | loss1.backward()
140 |
--------------------------------------------------------------------------------
/util/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 | import numpy as np
5 | import pandas as pd
6 | import collections
7 | import random
8 | import math
9 | ## For torch lib
10 | import torch
11 | from torch.utils.data import Dataset, DataLoader
12 | from torch.utils.data.sampler import SubsetRandomSampler
13 | import torchvision.transforms as T
14 | import torch.nn.functional as F
15 | ## For Image lib
16 | from PIL import Image
17 |
18 | '''
19 | For MARS,Video-based Re-ID
20 | '''
21 | def process_labels(labels):
22 | unique_id = np.unique(labels)
23 | id_count = len(unique_id)
24 | id_dict = {ID:i for i, ID in enumerate(unique_id.tolist())}
25 | for i in range(len(labels)):
26 | labels[i] = id_dict[labels[i]]
27 | assert len(unique_id)-1 == np.max(labels)
28 | return labels,id_count
29 |
30 | class Video_train_Dataset(Dataset):
31 | def __init__(self,db_txt,info,transform,S=6,track_per_class=4,flip_p=0.5,delete_one_cam=False,cam_type='normal'):
32 | with open(db_txt,'r') as f:
33 | self.imgs = np.array(f.read().strip().split('\n'))
34 | # For info (id,track)
35 | if delete_one_cam == True:
36 | info = np.load(info)
37 | info[:,2],id_count = process_labels(info[:,2])
38 | for i in range(id_count):
39 | idx = np.where(info[:,2]==i)[0]
40 | if len(np.unique(info[idx,3])) ==1:
41 | info = np.delete(info,idx,axis=0)
42 | id_count -=1
43 | info[:,2],id_count = process_labels(info[:,2])
44 | #change from 625 to 619
45 | else:
46 | info = np.load(info)
47 | info[:,2],id_count = process_labels(info[:,2])
48 |
49 | self.info = []
50 | for i in range(len(info)):
51 | sample_clip = []
52 | F = info[i][1]-info[i][0]+1
53 | if F < S:
54 | strip = list(range(info[i][0],info[i][1]+1))+[info[i][1]]*(S-F)
55 | for s in range(S):
56 | pool = strip[s*1:(s+1)*1]
57 | sample_clip.append(list(pool))
58 | else:
59 | interval = math.ceil(F/S)
60 | strip = list(range(info[i][0],info[i][1]+1))+[info[i][1]]*(interval*S-F)
61 | for s in range(S):
62 | pool = strip[s*interval:(s+1)*interval]
63 | sample_clip.append(list(pool))
64 | self.info.append(np.array([np.array(sample_clip),info[i][2],info[i][3]]))
65 |
66 | self.info = np.array(self.info)
67 | self.transform = transform
68 | self.n_id = id_count
69 | self.n_tracklets = self.info.shape[0]
70 | self.flip_p = flip_p
71 | self.track_per_class = track_per_class
72 | self.cam_type = cam_type
73 | self.two_cam = False
74 | self.cross_cam = False
75 |
76 | def __getitem__(self,ID):
77 | sub_info = self.info[self.info[:,1] == ID]
78 |
79 | if self.cam_type == 'normal':
80 | tracks_pool = list(np.random.choice(sub_info[:,0],self.track_per_class))
81 | elif self.cam_type == 'two_cam':
82 | unique_cam = np.random.permutation(np.unique(sub_info[:,2]))[:2]
83 | tracks_pool = list(np.random.choice(sub_info[sub_info[:,2]==unique_cam[0],0],1))+\
84 | list(np.random.choice(sub_info[sub_info[:,2]==unique_cam[1],0],1))
85 | elif self.cam_type == 'cross_cam':
86 | unique_cam = np.random.permutation(np.unique(sub_info[:,2]))
87 | while len(unique_cam) < self.track_per_class:
88 | unique_cam = np.append(unique_cam,unique_cam)
89 | unique_cam = unique_cam[:self.track_per_class]
90 | tracks_pool = []
91 | for i in range(self.track_per_class):
92 | tracks_pool += list(np.random.choice(sub_info[sub_info[:,2]==unique_cam[i],0],1))
93 |
94 | one_id_tracks = []
95 | for track_pool in tracks_pool:
96 | idx = np.random.choice(track_pool.shape[1],track_pool.shape[0])
97 | number = track_pool[np.arange(len(track_pool)),idx]
98 | imgs = [self.transform(Image.open(path)) for path in self.imgs[number]]
99 | imgs = torch.stack(imgs,dim=0)
100 |
101 | random_p = random.random()
102 | if random_p < self.flip_p:
103 | imgs = torch.flip(imgs,dims=[3])
104 | one_id_tracks.append(imgs)
105 | return torch.stack(one_id_tracks,dim=0), ID*torch.ones(self.track_per_class,dtype=torch.int64)
106 |
107 | def __len__(self):
108 | return self.n_id
109 |
110 | def Video_train_collate_fn(data):
111 | if isinstance(data[0],collections.Mapping):
112 | t_data = [tuple(d.values()) for d in data]
113 | values = MARS_collate_fn(t_data)
114 | return {key:value for key,value in zip(data[0].keys(),values)}
115 | else:
116 | imgs,labels = zip(*data)
117 | imgs = torch.cat(imgs,dim=0)
118 | labels = torch.cat(labels,dim=0)
119 | return imgs,labels
120 |
121 | def Get_Video_train_DataLoader(db_txt,info,transform,shuffle=True,num_workers=8,S=10,track_per_class=4,class_per_batch=8):
122 | dataset = Video_train_Dataset(db_txt,info,transform,S,track_per_class)
123 | dataloader = DataLoader(dataset,batch_size=class_per_batch,collate_fn=Video_train_collate_fn,shuffle=shuffle,worker_init_fn=lambda _:np.random.seed(),drop_last=True,num_workers=num_workers)
124 | return dataloader
125 |
126 | class Video_test_Dataset(Dataset):
127 | def __init__(self,db_txt,info,query,transform,S=6,distractor=True):
128 | with open(db_txt,'r') as f:
129 | self.imgs = np.array(f.read().strip().split('\n'))
130 | # info
131 | info = np.load(info)
132 | self.info = []
133 | for i in range(len(info)):
134 | if distractor == False and info[i][2]==0:
135 | continue
136 | sample_clip = []
137 | F = info[i][1]-info[i][0]+1
138 | if F < S:
139 | strip = list(range(info[i][0],info[i][1]+1))+[info[i][1]]*(S-F)
140 | for s in range(S):
141 | pool = strip[s*1:(s+1)*1]
142 | sample_clip.append(list(pool))
143 | else:
144 | interval = math.ceil(F/S)
145 | strip = list(range(info[i][0],info[i][1]+1))+[info[i][1]]*(interval*S-F)
146 | for s in range(S):
147 | pool = strip[s*interval:(s+1)*interval]
148 | sample_clip.append(list(pool))
149 | self.info.append(np.array([np.array(sample_clip),info[i][2],info[i][3]]))
150 |
151 | self.info = np.array(self.info)
152 | self.transform = transform
153 | self.n_id = len(np.unique(self.info[:,1]))
154 | self.n_tracklets = self.info.shape[0]
155 | self.query_idx = np.load(query).reshape(-1)
156 |
157 | if distractor == False:
158 | zero = np.where(info[:,2]==0)[0]
159 | self.new_query = []
160 | for i in self.query_idx:
161 | if i < zero[0]:
162 | self.new_query.append(i)
163 | elif i <= zero[-1]:
164 | continue
165 | elif i > zero[-1]:
166 | self.new_query.append(i-len(zero))
167 | else:
168 | continue
169 | self.query_idx = np.array(self.new_query)
170 |
171 | def __getitem__(self,idx):
172 | clips = self.info[idx,0]
173 | imgs = [self.transform(Image.open(path)) for path in self.imgs[clips[:,0]]]
174 | imgs = torch.stack(imgs,dim=0)
175 | label = self.info[idx,1]*torch.ones(1,dtype=torch.int32)
176 | cam = self.info[idx,2]*torch.ones(1,dtype=torch.int32)
177 | return imgs,label,cam
178 | def __len__(self):
179 | return len(self.info)
180 |
181 | def Video_test_collate_fn(data):
182 | if isinstance(data[0],collections.Mapping):
183 | t_data = [tuple(d.values()) for d in data]
184 | values = MARS_collate_fn(t_data)
185 | return {key:value for key,value in zip(data[0].keys(),values)}
186 | else:
187 | imgs,label,cam= zip(*data)
188 | imgs = torch.cat(imgs,dim=0)
189 | labels = torch.cat(label,dim=0)
190 | cams = torch.cat(cam,dim=0)
191 | return imgs,labels,cams
192 |
193 | def Get_Video_test_DataLoader(db_txt,info,query,transform,batch_size=10,shuffle=False,num_workers=8,S=6,distractor=True):
194 | dataset = Video_test_Dataset(db_txt,info,query,transform,S,distractor=distractor)
195 | dataloader = DataLoader(dataset,batch_size=batch_size,collate_fn=Video_test_collate_fn,shuffle=shuffle,worker_init_fn=lambda _:np.random.seed(),num_workers=num_workers)
196 | return dataloader
197 |
198 |
199 |
200 |
--------------------------------------------------------------------------------