├── .gitignore ├── README.md ├── configs ├── activitynet.yaml ├── didemo.yaml ├── msrvtt-7k.yaml ├── msrvtt-9k.yaml ├── msvd.yaml └── vatex.yaml ├── dataloaders ├── data_dataloaders.py ├── dataloader_activitynet_retrieval.py ├── dataloader_didemo_retrieval.py ├── dataloader_msrvtt_retrieval.py ├── dataloader_msvd_retrieval.py ├── rawvideo_util.py ├── text_dataloader.py └── video_dataloader.py ├── do_eval.sh ├── do_extract_text_feat.sh ├── do_extract_video_feat.sh ├── evaluation.py ├── extract_feat.py ├── images └── teachclip.png ├── metrics.py ├── modules ├── __init__.py ├── bpe_simple_vocab_16e6.txt.gz ├── cross-base │ └── cross_config.json ├── differential_topk.py ├── file_utils.py ├── modeling.py ├── modeling_ts2net.py ├── modeling_xclip.py ├── modeling_xpool.py ├── module_clip.py ├── module_clip_ts2net.py ├── module_cross.py ├── optimization.py ├── tokenization_clip.py ├── transformer_xpool.py ├── until_config.py └── until_module.py ├── requirements.txt ├── train.py ├── util.py └── utils ├── bigfile.py └── txt2bin.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | **__pycache__** 3 | **checkpoint** 4 | ***ckpts* -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Holistic Features are almost Sufficient for Text-to-Video Retrieval 2 | 3 | The official source code of our CVPR24 paper TeachCLIP, "[Holistic Features are almost Sufficient for Text-to-Video Retrieval](https://openaccess.thecvf.com/content/CVPR2024/papers/Tian_Holistic_Features_are_almost_Sufficient_for_Text-to-Video_Retrieval_CVPR_2024_paper.pdf)". 4 | 5 | ![](./images/teachclip.png) 6 | 7 | ## Environment 8 | 9 | We used Anaconda to setup a deep learning workspace that supports PyTorch. Run the following script to install all the required packages. 10 | 11 | ```shell 12 | conda create -n TeachCLIP python==3.9 -y 13 | conda activate TeachCLIP 14 | git clone https://github.com/ruc-aimc-lab/TeachCLIP.git 15 | cd TeachCLIP 16 | pip install -r requirements.txt 17 | ``` 18 | 19 | 20 | ## Data 21 | 22 | ### Data download 23 | 24 | + We provide annotations of five datasets and checkpoints of three teachers ([X-CLIP](https://github.com/xuguohai/X-CLIP), [TS2-Net](https://github.com/yuqi657/ts2_net) and [XPool](https://github.com/layer6ai-labs/xpool)) trained on five datasets at [Google drive](https://drive.google.com/drive/folders/1cU0ehXfucf4M5IyDRSxywBadCt1LyZWz?usp=sharing). Video captions and data splits are provided in `Annotations` and `VideoSet`. 25 | 26 | + For raw videos, you can refer to the guides from [CLIP4Clip: Data Preparing](https://github.com/ArrowLuo/CLIP4Clip?tab=readme-ov-file#data-preparing). Put the videos into the corresponding `VideoData` folder for each dataset. (It is recommended to use symbolic links.) 27 | 28 | ### Data organization 29 | 30 | Before starting to run the code, please organize the downloaded data in the following format: (The `Models` and `FeatureData` folders will be automatically generated during training and testing, respectively.) 31 | 32 | ```shell 33 | data 34 | ├── datasets 35 | │   ├── msrvtt 36 | │   │   ├── Annotations 37 | │   │   │   ├── MSRVTT_data.json 38 | │   │   │   ├── MSRVTT_JSFUSION_test.csv 39 | │   │   │   └── ... 40 | │ │ ├── FeatureData 41 | │ │ ├── Models 42 | │ │ │ └── msrvtt-7k_xclip+ts2net-as-teacher_vit32 43 | │ │ │ ├── run0 44 | │ │ │ └── ... 45 | │ │ ├── QuerySet 46 | │ │ │ ├── msrvtt1k-test-query.txt 47 | │ │ │ ├── msrvtt3k-test-query.txt 48 | │ │ │ └── ... 49 | │   │   └── VideoData 50 | │   │   │ ├── video0.mp4 51 | │   │   │ ├── video1.mp4 52 | │   │   │ └── ... 53 | │ │ └── VideoSet 54 | │ │ ├── msrvtt1k-test.txt 55 | │ │ ├── msrvtt1k-train.txt 56 | │ │ └── ... 57 | │   ├── activitynet 58 | │   ├── didemo 59 | │   ├── msvd 60 | │   └── vatex 61 | └── teacher_checkpoints 62 | ├── xclip 63 | │   ├── didemo_xclip_model.bin 64 | │   ├── msrvtt-7k_xclip_model.bin 65 | │   └── ... 66 | ├── ts2net 67 | └── xpool 68 | ``` 69 | 70 | ## Code 71 | 72 | ### Training 73 | 74 | Write the config file before training. [Here](https://github.com/ruc-aimc-lab/TeachCLIP/tree/main/configs), we provide a demo config file for each dataset. You can train TeachCLIP on specified GPUs and dataset by using the following command (taking `msrvtt-9k` as an example): 75 | 76 | ```shell 77 | CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 train.py --config_path configs/msrvtt-9k.yaml 78 | ``` 79 | 80 | ### Inference 81 | 82 | Use the following command to extract video / text features: 83 | 84 | ```shell 85 | bash do_extract_video_feat.sh $test_collection $videoset $model_name 86 | # e.g. bash do_extract_video_feat.sh msrvtt msrvtt1k-test msrvtt/Models/msrvtt-9k_xclip+ts2net-as-teacher_vit32/run0 87 | 88 | bash do_extract_text_feat.sh $test_collection $queryset $model_name 89 | # e.g. bash do_extract_text_feat.sh msrvtt msrvtt1k-test-query msrvtt/Models/msrvtt-9k_xclip+ts2net-as-teacher_vit32/run0 90 | ``` 91 | 92 | ### Evaluation 93 | 94 | After obtaining the text and video features, the evaluation metrics can be calculated using the following instructions: 95 | 96 | ```shell 97 | bash do_eval.sh $test_collection $text_feat_name $video_feat_name $gt_file_name 98 | # e.g. bash do_eval.sh msrvtt msrvtt1k-test-query/msrvtt/msrvtt-9k_xclip+ts2net-as-teacher_vit32/run0 msrvtt1k-test/msrvtt/msrvtt-9k_xclip+ts2net-as-teacher_vit32/run0 msrvtt1k-gt 99 | ``` 100 | 101 | ## Citation 102 | 103 | If you find our method useful in your work, please cite: 104 | 105 | ```python 106 | @inproceedings{teachclip, 107 | title = {Holistic Features are almost Sufficient for Text-to-Video Retrieval} 108 | author = {Tian, Kaibin and Zhao, Ruixiang and Xin, Zijie and Lan, Bangxiang and Li, Xirong}, 109 | year = {2024}, 110 | booktitle={CVPR} 111 | } 112 | ``` 113 | 114 | 115 | ## Acknowledgments 116 | 117 | The implementation of TeachCLIP relies on resources from [CLIP4Clip](https://github.com/ArrowLuo/CLIP4Clip "CLIP4Clip"), [X-CLIP](https://github.com/xuguohai/X-CLIP "X-CLIP") and [XPool](https://github.com/layer6ai-labs/xpool "XPool"). We thank the original authors for their open-sourcing. 118 | 119 | 120 | ## Contact 121 | 122 | If you encounter any issue when running the code, please feel free to reach us either by creating a new issue in the GitHub or by emailing 123 | 124 | - Ruixiang Zhao ([ruixiangzhao@ruc.edu.cn](mailto:ruixiangzhao@ruc.edu.cn)) 125 | -------------------------------------------------------------------------------- /configs/activitynet.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | pretrained_clip_name: "ViT-B/32" 3 | freeze_layer_num: 0 4 | linear_patch: "2d" # linear projection of flattened patches. ["2d", "3d"] 5 | sim_header: "seqTransf" # choice a similarity header. ["meanP", "seqLSTM", "seqTransf", "tightTransf"] 6 | cross_model: "cross-base" 7 | loose_type: True 8 | cross_num_hidden_layers: 4 9 | 10 | datasets: 11 | data_type: "activity" 12 | data_path: "data/datasets/activitynet/Annotations" 13 | video_path: "data/datasets/activitynet/VideoData" 14 | 15 | num_thread_reader: 8 16 | max_words: 64 17 | max_frames: 64 18 | feature_framerate: 1 19 | train_frame_order: 0 # Frame order, 0: ordinary order; 1: reverse order; 2: random order. 20 | eval_frame_order: 0 21 | slice_framepos: 2 # "0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly." 22 | 23 | 24 | distillation: 25 | beta: 1.0 26 | distill_method: "pear" 27 | fine_method: "ce" 28 | teacher_num: 1 29 | teacher1_name: "XCLIP" 30 | init_teacher1_model: "data/teacher_checkpoints/xclip/activitynet_xclip_model.bin" 31 | 32 | train: 33 | overwrite: False 34 | seed: 42 35 | lr: 0.0001 36 | coef_lr: 0.001 37 | epochs: 10 38 | batch_size: 64 39 | batch_size_val: 16 40 | lr_decay: 0.9 41 | n_display: 10 42 | warmup_proportion: 0.1 43 | gradient_accumulation_steps: 1 44 | output_dir: "data/datasets/activitynet/Models/xclip-as-teacher_vit32" -------------------------------------------------------------------------------- /configs/didemo.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | pretrained_clip_name: "ViT-B/32" 3 | freeze_layer_num: 0 4 | linear_patch: "2d" # linear projection of flattened patches. ["2d", "3d"] 5 | sim_header: "seqTransf" # choice a similarity header. ["meanP", "seqLSTM", "seqTransf", "tightTransf"] 6 | cross_model: "cross-base" 7 | loose_type: True 8 | cross_num_hidden_layers: 4 9 | 10 | datasets: 11 | data_type: "didemo" 12 | data_path: "data/datasets/didemo/Annotations" 13 | video_path: "data/datasets/didemo/VideoData" 14 | 15 | num_thread_reader: 8 16 | max_words: 64 17 | max_frames: 64 18 | feature_framerate: 1 19 | train_frame_order: 0 # Frame order, 0: ordinary order; 1: reverse order; 2: random order. 20 | eval_frame_order: 0 21 | slice_framepos: 2 # "0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly." 22 | 23 | 24 | distillation: 25 | beta: 1.0 26 | distill_method: "pear" 27 | fine_method: "ce" 28 | teacher_num: 1 29 | teacher1_name: "XCLIP" 30 | init_teacher1_model: "data/teacher_checkpoints/xclip/didemo_xclip_model.bin" 31 | 32 | train: 33 | overwrite: False 34 | seed: 42 35 | lr: 0.0001 36 | coef_lr: 0.001 37 | epochs: 10 38 | batch_size: 64 39 | batch_size_val: 16 40 | lr_decay: 0.9 41 | n_display: 10 42 | warmup_proportion: 0.1 43 | gradient_accumulation_steps: 1 44 | output_dir: "data/datasets/didemo/Models/xclip-as-teacher_vit32" -------------------------------------------------------------------------------- /configs/msrvtt-7k.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | pretrained_clip_name: "ViT-B/32" 3 | freeze_layer_num: 0 4 | linear_patch: "2d" # linear projection of flattened patches. ["2d", "3d"] 5 | sim_header: "seqTransf" # choice a similarity header. ["meanP", "seqLSTM", "seqTransf", "tightTransf"] 6 | cross_model: "cross-base" 7 | loose_type: True 8 | cross_num_hidden_layers: 4 9 | 10 | datasets: 11 | data_type: "msrvtt" 12 | train_csv: "data/datasets/msrvtt/Annotations/MSRVTT_train.7k.csv" 13 | val_csv: "data/datasets/msrvtt/Annotations/MSRVTT_full_split_val.csv" 14 | test_csv: "data/datasets/msrvtt/Annotations/MSRVTT_full_split_test.csv" 15 | data_path: "data/datasets/msrvtt/Annotations/MSRVTT_data.json" 16 | video_path: "data/datasets/msrvtt/VideoData" 17 | 18 | num_thread_reader: 8 19 | max_words: 32 20 | max_frames: 12 21 | feature_framerate: 1 22 | expand_msrvtt_sentences: True 23 | train_frame_order: 0 # Frame order, 0: ordinary order; 1: reverse order; 2: random order. 24 | eval_frame_order: 0 25 | slice_framepos: 2 # "0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly." 26 | 27 | 28 | distillation: 29 | beta: 1.0 30 | distill_method: "pear" 31 | fine_method: "ce" 32 | teacher_num: 2 33 | teacher1_name: "XCLIP" 34 | init_teacher1_model: "data/teacher_checkpoints/xclip/msrvtt-7k_xclip_model.bin" 35 | teacher2_name: "TS2Net" 36 | init_teacher2_model: "data/teacher_checkpoints/ts2net/msrvtt-7k_ts2net_model.bin" 37 | 38 | train: 39 | overwrite: False 40 | seed: 42 41 | lr: 0.0001 42 | coef_lr: 0.001 43 | epochs: 10 44 | batch_size: 120 45 | batch_size_val: 40 46 | lr_decay: 0.9 47 | n_display: 10 48 | warmup_proportion: 0.1 49 | gradient_accumulation_steps: 1 50 | output_dir: "data/datasets/msrvtt/Models/msrvtt-7k_xclip+ts2net-as-teacher_vit32" -------------------------------------------------------------------------------- /configs/msrvtt-9k.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | pretrained_clip_name: "ViT-B/32" 3 | freeze_layer_num: 0 4 | linear_patch: "2d" # linear projection of flattened patches. ["2d", "3d"] 5 | sim_header: "seqTransf" # choice a similarity header. ["meanP", "seqLSTM", "seqTransf", "tightTransf"] 6 | cross_model: "cross-base" 7 | loose_type: True 8 | cross_num_hidden_layers: 4 9 | 10 | datasets: 11 | data_type: "msrvtt" 12 | train_csv: "data/datasets/msrvtt/Annotations/MSRVTT_train.9k.csv" 13 | val_csv: "data/datasets/msrvtt/Annotations/MSRVTT_JSFUSION_test.csv" 14 | test_csv: "data/datasets/msrvtt/Annotations/MSRVTT_JSFUSION_test.csv" 15 | data_path: "data/datasets/msrvtt/Annotations/MSRVTT_data.json" 16 | video_path: "data/datasets/msrvtt/VideoData" 17 | 18 | num_thread_reader: 8 19 | max_words: 32 20 | max_frames: 12 21 | feature_framerate: 1 22 | expand_msrvtt_sentences: True 23 | train_frame_order: 0 # Frame order, 0: ordinary order; 1: reverse order; 2: random order. 24 | eval_frame_order: 0 25 | slice_framepos: 2 # "0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly." 26 | 27 | 28 | distillation: 29 | beta: 1.0 30 | distill_method: "pear" 31 | fine_method: "ce" 32 | teacher_num: 2 33 | teacher1_name: "XCLIP" 34 | init_teacher1_model: "data/teacher_checkpoints/xclip/msrvtt-9k_xclip_model.bin" 35 | teacher2_name: "TS2Net" 36 | init_teacher2_model: "data/teacher_checkpoints/ts2net/msrvtt-9k_ts2net_model.bin" 37 | 38 | train: 39 | overwrite: False 40 | seed: 42 41 | lr: 0.0001 42 | coef_lr: 0.001 43 | epochs: 10 44 | batch_size: 120 45 | batch_size_val: 40 46 | lr_decay: 0.9 47 | n_display: 10 48 | warmup_proportion: 0.1 49 | gradient_accumulation_steps: 1 50 | output_dir: "data/datasets/msrvtt/Models/msrvtt-9k_xclip+ts2net-as-teacher_vit32" -------------------------------------------------------------------------------- /configs/msvd.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | pretrained_clip_name: "ViT-B/32" 3 | freeze_layer_num: 0 4 | linear_patch: "2d" # linear projection of flattened patches. ["2d", "3d"] 5 | sim_header: "seqTransf" # choice a similarity header. ["meanP", "seqLSTM", "seqTransf", "tightTransf"] 6 | cross_model: "cross-base" 7 | loose_type: True 8 | cross_num_hidden_layers: 4 9 | 10 | datasets: 11 | data_type: "msvd" 12 | data_path: "data/datasets/msvd/Annotations" 13 | video_path: "data/datasets/msvd/VideoData" 14 | 15 | num_thread_reader: 8 16 | max_words: 32 17 | max_frames: 12 18 | feature_framerate: 1 19 | train_frame_order: 0 # Frame order, 0: ordinary order; 1: reverse order; 2: random order. 20 | eval_frame_order: 0 21 | slice_framepos: 2 # "0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly." 22 | 23 | 24 | distillation: 25 | beta: 1.0 26 | distill_method: "pear" 27 | fine_method: "ce" 28 | teacher_num: 2 29 | teacher1_name: "XCLIP" 30 | init_teacher1_model: "data/teacher_checkpoints/xclip/msvd_xclip_model.bin" 31 | teacher2_name: "TS2Net" 32 | init_teacher2_model: "data/teacher_checkpoints/ts2net/msvd_ts2net_model.bin" 33 | 34 | train: 35 | overwrite: True 36 | seed: 42 37 | lr: 0.0001 38 | coef_lr: 0.001 39 | epochs: 10 40 | batch_size: 120 41 | batch_size_val: 40 42 | lr_decay: 0.9 43 | n_display: 10 44 | warmup_proportion: 0.1 45 | gradient_accumulation_steps: 1 46 | output_dir: "data/datasets/msvd/Models/xclip+ts2net-as-teacher_vit32" -------------------------------------------------------------------------------- /configs/vatex.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | pretrained_clip_name: "ViT-B/32" 3 | freeze_layer_num: 0 4 | linear_patch: "2d" # linear projection of flattened patches. ["2d", "3d"] 5 | sim_header: "seqTransf" # choice a similarity header. ["meanP", "seqLSTM", "seqTransf", "tightTransf"] 6 | cross_model: "cross-base" 7 | loose_type: True 8 | cross_num_hidden_layers: 4 9 | 10 | datasets: 11 | data_type: "msrvtt" 12 | train_csv: "data/datasets/vatex/Annotations/VATEX_train.csv" 13 | val_csv: "data/datasets/vatex/Annotations/VATEX_val.csv" 14 | test_csv: "data/datasets/vatex/Annotations/VATEX_test.csv" 15 | data_path: "data/datasets/vatex/Annotations/VATEX_data.json" 16 | video_path: "data/datasets/vatex/VideoData" 17 | 18 | num_thread_reader: 8 19 | max_words: 32 20 | max_frames: 12 21 | feature_framerate: 1 22 | expand_msrvtt_sentences: True 23 | train_frame_order: 0 # Frame order, 0: ordinary order; 1: reverse order; 2: random order. 24 | eval_frame_order: 0 25 | slice_framepos: 2 # "0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly." 26 | 27 | 28 | distillation: 29 | beta: 1.0 30 | distill_method: "pear" 31 | fine_method: "ce" 32 | teacher_num: 2 33 | teacher1_name: "XCLIP" 34 | init_teacher1_model: "data/teacher_checkpoints/xclip/vatex_xclip_model.bin" 35 | teacher2_name: "TS2Net" 36 | init_teacher2_model: "data/teacher_checkpoints/ts2net/vatex_ts2net_model.bin" 37 | 38 | train: 39 | overwrite: True 40 | seed: 42 41 | lr: 0.0001 42 | coef_lr: 0.001 43 | epochs: 10 44 | batch_size: 120 45 | batch_size_val: 40 46 | lr_decay: 0.9 47 | n_display: 10 48 | warmup_proportion: 0.1 49 | gradient_accumulation_steps: 1 50 | output_dir: "data/datasets/vatex/Models/xclip+ts2net-as-teacher_vit32" -------------------------------------------------------------------------------- /dataloaders/data_dataloaders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from dataloaders.dataloader_msrvtt_retrieval import MSRVTT_TrainDataLoader 4 | from dataloaders.dataloader_msrvtt_retrieval import MSRVTT_DataLoader 5 | from dataloaders.dataloader_msvd_retrieval import MSVD_DataLoader 6 | from dataloaders.dataloader_activitynet_retrieval import ActivityNet_DataLoader 7 | from dataloaders.dataloader_didemo_retrieval import DiDeMo_DataLoader 8 | from dataloaders.video_dataloader import Video_DataLoader 9 | from dataloaders.text_dataloader import Text_DataLoader 10 | 11 | def dataloader_msrvtt_train(args, tokenizer): 12 | msrvtt_dataset = MSRVTT_TrainDataLoader( 13 | csv_path=args.train_csv, 14 | json_path=args.data_path, 15 | video_path=args.video_path, 16 | max_words=args.max_words, 17 | feature_framerate=args.feature_framerate, 18 | tokenizer=tokenizer, 19 | max_frames=args.max_frames, 20 | unfold_sentences=args.expand_msrvtt_sentences, 21 | frame_order=args.train_frame_order, 22 | slice_framepos=args.slice_framepos, 23 | ) 24 | 25 | train_sampler = torch.utils.data.distributed.DistributedSampler(msrvtt_dataset,num_replicas=args.world_size,rank=args.rank) 26 | dataloader = DataLoader( 27 | msrvtt_dataset, 28 | batch_size=args.batch_size // args.n_gpu, 29 | num_workers=args.num_thread_reader, 30 | pin_memory=True, 31 | shuffle=(train_sampler is None), 32 | sampler=train_sampler, 33 | drop_last=True, 34 | ) 35 | return dataloader, len(msrvtt_dataset), train_sampler 36 | 37 | def dataloader_msrvtt_test(args, tokenizer, subset="test"): 38 | msrvtt_testset = MSRVTT_DataLoader( 39 | subset=subset, 40 | csv_path=args.test_csv, 41 | video_path=args.video_path, 42 | max_words=args.max_words, 43 | feature_framerate=args.feature_framerate, 44 | tokenizer=tokenizer, 45 | max_frames=args.max_frames, 46 | frame_order=args.eval_frame_order, 47 | slice_framepos=args.slice_framepos, 48 | ) 49 | dataloader_msrvtt = DataLoader( 50 | msrvtt_testset, 51 | batch_size=args.batch_size_val, 52 | num_workers=args.num_thread_reader, 53 | shuffle=False, 54 | drop_last=False, 55 | ) 56 | return dataloader_msrvtt, len(msrvtt_testset) 57 | 58 | 59 | 60 | def dataloader_msrvtt_val(args, tokenizer, subset="val"): 61 | msrvtt_valset = MSRVTT_DataLoader( 62 | subset=subset, 63 | csv_path=args.val_csv, 64 | video_path=args.video_path, 65 | max_words=args.max_words, 66 | feature_framerate=args.feature_framerate, 67 | tokenizer=tokenizer, 68 | max_frames=args.max_frames, 69 | frame_order=args.eval_frame_order, 70 | slice_framepos=args.slice_framepos, 71 | ) 72 | dataloader_msrvtt = DataLoader( 73 | msrvtt_valset, 74 | batch_size=args.batch_size_val, 75 | num_workers=args.num_thread_reader, 76 | shuffle=False, 77 | drop_last=False, 78 | ) 79 | return dataloader_msrvtt, len(msrvtt_valset) 80 | 81 | 82 | 83 | def dataloader_msvd_train(args, tokenizer): 84 | msvd_dataset = MSVD_DataLoader( 85 | subset="train", 86 | data_path=args.data_path, 87 | video_path=args.video_path, 88 | max_words=args.max_words, 89 | feature_framerate=args.feature_framerate, 90 | tokenizer=tokenizer, 91 | max_frames=args.max_frames, 92 | frame_order=args.train_frame_order, 93 | slice_framepos=args.slice_framepos, 94 | ) 95 | 96 | train_sampler = torch.utils.data.distributed.DistributedSampler(msvd_dataset,num_replicas=args.world_size,rank=args.rank) 97 | dataloader = DataLoader( 98 | msvd_dataset, 99 | batch_size=args.batch_size // args.n_gpu, 100 | num_workers=args.num_thread_reader, 101 | pin_memory=True, 102 | shuffle=(train_sampler is None), 103 | sampler=train_sampler, 104 | drop_last=True, 105 | ) 106 | 107 | return dataloader, len(msvd_dataset), train_sampler 108 | 109 | def dataloader_msvd_test(args, tokenizer, subset="test"): 110 | msvd_testset = MSVD_DataLoader( 111 | subset=subset, 112 | data_path=args.data_path, 113 | video_path=args.video_path, 114 | max_words=args.max_words, 115 | feature_framerate=args.feature_framerate, 116 | tokenizer=tokenizer, 117 | max_frames=args.max_frames, 118 | frame_order=args.eval_frame_order, 119 | slice_framepos=args.slice_framepos, 120 | ) 121 | dataloader_msrvtt = DataLoader( 122 | msvd_testset, 123 | batch_size=args.batch_size_val, 124 | num_workers=args.num_thread_reader, 125 | shuffle=False, 126 | drop_last=False, 127 | ) 128 | return dataloader_msrvtt, len(msvd_testset) 129 | 130 | 131 | def dataloader_lsmdc_train(args, tokenizer): 132 | lsmdc_dataset = LSMDC_DataLoader( 133 | subset="train", 134 | data_path=args.data_path, 135 | video_path=args.video_path, 136 | max_words=args.max_words, 137 | feature_framerate=args.feature_framerate, 138 | tokenizer=tokenizer, 139 | max_frames=args.max_frames, 140 | frame_order=args.train_frame_order, 141 | slice_framepos=args.slice_framepos, 142 | ) 143 | 144 | train_sampler = torch.utils.data.distributed.DistributedSampler(lsmdc_dataset) 145 | dataloader = DataLoader( 146 | lsmdc_dataset, 147 | batch_size=args.batch_size // args.n_gpu, 148 | num_workers=args.num_thread_reader, 149 | pin_memory=True, 150 | shuffle=(train_sampler is None), 151 | sampler=train_sampler, 152 | drop_last=True, 153 | ) 154 | 155 | return dataloader, len(lsmdc_dataset), train_sampler 156 | 157 | def dataloader_lsmdc_test(args, tokenizer, subset="test"): 158 | lsmdc_testset = LSMDC_DataLoader( 159 | subset=subset, 160 | data_path=args.data_path, 161 | video_path=args.video_path, 162 | max_words=args.max_words, 163 | feature_framerate=args.feature_framerate, 164 | tokenizer=tokenizer, 165 | max_frames=args.max_frames, 166 | frame_order=args.eval_frame_order, 167 | slice_framepos=args.slice_framepos, 168 | ) 169 | dataloader_msrvtt = DataLoader( 170 | lsmdc_testset, 171 | batch_size=args.batch_size_val, 172 | num_workers=args.num_thread_reader, 173 | shuffle=False, 174 | drop_last=False, 175 | ) 176 | return dataloader_msrvtt, len(lsmdc_testset) 177 | 178 | 179 | def dataloader_activity_train(args, tokenizer): 180 | activity_dataset = ActivityNet_DataLoader( 181 | subset="train", 182 | data_path=args.data_path, 183 | video_path=args.video_path, 184 | max_words=args.max_words, 185 | feature_framerate=args.feature_framerate, 186 | tokenizer=tokenizer, 187 | max_frames=args.max_frames, 188 | frame_order=args.train_frame_order, 189 | slice_framepos=args.slice_framepos, 190 | ) 191 | 192 | train_sampler = torch.utils.data.distributed.DistributedSampler(activity_dataset,num_replicas=args.world_size,rank=args.rank) 193 | dataloader = DataLoader( 194 | activity_dataset, 195 | batch_size=args.batch_size // args.n_gpu, 196 | num_workers=args.num_thread_reader, 197 | pin_memory=True, 198 | shuffle=(train_sampler is None), 199 | sampler=train_sampler, 200 | drop_last=True, 201 | ) 202 | 203 | return dataloader, len(activity_dataset), train_sampler 204 | 205 | def dataloader_activity_test(args, tokenizer, subset="val"): 206 | activity_testset = ActivityNet_DataLoader( 207 | subset=subset, 208 | data_path=args.data_path, 209 | video_path=args.video_path, 210 | max_words=args.max_words, 211 | feature_framerate=args.feature_framerate, 212 | tokenizer=tokenizer, 213 | max_frames=args.max_frames, 214 | frame_order=args.eval_frame_order, 215 | slice_framepos=args.slice_framepos, 216 | ) 217 | dataloader = DataLoader( 218 | activity_testset, 219 | batch_size=args.batch_size_val, 220 | num_workers=args.num_thread_reader, 221 | shuffle=False, 222 | drop_last=False, 223 | ) 224 | return dataloader, len(activity_testset) 225 | 226 | 227 | def dataloader_didemo_train(args, tokenizer): 228 | didemo_dataset = DiDeMo_DataLoader( 229 | subset="train", 230 | data_path=args.data_path, 231 | video_path=args.video_path, 232 | max_words=args.max_words, 233 | feature_framerate=args.feature_framerate, 234 | tokenizer=tokenizer, 235 | max_frames=args.max_frames, 236 | frame_order=args.train_frame_order, 237 | slice_framepos=args.slice_framepos, 238 | ) 239 | 240 | train_sampler = torch.utils.data.distributed.DistributedSampler(didemo_dataset) 241 | dataloader = DataLoader( 242 | didemo_dataset, 243 | batch_size=args.batch_size // args.n_gpu, 244 | num_workers=args.num_thread_reader, 245 | pin_memory=False, 246 | shuffle=(train_sampler is None), 247 | sampler=train_sampler, 248 | drop_last=True, 249 | ) 250 | 251 | return dataloader, len(didemo_dataset), train_sampler 252 | 253 | def dataloader_didemo_test(args, tokenizer, subset="test"): 254 | didemo_testset = DiDeMo_DataLoader( 255 | subset=subset, 256 | data_path=args.data_path, 257 | video_path=args.video_path, 258 | max_words=args.max_words, 259 | feature_framerate=args.feature_framerate, 260 | tokenizer=tokenizer, 261 | max_frames=args.max_frames, 262 | frame_order=args.eval_frame_order, 263 | slice_framepos=args.slice_framepos, 264 | ) 265 | dataloader_didemo = DataLoader( 266 | didemo_testset, 267 | batch_size=args.batch_size_val, 268 | num_workers=args.num_thread_reader, 269 | shuffle=False, 270 | drop_last=False, 271 | ) 272 | return dataloader_didemo, len(didemo_testset) 273 | 274 | def video_dataloader(args): 275 | video_set = Video_DataLoader( 276 | videofile_path=args.videofile_path, 277 | videodata_dir=args.video_path, 278 | feature_framerate=args.feature_framerate, 279 | max_frames=args.max_frames, 280 | frame_order=args.frame_order, 281 | slice_framepos=args.slice_framepos, 282 | image_resolution=args.image_resolution, 283 | ) 284 | dataloader_video = DataLoader( 285 | video_set, 286 | batch_size=args.batch_size, 287 | num_workers=args.num_thread_reader, 288 | shuffle=False, 289 | drop_last=False, 290 | ) 291 | return dataloader_video, len(video_set) 292 | 293 | def text_dataloader(args): 294 | text_set = Text_DataLoader( 295 | queryfile_path=args.queryfile_path, 296 | ) 297 | dataloader_text = DataLoader( 298 | text_set, 299 | batch_size=args.batch_size, 300 | num_workers=args.num_thread_reader, 301 | shuffle=False, 302 | drop_last=False, 303 | ) 304 | return dataloader_text, len(text_set) 305 | 306 | DATALOADER_DICT = {} 307 | DATALOADER_DICT["msrvtt"] = {"train":dataloader_msrvtt_train, "val":dataloader_msrvtt_val, "test":dataloader_msrvtt_test} 308 | DATALOADER_DICT["msvd"] = {"train":dataloader_msvd_train, "val":dataloader_msvd_test, "test":dataloader_msvd_test} 309 | DATALOADER_DICT["lsmdc"] = {"train":dataloader_lsmdc_train, "val":dataloader_lsmdc_test, "test":dataloader_lsmdc_test} 310 | DATALOADER_DICT["activity"] = {"train":dataloader_activity_train, "val":dataloader_activity_test, "test":dataloader_activity_test} 311 | DATALOADER_DICT["didemo"] = {"train":dataloader_didemo_train, "val":dataloader_didemo_test, "test":dataloader_didemo_test} 312 | DATALOADER_DICT["video"] = video_dataloader 313 | DATALOADER_DICT["text"] = text_dataloader 314 | -------------------------------------------------------------------------------- /dataloaders/dataloader_activitynet_retrieval.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import unicode_literals 4 | from __future__ import print_function 5 | 6 | import os 7 | from torch.utils.data import Dataset 8 | import numpy as np 9 | import json 10 | import math 11 | from dataloaders.rawvideo_util import RawVideoExtractor 12 | import random 13 | 14 | class ActivityNet_DataLoader(Dataset): 15 | def __init__( 16 | self, 17 | subset, 18 | data_path, 19 | video_path, 20 | tokenizer, 21 | max_words=30, 22 | feature_framerate=1.0, 23 | max_frames=100, 24 | image_resolution=224, 25 | frame_order=0, 26 | slice_framepos=0, 27 | ): 28 | self.data_path = data_path 29 | self.video_path = video_path 30 | self.feature_framerate = feature_framerate 31 | self.max_words = max_words 32 | self.max_frames = max_frames 33 | self.tokenizer = tokenizer 34 | # 0: ordinary order; 1: reverse order; 2: random order. 35 | self.frame_order = frame_order 36 | assert self.frame_order in [0, 1, 2] 37 | # 0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly. 38 | self.slice_framepos = slice_framepos 39 | assert self.slice_framepos in [0, 1, 2] 40 | 41 | self.subset = subset 42 | assert self.subset in ["train", "val"] 43 | 44 | video_id_path_dict = {} 45 | video_id_path_dict["train"] = os.path.join(self.data_path, "train_ids.json") 46 | video_id_path_dict["val"] = os.path.join(self.data_path, "val_ids.json") 47 | 48 | video_json_path_dict = {} 49 | video_json_path_dict["train"] = os.path.join(self.data_path, "train.json") 50 | video_json_path_dict["val"] = os.path.join(self.data_path, "val_1.json") 51 | 52 | pseudo_video_id_list, video_id_list = self._get_video_id_single(video_id_path_dict[self.subset]) 53 | pseudo_caption_dict = self._get_captions_single(video_json_path_dict[self.subset]) 54 | 55 | print("video id list: {}".format(len(video_id_list))) 56 | print("pseudo caption dict: {}".format(len(pseudo_caption_dict.keys()))) 57 | 58 | video_dict = {} 59 | for root, dub_dir, video_files in os.walk(self.video_path): 60 | for video_file in video_files: 61 | video_id_ = ".".join(video_file.split(".")[:-1])[2:] 62 | if video_id_ not in video_id_list: 63 | continue 64 | file_path_ = os.path.join(root, video_file) 65 | video_dict[video_id_] = file_path_ 66 | self.video_dict = video_dict 67 | print("video dict: {}".format(len(video_dict))) 68 | 69 | self.pseudo_video_id_list = pseudo_video_id_list 70 | self.video_id_list = video_id_list 71 | self.pseudo_caption_dict = pseudo_caption_dict 72 | 73 | # Get iterator video ids 74 | self.video_id2idx_dict = {pseudo_video_id: id for id, pseudo_video_id in enumerate(self.pseudo_video_id_list)} 75 | # Get all captions 76 | self.iter2video_pairs_dict = {} 77 | for pseudo_video_id, video_id in zip(self.pseudo_video_id_list, self.video_id_list): 78 | if pseudo_video_id not in self.pseudo_caption_dict or video_id not in self.video_dict: 79 | continue 80 | caption = self.pseudo_caption_dict[pseudo_video_id] 81 | n_caption = len(caption['start']) 82 | for sub_id in range(n_caption): 83 | self.iter2video_pairs_dict[len(self.iter2video_pairs_dict)] = (pseudo_video_id, sub_id) 84 | 85 | self.rawVideoExtractor = RawVideoExtractor(framerate=feature_framerate, size=image_resolution) 86 | self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>", 87 | "MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"} 88 | 89 | def __len__(self): 90 | return len(self.iter2video_pairs_dict) 91 | 92 | def _get_video_id_from_pseduo(self, pseudo_video_id): 93 | video_id = pseudo_video_id[2:] 94 | return video_id 95 | 96 | def _get_video_id_single(self, path): 97 | pseudo_video_id_list = [] 98 | video_id_list = [] 99 | print('Loading json: {}'.format(path)) 100 | with open(path, 'r') as f: 101 | json_data = json.load(f) 102 | 103 | for pseudo_video_id in json_data: 104 | if pseudo_video_id in pseudo_video_id_list: 105 | print("reduplicate.") 106 | else: 107 | video_id = self._get_video_id_from_pseduo(pseudo_video_id) 108 | pseudo_video_id_list.append(pseudo_video_id) 109 | video_id_list.append(video_id) 110 | return pseudo_video_id_list, video_id_list 111 | 112 | def _get_captions_single(self, path): 113 | pseudo_caption_dict = {} 114 | with open(path, 'r') as f: 115 | json_data = json.load(f) 116 | 117 | for pseudo_video_id, v_ in json_data.items(): 118 | pseudo_caption_dict[pseudo_video_id] = {} 119 | duration = v_["duration"] 120 | pseudo_caption_dict[pseudo_video_id]["start"] = np.array([0], dtype=object) 121 | pseudo_caption_dict[pseudo_video_id]["end"] = np.array([int(math.ceil(float(duration)))], dtype=object) 122 | pseudo_caption_dict[pseudo_video_id]["text"] = np.array([" ".join(v_["sentences"])], dtype=object) 123 | return pseudo_caption_dict 124 | 125 | def _get_text(self, pseudo_video_id, sub_id): 126 | caption = self.pseudo_caption_dict[pseudo_video_id] 127 | k = 1 128 | r_ind = [sub_id] 129 | 130 | starts = np.zeros(k, dtype=np.long) 131 | ends = np.zeros(k, dtype=np.long) 132 | pairs_text = np.zeros((k, self.max_words), dtype=np.long) 133 | pairs_mask = np.zeros((k, self.max_words), dtype=np.long) 134 | pairs_segment = np.zeros((k, self.max_words), dtype=np.long) 135 | 136 | for i in range(k): 137 | ind = r_ind[i] 138 | start_, end_ = caption['start'][ind], caption['end'][ind] 139 | words = self.tokenizer.tokenize(caption['text'][ind]) 140 | starts[i], ends[i] = start_, end_ 141 | 142 | words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words 143 | total_length_with_CLS = self.max_words - 1 144 | if len(words) > total_length_with_CLS: 145 | words = words[:total_length_with_CLS] 146 | words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]] 147 | 148 | input_ids = self.tokenizer.convert_tokens_to_ids(words) 149 | input_mask = [1] * len(input_ids) 150 | segment_ids = [0] * len(input_ids) 151 | while len(input_ids) < self.max_words: 152 | input_ids.append(0) 153 | input_mask.append(0) 154 | segment_ids.append(0) 155 | assert len(input_ids) == self.max_words 156 | assert len(input_mask) == self.max_words 157 | assert len(segment_ids) == self.max_words 158 | 159 | pairs_text[i] = np.array(input_ids) 160 | pairs_mask[i] = np.array(input_mask) 161 | pairs_segment[i] = np.array(segment_ids) 162 | 163 | return pairs_text, pairs_mask, pairs_segment, starts, ends 164 | 165 | def _get_rawvideo(self, idx, s, e): 166 | video_mask = np.zeros((len(s), self.max_frames), dtype=np.long) 167 | max_video_length = [0] * len(s) 168 | 169 | # Pair x L x T x 3 x H x W 170 | video = np.zeros((len(s), self.max_frames, 1, 3, 171 | self.rawVideoExtractor.size, self.rawVideoExtractor.size), dtype=np.float) 172 | video_path = self.video_dict[idx] 173 | try: 174 | for i in range(len(s)): 175 | start_time = int(s[i]) 176 | end_time = int(e[i]) 177 | start_time = start_time if start_time >= 0. else 0. 178 | end_time = end_time if end_time >= 0. else 0. 179 | if start_time > end_time: 180 | start_time, end_time = end_time, start_time 181 | elif start_time == end_time: 182 | end_time = end_time + 1 183 | 184 | # Should be optimized by gathering all asking of this video 185 | raw_video_data = self.rawVideoExtractor.get_video_data(video_path, start_time, end_time) 186 | raw_video_data = raw_video_data['video'] 187 | 188 | if len(raw_video_data.shape) > 3: 189 | raw_video_data_clip = raw_video_data 190 | # L x T x 3 x H x W 191 | raw_video_slice = self.rawVideoExtractor.process_raw_data(raw_video_data_clip) 192 | if self.max_frames < raw_video_slice.shape[0]: 193 | if self.slice_framepos == 0: 194 | video_slice = raw_video_slice[:self.max_frames, ...] 195 | elif self.slice_framepos == 1: 196 | video_slice = raw_video_slice[-self.max_frames:, ...] 197 | else: 198 | sample_indx = np.linspace(0, raw_video_slice.shape[0] - 1, num=self.max_frames, dtype=int) 199 | video_slice = raw_video_slice[sample_indx, ...] 200 | else: 201 | video_slice = raw_video_slice 202 | 203 | video_slice = self.rawVideoExtractor.process_frame_order(video_slice, frame_order=self.frame_order) 204 | 205 | slice_len = video_slice.shape[0] 206 | max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_len else slice_len 207 | if slice_len < 1: 208 | pass 209 | else: 210 | video[i][:slice_len, ...] = video_slice 211 | else: 212 | print("video path: {} error. video id: {}, start: {}, end: {}".format(video_path, idx, start_time, end_time)) 213 | except Exception as excep: 214 | print("video path: {} error. video id: {}, start: {}, end: {}, Error: {}".format(video_path, idx, s, e, excep)) 215 | raise excep 216 | 217 | for i, v_length in enumerate(max_video_length): 218 | video_mask[i][:v_length] = [1] * v_length 219 | 220 | return video, video_mask 221 | 222 | def __getitem__(self, feature_idx): 223 | pseudo_video_id, sub_id = self.iter2video_pairs_dict[feature_idx] 224 | idx = self.video_id2idx_dict[pseudo_video_id] 225 | 226 | pairs_text, pairs_mask, pairs_segment, starts, ends = self._get_text(pseudo_video_id, sub_id) 227 | video, video_mask = self._get_rawvideo(self.video_id_list[idx], starts, ends) 228 | return pairs_text, pairs_mask, pairs_segment, video, video_mask 229 | -------------------------------------------------------------------------------- /dataloaders/dataloader_didemo_retrieval.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import unicode_literals 4 | from __future__ import print_function 5 | 6 | import os 7 | from torch.utils.data import Dataset 8 | import numpy as np 9 | import json 10 | from dataloaders.rawvideo_util import RawVideoExtractor 11 | 12 | class DiDeMo_DataLoader(Dataset): 13 | def __init__( 14 | self, 15 | subset, 16 | data_path, 17 | video_path, 18 | tokenizer, 19 | max_words=30, 20 | feature_framerate=1.0, 21 | max_frames=100, 22 | image_resolution=224, 23 | frame_order=0, 24 | slice_framepos=0, 25 | ): 26 | self.data_path = data_path 27 | self.video_path = video_path 28 | self.feature_framerate = feature_framerate 29 | self.max_words = max_words 30 | self.max_frames = max_frames 31 | self.tokenizer = tokenizer 32 | # 0: ordinary order; 1: reverse order; 2: random order. 33 | self.frame_order = frame_order 34 | assert self.frame_order in [0, 1, 2] 35 | # 0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly. 36 | self.slice_framepos = slice_framepos 37 | assert self.slice_framepos in [0, 1, 2] 38 | 39 | self.subset = subset 40 | assert self.subset in ["train", "val", "test"] 41 | 42 | video_id_path_dict = {} 43 | video_id_path_dict["train"] = os.path.join(self.data_path, "train_list.txt") 44 | video_id_path_dict["val"] = os.path.join(self.data_path, "val_list.txt") 45 | video_id_path_dict["test"] = os.path.join(self.data_path, "test_list.txt") 46 | 47 | video_json_path_dict = {} 48 | video_json_path_dict["train"] = os.path.join(self.data_path, "train_data.json") 49 | video_json_path_dict["val"] = os.path.join(self.data_path, "val_data.json") 50 | video_json_path_dict["test"] = os.path.join(self.data_path, "test_data.json") 51 | 52 | with open(video_id_path_dict[self.subset], 'r') as fp: 53 | video_ids = [itm.strip() for itm in fp.readlines()] 54 | 55 | caption_dict = {} 56 | with open(video_json_path_dict[self.subset], 'r') as f: 57 | json_data = json.load(f) 58 | for itm in json_data: 59 | description = itm["description"] 60 | times = itm["times"] 61 | video = itm["video"] 62 | if video not in video_ids: 63 | continue 64 | 65 | # each video is split into 5-second temporal chunks 66 | # average the points from each annotator 67 | start_ = np.mean([t_[0] for t_ in times]) * 5 68 | end_ = (np.mean([t_[1] for t_ in times]) + 1) * 5 69 | if video in caption_dict: 70 | caption_dict[video]["start"].append(start_) 71 | caption_dict[video]["end"].append(end_) 72 | caption_dict[video]["text"].append(description) 73 | else: 74 | caption_dict[video] = {} 75 | caption_dict[video]["start"] = [start_] 76 | caption_dict[video]["end"] = [end_] 77 | caption_dict[video]["text"] = [description] 78 | 79 | for k_ in caption_dict.keys(): 80 | caption_dict[k_]["start"] = [0] 81 | # trick to save time on obtaining each video length 82 | # [https://github.com/LisaAnne/LocalizingMoments/blob/master/README.md]: 83 | # Some videos are longer than 30 seconds. These videos were truncated to 30 seconds during annotation. 84 | caption_dict[k_]["end"] = [31] 85 | caption_dict[k_]["text"] = [" ".join(caption_dict[k_]["text"])] 86 | 87 | video_dict = {} 88 | for root, dub_dir, video_files in os.walk(self.video_path): 89 | for video_file in video_files: 90 | video_id_ = video_file[:-4] 91 | if video_id_ not in video_ids: 92 | continue 93 | file_path_ = os.path.join(root, video_file) 94 | video_dict[video_id_] = file_path_ 95 | 96 | self.caption_dict = caption_dict 97 | self.video_dict = video_dict 98 | video_ids = list(set(video_ids) & set(self.caption_dict.keys()) & set(self.video_dict.keys())) 99 | 100 | # Get all captions 101 | self.iter2video_pairs_dict = {} 102 | for video_id in self.caption_dict.keys(): 103 | if video_id not in video_ids: 104 | continue 105 | caption = self.caption_dict[video_id] 106 | n_caption = len(caption['start']) 107 | for sub_id in range(n_caption): 108 | self.iter2video_pairs_dict[len(self.iter2video_pairs_dict)] = (video_id, sub_id) 109 | self.rawVideoExtractor = RawVideoExtractor(framerate=feature_framerate, size=image_resolution) 110 | self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>", 111 | "MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"} 112 | 113 | def __len__(self): 114 | return len(self.iter2video_pairs_dict) 115 | 116 | def _get_text(self, video_id, sub_id): 117 | caption = self.caption_dict[video_id] 118 | k = 1 119 | r_ind = [sub_id] 120 | 121 | starts = np.zeros(k, dtype=np.long) 122 | ends = np.zeros(k, dtype=np.long) 123 | pairs_text = np.zeros((k, self.max_words), dtype=np.long) 124 | pairs_mask = np.zeros((k, self.max_words), dtype=np.long) 125 | pairs_segment = np.zeros((k, self.max_words), dtype=np.long) 126 | 127 | for i in range(k): 128 | ind = r_ind[i] 129 | start_, end_ = caption['start'][ind], caption['end'][ind] 130 | words = self.tokenizer.tokenize(caption['text'][ind]) 131 | starts[i], ends[i] = start_, end_ 132 | 133 | words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words 134 | total_length_with_CLS = self.max_words - 1 135 | if len(words) > total_length_with_CLS: 136 | words = words[:total_length_with_CLS] 137 | words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]] 138 | 139 | input_ids = self.tokenizer.convert_tokens_to_ids(words) 140 | input_mask = [1] * len(input_ids) 141 | segment_ids = [0] * len(input_ids) 142 | while len(input_ids) < self.max_words: 143 | input_ids.append(0) 144 | input_mask.append(0) 145 | segment_ids.append(0) 146 | assert len(input_ids) == self.max_words 147 | assert len(input_mask) == self.max_words 148 | assert len(segment_ids) == self.max_words 149 | 150 | pairs_text[i] = np.array(input_ids) 151 | pairs_mask[i] = np.array(input_mask) 152 | pairs_segment[i] = np.array(segment_ids) 153 | 154 | return pairs_text, pairs_mask, pairs_segment, starts, ends 155 | 156 | def _get_rawvideo(self, idx, s, e): 157 | video_mask = np.zeros((len(s), self.max_frames), dtype=np.long) 158 | max_video_length = [0] * len(s) 159 | 160 | # Pair x L x T x 3 x H x W 161 | video = np.zeros((len(s), self.max_frames, 1, 3, 162 | self.rawVideoExtractor.size, self.rawVideoExtractor.size), dtype=np.float) 163 | video_path = self.video_dict[idx] 164 | 165 | try: 166 | for i in range(len(s)): 167 | start_time = int(s[i]) 168 | end_time = int(e[i]) 169 | start_time = start_time if start_time >= 0. else 0. 170 | end_time = end_time if end_time >= 0. else 0. 171 | if start_time > end_time: 172 | start_time, end_time = end_time, start_time 173 | elif start_time == end_time: 174 | end_time = end_time + 1 175 | 176 | cache_id = "{}_{}_{}".format(video_path, start_time, end_time) 177 | # Should be optimized by gathering all asking of this video 178 | raw_video_data = self.rawVideoExtractor.get_video_data(video_path, start_time, end_time) 179 | raw_video_data = raw_video_data['video'] 180 | 181 | if len(raw_video_data.shape) > 3: 182 | raw_video_data_clip = raw_video_data 183 | # L x T x 3 x H x W 184 | raw_video_slice = self.rawVideoExtractor.process_raw_data(raw_video_data_clip) 185 | if self.max_frames < raw_video_slice.shape[0]: 186 | if self.slice_framepos == 0: 187 | video_slice = raw_video_slice[:self.max_frames, ...] 188 | elif self.slice_framepos == 1: 189 | video_slice = raw_video_slice[-self.max_frames:, ...] 190 | else: 191 | sample_indx = np.linspace(0, raw_video_slice.shape[0] - 1, num=self.max_frames, dtype=int) 192 | video_slice = raw_video_slice[sample_indx, ...] 193 | else: 194 | video_slice = raw_video_slice 195 | 196 | video_slice = self.rawVideoExtractor.process_frame_order(video_slice, frame_order=self.frame_order) 197 | 198 | slice_len = video_slice.shape[0] 199 | max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_len else slice_len 200 | if slice_len < 1: 201 | pass 202 | else: 203 | video[i][:slice_len, ...] = video_slice 204 | else: 205 | print("video path: {} error. video id: {}, start: {}, end: {}".format(video_path, idx, start_time, end_time)) 206 | except Exception as excep: 207 | print("video path: {} error. video id: {}, start: {}, end: {}, Error: {}".format(video_path, idx, s, e, excep)) 208 | pass 209 | # raise e 210 | 211 | for i, v_length in enumerate(max_video_length): 212 | video_mask[i][:v_length] = [1] * v_length 213 | 214 | return video, video_mask 215 | 216 | def __getitem__(self, feature_idx): 217 | video_id, sub_id = self.iter2video_pairs_dict[feature_idx] 218 | 219 | pairs_text, pairs_mask, pairs_segment, starts, ends = self._get_text(video_id, sub_id) 220 | video, video_mask = self._get_rawvideo(video_id, starts, ends) 221 | return pairs_text, pairs_mask, pairs_segment, video, video_mask -------------------------------------------------------------------------------- /dataloaders/dataloader_msvd_retrieval.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import unicode_literals 4 | from __future__ import print_function 5 | 6 | import os 7 | from torch.utils.data import Dataset 8 | import numpy as np 9 | import pickle 10 | from dataloaders.rawvideo_util import RawVideoExtractor 11 | 12 | class MSVD_DataLoader(Dataset): 13 | """MSVD dataset loader.""" 14 | def __init__( 15 | self, 16 | subset, 17 | data_path, 18 | video_path, 19 | tokenizer, 20 | max_words=30, 21 | feature_framerate=1.0, 22 | max_frames=100, 23 | image_resolution=224, 24 | frame_order=0, 25 | slice_framepos=0, 26 | ): 27 | self.data_path = data_path 28 | self.video_path = video_path 29 | self.feature_framerate = feature_framerate 30 | self.max_words = max_words 31 | self.max_frames = max_frames 32 | self.tokenizer = tokenizer 33 | # 0: ordinary order; 1: reverse order; 2: random order. 34 | self.frame_order = frame_order 35 | assert self.frame_order in [0, 1, 2] 36 | # 0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly. 37 | self.slice_framepos = slice_framepos 38 | assert self.slice_framepos in [0, 1, 2] 39 | 40 | self.subset = subset 41 | assert self.subset in ["train", "val", "test"] 42 | video_id_path_dict = {} 43 | video_id_path_dict["train"] = os.path.join(self.data_path, "train_list.txt") 44 | video_id_path_dict["val"] = os.path.join(self.data_path, "val_list.txt") 45 | video_id_path_dict["test"] = os.path.join(self.data_path, "test_list.txt") 46 | caption_file = os.path.join(self.data_path, "raw-captions.pkl") # msrvttpretrain_1masksaliencynoveltytoken.pkl raw-captions.pkl msvdpretrain_recover1masksaliencynoveltytoken.pkl msvdpretrain_1masksaliencynoveltytoken.pkl 47 | 48 | with open(video_id_path_dict[self.subset], 'r') as fp: 49 | video_ids = [itm.strip() for itm in fp.readlines()] 50 | 51 | with open(caption_file, 'rb') as f: 52 | captions = pickle.load(f) 53 | 54 | video_dict = {} 55 | tmpvideo_path = self.video_path.replace("frames", "videos") 56 | # print(tmpvideo_path) 57 | for root, dub_dir, video_files in os.walk(tmpvideo_path): 58 | for video_file in video_files: 59 | video_id_ = video_file.split(".")[0] 60 | if video_id_ not in video_ids: 61 | continue 62 | file_path_ = os.path.join(root, video_file) 63 | video_dict[video_id_] = file_path_ 64 | self.video_dict = video_dict 65 | 66 | 67 | self.sample_len = 0 68 | self.sentences_dict = {} 69 | self.cut_off_points = [] 70 | for video_id in video_ids: 71 | assert video_id in captions 72 | for cap in captions[video_id]: 73 | cap_txt = " ".join(cap) 74 | self.sentences_dict[len(self.sentences_dict)] = (video_id, cap_txt) 75 | self.cut_off_points.append(len(self.sentences_dict)) 76 | 77 | ## below variables are used to multi-sentences retrieval 78 | # self.cut_off_points: used to tag the label when calculate the metric 79 | # self.sentence_num: used to cut the sentence representation 80 | # self.video_num: used to cut the video representation 81 | self.multi_sentence_per_video = True # !!! important tag for eval 82 | if self.subset == "val" or self.subset == "test": 83 | self.sentence_num = len(self.sentences_dict) 84 | self.video_num = len(video_ids) 85 | assert len(self.cut_off_points) == self.video_num 86 | print("For {}, sentence number: {}".format(self.subset, self.sentence_num)) 87 | print("For {}, video number: {}".format(self.subset, self.video_num)) 88 | 89 | print("Video number: {}".format(len(self.video_dict))) 90 | print("Total Pair: {}".format(len(self.sentences_dict))) 91 | 92 | self.sample_len = len(self.sentences_dict) 93 | self.rawVideoExtractor = RawVideoExtractor(framerate=feature_framerate, size=image_resolution) 94 | self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>", 95 | "MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"} 96 | 97 | def __len__(self): 98 | return self.sample_len 99 | 100 | def _get_text(self, video_id, caption): 101 | k = 1 102 | choice_video_ids = [video_id] 103 | pairs_text = np.zeros((k, self.max_words), dtype=np.long) 104 | pairs_mask = np.zeros((k, self.max_words), dtype=np.long) 105 | pairs_segment = np.zeros((k, self.max_words), dtype=np.long) 106 | 107 | for i, video_id in enumerate(choice_video_ids): 108 | words = self.tokenizer.tokenize(caption) 109 | 110 | words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words 111 | total_length_with_CLS = self.max_words - 1 112 | if len(words) > total_length_with_CLS: 113 | words = words[:total_length_with_CLS] 114 | words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]] 115 | 116 | input_ids = self.tokenizer.convert_tokens_to_ids(words) 117 | input_mask = [1] * len(input_ids) 118 | segment_ids = [0] * len(input_ids) 119 | while len(input_ids) < self.max_words: 120 | input_ids.append(0) 121 | input_mask.append(0) 122 | segment_ids.append(0) 123 | assert len(input_ids) == self.max_words 124 | assert len(input_mask) == self.max_words 125 | assert len(segment_ids) == self.max_words 126 | 127 | pairs_text[i] = np.array(input_ids) 128 | pairs_mask[i] = np.array(input_mask) 129 | pairs_segment[i] = np.array(segment_ids) 130 | 131 | return pairs_text, pairs_mask, pairs_segment, choice_video_ids 132 | 133 | def _get_rawvideo(self, choice_video_ids): 134 | video_mask = np.zeros((len(choice_video_ids), self.max_frames), dtype=np.long) 135 | max_video_length = [0] * len(choice_video_ids) 136 | 137 | # Pair x L x T x 3 x H x W 138 | video = np.zeros((len(choice_video_ids), self.max_frames, 1, 3, 139 | self.rawVideoExtractor.size, self.rawVideoExtractor.size), dtype=np.float) 140 | 141 | 142 | for i, video_id in enumerate(choice_video_ids): 143 | video_path = self.video_dict[video_id] 144 | 145 | raw_video_data = self.rawVideoExtractor.get_video_data(video_path) 146 | raw_video_data = raw_video_data['video'] 147 | 148 | if len(raw_video_data.shape) > 3: 149 | raw_video_data_clip = raw_video_data 150 | # L x T x 3 x H x W 151 | raw_video_slice = self.rawVideoExtractor.process_raw_data(raw_video_data_clip) 152 | if self.max_frames < raw_video_slice.shape[0]: 153 | if self.slice_framepos == 0: 154 | video_slice = raw_video_slice[:self.max_frames, ...] 155 | elif self.slice_framepos == 1: 156 | video_slice = raw_video_slice[-self.max_frames:, ...] 157 | else: 158 | sample_indx = np.linspace(0, raw_video_slice.shape[0] - 1, num=self.max_frames, dtype=int) 159 | video_slice = raw_video_slice[sample_indx, ...] 160 | else: 161 | video_slice = raw_video_slice 162 | 163 | video_slice = self.rawVideoExtractor.process_frame_order(video_slice, frame_order=self.frame_order) 164 | 165 | slice_len = video_slice.shape[0] 166 | max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_len else slice_len 167 | if slice_len < 1: 168 | pass 169 | else: 170 | video[i][:slice_len, ...] = video_slice 171 | else: 172 | print("video path: {} error. video id: {}".format(video_path, video_id)) 173 | 174 | for i, v_length in enumerate(max_video_length): 175 | video_mask[i][:v_length] = [1] * v_length 176 | 177 | return video, video_mask 178 | 179 | def __getitem__(self, idx): 180 | video_id, caption = self.sentences_dict[idx] 181 | 182 | pairs_text, pairs_mask, pairs_segment, choice_video_ids = self._get_text(video_id, caption) 183 | video, video_mask = self._get_rawvideo(choice_video_ids) 184 | return pairs_text, pairs_mask, pairs_segment, video, video_mask -------------------------------------------------------------------------------- /dataloaders/rawvideo_util.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import numpy as np 3 | from PIL import Image 4 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 5 | import cv2 6 | 7 | class RawVideoExtractorCV2(): 8 | def __init__(self, centercrop=False, size=224, framerate=-1, ): 9 | self.centercrop = centercrop 10 | self.size = size 11 | self.framerate = framerate 12 | self.transform = self._transform(self.size) 13 | 14 | 15 | def _transform(self, n_px): 16 | return Compose([ 17 | Resize(n_px, interpolation=Image.BICUBIC), 18 | CenterCrop(n_px), 19 | lambda image: image.convert("RGB"), 20 | ToTensor(), 21 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 22 | ]) 23 | 24 | def video_to_tensor(self, video_file, preprocess, sample_fp=0, start_time=None, end_time=None): 25 | if start_time is not None or end_time is not None: 26 | assert isinstance(start_time, int) and isinstance(end_time, int) \ 27 | and start_time > -1 and end_time > start_time 28 | assert sample_fp > -1 29 | 30 | # Samples a frame sample_fp X frames. 31 | cap = cv2.VideoCapture(video_file) 32 | frameCount = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 33 | fps = int(cap.get(cv2.CAP_PROP_FPS)) 34 | 35 | total_duration = (frameCount + fps - 1) // fps 36 | start_sec, end_sec = 0, total_duration 37 | 38 | if start_time is not None: 39 | start_sec, end_sec = start_time, end_time if end_time <= total_duration else total_duration 40 | cap.set(cv2.CAP_PROP_POS_FRAMES, int(start_time * fps)) 41 | 42 | interval = 1 43 | if sample_fp > 0: 44 | interval = fps // sample_fp 45 | else: 46 | sample_fp = fps 47 | if interval == 0: interval = 1 48 | 49 | inds = [ind for ind in np.arange(0, fps, interval)] 50 | assert len(inds) >= sample_fp 51 | inds = inds[:sample_fp] 52 | 53 | ret = True 54 | images, included = [], [] 55 | 56 | for sec in np.arange(start_sec, end_sec + 1): 57 | if not ret: break 58 | sec_base = int(sec * fps) 59 | for ind in inds: 60 | cap.set(cv2.CAP_PROP_POS_FRAMES, sec_base + ind) 61 | ret, frame = cap.read() 62 | if not ret: break 63 | frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 64 | images.append(preprocess(Image.fromarray(frame_rgb).convert("RGB"))) 65 | 66 | cap.release() 67 | 68 | if len(images) > 0: 69 | video_data = th.tensor(np.stack(images)) 70 | else: 71 | video_data = th.zeros(1) 72 | return {'video': video_data} 73 | 74 | 75 | 76 | def get_video_data(self, video_path, start_time=None, end_time=None): 77 | image_input = self.video_to_tensor(video_path, self.transform, sample_fp=self.framerate, start_time=start_time, end_time=end_time) 78 | return image_input 79 | 80 | 81 | def process_raw_data(self, raw_video_data): 82 | tensor_size = raw_video_data.size() 83 | tensor = raw_video_data.view(-1, 1, tensor_size[-3], tensor_size[-2], tensor_size[-1]) 84 | return tensor 85 | 86 | def process_frame_order(self, raw_video_data, frame_order=0): 87 | # 0: ordinary order; 1: reverse order; 2: random order. 88 | if frame_order == 0: 89 | pass 90 | elif frame_order == 1: 91 | reverse_order = np.arange(raw_video_data.size(0) - 1, -1, -1) 92 | raw_video_data = raw_video_data[reverse_order, ...] 93 | elif frame_order == 2: 94 | random_order = np.arange(raw_video_data.size(0)) 95 | np.random.shuffle(random_order) 96 | raw_video_data = raw_video_data[random_order, ...] 97 | 98 | return raw_video_data 99 | 100 | # An ordinary video frame extractor based CV2 101 | RawVideoExtractor = RawVideoExtractorCV2 -------------------------------------------------------------------------------- /dataloaders/text_dataloader.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import unicode_literals 4 | from __future__ import print_function 5 | 6 | import os 7 | from torch.utils.data import Dataset 8 | import numpy as np 9 | import pandas as pd 10 | import json 11 | import random 12 | 13 | class Text_DataLoader(Dataset): 14 | def __init__(self, queryfile_path, ): 15 | self.sentence_ids = [] 16 | self.sentences = [] 17 | with open(queryfile_path, 'r') as f: 18 | for line in f.readlines(): 19 | sentence_id, sentence = line.strip().split('\t', 1) 20 | self.sentence_ids.append(sentence_id) 21 | self.sentences.append(sentence) 22 | self.sentence_num = len(self.sentence_ids) 23 | print("Sentence number: {}".format(self.sentence_num)) 24 | 25 | def __len__(self): 26 | return self.sentence_num 27 | 28 | def __getitem__(self, idx): 29 | sentence_id = self.sentence_ids[idx] 30 | sentence = self.sentences[idx] 31 | 32 | return sentence_id, sentence -------------------------------------------------------------------------------- /dataloaders/video_dataloader.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import unicode_literals 4 | from __future__ import print_function 5 | 6 | import os 7 | from torch.utils.data import Dataset 8 | import numpy as np 9 | import pandas as pd 10 | from collections import defaultdict 11 | import json 12 | import random 13 | from dataloaders.rawvideo_util import RawVideoExtractor 14 | # from dataloaders.rawframe_util import RawVideoExtractor 15 | 16 | class Video_DataLoader(Dataset): 17 | def __init__( 18 | self, 19 | videofile_path, 20 | videodata_dir, 21 | feature_framerate=1.0, 22 | max_frames=100, 23 | image_resolution=224, 24 | frame_order=0, 25 | slice_framepos=0 26 | ): 27 | self.video_ids = [] 28 | with open(videofile_path, 'r') as f: 29 | for line in f.readlines(): 30 | self.video_ids.append(line.strip()) 31 | 32 | self.video_paths = [] 33 | for i, video_id in enumerate(self.video_ids): 34 | video_path = os.path.join(videodata_dir, "{}".format(video_id)) 35 | if os.path.exists(video_path) is False: 36 | video_path = video_path + '.mp4' 37 | if os.path.exists(video_path) is False: 38 | video_path = video_path.replace(".mp4", ".avi") 39 | if os.path.exists(video_path) is False: 40 | video_path = video_path.replace(".avi", "") 41 | if os.path.exists(video_path) is False: 42 | print('video path = {} is not exists.'.format(video_path)) 43 | break 44 | self.video_paths.append(video_path) 45 | 46 | self.test_set_start_time = {} 47 | self.test_set_end_time = {} 48 | if 'didemo' in videofile_path: # get start and end timestamps 49 | for k_ in self.video_ids: 50 | self.test_set_start_time[k_] = 0 51 | # trick to save time on obtaining each video length 52 | # [https://github.com/LisaAnne/LocalizingMoments/blob/master/README.md]: 53 | # Some videos are longer than 30 seconds. These videos were truncated to 30 seconds during annotation. 54 | self.test_set_end_time[k_] = 31 55 | 56 | self.feature_framerate = feature_framerate 57 | self.max_frames = max_frames 58 | # 0: ordinary order; 1: reverse order; 2: random order. 59 | self.frame_order = frame_order 60 | assert self.frame_order in [0, 1, 2] 61 | # 0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly. 62 | self.slice_framepos = slice_framepos 63 | assert self.slice_framepos in [0, 1, 2] 64 | 65 | self.video_num = len(self.video_ids) 66 | print("Video number: {}".format(self.video_num)) 67 | 68 | self.rawVideoExtractor = RawVideoExtractor(framerate=feature_framerate, size=image_resolution) 69 | 70 | def __len__(self): 71 | return self.video_num 72 | 73 | def _get_rawvideo(self, choice_video_ids, choice_video_paths): 74 | # NOTE: numpy==1.24.3时, np.long报错, 改为np.longlong, np.float改为np.float_ 75 | video_mask = np.zeros((len(choice_video_ids), self.max_frames), dtype=np.long) 76 | max_video_length = [0] * len(choice_video_ids) 77 | 78 | # Pair x L x T x 3 x H x W 79 | video = np.zeros((len(choice_video_ids), self.max_frames, 1, 3, 80 | self.rawVideoExtractor.size, self.rawVideoExtractor.size), dtype=np.float) 81 | try: 82 | for i, video_id in enumerate(choice_video_ids): 83 | video_path = choice_video_paths[i] 84 | if len(self.test_set_start_time) == 0: 85 | raw_video_data = self.rawVideoExtractor.get_video_data(video_path) 86 | else: 87 | raw_video_data = self.rawVideoExtractor.get_video_data(video_path, self.test_set_start_time[video_id], self.test_set_end_time[video_id]) 88 | 89 | raw_video_data = raw_video_data['video'] 90 | if len(raw_video_data.shape) > 3: 91 | raw_video_data_clip = raw_video_data 92 | # L x T x 3 x H x W 93 | raw_video_slice = self.rawVideoExtractor.process_raw_data(raw_video_data_clip) 94 | if self.max_frames < raw_video_slice.shape[0]: 95 | if self.slice_framepos == 0: 96 | video_slice = raw_video_slice[:self.max_frames, ...] 97 | elif self.slice_framepos == 1: 98 | video_slice = raw_video_slice[-self.max_frames:, ...] 99 | else: 100 | sample_indx = np.linspace(0, raw_video_slice.shape[0] - 1, num=self.max_frames, dtype=int) 101 | video_slice = raw_video_slice[sample_indx, ...] 102 | else: 103 | video_slice = raw_video_slice 104 | 105 | video_slice = self.rawVideoExtractor.process_frame_order(video_slice, frame_order=self.frame_order) 106 | 107 | slice_len = video_slice.shape[0] 108 | max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_len else slice_len 109 | if slice_len < 1: 110 | pass 111 | else: 112 | video[i][:slice_len, ...] = video_slice 113 | else: 114 | print("video path: {} error. video id: {}".format(video_path, video_id)) 115 | except Exception as excep: 116 | print("video path: {} error. Error: {}".format(video_path, excep)) 117 | pass 118 | 119 | for i, v_length in enumerate(max_video_length): 120 | video_mask[i][:v_length] = [1] * v_length 121 | 122 | return video, video_mask 123 | 124 | def __getitem__(self, idx): 125 | video_id = self.video_ids[idx] 126 | video_path = self.video_paths[idx] 127 | video, video_mask = self._get_rawvideo([video_id], [video_path]) 128 | return video_id, video, video_mask -------------------------------------------------------------------------------- /do_eval.sh: -------------------------------------------------------------------------------- 1 | test_collection=$1 2 | text_feat_name=$2 3 | video_feat_name=$3 4 | gt_file_name=$4 5 | 6 | python evaluation.py --local_rank=0 \ 7 | --text_feat_path="data/datasets/$test_collection/FeatureData/$text_feat_name" \ 8 | --video_feat_path="data/datasets/$test_collection/FeatureData/$video_feat_name" \ 9 | --gt_file_path="data/datasets/$test_collection/Annotations/$gt_file_name.txt" -------------------------------------------------------------------------------- /do_extract_text_feat.sh: -------------------------------------------------------------------------------- 1 | test_collection=$1 2 | queryset=$2 3 | model_name=$3 # $train_collection/Models/$config/run$ID 4 | 5 | train_collection=$(echo "$model_name" | cut -d'/' -f1) 6 | config=$(echo "$model_name" | cut -d'/' -f3) 7 | run=$(echo "$model_name" | cut -d'/' -f4) 8 | 9 | python -m torch.distributed.launch extract_feat.py --datatype text \ 10 | --local_rank=3 --num_thread_reader=8 --batch_size=100 \ 11 | --queryfile_path "data/datasets/$test_collection/QuerySet/$queryset.txt" \ 12 | --output_dir "data/datasets/$test_collection/FeatureData/$queryset/$train_collection/$config/$run" \ 13 | --init_model "data/datasets/$model_name/best_model.bin" \ 14 | --max_frames=12 --max_words=32 --feature_framerate 1 --slice_framepos 2 \ 15 | --linear_patch 2d --sim_header seqTransf \ 16 | --pretrained_clip_name ViT-B/32 17 | 18 | -------------------------------------------------------------------------------- /do_extract_video_feat.sh: -------------------------------------------------------------------------------- 1 | test_collection=$1 2 | videoset=$2 3 | model_name=$3 # $train_collection/Models/$config/run$ID 4 | 5 | train_collection=$(echo "$model_name" | cut -d'/' -f1) 6 | config=$(echo "$model_name" | cut -d'/' -f3) 7 | run=$(echo "$model_name" | cut -d'/' -f4) 8 | 9 | python -m torch.distributed.launch extract_feat.py --datatype video \ 10 | --local_rank=3 --num_thread_reader=8 --batch_size=100 \ 11 | --videofile_path "data/datasets/$test_collection/VideoSet/$videoset.txt" \ 12 | --video_path "data/datasets/$test_collection/VideoData" \ 13 | --output_dir "data/datasets/$test_collection/FeatureData/$videoset/$train_collection/$config/$run" \ 14 | --init_model "data/datasets/$model_name/best_model.bin" \ 15 | --max_frames=12 --max_words=32 --feature_framerate 1 --slice_framepos 2 \ 16 | --linear_patch 2d --sim_header seqTransf \ 17 | --pretrained_clip_name ViT-B/32 18 | 19 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | from utils.bigfile import BigFile 4 | from tqdm import tqdm 5 | from metrics import compute_metrics, tensor_text_to_video_metrics, tensor_video_to_text_sim 6 | import argparse 7 | import numpy as np 8 | 9 | def get_args(description='CLIP4Clip Distill on Retrieval Task'): 10 | parser = argparse.ArgumentParser(description=description) 11 | parser.add_argument('--local_rank', type=int, default=0, help='gpu id') 12 | parser.add_argument('--text_feat_path', type=str, default='', help='text_feat_path') 13 | parser.add_argument('--video_feat_path', type=str, default='', help='video_feat_path') 14 | parser.add_argument('--gt_file_path', type=str, default='', help='gt_file_path') 15 | args = parser.parse_args() 16 | return args 17 | 18 | def cal_sim_matrix(args): 19 | query_ids = [] 20 | video_ids = [] 21 | with open(args.gt_file_path, 'r') as f: 22 | for line in f.readlines(): 23 | query_id, video_id = line.strip().split('\t') 24 | if query_id not in query_ids: 25 | query_ids.append(query_id) 26 | if video_id not in video_ids: 27 | video_ids.append(video_id) 28 | 29 | video_file = BigFile(args.video_feat_path) 30 | text_file = BigFile(args.text_feat_path) 31 | 32 | device = torch.device('cuda:{}'.format(args.local_rank)) 33 | text_feats = [] 34 | for query_id in query_ids: 35 | text_feat = text_file.read_one(query_id) 36 | text_feats.append(text_feat) 37 | text_feats = torch.tensor(text_feats).to(device) 38 | text_feats = text_feats / text_feats.norm(dim=-1, keepdim=True) 39 | 40 | video_feats = [] 41 | for video_id in video_ids: 42 | video_feat = video_file.read_one(video_id) 43 | video_feats.append(video_feat) 44 | video_feats = torch.tensor(video_feats).to(device) 45 | video_feats = video_feats / video_feats.norm(dim=-1, keepdim=True) 46 | 47 | sim_matrix = torch.einsum('md,nd->mn', text_feats, video_feats) 48 | sim_matrix_npy = sim_matrix.cpu().numpy() 49 | return sim_matrix_npy 50 | 51 | def get_metrics(sim_matrix, args): 52 | if sim_matrix.shape[0] != sim_matrix.shape[1]: 53 | print("before reshape, sim matrix size: {} x {}".format(sim_matrix.shape[0], sim_matrix.shape[1])) 54 | cut_off_points2len_ = [] 55 | tmp_video_ids = [] 56 | i = 0 57 | for line in open(args.gt_file_path, 'r'): 58 | query_id, video_id = line.strip().split('\t') 59 | if video_id not in tmp_video_ids: 60 | tmp_video_ids.append(video_id) 61 | cut_off_points2len_.append(i) 62 | i = i + 1 63 | cut_off_points2len_ = cut_off_points2len_[1:] 64 | cut_off_points2len_.append(i) 65 | 66 | max_length = max([e_-s_ for s_, e_ in zip([0]+cut_off_points2len_[:-1], cut_off_points2len_)]) 67 | sim_matrix_new = [] 68 | for s_, e_ in zip([0] + cut_off_points2len_[:-1], cut_off_points2len_): 69 | sim_matrix_new.append(np.concatenate((sim_matrix[s_:e_], 70 | np.full((max_length-e_+s_, sim_matrix.shape[1]), -np.inf)), axis=0)) 71 | sim_matrix = np.stack(tuple(sim_matrix_new), axis=0) 72 | print("after reshape, sim matrix size: {} x {} x {}". 73 | format(sim_matrix.shape[0], sim_matrix.shape[1], sim_matrix.shape[2])) 74 | 75 | tv_metrics = tensor_text_to_video_metrics(sim_matrix) 76 | vt_metrics = compute_metrics(tensor_video_to_text_sim(sim_matrix)) 77 | else: 78 | print("sim matrix size: {}, {}".format(sim_matrix.shape[0], sim_matrix.shape[1])) 79 | tv_metrics = compute_metrics(sim_matrix) 80 | vt_metrics = compute_metrics(sim_matrix.T) 81 | print('\t Length-T: {}, Length-V:{}'.format(len(sim_matrix), len(sim_matrix[0]))) 82 | 83 | print("Text-to-Video:") 84 | print('\t>>> R@1: {:.1f} - R@5: {:.1f} - R@10: {:.1f} - Median R: {:.1f} - Mean R: {:.1f}'. 85 | format(tv_metrics['R1'], tv_metrics['R5'], tv_metrics['R10'], tv_metrics['MR'], tv_metrics['MeanR'])) 86 | print("Video-to-Text:") 87 | print('\t>>> V2T$R@1: {:.1f} - V2T$R@5: {:.1f} - V2T$R@10: {:.1f} - V2T$Median R: {:.1f} - V2T$Mean R: {:.1f}'. 88 | format(vt_metrics['R1'], vt_metrics['R5'], vt_metrics['R10'], vt_metrics['MR'], vt_metrics['MeanR'])) 89 | 90 | def main(): 91 | args = get_args() 92 | sim_matrix = cal_sim_matrix(args) 93 | get_metrics(sim_matrix, args) 94 | 95 | if __name__ == "__main__": 96 | main() 97 | -------------------------------------------------------------------------------- /extract_feat.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import unicode_literals 4 | from __future__ import print_function 5 | 6 | import torch 7 | import numpy as np 8 | import random 9 | import os 10 | import time 11 | import argparse 12 | import sys 13 | from tqdm import tqdm 14 | 15 | from dataloaders.data_dataloaders import DATALOADER_DICT 16 | from modules.modeling import CLIP4Clip 17 | from modules.file_utils import PYTORCH_PRETRAINED_BERT_CACHE 18 | from modules.tokenization_clip import Tokenizer 19 | 20 | import utils.txt2bin as txt2bin 21 | 22 | def get_args(description='TeachCLIP Feature Extraction on a sigle GPU'): 23 | parser = argparse.ArgumentParser(description=description) 24 | parser.add_argument("--datatype", default="video", type=str, required=True, choices=['video', 'text'], help="Point the dataset to extract feature.") 25 | 26 | # arguments for dataloder 27 | parser.add_argument('--seed', type=int, default=42, help='random seed') 28 | parser.add_argument('--local_rank', type=int, default=0, help='gpu id') 29 | parser.add_argument('--num_thread_reader', type=int, default=1, help='') 30 | parser.add_argument('--batch_size', type=int, default=256, help='batch size for all dataloder') 31 | 32 | parser.add_argument('--overwrite', action='store_true', help='overwrite output feature file if true') 33 | parser.add_argument('--queryfile_path', type=str, default='data/datasets/msrvtt/QuerySet/msrvtt1k-test-query.txt', help='query id, query') 34 | parser.add_argument('--videofile_path', type=str, default='data/datasets/msrvtt/VideoSet/msrvtt1k-test.txt', help='video id') 35 | parser.add_argument('--video_path', type=str, default='data/datasets/msrvtt/VideoData', help='video data dir') 36 | parser.add_argument("--output_dir", default=None, type=str, required=True, help="The output directory where the feature file will be written.") 37 | 38 | # arguments for model 39 | parser.add_argument("--init_model", default=None, type=str, required=True, help="Initial model.") 40 | parser.add_argument("--pretrained_clip_name", default="ViT-B/32", type=str, help="Choose a CLIP version") 41 | parser.add_argument("--cross_model", default="cross-base", type=str, required=False, help="Cross module") 42 | parser.add_argument("--cache_dir", default="", type=str, 43 | help="Where do you want to store the pre-trained models downloaded from s3") 44 | parser.add_argument('--cross_num_hidden_layers', type=int, default=4, help="Layer NO. of cross.") 45 | parser.add_argument('--sim_header', type=str, default="seqTransf", 46 | choices=["meanP", "seqLSTM", "seqTransf", "tightTransf"], 47 | help="choice a similarity header.") 48 | parser.add_argument('--linear_patch', type=str, default="2d", choices=["2d", "3d"], 49 | help="linear projection of flattened patches.") 50 | 51 | # arguments for video feature extraction 52 | parser.add_argument('--max_words', type=int, default=20, help='') 53 | parser.add_argument('--image_resolution', type=int, default=224, help='') 54 | parser.add_argument('--max_frames', type=int, default=100, help='') 55 | parser.add_argument('--feature_framerate', type=int, default=1, help='') 56 | parser.add_argument('--frame_order', type=int, default=0, choices=[0, 1, 2], 57 | help="Frame order, 0: ordinary order; 1: reverse order; 2: random order.") 58 | parser.add_argument('--slice_framepos', type=int, default=0, choices=[0, 1, 2], 59 | help="0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly.") 60 | 61 | args = parser.parse_args() 62 | 63 | return args 64 | 65 | def set_seed(args): 66 | # predefining random initial seeds 67 | random.seed(args.seed) 68 | os.environ['PYTHONHASHSEED'] = str(args.seed) 69 | np.random.seed(args.seed) 70 | torch.manual_seed(args.seed) 71 | torch.cuda.manual_seed(args.seed) 72 | 73 | def init_model(args, device): 74 | if args.init_model: 75 | model_state_dict = torch.load(args.init_model, map_location='cpu') 76 | else: 77 | model_state_dict = None 78 | 79 | # Prepare model 80 | cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed') 81 | model = CLIP4Clip.from_pretrained(args.cross_model, cache_dir=cache_dir, state_dict=model_state_dict, task_config=args) 82 | model.to(device) 83 | 84 | return model 85 | 86 | def extract_video_feature(args, model, video_dataloader, device): 87 | model.eval() 88 | 89 | id_feature_path = os.path.join(args.output_dir, 'id.feature.txt') 90 | if os.path.exists(id_feature_path): 91 | if args.overwrite: 92 | print('%s exists. overwrite', id_feature_path) 93 | else: 94 | print('%s exists. skip', id_feature_path) 95 | sys.exit(0) 96 | if not os.path.exists(args.output_dir): 97 | os.makedirs(args.output_dir) 98 | open_type = 'w' 99 | fw = open(id_feature_path, open_type) 100 | 101 | with torch.no_grad(): 102 | # ---------------------------- 103 | # 1. cache the features 104 | # ---------------------------- 105 | print("***** Extracting video featrures *****") 106 | for batch in tqdm(video_dataloader): 107 | video_ids = batch[0] 108 | batch = tuple(t.to(device) for t in batch[1:]) 109 | video, video_mask = batch 110 | 111 | # video_features: [batch_size, out_dim] 112 | video_features = model.get_video_output(video, video_mask) 113 | 114 | # write video features to txt file 115 | video_features_numpy = video_features.cpu().numpy() 116 | for i in range(len(video_ids)): 117 | line = str(video_ids[i]) + ' ' + ' '.join([str(num) for num in video_features_numpy[i, :]]) + '\n' 118 | fw.write(line) 119 | 120 | fw.close() 121 | # transform to bin format 122 | print("***** txt to bin format *****") 123 | overwrite = args.overwrite 124 | txt2bin.process(0, [id_feature_path], args.output_dir, overwrite) 125 | # delete id.feature.txt 126 | os.remove(id_feature_path) 127 | 128 | def extract_text_feature(args, model, text_dataloader, device): 129 | model.eval() 130 | 131 | id_feature_path = os.path.join(args.output_dir, 'id.feature.txt') 132 | if os.path.exists(id_feature_path): 133 | if args.overwrite: 134 | print('%s exists. overwrite', id_feature_path) 135 | else: 136 | print('%s exists. skip', id_feature_path) 137 | sys.exit(0) 138 | if not os.path.exists(args.output_dir): 139 | os.makedirs(args.output_dir) 140 | open_type = 'w' 141 | fw = open(id_feature_path, open_type) 142 | 143 | with torch.no_grad(): 144 | # ---------------------------- 145 | # 1. cache the features 146 | # ---------------------------- 147 | print("***** Extracting text featrures *****") 148 | tokenizer = Tokenizer(max_words=args.max_words) 149 | for batch in tqdm(text_dataloader): 150 | sentence_ids, sentences = batch[0], batch[1] 151 | 152 | # text tokenization 153 | input_ids, input_mask, segment_ids = tokenizer._get_text(sentence_ids, sentences) 154 | input_ids, input_mask, segment_ids = input_ids.to(device), input_mask.to(device), segment_ids.to(device) 155 | 156 | # text_features: [batch_size, out_dim] 157 | text_features = model.get_sequence_output(input_ids, segment_ids, input_mask).squeeze() 158 | 159 | # write text features to txt file 160 | text_features_numpy = text_features.cpu().numpy() 161 | if torch.is_tensor(sentence_ids): 162 | sentence_ids = sentence_ids.cpu().numpy() 163 | for i in range(len(sentence_ids)): 164 | line = str(sentence_ids[i]) + ' ' + ' '.join([str(num) for num in text_features_numpy[i, :]]) + '\n' 165 | fw.write(line) 166 | 167 | fw.close() 168 | # transform to bin format 169 | print("***** txt to bin format *****") 170 | txt2bin.process(0, [id_feature_path], args.output_dir, args.overwrite) 171 | # delete id.feature.txt 172 | os.remove(id_feature_path) 173 | 174 | def main(): 175 | args = get_args() 176 | set_seed(args) 177 | device = None 178 | if torch.cuda.is_available(): 179 | device = torch.device('cuda:{}'.format(args.local_rank)) 180 | else: 181 | device = torch.device('cpu') 182 | raise Error('GPU is not available, infer on cpu is too slow!') 183 | 184 | ## #################################### 185 | # model loading 186 | ## #################################### 187 | model = init_model(args, device) 188 | 189 | ## #################################### 190 | # dataloader loading 191 | ## #################################### 192 | assert args.datatype in DATALOADER_DICT 193 | video_dataloader = None 194 | if args.datatype == 'video': 195 | video_dataloader, video_set_length = DATALOADER_DICT[args.datatype](args) 196 | print("***** Video feature extraction *****") 197 | print(" Num examples = ", video_set_length) 198 | print(" Batch size = ", args.batch_size) 199 | print(" Num steps = ", len(video_dataloader)) 200 | elif args.datatype == 'text': 201 | text_dataloader, text_set_length = DATALOADER_DICT[args.datatype](args) 202 | print("***** Text feature extraction *****") 203 | print(" Num examples = ", text_set_length) 204 | print(" Batch size = ", args.batch_size) 205 | print(" Num steps = ", len(text_dataloader)) 206 | 207 | ## #################################### 208 | # featue extraction 209 | ## #################################### 210 | if args.datatype == 'video': 211 | extract_video_feature(args, model, video_dataloader, device) 212 | elif args.datatype == 'text': 213 | extract_text_feature(args, model, text_dataloader, device) 214 | else: 215 | print('not implemented!') 216 | 217 | if __name__ == "__main__": 218 | main() -------------------------------------------------------------------------------- /images/teachclip.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruc-aimc-lab/TeachCLIP/23c991fd7e4fe7bdbd7eb431f78e8e1340fdc736/images/teachclip.png -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import unicode_literals 4 | from __future__ import print_function 5 | 6 | import numpy as np 7 | import torch 8 | from scipy.special import softmax 9 | 10 | def compute_metrics(x): 11 | sx = np.sort(-x, axis=1) 12 | d = np.diag(-x) 13 | d = d[:, np.newaxis] 14 | ind = sx - d 15 | ind = np.where(ind == 0) 16 | ind = ind[1] 17 | metrics = {} 18 | metrics['R1'] = float(np.sum(ind == 0)) * 100 / len(ind) 19 | metrics['R5'] = float(np.sum(ind < 5)) * 100 / len(ind) 20 | metrics['R10'] = float(np.sum(ind < 10)) * 100 / len(ind) 21 | metrics['R50'] = float(np.sum(ind < 50)) * 100 / len(ind) 22 | metrics['MR'] = np.median(ind) + 1 23 | metrics["MedianR"] = metrics['MR'] 24 | metrics["MeanR"] = np.mean(ind) + 1 25 | metrics["cols"] = [int(i) for i in list(ind)] 26 | return metrics 27 | 28 | def compute_dsl_metrics(x): 29 | x = softmax(x, axis=0) * x 30 | sx = np.sort(-x, axis=1) 31 | d = np.diag(-x) 32 | d = d[:, np.newaxis] 33 | ind = sx - d 34 | ind = np.where(ind == 0) 35 | ind = ind[1] 36 | metrics = {} 37 | metrics['R1'] = float(np.sum(ind == 0)) * 100 / len(ind) 38 | metrics['R5'] = float(np.sum(ind < 5)) * 100 / len(ind) 39 | metrics['R10'] = float(np.sum(ind < 10)) * 100 / len(ind) 40 | metrics['MR'] = np.median(ind) + 1 41 | metrics["MedianR"] = metrics['MR'] 42 | metrics["MeanR"] = np.mean(ind) + 1 43 | metrics["cols"] = [int(i) for i in list(ind)] 44 | return metrics 45 | 46 | def print_computed_metrics(metrics): 47 | r1 = metrics['R1'] 48 | r5 = metrics['R5'] 49 | r10 = metrics['R10'] 50 | mr = metrics['MR'] 51 | print('R@1: {:.4f} - R@5: {:.4f} - R@10: {:.4f} - Median R: {}'.format(r1, r5, r10, mr)) 52 | 53 | # below two functions directly come from: https://github.com/Deferf/Experiments 54 | def tensor_text_to_video_metrics(sim_tensor, top_k = [1,5,10,50]): 55 | if not torch.is_tensor(sim_tensor): 56 | sim_tensor = torch.tensor(sim_tensor) 57 | 58 | # Permute sim_tensor so it represents a sequence of text-video similarity matrices. 59 | # Then obtain the double argsort to position the rank on the diagonal 60 | stacked_sim_matrices = sim_tensor.permute(1, 0, 2) 61 | first_argsort = torch.argsort(stacked_sim_matrices, dim = -1, descending= True) 62 | second_argsort = torch.argsort(first_argsort, dim = -1, descending= False) 63 | 64 | # Extracts ranks i.e diagonals 65 | ranks = torch.flatten(torch.diagonal(second_argsort, dim1 = 1, dim2 = 2)) 66 | 67 | # Now we need to extract valid ranks, as some belong to inf padding values 68 | permuted_original_data = torch.flatten(torch.diagonal(sim_tensor, dim1 = 0, dim2 = 2)) 69 | mask = ~ torch.logical_or(torch.isinf(permuted_original_data), torch.isnan(permuted_original_data)) 70 | valid_ranks = ranks[mask] 71 | # A quick dimension check validates our results, there may be other correctness tests pending 72 | # Such as dot product localization, but that is for other time. 73 | #assert int(valid_ranks.shape[0]) == sum([len(text_dict[k]) for k in text_dict]) 74 | if not torch.is_tensor(valid_ranks): 75 | valid_ranks = torch.tensor(valid_ranks) 76 | results = {f"R{k}": float(torch.sum(valid_ranks < k) * 100 / len(valid_ranks)) for k in top_k} 77 | results["MedianR"] = float(torch.median(valid_ranks + 1)) 78 | results["MeanR"] = float(np.mean(valid_ranks.numpy() + 1)) 79 | results["Std_Rank"] = float(np.std(valid_ranks.numpy() + 1)) 80 | results['MR'] = results["MedianR"] 81 | return results 82 | 83 | def tensor_video_to_text_sim(sim_tensor): 84 | if not torch.is_tensor(sim_tensor): 85 | sim_tensor = torch.tensor(sim_tensor) 86 | # Code to avoid nans 87 | sim_tensor[sim_tensor != sim_tensor] = float('-inf') 88 | # Forms a similarity matrix for use with rank at k 89 | values, _ = torch.max(sim_tensor, dim=1, keepdim=True) 90 | return torch.squeeze(values).T 91 | 92 | def compute_classification_metrics(pred_labels, gt_labels): 93 | ''' 94 | pred_labels: (n_videos, n_labels) 95 | gt_labels: (n_videos, 1) 96 | ''' 97 | sx = pred_labels 98 | d = gt_labels 99 | ind = sx - d 100 | ind = np.where(ind == 0) 101 | ind = ind[1] 102 | metrics = {} 103 | metrics['R1'] = float(np.sum(ind == 0)) * 100 / len(ind) 104 | metrics['R5'] = float(np.sum(ind < 5)) * 100 / len(ind) 105 | metrics['R10'] = float(np.sum(ind < 10)) * 100 / len(ind) 106 | metrics['MR'] = np.median(ind) + 1 107 | metrics["MedianR"] = metrics['MR'] 108 | metrics["MeanR"] = np.mean(ind) + 1 109 | metrics["cols"] = [int(i) for i in list(ind)] 110 | return metrics 111 | -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruc-aimc-lab/TeachCLIP/23c991fd7e4fe7bdbd7eb431f78e8e1340fdc736/modules/__init__.py -------------------------------------------------------------------------------- /modules/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruc-aimc-lab/TeachCLIP/23c991fd7e4fe7bdbd7eb431f78e8e1340fdc736/modules/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /modules/cross-base/cross_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 512, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 2048, 8 | "max_position_embeddings": 128, 9 | "num_attention_heads": 8, 10 | "num_hidden_layers": 4, 11 | "vocab_size": 512 12 | } -------------------------------------------------------------------------------- /modules/differential_topk.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from: https://github.com/yuqi657/ts2_net/blob/master/modules/modeling.py 3 | """ 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | 8 | ########################################### 9 | ############# differential topK ########### 10 | ########################################### 11 | # Calculation of differential topK is based on [Top-K](https://arxiv.org/pdf/2104.03059.pdf), thanks 12 | class PerturbedTopK(nn.Module): 13 | def __init__(self, k: int, num_samples: int=500, sigma: float=0.05): 14 | super().__init__() 15 | self.num_samples = num_samples 16 | self.sigma = sigma 17 | self.k = k 18 | 19 | def __call__(self, x): 20 | return PerturbedTopKFuntion.apply(x, self.k, self.num_samples, self.sigma) 21 | 22 | class PerturbedTopKFuntion(torch.autograd.Function): 23 | @staticmethod 24 | def forward(ctx, x, k: int, num_samples: int=500, sigma: float=0.05): 25 | # input here is scores with (bs, num_patches) 26 | b, d = x.shape 27 | noise = torch.normal(mean=0.0, std=1.0, size=(b, num_samples, d)).to(dtype=x.dtype, device=x.device) 28 | perturbed_x = x.unsqueeze(1) + noise*sigma # b, nS, d 29 | topk_results = torch.topk(perturbed_x, k=k, dim=-1, sorted=False) 30 | indices = topk_results.indices # b, nS, k 31 | indices = torch.sort(indices, dim=-1).values # b, nS, k 32 | 33 | perturbed_output = F.one_hot(indices, num_classes=d).float() # b, nS, k, d 34 | indicators = perturbed_output.mean(dim=1) # b, k, d 35 | 36 | # context for backward 37 | ctx.k = k 38 | ctx.num_samples = num_samples 39 | ctx.sigma = sigma 40 | 41 | ctx.perturbed_output = perturbed_output 42 | ctx.noise = noise 43 | 44 | return indicators 45 | 46 | @staticmethod 47 | def backward(ctx, grad_output): 48 | if grad_output is None: 49 | return tuple([None]*5) 50 | 51 | noise_gradient = ctx.noise 52 | expected_gradient = ( 53 | torch.einsum("bnkd,bnd->bkd", ctx.perturbed_output, noise_gradient) 54 | / ctx.num_samples 55 | / ctx.sigma 56 | ) 57 | grad_input = torch.einsum("bkd,bkd->bd", grad_output, expected_gradient) 58 | return (grad_input,) + tuple([None]*5) 59 | 60 | ########################################### 61 | ############# differential topK ########### 62 | ########################################### 63 | 64 | class PredictorLG(nn.Module): 65 | """ Image to Patch Embedding 66 | """ 67 | def __init__(self, embed_dim=512): 68 | super().__init__() 69 | self.in_conv = nn.Sequential( 70 | nn.LayerNorm(embed_dim), 71 | nn.Linear(embed_dim, embed_dim // 2, bias=False), 72 | nn.GELU() 73 | ) 74 | 75 | self.out_conv = nn.Sequential( 76 | nn.Linear(embed_dim, embed_dim // 2, bias=False), 77 | nn.GELU(), 78 | # nn.Linear(embed_dim // 2, embed_dim // 4, bias=False), 79 | # nn.GELU(), 80 | nn.Linear(embed_dim // 2, 1, bias=False), 81 | nn.Tanh() 82 | # nn.Sigmoid() 83 | # nn.Softmax(dim=-1) 84 | # nn.LogSoftmax(dim=-1) 85 | ) 86 | 87 | def forward(self, x): 88 | ''' 89 | x: shape (bs*n_length, num_tokens, hid_dim) 90 | ''' 91 | x = self.in_conv(x) 92 | B, N, C = x.size() 93 | local_x = x[:,:, :] 94 | global_x = x[:,:1, :] 95 | # print("global_x.shape: ", global_x.shape) 96 | x = torch.cat([local_x, global_x.expand(B, N, C)], dim=-1) 97 | return self.out_conv(x) 98 | 99 | class VisualTokenSelection(nn.Module): 100 | def __init__(self, max_frames, embed_dim=512, topk=3): 101 | super().__init__() 102 | self.max_frames = max_frames 103 | self.score_predictor = PredictorLG(embed_dim=embed_dim) 104 | self.topk_selector = PerturbedTopK(topk) 105 | 106 | def forward(self, x, training=True): 107 | ''' 108 | x: input embed, shape is (bs, length*Ntokens, hid_dim) 109 | use cls token as global representation 110 | prob = Tanh(MLP(x)) 111 | ''' 112 | 113 | B, L, D = x.shape 114 | N = L // self.max_frames 115 | x = x.reshape(B, -1, N, D) # shape here is (bs, max_frames, n_patches, hid_dim) 116 | x = x.reshape(-1, N, D) # shape here is (bs*max_frames, n_patches, hid_dim) 117 | pred_score = self.score_predictor(x).squeeze() # (bs*max_frames, n_patches) 118 | 119 | spatial_pred_score = pred_score[:, 1:] # seperate the cls_token (bs*max_frames, n_patches-1) 120 | topk_indicator = self.topk_selector(spatial_pred_score) # (bs*max_frames, k, n_patches-1)) 121 | 122 | # cls token as cls token 123 | cls_x_feature = x[:, :1, :] # cls_token, shape here is (bs*max_frames, 1, hid_dim) 124 | # # avg pool of all tokens as cls token 125 | # cls_x_feature = torch.mean(x, dim=1, keepdim=True) 126 | 127 | spatial_x_feature = x[:, 1:, :] # seperate the cls_token, shape here is (bs*max_frames, n_patches-1, hid_dim) 128 | selected_patch_feature = torch.einsum("bkl,bld->bkd", topk_indicator, spatial_x_feature) 129 | 130 | output = torch.cat((cls_x_feature, selected_patch_feature), dim=1) # shape here is (bs*max_frames, topkPatches, hid_dim) 131 | output = output.reshape(B, self.max_frames, -1, D).reshape(B, -1, D) # shape here is (B, max_frames*topkPatches, D) 132 | 133 | return output 134 | 135 | class STPredictorConv(nn.Module): 136 | """ Image to Patch Embedding 137 | """ 138 | def __init__(self, embed_dim=512): 139 | super().__init__() 140 | self.in_conv = nn.Sequential( 141 | nn.LayerNorm(embed_dim), 142 | nn.Linear(embed_dim, embed_dim // 2, bias=False), 143 | nn.GELU() 144 | ) 145 | 146 | self.out_conv = nn.Sequential( 147 | nn.Linear(embed_dim, embed_dim // 2, bias=False), 148 | nn.GELU(), 149 | # nn.Linear(embed_dim // 2, embed_dim // 4, bias=False), 150 | # nn.GELU(), 151 | nn.Linear(embed_dim // 2, 1, bias=False), 152 | # nn.Tanh() 153 | nn.Softmax(dim=-1) 154 | # nn.LogSoftmax(dim=-1) 155 | ) 156 | 157 | def forward(self, x, max_frames): 158 | ''' 159 | x: shape (bs*n_length, num_tokens, hid_dim) 160 | ''' 161 | x = self.in_conv(x) 162 | B_frame, N, C = x.size() 163 | B = B_frame // max_frames 164 | local_x = x[:,:, :] 165 | 166 | global_x = x[:,:1, :].reshape(B, max_frames, 1, C) # shape (bs, n_length, cls_tokens, hid_dim) 167 | global_x = torch.mean(global_x, 1, True).expand(B, max_frames, 1, C).reshape(B_frame, 1, C) 168 | # print("global_x.shape: ", global_x.shape) 169 | 170 | x = torch.cat([local_x, global_x.expand(B_frame, N, C)], dim=-1) 171 | return self.out_conv(x) 172 | 173 | 174 | class STVisualTokenSelection(nn.Module): 175 | def __init__(self, max_frames, embed_dim=512, topk=3): 176 | super().__init__() 177 | self.max_frames = max_frames 178 | self.score_predictor = STPredictorConv(embed_dim=embed_dim) 179 | self.topk_selector = PerturbedTopK(topk) 180 | 181 | def forward(self, x, training=True): 182 | ''' 183 | x: input embed, shape is (bs, length*Ntokens, hid_dim) 184 | use cls token as global representation 185 | prob = Tanh(MLP(x)) 186 | ''' 187 | 188 | B, L, D = x.shape 189 | N = L // self.max_frames 190 | x = x.reshape(B, -1, N, D) # shape here is (bs, max_frames, n_patches, hid_dim) 191 | x = x.reshape(-1, N, D) # shape here is (bs*max_frames, n_patches, hid_dim) 192 | pred_score = self.score_predictor(x, self.max_frames).squeeze() # (bs*max_frames, n_patches) 193 | 194 | spatial_pred_score = pred_score[:, 1:] # seperate the cls_token (bs*max_frames, n_patches-1) 195 | topk_indicator = self.topk_selector(spatial_pred_score) # (bs*max_frames, k, n_patches-1)) 196 | 197 | # cls token as cls token 198 | cls_x_feature = x[:, :1, :] # cls_token, shape here is (bs*max_frames, 1, hid_dim) 199 | # # avg pool of all tokens as cls token 200 | # cls_x_feature = torch.mean(x, dim=1, keepdim=True) 201 | 202 | spatial_x_feature = x[:, 1:, :] # seperate the cls_token, shape here is (bs*max_frames, n_patches-1, hid_dim) 203 | selected_patch_feature = torch.einsum("bkl,bld->bkd", topk_indicator, spatial_x_feature) 204 | 205 | output = torch.cat((cls_x_feature, selected_patch_feature), dim=1) # shape here is (bs*max_frames, topkPatches, hid_dim) 206 | output = output.reshape(B, self.max_frames, -1, D).reshape(B, -1, D) # shape here is (B, max_frames*topkPatches, D) 207 | 208 | return output 209 | 210 | class VisualTokenRandomSelection(nn.Module): 211 | def __init__(self, max_frames, embed_dim=512, topk=3): 212 | super().__init__() 213 | self.max_frames = max_frames 214 | self.topk = topk 215 | 216 | def forward(self, x, training=True): 217 | ''' 218 | x: input embed, shape is (bs, length*Ntokens, hid_dim) 219 | use cls token as global representation 220 | prob = Tanh(MLP(x)) 221 | ''' 222 | 223 | B, L, D = x.shape 224 | N = L // self.max_frames 225 | x = x.reshape(B, -1, N, D) # shape here is (bs, max_frames, n_patches, hid_dim) 226 | x = x.reshape(-1, N, D) # shape here is (bs*max_frames, n_patches, hid_dim) 227 | 228 | # cls token as cls token 229 | cls_x_feature = x[:, :1, :] # cls_token, shape here is (bs*max_frames, 1, hid_dim) 230 | # # avg pool of all tokens as cls token 231 | # cls_x_feature = torch.mean(x, dim=1, keepdim=True) 232 | 233 | spatial_x_feature = x[:, 1:, :] # seperate the cls_token, shape here is (bs*max_frames, n_patches-1, hid_dim) 234 | patch_len = spatial_x_feature.shape[1] 235 | selected_indices = torch.randperm(patch_len)[:self.topk].sort()[0] 236 | selected_patch_feature = spatial_x_feature[:, selected_indices, :] 237 | 238 | output = torch.cat((cls_x_feature, selected_patch_feature), dim=1) # shape here is (bs*max_frames, topkPatches, hid_dim) 239 | output = output.reshape(B, self.max_frames, -1, D).reshape(B, -1, D) # shape here is (B, max_frames*topkPatches, D) 240 | 241 | return output 242 | 243 | class TextPredictorLG(nn.Module): 244 | """ Text to Patch Embedding 245 | """ 246 | def __init__(self, embed_dim=512): 247 | super().__init__() 248 | self.in_conv = nn.Sequential( 249 | nn.LayerNorm(embed_dim), 250 | nn.Linear(embed_dim, embed_dim // 2), 251 | nn.GELU() 252 | ) 253 | 254 | self.out_conv = nn.Sequential( 255 | nn.Linear(embed_dim, embed_dim // 2, bias=False), 256 | nn.GELU(), 257 | # nn.Linear(embed_dim // 2, embed_dim // 4, bias=False), 258 | # nn.GELU(), 259 | nn.Linear(embed_dim // 2, 1, bias=False), 260 | # nn.Tanh() 261 | nn.Sigmoid() 262 | ) 263 | 264 | def forward(self, x, text): 265 | ''' 266 | x: shape (bs, num_tokens, hid_dim) 267 | ''' 268 | x = self.in_conv(x) 269 | B, N, C = x.size() 270 | local_x = x[:, :, :] 271 | global_x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)].unsqueeze(1) 272 | x = torch.cat([local_x, global_x.expand(B, N, C)], dim=-1) 273 | return self.out_conv(x) 274 | 275 | class TextTokenSelection(nn.Module): 276 | def __init__(self, embed_dim=512, topk=1): 277 | super().__init__() 278 | self.score_predictor = TextPredictorLG(embed_dim=embed_dim) 279 | self.topk_selector = PerturbedTopK(topk) 280 | 281 | def forward(self, x, input_ids, attention_mask, training=True): 282 | ''' 283 | x: input embed, shape is (bs, max_words, hid_dim) 284 | input_ids: (bs, max_words) token id, cls is the max 285 | attention_mask: (bs, max_words) 286 | use cls token as global representation 287 | prob = Tanh(MLP(x)) 288 | ''' 289 | B, N, D = x.shape 290 | pred_score = self.score_predictor(x, input_ids).squeeze() # (bs, max_words) 291 | 292 | attention_mask_new = torch.cat((attention_mask[:, 1:], torch.zeros(B,1).to(device=attention_mask.device, dtype=attention_mask.dtype)), dim=1) 293 | # print("attention_mask: ", attention_mask[0], "\nattention_mask_new: ", attention_mask_new[0]) 294 | word_pred_score = pred_score*attention_mask_new # seperate the cls_token (bs, n_token-1) 295 | # print("word_pred_score: ", word_pred_score[0]) 296 | topk_indicator = self.topk_selector(word_pred_score) # (bs, k, n_token-1)) 297 | 298 | # cls token as cls token 299 | cls_x_feature = x[torch.arange(x.shape[0]), input_ids.argmax(dim=-1)].unsqueeze(1) # cls_token, shape here is (bs, 1, hid_dim) 300 | 301 | selected_patch_feature = torch.einsum("bkl,bld->bkd", topk_indicator, x) 302 | 303 | output = torch.cat((cls_x_feature, selected_patch_feature), dim=1) # shape here is (bs, topkPatches, hid_dim) 304 | 305 | return output -------------------------------------------------------------------------------- /modules/file_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for working with the local dataset cache. 3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 4 | Copyright by the AllenNLP authors. 5 | """ 6 | 7 | import os 8 | import logging 9 | import shutil 10 | import tempfile 11 | import json 12 | from urllib.parse import urlparse 13 | from pathlib import Path 14 | from typing import Optional, Tuple, Union, IO, Callable, Set 15 | from hashlib import sha256 16 | from functools import wraps 17 | 18 | from tqdm import tqdm 19 | 20 | import boto3 21 | from botocore.exceptions import ClientError 22 | import requests 23 | 24 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 25 | 26 | PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 27 | Path.home() / '.pytorch_pretrained_bert')) 28 | 29 | 30 | def url_to_filename(url: str, etag: str = None) -> str: 31 | """ 32 | Convert `url` into a hashed filename in a repeatable way. 33 | If `etag` is specified, append its hash to the url's, delimited 34 | by a period. 35 | """ 36 | url_bytes = url.encode('utf-8') 37 | url_hash = sha256(url_bytes) 38 | filename = url_hash.hexdigest() 39 | 40 | if etag: 41 | etag_bytes = etag.encode('utf-8') 42 | etag_hash = sha256(etag_bytes) 43 | filename += '.' + etag_hash.hexdigest() 44 | 45 | return filename 46 | 47 | 48 | def filename_to_url(filename: str, cache_dir: Union[str, Path] = None) -> Tuple[str, str]: 49 | """ 50 | Return the url and etag (which may be ``None``) stored for `filename`. 51 | Raise ``FileNotFoundError`` if `filename` or its stored metadata do not exist. 52 | """ 53 | if cache_dir is None: 54 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 55 | if isinstance(cache_dir, Path): 56 | cache_dir = str(cache_dir) 57 | 58 | cache_path = os.path.join(cache_dir, filename) 59 | if not os.path.exists(cache_path): 60 | raise FileNotFoundError("file {} not found".format(cache_path)) 61 | 62 | meta_path = cache_path + '.json' 63 | if not os.path.exists(meta_path): 64 | raise FileNotFoundError("file {} not found".format(meta_path)) 65 | 66 | with open(meta_path) as meta_file: 67 | metadata = json.load(meta_file) 68 | url = metadata['url'] 69 | etag = metadata['etag'] 70 | 71 | return url, etag 72 | 73 | 74 | def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] = None) -> str: 75 | """ 76 | Given something that might be a URL (or might be a local path), 77 | determine which. If it's a URL, download the file and cache it, and 78 | return the path to the cached file. If it's already a local path, 79 | make sure the file exists and then return the path. 80 | """ 81 | if cache_dir is None: 82 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 83 | if isinstance(url_or_filename, Path): 84 | url_or_filename = str(url_or_filename) 85 | if isinstance(cache_dir, Path): 86 | cache_dir = str(cache_dir) 87 | 88 | parsed = urlparse(url_or_filename) 89 | 90 | if parsed.scheme in ('http', 'https', 's3'): 91 | # URL, so get it from the cache (downloading if necessary) 92 | return get_from_cache(url_or_filename, cache_dir) 93 | elif os.path.exists(url_or_filename): 94 | # File, and it exists. 95 | return url_or_filename 96 | elif parsed.scheme == '': 97 | # File, but it doesn't exist. 98 | raise FileNotFoundError("file {} not found".format(url_or_filename)) 99 | else: 100 | # Something unknown 101 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 102 | 103 | 104 | def split_s3_path(url: str) -> Tuple[str, str]: 105 | """Split a full s3 path into the bucket name and path.""" 106 | parsed = urlparse(url) 107 | if not parsed.netloc or not parsed.path: 108 | raise ValueError("bad s3 path {}".format(url)) 109 | bucket_name = parsed.netloc 110 | s3_path = parsed.path 111 | # Remove '/' at beginning of path. 112 | if s3_path.startswith("/"): 113 | s3_path = s3_path[1:] 114 | return bucket_name, s3_path 115 | 116 | 117 | def s3_request(func: Callable): 118 | """ 119 | Wrapper function for s3 requests in order to create more helpful error 120 | messages. 121 | """ 122 | 123 | @wraps(func) 124 | def wrapper(url: str, *args, **kwargs): 125 | try: 126 | return func(url, *args, **kwargs) 127 | except ClientError as exc: 128 | if int(exc.response["Error"]["Code"]) == 404: 129 | raise FileNotFoundError("file {} not found".format(url)) 130 | else: 131 | raise 132 | 133 | return wrapper 134 | 135 | 136 | @s3_request 137 | def s3_etag(url: str) -> Optional[str]: 138 | """Check ETag on S3 object.""" 139 | s3_resource = boto3.resource("s3") 140 | bucket_name, s3_path = split_s3_path(url) 141 | s3_object = s3_resource.Object(bucket_name, s3_path) 142 | return s3_object.e_tag 143 | 144 | 145 | @s3_request 146 | def s3_get(url: str, temp_file: IO) -> None: 147 | """Pull a file directly from S3.""" 148 | s3_resource = boto3.resource("s3") 149 | bucket_name, s3_path = split_s3_path(url) 150 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 151 | 152 | 153 | def http_get(url: str, temp_file: IO) -> None: 154 | req = requests.get(url, stream=True) 155 | content_length = req.headers.get('Content-Length') 156 | total = int(content_length) if content_length is not None else None 157 | progress = tqdm(unit="B", total=total) 158 | for chunk in req.iter_content(chunk_size=1024): 159 | if chunk: # filter out keep-alive new chunks 160 | progress.update(len(chunk)) 161 | temp_file.write(chunk) 162 | progress.close() 163 | 164 | 165 | def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str: 166 | """ 167 | Given a URL, look for the corresponding dataset in the local cache. 168 | If it's not there, download it. Then return the path to the cached file. 169 | """ 170 | if cache_dir is None: 171 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 172 | if isinstance(cache_dir, Path): 173 | cache_dir = str(cache_dir) 174 | 175 | os.makedirs(cache_dir, exist_ok=True) 176 | 177 | # Get eTag to add to filename, if it exists. 178 | if url.startswith("s3://"): 179 | etag = s3_etag(url) 180 | else: 181 | response = requests.head(url, allow_redirects=True) 182 | if response.status_code != 200: 183 | raise IOError("HEAD request failed for url {} with status code {}" 184 | .format(url, response.status_code)) 185 | etag = response.headers.get("ETag") 186 | 187 | filename = url_to_filename(url, etag) 188 | 189 | # get cache path to put the file 190 | cache_path = os.path.join(cache_dir, filename) 191 | 192 | if not os.path.exists(cache_path): 193 | # Download to temporary file, then copy to cache dir once finished. 194 | # Otherwise you get corrupt cache entries if the download gets interrupted. 195 | with tempfile.NamedTemporaryFile() as temp_file: 196 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name) 197 | 198 | # GET file object 199 | if url.startswith("s3://"): 200 | s3_get(url, temp_file) 201 | else: 202 | http_get(url, temp_file) 203 | 204 | # we are copying the file before closing it, so flush to avoid truncation 205 | temp_file.flush() 206 | # shutil.copyfileobj() starts at the current position, so go to the start 207 | temp_file.seek(0) 208 | 209 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) 210 | with open(cache_path, 'wb') as cache_file: 211 | shutil.copyfileobj(temp_file, cache_file) 212 | 213 | logger.info("creating metadata file for %s", cache_path) 214 | meta = {'url': url, 'etag': etag} 215 | meta_path = cache_path + '.json' 216 | with open(meta_path, 'w') as meta_file: 217 | json.dump(meta, meta_file) 218 | 219 | logger.info("removing temp file %s", temp_file.name) 220 | 221 | return cache_path 222 | 223 | 224 | def read_set_from_file(filename: str) -> Set[str]: 225 | ''' 226 | Extract a de-duped collection (set) of text from a file. 227 | Expected file format is one item per line. 228 | ''' 229 | collection = set() 230 | with open(filename, 'r', encoding='utf-8') as file_: 231 | for line in file_: 232 | collection.add(line.rstrip()) 233 | return collection 234 | 235 | 236 | def get_file_extension(path: str, dot=True, lower: bool = True): 237 | ext = os.path.splitext(path)[1] 238 | ext = ext if dot else ext[1:] 239 | return ext.lower() if lower else ext 240 | -------------------------------------------------------------------------------- /modules/modeling_xpool.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from: https://github.com/layer6ai-labs/xpool 3 | """ 4 | import torch 5 | from torch import nn 6 | 7 | from modules.transformer_xpool import Transformer 8 | 9 | from modules.until_module import AllGather 10 | allgather = AllGather.apply 11 | 12 | class XPool(nn.Module): 13 | def __init__(self, task_config=None): 14 | super(XPool, self).__init__() 15 | self.task_config = task_config # 只使用里面的world_size, rank 16 | self.huggingface = True 17 | self.max_frames = 12 18 | 19 | if self.huggingface: 20 | from transformers import CLIPModel 21 | self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") 22 | 23 | self.pool_frames = Transformer() 24 | 25 | 26 | def get_origin_text_and_video_features(self, input_ids, attention_mask, video): 27 | batch_size = video.shape[0] 28 | text_data = {'input_ids': input_ids.squeeze(), 'attention_mask': attention_mask.squeeze()} 29 | video_data = video 30 | video_data = video_data.reshape(-1, 3, 224, 224) 31 | 32 | if self.huggingface: 33 | text_features = self.clip.get_text_features(**text_data) 34 | video_features = self.clip.get_image_features(video_data.float()) 35 | else: 36 | text_features = self.clip.encode_text(text_data) 37 | video_features = self.clip.encode_image(video_data) 38 | 39 | video_features = video_features.reshape(batch_size, self.max_frames, -1) 40 | 41 | return text_features, video_features 42 | 43 | 44 | def forward(self, input_ids, token_type_ids, attention_mask, video, video_mask=None,return_fine=False): 45 | text_features, video_features = self.get_origin_text_and_video_features(input_ids, attention_mask, video) 46 | 47 | if self.training: 48 | text_features = allgather(text_features, self.task_config) 49 | video_features = allgather(video_features, self.task_config) 50 | torch.distributed.barrier() 51 | else: 52 | text_features = allgather(text_features, self.task_config) 53 | video_features = allgather(video_features, self.task_config) 54 | torch.distributed.barrier() 55 | 56 | video_features_pooled, frame_attention_weights = self.pool_frames(text_features, video_features) 57 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 58 | video_features_pooled = video_features_pooled / video_features_pooled.norm(dim=-1, keepdim=True) 59 | 60 | video_features_pooled = video_features_pooled.permute(1,2,0) 61 | text_features = text_features.unsqueeze(1) 62 | 63 | logit_scale = self.clip.logit_scale.exp() 64 | sims = logit_scale * torch.bmm(text_features, video_features_pooled).squeeze(1) 65 | 66 | if return_fine: 67 | return sims,frame_attention_weights 68 | else: 69 | return sims 70 | 71 | 72 | -------------------------------------------------------------------------------- /modules/module_clip_ts2net.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from: https://github.com/openai/CLIP/blob/main/clip/clip.py 3 | """ 4 | from collections import OrderedDict 5 | from typing import Tuple, Union 6 | 7 | import hashlib 8 | import os 9 | import urllib 10 | import warnings 11 | from tqdm import tqdm 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from torch import nn 16 | 17 | 18 | _MODELS = { 19 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 20 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 21 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 22 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 23 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 24 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 25 | } 26 | _PT_NAME = { 27 | "RN50": "RN50.pt", 28 | "RN101": "RN101.pt", 29 | "RN50x4": "RN50x4.pt", 30 | "RN50x16": "RN50x16.pt", 31 | "ViT-B/32": "ViT-B-32.pt", 32 | "ViT-B/16": "ViT-B-16.pt", 33 | } 34 | 35 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): 36 | os.makedirs(root, exist_ok=True) 37 | filename = os.path.basename(url) 38 | 39 | expected_sha256 = url.split("/")[-2] 40 | download_target = os.path.join(root, filename) 41 | 42 | if os.path.exists(download_target) and not os.path.isfile(download_target): 43 | raise RuntimeError(f"{download_target} exists and is not a regular file") 44 | 45 | if os.path.isfile(download_target): 46 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 47 | return download_target 48 | else: 49 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 50 | 51 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 52 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: 53 | while True: 54 | buffer = source.read(8192) 55 | if not buffer: 56 | break 57 | 58 | output.write(buffer) 59 | loop.update(len(buffer)) 60 | 61 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 62 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 63 | 64 | return download_target 65 | 66 | def available_models(): 67 | """Returns the names of available CLIP models""" 68 | return list(_MODELS.keys()) 69 | 70 | # ============================= 71 | 72 | class Bottleneck(nn.Module): 73 | expansion = 4 74 | 75 | def __init__(self, inplanes, planes, stride=1): 76 | super().__init__() 77 | 78 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 79 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 80 | self.bn1 = nn.BatchNorm2d(planes) 81 | 82 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 83 | self.bn2 = nn.BatchNorm2d(planes) 84 | 85 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 86 | 87 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 88 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 89 | 90 | self.relu = nn.ReLU(inplace=True) 91 | self.downsample = None 92 | self.stride = stride 93 | 94 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 95 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 96 | self.downsample = nn.Sequential(OrderedDict([ 97 | ("-1", nn.AvgPool2d(stride)), 98 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 99 | ("1", nn.BatchNorm2d(planes * self.expansion)) 100 | ])) 101 | 102 | def forward(self, x: torch.Tensor): 103 | identity = x 104 | 105 | out = self.relu(self.bn1(self.conv1(x))) 106 | out = self.relu(self.bn2(self.conv2(out))) 107 | out = self.avgpool(out) 108 | out = self.bn3(self.conv3(out)) 109 | 110 | if self.downsample is not None: 111 | identity = self.downsample(x) 112 | 113 | out += identity 114 | out = self.relu(out) 115 | return out 116 | 117 | 118 | class AttentionPool2d(nn.Module): 119 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 120 | super().__init__() 121 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 122 | self.k_proj = nn.Linear(embed_dim, embed_dim) 123 | self.q_proj = nn.Linear(embed_dim, embed_dim) 124 | self.v_proj = nn.Linear(embed_dim, embed_dim) 125 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 126 | self.num_heads = num_heads 127 | 128 | def forward(self, x): 129 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 130 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 131 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 132 | x, _ = F.multi_head_attention_forward( 133 | query=x, key=x, value=x, 134 | embed_dim_to_check=x.shape[-1], 135 | num_heads=self.num_heads, 136 | q_proj_weight=self.q_proj.weight, 137 | k_proj_weight=self.k_proj.weight, 138 | v_proj_weight=self.v_proj.weight, 139 | in_proj_weight=None, 140 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 141 | bias_k=None, 142 | bias_v=None, 143 | add_zero_attn=False, 144 | dropout_p=0, 145 | out_proj_weight=self.c_proj.weight, 146 | out_proj_bias=self.c_proj.bias, 147 | use_separate_proj_weight=True, 148 | training=self.training, 149 | need_weights=False 150 | ) 151 | 152 | return x[0] 153 | 154 | 155 | class ModifiedResNet(nn.Module): 156 | """ 157 | A ResNet class that is similar to torchvision's but contains the following changes: 158 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 159 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 160 | - The final pooling layer is a QKV attention instead of an average pool 161 | """ 162 | 163 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 164 | super().__init__() 165 | self.output_dim = output_dim 166 | self.input_resolution = input_resolution 167 | 168 | # the 3-layer stem 169 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 170 | self.bn1 = nn.BatchNorm2d(width // 2) 171 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 172 | self.bn2 = nn.BatchNorm2d(width // 2) 173 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 174 | self.bn3 = nn.BatchNorm2d(width) 175 | self.avgpool = nn.AvgPool2d(2) 176 | self.relu = nn.ReLU(inplace=True) 177 | 178 | # residual layers 179 | self._inplanes = width # this is a *mutable* variable used during construction 180 | self.layer1 = self._make_layer(width, layers[0]) 181 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 182 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 183 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 184 | 185 | embed_dim = width * 32 # the ResNet feature dimension 186 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 187 | 188 | def _make_layer(self, planes, blocks, stride=1): 189 | layers = [Bottleneck(self._inplanes, planes, stride)] 190 | 191 | self._inplanes = planes * Bottleneck.expansion 192 | for _ in range(1, blocks): 193 | layers.append(Bottleneck(self._inplanes, planes)) 194 | 195 | return nn.Sequential(*layers) 196 | 197 | def forward(self, x): 198 | def stem(x): 199 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: 200 | x = self.relu(bn(conv(x))) 201 | x = self.avgpool(x) 202 | return x 203 | 204 | x = x.type(self.conv1.weight.dtype) 205 | x = stem(x) 206 | x = self.layer1(x) 207 | x = self.layer2(x) 208 | x = self.layer3(x) 209 | x = self.layer4(x) 210 | x = self.attnpool(x) 211 | 212 | return x 213 | 214 | 215 | class LayerNorm(nn.LayerNorm): 216 | """Subclass torch's LayerNorm to handle fp16.""" 217 | 218 | def forward(self, x: torch.Tensor): 219 | orig_type = x.dtype 220 | ret = super().forward(x.type(torch.float32)) 221 | return ret.type(orig_type) 222 | 223 | 224 | class QuickGELU(nn.Module): 225 | def forward(self, x: torch.Tensor): 226 | return x * torch.sigmoid(1.702 * x) 227 | 228 | 229 | class ResidualAttentionBlock(nn.Module): 230 | def __init__(self, d_model: int, n_head: int, attn_mask=None): 231 | super().__init__() 232 | 233 | self.attn = nn.MultiheadAttention(d_model, n_head) 234 | self.ln_1 = LayerNorm(d_model) 235 | self.mlp = nn.Sequential(OrderedDict([ 236 | ("c_fc", nn.Linear(d_model, d_model * 4)), 237 | ("gelu", QuickGELU()), 238 | ("c_proj", nn.Linear(d_model * 4, d_model)) 239 | ])) 240 | self.ln_2 = LayerNorm(d_model) 241 | self.attn_mask = attn_mask 242 | 243 | def attention(self, x: torch.Tensor): 244 | attn_mask_ = self.attn_mask 245 | if self.attn_mask is not None and hasattr(self.attn_mask, '__call__'): 246 | attn_mask_ = self.attn_mask(x.size(0)) # LND 247 | 248 | attn_mask_ = attn_mask_.to(dtype=x.dtype, device=x.device) if attn_mask_ is not None else None 249 | return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask_)[0] 250 | 251 | def forward(self, x_tuple:tuple): 252 | x, video_frame = x_tuple 253 | x = x + self.attention(self.ln_1(x)) 254 | x = x + self.mlp(self.ln_2(x)) 255 | return (x, video_frame) 256 | 257 | 258 | class Transformer(nn.Module): 259 | def __init__(self, width: int, layers: int, heads: int, attn_mask = None): 260 | super().__init__() 261 | self.width = width 262 | self.layers = layers 263 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 264 | 265 | def forward(self, x: torch.Tensor, video_frame=-1): 266 | return self.resblocks((x, video_frame))[0] 267 | 268 | 269 | class VisualTransformer(nn.Module): 270 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int, 271 | linear_patch: str = '2d',): 272 | super().__init__() 273 | self.input_resolution = input_resolution 274 | self.output_dim = output_dim 275 | 276 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 277 | 278 | scale = width ** -0.5 279 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 280 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 281 | self.ln_pre = LayerNorm(width) 282 | 283 | self.transformer = Transformer(width, layers, heads) 284 | 285 | self.ln_post = LayerNorm(width) 286 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 287 | 288 | # For 3D 289 | assert linear_patch in ['2d', '3d'] 290 | self.linear_patch = linear_patch 291 | if self.linear_patch == '3d': 292 | self.conv2 = nn.Conv3d(in_channels=3, out_channels=width, kernel_size=(3, patch_size, patch_size), 293 | stride=(1, patch_size, patch_size), padding=(1, 0, 0), bias=False) 294 | 295 | def forward(self, x: torch.Tensor, video_frame=-1): 296 | 297 | if self.linear_patch == '3d': 298 | assert video_frame != -1 299 | x_3d = x.reshape(-1, video_frame, x.shape[-3], x.shape[-2], x.shape[-1]) 300 | x_3d = x_3d.permute(0, 2, 1, 3, 4) 301 | x_3d = self.conv2(x_3d) # shape = [*, width, frame, grid, grid] 302 | x_3d = x_3d.permute(0, 2, 1, 3, 4) # shape = [*, frame, width, grid, grid] 303 | x = x_3d.reshape(-1, x_3d.shape[-3], x_3d.shape[-2], x_3d.shape[-1]).contiguous() # shape = [*, width, grid, grid] 304 | else: 305 | x = self.conv1(x) # shape = [*, width, grid, grid] 306 | 307 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 308 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 309 | # print("xshape:", x.shape) 310 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 311 | x = x + self.positional_embedding.to(x.dtype) 312 | x = self.ln_pre(x) 313 | 314 | x = x.permute(1, 0, 2) # NLD -> LND 315 | x = self.transformer(x, video_frame=video_frame) 316 | x = x.permute(1, 0, 2) # LND -> NLD 317 | 318 | # Move the three lines below to `encode_image` for entire hidden sequence 319 | # x = self.ln_post(x[:, 0, :]) 320 | # if self.proj is not None: 321 | # x = x @ self.proj 322 | 323 | return x 324 | 325 | 326 | class CLIP(nn.Module): 327 | def __init__(self, 328 | embed_dim: int, 329 | # vision 330 | image_resolution: int, 331 | vision_layers: Union[Tuple[int, int, int, int], int], 332 | vision_width: int, 333 | vision_patch_size: int, 334 | # text 335 | context_length: int, 336 | vocab_size: int, 337 | transformer_width: int, 338 | transformer_heads: int, 339 | transformer_layers: int, 340 | # vision linear of patch 341 | linear_patch: str = '2d', 342 | ): 343 | super().__init__() 344 | 345 | self.context_length = context_length 346 | 347 | if isinstance(vision_layers, (tuple, list)): 348 | vision_heads = vision_width * 32 // 64 349 | self.visual = ModifiedResNet( 350 | layers=vision_layers, 351 | output_dim=embed_dim, 352 | heads=vision_heads, 353 | input_resolution=image_resolution, 354 | width=vision_width 355 | ) 356 | else: 357 | vision_heads = vision_width // 64 358 | self.visual = VisualTransformer( 359 | input_resolution=image_resolution, 360 | patch_size=vision_patch_size, 361 | width=vision_width, 362 | layers=vision_layers, 363 | heads=vision_heads, 364 | output_dim=embed_dim, 365 | linear_patch=linear_patch 366 | ) 367 | 368 | self.transformer = Transformer( 369 | width=transformer_width, 370 | layers=transformer_layers, 371 | heads=transformer_heads, 372 | attn_mask=self.build_attention_mask 373 | ) 374 | 375 | self.vocab_size = vocab_size 376 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 377 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 378 | self.ln_final = LayerNorm(transformer_width) 379 | 380 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 381 | self.logit_scale = nn.Parameter(torch.ones([])) 382 | 383 | self.initialize_parameters() 384 | 385 | def initialize_parameters(self): 386 | nn.init.normal_(self.token_embedding.weight, std=0.02) 387 | nn.init.normal_(self.positional_embedding, std=0.01) 388 | 389 | if isinstance(self.visual, ModifiedResNet): 390 | if self.visual.attnpool is not None: 391 | std = self.visual.attnpool.c_proj.in_features ** -0.5 392 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 393 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 394 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 395 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 396 | 397 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 398 | for name, param in resnet_block.named_parameters(): 399 | if name.endswith("bn3.weight"): 400 | nn.init.zeros_(param) 401 | 402 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 403 | attn_std = self.transformer.width ** -0.5 404 | fc_std = (2 * self.transformer.width) ** -0.5 405 | for block in self.transformer.resblocks: 406 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 407 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 408 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 409 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 410 | 411 | if self.text_projection is not None: 412 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 413 | 414 | @staticmethod 415 | def get_config(pretrained_clip_name="ViT-B/32"): 416 | model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "ViT-B-32.pt") 417 | if pretrained_clip_name in _MODELS and pretrained_clip_name in _PT_NAME: 418 | model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), _PT_NAME[pretrained_clip_name]) 419 | 420 | if pretrained_clip_name in ["ViT-B/32", "ViT-B/16"] and os.path.exists(model_path): 421 | pass 422 | else: 423 | if pretrained_clip_name in _MODELS: 424 | model_path = _download(_MODELS[pretrained_clip_name]) 425 | elif os.path.isfile(pretrained_clip_name): 426 | model_path = pretrained_clip_name 427 | else: 428 | raise RuntimeError(f"Model {pretrained_clip_name} not found; available models = {available_models()}") 429 | 430 | try: 431 | # loading JIT archive 432 | model = torch.jit.load(model_path, map_location="cpu").eval() 433 | state_dict = model.state_dict() 434 | except RuntimeError: 435 | state_dict = torch.load(model_path, map_location="cpu") 436 | 437 | return state_dict 438 | 439 | def build_attention_mask(self, context_length): 440 | # lazily create causal attention mask, with full attention between the vision tokens 441 | # pytorch uses additive attention mask; fill with -inf 442 | mask = torch.zeros(context_length, context_length) 443 | mask.fill_(float("-inf")) 444 | mask.triu_(1) # zero out the lower diagonal 445 | return mask 446 | 447 | @property 448 | def dtype(self): 449 | return self.visual.conv1.weight.dtype 450 | 451 | def encode_image(self, image, return_hidden=False, video_frame=-1): 452 | # with torch.no_grad(): 453 | hidden = self.visual(image.type(self.dtype), video_frame=video_frame) 454 | hidden = self.visual.ln_post(hidden) @ self.visual.proj 455 | 456 | return hidden 457 | 458 | def encode_text(self, text, return_hidden=False): 459 | # with torch.no_grad(): 460 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 461 | 462 | pos_emd = self.positional_embedding[:x.size(1), :].type(self.dtype) 463 | x = x + pos_emd 464 | 465 | x = x.permute(1, 0, 2) # NLD -> LND 466 | x = self.transformer(x) 467 | x = x.permute(1, 0, 2) # LND -> NLD 468 | 469 | hidden = self.ln_final(x).type(self.dtype) @ self.text_projection 470 | # x.shape = [batch_size, n_ctx, transformer.width] 471 | # take features from the eot embedding (eot_token is the highest number in each sequence) 472 | x = hidden[torch.arange(hidden.shape[0]), text.argmax(dim=-1)] 473 | return x 474 | 475 | 476 | def forward(self, image, text): 477 | image_features = self.encode_image(image) 478 | text_features = self.encode_text(text) 479 | 480 | # normalized features 481 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 482 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 483 | 484 | # cosine similarity as logits 485 | logit_scale = self.logit_scale.exp() 486 | logits_per_image = logit_scale * image_features @ text_features.t() 487 | logits_per_text = logit_scale * text_features @ image_features.t() 488 | 489 | # shape = [global_batch_size, global_batch_size] 490 | return logits_per_image, logits_per_text 491 | 492 | 493 | def convert_weights(model: nn.Module): 494 | """Convert applicable model parameters to fp16""" 495 | 496 | def _convert_weights_to_fp16(l): 497 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)): 498 | l.weight.data = l.weight.data.half() 499 | if l.bias is not None: 500 | l.bias.data = l.bias.data.half() 501 | 502 | if isinstance(l, nn.MultiheadAttention): 503 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 504 | tensor = getattr(l, attr) 505 | if tensor is not None: 506 | tensor.data = tensor.data.half() 507 | 508 | for name in ["text_projection", "proj"]: 509 | if hasattr(l, name): 510 | attr = getattr(l, name) 511 | if attr is not None: 512 | attr.data = attr.data.half() 513 | 514 | model.apply(_convert_weights_to_fp16) 515 | 516 | 517 | def build_model(state_dict: dict): 518 | vit = "visual.proj" in state_dict 519 | 520 | if vit: 521 | vision_width = state_dict["visual.conv1.weight"].shape[0] 522 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 523 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 524 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 525 | image_resolution = vision_patch_size * grid_size 526 | else: 527 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 528 | vision_layers = tuple(counts) 529 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 530 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 531 | vision_patch_size = None 532 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 533 | image_resolution = output_width * 32 534 | 535 | embed_dim = state_dict["text_projection"].shape[1] 536 | context_length = state_dict["positional_embedding"].shape[0] 537 | vocab_size = state_dict["token_embedding.weight"].shape[0] 538 | transformer_width = state_dict["ln_final.weight"].shape[0] 539 | transformer_heads = transformer_width // 64 540 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) 541 | 542 | model = CLIP( 543 | embed_dim, 544 | image_resolution, vision_layers, vision_width, vision_patch_size, 545 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 546 | ) 547 | 548 | for key in ["input_resolution", "context_length", "vocab_size"]: 549 | if key in state_dict: 550 | del state_dict[key] 551 | 552 | convert_weights(model) 553 | model.load_state_dict(state_dict) 554 | return model.eval() 555 | -------------------------------------------------------------------------------- /modules/module_cross.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import copy 7 | import json 8 | import math 9 | import logging 10 | import tarfile 11 | import tempfile 12 | import shutil 13 | 14 | import torch 15 | from torch import nn 16 | import torch.nn.functional as F 17 | from .file_utils import cached_path 18 | from .until_config import PretrainedConfig 19 | from .until_module import PreTrainedModel, LayerNorm, ACT2FN 20 | from collections import OrderedDict 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | PRETRAINED_MODEL_ARCHIVE_MAP = {} 25 | CONFIG_NAME = 'cross_config.json' 26 | WEIGHTS_NAME = 'cross_pytorch_model.bin' 27 | 28 | 29 | class CrossConfig(PretrainedConfig): 30 | """Configuration class to store the configuration of a `CrossModel`. 31 | """ 32 | pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP 33 | config_name = CONFIG_NAME 34 | weights_name = WEIGHTS_NAME 35 | def __init__(self, 36 | vocab_size_or_config_json_file, 37 | hidden_size=768, 38 | num_hidden_layers=12, 39 | num_attention_heads=12, 40 | intermediate_size=3072, 41 | hidden_act="gelu", 42 | hidden_dropout_prob=0.1, 43 | attention_probs_dropout_prob=0.1, 44 | max_position_embeddings=512, 45 | type_vocab_size=2, 46 | initializer_range=0.02): 47 | """Constructs CrossConfig. 48 | 49 | Args: 50 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `CrossModel`. 51 | hidden_size: Size of the encoder layers and the pooler layer. 52 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 53 | num_attention_heads: Number of attention heads for each attention layer in 54 | the Transformer encoder. 55 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 56 | layer in the Transformer encoder. 57 | hidden_act: The non-linear activation function (function or string) in the 58 | encoder and pooler. If string, "gelu", "relu" and "swish" are supported. 59 | hidden_dropout_prob: The dropout probabilitiy for all fully connected 60 | layers in the embeddings, encoder, and pooler. 61 | attention_probs_dropout_prob: The dropout ratio for the attention 62 | probabilities. 63 | max_position_embeddings: The maximum sequence length that this model might 64 | ever be used with. Typically set this to something large just in case 65 | (e.g., 512 or 1024 or 2048). 66 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 67 | `CrossModel`. 68 | initializer_range: The sttdev of the truncated_normal_initializer for 69 | initializing all weight matrices. 70 | """ 71 | if isinstance(vocab_size_or_config_json_file, str): 72 | with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: 73 | json_config = json.loads(reader.read()) 74 | for key, value in json_config.items(): 75 | self.__dict__[key] = value 76 | elif isinstance(vocab_size_or_config_json_file, int): 77 | self.vocab_size = vocab_size_or_config_json_file 78 | self.hidden_size = hidden_size 79 | self.num_hidden_layers = num_hidden_layers 80 | self.num_attention_heads = num_attention_heads 81 | self.hidden_act = hidden_act 82 | self.intermediate_size = intermediate_size 83 | self.hidden_dropout_prob = hidden_dropout_prob 84 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 85 | self.max_position_embeddings = max_position_embeddings 86 | self.type_vocab_size = type_vocab_size 87 | self.initializer_range = initializer_range 88 | else: 89 | raise ValueError("First argument must be either a vocabulary size (int)" 90 | "or the path to a pretrained model config file (str)") 91 | 92 | class QuickGELU(nn.Module): 93 | def forward(self, x: torch.Tensor): 94 | return x * torch.sigmoid(1.702 * x) 95 | 96 | class ResidualAttentionBlock(nn.Module): 97 | def __init__(self, d_model: int, n_head: int): 98 | super().__init__() 99 | 100 | self.attn = nn.MultiheadAttention(d_model, n_head) 101 | self.ln_1 = LayerNorm(d_model) 102 | self.mlp = nn.Sequential(OrderedDict([ 103 | ("c_fc", nn.Linear(d_model, d_model * 4)), 104 | ("gelu", QuickGELU()), 105 | ("c_proj", nn.Linear(d_model * 4, d_model)) 106 | ])) 107 | self.ln_2 = LayerNorm(d_model) 108 | self.n_head = n_head 109 | 110 | def attention(self, x: torch.Tensor, attn_mask: torch.Tensor): 111 | attn_mask_ = attn_mask.repeat_interleave(self.n_head, dim=0) 112 | return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask_)[0] 113 | 114 | def forward(self, para_tuple: tuple): 115 | # x: torch.Tensor, attn_mask: torch.Tensor 116 | # print(para_tuple) 117 | x, attn_mask = para_tuple 118 | x = x + self.attention(self.ln_1(x), attn_mask) 119 | x = x + self.mlp(self.ln_2(x)) 120 | return (x, attn_mask) 121 | 122 | class Transformer(nn.Module): 123 | def __init__(self, width: int, layers: int, heads: int): 124 | super().__init__() 125 | self.width = width 126 | self.layers = layers 127 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads) for _ in range(layers)]) 128 | 129 | def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): 130 | return self.resblocks((x, attn_mask))[0] 131 | 132 | class CrossEmbeddings(nn.Module): 133 | """Construct the embeddings from word, position and token_type embeddings. 134 | """ 135 | def __init__(self, config): 136 | super(CrossEmbeddings, self).__init__() 137 | 138 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 139 | # self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) 140 | # self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12) 141 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 142 | 143 | def forward(self, concat_embeddings, concat_type=None): 144 | 145 | batch_size, seq_length = concat_embeddings.size(0), concat_embeddings.size(1) 146 | # if concat_type is None: 147 | # concat_type = torch.zeros(batch_size, concat_type).to(concat_embeddings.device) 148 | 149 | position_ids = torch.arange(seq_length, dtype=torch.long, device=concat_embeddings.device) 150 | position_ids = position_ids.unsqueeze(0).expand(concat_embeddings.size(0), -1) 151 | 152 | # token_type_embeddings = self.token_type_embeddings(concat_type) 153 | position_embeddings = self.position_embeddings(position_ids) 154 | 155 | embeddings = concat_embeddings + position_embeddings # + token_type_embeddings 156 | # embeddings = self.LayerNorm(embeddings) 157 | embeddings = self.dropout(embeddings) 158 | return embeddings 159 | 160 | class CrossPooler(nn.Module): 161 | def __init__(self, config): 162 | super(CrossPooler, self).__init__() 163 | self.ln_pool = LayerNorm(config.hidden_size) 164 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 165 | self.activation = QuickGELU() 166 | 167 | def forward(self, hidden_states, hidden_mask): 168 | # We "pool" the model by simply taking the hidden state corresponding 169 | # to the first token. 170 | hidden_states = self.ln_pool(hidden_states) 171 | pooled_output = hidden_states[:, 0] 172 | pooled_output = self.dense(pooled_output) 173 | pooled_output = self.activation(pooled_output) 174 | return pooled_output 175 | 176 | class CrossModel(PreTrainedModel): 177 | 178 | def initialize_parameters(self): 179 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 180 | attn_std = self.transformer.width ** -0.5 181 | fc_std = (2 * self.transformer.width) ** -0.5 182 | for block in self.transformer.resblocks: 183 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 184 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 185 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 186 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 187 | 188 | def __init__(self, config): 189 | super(CrossModel, self).__init__(config) 190 | 191 | self.embeddings = CrossEmbeddings(config) 192 | 193 | transformer_width = config.hidden_size 194 | transformer_layers = config.num_hidden_layers 195 | transformer_heads = config.num_attention_heads 196 | self.transformer = Transformer(width=transformer_width, layers=transformer_layers, heads=transformer_heads,) 197 | self.pooler = CrossPooler(config) 198 | self.apply(self.init_weights) 199 | 200 | def build_attention_mask(self, attention_mask): 201 | extended_attention_mask = attention_mask.unsqueeze(1) 202 | extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility 203 | extended_attention_mask = (1.0 - extended_attention_mask) * -1000000.0 204 | extended_attention_mask = extended_attention_mask.expand(-1, attention_mask.size(1), -1) 205 | return extended_attention_mask 206 | 207 | def forward(self, concat_input, concat_type=None, attention_mask=None, output_all_encoded_layers=True): 208 | 209 | if attention_mask is None: 210 | attention_mask = torch.ones(concat_input.size(0), concat_input.size(1)) 211 | if concat_type is None: 212 | concat_type = torch.zeros_like(attention_mask) 213 | 214 | extended_attention_mask = self.build_attention_mask(attention_mask) 215 | 216 | embedding_output = self.embeddings(concat_input, concat_type) 217 | embedding_output = embedding_output.permute(1, 0, 2) # NLD -> LND 218 | embedding_output = self.transformer(embedding_output, extended_attention_mask) 219 | embedding_output = embedding_output.permute(1, 0, 2) # LND -> NLD 220 | 221 | pooled_output = self.pooler(embedding_output, hidden_mask=attention_mask) 222 | 223 | return embedding_output, pooled_output 224 | -------------------------------------------------------------------------------- /modules/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.optim.optimizer import required 21 | from torch.nn.utils import clip_grad_norm_ 22 | import logging 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | def warmup_cosine(x, warmup=0.002): 27 | if x < warmup: 28 | return x/warmup 29 | return 0.5 * (1.0 + math.cos(math.pi * x)) 30 | 31 | def warmup_constant(x, warmup=0.002): 32 | """ Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps. 33 | Learning rate is 1. afterwards. """ 34 | if x < warmup: 35 | return x/warmup 36 | return 1.0 37 | 38 | def warmup_linear(x, warmup=0.002): 39 | """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step. 40 | After `t_total`-th training step, learning rate is zero. """ 41 | if x < warmup: 42 | return x/warmup 43 | return max((x-1.)/(warmup-1.), 0) 44 | 45 | SCHEDULES = { 46 | 'warmup_cosine': warmup_cosine, 47 | 'warmup_constant': warmup_constant, 48 | 'warmup_linear': warmup_linear, 49 | } 50 | 51 | 52 | class BertAdam(Optimizer): 53 | """Implements BERT version of Adam algorithm with weight decay fix. 54 | Params: 55 | lr: learning rate 56 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 57 | t_total: total number of training steps for the learning 58 | rate schedule, -1 means constant learning rate. Default: -1 59 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear' 60 | b1: Adams b1. Default: 0.9 61 | b2: Adams b2. Default: 0.999 62 | e: Adams epsilon. Default: 1e-6 63 | weight_decay: Weight decay. Default: 0.01 64 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 65 | """ 66 | def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear', 67 | b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, 68 | max_grad_norm=1.0): 69 | if lr is not required and lr < 0.0: 70 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 71 | if schedule not in SCHEDULES: 72 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 73 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 74 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 75 | if not 0.0 <= b1 < 1.0: 76 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 77 | if not 0.0 <= b2 < 1.0: 78 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 79 | if not e >= 0.0: 80 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 81 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, 82 | b1=b1, b2=b2, e=e, weight_decay=weight_decay, 83 | max_grad_norm=max_grad_norm) 84 | super(BertAdam, self).__init__(params, defaults) 85 | 86 | def get_lr(self): 87 | lr = [] 88 | for group in self.param_groups: 89 | for p in group['params']: 90 | if p.grad is None: 91 | continue 92 | state = self.state[p] 93 | if len(state) == 0: 94 | return [0] 95 | if group['t_total'] != -1: 96 | schedule_fct = SCHEDULES[group['schedule']] 97 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 98 | else: 99 | lr_scheduled = group['lr'] 100 | lr.append(lr_scheduled) 101 | return lr 102 | 103 | def step(self, closure=None): 104 | """Performs a single optimization step. 105 | Arguments: 106 | closure (callable, optional): A closure that reevaluates the model 107 | and returns the loss. 108 | """ 109 | loss = None 110 | if closure is not None: 111 | loss = closure() 112 | 113 | for group in self.param_groups: 114 | for p in group['params']: 115 | if p.grad is None: 116 | continue 117 | grad = p.grad.data 118 | if grad.is_sparse: 119 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 120 | 121 | state = self.state[p] 122 | 123 | # State initialization 124 | if len(state) == 0: 125 | state['step'] = 0 126 | # Exponential moving average of gradient values 127 | state['next_m'] = torch.zeros_like(p.data) 128 | # Exponential moving average of squared gradient values 129 | state['next_v'] = torch.zeros_like(p.data) 130 | 131 | next_m, next_v = state['next_m'], state['next_v'] 132 | beta1, beta2 = group['b1'], group['b2'] 133 | 134 | # Add grad clipping 135 | if group['max_grad_norm'] > 0: 136 | clip_grad_norm_(p, group['max_grad_norm']) 137 | 138 | # Decay the first and second moment running average coefficient 139 | # In-place operations to update the averages at the same time 140 | # next_m.mul_(beta1).add_(1 - beta1, grad) --> pytorch 1.7 141 | next_m.mul_(beta1).add_(grad, alpha=1 - beta1) 142 | # next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) --> pytorch 1.7 143 | next_v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 144 | update = next_m / (next_v.sqrt() + group['e']) 145 | 146 | # Just adding the square of the weights to the loss function is *not* 147 | # the correct way of using L2 regularization/weight decay with Adam, 148 | # since that will interact with the m and v parameters in strange ways. 149 | # 150 | # Instead we want to decay the weights in a manner that doesn't interact 151 | # with the m/v parameters. This is equivalent to adding the square 152 | # of the weights to the loss with plain (non-momentum) SGD. 153 | if group['weight_decay'] > 0.0: 154 | update += group['weight_decay'] * p.data 155 | 156 | if group['t_total'] != -1: 157 | schedule_fct = SCHEDULES[group['schedule']] 158 | progress = state['step']/group['t_total'] 159 | lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup']) 160 | else: 161 | lr_scheduled = group['lr'] 162 | 163 | update_with_lr = lr_scheduled * update 164 | p.data.add_(-update_with_lr) 165 | 166 | state['step'] += 1 167 | 168 | return loss -------------------------------------------------------------------------------- /modules/tokenization_clip.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | import json 9 | import numpy as np 10 | import torch 11 | 12 | @lru_cache() 13 | def default_bpe(): 14 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 15 | 16 | 17 | @lru_cache() 18 | def bytes_to_unicode(): 19 | """ 20 | Returns list of utf-8 byte and a corresponding list of unicode strings. 21 | The reversible bpe codes work on unicode strings. 22 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 23 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 24 | This is a signficant percentage of your normal, say, 32K bpe vocab. 25 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 26 | And avoids mapping to whitespace/control characters the bpe code barfs on. 27 | """ 28 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 29 | cs = bs[:] 30 | n = 0 31 | for b in range(2**8): 32 | if b not in bs: 33 | bs.append(b) 34 | cs.append(2**8+n) 35 | n += 1 36 | cs = [chr(n) for n in cs] 37 | return dict(zip(bs, cs)) 38 | 39 | 40 | def get_pairs(word): 41 | """Return set of symbol pairs in a word. 42 | Word is represented as tuple of symbols (symbols being variable-length strings). 43 | """ 44 | pairs = set() 45 | prev_char = word[0] 46 | for char in word[1:]: 47 | pairs.add((prev_char, char)) 48 | prev_char = char 49 | return pairs 50 | 51 | 52 | def basic_clean(text): 53 | text = ftfy.fix_text(text) 54 | text = html.unescape(html.unescape(text)) 55 | return text.strip() 56 | 57 | 58 | def whitespace_clean(text): 59 | text = re.sub(r'\s+', ' ', text) 60 | text = text.strip() 61 | return text 62 | 63 | 64 | class SimpleTokenizer(object): 65 | def __init__(self, bpe_path: str = default_bpe()): 66 | self.byte_encoder = bytes_to_unicode() 67 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 68 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 69 | merges = merges[1:49152-256-2+1] 70 | merges = [tuple(merge.split()) for merge in merges] 71 | vocab = list(bytes_to_unicode().values()) 72 | vocab = vocab + [v+'' for v in vocab] 73 | for merge in merges: 74 | vocab.append(''.join(merge)) 75 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 76 | self.encoder = dict(zip(vocab, range(len(vocab)))) 77 | self.decoder = {v: k for k, v in self.encoder.items()} 78 | 79 | # self.saliencyencoder = self.construct_saliencyencoder() 80 | 81 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 82 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 83 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 84 | 85 | self.vocab = self.encoder 86 | 87 | 88 | def bpe(self, token): 89 | if token in self.cache: 90 | return self.cache[token] 91 | word = tuple(token[:-1]) + ( token[-1] + '',) 92 | pairs = get_pairs(word) 93 | 94 | if not pairs: 95 | return token+'' 96 | 97 | while True: 98 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 99 | if bigram not in self.bpe_ranks: 100 | break 101 | first, second = bigram 102 | new_word = [] 103 | i = 0 104 | while i < len(word): 105 | try: 106 | j = word.index(first, i) 107 | new_word.extend(word[i:j]) 108 | i = j 109 | except: 110 | new_word.extend(word[i:]) 111 | break 112 | 113 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 114 | new_word.append(first+second) 115 | i += 2 116 | else: 117 | new_word.append(word[i]) 118 | i += 1 119 | new_word = tuple(new_word) 120 | word = new_word 121 | if len(word) == 1: 122 | break 123 | else: 124 | pairs = get_pairs(word) 125 | word = ' '.join(word) 126 | self.cache[token] = word 127 | return word 128 | 129 | def encode(self, text): 130 | bpe_tokens = [] 131 | text = whitespace_clean(basic_clean(text)).lower() 132 | for token in re.findall(self.pat, text): 133 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 134 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 135 | return bpe_tokens 136 | 137 | def decode(self, tokens): 138 | text = ''.join([self.decoder[token] for token in tokens]) 139 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 140 | return text 141 | 142 | def tokenize(self, text): 143 | tokens = [] 144 | text = whitespace_clean(basic_clean(text)).lower() 145 | for token in re.findall(self.pat, text): 146 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 147 | tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' ')) 148 | return tokens 149 | 150 | def convert_tokens_to_ids(self, tokens): 151 | return [self.encoder[bpe_token] for bpe_token in tokens] 152 | 153 | # def convert_tokens_to_saliencyids(self, tokens): 154 | # return [self.saliencyencoder[bpe_token] for bpe_token in tokens] 155 | 156 | def convert_ids_to_tokens(self, ids): 157 | return [self.decoder[bpe_id] for bpe_id in ids] 158 | 159 | 160 | class Tokenizer(object): 161 | def __init__( 162 | self, 163 | max_words=30, 164 | ): 165 | 166 | self.max_words = max_words 167 | self.tokenizer = SimpleTokenizer() 168 | self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>", 169 | "MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"} 170 | 171 | def _get_text(self, sentence_ids, sentences): 172 | n_caption = len(sentence_ids) 173 | 174 | k = n_caption 175 | pairs_text = np.zeros((k, self.max_words), dtype=np.long) 176 | pairs_mask = np.zeros((k, self.max_words), dtype=np.long) 177 | pairs_segment = np.zeros((k, self.max_words), dtype=np.long) 178 | 179 | for i, sentence_id in enumerate(sentence_ids): 180 | words = self.tokenizer.tokenize(sentences[i]) 181 | 182 | # ########################## 183 | # # add a cls in the mid # 184 | # ########################## 185 | # mid = len(words) // 2 186 | # words.insert(mid, self.SPECIAL_TOKEN["SEP_TOKEN"]) 187 | # ########################## 188 | # # add a cls in the mid # 189 | # ########################## 190 | words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words 191 | total_length_with_CLS = self.max_words - 1 192 | if len(words) > total_length_with_CLS: 193 | words = words[:total_length_with_CLS] 194 | words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]] 195 | 196 | input_ids = self.tokenizer.convert_tokens_to_ids(words) 197 | input_mask = [1] * len(input_ids) 198 | segment_ids = [0] * len(input_ids) 199 | while len(input_ids) < self.max_words: 200 | input_ids.append(0) 201 | input_mask.append(0) 202 | segment_ids.append(0) 203 | assert len(input_ids) == self.max_words 204 | assert len(input_mask) == self.max_words 205 | assert len(segment_ids) == self.max_words 206 | 207 | pairs_text[i] = np.array(input_ids) 208 | pairs_mask[i] = np.array(input_mask) 209 | pairs_segment[i] = np.array(segment_ids) 210 | 211 | pairs_text, pairs_mask, pairs_segment = torch.tensor(pairs_text), torch.tensor(pairs_mask), torch.tensor(pairs_segment) 212 | return pairs_text, pairs_mask, pairs_segment -------------------------------------------------------------------------------- /modules/transformer_xpool.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class MultiHeadedAttention(nn.Module): 7 | def __init__(self): 8 | super(MultiHeadedAttention, self).__init__() 9 | self.embed_dim = 512 10 | self.num_heads = 1 11 | assert self.embed_dim % self.num_heads == 0 12 | self.head_dim = self.embed_dim // self.num_heads 13 | 14 | self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) 15 | self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) 16 | self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) 17 | self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) 18 | 19 | 20 | def forward(self, text_embeds, video_embeds): 21 | """ 22 | Input 23 | text_embeds: num_texts x embed_dim 24 | video_embeds: num_vids x num_frames x embed_dim 25 | Output 26 | o: num_vids x num_texts x embed_dim 27 | """ 28 | num_texts, _ = text_embeds.shape 29 | # num_texts x embed_dim 30 | q = self.q_proj(text_embeds) 31 | q = q.reshape(num_texts, self.num_heads, self.head_dim) 32 | # num_heads x head_dim x num_texts 33 | q = q.permute(1,2,0) 34 | 35 | num_vids, num_frames, _ = video_embeds.shape 36 | # num_vids x num_frames x embed_dim 37 | k = self.k_proj(video_embeds) 38 | k = k.reshape(num_vids, num_frames, self.num_heads, self.head_dim) 39 | # num_vids x num_heads x num_frames x head_dim 40 | k = k.permute(0,2,1,3) 41 | 42 | # num_vids x num_frames x embed_dim 43 | v = self.v_proj(video_embeds) 44 | v = v.reshape(num_vids, num_frames, self.num_heads, self.head_dim) 45 | # num_vids x num_heads x head_dim x num_frames 46 | v = v.permute(0,2,3,1) 47 | 48 | # num_vids x num_heads x num_frames x num_texts 49 | attention_logits = k @ q 50 | attention_logits = attention_logits / math.sqrt(self.head_dim) 51 | attention_weights = F.softmax(attention_logits, dim=2) 52 | 53 | # num_vids x num_heads x head_dim x num_texts 54 | attention = v @ attention_weights 55 | # num_vids x num_texts x num_heads x head_dim 56 | attention = attention.permute(0,3,1,2) 57 | attention = attention.reshape(num_vids, num_texts, self.embed_dim) 58 | 59 | # num_vids x num_texts x embed_dim 60 | o = self.out_proj(attention) 61 | # num_vids x num_texts x num_frames 62 | attention_weights_out = attention_weights.squeeze().permute(0,2,1) 63 | # num_vids x num_frames 64 | attention_weights_out = torch.diagonal(attention_weights_out).T 65 | 66 | return o, attention_weights_out 67 | 68 | 69 | class Transformer(nn.Module): 70 | def __init__(self): 71 | super(Transformer, self).__init__() 72 | self.embed_dim = 512 73 | dropout = 0.3 74 | 75 | self.cross_attn = MultiHeadedAttention() 76 | 77 | self.linear_proj = nn.Linear(self.embed_dim, self.embed_dim) 78 | 79 | self.layer_norm1 = nn.LayerNorm(self.embed_dim) 80 | self.layer_norm2 = nn.LayerNorm(self.embed_dim) 81 | self.layer_norm3 = nn.LayerNorm(self.embed_dim) 82 | self.dropout = nn.Dropout(dropout) 83 | 84 | self._init_parameters() 85 | 86 | 87 | def _init_parameters(self): 88 | for name, param in self.named_parameters(): 89 | if 'linear' in name or 'proj' in name: 90 | if 'weight' in name: 91 | nn.init.eye_(param) 92 | elif 'bias' in name: 93 | param.data.fill_(0.) 94 | 95 | 96 | def forward(self, text_embeds, video_embeds): 97 | """ 98 | Input 99 | text_embeds: num_texts x embed_dim 100 | video_embeds: num_vids x num_frames x embed_dim 101 | Output 102 | out: num_vids x num_texts x embed_dim 103 | """ 104 | text_embeds = self.layer_norm1(text_embeds) 105 | video_embeds = self.layer_norm1(video_embeds) 106 | 107 | # num_vids x num_texts x embed_dim 108 | attn_out, attention_weights_out = self.cross_attn(text_embeds, video_embeds) 109 | attn_out = self.layer_norm2(attn_out) 110 | 111 | linear_out = self.linear_proj(attn_out) 112 | out = attn_out + self.dropout(linear_out) 113 | out = self.layer_norm3(out) 114 | 115 | return out, attention_weights_out 116 | -------------------------------------------------------------------------------- /modules/until_config.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """PyTorch BERT model.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | import copy 24 | import json 25 | import logging 26 | import tarfile 27 | import tempfile 28 | import shutil 29 | import torch 30 | from .file_utils import cached_path 31 | 32 | logger = logging.getLogger(__name__) 33 | 34 | class PretrainedConfig(object): 35 | 36 | pretrained_model_archive_map = {} 37 | config_name = "" 38 | weights_name = "" 39 | 40 | @classmethod 41 | def get_config(cls, pretrained_model_name, cache_dir, type_vocab_size, state_dict, task_config=None): 42 | archive_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), pretrained_model_name) 43 | if os.path.exists(archive_file) is False: 44 | if pretrained_model_name in cls.pretrained_model_archive_map: 45 | archive_file = cls.pretrained_model_archive_map[pretrained_model_name] 46 | else: 47 | archive_file = pretrained_model_name 48 | 49 | # redirect to the cache, if necessary 50 | try: 51 | resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) 52 | except FileNotFoundError: 53 | if task_config is None or task_config.local_rank == 0: 54 | logger.error( 55 | "Model name '{}' was not found in model name list. " 56 | "We assumed '{}' was a path or url but couldn't find any file " 57 | "associated to this path or url.".format( 58 | pretrained_model_name, 59 | archive_file)) 60 | return None 61 | if resolved_archive_file == archive_file: 62 | if task_config is None or task_config.local_rank == 0: 63 | logger.info("loading archive file {}".format(archive_file)) 64 | else: 65 | if task_config is None or task_config.local_rank == 0: 66 | logger.info("loading archive file {} from cache at {}".format( 67 | archive_file, resolved_archive_file)) 68 | tempdir = None 69 | if os.path.isdir(resolved_archive_file): 70 | serialization_dir = resolved_archive_file 71 | else: 72 | # Extract archive to temp dir 73 | tempdir = tempfile.mkdtemp() 74 | if task_config is None or task_config.local_rank == 0: 75 | logger.info("extracting archive file {} to temp dir {}".format( 76 | resolved_archive_file, tempdir)) 77 | with tarfile.open(resolved_archive_file, 'r:gz') as archive: 78 | archive.extractall(tempdir) 79 | serialization_dir = tempdir 80 | # Load config 81 | config_file = os.path.join(serialization_dir, cls.config_name) 82 | config = cls.from_json_file(config_file) 83 | config.type_vocab_size = type_vocab_size 84 | if task_config is None or task_config.local_rank == 0: 85 | logger.info("Model config {}".format(config)) 86 | 87 | if state_dict is None: 88 | weights_path = os.path.join(serialization_dir, cls.weights_name) 89 | if os.path.exists(weights_path): 90 | state_dict = torch.load(weights_path, map_location='cpu') 91 | else: 92 | if task_config is None or task_config.local_rank == 0: 93 | logger.info("Weight doesn't exsits. {}".format(weights_path)) 94 | 95 | if tempdir: 96 | # Clean up temp dir 97 | shutil.rmtree(tempdir) 98 | 99 | return config, state_dict 100 | 101 | @classmethod 102 | def from_dict(cls, json_object): 103 | """Constructs a `BertConfig` from a Python dictionary of parameters.""" 104 | config = cls(vocab_size_or_config_json_file=-1) 105 | for key, value in json_object.items(): 106 | config.__dict__[key] = value 107 | return config 108 | 109 | @classmethod 110 | def from_json_file(cls, json_file): 111 | """Constructs a `BertConfig` from a json file of parameters.""" 112 | with open(json_file, "r", encoding='utf-8') as reader: 113 | text = reader.read() 114 | return cls.from_dict(json.loads(text)) 115 | 116 | def __repr__(self): 117 | return str(self.to_json_string()) 118 | 119 | def to_dict(self): 120 | """Serializes this instance to a Python dictionary.""" 121 | output = copy.deepcopy(self.__dict__) 122 | return output 123 | 124 | def to_json_string(self): 125 | """Serializes this instance to a JSON string.""" 126 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" -------------------------------------------------------------------------------- /modules/until_module.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """PyTorch BERT model.""" 17 | 18 | import logging 19 | import numpy as np 20 | import torch 21 | from torch import nn 22 | import torch.nn.functional as F 23 | import math 24 | from modules.until_config import PretrainedConfig 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | def gelu(x): 29 | """Implementation of the gelu activation function. 30 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 31 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 32 | """ 33 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 34 | 35 | def swish(x): 36 | return x * torch.sigmoid(x) 37 | 38 | ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} 39 | 40 | class LayerNorm(nn.Module): 41 | def __init__(self, hidden_size, eps=1e-12): 42 | """Construct a layernorm module in the TF style (epsilon inside the square root). 43 | """ 44 | super(LayerNorm, self).__init__() 45 | self.weight = nn.Parameter(torch.ones(hidden_size)) 46 | self.bias = nn.Parameter(torch.zeros(hidden_size)) 47 | self.variance_epsilon = eps 48 | 49 | def forward(self, x): 50 | u = x.mean(-1, keepdim=True) 51 | s = (x - u).pow(2).mean(-1, keepdim=True) 52 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 53 | return self.weight * x + self.bias 54 | 55 | class PreTrainedModel(nn.Module): 56 | """ An abstract class to handle weights initialization and 57 | a simple interface for dowloading and loading pretrained models. 58 | """ 59 | def __init__(self, config, *inputs, **kwargs): 60 | super(PreTrainedModel, self).__init__() 61 | if not isinstance(config, PretrainedConfig): 62 | raise ValueError( 63 | "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. " 64 | "To create a model from a Google pretrained model use " 65 | "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( 66 | self.__class__.__name__, self.__class__.__name__ 67 | )) 68 | self.config = config 69 | 70 | def init_weights(self, module): 71 | """ Initialize the weights. 72 | """ 73 | if isinstance(module, (nn.Linear, nn.Embedding)): 74 | # Slightly different from the TF version which uses truncated_normal for initialization 75 | # cf https://github.com/pytorch/pytorch/pull/5617 76 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 77 | elif isinstance(module, LayerNorm): 78 | if 'beta' in dir(module) and 'gamma' in dir(module): 79 | module.beta.data.zero_() 80 | module.gamma.data.fill_(1.0) 81 | else: 82 | module.bias.data.zero_() 83 | module.weight.data.fill_(1.0) 84 | if isinstance(module, nn.Linear) and module.bias is not None: 85 | module.bias.data.zero_() 86 | 87 | def resize_token_embeddings(self, new_num_tokens=None): 88 | raise NotImplementedError 89 | 90 | @classmethod 91 | def init_preweight(cls, model, state_dict, prefix=None, task_config=None): 92 | old_keys = [] 93 | new_keys = [] 94 | for key in state_dict.keys(): 95 | new_key = None 96 | if 'gamma' in key: 97 | new_key = key.replace('gamma', 'weight') 98 | if 'beta' in key: 99 | new_key = key.replace('beta', 'bias') 100 | if new_key: 101 | old_keys.append(key) 102 | new_keys.append(new_key) 103 | for old_key, new_key in zip(old_keys, new_keys): 104 | state_dict[new_key] = state_dict.pop(old_key) 105 | 106 | if prefix is not None: 107 | old_keys = [] 108 | new_keys = [] 109 | for key in state_dict.keys(): 110 | old_keys.append(key) 111 | new_keys.append(prefix + key) 112 | for old_key, new_key in zip(old_keys, new_keys): 113 | state_dict[new_key] = state_dict.pop(old_key) 114 | 115 | missing_keys = [] 116 | unexpected_keys = [] 117 | error_msgs = [] 118 | # copy state_dict so _load_from_state_dict can modify it 119 | metadata = getattr(state_dict, '_metadata', None) 120 | state_dict = state_dict.copy() 121 | if metadata is not None: 122 | state_dict._metadata = metadata 123 | 124 | def load(module, prefix=''): 125 | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) 126 | module._load_from_state_dict( 127 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) 128 | for name, child in module._modules.items(): 129 | if child is not None: 130 | load(child, prefix + name + '.') 131 | 132 | load(model, prefix='') 133 | 134 | if prefix is None and (task_config is None or task_config.local_rank == 0): 135 | logger.info("-" * 20) 136 | if len(missing_keys) > 0: 137 | logger.info("Weights of {} not initialized from pretrained model: {}" 138 | .format(model.__class__.__name__, "\n " + "\n ".join(missing_keys))) 139 | if len(unexpected_keys) > 0: 140 | logger.info("Weights from pretrained model not used in {}: {}" 141 | .format(model.__class__.__name__, "\n " + "\n ".join(unexpected_keys))) 142 | if len(error_msgs) > 0: 143 | logger.error("Weights from pretrained model cause errors in {}: {}" 144 | .format(model.__class__.__name__, "\n " + "\n ".join(error_msgs))) 145 | 146 | return model 147 | 148 | @property 149 | def dtype(self): 150 | """ 151 | :obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). 152 | """ 153 | try: 154 | return next(self.parameters()).dtype 155 | except StopIteration: 156 | # For nn.DataParallel compatibility in PyTorch 1.5 157 | def find_tensor_attributes(module: nn.Module): 158 | tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] 159 | return tuples 160 | 161 | gen = self._named_members(get_members_fn=find_tensor_attributes) 162 | first_tuple = next(gen) 163 | return first_tuple[1].dtype 164 | 165 | @classmethod 166 | def from_pretrained(cls, config, state_dict=None, *inputs, **kwargs): 167 | """ 168 | Instantiate a PreTrainedModel from a pre-trained model file or a pytorch state dict. 169 | Download and cache the pre-trained model file if needed. 170 | """ 171 | # Instantiate model. 172 | model = cls(config, *inputs, **kwargs) 173 | if state_dict is None: 174 | return model 175 | model = cls.init_preweight(model, state_dict) 176 | 177 | return model 178 | 179 | ################################## 180 | ###### LOSS FUNCTION ############# 181 | ################################## 182 | class CrossEn(nn.Module): 183 | def __init__(self,): 184 | super(CrossEn, self).__init__() 185 | 186 | def forward(self, sim_matrix): 187 | logpt = F.log_softmax(sim_matrix, dim=-1) 188 | logpt = torch.diag(logpt) 189 | nce_loss = -logpt 190 | # here means add a square in sim matrix to enlarge the sim score 191 | sim_loss = nce_loss.mean() 192 | return sim_loss 193 | 194 | def off_diagonal(x): 195 | # return a flattened view of the off-diagonal elements of a square matrix 196 | n, m = x.shape 197 | assert n == m 198 | return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten() 199 | 200 | class BTloss(nn.Module): 201 | def __init__(self, lambd: float=0.005): 202 | super(BTloss, self).__init__() 203 | self.lambd = lambd 204 | 205 | def forward(self, sim_matrix): 206 | on_diag = torch.diagonal(sim_matrix).add_(-1).pow_(2).sum() 207 | off_diag = off_diagonal(sim_matrix).pow_(2).sum() 208 | loss = on_diag + self.lambd * off_diag 209 | return loss 210 | 211 | class ClassifyCrossEn(nn.Module): 212 | def __init__(self): 213 | super(ClassifyCrossEn, self).__init__() 214 | self.loss = nn.CrossEntropyLoss() 215 | 216 | def forward(self, sim_matrix, label): 217 | ''' 218 | sim_matrix: (bs, num_classes) 219 | label: (bs, ) 220 | ''' 221 | nce_loss = self.loss(sim_matrix, label) 222 | sim_loss = nce_loss.mean() 223 | return sim_loss 224 | 225 | 226 | class MILNCELoss(nn.Module): 227 | def __init__(self, batch_size=1, n_pair=1,): 228 | super(MILNCELoss, self).__init__() 229 | self.batch_size = batch_size 230 | self.n_pair = n_pair 231 | torch_v = float(".".join(torch.__version__.split(".")[:2])) 232 | self.bool_dtype = torch.bool if torch_v >= 1.3 else torch.uint8 233 | 234 | def forward(self, sim_matrix): 235 | mm_mask = np.eye(self.batch_size) 236 | mm_mask = np.kron(mm_mask, np.ones((self.n_pair, self.n_pair))) 237 | mm_mask = torch.tensor(mm_mask).float().to(sim_matrix.device) 238 | 239 | from_text_matrix = sim_matrix + mm_mask * -1e12 240 | from_video_matrix = sim_matrix.transpose(1, 0) 241 | 242 | new_sim_matrix = torch.cat([from_video_matrix, from_text_matrix], dim=-1) 243 | logpt = F.log_softmax(new_sim_matrix, dim=-1) 244 | 245 | mm_mask_logpt = torch.cat([mm_mask, torch.zeros_like(mm_mask)], dim=-1) 246 | masked_logpt = logpt + (torch.ones_like(mm_mask_logpt) - mm_mask_logpt) * -1e12 247 | 248 | new_logpt = -torch.logsumexp(masked_logpt, dim=-1) 249 | 250 | logpt_choice = torch.zeros_like(new_logpt) 251 | mark_ind = torch.arange(self.batch_size).to(sim_matrix.device) * self.n_pair + (self.n_pair//2) 252 | logpt_choice[mark_ind] = 1 253 | sim_loss = new_logpt.masked_select(logpt_choice.to(dtype=self.bool_dtype)).mean() 254 | return sim_loss 255 | 256 | class MaxMarginRankingLoss(nn.Module): 257 | def __init__(self, 258 | margin=1.0, 259 | negative_weighting=False, 260 | batch_size=1, 261 | n_pair=1, 262 | hard_negative_rate=0.5, 263 | ): 264 | super(MaxMarginRankingLoss, self).__init__() 265 | self.margin = margin 266 | self.n_pair = n_pair 267 | self.batch_size = batch_size 268 | easy_negative_rate = 1 - hard_negative_rate 269 | self.easy_negative_rate = easy_negative_rate 270 | self.negative_weighting = negative_weighting 271 | if n_pair > 1 and batch_size > 1: 272 | alpha = easy_negative_rate / ((batch_size - 1) * (1 - easy_negative_rate)) 273 | mm_mask = (1 - alpha) * np.eye(self.batch_size) + alpha 274 | mm_mask = np.kron(mm_mask, np.ones((n_pair, n_pair))) 275 | mm_mask = torch.tensor(mm_mask) * (batch_size * (1 - easy_negative_rate)) 276 | self.mm_mask = mm_mask.float() 277 | 278 | def forward(self, x): 279 | d = torch.diag(x) 280 | max_margin = F.relu(self.margin + x - d.view(-1, 1)) + \ 281 | F.relu(self.margin + x - d.view(1, -1)) 282 | if self.negative_weighting and self.n_pair > 1 and self.batch_size > 1: 283 | max_margin = max_margin * self.mm_mask.to(max_margin.device) 284 | return max_margin.mean() 285 | 286 | class AllGather(torch.autograd.Function): 287 | """An autograd function that performs allgather on a tensor.""" 288 | 289 | @staticmethod 290 | def forward(ctx, tensor, args): 291 | output = [torch.empty_like(tensor) for _ in range(args.world_size)] 292 | torch.distributed.all_gather(output, tensor) 293 | ctx.rank = args.rank 294 | ctx.batch_size = tensor.shape[0] 295 | return torch.cat(output, dim=0) 296 | 297 | @staticmethod 298 | def backward(ctx, grad_output): 299 | return ( 300 | grad_output[ctx.batch_size * ctx.rank : ctx.batch_size * (ctx.rank + 1)], 301 | None, 302 | ) 303 | 304 | class TextPromptEncoder(nn.Module): 305 | def __init__(self, prompt_len=5, hid_dim=512, reduction=1): 306 | super().__init__() 307 | self.hidden_size = hid_dim 308 | self.embedding = nn.Embedding(prompt_len, self.hidden_size) 309 | self.pos_embedding = nn.Parameter(torch.empty(prompt_len, self.hidden_size)) 310 | # self.mlp_head = nn.Sequential(nn.Linear(self.hidden_size, self.hidden_size // reduction), 311 | # nn.ReLU(), 312 | # nn.Linear(self.hidden_size // reduction, self.hidden_size)) 313 | 314 | def forward(self, input): 315 | ''' 316 | Args: input (bs*n_concept, prompt_len) 317 | ''' 318 | input_embed = self.embedding(input) # shape here is (bs*n_concept, prompt_len, hid_dim) 319 | pos_embed = self.pos_embedding 320 | output_embed = input_embed + pos_embed 321 | # output_embeds = self.mlp_head(input_embeds) # shape here is (bs*n_concept, prompt_len, hid_dim) 322 | return output_embed 323 | 324 | class VideoPromptEncoder(nn.Module): 325 | def __init__(self, prompt_len=4, embed_dim=768, vision_patch_size=32): 326 | super().__init__() 327 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=embed_dim, kernel_size=vision_patch_size, stride=vision_patch_size, bias=False) # shape here is (bs*frames, embed_dim, grid_size 7, grid_size) 328 | self.avgpool = nn.AdaptiveAvgPool2d(int(prompt_len ** 0.5)) 329 | scale = embed_dim ** -0.5 330 | self.positional_embedding = nn.Parameter(scale * torch.randn(prompt_len, embed_dim)) 331 | self.layer_norm = nn.LayerNorm(embed_dim) 332 | 333 | def forward(self, x): 334 | ''' 335 | Args: x (bs*n_concept, c, h, w) 336 | ''' 337 | x = F.relu(x) 338 | x = self.conv1(x) # shape = [*, embed_dim, grid 7, grid] 339 | x = self.avgpool(x) # shape here is [*, embed_dim, 2, 2] 340 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 341 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 342 | x = x + self.positional_embedding 343 | x = self.layer_norm(x) 344 | return x 345 | 346 | class PatchShiftModule(nn.Module): 347 | def __init__(self, net, video_frame, n_div): 348 | super().__init__() 349 | self.net = net 350 | self.video_frame = video_frame 351 | self.n_div = n_div 352 | logger.warning('Using patch shift!') 353 | 354 | def forward(self, query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None): 355 | # here q == k == v, psm means patch shift output 356 | x = query # shape here is LND, not NLD (50, 384, 768) 357 | # print("query.shape: ", query.shape) 358 | x = x.permute(1, 0, 2) # LND -> NLD 359 | patch_len = x.shape[-2] 360 | fold = patch_len // self.n_div 361 | x = x.reshape(-1, self.video_frame, x.shape[-2], x.shape[-1]) # shape = [bs, frame, grid ** 2, width] 362 | psm = torch.zeros_like(x) # shape = [bs, frame, grid ** 2, width] 363 | psm[:,:,:,:] = x[:,:,:,:] 364 | 365 | # ##############channel shift################## 366 | # psm[:, 1:, :, :fold] = x[:, :-1, :, :fold] 367 | # psm[:, :-1, :, fold:2*fold] = x[:, 1:, :, fold:2*fold] 368 | # ##############channel shift################## 369 | 370 | # ##############patch channel shift################## 371 | # psm[:, 1:, 1:, :fold] = x[:, :-1, 1:, :fold] 372 | # psm[:, :-1, 1:, fold:2*fold] = x[:, 1:, 1:, fold:2*fold] 373 | # ##############patch channel shift################## 374 | 375 | # ##############cls channel shift################## 376 | # psm[:, 1:, :1, :fold] = x[:, :-1, :1, :fold] 377 | # psm[:, :-1, :1, fold:2*fold] = x[:, 1:, :1, fold:2*fold] 378 | # ##############cls channel shift################## 379 | 380 | 381 | ##############left and right shift############## 382 | lshift_indices = torch.arange(start=1, end=patch_len, step=fold) 383 | psm[:, 1:, lshift_indices, :] = x[:, :-1, lshift_indices, :] # f_t = f_t-1 384 | rshift_indices = torch.arange(start=1+3, end=patch_len, step=fold) 385 | psm[:, :-1, rshift_indices, :] = x[:, 1:, rshift_indices, :] # f_t = f_t+1 386 | ##############left and right shift############## 387 | x = psm.reshape(-1, patch_len, x.shape[-1]) 388 | x = x.permute(1, 0, 2) # NLD -> LND 389 | 390 | return self.net(x, x, x, need_weights=need_weights, attn_mask=attn_mask) 391 | 392 | def make_patch_shift(net, video_frame=12, shift_layers=4, n_div=7): 393 | ''' 394 | Args: 395 | net: CLIP 396 | video_frame: need predefine here 397 | shift_layers: layers to be shift 398 | ''' 399 | 400 | def make_trans_patch_shift(stage, shift_layers): 401 | # net.clip.visual.transformer.resblocks[i] is a ResidualAttentionBlock type, contain net.attn -- a nn.MultiheadAttention 402 | # make a shift before net.attn, so it is a residual attn 403 | blocks = list(stage.children()) 404 | for i, b in enumerate(blocks): 405 | # b is a ResidualAttentionBlock type, contain self.attn 406 | # if i==4 or i==6 or i==8 or i==10: 407 | # blocks[i].attn = TokenShiftModule(b.attn, video_frame=video_frame, n_div=n_div) 408 | if i>=10 and i<=11: 409 | blocks[i].attn = PatchShiftModule(b.attn, video_frame=video_frame, n_div=n_div) 410 | return nn.Sequential(*blocks) 411 | 412 | net.clip.visual.transformer.resblocks = make_trans_patch_shift(net.clip.visual.transformer.resblocks, shift_layers=shift_layers) 413 | 414 | class TokenShuffleModule(nn.Module): 415 | def __init__(self, net, video_frame): 416 | super().__init__() 417 | self.net = net 418 | self.video_frame = video_frame 419 | logger.warning('Using token shuffle!') 420 | 421 | def forward(self, query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None): 422 | # here q == k == v, psm means patch shift output 423 | x = query # shape here is LND (50, 384, 768), not NLD 424 | L, N, D = x.shape 425 | bs = N // self.video_frame 426 | 427 | cls_x = x[:1,:,:] # shape is (1, 384, 768) 428 | x = x.reshape(L, bs, self.video_frame, D) # shape is (50, 32, 12, 768) 429 | spatial_x = x[1:,:,:,:] 430 | spatial_x = spatial_x.permute(0,2,1,3) # shuffle 431 | spatial_x = spatial_x.reshape(L-1, N, D) # shape is (49, 384, 768) 432 | x = torch.cat((cls_x, spatial_x), dim=0) # shape is LND 433 | x = self.net(x, x, x, need_weights=need_weights, attn_mask=attn_mask)[0] 434 | 435 | # reshape 436 | cls_x = x[:1,:,:] # shape is (1, 384, 768) 437 | x = x.reshape(L, bs, self.video_frame, D) # shape is (50, 32, 12, 768) 438 | spatial_x = x[1:,:,:,:] 439 | spatial_x = spatial_x.permute(0,2,1,3) # shuffle 440 | spatial_x = spatial_x.reshape(L-1, N, D) # shape is (49, 384, 768) 441 | x = torch.cat((cls_x, spatial_x), dim=0) # shape is LND 442 | 443 | return (x,) 444 | 445 | def make_token_shuffle(net, video_frame=12): 446 | ''' 447 | Args: 448 | net: CLIP 449 | video_frame: need predefine here 450 | shift_layers: layers to be shift 451 | ''' 452 | 453 | def make_trans_token_shuffle(stage): 454 | # net.clip.visual.transformer.resblocks[i] is a ResidualAttentionBlock type, contain net.attn -- a nn.MultiheadAttention 455 | # make a shift before net.attn, so it is a residual attn 456 | blocks = list(stage.children()) 457 | for i, b in enumerate(blocks): 458 | # b is a ResidualAttentionBlock type, contain self.attn 459 | if i==6 and i==8 and i == 10: 460 | blocks[i].attn = TokenShuffleModule(b.attn, video_frame=video_frame) 461 | return nn.Sequential(*blocks) 462 | 463 | net.clip.visual.transformer.resblocks = make_trans_token_shuffle(net.clip.visual.transformer.resblocks) 464 | 465 | class AttenVisual(nn.Module): 466 | def __init__(self, net): 467 | super().__init__() 468 | self.net = net 469 | logger.warning('Visualizing!') 470 | 471 | def forward(self, query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None): 472 | # here q == k == v, psm means patch shift output 473 | x = query # shape here is LND, shape is (32, N, 512) 474 | attn_output, attn_scores = self.net(x, x, x, need_weights=need_weights, attn_mask=attn_mask) 475 | # attn_scores here is (length, length) a.k.a (32, 32) 476 | print('attn_scores: ', attn_scores[0,7,:]) 477 | print('attn_scores: ', attn_scores[0,15,:]) 478 | print('attn_scores: ', attn_scores.shape) 479 | return attn_output 480 | 481 | def make_attn_visual(net): 482 | ''' 483 | Args: 484 | net: CLIP 485 | ''' 486 | def make_trans_patch_shift(stage): 487 | # net.clip.transformer.resblocks[i] is a ResidualAttentionBlock type, contain net.attn -- a nn.MultiheadAttention 488 | blocks = list(stage.children()) 489 | for i, b in enumerate(blocks): 490 | # b is a ResidualAttentionBlock type, contain self.attn 491 | if i>= 11 and i < 12: 492 | blocks[i].attn = AttenVisual(b.attn) 493 | return nn.Sequential(*blocks) 494 | 495 | net.clip.transformer.resblocks = make_trans_patch_shift(net.clip.transformer.resblocks) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | boto3==1.26.14 2 | botocore==1.29.14 3 | ftfy==6.2.0 4 | numpy==1.23.4 5 | opencv_python==4.6.0.66 6 | pandas==1.5.1 7 | Pillow==10.3.0 8 | regex==2022.10.31 9 | Requests==2.31.0 10 | scipy==1.13.0 11 | torch==1.13.0 12 | torchvision==0.14.0 13 | tqdm==4.64.1 14 | transformers==4.25.1 15 | PyYAML==6.0 -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import threading 4 | from torch._utils import ExceptionWrapper 5 | import logging 6 | 7 | def get_a_var(obj): 8 | if isinstance(obj, torch.Tensor): 9 | return obj 10 | 11 | if isinstance(obj, list) or isinstance(obj, tuple): 12 | for result in map(get_a_var, obj): 13 | if isinstance(result, torch.Tensor): 14 | return result 15 | if isinstance(obj, dict): 16 | for result in map(get_a_var, obj.items()): 17 | if isinstance(result, torch.Tensor): 18 | return result 19 | return None 20 | 21 | def parallel_apply(fct, model, inputs, device_ids): 22 | modules = nn.parallel.replicate(model, device_ids) 23 | assert len(modules) == len(inputs) 24 | lock = threading.Lock() 25 | results = {} 26 | grad_enabled = torch.is_grad_enabled() 27 | 28 | def _worker(i, module, input): 29 | torch.set_grad_enabled(grad_enabled) 30 | device = get_a_var(input).get_device() 31 | try: 32 | with torch.cuda.device(device): 33 | # this also avoids accidental slicing of `input` if it is a Tensor 34 | if not isinstance(input, (list, tuple)): 35 | input = (input,) 36 | output = fct(module, *input) 37 | with lock: 38 | results[i] = output 39 | except Exception: 40 | with lock: 41 | results[i] = ExceptionWrapper(where="in replica {} on device {}".format(i, device)) 42 | 43 | if len(modules) > 1: 44 | threads = [threading.Thread(target=_worker, args=(i, module, input)) 45 | for i, (module, input) in enumerate(zip(modules, inputs))] 46 | 47 | for thread in threads: 48 | thread.start() 49 | for thread in threads: 50 | thread.join() 51 | else: 52 | _worker(0, modules[0], inputs[0]) 53 | 54 | outputs = [] 55 | for i in range(len(inputs)): 56 | output = results[i] 57 | if isinstance(output, ExceptionWrapper): 58 | output.reraise() 59 | outputs.append(output) 60 | return outputs 61 | 62 | def get_logger(filename=None): 63 | logger = logging.getLogger('logger') 64 | logger.setLevel(logging.DEBUG) 65 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', 66 | datefmt='%m/%d/%Y %H:%M:%S', 67 | level=logging.INFO) 68 | if filename is not None: 69 | handler = logging.FileHandler(filename) 70 | handler.setLevel(logging.DEBUG) 71 | handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s')) 72 | logging.getLogger().addHandler(handler) 73 | return logger -------------------------------------------------------------------------------- /utils/bigfile.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import os, sys, array 3 | import time 4 | 5 | import numpy as np 6 | import torch 7 | 8 | from itertools import tee 9 | import multiprocessing as mp 10 | 11 | 12 | class BigFile: 13 | 14 | def __init__(self, datadir, bin_file="feature.bin"): 15 | self.nr_of_images, self.ndims = list(map(int, open(os.path.join(datadir, 'shape.txt')).readline().split())) 16 | id_file = os.path.join(datadir, "id.txt") 17 | self.names = open(id_file, 'r').read().strip().split('\n') # 所有 video 文件名 18 | if len(self.names) != self.nr_of_images: 19 | self.names = open(id_file, 'r').read().strip().split(' ') 20 | assert(len(self.names) == self.nr_of_images) 21 | self.name2index = dict(list(zip(self.names, list(range(self.nr_of_images))))) # 给每一个文件名弄一个编号 22 | self.binary_file = os.path.join(datadir, bin_file) 23 | print(("[%s] %dx%d instances loaded from %s" % (self.__class__.__name__, self.nr_of_images, self.ndims, datadir))) 24 | self.torch_array = None 25 | 26 | def read_all_and_store(self): 27 | def readall(self, ndims): 28 | torch_array = torch.zeros(ndims, dtype=torch.half) 29 | 30 | index_name_array = [(self.name2index[x], x) for x in set(self.names) if x in self.name2index] 31 | 32 | index_name_array.sort(key=lambda v: v[0]) 33 | sorted_index = [x[0] for x in index_name_array] 34 | 35 | nr_of_images = len(index_name_array) 36 | 37 | offset = np.float32(1).nbytes * self.ndims 38 | 39 | res1 = array.array('f') 40 | fr = open(self.binary_file, 'rb') 41 | fr.seek(index_name_array[0][0] * offset) 42 | res1.fromfile(fr, self.ndims) 43 | previous = index_name_array[0][0] 44 | torch_array[previous] = torch.tensor(res1) 45 | 46 | for next in sorted_index[1:]: 47 | res1 = array.array('f') 48 | move = (next - 1 - previous) * offset 49 | # print next, move 50 | fr.seek(move, 1) 51 | res1.fromfile(fr, self.ndims) 52 | previous = next 53 | torch_array[previous] = torch.tensor(res1) 54 | 55 | return torch_array 56 | self.torch_array = readall(self, self.shape()) 57 | 58 | def readall(self, isname=True): 59 | index_name_array = [(self.name2index[x], x) for x in set(self.names) if x in self.name2index] 60 | 61 | index_name_array.sort(key=lambda v:v[0]) 62 | sorted_index = [x[0] for x in index_name_array] 63 | 64 | nr_of_images = len(index_name_array) 65 | vecs = [None] * nr_of_images 66 | offset = np.float32(1).nbytes * self.ndims 67 | 68 | res = array.array('f') 69 | fr = open(self.binary_file, 'rb') 70 | fr.seek(index_name_array[0][0] * offset) 71 | res.fromfile(fr, self.ndims) 72 | previous = index_name_array[0][0] 73 | 74 | for next in sorted_index[1:]: 75 | move = (next-1-previous) * offset 76 | #print next, move 77 | fr.seek(move, 1) 78 | res.fromfile(fr, self.ndims) 79 | previous = next 80 | 81 | 82 | return [x[1] for x in index_name_array], [ res[i*self.ndims:(i+1)*self.ndims].tolist() for i in range(nr_of_images) ] 83 | 84 | def _read_from_ram(self, requested, isname=True): 85 | """ 86 | 从内存中直接读 87 | :param requested: 88 | :param isname: 89 | :return: 这里主要是视频名字和 feature vector, 一般输出 list 90 | """ 91 | requested = set(requested) 92 | if isname: 93 | index_name_array = [(self.name2index[x], x) for x in requested if x in self.name2index] 94 | else: 95 | assert(min(requested)>=0) 96 | assert(max(requested)=0) 125 | assert(max(requested)= len(self.fr_list): 136 | index = len(self.fr_list)-1 137 | 138 | # 获取信号量 139 | signal = True 140 | while signal: 141 | with self.mp_signal.get_lock(): # 直接调用get_lock()函数获取锁 142 | for signal_index in range(len(self.fr_list[index])): 143 | if self.mp_signal[index*len(self.fr_list[0]) + signal_index] == 1: 144 | self.mp_signal[index*len(self.fr_list[0]) + signal_index] = 0 145 | signal = False 146 | break 147 | if signal: 148 | time.sleep(0.0001) 149 | 150 | fr = self.fr_list[index][signal_index]['fr'] 151 | move = index_name_array[0][0] * offset - fr.tell() 152 | fr.seek(move, 1) 153 | res.fromfile(fr, self.ndims) 154 | fr.seek(-move - offset, 1) 155 | self.mp_signal[index*len(self.fr_list[0]) + signal_index] = 1 156 | 157 | # with open(self.binary_file, 'rb') as fr: 158 | # move = index_name_array[0][0] * offset 159 | # fr.seek(move) 160 | # res.fromfile(fr, self.ndims) 161 | 162 | except Exception as e: 163 | print(e) 164 | 165 | # print([ res.tolist() ]) 166 | # print(self.read(requested)[1]) 167 | 168 | return [index_name_array[0][1]], [ res.tolist() ] 169 | 170 | def read(self, requested, isname=True): 171 | """ 172 | 根据文件名读取文件,具体是从 bin 文件中读取numpy 矩阵,这里主要是视频名字和 feature vector 173 | :param requested: [] 174 | :param isname: 175 | :return: 这里主要是视频名字和 feature vector 176 | """ 177 | requested = set(requested) 178 | if isname: 179 | index_name_array = [(self.name2index[x], x) for x in requested if x in self.name2index] 180 | else: 181 | assert(min(requested)>=0) 182 | assert(max(requested)= self.nr_of_images: 257 | self.close() 258 | raise StopIteration 259 | else: 260 | res = array.array('f') 261 | res.fromfile(self.fr, self.ndims) 262 | _id = self.names[self.current] 263 | self.current += 1 264 | return _id, res.tolist() 265 | 266 | 267 | if __name__ == '__main__': 268 | feat_dir = "/data2/hf/VisualSearch/toydata/FeatureData/f1" 269 | bigfile = BigFile(feat_dir) 270 | 271 | imset = str.split('b z a a b c') 272 | renamed, vectors = bigfile.read(imset) 273 | 274 | 275 | for name,vec in zip(renamed, vectors): 276 | print(name, vec) 277 | 278 | bigfile = StreamFile(feat_dir) 279 | bigfile.open() 280 | for name, vec in bigfile: 281 | print(name, vec) 282 | bigfile.close() -------------------------------------------------------------------------------- /utils/txt2bin.py: -------------------------------------------------------------------------------- 1 | ''' 2 | # convert one or multiple feature files from txt format to binary (float32) format 3 | ''' 4 | 5 | import os, sys, math 6 | import numpy as np 7 | from optparse import OptionParser 8 | 9 | def checkToSkip(filename, overwrite): 10 | if os.path.exists(filename): 11 | print ("%s exists." % filename), 12 | if overwrite: 13 | print ("overwrite") 14 | return 0 15 | else: 16 | print ("skip") 17 | return 1 18 | return 0 19 | 20 | 21 | def process(feat_dim, inputTextFiles, resultdir, overwrite): 22 | res_binary_file = os.path.join(resultdir, 'feature.bin') 23 | res_id_file = os.path.join(resultdir, 'id.txt') 24 | 25 | if checkToSkip(res_binary_file, overwrite): 26 | return 0 27 | 28 | if os.path.isdir(resultdir) is False: 29 | os.makedirs(resultdir) 30 | 31 | fw = open(res_binary_file, 'wb') 32 | processed = set() 33 | imset = [] 34 | count_line = 0 35 | failed = 0 36 | 37 | for filename in inputTextFiles: 38 | print ('***** Processing %s *****' % filename) 39 | for line in open(filename): 40 | count_line += 1 41 | elems = line.strip().split() 42 | if not elems: 43 | continue 44 | name = elems[0] 45 | if name in processed: 46 | continue 47 | processed.add(name) 48 | 49 | del elems[0] 50 | vec = np.array(list(map(float, elems)), dtype=np.float32) 51 | okay = True 52 | for x in vec: 53 | if math.isnan(x): 54 | okay = False 55 | break 56 | if not okay: 57 | failed += 1 58 | continue 59 | 60 | if feat_dim == 0: 61 | feat_dim = len(vec) 62 | else: 63 | assert(len(vec) == feat_dim), "dimensionality mismatch: required %d, input %d, id=%s, inputfile=%s" % (feat_dim, len(vec), name, filename) 64 | vec.tofile(fw) 65 | #print name, vec 66 | imset.append(name) 67 | fw.close() 68 | 69 | fw = open(res_id_file, 'w') 70 | fw.write(' '.join(imset)) 71 | fw.close() 72 | fw = open(os.path.join(resultdir,'shape.txt'), 'w') 73 | fw.write('%d %d' % (len(imset), feat_dim)) 74 | fw.close() 75 | print ('%d lines parsed, %d ids, %d failed -> %d unique ids' % (count_line, len(processed), failed, len(imset))) 76 | 77 | 78 | def main(argv=None): 79 | if argv is None: 80 | argv = sys.argv[1:] 81 | 82 | parser = OptionParser(usage="""usage: %prog [options] nDims inputTextFile isFileList resultDir""") 83 | parser.add_option("--overwrite", default=0, type="int", help="overwrite existing file (default=0)") 84 | 85 | (options, args) = parser.parse_args(argv) 86 | if len(args) < 4: 87 | parser.print_help() 88 | return 1 89 | 90 | fea_dim = int(args[0]) 91 | inputTextFile = args[1] 92 | if int(args[2]) == 1: 93 | inputTextFiles = [x.strip() for x in open(inputTextFile).readlines() if x.strip() and not x.strip().startswith('#')] 94 | else: 95 | inputTextFiles = [inputTextFile] 96 | return process(fea_dim, inputTextFiles, args[3], options.overwrite) 97 | 98 | 99 | if __name__ == "__main__": 100 | sys.exit(main()) --------------------------------------------------------------------------------