├── README.md ├── bert-base-uncased ├── added_tokens.json ├── config.json ├── special_tokens_map.json └── vocab.txt ├── config_bert.json ├── examples ├── test_joint.py └── train_joint.py ├── reid ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ └── __init__.cpython-38.pyc ├── datasets │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── data_builder.cpython-36.pyc │ │ ├── data_builder_384.cpython-36.pyc │ │ ├── data_builder_attr.cpython-36.pyc │ │ ├── data_builder_attr.cpython-38.pyc │ │ ├── data_builder_attr_cc_multi.cpython-36.pyc │ │ ├── data_builder_attr_multi.cpython-36.pyc │ │ ├── data_builder_attr_t2i.cpython-36.pyc │ │ ├── data_builder_cc.cpython-36.pyc │ │ ├── data_builder_cc.cpython-38.pyc │ │ ├── data_builder_cc_ceph.cpython-36.pyc │ │ ├── data_builder_cross.cpython-36.pyc │ │ ├── data_builder_cross.cpython-38.pyc │ │ ├── data_builder_ctcc.cpython-36.pyc │ │ ├── data_builder_ctcc.cpython-38.pyc │ │ ├── data_builder_multi.cpython-36.pyc │ │ ├── data_builder_multi_384.cpython-36.pyc │ │ ├── data_builder_sc.cpython-36.pyc │ │ ├── data_builder_sc.cpython-38.pyc │ │ ├── data_builder_sc_mnt.cpython-36.pyc │ │ ├── data_builder_sc_mnt.cpython-38.pyc │ │ ├── data_builder_t2i.cpython-36.pyc │ │ ├── data_builder_t2i.cpython-38.pyc │ │ ├── image_layer.cpython-36.pyc │ │ ├── image_layer.cpython-38.pyc │ │ ├── image_layer_attr.cpython-36.pyc │ │ ├── image_layer_ceph.cpython-36.pyc │ │ ├── image_layer_multi.cpython-36.pyc │ │ ├── image_layer_multi.cpython-38.pyc │ │ └── image_layer_sc.cpython-36.pyc │ ├── data │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── base_dataset.cpython-36.pyc │ │ │ ├── base_dataset.cpython-38.pyc │ │ │ ├── preprocessor.cpython-36.pyc │ │ │ ├── preprocessor_attr.cpython-36.pyc │ │ │ ├── preprocessor_attr.cpython-38.pyc │ │ │ ├── preprocessor_attr_t2i.cpython-36.pyc │ │ │ ├── preprocessor_cc.cpython-36.pyc │ │ │ ├── preprocessor_cc.cpython-38.pyc │ │ │ ├── preprocessor_cc_ceph.cpython-36.pyc │ │ │ ├── preprocessor_cross.cpython-36.pyc │ │ │ ├── preprocessor_cross.cpython-38.pyc │ │ │ ├── preprocessor_ctcc.cpython-36.pyc │ │ │ ├── preprocessor_ctcc.cpython-38.pyc │ │ │ ├── preprocessor_sc.cpython-36.pyc │ │ │ ├── preprocessor_sc.cpython-38.pyc │ │ │ ├── preprocessor_sc_ceph.cpython-36.pyc │ │ │ ├── preprocessor_sc_ceph.cpython-38.pyc │ │ │ ├── preprocessor_t2i.cpython-36.pyc │ │ │ ├── preprocessor_t2i.cpython-38.pyc │ │ │ ├── sampler.cpython-36.pyc │ │ │ ├── sampler.cpython-38.pyc │ │ │ ├── transforms.cpython-36.pyc │ │ │ └── transforms.cpython-38.pyc │ │ ├── base_dataset.py │ │ ├── preprocessor_attr.py │ │ ├── preprocessor_cc.py │ │ ├── preprocessor_cross.py │ │ ├── preprocessor_ctcc.py │ │ ├── preprocessor_sc.py │ │ ├── preprocessor_t2i.py │ │ ├── sampler.py │ │ └── transforms.py │ ├── data_builder_attr.py │ ├── data_builder_cc.py │ ├── data_builder_cross.py │ ├── data_builder_ctcc.py │ ├── data_builder_sc_mnt.py │ ├── data_builder_t2i.py │ ├── image_layer.py │ └── image_layer_multi.py ├── evaluation │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── evaluators.cpython-36.pyc │ │ ├── evaluators.cpython-38.pyc │ │ ├── evaluators_cl.cpython-36.pyc │ │ ├── evaluators_d.cpython-36.pyc │ │ ├── evaluators_n.cpython-36.pyc │ │ ├── evaluators_t.cpython-36.pyc │ │ └── evaluators_t.cpython-38.pyc │ ├── evaluators.py │ ├── evaluators_cl.py │ ├── evaluators_d.py │ ├── evaluators_t.py │ └── evalutator_m.py ├── evaluators_cl.py ├── loss │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── adaptive_triplet.cpython-36.pyc │ │ ├── adaptive_triplet.cpython-38.pyc │ │ ├── adv_loss.cpython-36.pyc │ │ ├── adv_loss.cpython-38.pyc │ │ ├── crossentropy.cpython-36.pyc │ │ ├── crossentropy.cpython-38.pyc │ │ ├── dual_causality_loss.cpython-36.pyc │ │ ├── dual_causality_loss.cpython-38.pyc │ │ ├── transloss.cpython-36.pyc │ │ ├── transloss.cpython-38.pyc │ │ ├── triplet.cpython-36.pyc │ │ ├── triplet.cpython-38.pyc │ │ └── triplet_new.cpython-36.pyc │ ├── adaptive_triplet.py │ ├── adv_loss.py │ ├── crossentropy.py │ ├── dual_causality_loss.py │ ├── transloss.py │ └── triplet.py ├── models │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── augmentor.cpython-36.pyc │ │ ├── augmentor.cpython-38.pyc │ │ ├── efficientnet.cpython-36.pyc │ │ ├── efficientnet.cpython-38.pyc │ │ ├── hap_transformer_joint.cpython-36.pyc │ │ ├── mgn.cpython-36.pyc │ │ ├── mgn.cpython-38.pyc │ │ ├── pass_transformer.cpython-36.pyc │ │ ├── pass_transformer_cl.cpython-36.pyc │ │ ├── pass_transformer_d.cpython-36.pyc │ │ ├── pass_transformer_gene.cpython-36.pyc │ │ ├── pass_transformer_gene.cpython-38.pyc │ │ ├── pass_transformer_joint.cpython-36.pyc │ │ ├── pass_transformer_joint.cpython-38.pyc │ │ ├── pass_transformer_pass.cpython-36.pyc │ │ ├── pass_transformer_pass.cpython-38.pyc │ │ ├── pass_transformer_t.cpython-36.pyc │ │ ├── pass_transformer_t_q.cpython-36.pyc │ │ ├── pass_transformer_t_q.cpython-38.pyc │ │ ├── resnet.cpython-36.pyc │ │ ├── resnet.cpython-38.pyc │ │ ├── resnet_ibn.cpython-36.pyc │ │ ├── resnet_ibn.cpython-38.pyc │ │ ├── resnet_ibn_snr.cpython-36.pyc │ │ ├── resnet_ibn_snr.cpython-38.pyc │ │ ├── resnet_ibn_two_branch.cpython-36.pyc │ │ ├── resnet_ibn_two_branch.cpython-38.pyc │ │ ├── resnext_ibn.cpython-36.pyc │ │ ├── resnext_ibn.cpython-38.pyc │ │ ├── se_resnet_ibn.cpython-36.pyc │ │ ├── se_resnet_ibn.cpython-38.pyc │ │ ├── swin.cpython-36.pyc │ │ ├── swin.cpython-38.pyc │ │ ├── tokenization_bert.cpython-36.pyc │ │ ├── tokenization_bert.cpython-38.pyc │ │ ├── transformer.cpython-36.pyc │ │ ├── transformer.cpython-38.pyc │ │ ├── transformer_attr.cpython-36.pyc │ │ ├── transformer_attr_new_cross.cpython-36.pyc │ │ ├── xbert.cpython-36.pyc │ │ └── xbert.cpython-38.pyc │ ├── augmentor.py │ ├── backbone │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── ckpt.cpython-36.pyc │ │ │ ├── ckpt.cpython-38.pyc │ │ │ ├── hap_vit.cpython-36.pyc │ │ │ ├── pass_vit.cpython-36.pyc │ │ │ ├── pass_vit.cpython-38.pyc │ │ │ ├── pass_vit_adapter.cpython-36.pyc │ │ │ ├── pass_vit_adapter.cpython-38.pyc │ │ │ ├── resnet_ibn_a.cpython-36.pyc │ │ │ ├── resnet_ibn_a.cpython-38.pyc │ │ │ ├── resnext_ibn.cpython-36.pyc │ │ │ ├── resnext_ibn.cpython-38.pyc │ │ │ ├── se_resnet_ibn.cpython-36.pyc │ │ │ ├── se_resnet_ibn.cpython-38.pyc │ │ │ ├── swin_transformer.cpython-36.pyc │ │ │ ├── swin_transformer.cpython-38.pyc │ │ │ ├── vit.cpython-36.pyc │ │ │ ├── vit.cpython-38.pyc │ │ │ ├── vit_.cpython-36.pyc │ │ │ ├── vit_albef.cpython-36.pyc │ │ │ ├── vit_albef.cpython-38.pyc │ │ │ ├── vit_human.cpython-36.pyc │ │ │ └── vitdet.cpython-36.pyc │ │ ├── ckpt.py │ │ ├── hap_vit.py │ │ ├── hap_vit_lem.py │ │ ├── modules │ │ │ ├── IBN.py │ │ │ ├── SE.py │ │ │ ├── SNR.py │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── IBN.cpython-36.pyc │ │ │ │ ├── IBN.cpython-38.pyc │ │ │ │ ├── SE.cpython-36.pyc │ │ │ │ ├── SE.cpython-38.pyc │ │ │ │ ├── SNR.cpython-36.pyc │ │ │ │ ├── SNR.cpython-38.pyc │ │ │ │ ├── __init__.cpython-36.pyc │ │ │ │ ├── __init__.cpython-38.pyc │ │ │ │ ├── spatial_snr.cpython-36.pyc │ │ │ │ └── spatial_snr.cpython-38.pyc │ │ │ └── spatial_snr.py │ │ ├── pass_vit.py │ │ ├── pass_vit_adapter.py │ │ ├── pass_vit_ori.py │ │ ├── resnet.py │ │ ├── resnet_ibn_a.py │ │ ├── resnext_ibn.py │ │ ├── se_resnet_ibn.py │ │ ├── vit.py │ │ ├── vit_albef.py │ │ ├── vit_albef_ori.py │ │ └── vit_ri.py │ ├── efficientnet.py │ ├── layers │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── cos_layer.cpython-36.pyc │ │ │ ├── cos_layer.cpython-38.pyc │ │ │ ├── gem.cpython-36.pyc │ │ │ ├── gem.cpython-38.pyc │ │ │ ├── metric.cpython-36.pyc │ │ │ ├── metric.cpython-38.pyc │ │ │ ├── rbf_layer.cpython-36.pyc │ │ │ └── rbf_layer.cpython-38.pyc │ │ ├── cos_layer.py │ │ ├── gem.py │ │ ├── metric.py │ │ └── rbf_layer.py │ ├── mgn.py │ ├── pass_transformer_joint.py │ ├── resnet.py │ ├── resnet_ibn.py │ ├── resnet_ibn_snr.py │ ├── resnet_ibn_two_branch.py │ ├── resnext_ibn.py │ ├── se_resnet_ibn.py │ ├── swin.py │ ├── swin_trans.py │ ├── tokenization_bert.py │ ├── transformer.py │ └── xbert.py ├── multi_tasks_utils │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── multi_task_distributed_utils.cpython-36.pyc │ │ ├── multi_task_distributed_utils.cpython-38.pyc │ │ ├── multi_task_distributed_utils_pt.cpython-38.pyc │ │ ├── task_info.cpython-36.pyc │ │ ├── task_info.cpython-38.pyc │ │ └── task_info_pt.cpython-38.pyc │ ├── model_serialization.py │ ├── multi_task_distributed_utils.py │ ├── multi_task_distributed_utils_pt.py │ ├── task_info.py │ └── task_info_pt.py ├── trainer │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── base_trainer.cpython-36.pyc │ │ ├── base_trainer.cpython-38.pyc │ │ ├── base_trainer_pt.cpython-38.pyc │ │ ├── base_trainer_t2i.cpython-36.pyc │ │ ├── base_trainer_t2i.cpython-38.pyc │ │ ├── general_clothes_trainer.cpython-36.pyc │ │ ├── general_clothes_trainer.cpython-38.pyc │ │ ├── general_trainer.cpython-36.pyc │ │ ├── general_trainer.cpython-38.pyc │ │ ├── general_transformer_trainer.cpython-36.pyc │ │ ├── general_transformer_trainer_attr.cpython-36.pyc │ │ ├── general_transformer_trainer_attr_tri.cpython-36.pyc │ │ ├── mgn_trainer.cpython-36.pyc │ │ ├── mgn_trainer.cpython-38.pyc │ │ ├── pass_trainer.cpython-36.pyc │ │ ├── pass_trainer.cpython-38.pyc │ │ ├── pass_trainer_cl.cpython-36.pyc │ │ ├── pass_trainer_joint.cpython-36.pyc │ │ ├── pass_trainer_joint.cpython-38.pyc │ │ ├── pass_trainer_t.cpython-36.pyc │ │ ├── pass_trainer_t_q.cpython-36.pyc │ │ ├── pass_trainer_t_q.cpython-38.pyc │ │ ├── snr_trainer.cpython-36.pyc │ │ ├── snr_trainer.cpython-38.pyc │ │ ├── swin_trainer.cpython-36.pyc │ │ ├── transreid_twobranch_aug_trainer.cpython-36.pyc │ │ ├── transreid_twobranch_aug_trainer.cpython-38.pyc │ │ ├── transreid_twobranch_trainer.cpython-36.pyc │ │ └── transreid_twobranch_trainer.cpython-38.pyc │ ├── base_trainer.py │ ├── base_trainer_ori.py │ ├── base_trainer_pt.py │ ├── base_trainer_t2i.py │ ├── general_clothes_trainer.py │ ├── general_trainer.py │ ├── general_trainer_pt.py │ ├── general_transformer_trainer.py │ ├── general_transformer_trainer_pt.py │ ├── mgn_trainer.py │ ├── pass_trainer.py │ ├── pass_trainer_cl.py │ ├── pass_trainer_joint.py │ ├── pass_trainer_joint_ori.py │ ├── pass_trainer_pt.py │ ├── pass_trainer_t.py │ ├── pass_trainer_t_q.py │ ├── snr_trainer.py │ ├── swin_trainer.py │ ├── transreid_twobranch_aug_trainer.py │ └── transreid_twobranch_trainer.py ├── utils.py └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-38.pyc │ ├── adamw.cpython-36.pyc │ ├── adamw.cpython-38.pyc │ ├── comm_.cpython-36.pyc │ ├── comm_.cpython-38.pyc │ ├── distributed_utils.cpython-36.pyc │ ├── distributed_utils.cpython-38.pyc │ ├── distributed_utils_pt.cpython-38.pyc │ ├── logging.cpython-36.pyc │ ├── logging.cpython-38.pyc │ ├── lr_scheduler.cpython-36.pyc │ ├── lr_scheduler.cpython-38.pyc │ ├── meters.cpython-36.pyc │ ├── meters.cpython-38.pyc │ ├── osutils.cpython-36.pyc │ ├── osutils.cpython-38.pyc │ ├── serialization.cpython-36.pyc │ ├── serialization.cpython-38.pyc │ ├── visualizer.cpython-36.pyc │ ├── visualizer.cpython-38.pyc │ ├── vit_rollout.cpython-36.pyc │ └── vit_rollout.cpython-38.pyc │ ├── adamw.py │ ├── comm_.py │ ├── distributed_utils.py │ ├── distributed_utils_pt.py │ ├── faiss_rerank.py │ ├── faiss_utils.py │ ├── logging.py │ ├── lr_scheduler.py │ ├── meters.py │ ├── osutils.py │ ├── rerank.py │ ├── serialization.py │ ├── visualizer.py │ └── vit_rollout.py └── scripts ├── config_ablation5.yaml ├── config_attr.yaml ├── config_ctcc.yaml ├── config_cuhk.yaml ├── config_cuhk_pedes.yaml ├── config_joint.yaml ├── config_llcm.yaml ├── config_ltcc.yaml ├── config_market.yaml ├── config_msmt.yaml ├── config_prcc.yaml ├── config_vc.yaml ├── test.sh └── train.sh /bert-base-uncased/added_tokens.json: -------------------------------------------------------------------------------- 1 | {} -------------------------------------------------------------------------------- /bert-base-uncased/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "max_position_embeddings": 512, 12 | "num_attention_heads": 12, 13 | "num_hidden_layers": 12, 14 | "type_vocab_size": 2, 15 | "vocab_size": 30522 16 | } 17 | -------------------------------------------------------------------------------- /bert-base-uncased/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | {"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"} -------------------------------------------------------------------------------- /config_bert.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "layer_norm_eps": 1e-12, 12 | "max_position_embeddings": 512, 13 | "model_type": "bert", 14 | "num_attention_heads": 12, 15 | "num_hidden_layers": 12, 16 | "pad_token_id": 0, 17 | "type_vocab_size": 2, 18 | "vocab_size": 30522, 19 | "fusion_layer": 6, 20 | "encoder_width": 768 21 | } 22 | -------------------------------------------------------------------------------- /reid/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from . import datasets 4 | from . import evaluation 5 | from . import models 6 | from . import utils 7 | from .evaluation import evaluators 8 | 9 | __version__ = '0.1.0' 10 | -------------------------------------------------------------------------------- /reid/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /reid/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /reid/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_builder_attr import DataBuilder_attr 2 | from .data_builder_sc_mnt import DataBuilder_sc 3 | from .data_builder_cc import DataBuilder_cc 4 | from .data_builder_ctcc import DataBuilder_ctcc 5 | from .data_builder_attr import DataBuilder_attr 6 | from .data_builder_t2i import DataBuilder_t2i 7 | from .data_builder_cross import DataBuilder_cross 8 | 9 | 10 | def dataset_entry(this_task_info): 11 | return globals()[this_task_info.task_name] -------------------------------------------------------------------------------- /reid/datasets/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /reid/datasets/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /reid/datasets/__pycache__/data_builder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/__pycache__/data_builder.cpython-36.pyc -------------------------------------------------------------------------------- /reid/datasets/__pycache__/data_builder_384.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/__pycache__/data_builder_384.cpython-36.pyc -------------------------------------------------------------------------------- /reid/datasets/__pycache__/data_builder_attr.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/__pycache__/data_builder_attr.cpython-36.pyc -------------------------------------------------------------------------------- /reid/datasets/__pycache__/data_builder_attr.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/__pycache__/data_builder_attr.cpython-38.pyc -------------------------------------------------------------------------------- /reid/datasets/__pycache__/data_builder_attr_cc_multi.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/__pycache__/data_builder_attr_cc_multi.cpython-36.pyc -------------------------------------------------------------------------------- /reid/datasets/__pycache__/data_builder_attr_multi.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/__pycache__/data_builder_attr_multi.cpython-36.pyc -------------------------------------------------------------------------------- /reid/datasets/__pycache__/data_builder_attr_t2i.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/__pycache__/data_builder_attr_t2i.cpython-36.pyc -------------------------------------------------------------------------------- /reid/datasets/__pycache__/data_builder_cc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/__pycache__/data_builder_cc.cpython-36.pyc -------------------------------------------------------------------------------- /reid/datasets/__pycache__/data_builder_cc.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/__pycache__/data_builder_cc.cpython-38.pyc -------------------------------------------------------------------------------- /reid/datasets/__pycache__/data_builder_cc_ceph.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/__pycache__/data_builder_cc_ceph.cpython-36.pyc -------------------------------------------------------------------------------- /reid/datasets/__pycache__/data_builder_cross.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/__pycache__/data_builder_cross.cpython-36.pyc -------------------------------------------------------------------------------- /reid/datasets/__pycache__/data_builder_cross.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/__pycache__/data_builder_cross.cpython-38.pyc -------------------------------------------------------------------------------- /reid/datasets/__pycache__/data_builder_ctcc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/__pycache__/data_builder_ctcc.cpython-36.pyc -------------------------------------------------------------------------------- /reid/datasets/__pycache__/data_builder_ctcc.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/__pycache__/data_builder_ctcc.cpython-38.pyc -------------------------------------------------------------------------------- /reid/datasets/__pycache__/data_builder_multi.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/__pycache__/data_builder_multi.cpython-36.pyc -------------------------------------------------------------------------------- /reid/datasets/__pycache__/data_builder_multi_384.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/__pycache__/data_builder_multi_384.cpython-36.pyc -------------------------------------------------------------------------------- /reid/datasets/__pycache__/data_builder_sc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/__pycache__/data_builder_sc.cpython-36.pyc -------------------------------------------------------------------------------- /reid/datasets/__pycache__/data_builder_sc.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/__pycache__/data_builder_sc.cpython-38.pyc -------------------------------------------------------------------------------- /reid/datasets/__pycache__/data_builder_sc_mnt.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/__pycache__/data_builder_sc_mnt.cpython-36.pyc -------------------------------------------------------------------------------- /reid/datasets/__pycache__/data_builder_sc_mnt.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/__pycache__/data_builder_sc_mnt.cpython-38.pyc -------------------------------------------------------------------------------- /reid/datasets/__pycache__/data_builder_t2i.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/__pycache__/data_builder_t2i.cpython-36.pyc -------------------------------------------------------------------------------- /reid/datasets/__pycache__/data_builder_t2i.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/__pycache__/data_builder_t2i.cpython-38.pyc -------------------------------------------------------------------------------- /reid/datasets/__pycache__/image_layer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/__pycache__/image_layer.cpython-36.pyc -------------------------------------------------------------------------------- /reid/datasets/__pycache__/image_layer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/__pycache__/image_layer.cpython-38.pyc -------------------------------------------------------------------------------- /reid/datasets/__pycache__/image_layer_attr.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/__pycache__/image_layer_attr.cpython-36.pyc -------------------------------------------------------------------------------- /reid/datasets/__pycache__/image_layer_ceph.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/__pycache__/image_layer_ceph.cpython-36.pyc -------------------------------------------------------------------------------- /reid/datasets/__pycache__/image_layer_multi.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/__pycache__/image_layer_multi.cpython-36.pyc -------------------------------------------------------------------------------- /reid/datasets/__pycache__/image_layer_multi.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/__pycache__/image_layer_multi.cpython-38.pyc -------------------------------------------------------------------------------- /reid/datasets/__pycache__/image_layer_sc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/__pycache__/image_layer_sc.cpython-36.pyc -------------------------------------------------------------------------------- /reid/datasets/data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /reid/datasets/data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /reid/datasets/data/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/data/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /reid/datasets/data/__pycache__/base_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/data/__pycache__/base_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /reid/datasets/data/__pycache__/base_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/data/__pycache__/base_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /reid/datasets/data/__pycache__/preprocessor.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/data/__pycache__/preprocessor.cpython-36.pyc -------------------------------------------------------------------------------- /reid/datasets/data/__pycache__/preprocessor_attr.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/data/__pycache__/preprocessor_attr.cpython-36.pyc -------------------------------------------------------------------------------- /reid/datasets/data/__pycache__/preprocessor_attr.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/data/__pycache__/preprocessor_attr.cpython-38.pyc -------------------------------------------------------------------------------- /reid/datasets/data/__pycache__/preprocessor_attr_t2i.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/data/__pycache__/preprocessor_attr_t2i.cpython-36.pyc -------------------------------------------------------------------------------- /reid/datasets/data/__pycache__/preprocessor_cc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/data/__pycache__/preprocessor_cc.cpython-36.pyc -------------------------------------------------------------------------------- /reid/datasets/data/__pycache__/preprocessor_cc.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/data/__pycache__/preprocessor_cc.cpython-38.pyc -------------------------------------------------------------------------------- /reid/datasets/data/__pycache__/preprocessor_cc_ceph.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/data/__pycache__/preprocessor_cc_ceph.cpython-36.pyc -------------------------------------------------------------------------------- /reid/datasets/data/__pycache__/preprocessor_cross.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/data/__pycache__/preprocessor_cross.cpython-36.pyc -------------------------------------------------------------------------------- /reid/datasets/data/__pycache__/preprocessor_cross.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/data/__pycache__/preprocessor_cross.cpython-38.pyc -------------------------------------------------------------------------------- /reid/datasets/data/__pycache__/preprocessor_ctcc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/data/__pycache__/preprocessor_ctcc.cpython-36.pyc -------------------------------------------------------------------------------- /reid/datasets/data/__pycache__/preprocessor_ctcc.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/data/__pycache__/preprocessor_ctcc.cpython-38.pyc -------------------------------------------------------------------------------- /reid/datasets/data/__pycache__/preprocessor_sc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/data/__pycache__/preprocessor_sc.cpython-36.pyc -------------------------------------------------------------------------------- /reid/datasets/data/__pycache__/preprocessor_sc.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/data/__pycache__/preprocessor_sc.cpython-38.pyc -------------------------------------------------------------------------------- /reid/datasets/data/__pycache__/preprocessor_sc_ceph.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/data/__pycache__/preprocessor_sc_ceph.cpython-36.pyc -------------------------------------------------------------------------------- /reid/datasets/data/__pycache__/preprocessor_sc_ceph.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/data/__pycache__/preprocessor_sc_ceph.cpython-38.pyc -------------------------------------------------------------------------------- /reid/datasets/data/__pycache__/preprocessor_t2i.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/data/__pycache__/preprocessor_t2i.cpython-36.pyc -------------------------------------------------------------------------------- /reid/datasets/data/__pycache__/preprocessor_t2i.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/data/__pycache__/preprocessor_t2i.cpython-38.pyc -------------------------------------------------------------------------------- /reid/datasets/data/__pycache__/sampler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/data/__pycache__/sampler.cpython-36.pyc -------------------------------------------------------------------------------- /reid/datasets/data/__pycache__/sampler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/data/__pycache__/sampler.cpython-38.pyc -------------------------------------------------------------------------------- /reid/datasets/data/__pycache__/transforms.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/data/__pycache__/transforms.cpython-36.pyc -------------------------------------------------------------------------------- /reid/datasets/data/__pycache__/transforms.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/datasets/data/__pycache__/transforms.cpython-38.pyc -------------------------------------------------------------------------------- /reid/datasets/data/base_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | class BaseDataset(object): 3 | """ 4 | Base class of reid dataset 5 | """ 6 | 7 | @staticmethod 8 | def get_imagedata_info(data): 9 | pids, cids, cams = [], [], [] 10 | for _, _, pid, cid, camid in data: 11 | pids += [pid] 12 | cids += [cid] 13 | cams += [camid] 14 | pids = set(pids) 15 | cids = set(cids) 16 | cams = set(cams) 17 | num_pids = len(pids) 18 | num_cids = len(cids) 19 | num_imgs = len(data) 20 | num_cams = len(cams) 21 | return num_pids, num_imgs, num_cids, num_cams 22 | 23 | 24 | class BaseImageDataset(BaseDataset): 25 | """ 26 | Base class of image reid dataset 27 | """ 28 | 29 | def print_dataset_statistics(self, dataset, dataset_type): 30 | num_train_pids, num_train_imgs, num_train_cids, num_train_cams = self.get_imagedata_info(dataset) 31 | 32 | print("Dataset statistics:") 33 | print(" ------------------------------------------") 34 | print(" {:<9s}| {:^5s} | {:^8s} | {:^8s} | {:^9s}".format('subset', '# ids', '# images', '# clothes', '# cameras')) 35 | print(" ------------------------------------------") 36 | print(" {:<9s}| {:^5d} | {:^8d} | {:^8d} | {:^9d}".format(dataset_type, num_train_pids, num_train_imgs, num_train_cids, num_train_cams)) 37 | print(" ------------------------------------------") 38 | 39 | @staticmethod 40 | def _relabel(label_list): 41 | sorted_pids = sorted(list(set(label_list))) 42 | label_dict = dict() 43 | for idx, pid in enumerate(sorted_pids): 44 | if pid in label_dict.keys(): 45 | continue 46 | label_dict[pid] = idx 47 | 48 | relabeled_list = [label_dict[pid] for pid in label_list] 49 | return relabeled_list 50 | -------------------------------------------------------------------------------- /reid/datasets/data/preprocessor_attr.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import os 4 | 5 | from PIL import Image, ImageFilter 6 | import cv2 7 | import numpy as np 8 | from torch.utils.data import Dataset 9 | import json 10 | import re 11 | 12 | class PreProcessor(Dataset): 13 | def __init__(self, dataset, json_list=None, root=None, root_additional=None, transform=None, clothes_transform=None, blur_clo=False): 14 | super(PreProcessor, self).__init__() 15 | self.dataset = dataset 16 | self.root = root 17 | self.root_additional = root_additional 18 | self.transform = transform 19 | self.initialized = False 20 | self.clothes_transform = clothes_transform 21 | self.blur_clo = blur_clo 22 | attr_file = open(json_list, 'r', encoding='utf-8') 23 | self.attr_dict = json.load(attr_file) 24 | 25 | 26 | def __len__(self): 27 | return len(self.dataset) 28 | 29 | def __getitem__(self, indices): 30 | return self._get_single_item(indices) 31 | 32 | def _get_single_item(self, index): 33 | fname, attr_fname, pid, cid, cam = self.dataset[index] 34 | fpath = fname 35 | 36 | attr_item = self.attr_dict[attr_fname] 37 | if int(pid)==-1: 38 | if self.root_additional is not None: 39 | fpath = os.path.join(self.root_additional, fname) 40 | else: 41 | if self.root is not None: 42 | fpath = os.path.join(self.root, fname) 43 | 44 | 45 | img = Image.open(fpath).convert('RGB') 46 | 47 | attribute = pre_caption(attr_item, 50) 48 | 49 | if self.transform is not None: 50 | img = self.transform(img) 51 | 52 | 53 | return img, attribute, fname, attr_fname, pid, cid, cam, index 54 | 55 | 56 | def pre_caption(caption, max_words): 57 | caption = re.sub( 58 | r"([,.'!?\"()*#:;~])", 59 | '', 60 | caption.lower(), 61 | ).replace('-', ' ').replace('/', ' ').replace('', 'person') 62 | caption = re.sub( 63 | r"\s{2,}", 64 | ' ', 65 | caption, 66 | ) 67 | caption = caption.rstrip('\n') 68 | caption = caption.strip(' ') 69 | # truncate caption 70 | caption_words = caption.split(' ') 71 | if len(caption_words)>max_words: 72 | caption = ' '.join(caption_words[:max_words]) 73 | return caption -------------------------------------------------------------------------------- /reid/datasets/data/preprocessor_cc.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import os 4 | 5 | from PIL import Image, ImageFilter 6 | import cv2 7 | import numpy as np 8 | from torch.utils.data import Dataset 9 | 10 | class PreProcessor(Dataset): 11 | def __init__(self, dataset, json_list=None, root=None, root_additional=None, transform=None, clothes_transform=None, blur_clo=False): 12 | super(PreProcessor, self).__init__() 13 | self.dataset = dataset 14 | self.root = root 15 | self.root_additional = root_additional 16 | self.transform = transform 17 | self.initialized = False 18 | self.clothes_transform = clothes_transform 19 | self.blur_clo = blur_clo 20 | def __len__(self): 21 | return len(self.dataset) 22 | 23 | def __getitem__(self, indices): 24 | return self._get_single_item(indices) 25 | 26 | def _get_single_item(self, index): 27 | fname, clothes_fname, pid, cid, cam = self.dataset[index] 28 | fpath = fname 29 | clothes_path = clothes_fname 30 | 31 | if int(pid)==-1: 32 | if self.root_additional is not None: 33 | fpath = os.path.join(self.root_additional, fname) 34 | clothes_path = os.path.join(self.root_additional, clothes_fname) 35 | else: 36 | if self.root is not None: 37 | fpath = os.path.join(self.root, fname) 38 | clothes_path = os.path.join(self.root, clothes_fname) 39 | 40 | img = Image.open(fpath).convert('RGB') 41 | clothes_img = Image.open(clothes_path).convert('RGB') 42 | 43 | if self.blur_clo: 44 | clothes_img = cv2.cvtColor(np.asarray(clothes_img),cv2.COLOR_RGB2BGR) 45 | kernel_size = (3, 3) 46 | sigma = 1.5 47 | clothes_img = cv2.GaussianBlur(clothes_img, kernel_size, sigma) 48 | clothes_img = Image.fromarray(cv2.cvtColor(clothes_img, cv2.COLOR_BGR2RGB)) 49 | #print(clothes_img.size) 50 | if self.transform is not None: 51 | img = self.transform(img) 52 | clothes_img = self.clothes_transform(clothes_img) 53 | 54 | return img, clothes_img, fname, clothes_fname, pid, cid, cam, index 55 | -------------------------------------------------------------------------------- /reid/datasets/data/preprocessor_cross.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import os 4 | 5 | from PIL import Image, ImageFilter 6 | import cv2 7 | import numpy as np 8 | from torch.utils.data import Dataset 9 | import json 10 | import random 11 | import re 12 | 13 | class PreProcessor(Dataset): 14 | def __init__(self, dataset, json_list=None, root=None, root_additional=None, transform=None, clothes_transform=None, blur_clo=False): 15 | super(PreProcessor, self).__init__() 16 | self.dataset = dataset 17 | self.root = root 18 | self.root_additional = root_additional 19 | self.transform = transform 20 | self.initialized = False 21 | self.clothes_transform = clothes_transform 22 | self.blur_clo = blur_clo 23 | 24 | def __len__(self): 25 | return len(self.dataset) 26 | 27 | def __getitem__(self, indices): 28 | return self._get_single_item(indices) 29 | 30 | 31 | def _get_single_item(self, index): 32 | fname, attr_fname, pid, cid, cam = self.dataset[index] 33 | fpath = fname 34 | attr_item = 'cross modiality' 35 | if int(pid)==-1: 36 | if self.root_additional is not None: 37 | fpath = os.path.join(self.root_additional, fname) 38 | else: 39 | if self.root is not None: 40 | fpath = os.path.join(self.root, fname) 41 | 42 | img = Image.open(fpath).convert('RGB') 43 | attribute = pre_caption(attr_item, 50) 44 | 45 | if self.transform is not None: 46 | img = self.transform(img) 47 | 48 | return img, attribute, fname, attr_fname, pid, cid, cam, index 49 | 50 | def pre_caption(caption, max_words): 51 | caption = re.sub( 52 | r"([,.'!?\"()*#:;~])", 53 | '', 54 | caption.lower(), 55 | ).replace('-', ' ').replace('/', ' ').replace('', 'person') 56 | caption = re.sub( 57 | r"\s{2,}", 58 | ' ', 59 | caption, 60 | ) 61 | caption = caption.rstrip('\n') 62 | caption = caption.strip(' ') 63 | # truncate caption 64 | caption_words = caption.split(' ') 65 | if len(caption_words)>max_words: 66 | caption = ' '.join(caption_words[:max_words]) 67 | return caption -------------------------------------------------------------------------------- /reid/datasets/data/preprocessor_ctcc.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import os 4 | 5 | from PIL import Image, ImageFilter 6 | import cv2 7 | import numpy as np 8 | from torch.utils.data import Dataset 9 | 10 | class PreProcessor(Dataset): 11 | def __init__(self, dataset, json_list=None, root=None, root_additional=None, transform=None, clothes_transform=None, blur_clo=False): 12 | super(PreProcessor, self).__init__() 13 | self.dataset = dataset 14 | self.root = root 15 | self.root_additional = root_additional 16 | self.transform = transform 17 | self.initialized = False 18 | self.clothes_transform = clothes_transform 19 | self.blur_clo = blur_clo 20 | def __len__(self): 21 | return len(self.dataset) 22 | 23 | def __getitem__(self, indices): 24 | return self._get_single_item(indices) 25 | 26 | 27 | def _get_single_item(self, index): 28 | fname, clothes_fname, pid, cid, cam = self.dataset[index] 29 | fpath = fname 30 | clothes_path = clothes_fname 31 | 32 | if int(pid)==-1: 33 | if self.root_additional is not None: 34 | fpath = os.path.join(self.root_additional, fname) 35 | clothes_path = os.path.join(self.root_additional, clothes_fname) 36 | else: 37 | if self.root is not None: 38 | fpath = os.path.join(self.root, fname) 39 | clothes_path = os.path.join(self.root, clothes_fname) 40 | 41 | img = Image.open(fpath).convert('RGB') 42 | clothes_img = Image.open(clothes_path).convert('RGB') 43 | 44 | if self.blur_clo: 45 | clothes_img = cv2.cvtColor(np.asarray(clothes_img),cv2.COLOR_RGB2BGR) 46 | kernel_size = (3, 3) 47 | sigma = 1.5 48 | clothes_img = cv2.GaussianBlur(clothes_img, kernel_size, sigma) 49 | clothes_img = Image.fromarray(cv2.cvtColor(clothes_img, cv2.COLOR_BGR2RGB)) 50 | #print(clothes_img.size) 51 | if self.transform is not None: 52 | img = self.transform(img) 53 | clothes_img = self.clothes_transform(clothes_img) 54 | 55 | 56 | return img, clothes_img, fname, clothes_fname, pid, cid, cam, index 57 | -------------------------------------------------------------------------------- /reid/datasets/data/preprocessor_sc.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import os 4 | 5 | from PIL import Image, ImageFilter 6 | import cv2 7 | import numpy as np 8 | from torch.utils.data import Dataset 9 | import json 10 | import random 11 | import re 12 | 13 | class PreProcessor(Dataset): 14 | def __init__(self, dataset, json_list=None, root=None, root_additional=None, transform=None, clothes_transform=None, blur_clo=False): 15 | super(PreProcessor, self).__init__() 16 | self.dataset = dataset 17 | self.root = root 18 | self.root_additional = root_additional 19 | self.transform = transform 20 | self.initialized = False 21 | self.clothes_transform = clothes_transform 22 | self.blur_clo = blur_clo 23 | 24 | def __len__(self): 25 | return len(self.dataset) 26 | 27 | def __getitem__(self, indices): 28 | return self._get_single_item(indices) 29 | 30 | 31 | def _get_single_item(self, index): 32 | fname, attr_fname, pid, cid, cam = self.dataset[index] 33 | fpath = fname 34 | attr_item = 'do not change clothes' 35 | if int(pid)==-1: 36 | if self.root_additional is not None: 37 | fpath = os.path.join(self.root_additional, fname) 38 | else: 39 | if self.root is not None: 40 | fpath = os.path.join(self.root, fname) 41 | 42 | img = Image.open(fpath).convert('RGB') 43 | attribute = pre_caption(attr_item, 50) 44 | 45 | if self.transform is not None: 46 | img = self.transform(img) 47 | 48 | return img, attribute, fname, attr_fname, pid, cid, cam, index 49 | 50 | def pre_caption(caption, max_words): 51 | caption = re.sub( 52 | r"([,.'!?\"()*#:;~])", 53 | '', 54 | caption.lower(), 55 | ).replace('-', ' ').replace('/', ' ').replace('', 'person') 56 | caption = re.sub( 57 | r"\s{2,}", 58 | ' ', 59 | caption, 60 | ) 61 | caption = caption.rstrip('\n') 62 | caption = caption.strip(' ') 63 | # truncate caption 64 | caption_words = caption.split(' ') 65 | if len(caption_words)>max_words: 66 | caption = ' '.join(caption_words[:max_words]) 67 | return caption -------------------------------------------------------------------------------- /reid/datasets/data/preprocessor_t2i.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import os 4 | 5 | from PIL import Image, ImageFilter 6 | import cv2 7 | import numpy as np 8 | from torch.utils.data import Dataset 9 | import json 10 | import re 11 | 12 | class PreProcessor(Dataset): 13 | def __init__(self, dataset, json_list=None, root=None, root_additional=None, transform=None, clothes_transform=None, blur_clo=False): 14 | super(PreProcessor, self).__init__() 15 | self.dataset = dataset 16 | self.root = root 17 | self.root_additional = root_additional 18 | self.transform = transform 19 | self.initialized = False 20 | self.clothes_transform = clothes_transform 21 | self.blur_clo = blur_clo 22 | attr_file = open(json_list, 'r', encoding='utf-8') 23 | self.attr_dict = json.load(attr_file) 24 | 25 | 26 | def __len__(self): 27 | return len(self.dataset) 28 | 29 | def __getitem__(self, indices): 30 | return self._get_single_item(indices) 31 | 32 | def _get_single_item(self, index): 33 | fname, attr_fname, pid, cid, cam = self.dataset[index] 34 | fpath = fname 35 | 36 | attr_item = self.attr_dict[attr_fname] 37 | if int(pid)==-1: 38 | if self.root_additional is not None: 39 | fpath = os.path.join(self.root_additional, fname) 40 | else: 41 | if self.root is not None: 42 | fpath = os.path.join(self.root, fname) 43 | 44 | 45 | img = Image.open(fpath).convert('RGB') 46 | attribute = pre_caption(attr_item, 50) 47 | 48 | if self.transform is not None: 49 | img = self.transform(img) 50 | 51 | 52 | return img, attribute, fname, attr_fname, pid, cid, cam, index 53 | 54 | 55 | def pre_caption(caption, max_words): 56 | caption = re.sub( 57 | r"([,.'!?\"()*#:;~])", 58 | '', 59 | caption.lower(), 60 | ).replace('-', ' ').replace('/', ' ').replace('', 'person') 61 | caption = re.sub( 62 | r"\s{2,}", 63 | ' ', 64 | caption, 65 | ) 66 | caption = caption.rstrip('\n') 67 | caption = caption.strip(' ') 68 | # truncate caption 69 | caption_words = caption.split(' ') 70 | if len(caption_words)>max_words: 71 | caption = ' '.join(caption_words[:max_words]) 72 | return caption -------------------------------------------------------------------------------- /reid/datasets/image_layer.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | from reid.datasets.data.base_dataset import BaseImageDataset 4 | 5 | 6 | class Image_Layer(BaseImageDataset): 7 | def __init__(self, image_list, image_list_additional=None, is_train=False, is_query=False, is_gallery=False, verbose=True): 8 | super(Image_Layer, self).__init__() 9 | imgs, clothes, pids, cids, cams =[], [], [], [], [] 10 | with open(image_list) as f: 11 | for line in f.readlines(): 12 | info = line.strip('\n').split(" ") 13 | imgs.append(info[0]) 14 | clothes.append(info[1]) 15 | pids.append(int(info[2])) 16 | cids.append(int(info[3])) 17 | if len(info) >4: 18 | cams.append(int(info[4])) 19 | elif is_train: 20 | cams.append(0) 21 | elif is_query: 22 | cams.append(-1) 23 | else: 24 | cams.append(-2) 25 | if image_list_additional is not None: 26 | with open(image_list_additional) as f: 27 | for line in f.readlines(): 28 | info = line.strip('\n').split(" ") 29 | imgs.append(info[0]) 30 | clothes.append(info[1]) 31 | pids.append(int(info[2])) 32 | cids.append(int(info[3])) 33 | if len(info) >4: 34 | cams.append(int(info[4])) 35 | elif is_train: 36 | cams.append(0) 37 | elif is_query: 38 | cams.append(-1) 39 | else: 40 | cams.append(-2) 41 | 42 | if is_train: 43 | pids = self._relabel(pids) 44 | 45 | self.data = list(zip(imgs, clothes, pids, cids, cams)) 46 | self.num_classes, self.num_imgs, self.num_cids, self.num_cams = self.get_imagedata_info(self.data) 47 | 48 | if verbose: 49 | print("=> {} Dataset information has been loaded.".format(image_list)) 50 | if is_train: 51 | self.print_dataset_statistics(self.data, 'train') 52 | if is_gallery: 53 | self.print_dataset_statistics(self.data, 'gallery') 54 | if is_query: 55 | self.print_dataset_statistics(self.data, 'query') 56 | 57 | 58 | 59 | -------------------------------------------------------------------------------- /reid/datasets/image_layer_multi.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | from reid.datasets.data.base_dataset import BaseImageDataset 4 | 5 | class Image_Layer(BaseImageDataset): 6 | def __init__(self, image_list, image_list_name, image_list_additional=None, is_train=False, is_query=False, is_gallery=False, verbose=True): 7 | super(Image_Layer, self).__init__() 8 | imgs, clothes, pids, cids, cams =[], [], [], [], [] 9 | for line in image_list: 10 | info = line.strip('\n').split(" ") 11 | imgs.append(info[0]) 12 | clothes.append(info[1]) 13 | pids.append(int(info[2])) 14 | cids.append(int(info[3])) 15 | if len(info) >4: 16 | cams.append(int(info[4])) 17 | elif is_train: 18 | cams.append(0) 19 | elif is_query: 20 | cams.append(-1) 21 | else: 22 | cams.append(-2) 23 | if image_list_additional is not None: 24 | for line in image_list_additional: 25 | info = line.strip('\n').split(" ") 26 | imgs.append(info[0]) 27 | clothes.append(info[1]) 28 | pids.append(int(info[2])) 29 | cids.append(int(info[3])) 30 | if len(info) >4: 31 | cams.append(int(info[4])) 32 | elif is_train: 33 | cams.append(0) 34 | elif is_query: 35 | cams.append(-1) 36 | else: 37 | cams.append(-2) 38 | 39 | if is_train: 40 | pids = self._relabel(pids) 41 | 42 | self.data = list(zip(imgs, clothes, pids, cids, cams)) 43 | self.num_classes, self.num_imgs, self.num_cids, self.num_cams = self.get_imagedata_info(self.data) 44 | 45 | if verbose: 46 | print("=> {} Dataset information has been loaded.".format(image_list_name)) 47 | if is_train: 48 | self.print_dataset_statistics(self.data, 'train') 49 | if is_gallery: 50 | self.print_dataset_statistics(self.data, 'gallery') 51 | if is_query: 52 | self.print_dataset_statistics(self.data, 'query') 53 | 54 | 55 | 56 | -------------------------------------------------------------------------------- /reid/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /reid/evaluation/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/evaluation/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /reid/evaluation/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/evaluation/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /reid/evaluation/__pycache__/evaluators.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/evaluation/__pycache__/evaluators.cpython-36.pyc -------------------------------------------------------------------------------- /reid/evaluation/__pycache__/evaluators.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/evaluation/__pycache__/evaluators.cpython-38.pyc -------------------------------------------------------------------------------- /reid/evaluation/__pycache__/evaluators_cl.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/evaluation/__pycache__/evaluators_cl.cpython-36.pyc -------------------------------------------------------------------------------- /reid/evaluation/__pycache__/evaluators_d.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/evaluation/__pycache__/evaluators_d.cpython-36.pyc -------------------------------------------------------------------------------- /reid/evaluation/__pycache__/evaluators_n.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/evaluation/__pycache__/evaluators_n.cpython-36.pyc -------------------------------------------------------------------------------- /reid/evaluation/__pycache__/evaluators_t.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/evaluation/__pycache__/evaluators_t.cpython-36.pyc -------------------------------------------------------------------------------- /reid/evaluation/__pycache__/evaluators_t.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/evaluation/__pycache__/evaluators_t.cpython-38.pyc -------------------------------------------------------------------------------- /reid/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .adaptive_triplet import TripletLoss, SoftTripletLoss 4 | from .crossentropy import CrossEntropyLabelSmooth, SoftEntropy 5 | from .transloss import TransLoss 6 | from .adv_loss import ClothesBasedAdversarialLoss, CosFaceLoss 7 | __all__ = [ 8 | 'TripletLoss', 9 | 'CrossEntropyLabelSmooth', 10 | 'SoftTripletLoss', 11 | 'SoftEntropy', 12 | 'TransLoss', 13 | 'ClothesBasedAdversarialLoss', 14 | 'CosFaceLoss' 15 | ] 16 | -------------------------------------------------------------------------------- /reid/loss/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/loss/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /reid/loss/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/loss/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /reid/loss/__pycache__/adaptive_triplet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/loss/__pycache__/adaptive_triplet.cpython-36.pyc -------------------------------------------------------------------------------- /reid/loss/__pycache__/adaptive_triplet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/loss/__pycache__/adaptive_triplet.cpython-38.pyc -------------------------------------------------------------------------------- /reid/loss/__pycache__/adv_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/loss/__pycache__/adv_loss.cpython-36.pyc -------------------------------------------------------------------------------- /reid/loss/__pycache__/adv_loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/loss/__pycache__/adv_loss.cpython-38.pyc -------------------------------------------------------------------------------- /reid/loss/__pycache__/crossentropy.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/loss/__pycache__/crossentropy.cpython-36.pyc -------------------------------------------------------------------------------- /reid/loss/__pycache__/crossentropy.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/loss/__pycache__/crossentropy.cpython-38.pyc -------------------------------------------------------------------------------- /reid/loss/__pycache__/dual_causality_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/loss/__pycache__/dual_causality_loss.cpython-36.pyc -------------------------------------------------------------------------------- /reid/loss/__pycache__/dual_causality_loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/loss/__pycache__/dual_causality_loss.cpython-38.pyc -------------------------------------------------------------------------------- /reid/loss/__pycache__/transloss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/loss/__pycache__/transloss.cpython-36.pyc -------------------------------------------------------------------------------- /reid/loss/__pycache__/transloss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/loss/__pycache__/transloss.cpython-38.pyc -------------------------------------------------------------------------------- /reid/loss/__pycache__/triplet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/loss/__pycache__/triplet.cpython-36.pyc -------------------------------------------------------------------------------- /reid/loss/__pycache__/triplet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/loss/__pycache__/triplet.cpython-38.pyc -------------------------------------------------------------------------------- /reid/loss/__pycache__/triplet_new.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/loss/__pycache__/triplet_new.cpython-36.pyc -------------------------------------------------------------------------------- /reid/loss/adv_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | import torch.distributed as dist 5 | 6 | class GatherLayer(torch.autograd.Function): 7 | """Gather tensors from all process, supporting backward propagation.""" 8 | 9 | @staticmethod 10 | def forward(ctx, input): 11 | ctx.save_for_backward(input) 12 | output = [torch.zeros_like(input) for _ in range(dist.get_world_size())] 13 | dist.all_gather(output, input) 14 | 15 | return tuple(output) 16 | 17 | @staticmethod 18 | def backward(ctx, *grads): 19 | (input,) = ctx.saved_tensors 20 | grad_out = torch.zeros_like(input) 21 | 22 | # dist.reduce_scatter(grad_out, list(grads)) 23 | # grad_out.div_(dist.get_world_size()) 24 | 25 | grad_out[:] = grads[dist.get_rank()] 26 | 27 | return grad_out 28 | 29 | class CosFaceLoss(nn.Module): 30 | """ CosFace Loss based on the predictions of classifier. 31 | 32 | Reference: 33 | Wang et al. CosFace: Large Margin Cosine Loss for Deep Face Recognition. In CVPR, 2018. 34 | 35 | Args: 36 | scale (float): scaling factor. 37 | margin (float): pre-defined margin. 38 | """ 39 | def __init__(self, scale=16, margin=0.1, **kwargs): 40 | super().__init__() 41 | self.s = scale 42 | self.m = margin 43 | 44 | def forward(self, inputs, targets): 45 | """ 46 | Args: 47 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 48 | targets: ground truth labels with shape (batch_size) 49 | """ 50 | one_hot = torch.zeros_like(inputs) 51 | one_hot.scatter_(1, targets.view(-1, 1), 1.0) 52 | 53 | output = self.s * (inputs - one_hot * self.m) 54 | 55 | return F.cross_entropy(output, targets) 56 | 57 | class ClothesBasedAdversarialLoss(nn.Module): 58 | """ Clothes-based Adversarial Loss. 59 | 60 | Reference: 61 | Gu et al. Clothes-Changing Person Re-identification with RGB Modality Only. In CVPR, 2022. 62 | 63 | Args: 64 | scale (float): scaling factor. 65 | epsilon (float): a trade-off hyper-parameter. 66 | """ 67 | def __init__(self, scale=16, epsilon=0.1): 68 | super().__init__() 69 | self.scale = scale 70 | self.epsilon = epsilon 71 | 72 | def forward(self, inputs, targets, positive_mask): 73 | """ 74 | Args: 75 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 76 | targets: ground truth labels with shape (batch_size) 77 | positive_mask: positive mask matrix with shape (batch_size, num_classes). The clothes classes with 78 | the same identity as the anchor sample are defined as positive clothes classes and their mask 79 | values are 1. The clothes classes with different identities from the anchor sample are defined 80 | as negative clothes classes and their mask values in positive_mask are 0. 81 | """ 82 | inputs = self.scale * inputs 83 | negtive_mask = 1 - positive_mask 84 | identity_mask = torch.zeros(inputs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1).cuda() 85 | 86 | exp_logits = torch.exp(inputs) 87 | log_sum_exp_pos_and_all_neg = torch.log((exp_logits * negtive_mask).sum(1, keepdim=True) + exp_logits) 88 | log_prob = inputs - log_sum_exp_pos_and_all_neg 89 | 90 | mask = (1 - self.epsilon) * identity_mask + self.epsilon / positive_mask.sum(1, keepdim=True) * positive_mask 91 | loss = (- mask * log_prob).sum(1).mean() 92 | 93 | return loss -------------------------------------------------------------------------------- /reid/loss/crossentropy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import * 5 | 6 | 7 | class CrossEntropyLabelSmooth(nn.Module): 8 | """ 9 | Cross entropy loss with label smoothing regularizer. 10 | 11 | Reference: 12 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 13 | Equation: y = (1 - epsilon) * y + epsilon / K. 14 | 15 | Args: 16 | num_classes (int): number of classes. 17 | epsilon (float): weight. 18 | """ 19 | 20 | def __init__(self, num_classes, epsilon=0.1): 21 | super(CrossEntropyLabelSmooth, self).__init__() 22 | self.num_classes = num_classes 23 | self.epsilon = epsilon 24 | self.logsoftmax = nn.LogSoftmax(dim=1).cuda() 25 | 26 | def forward(self, inputs, targets): 27 | """ 28 | Args: 29 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 30 | targets: ground truth labels with shape (num_classes) 31 | """ 32 | 33 | log_probs = self.logsoftmax(inputs) 34 | targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1) 35 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 36 | loss = (- targets * log_probs).mean(0).sum() 37 | return loss 38 | 39 | 40 | class SoftEntropy(nn.Module): 41 | def __init__(self): 42 | super(SoftEntropy, self).__init__() 43 | self.logsoftmax = nn.LogSoftmax(dim=1).cuda() 44 | self.softmax = nn.Softmax(dim=1).cuda() 45 | 46 | def forward(self, inputs, targets): 47 | log_probs = self.logsoftmax(inputs) 48 | loss = (- self.softmax(targets).detach() * log_probs).mean(0).sum() 49 | return loss 50 | -------------------------------------------------------------------------------- /reid/loss/dual_causality_loss.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn 7 | 8 | from reid.utils import to_numpy 9 | 10 | 11 | class DualCausalityLoss(nn.Module): 12 | def __init__(self): 13 | super(DualCausalityLoss, self).__init__() 14 | 15 | def forward(self, s_dual, label): 16 | f, fp, fm = s_dual 17 | pos, negs = self._sample_triplet(label) 18 | f_ap, f_an = self._forward(f, pos, negs) 19 | fp_ap, fp_an = self._forward(fp, pos, negs) 20 | fm_ap, fm_an = self._forward(fm, pos, negs) 21 | 22 | l1 = torch.mean(self.soft_plus(fp_ap - f_ap)) + torch.mean(self.soft_plus(f_an - fp_an)) 23 | l2 = torch.mean(self.soft_plus(f_ap - fm_ap)) + torch.mean(self.soft_plus(fm_an - f_an)) 24 | return l1 + l2 25 | 26 | def _forward(self, f: torch.Tensor, pos, negs): 27 | n = f.shape[0] 28 | dist = self.pairwise_distance(f) 29 | dist_ap, dist_an = [], [] 30 | for i in range(n): 31 | dist_ap.append(dist[i][pos[i]].unsqueeze(dim=0)) 32 | dist_an.append(dist[i][negs[i]].unsqueeze(dim=0)) 33 | dist_ap = torch.cat(dist_ap, dim=0) 34 | dist_an = torch.cat(dist_an, dim=0) 35 | return dist_ap, dist_an 36 | 37 | @staticmethod 38 | def _sample_triplet(label): 39 | label = label.view(-1, 1) 40 | n = label.shape[0] 41 | mask = label.expand(n, n).eq(label.expand(n, n).t()) 42 | mask = to_numpy(mask) 43 | 44 | pos, negs = [], [] 45 | for i in range(n): 46 | pos_indices = np.where(mask[i, :] == 1) 47 | idx = random.sample(list(pos_indices[0]), 1)[0] 48 | while idx == i: 49 | idx = random.sample(list(pos_indices[0]), 1)[0] 50 | pos.append(idx) 51 | neg_indices = np.where(mask[i, :] == 0) 52 | negs.append(random.sample(list(neg_indices[0]), 1)[0]) 53 | return pos, negs 54 | 55 | @staticmethod 56 | def pairwise_distance(x: torch.Tensor): 57 | x = F.normalize(x) 58 | cosine = torch.matmul(x, x.t()) 59 | distmat = -cosine + 0.5 60 | return distmat 61 | 62 | @staticmethod 63 | def soft_plus(x): 64 | return torch.log(1 + torch.exp(x)) 65 | -------------------------------------------------------------------------------- /reid/loss/transloss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | 7 | 8 | def euclidean_dist(x, y): 9 | m, n = x.size(0), y.size(0) 10 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 11 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 12 | dist = xx + yy 13 | dist.addmm_(1, -2, x, y.t()) 14 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 15 | return dist 16 | 17 | 18 | def cosine_dist(x, y): 19 | bs1, bs2 = x.size(0), y.size(0) 20 | frac_up = torch.matmul(x, y.transpose(0, 1)) 21 | frac_down = (torch.sqrt(torch.sum(torch.pow(x, 2), 1))).view(bs1, 1).repeat(1, bs2) * \ 22 | (torch.sqrt(torch.sum(torch.pow(y, 2), 1))).view(1, bs2).repeat(bs1, 1) 23 | cosine = frac_up / frac_down 24 | return 1 - cosine 25 | 26 | 27 | def _batch_hard(mat_distance, mat_similarity, indice=False): 28 | sorted_mat_distance, positive_indices = torch.sort(mat_distance + (-9999999.) * (1 - mat_similarity), dim=1, 29 | descending=True) 30 | hard_p = sorted_mat_distance[:, 0] 31 | hard_p_indice = positive_indices[:, 0] 32 | sorted_mat_distance, negative_indices = torch.sort(mat_distance + (9999999.) * (mat_similarity), dim=1, 33 | descending=False) 34 | hard_n = sorted_mat_distance[:, 0] 35 | hard_n_indice = negative_indices[:, 0] 36 | if (indice): 37 | return hard_p, hard_n, hard_p_indice, hard_n_indice 38 | return hard_p, hard_n 39 | 40 | 41 | class TransLoss(nn.Module): 42 | 43 | def __init__(self, margin, normalize_feature=False): 44 | super(TransLoss, self).__init__() 45 | self.margin = margin 46 | self.normalize_feature = normalize_feature 47 | self.margin_loss = nn.MarginRankingLoss(margin=margin).cuda() 48 | 49 | def forward(self, emb, emb_pos, emb_neg): 50 | if self.normalize_feature: 51 | # equal to cosine similarity 52 | emb = F.normalize(emb) 53 | emb_pos = F.normalize(emb_pos) 54 | emb_neg = F.normalize(emb_neg) 55 | dist_pos = torch.diag(euclidean_dist(emb, emb_pos)) 56 | dist_neg = torch.diag(euclidean_dist(emb, emb_neg)) 57 | # mat_dist = cosine_dist(emb, emb) 58 | 59 | 60 | y = torch.ones_like(dist_pos) 61 | loss = self.margin_loss(dist_neg, dist_pos, y) 62 | prec = (dist_neg.data > dist_pos.data).sum() * 1. / y.size(0) 63 | return loss, prec 64 | 65 | 66 | class SoftTripletLoss(nn.Module): 67 | 68 | def __init__(self, margin=None, normalize_feature=False): 69 | super(SoftTripletLoss, self).__init__() 70 | self.margin = margin 71 | self.normalize_feature = normalize_feature 72 | self.logsoftmax = nn.LogSoftmax(dim=1).cuda() 73 | self.softmax = nn.Softmax(dim=1).cuda() 74 | 75 | def forward(self, emb1, emb2, label): 76 | if self.normalize_feature: 77 | # equal to cosine similarity 78 | emb1 = F.normalize(emb1) 79 | emb2 = F.normalize(emb2) 80 | 81 | mat_dist = euclidean_dist(emb1, emb1) 82 | assert mat_dist.size(0) == mat_dist.size(1) 83 | N = mat_dist.size(0) 84 | mat_sim = label.expand(N, N).eq(label.expand(N, N).t()).float() 85 | 86 | dist_ap, dist_an, ap_idx, an_idx = _batch_hard(mat_dist, mat_sim, indice=True) 87 | assert dist_an.size(0) == dist_ap.size(0) 88 | triple_dist = torch.stack((dist_ap, dist_an), dim=1) 89 | triple_dist = self.logsoftmax(triple_dist) 90 | if self.margin is not None: 91 | loss = (- self.margin * triple_dist[:, 0] - (1 - self.margin) * triple_dist[:, 1]).mean() 92 | return loss 93 | 94 | mat_dist_ref = euclidean_dist(emb2, emb2) 95 | dist_ap_ref = torch.gather(mat_dist_ref, 1, ap_idx.view(N, 1).expand(N, N))[:, 0] 96 | dist_an_ref = torch.gather(mat_dist_ref, 1, an_idx.view(N, 1).expand(N, N))[:, 0] 97 | triple_dist_ref = torch.stack((dist_ap_ref, dist_an_ref), dim=1) 98 | triple_dist_ref = self.softmax(triple_dist_ref).detach() 99 | 100 | loss = (- triple_dist_ref * triple_dist).mean(0).sum() 101 | return loss 102 | -------------------------------------------------------------------------------- /reid/loss/triplet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | 7 | 8 | def euclidean_dist(x, y): 9 | m, n = x.size(0), y.size(0) 10 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 11 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 12 | dist = xx + yy 13 | dist.addmm_(1, -2, x, y.t()) 14 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 15 | return dist 16 | 17 | 18 | def cosine_dist(x, y): 19 | bs1, bs2 = x.size(0), y.size(0) 20 | frac_up = torch.matmul(x, y.transpose(0, 1)) 21 | frac_down = (torch.sqrt(torch.sum(torch.pow(x, 2), 1))).view(bs1, 1).repeat(1, bs2) * \ 22 | (torch.sqrt(torch.sum(torch.pow(y, 2), 1))).view(1, bs2).repeat(bs1, 1) 23 | cosine = frac_up / frac_down 24 | return 1 - cosine 25 | 26 | 27 | def _batch_hard(mat_distance, mat_similarity, indice=False): 28 | sorted_mat_distance, positive_indices = torch.sort(mat_distance + (-9999999.) * (1 - mat_similarity), dim=1, 29 | descending=True) 30 | hard_p = sorted_mat_distance[:, 0] 31 | hard_p_indice = positive_indices[:, 0] 32 | sorted_mat_distance, negative_indices = torch.sort(mat_distance + (9999999.) * (mat_similarity), dim=1, 33 | descending=False) 34 | hard_n = sorted_mat_distance[:, 0] 35 | hard_n_indice = negative_indices[:, 0] 36 | if (indice): 37 | return hard_p, hard_n, hard_p_indice, hard_n_indice 38 | return hard_p, hard_n 39 | 40 | 41 | class TripletLoss(nn.Module): 42 | 43 | def __init__(self, margin, normalize_feature=False): 44 | super(TripletLoss, self).__init__() 45 | self.margin = margin 46 | self.normalize_feature = normalize_feature 47 | self.margin_loss = nn.MarginRankingLoss(margin=margin).cuda() 48 | 49 | def forward(self, emb, label, clot_feats_s=None): 50 | if self.normalize_feature: 51 | # equal to cosine similarity 52 | emb = F.normalize(emb) 53 | mat_dist = euclidean_dist(emb, emb) 54 | # mat_dist = cosine_dist(emb, emb) 55 | assert mat_dist.size(0) == mat_dist.size(1) 56 | N = mat_dist.size(0) 57 | mat_sim = label.expand(N, N).eq(label.expand(N, N).t()).float() 58 | 59 | dist_ap, dist_an = _batch_hard(mat_dist, mat_sim) 60 | assert dist_an.size(0) == dist_ap.size(0) 61 | y = torch.ones_like(dist_ap) 62 | loss = self.margin_loss(dist_an, dist_ap, y) 63 | prec = (dist_an.data > dist_ap.data).sum() * 1. / y.size(0) 64 | return loss, prec 65 | 66 | 67 | class SoftTripletLoss(nn.Module): 68 | 69 | def __init__(self, margin=None, normalize_feature=False): 70 | super(SoftTripletLoss, self).__init__() 71 | self.margin = margin 72 | self.normalize_feature = normalize_feature 73 | self.logsoftmax = nn.LogSoftmax(dim=1).cuda() 74 | self.softmax = nn.Softmax(dim=1).cuda() 75 | 76 | def forward(self, emb1, emb2, label): 77 | if self.normalize_feature: 78 | # equal to cosine similarity 79 | emb1 = F.normalize(emb1) 80 | emb2 = F.normalize(emb2) 81 | 82 | mat_dist = euclidean_dist(emb1, emb1) 83 | assert mat_dist.size(0) == mat_dist.size(1) 84 | N = mat_dist.size(0) 85 | mat_sim = label.expand(N, N).eq(label.expand(N, N).t()).float() 86 | 87 | dist_ap, dist_an, ap_idx, an_idx = _batch_hard(mat_dist, mat_sim, indice=True) 88 | assert dist_an.size(0) == dist_ap.size(0) 89 | triple_dist = torch.stack((dist_ap, dist_an), dim=1) 90 | triple_dist = self.logsoftmax(triple_dist) 91 | if self.margin is not None: 92 | loss = (- self.margin * triple_dist[:, 0] - (1 - self.margin) * triple_dist[:, 1]).mean() 93 | return loss 94 | 95 | mat_dist_ref = euclidean_dist(emb2, emb2) 96 | dist_ap_ref = torch.gather(mat_dist_ref, 1, ap_idx.view(N, 1).expand(N, N))[:, 0] 97 | dist_an_ref = torch.gather(mat_dist_ref, 1, an_idx.view(N, 1).expand(N, N))[:, 0] 98 | triple_dist_ref = torch.stack((dist_ap_ref, dist_an_ref), dim=1) 99 | triple_dist_ref = self.softmax(triple_dist_ref).detach() 100 | 101 | loss = (- triple_dist_ref * triple_dist).mean(0).sum() 102 | return loss 103 | -------------------------------------------------------------------------------- /reid/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .resnet import * 4 | from .resnet_ibn import * 5 | from .resnext_ibn import * 6 | from .se_resnet_ibn import * 7 | from .efficientnet import * 8 | from .mgn import MGN 9 | from .resnet_ibn_snr import * 10 | from .resnet_ibn_two_branch import * 11 | from .transformer import * 12 | from .augmentor import * 13 | from .transformer import Transformer_local, Transformer_DualAttn, Transformer_DualAttn_multi 14 | from .pass_transformer_joint import PASS_Transformer_DualAttn_joint 15 | 16 | __factory = { 17 | 'resnet18': resnet18, 18 | 'resnet34': resnet34, 19 | 'resnet50': resnet50, 20 | 'resnet101': resnet101, 21 | 'resnet152': resnet152, 22 | 'resnet_ibn50a': resnet_ibn50a, 23 | 'resnet_ibn101a': resnet_ibn101a, 24 | 'resnext_ibn101a': resnext101_ibn_a, 25 | 'se_resnet_ibn101a': se_resnet101_ibn_a, 26 | 'efficientnet_b0': efficientnet_b0, 27 | 'efficientnet_b1': efficientnet_b1, 28 | 'efficientnet_b2': efficientnet_b2, 29 | 'efficientnet_b3': efficientnet_b3, 30 | 'efficientnet_b4': efficientnet_b4, 31 | 'efficientnet_b5': efficientnet_b5, 32 | 'mgn': MGN, 33 | 'resnet_ibn50a_snr': resnet_ibn50a_snr, 34 | 'resnet_ibn101a_snr': resnet_ibn101a_snr, 35 | 'resnet_ibn50a_snr_spatial': resnet_ibn50a_snr_spatial, 36 | 'resnet_ibn101a_snr_spatial': resnet_ibn101a_snr_spatial, 37 | 'resnet_ibn50a_two_branch': resnet_ibn50a_two_branch, 38 | 'resnet_ibn101a_two_branch': resnet_ibn101a_two_branch, 39 | # 'deit_small_patch16_224_TransReID': deit_small_patch16_224_TransReID, 40 | # 'deit_small_patch16_224_TransReID_mask': deit_small_patch16_224_TransReID_mask, 41 | # 'deit_small_patch16_224_TransReID_aug': deit_small_patch16_224_TransReID, 42 | # 'deit_small_patch16_224_TransReID_mask_aug': deit_small_patch16_224_TransReID_mask, 43 | 'augmentor': Augmentor, 44 | 'transformer': Transformer_local, 45 | 'transformer_dualattn': Transformer_DualAttn, 46 | 'transformer_dualattn_multi': Transformer_DualAttn_multi, 47 | 'PASS_Transformer_DualAttn_joint': PASS_Transformer_DualAttn_joint, 48 | 'transformer_dualattn_joint': PASS_Transformer_DualAttn_joint, 49 | 50 | } 51 | 52 | def names(): 53 | return sorted(__factory.keys()) 54 | 55 | 56 | def create(name, *args, **kwargs): 57 | """ 58 | Create a model instance. 59 | 60 | Parameters 61 | ---------- 62 | name : str 63 | Model name. Can be one of 'inception', 'resnet18', 'resnet34', 64 | 'resnet50', 'resnet101', and 'resnet152'. 65 | num_classes : int, optional 66 | If positive, will append a Linear layer at the end as the classifier 67 | with this number of output units. Default: 0 68 | net_config : ArgumentParser 69 | """ 70 | if name not in __factory: 71 | raise KeyError("Unknown model:", name) 72 | return __factory[name](**kwargs) 73 | -------------------------------------------------------------------------------- /reid/models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/augmentor.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/__pycache__/augmentor.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/augmentor.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/__pycache__/augmentor.cpython-38.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/efficientnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/__pycache__/efficientnet.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/efficientnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/__pycache__/efficientnet.cpython-38.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/hap_transformer_joint.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/__pycache__/hap_transformer_joint.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/mgn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/__pycache__/mgn.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/mgn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/__pycache__/mgn.cpython-38.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/pass_transformer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/__pycache__/pass_transformer.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/pass_transformer_cl.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/__pycache__/pass_transformer_cl.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/pass_transformer_d.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/__pycache__/pass_transformer_d.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/pass_transformer_gene.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/__pycache__/pass_transformer_gene.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/pass_transformer_gene.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/__pycache__/pass_transformer_gene.cpython-38.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/pass_transformer_joint.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/__pycache__/pass_transformer_joint.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/pass_transformer_joint.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/__pycache__/pass_transformer_joint.cpython-38.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/pass_transformer_pass.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/__pycache__/pass_transformer_pass.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/pass_transformer_pass.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/__pycache__/pass_transformer_pass.cpython-38.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/pass_transformer_t.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/__pycache__/pass_transformer_t.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/pass_transformer_t_q.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/__pycache__/pass_transformer_t_q.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/pass_transformer_t_q.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/__pycache__/pass_transformer_t_q.cpython-38.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/__pycache__/resnet.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/resnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/__pycache__/resnet.cpython-38.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/resnet_ibn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/__pycache__/resnet_ibn.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/resnet_ibn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/__pycache__/resnet_ibn.cpython-38.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/resnet_ibn_snr.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/__pycache__/resnet_ibn_snr.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/resnet_ibn_snr.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/__pycache__/resnet_ibn_snr.cpython-38.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/resnet_ibn_two_branch.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/__pycache__/resnet_ibn_two_branch.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/resnet_ibn_two_branch.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/__pycache__/resnet_ibn_two_branch.cpython-38.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/resnext_ibn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/__pycache__/resnext_ibn.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/resnext_ibn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/__pycache__/resnext_ibn.cpython-38.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/se_resnet_ibn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/__pycache__/se_resnet_ibn.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/se_resnet_ibn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/__pycache__/se_resnet_ibn.cpython-38.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/swin.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/__pycache__/swin.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/swin.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/__pycache__/swin.cpython-38.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/tokenization_bert.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/__pycache__/tokenization_bert.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/tokenization_bert.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/__pycache__/tokenization_bert.cpython-38.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/transformer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/__pycache__/transformer.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/transformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/__pycache__/transformer.cpython-38.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/transformer_attr.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/__pycache__/transformer_attr.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/transformer_attr_new_cross.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/__pycache__/transformer_attr_new_cross.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/xbert.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/__pycache__/xbert.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/xbert.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/__pycache__/xbert.cpython-38.pyc -------------------------------------------------------------------------------- /reid/models/augmentor.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn import functional as F 7 | from torch.nn import init 8 | import math 9 | 10 | def init_weights(m): 11 | if isinstance(m, nn.Linear): 12 | nn.init.kaiming_normal_(m.weight) 13 | 14 | 15 | class Linear(nn.Module): 16 | def __init__(self, linear_size): 17 | super(Linear, self).__init__() 18 | self.l_size = linear_size 19 | 20 | self.relu = nn.LeakyReLU(inplace=True) 21 | 22 | self.w1 = nn.Linear(self.l_size, self.l_size) 23 | self.batch_norm1 = nn.BatchNorm1d(self.l_size) 24 | 25 | self.w2 = nn.Linear(self.l_size, self.l_size) 26 | self.batch_norm2 = nn.BatchNorm1d(self.l_size) 27 | 28 | def forward(self, x): 29 | y = self.w1(x) 30 | y = self.batch_norm1(y) 31 | y = self.relu(y) 32 | 33 | y = self.w2(y) 34 | y = self.batch_norm2(y) 35 | y = self.relu(y) 36 | 37 | return y 38 | 39 | class Bottleneck(nn.Module): 40 | expansion = 4 41 | 42 | def __init__(self, inplanes, planes, stride=1, downsample=None): 43 | super(Bottleneck, self).__init__() 44 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 45 | 46 | self.bn1 = nn.BatchNorm2d(planes) 47 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 48 | padding=1, bias=False) 49 | self.bn2 = nn.BatchNorm2d(planes) 50 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 51 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 52 | self.relu = nn.ReLU(inplace=True) 53 | self.downsample = downsample 54 | self.stride = stride 55 | 56 | def forward(self, x): 57 | residual = x 58 | 59 | out = self.conv1(x) 60 | out = self.bn1(out) 61 | out = self.relu(out) 62 | 63 | out = self.conv2(out) 64 | out = self.bn2(out) 65 | out = self.relu(out) 66 | 67 | out = self.conv3(out) 68 | out = self.bn3(out) 69 | 70 | if self.downsample is not None: 71 | residual = self.downsample(x) 72 | 73 | out += residual 74 | out = self.relu(out) 75 | 76 | return out 77 | 78 | 79 | class Augmentor(nn.Module): 80 | 81 | def __init__(self, num_classes=1000, net_config=None): 82 | inplane = 256 83 | planes = 64 84 | self.num_classes = num_classes 85 | super(Augmentor, self).__init__() 86 | self.conv1 = nn.Conv2d(3, inplane, kernel_size=7, stride=2, padding=3, 87 | bias=False) 88 | self.bn1 = nn.BatchNorm2d(inplane) 89 | self.relu = nn.ReLU(inplace=True) 90 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 91 | self.layer = Bottleneck(inplane,planes) 92 | self.gap = nn.AdaptiveAvgPool2d(1) 93 | if self.num_classes > 0: 94 | 95 | self.classifier = nn.Linear(inplane, self.num_classes, bias=False) 96 | init.normal_(self.classifier.weight, std=0.001) 97 | 98 | for m in self.modules(): 99 | if isinstance(m, nn.Conv2d): 100 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 101 | m.weight.data.normal_(0, math.sqrt(2. / n)) 102 | elif isinstance(m, nn.BatchNorm2d): 103 | m.weight.data.fill_(1) 104 | m.bias.data.zero_() 105 | elif isinstance(m, nn.InstanceNorm2d): 106 | m.weight.data.fill_(1) 107 | m.bias.data.zero_() 108 | 109 | def forward(self, x): 110 | x = self.conv1(x) 111 | x = self.bn1(x) 112 | x = self.relu(x) 113 | x = self.maxpool(x) 114 | 115 | x = self.layer(x) 116 | 117 | x = self.gap(x) 118 | x = x.view(x.size(0), -1) 119 | x = self.classifier(x) 120 | x = F.sigmoid(x) 121 | return x -------------------------------------------------------------------------------- /reid/models/backbone/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/backbone/__init__.py -------------------------------------------------------------------------------- /reid/models/backbone/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/backbone/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/backbone/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/backbone/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /reid/models/backbone/__pycache__/ckpt.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/backbone/__pycache__/ckpt.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/backbone/__pycache__/ckpt.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/backbone/__pycache__/ckpt.cpython-38.pyc -------------------------------------------------------------------------------- /reid/models/backbone/__pycache__/hap_vit.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/backbone/__pycache__/hap_vit.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/backbone/__pycache__/pass_vit.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/backbone/__pycache__/pass_vit.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/backbone/__pycache__/pass_vit.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/backbone/__pycache__/pass_vit.cpython-38.pyc -------------------------------------------------------------------------------- /reid/models/backbone/__pycache__/pass_vit_adapter.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/backbone/__pycache__/pass_vit_adapter.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/backbone/__pycache__/pass_vit_adapter.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/backbone/__pycache__/pass_vit_adapter.cpython-38.pyc -------------------------------------------------------------------------------- /reid/models/backbone/__pycache__/resnet_ibn_a.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/backbone/__pycache__/resnet_ibn_a.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/backbone/__pycache__/resnet_ibn_a.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/backbone/__pycache__/resnet_ibn_a.cpython-38.pyc -------------------------------------------------------------------------------- /reid/models/backbone/__pycache__/resnext_ibn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/backbone/__pycache__/resnext_ibn.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/backbone/__pycache__/resnext_ibn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/backbone/__pycache__/resnext_ibn.cpython-38.pyc -------------------------------------------------------------------------------- /reid/models/backbone/__pycache__/se_resnet_ibn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/backbone/__pycache__/se_resnet_ibn.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/backbone/__pycache__/se_resnet_ibn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/backbone/__pycache__/se_resnet_ibn.cpython-38.pyc -------------------------------------------------------------------------------- /reid/models/backbone/__pycache__/swin_transformer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/backbone/__pycache__/swin_transformer.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/backbone/__pycache__/swin_transformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/backbone/__pycache__/swin_transformer.cpython-38.pyc -------------------------------------------------------------------------------- /reid/models/backbone/__pycache__/vit.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/backbone/__pycache__/vit.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/backbone/__pycache__/vit.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/backbone/__pycache__/vit.cpython-38.pyc -------------------------------------------------------------------------------- /reid/models/backbone/__pycache__/vit_.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/backbone/__pycache__/vit_.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/backbone/__pycache__/vit_albef.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/backbone/__pycache__/vit_albef.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/backbone/__pycache__/vit_albef.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/backbone/__pycache__/vit_albef.cpython-38.pyc -------------------------------------------------------------------------------- /reid/models/backbone/__pycache__/vit_human.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/backbone/__pycache__/vit_human.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/backbone/__pycache__/vitdet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/backbone/__pycache__/vitdet.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/backbone/hap_vit.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from itertools import repeat 3 | import math 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch._six import container_abcs 9 | from torch.utils.checkpoint import checkpoint as checkpoint_train 10 | import timm.models.vision_transformer 11 | 12 | 13 | __all__ = [ 14 | 'vit_base_patch16', 15 | 'vit_large_patch16', 16 | 'vit_huge_patch14' 17 | ] 18 | 19 | def _ntuple(n): 20 | def parse(x): 21 | if isinstance(x, container_abcs.Iterable): 22 | return x 23 | return tuple(repeat(x, n)) 24 | return parse 25 | 26 | to_2tuple = _ntuple(2) 27 | 28 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer): 29 | """ Vision Transformer with support for global average pooling 30 | """ 31 | def __init__(self, global_pool=False, **kwargs): 32 | super(VisionTransformer, self).__init__(**kwargs) 33 | self.global_pool = global_pool 34 | img_size = to_2tuple(kwargs['img_size']) 35 | patch_size = to_2tuple(kwargs['patch_size']) 36 | stride_size_tuple = to_2tuple(kwargs['patch_size']) 37 | self.num_x = (img_size[1] - patch_size[1]) // stride_size_tuple[1] + 1 38 | self.num_y = (img_size[0] - patch_size[0]) // stride_size_tuple[0] + 1 39 | if self.global_pool: 40 | norm_layer = kwargs['norm_layer'] 41 | embed_dim = kwargs['embed_dim'] 42 | self.fc_norm = norm_layer(embed_dim) 43 | 44 | del self.norm # remove the original norm 45 | 46 | def forward_features(self, x): 47 | B = x.shape[0] 48 | x = self.patch_embed(x) 49 | 50 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 51 | x = torch.cat((cls_tokens, x), dim=1) 52 | x = x + self.pos_embed 53 | x = self.pos_drop(x) 54 | 55 | for blk in self.blocks: 56 | x = blk(x) 57 | 58 | if self.global_pool: 59 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token 60 | outcome = self.fc_norm(x) 61 | else: 62 | x = self.norm(x) 63 | outcome = x[:, 0] 64 | local_outcome = x[:, 1:] 65 | 66 | return outcome, local_outcome 67 | 68 | def forward(self, x): 69 | x, local = self.forward_features(x) 70 | # logit = self.head(x) 71 | 72 | return x, local 73 | 74 | 75 | def vit_base_patch16(num_classes, img_size=(256, 128), stride_size=16, drop_path_rate=0.1, camera=0, view=0,local_feature=False,sie_xishu=1.5, **kwargs): 76 | model = VisionTransformer(img_size=img_size, num_classes=num_classes, 77 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 78 | norm_layer=partial(nn.LayerNorm, eps=1e-6), drop_path_rate=drop_path_rate) 79 | return model 80 | 81 | 82 | def vit_large_patch16(num_classes, img_size=(256, 128), stride_size=16, drop_path_rate=0.1, camera=0, view=0,local_feature=False,sie_xishu=1.5, **kwargs): 83 | model = VisionTransformer(img_size=img_size, num_classes=num_classes, 84 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 85 | norm_layer=partial(nn.LayerNorm, eps=1e-6), drop_path_rate=drop_path_rate) 86 | return model 87 | 88 | 89 | def vit_huge_patch14(num_classes, img_size=(256, 128), stride_size=16, drop_path_rate=0.1, camera=0, view=0,local_feature=False,sie_xishu=1.5, **kwargs): 90 | model = VisionTransformer(img_size=img_size, num_classes=num_classes, 91 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True, 92 | norm_layer=partial(nn.LayerNorm, eps=1e-6), drop_path_rate=drop_path_rate) 93 | return model -------------------------------------------------------------------------------- /reid/models/backbone/hap_vit_lem.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import timm.models.vision_transformer 4 | import torch 5 | import torch.nn as nn 6 | from timm.models.layers import create_classifier 7 | from timm.models.resnet import Bottleneck, downsample_conv 8 | 9 | __all__ = [ 10 | 'lem_base_patch16', 11 | 'lem_large_patch16', 12 | 'lem_huge_patch14' 13 | ] 14 | 15 | class LocalityEnhancedModule(timm.models.vision_transformer.VisionTransformer): 16 | """ Vision Transformer with support for global average pooling 17 | """ 18 | def __init__(self, global_pool=False, **kwargs): 19 | super(LocalityEnhancedModule, self).__init__(**kwargs) 20 | 21 | in_plane = kwargs['embed_dim'] 22 | self.conv_head = nn.Sequential( 23 | Bottleneck(in_plane, in_plane // 2, stride=2, downsample=downsample_conv(in_plane, in_plane * 2, kernel_size=1, stride=2)), 24 | Bottleneck(in_plane * 2, in_plane // 2, stride=1), 25 | Bottleneck(in_plane * 2, in_plane // 2, stride=1) 26 | ) 27 | self.conv_global_pool, self.conv_fc = create_classifier(in_plane * 2, self.num_classes) 28 | self.conv_fc_norm = kwargs['norm_layer'](in_plane * 2) 29 | self.conv_bn = nn.BatchNorm1d(in_plane * 2) 30 | self.conv_bn.bias.requires_grad_(False) 31 | 32 | self.bn = nn.BatchNorm1d(kwargs['embed_dim']) 33 | self.bn.bias.requires_grad_(False) 34 | 35 | self.global_pool = global_pool 36 | if self.global_pool: 37 | norm_layer = kwargs['norm_layer'] 38 | embed_dim = kwargs['embed_dim'] 39 | self.fc_norm = norm_layer(embed_dim) 40 | 41 | del self.norm # remove the original norm 42 | 43 | def forward_features(self, x): 44 | B = x.shape[0] 45 | C = self.embed_dim 46 | H, W = self.patch_embed.grid_size 47 | x = self.patch_embed(x) 48 | 49 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 50 | x = torch.cat((cls_tokens, x), dim=1) 51 | x = x + self.pos_embed 52 | x = self.pos_drop(x) 53 | 54 | for blk in self.blocks: 55 | x = blk(x) 56 | 57 | _x = x[:, 1:, :].transpose(1, 2).reshape(B, C, H, W) 58 | _x = self.conv_head(_x) 59 | _x = self.conv_global_pool(_x) 60 | _x = self.conv_fc_norm(_x) 61 | 62 | if self.global_pool: 63 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token 64 | outcome = (self.fc_norm(x), _x) 65 | else: 66 | x = self.norm(x) 67 | outcome = (x[:, 0], _x) 68 | 69 | return outcome 70 | 71 | def forward(self, x): 72 | x = self.forward_features(x) 73 | if self.head_dist is not None: 74 | x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple 75 | if self.training and not torch.jit.is_scripting(): 76 | # during inference, return the average of both classifier predictions 77 | return x, x_dist 78 | else: 79 | return (x + x_dist) / 2 80 | else: 81 | logit = (self.head(self.bn(x[0])), self.conv_fc(self.conv_bn(x[1]))) 82 | return x, logit 83 | 84 | 85 | def lem_base_patch16(num_classes, img_size=(256, 128), stride_size=16, drop_path_rate=0.1, camera=0, view=0,local_feature=False,sie_xishu=1.5, **kwargs): 86 | model = LocalityEnhancedModule(img_size=img_size, num_classes=num_classes, 87 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 88 | norm_layer=partial(nn.LayerNorm, eps=1e-6), drop_path_rate=drop_path_rate) 89 | return model 90 | 91 | 92 | def lem_large_patch16(num_classes, img_size=(256, 128), stride_size=16, drop_path_rate=0.1, camera=0, view=0,local_feature=False,sie_xishu=1.5, **kwargs): 93 | model = LocalityEnhancedModule(img_size=img_size, num_classes=num_classes, 94 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 95 | norm_layer=partial(nn.LayerNorm, eps=1e-6), drop_path_rate=drop_path_rate) 96 | return model 97 | 98 | 99 | def lem_huge_patch14(num_classes, img_size=(256, 128), stride_size=16, drop_path_rate=0.1, camera=0, view=0,local_feature=False,sie_xishu=1.5, **kwargs): 100 | model = LocalityEnhancedModule(img_size=img_size, num_classes=num_classes, 101 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True, 102 | norm_layer=partial(nn.LayerNorm, eps=1e-6), drop_path_rate=drop_path_rate) 103 | return model 104 | -------------------------------------------------------------------------------- /reid/models/backbone/modules/IBN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class IBN(nn.Module): 6 | def __init__(self, planes): 7 | super(IBN, self).__init__() 8 | half1 = int(planes / 2) 9 | self.half = half1 10 | half2 = planes - half1 11 | self.IN = nn.InstanceNorm2d(half1, affine=True) 12 | self.BN = nn.BatchNorm2d(half2) 13 | 14 | def forward(self, x): 15 | split = torch.split(x, self.half, 1) 16 | out1 = self.IN(split[0].contiguous()) 17 | out2 = self.BN(split[1].contiguous()) 18 | out = torch.cat((out1, out2), 1) 19 | return out -------------------------------------------------------------------------------- /reid/models/backbone/modules/SE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SELayer(nn.Module): 6 | def __init__(self, channel, reduction=16): 7 | super(SELayer, self).__init__() 8 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 9 | self.fc = nn.Sequential( 10 | nn.Linear(channel, int(channel/reduction), bias=False), 11 | nn.ReLU(inplace=True), 12 | nn.Linear(int(channel/reduction), channel, bias=False), 13 | nn.Sigmoid() 14 | ) 15 | 16 | def forward(self, x): 17 | b, c, _, _ = x.size() 18 | y = self.avg_pool(x).view(b, c) 19 | y = self.fc(y).view(b, c, 1, 1) 20 | return x * y.expand_as(x) 21 | -------------------------------------------------------------------------------- /reid/models/backbone/modules/SNR.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class SELayer(nn.Module): 5 | def __init__(self, num_channels, reduction=16): 6 | super(SELayer, self).__init__() 7 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 8 | self.fc = nn.Sequential( 9 | nn.Linear(num_channels, int(num_channels / reduction), bias=False), 10 | nn.ReLU(inplace=True), 11 | nn.Linear(int(num_channels / reduction), num_channels, bias=False), 12 | nn.Sigmoid() 13 | ) 14 | 15 | def forward(self, x): 16 | b, c, _, _ = x.size() 17 | y = self.avg_pool(x).view(b, c) 18 | y = self.fc(y).view(b, c, 1, 1) 19 | return y 20 | 21 | 22 | class SNR(nn.Module): 23 | def __init__(self, num_channels): 24 | super(SNR, self).__init__() 25 | self.num_features = num_channels 26 | self.IN = nn.InstanceNorm2d(num_channels, affine=True) 27 | self.SE = SELayer(num_channels) 28 | self.avg_pooling = nn.AdaptiveAvgPool2d(1) 29 | 30 | def forward(self, x): 31 | in_x = self.IN(x) 32 | r = x - in_x 33 | mask = self.SE(r) 34 | r_plus = mask * r 35 | r_minus = (1 - mask) * r 36 | 37 | x_plus = r_plus + in_x 38 | x_minus = r_minus + in_x 39 | 40 | f = self.avg_pooling(in_x).view(-1, self.num_features) 41 | fp = self.avg_pooling(x_plus).view(-1, self.num_features) 42 | fm = self.avg_pooling(x_minus).view(-1, self.num_features) 43 | return x_plus, (f, fp, fm) 44 | -------------------------------------------------------------------------------- /reid/models/backbone/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/backbone/modules/__init__.py -------------------------------------------------------------------------------- /reid/models/backbone/modules/__pycache__/IBN.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/backbone/modules/__pycache__/IBN.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/backbone/modules/__pycache__/IBN.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/backbone/modules/__pycache__/IBN.cpython-38.pyc -------------------------------------------------------------------------------- /reid/models/backbone/modules/__pycache__/SE.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/backbone/modules/__pycache__/SE.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/backbone/modules/__pycache__/SE.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/backbone/modules/__pycache__/SE.cpython-38.pyc -------------------------------------------------------------------------------- /reid/models/backbone/modules/__pycache__/SNR.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/backbone/modules/__pycache__/SNR.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/backbone/modules/__pycache__/SNR.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/backbone/modules/__pycache__/SNR.cpython-38.pyc -------------------------------------------------------------------------------- /reid/models/backbone/modules/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/backbone/modules/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/backbone/modules/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/backbone/modules/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /reid/models/backbone/modules/__pycache__/spatial_snr.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/backbone/modules/__pycache__/spatial_snr.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/backbone/modules/__pycache__/spatial_snr.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/backbone/modules/__pycache__/spatial_snr.cpython-38.pyc -------------------------------------------------------------------------------- /reid/models/backbone/modules/spatial_snr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class SELayer(nn.Module): 6 | def __init__(self, num_channels, reduction=16): 7 | super(SELayer, self).__init__() 8 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 9 | self.fc = nn.Sequential( 10 | nn.Linear(num_channels, int(num_channels / reduction), bias=False), 11 | nn.ReLU(inplace=True), 12 | nn.Linear(int(num_channels / reduction), num_channels, bias=False), 13 | nn.Sigmoid() 14 | ) 15 | 16 | def forward(self, x): 17 | b, c, _, _ = x.size() 18 | y = self.avg_pool(x).view(b, c) 19 | y = self.fc(y).view(b, c, 1, 1) 20 | return y 21 | 22 | 23 | class SpatialSNR(nn.Module): 24 | def __init__(self, num_channels): 25 | super(SpatialSNR, self).__init__() 26 | self.num_channels = num_channels 27 | self.channel_mask = SELayer(self.num_channels) 28 | self.spatial_mask = nn.Sequential( 29 | nn.Conv2d(self.num_channels, 128, 1), 30 | nn.ReLU(inplace=True), 31 | nn.Conv2d(128, 1, 1), 32 | nn.Sigmoid()) 33 | 34 | self.IN = nn.InstanceNorm2d(self.num_channels, affine=False) 35 | self.avg_pooling = nn.AdaptiveAvgPool2d(1) 36 | 37 | def forward(self, x): 38 | in_x = self.IN(x) 39 | r = x - in_x 40 | channel_mask = self.channel_mask(r) 41 | spatial_mask = self.spatial_mask(r) 42 | 43 | p_1 = channel_mask * r 44 | p_2 = spatial_mask * r 45 | 46 | d_1 = (1 - channel_mask) * r 47 | d_2 = (1 - spatial_mask) * r 48 | 49 | x = p_1 - 0.1*d_2 50 | y = p_2 - 0.1*d_1 51 | 52 | m = (x + y) / 2 53 | 54 | f = self.avg_pooling(in_x).view(-1, self.num_channels) 55 | f_p = self.avg_pooling(in_x + m).view(-1, self.num_channels) 56 | f_n = self.avg_pooling(in_x + (d_1 + d_2) / 2).view(-1, self.num_channels) 57 | 58 | return m + in_x, (f, f_p, f_n) 59 | 60 | -------------------------------------------------------------------------------- /reid/models/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/layers/__init__.py -------------------------------------------------------------------------------- /reid/models/layers/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/layers/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/layers/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/layers/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /reid/models/layers/__pycache__/cos_layer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/layers/__pycache__/cos_layer.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/layers/__pycache__/cos_layer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/layers/__pycache__/cos_layer.cpython-38.pyc -------------------------------------------------------------------------------- /reid/models/layers/__pycache__/gem.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/layers/__pycache__/gem.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/layers/__pycache__/gem.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/layers/__pycache__/gem.cpython-38.pyc -------------------------------------------------------------------------------- /reid/models/layers/__pycache__/metric.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/layers/__pycache__/metric.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/layers/__pycache__/metric.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/layers/__pycache__/metric.cpython-38.pyc -------------------------------------------------------------------------------- /reid/models/layers/__pycache__/rbf_layer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/layers/__pycache__/rbf_layer.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/layers/__pycache__/rbf_layer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/models/layers/__pycache__/rbf_layer.cpython-38.pyc -------------------------------------------------------------------------------- /reid/models/layers/cos_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import Parameter 4 | 5 | 6 | def cosine_sim(x1, x2, dim=1, eps=1e-8): 7 | ip = torch.mm(x1, x2.t()) 8 | w1 = torch.norm(x1, 2, dim) 9 | w2 = torch.norm(x2, 2, dim) 10 | return ip / torch.ger(w1,w2).clamp(min=eps) 11 | 12 | 13 | class MarginCosineProduct(nn.Module): 14 | r"""Implement of large margin cosine distance: : 15 | Args: 16 | in_features: size of each input sample 17 | out_features: size of each output sample 18 | s: norm of input feature 30~64 19 | m: margin 0.2~0.7 20 | """ 21 | 22 | def __init__(self, in_features, out_features, s=30.0, m=0.40): 23 | super(MarginCosineProduct, self).__init__() 24 | self.in_features = in_features 25 | self.out_features = out_features 26 | self.s = s 27 | self.m = m 28 | self.weight = Parameter(torch.Tensor(out_features, in_features)) 29 | nn.init.xavier_uniform_(self.weight) 30 | 31 | def forward(self, input, label): 32 | cosine = cosine_sim(input, self.weight) 33 | one_hot = torch.zeros_like(cosine) 34 | one_hot.scatter_(1, label.view(-1, 1), 1.0) 35 | output = self.s * (cosine - one_hot * self.m) 36 | return output 37 | -------------------------------------------------------------------------------- /reid/models/layers/gem.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | 6 | 7 | class GeneralizedMeanPooling(nn.Module): 8 | r"""Applies a 2D power-average adaptive pooling over an input signal composed of several input planes. 9 | The function computed is: :math:`f(X) = pow(sum(pow(X, p)), 1/p)` 10 | - At p = infinity, one gets Max Pooling 11 | - At p = 1, one gets Average Pooling 12 | The output is of size H x W, for any input size. 13 | The number of output features is equal to the number of input planes. 14 | Args: 15 | output_size: the target output size of the image of the form H x W. 16 | Can be a tuple (H, W) or a single H for a square image H x H 17 | H and W can be either a ``int``, or ``None`` which means the size will 18 | be the same as that of the input. 19 | """ 20 | 21 | def __init__(self, norm, output_size=1, eps=1e-6): 22 | super(GeneralizedMeanPooling, self).__init__() 23 | assert norm > 0 24 | self.p = float(norm) 25 | self.output_size = output_size 26 | self.eps = eps 27 | 28 | def forward(self, x): 29 | x = x.clamp(min=self.eps).pow(self.p) 30 | return F.adaptive_avg_pool2d(x, self.output_size).pow(1. / self.p) 31 | 32 | def __repr__(self): 33 | return self.__class__.__name__ + '(' \ 34 | + str(self.p) + ', ' \ 35 | + 'output_size=' + str(self.output_size) + ')' 36 | 37 | class GeneralizedMeanPoolingP(GeneralizedMeanPooling): 38 | """ Same, but norm is trainable 39 | """ 40 | 41 | def __init__(self, norm=3, output_size=1, eps=1e-6): 42 | super(GeneralizedMeanPoolingP, self).__init__(norm, output_size, eps) 43 | self.p = nn.Parameter(torch.ones(1) * norm) 44 | 45 | 46 | class AdaptiveAvgMaxPool2d(nn.Module): 47 | def __init__(self, output_size): 48 | super(AdaptiveAvgMaxPool2d, self).__init__() 49 | self.output_size = output_size 50 | 51 | def forward(self, x): 52 | x_max = F.adaptive_avg_pool2d(x, self.output_size) 53 | x_avg = F.adaptive_max_pool2d(x, self.output_size) 54 | x = x_max + x_avg 55 | return x 56 | -------------------------------------------------------------------------------- /reid/models/layers/rbf_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import Parameter 4 | 5 | def l2_dist(x, y): 6 | m, n = x.size(0), y.size(0) 7 | # x = x.view(m, -1) 8 | # y = y.view(n, -1) 9 | dist_m = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 10 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 11 | # dist_m.addmm_(1, -2, x, y.t()) 12 | dist_m = dist_m - 2 * torch.mm(x,y.t()) 13 | return dist_m 14 | 15 | 16 | class RBFLogits(nn.Module): 17 | # 8~35, 2/4~16 18 | def __init__(self, feature_dim, class_num, scale = 16.0, gamma = 8.0): 19 | super(RBFLogits, self).__init__() 20 | self.feature_dim = feature_dim 21 | self.class_num = class_num 22 | self.weight = nn.Parameter(torch.FloatTensor(class_num, feature_dim)) 23 | # self.bias = nn.Parameter(torch.FloatTensor(class_num)) 24 | self.scale = scale 25 | self.gamma = gamma 26 | nn.init.xavier_uniform_(self.weight) 27 | 28 | def forward(self, feat, label): 29 | # weight: L*2048 -> 1*L*2048 30 | # feat: B*2048 -> B*1*2048 31 | # diff: B*L*2048 32 | 33 | # diff = torch.unsqueeze(self.weight, dim=0) - torch.unsqueeze(feat, dim=1) 34 | # diff = torch.mul(diff, diff) 35 | # metric = torch.sum(diff, dim=-1) 36 | kernal_metric = l2_dist(feat, self.weight) 37 | kernal_metric = torch.exp(-1.0 * kernal_metric / self.gamma) 38 | logits = self.scale * kernal_metric 39 | return logits 40 | 41 | 42 | class MarginRBFLogits(nn.Module): 43 | def __init__(self, feature_dim, class_num, scale = 35.0, gamma = 16.0, margin=0.1): 44 | super(MarginRBFLogits, self).__init__() 45 | self.feature_dim = feature_dim 46 | self.class_num = class_num 47 | self.weight = nn.Parameter(torch.FloatTensor(class_num, feature_dim)) 48 | # self.bias = nn.Parameter(torch.FloatTensor(class_num)) 49 | self.scale = scale 50 | self.gamma = gamma 51 | self.margin = margin 52 | nn.init.xavier_uniform_(self.weight) 53 | 54 | def forward(self, feat, label): 55 | # diff = torch.unsqueeze(self.weight, dim=0) - torch.unsqueeze(feat, dim=1) 56 | # diff = torch.mul(diff, diff) 57 | # metric = torch.sum(diff, dim=-1) 58 | metric = l2_dist(feat, self.weight) 59 | kernal_metric = torch.exp(-1.0 * metric / self.gamma) 60 | 61 | if self.training: 62 | phi = kernal_metric - self.margin 63 | one_hot = torch.zeros(kernal_metric.size()).cuda() 64 | one_hot.scatter_(1, label.view(-1, 1).long(), 1) 65 | train_logits = self.scale * ((one_hot * phi) + ((1.0 - one_hot) * kernal_metric)) 66 | return train_logits 67 | else: 68 | test_logits = self.scale * kernal_metric 69 | return test_logits 70 | -------------------------------------------------------------------------------- /reid/models/resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torch.nn import init 6 | import torchvision 7 | from reid.models.layers.metric import build_metric 8 | 9 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 10 | 'resnet152'] 11 | 12 | 13 | class ResNet(nn.Module): 14 | __factory = { 15 | 18: torchvision.models.resnet18, 16 | 34: torchvision.models.resnet34, 17 | 50: torchvision.models.resnet50, 18 | 101: torchvision.models.resnet101, 19 | 152: torchvision.models.resnet152, 20 | } 21 | 22 | def __init__(self, depth, num_classes=0, net_config=None): 23 | super(ResNet, self).__init__() 24 | self.depth = depth 25 | self.net_config = net_config 26 | 27 | # Construct base (pretrained) resnet 28 | if depth not in ResNet.__factory: 29 | raise KeyError("Unsupported depth:", depth) 30 | 31 | resnet = ResNet.__factory[depth](pretrained=True) 32 | resnet.layer4[0].conv2.stride = (1, 1) 33 | resnet.layer4[0].downsample[0].stride = (1, 1) 34 | self.base = nn.Sequential( 35 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, 36 | resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4) 37 | self.gap = nn.AdaptiveAvgPool2d(1) 38 | 39 | self.num_features = self.net_config.num_features 40 | self.dropout = self.net_config.dropout 41 | self.has_embedding = self.net_config.num_features > 0 42 | self.num_classes = num_classes 43 | 44 | out_planes = resnet.fc.in_features 45 | 46 | if self.has_embedding: 47 | self.feat = nn.Linear(out_planes, self.num_features) 48 | init.kaiming_normal_(self.feat.weight, mode='fan_out') 49 | init.constant_(self.feat.bias, 0) 50 | else: 51 | self.num_features = out_planes 52 | 53 | self.feat_bn = nn.BatchNorm1d(self.num_features) 54 | self.feat_bn.bias.requires_grad_(False) 55 | init.constant_(self.feat_bn.weight, 1) 56 | init.constant_(self.feat_bn.bias, 0) 57 | 58 | if self.dropout > 0: 59 | self.drop = nn.Dropout(self.dropout) 60 | 61 | if self.num_classes > 0: 62 | if self.net_config.metric == 'linear': 63 | self.classifier = nn.Linear(self.num_features, self.num_classes, bias=False) 64 | init.normal_(self.classifier.weight, std=0.001) 65 | else: 66 | self.classifier = build_metric(self.net_config.metric, self.num_features, 67 | self.num_classes, self.net_config.scale, self.net_config.metric_margin) 68 | 69 | def forward(self, x, y=None): 70 | x = self.base(x) 71 | x = self.gap(x) 72 | x = x.view(x.size(0), -1) 73 | 74 | if self.has_embedding: 75 | bn_x = self.feat_bn(self.feat(x)) 76 | else: 77 | bn_x = self.feat_bn(x) 78 | 79 | if not self.training: 80 | bn_x = F.normalize(bn_x) 81 | return bn_x 82 | 83 | if self.has_embedding: 84 | bn_x = F.relu(bn_x) 85 | 86 | if self.dropout > 0: 87 | bn_x = self.drop(bn_x) 88 | 89 | if self.num_classes > 0: 90 | if isinstance(self.classifier, nn.Linear): 91 | logits = self.classifier(bn_x) 92 | else: 93 | logits = self.classifier(bn_x, y) 94 | else: 95 | return bn_x 96 | 97 | return x, bn_x, logits 98 | 99 | 100 | def resnet18(**kwargs): 101 | return ResNet(18, **kwargs) 102 | 103 | 104 | def resnet34(**kwargs): 105 | return ResNet(34, **kwargs) 106 | 107 | 108 | def resnet50(**kwargs): 109 | return ResNet(50, **kwargs) 110 | 111 | 112 | def resnet101(**kwargs): 113 | return ResNet(101, **kwargs) 114 | 115 | 116 | def resnet152(**kwargs): 117 | return ResNet(152, **kwargs) 118 | -------------------------------------------------------------------------------- /reid/models/resnet_ibn.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torch.nn import init 6 | 7 | from reid.models.backbone.resnet_ibn_a import resnet50_ibn_a, resnet101_ibn_a 8 | from reid.models.layers.metric import build_metric 9 | 10 | __all__ = ['ResNetIBN', 'resnet_ibn50a', 'resnet_ibn101a'] 11 | 12 | 13 | class ResNetIBN(nn.Module): 14 | __factory = { 15 | '50a': resnet50_ibn_a, 16 | '101a': resnet101_ibn_a 17 | } 18 | 19 | def __init__(self, depth, num_classes=0, net_config=None): 20 | super(ResNetIBN, self).__init__() 21 | self.depth = depth 22 | self.net_config = net_config 23 | 24 | resnet = ResNetIBN.__factory[depth](pretrained=True) 25 | resnet.layer4[0].conv2.stride = (1, 1) 26 | resnet.layer4[0].downsample[0].stride = (1, 1) 27 | self.base = nn.Sequential( 28 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, 29 | resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4) 30 | self.gap = nn.AdaptiveAvgPool2d(1) 31 | 32 | self.num_features = self.net_config.num_features 33 | self.dropout = self.net_config.dropout 34 | self.has_embedding = self.net_config.num_features > 0 35 | self.num_classes = num_classes 36 | 37 | out_planes = resnet.fc.in_features 38 | 39 | if self.has_embedding: 40 | self.feat = nn.Linear(out_planes, self.num_features) 41 | init.kaiming_normal_(self.feat.weight, mode='fan_out') 42 | init.constant_(self.feat.bias, 0) 43 | else: 44 | self.num_features = out_planes 45 | 46 | self.feat_bn = nn.BatchNorm1d(self.num_features) 47 | self.feat_bn.bias.requires_grad_(False) 48 | init.constant_(self.feat_bn.weight, 1) 49 | init.constant_(self.feat_bn.bias, 0) 50 | 51 | if self.dropout > 0: 52 | self.drop = nn.Dropout(self.dropout) 53 | 54 | if self.num_classes > 0: 55 | if self.net_config.metric == 'linear': 56 | self.classifier = nn.Linear(self.num_features, self.num_classes, bias=False) 57 | init.normal_(self.classifier.weight, std=0.001) 58 | else: 59 | self.classifier = build_metric(self.net_config.metric, self.num_features, 60 | self.num_classes, self.net_config.scale, self.net_config.metric_margin) 61 | 62 | def forward(self, x, y=None): 63 | x = self.base(x) 64 | x = self.gap(x) 65 | x = x.view(x.size(0), -1) 66 | 67 | if self.has_embedding: 68 | bn_x = self.feat_bn(self.feat(x)) 69 | else: 70 | bn_x = self.feat_bn(x) 71 | 72 | if not self.training: 73 | bn_x = F.normalize(bn_x) 74 | return bn_x 75 | 76 | if self.has_embedding: 77 | bn_x = F.relu(bn_x) 78 | 79 | if self.dropout > 0: 80 | bn_x = self.drop(bn_x) 81 | 82 | if self.num_classes > 0: 83 | if isinstance(self.classifier, nn.Linear): 84 | logits = self.classifier(bn_x) 85 | else: 86 | logits = self.classifier(bn_x, y) 87 | else: 88 | return bn_x 89 | 90 | return x, bn_x, logits 91 | 92 | 93 | def resnet_ibn50a(**kwargs): 94 | return ResNetIBN('50a', **kwargs) 95 | 96 | 97 | def resnet_ibn101a(**kwargs): 98 | return ResNetIBN('101a', **kwargs) 99 | -------------------------------------------------------------------------------- /reid/models/resnet_ibn_snr.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torch.nn import init 6 | 7 | from reid.models.backbone.resnet_ibn_a import resnet50_ibn_a, resnet101_ibn_a 8 | from reid.models.layers.metric import build_metric 9 | 10 | from .backbone.modules.SNR import SNR 11 | from .backbone.modules.spatial_snr import SpatialSNR 12 | 13 | __all__ = ['ResNetIBNSNR', 'resnet_ibn50a_snr', 'resnet_ibn101a_snr', 14 | 'resnet_ibn50a_snr_spatial', 'resnet_ibn101a_snr_spatial'] 15 | 16 | 17 | class ResNetIBNSNR(nn.Module): 18 | __factory = { 19 | '50a': resnet50_ibn_a, 20 | '101a': resnet101_ibn_a 21 | } 22 | 23 | def __init__(self, depth, spatial=False, num_classes=0, net_config=None): 24 | super(ResNetIBNSNR, self).__init__() 25 | self.depth = depth 26 | self.net_config = net_config 27 | 28 | resnet = ResNetIBNSNR.__factory[depth](pretrained=True) 29 | resnet.layer4[0].conv2.stride = (1, 1) 30 | resnet.layer4[0].downsample[0].stride = (1, 1) 31 | self.base = nn.Sequential( 32 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, 33 | ) 34 | self.layers = nn.ModuleList([resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4]) 35 | if spatial: 36 | self.decouple = nn.ModuleList([SpatialSNR(256), SpatialSNR(512), SpatialSNR(1024), SpatialSNR(2048)]) 37 | else: 38 | self.decouple = nn.ModuleList([SNR(256), SNR(512), SNR(1024), SNR(2048)]) 39 | self.gap = nn.AdaptiveAvgPool2d(1) 40 | 41 | self.num_features = self.net_config.num_features 42 | self.dropout = self.net_config.dropout 43 | self.has_embedding = self.net_config.num_features > 0 44 | self.num_classes = num_classes 45 | 46 | out_planes = resnet.fc.in_features 47 | 48 | if self.has_embedding: 49 | self.feat = nn.Linear(out_planes, self.num_features) 50 | init.kaiming_normal_(self.feat.weight, mode='fan_out') 51 | init.constant_(self.feat.bias, 0) 52 | else: 53 | self.num_features = out_planes 54 | 55 | self.feat_bn = nn.BatchNorm1d(self.num_features) 56 | self.feat_bn.bias.requires_grad_(False) 57 | init.constant_(self.feat_bn.weight, 1) 58 | init.constant_(self.feat_bn.bias, 0) 59 | 60 | if self.dropout > 0: 61 | self.drop = nn.Dropout(self.dropout) 62 | 63 | if self.num_classes > 0: 64 | if self.net_config.metric == 'linear': 65 | self.classifier = nn.Linear(self.num_features, self.num_classes, bias=False) 66 | init.normal_(self.classifier.weight, std=0.001) 67 | else: 68 | self.classifier = build_metric(self.net_config.metric, self.num_features, 69 | self.num_classes, self.net_config.scale, self.net_config.metric_margin) 70 | 71 | def forward(self, x, y=None): 72 | x = self.base(x) 73 | 74 | dual_feats = [] 75 | for item, layer in enumerate(self.layers): 76 | x = layer(x) 77 | x, f = self.decouple[item](x) 78 | dual_feats.append(f) 79 | 80 | x = self.gap(x) 81 | x = x.view(x.size(0), -1) 82 | 83 | if self.has_embedding: 84 | bn_x = self.feat_bn(self.feat(x)) 85 | else: 86 | bn_x = self.feat_bn(x) 87 | 88 | if not self.training: 89 | bn_x = F.normalize(bn_x) 90 | return bn_x 91 | 92 | if self.has_embedding: 93 | bn_x = F.relu(bn_x) 94 | 95 | if self.dropout > 0: 96 | bn_x = self.drop(bn_x) 97 | 98 | if self.num_classes > 0: 99 | if isinstance(self.classifier, nn.Linear): 100 | logits = self.classifier(bn_x) 101 | else: 102 | logits = self.classifier(bn_x, y) 103 | else: 104 | return bn_x 105 | 106 | return x, bn_x, logits, dual_feats 107 | 108 | 109 | def resnet_ibn50a_snr(**kwargs): 110 | return ResNetIBNSNR('50a', **kwargs) 111 | 112 | 113 | def resnet_ibn101a_snr(**kwargs): 114 | return ResNetIBNSNR('101a', **kwargs) 115 | 116 | 117 | def resnet_ibn50a_snr_spatial(**kwargs): 118 | return ResNetIBNSNR('50a', spatial=True, **kwargs) 119 | 120 | 121 | def resnet_ibn101a_snr_spatial(**kwargs): 122 | return ResNetIBNSNR('101a', spatial=True, **kwargs) -------------------------------------------------------------------------------- /reid/models/resnext_ibn.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torch.nn import init 6 | 7 | from reid.models.layers.gem import GeneralizedMeanPoolingP 8 | from reid.models.layers.metric import build_metric 9 | from reid.models.backbone.resnext_ibn import resnext101_ibn_a as resnext101_ibn_a_backbone 10 | 11 | __all__ = ['resnext101_ibn_a'] 12 | 13 | 14 | class ResNeXt_IBN(nn.Module): 15 | def __init__(self, depth, num_classes=0, net_config=None): 16 | super(ResNeXt_IBN, self).__init__() 17 | self.depth = depth 18 | self.net_config = net_config 19 | 20 | resnet = resnext101_ibn_a_backbone(pretrained=True) 21 | resnet.layer4[0].conv2.stride = (1, 1) 22 | resnet.layer4[0].downsample[0].stride = (1, 1) 23 | self.base = nn.Sequential( 24 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool1, 25 | resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4) 26 | self.gap = nn.AdaptiveAvgPool2d(1) 27 | 28 | self.num_features = self.net_config.num_features 29 | self.dropout = self.net_config.dropout 30 | self.has_embedding = self.net_config.num_features > 0 31 | self.num_classes = num_classes 32 | 33 | out_planes = resnet.fc.in_features 34 | 35 | if self.has_embedding: 36 | self.feat = nn.Linear(out_planes, self.num_features) 37 | init.kaiming_normal_(self.feat.weight, mode='fan_out') 38 | init.constant_(self.feat.bias, 0) 39 | else: 40 | self.num_features = out_planes 41 | 42 | self.feat_bn = nn.BatchNorm1d(self.num_features) 43 | self.feat_bn.bias.requires_grad_(False) 44 | init.constant_(self.feat_bn.weight, 1) 45 | init.constant_(self.feat_bn.bias, 0) 46 | 47 | if self.dropout > 0: 48 | self.drop = nn.Dropout(self.dropout) 49 | 50 | if self.num_classes > 0: 51 | if self.net_config.metric == 'linear': 52 | self.classifier = nn.Linear(self.num_features, self.num_classes, bias=False) 53 | init.normal_(self.classifier.weight, std=0.001) 54 | else: 55 | self.classifier = build_metric(self.net_config.metric, self.num_features, 56 | self.num_classes, self.net_config.scale, self.net_config.metric_margin) 57 | 58 | def forward(self, x, y=None): 59 | x = self.base(x) 60 | x = self.gap(x) 61 | x = x.view(x.size(0), -1) 62 | 63 | if self.has_embedding: 64 | bn_x = self.feat_bn(self.feat(x)) 65 | else: 66 | bn_x = self.feat_bn(x) 67 | 68 | if not self.training: 69 | bn_x = F.normalize(bn_x) 70 | return bn_x 71 | 72 | if self.has_embedding: 73 | bn_x = F.relu(bn_x) 74 | 75 | if self.dropout > 0: 76 | bn_x = self.drop(bn_x) 77 | 78 | if self.num_classes > 0: 79 | if isinstance(self.classifier, nn.Linear): 80 | logits = self.classifier(bn_x) 81 | else: 82 | logits = self.classifier(bn_x, y) 83 | else: 84 | return bn_x 85 | 86 | return x, bn_x, logits 87 | 88 | 89 | def resnext101_ibn_a(**kwargs): 90 | return ResNeXt_IBN('101a', **kwargs) 91 | -------------------------------------------------------------------------------- /reid/models/se_resnet_ibn.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torch.nn import init 6 | 7 | from reid.models.layers.metric import build_metric 8 | from reid.models.backbone.se_resnet_ibn import se_resnet101_ibn_a as se 9 | 10 | __all__ = ['se_resnet101_ibn_a'] 11 | 12 | 13 | class ResNet_IBN(nn.Module): 14 | def __init__(self, depth, num_classes=0, net_config=None): 15 | super(ResNet_IBN, self).__init__() 16 | 17 | self.depth = depth 18 | self.net_config = net_config 19 | 20 | resnet = se(pretrained=True) 21 | resnet.layer4[0].conv2.stride = (1, 1) 22 | resnet.layer4[0].downsample[0].stride = (1, 1) 23 | self.base = nn.Sequential( 24 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, 25 | resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4) 26 | self.gap = nn.AdaptiveAvgPool2d(1) 27 | 28 | self.num_features = self.net_config.num_features 29 | self.dropout = self.net_config.dropout 30 | self.has_embedding = self.net_config.num_features > 0 31 | self.num_classes = num_classes 32 | 33 | out_planes = resnet.fc.in_features 34 | 35 | if self.has_embedding: 36 | self.feat = nn.Linear(out_planes, self.num_features) 37 | init.kaiming_normal_(self.feat.weight, mode='fan_out') 38 | init.constant_(self.feat.bias, 0) 39 | else: 40 | self.num_features = out_planes 41 | 42 | self.feat_bn = nn.BatchNorm1d(self.num_features) 43 | self.feat_bn.bias.requires_grad_(False) 44 | init.constant_(self.feat_bn.weight, 1) 45 | init.constant_(self.feat_bn.bias, 0) 46 | 47 | if self.dropout > 0: 48 | self.drop = nn.Dropout(self.dropout) 49 | 50 | if self.num_classes > 0: 51 | if self.net_config.metric == 'linear': 52 | self.classifier = nn.Linear(self.num_features, self.num_classes, bias=False) 53 | init.normal_(self.classifier.weight, std=0.001) 54 | else: 55 | self.classifier = build_metric(self.net_config.metric, self.num_features, 56 | self.num_classes, self.net_config.scale, self.net_config.metric_margin) 57 | 58 | def forward(self, x, y=None): 59 | x = self.base(x) 60 | x = self.gap(x) 61 | x = x.view(x.size(0), -1) 62 | 63 | if self.has_embedding: 64 | bn_x = self.feat_bn(self.feat(x)) 65 | else: 66 | bn_x = self.feat_bn(x) 67 | 68 | if not self.training: 69 | bn_x = F.normalize(bn_x) 70 | return bn_x 71 | 72 | if self.has_embedding: 73 | bn_x = F.relu(bn_x) 74 | 75 | if self.dropout > 0: 76 | bn_x = self.drop(bn_x) 77 | 78 | if self.num_classes > 0: 79 | if isinstance(self.classifier, nn.Linear): 80 | logits = self.classifier(bn_x) 81 | else: 82 | logits = self.classifier(bn_x, y) 83 | else: 84 | return bn_x 85 | 86 | return x, bn_x, logits 87 | 88 | 89 | def se_resnet101_ibn_a(**kwargs): 90 | return ResNet_IBN('', **kwargs) 91 | -------------------------------------------------------------------------------- /reid/multi_tasks_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/multi_tasks_utils/__init__.py -------------------------------------------------------------------------------- /reid/multi_tasks_utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/multi_tasks_utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /reid/multi_tasks_utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/multi_tasks_utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /reid/multi_tasks_utils/__pycache__/multi_task_distributed_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/multi_tasks_utils/__pycache__/multi_task_distributed_utils.cpython-36.pyc -------------------------------------------------------------------------------- /reid/multi_tasks_utils/__pycache__/multi_task_distributed_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/multi_tasks_utils/__pycache__/multi_task_distributed_utils.cpython-38.pyc -------------------------------------------------------------------------------- /reid/multi_tasks_utils/__pycache__/multi_task_distributed_utils_pt.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/multi_tasks_utils/__pycache__/multi_task_distributed_utils_pt.cpython-38.pyc -------------------------------------------------------------------------------- /reid/multi_tasks_utils/__pycache__/task_info.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/multi_tasks_utils/__pycache__/task_info.cpython-36.pyc -------------------------------------------------------------------------------- /reid/multi_tasks_utils/__pycache__/task_info.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/multi_tasks_utils/__pycache__/task_info.cpython-38.pyc -------------------------------------------------------------------------------- /reid/multi_tasks_utils/__pycache__/task_info_pt.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/multi_tasks_utils/__pycache__/task_info_pt.cpython-38.pyc -------------------------------------------------------------------------------- /reid/multi_tasks_utils/model_serialization.py: -------------------------------------------------------------------------------- 1 | import os 2 | from reid.utils.serialization import save_checkpoint, load_checkpoint 3 | import torch.distributed as dist 4 | 5 | class ModelSerialization(): 6 | def __init__(self, taskinfo, args): 7 | super(ModelSerialization, self).__init__() 8 | self.taskinfo = taskinfo 9 | self.world_size = dist.get_world_size() 10 | self.rank = dist.get_rank() 11 | self.args = args 12 | 13 | def save(self, model, cur_iter): 14 | dirs = os.path.join(self.args.logs_dir, 'checkpoint_{}'.format(cur_iter)) 15 | base_fpath = os.path.join(dirs, "base_checkpoint.pth.tar") 16 | task_specific_fpath = os.path.join(dirs, self.taskinfo.task_name+"_checkpoint.pth.tar") 17 | 18 | # save base 19 | if self.rank == 0: 20 | base_params = dict() 21 | for key, value in model.base.state_dict().items(): 22 | base_params['base.'+key] = value 23 | save_checkpoint({ 24 | "state_dict": base_params, 25 | "start_iter": cur_iter 26 | }, fpath=base_fpath) 27 | print("Stage 1: Base Checkpoint has saved...") 28 | 29 | # save task-specific parameters 30 | if self.taskinfo.task_rank == 0: 31 | task_specific_params = dict() 32 | for key, value in model.state_dict().items(): 33 | if key.startswith('base'): 34 | continue 35 | task_specific_params[key] = value 36 | save_checkpoint(task_specific_params, fpath=task_specific_fpath) 37 | print("Stage 2: [Task-{}] Task-Specific Checkpoint has saved...".format(self.taskinfo.task_name)) 38 | 39 | def load(self, checkpoint_path, load_task_specific=True): 40 | base_checkpoint_path = os.path.join(checkpoint_path, "base_checkpoint.pth.tar") 41 | base_state_dict = load_checkpoint(base_checkpoint_path) 42 | params_state_dict = base_state_dict["state_dict"] 43 | start_iter = base_state_dict['start_iter'] 44 | 45 | if not load_task_specific: 46 | return params_state_dict, start_iter 47 | task_specific_checkpoint_path = os.path.join(checkpoint_path, self.taskinfo.task_name+"_checkpoint.pth.tar") 48 | task_specific_state_dict = load_checkpoint(task_specific_checkpoint_path) 49 | params_state_dict.update(task_specific_state_dict) 50 | return params_state_dict, start_iter 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | -------------------------------------------------------------------------------- /reid/multi_tasks_utils/multi_task_distributed_utils_pt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | import torch 5 | 6 | import torch.distributed as dist 7 | 8 | 9 | class Multitask_DistModule(torch.nn.Module): 10 | def __init__(self, module, sync=False, ignore=None, task_grp=None, task_root_rank=0): 11 | super(Multitask_DistModule, self).__init__() 12 | self.module = module 13 | self.ignore = ignore 14 | self.task_grp = task_grp 15 | self.task_root_rank = task_root_rank 16 | broadcast_params(self.module, self.ignore, self.task_grp, self.task_root_rank) 17 | 18 | if not sync: 19 | self._grad_accs = [] 20 | self._register_hooks() 21 | 22 | def forward(self, *inputs, **kwargs): 23 | return self.module(*inputs, **kwargs) 24 | 25 | def train(self, mode=True): 26 | super(Multitask_DistModule, self).train(mode) 27 | self.module.train(mode) 28 | 29 | def _register_hooks(self): 30 | for i,(name,p) in enumerate(self.named_parameters()): 31 | if p.requires_grad: 32 | #if name.startswith(self.prefix): 33 | if not self.ignore in name: 34 | p_tmp = p.expand_as(p) 35 | grad_acc = p_tmp.grad_fn.next_functions[0][0] 36 | grad_acc.register_hook(self._make_hook(name, p)) 37 | self._grad_accs.append(grad_acc) 38 | else: 39 | p_tmp = p.expand_as(p) 40 | grad_acc = p_tmp.grad_fn.next_functions[0][0] 41 | grad_acc.register_hook(self._make_hook(name, p, self.task_grp)) 42 | self._grad_accs.append(grad_acc) 43 | 44 | 45 | def _make_hook(self, name, p, task_grp=None): 46 | def hook(*ignore): 47 | if task_grp: 48 | allreduce_async(name, p.grad.data, group_idx=task_grp) 49 | else: 50 | allreduce_async(name, p.grad.data) 51 | return hook 52 | 53 | 54 | def multitask_reduce_gradients(model, sync=False, ignore_list=None, task_grp=None): 55 | """ average gradients """ 56 | if sync: 57 | if ignore_list is not None: 58 | for name, param in model.named_parameters(): 59 | if param.requires_grad and param.grad is not None: 60 | for ignore in ignore_list: 61 | if not ignore in name: 62 | allreduce(param.grad.data) 63 | else: 64 | allreduce(param.grad.data, group_idx=task_grp) 65 | elif param.grad is None: 66 | param.grad = param.data * 0 67 | for ignore in ignore_list: 68 | if not ignore in name: 69 | allreduce(param.grad.data) 70 | # param.grad 71 | else: 72 | # print('reduce task-specific param {} from {} to {}'.format(name,link.get_rank(),task_grp)) 73 | allreduce(param.grad.data, group_idx=task_grp) 74 | else: 75 | for param in model.parameters(): 76 | if param.requires_grad and param.grad is not None: 77 | allreduce(param.grad.data) 78 | else: 79 | dist.synchronize() 80 | 81 | def allreduce(x, group_idx=None, ): 82 | if group_idx == 0: 83 | group_idx = None 84 | return dist.all_reduce(x, group=group_idx) 85 | 86 | def allreduce_async(name, x, group_idx=None, ): 87 | if group_idx == 0: 88 | group_idx = None 89 | return dist.all_reduce(x, group=group_idx) 90 | 91 | def broadcast_params(model, ignore_list=None, task_grp=None, task_root_rank=0): 92 | """ broadcast model parameters """ 93 | if ignore_list is not None: 94 | for name, p in model.state_dict().items(): 95 | Flag = False 96 | for ignore in ignore_list: 97 | if ignore in name: 98 | Flag = True 99 | if Flag: 100 | dist.broadcast(p, task_root_rank, group=task_grp) 101 | else: 102 | dist.broadcast(p, 0) 103 | else: 104 | for name,p in model.state_dict().items(): 105 | dist.broadcast(p, 0) 106 | -------------------------------------------------------------------------------- /reid/multi_tasks_utils/task_info.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from easydict import EasyDict as edict 3 | # import linklink as link 4 | # try: 5 | # import spring.linklink as link 6 | # except: 7 | # import linklink as link 8 | 9 | __all__ = ['get_taskinfo'] 10 | 11 | def specific_task_split(task_spec, world_size, rank, tasks): 12 | ## sanity check 13 | assert type(task_spec) is list 14 | assert all(map(lambda x: type(x) is int, task_spec)) 15 | 16 | num_tasks = len(task_spec) 17 | splits = np.sum(task_spec) 18 | 19 | assert world_size % splits == 0 20 | unit = int(world_size / splits) 21 | rank = link.get_rank() 22 | if rank==0: 23 | print("processing unit num : {0}".format(unit)) 24 | ## split 25 | Ltask_sizes = [x*unit for x in task_spec] 26 | Ltasks = [] 27 | Lroots = [] 28 | last = 0 29 | thistask_info = edict() 30 | alltask_info = edict() 31 | 32 | for i,gs in enumerate(Ltask_sizes): 33 | ranks = list(map(int, np.arange(last, last+gs))) 34 | Ltasks.append(link.new_group(ranks=ranks)) ## The handle for each task 35 | Lroots.append(last) 36 | 37 | if rank in ranks: 38 | thistask_info.task_handle = Ltasks[-1] 39 | thistask_info.task_size = gs 40 | thistask_info.task_id = i 41 | thistask_info.task_rank = rank - last 42 | thistask_info.task_root_rank = last 43 | last += gs 44 | 45 | alltask_info.root_handles = link.new_group(ranks=Lroots) 46 | alltask_info.task_sizes = Ltask_sizes 47 | alltask_info.task_root_ranks = Lroots 48 | alltask_info.task_num = num_tasks 49 | 50 | return thistask_info, alltask_info 51 | 52 | 53 | def get_taskinfo(args, world_size, rank): 54 | # config = args.config 55 | # tasks = config['tasks'] 56 | tasks = args 57 | num_tasks = len(tasks) 58 | task_spec = [tasks[i].get('gres_ratio',1) for i in range(num_tasks)] 59 | thistask_info, alltask_info = specific_task_split(task_spec, world_size, rank, tasks) 60 | 61 | loss_weight_sum = float(np.sum(np.array([task['loss_weight'] for task in tasks.values()]))) 62 | 63 | thistask_info.task_name = tasks[thistask_info.task_id]['task_name'] 64 | thistask_info.task_weight = float(tasks[thistask_info.task_id]['loss_weight']) / loss_weight_sum 65 | thistask_info.train_file_path = tasks[thistask_info.task_id].get('train_file_path','') 66 | thistask_info.root_path = tasks[thistask_info.task_id].get('root_path', '') 67 | thistask_info.task_spec = tasks[thistask_info.task_id].get('task_spec', '') 68 | alltask_info.task_names = [tasks[i]['task_name'] for i in range(alltask_info.task_num)] 69 | return thistask_info, alltask_info 70 | -------------------------------------------------------------------------------- /reid/multi_tasks_utils/task_info_pt.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from easydict import EasyDict as edict 3 | import torch.distributed as dist 4 | 5 | 6 | __all__ = ['get_taskinfo'] 7 | 8 | def specific_task_split_(task_spec, world_size, rank, tasks): 9 | ## sanity check 10 | assert type(task_spec) is list 11 | assert all(map(lambda x: type(x) is int, task_spec)) 12 | 13 | num_tasks = len(task_spec) 14 | splits = np.sum(task_spec) 15 | 16 | assert world_size % splits == 0 17 | unit = int(world_size / splits) 18 | rank = dist.get_rank() 19 | if rank==0: 20 | print("processing unit num : {0}".format(unit)) 21 | ## split 22 | Ltask_sizes = [x*unit for x in task_spec] 23 | Ltasks = [] 24 | Lroots = [] 25 | last = 0 26 | thistask_info = edict() 27 | alltask_info = edict() 28 | 29 | for i,gs in enumerate(Ltask_sizes): 30 | ranks = list(map(int, np.arange(last, last+gs))) 31 | Ltasks.append(dist.new_group(ranks=ranks)) ## The handle for each task 32 | Lroots.append(last) 33 | 34 | if rank in ranks: 35 | thistask_info.task_handle = Ltasks[-1] 36 | thistask_info.task_size = gs 37 | thistask_info.task_id = i 38 | thistask_info.task_rank = rank - last 39 | thistask_info.task_root_rank = last 40 | last += gs 41 | 42 | alltask_info.root_handles = dist.new_group(ranks=Lroots) 43 | alltask_info.task_sizes = Ltask_sizes 44 | alltask_info.task_root_ranks = Lroots 45 | alltask_info.task_num = num_tasks 46 | 47 | return thistask_info, alltask_info 48 | 49 | def specific_task_split(task_spec, world_size, rank, tasks): 50 | ## sanity check 51 | assert type(task_spec) is list 52 | assert all(map(lambda x: type(x) is int, task_spec)) 53 | num_groups = len(task_spec) 54 | splits = np.sum(task_spec) 55 | assert world_size % splits == 0 56 | unit = int(world_size / splits) 57 | ## split 58 | task_spec = [x*unit for x in task_spec] 59 | groups = [] 60 | roots = [] 61 | last = 0 62 | group_info = edict() 63 | alltask_info = edict() 64 | for i,gs in enumerate(task_spec): 65 | ranks = list(map(int, np.arange(last, last+gs))) 66 | groups.append(dist.new_group(ranks=ranks)) 67 | roots.append(last) 68 | if rank in ranks: 69 | group_info.task_handle = groups[-1] 70 | group_info.task_size = gs 71 | group_info.task_id = i 72 | group_info.task_rank = rank - last 73 | group_info.task_root_rank = last 74 | last += gs 75 | group_info.task_roots = roots 76 | group_info.num_groups = num_groups 77 | 78 | alltask_info.root_handles = dist.new_group(ranks=roots) 79 | alltask_info.task_sizes = task_spec 80 | alltask_info.task_root_ranks = roots 81 | alltask_info.task_num = num_groups 82 | 83 | return group_info, alltask_info 84 | 85 | def get_taskinfo(args, world_size, rank): 86 | # config = args.config 87 | # tasks = config['tasks'] 88 | tasks = args 89 | num_tasks = len(tasks) 90 | task_spec = [tasks[i].get('gres_ratio',1) for i in range(num_tasks)] 91 | thistask_info, alltask_info = specific_task_split(task_spec, world_size, rank, tasks) 92 | 93 | loss_weight_sum = float(np.sum(np.array([task['loss_weight'] for task in tasks.values()]))) 94 | 95 | thistask_info.task_name = tasks[thistask_info.task_id]['task_name'] 96 | thistask_info.task_weight = float(tasks[thistask_info.task_id]['loss_weight']) / loss_weight_sum 97 | thistask_info.train_file_path = tasks[thistask_info.task_id].get('train_file_path','') 98 | thistask_info.root_path = tasks[thistask_info.task_id].get('root_path', '') 99 | thistask_info.task_spec = tasks[thistask_info.task_id].get('task_spec', '') 100 | thistask_info.attt_file = tasks[thistask_info.task_id].get('attt_file', '') 101 | 102 | alltask_info.task_names = [tasks[i]['task_name'] for i in range(alltask_info.task_num)] 103 | return thistask_info, alltask_info 104 | -------------------------------------------------------------------------------- /reid/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .general_trainer import GeneralTrainer 2 | from .snr_trainer import SNRTrainer 3 | from .mgn_trainer import MGNTrainer 4 | from .general_clothes_trainer import GeneralClothesTrainer 5 | from .transreid_twobranch_trainer import TransreidTwobranchTrainer 6 | from .transreid_twobranch_aug_trainer import TransreidTwobranchAugTrainer 7 | from .pass_trainer import GeneralTransformerTrainer 8 | from .pass_trainer_t_q import GeneralTransformerTrainer_t2i 9 | from .pass_trainer_joint import GeneralTransformerTrainer_joint 10 | 11 | class TrainerFactory(object): 12 | def __init__(self): 13 | super(TrainerFactory, self).__init__() 14 | self.snr_net = ['resnet_ibn50a_snr', 'resnet_ibn101a_snr', 15 | 'resnet_ibn50a_snr_spatial', 'resnet_ibn101a_snr_spatial'] 16 | self.mgn_net = ['mgn'] 17 | self.clothes_net = ['resnet_ibn101a_two_branch', 'resnet_ibn50a_two_branch', 'transformer', 'transformer_kmeans'] 18 | self.transreid_two_branch = ['deit_small_patch16_224_TransReID','deit_small_patch16_224_TransReID_mask'] 19 | self.transreid_two_branch_aug = ['deit_small_patch16_224_TransReID_aug','deit_small_patch16_224_TransReID_mask_aug'] 20 | self.transformer = ['transformer_dualattn'] 21 | self.transformer_t2i = ['transformer_dualattn_t2i'] 22 | self.transformer_joint = ['transformer_dualattn_joint'] 23 | 24 | def create(self, name, *args, **kwargs): 25 | if name in self.snr_net: 26 | return SNRTrainer(*args, **kwargs) 27 | if name in self.mgn_net: 28 | return MGNTrainer(*args, **kwargs) 29 | if name in self.clothes_net: 30 | return GeneralClothesTrainer(*args, **kwargs) 31 | if name in self.transreid_two_branch: 32 | return TransreidTwobranchTrainer(*args, **kwargs) 33 | if name in self.transreid_two_branch_aug: 34 | return TransreidTwobranchAugTrainer(*args, **kwargs) 35 | if name in self.transformer: 36 | return GeneralTransformerTrainer(*args, **kwargs) 37 | if name in self.transformer_t2i: 38 | return GeneralTransformerTrainer_t2i(*args, **kwargs) 39 | if name in self.transformer_joint: 40 | return GeneralTransformerTrainer_joint(*args, **kwargs) 41 | return GeneralTrainer(*args, **kwargs) 42 | -------------------------------------------------------------------------------- /reid/trainer/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/trainer/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /reid/trainer/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/trainer/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /reid/trainer/__pycache__/base_trainer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/trainer/__pycache__/base_trainer.cpython-36.pyc -------------------------------------------------------------------------------- /reid/trainer/__pycache__/base_trainer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/trainer/__pycache__/base_trainer.cpython-38.pyc -------------------------------------------------------------------------------- /reid/trainer/__pycache__/base_trainer_pt.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/trainer/__pycache__/base_trainer_pt.cpython-38.pyc -------------------------------------------------------------------------------- /reid/trainer/__pycache__/base_trainer_t2i.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/trainer/__pycache__/base_trainer_t2i.cpython-36.pyc -------------------------------------------------------------------------------- /reid/trainer/__pycache__/base_trainer_t2i.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/trainer/__pycache__/base_trainer_t2i.cpython-38.pyc -------------------------------------------------------------------------------- /reid/trainer/__pycache__/general_clothes_trainer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/trainer/__pycache__/general_clothes_trainer.cpython-36.pyc -------------------------------------------------------------------------------- /reid/trainer/__pycache__/general_clothes_trainer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/trainer/__pycache__/general_clothes_trainer.cpython-38.pyc -------------------------------------------------------------------------------- /reid/trainer/__pycache__/general_trainer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/trainer/__pycache__/general_trainer.cpython-36.pyc -------------------------------------------------------------------------------- /reid/trainer/__pycache__/general_trainer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/trainer/__pycache__/general_trainer.cpython-38.pyc -------------------------------------------------------------------------------- /reid/trainer/__pycache__/general_transformer_trainer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/trainer/__pycache__/general_transformer_trainer.cpython-36.pyc -------------------------------------------------------------------------------- /reid/trainer/__pycache__/general_transformer_trainer_attr.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/trainer/__pycache__/general_transformer_trainer_attr.cpython-36.pyc -------------------------------------------------------------------------------- /reid/trainer/__pycache__/general_transformer_trainer_attr_tri.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/trainer/__pycache__/general_transformer_trainer_attr_tri.cpython-36.pyc -------------------------------------------------------------------------------- /reid/trainer/__pycache__/mgn_trainer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/trainer/__pycache__/mgn_trainer.cpython-36.pyc -------------------------------------------------------------------------------- /reid/trainer/__pycache__/mgn_trainer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/trainer/__pycache__/mgn_trainer.cpython-38.pyc -------------------------------------------------------------------------------- /reid/trainer/__pycache__/pass_trainer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/trainer/__pycache__/pass_trainer.cpython-36.pyc -------------------------------------------------------------------------------- /reid/trainer/__pycache__/pass_trainer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/trainer/__pycache__/pass_trainer.cpython-38.pyc -------------------------------------------------------------------------------- /reid/trainer/__pycache__/pass_trainer_cl.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/trainer/__pycache__/pass_trainer_cl.cpython-36.pyc -------------------------------------------------------------------------------- /reid/trainer/__pycache__/pass_trainer_joint.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/trainer/__pycache__/pass_trainer_joint.cpython-36.pyc -------------------------------------------------------------------------------- /reid/trainer/__pycache__/pass_trainer_joint.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/trainer/__pycache__/pass_trainer_joint.cpython-38.pyc -------------------------------------------------------------------------------- /reid/trainer/__pycache__/pass_trainer_t.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/trainer/__pycache__/pass_trainer_t.cpython-36.pyc -------------------------------------------------------------------------------- /reid/trainer/__pycache__/pass_trainer_t_q.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/trainer/__pycache__/pass_trainer_t_q.cpython-36.pyc -------------------------------------------------------------------------------- /reid/trainer/__pycache__/pass_trainer_t_q.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/trainer/__pycache__/pass_trainer_t_q.cpython-38.pyc -------------------------------------------------------------------------------- /reid/trainer/__pycache__/snr_trainer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/trainer/__pycache__/snr_trainer.cpython-36.pyc -------------------------------------------------------------------------------- /reid/trainer/__pycache__/snr_trainer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/trainer/__pycache__/snr_trainer.cpython-38.pyc -------------------------------------------------------------------------------- /reid/trainer/__pycache__/swin_trainer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/trainer/__pycache__/swin_trainer.cpython-36.pyc -------------------------------------------------------------------------------- /reid/trainer/__pycache__/transreid_twobranch_aug_trainer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/trainer/__pycache__/transreid_twobranch_aug_trainer.cpython-36.pyc -------------------------------------------------------------------------------- /reid/trainer/__pycache__/transreid_twobranch_aug_trainer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/trainer/__pycache__/transreid_twobranch_aug_trainer.cpython-38.pyc -------------------------------------------------------------------------------- /reid/trainer/__pycache__/transreid_twobranch_trainer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/trainer/__pycache__/transreid_twobranch_trainer.cpython-36.pyc -------------------------------------------------------------------------------- /reid/trainer/__pycache__/transreid_twobranch_trainer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/trainer/__pycache__/transreid_twobranch_trainer.cpython-38.pyc -------------------------------------------------------------------------------- /reid/trainer/general_clothes_trainer.py: -------------------------------------------------------------------------------- 1 | import torch.distributed as dist 2 | from torch.nn import CrossEntropyLoss 3 | 4 | from reid.loss import TripletLoss 5 | from reid.trainer.base_trainer import BaseTrainer 6 | from reid.utils import accuracy 7 | from reid.utils.meters import AverageMeter 8 | 9 | class GeneralClothesTrainer(BaseTrainer): 10 | def __init__(self, model, args, this_task_info=None): 11 | super(GeneralClothesTrainer, self).__init__(model, args, this_task_info) 12 | 13 | self.ce_loss = CrossEntropyLoss().cuda() 14 | self.triplet_loss = TripletLoss(margin=self.args.margin).cuda() 15 | 16 | self.losses_ce = AverageMeter() 17 | self.losses_bme = AverageMeter() 18 | self.losses_tr = AverageMeter() 19 | self.precisions = AverageMeter() 20 | 21 | def _logging(self, cur_iter): 22 | self._tensorboard_writer(cur_iter, data={ 23 | 'loss': self.losses_ce.val + self.losses_tr.val, 24 | 'loss_ce': self.losses_ce.val, 25 | 'loss_bme': self.losses_bme.val, 26 | 'loss_tr': self.losses_tr.val, 27 | 'prec': self.precisions.val 28 | }) 29 | local_rank = self.this_task_info.task_rank if self.this_task_info else dist.get_rank() 30 | if not (cur_iter % self.args.print_freq == 0 and local_rank == 0): 31 | return 32 | if self.this_task_info: 33 | task_id, task_name = self.this_task_info.task_id, self.this_task_info.task_name 34 | else: 35 | task_id, task_name = 0, 'single task' 36 | print('Iter: [{}/{}]\t' 37 | 'task{}: {}\t' 38 | 'Time {:.3f} ({:.3f}) (ETA: {:.2f}h)\t' 39 | 'Data {:.3f} ({:.3f})\t' 40 | 'Loss_ce {:.3f} ({:.3f})\t' 41 | 'Loss_tr {:.3f} ({:.3f})\t' 42 | 'Loss_bme {:.3f} ({:.3f})\t' 43 | 'Prec {:.2%} ({:.2%})' 44 | .format(cur_iter, self.args.iters, 45 | str(task_id), str(task_name), 46 | self.batch_time.val, self.batch_time.avg, 47 | (self.args.iters - cur_iter) * self.batch_time.avg / 3600, 48 | self.data_time.val, self.data_time.avg, 49 | self.losses_ce.val, self.losses_ce.avg, 50 | self.losses_tr.val, self.losses_tr.avg, 51 | self.losses_bme.val, self.losses_bme.avg, 52 | self.precisions.val, self.precisions.avg)) 53 | 54 | def _refresh_information(self, cur_iter, lr): 55 | if cur_iter % self.args.refresh_freq == 0 or cur_iter == 1: 56 | self.batch_time = AverageMeter() 57 | self.data_time = AverageMeter() 58 | self.losses_ce = AverageMeter() 59 | self.losses_tr = AverageMeter() 60 | self.losses_bme = AverageMeter() 61 | self.precisions = AverageMeter() 62 | local_rank = self.this_task_info.task_rank if self.this_task_info else dist.get_rank() 63 | if local_rank == 0: 64 | print("lr = {} \t".format(lr)) 65 | 66 | def _parse_data(self, inputs): 67 | imgs, clothes, _, _, pids, _, _, indices = inputs 68 | inputs = imgs.cuda() 69 | clothes = clothes.cuda() 70 | targets = pids.cuda() 71 | return inputs, clothes, targets 72 | 73 | def run(self, inputs): 74 | inputs, clothes, targets = self._parse_data(inputs) 75 | feat, fusion_feat, logits1, logits2 = self.model(inputs, clothes) 76 | 77 | loss_ce_biometric = self.ce_loss(logits1, targets) 78 | 79 | loss_ce = self.ce_loss(logits2, targets) 80 | # loss_tr, _ = self.triplet_loss(feat, targets) 81 | loss_tr, _ = self.triplet_loss(fusion_feat, targets) 82 | loss = loss_ce + loss_tr + loss_ce_biometric 83 | 84 | self.losses_ce.update(loss_ce.item()) 85 | self.losses_tr.update(loss_tr.item()) 86 | self.losses_bme.update(loss_ce_biometric.item()) 87 | 88 | prec, = accuracy(logits2.data, targets.data) 89 | prec = prec[0] 90 | self.precisions.update(prec) 91 | 92 | return loss 93 | -------------------------------------------------------------------------------- /reid/trainer/general_trainer_pt.py: -------------------------------------------------------------------------------- 1 | import torch.distributed as dist 2 | from torch.nn import CrossEntropyLoss 3 | 4 | from reid.loss import TripletLoss 5 | from reid.trainer.base_trainer import BaseTrainer 6 | from reid.utils import accuracy 7 | from reid.utils.meters import AverageMeter 8 | 9 | 10 | class GeneralTrainer(BaseTrainer): 11 | def __init__(self, model, args, this_task_info=None): 12 | super(GeneralTrainer, self).__init__(model, args, this_task_info) 13 | self.ce_loss = CrossEntropyLoss().cuda() 14 | self.triplet_loss = TripletLoss(margin=self.args.margin).cuda() 15 | 16 | self.losses_ce = AverageMeter() 17 | self.losses_tr = AverageMeter() 18 | self.precisions = AverageMeter() 19 | 20 | def _logging(self, cur_iter): 21 | self._tensorboard_writer(cur_iter, data={ 22 | 'loss': self.losses_ce.val + self.losses_tr.val, 23 | 'loss_ce': self.losses_ce.val, 24 | 'loss_tr': self.losses_tr.val, 25 | 'prec': self.precisions.val 26 | }) 27 | 28 | if not (cur_iter % self.args.print_freq == 0 and dist.get_rank() == 0): 29 | return 30 | print('Iter: [{}/{}]\t' 31 | 'Time {:.3f} ({:.3f}) (ETA: {:.2f}h)\t' 32 | 'Data {:.3f} ({:.3f})\t' 33 | 'Loss_ce {:.3f} ({:.3f})\t' 34 | 'Loss_tr {:.3f} ({:.3f})\t' 35 | 'Prec {:.2%} ({:.2%})' 36 | .format(cur_iter, self.args.iters, 37 | self.batch_time.val, self.batch_time.avg, 38 | (self.args.iters - cur_iter) * self.batch_time.avg / 3600, 39 | self.data_time.val, self.data_time.avg, 40 | self.losses_ce.val, self.losses_ce.avg, 41 | self.losses_tr.val, self.losses_tr.avg, 42 | self.precisions.val, self.precisions.avg)) 43 | 44 | def _refresh_information(self, cur_iter, lr): 45 | if cur_iter % self.args.refresh_freq == 0 or cur_iter == 1: 46 | self.batch_time = AverageMeter() 47 | self.data_time = AverageMeter() 48 | self.losses_ce = AverageMeter() 49 | self.losses_tr = AverageMeter() 50 | self.precisions = AverageMeter() 51 | if dist.get_rank() == 0: 52 | print("lr = {} \t".format(lr)) 53 | def _parse_data(self, inputs): 54 | imgs, clothes, _, _, pids, _, _, indices = inputs 55 | inputs = imgs.cuda() 56 | clothes = clothes.cuda() 57 | targets = pids.cuda() 58 | return inputs, clothes, targets 59 | 60 | def run(self, inputs): 61 | inputs, clothes, targets = self._parse_data(inputs) 62 | feat, logits, clot_feats_s = self.model(inputs, clothes) 63 | 64 | loss_ce = self.ce_loss(logits, targets) 65 | loss_tr, _ = self.triplet_loss(feat, targets, clot_feats_s) 66 | # import pdb;pdb.set_trace() 67 | # loss_tr, _ = self.triplet_loss(feat, targets) 68 | loss = loss_ce + loss_tr 69 | 70 | self.losses_ce.update(loss_ce.item()) 71 | self.losses_tr.update(loss_tr.item()) 72 | 73 | prec, = accuracy(logits.data, targets.data) 74 | prec = prec[0] 75 | self.precisions.update(prec) 76 | 77 | return loss 78 | -------------------------------------------------------------------------------- /reid/trainer/mgn_trainer.py: -------------------------------------------------------------------------------- 1 | import torch.distributed as dist 2 | from torch.nn import CrossEntropyLoss 3 | 4 | from reid.loss import TripletLoss 5 | from reid.loss.dual_causality_loss import DualCausalityLoss 6 | from reid.trainer.base_trainer import BaseTrainer 7 | from reid.utils import accuracy 8 | from reid.utils.meters import AverageMeter 9 | 10 | 11 | class MGNTrainer(BaseTrainer): 12 | def __init__(self, model, args): 13 | super(MGNTrainer, self).__init__(model, args) 14 | self.ce_loss = CrossEntropyLoss().cuda() 15 | self.triplet_loss = TripletLoss(margin=self.args.margin).cuda() 16 | self.dual_loss = DualCausalityLoss().cuda() 17 | 18 | self.losses_ce = AverageMeter() 19 | self.losses_tr = AverageMeter() 20 | self.precisions = AverageMeter() 21 | 22 | def _logging(self, cur_iter): 23 | self._tensorboard_writer(cur_iter, data={ 24 | 'loss': self.losses_ce.val + self.losses_tr.val, 25 | 'loss_ce': self.losses_ce.val, 26 | 'loss_tr': self.losses_tr.val, 27 | 'prec': self.precisions.val 28 | }) 29 | if not (cur_iter % self.args.print_freq == 0 and dist.get_rank() == 0): 30 | return 31 | print('Iter: [{}/{}]\t' 32 | 'Time {:.3f} ({:.3f}) (ETA: {:.2f}h)\t' 33 | 'Data {:.3f} ({:.3f})\t' 34 | 'Loss_ce {:.3f} ({:.3f})\t' 35 | 'Loss_tr {:.3f} ({:.3f})\t' 36 | 'Prec {:.2%} ({:.2%})' 37 | .format(cur_iter, self.args.iters, 38 | self.batch_time.val, self.batch_time.avg, 39 | (self.args.iters - cur_iter) * self.batch_time.avg / 3600, 40 | self.data_time.val, self.data_time.avg, 41 | self.losses_ce.val, self.losses_ce.avg, 42 | self.losses_tr.val, self.losses_tr.avg, 43 | self.precisions.val, self.precisions.avg)) 44 | 45 | def _refresh_information(self, cur_iter, lr): 46 | if cur_iter % self.args.refresh_freq == 0 or cur_iter == 1: 47 | self.batch_time = AverageMeter() 48 | self.data_time = AverageMeter() 49 | self.losses_ce = AverageMeter() 50 | self.losses_tr = AverageMeter() 51 | self.precisions = AverageMeter() 52 | if dist.get_rank() == 0: 53 | print("lr = {} \t".format(lr)) 54 | 55 | def run(self, inputs): 56 | inputs, targets = self._parse_data(inputs) 57 | feats, _, logits = self.model(inputs, targets) 58 | 59 | loss_ce = 0 60 | for logit in logits: 61 | loss_ce += self.ce_loss(logit, targets) / len(logits) 62 | 63 | loss_tr = 0 64 | for feat in feats: 65 | loss_, _ = self.triplet_loss(feat, targets) 66 | loss_tr += loss_ / len(feats) 67 | 68 | loss = loss_ce + loss_tr 69 | 70 | self.losses_ce.update(loss_ce.item()) 71 | self.losses_tr.update(loss_tr.item()) 72 | 73 | prec, = accuracy(logits[0].data, targets.data) 74 | prec = prec[0] 75 | self.precisions.update(prec) 76 | 77 | return loss 78 | -------------------------------------------------------------------------------- /reid/trainer/snr_trainer.py: -------------------------------------------------------------------------------- 1 | import torch.distributed as dist 2 | from torch.nn import CrossEntropyLoss 3 | 4 | from reid.loss import TripletLoss 5 | from reid.loss.dual_causality_loss import DualCausalityLoss 6 | from reid.trainer.base_trainer import BaseTrainer 7 | from reid.utils import accuracy 8 | from reid.utils.meters import AverageMeter 9 | 10 | 11 | class SNRTrainer(BaseTrainer): 12 | def __init__(self, model, args): 13 | super(SNRTrainer, self).__init__(model, args) 14 | self.ce_loss = CrossEntropyLoss().cuda() 15 | self.triplet_loss = TripletLoss(margin=self.args.margin).cuda() 16 | self.dual_loss = DualCausalityLoss().cuda() 17 | 18 | self.losses_ce = AverageMeter() 19 | self.losses_tr = AverageMeter() 20 | self.losses_dual = AverageMeter() 21 | self.precisions = AverageMeter() 22 | 23 | def _logging(self, cur_iter): 24 | self._tensorboard_writer(cur_iter, data={ 25 | 'loss': self.losses_ce.val + self.losses_tr.val, 26 | 'loss_ce': self.losses_ce.val, 27 | 'loss_tr': self.losses_tr.val, 28 | 'loss_dual': self.losses_dual.val, 29 | 'prec': self.precisions.val 30 | }) 31 | if not (cur_iter % self.args.print_freq == 0 and dist.get_rank() == 0): 32 | return 33 | print('Iter: [{}/{}]\t' 34 | 'Time {:.3f} ({:.3f}) (ETA: {:.2f}h)\t' 35 | 'Data {:.3f} ({:.3f})\t' 36 | 'Loss_ce {:.3f} ({:.3f})\t' 37 | 'Loss_tr {:.3f} ({:.3f})\t' 38 | 'Loss_dual {:.3f} ({:.3f})\t' 39 | 'Prec {:.2%} ({:.2%})' 40 | .format(cur_iter, self.args.iters, 41 | self.batch_time.val, self.batch_time.avg, 42 | (self.args.iters - cur_iter) * self.batch_time.avg / 3600, 43 | self.data_time.val, self.data_time.avg, 44 | self.losses_ce.val, self.losses_ce.avg, 45 | self.losses_tr.val, self.losses_tr.avg, 46 | self.losses_dual.val, self.losses_dual.avg, 47 | self.precisions.val, self.precisions.avg)) 48 | 49 | def _refresh_information(self, cur_iter, lr): 50 | if cur_iter % self.args.refresh_freq == 0 or cur_iter == 1: 51 | self.batch_time = AverageMeter() 52 | self.data_time = AverageMeter() 53 | self.losses_ce = AverageMeter() 54 | self.losses_tr = AverageMeter() 55 | self.losses_dual = AverageMeter() 56 | self.precisions = AverageMeter() 57 | if dist.get_rank() == 0: 58 | print("lr = {} \t".format(lr)) 59 | 60 | def run(self, inputs): 61 | inputs, targets = self._parse_data(inputs) 62 | feat, _, logits, dual_list = self.model(inputs, targets) 63 | 64 | loss_ce = self.ce_loss(logits, targets) 65 | loss_tr, _ = self.triplet_loss(feat, targets) 66 | 67 | w = [0.1, 0.1, 0.5, 0.5] 68 | loss_dual = 0 69 | for idx, item in enumerate(dual_list): 70 | loss_dual += w[idx] * self.dual_loss(item, targets) 71 | 72 | loss = loss_ce + loss_tr + loss_dual 73 | 74 | self.losses_ce.update(loss_ce.item()) 75 | self.losses_tr.update(loss_tr.item()) 76 | self.losses_dual.update(loss_dual.item()) 77 | 78 | prec, = accuracy(logits.data, targets.data) 79 | prec = prec[0] 80 | self.precisions.update(prec) 81 | 82 | return loss 83 | -------------------------------------------------------------------------------- /reid/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | 5 | 6 | def to_numpy(tensor): 7 | if torch.is_tensor(tensor): 8 | return tensor.cpu().numpy() 9 | elif type(tensor).__module__ != 'numpy': 10 | raise ValueError("Cannot convert {} to numpy array" 11 | .format(type(tensor))) 12 | return tensor 13 | 14 | 15 | def to_torch(ndarray): 16 | if type(ndarray).__module__ == 'numpy': 17 | return torch.from_numpy(ndarray) 18 | elif not torch.is_tensor(ndarray): 19 | raise ValueError("Cannot convert {} to torch tensor" 20 | .format(type(ndarray))) 21 | return ndarray 22 | 23 | 24 | def accuracy(output, target, topk=(1,)): 25 | with torch.no_grad(): 26 | output, target = to_torch(output), to_torch(target) 27 | maxk = max(topk) 28 | batch_size = target.size(0) 29 | 30 | _, pred = output.topk(maxk, 1, True, True) 31 | pred = pred.t() 32 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 33 | 34 | ret = [] 35 | for k in topk: 36 | correct_k = correct[:k].view(-1).float().sum(dim=0, keepdim=True) 37 | ret.append(correct_k.mul_(1. / batch_size)) 38 | return ret 39 | -------------------------------------------------------------------------------- /reid/utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /reid/utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /reid/utils/__pycache__/adamw.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/utils/__pycache__/adamw.cpython-36.pyc -------------------------------------------------------------------------------- /reid/utils/__pycache__/adamw.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/utils/__pycache__/adamw.cpython-38.pyc -------------------------------------------------------------------------------- /reid/utils/__pycache__/comm_.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/utils/__pycache__/comm_.cpython-36.pyc -------------------------------------------------------------------------------- /reid/utils/__pycache__/comm_.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/utils/__pycache__/comm_.cpython-38.pyc -------------------------------------------------------------------------------- /reid/utils/__pycache__/distributed_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/utils/__pycache__/distributed_utils.cpython-36.pyc -------------------------------------------------------------------------------- /reid/utils/__pycache__/distributed_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/utils/__pycache__/distributed_utils.cpython-38.pyc -------------------------------------------------------------------------------- /reid/utils/__pycache__/distributed_utils_pt.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/utils/__pycache__/distributed_utils_pt.cpython-38.pyc -------------------------------------------------------------------------------- /reid/utils/__pycache__/logging.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/utils/__pycache__/logging.cpython-36.pyc -------------------------------------------------------------------------------- /reid/utils/__pycache__/logging.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/utils/__pycache__/logging.cpython-38.pyc -------------------------------------------------------------------------------- /reid/utils/__pycache__/lr_scheduler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/utils/__pycache__/lr_scheduler.cpython-36.pyc -------------------------------------------------------------------------------- /reid/utils/__pycache__/lr_scheduler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/utils/__pycache__/lr_scheduler.cpython-38.pyc -------------------------------------------------------------------------------- /reid/utils/__pycache__/meters.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/utils/__pycache__/meters.cpython-36.pyc -------------------------------------------------------------------------------- /reid/utils/__pycache__/meters.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/utils/__pycache__/meters.cpython-38.pyc -------------------------------------------------------------------------------- /reid/utils/__pycache__/osutils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/utils/__pycache__/osutils.cpython-36.pyc -------------------------------------------------------------------------------- /reid/utils/__pycache__/osutils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/utils/__pycache__/osutils.cpython-38.pyc -------------------------------------------------------------------------------- /reid/utils/__pycache__/serialization.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/utils/__pycache__/serialization.cpython-36.pyc -------------------------------------------------------------------------------- /reid/utils/__pycache__/serialization.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/utils/__pycache__/serialization.cpython-38.pyc -------------------------------------------------------------------------------- /reid/utils/__pycache__/visualizer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/utils/__pycache__/visualizer.cpython-36.pyc -------------------------------------------------------------------------------- /reid/utils/__pycache__/visualizer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/utils/__pycache__/visualizer.cpython-38.pyc -------------------------------------------------------------------------------- /reid/utils/__pycache__/vit_rollout.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/utils/__pycache__/vit_rollout.cpython-36.pyc -------------------------------------------------------------------------------- /reid/utils/__pycache__/vit_rollout.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwz-zju/Instruct-ReID/8250f44a301a50d8afcd2e09a46c3e96bf52090d/reid/utils/__pycache__/vit_rollout.cpython-38.pyc -------------------------------------------------------------------------------- /reid/utils/distributed_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.distributed as dist 5 | # import linklink as link 6 | # try: 7 | # import spring.linklink as link 8 | # except: 9 | # import linklink as link 10 | 11 | def dist_linklink_init(): 12 | proc_id = int(os.environ['SLURM_PROCID']) 13 | ntasks = int(os.environ['SLURM_NTASKS']) 14 | node_list = os.environ['SLURM_NODELIST'] 15 | num_gpus = torch.cuda.device_count() 16 | torch.cuda.set_device(proc_id % num_gpus) 17 | link.initialize() 18 | world_size = link.get_world_size() 19 | rank = link.get_rank() 20 | 21 | return rank, world_size 22 | 23 | def dist_init(args): 24 | try: 25 | proc_id = int(os.environ['SLURM_PROCID']) 26 | ntasks = int(os.environ['SLURM_NTASKS']) 27 | node_list = os.environ['SLURM_NODELIST'] 28 | num_gpus = torch.cuda.device_count() 29 | torch.cuda.set_device(proc_id % num_gpus) 30 | 31 | if '[' in node_list: 32 | beg = node_list.find('[') 33 | pos1 = node_list.find('-', beg) 34 | if pos1 < 0: 35 | pos1 = 1000 36 | pos2 = node_list.find(',', beg) 37 | if pos2 < 0: 38 | pos2 = 1000 39 | node_list = node_list[:min(pos1, pos2)].replace('[', '') 40 | addr = node_list[8:].replace('-', '.') 41 | print(addr) 42 | 43 | os.environ['MASTER_PORT'] = args.port 44 | os.environ['MASTER_ADDR'] = addr 45 | os.environ['WORLD_SIZE'] = str(ntasks) 46 | os.environ['RANK'] = str(proc_id) 47 | except BaseException: 48 | print("For debug....") 49 | num_gpus = torch.cuda.device_count() 50 | os.environ['MASTER_ADDR'] = 'localhost' 51 | os.environ['MASTER_PORT'] = str(args.port) 52 | torch.cuda.set_device(args.local_rank) 53 | 54 | dist.init_process_group(backend='nccl') 55 | print("Rank {}, World_size {}".format(dist.get_rank(), dist.get_world_size())) 56 | return num_gpus > 1 57 | -------------------------------------------------------------------------------- /reid/utils/faiss_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import faiss 4 | import torch 5 | 6 | 7 | def swig_ptr_from_FloatTensor(x): 8 | assert x.is_contiguous() 9 | assert x.dtype == torch.float32 10 | return faiss.cast_integer_to_float_ptr( 11 | x.storage().data_ptr() + x.storage_offset() * 4) 12 | 13 | 14 | def swig_ptr_from_LongTensor(x): 15 | assert x.is_contiguous() 16 | assert x.dtype == torch.int64, 'dtype=%s' % x.dtype 17 | return faiss.cast_integer_to_long_ptr( 18 | x.storage().data_ptr() + x.storage_offset() * 8) 19 | 20 | 21 | def search_index_pytorch(index, x, k, D=None, I=None): 22 | """call the search function of an index with pytorch tensor I/O (CPU 23 | and GPU supported)""" 24 | assert x.is_contiguous() 25 | n, d = x.size() 26 | assert d == index.d 27 | 28 | if D is None: 29 | D = torch.empty((n, k), dtype=torch.float32, device=x.device) 30 | else: 31 | assert D.size() == (n, k) 32 | 33 | if I is None: 34 | I = torch.empty((n, k), dtype=torch.int64, device=x.device) 35 | else: 36 | assert I.size() == (n, k) 37 | torch.cuda.synchronize() 38 | xptr = swig_ptr_from_FloatTensor(x) 39 | Iptr = swig_ptr_from_LongTensor(I) 40 | Dptr = swig_ptr_from_FloatTensor(D) 41 | index.search_c(n, xptr, 42 | k, Dptr, Iptr) 43 | torch.cuda.synchronize() 44 | return D, I 45 | 46 | 47 | def search_raw_array_pytorch(res, xb, xq, k, D=None, I=None, 48 | metric=faiss.METRIC_L2): 49 | assert xb.device == xq.device 50 | 51 | nq, d = xq.size() 52 | if xq.is_contiguous(): 53 | xq_row_major = True 54 | elif xq.t().is_contiguous(): 55 | xq = xq.t() # I initially wrote xq:t(), Lua is still haunting me :-) 56 | xq_row_major = False 57 | else: 58 | raise TypeError('matrix should be row or column-major') 59 | 60 | xq_ptr = swig_ptr_from_FloatTensor(xq) 61 | 62 | nb, d2 = xb.size() 63 | assert d2 == d 64 | if xb.is_contiguous(): 65 | xb_row_major = True 66 | elif xb.t().is_contiguous(): 67 | xb = xb.t() 68 | xb_row_major = False 69 | else: 70 | raise TypeError('matrix should be row or column-major') 71 | xb_ptr = swig_ptr_from_FloatTensor(xb) 72 | 73 | if D is None: 74 | D = torch.empty(nq, k, device=xb.device, dtype=torch.float32) 75 | else: 76 | assert D.shape == (nq, k) 77 | assert D.device == xb.device 78 | 79 | if I is None: 80 | I = torch.empty(nq, k, device=xb.device, dtype=torch.int64) 81 | else: 82 | assert I.shape == (nq, k) 83 | assert I.device == xb.device 84 | 85 | D_ptr = swig_ptr_from_FloatTensor(D) 86 | I_ptr = swig_ptr_from_LongTensor(I) 87 | 88 | faiss.bruteForceKnn(res, metric, 89 | xb_ptr, xb_row_major, nb, 90 | xq_ptr, xq_row_major, nq, 91 | d, k, D_ptr, I_ptr) 92 | 93 | return D, I 94 | 95 | 96 | def index_init_gpu(ngpus, feat_dim): 97 | flat_config = [] 98 | for i in range(ngpus): 99 | cfg = faiss.GpuIndexFlatConfig() 100 | cfg.useFloat16 = False 101 | cfg.device = i 102 | flat_config.append(cfg) 103 | 104 | res = [faiss.StandardGpuResources() for i in range(ngpus)] 105 | indexes = [faiss.GpuIndexFlatL2(res[i], feat_dim, flat_config[i]) for i in range(ngpus)] 106 | index = faiss.IndexShards(feat_dim) 107 | for sub_index in indexes: 108 | index.add_shard(sub_index) 109 | index.reset() 110 | return index 111 | 112 | 113 | def index_init_cpu(feat_dim): 114 | return faiss.IndexFlatL2(feat_dim) 115 | -------------------------------------------------------------------------------- /reid/utils/logging.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import sys 4 | 5 | from .osutils import mkdir_if_missing 6 | 7 | 8 | class Logger(object): 9 | def __init__(self, fpath=None): 10 | self.console = sys.stdout 11 | self.file = None 12 | if fpath is not None: 13 | mkdir_if_missing(os.path.dirname(fpath)) 14 | self.file = open(fpath, 'w') 15 | 16 | def __del__(self): 17 | self.close() 18 | 19 | def __enter__(self): 20 | pass 21 | 22 | def __exit__(self, *args): 23 | self.close() 24 | 25 | def write(self, msg): 26 | self.console.write(msg) 27 | if self.file is not None: 28 | self.file.write(msg) 29 | 30 | def flush(self): 31 | self.console.flush() 32 | if self.file is not None: 33 | self.file.flush() 34 | os.fsync(self.file.fileno()) 35 | 36 | def close(self): 37 | self.console.close() 38 | if self.file is not None: 39 | self.file.close() 40 | -------------------------------------------------------------------------------- /reid/utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | from bisect import bisect_right 3 | import torch 4 | from torch.optim.lr_scheduler import * 5 | import math 6 | 7 | 8 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 9 | def __init__( 10 | self, 11 | optimizer, 12 | milestones, 13 | gamma=0.1, 14 | warmup_factor=1.0 / 3, 15 | warmup_iters=500, 16 | warmup_method="linear", 17 | last_epoch=-1, 18 | ): 19 | if not list(milestones) == sorted(milestones): 20 | raise ValueError( 21 | "Milestones should be a list of" " increasing integers. Got {}", 22 | milestones, 23 | ) 24 | 25 | if warmup_method not in ("constant", "linear"): 26 | raise ValueError( 27 | "Only 'constant' or 'linear' warmup_method accepted" 28 | "got {}".format(warmup_method) 29 | ) 30 | self.milestones = milestones 31 | self.gamma = gamma 32 | self.warmup_factor = warmup_factor 33 | self.warmup_iters = warmup_iters 34 | self.warmup_method = warmup_method 35 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 36 | 37 | def get_lr(self): 38 | warmup_factor = 1 39 | if self.last_epoch < self.warmup_iters: 40 | if self.warmup_method == "constant": 41 | warmup_factor = self.warmup_factor 42 | elif self.warmup_method == "linear": 43 | alpha = float(self.last_epoch) / float(self.warmup_iters) 44 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 45 | return [ 46 | base_lr 47 | * warmup_factor 48 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 49 | for base_lr in self.base_lrs 50 | ] 51 | 52 | 53 | class WarmupCosineLR(torch.optim.lr_scheduler._LRScheduler): 54 | def __init__(self, optimizer, max_iters, warmup_iters, 55 | warmup_factor=1e-2, warmup_method="linear", last_epoch=-1): 56 | 57 | if warmup_method not in ("constant", "linear"): 58 | raise ValueError(f"Only 'constant' or 'linear' warmup_method accepted. Got {warmup_method}") 59 | 60 | self.max_iters = max_iters 61 | self.warmup_factor = warmup_factor 62 | self.warmup_iters = warmup_iters 63 | self.warmup_method = warmup_method 64 | super(WarmupCosineLR, self).__init__(optimizer, last_epoch) 65 | 66 | def get_lr(self): 67 | warmup_factor = 1 68 | if self.last_epoch < self.warmup_iters: 69 | if self.warmup_method == "constant": 70 | warmup_factor = self.warmup_factor 71 | elif self.warmup_method == "linear": 72 | alpha = float(self.last_epoch) / self.warmup_iters 73 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 74 | return [ 75 | warmup_factor * base_lr * (1 + math.cos(math.pi * self.last_epoch / self.max_iters)) / 2 76 | for base_lr in self.base_lrs 77 | ] -------------------------------------------------------------------------------- /reid/utils/meters.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | 7 | def __init__(self): 8 | self.val = 0 9 | self.avg = 0 10 | self.sum = 0 11 | self.count = 0 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count 24 | -------------------------------------------------------------------------------- /reid/utils/osutils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import errno 4 | 5 | 6 | def mkdir_if_missing(dir_path): 7 | try: 8 | os.makedirs(dir_path) 9 | except OSError as e: 10 | if e.errno != errno.EEXIST: 11 | raise 12 | -------------------------------------------------------------------------------- /reid/utils/serialization.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import json 3 | import os.path as osp 4 | import shutil 5 | 6 | import torch 7 | from torch.nn import Parameter 8 | 9 | from .osutils import mkdir_if_missing 10 | 11 | 12 | def read_json(fpath): 13 | with open(fpath, 'r') as f: 14 | obj = json.load(f) 15 | return obj 16 | 17 | 18 | def write_json(obj, fpath): 19 | mkdir_if_missing(osp.dirname(fpath)) 20 | with open(fpath, 'w') as f: 21 | json.dump(obj, f, indent=4, separators=(',', ': ')) 22 | 23 | 24 | def save_checkpoint(state, is_best=False, fpath='checkpoint.pth.tar'): 25 | mkdir_if_missing(osp.dirname(fpath)) 26 | torch.save(state, fpath) 27 | if is_best: 28 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'model_best.pth.tar')) 29 | 30 | 31 | def load_checkpoint(fpath): 32 | if osp.isfile(fpath): 33 | # checkpoint = torch.load(fpath) 34 | checkpoint = torch.load(fpath, map_location=torch.device('cpu')) 35 | print("=> Loaded checkpoint '{}'".format(fpath)) 36 | return checkpoint 37 | else: 38 | raise ValueError("=> No checkpoint found at '{}'".format(fpath)) 39 | 40 | 41 | def copy_state_dict(state_dict, model, strip=None): 42 | tgt_state = model.state_dict() 43 | copied_names = set() 44 | for name, param in state_dict.items(): 45 | if strip is not None and name.startswith(strip): 46 | name = name[len(strip):] 47 | if name not in tgt_state: 48 | continue 49 | if isinstance(param, Parameter): 50 | param = param.data 51 | if param.size() != tgt_state[name].size(): 52 | print('mismatch:', name, param.size(), tgt_state[name].size()) 53 | continue 54 | tgt_state[name].copy_(param) 55 | copied_names.add(name) 56 | 57 | missing = set(tgt_state.keys()) - copied_names 58 | if len(missing) > 0: 59 | print("missing keys in state_dict:", missing) 60 | 61 | return model 62 | -------------------------------------------------------------------------------- /reid/utils/vit_rollout.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | import numpy 4 | import sys 5 | from torchvision import transforms 6 | import numpy as np 7 | import cv2 8 | 9 | def rollout(attentions, discard_ratio, head_fusion): 10 | print('attentions[0].shape:') 11 | for atten in attentions: 12 | print(atten.shape) 13 | result = torch.eye(attentions[0].size(-1)) 14 | with torch.no_grad(): 15 | for attention in attentions: 16 | if head_fusion == "mean": 17 | attention_heads_fused = attention.mean(axis=1) 18 | elif head_fusion == "max": 19 | attention_heads_fused = attention.max(axis=1)[0] 20 | elif head_fusion == "min": 21 | attention_heads_fused = attention.min(axis=1)[0] 22 | else: 23 | raise "Attention head fusion type Not supported" 24 | 25 | # Drop the lowest attentions, but 26 | # don't drop the class token 27 | flat = attention_heads_fused.view(attention_heads_fused.size(0), -1) 28 | _, indices = flat.topk(int(flat.size(-1)*discard_ratio), -1, False) 29 | indices = indices[indices != 0] 30 | flat[0, indices] = 0 31 | 32 | I = torch.eye(attention_heads_fused.size(-1)) 33 | a = (attention_heads_fused + 1.0*I)/2 34 | print('a.shape:',a.shape) 35 | a = a / a.sum(dim=-1,keepdim=True) 36 | 37 | result = torch.matmul(a, result) 38 | 39 | # Look at the total attention between the class token, 40 | # and the image patches 41 | mask = result[0, 0 , 1 :] 42 | # In case of 224x224 image, this brings us from 196 to 14 43 | # width = int(mask.size(-1)**0.5) 44 | print('mask shape:',mask.shape) 45 | height = 8 46 | width = 8 47 | mask = mask.reshape(height, width).numpy() 48 | mask = mask / np.max(mask) 49 | return mask 50 | 51 | class VITAttentionRollout: 52 | def __init__(self, model,test_feat_type, attention_layer_name='attn_drop', head_fusion="mean", 53 | discard_ratio=0.9): 54 | self.model = model 55 | self.head_fusion = head_fusion 56 | self.discard_ratio = discard_ratio 57 | for name, module in self.model.named_modules(): 58 | if test_feat_type=='b': 59 | if attention_layer_name in name and 'clothes' not in name and 'fusion' not in name and 'b2' not in name: 60 | module.register_forward_hook(self.get_attention) 61 | else: 62 | if attention_layer_name in name and 'clothes' in name and 'b2' not in name and 'fusion' not in name: 63 | module.register_forward_hook(self.get_attention) 64 | 65 | self.attentions = [] 66 | 67 | def get_attention(self, module, input, output): 68 | self.attentions.append(output.cpu()) 69 | 70 | def __call__(self, input_tensor,clothes_tensor): 71 | self.attentions = [] 72 | with torch.no_grad(): 73 | output = self.model(input_tensor,clothes_tensor) 74 | 75 | return rollout(self.attentions, self.discard_ratio, self.head_fusion) 76 | 77 | 78 | def show_mask_on_image(img, mask): 79 | img = np.float32(img) / 255 80 | heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET) 81 | heatmap = np.float32(heatmap) / 255 82 | # cam=np.float32(img) 83 | cam = heatmap + np.float32(img) 84 | cam = cam / np.max(cam) 85 | return np.uint8(255 * cam) 86 | -------------------------------------------------------------------------------- /scripts/config_ablation5.yaml: -------------------------------------------------------------------------------- 1 | common: 2 | # transformer setting 3 | DROP_PATH: 0.0 4 | DROP_OUT: 0.0 5 | ATT_DROP_RATE: 0.0 6 | DROP_PATH_CLO: 0.0 7 | DROP_OUT_CLO: 0.0 8 | ATT_DROP_RATE_CLO: 0.0 9 | DROP_PATH_FUSION: 0.0 10 | DROP_OUT_FUSION: 0.0 11 | ATT_DROP_RATE_FUSION: 0.0 12 | STRIDE_SIZE: [16, 16] 13 | JPM: True 14 | SHIFT_NUM: 5 15 | SHUFFLE_GROUP: 2 16 | DEVIDE_LENGTH: 4 17 | RE_ARRANGE: True 18 | SIE_CAMERA: False 19 | SIE_VIEW: False 20 | FUSION_DEPTH: 3 21 | CLOTHES_AVGPOOL: True 22 | BIO_EMBEDDING: 768 23 | CLO_EMBEDDING: 768 24 | FUSION_EMBEDDING: 384 25 | PRETRAIN: True 26 | PATCH_SIZE_CLO: 16 27 | STRIDE_SIZE_CLO: [16,16] 28 | 29 | # structure setting 30 | NECK_FEAT: after 31 | ID_LOSS_TYPE: Linear 32 | # if cosine 33 | SCALE: 30.0 34 | MARGIN: 0.40 35 | 36 | # training tricks setting 37 | LABEL_SMOOTH: False 38 | CENTER_LOSS: False 39 | 40 | # circle loss 41 | CIRCLE_LOSS: False 42 | GAMMA: 256 43 | C_MARGIN: 0.25 44 | W: 1.0 45 | metric: cosine 46 | -------------------------------------------------------------------------------- /scripts/config_attr.yaml: -------------------------------------------------------------------------------- 1 | tasks: 2 | 0: 3 | loss_weight: 1.0 4 | gres_ratio: 1 5 | task_name: DataBuilder_attr 6 | # ceph or cluster 7 | root_path: + Instruct-ReID/data/real1 8 | train_file_path: + Instruct-ReID/data/real1/datalist/train_attr.txt -------------------------------------------------------------------------------- /scripts/config_ctcc.yaml: -------------------------------------------------------------------------------- 1 | tasks: 2 | 0: 3 | loss_weight: 1.0 4 | gres_ratio: 1 5 | task_name: DataBuilder_ctcc 6 | # ceph or cluster 7 | root_path: + Instruct-ReID/data/real1 8 | train_file_path: + Instruct-ReID/data/real1/datalist/train_ctcc.txt -------------------------------------------------------------------------------- /scripts/config_cuhk.yaml: -------------------------------------------------------------------------------- 1 | tasks: 2 | 0: 3 | loss_weight: 1.0 4 | gres_ratio: 1 5 | task_name: DataBuilder_sc 6 | # ceph or cluster 7 | root_path: + Instruct-ReID/data/cuhk 8 | train_file_path: + Instruct-ReID/data/cuhk/datalist/train.txt -------------------------------------------------------------------------------- /scripts/config_cuhk_pedes.yaml: -------------------------------------------------------------------------------- 1 | tasks: 2 | 0: 3 | loss_weight: 1.0 4 | gres_ratio: 1 5 | task_name: DataBuilder_t2i 6 | attt_file: + Instruct-ReID/data/cuhk_pedes/caption_t2i_v2.json 7 | root_path: + Instruct-ReID/data/cuhk_pedes 8 | train_file_path: + Instruct-ReID/data/cuhk_pedes/train_t2i_v2.txt -------------------------------------------------------------------------------- /scripts/config_joint.yaml: -------------------------------------------------------------------------------- 1 | tasks: 2 | 0: 3 | loss_weight: 1.0 4 | gres_ratio: 1 5 | task_name: DataBuilder_ctcc 6 | # ceph or cluster 7 | root_path: + Instruct-ReID/data/real1 8 | train_file_path: + Instruct-ReID/data/real1/datalist/train_ctcc.txt 9 | 1: 10 | loss_weight: 1.0 11 | gres_ratio: 1 12 | task_name: DataBuilder_sc 13 | # ceph or cluster 14 | root_path: + Instruct-ReID/data/market 15 | train_file_path: + Instruct-ReID/data/market/datalist/train.txt 16 | 2: 17 | loss_weight: 1.0 18 | gres_ratio: 1 19 | task_name: DataBuilder_sc 20 | # ceph or cluster 21 | root_path: + Instruct-ReID/data/msmt 22 | train_file_path: + Instruct-ReID/data/msmt/datalist/train.txt 23 | 3: 24 | loss_weight: 1.0 25 | gres_ratio: 1 26 | task_name: DataBuilder_sc 27 | # ceph or cluster 28 | root_path: + Instruct-ReID/data/cuhk 29 | train_file_path: + Instruct-ReID/data/cuhk/datalist/train.txt 30 | 4: 31 | loss_weight: 1.0 32 | gres_ratio: 1 33 | task_name: DataBuilder_cc 34 | # ceph or cluster 35 | root_path: + Instruct-ReID/data/ltcc 36 | train_file_path: + Instruct-ReID/data/ltcc/datalist/train.txt 37 | 5: 38 | loss_weight: 1.0 39 | gres_ratio: 1 40 | task_name: DataBuilder_cross 41 | root_path: + Instruct-ReID/data/llcm 42 | train_file_path: + Instruct-ReID/data/llcm/train.txt 43 | 6: 44 | loss_weight: 1.0 45 | gres_ratio: 1 46 | task_name: DataBuilder_t2i 47 | attt_file: + Instruct-ReID/data/cuhk_pedes/caption_t2i_v2.json 48 | root_path: + Instruct-ReID/data/cuhk_pedes 49 | train_file_path: + Instruct-ReID/data/cuhk_pedes/train_t2i_v2.txt 50 | 7: 51 | loss_weight: 1.0 52 | gres_ratio: 1 53 | task_name: DataBuilder_attr 54 | # ceph or cluster 55 | root_path: + Instruct-ReID/data/real1 56 | train_file_path: + Instruct-ReID/data/real1/datalist/train_attr.txt 57 | 8: 58 | loss_weight: 1.0 59 | gres_ratio: 1 60 | task_name: DataBuilder_cc 61 | # ceph or cluster 62 | root_path: + Instruct-ReID/data/prcc 63 | train_file_path: + Instruct-ReID/data/prcc/datalist/train.txt 64 | 9: 65 | loss_weight: 1.0 66 | gres_ratio: 1 67 | task_name: DataBuilder_cc 68 | # ceph or cluster 69 | root_path: + Instruct-ReID/data/vc_clothes 70 | train_file_path: + Instruct-ReID/data/vc_clothes/datalist/train_cc_clo.txt 71 | 10: 72 | loss_weight: 1.0 73 | gres_ratio: 6 74 | task_name: DataBuilder_t2i 75 | # ceph or cluster 76 | root_path: + Instruct-ReID/data/PLIP 77 | train_file_path: + Instruct-ReID/data/PLIP/train_t2i.txt -------------------------------------------------------------------------------- /scripts/config_llcm.yaml: -------------------------------------------------------------------------------- 1 | tasks: 2 | 0: 3 | loss_weight: 1.0 4 | gres_ratio: 1 5 | task_name: DataBuilder_cross 6 | root_path: + Instruct-ReID/data/llcm 7 | train_file_path: + Instruct-ReID/data/llcm/train.txt -------------------------------------------------------------------------------- /scripts/config_ltcc.yaml: -------------------------------------------------------------------------------- 1 | tasks: 2 | 0: 3 | loss_weight: 1.0 4 | gres_ratio: 1 5 | task_name: DataBuilder_cc 6 | # ceph or cluster 7 | root_path: + Instruct-ReID/data/ltcc 8 | train_file_path: + Instruct-ReID/data/ltcc/datalist/train.txt -------------------------------------------------------------------------------- /scripts/config_market.yaml: -------------------------------------------------------------------------------- 1 | tasks: 2 | 0: 3 | loss_weight: 1.0 4 | gres_ratio: 1 5 | task_name: DataBuilder_sc 6 | # ceph or cluster 7 | root_path: + Instruct-ReID/data/market 8 | train_file_path: + Instruct-ReID/data/market/datalist/train.txt -------------------------------------------------------------------------------- /scripts/config_msmt.yaml: -------------------------------------------------------------------------------- 1 | tasks: 2 | 0: 3 | loss_weight: 1.0 4 | gres_ratio: 1 5 | task_name: DataBuilder_sc 6 | # ceph or cluster 7 | root_path: + Instruct-ReID/data/msmt 8 | train_file_path: + Instruct-ReID/data/msmt/datalist/train.txt 9 | -------------------------------------------------------------------------------- /scripts/config_prcc.yaml: -------------------------------------------------------------------------------- 1 | tasks: 2 | 0: 3 | loss_weight: 1.0 4 | gres_ratio: 1 5 | task_name: DataBuilder_cc 6 | # ceph or cluster 7 | root_path: + Instruct-ReID/data/prcc 8 | train_file_path: + Instruct-ReID/data/prcc/datalist/train.txt -------------------------------------------------------------------------------- /scripts/config_vc.yaml: -------------------------------------------------------------------------------- 1 | tasks: 2 | 0: 3 | loss_weight: 1.0 4 | gres_ratio: 1 5 | task_name: DataBuilder_cc 6 | # ceph or cluster 7 | root_path: + Instruct-ReID/data/vc_clothes 8 | train_file_path: + Instruct-ReID/data/vc_clothes/datalist/train_cc_clo.txt -------------------------------------------------------------------------------- /scripts/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | ARCH=$1 4 | export PATH=~/.local/bin/:$PATH 5 | 6 | CUDA_VISIBLE_DEVICES=0 python -u examples/test_joint.py -a ${ARCH} --resume $2 --patch_size_clo 16 --stride_size_clo 16 --patch_size_bio 16 --stride_size_bio 16 -t $3 \ 7 | --query-list $4 \ 8 | --gallery-list $5 \ 9 | --validate_feat fusion --config ./scripts/config_ablation5.yaml \ 10 | --attn_type dual_attn --fusion_loss all --fusion_branch bio+clot --vit_type base --vit_fusion_layer 2 --test_feat_type f \ 11 | --root $6 12 | --------------------------------------------------------------------------------