├── LICENSE ├── README.md ├── commands.txt ├── data ├── datasets.py ├── dukevideoreid │ ├── __pycache__ │ │ └── data_manager.cpython-36.pyc │ └── data_manager.py ├── eval_metrics.py ├── mars │ ├── __pycache__ │ │ └── data_manager.cpython-36.pyc │ └── data_manager.py ├── misc.py ├── samplers.py ├── temporal_transforms.py └── veri │ ├── __pycache__ │ └── data_manager.cpython-36.pyc │ └── data_manager.py ├── images ├── gradcam.png ├── mars_all_withstudent.pdf ├── mars_all_withstudent.png └── mvd_framework.png ├── model ├── cbam │ ├── __pycache__ │ │ ├── bam.cpython-36.pyc │ │ └── resnet_bam.cpython-36.pyc │ ├── bam.py │ └── resnet_bam.py ├── loss.py └── net.py ├── requirements.txt ├── tools ├── eval.py ├── save_heatmaps.py ├── train_distill.py └── train_v2v.py └── utils ├── conf.py ├── misc.py └── saver.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Angelo Porrello 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Robust Re-Identification by Multiple Views Knowledge Distillation 2 | 3 | This repository contains Pytorch code for the [ECCV20](https://eccv2020.eu/) paper "Robust Re-Identification by Multiple Views Knowledge Distillation" [[arXiv](http://arxiv.org/abs/2007.04174)] 4 | 5 | ![VKD - Overview](images/mvd_framework.png) 6 | 7 | ```bibtex 8 | @inproceedings{porrello2020robust, 9 | title={Robust Re-Identification by Multiple Views Knowledge Distillation}, 10 | author={Porrello, Angelo and Bergamini, Luca and Calderara, Simone}, 11 | booktitle={European Conference on Computer Vision}, 12 | pages={93--110}, 13 | year={2020}, 14 | organization={Springer} 15 | } 16 | ``` 17 | 18 | ## Installation Note 19 | 20 | Tested with Python3.6.8 on Ubuntu (17.04, 18.04). 21 | 22 | - Setup an empty pip environment 23 | - Install packages using ``pip install -r requirements.txt`` 24 | - Install torch1.3.1 using ``pip install torch==1.3.1+cu92 torchvision==0.4.2+cu92 -f https://download.pytorch.org/whl/torch_stable.html 25 | `` 26 | - Place datasets in ``.datasets/`` (Please note you may need do request some of them to their respective authors) 27 | - Run scripts from ```commands.txt``` 28 | 29 | Please note that if you're running the code from Pycharm (or another IDE) you may need to manually set the working path to ``PROJECT_PATH`` 30 | 31 | ## VKD Training (MARS [1]) 32 | 33 | ### Data preparation 34 | - Create the folder ``./datasets/mars`` 35 | - Download the dataset from [here](https://drive.google.com/drive/u/1/folders/0B6tjyrV1YrHeMVV2UFFXQld6X1E) 36 | - Unzip data and place the two folders inside the MARS [1] folder 37 | - Download metadata from [here](https://github.com/liangzheng06/MARS-evaluation/tree/master/info) 38 | - Place them in a folder named ``info`` under the same path 39 | - You should end up with the following structure: 40 | 41 | ``` 42 | PROJECT_PATH/datasets/mars/ 43 | |-- bbox_train/ 44 | |-- bbox_test/ 45 | |-- info/ 46 | ``` 47 | 48 | ### Teacher-Student Training 49 | 50 | **First step**: the backbone network is trained for the standard Video-To-Video setting. In this stage, each training example comprises of N images drawn from the same tracklet (N=8 by default; you can change it through the argument ``--num_train_images``. 51 | 52 | ```shell 53 | # To train ResNet-50 on MARS (teacher, first step) run: 54 | python ./tools/train_v2v.py mars --backbone resnet50 --num_train_images 8 --p 8 --k 4 --exp_name base_mars_resnet50 --first_milestone 100 --step_milestone 100 55 | ``` 56 | 57 | **Second step**: we appoint it as the teacher and freeze its parameters. Then, a new network with the role of the student is instantiated. In doing so, we feed N views (i.e. images captured from multiple cameras) as input to the teacher and ask the student to mimic the same outputs from fewer (M=2 by default,``--num_student_images``) frames. 58 | ```shell 59 | # To train a ResVKD-50 (student) run: 60 | python ./tools/train_distill.py mars ./logs/base_mars_resnet50 --exp_name distill_mars_resnet50 --p 12 --k 4 --step_milestone 150 --num_epochs 500 61 | ``` 62 | 63 | ![](images/mars_all_withstudent.png) 64 | 65 | ## Model Zoo 66 | 67 | We provide a bunch of pre-trained checkpoints through two zip files (``baseline.zip`` containing the weights of the teacher networks, ``distilled.zip`` the student ones). Therefore, to evaluate ResNet-50 and ResVKD-50 on MARS, proceed as follows: 68 | - Download ``baseline.zip`` from [here](https://ailb-web.ing.unimore.it/publicfiles/vkd_checkpoints/baseline.zip) and ``distilled.zip`` from [here](https://ailb-web.ing.unimore.it/publicfiles/vkd_checkpoints/distilled.zip) (~4.8 GB) 69 | - Unzip the two folders inside the ``PROJECT_PATH/logs`` folder 70 | - Then, you can evaluate both networks using the ``eval.py`` script: 71 | 72 | ```sh 73 | python ./tools/eval.py mars ./logs/baseline_public/mars/base_mars_resnet50 --trinet_chk_name chk_end 74 | ``` 75 | 76 | ```sh 77 | python ./tools/eval.py mars ./logs/distilled_public/mars/selfdistill/distill_mars_resnet50 --trinet_chk_name chk_di_1 78 | ``` 79 | 80 | You should end up with the following results on MARS (see Tab.1 of the paper for VeRi-776 and Duke-Video-ReID): 81 | 82 | Backbone|top1 I2V|mAP I2V|top1 V2V|mAP V2V 83 | :-:|:-:|:-:|:-:|:-: 84 | ``ResNet-34`` | 80.81 | 70.74 | 86.67 | 78.03 85 | ``ResVKD-34`` | **82.17** | **73.68** | **87.83** | **79.50** 86 | ``ResNet-50`` | 82.22 | 73.38 | 87.88 | 81.13 87 | ``ResVKD-50`` | **83.89** | **77.27** | **88.74** | **82.22** 88 | ``ResNet-101`` | 82.78 | 74.94 | 88.59 | 81.66 89 | ``ResVKD-101`` | **85.91** | **77.64** | **89.60** | **82.65** 90 | 91 | Backbone|top1 I2V|mAP I2V|top1 V2V|mAP V2V 92 | :-:|:-:|:-:|:-:|:-: 93 | ``ResNet-50bam`` | 82.58 | 74.11 | 88.54 | 81.19 94 | ``ResVKD-50bam`` | **84.34** | **78.13** | **89.39** | **83.07** 95 | 96 | Backbone|top1 I2V|mAP I2V|top1 V2V|mAP V2V 97 | :-:|:-:|:-:|:-:|:-: 98 | ``DenseNet-121`` | 82.68 | 74.34 | 89.75 | 81.93 99 | ``DenseVKD-121`` | **84.04** | **77.09** | **89.80** | **82.84** 100 | 101 | Backbone|top1 I2V|mAP I2V|top1 V2V|mAP V2V 102 | :-:|:-:|:-:|:-:|:-: 103 | ``MobileNet-V2`` | 78.64 | 67.94 | 85.96 | 77.10 104 | ``MobileVKD-V2`` | **83.33** | **73.95** | **88.13** | **79.62** 105 | 106 | ## Teacher-Student Explanations 107 | 108 | As discussed in the main paper, we have leveraged GradCam [2] to highlight the input regions that have been considered paramount for predicting the identity. We have performed the same analysis for the teacher network as well as for the student one: as can be seen, the latter pays more attention to the subject of interest compared to its teacher. 109 | 110 | ![Model Explanation](images/gradcam.png) 111 | 112 | You can draw the heatmaps with the following command: 113 | 114 | ```sh 115 | python -u ./tools/save_heatmaps.py mars --chk_net1 --chk_net2 --dest_path 116 | ``` 117 | 118 | ## References 119 | 120 | 1. Zheng, L., Bie, Z., Sun, Y., Wang, J., Su, C., Wang, S., Tian, Q.: Mars: A video benchmark for large-scale person re-identification. In: European Conference on Computer Vision (2016) 121 | 2. Selvaraju, R. R., Cogswell, M., Das, A., Vedantam, R., Parikh, D., & Batra, D. (2017). Grad-cam: Visual explanations from deep networks via gradient-based localization. In Proceedings of the IEEE international conference on computer vision (pp. 618-626). 122 | -------------------------------------------------------------------------------- /commands.txt: -------------------------------------------------------------------------------- 1 | _______ _ _______ _ _ 2 | |__ __| | | |__ __| (_) (_) 3 | | | ___ __ _ ___| |__ ___ _ __ | |_ __ __ _ _ _ __ _ _ __ __ _ 4 | | |/ _ \/ _` |/ __| '_ \ / _ \ '__| | | '__/ _` | | '_ \| | '_ \ / _` | 5 | | | __/ (_| | (__| | | | __/ | | | | | (_| | | | | | | | | | (_| | 6 | |_|\___|\__,_|\___|_| |_|\___|_| |_|_| \__,_|_|_| |_|_|_| |_|\__, | 7 | __/ | 8 | |___/ 9 | 10 | ## VeRi Teacher Training 11 | python ./tools/train_v2v.py veri --backbone resnet101 --num_train_images 1 --p 18 --k 4 --exp_name base_veri_resnet101 --first_milestone 200 12 | python ./tools/train_v2v.py veri --backbone resnet50 --num_train_images 1 --p 18 --k 4 --exp_name base_veri_resnet50 --first_milestone 200 13 | python ./tools/train_v2v.py veri --backbone resnet18 --num_train_images 1 --p 18 --k 4 --exp_name base_veri_resnet18 --first_milestone 200 14 | python ./tools/train_v2v.py veri --backbone resnet50bam --num_train_images 1 --p 18 --k 4 --exp_name base_veri_resnet50bam --first_milestone 200 15 | python ./tools/train_v2v.py veri --backbone densenet121 --num_train_images 1 --p 18 --k 4 --exp_name base_veri_densenet121 --first_milestone 200 16 | python ./tools/train_v2v.py veri --backbone resnet34 --num_train_images 1 --p 18 --k 4 --exp_name base_veri_resnet34 --first_milestone 200 17 | python ./tools/train_v2v.py veri --backbone mobilenet --num_train_images 1 --p 18 --k 4 --exp_name base_veri_mobilenet --first_milestone 200 18 | 19 | ## MARS Teacher Training 20 | python ./tools/train_v2v.py mars --backbone resnet101 --num_train_images 8 --p 8 --k 4 --exp_name base_mars_resnet101 --first_milestone 100 --step_milestone 100 21 | python ./tools/train_v2v.py mars --backbone resnet50 --num_train_images 8 --p 8 --k 4 --exp_name base_mars_resnet50 --first_milestone 100 --step_milestone 100 22 | python ./tools/train_v2v.py mars --backbone resnet18 --num_train_images 8 --p 8 --k 4 --exp_name base_mars_resnet18 --first_milestone 50 23 | python ./tools/train_v2v.py mars --backbone resnet50bam --num_train_images 8 --p 8 --k 4 --exp_name base_mars_resnet50bam --first_milestone 100 --step_milestone 100 24 | python ./tools/train_v2v.py mars --backbone densenet121 --num_train_images 8 --p 8 --k 4 --exp_name base_mars_densenet121 --first_milestone 100 --step_milestone 100 25 | python ./tools/train_v2v.py mars --backbone resnet34 --num_train_images 8 --p 8 --k 4 --exp_name base_mars_resnet34 --first_milestone 100 --step_milestone 100 26 | python ./tools/train_v2v.py mars --backbone mobilenet --num_train_images 8 --p 8 --k 4 --exp_name base_mars_mobilenet --first_milestone 100 --step_milestone 100 27 | 28 | ## Duke-Video-ReId Teacher Training 29 | python ./tools/train_v2v.py duke-video-reid --backbone resnet101 --num_train_images 8 --p 8 --k 4 --exp_name base_duke_resnet101 --first_milestone 100 --step_milestone 100 30 | python ./tools/train_v2v.py duke-video-reid --backbone resnet50 --num_train_images 8 --p 8 --k 4 --exp_name base_duke_resnet50 --first_milestone 100 --step_milestone 100 31 | python ./tools/train_v2v.py duke-video-reid --backbone resnet18 --num_train_images 8 --p 8 --k 4 --exp_name base_duke_resnet18 --first_milestone 50 32 | python ./tools/train_v2v.py duke-video-reid --backbone resnet50bam --num_train_images 8 --p 8 --k 4 --exp_name base_duke_resnet50bam --first_milestone 100 --step_milestone 100 33 | python ./tools/train_v2v.py duke-video-reid --backbone densenet121 --num_train_images 8 --p 8 --k 4 --exp_name base_duke_densenet121 --first_milestone 100 --step_milestone 100 34 | python ./tools/train_v2v.py duke-video-reid --backbone resnet34 --num_train_images 8 --p 8 --k 4 --exp_name base_duke_resnet34 --first_milestone 100 --step_milestone 100 35 | python ./tools/train_v2v.py duke-video-reid --backbone mobilenet --num_train_images 8 --p 8 --k 4 --exp_name base_duke_mobilenet --first_milestone 100 --step_milestone 100 36 | 37 | _____ _ _ _ _______ _ _ 38 | / ____| | | | | | |__ __| (_) (_) 39 | | (___ | |_ _ _ __| | ___ _ __ | |_ | |_ __ __ _ _ _ __ _ _ __ __ _ 40 | \___ \| __| | | |/ _` |/ _ \ '_ \| __| | | '__/ _` | | '_ \| | '_ \ / _` | 41 | ____) | |_| |_| | (_| | __/ | | | |_ | | | | (_| | | | | | | | | | (_| | 42 | |_____/ \__|\__,_|\__,_|\___|_| |_|\__| |_|_| \__,_|_|_| |_|_|_| |_|\__, | 43 | __/ | 44 | |___/ 45 | ## VeRi Student Training (self-distill) 46 | python ./tools/train_distill.py veri ./logs/base_veri_resnet50bam --exp_name distill_veri_resnet50bam --p 12 --k 4 --step_milestone 150 --num_epochs 500 47 | python ./tools/train_distill.py veri ./logs/base_veri_mobilenet --exp_name distill_veri_mobilenet --p 12 --k 4 --step_milestone 150 --num_epochs 500 48 | python ./tools/train_distill.py veri ./logs/base_veri_resnet50 --exp_name distill_veri_resnet50 --p 12 --k 4 --step_milestone 150 --num_epochs 500 49 | python ./tools/train_distill.py veri ./logs/base_veri_densenet121 --exp_name distill_veri_densenet121 --p 12 --k 4 --step_milestone 150 --num_epochs 500 50 | python ./tools/train_distill.py veri ./logs/base_veri_resnet101 --exp_name distill_veri_resnet101 --p 12 --k 4 --step_milestone 150 --num_epochs 500 51 | python ./tools/train_distill.py veri ./logs/base_veri_resnet34 --exp_name distill_veri_resnet34 --p 12 --k 4 --step_milestone 150 --num_epochs 500 52 | 53 | ## MARS Student Training (self-distill) 54 | python ./tools/train_distill.py mars ./logs/base_mars_resnet50bam --exp_name distill_mars_resnet50bam --p 12 --k 4 --step_milestone 150 --num_epochs 500 55 | python ./tools/train_distill.py mars ./logs/base_mars_mobilenet --exp_name distill_mars_mobilenet --p 12 --k 4 --step_milestone 150 --num_epochs 500 56 | python ./tools/train_distill.py mars ./logs/base_mars_resnet50 --exp_name distill_mars_resnet50 --p 12 --k 4 --step_milestone 150 --num_epochs 500 57 | python ./tools/train_distill.py mars ./logs/base_mars_densenet121 --exp_name distill_mars_densenet121 --p 12 --k 4 --step_milestone 150 --num_epochs 500 58 | python ./tools/train_distill.py mars ./logs/base_mars_resnet101 --exp_name distill_mars_resnet101 --p 12 --k 4 --step_milestone 150 --num_epochs 500 59 | python ./tools/train_distill.py mars ./logs/base_mars_resnet34 --exp_name distill_mars_resnet34 --p 12 --k 4 --step_milestone 150 --num_epochs 500 60 | 61 | ## Duke-Video-ReId Student Training (self-distill) 62 | python ./tools/train_distill.py ./logs/duke-video-reid base_duke_resnet50bam --exp_name distill_duke_resnet50bam --p 12 --k 4 --step_milestone 150 --num_epochs 500 63 | python ./tools/train_distill.py ./logs/duke-video-reid base_duke_mobilenet --exp_name distill_duke_mobilenet --p 12 --k 4 --step_milestone 150 --num_epochs 500 64 | python ./tools/train_distill.py ./logs/duke-video-reid base_duke_resnet50 --exp_name distill_duke_resnet50 --p 12 --k 4 --step_milestone 150 --num_epochs 500 65 | python ./tools/train_distill.py ./logs/duke-video-reid base_duke_densenet121 --exp_name distill_duke_densenet121 --p 12 --k 4 --step_milestone 150 --num_epochs 500 66 | python ./tools/train_distill.py ./logs/duke-video-reid base_duke_resnet101 --exp_name distill_duke_resnet101 --p 12 --k 4 --step_milestone 150 --num_epochs 500 67 | python ./tools/train_distill.py ./logs/duke-video-reid base_duke_resnet34 --exp_name distill_duke_resnet34 --p 12 --k 4 --step_milestone 150 --num_epochs 500 68 | 69 | --------------------------------------------------------------------------------------------------- 70 | 71 | ## VeRi Student Training (cross-distill) 72 | python ./tools/train_distill.py veri ./logs/base_veri_resnet101 --student ./logs/base_veri_resnet34 --exp_name distill_veri_resnet101_to_resnet34 --p 12 --k 4 --step_milestone 150 --num_epochs 500 73 | python ./tools/train_distill.py veri ./logs/base_veri_resnet101 --student ./logs/base_veri_resnet50bam --exp_name distill_veri_resnet101_to_resnet50bam --p 12 --k 4 --step_milestone 150 --num_epochs 500 74 | python ./tools/train_distill.py veri ./logs/base_veri_resnet50 --student ./logs/base_veri_resnet34 --exp_name distill_veri_resnet50_to_resnet34 --p 12 --k 4 --step_milestone 150 --num_epochs 500 75 | python ./tools/train_distill.py veri ./logs/base_veri_resnet101 --student ./logs/base_veri_mobilenet --exp_name distill_veri_resnet101_to_mobilenet --p 12 --k 4 --step_milestone 150 --num_epochs 500 76 | python ./tools/train_distill.py veri ./logs/base_veri_resnet101 --student ./logs/base_veri_resnet50 --exp_name distill_veri_resnet101_to_resnet50 --p 12 --k 4 --step_milestone 150 --num_epochs 500 77 | 78 | ## MARS Student Training (cross-distill) 79 | python ./tools/train_distill.py mars ./logs/base_mars_resnet101 --student ./logs/base_mars_resnet34 --exp_name distill_mars_resnet101_to_resnet34 --p 12 --k 4 --step_milestone 150 --num_epochs 500 80 | python ./tools/train_distill.py mars ./logs/base_mars_resnet101 --student ./logs/base_mars_resnet50bam --exp_name distill_mars_resnet101_to_resnet50bam --p 12 --k 4 --step_milestone 150 --num_epochs 500 81 | python ./tools/train_distill.py mars ./logs/base_mars_resnet50 --student ./logs/base_mars_resnet34 --exp_name distill_mars_resnet50_to_resnet34 --p 12 --k 4 --step_milestone 150 --num_epochs 500 82 | python ./tools/train_distill.py mars ./logs/base_mars_resnet101 --student ./logs/base_mars_mobilenet --exp_name distill_mars_resnet101_to_mobilenet --p 12 --k 4 --step_milestone 150 --num_epochs 500 83 | python ./tools/train_distill.py mars ./logs/base_mars_resnet101 --student ./logs/base_mars_resnet50 --exp_name distill_mars_resnet101_to_resnet50 --p 12 --k 4 --step_milestone 150 --num_epochs 500 84 | 85 | ## Duke-Video-ReId Student Training (cross-distill) 86 | python ./tools/train_distill.py duke-video-reid ./logs/base_duke_resnet101 --student ./logs/base_duke_resnet34 --exp_name distill_duke_resnet101_to_resnet34 --p 12 --k 4 --step_milestone 150 --num_epochs 500 87 | python ./tools/train_distill.py duke-video-reid ./logs/base_duke_resnet101 --student ./logs/base_duke_resnet50bam --exp_name distill_duke_resnet101_to_resnet50bam --p 12 --k 4 --step_milestone 150 --num_epochs 500 88 | python ./tools/train_distill.py duke-video-reid ./logs/base_duke_resnet50 --student ./logs/base_duke_resnet34 --exp_name distill_duke_resnet50_to_resnet34 --p 12 --k 4 --step_milestone 150 --num_epochs 500 89 | python ./tools/train_distill.py duke-video-reid ./logs/base_duke_resnet101 --student ./logs/base_duke_mobilenet --exp_name distill_duke_resnet101_to_mobilenet --p 12 --k 4 --step_milestone 150 --num_epochs 500 90 | python ./tools/train_distill.py duke-video-reid ./logs/base_duke_resnet101 --student ./logs/base_duke_resnet50 --exp_name distill_duke_resnet101_to_resnet50 --p 12 --k 4 --step_milestone 150 --num_epochs 500 91 | -------------------------------------------------------------------------------- /data/datasets.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | 3 | import numpy as np 4 | import torch 5 | import torch.utils.data as data 6 | from torch.utils.data import DataLoader 7 | 8 | from torchvision import transforms as T 9 | import data.temporal_transforms as TT 10 | from data.dukevideoreid.data_manager import DukeVideoreID 11 | from data.mars.data_manager import Mars 12 | from data.misc import get_default_video_loader 13 | from data.misc import get_transforms 14 | from data.samplers import ReIDBatchSampler 15 | from data.veri.data_manager import Veri 16 | from utils.misc import init_worker 17 | 18 | 19 | class Dataset(data.Dataset): 20 | """Video ReID Dataset. 21 | Note: 22 | Batch data has shape N x C x T x H x W 23 | Args: 24 | dataset (list): List with items (img_paths, pid, camid) 25 | temporal_transform (callable, optional): A function/transform that takes in a list of frame indices 26 | and returns a transformed version 27 | target_transform (callable, optional): A function/transform that takes in the 28 | target and transforms it. 29 | """ 30 | 31 | def __init__(self, 32 | dataset, 33 | spatial_transform=None, 34 | temporal_transform=None, 35 | get_loader=get_default_video_loader): 36 | self.dataset = dataset 37 | self.spatial_transform = spatial_transform 38 | self.temporal_transform = temporal_transform 39 | self.loader = get_loader() 40 | self.teacher_mode = False 41 | 42 | def __len__(self): 43 | return len(self.dataset) 44 | 45 | def get_num_pids(self): 46 | return len(np.unique([el[1] for el in self.dataset])) 47 | 48 | def get_num_cams(self): 49 | return len(np.unique([el[2] for el in self.dataset])) 50 | 51 | def set_teacher_mode(self, is_teacher: bool): 52 | self.teacher_mode = is_teacher 53 | 54 | def __getitem__(self, index): 55 | """ 56 | Args: 57 | index (int): Index 58 | 59 | Returns: 60 | tuple: (clip, pid, camid) where pid is identity of the clip. 61 | """ 62 | img_paths, pid, camid = self.dataset[index] 63 | 64 | if isinstance(self.temporal_transform, TT.MultiViewTemporalTransform): 65 | candidates = list(filter(lambda x: x[1] == pid, self.dataset)) 66 | img_paths = self.temporal_transform(candidates, index) 67 | elif self.temporal_transform is not None: 68 | img_paths = self.temporal_transform(img_paths, index) 69 | 70 | clip = self.loader(img_paths) 71 | 72 | if not self.teacher_mode: 73 | clip = [self.spatial_transform(img) for img in clip] 74 | else: 75 | clip_aug = [self.spatial_transform(img) for img in clip] 76 | std_daug = T.Compose([ 77 | self.spatial_transform.transforms[0], 78 | T.ToTensor(), 79 | self.spatial_transform.transforms[-1] if not isinstance(self.spatial_transform.transforms[-1], T.RandomErasing) else self.spatial_transform.transforms[-2] 80 | ]) 81 | clip_std = [std_daug(img) for img in clip] 82 | clip = clip_aug + clip_std 83 | 84 | clip = torch.stack(clip, 0) 85 | 86 | return clip, pid, camid 87 | 88 | 89 | DATASETS = { 90 | 'mars': Mars, 91 | 'veri': Veri, 92 | 'duke-video-reid': DukeVideoreID, 93 | } 94 | 95 | 96 | class DataConf: 97 | def __init__(self, perform_x2i, perform_x2v, augment_gallery): 98 | self.perform_x2i = perform_x2i 99 | self.perform_x2v = perform_x2v 100 | self.augment_gallery = augment_gallery 101 | 102 | 103 | DATA_CONFS = { 104 | 'mars': DataConf(perform_x2i=False, perform_x2v=True, augment_gallery=True), 105 | 'veri': DataConf(perform_x2i=True, perform_x2v=True, augment_gallery=False), 106 | 'duke-video-reid': DataConf(perform_x2i=False, perform_x2v=True, augment_gallery=False), 107 | 108 | } 109 | 110 | 111 | def get_dataloaders(dataset_name: str, root: str, device: torch.device, args: Namespace): 112 | dataset_name = dataset_name.lower() 113 | assert dataset_name in DATASETS.keys() 114 | dataset = DATASETS[dataset_name](root) 115 | 116 | pin_memory = True if device == torch.device('cuda') else False 117 | 118 | s_tr_train, t_tr_train = get_transforms(True, args) 119 | s_tr_test, t_tr_test = get_transforms(False, args) 120 | 121 | train_loader = DataLoader( 122 | Dataset(dataset.train, spatial_transform=s_tr_train, 123 | temporal_transform=t_tr_train), 124 | batch_sampler=ReIDBatchSampler(dataset.train, p=args.p, k=args.k), 125 | num_workers=args.workers, pin_memory=pin_memory, 126 | worker_init_fn=init_worker 127 | ) 128 | 129 | query_loader = DataLoader( 130 | Dataset(dataset.query, spatial_transform=s_tr_test, 131 | temporal_transform=t_tr_test), 132 | batch_size=args.test_batch, shuffle=False, num_workers=2, 133 | pin_memory=pin_memory, drop_last=False, worker_init_fn=init_worker 134 | ) 135 | 136 | gallery_loader = DataLoader( 137 | Dataset(dataset.gallery, spatial_transform=s_tr_test, 138 | temporal_transform=t_tr_test), 139 | batch_size=args.test_batch, shuffle=False, num_workers=2, 140 | pin_memory=pin_memory, drop_last=False, worker_init_fn=init_worker 141 | ) 142 | 143 | queryimg_loader = DataLoader( 144 | Dataset(dataset.query_img, spatial_transform=s_tr_test), 145 | batch_size=args.img_test_batch, shuffle=False, num_workers=2, 146 | pin_memory=pin_memory, drop_last=False, worker_init_fn=init_worker 147 | ) 148 | 149 | galleryimg_loader = DataLoader( 150 | Dataset(dataset.gallery_img, spatial_transform=s_tr_test), 151 | batch_size=args.img_test_batch, shuffle=False, num_workers=2, 152 | pin_memory=pin_memory, drop_last=False, worker_init_fn=init_worker 153 | ) 154 | 155 | return train_loader, query_loader, gallery_loader, queryimg_loader, galleryimg_loader 156 | -------------------------------------------------------------------------------- /data/dukevideoreid/__pycache__/data_manager.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimagelab/VKD/a56b12e4f687c9c05f341381632d3eb3dd122dfb/data/dukevideoreid/__pycache__/data_manager.cpython-36.pyc -------------------------------------------------------------------------------- /data/dukevideoreid/data_manager.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import numpy as np 3 | from pathlib import Path 4 | from copy import deepcopy 5 | 6 | """Dataset classes""" 7 | 8 | 9 | class DukeVideoreID(object): 10 | """ 11 | Duke VIDEO re-id 12 | Reference: 13 | Ristani et al. Performance measures and a data set for multi-target, multi-camera tracking 14 | """ 15 | 16 | def __init__(self, root='/data/datasets/', min_seq_len=0): 17 | self.root = osp.join(root, 'DukeMTMC-VideoReID') 18 | self.train_name_path = osp.join(self.root, 'train') 19 | self.gallery_name_path = osp.join(self.root, 'gallery') 20 | self.query_name_path = osp.join(self.root, 'query') 21 | 22 | train_paths = self._get_paths(self.train_name_path) 23 | query_paths = self._get_paths(self.query_name_path) 24 | gallery_paths = self._get_paths(self.gallery_name_path) 25 | 26 | train, num_train_tracklets, num_train_pids, num_train_imgs = self._process_video(train_paths, re_label=True) 27 | gallery, num_gallery_tracklets, num_gallery_pids, num_gallery_imgs = self._process_video(gallery_paths, re_label=False) 28 | query, num_query_tracklets, num_query_pids, num_query_imgs = self._process_video(query_paths, re_label=False) 29 | 30 | # query and gallery image are computed from first frames 31 | query_img = [] 32 | for el in query: 33 | query_img.append((el[0][:1], el[1], el[2])) # first image of gallery tracklet 34 | 35 | gallery_img = [] 36 | for el in gallery: 37 | gallery_img.append((el[0][:1], el[1], el[2])) # first image of gallery tracklet 38 | 39 | num_imgs_per_tracklet = num_train_imgs + num_gallery_imgs + num_query_imgs 40 | total_num = np.sum(num_imgs_per_tracklet) 41 | min_num = np.min(num_imgs_per_tracklet) 42 | max_num = np.max(num_imgs_per_tracklet) 43 | avg_num = np.mean(num_imgs_per_tracklet) 44 | 45 | num_total_pids = num_train_pids + num_query_pids 46 | num_total_tracklets = num_train_tracklets + num_gallery_tracklets + num_query_tracklets 47 | 48 | print("=> DUKE re-ID loaded") 49 | print("Dataset statistics:") 50 | print(" -----------------------------------------") 51 | print(" subset | # ids | # tracklets | # images") 52 | print(" -----------------------------------------") 53 | print(" train | {:5d} | {:8d} | {:8d}".format(num_train_pids, num_train_tracklets, np.sum(num_train_imgs))) 54 | print(" query | {:5d} | {:8d} | {:8d}".format(num_query_pids, num_query_tracklets, np.sum(num_query_imgs))) 55 | print(" gallery | {:5d} | {:8d} | {:8d}".format(num_gallery_pids, num_gallery_tracklets, np.sum(num_gallery_imgs))) 56 | print(" -----------------------------------------") 57 | print(" total | {:5d} | {:8d} | {:8d}".format(num_total_pids, num_total_tracklets, total_num)) 58 | print(" -----------------------------------------") 59 | print(" number of images per tracklet: {} ~ {}, average {:.1f}".format(min_num, max_num, avg_num)) 60 | print(" -----------------------------------------") 61 | 62 | self.train = train 63 | self.gallery = gallery 64 | self.query = query 65 | self.query_img = query_img 66 | self.gallery_img = gallery_img 67 | 68 | self.num_train_pids = num_train_pids 69 | self.num_query_pids = num_query_pids 70 | self.num_gallery_pids = num_gallery_pids 71 | 72 | def _get_paths(self, fpath): 73 | path = Path(fpath) 74 | return sorted(path.glob('**/*.jpg')) 75 | 76 | def _process_video(self, train_paths, re_label: bool): 77 | train = [] 78 | pids = [] 79 | num_images = [] 80 | train_names = [p.name for p in train_paths] 81 | hs = np.asarray([hash(el[:7]) for el in train_names]) 82 | displaces = np.nonzero(hs[1:] - hs[:-1])[0] + 1 83 | displaces = np.concatenate([[0], displaces, [len(hs)]]) 84 | for idx in range(len(displaces) - 1): 85 | names = train_names[displaces[idx]: displaces[idx + 1]] 86 | pid, camera = map(int, names[0].replace('C', '').split('_')[:2]) 87 | camera -= 1 88 | paths = [str(p) for p in train_paths[displaces[idx]: displaces[idx + 1]]] 89 | train.append((paths, pid, camera)) 90 | num_images.append(len(names)) 91 | pids.append(pid) 92 | 93 | # RE-LABEl 94 | if re_label: 95 | pid_map = {pid: idx for idx, pid in enumerate(np.unique(pids))} 96 | for i in range(len(train)): 97 | pid = pids[i] 98 | train[i] = (train[i][0], pid_map[pid], train[i][2]) 99 | pids[i] = pid_map[pid] 100 | 101 | return train, len(train), len(set(pids)), num_images 102 | -------------------------------------------------------------------------------- /data/eval_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from warnings import warn 3 | 4 | 5 | def compute_ap_cmc(index, good_index, junk_index, 6 | return_is_correct: bool = False): 7 | ap = 0 8 | cmc = np.zeros(len(index)) 9 | 10 | # remove junk_index 11 | mask = np.in1d(index, junk_index, invert=True) 12 | index = index[mask] 13 | 14 | # find good_index index 15 | ngood = len(good_index) 16 | mask = np.in1d(index, good_index) 17 | rows_good = np.argwhere(mask == True) 18 | rows_good = rows_good.flatten() 19 | 20 | cmc[rows_good[0]:] = 1.0 21 | 22 | correct_ = np.any(np.in1d(index[:1], good_index)) 23 | predicted_ = index[:1] 24 | 25 | for i in range(ngood): 26 | 27 | d_recall = 1.0 / ngood 28 | precision = (i + 1) * 1.0 / (rows_good[i] + 1) 29 | if rows_good[i] != 0: 30 | old_precision = i * 1.0 / rows_good[i] 31 | else: 32 | old_precision = 1.0 33 | ap = ap + d_recall * (old_precision + precision) / 2 # trapezoid approximation 34 | 35 | if return_is_correct: 36 | return ap, cmc, (correct_, predicted_) 37 | return ap, cmc 38 | 39 | 40 | def evaluate(distmat, q_pids, g_pids, q_camids, g_camids, return_wrong_matches: bool = False): 41 | num_q, num_g = distmat.shape 42 | index = np.argsort(distmat, axis=1) # from small to large 43 | 44 | num_no_gt = 0 # num of query imgs without groundtruth 45 | num_r1 = 0 46 | CMC = np.zeros(len(g_pids)) 47 | AP = 0 48 | 49 | wrong_matches = [] 50 | 51 | for i in range(num_q): 52 | # groundtruth index 53 | query_index = np.argwhere(g_pids == q_pids[i]) 54 | camera_index = np.argwhere(g_camids == q_camids[i]) 55 | good_index = np.setdiff1d(query_index, camera_index, assume_unique=True) 56 | if good_index.size == 0: 57 | num_no_gt += 1 58 | continue 59 | # remove gallery samples that have the same pid and camid with query 60 | junk_index = np.intersect1d(query_index, camera_index) 61 | 62 | ap_tmp, CMC_tmp, (correct_, predicted_) = compute_ap_cmc(index[i], good_index, 63 | junk_index, True) 64 | 65 | if not correct_: 66 | wrong_matches.append((i, predicted_[0])) 67 | 68 | if CMC_tmp[0] == 1: 69 | num_r1 += 1 70 | CMC = CMC + CMC_tmp 71 | AP += ap_tmp 72 | 73 | if num_no_gt > 0: 74 | warn("{} query imgs do not have groundtruth.".format(num_no_gt)) 75 | 76 | CMC = CMC / (num_q - num_no_gt) 77 | mAP = AP / (num_q - num_no_gt) 78 | 79 | if return_wrong_matches: return CMC, mAP, wrong_matches 80 | 81 | return CMC, mAP 82 | -------------------------------------------------------------------------------- /data/mars/__pycache__/data_manager.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimagelab/VKD/a56b12e4f687c9c05f341381632d3eb3dd122dfb/data/mars/__pycache__/data_manager.cpython-36.pyc -------------------------------------------------------------------------------- /data/mars/data_manager.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | from scipy.io import loadmat 3 | import numpy as np 4 | 5 | 6 | class Mars(object): 7 | """ 8 | MARS 9 | 10 | Reference: 11 | Zheng et al. MARS: A Video Benchmark for Large-Scale Person Re-identification. ECCV 2016. 12 | """ 13 | 14 | def __init__(self, root='/data/datasets/', min_seq_len=0): 15 | self.root = osp.join(root, 'mars') 16 | self.train_name_path = osp.join(self.root, 'info/train_name.txt') 17 | self.test_name_path = osp.join(self.root, 'info/test_name.txt') 18 | self.track_train_info_path = osp.join(self.root, 'info/tracks_train_info.mat') 19 | self.track_test_info_path = osp.join(self.root, 'info/tracks_test_info.mat') 20 | self.query_IDX_path = osp.join(self.root, 'info/query_IDX.mat') 21 | 22 | self._check_before_run() 23 | 24 | # prepare meta data 25 | train_names = self._get_names(self.train_name_path) 26 | test_names = self._get_names(self.test_name_path) 27 | track_train = loadmat(self.track_train_info_path)['track_train_info'] # numpy.ndarray (8298, 4) 28 | track_test = loadmat(self.track_test_info_path)['track_test_info'] # numpy.ndarray (12180, 4) 29 | query_IDX = loadmat(self.query_IDX_path)['query_IDX'].squeeze() # numpy.ndarray (1980,) 30 | query_IDX -= 1 # index from 0 31 | track_query = track_test[query_IDX,:] 32 | gallery_IDX = [i for i in range(track_test.shape[0]) if i not in query_IDX] 33 | track_gallery = track_test[gallery_IDX,:] 34 | # track_gallery = track_test 35 | 36 | train, num_train_tracklets, num_train_pids, num_train_imgs = \ 37 | self._process_data(train_names, track_train, home_dir='bbox_train', relabel=True, min_seq_len=min_seq_len) 38 | 39 | query, num_query_tracklets, num_query_pids, num_query_imgs = \ 40 | self._process_data(test_names, track_query, home_dir='bbox_test', relabel=False, min_seq_len=min_seq_len) 41 | 42 | gallery, num_gallery_tracklets, num_gallery_pids, num_gallery_imgs = \ 43 | self._process_data(test_names, track_gallery, home_dir='bbox_test', relabel=False, min_seq_len=min_seq_len) 44 | 45 | train_img, _, _ = \ 46 | self._extract_1stfeame(train_names, track_train, home_dir='bbox_train', relabel=True) 47 | 48 | query_img, _, _ = \ 49 | self._extract_1stfeame(test_names, track_query, home_dir='bbox_test', relabel=False) 50 | 51 | gallery_img, _, _ = \ 52 | self._extract_1stfeame(test_names, track_gallery, home_dir='bbox_test', relabel=False) 53 | 54 | num_imgs_per_tracklet = num_train_imgs + num_gallery_imgs + num_query_imgs 55 | total_num = np.sum(num_imgs_per_tracklet) 56 | min_num = np.min(num_imgs_per_tracklet) 57 | max_num = np.max(num_imgs_per_tracklet) 58 | avg_num = np.mean(num_imgs_per_tracklet) 59 | 60 | num_total_pids = num_train_pids + num_query_pids 61 | num_total_tracklets = num_train_tracklets + num_gallery_tracklets + num_query_tracklets 62 | 63 | print("=> MARS loaded") 64 | print("Dataset statistics:") 65 | print(" -----------------------------------------") 66 | print(" subset | # ids | # tracklets | # images") 67 | print(" -----------------------------------------") 68 | print(" train | {:5d} | {:8d} | {:8d}".format(num_train_pids, num_train_tracklets, np.sum(num_train_imgs))) 69 | print(" query | {:5d} | {:8d} | {:8d}".format(num_query_pids, num_query_tracklets, np.sum(num_query_imgs))) 70 | print(" gallery | {:5d} | {:8d} | {:8d}".format(num_gallery_pids, num_gallery_tracklets, np.sum(num_gallery_imgs))) 71 | print(" -----------------------------------------") 72 | print(" total | {:5d} | {:8d} | {:8d}".format(num_total_pids, num_total_tracklets, total_num)) 73 | print(" -----------------------------------------") 74 | print(" number of images per tracklet: {} ~ {}, average {:.1f}".format(min_num, max_num, avg_num)) 75 | print(" -----------------------------------------") 76 | 77 | self.train = train 78 | self.query = query 79 | self.gallery = gallery 80 | 81 | self.train_img = train_img 82 | self.query_img = query_img 83 | self.gallery_img = gallery_img 84 | 85 | self.num_train_pids = num_train_pids 86 | self.num_query_pids = num_query_pids 87 | self.num_gallery_pids = num_gallery_pids 88 | 89 | def _check_before_run(self): 90 | """Check if all files are available before going deeper""" 91 | if not osp.exists(self.root): 92 | raise RuntimeError("'{}' is not available".format(self.root)) 93 | if not osp.exists(self.train_name_path): 94 | raise RuntimeError("'{}' is not available".format(self.train_name_path)) 95 | if not osp.exists(self.test_name_path): 96 | raise RuntimeError("'{}' is not available".format(self.test_name_path)) 97 | if not osp.exists(self.track_train_info_path): 98 | raise RuntimeError("'{}' is not available".format(self.track_train_info_path)) 99 | if not osp.exists(self.track_test_info_path): 100 | raise RuntimeError("'{}' is not available".format(self.track_test_info_path)) 101 | if not osp.exists(self.query_IDX_path): 102 | raise RuntimeError("'{}' is not available".format(self.query_IDX_path)) 103 | 104 | def _get_names(self, fpath): 105 | names = [] 106 | with open(fpath, 'r') as f: 107 | for line in f: 108 | new_line = line.rstrip() 109 | names.append(new_line) 110 | return names 111 | 112 | def _process_data(self, names, meta_data, home_dir=None, relabel=False, min_seq_len=0): 113 | assert home_dir in ['bbox_train', 'bbox_test'] 114 | num_tracklets = meta_data.shape[0] 115 | pid_list = list(set(meta_data[:,2].tolist())) 116 | num_pids = len(pid_list) 117 | 118 | if relabel: pid2label = {pid:label for label, pid in enumerate(pid_list)} 119 | tracklets = [] 120 | num_imgs_per_tracklet = [] 121 | 122 | for tracklet_idx in range(num_tracklets): 123 | data = meta_data[tracklet_idx,...] 124 | start_index, end_index, pid, camid = data 125 | if pid == -1: continue # junk images are just ignored 126 | assert 1 <= camid <= 6 127 | if relabel: pid = pid2label[pid] 128 | camid -= 1 # index starts from 0 129 | img_names = names[start_index-1:end_index] 130 | 131 | # make sure image names correspond to the same person 132 | pnames = [img_name[:4] for img_name in img_names] 133 | assert len(set(pnames)) == 1, "Error: a single tracklet contains different person images" 134 | 135 | # make sure all images are captured under the same camera 136 | camnames = [img_name[5] for img_name in img_names] 137 | assert len(set(camnames)) == 1, "Error: images are captured under different cameras!" 138 | 139 | # append image names with directory information 140 | img_paths = [osp.join(self.root, home_dir, img_name[:4], img_name) for img_name in img_names] 141 | if len(img_paths) >= min_seq_len: 142 | img_paths = tuple(img_paths) 143 | tracklets.append((img_paths, pid, camid)) 144 | num_imgs_per_tracklet.append(len(img_paths)) 145 | 146 | num_tracklets = len(tracklets) 147 | 148 | return tracklets, num_tracklets, num_pids, num_imgs_per_tracklet 149 | 150 | def _extract_1stfeame(self, names, meta_data, home_dir=None, relabel=False): 151 | assert home_dir in ['bbox_train', 'bbox_test'] 152 | num_tracklets = meta_data.shape[0] 153 | pid_list = list(set(meta_data[:,2].tolist())) 154 | num_pids = len(pid_list) 155 | 156 | if relabel: pid2label = {pid:label for label, pid in enumerate(pid_list)} 157 | imgs = [] 158 | 159 | for tracklet_idx in range(num_tracklets): 160 | data = meta_data[tracklet_idx,...] 161 | start_index, end_index, pid, camid = data 162 | if pid == -1: continue # junk images are just ignored 163 | assert 1 <= camid <= 6 164 | if relabel: pid = pid2label[pid] 165 | camid -= 1 # index starts from 0 166 | img_name = names[start_index-1] 167 | 168 | # append image names with directory information 169 | img_path = osp.join(self.root, home_dir, img_name[:4], img_name) 170 | 171 | imgs.append(([img_path], pid, camid)) 172 | 173 | num_imgs = len(imgs) 174 | 175 | return imgs, num_imgs, num_pids 176 | -------------------------------------------------------------------------------- /data/misc.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import os 3 | from argparse import Namespace 4 | import numpy as np 5 | from PIL import Image 6 | from torchvision import transforms as T 7 | 8 | from data import temporal_transforms as TT 9 | 10 | 11 | def pil_loader(path): 12 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 13 | with open(path, 'rb') as f: 14 | with Image.open(f) as img: 15 | return img.convert('RGB') 16 | 17 | 18 | def accimage_loader(path): 19 | try: 20 | import accimage 21 | return accimage.Image(path) 22 | except IOError: 23 | # Potentially a decoding problem, fall back to PIL.Image 24 | return pil_loader(path) 25 | 26 | 27 | def get_default_image_loader(): 28 | from torchvision import get_image_backend 29 | if get_image_backend() == 'accimage': 30 | return accimage_loader 31 | else: 32 | return pil_loader 33 | 34 | 35 | def image_loader(path): 36 | from torchvision import get_image_backend 37 | if get_image_backend() == 'accimage': 38 | return accimage_loader(path) 39 | else: 40 | return pil_loader(path) 41 | 42 | 43 | def video_loader(img_paths, image_loader): 44 | video = [] 45 | for image_path in img_paths: 46 | if os.path.exists(image_path): 47 | video.append(image_loader(image_path)) 48 | else: 49 | return video 50 | 51 | return video 52 | 53 | 54 | def get_default_video_loader(): 55 | image_loader = get_default_image_loader() 56 | return functools.partial(video_loader, image_loader=image_loader) 57 | 58 | 59 | def get_transforms(train_mode: bool, args: Namespace): 60 | 61 | mean, var = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 62 | 63 | input_res = { 64 | 'mars': (256, 128), 65 | 'duke-video-reid': (256, 128), 66 | 'veri': (224, 224), 67 | }.get(args.dataset_name, (224, 224)) 68 | 69 | erase_ratio = { 70 | 'mars': (0.3, 3.3), 71 | 'duke-video-reid': (0.3, 3.3), 72 | 'veri': (0.7, 1.4), 73 | }.get(args.dataset_name, (0.7, 1.4)) 74 | 75 | erase_scale = (0.02, 0.4) 76 | 77 | resize_operation = { 78 | 'mars': T.Resize(input_res, interpolation=3), 79 | 'duke-video-reid': T.Resize(input_res, interpolation=3), 80 | 'veri': T.Resize(input_res, interpolation=3), 81 | }.get(args.dataset_name, AdaptiveResize(height=input_res[0], width=input_res[1])) 82 | 83 | if not train_mode: 84 | t_tr_test = TT.TemporalChunkCrop(args.num_test_images) 85 | 86 | s_tr_test = T.Compose([ 87 | resize_operation, 88 | T.ToTensor(), 89 | T.Normalize(mean, var) 90 | ]) 91 | return s_tr_test, t_tr_test 92 | 93 | tr_re = [T.RandomErasing(p=0.5, scale=erase_scale, ratio=erase_ratio)] \ 94 | if args.use_random_erasing else [] 95 | 96 | # Data augmentation 97 | s_tr_train = T.Compose([ 98 | resize_operation, 99 | T.Pad(10), 100 | T.RandomCrop(input_res), 101 | T.RandomHorizontalFlip(), 102 | T.ToTensor(), 103 | T.Normalize(mean, var) 104 | ] + tr_re) 105 | 106 | if args.train_strategy == 'random': 107 | t_tr_train = TT.TemporalRandomFrames(args.num_train_images) 108 | elif args.train_strategy == 'chunk': 109 | t_tr_train = TT.RandomTemporalChunkCrop(args.num_train_images) 110 | elif args.train_strategy == 'temporal': 111 | t_tr_train = TT.TemporalChunkCrop(args.num_train_images) 112 | elif args.train_strategy == 'multiview': 113 | t_tr_train = TT.MultiViewTemporalTransform(args.num_train_images) 114 | else: 115 | raise ValueError 116 | 117 | return s_tr_train, t_tr_train 118 | 119 | 120 | class AdaptiveResize: 121 | def __init__(self, width, height, interpolation=3): 122 | self.height = height 123 | self.width = width 124 | self.interpolation = interpolation 125 | 126 | @staticmethod 127 | def get_padding(padding): 128 | if padding == 0: 129 | p_1, p_2 = 0, 0 130 | elif padding % 2 == 0: 131 | p_1, p_2 = padding // 2, padding // 2 132 | else: 133 | p_1, p_2 = padding // 2 + 1, padding // 2 134 | return p_1, p_2 135 | 136 | def __call__(self, img: Image.Image): 137 | h, w = img.height, img.width 138 | # resize to ensure fit in target shape 139 | ratio_w = self.width / w 140 | ratio_h = self.height / h 141 | ratio = min(ratio_w, ratio_h) 142 | new_w, new_h = map(lambda x: int(np.floor(x * ratio)), (w, h)) 143 | img = img.resize((new_w, new_h), resample=self.interpolation) 144 | 145 | # compute padding 146 | h, w = img.height, img.width 147 | p_t, p_b = self.get_padding(self.height - h) 148 | p_l, p_r = self.get_padding(self.width - w) 149 | 150 | # copy into new buffer 151 | img = np.pad(np.asarray(img), ((p_t, p_b), (p_l, p_r), (0, 0)), mode='constant') 152 | 153 | return Image.fromarray(img) 154 | -------------------------------------------------------------------------------- /data/samplers.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from itertools import chain 3 | 4 | import numpy as np 5 | import random 6 | 7 | from torch.utils.data.sampler import Sampler 8 | 9 | 10 | def compute_pids_and_pids_dict(data_source): 11 | 12 | index_dic = defaultdict(list) 13 | for index, (_, pid, _) in enumerate(data_source): 14 | index_dic[pid].append(index) 15 | pids = list(index_dic.keys()) 16 | return pids, index_dic 17 | 18 | 19 | class ReIDBatchSampler(Sampler): 20 | 21 | def __init__(self, data_source, p: int, k: int): 22 | 23 | self._p = p 24 | self._k = k 25 | 26 | pids, index_dic = compute_pids_and_pids_dict(data_source) 27 | 28 | self._unique_labels = np.array(pids) 29 | self._label_to_items = index_dic.copy() 30 | 31 | self._num_iterations = len(self._unique_labels) // self._p 32 | 33 | def __iter__(self): 34 | 35 | def sample(set, n): 36 | if len(set) < n: 37 | return np.random.choice(set, n, replace=True) 38 | return np.random.choice(set, n, replace=False) 39 | 40 | np.random.shuffle(self._unique_labels) 41 | 42 | for k, v in self._label_to_items.items(): 43 | random.shuffle(self._label_to_items[k]) 44 | 45 | curr_p = 0 46 | 47 | for idx in range(self._num_iterations): 48 | p_labels = self._unique_labels[curr_p: curr_p + self._p] 49 | curr_p += self._p 50 | batch = [sample(self._label_to_items[l], self._k) for l in p_labels] 51 | batch = list(chain(*batch)) 52 | yield batch 53 | 54 | def __len__(self): 55 | return self._num_iterations 56 | -------------------------------------------------------------------------------- /data/temporal_transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | 4 | 5 | class TemporalChunkCrop(object): 6 | 7 | def __init__(self, size: int = 4): 8 | self.S = size 9 | 10 | def __call__(self, frame_indices, tracklet_index): 11 | sample_clip = [] 12 | F = len(frame_indices) 13 | if F < self.S: 14 | strip = list(range(0, F)) + [F-1] * (self.S - F) 15 | for s in range(self.S): 16 | pool = strip[s * 1:(s + 1) * 1] 17 | sample_clip.append(list(pool)) 18 | else: 19 | interval = math.ceil(F / self.S) 20 | strip = list(range(0, F)) + [F-1] * (interval * self.S - F) 21 | for s in range(self.S): 22 | pool = strip[s * interval:(s + 1) * interval] 23 | sample_clip.append(list(pool)) 24 | return [ frame_indices[idx] for idx 25 | in np.array(sample_clip)[:, 0].tolist() ] 26 | 27 | 28 | class RandomTemporalChunkCrop(object): 29 | 30 | def __init__(self, size: int = 4): 31 | self.S = size 32 | 33 | def __call__(self, frame_indices, tracklet_index): 34 | sample_clip = [] 35 | F = len(frame_indices) 36 | if F < self.S: 37 | strip = list(range(0, F)) + [F-1] * (self.S - F) 38 | for s in range(self.S): 39 | pool = strip[s * 1:(s + 1) * 1] 40 | sample_clip.append(list(pool)) 41 | else: 42 | interval = math.ceil(F / self.S) 43 | strip = list(range(0, F)) + [F-1] * (interval * self.S - F) 44 | for s in range(self.S): 45 | pool = strip[s * interval:(s + 1) * interval] 46 | sample_clip.append(list(pool)) 47 | 48 | sample_clip = np.array(sample_clip) 49 | sample_clip = sample_clip[np.arange(self.S), 50 | np.random.randint(0, sample_clip.shape[1], self.S)] 51 | return [ frame_indices[idx] for idx in sample_clip ] 52 | 53 | 54 | class MultiViewTemporalTransform(object): 55 | 56 | def __init__(self, size: int = 4): 57 | self.size = size 58 | 59 | def __call__(self, candidate, tracklet_index): 60 | img_paths = [] 61 | candidate_perm = np.random.permutation(len(candidate)) 62 | for idx in range(self.size): 63 | cur_tracklet = candidate_perm[idx % len(candidate_perm)] 64 | cur_frame = np.random.randint(0, len(candidate[cur_tracklet][0])) 65 | cur_img_path = candidate[cur_tracklet][0][cur_frame] 66 | img_paths.append(cur_img_path) 67 | return img_paths 68 | 69 | 70 | class TemporalRandomFrames(object): 71 | """ 72 | Get size random frames (without replacement if possible) from a video 73 | """ 74 | 75 | def __init__(self, num_images=4): 76 | self.num_images = num_images 77 | 78 | def __call__(self, frame_indices, tracklet_index): 79 | frame_indices = list(frame_indices) 80 | if len(frame_indices) < self.num_images: 81 | return list(np.random.choice(frame_indices, size=self.num_images, replace=True)) 82 | 83 | return list(np.random.choice(frame_indices, size=self.num_images, replace=False)) 84 | -------------------------------------------------------------------------------- /data/veri/__pycache__/data_manager.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimagelab/VKD/a56b12e4f687c9c05f341381632d3eb3dd122dfb/data/veri/__pycache__/data_manager.cpython-36.pyc -------------------------------------------------------------------------------- /data/veri/data_manager.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import numpy as np 3 | from copy import deepcopy 4 | 5 | 6 | class Veri(object): 7 | """ 8 | VeRi 9 | 10 | Reference: 11 | """ 12 | 13 | def __init__(self, root='/data/datasets/', min_seq_len=0): 14 | self.root = osp.join(root, 'VeRi') 15 | self.train_name_path = osp.join(self.root, 'name_train.txt') 16 | self.query_name_path = osp.join(self.root, 'name_query.txt') 17 | self.track_gallery_info_path = osp.join(self.root, 'test_track.txt') 18 | 19 | train_names = self._get_names(self.train_name_path) 20 | query_names = self._get_names(self.query_name_path) 21 | train, num_train_tracklets, num_train_pids, num_train_imgs = self._process_train(train_names) 22 | gallery, num_gallery_tracklets, num_gallery_pids, num_gallery_imgs = self._process_gallery() 23 | query, num_query_tracklets, num_query_pids, num_query_imgs = self._process_query(query_names, gallery) 24 | 25 | gallery_img = [] 26 | num_gallery_tracklets = 0 27 | for el in gallery: 28 | for fr in el[0]: 29 | gallery_img.append(([fr], el[1], el[2])) 30 | num_gallery_tracklets += 1 31 | 32 | query_img, _, _, _ = self._process_query_image(query_names) 33 | 34 | num_imgs_per_tracklet = num_train_imgs + num_gallery_imgs + num_query_imgs 35 | total_num = np.sum(num_imgs_per_tracklet) 36 | min_num = np.min(num_imgs_per_tracklet) 37 | max_num = np.max(num_imgs_per_tracklet) 38 | avg_num = np.mean(num_imgs_per_tracklet) 39 | 40 | num_total_pids = num_train_pids + num_query_pids 41 | num_total_tracklets = num_train_tracklets + num_gallery_tracklets + num_query_tracklets 42 | 43 | print("=> VeRi loaded") 44 | print("Dataset statistics:") 45 | print(" -----------------------------------------") 46 | print(" subset | # ids | # tracklets | # images") 47 | print(" -----------------------------------------") 48 | print(" train | {:5d} | {:8d} | {:8d}".format(num_train_pids, num_train_tracklets, np.sum(num_train_imgs))) 49 | print(" query | {:5d} | {:8d} | {:8d}".format(num_query_pids, num_query_tracklets, np.sum(num_query_imgs))) 50 | print(" gallery | {:5d} | {:8d} | {:8d}".format(num_gallery_pids, num_gallery_tracklets, np.sum(num_gallery_imgs))) 51 | print(" -----------------------------------------") 52 | print(" total | {:5d} | {:8d} | {:8d}".format(num_total_pids, num_total_tracklets, total_num)) 53 | print(" -----------------------------------------") 54 | print(" number of images per tracklet: {} ~ {}, average {:.1f}".format(min_num, max_num, avg_num)) 55 | print(" -----------------------------------------") 56 | 57 | self.train = train 58 | self.query = query 59 | self.gallery = gallery 60 | self.gallery_img = gallery_img 61 | self.query_img = query_img 62 | 63 | self.num_train_pids = num_train_pids 64 | self.num_query_pids = num_query_pids 65 | self.num_gallery_pids = num_gallery_pids 66 | 67 | def _get_names(self, fpath): 68 | names = [] 69 | with open(fpath, 'r') as f: 70 | for line in f: 71 | new_line = line.rstrip() 72 | names.append(new_line) 73 | return names 74 | 75 | def _process_gallery(self): 76 | gallery = [] 77 | pids = [] 78 | num_images = [] 79 | 80 | with open(self.track_gallery_info_path) as fp: 81 | for line in fp.readlines(): 82 | imgs_names = [osp.join(self.root, f'image_test/{el}') for el in line.split(' ')] 83 | imgs_names = imgs_names[1:len(imgs_names)-1] 84 | pid, camera = map(int, line.split(' ')[0].replace('c', '').split('_')[:2]) 85 | camera -= 1 86 | gallery.append((imgs_names, pid, camera)) 87 | num_images.append(len(imgs_names)) 88 | pids.append(pid) 89 | 90 | return gallery, len(gallery), len(set(pids)), num_images 91 | 92 | def _process_train(self, train_names): 93 | train = [] 94 | pids = [] 95 | num_images = [] 96 | train_names = sorted(train_names) 97 | hs = np.asarray([hash(el[:9]) for el in train_names]) 98 | displaces = np.nonzero(hs[1:] - hs[:-1])[0] + 1 99 | displaces = np.concatenate([[0], displaces, [len(hs)]]) 100 | for idx in range(len(displaces) - 1): 101 | names = train_names[displaces[idx]: displaces[idx + 1]] 102 | imgs_names = [osp.join(self.root, f'image_train/{el}') for el in names] 103 | pid, camera = map(int, names[0].replace('c', '').split('_')[:2]) 104 | camera -= 1 105 | train.append((imgs_names, pid, camera)) 106 | num_images.append(len(imgs_names)) 107 | pids.append(pid) 108 | 109 | # RE-LABEl 110 | pid_map = {pid: idx for idx, pid in enumerate(np.unique(pids))} 111 | for i in range(len(train)): 112 | pid = pids[i] 113 | train[i] = (train[i][0], pid_map[pid], train[i][2]) 114 | pids[i] = pid_map[pid] 115 | 116 | return train, len(train), len(set(pids)), num_images 117 | 118 | def _process_query(self, query_names, gallery): 119 | queries = [] 120 | pids = [] 121 | num_images = [] 122 | 123 | for qn in query_names: 124 | pid, camera = map(int, qn.replace('c', '').split('_')[:2]) 125 | camera -= 1 126 | # look into gallery 127 | for el in gallery: 128 | if el[1] == pid and el[2] == camera: 129 | queries.append((deepcopy(el[0]), el[1], el[2])) 130 | num_images.append(len(el[0])) 131 | pids.append(el[1]) 132 | break 133 | 134 | return queries, len(queries), len(set(pids)), num_images 135 | 136 | def _process_query_image(self, query_names): 137 | queries = [] 138 | pids = [] 139 | num_images = [] 140 | 141 | for qn in query_names: 142 | imgs_names = [osp.join(self.root, f'image_query/{qn}')] 143 | pid, camera = map(int, qn.replace('c', '').split('_')[:2]) 144 | camera -= 1 145 | queries.append((imgs_names, pid, camera)) 146 | num_images.append(len(imgs_names)) 147 | 148 | pids.append(pid) 149 | 150 | return queries, len(queries), len(set(pids)), num_images 151 | -------------------------------------------------------------------------------- /images/gradcam.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimagelab/VKD/a56b12e4f687c9c05f341381632d3eb3dd122dfb/images/gradcam.png -------------------------------------------------------------------------------- /images/mars_all_withstudent.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimagelab/VKD/a56b12e4f687c9c05f341381632d3eb3dd122dfb/images/mars_all_withstudent.pdf -------------------------------------------------------------------------------- /images/mars_all_withstudent.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimagelab/VKD/a56b12e4f687c9c05f341381632d3eb3dd122dfb/images/mars_all_withstudent.png -------------------------------------------------------------------------------- /images/mvd_framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimagelab/VKD/a56b12e4f687c9c05f341381632d3eb3dd122dfb/images/mvd_framework.png -------------------------------------------------------------------------------- /model/cbam/__pycache__/bam.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimagelab/VKD/a56b12e4f687c9c05f341381632d3eb3dd122dfb/model/cbam/__pycache__/bam.cpython-36.pyc -------------------------------------------------------------------------------- /model/cbam/__pycache__/resnet_bam.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimagelab/VKD/a56b12e4f687c9c05f341381632d3eb3dd122dfb/model/cbam/__pycache__/resnet_bam.cpython-36.pyc -------------------------------------------------------------------------------- /model/cbam/bam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Flatten(nn.Module): 7 | def forward(self, x): 8 | return x.view(x.size(0), -1) 9 | 10 | 11 | class ChannelGate(nn.Module): 12 | def __init__(self, gate_channel, reduction_ratio=16, num_layers=1): 13 | super(ChannelGate, self).__init__() 14 | self.gate_c = nn.Sequential() 15 | self.gate_c.add_module('flatten', Flatten()) 16 | gate_channels = [gate_channel] 17 | gate_channels += [gate_channel // reduction_ratio] * num_layers 18 | gate_channels += [gate_channel] 19 | for i in range(len(gate_channels) - 2): 20 | self.gate_c.add_module('gate_c_fc_%d' % i, 21 | nn.Linear(gate_channels[i], gate_channels[i + 1])) 22 | self.gate_c.add_module('gate_c_bn_%d' % (i + 1), nn.BatchNorm1d(gate_channels[i + 1])) 23 | self.gate_c.add_module('gate_c_relu_%d' % (i + 1), nn.ReLU()) 24 | self.gate_c.add_module('gate_c_fc_final', nn.Linear(gate_channels[-2], gate_channels[-1])) 25 | 26 | def forward(self, in_tensor): 27 | avg_pool = F.avg_pool2d(in_tensor, (in_tensor.size(2), in_tensor.size(3)), 28 | stride=(in_tensor.size(2), in_tensor.size(3))) 29 | return self.gate_c(avg_pool).unsqueeze(2).unsqueeze(3).expand_as(in_tensor) 30 | 31 | 32 | class SpatialGate(nn.Module): 33 | def __init__(self, gate_channel, reduction_ratio=16, dilation_conv_num=2, dilation_val=4): 34 | super(SpatialGate, self).__init__() 35 | self.gate_s = nn.Sequential() 36 | self.gate_s.add_module('gate_s_conv_reduce0', 37 | nn.Conv2d(gate_channel, gate_channel // reduction_ratio, 38 | kernel_size=1)) 39 | self.gate_s.add_module('gate_s_bn_reduce0', nn.BatchNorm2d(gate_channel // reduction_ratio)) 40 | self.gate_s.add_module('gate_s_relu_reduce0', nn.ReLU()) 41 | for i in range(dilation_conv_num): 42 | self.gate_s.add_module('gate_s_conv_di_%d' % i, 43 | nn.Conv2d(gate_channel // reduction_ratio, 44 | gate_channel // reduction_ratio, kernel_size=3, 45 | padding=dilation_val, dilation=dilation_val)) 46 | self.gate_s.add_module('gate_s_bn_di_%d' % i, 47 | nn.BatchNorm2d(gate_channel // reduction_ratio)) 48 | self.gate_s.add_module('gate_s_relu_di_%d' % i, nn.ReLU()) 49 | self.gate_s.add_module('gate_s_conv_final', 50 | nn.Conv2d(gate_channel // reduction_ratio, 1, kernel_size=1)) 51 | 52 | def forward(self, in_tensor): 53 | return self.gate_s(in_tensor).expand_as(in_tensor) 54 | 55 | 56 | class BAM(nn.Module): 57 | def __init__(self, gate_channel): 58 | super(BAM, self).__init__() 59 | self.channel_att = ChannelGate(gate_channel) 60 | self.spatial_att = SpatialGate(gate_channel) 61 | 62 | def forward(self, in_tensor): 63 | att = 1 + torch.sigmoid(self.channel_att(in_tensor) * self.spatial_att(in_tensor)) 64 | return att * in_tensor 65 | -------------------------------------------------------------------------------- /model/cbam/resnet_bam.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch.utils.model_zoo as model_zoo 4 | from torch.nn import init 5 | 6 | from model.cbam.bam import BAM 7 | 8 | model_urls = { 9 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 10 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 11 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 12 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 13 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 14 | } 15 | 16 | 17 | def conv3x3(in_planes, out_planes, stride=1): 18 | "3x3 convolution with padding" 19 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 20 | padding=1, bias=False) 21 | 22 | 23 | class BasicBlock(nn.Module): 24 | expansion = 1 25 | 26 | def __init__(self, inplanes, planes, stride=1, downsample=None): 27 | 28 | super(BasicBlock, self).__init__() 29 | self.conv1 = conv3x3(inplanes, planes, stride) 30 | self.bn1 = nn.BatchNorm2d(planes) 31 | self.relu = nn.ReLU(inplace=True) 32 | self.conv2 = conv3x3(planes, planes) 33 | self.bn2 = nn.BatchNorm2d(planes) 34 | self.downsample = downsample 35 | self.stride = stride 36 | 37 | def forward(self, x): 38 | residual = x 39 | 40 | out = self.conv1(x) 41 | out = self.bn1(out) 42 | out = self.relu(out) 43 | 44 | out = self.conv2(out) 45 | out = self.bn2(out) 46 | 47 | if self.downsample is not None: 48 | residual = self.downsample(x) 49 | 50 | out += residual 51 | out = self.relu(out) 52 | 53 | return out 54 | 55 | 56 | class Bottleneck(nn.Module): 57 | expansion = 4 58 | 59 | def __init__(self, inplanes, planes, stride=1, downsample=None): 60 | 61 | super(Bottleneck, self).__init__() 62 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 63 | self.bn1 = nn.BatchNorm2d(planes) 64 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 65 | padding=1, bias=False) 66 | self.bn2 = nn.BatchNorm2d(planes) 67 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 68 | self.bn3 = nn.BatchNorm2d(planes * 4) 69 | self.relu = nn.ReLU(inplace=True) 70 | self.downsample = downsample 71 | self.stride = stride 72 | 73 | def forward(self, x): 74 | residual = x 75 | 76 | out = self.conv1(x) 77 | out = self.bn1(out) 78 | out = self.relu(out) 79 | 80 | out = self.conv2(out) 81 | out = self.bn2(out) 82 | out = self.relu(out) 83 | 84 | out = self.conv3(out) 85 | out = self.bn3(out) 86 | 87 | if self.downsample is not None: 88 | residual = self.downsample(x) 89 | 90 | out += residual 91 | out = self.relu(out) 92 | 93 | return out 94 | 95 | 96 | class ResNet(nn.Module): 97 | def __init__(self, block, layers, network_type, num_classes, zero_init_residual=False): 98 | 99 | self.inplanes = 64 100 | super(ResNet, self).__init__() 101 | self.network_type = network_type 102 | # different model config between ImageNet and CIFAR 103 | if network_type == "ImageNet": 104 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 105 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 106 | self.avgpool = nn.AvgPool2d(7) 107 | else: 108 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 109 | 110 | self.bn1 = nn.BatchNorm2d(64) 111 | self.relu = nn.ReLU(inplace=True) 112 | 113 | bam1 = BAM(64 * block.expansion) 114 | bam2 = BAM(128 * block.expansion) 115 | bam3 = BAM(256 * block.expansion) 116 | 117 | self.layer1 = self._make_layer(block, 64, layers[0]) 118 | self.layer1.add_module('bam', bam1) 119 | 120 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 121 | self.layer2.add_module('bam', bam2) 122 | 123 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 124 | self.layer3.add_module('bam', bam3) 125 | 126 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 127 | 128 | self.fc = nn.Linear(512 * block.expansion, num_classes) 129 | 130 | init.kaiming_normal_(self.fc.weight) 131 | for key in self.state_dict(): 132 | if key.split('.')[-1] == "weight": 133 | if "conv" in key: 134 | init.kaiming_normal_(self.state_dict()[key], mode='fan_out') 135 | if "bn" in key: 136 | if "SpatialGate" in key: 137 | self.state_dict()[key][...] = 0 138 | else: 139 | self.state_dict()[key][...] = 1 140 | elif key.split(".")[-1] == 'bias': 141 | self.state_dict()[key][...] = 0 142 | 143 | if zero_init_residual: 144 | for m in self.modules(): 145 | if isinstance(m, Bottleneck): 146 | nn.init.constant_(m.bn3.weight, 0) 147 | elif isinstance(m, BasicBlock): 148 | nn.init.constant_(m.bn2.weight, 0) 149 | 150 | def _make_layer(self, block, planes, blocks, stride=1): 151 | downsample = None 152 | if stride != 1 or self.inplanes != planes * block.expansion: 153 | downsample = nn.Sequential( 154 | nn.Conv2d(self.inplanes, planes * block.expansion, 155 | kernel_size=1, stride=stride, bias=False), 156 | nn.BatchNorm2d(planes * block.expansion), 157 | ) 158 | 159 | layers = [] 160 | layers.append(block(self.inplanes, planes, stride, downsample)) 161 | self.inplanes = planes * block.expansion 162 | for i in range(1, blocks): 163 | layers.append(block(self.inplanes, planes)) 164 | 165 | return nn.Sequential(*layers) 166 | 167 | def forward(self, x): 168 | x = self.conv1(x) 169 | x = self.bn1(x) 170 | x = self.relu(x) 171 | if self.network_type == "ImageNet": 172 | x = self.maxpool(x) 173 | 174 | x = self.layer1(x) 175 | x = self.layer2(x) 176 | x = self.layer3(x) 177 | x = self.layer4(x) 178 | 179 | if self.network_type == "ImageNet": 180 | x = self.avgpool(x) 181 | else: 182 | x = F.avg_pool2d(x, 4) 183 | x = x.view(x.size(0), -1) 184 | x = self.fc(x) 185 | return x 186 | 187 | 188 | def ResidualNet(network_type, depth, num_classes, zero_init_residual): 189 | assert network_type in ["ImageNet", "CIFAR10", 190 | "CIFAR100"], "network type should be ImageNet or CIFAR10 / CIFAR100" 191 | assert depth in [18, 34, 50, 101], 'network depth should be 18, 34, 50 or 101' 192 | 193 | if depth == 18: 194 | model = ResNet(BasicBlock, [2, 2, 2, 2], network_type, num_classes, zero_init_residual) 195 | 196 | elif depth == 34: 197 | model = ResNet(BasicBlock, [3, 4, 6, 3], network_type, num_classes, zero_init_residual) 198 | 199 | elif depth == 50: 200 | model = ResNet(Bottleneck, [3, 4, 6, 3], network_type, num_classes, zero_init_residual) 201 | 202 | elif depth == 101: 203 | model = ResNet(Bottleneck, [3, 4, 23, 3], network_type, num_classes, zero_init_residual) 204 | else: 205 | raise ValueError() 206 | 207 | return model 208 | 209 | 210 | def resnet50_bam(pretrained=False, **kwargs): 211 | """Constructs a ResNet-50 model. 212 | Args: 213 | pretrained (bool): If True, returns a model pre-trained on ImageNet 214 | """ 215 | model = ResidualNet('ImageNet', 50, 1000, **kwargs) 216 | if pretrained: 217 | pretrained_state_dict = model_zoo.load_url(model_urls['resnet50']) 218 | now_state_dict = model.state_dict() 219 | now_state_dict.update(pretrained_state_dict) 220 | model.load_state_dict(now_state_dict) 221 | return model 222 | 223 | 224 | def resnet101_bam(pretrained=False, **kwargs): 225 | """Constructs a ResNet-50 model. 226 | Args: 227 | pretrained (bool): If True, returns a model pre-trained on ImageNet 228 | """ 229 | model = ResidualNet('ImageNet', 101, 1000, **kwargs) 230 | if pretrained: 231 | pretrained_state_dict = model_zoo.load_url(model_urls['resnet101']) 232 | now_state_dict = model.state_dict() 233 | now_state_dict.update(pretrained_state_dict) 234 | model.load_state_dict(now_state_dict) 235 | return model 236 | -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from torch.nn import functional as F 4 | 5 | 6 | def l2_norm(x: torch.Tensor): 7 | return x / torch.norm(x, dim=-1, keepdim=True) 8 | 9 | 10 | class MatrixPairwiseDistances(nn.Module): 11 | 12 | def __init__(self): 13 | super(MatrixPairwiseDistances, self).__init__() 14 | 15 | def __call__(self, x: torch.Tensor, y: torch.Tensor = None): 16 | if y is not None: # exact form of squared distances 17 | differences = x.unsqueeze(1) - y.unsqueeze(0) 18 | else: 19 | differences = x.unsqueeze(1) - x.unsqueeze(0) 20 | distances = torch.sum(differences * differences, -1) 21 | return distances 22 | 23 | 24 | class SmoothedCCE(nn.Module): 25 | def __init__(self, num_classes: int, eps: float = 0.1, reduction: str='sum'): 26 | super(SmoothedCCE, self).__init__() 27 | self.reduction = reduction 28 | assert reduction in ['sum', 'mean'] 29 | self.eps = eps 30 | self.num_classes = num_classes 31 | 32 | self.logsoftmax = nn.LogSoftmax(dim=1) 33 | self.factor_0 = self.eps / self.num_classes 34 | self.factor_1 = 1 - ((self.num_classes - 1) / self.num_classes) * self.eps 35 | 36 | def labels_to_one_hot(self, labels): 37 | onehot = torch.ones(len(labels), self.num_classes).to(labels.device) * self.factor_0 38 | onehot[torch.arange(0, len(labels), dtype=torch.long), labels.long()] = self.factor_1 39 | return onehot 40 | 41 | def forward(self, feats, target): 42 | """ 43 | target are long in [0, num_classes)! 44 | """ 45 | one_hots = self.labels_to_one_hot(target) 46 | if self.reduction == 'sum': 47 | loss = torch.sum(-torch.sum(one_hots * self.logsoftmax(feats), -1)) 48 | else: 49 | loss = torch.mean(-torch.sum(one_hots * self.logsoftmax(feats), -1)) 50 | 51 | return loss 52 | 53 | def __call__(self, *args, **kwargs): 54 | return super(SmoothedCCE, self).__call__(*args, **kwargs) 55 | 56 | 57 | class KDLoss(nn.Module): 58 | 59 | def __init__(self, temp: float, reduction: str): 60 | super(KDLoss, self).__init__() 61 | 62 | self.temp = temp 63 | self.reduction = reduction 64 | self.kl_loss = nn.KLDivLoss(reduction=reduction) 65 | 66 | def forward(self, teacher_logits: torch.Tensor, student_logits: torch.Tensor): 67 | 68 | student_softmax = F.log_softmax(student_logits / self.temp, dim=-1) 69 | teacher_softmax = F.softmax(teacher_logits / self.temp, dim=-1) 70 | 71 | kl = nn.KLDivLoss(reduction='none')(student_softmax, teacher_softmax) 72 | kl = kl.sum() if self.reduction == 'sum' else kl.sum(1).mean() 73 | kl = kl * (self.temp ** 2) 74 | 75 | return kl 76 | 77 | def __call__(self, *args, **kwargs): 78 | return super(KDLoss, self).__call__(*args, **kwargs) 79 | 80 | 81 | class LogitsMatching(nn.Module): 82 | 83 | def __init__(self, reduction: str): 84 | super(LogitsMatching, self).__init__() 85 | self.mse_loss = nn.MSELoss(reduction=reduction) 86 | 87 | def forward(self, teacher_logits: torch.Tensor, student_logits: torch.Tensor): 88 | return self.mse_loss(student_logits, teacher_logits) 89 | 90 | def __call__(self, *args, **kwargs): 91 | return super(LogitsMatching, self).__call__(*args, **kwargs) 92 | 93 | 94 | class SimilarityDistillationLoss(nn.Module): 95 | 96 | def __init__(self, metric: str): 97 | assert metric in ['l2', 'l1', 'huber'] 98 | super(SimilarityDistillationLoss, self).__init__() 99 | self.distances = MatrixPairwiseDistances() 100 | self.metric = metric 101 | 102 | def forward(self, teacher_embs: torch.Tensor, student_embs: torch.Tensor): 103 | teacher_distances = self.distances(teacher_embs) 104 | student_distances = self.distances(student_embs) 105 | 106 | if self.metric == 'l2': 107 | return 0.5 * nn.MSELoss(reduction='mean')(student_distances, teacher_distances) 108 | if self.metric == 'l1': 109 | return 0.5 * nn.L1Loss(reduction='mean')(student_distances, teacher_distances) 110 | if self.metric == 'huber': 111 | return 0.5 * nn.SmoothL1Loss(reduction='mean')(student_distances, teacher_distances) 112 | raise ValueError() 113 | 114 | def __call__(self, *args, **kwargs): 115 | return super(SimilarityDistillationLoss, self).__call__(*args, **kwargs) 116 | 117 | 118 | class OnlineTripletLoss(nn.Module): 119 | 120 | def __init__(self, margin='soft', batch_hard=True, reduction='mean'): 121 | super(OnlineTripletLoss, self).__init__() 122 | self.batch_hard = batch_hard 123 | self.reduction = reduction 124 | if isinstance(margin, float) or margin == 'soft': 125 | self.margin = margin 126 | else: 127 | raise NotImplementedError( 128 | 'The margin {} is not recognized in TripletLoss()'.format(margin)) 129 | 130 | def forward(self, feat, id=None, pos_mask=None, neg_mask=None, mode='id', dis_func='eu', 131 | n_dis=0): 132 | 133 | if dis_func == 'cdist': 134 | feat = feat / feat.norm(p=2, dim=1, keepdim=True) 135 | dist = self.cdist(feat, feat) 136 | elif dis_func == 'eu': 137 | dist = self.cdist(feat, feat) 138 | 139 | if mode == 'id': 140 | if id is None: 141 | raise RuntimeError('foward is in id mode, please input id!') 142 | else: 143 | identity_mask = torch.eye(feat.size(0)).byte() 144 | identity_mask = identity_mask.cuda() if id.is_cuda else identity_mask 145 | same_id_mask = torch.eq(id.unsqueeze(1), id.unsqueeze(0)) 146 | negative_mask = same_id_mask ^ 1 147 | positive_mask = same_id_mask ^ identity_mask.bool() 148 | elif mode == 'mask': 149 | if pos_mask is None or neg_mask is None: 150 | raise RuntimeError('foward is in mask mode, please input pos_mask & neg_mask!') 151 | else: 152 | positive_mask = pos_mask 153 | same_id_mask = neg_mask ^ 1 154 | negative_mask = neg_mask 155 | else: 156 | raise ValueError('unrecognized mode') 157 | 158 | if self.batch_hard: 159 | if n_dis != 0: 160 | img_dist = dist[:-n_dis, :-n_dis] 161 | max_positive = (img_dist * positive_mask[:-n_dis, :-n_dis].float()).max(1)[0] 162 | min_negative = (img_dist + 1e5 * same_id_mask[:-n_dis, :-n_dis].float()).min(1)[0] 163 | dis_min_negative = dist[:-n_dis, -n_dis:].min(1)[0] 164 | z_origin = max_positive - min_negative 165 | # z_dis = max_positive - dis_min_negative 166 | else: 167 | max_positive = (dist * positive_mask.float()).max(1)[0] 168 | min_negative = (dist + 1e5 * same_id_mask.float()).min(1)[0] 169 | z = max_positive - min_negative 170 | else: 171 | pos = positive_mask.topk(k=1, dim=1)[1].view(-1, 1) 172 | positive = torch.gather(dist, dim=1, index=pos) 173 | pos = negative_mask.topk(k=1, dim=1)[1].view(-1, 1) 174 | negative = torch.gather(dist, dim=1, index=pos) 175 | z = positive - negative 176 | 177 | if isinstance(self.margin, float): 178 | b_loss = torch.clamp(z + self.margin, min=0) 179 | elif self.margin == 'soft': 180 | if n_dis != 0: 181 | b_loss = torch.log(1 + torch.exp( 182 | z_origin)) + -0.5 * dis_min_negative # + torch.log(1+torch.exp(z_dis)) 183 | else: 184 | b_loss = torch.log(1 + torch.exp(z)) 185 | else: 186 | raise NotImplementedError("How do you even get here!") 187 | 188 | if self.reduction == 'mean': 189 | return b_loss.mean() 190 | 191 | return b_loss.sum() 192 | 193 | def cdist(self, a, b): 194 | ''' 195 | Returns euclidean distance between a and b 196 | 197 | Args: 198 | a (2D Tensor): A batch of vectors shaped (B1, D) 199 | b (2D Tensor): A batch of vectors shaped (B2, D) 200 | Returns: 201 | A matrix of all pairwise distance between all vectors in a and b, 202 | will be shape of (B1, B2) 203 | ''' 204 | diff = a.unsqueeze(1) - b.unsqueeze(0) 205 | return ((diff ** 2).sum(2) + 1e-12).sqrt() 206 | 207 | def __call__(self, *args, **kwargs): 208 | return super(OnlineTripletLoss, self).__call__(*args, **kwargs) 209 | -------------------------------------------------------------------------------- /model/net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torchvision.models.densenet import densenet121 4 | 5 | from torchvision.models.resnet import resnet18 6 | from torchvision.models.resnet import resnet34 7 | from torchvision.models.resnet import resnet50 8 | from torchvision.models.resnet import resnet101 9 | 10 | from model.cbam.resnet_bam import resnet50_bam 11 | from model.cbam.resnet_bam import resnet101_bam 12 | 13 | from argparse import Namespace 14 | from itertools import chain 15 | 16 | import inspect 17 | 18 | 19 | class Backbone(nn.Module): 20 | 21 | RESNET_18 = 'resnet18' 22 | RESNET_34 = 'resnet34' 23 | RESNET_50 = 'resnet50' 24 | RESNET_101 = 'resnet101' 25 | RESNET_50_BAM = 'resnet50bam' 26 | RESNET_101_BAM = 'resnet101bam' 27 | DENSENET = 'densenet121' 28 | MOBILENET = 'mobilenet' 29 | 30 | def __init__(self, btype: str, pretrained: bool = True, last_stride: int = 1): 31 | 32 | super(Backbone, self).__init__() 33 | 34 | assert btype in [Backbone.RESNET_18, Backbone.RESNET_34, Backbone.RESNET_50, 35 | Backbone.RESNET_50_BAM, Backbone.RESNET_101, Backbone.RESNET_101_BAM, 36 | Backbone.DENSENET, Backbone.MOBILENET] 37 | 38 | if btype in [Backbone.RESNET_18, Backbone.RESNET_34, Backbone.RESNET_50, 39 | Backbone.RESNET_101, Backbone.RESNET_50_BAM, Backbone.RESNET_101_BAM]: 40 | self.features_layers, self.output_shape = \ 41 | self.get_resnet_backbone_layers(btype, pretrained, last_stride) 42 | 43 | if btype in [Backbone.DENSENET]: 44 | self.features_layers, self.output_shape = self.get_denset_backbone_layers(pretrained) 45 | 46 | if btype in [Backbone.MOBILENET]: 47 | self.features_layers, self.output_shape = self.get_mobilenet_backbone_layers(pretrained, last_stride) 48 | 49 | self.btype = btype 50 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 51 | 52 | @staticmethod 53 | def get_net(btype: str, pretrained: bool): 54 | if btype == Backbone.RESNET_18: 55 | return resnet18(pretrained=pretrained, zero_init_residual=True), 512 56 | elif btype == Backbone.RESNET_34: 57 | return resnet34(pretrained=pretrained, zero_init_residual=True), 512 58 | elif btype == Backbone.RESNET_50: 59 | return resnet50(pretrained=pretrained, zero_init_residual=True), 2048 60 | elif btype == Backbone.RESNET_101: 61 | return resnet101(pretrained=pretrained, zero_init_residual=True), 2048 62 | elif btype == Backbone.RESNET_50_BAM: 63 | return resnet50_bam(pretrained=pretrained, zero_init_residual=True), 2048 64 | elif btype == Backbone.RESNET_101_BAM: 65 | return resnet101_bam(pretrained=pretrained, zero_init_residual=True), 2048 66 | elif btype == Backbone.DENSENET: 67 | return densenet121(pretrained=pretrained), 1024 68 | elif btype == Backbone.MOBILENET: 69 | model = torch.hub.load('pytorch/vision:v0.5.0', 'mobilenet_v2', 70 | pretrained=pretrained), 1280 71 | return model 72 | raise ValueError() 73 | 74 | @staticmethod 75 | def get_mobilenet_backbone_layers(pretrained: bool, last_stride: int): 76 | mobilenet, num_out_channels = Backbone.get_net(Backbone.MOBILENET, pretrained=pretrained) 77 | mobilenet_features = mobilenet.features 78 | mobilenet_features[14].conv[1][0].stride = (last_stride, last_stride) 79 | mobilenet_features[18][2] = nn.Sequential() 80 | return nn.ModuleList(mobilenet_features.children()), num_out_channels 81 | 82 | @staticmethod 83 | def get_resnet_backbone_layers(btype: str, pretrained: bool, last_stride: int): 84 | 85 | assert last_stride in [1, 2] 86 | 87 | resnet, output_shape = Backbone.get_net(btype, pretrained) 88 | 89 | if btype in [Backbone.RESNET_18, Backbone.RESNET_34]: 90 | resnet.layer4[0].conv1.stride = (last_stride, last_stride) 91 | if btype in [Backbone.RESNET_50, Backbone.RESNET_101, 92 | Backbone.RESNET_50_BAM, Backbone.RESNET_101_BAM]: 93 | resnet.layer4[0].conv2.stride = (last_stride, last_stride) 94 | 95 | resnet.layer4[0].downsample[0].stride = (last_stride, last_stride) 96 | 97 | resnet.layer4[-1].relu = nn.Sequential() # replace relu with empty sequential 98 | 99 | resnet_layers = [ 100 | nn.Sequential( 101 | resnet.conv1, resnet.bn1, 102 | resnet.relu, 103 | resnet.maxpool), 104 | nn.Sequential(resnet.layer1), 105 | nn.Sequential(resnet.layer2), 106 | nn.Sequential(resnet.layer3), 107 | nn.Sequential(resnet.layer4), 108 | ] 109 | 110 | return nn.ModuleList(resnet_layers), output_shape 111 | 112 | @staticmethod 113 | def get_denset_backbone_layers(pretrained: bool): 114 | dnet = densenet121(pretrained=pretrained) 115 | original_model = list(dnet.children())[0] 116 | return nn.ModuleList(original_model.children()), 1024 117 | 118 | def get_output_shape(self): 119 | return self.output_shape 120 | 121 | def backbone_features(self, x: torch.Tensor): 122 | for m in self.features_layers: 123 | x = m(x) 124 | return x 125 | 126 | def forward(self, x: torch.Tensor): 127 | b, v, c, h, w = x.shape 128 | x = x.reshape(b * v, c, h, w) 129 | x = self.backbone_features(x) 130 | x = self.avgpool(x) 131 | return x 132 | 133 | def __call__(self, *args, **kwargs): 134 | return super(Backbone, self).__call__(*args, **kwargs) 135 | 136 | 137 | BACKBONES = [Backbone.RESNET_18, Backbone.RESNET_34, Backbone.RESNET_50, 138 | Backbone.RESNET_101, Backbone.RESNET_50_BAM, Backbone.RESNET_101_BAM, 139 | Backbone.DENSENET, Backbone.MOBILENET] 140 | 141 | 142 | class ClassificationLayer(nn.Module): 143 | 144 | def __init__(self, num_classes, feat_in: int): 145 | super(ClassificationLayer, self).__init__() 146 | 147 | self.feat_in = feat_in 148 | self.num_classes = num_classes 149 | 150 | self.bottleneck = nn.BatchNorm1d(self.feat_in) 151 | self.bottleneck.bias.requires_grad_(False) # no shift 152 | self.classifier = nn.Linear(self.feat_in, self.num_classes, bias=False) 153 | 154 | self.bottleneck.apply(self.weights_init_kaiming) 155 | self.classifier.apply(self.weights_init_classifier) 156 | 157 | def forward(self, feats): 158 | return self.classifier(self.bottleneck(feats)) 159 | 160 | @staticmethod 161 | def weights_init_kaiming(m): 162 | classname = m.__class__.__name__ 163 | if classname.find('Linear') != -1: 164 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out') 165 | nn.init.constant_(m.bias, 0.0) 166 | elif classname.find('Conv') != -1: 167 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 168 | if m.bias is not None: 169 | nn.init.constant_(m.bias, 0.0) 170 | elif classname.find('BatchNorm') != -1: 171 | if m.affine: 172 | nn.init.constant_(m.weight, 1.0) 173 | nn.init.constant_(m.bias, 0.0) 174 | 175 | @staticmethod 176 | def weights_init_classifier(m): 177 | classname = m.__class__.__name__ 178 | if classname.find('Linear') != -1: 179 | nn.init.normal_(m.weight, std=0.001) 180 | if m.bias: 181 | nn.init.constant_(m.bias, 0.0) 182 | 183 | def __call__(self, *args, **kwargs): 184 | return super(ClassificationLayer, self).__call__(*args, **kwargs) 185 | 186 | 187 | class TriNet(nn.Module): 188 | 189 | def __init__(self, backbone_type: str, num_classes: int, pretrained: bool): 190 | 191 | super(TriNet, self).__init__() 192 | 193 | _, _, _, values = inspect.getargvalues(inspect.currentframe()) 194 | self.hparams = {key: values[key] for key in values.keys() 195 | if key not in ('self', '__class__')} 196 | 197 | self.backbone = Backbone(btype=backbone_type, pretrained=pretrained) 198 | 199 | self.aggregator = MeanAggregator() 200 | 201 | self.classifier = ClassificationLayer(num_classes=num_classes, 202 | feat_in=self.backbone.get_output_shape()) 203 | 204 | def get_hparams(self): 205 | return self.hparams 206 | 207 | def backbone_features(self, x: torch.Tensor): 208 | b, v, c, h, w = x.shape 209 | x = self.backbone(x) 210 | out_shape = [b, v, self.backbone.output_shape] 211 | x = x.reshape(*out_shape) 212 | return x 213 | 214 | def forward(self, x: torch.Tensor, return_logits: bool = False): 215 | 216 | if len(x.shape) == 4: 217 | x = x.unsqueeze(1) 218 | 219 | b, v, c, h, w = x.shape 220 | x = self.backbone(x) # out before BN 221 | 222 | x_agg = self.aggregator(x.view(b, v, self.backbone.output_shape)) 223 | 224 | if return_logits: 225 | # Note: this applies mean AFTER classifier, not before. 226 | x_class = self.classifier(x.view(b * v, self.backbone.output_shape)) 227 | x_class = x_class.view(b, v, self.classifier.num_classes) 228 | x_class = torch.mean(x_class, dim=1) 229 | return x_agg, x_class 230 | 231 | return x_agg 232 | 233 | def teacher_mode(self): 234 | for m in self.modules(): 235 | if isinstance(m, (torch.nn.BatchNorm2d, torch.nn.BatchNorm1d)): 236 | m.track_running_stats = False 237 | self.train() 238 | 239 | def student_mode(self): 240 | for m in self.modules(): 241 | if isinstance(m, (torch.nn.BatchNorm2d, torch.nn.BatchNorm1d)): 242 | m.track_running_stats = True 243 | self.train() 244 | 245 | def reinit_layers(self, reinitl4: bool, reinitl3: bool): 246 | 247 | if reinitl4 or reinitl3: 248 | self.classifier.load_state_dict( 249 | ClassificationLayer(self.classifier.num_classes, 250 | self.classifier.feat_in).state_dict()) 251 | 252 | r, _ = Backbone.get_net(self.backbone.btype, True) 253 | 254 | if reinitl4: 255 | if self.backbone.btype == Backbone.DENSENET: 256 | block = self.backbone.features_layers[-2] 257 | block.load_state_dict(r._modules['features'].denseblock4.state_dict()) 258 | elif self.backbone.btype == Backbone.MOBILENET: 259 | pass 260 | else: 261 | block = self.backbone.features_layers[-1][0] 262 | block.load_state_dict(r.layer4.state_dict()) 263 | 264 | if reinitl3: 265 | raise ValueError() 266 | 267 | def block_parameters(self, reinitl4: bool, reinitl3: bool): 268 | first_idx = [] 269 | if reinitl4: first_idx.append(4) 270 | if reinitl3: first_idx.append(3) 271 | base_params = [ list(f.parameters()) for i, f in enumerate(self.backbone.features_layers) 272 | if i not in first_idx ] 273 | upper_params = [ list(f.parameters()) for i, f in enumerate(self.backbone.features_layers) 274 | if i in first_idx ] 275 | upper_params.append(list(self.classifier.parameters())) 276 | base_params = chain(*base_params) 277 | upper_params = chain(*upper_params) 278 | return base_params, upper_params 279 | 280 | def __call__(self, *args, **kwargs): 281 | return super(TriNet, self).__call__(*args, **kwargs) 282 | 283 | 284 | def get_model(args: Namespace, num_pids: int): 285 | return TriNet(backbone_type=args.backbone, pretrained=args.pretrained, 286 | num_classes=num_pids) 287 | 288 | 289 | class MeanAggregator(nn.Module): 290 | 291 | def __init__(self): 292 | super(MeanAggregator, self).__init__() 293 | 294 | def forward(self, x: torch.Tensor): 295 | return x.mean(dim=1) 296 | 297 | def __call__(self, *args, **kwargs): 298 | return super(MeanAggregator, self).__call__(*args, **kwargs) 299 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Cython==0.29.14 2 | numpy==1.17.3 3 | opencv-python==4.0.0.21 4 | Pillow==6.0.0 5 | scikit-learn==0.18.1 6 | scipy==1.2.1 7 | tensorboard==2.0.2 8 | -------------------------------------------------------------------------------- /tools/eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | from typing import Tuple 4 | 5 | import numpy as np 6 | import torch 7 | from torch.utils.data import DataLoader 8 | 9 | from data.datasets import get_dataloaders 10 | from data.eval_metrics import evaluate 11 | from model.net import get_model 12 | from utils.conf import Conf 13 | from utils.saver import Saver 14 | from data.datasets import DataConf, DATA_CONFS 15 | 16 | 17 | class Evaluator: 18 | 19 | def __init__(self, model: torch.nn.Module, query_loader: DataLoader, 20 | gallery_loader: DataLoader, queryimg_loader: DataLoader, 21 | galleryimg_loader: DataLoader, data_conf: DataConf, 22 | device: torch.device): 23 | 24 | self.perform_x2i = data_conf.perform_x2i 25 | self.perform_x2v = data_conf.perform_x2v 26 | model.eval() 27 | 28 | self.gallery_loader = gallery_loader 29 | self.query_loader = query_loader 30 | self.queryimg_loader = queryimg_loader 31 | self.galleryimg_loader = galleryimg_loader 32 | 33 | # ----------- QUERY 34 | vid_qf, self.vid_q_pids, self.vid_q_camids = self.extract_features(model, query_loader, 35 | device) 36 | img_qf, self.img_q_pids, self.img_q_camids = self.extract_features(model, queryimg_loader, 37 | device) 38 | # ----------- GALLERY 39 | if self.perform_x2v: 40 | vid_gf, self.vid_g_pids, self.vid_g_camids = self.extract_features(model, 41 | gallery_loader, 42 | device) 43 | if self.perform_x2i: 44 | img_gf, self.img_g_pids, self.img_g_camids = self.extract_features(model, 45 | galleryimg_loader, 46 | device) 47 | 48 | if data_conf.augment_gallery: 49 | # gallery must contain query, if not 140 query will not have ground truth in MARS. 50 | if self.perform_x2v: 51 | vid_gf = torch.cat((vid_qf, vid_gf), 0) 52 | self.vid_g_pids = np.append(self.vid_q_pids, self.vid_g_pids) 53 | self.vid_g_camids = np.append(self.vid_q_camids, self.vid_g_camids) 54 | if self.perform_x2i: 55 | img_gf = torch.cat((img_qf, img_gf), 0) 56 | self.img_g_pids = np.append(self.img_q_pids, self.img_g_pids) 57 | self.img_g_camids = np.append(self.img_q_camids, self.img_g_camids) 58 | 59 | if self.perform_x2v: 60 | self.v2v_distmat = self.compute_distance_matrix(vid_qf, vid_gf, metric='cosine').numpy() 61 | self.i2v_distmat = self.compute_distance_matrix(img_qf, vid_gf, metric='cosine').numpy() 62 | if self.perform_x2i: 63 | self.v2i_distmat = self.compute_distance_matrix(vid_qf, img_gf, metric='cosine').numpy() 64 | self.i2i_distmat = self.compute_distance_matrix(img_qf, img_gf, metric='cosine').numpy() 65 | 66 | @staticmethod 67 | def compute_distance_matrix(x: torch.Tensor, y: torch.Tensor, metric='cosine'): 68 | if metric == 'cosine': 69 | x = x / torch.norm(x, dim=-1, keepdim=True) 70 | y = y / torch.norm(y, dim=-1, keepdim=True) 71 | 72 | return 1 - torch.mm(x, y.T) 73 | 74 | def evaluate_v2v(self, verbose: bool = True): 75 | cmc, mAP, wrong_matches = evaluate(self.v2v_distmat, self.vid_q_pids, 76 | self.vid_g_pids, self.vid_q_camids, 77 | self.vid_g_camids, True) 78 | 79 | if verbose: 80 | print(f'V2V') 81 | print(f'top1:{cmc[0]:.2%} top5:{cmc[4]:.2%} top10:{cmc[9]:.2%} mAP:{mAP:.2%}') 82 | 83 | return cmc, mAP 84 | 85 | def evaluate_i2v(self, verbose: bool = True): 86 | 87 | cmc, mAP, wrong_matches = evaluate(self.i2v_distmat, self.img_q_pids, 88 | self.vid_g_pids, self.img_q_camids, 89 | self.vid_g_camids, True) 90 | if verbose: 91 | print(f'I2V') 92 | print(f'top1:{cmc[0]:.2%} top5:{cmc[4]:.2%} top10:{cmc[9]:.2%} mAP:{mAP:.2%}') 93 | 94 | return cmc, mAP 95 | 96 | def evaluate_v2i(self, verbose: bool = True): 97 | 98 | cmc, mAP, wrong_matches = evaluate(self.v2i_distmat, self.vid_q_pids, 99 | self.img_g_pids, self.vid_q_camids, 100 | self.img_g_camids, True) 101 | 102 | if verbose: 103 | print(f'V21') 104 | print(f'top1:{cmc[0]:.2%} top5:{cmc[4]:.2%} top10:{cmc[9]:.2%} mAP:{mAP:.2%}') 105 | 106 | return cmc, mAP 107 | 108 | def evaluate_i2i(self, verbose: bool = True): 109 | 110 | cmc, mAP, wrong_matches = evaluate(self.i2i_distmat, self.img_q_pids, 111 | self.img_g_pids, self.img_q_camids, 112 | self.img_g_camids, True) 113 | if verbose: 114 | print(f'I2I') 115 | print(f'top1:{cmc[0]:.2%} top5:{cmc[4]:.2%} top10:{cmc[9]:.2%} mAP:{mAP:.2%}') 116 | 117 | return cmc, mAP 118 | 119 | @torch.no_grad() 120 | def extract_features(self, model: torch.nn.Module, loader: DataLoader, 121 | device: torch.device): 122 | """ 123 | Extract features for the entire dataloader. It returns also pids and cams. 124 | """ 125 | features, pids, cams = [], [], [] 126 | 127 | for vids, pidids, camids in loader: 128 | vids = vids.to(device) 129 | 130 | feat = model(vids) 131 | feat = feat.data 132 | 133 | features.append(feat) 134 | pids.extend(pidids) 135 | cams.extend(camids) 136 | 137 | features = torch.cat(features, 0).to('cpu') 138 | pids = np.asarray(pids) 139 | cams = np.asarray(cams) 140 | return features, pids, cams 141 | 142 | @staticmethod 143 | def tb_cmc(saver: Saver, cmc_scores, it, method): 144 | for cmc_v in [0, 4, 9]: 145 | saver.dump_metric_tb(cmc_scores[cmc_v], it, 146 | f'{method}', f'cmc{cmc_v + 1}') 147 | 148 | def eval(self, saver: Saver, iteration: int, verbose: bool, do_tb: bool = True): 149 | 150 | if self.perform_x2v: 151 | cmc_scores_i2v, mAP_i2v = self.evaluate_i2v(verbose=verbose) 152 | if do_tb: 153 | saver.dump_metric_tb(mAP_i2v, iteration, 'i2v', f'mAP') 154 | self.tb_cmc(saver, cmc_scores_i2v, iteration, 'i2v') 155 | 156 | cmc_scores_v2v, mAP_v2v = self.evaluate_v2v(verbose=verbose) 157 | if do_tb: 158 | saver.dump_metric_tb(mAP_v2v, iteration, 'v2v', f'mAP') 159 | self.tb_cmc(saver, cmc_scores_v2v, iteration, 'v2v') 160 | 161 | if self.perform_x2i: 162 | cmc_scores_i2i, mAP_i2i = self.evaluate_i2i(verbose=verbose) 163 | if do_tb: 164 | saver.dump_metric_tb(mAP_i2i, iteration, 'i2i', f'mAP') 165 | self.tb_cmc(saver, cmc_scores_i2i, iteration, 'i2i') 166 | 167 | cmc_scores_v2i, mAP_v2i = self.evaluate_v2i(verbose=verbose) 168 | if do_tb: 169 | saver.dump_metric_tb(mAP_v2i, iteration, 'v2i', f'mAP') 170 | self.tb_cmc(saver, cmc_scores_v2i, iteration, 'v2i') 171 | 172 | 173 | def parse(conf: Conf): 174 | 175 | parser = argparse.ArgumentParser(description='Train img to video model') 176 | parser = conf.add_default_args(parser) 177 | 178 | parser.add_argument('trinet_folder', type=str, help='Path to TriNet base folder.') 179 | parser.add_argument('--trinet_chk_name', type=str, help='checkpoint name', default='chk_end') 180 | 181 | args = parser.parse_args() 182 | args.train_strategy = 'chunk' 183 | args.use_random_erasing = False 184 | args.num_train_images = 0 185 | 186 | return args 187 | 188 | 189 | def main(): 190 | conf = Conf() 191 | conf.suppress_random() 192 | device = conf.get_device() 193 | 194 | args = parse(conf) 195 | 196 | # ---- SAVER OLD NET TO RESTORE PARAMS 197 | saver_trinet = Saver(Path(args.trinet_folder).parent, Path(args.trinet_folder).name) 198 | old_params, old_hparams = saver_trinet.load_logs() 199 | args.backbone = old_params['backbone'] 200 | args.metric = old_params['metric'] 201 | 202 | train_loader, query_loader, gallery_loader, queryimg_loader, galleryimg_loader = \ 203 | get_dataloaders(args.dataset_name, conf.nas_path, device, args) 204 | num_pids = train_loader.dataset.get_num_pids() 205 | 206 | assert num_pids == old_hparams['num_classes'] 207 | 208 | net = get_model(args, num_pids).to(device) 209 | state_dict = torch.load(Path(args.trinet_folder) / 'chk' / args.trinet_chk_name) 210 | net.load_state_dict(state_dict) 211 | 212 | e = Evaluator(net, query_loader, gallery_loader, queryimg_loader, galleryimg_loader, 213 | device=device, data_conf=DATA_CONFS[args.dataset_name]) 214 | 215 | e.eval(None, 0, verbose=True, do_tb=False) 216 | 217 | 218 | if __name__ == '__main__': 219 | main() 220 | -------------------------------------------------------------------------------- /tools/save_heatmaps.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import cv2 5 | import matplotlib 6 | import numpy as np 7 | import torch 8 | from tqdm import tqdm 9 | 10 | from data.datasets import get_dataloaders 11 | from utils.conf import Conf 12 | from utils.saver import Saver 13 | 14 | matplotlib.use('Agg') 15 | from matplotlib import pyplot as plt 16 | from torch.nn.functional import adaptive_avg_pool2d 17 | from torch.nn.functional import relu 18 | from torch.nn.functional import normalize 19 | 20 | from model.net import TriNet 21 | 22 | class Hook: 23 | def __init__(self): 24 | self.buffer = [] 25 | 26 | def __call__(self, module, _, ten_out): 27 | self.buffer.append(ten_out) 28 | 29 | def reset(self): 30 | self.buffer = [] 31 | 32 | 33 | def parse(conf: Conf): 34 | 35 | parser = argparse.ArgumentParser(description='Train img to video model') 36 | parser = conf.add_default_args(parser) 37 | 38 | parser.add_argument('net1', type=str, help='Path to TriNet base folder.') 39 | parser.add_argument('--chk_net1', type=str, help='checkpoint name', default='chk_end') 40 | parser.add_argument('net2', type=str, help='Path to TriNet base folder.') 41 | parser.add_argument('--chk_net2', type=str, help='checkpoint name', default='chk_end') 42 | parser.add_argument('--dest_path', type=Path, default='/tmp/heatmaps_out') 43 | 44 | args = parser.parse_args() 45 | args.train_strategy = 'multiview' 46 | args.use_random_erasing = False 47 | args.num_train_images = 0 48 | args.img_test_batch = 32 49 | 50 | return args 51 | 52 | 53 | def extract_grad_cam(net: TriNet, inputs: torch.Tensor, device: torch.device, 54 | hook: Hook): 55 | 56 | _, logits = net(inputs, return_logits=True) # forward calls hooks 57 | logits_max = torch.max(logits, 1)[0] 58 | 59 | conv_features = hook.buffer[0] 60 | 61 | grads = torch.autograd.grad(logits_max, conv_features, 62 | grad_outputs=torch.ones(len(conv_features)).to(device))[0] 63 | 64 | with torch.no_grad(): 65 | weights = adaptive_avg_pool2d(grads, (1, 1)) 66 | attn = relu(torch.sum(conv_features * weights, 1)) 67 | old_shape = attn.shape 68 | attn = normalize(attn.view(attn.shape[0], -1)) 69 | attn = attn.view(old_shape) 70 | 71 | return attn.view(*inputs.shape[:2], *attn.shape[1:]) 72 | 73 | 74 | def save_img(img, attn, dest_path): 75 | height, width = img.shape[0], img.shape[1] 76 | fig = plt.figure() 77 | fig.set_size_inches(width / height, 1, forward=False) 78 | ax = plt.Axes(fig, [0., 0., 1., 1.]) 79 | ax.set_axis_off() 80 | fig.add_axes(ax) 81 | ax.imshow(img, origin='upper') 82 | if attn is not None: 83 | ax.imshow(attn, origin='upper', extent=[0, width, height, 0], 84 | alpha=0.4, cmap=plt.cm.get_cmap('jet')) 85 | fig.canvas.draw() 86 | plt.savefig(dest_path, dpi=height) 87 | plt.close() 88 | 89 | 90 | def main(): 91 | conf = Conf() 92 | conf.suppress_random() 93 | device = conf.get_device() 94 | 95 | args = parse(conf) 96 | 97 | dest_path = args.dest_path / (Path(args.net1).name + '__vs__' + Path(args.net2).name) 98 | dest_path.mkdir(exist_ok=True, parents=True) 99 | 100 | both_path = dest_path / 'both' 101 | both_path.mkdir(exist_ok=True, parents=True) 102 | 103 | net1_path = dest_path / Path(args.net1).name 104 | net1_path.mkdir(exist_ok=True, parents=True) 105 | 106 | net2_path = dest_path / Path(args.net2).name 107 | net2_path.mkdir(exist_ok=True, parents=True) 108 | 109 | orig_path = dest_path / 'orig' 110 | orig_path.mkdir(exist_ok=True, parents=True) 111 | 112 | # ---- Restore net 113 | net1 = Saver.load_net(args.net1, args.chk_net1, args.dataset_name).to(device) 114 | net2 = Saver.load_net(args.net2, args.chk_net2, args.dataset_name).to(device) 115 | 116 | net1.eval() 117 | net2.eval() 118 | 119 | train_loader, query_loader, gallery_loader, queryimg_loader, galleryimg_loader = \ 120 | get_dataloaders(args.dataset_name, conf.nas_path, device, args) 121 | 122 | # register hooks 123 | hook_net_1, hook_net_2 = Hook(), Hook() 124 | 125 | net1.backbone.features_layers[4].register_forward_hook(hook_net_1) 126 | net2.backbone.features_layers[4].register_forward_hook(hook_net_2) 127 | 128 | dst_idx = 0 129 | 130 | for idx_batch, (vids, *_) in enumerate(tqdm(galleryimg_loader, 'iterating..')): 131 | if idx_batch < len(galleryimg_loader) - 50: 132 | continue 133 | net1.zero_grad() 134 | net2.zero_grad() 135 | 136 | hook_net_1.reset() 137 | hook_net_2.reset() 138 | 139 | vids = vids.to(device) 140 | attn_1 = extract_grad_cam(net1, vids, device, hook_net_1) 141 | attn_2 = extract_grad_cam(net2, vids, device, hook_net_2) 142 | 143 | B, N_VIEWS = attn_1.shape[0], attn_1.shape[1] 144 | 145 | for idx_b in range(B): 146 | for idx_v in range(N_VIEWS): 147 | 148 | el_img = vids[idx_b, idx_v] 149 | el_attn_1 = attn_1[idx_b, idx_v] 150 | el_attn_2 = attn_2[idx_b, idx_v] 151 | 152 | el_img = el_img.cpu().numpy().transpose(1, 2, 0) 153 | el_attn_1 = el_attn_1.cpu().numpy() 154 | el_attn_2 = el_attn_2.cpu().numpy() 155 | 156 | mean, var = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 157 | el_img = (el_img * var) + mean 158 | el_img = np.clip(el_img, 0, 1) 159 | 160 | el_attn_1 = cv2.blur(el_attn_1, (3, 3)) 161 | el_attn_1 = cv2.resize(el_attn_1, (el_img.shape[1], el_img.shape[0]), 162 | interpolation=cv2.INTER_CUBIC) 163 | 164 | el_attn_2 = cv2.blur(el_attn_2, (3, 3)) 165 | el_attn_2 = cv2.resize(el_attn_2, (el_img.shape[1], el_img.shape[0]), 166 | interpolation=cv2.INTER_CUBIC) 167 | 168 | save_img(el_img, el_attn_1, net1_path / f'{dst_idx}.png') 169 | save_img(el_img, el_attn_2, net2_path / f'{dst_idx}.png') 170 | 171 | save_img(el_img, None, orig_path / f'{dst_idx}.png') 172 | 173 | save_img(np.concatenate([el_img, el_img], 1), 174 | np.concatenate([el_attn_1, el_attn_2], 1), both_path / f'{dst_idx}.png') 175 | 176 | dst_idx += 1 177 | 178 | 179 | if __name__ == '__main__': 180 | main() 181 | 182 | -------------------------------------------------------------------------------- /tools/train_distill.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from copy import deepcopy 3 | from uuid import uuid4 4 | 5 | import torch 6 | from torch import nn 7 | from torch.optim import Adam 8 | from torch.optim import lr_scheduler 9 | from torch.utils.data import DataLoader 10 | 11 | from data.datasets import get_dataloaders, DATA_CONFS 12 | from model.loss import KDLoss 13 | from model.loss import OnlineTripletLoss 14 | from model.loss import SimilarityDistillationLoss 15 | from model.loss import LogitsMatching 16 | from model.net import TriNet 17 | from tools.eval import Evaluator 18 | from utils.conf import Conf 19 | from utils.misc import AvgMeter 20 | from utils.misc import str2bool 21 | from utils.saver import Saver 22 | 23 | 24 | class LearningRateGenDecayer(object): 25 | 26 | def __init__(self, initial_lr: float, decay: float, min: float = 1e-5): 27 | self.decay = decay 28 | self.initial_lr = initial_lr 29 | 30 | def __call__(self, epoch: int): 31 | return max(self.initial_lr * (self.decay ** epoch), 1e-5) 32 | 33 | 34 | def parse(conf: Conf): 35 | parser = argparse.ArgumentParser(description='Train img to video model') 36 | parser = conf.add_default_args(parser) 37 | 38 | parser.add_argument('teacher', type=str) 39 | parser.add_argument('--teacher_chk_name', type=str, default='chk_end') 40 | 41 | parser.add_argument('--student', type=str) 42 | parser.add_argument('--student_chk_name', type=str, default='chk_end') 43 | 44 | parser.add_argument('--exp_name', type=str, default=str(uuid4())) 45 | parser.add_argument('--num_generations', type=int, default=1) 46 | 47 | parser.add_argument('--eval_epoch_interval', type=int, default=50) 48 | parser.add_argument('--print_epoch_interval', type=int, default=5) 49 | 50 | parser.add_argument('--lr', type=float, default=1e-4) 51 | parser.add_argument('--lr_decay', type=float, default=0.1) 52 | parser.add_argument('--temp', type=float, default=10.) 53 | parser.add_argument('--lambda_coeff', type=float, default=0.0001) 54 | parser.add_argument('--kl_coeff', type=float, default=0.1) 55 | 56 | parser.add_argument('--num_train_images', type=int, default=8) 57 | parser.add_argument('--num_student_images', type=int, default=2) 58 | 59 | parser.add_argument('--train_strategy', type=str, default='multiview', 60 | choices=['multiview', 'temporal']) 61 | 62 | parser.add_argument('--num_epochs', type=int, default=400) 63 | parser.add_argument('--gamma', type=float, default=0.1) 64 | parser.add_argument('--first_milestone', type=int, default=300) 65 | parser.add_argument('--step_milestone', type=int, default=50) 66 | 67 | parser.add_argument('--reinit_l4', type=str2bool, default=True) 68 | parser.add_argument('--reinit_l3', type=str2bool, default=False) 69 | 70 | parser.add_argument('--logits_dist', type=str, default='kl', 71 | choices=['kl', 'mse']) 72 | 73 | args = parser.parse_args() 74 | args.use_random_erasing = True 75 | 76 | return args 77 | 78 | 79 | class DistillationTrainer: 80 | 81 | def __init__(self, train_loader: DataLoader, query_loader: DataLoader, 82 | gallery_loader: DataLoader, queryimg_loader: DataLoader, 83 | galleryimg_loader: DataLoader, device: torch.device, saver: Saver, 84 | args: argparse.Namespace, conf: Conf): 85 | 86 | self.class_loss = nn.CrossEntropyLoss(reduction='mean').to(device) 87 | self.distill_loss = KDLoss(temp=args.temp, reduction='mean').to(device) \ 88 | if args.logits_dist == 'kl' else LogitsMatching(reduction='mean') 89 | self.similarity_loss = SimilarityDistillationLoss(metric='l2').to(device) 90 | self.triplet_loss = OnlineTripletLoss('soft', True, reduction='mean').to(device) 91 | 92 | self.train_loader = train_loader 93 | self.query_loader = query_loader 94 | self.gallery_loader = gallery_loader 95 | self.queryimg_loader = queryimg_loader 96 | self.galleryimg_loader = galleryimg_loader 97 | 98 | self.device = device 99 | self.saver = saver 100 | self.args = args 101 | self.conf = conf 102 | 103 | self.lr = LearningRateGenDecayer(initial_lr=self.args.lr, 104 | decay=self.args.lr_decay) 105 | self._epoch = 0 106 | self._gen = 0 107 | 108 | def evaluate(self, net: nn.Module): 109 | ev = Evaluator(net, query_loader, gallery_loader, queryimg_loader, galleryimg_loader, 110 | DATA_CONFS[self.args.dataset_name], device) 111 | ev.eval(self.saver, self._epoch, self.args.eval_epoch_interval, self.args.verbose) 112 | 113 | def __call__(self, teacher_net: TriNet, student_net: TriNet): 114 | 115 | opt = Adam(student_net.parameters(), lr=self.lr(self._gen), weight_decay=1e-5) 116 | 117 | milestones = list(range(self.args.first_milestone, self.args.num_epochs, 118 | self.args.step_milestone)) 119 | 120 | scheduler = lr_scheduler.MultiStepLR(opt, milestones=milestones, gamma=self.args.gamma) 121 | 122 | for e in range(self.args.num_epochs): 123 | 124 | if e % self.args.eval_epoch_interval == 0 and e > 0: 125 | self.evaluate(student_net) 126 | 127 | avm = AvgMeter(['kl', 'triplet', 'class', 'similarity', 'loss']) 128 | 129 | student_net.student_mode() 130 | teacher_net.teacher_mode() 131 | 132 | for x, y, cams in self.train_loader: 133 | 134 | x, y = x.to(self.device), y.to(self.device) 135 | x_ = torch.stack([x[i, torch.randperm(x.shape[1])] for i in range(x.shape[0])]) 136 | 137 | x_teacher, x_student = x, x_[:, :self.args.num_student_images] 138 | 139 | with torch.no_grad(): 140 | teacher_emb, teacher_logits = teacher_net(x_teacher, return_logits=True) 141 | 142 | opt.zero_grad() 143 | 144 | student_emb, student_logits = student_net(x_student, return_logits=True) 145 | 146 | kl_div_batch = self.distill_loss(teacher_logits, student_logits) 147 | similarity_loss_batch = self.similarity_loss(teacher_emb, student_emb) 148 | triplet_loss_batch = self.triplet_loss(student_emb, y) 149 | class_loss_batch = self.class_loss(student_logits, y) 150 | 151 | loss = (triplet_loss_batch + class_loss_batch) + \ 152 | self.args.lambda_coeff * (similarity_loss_batch) + \ 153 | self.args.kl_coeff * (kl_div_batch) 154 | 155 | avm.add([kl_div_batch.item(), triplet_loss_batch.item(), 156 | class_loss_batch.item(), similarity_loss_batch.item(), 157 | loss.item()]) 158 | 159 | loss.backward() 160 | opt.step() 161 | 162 | scheduler.step() 163 | 164 | if self._epoch % self.args.print_epoch_interval == 0: 165 | stats = avm() 166 | str_ = f"Epoch: {self._epoch}" 167 | for (l, m) in stats: 168 | str_ += f" - {l} {m:.2f}" 169 | self.saver.dump_metric_tb(m, self._epoch, 'losses', f"avg_{l}") 170 | self.saver.dump_metric_tb(opt.defaults['lr'], self._epoch, 'lr', 'lr') 171 | print(str_) 172 | 173 | self._epoch += 1 174 | 175 | self._gen += 1 176 | 177 | return student_net 178 | 179 | 180 | if __name__ == '__main__': 181 | conf = Conf() 182 | device = conf.get_device() 183 | args = parse(conf) 184 | 185 | conf.suppress_random(set_determinism=args.set_determinism) 186 | 187 | train_loader, query_loader, gallery_loader, queryimg_loader, galleryimg_loader = \ 188 | get_dataloaders(args.dataset_name, conf.nas_path, device, args) 189 | 190 | teacher_net: TriNet = Saver.load_net(args.teacher, 191 | args.teacher_chk_name, args.dataset_name).to(device) 192 | 193 | student_net: TriNet = deepcopy(teacher_net) if args.student is None \ 194 | else Saver.load_net(args.student, args.student_chk_name, args.dataset_name) 195 | student_net = student_net.to(device) 196 | 197 | ev = Evaluator(student_net, query_loader, gallery_loader, queryimg_loader, galleryimg_loader, 198 | DATA_CONFS[args.dataset_name], device) 199 | 200 | print('v' * 100) 201 | ev.eval(saver=None, iteration=None, verbose=True, do_tb=False) 202 | print('v' * 100) 203 | 204 | student_net.reinit_layers(args.reinit_l4, args.reinit_l3) 205 | 206 | saver = Saver(conf.log_path, args.exp_name) 207 | saver.write_logs(student_net, vars(args)) 208 | 209 | d_trainer: DistillationTrainer = DistillationTrainer(train_loader, query_loader, 210 | gallery_loader, queryimg_loader, galleryimg_loader, conf.get_device(), 211 | saver, args, conf) 212 | 213 | print("EXP_NAME: ", args.exp_name) 214 | 215 | for idx_iteration in range(args.num_generations): 216 | print(f'starting generation {idx_iteration+1}') 217 | print('#'*100) 218 | teacher_net = d_trainer(teacher_net, student_net) 219 | d_trainer.evaluate(teacher_net) 220 | teacher_net.teacher_mode() 221 | 222 | student_net = deepcopy(teacher_net) 223 | saver.save_net(student_net, f'chk_di_{idx_iteration + 1}') 224 | 225 | student_net.reinit_layers(args.reinit_l4, args.reinit_l3) 226 | 227 | saver.writer.close() 228 | -------------------------------------------------------------------------------- /tools/train_v2v.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from uuid import uuid4 3 | 4 | from torch import nn 5 | from torch.optim import Adam 6 | from torch.optim import lr_scheduler 7 | 8 | from data.datasets import get_dataloaders, DATA_CONFS 9 | from model.loss import OnlineTripletLoss 10 | from model.net import get_model 11 | from tools.eval import Evaluator 12 | from utils.conf import Conf 13 | from utils.saver import Saver 14 | from utils.misc import str2bool 15 | 16 | from utils.misc import AvgMeter 17 | 18 | 19 | def parse(conf: Conf): 20 | 21 | parser = argparse.ArgumentParser(description='Train img to video model') 22 | parser = conf.add_default_args(parser) 23 | 24 | parser.add_argument('--exp_name', type=str, default=str(uuid4()), help='Experiment name.') 25 | parser.add_argument('--metric', type=str, default='euclidean', 26 | choices=['euclidean', 'cosine'], help='Metric for distances') 27 | parser.add_argument('--num_train_images', type=int, default=8, help='Num. of bag images.') 28 | 29 | parser.add_argument('--num_epochs', type=int, default=300) 30 | parser.add_argument('--eval_epoch_interval', type=int, default=50) 31 | parser.add_argument('--save_epoch_interval', type=int, default=50) 32 | parser.add_argument('--print_epoch_interval', type=int, default=5) 33 | 34 | parser.add_argument('--wd', type=float, default=1e-5) 35 | 36 | parser.add_argument('--gamma', type=float, default=0.1) 37 | parser.add_argument('--first_milestone', type=int, default=200) 38 | parser.add_argument('--step_milestone', type=int, default=50) 39 | 40 | parser.add_argument('--use_random_erasing', type=str2bool, default=True) 41 | parser.add_argument('--train_strategy', type=str, default='chunk', 42 | choices=['multiview', 'chunk']) 43 | 44 | args = parser.parse_args() 45 | return args 46 | 47 | 48 | def main(): 49 | conf = Conf() 50 | args = parse(conf) 51 | device = conf.get_device() 52 | 53 | conf.suppress_random(set_determinism=args.set_determinism) 54 | saver = Saver(conf.log_path, args.exp_name) 55 | 56 | train_loader, query_loader, gallery_loader, queryimg_loader, galleryimg_loader = \ 57 | get_dataloaders(args.dataset_name, conf.nas_path, device, args) 58 | 59 | num_pids = train_loader.dataset.get_num_pids() 60 | 61 | net = nn.DataParallel(get_model(args, num_pids)) 62 | net = net.to(device) 63 | 64 | saver.write_logs(net.module, vars(args)) 65 | 66 | opt = Adam(net.parameters(), lr=1e-4, weight_decay=args.wd) 67 | milestones = list(range(args.first_milestone, args.num_epochs, 68 | args.step_milestone)) 69 | scheduler = lr_scheduler.MultiStepLR(opt, milestones=milestones, gamma=args.gamma) 70 | 71 | triplet_loss = OnlineTripletLoss('soft', True, reduction='mean').to(device) 72 | class_loss = nn.CrossEntropyLoss(reduction='mean').to(device) 73 | 74 | print("EXP_NAME: ", args.exp_name) 75 | 76 | for e in range(args.num_epochs): 77 | 78 | if e % args.eval_epoch_interval == 0 and e > 0: 79 | ev = Evaluator(net, query_loader, gallery_loader, queryimg_loader, galleryimg_loader, 80 | DATA_CONFS[args.dataset_name], device) 81 | ev.eval(saver, e, args.verbose) 82 | 83 | if e % args.save_epoch_interval == 0 and e > 0: 84 | saver.save_net(net.module, f'chk_{e // args.save_epoch_interval}') 85 | 86 | avm = AvgMeter(['triplet', 'class']) 87 | 88 | for it, (x, y, cams) in enumerate(train_loader): 89 | net.train() 90 | 91 | x, y = x.to(device), y.to(device) 92 | 93 | opt.zero_grad() 94 | embeddings, f_class = net(x, return_logits=True) 95 | 96 | triplet_loss_batch = triplet_loss(embeddings, y) 97 | class_loss_batch = class_loss(f_class, y) 98 | loss = triplet_loss_batch + class_loss_batch 99 | 100 | avm.add([triplet_loss_batch.item(), class_loss_batch.item()]) 101 | 102 | loss.backward() 103 | opt.step() 104 | 105 | if e % args.print_epoch_interval == 0: 106 | stats = avm() 107 | str_ = f"Epoch: {e}" 108 | for (l, m) in stats: 109 | str_ += f" - {l} {m:.2f}" 110 | saver.dump_metric_tb(m, e, 'losses', f"avg_{l}") 111 | saver.dump_metric_tb(opt.param_groups[0]['lr'], e, 'lr', 'lr') 112 | print(str_) 113 | 114 | scheduler.step() 115 | 116 | ev = Evaluator(net, query_loader, gallery_loader, queryimg_loader, galleryimg_loader, 117 | DATA_CONFS[args.dataset_name], device) 118 | ev.eval(saver, e, args.verbose) 119 | 120 | saver.save_net(net.module, 'chk_end') 121 | saver.writer.close() 122 | 123 | 124 | if __name__ == '__main__': 125 | main() 126 | -------------------------------------------------------------------------------- /utils/conf.py: -------------------------------------------------------------------------------- 1 | import socket 2 | import tempfile 3 | import torch 4 | 5 | from argparse import ArgumentParser 6 | 7 | from model.net import BACKBONES 8 | from model.net import Backbone 9 | 10 | from utils.misc import str2bool 11 | from data.datasets import DATASETS 12 | 13 | 14 | class Conf: 15 | """ 16 | This class encapsulate some configuration variables useful to have around. 17 | """ 18 | 19 | SEED = 1897 20 | 21 | def __init__(self): 22 | """ Constructor class. """ 23 | 24 | self.host_name = socket.gethostname() 25 | self.nas_path = self._get_nas_path() 26 | self.log_path = self._get_log_path() 27 | 28 | @staticmethod 29 | def add_default_args(parser: ArgumentParser): 30 | parser.add_argument('dataset_name', choices=list(DATASETS.keys()), type=str, help='dataset name') 31 | # Network 32 | parser.add_argument('--backbone', type=str, choices=BACKBONES, default=Backbone.RESNET_50, 33 | help='Backbone network type.') 34 | parser.add_argument('--pretrained', type=str2bool, default=True, help='No pretraining.') 35 | 36 | # Others 37 | parser.add_argument('--set_determinism', type=str2bool, default=False) 38 | parser.add_argument('--test_batch', default=32, type=int) 39 | parser.add_argument('--img_test_batch', default=512, type=int) 40 | parser.add_argument('--verbose', type=str2bool, default=True, help='Debug mode') 41 | parser.add_argument('-j', '--workers', default=4, type=int) 42 | parser.add_argument('--p', type=int, default=18, help='') 43 | parser.add_argument('--k', type=int, default=4, help='') 44 | 45 | parser.add_argument('--num_test_images', type=int, default=8) 46 | 47 | return parser 48 | 49 | @staticmethod 50 | def get_tmp_path(): 51 | return tempfile.gettempdir() 52 | 53 | @staticmethod 54 | def get_hostname_config(): 55 | 56 | default_config = { 57 | 'log_path': './logs', 58 | 'nas_path': './datasets' 59 | } 60 | 61 | return default_config 62 | 63 | def _get_log_path(self) -> str: 64 | 65 | default_config = self.get_hostname_config() 66 | return default_config["log_path"] 67 | 68 | def _get_nas_path(self) -> str: 69 | 70 | default_config = self.get_hostname_config() 71 | return default_config["nas_path"] 72 | 73 | @staticmethod 74 | def suppress_random(seed: int = SEED, set_determinism: bool = False): 75 | import random 76 | import torch 77 | import numpy as np 78 | random.seed(seed) 79 | np.random.seed(seed) 80 | torch.manual_seed(seed) 81 | torch.cuda.manual_seed_all(seed) 82 | if set_determinism: 83 | torch.backends.cudnn.deterministic = True 84 | 85 | @staticmethod 86 | def get_device(): 87 | return torch.device('cuda') if torch.cuda.is_available() \ 88 | else torch.device('cpu') 89 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from collections import defaultdict 3 | from typing import List 4 | 5 | import numpy as np 6 | 7 | 8 | def init_worker(worker_id): 9 | np.random.seed(1234 + worker_id) 10 | 11 | 12 | def str2bool(v): 13 | if isinstance(v, bool): 14 | return v 15 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 16 | return True 17 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 18 | return False 19 | else: 20 | raise argparse.ArgumentTypeError('Boolean value expected.') 21 | 22 | 23 | class SingleAvgMeter: 24 | 25 | def __init__(self, label: str): 26 | self.values = [] 27 | self.label = label 28 | 29 | def add(self, value: float): 30 | self.values.append(value) 31 | 32 | def __call__(self): 33 | return (self.label, np.array(self.values).mean()) 34 | 35 | def reset(self): 36 | self.values = () 37 | 38 | 39 | class AvgMeter: 40 | 41 | def __init__(self, labels: List[str]): 42 | 43 | self.avg_meters = [] 44 | 45 | for i in range(len(labels)): 46 | self.avg_meters.append(SingleAvgMeter(labels[i])) 47 | 48 | def add(self, values: List[float]): 49 | for i, v in enumerate(values): 50 | self.avg_meters[i].add(v) 51 | 52 | def __call__(self): 53 | return [avg_meter() for avg_meter in self.avg_meters] 54 | 55 | def reset(self): 56 | for avg_meter in self.avg_meters: 57 | avg_meter.reset() 58 | -------------------------------------------------------------------------------- /utils/saver.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from pathlib import Path 3 | 4 | import cv2 5 | import numpy as np 6 | import torch 7 | import json 8 | 9 | from torch.utils.tensorboard import SummaryWriter 10 | from model.net import TriNet 11 | 12 | 13 | class Saver(object): 14 | """ 15 | """ 16 | def __init__(self, path: str, uuid: str): 17 | self.path = Path(path) / uuid 18 | self.path.mkdir(exist_ok=True, parents=True) 19 | 20 | self.chk_path = self.path / 'chk' 21 | self.chk_path.mkdir(exist_ok=True) 22 | 23 | self.log_path = self.path / 'logs' 24 | self.log_path.mkdir(exist_ok=True) 25 | 26 | self.params_path = self.path / 'params' 27 | self.params_path.mkdir(exist_ok=True) 28 | 29 | # TB logs 30 | self.writer = SummaryWriter(str(self.path)) 31 | 32 | # Dump the `git log` and `git diff`. In this way one can checkout 33 | # the last commit, add the diff and should be in the same state. 34 | for cmd in ['log', 'diff']: 35 | with open(self.path / f'git_{cmd}.txt', mode='wt') as f: 36 | subprocess.run(['git', cmd], stdout=f) 37 | 38 | def load_logs(self): 39 | with open(str(self.params_path / 'params.json'), 'r') as fp: 40 | params = json.load(fp) 41 | with open(str(self.params_path / 'hparams.json'), 'r') as fp: 42 | hparams = json.load(fp) 43 | return params, hparams 44 | 45 | @staticmethod 46 | def load_net(path: str, chk_name: str, dataset_name: str): 47 | with open(str(Path(path) / 'params' / 'hparams.json'), 'r') as fp: 48 | net_hparams = json.load(fp) 49 | with open(str(Path(path) / 'params' / 'params.json'), 'r') as fp: 50 | net_params = json.load(fp) 51 | 52 | assert dataset_name == net_params['dataset_name'] 53 | net = TriNet(backbone_type=net_hparams['backbone_type'], pretrained=True, 54 | num_classes=net_hparams['num_classes']) 55 | net_state_dict = torch.load(Path(path) / 'chk' / chk_name) 56 | net.load_state_dict(net_state_dict) 57 | return net 58 | 59 | def write_logs(self, model: torch.nn.Module, params: dict): 60 | with open(str(self.params_path / 'params.json'), 'w') as fp: 61 | json.dump(params, fp) 62 | with open(str(self.params_path / 'hparams.json'), 'w') as fp: 63 | json.dump(model.get_hparams(), fp) 64 | 65 | def write_image(self, image: np.ndarray, epoch: int, name: str): 66 | out_image_path = self.log_path / f'{epoch:05d}_{name}.jpg' 67 | cv2.imwrite(str(out_image_path), image) 68 | 69 | image = image[..., ::-1] 70 | self.writer.add_image(f'{name}', image, epoch, dataformats='HWC') 71 | 72 | def dump_metric_tb(self, value: float, epoch: int, m_type: str, m_desc: str): 73 | self.writer.add_scalar(f'{m_type}/{m_desc}', value, epoch) 74 | 75 | def save_net(self, net: torch.nn.Module, name: str = 'weights', overwrite: bool = False): 76 | weights_path = self.chk_path / name 77 | if weights_path.exists() and not overwrite: 78 | raise ValueError('PREVENT OVERWRITE WEIGHTS') 79 | torch.save(net.state_dict(), weights_path) 80 | 81 | def dump_hparams(self, hparams: dict, metrics: dict): 82 | self.writer.add_hparams(hparams, metrics) 83 | --------------------------------------------------------------------------------