├── CVPR2024 Harnessing the Power of MLLMs for Transferable Text-to-Image Person ReID.pdf ├── MLLM4Text-reid-sup.pdf ├── README.md ├── captions └── captioner.py ├── data └── bpe_simple_vocab_16e6.txt.gz ├── datasets ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-39.pyc │ ├── bases.cpython-310.pyc │ ├── bases.cpython-38.pyc │ ├── bases.cpython-39.pyc │ ├── build.cpython-310.pyc │ ├── build.cpython-38.pyc │ ├── build.cpython-39.pyc │ ├── cap2img.cpython-310.pyc │ ├── cap2img.cpython-38.pyc │ ├── cap2img.cpython-39.pyc │ ├── cuhkpedes.cpython-310.pyc │ ├── cuhkpedes.cpython-38.pyc │ ├── cuhkpedes.cpython-39.pyc │ ├── icfgpedes.cpython-310.pyc │ ├── icfgpedes.cpython-38.pyc │ ├── icfgpedes.cpython-39.pyc │ ├── luperson.cpython-310.pyc │ ├── luperson.cpython-38.pyc │ ├── luperson.cpython-39.pyc │ ├── luperson_att.cpython-310.pyc │ ├── luperson_att.cpython-38.pyc │ ├── luperson_att.cpython-39.pyc │ ├── mals.cpython-310.pyc │ ├── mals.cpython-38.pyc │ ├── mals.cpython-39.pyc │ ├── plip.cpython-38.pyc │ ├── rstpreid.cpython-310.pyc │ ├── rstpreid.cpython-38.pyc │ ├── rstpreid.cpython-39.pyc │ ├── sampler.cpython-310.pyc │ ├── sampler.cpython-38.pyc │ ├── sampler.cpython-39.pyc │ ├── sampler_ddp.cpython-310.pyc │ ├── sampler_ddp.cpython-38.pyc │ └── sampler_ddp.cpython-39.pyc ├── bases.py ├── build.py ├── cuhkpedes.py ├── icfgpedes.py ├── luperson.py ├── preprocessing.py ├── rstpreid.py ├── sampler.py └── sampler_ddp.py ├── figures ├── example.png └── framework.png ├── finetune.py ├── finetune.sh ├── model ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-39.pyc │ ├── build.cpython-310.pyc │ ├── build.cpython-38.pyc │ ├── build.cpython-39.pyc │ ├── build_finetune.cpython-38.pyc │ ├── build_finetune.cpython-39.pyc │ ├── clip_model.cpython-310.pyc │ ├── clip_model.cpython-38.pyc │ ├── clip_model.cpython-39.pyc │ ├── memory.cpython-38.pyc │ ├── objectives.cpython-310.pyc │ ├── objectives.cpython-38.pyc │ ├── objectives.cpython-39.pyc │ ├── style.cpython-310.pyc │ ├── style.cpython-38.pyc │ └── style.cpython-39.pyc ├── build.py ├── build_finetune.py ├── clip_model.py ├── memory.py ├── objectives.py └── style.py ├── processor ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-39.pyc │ ├── processor.cpython-310.pyc │ ├── processor.cpython-38.pyc │ ├── processor.cpython-39.pyc │ └── processor_finetune.cpython-39.pyc ├── processor.py └── processor_finetune.py ├── run.sh ├── solver ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-39.pyc │ ├── build.cpython-310.pyc │ ├── build.cpython-38.pyc │ ├── build.cpython-39.pyc │ ├── lr_scheduler.cpython-310.pyc │ ├── lr_scheduler.cpython-38.pyc │ └── lr_scheduler.cpython-39.pyc ├── build.py └── lr_scheduler.py ├── test.py ├── train.py └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-310.pyc ├── __init__.cpython-38.pyc ├── __init__.cpython-39.pyc ├── checkpoint.cpython-310.pyc ├── checkpoint.cpython-38.pyc ├── checkpoint.cpython-39.pyc ├── comm.cpython-310.pyc ├── comm.cpython-38.pyc ├── comm.cpython-39.pyc ├── iotools.cpython-310.pyc ├── iotools.cpython-38.pyc ├── iotools.cpython-39.pyc ├── logger.cpython-310.pyc ├── logger.cpython-38.pyc ├── logger.cpython-39.pyc ├── meter.cpython-310.pyc ├── meter.cpython-38.pyc ├── meter.cpython-39.pyc ├── metrics.cpython-310.pyc ├── metrics.cpython-38.pyc ├── metrics.cpython-39.pyc ├── options.cpython-310.pyc ├── options.cpython-38.pyc ├── options.cpython-39.pyc ├── simple_tokenizer.cpython-310.pyc ├── simple_tokenizer.cpython-38.pyc └── simple_tokenizer.cpython-39.pyc ├── checkpoint.py ├── comm.py ├── iotools.py ├── logger.py ├── meter.py ├── metrics.py ├── options.py └── simple_tokenizer.py /CVPR2024 Harnessing the Power of MLLMs for Transferable Text-to-Image Person ReID.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/CVPR2024 Harnessing the Power of MLLMs for Transferable Text-to-Image Person ReID.pdf -------------------------------------------------------------------------------- /MLLM4Text-reid-sup.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/MLLM4Text-reid-sup.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Harnessing the Power of MLLMs for Transferable Text-to-Image Person ReID (CVPR 2024) 2 | 3 | 5 | 6 | ![](figures/framework.png) 7 | 8 | ### Requirements 9 | ``` 10 | pytorch 1.9.0 11 | torchvision 0.10.0 12 | prettytable 13 | easydict 14 | ``` 15 | 16 | ### 1、Construct LUPerson-MLLM 17 | - Download the LUPerson images from [here](https://github.com/DengpanFu/LUPerson). 18 | - Use MLLMs to annotate LUPerson images. Take [Qwen](https://github.com/QwenLM/Qwen-VL) as an example. The code for image captioning is provided in the ```captions``` folder. Inside, you will find 46 templates along with static and dynamic instructions. You can download all the descriptions for the final LUPerson-MLLM from [here](https://huggingface.co/datasets/TwT-6/MLLM4Text-ReID). 19 | - Place the generated descriptions in the ```captions``` folder. 20 | 21 | ### 2、Prepare Downstream Datasets 22 | Download the CUHK-PEDES dataset from [here](https://github.com/ShuangLI59/Person-Search-with-Natural-Language-Description), ICFG-PEDES dataset from [here](https://github.com/zifyloo/SSAN) and RSTPReid dataset form [here](https://github.com/NjtechCVLab/RSTPReid-Dataset). 23 | 24 | ### 3、Pretrain Model (direct transfer setting) 25 | To pretrain your model, you can simply run ```sh run.sh```. After the model training is completed, it will provide the performance of direct transfer setting. 26 | 27 | ### 4、Fine-tune the Pretrained Model on Downstream Datasets (fine-tune setting) 28 | We release the Pretrain Model Checkpoints [here](https://huggingface.co/datasets/TwT-6/MLLM4Text-ReID). \ 29 | To fine-tune your model, you can simply run ```sh finetune.sh --finetune checkpoint.pth```. After the model training is completed, it will provide the performance of fine-tune setting. 30 | 31 | ### Acknowledgments 32 | This repo borrows partially from [IRRA](https://github.com/anosorae/IRRA). 33 | 34 | ### Citation 35 | ``` 36 | @article{tan2024harnessing, 37 | title={Harnessing the Power of MLLMs for Transferable Text-to-Image Person ReID}, 38 | author={Wentao Tan, Changxing Ding, Jiayu Jiang, Fei Wang, Yibing Zhan, Dapeng Tao}, 39 | journal={CVPR}, 40 | year={2024}, 41 | } 42 | ``` 43 | ### Contact 44 | Email: ftwentaotan@mail.scut.edu.cn or 731584671@qq.com 45 | 46 | 如果可以当然还是希望用中文contact我啦! 47 | -------------------------------------------------------------------------------- /captions/captioner.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer 2 | from transformers.generation import GenerationConfig 3 | import torch 4 | torch.manual_seed(1234) 5 | 6 | model_path = 'Your_download_path' 7 | # Note: The default behavior now has injection attack prevention off. 8 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 9 | 10 | # use bf16 11 | # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-VL-Chat", device_map="auto", trust_remote_code=True, bf16=True).eval() 12 | # use fp16 13 | # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-VL-Chat", device_map="auto", trust_remote_code=True, fp16=True).eval() 14 | # use cpu only 15 | # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-VL-Chat", device_map="cpu", trust_remote_code=True).eval() 16 | # use cuda device 17 | model = AutoModelForCausalLM.from_pretrained(model_path, device_map="cuda", trust_remote_code=True).eval() 18 | 19 | # Specify hyperparameters for generation 20 | model.generation_config = GenerationConfig.from_pretrained(model_path, trust_remote_code=True) 21 | 22 | templates = [ 23 | "Wearing [clothing description], the [person/woman/man] also has [hair description] and is carrying [belongings description].", 24 | "Sporting [hair description], the [person/woman/man] is dressed in [clothing description] and is carrying [belongings description].", 25 | "With [hair description], the [person/woman/man] is wearing [clothing description] and is also carrying [belongings description].", 26 | "In [clothing description] and [footwear description], the [person/woman/man] is also carrying [belongings description].", 27 | "With [hair description], the [person/woman/man] is wearing [clothing description] and is also carrying [belongings description].", 28 | "Carrying [belongings description], the [person/woman/man] is dressed in [clothing description] and [footwear description].", 29 | "In [clothing description] and [footwear description], the [person/woman/man] also has [hair description].", 30 | "Carrying [belongings description], the [person/woman/man] is wearing [clothing description] and [footwear description].", 31 | "In [clothing description] and [accessory description], the [person/woman/man] is also carrying [belongings description].", 32 | "With [hair description], the [person/woman/man] is dressed in [clothing description] and [accessory description].", 33 | "Sporting [hair description], the [person/woman/man] is wearing [clothing description] with [accessory description].", 34 | "With [footwear description], the [person/woman/man] is wearing [clothing description] and [accessory description].", 35 | "With [hair description], the [person/woman/man] is wearing [clothing description] with [accessory description].", 36 | "In [clothing description] and [accessory description], the [person/woman/man] also has [hair description].", 37 | "In [accessory description], the [person/woman/man] also has [hair description] and is carrying [belongings description].", 38 | "With [accessory description], the [person/woman/man] also has [hair description] and is carrying [belongings description].", 39 | "Wearing [clothing description] and [footwear description], the [person/woman/man] also has [hair description].", 40 | "The [person/woman/man] is wearing [footwear description], [accessory description], [clothing description], and [belongings description]. The [person/woman/man] has [hair description].", 41 | "The [person/woman/man] has [hair description] and is wearing [accessory description], [footwear description], [clothing description], and carrying [belongings description].", 42 | "The [person/woman/man] is dressed in [footwear description], [clothing description], [accessory description], and carrying [belongings description]. The [person/woman/man] has [hair description].", 43 | "With [footwear description], the [person/woman/man] is wearing [clothing description], [accessory description], and carrying [belongings description]. The [person/woman/man] has [hair description].", 44 | "The [person/woman/man] sports [hair description] and is dressed in [footwear description], [clothing description], [accessory description], and carrying [belongings description].", 45 | "Wearing [footwear description], [accessory description], [clothing description], the [person/woman/man] is also carrying [belongings description]. The [person/woman/man] has [hair description].", 46 | "The [person/woman/man] is attired in [clothing description], [accessory description], [footwear description], and carrying [belongings description]. The [person/woman/man] has [hair description].", 47 | "The [person/woman/man] is seen wearing [footwear description], [clothing description], [accessory description], and carrying [belongings description]. The [person/woman/man] has [hair description].", 48 | "With [hair description], the [person/woman/man] is wearing [footwear description], [clothing description], [accessory description], and carrying [belongings description].", 49 | "Dressed in [footwear description], [accessory description], [clothing description], and carrying [belongings description], the [person/woman/man] has [hair description].", 50 | "The [person/woman/man] can be seen wearing [footwear description], [clothing description], [accessory description], and carrying [belongings description]. The [person/woman/man] has [hair description].", 51 | "The [person/woman/man] is dressed in [clothing description], [footwear description], [accessory description], and carrying [belongings description]. The [person/woman/man] has [hair description].", 52 | "The [person/woman/man] is wearing [footwear description], [accessory description], [clothing description], and carrying [belongings description]. The [person/woman/man] has [hair description].", 53 | "The [person/woman/man] is attired in [accessory description], [footwear description], [clothing description], and carrying [belongings description]. The [person/woman/man] has [hair description].", 54 | "The [person/woman/man] has [hair description] and is wearing [clothing description], [footwear description], [accessory description], and carrying [belongings description].", 55 | "In [accessory description], [footwear description], [clothing description], and carrying [belongings description], the [person/woman/man] has [hair description].", 56 | "The [person/woman/man] is seen wearing [clothing description], [footwear description], [accessory description], and carrying [belongings description]. The [person/woman/man] has [hair description].", 57 | "The [person/woman/man] is wearing [accessory description], [footwear description], [clothing description], and carrying [belongings description]. The [person/woman/man] has [hair description].", 58 | "Sporting [hair description], the [person/woman/man] is wearing [footwear description], [clothing description], [accessory description], and carrying [belongings description].", 59 | "The [person/woman/man] is seen in [footwear description], [accessory description], [clothing description], and carrying [belongings description]. The [person/woman/man] has [hair description].", 60 | "The [person/woman/man] can be spotted wearing [accessory description], [footwear description], [clothing description], and carrying [belongings description]. The [person/woman/man] has [hair description].", 61 | "The [person/woman/man] has [hair description] and is dressed in [accessory description], [footwear description], [clothing description], and carrying [belongings description].", 62 | "The [person/woman/man] is attired in [accessory description], [clothing description], [footwear description], and carrying [belongings description]. The [person/woman/man] has [hair description].", 63 | "The [person/woman/man] is wearing [accessory description], [clothing description], [footwear description], and carrying [belongings description]. The [person/woman/man] has [hair description].", 64 | "With [hair description], the [person/woman/man] is wearing [accessory description], [clothing description], [footwear description], and carrying [belongings description].", 65 | "Dressed in [accessory description], [clothing description], [footwear description], and carrying [belongings description], the [person/woman/man] has [hair description].", 66 | "The [person/woman/man] can be seen wearing [accessory description], [clothing description], [footwear description], and carrying [belongings description]. The [person/woman/man] has [hair description].", 67 | "The [person/woman/man] is dressed in [clothing description], [accessory description], [footwear description], and carrying [belongings description]. The [person/woman/man] has [hair description].", 68 | "The [person/woman/man] is wearing [clothing description], [accessory description], [footwear description], and carrying [belongings description]. The [person/woman/man] has [hair description]." 69 | ] 70 | att = ['clothing','shoes','hairstyle','gender','belongings'] 71 | 72 | text = f'Write a description about the overall appearance of the person in the image, including the attributions: {att[0]}, {att[1]}, {att[2]}, {att[3]} and {att[4]}. If any attribute is not visible, you can ignore. Do not imagine any contents that are not in the image.' 73 | text = f'Generate a description about the overall appearance of the person, including the {att[0]}, {att[1]}, {att[2]}, {att[3]} and {att[4]}, in a style similar to the template:"{temp}". If some requirements in the template are not visible, you can ignore. Do not imagine any contents that are not in the image.' 74 | query = tokenizer.from_list_format([ 75 | {'image': './figures/framework.png'}, # Either a local path or an url 76 | {'text': text},] 77 | ) 78 | caption, history = model.chat(tokenizer, query=query, history=None) 79 | print(query) -------------------------------------------------------------------------------- /data/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/data/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_dataloader -------------------------------------------------------------------------------- /datasets/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/datasets/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/datasets/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/datasets/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/bases.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/datasets/__pycache__/bases.cpython-310.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/bases.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/datasets/__pycache__/bases.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/bases.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/datasets/__pycache__/bases.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/build.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/datasets/__pycache__/build.cpython-310.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/build.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/datasets/__pycache__/build.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/build.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/datasets/__pycache__/build.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/cap2img.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/datasets/__pycache__/cap2img.cpython-310.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/cap2img.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/datasets/__pycache__/cap2img.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/cap2img.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/datasets/__pycache__/cap2img.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/cuhkpedes.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/datasets/__pycache__/cuhkpedes.cpython-310.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/cuhkpedes.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/datasets/__pycache__/cuhkpedes.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/cuhkpedes.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/datasets/__pycache__/cuhkpedes.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/icfgpedes.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/datasets/__pycache__/icfgpedes.cpython-310.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/icfgpedes.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/datasets/__pycache__/icfgpedes.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/icfgpedes.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/datasets/__pycache__/icfgpedes.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/luperson.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/datasets/__pycache__/luperson.cpython-310.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/luperson.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/datasets/__pycache__/luperson.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/luperson.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/datasets/__pycache__/luperson.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/luperson_att.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/datasets/__pycache__/luperson_att.cpython-310.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/luperson_att.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/datasets/__pycache__/luperson_att.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/luperson_att.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/datasets/__pycache__/luperson_att.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/mals.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/datasets/__pycache__/mals.cpython-310.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/mals.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/datasets/__pycache__/mals.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/mals.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/datasets/__pycache__/mals.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/plip.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/datasets/__pycache__/plip.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/rstpreid.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/datasets/__pycache__/rstpreid.cpython-310.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/rstpreid.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/datasets/__pycache__/rstpreid.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/rstpreid.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/datasets/__pycache__/rstpreid.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/sampler.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/datasets/__pycache__/sampler.cpython-310.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/sampler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/datasets/__pycache__/sampler.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/sampler.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/datasets/__pycache__/sampler.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/sampler_ddp.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/datasets/__pycache__/sampler_ddp.cpython-310.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/sampler_ddp.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/datasets/__pycache__/sampler_ddp.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/sampler_ddp.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/datasets/__pycache__/sampler_ddp.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/bases.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | import os.path as osp 5 | import logging 6 | import torch 7 | from utils.iotools import read_image 8 | from utils.simple_tokenizer import SimpleTokenizer 9 | from prettytable import PrettyTable 10 | import random 11 | import regex as re 12 | import copy 13 | 14 | 15 | class BaseDataset(object): 16 | """ 17 | Base class of text to image reid dataset 18 | """ 19 | logger = logging.getLogger("IRRA.dataset") 20 | 21 | def show_dataset_info(self): 22 | num_train_pids, num_train_imgs, num_train_captions = len( 23 | self.train_id_container), len(self.train_annos), len(self.train) 24 | num_test_pids, num_test_imgs, num_test_captions = len( 25 | self.test_id_container), len(self.test_annos), len( 26 | self.test['captions']) 27 | num_val_pids, num_val_imgs, num_val_captions = len( 28 | self.val_id_container), len(self.val_annos), len( 29 | self.val['captions']) 30 | 31 | # TODO use prettytable print comand line table 32 | 33 | self.logger.info(f"{self.__class__.__name__} Dataset statistics:") 34 | table = PrettyTable(['subset', 'ids', 'images', 'captions']) 35 | table.add_row( 36 | ['train', num_train_pids, num_train_imgs, num_train_captions]) 37 | table.add_row( 38 | ['test', num_test_pids, num_test_imgs, num_test_captions]) 39 | table.add_row(['val', num_val_pids, num_val_imgs, num_val_captions]) 40 | self.logger.info('\n' + str(table)) 41 | 42 | 43 | def tokenize(caption: str, tokenizer, text_length=77, truncate=True) -> torch.LongTensor: 44 | sot_token = tokenizer.encoder["<|startoftext|>"] 45 | eot_token = tokenizer.encoder["<|endoftext|>"] 46 | tokens = [sot_token] + tokenizer.encode(caption) + [eot_token] 47 | 48 | result = torch.zeros(text_length, dtype=torch.long) 49 | if len(tokens) > text_length: 50 | if truncate: 51 | tokens = tokens[:text_length] 52 | tokens[-1] = eot_token 53 | else: 54 | raise RuntimeError( 55 | f"Input {caption} is too long for context length {text_length}" 56 | ) 57 | result[:len(tokens)] = torch.tensor(tokens) 58 | return result 59 | 60 | 61 | class ImageTextDataset(Dataset): 62 | def __init__(self, 63 | dataset, 64 | transform=None, 65 | text_length: int = 77, 66 | truncate: bool = True): 67 | self.dataset = dataset 68 | self.transform = transform 69 | self.text_length = text_length 70 | self.truncate = truncate 71 | self.tokenizer = SimpleTokenizer() 72 | 73 | def __len__(self): 74 | return len(self.dataset) 75 | 76 | def __getitem__(self, index): 77 | pid, image_id, img_path, caption = self.dataset[index] 78 | img = read_image(img_path) 79 | if self.transform is not None: 80 | img = self.transform(img) 81 | 82 | tokens = tokenize(caption, tokenizer=self.tokenizer, text_length=self.text_length, truncate=self.truncate) 83 | 84 | ret = { 85 | 'img_path':img_path, 86 | 'caption':caption, 87 | 'pids': pid, 88 | 'image_ids': image_id, 89 | 'images': img, 90 | 'caption_ids': tokens, 91 | } 92 | 93 | return ret 94 | 95 | 96 | class ImageDataset(Dataset): 97 | def __init__(self, image_pids, img_paths, transform=None): 98 | self.image_pids = image_pids 99 | self.img_paths = img_paths 100 | self.transform = transform 101 | 102 | def __len__(self): 103 | return len(self.image_pids) 104 | 105 | def __getitem__(self, index): 106 | pid, img_path = self.image_pids[index], self.img_paths[index] 107 | img = read_image(img_path) 108 | if self.transform is not None: 109 | img = self.transform(img) 110 | return pid, img 111 | 112 | 113 | class TextDataset(Dataset): 114 | def __init__(self, 115 | caption_pids, 116 | captions, 117 | text_length: int = 77, 118 | truncate: bool = True): 119 | self.caption_pids = caption_pids 120 | self.captions = captions 121 | self.text_length = text_length 122 | self.truncate = truncate 123 | self.tokenizer = SimpleTokenizer() 124 | 125 | def __len__(self): 126 | return len(self.caption_pids) 127 | 128 | def __getitem__(self, index): 129 | pid, caption = self.caption_pids[index], self.captions[index] 130 | 131 | caption = tokenize(caption, tokenizer=self.tokenizer, text_length=self.text_length, truncate=self.truncate) 132 | 133 | return pid, caption 134 | 135 | def softmax(x): 136 | """Compute softmax values for each sets of scores in x.""" 137 | e_x = np.exp(x - np.max(x)) 138 | return e_x / e_x.sum() 139 | 140 | class ImageTextMLMDataset(Dataset): 141 | def __init__(self, 142 | dataset, 143 | transform=None, 144 | text_length: int = 77, 145 | truncate: bool = True): 146 | self.dataset = dataset 147 | self.transform = transform 148 | self.text_length = text_length 149 | self.truncate = truncate 150 | 151 | self.tokenizer = SimpleTokenizer() 152 | 153 | 154 | def __len__(self): 155 | return len(self.dataset) 156 | 157 | def __getitem__(self, index): 158 | pid, image_id, img_path, caption, sim = self.dataset[index] 159 | img = read_image(img_path) 160 | if self.transform is not None: 161 | img = self.transform(img) 162 | 163 | caption_tokens = tokenize(caption, tokenizer=self.tokenizer, text_length=self.text_length, truncate=self.truncate) 164 | mlm_tokens, mlm_labels = self._build_random_masked_tokens_and_labels(caption_tokens.cpu().numpy()) 165 | ret = { 166 | 'pids': pid, 167 | 'image_ids': image_id, 168 | 'images': img, 169 | 'caption_ids': caption_tokens, 170 | 'mlm_ids': mlm_tokens, 171 | 'mlm_labels': mlm_labels, 172 | } 173 | 174 | return ret 175 | 176 | def _build_random_masked_tokens_and_labels(self, tokens): 177 | """ 178 | Masking some random tokens for Language Model task with probabilities as in the original BERT paper. 179 | :param tokens: list of int, tokenized sentence. 180 | :return: (list of int, list of int), masked tokens and related labels for MLM prediction 181 | """ 182 | mask = self.tokenizer.encoder["<|mask|>"] 183 | token_range = list(range(1, len(self.tokenizer.encoder)-3)) # 1 ~ 49405 184 | 185 | labels = [] 186 | for i, token in enumerate(tokens): 187 | if 0 < token < 49405: 188 | prob = random.random() 189 | # mask token with 15% probability 190 | if prob < 0.15: 191 | prob /= 0.15 192 | 193 | # 80% randomly change token to mask token 194 | if prob < 0.8: 195 | tokens[i] = mask 196 | 197 | # 10% randomly change token to random token 198 | elif prob < 0.9: 199 | tokens[i] = random.choice(token_range) 200 | 201 | # -> rest 10% randomly keep current token 202 | 203 | # append current token to output (we will predict these later) 204 | labels.append(token) 205 | else: 206 | # no masking token (will be ignored by loss function later) 207 | labels.append(0) 208 | else: 209 | labels.append(0) 210 | 211 | if all(l == 0 for l in labels): 212 | # at least mask 1 213 | labels[1] = tokens[1] 214 | tokens[1] = mask 215 | 216 | return torch.tensor(tokens), torch.tensor(labels) 217 | 218 | class FilterDataset(Dataset): 219 | def __init__(self, 220 | dataset, 221 | transform=None, 222 | text_length: int = 77, 223 | truncate: bool = True): 224 | self.dataset = dataset 225 | self.transform = transform 226 | self.text_length = text_length 227 | self.truncate = truncate 228 | 229 | self.tokenizer = SimpleTokenizer() 230 | 231 | 232 | def __len__(self): 233 | return len(self.dataset) 234 | 235 | def __getitem__(self, index): 236 | pid, image_id, img_path, caption, sim = self.dataset[index] 237 | img = read_image(img_path) 238 | if self.transform is not None: 239 | img = self.transform(img) 240 | 241 | caption_tokens = tokenize(caption, tokenizer=self.tokenizer, text_length=self.text_length, truncate=self.truncate) 242 | mlm_tokens, mlm_labels = self._build_random_masked_tokens_and_labels(caption_tokens.cpu().numpy(), sim) 243 | ori_tokens = tokenize(caption, tokenizer=self.tokenizer, text_length=self.text_length, truncate=self.truncate) 244 | 245 | ret = { 246 | 'pids': pid, 247 | 'image_ids': image_id, 248 | 'images': img, 249 | 'caption_ids': caption_tokens, 250 | 'mlm_ids': mlm_tokens, 251 | 'mlm_labels': mlm_labels, 252 | 'caption_ids_ori':ori_tokens 253 | } 254 | 255 | return ret 256 | 257 | def _build_random_masked_tokens_and_labels(self, tokens, sim): 258 | """ 259 | Masking some random tokens for Language Model task with probabilities as in the original BERT paper. 260 | :param tokens: list of int, tokenized sentence. 261 | :return: (list of int, list of int), masked tokens and related labels for MLM prediction 262 | """ 263 | mask = self.tokenizer.encoder["<|mask|>"] 264 | token_range = list(range(1, len(self.tokenizer.encoder)-3)) # 1 ~ 49405 265 | 266 | labels = [] 267 | 268 | if tokens[-1] == 0: 269 | valid_token_num = np.where(tokens == 0)[0][0] 270 | else: 271 | valid_token_num = len(tokens) 272 | ori_sim = np.array(sim) 273 | ori_pro = 1 - ori_sim 274 | if ori_pro[-1] != 0.15: 275 | valid_prob = ori_pro[1:valid_token_num-1] 276 | # normalize the probisibility to match E = 0.15 277 | mean_prob = np.mean(valid_prob) 278 | normed_prob = valid_prob - mean_prob 279 | normalized_prob = normed_prob + 0.15 280 | normalized_prob = np.clip(normalized_prob, 0, 1) 281 | ori_pro[1:valid_token_num-1] = normalized_prob 282 | 283 | for i, token in enumerate(tokens): 284 | if 0 < token < 49405: 285 | prob = random.random() 286 | # mask token with 15% probability 287 | if prob < ori_pro[i]: 288 | prob /= ori_pro[i] 289 | 290 | # 80% randomly change token to mask token 291 | if prob < 0.8: 292 | tokens[i] = mask 293 | 294 | # 10% randomly change token to random token 295 | elif prob < 0.9: 296 | tokens[i] = random.choice(token_range) 297 | 298 | # -> rest 10% randomly keep current token 299 | 300 | # append current token to output (we will predict these later) 301 | labels.append(token) 302 | else: 303 | # no masking token (will be ignored by loss function later) 304 | labels.append(0) 305 | else: 306 | labels.append(0) 307 | 308 | if all(l == 0 for l in labels): 309 | # at least mask 1 310 | labels[1] = tokens[1] 311 | tokens[1] = mask 312 | 313 | return torch.tensor(tokens), torch.tensor(labels) -------------------------------------------------------------------------------- /datasets/build.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | import torchvision.transforms as T 4 | from torch.utils.data import DataLoader 5 | from datasets.luperson import LuPerson_PEDES 6 | from datasets.sampler import RandomIdentitySampler 7 | from datasets.sampler_ddp import RandomIdentitySampler_DDP 8 | from torch.utils.data.distributed import DistributedSampler 9 | 10 | from utils.comm import get_world_size 11 | 12 | from .bases import FilterDataset, ImageDataset, TextDataset, ImageTextDataset, ImageTextMLMDataset 13 | 14 | from .cuhkpedes import CUHKPEDES 15 | from .icfgpedes import ICFGPEDES 16 | from .rstpreid import RSTPReid 17 | 18 | __factory = {'CUHK-PEDES': CUHKPEDES, 'ICFG-PEDES': ICFGPEDES, 'RSTPReid': RSTPReid, 19 | 'LuPerson_PEDES':LuPerson_PEDES,} 20 | 21 | def build_transforms(img_size=(384, 128), aug=False, is_train=True): 22 | height, width = img_size 23 | 24 | mean = [0.48145466, 0.4578275, 0.40821073] 25 | std = [0.26862954, 0.26130258, 0.27577711] 26 | 27 | if not is_train: 28 | transform = T.Compose([ 29 | T.Resize((height, width)), 30 | T.ToTensor(), 31 | T.Normalize(mean=mean, std=std), 32 | ]) 33 | return transform 34 | 35 | # transform for training 36 | if aug: 37 | transform = T.Compose([ 38 | T.Resize((height, width)), 39 | T.RandomHorizontalFlip(0.5), 40 | T.Pad(10), 41 | T.RandomCrop((height, width)), 42 | T.ToTensor(), 43 | T.Normalize(mean=mean, std=std), 44 | T.RandomErasing(scale=(0.02, 0.4), value=mean), 45 | ]) 46 | else: 47 | transform = T.Compose([ 48 | T.Resize((height, width)), 49 | T.RandomHorizontalFlip(0.5), 50 | T.ToTensor(), 51 | T.Normalize(mean=mean, std=std), 52 | ]) 53 | return transform 54 | 55 | 56 | def collate(batch): 57 | keys = set([key for b in batch for key in b.keys()]) 58 | # turn list of dicts data structure to dict of lists data structure 59 | dict_batch = {k: [dic[k] if k in dic else None for dic in batch] for k in keys} 60 | 61 | batch_tensor_dict = {} 62 | for k, v in dict_batch.items(): 63 | if isinstance(v[0], int): 64 | batch_tensor_dict.update({k: torch.tensor(v)}) 65 | elif torch.is_tensor(v[0]): 66 | batch_tensor_dict.update({k: torch.stack(v)}) 67 | else: 68 | raise TypeError(f"Unexpect data type: {type(v[0])} in a batch.") 69 | 70 | return batch_tensor_dict 71 | 72 | def build_dataloader(args, tranforms=None): 73 | logger = logging.getLogger("IRRA.dataset") 74 | 75 | num_workers = args.num_workers 76 | dataset = __factory[args.dataset_name](root=args.root_dir) 77 | num_classes = len(dataset.train_id_container) 78 | 79 | if args.training: 80 | train_transforms = build_transforms(img_size=args.img_size, 81 | aug=args.img_aug, 82 | is_train=True) 83 | val_transforms = build_transforms(img_size=args.img_size, 84 | is_train=False) 85 | 86 | if args.MLM: 87 | if args.pretrain: 88 | syn_dataset = __factory[args.pretrain](root=args.root_dir) 89 | train_set = ImageTextMLMDataset(syn_dataset.train, 90 | train_transforms, 91 | text_length=args.text_length) 92 | num_classes = len(syn_dataset.train) 93 | else: 94 | train_set = ImageTextMLMDataset(dataset.train, 95 | train_transforms, 96 | text_length=args.text_length) 97 | else: 98 | train_set = ImageTextDataset(dataset.train, 99 | train_transforms, 100 | text_length=args.text_length) 101 | 102 | if args.sampler == 'identity': 103 | if args.distributed: 104 | logger.info('using ddp random identity sampler') 105 | logger.info('DISTRIBUTED TRAIN START') 106 | mini_batch_size = args.batch_size // get_world_size() 107 | # TODO wait to fix bugs 108 | data_sampler = RandomIdentitySampler_DDP( 109 | dataset.train, args.batch_size, args.num_instance) 110 | batch_sampler = torch.utils.data.sampler.BatchSampler( 111 | data_sampler, mini_batch_size, True) 112 | 113 | else: 114 | logger.info( 115 | f'using random identity sampler: batch_size: {args.batch_size}, id: {args.batch_size // args.num_instance}, instance: {args.num_instance}' 116 | ) 117 | train_loader = DataLoader(train_set, 118 | batch_size=args.batch_size, 119 | sampler=RandomIdentitySampler( 120 | dataset.train, args.batch_size, 121 | args.num_instance), 122 | num_workers=num_workers, 123 | collate_fn=collate) 124 | elif args.sampler == 'random': 125 | # TODO add distributed condition 126 | logger.info('using random sampler') 127 | train_loader = DataLoader(train_set, 128 | batch_size=args.batch_size, 129 | shuffle=True, 130 | num_workers=num_workers, 131 | collate_fn=collate) 132 | else: 133 | logger.error('unsupported sampler! expected softmax or triplet but got {}'.format(args.sampler)) 134 | 135 | # use test set as validate set 136 | ds = dataset.val if args.val_dataset == 'val' else dataset.test 137 | val_img_set = ImageDataset(ds['image_pids'], ds['img_paths'], 138 | val_transforms) 139 | val_txt_set = TextDataset(ds['caption_pids'], 140 | ds['captions'], 141 | text_length=args.text_length) 142 | # val_txt_set2 = TextDataset(ds['caption_pids'], 143 | # ds['inblip'], 144 | # text_length=args.text_length) 145 | 146 | val_img_loader = DataLoader(val_img_set, 147 | batch_size=args.batch_size, 148 | shuffle=False, 149 | num_workers=num_workers) 150 | val_txt_loader = DataLoader(val_txt_set, 151 | batch_size=args.batch_size, 152 | shuffle=False, 153 | num_workers=num_workers) 154 | # val_txt_loader2 = DataLoader(val_txt_set2, 155 | # batch_size=args.batch_size, 156 | # shuffle=False, 157 | # num_workers=num_workers) 158 | 159 | return train_loader, val_img_loader, val_txt_loader, num_classes 160 | else: 161 | # build dataloader for testing 162 | if tranforms: 163 | test_transforms = tranforms 164 | else: 165 | test_transforms = build_transforms(img_size=args.img_size, 166 | is_train=False) 167 | 168 | ds = dataset.test 169 | test_img_set = ImageDataset(ds['image_pids'], ds['img_paths'], 170 | test_transforms) 171 | test_txt_set = TextDataset(ds['caption_pids'], 172 | ds['captions'], 173 | text_length=args.text_length) 174 | 175 | test_img_loader = DataLoader(test_img_set, 176 | batch_size=args.test_batch_size, 177 | shuffle=False, 178 | num_workers=num_workers) 179 | test_txt_loader = DataLoader(test_txt_set, 180 | batch_size=args.test_batch_size, 181 | shuffle=False, 182 | num_workers=num_workers) 183 | return test_img_loader, test_txt_loader, num_classes 184 | 185 | 186 | def build_zero_shot_loader(args, finetune=False): 187 | logger = logging.getLogger("IRRA.dataset") 188 | 189 | num_workers = args.num_workers 190 | dataset0 = __factory['CUHK-PEDES'](root=args.root_dir) 191 | dataset1 = __factory['ICFG-PEDES'](root=args.root_dir) 192 | dataset2 = __factory['RSTPReid'](root=args.root_dir) 193 | 194 | train_transforms = build_transforms(img_size=args.img_size, 195 | aug=args.img_aug, 196 | is_train=True) 197 | val_transforms = build_transforms(img_size=args.img_size, 198 | is_train=False) 199 | 200 | ds = dataset0.test 201 | val_img_set = ImageDataset(ds['image_pids'], ds['img_paths'], 202 | val_transforms) 203 | val_txt_set = TextDataset(ds['caption_pids'], 204 | ds['captions'], 205 | text_length=args.text_length) 206 | val_img_loader0 = DataLoader(val_img_set, 207 | batch_size=args.batch_size, 208 | shuffle=False, 209 | num_workers=num_workers) 210 | val_txt_loader0 = DataLoader(val_txt_set, 211 | batch_size=args.batch_size, 212 | shuffle=False, 213 | num_workers=num_workers) 214 | 215 | ds = dataset1.test 216 | val_img_set = ImageDataset(ds['image_pids'], ds['img_paths'], 217 | val_transforms) 218 | val_txt_set = TextDataset(ds['caption_pids'], 219 | ds['captions'], 220 | text_length=args.text_length) 221 | val_img_loader1 = DataLoader(val_img_set, 222 | batch_size=args.batch_size, 223 | shuffle=False, 224 | num_workers=num_workers) 225 | val_txt_loader1 = DataLoader(val_txt_set, 226 | batch_size=args.batch_size, 227 | shuffle=False, 228 | num_workers=num_workers) 229 | 230 | ds = dataset2.test 231 | val_img_set = ImageDataset(ds['image_pids'], ds['img_paths'], 232 | val_transforms) 233 | val_txt_set = TextDataset(ds['caption_pids'], 234 | ds['captions'], 235 | text_length=args.text_length) 236 | val_img_loader2 = DataLoader(val_img_set, 237 | batch_size=args.batch_size, 238 | shuffle=False, 239 | num_workers=num_workers) 240 | val_txt_loader2 = DataLoader(val_txt_set, 241 | batch_size=args.batch_size, 242 | shuffle=False, 243 | num_workers=num_workers) 244 | if finetune: 245 | syn_dataset = __factory[args.dataset_name](root=args.root_dir) 246 | else: 247 | syn_dataset = __factory[args.pretrain](root=args.root_dir) 248 | train_set = ImageTextMLMDataset(syn_dataset.train, 249 | train_transforms, 250 | text_length=args.text_length) 251 | num_classes = len(syn_dataset.train) 252 | 253 | logger.info('using random sampler') 254 | train_loader = DataLoader(train_set, 255 | batch_size=args.batch_size, 256 | shuffle=True, 257 | num_workers=num_workers, 258 | ) 259 | 260 | return syn_dataset.train, train_loader, val_img_loader0, val_txt_loader0, val_img_loader1, val_txt_loader1, val_img_loader2, val_txt_loader2, num_classes 261 | 262 | def build_filter_loader(args, dataset): 263 | logger = logging.getLogger("IRRA.dataset") 264 | 265 | num_workers = args.num_workers 266 | 267 | train_transforms = build_transforms(img_size=args.img_size, 268 | aug=args.img_aug, 269 | is_train=True) 270 | train_set = FilterDataset(dataset, 271 | train_transforms, 272 | text_length=args.text_length) 273 | train_loader = DataLoader(train_set, 274 | batch_size=args.batch_size, 275 | shuffle=True, 276 | num_workers=num_workers) 277 | 278 | return train_loader 279 | -------------------------------------------------------------------------------- /datasets/cuhkpedes.py: -------------------------------------------------------------------------------- 1 | import os.path as op 2 | from typing import List 3 | from utils.iotools import read_json 4 | from .bases import BaseDataset 5 | import random 6 | 7 | class CUHKPEDES(BaseDataset): 8 | """ 9 | CUHK-PEDES 10 | 11 | Reference: 12 | Person Search With Natural Language Description (CVPR 2017) 13 | 14 | URL: https://openaccess.thecvf.com/content_cvpr_2017/html/Li_Person_Search_With_CVPR_2017_paper.html 15 | 16 | Dataset statistics: 17 | ### identities: 13003 18 | ### images: 40206, (train) (test) (val) 19 | ### captions: 20 | ### 9 images have more than 2 captions 21 | ### 4 identity have only one image 22 | 23 | annotation format: 24 | [{'split', str, 25 | 'captions', list, 26 | 'file_path', str, 27 | 'processed_tokens', list, 28 | 'id', int}...] 29 | """ 30 | dataset_dir = 'CUHK-PEDES' 31 | 32 | def __init__(self, root='', verbose=True): 33 | super(CUHKPEDES, self).__init__() 34 | self.dataset_dir = op.join(root, self.dataset_dir) 35 | self.img_dir = op.join(self.dataset_dir, 'imgs/') 36 | 37 | self.anno_path = op.join(self.dataset_dir, 'reid_raw.json') 38 | self._check_before_run() 39 | 40 | self.train_annos, self.test_annos, self.val_annos = self._split_anno(self.anno_path) 41 | self.train, self.train_id_container = self._process_anno(self.train_annos, training=True) 42 | self.test, self.test_id_container = self._process_anno(self.test_annos) 43 | self.val, self.val_id_container = self._process_anno(self.val_annos) 44 | 45 | if verbose: 46 | self.logger.info("=> CUHK-PEDES Images and Captions are loaded") 47 | self.show_dataset_info() 48 | 49 | 50 | def _split_anno(self, anno_path: str): 51 | train_annos, test_annos, val_annos = [], [], [] 52 | annos = read_json(anno_path) 53 | for anno in annos: 54 | if anno['split'] == 'train': 55 | train_annos.append(anno) 56 | elif anno['split'] == 'test': 57 | test_annos.append(anno) 58 | else: 59 | val_annos.append(anno) 60 | return train_annos, test_annos, val_annos 61 | 62 | 63 | def _process_anno(self, annos: List[dict], training=False): 64 | 65 | pid_container = set() 66 | if training: 67 | dataset = [] 68 | image_id = 0 69 | for anno in annos: 70 | pid = int(anno['id']) - 1 # make pid begin from 0 71 | pid_container.add(pid) 72 | img_path = op.join(self.img_dir, anno['file_path']) 73 | captions = anno['captions'] # caption list 74 | for caption in captions: 75 | dataset.append((pid, image_id, img_path, caption)) 76 | 77 | image_id += 1 78 | for idx, pid in enumerate(pid_container): 79 | # check pid begin from 0 and no break 80 | assert idx == pid, f"idx: {idx} and pid: {pid} are not match" 81 | return dataset, pid_container 82 | else: 83 | dataset = {} 84 | img_paths = [] 85 | captions = [] 86 | image_pids = [] 87 | caption_pids = [] 88 | for anno in annos: 89 | pid = int(anno['id']) 90 | pid_container.add(pid) 91 | img_path = op.join(self.img_dir, anno['file_path']) 92 | img_paths.append(img_path) 93 | image_pids.append(pid) 94 | caption_list = anno['captions'] # caption list 95 | for caption in caption_list: 96 | captions.append(caption) 97 | caption_pids.append(pid) 98 | dataset = { 99 | "image_pids": image_pids, 100 | "img_paths": img_paths, 101 | "caption_pids": caption_pids, 102 | "captions": captions 103 | } 104 | return dataset, pid_container 105 | 106 | 107 | def _check_before_run(self): 108 | """Check if all files are available before going deeper""" 109 | if not op.exists(self.dataset_dir): 110 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 111 | if not op.exists(self.img_dir): 112 | raise RuntimeError("'{}' is not available".format(self.img_dir)) 113 | if not op.exists(self.anno_path): 114 | raise RuntimeError("'{}' is not available".format(self.anno_path)) 115 | import re 116 | 117 | def remove_punctuation_and_spaces(text): 118 | # 使用正则表达式去掉标点符号和空格 119 | cleaned_text = re.sub(r'[^\w\s]', ' ', text) 120 | # cleaned_text = re.sub(r'\s+', '', cleaned_text) 121 | return cleaned_text -------------------------------------------------------------------------------- /datasets/icfgpedes.py: -------------------------------------------------------------------------------- 1 | import os.path as op 2 | from typing import List 3 | 4 | from utils.iotools import read_json 5 | from .bases import BaseDataset 6 | 7 | 8 | class ICFGPEDES(BaseDataset): 9 | """ 10 | ICFG-PEDES 11 | 12 | Reference: 13 | Semantically Self-Aligned Network for Text-to-Image Part-aware Person Re-identification arXiv 2107 14 | 15 | URL: http://arxiv.org/abs/2107.12666 16 | 17 | Dataset statistics: 18 | # identities: 4102 19 | # images: 34674 (train) + 4855 (query) + 14993 (gallery) 20 | # cameras: 15 21 | """ 22 | dataset_dir = 'ICFG-PEDES' 23 | 24 | def __init__(self, root='', verbose=True): 25 | super(ICFGPEDES, self).__init__() 26 | self.dataset_dir = op.join(root, self.dataset_dir) 27 | self.img_dir = op.join(self.dataset_dir, 'imgs/') 28 | self.anno_path = op.join(self.dataset_dir, 'ICFG-PEDES.json') 29 | self._check_before_run() 30 | 31 | self.train_annos, self.test_annos, self.val_annos = self._split_anno(self.anno_path) 32 | 33 | self.train, self.train_id_container = self._process_anno(self.train_annos, training=True) 34 | self.test, self.test_id_container = self._process_anno(self.test_annos) 35 | self.val, self.val_id_container = self._process_anno(self.val_annos) 36 | 37 | if verbose: 38 | self.logger.info("=> ICFG-PEDES Images and Captions are loaded") 39 | self.show_dataset_info() 40 | 41 | 42 | def _split_anno(self, anno_path: str): 43 | train_annos, test_annos, val_annos = [], [], [] 44 | annos = read_json(anno_path) 45 | for anno in annos: 46 | if anno['split'] == 'train': 47 | train_annos.append(anno) 48 | elif anno['split'] == 'test': 49 | test_annos.append(anno) 50 | else: 51 | val_annos.append(anno) 52 | return train_annos, test_annos, val_annos 53 | 54 | 55 | def _process_anno(self, annos: List[dict], training=False): 56 | pid_container = set() 57 | if training: 58 | dataset = [] 59 | image_id = 0 60 | for anno in annos: 61 | pid = int(anno['id']) 62 | pid_container.add(pid) 63 | img_path = op.join(self.img_dir, anno['file_path']) 64 | captions = anno['captions'] # caption list 65 | for caption in captions: 66 | dataset.append((pid, image_id, img_path, caption)) 67 | 68 | image_id += 1 69 | 70 | for idx, pid in enumerate(pid_container): 71 | # check pid begin from 0 and no break 72 | assert idx == pid, f"idx: {idx} and pid: {pid} are not match" 73 | return dataset, pid_container 74 | else: 75 | dataset = {} 76 | img_paths = [] 77 | captions = [] 78 | image_pids = [] 79 | caption_pids = [] 80 | for anno in annos: 81 | pid = int(anno['id']) 82 | pid_container.add(pid) 83 | img_path = op.join(self.img_dir, anno['file_path']) 84 | img_paths.append(img_path) 85 | image_pids.append(pid) 86 | caption_list = anno['captions'] # caption list 87 | for caption in caption_list: 88 | captions.append(caption) 89 | caption_pids.append(pid) 90 | dataset = { 91 | "image_pids": image_pids, 92 | "img_paths": img_paths, 93 | "caption_pids": caption_pids, 94 | "captions": captions 95 | } 96 | return dataset, pid_container 97 | 98 | 99 | def _check_before_run(self): 100 | """Check if all files are available before going deeper""" 101 | if not op.exists(self.dataset_dir): 102 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 103 | if not op.exists(self.img_dir): 104 | raise RuntimeError("'{}' is not available".format(self.img_dir)) 105 | if not op.exists(self.anno_path): 106 | raise RuntimeError("'{}' is not available".format(self.anno_path)) 107 | import re 108 | 109 | def remove_punctuation_and_spaces(text): 110 | # 使用正则表达式去掉标点符号和空格 111 | cleaned_text = re.sub(r'[^\w\s]', ' ', text) 112 | # cleaned_text = re.sub(r'\s+', '', cleaned_text) 113 | return cleaned_text -------------------------------------------------------------------------------- /datasets/luperson.py: -------------------------------------------------------------------------------- 1 | import os.path as op 2 | import random 3 | from typing import List 4 | 5 | from utils.iotools import read_json 6 | from .bases import BaseDataset 7 | 8 | import os 9 | import json 10 | from prettytable import PrettyTable 11 | import collections 12 | import numpy as np 13 | 14 | class LuPerson_PEDES(BaseDataset): 15 | dataset_dir = 'LUPerson_images' 16 | def __init__(self, root='', verbose=True): 17 | super(LuPerson_PEDES, self).__init__() 18 | self.dataset_dir = '/data0/wentao/data/LuPerson-T/LUPerson_images' 19 | self.image_dir = op.join(self.dataset_dir, 'LUPerson-MLLM') 20 | self.caption_dir = self.dataset_dir 21 | self.train_img_paths = [] 22 | self.train_cap_paths = [] 23 | 24 | self.test_img_paths = [] 25 | self.test_cap_paths = [] 26 | 27 | for filename in os.listdir(self.image_dir): # part1234 28 | image_path = os.path.join(self.image_dir, filename) 29 | if filename.endswith('.jpg'): 30 | self.train_img_paths.append(image_path) 31 | for filename in os.listdir(self.caption_dir): 32 | caption_path = os.path.join(self.caption_dir, filename) 33 | if filename.endswith('.json'): 34 | self.train_cap_paths.append(caption_path) 35 | 36 | train_cap_dict = self._merged_multi_json_file(self.train_cap_paths) 37 | test_cap_dict = self._merged_json_file(self.test_cap_paths) 38 | 39 | self.train, self.train_id_container, self.part_dataset, num_caption,self.fpath2part_cap,self.fpaht2sim = self._get_dataset(self.train_img_paths, train_cap_dict) 40 | self.test = self._get_test_dataset(self.test_img_paths, test_cap_dict) 41 | 42 | self.logger.info("=> LuPerson-MLLM Images and Captions are loaded") 43 | self.logger.info("LuPerson-MLLM Dataset statistics:") 44 | table = PrettyTable(['subset', 'ids', 'images', 'captions']) 45 | table.add_row(['train', len(set(self.train_id_container)),len(self.train), num_caption]) 46 | table.add_row(['test', len(self.test["image_pids"]),len(self.test["image_pids"]), len(self.test["image_pids"])]) 47 | self.logger.info('\n' + str(table)) 48 | 49 | 50 | def _merged_json_file(self, json_path_list): 51 | merged_dict = {} 52 | 53 | # 逐个读取JSON文件并合并到字典中 54 | for file_path in json_path_list: 55 | with open(file_path, 'r') as json_file: 56 | data = json.load(json_file) 57 | merged_dict.update(data) 58 | return merged_dict 59 | 60 | def _merged_multi_json_file(self, json_path_list): 61 | merged_dict = collections.defaultdict(list) 62 | json_path_list = [ 63 | "./caption/Ts-qwen.json", 64 | "./caption/Td-qwen.json", 65 | "./caption/Ts-shikra.json", 66 | "./caption/Td-shikra.json", 67 | ] 68 | for file_path in json_path_list: 69 | with open(file_path, 'r') as json_file: 70 | data = json.load(json_file) 71 | print(file_path, len(data)) 72 | for k,v in data.items(): 73 | img_name = k.split('/')[-1] 74 | merged_dict[img_name].append(v) 75 | return merged_dict 76 | 77 | def _get_test_dataset(self, test_img_paths, cap_dict): 78 | dataset = {} 79 | img_paths = [] 80 | captions = [] 81 | image_pids = [] 82 | caption_pids = [] 83 | for i in range(len(test_img_paths)): 84 | pid = i 85 | img_path = test_img_paths[i] 86 | img_paths.append(img_path) 87 | image_pids.append(pid) 88 | path2cap = '/'.join(img_path.split('/')[-1]) 89 | caption = cap_dict[path2cap][0] 90 | captions.append(caption) 91 | caption_pids.append(pid) 92 | dataset = { 93 | "image_pids": image_pids, 94 | "img_paths": img_paths, 95 | "caption_pids": caption_pids, 96 | "captions": captions 97 | } 98 | return dataset 99 | 100 | def _get_dataset(self, img_paths, cap_dict): 101 | safe_dict = collections.defaultdict(list) 102 | with open('./caption/Ts-shikra.json', 'r') as json_file: 103 | data = json.load(json_file) 104 | for k,v in data.items(): 105 | img_name = k.split('/')[-1] 106 | safe_dict[img_name].append(v) 107 | 108 | with open('./caption/Ts-qwen.json', 'r') as json_file: 109 | data = json.load(json_file) 110 | for k,v in data.items(): 111 | img_name = k.split('/')[-1] 112 | safe_dict[img_name].append(v) 113 | pid_container = set() 114 | img_paths = sorted(img_paths) 115 | 116 | dataset = [] 117 | part_dataset = [] 118 | idx_count = 0 119 | pid_count = 0 120 | num_caption = 0 121 | 122 | fpath2part_cap = {} 123 | fpaht2sim = {} 124 | for i in range(len(img_paths)): 125 | img_path = img_paths[i] 126 | 127 | path2cap = img_path.split('/')[-1] 128 | caption = cap_dict[path2cap] 129 | 130 | # if len(caption) != 4: 131 | # continue 132 | fpath2part_cap[img_path] = {} 133 | fpaht2sim[img_path] = {} 134 | pid = pid_count 135 | image_id = idx_count 136 | pid_container.add(pid) 137 | for cap in caption: 138 | if 'description]' in cap or '<' in cap: 139 | try: 140 | cap = random.choice(safe_dict[path2cap]) 141 | except: 142 | pass 143 | part2sim = 77 * [1- 0.15] 144 | part2sim = np.array(part2sim) 145 | dataset.append([pid,idx_count,img_path, cap, part2sim]) 146 | num_caption += 1 147 | idx_count += 1 148 | pid_count += 1 149 | assert idx_count == len(dataset) 150 | 151 | return dataset, pid_container, part_dataset,num_caption,fpath2part_cap,fpaht2sim -------------------------------------------------------------------------------- /datasets/preprocessing.py: -------------------------------------------------------------------------------- 1 | import random 2 | import math 3 | 4 | 5 | class RandomErasing(object): 6 | """ Randomly selects a rectangle region in an image and erases its pixels. 7 | 'Random Erasing Data Augmentation' by Zhong et al. 8 | See https://arxiv.org/pdf/1708.04896.pdf 9 | Args: 10 | probability: The probability that the Random Erasing operation will be performed. 11 | sl: Minimum proportion of erased area against input image. 12 | sh: Maximum proportion of erased area against input image. 13 | r1: Minimum aspect ratio of erased area. 14 | mean: Erasing value. 15 | """ 16 | 17 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)): 18 | self.probability = probability 19 | self.mean = mean 20 | self.sl = sl 21 | self.sh = sh 22 | self.r1 = r1 23 | 24 | def __call__(self, img): 25 | 26 | if random.uniform(0, 1) >= self.probability: 27 | return img 28 | 29 | for attempt in range(100): 30 | area = img.size()[1] * img.size()[2] 31 | 32 | target_area = random.uniform(self.sl, self.sh) * area 33 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 34 | 35 | h = int(round(math.sqrt(target_area * aspect_ratio))) 36 | w = int(round(math.sqrt(target_area / aspect_ratio))) 37 | 38 | if w < img.size()[2] and h < img.size()[1]: 39 | x1 = random.randint(0, img.size()[1] - h) 40 | y1 = random.randint(0, img.size()[2] - w) 41 | if img.size()[0] == 3: 42 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 43 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 44 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 45 | else: 46 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 47 | return img 48 | 49 | return img 50 | 51 | -------------------------------------------------------------------------------- /datasets/rstpreid.py: -------------------------------------------------------------------------------- 1 | import os.path as op 2 | from typing import List 3 | 4 | from utils.iotools import read_json 5 | from .bases import BaseDataset 6 | 7 | 8 | class RSTPReid(BaseDataset): 9 | """ 10 | RSTPReid 11 | 12 | Reference: 13 | DSSL: Deep Surroundings-person Separation Learning for Text-based Person Retrieval MM 21 14 | 15 | URL: http://arxiv.org/abs/2109.05534 16 | 17 | Dataset statistics: 18 | # identities: 4101 19 | """ 20 | dataset_dir = 'RSTPReid' 21 | 22 | def __init__(self, root='', verbose=True): 23 | super(RSTPReid, self).__init__() 24 | self.dataset_dir = op.join(root, self.dataset_dir) 25 | self.img_dir = op.join(self.dataset_dir, 'imgs/') 26 | 27 | self.anno_path = op.join(self.dataset_dir, 'data_captions.json') 28 | self._check_before_run() 29 | 30 | self.train_annos, self.test_annos, self.val_annos = self._split_anno(self.anno_path) 31 | 32 | self.train, self.train_id_container = self._process_anno(self.train_annos, training=True) 33 | self.test, self.test_id_container = self._process_anno(self.test_annos) 34 | self.val, self.val_id_container = self._process_anno(self.val_annos) 35 | 36 | if verbose: 37 | self.logger.info("=> RSTPReid Images and Captions are loaded") 38 | self.show_dataset_info() 39 | 40 | 41 | def _split_anno(self, anno_path: str): 42 | train_annos, test_annos, val_annos = [], [], [] 43 | annos = read_json(anno_path) 44 | for anno in annos: 45 | if anno['split'] == 'train': 46 | train_annos.append(anno) 47 | elif anno['split'] == 'test': 48 | test_annos.append(anno) 49 | else: 50 | val_annos.append(anno) 51 | return train_annos, test_annos, val_annos 52 | 53 | 54 | def _process_anno(self, annos: List[dict], training=False): 55 | pid_container = set() 56 | if training: 57 | dataset = [] 58 | image_id = 0 59 | for anno in annos: 60 | pid = int(anno['id']) 61 | pid_container.add(pid) 62 | img_path = op.join(self.img_dir, anno['img_path']) 63 | captions = anno['captions'] # caption list 64 | for caption in captions: 65 | dataset.append((pid, image_id, img_path, caption)) 66 | image_id += 1 67 | for idx, pid in enumerate(pid_container): 68 | # check pid begin from 0 and no break 69 | assert idx == pid, f"idx: {idx} and pid: {pid} are not match" 70 | return dataset, pid_container 71 | else: 72 | dataset = {} 73 | img_paths = [] 74 | captions = [] 75 | image_pids = [] 76 | caption_pids = [] 77 | for anno in annos: 78 | pid = int(anno['id']) 79 | pid_container.add(pid) 80 | img_path = op.join(self.img_dir, anno['img_path']) 81 | img_paths.append(img_path) 82 | image_pids.append(pid) 83 | caption_list = anno['captions'] # caption list 84 | for caption in caption_list: 85 | captions.append(caption) 86 | caption_pids.append(pid) 87 | dataset = { 88 | "image_pids": image_pids, 89 | "img_paths": img_paths, 90 | "caption_pids": caption_pids, 91 | "captions": captions 92 | } 93 | return dataset, pid_container 94 | 95 | 96 | def _check_before_run(self): 97 | """Check if all files are available before going deeper""" 98 | if not op.exists(self.dataset_dir): 99 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 100 | if not op.exists(self.img_dir): 101 | raise RuntimeError("'{}' is not available".format(self.img_dir)) 102 | if not op.exists(self.anno_path): 103 | raise RuntimeError("'{}' is not available".format(self.anno_path)) 104 | -------------------------------------------------------------------------------- /datasets/sampler.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.sampler import Sampler 2 | from collections import defaultdict 3 | import copy 4 | import random 5 | import numpy as np 6 | 7 | class RandomIdentitySampler(Sampler): 8 | """ 9 | Randomly sample N identities, then for each identity, 10 | randomly sample K instances, therefore batch size is N*K. 11 | Args: 12 | - data_source (list): list of (img_path, pid, camid). 13 | - num_instances (int): number of instances per identity in a batch. 14 | - batch_size (int): number of examples in a batch. 15 | """ 16 | 17 | def __init__(self, data_source, batch_size, num_instances): 18 | self.data_source = data_source 19 | self.batch_size = batch_size 20 | self.num_instances = num_instances 21 | self.num_pids_per_batch = self.batch_size // self.num_instances 22 | self.index_dic = defaultdict(list) #dict with list value 23 | #{783: [0, 5, 116, 876, 1554, 2041],...,} 24 | for index, (pid, _, _, _) in enumerate(self.data_source): 25 | self.index_dic[pid].append(index) 26 | self.pids = list(self.index_dic.keys()) 27 | 28 | # estimate number of examples in an epoch 29 | self.length = 0 30 | for pid in self.pids: 31 | idxs = self.index_dic[pid] 32 | num = len(idxs) 33 | if num < self.num_instances: 34 | num = self.num_instances 35 | self.length += num - num % self.num_instances 36 | 37 | def __iter__(self): 38 | batch_idxs_dict = defaultdict(list) 39 | 40 | for pid in self.pids: 41 | idxs = copy.deepcopy(self.index_dic[pid]) 42 | if len(idxs) < self.num_instances: 43 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True) 44 | random.shuffle(idxs) 45 | batch_idxs = [] 46 | for idx in idxs: 47 | batch_idxs.append(idx) 48 | if len(batch_idxs) == self.num_instances: 49 | batch_idxs_dict[pid].append(batch_idxs) 50 | batch_idxs = [] 51 | 52 | avai_pids = copy.deepcopy(self.pids) 53 | final_idxs = [] 54 | 55 | while len(avai_pids) >= self.num_pids_per_batch: 56 | selected_pids = random.sample(avai_pids, self.num_pids_per_batch) 57 | for pid in selected_pids: 58 | batch_idxs = batch_idxs_dict[pid].pop(0) 59 | final_idxs.extend(batch_idxs) 60 | if len(batch_idxs_dict[pid]) == 0: 61 | avai_pids.remove(pid) 62 | 63 | return iter(final_idxs) 64 | 65 | def __len__(self): 66 | return self.length 67 | 68 | -------------------------------------------------------------------------------- /datasets/sampler_ddp.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.sampler import Sampler 2 | from collections import defaultdict 3 | import copy 4 | import random 5 | import numpy as np 6 | import math 7 | import torch.distributed as dist 8 | _LOCAL_PROCESS_GROUP = None 9 | import torch 10 | import pickle 11 | 12 | def _get_global_gloo_group(): 13 | """ 14 | Return a process group based on gloo backend, containing all the ranks 15 | The result is cached. 16 | """ 17 | if dist.get_backend() == "nccl": 18 | return dist.new_group(backend="gloo") 19 | else: 20 | return dist.group.WORLD 21 | 22 | def _serialize_to_tensor(data, group): 23 | backend = dist.get_backend(group) 24 | assert backend in ["gloo", "nccl"] 25 | device = torch.device("cpu" if backend == "gloo" else "cuda") 26 | 27 | buffer = pickle.dumps(data) 28 | if len(buffer) > 1024 ** 3: 29 | print( 30 | "Rank {} trying to all-gather {:.2f} GB of data on device {}".format( 31 | dist.get_rank(), len(buffer) / (1024 ** 3), device 32 | ) 33 | ) 34 | storage = torch.ByteStorage.from_buffer(buffer) 35 | tensor = torch.ByteTensor(storage).to(device=device) 36 | return tensor 37 | 38 | def _pad_to_largest_tensor(tensor, group): 39 | """ 40 | Returns: 41 | list[int]: size of the tensor, on each rank 42 | Tensor: padded tensor that has the max size 43 | """ 44 | world_size = dist.get_world_size(group=group) 45 | assert ( 46 | world_size >= 1 47 | ), "comm.gather/all_gather must be called from ranks within the given group!" 48 | local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device) 49 | size_list = [ 50 | torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size) 51 | ] 52 | dist.all_gather(size_list, local_size, group=group) 53 | size_list = [int(size.item()) for size in size_list] 54 | 55 | max_size = max(size_list) 56 | 57 | # we pad the tensor because torch all_gather does not support 58 | # gathering tensors of different shapes 59 | if local_size != max_size: 60 | padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device) 61 | tensor = torch.cat((tensor, padding), dim=0) 62 | return size_list, tensor 63 | 64 | def all_gather(data, group=None): 65 | """ 66 | Run all_gather on arbitrary picklable data (not necessarily tensors). 67 | Args: 68 | data: any picklable object 69 | group: a torch process group. By default, will use a group which 70 | contains all ranks on gloo backend. 71 | Returns: 72 | list[data]: list of data gathered from each rank 73 | """ 74 | if dist.get_world_size() == 1: 75 | return [data] 76 | if group is None: 77 | group = _get_global_gloo_group() 78 | if dist.get_world_size(group) == 1: 79 | return [data] 80 | 81 | tensor = _serialize_to_tensor(data, group) 82 | 83 | size_list, tensor = _pad_to_largest_tensor(tensor, group) 84 | max_size = max(size_list) 85 | 86 | # receiving Tensor from all ranks 87 | tensor_list = [ 88 | torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list 89 | ] 90 | dist.all_gather(tensor_list, tensor, group=group) 91 | 92 | data_list = [] 93 | for size, tensor in zip(size_list, tensor_list): 94 | buffer = tensor.cpu().numpy().tobytes()[:size] 95 | data_list.append(pickle.loads(buffer)) 96 | 97 | return data_list 98 | 99 | def shared_random_seed(): 100 | """ 101 | Returns: 102 | int: a random number that is the same across all workers. 103 | If workers need a shared RNG, they can use this shared seed to 104 | create one. 105 | All workers must call this function, otherwise it will deadlock. 106 | """ 107 | ints = np.random.randint(2 ** 31) 108 | all_ints = all_gather(ints) 109 | return all_ints[0] 110 | 111 | class RandomIdentitySampler_DDP(Sampler): 112 | """ 113 | Randomly sample N identities, then for each identity, 114 | randomly sample K instances, therefore batch size is N*K. 115 | Args: 116 | - data_source (list): list of (img_path, pid, camid). 117 | - num_instances (int): number of instances per identity in a batch. 118 | - batch_size (int): number of examples in a batch. 119 | """ 120 | 121 | def __init__(self, data_source, batch_size, num_instances): 122 | self.data_source = data_source 123 | self.batch_size = batch_size 124 | self.world_size = dist.get_world_size() 125 | self.num_instances = num_instances 126 | self.mini_batch_size = self.batch_size // self.world_size 127 | self.num_pids_per_batch = self.mini_batch_size // self.num_instances 128 | self.index_dic = defaultdict(list) 129 | 130 | for index, (pid, _, _, _) in enumerate(self.data_source): 131 | self.index_dic[pid].append(index) 132 | self.pids = list(self.index_dic.keys()) 133 | 134 | # estimate number of examples in an epoch 135 | self.length = 0 136 | for pid in self.pids: 137 | idxs = self.index_dic[pid] 138 | num = len(idxs) 139 | if num < self.num_instances: 140 | num = self.num_instances 141 | self.length += num - num % self.num_instances 142 | 143 | self.rank = dist.get_rank() 144 | #self.world_size = dist.get_world_size() 145 | self.length //= self.world_size 146 | 147 | def __iter__(self): 148 | seed = shared_random_seed() 149 | np.random.seed(seed) 150 | self._seed = int(seed) 151 | final_idxs = self.sample_list() 152 | length = int(math.ceil(len(final_idxs) * 1.0 / self.world_size)) 153 | #final_idxs = final_idxs[self.rank * length:(self.rank + 1) * length] 154 | final_idxs = self.__fetch_current_node_idxs(final_idxs, length) 155 | self.length = len(final_idxs) 156 | return iter(final_idxs) 157 | 158 | 159 | def __fetch_current_node_idxs(self, final_idxs, length): 160 | total_num = len(final_idxs) 161 | block_num = (length // self.mini_batch_size) 162 | index_target = [] 163 | for i in range(0, block_num * self.world_size, self.world_size): 164 | index = range(self.mini_batch_size * self.rank + self.mini_batch_size * i, min(self.mini_batch_size * self.rank + self.mini_batch_size * (i+1), total_num)) 165 | index_target.extend(index) 166 | index_target_npy = np.array(index_target) 167 | final_idxs = list(np.array(final_idxs)[index_target_npy]) 168 | return final_idxs 169 | 170 | 171 | def sample_list(self): 172 | #np.random.seed(self._seed) 173 | avai_pids = copy.deepcopy(self.pids) 174 | batch_idxs_dict = {} 175 | 176 | batch_indices = [] 177 | while len(avai_pids) >= self.num_pids_per_batch: 178 | selected_pids = np.random.choice(avai_pids, self.num_pids_per_batch, replace=False).tolist() 179 | for pid in selected_pids: 180 | if pid not in batch_idxs_dict: 181 | idxs = copy.deepcopy(self.index_dic[pid]) 182 | if len(idxs) < self.num_instances: 183 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True).tolist() 184 | np.random.shuffle(idxs) 185 | batch_idxs_dict[pid] = idxs 186 | 187 | avai_idxs = batch_idxs_dict[pid] 188 | for _ in range(self.num_instances): 189 | batch_indices.append(avai_idxs.pop(0)) 190 | 191 | if len(avai_idxs) < self.num_instances: avai_pids.remove(pid) 192 | 193 | return batch_indices 194 | 195 | def __len__(self): 196 | return self.length 197 | 198 | -------------------------------------------------------------------------------- /figures/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/figures/example.png -------------------------------------------------------------------------------- /figures/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/figures/framework.png -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import os 3 | import os.path as op 4 | from model.build_finetune import build_finetune_model 5 | import torch 6 | import numpy as np 7 | import random 8 | import time 9 | import torch.nn as nn 10 | 11 | from datasets import build_dataloader 12 | from datasets.bases import ImageTextMLMDataset 13 | from datasets.build import build_mix_loader, build_zero_shot_loader 14 | from processor.processor_finetune import do_train 15 | from utils.checkpoint import Checkpointer 16 | from utils.iotools import save_train_configs 17 | from utils.logger import setup_logger 18 | from solver import build_optimizer, build_lr_scheduler 19 | from model import build_model 20 | from utils.metrics import Evaluator 21 | from utils.options import get_args 22 | from utils.comm import get_rank, synchronize 23 | 24 | 25 | def set_seed(seed=0): 26 | torch.manual_seed(seed) 27 | torch.cuda.manual_seed(seed) 28 | torch.cuda.manual_seed_all(seed) 29 | np.random.seed(seed) 30 | random.seed(seed) 31 | torch.backends.cudnn.deterministic = True 32 | torch.backends.cudnn.benchmark = True 33 | 34 | 35 | if __name__ == '__main__': 36 | args = get_args() 37 | set_seed(1+get_rank()) 38 | name = args.name 39 | 40 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 41 | args.distributed = num_gpus > 1 42 | 43 | if args.distributed: 44 | torch.cuda.set_device(args.local_rank) 45 | torch.distributed.init_process_group(backend="nccl", init_method="env://") 46 | synchronize() 47 | 48 | device = "cuda" 49 | cur_time = time.strftime("%Y%m%d_%H%M%S", time.localtime()) 50 | args.output_dir = op.join(args.output_dir, args.dataset_name, f'{cur_time}_{name}') 51 | logger = setup_logger('IRRA', save_dir=args.output_dir, if_train=args.training, distributed_rank=get_rank()) 52 | logger.info("Using {} GPUs".format(num_gpus)) 53 | logger.info(str(args).replace(',', '\n')) 54 | save_train_configs(args.output_dir, args) 55 | 56 | # get image-text pair datasets dataloader 57 | trainset ,train_loader, val_img_loader0, val_txt_loader0, val_img_loader1, val_txt_loader1, val_img_loader2, val_txt_loader2, num_classes = build_zero_shot_loader(args,finetune=True) 58 | model = build_finetune_model(args, num_classes) 59 | logger.info('Total params: %2.fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0)) 60 | if args.finetune: 61 | logger.info("loading {} model".format(args.finetune)) 62 | param_dict = torch.load(args.finetune,map_location='cpu')['model'] 63 | for k in list(param_dict.keys()): 64 | refine_k = k.replace('module.','') 65 | param_dict[refine_k] = param_dict[k].detach().clone() 66 | del param_dict[k] 67 | model.load_state_dict(param_dict, False) 68 | # model = model.float() 69 | model.cuda() 70 | model = nn.DataParallel(model) 71 | 72 | if args.distributed: 73 | model = torch.nn.parallel.DistributedDataParallel( 74 | model, 75 | device_ids=[args.local_rank], 76 | output_device=args.local_rank, 77 | # this should be removed if we update BatchNorm stats 78 | broadcast_buffers=False, 79 | ) 80 | optimizer = build_optimizer(args, model) 81 | scheduler = build_lr_scheduler(args, optimizer) 82 | 83 | is_master = get_rank() == 0 84 | checkpointer = Checkpointer(model, optimizer, scheduler, args.output_dir, is_master) 85 | evaluator0 = Evaluator(val_img_loader0, val_txt_loader0) 86 | evaluator1 = Evaluator(val_img_loader1, val_txt_loader1) 87 | evaluator2 = Evaluator(val_img_loader2, val_txt_loader2) 88 | 89 | start_epoch = 1 90 | if args.resume: 91 | checkpoint = checkpointer.resume(args.resume_ckpt_file) 92 | start_epoch = checkpoint['epoch'] 93 | 94 | 95 | do_train(start_epoch, args, model, train_loader, evaluator0,evaluator1,evaluator2, optimizer, scheduler, checkpointer, trainset) 96 | 97 | -------------------------------------------------------------------------------- /finetune.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | DATASET_NAME="CUHK-PEDES" 3 | 4 | CUDA_VISIBLE_DEVICES=0 \ 5 | python finetune.py \ 6 | --name finetune \ 7 | --img_aug \ 8 | --batch_size 64 \ 9 | --MLM \ 10 | --dataset_name $DATASET_NAME \ 11 | --loss_names 'sdm+id+mlm' \ 12 | --num_epoch 60 \ 13 | --root_dir /data0/wentao/data/textReID \ 14 | --finetune The_pretrained_checkpoint 15 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_model 2 | from .build_finetune import build_finetune_model -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/model/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/model/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/model/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /model/__pycache__/build.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/model/__pycache__/build.cpython-310.pyc -------------------------------------------------------------------------------- /model/__pycache__/build.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/model/__pycache__/build.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/build.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/model/__pycache__/build.cpython-39.pyc -------------------------------------------------------------------------------- /model/__pycache__/build_finetune.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/model/__pycache__/build_finetune.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/build_finetune.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/model/__pycache__/build_finetune.cpython-39.pyc -------------------------------------------------------------------------------- /model/__pycache__/clip_model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/model/__pycache__/clip_model.cpython-310.pyc -------------------------------------------------------------------------------- /model/__pycache__/clip_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/model/__pycache__/clip_model.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/clip_model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/model/__pycache__/clip_model.cpython-39.pyc -------------------------------------------------------------------------------- /model/__pycache__/memory.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/model/__pycache__/memory.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/objectives.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/model/__pycache__/objectives.cpython-310.pyc -------------------------------------------------------------------------------- /model/__pycache__/objectives.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/model/__pycache__/objectives.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/objectives.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/model/__pycache__/objectives.cpython-39.pyc -------------------------------------------------------------------------------- /model/__pycache__/style.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/model/__pycache__/style.cpython-310.pyc -------------------------------------------------------------------------------- /model/__pycache__/style.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/model/__pycache__/style.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/style.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/model/__pycache__/style.cpython-39.pyc -------------------------------------------------------------------------------- /model/build.py: -------------------------------------------------------------------------------- 1 | from model import objectives 2 | from .clip_model import ResidualAttentionBlock, ResidualCrossAttentionBlock, Transformer, QuickGELU, LayerNorm, build_CLIP_from_openai_pretrained, convert_weights 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from collections import OrderedDict 7 | import torch.nn.functional as F 8 | 9 | class IRRA(nn.Module): 10 | def __init__(self, args, num_classes=11003): 11 | super().__init__() 12 | self.args = args 13 | self.num_classes = num_classes 14 | self._set_task() 15 | 16 | self.base_model, base_cfg = build_CLIP_from_openai_pretrained(args.pretrain_choice, args.img_size, args.stride_size) 17 | self.embed_dim = base_cfg['embed_dim'] 18 | self.logit_scale = torch.ones([]) * (1 / args.temperature) 19 | 20 | if 'id' in args.loss_names: 21 | self.classifier = nn.Linear(self.embed_dim, self.num_classes) 22 | nn.init.normal_(self.classifier.weight.data, std=0.001) 23 | nn.init.constant_(self.classifier.bias.data, val=0.0) 24 | 25 | if 'mlm' in args.loss_names: 26 | self.cross_attn = nn.MultiheadAttention(self.embed_dim, 27 | self.embed_dim // 64, 28 | batch_first=True) 29 | self.cross_modal_transformer = Transformer(width=self.embed_dim, 30 | layers=args.cmt_depth, 31 | heads=self.embed_dim // 32 | 64) 33 | scale = self.cross_modal_transformer.width**-0.5 34 | 35 | self.ln_pre_t = LayerNorm(self.embed_dim) 36 | self.ln_pre_i = LayerNorm(self.embed_dim) 37 | self.ln_post = LayerNorm(self.embed_dim) 38 | 39 | proj_std = scale * ((2 * self.cross_modal_transformer.layers)**-0.5) 40 | attn_std = scale 41 | fc_std = (2 * self.cross_modal_transformer.width)**-0.5 42 | for block in self.cross_modal_transformer.resblocks: 43 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 44 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 45 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 46 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 47 | 48 | # init cross attn 49 | nn.init.normal_(self.cross_attn.in_proj_weight, std=attn_std) 50 | nn.init.normal_(self.cross_attn.out_proj.weight, std=proj_std) 51 | 52 | self.mlm_head = nn.Sequential( 53 | OrderedDict([('dense', nn.Linear(self.embed_dim, self.embed_dim)), 54 | ('gelu', QuickGELU()), 55 | ('ln', LayerNorm(self.embed_dim)), 56 | ('fc', nn.Linear(self.embed_dim, args.vocab_size))])) 57 | # init mlm head 58 | nn.init.normal_(self.mlm_head.dense.weight, std=fc_std) 59 | nn.init.normal_(self.mlm_head.fc.weight, std=proj_std) 60 | 61 | def _set_task(self): 62 | loss_names = self.args.loss_names 63 | self.current_task = [l.strip() for l in loss_names.split('+')] 64 | print(f'Training Model with {self.current_task} tasks') 65 | 66 | 67 | def cross_former(self, q, k, v): 68 | x = self.cross_attn( 69 | self.ln_pre_t(q), 70 | self.ln_pre_i(k), 71 | self.ln_pre_i(v), 72 | need_weights=False)[0] 73 | x = x.permute(1, 0, 2) # NLD -> LND 74 | x = self.cross_modal_transformer(x) 75 | x = x.permute(1, 0, 2) # LND -> NLD 76 | 77 | x = self.ln_post(x) 78 | return x 79 | 80 | def encode_image(self, image): 81 | image_feats = self.base_model.encode_image(image) 82 | return image_feats[:, 0, :].float() 83 | # return x[:, 0, :].float() 84 | # return x.float() # for CLIP ResNet visual model 85 | 86 | def encode_text(self, text): 87 | x = self.base_model.encode_text(text) 88 | return x[torch.arange(x.shape[0]), text.argmax(dim=-1)].float() 89 | 90 | def forward(self, image, text, ori_text): 91 | images = image 92 | caption_ids = text 93 | ori_caption_ids = ori_text 94 | mix_ids = torch.cat([caption_ids,ori_caption_ids],dim=0) 95 | with torch.autocast(dtype=torch.float16, device_type='cuda'): 96 | image_feats, text_feats = self.base_model(images, mix_ids) 97 | image_feats, fu_img_feats = image_feats.chunk(2,dim=0) 98 | text_feats, fu_txt_feats = text_feats.chunk(2,dim=0) 99 | return image_feats.float(), text_feats.float(), fu_img_feats.float(),fu_txt_feats.float() 100 | 101 | ret = {} 102 | i_feats = image_feats[:, 0, :].float() 103 | t_feats = text_feats[torch.arange(text_feats.shape[0]), caption_ids.argmax(dim=-1)].float() 104 | 105 | if 'itc' in self.current_task: 106 | ret.update({'itc_loss':objectives.compute_itc(i_feats, t_feats, logit_scale)}) 107 | 108 | if 'sdm' in self.current_task: 109 | ret.update({'sdm_loss':objectives.compute_sdm(i_feats, t_feats, batch['pids'], logit_scale)}) 110 | 111 | 112 | if 'cmpm' in self.current_task: 113 | ret.update({'cmpm_loss':objectives.compute_cmpm(i_feats, t_feats, batch['pids'])}) 114 | 115 | if 'id' in self.current_task: 116 | image_logits = self.classifier(i_feats.half()).float() 117 | text_logits = self.classifier(t_feats.half()).float() 118 | ret.update({'id_loss':objectives.compute_id(image_logits, text_logits, batch['pids'])*self.args.id_loss_weight}) 119 | 120 | image_pred = torch.argmax(image_logits, dim=1) 121 | text_pred = torch.argmax(text_logits, dim=1) 122 | 123 | image_precision = (image_pred == batch['pids']).float().mean() 124 | text_precision = (text_pred == batch['pids']).float().mean() 125 | ret.update({'img_acc': image_precision}) 126 | ret.update({'txt_acc': text_precision}) 127 | 128 | if 'mlm' in self.current_task: 129 | mlm_ids = batch['mlm_ids'] 130 | 131 | mlm_feats = self.base_model.encode_text(mlm_ids) 132 | 133 | x = self.cross_former(mlm_feats, image_feats, image_feats) 134 | 135 | x = self.mlm_head(x) # [batch_size, text_len, num_colors] 136 | 137 | scores = x.float().reshape(-1, self.args.vocab_size) 138 | mlm_labels = batch['mlm_labels'].reshape(-1) 139 | ret.update({'mlm_loss': objectives.compute_mlm(scores, mlm_labels)*self.args.mlm_loss_weight}) 140 | 141 | pred = scores.max(1)[1] 142 | mlm_label_idx = torch.nonzero(mlm_labels) 143 | acc = (pred[mlm_label_idx] == mlm_labels[mlm_label_idx]).float().mean() 144 | ret.update({'mlm_acc': acc}) 145 | 146 | if 'att_mlm' in self.current_task: 147 | for att_type in ['shoes','hairstyle','genders','top','trousers','belongings']: 148 | mlm_ids = batch[att_type+'_mlm_ids'] 149 | 150 | mlm_feats = self.base_model.encode_text(mlm_ids) 151 | 152 | x = self.cross_former(mlm_feats, image_feats, image_feats) 153 | 154 | x = self.mlm_head(x) # [batch_size, text_len, num_colors] 155 | 156 | scores = x.float().reshape(-1, self.args.vocab_size) 157 | mlm_labels = batch[att_type+'_mlm_labels'].reshape(-1) 158 | ret.update({att_type+'_loss': objectives.compute_mlm(scores, mlm_labels)*self.args.mlm_loss_weight}) 159 | 160 | pred = scores.max(1)[1] 161 | mlm_label_idx = torch.nonzero(mlm_labels) 162 | acc = (pred[mlm_label_idx] == mlm_labels[mlm_label_idx]).float().mean() 163 | ret.update({att_type+'_acc': acc}) 164 | 165 | return ret 166 | 167 | 168 | def build_model(args, num_classes=11003): 169 | model = IRRA(args, num_classes) 170 | # covert model to fp16 171 | convert_weights(model) 172 | return model 173 | -------------------------------------------------------------------------------- /model/build_finetune.py: -------------------------------------------------------------------------------- 1 | from model import objectives 2 | from .clip_model import ResidualAttentionBlock, ResidualCrossAttentionBlock, Transformer, QuickGELU, LayerNorm, build_CLIP_from_openai_pretrained, convert_weights 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from collections import OrderedDict 7 | import torch.nn.functional as F 8 | 9 | 10 | class IRRA(nn.Module): 11 | def __init__(self, args, num_classes=11003): 12 | super().__init__() 13 | self.args = args 14 | self.num_classes = num_classes 15 | self._set_task() 16 | 17 | self.base_model, base_cfg = build_CLIP_from_openai_pretrained(args.pretrain_choice, args.img_size, args.stride_size) 18 | self.embed_dim = base_cfg['embed_dim'] 19 | self.logit_scale = torch.ones([]) * (1 / args.temperature) 20 | 21 | 22 | if 'id' in args.loss_names: 23 | self.classifier = nn.Linear(self.embed_dim, self.num_classes) 24 | nn.init.normal_(self.classifier.weight.data, std=0.001) 25 | nn.init.constant_(self.classifier.bias.data, val=0.0) 26 | 27 | if 'mlm' in args.loss_names: 28 | self.cross_attn = nn.MultiheadAttention(self.embed_dim, 29 | self.embed_dim // 64, 30 | batch_first=True) 31 | self.cross_modal_transformer = Transformer(width=self.embed_dim, 32 | layers=args.cmt_depth, 33 | heads=self.embed_dim // 34 | 64) 35 | scale = self.cross_modal_transformer.width**-0.5 36 | 37 | self.ln_pre_t = LayerNorm(self.embed_dim) 38 | self.ln_pre_i = LayerNorm(self.embed_dim) 39 | self.ln_post = LayerNorm(self.embed_dim) 40 | 41 | proj_std = scale * ((2 * self.cross_modal_transformer.layers)**-0.5) 42 | attn_std = scale 43 | fc_std = (2 * self.cross_modal_transformer.width)**-0.5 44 | for block in self.cross_modal_transformer.resblocks: 45 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 46 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 47 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 48 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 49 | 50 | # init cross attn 51 | nn.init.normal_(self.cross_attn.in_proj_weight, std=attn_std) 52 | nn.init.normal_(self.cross_attn.out_proj.weight, std=proj_std) 53 | 54 | self.mlm_head = nn.Sequential( 55 | OrderedDict([('dense', nn.Linear(self.embed_dim, self.embed_dim)), 56 | ('gelu', QuickGELU()), 57 | ('ln', LayerNorm(self.embed_dim)), 58 | ('fc', nn.Linear(self.embed_dim, args.vocab_size))])) 59 | # init mlm head 60 | nn.init.normal_(self.mlm_head.dense.weight, std=fc_std) 61 | nn.init.normal_(self.mlm_head.fc.weight, std=proj_std) 62 | 63 | def _set_task(self): 64 | loss_names = self.args.loss_names 65 | self.current_task = [l.strip() for l in loss_names.split('+')] 66 | print(f'Training Model with {self.current_task} tasks') 67 | 68 | 69 | def cross_former(self, q, k, v): 70 | x = self.cross_attn( 71 | self.ln_pre_t(q), 72 | self.ln_pre_i(k), 73 | self.ln_pre_i(v), 74 | need_weights=False)[0] 75 | x = x.permute(1, 0, 2) # NLD -> LND 76 | x = self.cross_modal_transformer(x) 77 | x = x.permute(1, 0, 2) # LND -> NLD 78 | 79 | x = self.ln_post(x) 80 | return x 81 | 82 | def encode_image(self, image): 83 | image_feats = self.base_model.encode_image(image) 84 | return image_feats[:, 0, :].float() 85 | 86 | def encode_text(self, text): 87 | x = self.base_model.encode_text(text) 88 | return x[torch.arange(x.shape[0]), text.argmax(dim=-1)].float() 89 | 90 | def forward(self, batch): 91 | ret = dict() 92 | 93 | images = batch['images'] 94 | caption_ids = batch['caption_ids'] 95 | with torch.autocast(dtype=torch.float16, device_type='cuda'): 96 | image_feats, text_feats = self.base_model(images, caption_ids) 97 | 98 | i_feats = image_feats[:, 0, :].float() 99 | # i_feats = image_feats.float() # for CLIP ResNet visual model 100 | t_feats = text_feats[torch.arange(text_feats.shape[0]), caption_ids.argmax(dim=-1)].float() 101 | 102 | logit_scale = self.logit_scale 103 | 104 | if 'itc' in self.current_task: 105 | ret.update({'itc_loss':objectives.compute_itc(i_feats, t_feats, logit_scale)}) 106 | 107 | if 'sdm' in self.current_task: 108 | ret.update({'sdm_loss':objectives.compute_sdm(i_feats, t_feats, batch['pids'], logit_scale)}) 109 | 110 | if 'cmpm' in self.current_task: 111 | ret.update({'cmpm_loss':objectives.compute_cmpm(i_feats, t_feats, batch['pids'])}) 112 | 113 | if 'id' in self.current_task: 114 | image_logits = self.classifier(i_feats.half()).float() 115 | text_logits = self.classifier(t_feats.half()).float() 116 | ret.update({'id_loss':objectives.compute_id(image_logits, text_logits, batch['pids'])*self.args.id_loss_weight}) 117 | 118 | image_pred = torch.argmax(image_logits, dim=1) 119 | text_pred = torch.argmax(text_logits, dim=1) 120 | 121 | image_precision = (image_pred == batch['pids']).float().mean() 122 | text_precision = (text_pred == batch['pids']).float().mean() 123 | ret.update({'img_acc': image_precision}) 124 | ret.update({'txt_acc': text_precision}) 125 | 126 | if 'mlm' in self.current_task: 127 | mlm_ids = batch['mlm_ids'] 128 | 129 | mlm_feats = self.base_model.encode_text(mlm_ids) 130 | 131 | x = self.cross_former(mlm_feats, image_feats, image_feats) 132 | 133 | x = self.mlm_head(x) # [batch_size, text_len, num_colors] 134 | 135 | scores = x.float().reshape(-1, self.args.vocab_size) 136 | mlm_labels = batch['mlm_labels'].reshape(-1) 137 | ret.update({'mlm_loss': objectives.compute_mlm(scores, mlm_labels)*self.args.mlm_loss_weight}) 138 | 139 | pred = scores.max(1)[1] 140 | mlm_label_idx = torch.nonzero(mlm_labels) 141 | acc = (pred[mlm_label_idx] == mlm_labels[mlm_label_idx]).float().mean() 142 | ret.update({'mlm_acc': acc}) 143 | 144 | if 'att_mlm' in self.current_task: 145 | for att_type in ['shoes','hairstyle','genders','top','trousers','belongings']: 146 | mlm_ids = batch[att_type+'_mlm_ids'] 147 | 148 | mlm_feats = self.base_model.encode_text(mlm_ids) 149 | 150 | x = self.cross_former(mlm_feats, image_feats, image_feats) 151 | 152 | x = self.mlm_head(x) # [batch_size, text_len, num_colors] 153 | 154 | scores = x.float().reshape(-1, self.args.vocab_size) 155 | mlm_labels = batch[att_type+'_mlm_labels'].reshape(-1) 156 | ret.update({att_type+'_loss': objectives.compute_mlm(scores, mlm_labels)*self.args.mlm_loss_weight}) 157 | 158 | pred = scores.max(1)[1] 159 | mlm_label_idx = torch.nonzero(mlm_labels) 160 | acc = (pred[mlm_label_idx] == mlm_labels[mlm_label_idx]).float().mean() 161 | ret.update({att_type+'_acc': acc}) 162 | 163 | return ret 164 | 165 | 166 | def build_finetune_model(args, num_classes=11003): 167 | model = IRRA(args, num_classes) 168 | # covert model to fp16 169 | convert_weights(model) 170 | return model 171 | -------------------------------------------------------------------------------- /model/clip_model.py: -------------------------------------------------------------------------------- 1 | """ CLIP Model 2 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 3 | """ 4 | from collections import OrderedDict 5 | import logging 6 | import math 7 | import os 8 | from typing import List, Tuple, Union 9 | import hashlib 10 | import urllib 11 | from tqdm import tqdm 12 | import warnings 13 | import numpy as np 14 | import torch 15 | import torch.nn.functional as F 16 | from torch import nn 17 | 18 | from model.style import AdaIN 19 | 20 | 21 | logger = logging.getLogger("IRRA.model") 22 | 23 | _MODELS = { 24 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 25 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 26 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 27 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 28 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 29 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 30 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 31 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 32 | } 33 | 34 | def available_models() -> List[str]: 35 | """Returns the names of available CLIP models""" 36 | return list(_MODELS.keys()) 37 | 38 | def _download(url: str, root: str): 39 | os.makedirs(root, exist_ok=True) 40 | filename = os.path.basename(url) 41 | 42 | expected_sha256 = url.split("/")[-2] 43 | download_target = os.path.join(root, filename) 44 | 45 | if os.path.exists(download_target) and not os.path.isfile(download_target): 46 | raise RuntimeError(f"{download_target} exists and is not a regular file") 47 | 48 | if os.path.isfile(download_target): 49 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 50 | return download_target 51 | else: 52 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 53 | 54 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 55 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 56 | while True: 57 | buffer = source.read(8192) 58 | if not buffer: 59 | break 60 | 61 | output.write(buffer) 62 | loop.update(len(buffer)) 63 | 64 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 65 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 66 | 67 | return download_target 68 | 69 | 70 | class Bottleneck(nn.Module): 71 | expansion = 4 72 | 73 | def __init__(self, inplanes, planes, stride=1): 74 | super().__init__() 75 | 76 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 77 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 78 | self.bn1 = nn.BatchNorm2d(planes) 79 | 80 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 81 | self.bn2 = nn.BatchNorm2d(planes) 82 | 83 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 84 | 85 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 86 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 87 | 88 | self.relu = nn.ReLU(inplace=True) 89 | self.downsample = None 90 | self.stride = stride 91 | 92 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 93 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 94 | self.downsample = nn.Sequential(OrderedDict([ 95 | ("-1", nn.AvgPool2d(stride)), 96 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 97 | ("1", nn.BatchNorm2d(planes * self.expansion)) 98 | ])) 99 | 100 | def forward(self, x: torch.Tensor): 101 | identity = x 102 | 103 | out = self.relu(self.bn1(self.conv1(x))) 104 | out = self.relu(self.bn2(self.conv2(out))) 105 | out = self.avgpool(out) 106 | out = self.bn3(self.conv3(out)) 107 | 108 | if self.downsample is not None: 109 | identity = self.downsample(x) 110 | 111 | out += identity 112 | out = self.relu(out) 113 | return out 114 | 115 | 116 | class AttentionPool2d(nn.Module): 117 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None, cls_token=None): 118 | super().__init__() 119 | # self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 120 | self.positional_embedding = nn.Parameter(torch.randn((spacial_dim[0] * spacial_dim[1]) + 1, embed_dim)/ embed_dim ** 0.5) 121 | self.k_proj = nn.Linear(embed_dim, embed_dim) 122 | self.q_proj = nn.Linear(embed_dim, embed_dim) 123 | self.v_proj = nn.Linear(embed_dim, embed_dim) 124 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 125 | self.num_heads = num_heads 126 | self.cls_token = cls_token 127 | if self.cls_token is not None: 128 | self.cls = nn.Parameter(torch.randn([1,2048])) 129 | def forward(self, x): 130 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 131 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 132 | 133 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 134 | if self.cls_token is not None: 135 | q = self.cls.unsqueeze(1).repeat(1,x.size(1),1).to(x.dtype).to(x.device) 136 | 137 | else: 138 | q = x 139 | x, _ = F.multi_head_attention_forward( 140 | query=q, key=x, value=x, 141 | embed_dim_to_check=x.shape[-1], 142 | num_heads=self.num_heads, 143 | q_proj_weight=self.q_proj.weight, 144 | k_proj_weight=self.k_proj.weight, 145 | v_proj_weight=self.v_proj.weight, 146 | in_proj_weight=None, 147 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 148 | bias_k=None, 149 | bias_v=None, 150 | add_zero_attn=False, 151 | dropout_p=0, 152 | out_proj_weight=self.c_proj.weight, 153 | out_proj_bias=self.c_proj.bias, 154 | use_separate_proj_weight=True, 155 | training=self.training, 156 | need_weights=False 157 | ) 158 | return x.permute(1,0,2) 159 | # return x[0] 160 | 161 | 162 | class ModifiedResNet(nn.Module): 163 | """ 164 | A ResNet class that is similar to torchvision's but contains the following changes: 165 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 166 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 167 | - The final pooling layer is a QKV attention instead of an average pool 168 | """ 169 | 170 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 171 | super().__init__() 172 | self.output_dim = output_dim 173 | self.input_resolution = input_resolution 174 | 175 | # the 3-layer stem 176 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 177 | self.bn1 = nn.BatchNorm2d(width // 2) 178 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 179 | self.bn2 = nn.BatchNorm2d(width // 2) 180 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 181 | self.bn3 = nn.BatchNorm2d(width) 182 | self.avgpool = nn.AvgPool2d(2) 183 | self.relu = nn.ReLU(inplace=True) 184 | 185 | # residual layers 186 | self._inplanes = width # this is a *mutable* variable used during construction 187 | self.layer1 = self._make_layer(width, layers[0]) 188 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 189 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 190 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 191 | 192 | embed_dim = width * 32 # the ResNet feature dimension 193 | spacial_dim = ( 194 | input_resolution[0] // 32, 195 | input_resolution[1] // 32, 196 | ) 197 | self.attnpool = AttentionPool2d(spacial_dim, embed_dim, heads, output_dim) 198 | # self.attnpool0 = AttentionPool2d(spacial_dim, embed_dim, heads, output_dim,1) 199 | # self.attnpool1 = AttentionPool2d(spacial_dim, embed_dim, heads, output_dim,1) 200 | # self.attnpool2 = AttentionPool2d(spacial_dim, embed_dim, heads, output_dim,1) 201 | # self.attnpool3 = AttentionPool2d(spacial_dim, embed_dim, heads, output_dim,1) 202 | self.style = AdaIN(p=0.5) 203 | def _make_layer(self, planes, blocks, stride=1): 204 | layers = [Bottleneck(self._inplanes, planes, stride)] 205 | 206 | self._inplanes = planes * Bottleneck.expansion 207 | for _ in range(1, blocks): 208 | layers.append(Bottleneck(self._inplanes, planes)) 209 | 210 | return nn.Sequential(*layers) 211 | 212 | def forward(self, x): 213 | def stem(x): 214 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: 215 | x = self.relu(bn(conv(x))) 216 | x = self.avgpool(x) 217 | return x 218 | 219 | x = x.type(self.conv1.weight.dtype) 220 | x = stem(x) 221 | x = self.layer1(x) 222 | # if self.training: 223 | # x = self.style(x) 224 | x = self.layer2(x) 225 | # if self.training: 226 | # x = self.style(x) 227 | x = self.layer3(x) 228 | x = self.layer4(x) 229 | x_glo = self.attnpool(x) 230 | # x0= self.attnpool0(x) 231 | # x1 = self.attnpool1(x) 232 | # x2= self.attnpool2(x) 233 | # x3 = self.attnpool3(x) 234 | # if not self.training:return x_glo 235 | return x_glo 236 | 237 | 238 | class LayerNorm(nn.LayerNorm): 239 | """Subclass torch's LayerNorm to handle fp16.""" 240 | 241 | def forward(self, x: torch.Tensor): 242 | orig_type = x.dtype 243 | ret = super().forward(x.type(torch.float32)) 244 | return ret.type(orig_type) 245 | 246 | 247 | class QuickGELU(nn.Module): 248 | def forward(self, x: torch.Tensor): 249 | return x * torch.sigmoid(1.702 * x) 250 | 251 | 252 | class ResidualAttentionBlock(nn.Module): 253 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 254 | super().__init__() 255 | 256 | self.attn = nn.MultiheadAttention(d_model, n_head) 257 | self.ln_1 = LayerNorm(d_model) 258 | self.mlp = nn.Sequential(OrderedDict([ 259 | ("c_fc", nn.Linear(d_model, d_model * 4)), 260 | ("gelu", QuickGELU()), 261 | ("c_proj", nn.Linear(d_model * 4, d_model)) 262 | ])) 263 | self.ln_2 = LayerNorm(d_model) 264 | self.attn_mask = attn_mask 265 | 266 | def attention(self, x: torch.Tensor): 267 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 268 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 269 | 270 | def forward(self, x: torch.Tensor): 271 | x = x + self.attention(self.ln_1(x)) 272 | x = x + self.mlp(self.ln_2(x)) 273 | return x 274 | 275 | class ResidualCrossAttentionBlock(nn.Module): 276 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 277 | super().__init__() 278 | 279 | self.attn = nn.MultiheadAttention(d_model, n_head) 280 | self.ln_1 = LayerNorm(d_model) 281 | self.mlp = nn.Sequential(OrderedDict([ 282 | ("c_fc", nn.Linear(d_model, d_model * 4)), 283 | ("gelu", QuickGELU()), 284 | ("c_proj", nn.Linear(d_model * 4, d_model)) 285 | ])) 286 | self.ln_2 = LayerNorm(d_model) 287 | self.attn_mask = attn_mask 288 | 289 | def attention(self, q,k,v): 290 | self.attn_mask = self.attn_mask.to(dtype=q.dtype, device=q.device) if self.attn_mask is not None else None 291 | return self.attn(q, k, v, need_weights=False, attn_mask=self.attn_mask)[0] 292 | 293 | def forward(self, q, k): 294 | x = q + self.attention(q, self.ln_1(k), self.ln_1(k)) 295 | x = x + self.mlp(self.ln_2(x)) 296 | return x 297 | 298 | class Transformer(nn.Module): 299 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 300 | super().__init__() 301 | self.width = width 302 | self.layers = layers 303 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 304 | self.type = type 305 | def forward(self, x, modal): 306 | if self.training: 307 | i = 2 308 | if modal=='visual': 309 | x_fu1 = self.resblocks[:self.layers-i](x) 310 | x = self.resblocks[self.layers-i:](x_fu1) 311 | return torch.cat([x,x_fu1],dim=1) 312 | elif modal == 'text': 313 | mix_token = self.resblocks[:self.layers-i](x) 314 | mlm_token, ori_token = mix_token.chunk(2,dim=1) 315 | x = self.resblocks[self.layers-i:](mlm_token) 316 | return torch.cat([x,ori_token],dim=1) 317 | return self.resblocks(x) 318 | 319 | 320 | class VisionTransformer(nn.Module): 321 | def __init__(self, input_resolution: Tuple[int, int], patch_size: int, stride_size: int, width: int, layers: int, heads: int, output_dim: int): 322 | super().__init__() 323 | self.input_resolution = input_resolution # (384, 128) 324 | self.num_x = (input_resolution[1] - patch_size) // stride_size + 1 325 | self.num_y = (input_resolution[0] - patch_size) // stride_size + 1 326 | num_patches = self.num_x * self.num_y 327 | 328 | self.output_dim = output_dim 329 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=stride_size, bias=False) 330 | 331 | scale = width ** -0.5 # 1/sqrt(768) 332 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 333 | self.positional_embedding = nn.Parameter(scale * torch.randn(num_patches + 1, width)) 334 | self.ln_pre = LayerNorm(width) 335 | 336 | self.transformer = Transformer(width, layers, heads) 337 | 338 | self.ln_post = LayerNorm(width) 339 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 340 | 341 | 342 | # self.style = AdaIN(p=0.5) 343 | 344 | def forward(self, x, modal=None): 345 | x = self.conv1(x) # shape = [*, width, grid, grid] 346 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 347 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 348 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 349 | x = x + self.positional_embedding.to(x.dtype) 350 | 351 | x = self.ln_pre(x) 352 | 353 | x = x.permute(1, 0, 2) # NLD -> LND 354 | x = self.transformer(x,modal) 355 | x = x.permute(1, 0, 2) # LND -> NLD 356 | 357 | # x = self.ln_post(x[:, 0, :]) 358 | x = self.ln_post(x) 359 | 360 | if self.proj is not None: 361 | x = x @ self.proj 362 | return x 363 | 364 | 365 | 366 | class CLIP(nn.Module): 367 | def __init__(self, 368 | embed_dim: int, 369 | # vision 370 | image_resolution: Union[int, Tuple[int, int]], 371 | vision_layers: Union[Tuple[int, int, int, int], int], 372 | vision_width: int, 373 | vision_patch_size: int, 374 | stride_size: int, 375 | # text 376 | context_length: int, 377 | vocab_size: int, 378 | transformer_width: int, 379 | transformer_heads: int, 380 | transformer_layers: int 381 | ): 382 | super().__init__() 383 | 384 | self.context_length = context_length 385 | 386 | if isinstance(vision_layers, (tuple, list)): 387 | vision_heads = vision_width * 32 // 64 388 | self.visual = ModifiedResNet( 389 | layers=vision_layers, 390 | output_dim=embed_dim, 391 | heads=vision_heads, 392 | input_resolution=image_resolution, 393 | width=vision_width 394 | ) 395 | else: 396 | vision_heads = vision_width // 64 397 | self.visual = VisionTransformer( 398 | input_resolution=image_resolution, 399 | patch_size=vision_patch_size, 400 | stride_size=stride_size, 401 | width=vision_width, 402 | layers=vision_layers, 403 | heads=vision_heads, 404 | output_dim=embed_dim 405 | ) 406 | 407 | self.transformer = Transformer( 408 | width=transformer_width, 409 | layers=transformer_layers, 410 | heads=transformer_heads, 411 | attn_mask=self.build_attention_mask() 412 | ) 413 | 414 | self.vocab_size = vocab_size 415 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 416 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 417 | # self.token_embedding.requires_grad_(False) 418 | # self.positional_embedding.requires_grad_(False) 419 | self.ln_final = LayerNorm(transformer_width) 420 | 421 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 422 | # self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 423 | 424 | self.initialize_parameters() 425 | 426 | def initialize_parameters(self): 427 | nn.init.normal_(self.token_embedding.weight, std=0.02) 428 | nn.init.normal_(self.positional_embedding, std=0.01) 429 | 430 | if isinstance(self.visual, ModifiedResNet): 431 | if self.visual.attnpool is not None: 432 | std = self.visual.attnpool.c_proj.in_features ** -0.5 433 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 434 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 435 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 436 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 437 | 438 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 439 | for name, param in resnet_block.named_parameters(): 440 | if name.endswith("bn3.weight"): 441 | nn.init.zeros_(param) 442 | 443 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 444 | attn_std = self.transformer.width ** -0.5 445 | fc_std = (2 * self.transformer.width) ** -0.5 446 | for block in self.transformer.resblocks: 447 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 448 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 449 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 450 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 451 | 452 | if self.text_projection is not None: 453 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 454 | 455 | def build_attention_mask(self): 456 | # lazily create causal attention mask, with full attention between the vision tokens 457 | # pytorch uses additive attention mask; fill with -inf 458 | mask = torch.empty(self.context_length, self.context_length) 459 | mask.fill_(float("-inf")) 460 | mask.triu_(1) # zero out the lower diagonal 461 | return mask 462 | 463 | @property 464 | def dtype(self): 465 | return self.visual.conv1.weight.dtype 466 | 467 | def encode_image(self, image, modal=None): 468 | return self.visual(image.type(self.dtype),modal) 469 | 470 | def encode_text(self, text, modal=None): 471 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 472 | 473 | x = x + self.positional_embedding.type(self.dtype) 474 | x = x.permute(1, 0, 2) # NLD -> LND 475 | x = self.transformer(x, modal) 476 | x = x.permute(1, 0, 2) # LND -> NLD 477 | x = self.ln_final(x).type(self.dtype) 478 | 479 | # x.shape = [batch_size, n_ctx, transformer.width] 480 | # take features from the eot embedding (eot_token is the highest number in each sequence) 481 | # x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 482 | x = x @ self.text_projection 483 | 484 | return x 485 | 486 | def forward(self, image, text): 487 | if text.size(0)==2*image.size(0) : 488 | image_features = self.encode_image(image,modal='visual') 489 | text_features = self.encode_text(text,modal='text') 490 | else: 491 | image_features = self.encode_image(image) 492 | text_features = self.encode_text(text) 493 | 494 | # # normalized features 495 | # image_features = image_features / image_features.norm(dim=-1, keepdim=True) 496 | # text_features = text_features / text_features.norm(dim=-1, keepdim=True) 497 | 498 | # # cosine similarity as logits 499 | # logit_scale = self.logit_scale.exp() 500 | # logits_per_image = logit_scale * image_features @ text_features.t() 501 | # logits_per_text = logits_per_image.t() 502 | 503 | # # shape = [global_batch_size, global_batch_size] 504 | # return logits_per_image, logits_per_text 505 | 506 | return image_features, text_features 507 | 508 | 509 | def load_param(self, state_dict): 510 | # 将pretrained_dict里不属于model_dict的键剔除掉 511 | param_dict = {k: v for k, v in state_dict.items() if k in self.state_dict()} 512 | if 'model' in param_dict: 513 | param_dict = param_dict['model'] 514 | if 'state_dict' in param_dict: 515 | param_dict = param_dict['state_dict'] 516 | for k, v in param_dict.items(): 517 | if k == 'visual.positional_embedding' and v.shape != self.visual.positional_embedding.shape: 518 | v = resize_pos_embed(v, self.visual.positional_embedding, self.visual.num_y, self.visual.num_x) 519 | elif k == 'positional_embedding' and v.shape != self.positional_embedding.shape: 520 | v = resize_text_pos_embed(v, self.context_length) 521 | # elif 'visual.attnpool' in k and '.positional_embedding' not in k: 522 | # k0 = k.replace('attnpool','attnpool0') 523 | # k1 = k.replace('attnpool','attnpool1') 524 | # k2 = k.replace('attnpool','attnpool2') 525 | # k3 = k.replace('attnpool','attnpool3') 526 | # self.state_dict()[k0].copy_(v) 527 | # self.state_dict()[k1].copy_(v) 528 | # self.state_dict()[k2].copy_(v) 529 | # self.state_dict()[k3].copy_(v) 530 | # elif 'visual.transformer.resblocks.11' in k: 531 | # k0 = k.replace('resblocks.11','copy_block.0') 532 | # k1 = k.replace('resblocks.11','copy_block.1') 533 | # k2 = k.replace('resblocks.11','copy_block.2') 534 | # k3 = k.replace('resblocks.11','copy_block.3') 535 | # self.state_dict()[k0].copy_(v) 536 | # self.state_dict()[k1].copy_(v) 537 | # self.state_dict()[k2].copy_(v) 538 | # self.state_dict()[k3].copy_(v) 539 | 540 | try: 541 | self.state_dict()[k].copy_(v) 542 | except: 543 | print(f'===========================ERROR occur in copy {k}, {v.shape}=========================') 544 | print('shape do not match in k :{}: param_dict{} vs self.state_dict(){}'.format(k, v.shape, self.state_dict()[k].shape)) 545 | 546 | 547 | 548 | def resize_pos_embed(posemb, posemb_new, hight, width): 549 | # Rescale the grid of position embeddings when loading from state_dict. Adapted from 550 | # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 551 | posemb = posemb.unsqueeze(0) 552 | posemb_new = posemb_new.unsqueeze(0) 553 | 554 | posemb_token, posemb_grid = posemb[:, :1], posemb[0, 1:] 555 | 556 | gs_old = int(math.sqrt(len(posemb_grid))) 557 | print('Resized position embedding from size:{} to size: {} with height:{} width: {}'.format(posemb.shape, posemb_new.shape, hight, width)) 558 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) 559 | posemb_grid = F.interpolate(posemb_grid, size=(hight, width), mode='bilinear') 560 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, hight * width, -1) 561 | posemb = torch.cat([posemb_token, posemb_grid], dim=1) 562 | return posemb.squeeze(0) 563 | 564 | 565 | def convert_weights(model: nn.Module): 566 | """Convert applicable model parameters to fp16""" 567 | 568 | def _convert_weights_to_fp16(l): 569 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 570 | l.weight.data = l.weight.data.half() 571 | if l.bias is not None: 572 | l.bias.data = l.bias.data.half() 573 | 574 | if isinstance(l, nn.MultiheadAttention): 575 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 576 | tensor = getattr(l, attr) 577 | if tensor is not None: 578 | tensor.data = tensor.data.half() 579 | 580 | for name in ["text_projection", "proj", "mcq_proj"]: 581 | if hasattr(l, name): 582 | attr = getattr(l, name) 583 | if attr is not None: 584 | attr.data = attr.data.half() 585 | 586 | model.apply(_convert_weights_to_fp16) 587 | 588 | 589 | def build_CLIP_from_openai_pretrained(name: str, image_size: Union[int, Tuple[int, int]], stride_size: int, jit: bool = False, download_root: str = None): 590 | """Load a CLIP model 591 | 592 | Parameters 593 | ---------- 594 | name : str 595 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 596 | 597 | image_size: Union[int, Tuple[int, int]] 598 | Input image size, in Re-ID task, image size commonly set to 384x128, instead of 224x224 599 | 600 | jit : bool 601 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 602 | 603 | download_root: str 604 | path to download the model files; by default, it uses "~/.cache/clip" 605 | 606 | Returns 607 | ------- 608 | model : torch.nn.Module 609 | The CLIP model 610 | """ 611 | if name in _MODELS: 612 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) 613 | elif os.path.isfile(name): 614 | model_path = name 615 | else: 616 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 617 | 618 | try: 619 | # loading JIT archive 620 | model = torch.jit.load(model_path, map_location="cpu") 621 | state_dict = None 622 | except RuntimeError: 623 | # loading saved state dict 624 | if jit: 625 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 626 | jit = False 627 | state_dict = torch.load(model_path, map_location="cpu") 628 | 629 | state_dict = state_dict or model.state_dict() 630 | 631 | vit = "visual.proj" in state_dict 632 | 633 | if vit: 634 | vision_width = state_dict["visual.conv1.weight"].shape[0] 635 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 636 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 637 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 638 | image_resolution = vision_patch_size * grid_size 639 | else: 640 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 641 | vision_layers = tuple(counts) 642 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 643 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 644 | vision_patch_size = None 645 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 646 | image_resolution = output_width * 32 647 | 648 | embed_dim = state_dict["text_projection"].shape[1] 649 | context_length = state_dict["positional_embedding"].shape[0] 650 | vocab_size = state_dict["token_embedding.weight"].shape[0] 651 | transformer_width = state_dict["ln_final.weight"].shape[0] 652 | transformer_heads = transformer_width // 64 653 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) 654 | 655 | model_cfg = { 656 | 'embed_dim': embed_dim, 657 | 'image_resolution': image_resolution, 658 | 'vision_layers': vision_layers, 659 | 'vision_width': vision_width, 660 | 'vision_patch_size': vision_patch_size, 661 | 'context_length': context_length, 662 | 'vocab_size': vocab_size, 663 | 'transformer_width': transformer_width, 664 | 'transformer_heads': transformer_heads, 665 | 'transformer_layers': transformer_layers 666 | } 667 | 668 | 669 | # modify image resolution to adapt Re-ID task 670 | model_cfg['image_resolution'] = image_size 671 | model_cfg['stride_size'] = stride_size 672 | logger.info(f"Load pretrained {name} CLIP model with model config: {model_cfg}") 673 | model = CLIP(**model_cfg) 674 | 675 | # covert model to fp16 676 | # convert_weights(model) 677 | 678 | # resize modified pos embedding 679 | model.load_param(state_dict) 680 | return model, model_cfg 681 | 682 | 683 | -------------------------------------------------------------------------------- /model/memory.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.nn import init 4 | from torch import nn, autograd 5 | import numpy as np 6 | 7 | 8 | class MC(autograd.Function): 9 | 10 | @staticmethod 11 | def forward(ctx, inputs, indexes, features, momentum): 12 | ctx.features = features 13 | ctx.momentum = momentum 14 | ctx.save_for_backward(inputs, indexes) 15 | outputs = inputs.mm(ctx.features.t()) 16 | 17 | return outputs 18 | 19 | @staticmethod 20 | def backward(ctx, grad_outputs): 21 | inputs, indexes = ctx.saved_tensors 22 | grad_inputs = None 23 | if ctx.needs_input_grad[0]: 24 | grad_inputs = grad_outputs.mm(ctx.features) 25 | 26 | return grad_inputs, None, None, None 27 | 28 | 29 | def mc(inputs, indexes, features, momentum=0.5): 30 | return MC.apply(inputs, indexes, features, torch.Tensor([momentum]).to(inputs.device)) 31 | 32 | 33 | class MemoryClassifier(nn.Module): 34 | def __init__(self, num_features, num_samples, temp=0.05, momentum=0.2): 35 | super(MemoryClassifier, self).__init__() 36 | self.num_features = num_features 37 | self.num_samples = num_samples 38 | self.momentum = momentum 39 | self.temp = temp 40 | 41 | self.register_buffer('features_txt', torch.zeros(num_samples, num_features)) 42 | self.register_buffer('features_img', torch.zeros(num_samples, num_features)) 43 | # self.register_buffer('cam_features',torch.zeros(num_samples,num_features,16,8)) 44 | 45 | def MomentumUpdate(self, img, txt, indexes): 46 | # momentum update 47 | for im,tx, y in zip(img, txt, indexes): 48 | self.features_img[y] = self.momentum * self.features_img[y] + (1. - self.momentum) * im 49 | self.features_img[y] = self.features_img[y] / self.features_img[y].norm() 50 | 51 | self.features_txt[y] = self.momentum * self.features_txt[y] + (1. - self.momentum) * tx 52 | self.features_txt[y] = self.features_txt[y] / self.features_txt[y].norm() 53 | 54 | def forward(self, img, txt , indexes): 55 | sim_i2t = mc(img, indexes, self.features_txt, self.momentum) ## B * C 56 | sim_i2i = mc(img, indexes, self.features_img, self.momentum) 57 | sim_t2i = mc(txt, indexes, self.features_img, self.momentum) 58 | sim_t2t = mc(txt, indexes, self.features_txt, self.momentum) 59 | 60 | loss_i2t = F.cross_entropy(sim_i2t / self.temp, indexes) 61 | loss_i2i = F.cross_entropy(sim_i2i / self.temp, indexes) 62 | loss_t2i = F.cross_entropy(sim_t2i / self.temp, indexes) 63 | loss_t2t = F.cross_entropy(sim_t2t / self.temp, indexes) 64 | return loss_i2t ,loss_i2i,loss_t2i,loss_t2t 65 | 66 | 67 | class Memory(nn.Module): 68 | def __init__(self, num_features, num_samples, temp=0.05, momentum=0.2): 69 | super(Memory, self).__init__() 70 | self.num_features = num_features 71 | self.num_samples = num_samples 72 | self.momentum = momentum 73 | self.temp = temp 74 | 75 | # self.register_buffer('features', torch.zeros(num_samples, num_features)) 76 | self.register_buffer('labels', torch.zeros(num_samples).long()) 77 | self.register_buffer('cam_features', torch.zeros(num_samples, num_features, 16, 8)) 78 | 79 | def MomentumUpdate(self, inputs, indexes): 80 | # momentum update 81 | for x, y in zip(inputs, indexes): 82 | self.cam_features[y] = self.momentum * self.cam_features[y] + (1. - self.momentum) * x 83 | self.cam_features[y] = self.cam_features[y] / self.cam_features[y].norm() 84 | 85 | def forward(self, inputs, indexes): 86 | return 0 87 | -------------------------------------------------------------------------------- /model/objectives.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | def find_mutual_nearest_neighbors(image_features, text_features, k): 6 | # 归一化特征向量 7 | image_features = F.normalize(image_features, p=2, dim=1) 8 | text_features = F.normalize(text_features, p=2, dim=1) 9 | 10 | # 计算余弦相似度 11 | image_text_similarities = torch.matmul(image_features, text_features.t()) 12 | text_image_similarities = torch.matmul(text_features, image_features.t()) 13 | 14 | # 找到图片特征的 k 个文本近邻 15 | _, image_to_text_nearest_neighbors = torch.topk(image_text_similarities, k, dim=1) 16 | _, text_to_image_nearest_neighbors = torch.topk(text_image_similarities, k, dim=1) 17 | 18 | mutual_nearest_neighbors = torch.zeros([image_features.size(0),text_features.size(0)]) 19 | for i in range(image_features.size(0)): 20 | image_k_nearest = image_to_text_nearest_neighbors[i] 21 | text_k_nearest = text_to_image_nearest_neighbors[image_k_nearest] 22 | 23 | # 检查是否当前图片是文本的 k 近邻之一 24 | has_mutual = torch.where(text_k_nearest == i, 1, 0).sum(dim=1) 25 | mutual_text_index = has_mutual.nonzero() 26 | if len(mutual_text_index) !=0: 27 | for idx in mutual_text_index: 28 | mutual_nearest_neighbors[i,image_k_nearest[idx]] = 1 29 | 30 | 31 | return mutual_nearest_neighbors 32 | 33 | def compute_part(image_fetures, text_fetures, pid, logit_scale, image_id=None, factor=0.3, epsilon=1e-6): 34 | batch_size = image_fetures.shape[0] 35 | pid = pid.reshape((batch_size, 1)) # make sure pid size is [batch_size, 1] 36 | 37 | pid_dist = pid - pid.t() 38 | labels = (pid_dist == 0).to(torch.int) 39 | 40 | if image_id != None: 41 | # print("Mix PID and ImageID to create soft label.") 42 | image_id = image_id.reshape((-1, 1)) 43 | image_id_dist = image_id - image_id.t() 44 | image_id_mask = (image_id_dist == 0).float() 45 | labels = (labels - image_id_mask) * factor + image_id_mask 46 | # labels = (labels + image_id_mask) / 2 47 | 48 | image_norm = image_fetures / image_fetures.norm(dim=1, keepdim=True) 49 | text_norm = text_fetures / text_fetures.norm(dim=1, keepdim=True) 50 | 51 | t2i_cosine_theta = text_norm @ image_norm.t() 52 | i2t_cosine_theta = t2i_cosine_theta.t() 53 | 54 | text_proj_image = logit_scale * t2i_cosine_theta 55 | image_proj_text = logit_scale * i2t_cosine_theta 56 | # ['trousers','genders','belongings','shoes','top','hairstyle'] 57 | # text_proj_text = F.softmax(logit_scale * text_norm @ text_norm.t(),dim=1) 58 | # labels = torch.where(text_proj_text>0.4,1,0) | labels 59 | k = 5 if text_fetures.size(0) > 5 else text_fetures.size(0) 60 | t2i_labels = find_mutual_nearest_neighbors(text_fetures,image_fetures,k=k).to(text_fetures.device).to(torch.int) 61 | i2t_labels = find_mutual_nearest_neighbors(image_fetures,text_fetures,k=k).to(text_fetures.device).to(torch.int) 62 | 63 | labels = (i2t_labels * t2i_labels) | labels 64 | # normalize the true matching distribution 65 | labels_distribute = labels / labels.sum(dim=1) 66 | 67 | i2t_pred = F.softmax(image_proj_text, dim=1) 68 | i2t_loss = i2t_pred * (F.log_softmax(image_proj_text, dim=1) - torch.log(labels_distribute + epsilon)) 69 | t2i_pred = F.softmax(text_proj_image, dim=1) 70 | t2i_loss = t2i_pred * (F.log_softmax(text_proj_image, dim=1) - torch.log(labels_distribute + epsilon)) 71 | 72 | loss = torch.mean(torch.sum(i2t_loss, dim=1)) + torch.mean(torch.sum(t2i_loss, dim=1)) 73 | 74 | return loss 75 | 76 | def compute_patch(image_fetures, text_fetures, pid, logit_scale, image_id=None, factor=0.3, epsilon=1e-8): 77 | """ 78 | Similarity Distribution Matching 79 | """ 80 | batch_size,length,num_dim = image_fetures.size() #32,192,512 81 | patch_feats = image_fetures.reshape(batch_size*length,num_dim) 82 | patch_feats = patch_feats / patch_feats.norm(dim=1, keepdim=True) 83 | text_fetures = text_fetures/ text_fetures.norm(dim=1, keepdim=True) 84 | sim_i2t = patch_feats @ text_fetures.t() 85 | sim_t2i = text_fetures @ patch_feats.t() 86 | 87 | label_i2t = sim_i2t.argmax(dim=-1) 88 | label_t2i = sim_t2i.argmax(dim=-1) 89 | loss = F.cross_entropy(logit_scale * sim_i2t, label_i2t) + F.cross_entropy(logit_scale * sim_t2i, label_t2i) 90 | 91 | return loss 92 | 93 | 94 | def compute_sdm(image_fetures, text_fetures, pid, logit_scale, image_id=None, factor=0.3, epsilon=1e-8): 95 | """ 96 | Similarity Distribution Matching 97 | """ 98 | batch_size = image_fetures.shape[0] 99 | pid = pid.reshape((batch_size, 1)) # make sure pid size is [batch_size, 1] 100 | 101 | pid_dist = pid - pid.t() 102 | labels = (pid_dist == 0).float() 103 | 104 | if image_id != None: 105 | # print("Mix PID and ImageID to create soft label.") 106 | image_id = image_id.reshape((-1, 1)) 107 | image_id_dist = image_id - image_id.t() 108 | image_id_mask = (image_id_dist == 0).float() 109 | labels = (labels - image_id_mask) * factor + image_id_mask 110 | # labels = (labels + image_id_mask) / 2 111 | 112 | image_norm = image_fetures / image_fetures.norm(dim=1, keepdim=True) 113 | text_norm = text_fetures / text_fetures.norm(dim=1, keepdim=True) 114 | 115 | t2i_cosine_theta = text_norm @ image_norm.t() 116 | i2t_cosine_theta = t2i_cosine_theta.t() 117 | 118 | text_proj_image = logit_scale * t2i_cosine_theta 119 | image_proj_text = logit_scale * i2t_cosine_theta 120 | 121 | # normalize the true matching distribution 122 | labels_distribute = labels / labels.sum(dim=1) 123 | 124 | i2t_pred = F.softmax(image_proj_text, dim=1) 125 | i2t_loss = i2t_pred * (F.log_softmax(image_proj_text, dim=1) - torch.log(labels_distribute + epsilon)) 126 | t2i_pred = F.softmax(text_proj_image, dim=1) 127 | t2i_loss = t2i_pred * (F.log_softmax(text_proj_image, dim=1) - torch.log(labels_distribute + epsilon)) 128 | 129 | loss = torch.mean(torch.sum(i2t_loss, dim=1)) + torch.mean(torch.sum(t2i_loss, dim=1)) 130 | 131 | return loss 132 | 133 | 134 | def compute_mlm(scores, labels): 135 | ce = nn.CrossEntropyLoss(ignore_index=0) 136 | return ce(scores, labels) 137 | 138 | 139 | def compute_itc(image_features, text_features, logit_scale): 140 | """ 141 | image-text contrastive (ITC) loss, InfoNCE 142 | """ 143 | batch_size = image_features.shape[0] 144 | labels = torch.arange(start=0, end=batch_size, dtype=torch.int64) 145 | labels = labels.to(image_features.device) 146 | 147 | 148 | # normalized features 149 | image_norm = image_features / image_features.norm(dim=-1, keepdim=True) 150 | text_norm = text_features / text_features.norm(dim=-1, keepdim=True) 151 | 152 | # cosine similarity as logits 153 | logits_per_image = logit_scale * image_norm @ text_norm.t() 154 | logits_per_text = logits_per_image.t() 155 | 156 | loss_i = F.cross_entropy(logits_per_image, labels) 157 | loss_t = F.cross_entropy(logits_per_text, labels) 158 | loss = (loss_i + loss_t)/2 159 | 160 | return loss 161 | 162 | 163 | def compute_id(image_logits, text_logits, labels): 164 | """ 165 | Instance loss proposed at http://arxiv.org/abs/1711.05535 166 | """ 167 | criterion = nn.CrossEntropyLoss(reduction="mean") 168 | 169 | loss = criterion(image_logits, labels) + criterion(text_logits, labels) 170 | 171 | return loss / 2 172 | 173 | 174 | def compute_cmpm(image_embeddings, text_embeddings, labels, epsilon=1e-8): 175 | """ 176 | Cross-Modal Projection Matching Loss(CMPM) 177 | :param image_embeddings: Tensor with dtype torch.float32 178 | :param text_embeddings: Tensor with dtype torch.float32 179 | :param labels: Tensor with dtype torch.int32 180 | :return: 181 | i2t_loss: cmpm loss for image projected to text 182 | t2i_loss: cmpm loss for text projected to image 183 | pos_avg_sim: average cosine-similarity for positive pairs 184 | neg_avg_sim: averate cosine-similarity for negative pairs 185 | """ 186 | 187 | batch_size = image_embeddings.shape[0] 188 | labels_reshape = torch.reshape(labels, (batch_size, 1)) 189 | labels_dist = labels_reshape - labels_reshape.t() 190 | labels_mask = (labels_dist == 0).float() 191 | 192 | image_norm = image_embeddings / image_embeddings.norm(dim=1, keepdim=True) 193 | text_norm = text_embeddings / text_embeddings.norm(dim=1, keepdim=True) 194 | image_proj_text = torch.matmul(image_embeddings, text_norm.t()) 195 | text_proj_image = torch.matmul(text_embeddings, image_norm.t()) 196 | 197 | # normalize the true matching distribution 198 | labels_mask_norm = labels_mask / labels_mask.norm(dim=1) 199 | 200 | i2t_pred = F.softmax(image_proj_text, dim=1) 201 | i2t_loss = i2t_pred * (F.log_softmax(image_proj_text, dim=1) - torch.log(labels_mask_norm + epsilon)) 202 | t2i_pred = F.softmax(text_proj_image, dim=1) 203 | t2i_loss = t2i_pred * (F.log_softmax(text_proj_image, dim=1) - torch.log(labels_mask_norm + epsilon)) 204 | 205 | cmpm_loss = torch.mean(torch.sum(i2t_loss, dim=1)) + torch.mean(torch.sum(t2i_loss, dim=1)) 206 | 207 | return cmpm_loss 208 | 209 | -------------------------------------------------------------------------------- /processor/__init__.py: -------------------------------------------------------------------------------- 1 | from .processor import do_pretrain, do_inference -------------------------------------------------------------------------------- /processor/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/processor/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /processor/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/processor/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /processor/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/processor/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /processor/__pycache__/processor.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/processor/__pycache__/processor.cpython-310.pyc -------------------------------------------------------------------------------- /processor/__pycache__/processor.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/processor/__pycache__/processor.cpython-38.pyc -------------------------------------------------------------------------------- /processor/__pycache__/processor.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/processor/__pycache__/processor.cpython-39.pyc -------------------------------------------------------------------------------- /processor/__pycache__/processor_finetune.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/processor/__pycache__/processor_finetune.cpython-39.pyc -------------------------------------------------------------------------------- /processor/processor.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import logging 3 | import random 4 | import time 5 | import torch 6 | from datasets.build import build_filter_loader 7 | from model import objectives 8 | from utils.meter import AverageMeter 9 | from utils.metrics import Evaluator 10 | from utils.comm import get_rank, synchronize 11 | from torch.utils.tensorboard import SummaryWriter 12 | from prettytable import PrettyTable 13 | import torch.nn.functional as F 14 | 15 | def do_pretrain(start_epoch, args, model, train_loader, evaluator0,evaluator1,evaluator2, optimizer, 16 | scheduler, checkpointer, trainset): 17 | 18 | log_period = args.log_period 19 | eval_period = args.eval_period 20 | device = "cuda" 21 | num_epoch = args.num_epoch 22 | arguments = {} 23 | arguments["num_epoch"] = num_epoch 24 | arguments["iteration"] = 0 25 | 26 | logger = logging.getLogger("IRRA.train") 27 | if get_rank() == 0: 28 | logger.info("Validation before training - Epoch: {}".format(-1)) 29 | # top1 = evaluator0.eval(model.module.eval()) 30 | # top1 = evaluator1.eval(model.module.eval()) 31 | # top1 = evaluator2.eval(model.module.eval()) 32 | logger.info('start training') 33 | 34 | meters = { 35 | "loss": AverageMeter(), 36 | "sdm_loss": AverageMeter(), 37 | "itc_loss": AverageMeter(), 38 | "id_loss": AverageMeter(), 39 | "mlm_loss": AverageMeter(), 40 | "img_acc": AverageMeter(), 41 | "txt_acc": AverageMeter(), 42 | "mlm_acc": AverageMeter() 43 | } 44 | 45 | tb_writer = SummaryWriter(log_dir=args.output_dir) 46 | 47 | best_top1_0 = 0.0 48 | best_top1_1 = 0.0 49 | best_top1_2 = 0.0 50 | 51 | # train 52 | for epoch in range(start_epoch, num_epoch + 1): 53 | with torch.no_grad(): 54 | if epoch % 1 == 0: 55 | logger.info('Reconstruct the train loader') 56 | train_loader = build_filter_loader(args, trainset) 57 | 58 | start_time = time.time() 59 | for meter in meters.values(): 60 | meter.reset() 61 | model.train() 62 | 63 | for n_iter, batch in enumerate(train_loader): 64 | # batch = {k: v.cuda() for k, v in batch.items()} 65 | 66 | image = batch['images'].cuda() 67 | text = batch['caption_ids'].cuda() 68 | ori_text = batch['caption_ids_ori'].cuda() 69 | 70 | i_feats, text_feats,fu_i_feats,fu_t_feats = model(image, text, ori_text) 71 | 72 | caption_ids = text 73 | t_feats = text_feats[torch.arange(text_feats.shape[0]), caption_ids.argmax(dim=-1)].float() 74 | logit_scale = torch.ones([]) * (1 / args.temperature) 75 | 76 | loss_sdm = objectives.compute_sdm(i_feats[:,0,:], t_feats, batch['pids'].cuda(), logit_scale) 77 | 78 | total_loss = loss_sdm 79 | with torch.no_grad(): 80 | similarity_matrix = torch.einsum('nld,nkd->nlk', [F.normalize(fu_t_feats,dim=-1), F.normalize(fu_i_feats[:,1:,:],dim=-1)]) 81 | similarity_matrix = similarity_matrix.max(-1)[0] 82 | for idx, sim in zip(batch['image_ids'].data, similarity_matrix): 83 | trainset[idx][-1] = sim.data.cpu().numpy() 84 | 85 | batch_size = batch['images'].shape[0] 86 | meters['loss'].update(total_loss.item(), batch_size) 87 | meters['sdm_loss'].update(loss_sdm, batch_size) 88 | 89 | optimizer.zero_grad() 90 | total_loss.backward() 91 | optimizer.step() 92 | synchronize() 93 | 94 | if (n_iter + 1) % log_period == 0: 95 | info_str = f"Epoch[{epoch}] Iteration[{n_iter + 1}/{len(train_loader)}]" 96 | # log loss and acc info 97 | for k, v in meters.items(): 98 | if v.avg > 0: 99 | info_str += f", {k}: {v.avg:.4f}" 100 | info_str += f", Base Lr: {scheduler.get_lr()[0]:.2e}" 101 | logger.info(info_str) 102 | 103 | tb_writer.add_scalar('lr', scheduler.get_lr()[0], epoch) 104 | # tb_writer.add_scalar('temperature', ret['temperature'], epoch) 105 | for k, v in meters.items(): 106 | if v.avg > 0: 107 | tb_writer.add_scalar(k, v.avg, epoch) 108 | 109 | 110 | scheduler.step() 111 | if get_rank() == 0: 112 | end_time = time.time() 113 | time_per_batch = (end_time - start_time) / 60 114 | logger.info( 115 | "Epoch {} done. Time per batch: {:.3f}[min] Speed: {:.1f}[samples/s]" 116 | .format(epoch, time_per_batch, 117 | train_loader.batch_size / time_per_batch)) 118 | if epoch % eval_period == 0: 119 | logger.info(f"best R1: CUHK {best_top1_0}, ICFG {best_top1_1}, RSTP {best_top1_2}") 120 | if get_rank() == 0: 121 | logger.info("Validation Results - Epoch: {}".format(epoch)) 122 | if args.distributed: 123 | top1_0 = evaluator0.eval(model.module.eval()) 124 | top1_1 = evaluator1.eval(model.module.eval()) 125 | top1_2 = evaluator2.eval(model.module.eval()) 126 | else: 127 | top1_0 = evaluator0.eval(model.module.eval()) 128 | top1_1 = evaluator1.eval(model.module.eval()) 129 | top1_2 = evaluator2.eval(model.module.eval()) 130 | torch.cuda.empty_cache() 131 | if best_top1_0 < top1_0: 132 | best_top1_0 = top1_0 133 | arguments["epoch"] = epoch 134 | checkpointer.save("best0", **arguments) 135 | if best_top1_1 < top1_1: 136 | best_top1_1 = top1_1 137 | arguments["epoch"] = epoch 138 | checkpointer.save("best1", **arguments) 139 | if best_top1_2 < top1_2: 140 | best_top1_2 = top1_2 141 | arguments["epoch"] = epoch 142 | checkpointer.save("best2", **arguments) 143 | if get_rank() == 0: 144 | logger.info(f"best R1: {best_top1_0}, {best_top1_1}, {best_top1_2} at epoch {arguments['epoch']}") 145 | 146 | 147 | def do_inference(model, test_img_loader, test_txt_loader): 148 | 149 | logger = logging.getLogger("IRRA.test") 150 | logger.info("Enter inferencing") 151 | 152 | evaluator = Evaluator(test_img_loader, test_txt_loader) 153 | top1 = evaluator.eval(model.eval()) 154 | -------------------------------------------------------------------------------- /processor/processor_finetune.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import logging 3 | import random 4 | import time 5 | import torch 6 | from datasets.build import build_filter_loader 7 | from model import objectives 8 | from utils.meter import AverageMeter 9 | from utils.metrics import Evaluator 10 | from utils.comm import get_rank, synchronize 11 | from torch.utils.tensorboard import SummaryWriter 12 | from prettytable import PrettyTable 13 | import torch.nn.functional as F 14 | 15 | def do_train(start_epoch, args, model, train_loader, evaluator0,evaluator1,evaluator2, optimizer, 16 | scheduler, checkpointer, trainset): 17 | 18 | log_period = args.log_period 19 | eval_period = args.eval_period 20 | device = "cuda" 21 | num_epoch = args.num_epoch 22 | arguments = {} 23 | arguments["num_epoch"] = num_epoch 24 | arguments["iteration"] = 0 25 | 26 | logger = logging.getLogger("IRRA.train") 27 | if get_rank() == 0: 28 | logger.info("Validation before training - Epoch: {}".format(-1)) 29 | top1 = evaluator0.eval(model.module.eval()) 30 | top1 = evaluator1.eval(model.module.eval()) 31 | top1 = evaluator2.eval(model.module.eval()) 32 | logger.info('start training') 33 | 34 | meters = { 35 | "loss": AverageMeter(), 36 | "sdm_loss": AverageMeter(), 37 | "itc_loss": AverageMeter(), 38 | "id_loss": AverageMeter(), 39 | "mlm_loss": AverageMeter(), 40 | "img_acc": AverageMeter(), 41 | "txt_acc": AverageMeter(), 42 | "mlm_acc": AverageMeter() 43 | } 44 | 45 | tb_writer = SummaryWriter(log_dir=args.output_dir) 46 | 47 | best_top1_0 = 0.0 48 | best_top1_1 = 0.0 49 | best_top1_2 = 0.0 50 | 51 | # train 52 | for epoch in range(start_epoch, num_epoch + 1): 53 | start_time = time.time() 54 | for meter in meters.values(): 55 | meter.reset() 56 | model.train() 57 | 58 | for n_iter, batch in enumerate(train_loader): 59 | batch = {k: v.cuda() for k, v in batch.items()} 60 | 61 | ret = model(batch) 62 | ret = {key: values.mean() for key, values in ret.items()} 63 | total_loss = sum([v for k, v in ret.items() if "loss" in k]) 64 | 65 | batch_size = batch['images'].shape[0] 66 | 67 | meters['loss'].update(total_loss.item(), batch_size) 68 | meters['sdm_loss'].update(ret.get('sdm_loss', 0), batch_size) 69 | meters['itc_loss'].update(ret.get('itc_loss', 0), batch_size) 70 | meters['id_loss'].update(ret.get('id_loss', 0), batch_size) 71 | meters['mlm_loss'].update(ret.get('mlm_loss', 0), batch_size) 72 | 73 | meters['img_acc'].update(ret.get('img_acc', 0), batch_size) 74 | meters['txt_acc'].update(ret.get('txt_acc', 0), batch_size) 75 | meters['mlm_acc'].update(ret.get('mlm_acc', 0), 1) 76 | 77 | optimizer.zero_grad() 78 | total_loss.backward() 79 | optimizer.step() 80 | synchronize() 81 | 82 | if (n_iter + 1) % log_period == 0: 83 | info_str = f"Epoch[{epoch}] Iteration[{n_iter + 1}/{len(train_loader)}]" 84 | # log loss and acc info 85 | for k, v in meters.items(): 86 | if v.avg > 0: 87 | info_str += f", {k}: {v.avg:.4f}" 88 | info_str += f", Base Lr: {scheduler.get_lr()[0]:.2e}" 89 | logger.info(info_str) 90 | 91 | tb_writer.add_scalar('lr', scheduler.get_lr()[0], epoch) 92 | # tb_writer.add_scalar('temperature', ret['temperature'], epoch) 93 | for k, v in meters.items(): 94 | if v.avg > 0: 95 | tb_writer.add_scalar(k, v.avg, epoch) 96 | 97 | 98 | scheduler.step() 99 | if get_rank() == 0: 100 | end_time = time.time() 101 | time_per_batch = (end_time - start_time) / 60 102 | logger.info( 103 | "Epoch {} done. Time per batch: {:.3f}[min] Speed: {:.1f}[samples/s]" 104 | .format(epoch, time_per_batch, 105 | train_loader.batch_size / time_per_batch)) 106 | if epoch % eval_period == 0: 107 | logger.info(f"best R1: CUHK {best_top1_0}, ICFG {best_top1_1}, RSTP {best_top1_2}") 108 | if get_rank() == 0: 109 | logger.info("Validation Results - Epoch: {}".format(epoch)) 110 | if args.distributed: 111 | top1_0 = evaluator0.eval(model.module.eval()) 112 | top1_1 = evaluator1.eval(model.module.eval()) 113 | top1_2 = evaluator2.eval(model.module.eval()) 114 | else: 115 | top1_0 = evaluator0.eval(model.module.eval()) 116 | top1_1 = evaluator1.eval(model.module.eval()) 117 | top1_2 = evaluator2.eval(model.module.eval()) 118 | torch.cuda.empty_cache() 119 | if best_top1_0 < top1_0: 120 | best_top1_0 = top1_0 121 | arguments["epoch"] = epoch 122 | checkpointer.save("best0", **arguments) 123 | if best_top1_1 < top1_1: 124 | best_top1_1 = top1_1 125 | arguments["epoch"] = epoch 126 | checkpointer.save("best1", **arguments) 127 | if best_top1_2 < top1_2: 128 | best_top1_2 = top1_2 129 | arguments["epoch"] = epoch 130 | checkpointer.save("best2", **arguments) 131 | if get_rank() == 0: 132 | logger.info(f"best R1: {best_top1_0}, {best_top1_1}, {best_top1_2} at epoch {arguments['epoch']}") 133 | 134 | 135 | def do_inference(model, test_img_loader, test_txt_loader): 136 | 137 | logger = logging.getLogger("IRRA.test") 138 | logger.info("Enter inferencing") 139 | 140 | evaluator = Evaluator(test_img_loader, test_txt_loader) 141 | top1 = evaluator.eval(model.eval()) 142 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | DATASET_NAME="Testing" 3 | 4 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ 5 | python train.py \ 6 | --name Pretrain \ 7 | --img_aug \ 8 | --batch_size 512 \ 9 | --MLM \ 10 | --dataset_name $DATASET_NAME \ 11 | --loss_names 'sdm' \ 12 | --num_epoch 30 \ 13 | --root_dir /data0/wentao/data/textReID \ 14 | --pretrain LuPerson_PEDES \ 15 | --nam 16 | -------------------------------------------------------------------------------- /solver/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_optimizer, build_lr_scheduler 2 | 3 | __all__ = ["build_optimizer", "build_lr_scheduler"] -------------------------------------------------------------------------------- /solver/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/solver/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /solver/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/solver/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /solver/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/solver/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /solver/__pycache__/build.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/solver/__pycache__/build.cpython-310.pyc -------------------------------------------------------------------------------- /solver/__pycache__/build.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/solver/__pycache__/build.cpython-38.pyc -------------------------------------------------------------------------------- /solver/__pycache__/build.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/solver/__pycache__/build.cpython-39.pyc -------------------------------------------------------------------------------- /solver/__pycache__/lr_scheduler.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/solver/__pycache__/lr_scheduler.cpython-310.pyc -------------------------------------------------------------------------------- /solver/__pycache__/lr_scheduler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/solver/__pycache__/lr_scheduler.cpython-38.pyc -------------------------------------------------------------------------------- /solver/__pycache__/lr_scheduler.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/solver/__pycache__/lr_scheduler.cpython-39.pyc -------------------------------------------------------------------------------- /solver/build.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .lr_scheduler import LRSchedulerWithWarmup 4 | 5 | 6 | def build_optimizer(args, model): 7 | params = [] 8 | 9 | print(f'Using {args.lr_factor} times learning rate for random init module ') 10 | 11 | for key, value in model.named_parameters(): 12 | if not value.requires_grad: 13 | continue 14 | lr = args.lr 15 | weight_decay = args.weight_decay 16 | 17 | # if "cross" in key: 18 | # # use large learning rate for random initialized cross modal module 19 | # lr = args.lr * args.lr_factor # default 5.0 20 | # if "bias" in key: 21 | # lr = args.lr * args.bias_lr_factor 22 | # weight_decay = args.weight_decay_bias 23 | # if "classifier" in key or "mlm_head" in key: 24 | # lr = args.lr * args.lr_factor 25 | # if "base_model" in key: 26 | # lr = args.lr 27 | # else: 28 | # print(key, lr*5) 29 | # lr = args.lr * 5 30 | 31 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 32 | 33 | if args.optimizer == "SGD": 34 | optimizer = torch.optim.SGD( 35 | params, lr=args.lr, momentum=args.momentum 36 | ) 37 | elif args.optimizer == "Adam": 38 | optimizer = torch.optim.Adam( 39 | params, 40 | lr=args.lr, 41 | betas=(args.alpha, args.beta), 42 | eps=1e-3, 43 | ) 44 | elif args.optimizer == "AdamW": 45 | optimizer = torch.optim.AdamW( 46 | params, 47 | lr=args.lr, 48 | betas=(args.alpha, args.beta), 49 | eps=1e-8, 50 | ) 51 | else: 52 | NotImplementedError 53 | 54 | return optimizer 55 | 56 | 57 | def build_lr_scheduler(args, optimizer): 58 | return LRSchedulerWithWarmup( 59 | optimizer, 60 | milestones=args.milestones, 61 | gamma=args.gamma, 62 | warmup_factor=args.warmup_factor, 63 | warmup_epochs=args.warmup_epochs, 64 | warmup_method=args.warmup_method, 65 | total_epochs=args.num_epoch, 66 | mode=args.lrscheduler, 67 | target_lr=args.target_lr, 68 | power=args.power, 69 | ) 70 | -------------------------------------------------------------------------------- /solver/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from bisect import bisect_right 2 | from math import cos, pi 3 | 4 | from torch.optim.lr_scheduler import _LRScheduler 5 | 6 | 7 | class LRSchedulerWithWarmup(_LRScheduler): 8 | def __init__( 9 | self, 10 | optimizer, 11 | milestones, 12 | gamma=0.1, 13 | mode="step", 14 | warmup_factor=1.0 / 3, 15 | warmup_epochs=10, 16 | warmup_method="linear", 17 | total_epochs=100, 18 | target_lr=0, 19 | power=0.9, 20 | last_epoch=-1, 21 | ): 22 | if not list(milestones) == sorted(milestones): 23 | raise ValueError( 24 | "Milestones should be a list of" 25 | " increasing integers. Got {}".format(milestones), 26 | ) 27 | if mode not in ("step", "exp", "poly", "cosine", "linear"): 28 | raise ValueError( 29 | "Only 'step', 'exp', 'poly' or 'cosine' learning rate scheduler accepted" 30 | "got {}".format(mode) 31 | ) 32 | if warmup_method not in ("constant", "linear"): 33 | raise ValueError( 34 | "Only 'constant' or 'linear' warmup_method accepted" 35 | "got {}".format(warmup_method) 36 | ) 37 | self.milestones = milestones 38 | self.mode = mode 39 | self.gamma = gamma 40 | self.warmup_factor = warmup_factor 41 | self.warmup_epochs = warmup_epochs 42 | self.warmup_method = warmup_method 43 | self.total_epochs = total_epochs 44 | self.target_lr = target_lr 45 | self.power = power 46 | super().__init__(optimizer, last_epoch) 47 | 48 | def get_lr(self): 49 | 50 | if self.last_epoch < self.warmup_epochs: 51 | if self.warmup_method == "constant": 52 | warmup_factor = self.warmup_factor 53 | elif self.warmup_method == "linear": 54 | alpha = self.last_epoch / self.warmup_epochs 55 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 56 | return [base_lr * warmup_factor for base_lr in self.base_lrs] 57 | 58 | if self.mode == "step": 59 | return [ 60 | base_lr * self.gamma ** bisect_right(self.milestones, self.last_epoch) 61 | for base_lr in self.base_lrs 62 | ] 63 | 64 | epoch_ratio = (self.last_epoch - self.warmup_epochs) / ( 65 | self.total_epochs - self.warmup_epochs 66 | ) 67 | 68 | if self.mode == "exp": 69 | factor = epoch_ratio 70 | return [base_lr * self.power ** factor for base_lr in self.base_lrs] 71 | if self.mode == "linear": 72 | factor = 1 - epoch_ratio 73 | return [base_lr * factor for base_lr in self.base_lrs] 74 | 75 | if self.mode == "poly": 76 | factor = 1 - epoch_ratio 77 | return [ 78 | self.target_lr + (base_lr - self.target_lr) * self.power ** factor 79 | for base_lr in self.base_lrs 80 | ] 81 | if self.mode == "cosine": 82 | factor = 0.5 * (1 + cos(pi * epoch_ratio)) 83 | return [ 84 | self.target_lr + (base_lr - self.target_lr) * factor 85 | for base_lr in self.base_lrs 86 | ] 87 | raise NotImplementedError 88 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from prettytable import PrettyTable 2 | import os 3 | # os.environ['CUDA_VISIBLE_DEVICES'] = '3' 4 | import torch 5 | import numpy as np 6 | import time 7 | import os.path as op 8 | 9 | from datasets import build_dataloader 10 | from processor.processor import do_inference 11 | from utils.checkpoint import Checkpointer 12 | from utils.logger import setup_logger 13 | from model import build_model 14 | from utils.metrics import Evaluator 15 | import argparse 16 | from utils.iotools import load_train_configs 17 | 18 | 19 | if __name__ == '__main__': 20 | parser = argparse.ArgumentParser(description="IRRA Test") 21 | parser.add_argument("--config_file", default='logs/CUHK-PEDES/iira/configs.yaml') 22 | args = parser.parse_args() 23 | args = load_train_configs(args.config_file) 24 | 25 | args.training = False 26 | logger = setup_logger('IRRA', save_dir=args.output_dir, if_train=args.training) 27 | logger.info(args) 28 | device = "cuda" 29 | 30 | test_img_loader, test_txt_loader, num_classes = build_dataloader(args) 31 | model = build_model(args, num_classes=num_classes) 32 | checkpointer = Checkpointer(model) 33 | checkpointer.load(f=op.join(args.output_dir, 'best.pth')) 34 | model.to(device) 35 | do_inference(model, test_img_loader, test_txt_loader) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import os 3 | import os.path as op 4 | import torch 5 | import numpy as np 6 | import random 7 | import time 8 | import torch.nn as nn 9 | 10 | from datasets import build_dataloader 11 | from datasets.bases import ImageTextMLMDataset 12 | from datasets.build import build_zero_shot_loader 13 | from processor.processor import do_pretrain 14 | from processor.processor_finetune import do_train 15 | from utils.checkpoint import Checkpointer 16 | from utils.iotools import save_train_configs 17 | from utils.logger import setup_logger 18 | from solver import build_optimizer, build_lr_scheduler 19 | from model import build_model,build_finetune_model 20 | from utils.metrics import Evaluator 21 | from utils.options import get_args 22 | from utils.comm import get_rank, synchronize 23 | 24 | 25 | def set_seed(seed=0): 26 | torch.manual_seed(seed) 27 | torch.cuda.manual_seed(seed) 28 | torch.cuda.manual_seed_all(seed) 29 | np.random.seed(seed) 30 | random.seed(seed) 31 | torch.backends.cudnn.deterministic = True 32 | torch.backends.cudnn.benchmark = True 33 | 34 | 35 | if __name__ == '__main__': 36 | args = get_args() 37 | set_seed(1+get_rank()) 38 | name = args.name 39 | 40 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 41 | args.distributed = num_gpus > 1 42 | 43 | if args.distributed: 44 | torch.cuda.set_device(args.local_rank) 45 | torch.distributed.init_process_group(backend="nccl", init_method="env://") 46 | synchronize() 47 | 48 | device = "cuda" 49 | cur_time = time.strftime("%Y%m%d_%H%M%S", time.localtime()) 50 | args.output_dir = op.join(args.output_dir, args.dataset_name, f'{cur_time}_{name}') 51 | logger = setup_logger('IRRA', save_dir=args.output_dir, if_train=args.training, distributed_rank=get_rank()) 52 | logger.info("Using {} GPUs".format(num_gpus)) 53 | logger.info(str(args).replace(',', '\n')) 54 | save_train_configs(args.output_dir, args) 55 | 56 | # get image-text pair datasets dataloader 57 | trainset ,train_loader, val_img_loader0, val_txt_loader0, val_img_loader1, val_txt_loader1, val_img_loader2, val_txt_loader2, num_classes = build_zero_shot_loader(args) 58 | if args.nam: 59 | model = build_model(args, num_classes) 60 | else: 61 | model = build_finetune_model(args, num_classes) 62 | logger.info('Total params: %2.fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0)) 63 | if args.finetune: 64 | logger.info("loading {} model".format(args.finetune)) 65 | param_dict = torch.load(args.finetune,map_location='cpu')['model'] 66 | for k in list(param_dict.keys()): 67 | refine_k = k.replace('module.','') 68 | param_dict[refine_k] = param_dict[k].detach().clone() 69 | del param_dict[k] 70 | model.load_state_dict(param_dict) 71 | # model = model.float() 72 | model.cuda() 73 | model = nn.DataParallel(model) 74 | 75 | if args.distributed: 76 | model = torch.nn.parallel.DistributedDataParallel( 77 | model, 78 | device_ids=[args.local_rank], 79 | output_device=args.local_rank, 80 | # this should be removed if we update BatchNorm stats 81 | broadcast_buffers=False, 82 | ) 83 | optimizer = build_optimizer(args, model) 84 | scheduler = build_lr_scheduler(args, optimizer) 85 | 86 | is_master = get_rank() == 0 87 | checkpointer = Checkpointer(model, optimizer, scheduler, args.output_dir, is_master) 88 | evaluator0 = Evaluator(val_img_loader0, val_txt_loader0) 89 | evaluator1 = Evaluator(val_img_loader1, val_txt_loader1) 90 | evaluator2 = Evaluator(val_img_loader2, val_txt_loader2) 91 | 92 | start_epoch = 1 93 | if args.resume: 94 | checkpoint = checkpointer.resume(args.resume_ckpt_file) 95 | start_epoch = checkpoint['epoch'] 96 | 97 | if args.nam: 98 | do_pretrain(start_epoch, args, model, train_loader, evaluator0,evaluator1,evaluator2, optimizer, scheduler, checkpointer, trainset) 99 | else: 100 | do_train(start_epoch, args, model, train_loader, evaluator0,evaluator1,evaluator2, optimizer, scheduler, checkpointer, trainset) 101 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/utils/__init__.py -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/utils/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/utils/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/checkpoint.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/utils/__pycache__/checkpoint.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/checkpoint.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/utils/__pycache__/checkpoint.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/checkpoint.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/utils/__pycache__/checkpoint.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/comm.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/utils/__pycache__/comm.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/comm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/utils/__pycache__/comm.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/comm.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/utils/__pycache__/comm.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/iotools.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/utils/__pycache__/iotools.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/iotools.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/utils/__pycache__/iotools.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/iotools.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/utils/__pycache__/iotools.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/logger.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/utils/__pycache__/logger.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/logger.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/utils/__pycache__/logger.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/logger.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/utils/__pycache__/logger.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/meter.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/utils/__pycache__/meter.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/meter.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/utils/__pycache__/meter.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/meter.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/utils/__pycache__/meter.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/metrics.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/utils/__pycache__/metrics.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/metrics.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/utils/__pycache__/metrics.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/metrics.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/utils/__pycache__/metrics.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/options.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/utils/__pycache__/options.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/options.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/utils/__pycache__/options.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/options.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/utils/__pycache__/options.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/simple_tokenizer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/utils/__pycache__/simple_tokenizer.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/simple_tokenizer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/utils/__pycache__/simple_tokenizer.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/simple_tokenizer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WentaoTan/MLLM4Text-ReID/8f19e67c5233b38dce040ac320f97aad79cf4b06/utils/__pycache__/simple_tokenizer.cpython-39.pyc -------------------------------------------------------------------------------- /utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import logging 3 | import os 4 | from collections import OrderedDict 5 | 6 | import torch 7 | 8 | 9 | class Checkpointer: 10 | def __init__( 11 | self, 12 | model, 13 | optimizer=None, 14 | scheduler=None, 15 | save_dir="", 16 | save_to_disk=None, 17 | logger=None, 18 | ): 19 | self.model = model 20 | self.optimizer = optimizer 21 | self.scheduler = scheduler 22 | self.save_dir = save_dir 23 | self.save_to_disk = save_to_disk 24 | if logger is None: 25 | logger = logging.getLogger(__name__) 26 | self.logger = logger 27 | 28 | def save(self, name, **kwargs): 29 | if not self.save_dir: 30 | return 31 | 32 | if not self.save_to_disk: 33 | return 34 | 35 | data = {} 36 | data["model"] = self.model.state_dict() 37 | if self.optimizer is not None: 38 | data["optimizer"] = self.optimizer.state_dict() 39 | if self.scheduler is not None: 40 | data["scheduler"] = self.scheduler.state_dict() 41 | data.update(kwargs) 42 | 43 | save_file = os.path.join(self.save_dir, "{}.pth".format(name)) 44 | self.logger.info("Saving checkpoint to {}".format(save_file)) 45 | torch.save(data, save_file) 46 | 47 | def load(self, f=None): 48 | if not f: 49 | # no checkpoint could be found 50 | self.logger.info("No checkpoint found.") 51 | return {} 52 | self.logger.info("Loading checkpoint from {}".format(f)) 53 | checkpoint = self._load_file(f) 54 | self._load_model(checkpoint) 55 | 56 | def resume(self, f=None): 57 | if not f: 58 | # no checkpoint could be found 59 | self.logger.info("No checkpoint found.") 60 | raise IOError(f"No Checkpoint file found on {f}") 61 | self.logger.info("Loading checkpoint from {}".format(f)) 62 | checkpoint = self._load_file(f) 63 | self._load_model(checkpoint) 64 | if "optimizer" in checkpoint and self.optimizer: 65 | self.logger.info("Loading optimizer from {}".format(f)) 66 | self.optimizer.load_state_dict(checkpoint.pop("optimizer")) 67 | if "scheduler" in checkpoint and self.scheduler: 68 | self.logger.info("Loading scheduler from {}".format(f)) 69 | self.scheduler.load_state_dict(checkpoint.pop("scheduler")) 70 | # return any further checkpoint data 71 | return checkpoint 72 | 73 | def _load_file(self, f): 74 | return torch.load(f, map_location=torch.device("cpu")) 75 | 76 | def _load_model(self, checkpoint, except_keys=None): 77 | load_state_dict(self.model, checkpoint.pop("model"), except_keys) 78 | 79 | 80 | def check_key(key, except_keys): 81 | if except_keys is None: 82 | return False 83 | else: 84 | for except_key in except_keys: 85 | if except_key in key: 86 | return True 87 | return False 88 | 89 | 90 | def align_and_update_state_dicts(model_state_dict, loaded_state_dict, except_keys=None): 91 | current_keys = sorted(list(model_state_dict.keys())) 92 | loaded_keys = sorted(list(loaded_state_dict.keys())) 93 | # get a matrix of string matches, where each (i, j) entry correspond to the size of the 94 | # loaded_key string, if it matches 95 | match_matrix = [ 96 | len(j) if i.endswith(j) else 0 for i in current_keys for j in loaded_keys 97 | ] 98 | match_matrix = torch.as_tensor(match_matrix).view( 99 | len(current_keys), len(loaded_keys) 100 | ) 101 | max_match_size, idxs = match_matrix.max(1) 102 | # remove indices that correspond to no-match 103 | idxs[max_match_size == 0] = -1 104 | 105 | # used for logging 106 | max_size = max([len(key) for key in current_keys]) if current_keys else 1 107 | max_size_loaded = max([len(key) for key in loaded_keys]) if loaded_keys else 1 108 | log_str_template = "{: <{}} loaded from {: <{}} of shape {}" 109 | logger = logging.getLogger("PersonSearch.checkpoint") 110 | for idx_new, idx_old in enumerate(idxs.tolist()): 111 | if idx_old == -1: 112 | continue 113 | key = current_keys[idx_new] 114 | key_old = loaded_keys[idx_old] 115 | if check_key(key, except_keys): 116 | continue 117 | model_state_dict[key] = loaded_state_dict[key_old] 118 | logger.info( 119 | log_str_template.format( 120 | key, 121 | max_size, 122 | key_old, 123 | max_size_loaded, 124 | tuple(loaded_state_dict[key_old].shape), 125 | ) 126 | ) 127 | 128 | 129 | def strip_prefix_if_present(state_dict, prefix): 130 | keys = sorted(state_dict.keys()) 131 | if not all(key.startswith(prefix) for key in keys): 132 | return state_dict 133 | stripped_state_dict = OrderedDict() 134 | for key, value in state_dict.items(): 135 | stripped_state_dict[key.replace(prefix, "")] = value 136 | return stripped_state_dict 137 | 138 | 139 | def load_state_dict(model, loaded_state_dict, except_keys=None): 140 | model_state_dict = model.state_dict() 141 | # if the state_dict comes from a model that was wrapped in a 142 | # DataParallel or DistributedDataParallel during serialization, 143 | # remove the "module" prefix before performing the matching 144 | loaded_state_dict = strip_prefix_if_present(loaded_state_dict, prefix="module.") 145 | align_and_update_state_dicts(model_state_dict, loaded_state_dict, except_keys) 146 | 147 | # use strict loading 148 | model.load_state_dict(model_state_dict) 149 | -------------------------------------------------------------------------------- /utils/comm.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains primitives for multi-gpu communication. 3 | This is useful when doing distributed training. 4 | """ 5 | 6 | import pickle 7 | 8 | import torch 9 | import torch.distributed as dist 10 | 11 | 12 | def get_world_size(): 13 | if not dist.is_available(): 14 | return 1 15 | if not dist.is_initialized(): 16 | return 1 17 | return dist.get_world_size() 18 | 19 | 20 | def get_rank(): 21 | if not dist.is_available(): 22 | return 0 23 | if not dist.is_initialized(): 24 | return 0 25 | return dist.get_rank() 26 | 27 | 28 | def is_main_process(): 29 | return get_rank() == 0 30 | 31 | 32 | def synchronize(): 33 | """ 34 | Helper function to synchronize (barrier) among all processes when 35 | using distributed training 36 | """ 37 | if not dist.is_available(): 38 | return 39 | if not dist.is_initialized(): 40 | return 41 | world_size = dist.get_world_size() 42 | if world_size == 1: 43 | return 44 | dist.barrier() 45 | 46 | 47 | def all_gather(data): 48 | """ 49 | Run all_gather on arbitrary picklable data (not necessarily tensors) 50 | Args: 51 | data: any picklable object 52 | Returns: 53 | list[data]: list of data gathered from each rank 54 | """ 55 | world_size = get_world_size() 56 | if world_size == 1: 57 | return [data] 58 | 59 | # serialized to a Tensor 60 | buffer = pickle.dumps(data) 61 | storage = torch.ByteStorage.from_buffer(buffer) 62 | tensor = torch.ByteTensor(storage).to("cuda") 63 | 64 | # obtain Tensor size of each rank 65 | local_size = torch.IntTensor([tensor.numel()]).to("cuda") 66 | size_list = [torch.IntTensor([0]).to("cuda") for _ in range(world_size)] 67 | dist.all_gather(size_list, local_size) 68 | size_list = [int(size.item()) for size in size_list] 69 | max_size = max(size_list) 70 | 71 | # receiving Tensor from all ranks 72 | # we pad the tensor because torch all_gather does not support 73 | # gathering tensors of different shapes 74 | tensor_list = [] 75 | for _ in size_list: 76 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda")) 77 | if local_size != max_size: 78 | padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda") 79 | tensor = torch.cat((tensor, padding), dim=0) 80 | dist.all_gather(tensor_list, tensor) 81 | 82 | data_list = [] 83 | for size, tensor in zip(size_list, tensor_list): 84 | buffer = tensor.cpu().numpy().tobytes()[:size] 85 | data_list.append(pickle.loads(buffer)) 86 | 87 | return data_list 88 | 89 | 90 | def reduce_dict(input_dict, average=True): 91 | """ 92 | Args: 93 | input_dict (dict): all the values will be reduced 94 | average (bool): whether to do average or sum 95 | Reduce the values in the dictionary from all processes so that process with rank 96 | 0 has the averaged results. Returns a dict with the same fields as 97 | input_dict, after reduction. 98 | """ 99 | world_size = get_world_size() 100 | if world_size < 2: 101 | return input_dict 102 | with torch.no_grad(): 103 | names = [] 104 | values = [] 105 | # sort the keys so that they are consistent across processes 106 | for k in sorted(input_dict.keys()): 107 | names.append(k) 108 | values.append(input_dict[k]) 109 | values = torch.stack(values, dim=0) 110 | dist.reduce(values, dst=0) 111 | if dist.get_rank() == 0 and average: 112 | # only main process gets accumulated, so only divide by 113 | # world_size in this case 114 | values /= world_size 115 | reduced_dict = {k: v for k, v in zip(names, values)} 116 | return reduced_dict 117 | -------------------------------------------------------------------------------- /utils/iotools.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | from PIL import Image, ImageFile 7 | import errno 8 | import json 9 | import pickle as pkl 10 | import os 11 | import os.path as osp 12 | import yaml 13 | from easydict import EasyDict as edict 14 | 15 | ImageFile.LOAD_TRUNCATED_IMAGES = True 16 | 17 | 18 | def read_image(img_path): 19 | """Keep reading image until succeed. 20 | This can avoid IOError incurred by heavy IO process.""" 21 | got_img = False 22 | if not osp.exists(img_path): 23 | raise IOError("{} does not exist".format(img_path)) 24 | while not got_img: 25 | try: 26 | img = Image.open(img_path).convert('RGB') 27 | got_img = True 28 | except IOError: 29 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path)) 30 | pass 31 | return img 32 | 33 | 34 | def mkdir_if_missing(directory): 35 | if not osp.exists(directory): 36 | try: 37 | os.makedirs(directory) 38 | except OSError as e: 39 | if e.errno != errno.EEXIST: 40 | raise 41 | 42 | 43 | def check_isfile(path): 44 | isfile = osp.isfile(path) 45 | if not isfile: 46 | print("=> Warning: no file found at '{}' (ignored)".format(path)) 47 | return isfile 48 | 49 | 50 | def read_json(fpath): 51 | with open(fpath, 'r') as f: 52 | obj = json.load(f) 53 | return obj 54 | 55 | 56 | def write_json(obj, fpath): 57 | mkdir_if_missing(osp.dirname(fpath)) 58 | with open(fpath, 'w') as f: 59 | json.dump(obj, f, indent=4, separators=(',', ': ')) 60 | 61 | 62 | def get_text_embedding(path, length): 63 | with open(path, 'rb') as f: 64 | word_frequency = pkl.load(f) 65 | 66 | 67 | def save_train_configs(path, args): 68 | if not os.path.exists(path): 69 | os.makedirs(path) 70 | with open(f'{path}/configs.yaml', 'w') as f: 71 | yaml.dump(vars(args), f, default_flow_style=False) 72 | 73 | def load_train_configs(path): 74 | with open(path, 'r') as f: 75 | args = yaml.load(f, Loader=yaml.FullLoader) 76 | return edict(args) -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | import os.path as op 5 | 6 | 7 | def setup_logger(name, save_dir, if_train, distributed_rank=0): 8 | logger = logging.getLogger(name) 9 | logger.setLevel(logging.DEBUG) 10 | 11 | # don't log results for the non-master process 12 | if distributed_rank > 0: 13 | return logger 14 | 15 | ch = logging.StreamHandler(stream=sys.stdout) 16 | ch.setLevel(logging.DEBUG) 17 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 18 | ch.setFormatter(formatter) 19 | logger.addHandler(ch) 20 | 21 | if not op.exists(save_dir): 22 | print(f"{save_dir} is not exists, create given directory") 23 | os.makedirs(save_dir) 24 | if if_train: 25 | fh = logging.FileHandler(os.path.join(save_dir, "train_log.txt"), mode='w') 26 | else: 27 | fh = logging.FileHandler(os.path.join(save_dir, "test_log.txt"), mode='a') 28 | fh.setLevel(logging.DEBUG) 29 | fh.setFormatter(formatter) 30 | logger.addHandler(fh) 31 | 32 | return logger -------------------------------------------------------------------------------- /utils/meter.py: -------------------------------------------------------------------------------- 1 | class AverageMeter(object): 2 | """Computes and stores the average and current value""" 3 | 4 | def __init__(self): 5 | self.val = 0 6 | self.avg = 0 7 | self.sum = 0 8 | self.count = 0 9 | 10 | def reset(self): 11 | self.val = 0 12 | self.avg = 0 13 | self.sum = 0 14 | self.count = 0 15 | 16 | def update(self, val, n=1): 17 | self.val = val 18 | self.sum += val * n 19 | self.count += n 20 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | from prettytable import PrettyTable 2 | import torch 3 | import numpy as np 4 | import os 5 | import torch.nn.functional as F 6 | import logging 7 | 8 | 9 | def rank(similarity, q_pids, g_pids, max_rank=10, get_mAP=True): 10 | if get_mAP: 11 | indices = torch.argsort(similarity.data.cpu(), dim=1, descending=True) 12 | indices = indices.to(similarity.device) 13 | else: 14 | # acclerate sort with topk 15 | _, indices = torch.topk( 16 | similarity, k=max_rank, dim=1, largest=True, sorted=True 17 | ) # q * topk 18 | pred_labels = g_pids[indices.cpu()] # q * k 19 | matches = pred_labels.eq(q_pids.view(-1, 1)) # q * k 20 | 21 | all_cmc = matches[:, :max_rank].cumsum(1) # cumulative sum 22 | all_cmc[all_cmc > 1] = 1 23 | all_cmc = all_cmc.float().mean(0) * 100 24 | # all_cmc = all_cmc[topk - 1] 25 | 26 | if not get_mAP: 27 | return all_cmc, indices 28 | 29 | num_rel = matches.sum(1) # q 30 | tmp_cmc = matches.cumsum(1) # q * k 31 | 32 | inp = [tmp_cmc[i][match_row.nonzero()[-1]] / (match_row.nonzero()[-1] + 1.) for i, match_row in enumerate(matches)] 33 | mINP = torch.cat(inp).mean() * 100 34 | 35 | tmp_cmc = [tmp_cmc[:, i] / (i + 1.0) for i in range(tmp_cmc.shape[1])] 36 | tmp_cmc = torch.stack(tmp_cmc, 1) * matches 37 | AP = tmp_cmc.sum(1) / num_rel # q 38 | mAP = AP.mean() * 100 39 | 40 | return all_cmc, mAP, mINP, indices 41 | 42 | 43 | class Evaluator(): 44 | def __init__(self, img_loader, txt_loader): 45 | self.img_loader = img_loader # gallery 46 | self.txt_loader = txt_loader # query 47 | self.logger = logging.getLogger("IRRA.eval") 48 | 49 | def _compute_embedding(self, model): 50 | model = model.eval() 51 | device = next(model.parameters()).device 52 | 53 | qids, gids, qfeats, gfeats = [], [], [], [] 54 | # text 55 | for pid, caption in self.txt_loader: 56 | caption = caption.to(device) 57 | with torch.no_grad(): 58 | text_feat = model.encode_text(caption) 59 | qids.append(pid.view(-1)) # flatten 60 | qfeats.append(text_feat.data.cpu()) 61 | qids = torch.cat(qids, 0) 62 | qfeats = torch.cat(qfeats, 0) 63 | 64 | # image 65 | for pid, img in self.img_loader: 66 | img = img.to(device) 67 | with torch.no_grad(): 68 | img_feat = model.encode_image(img) 69 | gids.append(pid.view(-1)) # flatten 70 | gfeats.append(img_feat.data.cpu()) 71 | gids = torch.cat(gids, 0) 72 | gfeats = torch.cat(gfeats, 0) 73 | 74 | return qfeats.cuda(), gfeats.cuda(), qids, gids 75 | 76 | def eval(self, model, i2t_metric=False): 77 | 78 | qfeats, gfeats, qids, gids = self._compute_embedding(model) 79 | 80 | qfeats = F.normalize(qfeats, p=2, dim=1) # text features 81 | gfeats = F.normalize(gfeats, p=2, dim=1) # image features 82 | 83 | similarity = qfeats @ gfeats.t() 84 | 85 | t2i_cmc, t2i_mAP, t2i_mINP, _ = rank(similarity=similarity, q_pids=qids, g_pids=gids, max_rank=10, get_mAP=True) 86 | t2i_cmc, t2i_mAP, t2i_mINP = t2i_cmc.numpy(), t2i_mAP.numpy(), t2i_mINP.numpy() 87 | table = PrettyTable(["task", "R1", "R5", "R10", "mAP", "mINP"]) 88 | table.add_row(['t2i', t2i_cmc[0], t2i_cmc[4], t2i_cmc[9], t2i_mAP, t2i_mINP]) 89 | 90 | if i2t_metric: 91 | i2t_cmc, i2t_mAP, i2t_mINP, _ = rank(similarity=similarity.t(), q_pids=gids, g_pids=qids, max_rank=10, get_mAP=True) 92 | i2t_cmc, i2t_mAP, i2t_mINP = i2t_cmc.numpy(), i2t_mAP.numpy(), i2t_mINP.numpy() 93 | table.add_row(['i2t', i2t_cmc[0], i2t_cmc[4], i2t_cmc[9], i2t_mAP, i2t_mINP]) 94 | # table.float_format = '.4' 95 | table.custom_format["R1"] = lambda f, v: f"{v:.3f}" 96 | table.custom_format["R5"] = lambda f, v: f"{v:.3f}" 97 | table.custom_format["R10"] = lambda f, v: f"{v:.3f}" 98 | table.custom_format["mAP"] = lambda f, v: f"{v:.3f}" 99 | table.custom_format["mINP"] = lambda f, v: f"{v:.3f}" 100 | self.logger.info('\n' + str(table)) 101 | 102 | return t2i_cmc[0] 103 | -------------------------------------------------------------------------------- /utils/options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def get_args(): 5 | parser = argparse.ArgumentParser(description="IRRA Args") 6 | ######################## general settings ######################## 7 | parser.add_argument("--local_rank", default=0, type=int) 8 | parser.add_argument("--name", default="baseline", help="experiment name to save") 9 | parser.add_argument("--output_dir", default="logs") 10 | parser.add_argument("--log_period", default=100) 11 | parser.add_argument("--eval_period", default=1) 12 | parser.add_argument("--val_dataset", default="test") # use val set when evaluate, if test use test set 13 | parser.add_argument("--resume", default=False, action='store_true') 14 | parser.add_argument("--resume_ckpt_file", default="", help='resume from ...') 15 | 16 | parser.add_argument("--finetune", type=str, default="") 17 | parser.add_argument("--pretrain", type=str, default="") 18 | parser.add_argument("--nam", default=False, action='store_true') 19 | 20 | ######################## model general settings ######################## 21 | parser.add_argument("--pretrain_choice", default='ViT-B/16') # whether use pretrained model 22 | parser.add_argument("--temperature", type=float, default=0.02, help="initial temperature value, if 0, don't use temperature") 23 | parser.add_argument("--img_aug", default=False, action='store_true') 24 | 25 | ## cross modal transfomer setting 26 | parser.add_argument("--cmt_depth", type=int, default=4, help="cross modal transformer self attn layers") 27 | parser.add_argument("--masked_token_rate", type=float, default=0.8, help="masked token rate for mlm task") 28 | parser.add_argument("--masked_token_unchanged_rate", type=float, default=0.1, help="masked token unchanged rate") 29 | parser.add_argument("--lr_factor", type=float, default=5.0, help="lr factor for random init self implement module") 30 | parser.add_argument("--MLM", default=False, action='store_true', help="whether to use Mask Language Modeling dataset") 31 | 32 | ######################## loss settings ######################## 33 | parser.add_argument("--loss_names", default='sdm+id+mlm', help="which loss to use ['mlm', 'cmpm', 'id', 'itc', 'sdm']") 34 | parser.add_argument("--mlm_loss_weight", type=float, default=1.0, help="mlm loss weight") 35 | parser.add_argument("--id_loss_weight", type=float, default=1.0, help="id loss weight") 36 | 37 | ######################## vison trainsformer settings ######################## 38 | parser.add_argument("--img_size", type=tuple, default=(384, 128)) 39 | parser.add_argument("--stride_size", type=int, default=16) 40 | 41 | ######################## text transformer settings ######################## 42 | parser.add_argument("--text_length", type=int, default=77) 43 | parser.add_argument("--vocab_size", type=int, default=49408) 44 | 45 | ######################## solver ######################## 46 | parser.add_argument("--optimizer", type=str, default="Adam", help="[SGD, Adam, Adamw]") 47 | parser.add_argument("--lr", type=float, default=1e-5) 48 | parser.add_argument("--bias_lr_factor", type=float, default=2.) 49 | parser.add_argument("--momentum", type=float, default=0.9) 50 | parser.add_argument("--weight_decay", type=float, default=4e-5) 51 | parser.add_argument("--weight_decay_bias", type=float, default=0.) 52 | parser.add_argument("--alpha", type=float, default=0.9) 53 | parser.add_argument("--beta", type=float, default=0.999) 54 | 55 | ######################## scheduler ######################## 56 | parser.add_argument("--num_epoch", type=int, default=60) 57 | parser.add_argument("--milestones", type=int, nargs='+', default=(20, 40)) 58 | parser.add_argument("--gamma", type=float, default=0.1) 59 | parser.add_argument("--warmup_factor", type=float, default=0.1) 60 | parser.add_argument("--warmup_epochs", type=int, default=5) 61 | parser.add_argument("--warmup_method", type=str, default="linear") 62 | parser.add_argument("--lrscheduler", type=str, default="cosine") 63 | parser.add_argument("--target_lr", type=float, default=0) 64 | parser.add_argument("--power", type=float, default=0.9) 65 | 66 | ######################## dataset ######################## 67 | parser.add_argument("--dataset_name", default="CUHK-PEDES", help="[CUHK-PEDES, ICFG-PEDES, RSTPReid]") 68 | parser.add_argument("--sampler", default="random", help="choose sampler from [idtentity, random]") 69 | parser.add_argument("--num_instance", type=int, default=4) 70 | parser.add_argument("--root_dir", default="./data") 71 | parser.add_argument("--batch_size", type=int, default=128) 72 | parser.add_argument("--test_batch_size", type=int, default=512) 73 | parser.add_argument("--num_workers", type=int, default=8) 74 | parser.add_argument("--test", dest='training', default=True, action='store_false') 75 | 76 | args = parser.parse_args() 77 | 78 | return args -------------------------------------------------------------------------------- /utils/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "../data/bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | 74 | vocab.pop(-1) # remove last one in vocab(jekyll) to keep vocab_size unchanged 75 | vocab.extend(['<|mask|>', '<|startoftext|>', '<|endoftext|>']) # vocab_size 49408 76 | # vocab.extend(['<|startoftext|>', '<|endoftext|>']) # vocab_size 49408 77 | self.encoder = dict(zip(vocab, range(len(vocab)))) 78 | self.decoder = {v: k for k, v in self.encoder.items()} 79 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 80 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|mask|>': '<|mask|>', '<|endoftext|>': '<|endoftext|>'} 81 | self.pat = re.compile(r"""<\|startoftext\|>|<\|mask\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 82 | 83 | def bpe(self, token): 84 | if token in self.cache: 85 | return self.cache[token] 86 | word = tuple(token[:-1]) + ( token[-1] + '',) 87 | pairs = get_pairs(word) 88 | 89 | if not pairs: 90 | return token+'' 91 | 92 | while True: 93 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 94 | if bigram not in self.bpe_ranks: 95 | break 96 | first, second = bigram 97 | new_word = [] 98 | i = 0 99 | while i < len(word): 100 | try: 101 | j = word.index(first, i) 102 | new_word.extend(word[i:j]) 103 | i = j 104 | except: 105 | new_word.extend(word[i:]) 106 | break 107 | 108 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 109 | new_word.append(first+second) 110 | i += 2 111 | else: 112 | new_word.append(word[i]) 113 | i += 1 114 | new_word = tuple(new_word) 115 | word = new_word 116 | if len(word) == 1: 117 | break 118 | else: 119 | pairs = get_pairs(word) 120 | word = ' '.join(word) 121 | self.cache[token] = word 122 | return word 123 | 124 | def encode(self, text): 125 | bpe_tokens = [] 126 | text = whitespace_clean(basic_clean(text)).lower() 127 | for token in re.findall(self.pat, text): 128 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 129 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 130 | return bpe_tokens 131 | 132 | def decode(self, tokens): 133 | text = ''.join([self.decoder[token] for token in tokens]) 134 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 135 | return text 136 | --------------------------------------------------------------------------------