├── README.md ├── __pycache__ └── engine_pretrain.cpython-310.pyc ├── assets ├── method.png └── teaser.png ├── configs ├── config │ ├── clip_base.yml │ ├── clip_base_eval.yml │ ├── clip_large.yml │ └── clip_large_eval.yml ├── default_clip_base.yml └── default_clip_large.yml ├── dataset ├── __pycache__ │ ├── data_utils.cpython-310.pyc │ ├── egodataset.cpython-310.pyc │ └── ek100dataset.cpython-310.pyc ├── data_utils.py ├── egodataset.py └── ek100dataset.py ├── engine_pretrain.py ├── environment.yml ├── evaluation ├── eval_egomcq.py ├── eval_egtea.py ├── eval_ekcls.py └── eval_mir.py ├── exps ├── eval_egomcq.sh ├── eval_egtea.sh ├── eval_ekcls.sh ├── eval_mir.sh ├── pretrain.sh └── pretrain_large.sh ├── main_pretrain.py ├── model ├── __pycache__ │ ├── clip.cpython-310.pyc │ ├── loss.cpython-310.pyc │ ├── timesformer.cpython-310.pyc │ └── transformer.cpython-310.pyc ├── clip.py ├── loss.py ├── timesformer.py └── transformer.py ├── output_dir ├── events.out.tfevents.1740572442.SH-IDC1-10-140-37-2.156563.0 ├── events.out.tfevents.1740573180.SH-IDC1-10-140-37-41.259551.0 ├── events.out.tfevents.1740573310.SH-IDC1-10-140-37-41.261269.0 └── events.out.tfevents.1740573743.SH-IDC1-10-140-37-41.119535.0 ├── requirements.txt └── util ├── __pycache__ ├── config.cpython-310.pyc ├── dist_utils.cpython-310.pyc ├── lr_sched.cpython-310.pyc ├── meter.cpython-310.pyc ├── misc.cpython-310.pyc └── pos_embed.cpython-310.pyc ├── config.py ├── crop.py ├── datasets.py ├── dist_utils.py ├── lars.py ├── lr_decay.py ├── lr_sched.py ├── meter.py ├── misc.py └── pos_embed.py /README.md: -------------------------------------------------------------------------------- 1 | # EgoHOD 2 | 3 | This repo is the official implementation of EgoHOD at ICLR 2025 4 | 5 | > **["Modeling Fine-Grained Hand-Object Dynamics for Egocentric Video Representation Learning"](https://openreview.net/forum?id=P6G1Z6jkf3)**
6 | > [Baoqi Pei](https://scholar.google.com/citations?user=sTCkd54AAAAJ), [Yifei Huang](https://scholar.google.com/citations?user=RU8gNcgAAAAJ), [Jilan Xu](https://scholar.google.com/citations?user=mf2U64IAAAAJ), [Guo Chen](https://scholar.google.com/citations?user=lRj3moAAAAAJ), Yuping He, [Lijin Yang](https://scholar.google.com/citations?user=ppR-rpkAAAAJ),
7 | > [Yali Wang](https://scholar.google.com/citations?user=hD948dkAAAAJ), [Weidi Xie](https://scholar.google.com/citations?user=Vtrqj4gAAAAJ), [Yu Qiao](https://scholar.google.com/citations?user=gFtI-8QAAAAJ), [Fei Wu](https://scholar.google.com/citations?user=XJLn4MYAAAAJ), [Limin Wang](https://scholar.google.com/citations?user=HEuN8PcAAAAJ)
8 | 9 |
10 | 11 |
12 | 13 | ## Todo 14 | 15 | - [x] HOD data release 16 | - [x] Pretrained code release 17 | - [ ] Finetuned code release 18 | - [x] Pretrained model checkpoints release 19 | - [ ] Finetuned model checkpoints release 20 | - [x] Evaluation code release 21 | 22 | ## Introduction 23 | 24 | In egocentric video understanding, the motion of hands and objects as well as their interactions play a significant role by nature. However, existing egocentric video representation learning methods mainly focus on aligning video representation with high-level narrations, overlooking the intricate dynamics between hands and objects. In this work, we aim to integrate the modeling of fine-grained hand-object dynamics into the video representation learning process. 25 | Since no suitable data is available, we introduce HOD, a novel pipeline employing a hand-object detector and a large language model to generate high-quality narrations with detailed descriptions of hand-object dynamics. 26 | To learn these fine-grained dynamics, we propose EgoVideo, a model with a new lightweight motion adapter to capture fine-grained hand-object motion information. Through our co-training strategy, EgoVideo effectively and efficiently leverages the fine-grained hand-object dynamics in the HOD data. Extensive experiments demonstrate that our method achieves state-of-the-art performance across multiple egocentric downstream tasks, including improvements of 6.3% in EK-100 multi-instance retrieval, 5.7% in EK-100 classification, and 16.3% in EGTEA classification in zero-shot settings. Furthermore, our model exhibits robust generalization capabilities in hand-object interaction and robot manipulation tasks. 27 | 28 |
29 | 30 |
31 | 32 | 33 | ## Installation 34 | ``` 35 | https://github.com/OpenRobotLab/EgoHOD.git 36 | conda env create -f environment.yml 37 | conda activate hod 38 | pip install -r requirements.txt 39 | ``` 40 | 41 | ## Datasets 42 | 43 | You can get our HOD annotations from this [Huggingface link](https://huggingface.co/datasets/Jazzcharles/EgoHOD). 44 | 45 | ## Pretraining 46 | 47 | For training EgoVideo model without adapter, you can simply run the following code: 48 | ```shell 49 | bash ./exps/pretrain.sh 50 | ``` 51 | **Notes:** 52 | 1. Modify the yml files in `./configs` before running the scripts. 53 | 2. For training without slurm script, you can simply run 54 | ```shell 55 | python main_pretrain.py --config_file configs/clip_base.yml 56 | ``` 57 | 3. For model with Adapter, we will release the pretraining code soon. 58 | 59 | 60 | ## Pretrained Model 61 | 62 | For our pretrained model, you can download checkpoint from [this link](https://huggingface.co/Jazzcharles/EgoVideo). 63 | 64 | ## Finetuning 65 | 66 | We will update the code soon. 67 | 68 | ## Zero-shot Evaluation 69 | 70 | For zero-shot evaluation, you can simply run the scripts in `exps` as follows: 71 | ```shell 72 | bash exps/eval_ekcls.sh 73 | ``` 74 | We provide the evaluation code for EK100-MIR, EK100-CLS, EGTEA, and EGOMCQ. 75 | ## Cite 76 | 77 | If you find this repository useful, please use the following BibTeX entry for citation. 78 | 79 | ```latex 80 | @misc{pei2025modeling, 81 | title={Modeling Fine-Grained Hand-Object Dynamics for Egocentric Video Representation Learning}, 82 | author={Baoqi Pei, Yifei Huang, Jilan Xu, Guo Chen, Yuping He, Lijin Yang, Yali Wang, Weidi Xie, Yu Qiao, Fei Wu, Limin Wang}, 83 | year={2025}, 84 | eprint={2503.00986}, 85 | archivePrefix={arXiv}, 86 | primaryClass={cs.CV} 87 | } 88 | ``` 89 | 90 | ## Acknowledgement 91 | 92 | This repository is built based on [mae](https://github.com/facebookresearch/mae) and [AVION](https://github.com/zhaoyue-zephyrus/AVION). Thanks to the contributors of the great codebase. 93 | -------------------------------------------------------------------------------- /__pycache__/engine_pretrain.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenRobotLab/EgoHOD/e874d04e42e716e3c9a85b2697450e61e696d48a/__pycache__/engine_pretrain.cpython-310.pyc -------------------------------------------------------------------------------- /assets/method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenRobotLab/EgoHOD/e874d04e42e716e3c9a85b2697450e61e696d48a/assets/method.png -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenRobotLab/EgoHOD/e874d04e42e716e3c9a85b2697450e61e696d48a/assets/teaser.png -------------------------------------------------------------------------------- /configs/config/clip_base.yml: -------------------------------------------------------------------------------- 1 | _base_: '../default_clip_base.yml' 2 | data: 3 | dataset: ego4d_htego 4 | # ego4d_root: pssd:s3://ego4d/all_videos_fps30_short320_chunked/ 5 | ego4d_root: cluster1:s3://videos/ego4d/videos_short320_chunked_15s/ 6 | ego4d_metadata: /mnt/petrelfs/peibaoqi/robot/annotation/ego4d_new_recaption_final.csv 7 | 8 | howto_metadata: /mnt/petrelfs/share_data/xujilan/annotations/howto100m/Howto-Interlink-egoonly_egolabel.csv 9 | 10 | clip_length: 4 11 | clip_stride: 16 12 | 13 | train: 14 | task: vlp 15 | batch_size: 128 16 | epochs: 15 17 | lr: 4e-5 18 | fix_lr: true 19 | 20 | model: 21 | name: CLIP_VITB16 22 | ckpt_path: /mnt/petrelfs/peibaoqi/robot/mae/ckpt/ViT-B-16.pt 23 | freeze_temperature: true 24 | lavila_path: dont_use 25 | wandb: true 26 | 27 | #resume: checkpoint_best.pt 28 | 29 | output_dir: output/ 30 | use_eva: False 31 | use_bert: False 32 | 33 | -------------------------------------------------------------------------------- /configs/config/clip_base_eval.yml: -------------------------------------------------------------------------------- 1 | _base_: '../default_clip_base.yml' 2 | data: 3 | dataset: ego4d_htego 4 | 5 | ego4d_root: cluster1:s3://videos/ego4d/videos_short320_chunked_15s/ 6 | ego4d_metadata: /mnt/petrelfs/peibaoqi/robot/annotation/ego4d_new_recaption_final.csv 7 | 8 | #egoexo4d_root: cluster1:s3://howto_ego/htm370k_cooking_clips_15s_short256/ 9 | exoego4d_root: sssdpbq:s3://video_pub/howto100m 10 | # egoexo4d_metadata: /mnt/petrelfs/xujilan/data/howto100/Howto-Interlink7M/Howto-Interlink-cookingclips.csv 11 | egoexo4d_metadata: /mnt/petrelfs/share_data/xujilan/annotations/howto100m/Howto-Interlink-egoonly_egolabel.csv 12 | howto_metadata: /mnt/petrelfs/share_data/xujilan/annotations/howto100m/Howto-Interlink-egoonly_egolabel.csv 13 | # howto_metadata: /mnt/petrelfs/peibaoqi/robot/ego_annotation/howto_caption.csv 14 | how2_traj_metadata: /mnt/petrelfs/peibaoqi/EgoVideo/how2_traj 15 | ego4d_traj_metadata: /mnt/petrelfs/peibaoqi/EgoVideo/ego4d_traj 16 | clip_length: 4 17 | clip_stride: 16 18 | 19 | train: 20 | task: vlp 21 | batch_size: 128 22 | epochs: 15 23 | lr: 4e-5 24 | fix_lr: true 25 | 26 | model: 27 | name: CLIP_VITB16 28 | ckpt_path: /mnt/petrelfs/peibaoqi/robot/mae/ckpt/ViT-B-16.pt 29 | freeze_temperature: true 30 | lavila_path: /mnt/petrelfs/peibaoqi/AVION/clip_ckpt/avion_pretrain_lavila_vitb 31 | wandb: true 32 | resume: /mnt/petrelfs/peibaoqi/robot/mae/ckpt/base_best.pt 33 | 34 | output_dir: /mnt/petrelfs/peibaoqi/exps/avion3/ 35 | use_eva: False 36 | use_bert: False 37 | 38 | -------------------------------------------------------------------------------- /configs/config/clip_large.yml: -------------------------------------------------------------------------------- 1 | _base_: '../default_clip_large.yml' 2 | data: 3 | dataset: ego4d_htego 4 | # ego4d_root: pssd:s3://ego4d/all_videos_fps30_short320_chunked/ 5 | ego4d_root: cluster1:s3://videos/ego4d/videos_short320_chunked_15s/ 6 | ego4d_metadata: /mnt/petrelfs/peibaoqi/robot/annotation/ego4d_new_recaption_final.csv 7 | 8 | egoexo4d_root: cluster1:s3://howto_ego/htm370k_cooking_clips_15s_short256/ 9 | # egoexo4d_metadata: /mnt/petrelfs/xujilan/data/howto100/Howto-Interlink7M/Howto-Interlink-cookingclips.csv 10 | egoexo4d_metadata: /mnt/petrelfs/share_data/xujilan/avion_metadata/Howto-Interlink-egoonly_egolabel.csv 11 | howto_metadata: /mnt/petrelfs/share_data/xujilan/annotations/howto100m/Howto-Interlink-egoonly_egolabel.csv 12 | clip_length: 4 13 | clip_stride: 16 14 | 15 | 16 | train: 17 | task: vlp 18 | batch_size: 8 19 | epochs: 15 20 | lr: 1e-5 21 | fix_lr: true 22 | 23 | 24 | model: 25 | name: CLIP_VITL14_336PX 26 | ckpt_path: /mnt/petrelfs/peibaoqi/robot/mae/ckpt/ViT-L-14-336px.pt 27 | lavila_path: /mnt/petrelfs/peibaoqi/AVION/clip_ckpt/avion_pretrain_lavila_vitb 28 | wandb: true 29 | # resume: /mnt/petrelfs/peibaoqi/exps/avion_large/ 30 | output_dir: /mnt/petrelfs/peibaoqi/exps/avion_large/ 31 | use_eva: False 32 | use_bert: False 33 | -------------------------------------------------------------------------------- /configs/config/clip_large_eval.yml: -------------------------------------------------------------------------------- 1 | _base_: '../default_clip_large.yml' 2 | data: 3 | dataset: ego4d_htego 4 | # ego4d_root: pssd:s3://ego4d/all_videos_fps30_short320_chunked/ 5 | ego4d_root: cluster1:s3://videos/ego4d/videos_short320_chunked_15s/ 6 | ego4d_metadata: /mnt/petrelfs/peibaoqi/robot/annotation/ego4d_new_recaption_final.csv 7 | 8 | egoexo4d_root: cluster1:s3://howto_ego/htm370k_cooking_clips_15s_short256/ 9 | # egoexo4d_metadata: /mnt/petrelfs/xujilan/data/howto100/Howto-Interlink7M/Howto-Interlink-cookingclips.csv 10 | egoexo4d_metadata: /mnt/petrelfs/share_data/xujilan/avion_metadata/Howto-Interlink-egoonly_egolabel.csv 11 | howto_metadata: /mnt/petrelfs/share_data/xujilan/annotations/howto100m/Howto-Interlink-egoonly_egolabel.csv 12 | clip_length: 4 13 | clip_stride: 16 14 | 15 | 16 | train: 17 | task: vlp 18 | batch_size: 8 19 | epochs: 15 20 | lr: 1e-5 21 | fix_lr: true 22 | 23 | 24 | model: 25 | name: CLIP_VITL14_336PX 26 | ckpt_path: /mnt/petrelfs/peibaoqi/robot/mae/ckpt/ViT-L-14-336px.pt 27 | lavila_path: /mnt/petrelfs/peibaoqi/AVION/clip_ckpt/avion_pretrain_lavila_vitb 28 | wandb: true 29 | resume: /mnt/petrelfs/peibaoqi/robot/mae/ckpt/large_best.pt 30 | output_dir: /mnt/petrelfs/peibaoqi/exps/avion_large/ 31 | use_eva: False 32 | use_bert: False 33 | -------------------------------------------------------------------------------- /configs/default_clip_base.yml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: ego4d_htego 3 | type: normal 4 | # ego4d_root: pssd:s3://ego4d/all_videos_fps30_short320_chunked/ 5 | ego4d_root: cluster1:s3://videos/ego4d/videos_short320_chunked_15s/ 6 | ego4d_metadata: /mnt/petrelfs/share_data/xujilan/avion_metadata/ego4d_train.csv 7 | ego4d_metadata_aux: null 8 | ego4d_video_chunk_len: 15 9 | ego4d_fps: 30 10 | 11 | howto_root: cluster1:s3://howto_ego/htm_clips_15s_25fps_short128/ 12 | howto_metadata: /mnt/hwfile/internvideo/share_data/xujilan/howto100/Howto-Interlink-egoonly_egolabel.csv 13 | howto_video_chunk_len: 15 # -1 14 | howto_fps: 25 # -1 15 | 16 | clip_length: 16 17 | clip_stride: 4 18 | input_size: 224 19 | patch_size: 16 20 | is_trimmed: true 21 | 22 | context_length: 77 # 128 23 | vocab_size: 49408 24 | norm_style: clip 25 | fused_decode_crop: true 26 | decode_threads: 1 27 | 28 | multiview: false 29 | clear_narration: false 30 | return_uid: false 31 | 32 | model: 33 | name: CLIP_VITB16 34 | # norm_embed: true 35 | clip_length: ${data.clip_length} 36 | # contrastive_use_vissl: true #use contrastive implementation in vissl 37 | temperature_init: 0.07 38 | 39 | freeze_temperature: true 40 | grad_checkpointing: true 41 | use_fast_conv1: true 42 | use_flash_attn: true 43 | patch_dropout: 0.0 44 | drop_path_rate: 0.0 45 | pretrain_zoo: intern 46 | pretrain_path: null 47 | project_embed_dim: 512 48 | 49 | multiview: ${data.multiview} 50 | 51 | train: 52 | task: vlp 53 | batch_size: 16 54 | epochs: 10 55 | warmup_epochs: 1 56 | lr: 1e-5 57 | fix_lr: true 58 | lr_start: 6e-8 59 | lr_end: 6e-7 60 | grad_clip_norm: null 61 | update_freq: 1 62 | seed: 0 63 | workers: 10 64 | 65 | optimizer: 66 | name: adamw 67 | wd: 0.01 68 | betas: [0.9, 0.999] 69 | eps: 1e-8 70 | 71 | eval_freq: 1 72 | print_freq: 10 73 | save_freq: 1 74 | disable_amp: false 75 | 76 | use_half: false 77 | find_unused_parameters: false 78 | 79 | local_loss: false 80 | gather_with_grad: false 81 | use_zero: false 82 | 83 | use_multi_epochs_loader: false 84 | 85 | test: 86 | # batch_size: ${train.batch_size} 87 | batch_size: 16 88 | workers: 10 89 | testonly: false 90 | savemetaonly: false 91 | ek100_mir: 92 | metapath: /mnt/petrelfs/peibaoqi/robot/epic_kitchen/ 93 | metadata: ${test.ek100_mir.metapath}EPIC_100_retrieval_test.csv 94 | relevancy_path: /mnt/petrelfs/peibaoqi/robot/epic_kitchen/relevancy/caption_relevancy_EPIC_100_retrieval_test.pkl 95 | 96 | 97 | root: cluster2:s3://epic/epic_video_320p/ 98 | video_chunk_len: -1 99 | fps: -1 100 | clip_length: 4 101 | clip_stride: 16 102 | num_clips: 1 103 | num_crops: 1 104 | sparse_sample: false 105 | decode_threads: ${data.decode_threads} 106 | fused_decode_crop: ${data.fused_decode_crop} 107 | 108 | 109 | wandb: false 110 | resume: null 111 | resume_pretrain: null 112 | output_dir: /mnt/petrelfs/peibaoqi/exps/avion/debug/ 113 | local_rank: 0 114 | 115 | visualisation: false 116 | 117 | -------------------------------------------------------------------------------- /configs/default_clip_large.yml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: ego4d_htego 3 | 4 | type: normal 5 | # ego4d_root: pssd:s3://ego4d/all_videos_fps30_short320_chunked/ 6 | ego4d_root: cluster1:s3://videos/ego4d/videos_short320_chunked_15s/ 7 | ego4d_metadata: /mnt/petrelfs/share_data/xujilan/avion_metadata/ego4d_train.csv 8 | ego4d_metadata_aux: null 9 | ego4d_video_chunk_len: 15 10 | ego4d_fps: 30 11 | 12 | howto_root: cluster1:s3://howto_ego/htm_clips_15s_25fps_short128/ 13 | # howto_root: phdd:s3://howto100m/ 14 | howto_metadata: /mnt/hwfile/internvideo/share_data/xujilan/howto100/Howto-Interlink-egoonly_egolabel.csv 15 | howto_video_chunk_len: 15 # -1 16 | howto_fps: 25 # -1 17 | 18 | recap: /mnt/petrelfs/peibaoqi/videorecap/goalstep+recap.csv 19 | egoexo: /mnt/petrelfs/peibaoqi/videorecap/egoexo_final_video.csv 20 | egoexoroot: shddnew:s3://huangyifei/ego_trimmed_clips/ 21 | clip_length: 16 22 | clip_stride: 4 23 | input_size: 224 24 | patch_size: 16 25 | is_trimmed: true 26 | 27 | context_length: 77 # 128 28 | vocab_size: 49408 29 | norm_style: clip 30 | fused_decode_crop: true 31 | decode_threads: 1 32 | 33 | multiview: false 34 | clear_narration: false 35 | return_uid: false 36 | 37 | model: 38 | name: CLIP_VITL14_336PX 39 | # norm_embed: true 40 | clip_length: ${data.clip_length} 41 | # contrastive_use_vissl: true #use contrastive implementation in vissl 42 | temperature_init: 0.07 43 | 44 | freeze_temperature: true 45 | grad_checkpointing: true 46 | use_fast_conv1: true 47 | use_flash_attn: true 48 | patch_dropout: 0.0 49 | drop_path_rate: 0.0 50 | pretrain_zoo: intern 51 | pretrain_path: null 52 | project_embed_dim: 512 53 | 54 | multiview: ${data.multiview} 55 | 56 | train: 57 | task: vlp 58 | batch_size: 16 59 | epochs: 10 60 | warmup_epochs: 1 61 | lr: 1e-5 62 | fix_lr: true 63 | lr_start: 6e-8 64 | lr_end: 6e-7 65 | grad_clip_norm: null 66 | update_freq: 1 67 | seed: 0 68 | workers: 10 69 | 70 | optimizer: 71 | name: adamw 72 | wd: 0.01 73 | betas: [0.9, 0.999] 74 | eps: 1e-8 75 | 76 | eval_freq: 1 77 | print_freq: 10 78 | save_freq: 1 79 | disable_amp: false 80 | 81 | use_half: false 82 | find_unused_parameters: false 83 | 84 | local_loss: false 85 | gather_with_grad: false 86 | use_zero: false 87 | 88 | use_multi_epochs_loader: false 89 | 90 | test: 91 | # batch_size: ${train.batch_size} 92 | batch_size: 16 93 | workers: 10 94 | testonly: false 95 | savemetaonly: false 96 | ek100_mir: 97 | 98 | metapath: /mnt/petrelfs/peibaoqi/robot/epic_kitchen/ 99 | metadata: ${test.ek100_mir.metapath}EPIC_100_retrieval_test.csv 100 | relevancy_path: /mnt/petrelfs/peibaoqi/robot/epic_kitchen/relevancy/caption_relevancy_EPIC_100_retrieval_test.pkl 101 | 102 | root: cluster2:s3://epic/epic_video_320p/ 103 | video_chunk_len: -1 104 | fps: -1 105 | 106 | # root: cluster1:s3://downstream/epic_clips_300s_30fps_short320/ 107 | # video_chunk_len: 300 108 | # fps: 30 109 | 110 | clip_length: 16 111 | clip_stride: 4 112 | num_clips: 1 113 | num_crops: 1 114 | sparse_sample: false 115 | decode_threads: ${data.decode_threads} 116 | fused_decode_crop: ${data.fused_decode_crop} 117 | 118 | 119 | wandb: false 120 | resume: null 121 | resume_pretrain: null 122 | output_dir: /mnt/petrelfs/peibaoqi/exps/avion/debug/ 123 | local_rank: 0 124 | 125 | visualisation: false 126 | 127 | 128 | 129 | # myshdd:s3://egoschema/videos/001934bb-81bd-4cd8-a574-0472ef3f6678.mp4 -------------------------------------------------------------------------------- /dataset/__pycache__/data_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenRobotLab/EgoHOD/e874d04e42e716e3c9a85b2697450e61e696d48a/dataset/__pycache__/data_utils.cpython-310.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/egodataset.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenRobotLab/EgoHOD/e874d04e42e716e3c9a85b2697450e61e696d48a/dataset/__pycache__/egodataset.cpython-310.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/ek100dataset.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenRobotLab/EgoHOD/e874d04e42e716e3c9a85b2697450e61e696d48a/dataset/__pycache__/ek100dataset.cpython-310.pyc -------------------------------------------------------------------------------- /dataset/data_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import csv 3 | import glob 4 | import json 5 | import numpy as np 6 | import os.path as osp 7 | import pickle 8 | import random 9 | 10 | import decord 11 | import pandas as pd 12 | import torch 13 | from decord import cpu 14 | import cv2 15 | import io,os 16 | import argparse 17 | import time 18 | import func_timeout 19 | from func_timeout import func_set_timeout 20 | try: 21 | from petrel_client.client import Client 22 | client = Client() 23 | 24 | # Disable boto logger 25 | import logging 26 | logging.getLogger('boto3').setLevel(logging.WARNING) 27 | logging.getLogger('botocore').setLevel(logging.WARNING) 28 | logging.getLogger('nose').setLevel(logging.WARNING) 29 | except: 30 | client = None 31 | 32 | 33 | def datetime2sec(str): 34 | hh, mm, ss = str.split(':') 35 | return int(hh) * 3600 + int(mm) * 60 + float(ss) 36 | 37 | # def get_vr(video_path): 38 | # video_bytes = client.get(video_path) 39 | # assert video_bytes is not None, "Get video failed from {}".format(video_path) 40 | # video_path = video_bytes 41 | # if isinstance(video_path, bytes): 42 | # video_path = io.BytesIO(video_bytes) 43 | # vreader = decord.VideoReader(video_path, ctx=cpu(0)) 44 | # return vreader 45 | 46 | 47 | def get_frame_ids(start_frame, end_frame, num_segments=32, jitter=True): 48 | frame_ids = np.convolve(np.linspace(start_frame, end_frame, num_segments + 1), [0.5, 0.5], mode='valid') 49 | if jitter: 50 | seg_size = float(end_frame - start_frame - 1) / num_segments 51 | shift = (np.random.rand(num_segments) - 0.5) * seg_size 52 | frame_ids += shift 53 | return frame_ids.astype(int).tolist() 54 | 55 | def get_videobytesIO(video_path): 56 | # video_bytes = client.get(video_path, enable_stream=True) 57 | video_bytes = client.get(video_path) 58 | assert video_bytes is not None, "Get video failed from {}".format(video_path) 59 | video_path = video_bytes 60 | if isinstance(video_path, bytes): 61 | video_path = io.BytesIO(video_bytes) 62 | return video_path 63 | 64 | 65 | def get_video_reader(videoname, num_threads, fast_rrc, rrc_params, fast_rcc, rcc_params): 66 | if '/mnt/petrelfs' in videoname: 67 | if fast_rrc: 68 | video_reader = decord.VideoReader( 69 | videoname, 70 | num_threads=num_threads, 71 | ) 72 | elif fast_rcc: 73 | video_reader = decord.VideoReader( 74 | videoname, 75 | num_threads=num_threads, 76 | ) 77 | else: 78 | video_reader = decord.VideoReader(videoname, num_threads=num_threads) 79 | #print(video_reader) 80 | return video_reader 81 | else: 82 | video_reader = None 83 | video_bytes = client.get(videoname) 84 | assert video_bytes is not None, "Get video failed from {}".format(videoname) 85 | videoname = video_bytes 86 | if isinstance(videoname, bytes): 87 | videoname = io.BytesIO(video_bytes) 88 | 89 | if fast_rrc: 90 | video_reader = decord.VideoReader( 91 | videoname, 92 | num_threads=num_threads, 93 | ) 94 | elif fast_rcc: 95 | video_reader = decord.VideoReader( 96 | videoname, 97 | num_threads=num_threads, 98 | ) 99 | else: 100 | video_reader = decord.VideoReader(videoname, num_threads=num_threads) 101 | #print(video_reader) 102 | return video_reader 103 | 104 | def video_loader(root, vid, ext, second, end_second, 105 | chunk_len=300, fps=-1, clip_length=32, 106 | threads=1, 107 | fast_rrc=False, rrc_params=(224, (0.5, 1.0)), 108 | fast_rcc=False, rcc_params=(224, ), 109 | jitter=False,use_crop=False): 110 | # assert fps > 0, 'fps should be greater than 0' 111 | 112 | if chunk_len == -1: 113 | 114 | if root == '': 115 | vr = get_video_reader( 116 | '{}.{}'.format(vid, ext), 117 | num_threads=threads, 118 | fast_rrc=fast_rrc, rrc_params=rrc_params, 119 | fast_rcc=fast_rcc, rcc_params=rcc_params, 120 | ) 121 | else: 122 | vr = get_video_reader( 123 | osp.join(root, '{}.{}'.format(vid, ext)), 124 | num_threads=threads, 125 | fast_rrc=fast_rrc, rrc_params=rrc_params, 126 | fast_rcc=fast_rcc, rcc_params=rcc_params, 127 | ) 128 | fps = vr.get_avg_fps() if fps == -1 else fps 129 | 130 | ### howto_crop 131 | 132 | if use_crop: 133 | end_second = len(vr) / fps 134 | second = 0 135 | else: 136 | end_second = min(end_second, len(vr) / fps) 137 | if end_second == 0: 138 | end_second = len(vr) / fps 139 | 140 | # calculate frame_ids 141 | frame_offset = int(np.round(second * fps)) 142 | total_duration = max(int((end_second - second) * fps), clip_length) 143 | 144 | frame_ids = get_frame_ids(frame_offset, min(frame_offset + total_duration, len(vr)), num_segments=clip_length, jitter=jitter) 145 | 146 | # load frames 147 | assert max(frame_ids) < len(vr) 148 | try: 149 | 150 | frames = vr.get_batch(frame_ids).asnumpy() 151 | except decord.DECORDError as error: 152 | print(error) 153 | frames = vr.get_batch([0] * len(frame_ids)).asnumpy() 154 | 155 | return torch.from_numpy(frames.astype(np.float32)) 156 | else: 157 | assert fps > 0, 'fps should be greater than 0' 158 | 159 | ## test broader 160 | # if end_second - second < 1: 161 | # if second >= 1: 162 | # second -= 1 163 | # end_second += 1 164 | # else: 165 | # end_second += 1.5 166 | 167 | ## sanity check, for those who have start >= end ## 168 | end_second = max(end_second, second + 1) 169 | 170 | chunk_start = int(second) // chunk_len * chunk_len 171 | chunk_end = int(end_second) // chunk_len * chunk_len 172 | 173 | # print(f'Vid={vid}, begin_sec={second}, end_sec={end_second}, \t, st_frame={int(np.round(second * fps))}, ed_frame={int(np.round(end_second * fps))}') 174 | # calculate frame_ids 175 | frame_ids = get_frame_ids( 176 | int(np.round(second * fps)), 177 | int(np.round(end_second * fps)), 178 | num_segments=clip_length, jitter=jitter 179 | ) 180 | # print(f'Frames: {frame_ids}') 181 | 182 | all_frames = [] 183 | total_frames = [] 184 | # allocate absolute frame-ids into the relative ones 185 | for chunk in range(chunk_start, chunk_end + chunk_len, chunk_len): 186 | # print(f'Chunk: {chunk}, \t, Rel_frame_ids={rel_frame_ids}') 187 | vr = get_video_reader( 188 | # osp.join(root, '{}.{}'.format(vid, ext), '{}.{}'.format(chunk, ext)), 189 | osp.join(root, vid, '{}.{}'.format(str(chunk // chunk_len).zfill(4), ext)), 190 | num_threads=threads, 191 | fast_rrc=fast_rrc, rrc_params=rrc_params, 192 | fast_rcc=fast_rcc, rcc_params=rcc_params, 193 | ) 194 | rel_frame_ids = list(filter(lambda x: int(chunk * fps) <= x < int((chunk + chunk_len) * fps), frame_ids)) 195 | # rel_frame_ids = [int(frame_id - chunk * fps) for frame_id in rel_frame_ids] 196 | rel_frame_ids = [min(len(vr) - 1, int(frame_id - chunk * fps)) for frame_id in rel_frame_ids] 197 | 198 | try: 199 | frames = vr.get_batch(rel_frame_ids).asnumpy() 200 | except decord.DECORDError as error: 201 | # print(error) 202 | frames = vr.get_batch([0] * len(rel_frame_ids)).asnumpy() 203 | except IndexError: 204 | print(root, vid, str(chunk // chunk_len).zfill(4), second, end_second) 205 | print(len(vr), rel_frame_ids) 206 | 207 | # try: 208 | # sub_frames = vr.get_batch([temp for temp in range(len(vr))]).asnumpy() 209 | # total_frames.append(sub_frames) 210 | except: 211 | pass 212 | all_frames.append(frames) 213 | if sum(map(lambda x: x.shape[0], all_frames)) == clip_length: 214 | break 215 | 216 | res = torch.from_numpy(np.concatenate(all_frames, axis=0).astype(np.float32)) 217 | 218 | # if res.shape[0] != clip_length and (len(total_frames) > 0): 219 | # all_frames = [] 220 | # res_new = np.concatenate(total_frames, axis=0).astype(np.float32) # n h w c 221 | # indexs = np.linspace(0,len(res_new)-1,clip_length,dtype=int) 222 | # for index in indexs: 223 | # all_frames.append(res_new[index:index+1]) 224 | # res = torch.from_numpy(np.concatenate(all_frames, axis=0).astype(np.float32)) 225 | # print(f'total: {len(res_new)}, indexs{indexs}, ') 226 | assert res.shape[0] == clip_length, "{}, {}, {}, {}, {}, {}, {}".format(root, vid, second, end_second, res.shape[0], rel_frame_ids, frame_ids) 227 | return res 228 | 229 | 230 | 231 | def video_loader_by_frames(root, vid, frame_ids): 232 | ''' 233 | args: 234 | root: root directory of the video 235 | vid: the unique vid of the video, e.g. hello.mp4 236 | frame_ids: the sampled frame indices 237 | return: 238 | frames: torch tensor with shape: [T, H, W, C] 239 | ''' 240 | vr = get_vr(osp.join(root, vid)) 241 | try: 242 | frames = vr.get_batch(frame_ids).asnumpy() 243 | frames = [torch.tensor(frame, dtype=torch.float32) for frame in frames] 244 | except (IndexError, decord.DECORDError) as error: 245 | print(error) 246 | print("Erroneous video: ", vid) 247 | frames = [torch.zeros((240, 320, 3)) for _ in range(len(frame_ids))] 248 | return torch.stack(frames, dim=0) 249 | 250 | def video_loader_by_timestamp(root, vid, start_timestamp=0, end_timestamp=0, 251 | clip_length=4, is_training=False, threads=1, 252 | fast_rrc=False, rrc_params=(224, (0.5, 1.0)), 253 | fast_rcc=False, rcc_params=(224, ),): 254 | ''' 255 | args: 256 | root: root directory of the video 257 | vid: the unique vid of the video, e.g. hello.mp4 258 | start_timestamp: the start second of the clip/video 259 | end_timestamp: the end second of the clip/video 260 | clip_length: the number of frames to be sampled 261 | is_training: whether it is training, jitter=True/False for train/test 262 | return: 263 | frames: torch tensor with shape: [T, H, W, C] 264 | ''' 265 | vr = get_video_reader(osp.join(root, vid), 266 | num_threads=threads, 267 | fast_rrc=fast_rrc, rrc_params=rrc_params, 268 | fast_rcc=fast_rcc, rcc_params=rcc_params, 269 | ) 270 | fps = vr.get_avg_fps() 271 | 272 | start_frame = int(np.round(fps * start_timestamp)) if start_timestamp else 0 273 | end_frame = int(np.ceil(fps * end_timestamp)) if end_timestamp else len(vr) - 1 274 | end_frame = min(end_frame, len(vr) - 1) 275 | 276 | frame_ids = get_frame_ids(start_frame, end_frame, num_segments=clip_length, jitter=is_training) 277 | 278 | try: 279 | frames = vr.get_batch(frame_ids).asnumpy() 280 | frames = [torch.tensor(frame, dtype=torch.float32) for frame in frames] 281 | except (IndexError, decord.DECORDError) as error: 282 | print(error) 283 | print("Erroneous video: ", vid, start_timestamp, end_timestamp, start_frame, end_frame, fps, frame_ids, len(vr) - 1) 284 | frames = [torch.zeros((240, 320, 3)) for _ in range(len(frame_ids))] 285 | 286 | return torch.stack(frames, dim=0) 287 | 288 | 289 | 290 | def get_frame_ids(start_frame, end_frame, num_segments=32, jitter=True): 291 | ''' 292 | args: 293 | start_frame: the beginning frame indice 294 | end_frame: the end frame indice 295 | num_segment: number of frames to be sampled 296 | jitter: True stands for random sampling, False means center sampling 297 | return: 298 | seq: a list for the sampled frame indices 299 | ''' 300 | assert start_frame <= end_frame 301 | seg_size = float(end_frame - start_frame - 1) / num_segments 302 | seq = [] 303 | for i in range(num_segments): 304 | start = int(np.round(seg_size * i) + start_frame) 305 | end = int(np.round(seg_size * (i + 1)) + start_frame) 306 | 307 | ### added here to avoid out-of-boundary of frame_id, as np.random.randint ### 308 | start = min(start, end_frame-1) 309 | end = min(end, end_frame) 310 | 311 | if jitter: 312 | frame_id = np.random.randint(low=start, high=(end + 1)) 313 | else: 314 | frame_id = (start + end) // 2 315 | 316 | seq.append(frame_id) 317 | return seq 318 | 319 | def generate_label_map(dataset, metapath): 320 | if dataset == 'ek100_cls': 321 | print("Preprocess ek100 action label space") 322 | vn_list = [] 323 | mapping_vn2narration = {} 324 | for f in [ 325 | f'{metapath}epic-kitchens-100-annotations/EPIC_100_train.csv', 326 | f'{metapath}epic-kitchens-100-annotations/EPIC_100_validation.csv', 327 | ]: 328 | csv_reader = csv.reader(open(f)) 329 | _ = next(csv_reader) # skip the header 330 | for row in csv_reader: 331 | vn = '{}:{}'.format(int(row[10]), int(row[12])) 332 | narration = row[8] 333 | if vn not in vn_list: 334 | vn_list.append(vn) 335 | if vn not in mapping_vn2narration: 336 | mapping_vn2narration[vn] = [narration] 337 | else: 338 | mapping_vn2narration[vn].append(narration) 339 | # mapping_vn2narration[vn] = [narration] 340 | vn_list = sorted(vn_list) 341 | print('# of action= {}'.format(len(vn_list))) 342 | mapping_vn2act = {vn: i for i, vn in enumerate(vn_list)} 343 | labels = [list(set(mapping_vn2narration[vn_list[i]])) for i in range(len(mapping_vn2act))] 344 | print(labels[:5]) 345 | elif dataset == 'charades_ego': 346 | print("=> preprocessing charades_ego action label space") 347 | vn_list = [] 348 | labels = [] 349 | with open(f'{metapath}Charades_v1_classes.txt') as f: 350 | csv_reader = csv.reader(f) 351 | for row in csv_reader: 352 | vn = row[0][:4] 353 | vn_list.append(vn) 354 | narration = row[0][5:] 355 | labels.append(narration) 356 | mapping_vn2act = {vn: i for i, vn in enumerate(vn_list)} 357 | print(labels[:5]) 358 | elif dataset == 'egtea': 359 | print("=> preprocessing egtea action label space") 360 | labels = [] 361 | with open(f'{metapath}action_idx.txt') as f: 362 | for row in f: 363 | row = row.strip() 364 | narration = ' '.join(row.split(' ')[:-1]) 365 | labels.append(narration.replace('_', ' ').lower()) 366 | # labels.append(narration) 367 | mapping_vn2act = {label: i for i, label in enumerate(labels)} 368 | print(len(labels), labels[:5]) 369 | else: 370 | raise NotImplementedError 371 | return labels, mapping_vn2act 372 | 373 | -------------------------------------------------------------------------------- /dataset/egodataset.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import glob 3 | import os.path as osp 4 | import pickle 5 | import random 6 | import numpy as np 7 | import pandas as pd 8 | import torch 9 | import os 10 | 11 | import decord 12 | from decord import cpu 13 | import io 14 | from ipdb import set_trace 15 | from .data_utils import video_loader 16 | from petrel_client.client import Client 17 | import ast 18 | import clip 19 | client = Client() 20 | 21 | class EgoExoDataset(torch.utils.data.Dataset): 22 | def __init__(self, cfg, transform=None, is_training=True, tokenizer=None, crop_size=224, 23 | subsample_stride=None): 24 | self.cfg = cfg 25 | self.dataset = cfg.dataset 26 | self.ego4d_root = cfg.ego4d_root 27 | self.ego4d_metadata = cfg.ego4d_metadata 28 | self.ego4d_chunk_len = cfg.ego4d_video_chunk_len 29 | self.ego4d_fps = cfg.ego4d_fps 30 | 31 | self.howto_root = cfg.howto_root 32 | self.howto_metadata = cfg.howto_metadata 33 | self.howto_chunk_len = cfg.howto_video_chunk_len 34 | self.howto_fps = cfg.howto_fps 35 | 36 | self.is_trimmed = cfg.is_trimmed 37 | ### hardcode this for now ### 38 | self.narration_selection = 'random' 39 | 40 | if self.dataset == 'ego4d': 41 | self.samples = pd.read_csv(self.ego4d_metadata) 42 | if cfg.ego4d_metadata_aux is not None: 43 | self.aux_samples = pd.read_csv(cfg.ego4d_metadata_aux) 44 | self.samples = pd.concat([self.samples, self.aux_samples]) 45 | 46 | elif self.dataset == 'htego': 47 | self.samples = pd.read_csv(self.howto_metadata) 48 | elif self.dataset == 'ego4d_htego': 49 | self.ego4d_samples = pd.read_csv(self.ego4d_metadata) 50 | if cfg.ego4d_metadata_aux is not None: 51 | self.aux_samples = pd.read_csv(cfg.ego4d_metadata_aux) 52 | self.ego4d_samples = pd.concat([self.ego4d_samples, self.aux_samples]) 53 | 54 | self.htego_samples = pd.read_csv(self.howto_metadata) 55 | self.samples = pd.concat([self.ego4d_samples, self.htego_samples]) 56 | else: 57 | raise NotImplementedError 58 | print(len(self.samples)) 59 | self.full_samples = self.samples.copy() 60 | if isinstance(subsample_stride, int): 61 | self.samples = self.samples[::subsample_stride] 62 | 63 | 64 | self.transform = transform 65 | self.is_training = is_training 66 | self.tokenizer = tokenizer 67 | self.clip_length = cfg.clip_length 68 | self.clip_stride = cfg.clip_stride 69 | self.threads = cfg.decode_threads 70 | self.context_length = cfg.context_length 71 | print(f'sentence length {self.context_length}') 72 | self.multiview = cfg.multiview 73 | 74 | self.fast_rrc = cfg.fused_decode_crop 75 | self.rrc_params = (crop_size, (0.5, 1.0)) 76 | 77 | 78 | def __len__(self): 79 | return len(self.samples) 80 | 81 | def process_text(self, narration): 82 | ### this is a list of narrations ### 83 | if narration[0] == '[' and narration[-1] == ']': 84 | narration = ast.literal_eval(narration) 85 | if self.narration_selection == 'random': 86 | narration = random.choice(narration) 87 | elif self.narration_selection == 'concat': 88 | narration = '. '.join(narration) 89 | else: 90 | raise NotImplementedError 91 | 92 | return narration 93 | 94 | def __getitem__(self, i): 95 | try: 96 | ### get indicator ### 97 | curr = self.samples.iloc[i] 98 | curr_dataset = curr['dataset'] if 'dataset' in curr else 'howto_ego' 99 | exo_vid_path = '' 100 | #print(curr['video_id'],curr_dataset) 101 | ### get data ### 102 | 103 | if curr_dataset == 'ego4d': 104 | 105 | vid, start_second, end_second, narration = curr['video_id'], curr['start_second'], curr['end_second'], curr['text'] 106 | # print(f'Getting ego video {vid} from {start_second} to {end_second}') 107 | 108 | frames = video_loader(self.ego4d_root, vid, 'mp4', start_second, end_second, 109 | chunk_len=self.ego4d_chunk_len, clip_length=self.clip_length, threads=self.threads, fps=self.ego4d_fps, 110 | fast_rrc=self.fast_rrc, rrc_params=self.rrc_params, jitter=self.is_training) 111 | 112 | 113 | narration = self.process_text(narration) 114 | frames_slow = frames 115 | exo_frames = torch.zeros_like(frames) 116 | 117 | else: 118 | vid = vid_path = curr['video_id'] 119 | start_second, end_second, narration = curr['start_second'], curr['end_second'], curr['text'] 120 | uid = curr['uid'] if 'uid' in curr else '{}_{}'.format(vid, start_second) 121 | 122 | frames = video_loader(self.howto_root, vid_path, 'mp4', start_second, end_second, 123 | chunk_len=self.howto_chunk_len, clip_length=self.clip_length, threads=self.threads, fps=self.howto_fps, 124 | fast_rrc=self.fast_rrc, rrc_params=self.rrc_params, jitter=self.is_training) 125 | frames_slow = frames 126 | 127 | raw_caption = narration 128 | 129 | if self.transform is not None: 130 | frames = frames.float() / 255.0 131 | frames = self.transform(frames.permute(0, 3, 1, 2)) 132 | frames_slow = self.transform(frames_slow.permute(0, 3, 1, 2)) 133 | 134 | if self.tokenizer is not None: 135 | narration = narration.replace('\n','') 136 | caption = self.tokenizer(narration) 137 | else: 138 | narration = narration.replace('\n','') 139 | caption = clip.tokenize(narration,context_length=77, truncate=True) 140 | return frames, frames_slow,caption 141 | 142 | except Exception as e: 143 | print(f'Error with sample {i}: {exo_vid_path} dataset:{curr_dataset} error {e}') 144 | ids = np.random.randint(0, len(self.samples)) 145 | return self.__getitem__(ids) 146 | -------------------------------------------------------------------------------- /dataset/ek100dataset.py: -------------------------------------------------------------------------------- 1 | 2 | import csv 3 | import glob 4 | import json 5 | import numpy as np 6 | import os.path as osp 7 | import pickle 8 | import random 9 | 10 | import decord 11 | import pandas as pd 12 | import torch 13 | from ipdb import set_trace 14 | import cv2 15 | import io,os 16 | 17 | from nltk.stem import WordNetLemmatizer 18 | from .data_utils import datetime2sec, get_frame_ids 19 | from .data_utils import video_loader_by_frames, video_loader_by_timestamp, video_loader 20 | from .data_utils import generate_label_map 21 | from petrel_client.client import Client 22 | import clip 23 | 24 | class EK100Dataset(torch.utils.data.Dataset): 25 | def __init__(self, config, transform=None, is_training=False, tokenizer=None, crop_size=224): 26 | ### common setups ### 27 | self.config = config 28 | self.root = config.root 29 | self.metadata = config.metadata 30 | self.clip_length = config.clip_length 31 | self.clip_stride = config.clip_stride 32 | ### maybe customized ### 33 | self.transform = transform 34 | self.is_training = is_training 35 | self.tokenizer = tokenizer 36 | 37 | self.chunk_len = config.video_chunk_len 38 | self.fps = config.fps 39 | self.threads = config.decode_threads 40 | 41 | if is_training: 42 | self.fast_rrc = config.fused_decode_crop 43 | self.rrc_params = (crop_size, (0.5, 1.0)) 44 | else: 45 | self.fast_rcc = config.fused_decode_crop 46 | self.rcc_params = (crop_size,) 47 | 48 | self.samples = [] 49 | with open(self.metadata) as f: 50 | csv_reader = csv.reader(f) 51 | _ = next(csv_reader) # skip the header 52 | for row in csv_reader: 53 | pid, vid = row[1:3] 54 | # start_frame, end_frame = int(row[6]), int(row[7]) 55 | # Deprecated: some videos might have fps mismatch issue 56 | start_timestamp, end_timestamp = datetime2sec(row[4]), datetime2sec(row[5]) 57 | narration = row[8] 58 | verb, noun = int(row[10]), int(row[12]) 59 | 60 | vid_path = '{}.mp4'.format(vid) 61 | self.samples.append((vid_path, start_timestamp, end_timestamp, narration, verb, noun)) 62 | 63 | # if self.dataset == 'ek100_mir': 64 | self.metadata_sentence = pd.read_csv(self.metadata[:self.metadata.index('.csv')] + '_sentence.csv') 65 | if 'train' in self.metadata: 66 | self.relevancy_mat = pickle.load(open(osp.join(osp.dirname(self.metadata), 'relevancy', 'caption_relevancy_EPIC_100_retrieval_train.pkl'), 'rb')) 67 | elif 'test' in self.metadata: 68 | self.relevancy_mat = pickle.load(open(osp.join(osp.dirname(self.metadata), 'relevancy', 'caption_relevancy_EPIC_100_retrieval_test.pkl'), 'rb')) 69 | else: 70 | raise ValueError('{} should contain either "train" or "test"!'.format(self.metadata)) 71 | self.relevancy = .1 72 | 73 | print(self.threads) 74 | def __len__(self): 75 | return len(self.samples) 76 | 77 | def get_raw_item(self, i): 78 | vid_path, start_timestamp, end_timestamp, narration, verb, noun = self.samples[i] 79 | # frames = video_loader_by_timestamp(self.root, vid_path, 80 | # start_timestamp=start_frame, end_timestamp=end_frame, 81 | # clip_length=self.clip_length, is_training=self.is_training, 82 | # threads=self.threads, fast_rcc=self.fast_rcc, rcc_params=self.rcc_params 83 | # ) 84 | 85 | if self.is_training: 86 | frames = video_loader(self.root, vid_path.replace('.mp4',''), 'mp4', start_timestamp, end_timestamp, 87 | chunk_len=self.chunk_len, clip_length=self.clip_length, threads=self.threads, fps=self.fps, 88 | fast_rrc=self.fast_rrc, rrc_params=self.rrc_params, jitter=self.is_training) 89 | frames_slow = video_loader(self.root, vid_path.replace('.mp4',''), 'mp4', start_timestamp, end_timestamp, 90 | chunk_len=self.chunk_len, clip_length=4, threads=self.threads, fps=self.fps, 91 | fast_rrc=self.fast_rrc, rrc_params=self.rrc_params, jitter=self.is_training) 92 | else: 93 | while True: 94 | try: 95 | frames = video_loader(self.root, vid_path.replace('.mp4',''), 'mp4', start_timestamp, end_timestamp, 96 | chunk_len=self.chunk_len, clip_length=self.clip_length, threads=self.threads, fps=self.fps, 97 | fast_rcc=self.fast_rcc, rcc_params=self.rcc_params, jitter=self.is_training) 98 | frames_slow = video_loader(self.root, vid_path.replace('.mp4',''), 'mp4', start_timestamp, end_timestamp, 99 | chunk_len=self.chunk_len, clip_length=4, threads=self.threads, fps=self.fps, 100 | fast_rcc=self.fast_rcc, rcc_params=self.rcc_params, jitter=self.is_training) 101 | break 102 | except: 103 | continue 104 | 105 | if self.transform is not None: 106 | frames = frames.float() / 255.0 107 | frames = self.transform(frames.permute(0, 3, 1, 2)) 108 | frames_slow = self.transform(frames_slow.permute(0, 3, 1, 2)) 109 | 110 | if self.is_training: 111 | positive_list = np.where(self.relevancy_mat[i] > self.relevancy)[0].tolist() 112 | if positive_list != []: 113 | pos = random.sample(positive_list, min(len(positive_list), 1))[0] 114 | if pos < len(self.metadata_sentence) and pos < self.relevancy_mat.shape[1]: 115 | return frames, frames_slow, self.metadata_sentence.iloc[pos][1], self.relevancy_mat[i][pos] 116 | else: 117 | return frames, frames_slow,narration, 1 118 | 119 | 120 | def __getitem__(self, i): 121 | ### for record info only ### 122 | vid_path, start_timestamp, end_timestamp, narration, verb, noun = self.samples[i] 123 | uid = vid_path 124 | raw_caption = narration 125 | 126 | frames, frames_slow,narration, relevancy = self.get_raw_item(i) 127 | 128 | #### this is for ek100_cls ### 129 | # if self.config.dataset == 'ek100_cls': 130 | # return frames, '{}:{}'.format(verb, noun) 131 | 132 | #### this is for ek100_mir ### 133 | caption = clip.tokenize(narration,context_length=77, truncate=True) 134 | 135 | return frames,frames_slow, caption,relevancy 136 | 137 | 138 | 139 | class EK100Dataset_CLS(torch.utils.data.Dataset): 140 | def __init__(self, config, transform=None, is_training=False, tokenizer=None, crop_size=224,use_bert=False): 141 | ### common setups ### 142 | self.root = config.root 143 | self.metadata = config.metadata 144 | self.clip_length = config.clip_length 145 | self.clip_stride = config.clip_stride 146 | 147 | ### maybe customized ### 148 | self.transform = transform 149 | self.is_training = is_training 150 | self.use_bert = use_bert 151 | self.tokenizer = tokenizer 152 | 153 | self.chunk_len = config.video_chunk_len 154 | self.fps = config.fps 155 | self.threads = config.decode_threads 156 | 157 | if is_training: 158 | self.fast_rrc = config.fused_decode_crop 159 | self.rrc_params = (crop_size, (0.5, 1.0)) 160 | else: 161 | self.fast_rcc = config.fused_decode_crop 162 | self.rcc_params = (crop_size,) 163 | 164 | 165 | self.samples = [] 166 | with open(self.metadata) as f: 167 | csv_reader = csv.reader(f) 168 | _ = next(csv_reader) # skip the header 169 | for row in csv_reader: 170 | pid, vid = row[1:3] 171 | # start_frame, end_frame = int(row[6]), int(row[7]) 172 | # Deprecated: some videos might have fps mismatch issue 173 | start_timestamp, end_timestamp = datetime2sec(row[4]), datetime2sec(row[5]) 174 | narration = row[8] 175 | verb, noun = int(row[10]), int(row[12]) 176 | 177 | vid_path = '{}.mp4'.format(vid) 178 | self.samples.append((vid_path, start_timestamp, end_timestamp, narration, verb, noun)) 179 | 180 | self.labels, self.label_mapping = generate_label_map('ek100_cls', config.metapath) 181 | 182 | def __len__(self): 183 | return len(self.samples) 184 | 185 | def get_raw_item(self, i): 186 | # vid_path, start_frame, end_frame, narration, verb, noun = self.samples[i] 187 | # frames = video_loader_by_timestamp(self.root, vid_path, 188 | # start_timestamp=start_frame, end_timestamp=end_frame, 189 | # clip_length=self.clip_length, is_training=self.is_training, 190 | # threads=self.threads, fast_rcc=self.fast_rcc, rcc_params=self.rcc_params 191 | # ) 192 | vid_path, start_timestamp, end_timestamp, narration, verb, noun = self.samples[i] 193 | 194 | if self.is_training: 195 | frames = video_loader(self.root, vid_path.replace('.mp4',''), 'mp4', start_timestamp, end_timestamp, 196 | chunk_len=self.chunk_len, clip_length=self.clip_length, threads=self.threads, fps=self.fps, 197 | fast_rrc=self.fast_rrc, rrc_params=self.rrc_params, jitter=self.is_training) 198 | # frames_slow = video_loader(self.root, vid_path.replace('.mp4',''), 'mp4', start_timestamp, end_timestamp, 199 | # chunk_len=self.chunk_len, clip_length=4, threads=self.threads, fps=self.fps, 200 | # fast_rrc=self.fast_rrc, rrc_params=self.rrc_params, jitter=self.is_training) 201 | frames_slow = frames 202 | else: 203 | frames = video_loader(self.root, vid_path.replace('.mp4',''), 'mp4', start_timestamp, end_timestamp, 204 | chunk_len=self.chunk_len, clip_length=self.clip_length, threads=self.threads, fps=self.fps, 205 | fast_rcc=self.fast_rcc, rcc_params=self.rcc_params, jitter=self.is_training) 206 | frames_slow = video_loader(self.root, vid_path.replace('.mp4',''), 'mp4', start_timestamp, end_timestamp, 207 | chunk_len=self.chunk_len, clip_length=4, threads=self.threads, fps=self.fps, 208 | fast_rcc=self.fast_rcc, rcc_params=self.rcc_params, jitter=self.is_training) 209 | return frames,frames_slow, f'{verb}:{noun}', narration 210 | 211 | def __getitem__(self, i): 212 | ### for record info only ### 213 | 214 | vid_path, start_frame, end_frame, narration, verb, noun = self.samples[i] 215 | 216 | frames, frames_slow, label, narration = self.get_raw_item(i) 217 | raw_caption = narration 218 | 219 | frames = self.transform(frames) if self.transform is not None else None 220 | 221 | if isinstance(label, list): 222 | # multi-label case 223 | res_array = np.zeros(len(self.label_mapping)) 224 | for lbl in label: 225 | res_array[self.label_mapping[lbl]] = 1. 226 | label = res_array 227 | else: 228 | raw_label = label 229 | label = self.label_mapping[label] 230 | 231 | return frames, frames_slow, label 232 | 233 | 234 | if __name__ == "__main__": 235 | import os 236 | import time 237 | from pathlib import Path 238 | 239 | import torch 240 | import torch.backends.cudnn as cudnn 241 | from torch.utils.tensorboard import SummaryWriter 242 | import torchvision.transforms as transforms 243 | import torchvision.datasets as datasets 244 | mean, std = (0.48145466, 0.4578275, 0.40821073),(0.26862954, 0.26130258, 0.27577711) 245 | transform_val = transforms.Compose([ 246 | transforms.Resize(256, interpolation=3), 247 | transforms.CenterCrop(224), 248 | transforms.Normalize(mean=mean, std=std)]) 249 | val_dataset = EK100Dataset('epic_kitchen/', transform=transform_val, is_training=False, tokenizer=None, crop_size=224) 250 | 251 | -------------------------------------------------------------------------------- /engine_pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | import math 12 | import sys 13 | from typing import Iterable 14 | 15 | import torch 16 | 17 | import util.misc as misc 18 | import util.lr_sched as lr_sched 19 | from einops import rearrange 20 | import numpy as np 21 | import pandas as pd 22 | from util.meter import * 23 | import torch 24 | import torch.backends.cudnn as cudnn 25 | from torch.utils.tensorboard import SummaryWriter 26 | import torchvision.transforms as transforms 27 | import torchvision.datasets as datasets 28 | 29 | def build_transform(model_name, mode): 30 | 31 | mean, std = (0.48145466, 0.4578275, 0.40821073),(0.26862954, 0.26130258, 0.27577711) 32 | input_size = 336 if model_name.endswith("_336PX") else 224 33 | # simple augmentation 34 | if mode == 'train': 35 | transform = transforms.Compose([ 36 | transforms.RandomResizedCrop(input_size, scale=(0.5, 1.0), interpolation=3), # 3 is bicubic 37 | transforms.RandomHorizontalFlip(), 38 | transforms.Normalize(mean=mean, std=std)]) 39 | else: 40 | transform = transforms.Compose([ 41 | transforms.Resize(224, interpolation=3), 42 | transforms.CenterCrop(224), 43 | transforms.Normalize(mean=mean, std=std)]) 44 | return transform 45 | 46 | def train_one_epoch(model: torch.nn.Module, 47 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 48 | device: torch.device, epoch: int, scaler, 49 | log_writer=None, 50 | args=None,criterion=None): 51 | model.train(True) 52 | metric_logger = misc.MetricLogger(delimiter=" ") 53 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 54 | header = 'Epoch: [{}]'.format(epoch) 55 | print_freq = 20 56 | 57 | accum_iter = args.accum_iter 58 | 59 | optimizer.zero_grad() 60 | 61 | if log_writer is not None: 62 | print('log_dir: {}'.format(log_writer.log_dir)) 63 | 64 | for data_iter_step, inputs in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 65 | 66 | # we use a per iteration (instead of per epoch) lr scheduler 67 | # if data_iter_step % accum_iter == 0: 68 | # lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) 69 | 70 | optimizer.zero_grad() 71 | inputs = [tensor.cuda(args.gpu, non_blocking=True) for tensor in inputs] 72 | 73 | with torch.cuda.amp.autocast(): 74 | inputs[0] = inputs[0].permute(0, 2, 1, 3, 4) # [b t 3 t w -> b c t h w] 75 | inputs[1] = inputs[1].permute(0, 2, 1, 3, 4) # [b t 3 t w -> b c t h w] 76 | image_features, text_features, logit_scale = model(*inputs) 77 | loss_dict = criterion(image_features, text_features, logit_scale) 78 | loss = loss_dict['loss'] 79 | 80 | scaler.scale(loss).backward() 81 | metric_logger.update(loss=loss.item()) 82 | lr = optimizer.param_groups[0]["lr"] 83 | metric_logger.update(lr=lr) 84 | 85 | if log_writer is not None and (data_iter_step + 1) % 10 == 0: 86 | """ We use epoch_1000x as the x-axis in tensorboard. 87 | This calibrates different curves when batch size changes. 88 | """ 89 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 90 | log_writer.add_scalar('train_loss', loss.item(), epoch_1000x) 91 | log_writer.add_scalar('lr', lr, epoch_1000x) 92 | scaler.step(optimizer) 93 | scaler.update() 94 | model.zero_grad(set_to_none=True) 95 | # gather the stats from all processes 96 | metric_logger.synchronize_between_processes() 97 | print("Averaged stats:", metric_logger) 98 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 99 | 100 | def validate_ek100_mir_zeroshot(val_loader, model, criterion, args, config,split): 101 | from torch import nn 102 | import torch.nn.functional as F 103 | import torch 104 | # switch to eval mode 105 | model.eval() 106 | 107 | all_video_embed = [[] for _ in range(args.world_size)] 108 | all_text_embed = [[] for _ in range(args.world_size)] 109 | total_num = 0 110 | with torch.cuda.amp.autocast(enabled=not config.train.disable_amp): 111 | with torch.no_grad(): 112 | for i, inputs in enumerate(val_loader): 113 | 114 | inputs = [tensor.cuda(args.gpu, non_blocking=True) for tensor in inputs] 115 | _ = inputs.pop() # loader will a "relevancy" variable which is not needed except ek100_mir 116 | 117 | inputs[0] = inputs[0].permute(0, 2, 1, 3, 4) # [b t 3 t w -> b c t h w] 118 | inputs[1] = inputs[1].permute(0, 2, 1, 3, 4) # [b t 3 t w -> b c t h w] 119 | 120 | image_features, text_features, logit_scale = model(*inputs) 121 | gathered_image_features = [torch.zeros_like(image_features) for _ in range(args.world_size)] 122 | gathered_text_features = [torch.zeros_like(text_features) for _ in range(args.world_size)] 123 | torch.distributed.all_gather(gathered_image_features, image_features) 124 | torch.distributed.all_gather(gathered_text_features, text_features) 125 | for j in range(args.world_size): 126 | all_video_embed[j].append(gathered_image_features[j].detach().cpu()) 127 | all_text_embed[j].append(gathered_text_features[j].detach().cpu()) 128 | 129 | total_num += image_features.shape[0] * args.world_size 130 | if i % 10 == 0: 131 | print(f'step {i}/{len(val_loader)}') 132 | for j in range(args.world_size): 133 | all_video_embed[j] = torch.cat(all_video_embed[j], dim=0).numpy() 134 | all_text_embed[j] = torch.cat(all_text_embed[j], dim=0).numpy() 135 | all_text_embed_reorg, all_video_embed_reorg = [], [] 136 | for i in range(total_num): 137 | all_video_embed_reorg.append(all_video_embed[i % args.world_size][i // args.world_size]) 138 | all_text_embed_reorg.append(all_text_embed[i % args.world_size][i // args.world_size]) 139 | all_text_embed = np.vstack(all_text_embed_reorg) 140 | all_video_embed = np.vstack(all_video_embed_reorg) 141 | all_text_embed = all_text_embed[:9668, :] 142 | all_video_embed = all_video_embed[:9668, :] 143 | similarity_matrix = np.matmul(all_video_embed, all_text_embed.T) 144 | similarity_matrix = (similarity_matrix + 1) / 2 145 | 146 | video_id = pd.read_csv(config.test.ek100_mir.metadata).values[:, 0] 147 | text_id = pd.read_csv(config.test.ek100_mir.metadata.replace('test', 'test_sentence')).values[:, 0] 148 | indexes = [video_id.tolist().index(elem) for elem in text_id] 149 | similarity_matrix = similarity_matrix[:, indexes] 150 | # similarity_matrix = torch.from_numpy(similarity_matrix) 151 | # similarity_matrix = similarity_matrix * F.softmax(similarity_matrix, dim=0)*len(similarity_matrix) 152 | # similarity_matrix = similarity_matrix.numpy() 153 | print(similarity_matrix.shape,text_id.shape,video_id.shape) 154 | rel_matrix = pd.read_pickle(config.test.ek100_mir.relevancy_path) 155 | vis_map, txt_map, avg_map = get_mAP(similarity_matrix, rel_matrix) 156 | print('mAP: V->T: {:.3f} T->V: {:.3f} AVG: {:.3f}'.format(vis_map, txt_map, avg_map)) 157 | vis_nDCG, txt_nDCG, avg_nDCG = get_nDCG(similarity_matrix, rel_matrix) 158 | print('nDCG: V->T: {:.3f} T->V: {:.3f} AVG: {:.3f}'.format(vis_nDCG, txt_nDCG, avg_nDCG)) 159 | return {'vis_map': vis_map, 'txt_map': txt_map, 'avg_map': avg_map, 160 | 'vis_ndcg': vis_nDCG, 'txt_ndcg': txt_nDCG, 'avg_ndcg': avg_nDCG} -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: hod 2 | channels: 3 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _openmp_mutex=5.1=1_gnu 8 | - bzip2=1.0.8=h5eee18b_5 9 | - ca-certificates=2024.3.11=h06a4308_0 10 | - ld_impl_linux-64=2.38=h1181459_1 11 | - libffi=3.4.4=h6a678d5_0 12 | - libgcc-ng=11.2.0=h1234567_1 13 | - libgomp=11.2.0=h1234567_1 14 | - libstdcxx-ng=11.2.0=h1234567_1 15 | - libuuid=1.41.5=h5eee18b_0 16 | - ncurses=6.4=h6a678d5_0 17 | - openssl=3.0.13=h7f8727e_0 18 | - python=3.10.14=h955ad1f_0 19 | - readline=8.2=h5eee18b_0 20 | - sqlite=3.41.2=h5eee18b_0 21 | - tk=8.6.12=h1ccaba5_0 22 | - xz=5.4.6=h5eee18b_0 23 | - zlib=1.2.13=h5eee18b_0 24 | - pip: 25 | - decord==0.6.0 26 | - ffmpeg==1.4 27 | - ffmpeg-python==0.2.0 28 | - numpy==1.25.2 29 | - pillow==9.3.0 30 | -------------------------------------------------------------------------------- /evaluation/eval_egomcq.py: -------------------------------------------------------------------------------- 1 | import json 2 | import json 3 | 4 | import time 5 | import func_timeout 6 | from func_timeout import func_set_timeout 7 | import os 8 | import torch 9 | import numpy as np 10 | import csv 11 | import glob 12 | import os.path as osp 13 | import pickle 14 | import random 15 | import numpy as np 16 | import pandas as pd 17 | import torch 18 | import os 19 | import decord 20 | from model.clip import * 21 | from util.config import get_config 22 | from dataset.data_utils import video_loader 23 | import argparse 24 | from tqdm import tqdm 25 | import decord 26 | from decord import cpu 27 | import io 28 | from ipdb import set_trace 29 | from petrel_client.client import Client 30 | import ast 31 | client = Client() 32 | import torchvision.transforms as transforms 33 | import torchvision.datasets as datasets 34 | 35 | 36 | chunk_sec = 600 # Each segment is up to 600s 37 | noun_dim = 582 # num of nouns of ego4d taxonomy dictionary 38 | verb_dim = 118 # num of verbs of ego4d taxonomy dictionary 39 | 40 | def get_args_parser(): 41 | parser = argparse.ArgumentParser('egomcq eval', add_help=False) 42 | 43 | parser.add_argument('--config_file', default='configs/no_decoder/clip_base_eval.yml', type=str,help='config file') 44 | parser.add_argument('--device', default='cuda',help='device to use for training / testing') 45 | parser.add_argument('--root', default='ego4d/videos_short320_chunked_15s/', type=str,help='root of egtea video clips') 46 | parser.add_argument('--metadata', default='egomcq.json', type=str,help='root of egtea annotations') 47 | parser.add_argument('--crop_size', default=224, type=int,help='root of egtea annotations') 48 | return parser 49 | 50 | def _get_caption(sample): 51 | noun_vec = torch.zeros(noun_dim) 52 | verb_vec = torch.zeros(verb_dim) 53 | noun_idx = eval(sample['tag_noun']) 54 | verb_idx = eval(sample['tag_verb']) 55 | for i in noun_idx: 56 | noun_vec[i] = 1 57 | for i in verb_idx: 58 | verb_vec[i] = 1 59 | 60 | return sample['clip_text'], noun_vec, verb_vec 61 | 62 | def build_transform(model_name): 63 | 64 | mean, std = (0.48145466, 0.4578275, 0.40821073),(0.26862954, 0.26130258, 0.27577711) 65 | input_size = 336 if model_name.endswith("_336PX") else 224 66 | # simple augmentation 67 | 68 | transform = transforms.Compose([ 69 | transforms.Resize(input_size, interpolation=3), 70 | transforms.CenterCrop(input_size), 71 | transforms.Normalize(mean=mean, std=std)]) 72 | return transform 73 | 74 | def main(args): 75 | config = get_config(args) 76 | ego4d_root = args.root 77 | crop_size = args.crop_size 78 | mean = (0.48145466 * 255,0.4578275 * 255,0.40821073 * 255) 79 | std = (0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255) 80 | with open(args.metadata,'r') as f: 81 | json_data = json.load(f) 82 | print(len(json_data)) 83 | # mean, std = [0.485 * 255, 0.456 * 255, 0.406 * 255], [0.229 * 255, 0.224 * 255, 0.225 * 255] 84 | import kornia as K 85 | gpu_val_transform_ls = [K.enhance.Normalize(mean=mean, std=std)] 86 | transform_gpu = torch.nn.Sequential(*gpu_val_transform_ls) 87 | from tqdm import tqdm 88 | model_name = config.model.name 89 | 90 | if model_name == 'CLIP_VITB16': 91 | model = CLIP_VITB16( 92 | config=config.model, 93 | freeze_temperature=config.model.freeze_temperature, 94 | use_grad_checkpointing=config.model.grad_checkpointing, 95 | context_length=config.data.context_length, 96 | vocab_size=config.data.vocab_size, 97 | patch_dropout=config.model.patch_dropout, 98 | num_frames=config.data.clip_length, 99 | drop_path_rate=config.model.drop_path_rate, 100 | use_fast_conv1=config.model.use_fast_conv1, 101 | use_flash_attn=config.model.use_flash_attn, 102 | use_quick_gelu=True, 103 | project_embed_dim=config.model.project_embed_dim, 104 | pretrain_zoo=config.model.pretrain_zoo, 105 | pretrain_path=config.model.pretrain_path, 106 | ) 107 | elif model_name == 'CLIP_VITL14_336PX': 108 | model = CLIP_VITL14_336PX( 109 | config=config.model, 110 | freeze_temperature=config.model.freeze_temperature, 111 | use_grad_checkpointing=config.model.grad_checkpointing, 112 | context_length=config.data.context_length, 113 | vocab_size=config.data.vocab_size, 114 | patch_dropout=config.model.patch_dropout, 115 | num_frames=config.data.clip_length, 116 | drop_path_rate=config.model.drop_path_rate, 117 | use_fast_conv1=config.model.use_fast_conv1, 118 | use_flash_attn=config.model.use_flash_attn, 119 | use_quick_gelu=True, 120 | project_embed_dim=config.model.project_embed_dim, 121 | pretrain_zoo=config.model.pretrain_zoo, 122 | pretrain_path=config.model.pretrain_path, 123 | ) 124 | elif model_name == 'CLIP_VITL14_336PX_Slowfast': 125 | model = CLIP_VITL14_336PX_Slowfast( 126 | config=config.model, 127 | freeze_temperature=config.model.freeze_temperature, 128 | use_grad_checkpointing=config.model.grad_checkpointing, 129 | context_length=config.data.context_length, 130 | vocab_size=config.data.vocab_size, 131 | patch_dropout=config.model.patch_dropout, 132 | num_frames=config.data.clip_length, 133 | drop_path_rate=config.model.drop_path_rate, 134 | use_fast_conv1=config.model.use_fast_conv1, 135 | use_flash_attn=config.model.use_flash_attn, 136 | use_quick_gelu=True, 137 | project_embed_dim=config.model.project_embed_dim, 138 | pretrain_zoo=config.model.pretrain_zoo, 139 | pretrain_path=config.model.pretrain_path, 140 | ) 141 | elif model_name == 'CLIP_VITB16_Slowfast': 142 | model = CLIP_VITB16_Slowfast( 143 | config=config.model, 144 | freeze_temperature=config.model.freeze_temperature, 145 | use_grad_checkpointing=config.model.grad_checkpointing, 146 | context_length=config.data.context_length, 147 | vocab_size=config.data.vocab_size, 148 | patch_dropout=config.model.patch_dropout, 149 | num_frames=config.data.clip_length, 150 | drop_path_rate=config.model.drop_path_rate, 151 | use_fast_conv1=config.model.use_fast_conv1, 152 | use_flash_attn=config.model.use_flash_attn, 153 | use_quick_gelu=True, 154 | project_embed_dim=config.model.project_embed_dim, 155 | pretrain_zoo=config.model.pretrain_zoo, 156 | pretrain_path=config.model.pretrain_path, 157 | ) 158 | if config.resume: 159 | print("=> loading resume checkpoint '{}'".format(config.resume)) 160 | curr_checkpoint = torch.load(config.resume, map_location='cpu') 161 | new_ckpt = {} 162 | 163 | for key,value in curr_checkpoint['state_dict'].items(): 164 | new_key = key.replace('module.','') 165 | new_ckpt[new_key] = value 166 | result = model.load_state_dict(new_ckpt) 167 | print(result) 168 | model = model.to('cuda') 169 | 170 | model = model.eval().cuda().half() 171 | 172 | transform = build_transform(model_name) 173 | total0 = 0 174 | total1 = 0 175 | ac0 = 0 176 | ac1 = 0 177 | 178 | def sim_matrix(a, b, eps=1e-8): 179 | """ 180 | added eps for numerical stability 181 | """ 182 | a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None] 183 | a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n)) 184 | b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n)) 185 | sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1)) 186 | return sim_mt 187 | 188 | for key,item in tqdm(json_data.items()): 189 | 190 | itemMCQ = json_data[key] 191 | 192 | answerIndex = itemMCQ['answer'] 193 | sampleQuery = itemMCQ['query'] 194 | 195 | textQuery, _, _ = _get_caption(sampleQuery) 196 | 197 | sampleOptions = itemMCQ['choices'] 198 | num_options = len(sampleOptions) 199 | textOptions = [] 200 | videoOptions = torch.zeros([num_options, 16,crop_size,crop_size,3]) 201 | 202 | for id, option in enumerate(sampleOptions): 203 | si = sampleOptions[option] 204 | video_id = si['video_uid'] 205 | start = float(si['clip_start']) 206 | end = float(si['clip_end']) 207 | 208 | caption, _, _ = _get_caption(si) 209 | textOptions.append(caption) 210 | 211 | frames = video_loader(ego4d_root, video_id, 'mp4', start, end, 212 | chunk_len=15, clip_length=16, threads=1, fps=30, 213 | fast_rcc=True, rcc_params=(crop_size,), jitter=False) 214 | frames = frames.float() / 255.0 215 | frames = transform(frames.permute(0, 3, 1, 2)) 216 | frames = frames.permute(0, 2, 3, 1) 217 | videoOptions[id] = frames 218 | 219 | type = itemMCQ['types'] 220 | 221 | videoOptions = rearrange(videoOptions,'b t h w c->b c t h w') 222 | 223 | videoOptions = videoOptions.to(torch.float16).to('cuda') 224 | 225 | data = {'video': videoOptions, 'text': textQuery, 'text_ops':textOptions, 'correct': answerIndex, 'type': type} 226 | 227 | data['text'] = data['text'] 228 | 229 | text = clip.tokenize(data['text'],truncate=True).to('cuda') 230 | 231 | text_embed = model.encode_text(text) 232 | text_embed = F.normalize(text_embed, dim=-1) 233 | 234 | vid_embed = model.encode_image(videoOptions)[0] 235 | vid_embed = F.normalize(vid_embed,dim=-1) 236 | # text_embed = model.encode_text(text,mask) 237 | # text_embed = F.normalize(text_embed, dim=-1) 238 | 239 | data_gt = data['correct'] 240 | data_pred = sim_matrix(text_embed, vid_embed) 241 | 242 | index = torch.argmax(data_pred) 243 | data_type = data['type'] 244 | if data_type == 1: 245 | total0 += 1 246 | if index == data_gt: 247 | ac0 += 1 248 | else: 249 | total1 += 1 250 | if index == data_gt: 251 | ac1 += 1 252 | 253 | acc0 = round(ac0 / max(1,total0),2) 254 | acc1 = round(ac1 / max(1,total1),2) 255 | print(f'number of sample {total0 + total1}') 256 | print(f'inter acc: {acc0} acc0:{ac0} total:{total0}') 257 | print(f'intra acc: {acc1} acc1:{ac1} total:{total1}') 258 | print('---------------------------------------') 259 | 260 | 261 | if __name__ == '__main__': 262 | args = get_args_parser() 263 | args = args.parse_args() 264 | main(args) 265 | -------------------------------------------------------------------------------- /evaluation/eval_egtea.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from PIL import Image 4 | import torchvision.transforms as T 5 | from torchvision import transforms 6 | from torchvision.transforms import InterpolationMode 7 | import torch.nn.functional as F 8 | import argparse 9 | from model.clip import * 10 | from util.config import get_config 11 | import numpy as np 12 | import os.path as osp 13 | import decord 14 | import clip 15 | 16 | def get_args_parser(): 17 | parser = argparse.ArgumentParser('EGTEA eval', add_help=False) 18 | 19 | parser.add_argument('--config_file', default='clip_base_eval.yml', type=str,help='config file') 20 | parser.add_argument('--device', default='cuda', 21 | help='device to use for training / testing') 22 | parser.add_argument('--root', default='egtea_gaze/cropped_clips', type=str,help='root of egtea video clips') 23 | parser.add_argument('--metadata', default='./egtea', type=str,help='root of egtea annotations') 24 | parser.add_argument('--crop_size', default=224, type=int,help='root of egtea annotations') 25 | return parser 26 | 27 | def get_frame_ids(start_frame, end_frame, num_segments=32, jitter=True): 28 | frame_ids = np.convolve(np.linspace(start_frame, end_frame, num_segments + 1), [0.5, 0.5], mode='valid') 29 | return frame_ids.astype(int).tolist() 30 | 31 | 32 | class VideoClassyDataset(torch.utils.data.Dataset): 33 | def __init__( 34 | self, root, metadata,crop_size=224, transform=None, 35 | is_training=True, label_mapping=None, 36 | num_clips=1, 37 | clip_length=32, clip_stride=2, 38 | sparse_sample=False, 39 | is_trimmed=True, 40 | anno_dir='' 41 | ): 42 | super().__init__() 43 | 44 | metadata = metadata 45 | 46 | # mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 47 | mean = (0.48145466,0.4578275,0.40821073) 48 | std = (0.26862954, 0.26130258, 0.27577711) 49 | transform = transforms.Compose( 50 | [ 51 | T.Resize((crop_size), interpolation=InterpolationMode.BICUBIC), 52 | T.CenterCrop(crop_size), 53 | transforms.Lambda(lambda x: x.float().div(255.0)), 54 | T.Normalize(mean=mean, std=std) 55 | ] 56 | ) 57 | self.transform = transform 58 | self.is_training = is_training 59 | self.label_mapping = label_mapping 60 | self.num_clips = num_clips 61 | self.clip_length = clip_length 62 | self.clip_stride = clip_stride 63 | self.sparse_sample = sparse_sample 64 | self.anno_dir = anno_dir 65 | self.root = root 66 | vn_list = [] 67 | labels = [] 68 | for row in open(f'{metadata}/action_idx.txt'): 69 | row = row.strip() 70 | vn = int(row.split(' ')[-1]) 71 | vn_list.append(vn) 72 | narration = ' '.join(row.split(' ')[:-1]) 73 | labels.append(narration.replace('_', ' ').lower()) 74 | self.mapping_act2narration = {vn: narration for vn, narration in zip(vn_list, labels)} 75 | self.samples = [] 76 | with open(f'{metadata}/test_split1.txt') as f: 77 | for row in f: 78 | clip_id, action_idx = row.strip().split(' ')[:2] 79 | video_id = '-'.join(clip_id.split('-')[:3]) 80 | vid_relpath = osp.join(video_id, '{}.mp4'.format(clip_id)) 81 | vid_fullpath = osp.join(self.root, video_id, '{}.mp4'.format(clip_id)) 82 | self.samples.append((vid_relpath, 0, self.mapping_act2narration[int(action_idx)])) 83 | 84 | def __len__(self): 85 | return len(self.samples) 86 | 87 | def __getitem__(self, i): 88 | 89 | vid,_,label = self.samples[i] 90 | 91 | vr = decord.VideoReader(os.path.join(self.root,vid),num_threads=1) 92 | fps = vr.get_avg_fps() 93 | if len(vr) > 12000000: 94 | frame_ids = get_frame_ids(0,int(len(vr) * 0.75),16) 95 | frame_ids_slow = get_frame_ids(0,int(len(vr) * 0.75),4) 96 | else: 97 | frame_ids = get_frame_ids(0,int(len(vr)),16) 98 | frame_ids_slow = get_frame_ids(0,int(len(vr)),4) 99 | frames = torch.from_numpy(vr.get_batch(frame_ids).asnumpy()) 100 | frames = frames.permute(0, 3, 1, 2) # (T, C, H, W), torch.uint8 101 | frames = self.transform(frames) 102 | 103 | 104 | frames_slow = torch.from_numpy(vr.get_batch(frame_ids_slow).asnumpy()) 105 | frames_slow = frames_slow.permute(0, 3, 1, 2) # (T, C, H, W), torch.uint8 106 | frames_slow = self.transform(frames_slow) 107 | 108 | return frames,frames_slow,label 109 | 110 | 111 | 112 | def main(args): 113 | config = get_config(args) 114 | 115 | root = args.root 116 | metadata = args.metadata 117 | crop_size = args.crop_size 118 | 119 | dataset = VideoClassyDataset(root,metadata,crop_size) 120 | 121 | from tqdm import tqdm 122 | model_name = config.model.name 123 | 124 | if model_name == 'CLIP_VITB16': 125 | model = CLIP_VITB16( 126 | config=config.model, 127 | freeze_temperature=config.model.freeze_temperature, 128 | use_grad_checkpointing=config.model.grad_checkpointing, 129 | context_length=config.data.context_length, 130 | vocab_size=config.data.vocab_size, 131 | patch_dropout=config.model.patch_dropout, 132 | num_frames=config.data.clip_length, 133 | drop_path_rate=config.model.drop_path_rate, 134 | use_fast_conv1=config.model.use_fast_conv1, 135 | use_flash_attn=config.model.use_flash_attn, 136 | use_quick_gelu=True, 137 | project_embed_dim=config.model.project_embed_dim, 138 | pretrain_zoo=config.model.pretrain_zoo, 139 | pretrain_path=config.model.pretrain_path, 140 | ) 141 | elif model_name == 'CLIP_VITL14_336PX': 142 | model = CLIP_VITL14_336PX( 143 | config=config.model, 144 | freeze_temperature=config.model.freeze_temperature, 145 | use_grad_checkpointing=config.model.grad_checkpointing, 146 | context_length=config.data.context_length, 147 | vocab_size=config.data.vocab_size, 148 | patch_dropout=config.model.patch_dropout, 149 | num_frames=config.data.clip_length, 150 | drop_path_rate=config.model.drop_path_rate, 151 | use_fast_conv1=config.model.use_fast_conv1, 152 | use_flash_attn=config.model.use_flash_attn, 153 | use_quick_gelu=True, 154 | project_embed_dim=config.model.project_embed_dim, 155 | pretrain_zoo=config.model.pretrain_zoo, 156 | pretrain_path=config.model.pretrain_path, 157 | ) 158 | elif model_name == 'CLIP_VITL14_336PX_Slowfast': 159 | model = CLIP_VITL14_336PX_Slowfast( 160 | config=config.model, 161 | freeze_temperature=config.model.freeze_temperature, 162 | use_grad_checkpointing=config.model.grad_checkpointing, 163 | context_length=config.data.context_length, 164 | vocab_size=config.data.vocab_size, 165 | patch_dropout=config.model.patch_dropout, 166 | num_frames=config.data.clip_length, 167 | drop_path_rate=config.model.drop_path_rate, 168 | use_fast_conv1=config.model.use_fast_conv1, 169 | use_flash_attn=config.model.use_flash_attn, 170 | use_quick_gelu=True, 171 | project_embed_dim=config.model.project_embed_dim, 172 | pretrain_zoo=config.model.pretrain_zoo, 173 | pretrain_path=config.model.pretrain_path, 174 | ) 175 | elif model_name == 'CLIP_VITB16_Slowfast': 176 | model = CLIP_VITB16_Slowfast( 177 | config=config.model, 178 | freeze_temperature=config.model.freeze_temperature, 179 | use_grad_checkpointing=config.model.grad_checkpointing, 180 | context_length=config.data.context_length, 181 | vocab_size=config.data.vocab_size, 182 | patch_dropout=config.model.patch_dropout, 183 | num_frames=config.data.clip_length, 184 | drop_path_rate=config.model.drop_path_rate, 185 | use_fast_conv1=config.model.use_fast_conv1, 186 | use_flash_attn=config.model.use_flash_attn, 187 | use_quick_gelu=True, 188 | project_embed_dim=config.model.project_embed_dim, 189 | pretrain_zoo=config.model.pretrain_zoo, 190 | pretrain_path=config.model.pretrain_path, 191 | ) 192 | 193 | model = model.to('cuda') 194 | 195 | if config.resume: 196 | print("=> loading resume checkpoint '{}'".format(config.resume)) 197 | curr_checkpoint = torch.load(config.resume, map_location='cpu') 198 | new_ckpt = {} 199 | 200 | for key,value in curr_checkpoint['state_dict'].items(): 201 | new_key = key.replace('module.','') 202 | new_ckpt[new_key] = value 203 | result = model.load_state_dict(new_ckpt) 204 | print(result) 205 | 206 | model = model.eval().cuda().half() 207 | 208 | words = [] 209 | words_origin = [] 210 | narration2act = {} 211 | for i in range(1,107): 212 | word = dataset.mapping_act2narration[i] 213 | narration2act[word] = i 214 | text = clip.tokenize(word).to('cuda') 215 | text_embed = model.encode_text(text) 216 | words.append(F.normalize(text_embed, dim=-1)) 217 | 218 | words = torch.stack(words) 219 | words = words.squeeze() 220 | 221 | ans = [] 222 | total = 0 223 | acc = 0 224 | 225 | 226 | 227 | acc_total = [0 for i in range(106)] 228 | acc_acc = [0 for i in range(106)] 229 | 230 | for i in range(len(dataset)): 231 | with torch.no_grad(): 232 | frames,frames_slow,label = dataset[i] 233 | frames = frames.to('cuda').unsqueeze(0).to(torch.float16) 234 | frames = frames.permute(0, 2, 1, 3, 4) 235 | 236 | frames_slow = frames_slow.to('cuda').unsqueeze(0).to(torch.float16) 237 | frames_slow = frames_slow.permute(0, 2, 1, 3, 4) 238 | 239 | image_embed = model.encode_image(frames)[0] 240 | image_embed = F.normalize(image_embed, dim=-1) 241 | 242 | similarities = F.cosine_similarity(image_embed, words, dim=1) 243 | 244 | most_similar_index = torch.argmax(similarities) 245 | index2word = dataset.mapping_act2narration[most_similar_index.item() + 1] 246 | #label2word = dataset.mapping_act2narration[label] 247 | print(f"ans: {label} our: {index2word}") 248 | 249 | ans.append(most_similar_index.item()) 250 | total += 1 251 | if most_similar_index.item() + 1 == narration2act[label]: 252 | acc += 1 253 | acc_acc[narration2act[label] - 1] += 1 254 | 255 | acc_total[narration2act[label] - 1] += 1 256 | 257 | print(f'acc: {acc / total}') 258 | print('---------------------------------') 259 | 260 | mean_acc = 0 261 | for k in range(106): 262 | mean_acc += acc_acc[k] / acc_total[k] 263 | mean_acc = mean_acc / 106.0 264 | print(f'mean acc is {mean_acc}') 265 | 266 | if __name__ == '__main__': 267 | args = get_args_parser() 268 | args = args.parse_args() 269 | main(args) 270 | -------------------------------------------------------------------------------- /evaluation/eval_ekcls.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os.path as osp 3 | 4 | import csv 5 | import glob 6 | import json 7 | import numpy as np 8 | import os.path as osp 9 | import pickle 10 | import random 11 | 12 | import pandas as pd 13 | import torch 14 | from ipdb import set_trace 15 | import cv2 16 | import io,os 17 | 18 | import torch 19 | import os 20 | from PIL import Image 21 | import torchvision.transforms as T 22 | from torchvision import transforms 23 | from torchvision.transforms import InterpolationMode 24 | import torch.nn.functional as F 25 | 26 | import time 27 | import func_timeout 28 | from func_timeout import func_set_timeout 29 | 30 | import csv 31 | import glob 32 | import json 33 | import numpy as np 34 | import os.path as osp 35 | import pickle 36 | import random 37 | 38 | import pandas as pd 39 | import torch 40 | from decord import cpu 41 | import cv2 42 | import io,os 43 | import argparse 44 | 45 | import decord 46 | from model.clip import * 47 | from util.config import get_config 48 | from dataset.data_utils import video_loader 49 | import torchvision.transforms as transforms 50 | import torchvision.datasets as datasets 51 | try: 52 | from petrel_client.client import Client 53 | client = Client() 54 | 55 | # Disable boto logger 56 | import logging 57 | logging.getLogger('boto3').setLevel(logging.WARNING) 58 | logging.getLogger('botocore').setLevel(logging.WARNING) 59 | logging.getLogger('nose').setLevel(logging.WARNING) 60 | except: 61 | client = None 62 | 63 | def get_args_parser(): 64 | parser = argparse.ArgumentParser('EK-CLS eval', add_help=False) 65 | 66 | parser.add_argument('--config_file', default='configs/no_decoder/clip_base_eval.yml', type=str,help='config file') 67 | parser.add_argument('--device', default='cuda', 68 | help='device to use for training / testing') 69 | parser.add_argument('--root', default='epic/epic_video_320p/', type=str) 70 | parser.add_argument('--metadata', default='epic_kitchen/', type=str,help='root of egtea annotations') 71 | parser.add_argument('--crop_size', default=224, type=int,help='root of egtea annotations') 72 | return parser 73 | 74 | def generate_label_map(dataset, metapath): 75 | if dataset == 'ek100_cls': 76 | print("Preprocess ek100 action label space") 77 | vn_list = [] 78 | mapping_vn2narration = {} 79 | for f in [ 80 | f'{metapath}epic-kitchens-100-annotations/EPIC_100_train.csv', 81 | f'{metapath}epic-kitchens-100-annotations/EPIC_100_validation.csv', 82 | ]: 83 | csv_reader = csv.reader(open(f)) 84 | _ = next(csv_reader) # skip the header 85 | for row in csv_reader: 86 | vn = '{}:{}'.format(int(row[10]), int(row[12])) 87 | narration = row[8] 88 | if vn not in vn_list: 89 | vn_list.append(vn) 90 | if vn not in mapping_vn2narration: 91 | mapping_vn2narration[vn] = [narration] 92 | else: 93 | mapping_vn2narration[vn].append(narration) 94 | # mapping_vn2narration[vn] = [narration] 95 | vn_list = sorted(vn_list) 96 | print('# of action= {}'.format(len(vn_list))) 97 | mapping_vn2act = {vn: i for i, vn in enumerate(vn_list)} 98 | labels = [list(set(mapping_vn2narration[vn_list[i]])) for i in range(len(mapping_vn2act))] 99 | print(labels[:5]) 100 | elif dataset == 'charades_ego': 101 | print("=> preprocessing charades_ego action label space") 102 | vn_list = [] 103 | labels = [] 104 | with open(f'{metapath}Charades_v1_classes.txt') as f: 105 | csv_reader = csv.reader(f) 106 | for row in csv_reader: 107 | vn = row[0][:4] 108 | vn_list.append(vn) 109 | narration = row[0][5:] 110 | labels.append(narration) 111 | mapping_vn2act = {vn: i for i, vn in enumerate(vn_list)} 112 | print(labels[:5]) 113 | elif dataset == 'egtea': 114 | print("=> preprocessing egtea action label space") 115 | labels = [] 116 | with open(f'{metapath}action_idx.txt') as f: 117 | for row in f: 118 | row = row.strip() 119 | narration = ' '.join(row.split(' ')[:-1]) 120 | labels.append(narration.replace('_', ' ').lower()) 121 | # labels.append(narration) 122 | mapping_vn2act = {label: i for i, label in enumerate(labels)} 123 | print(len(labels), labels[:5]) 124 | else: 125 | raise NotImplementedError 126 | return labels, mapping_vn2act 127 | 128 | def datetime2sec(str): 129 | hh, mm, ss = str.split(':') 130 | return int(hh) * 3600 + int(mm) * 60 + float(ss) 131 | 132 | class EK100Dataset_CLS(torch.utils.data.Dataset): 133 | def __init__(self,root,metadata,crop_size=224): 134 | ### common setups ### 135 | self.root = root 136 | self.metadata = f'{metadata}/EPIC_100_retrieval_test.csv' 137 | self.clip_length = 16 138 | self.clip_stride = 2 139 | 140 | ### maybe customized ### 141 | self.transform = None 142 | self.is_training = False 143 | 144 | self.chunk_len = -1 145 | self.fps = -1 146 | self.threads = True 147 | 148 | 149 | self.fast_rcc = True 150 | self.rcc_params = (crop_size,) 151 | mean, std = (0.48145466, 0.4578275, 0.40821073),(0.26862954, 0.26130258, 0.27577711) 152 | self.transform = transforms.Compose([ 153 | transforms.Resize(crop_size, interpolation=3), 154 | transforms.CenterCrop(crop_size), 155 | transforms.Normalize(mean=mean, std=std)]) 156 | self.samples = [] 157 | with open(self.metadata) as f: 158 | csv_reader = csv.reader(f) 159 | _ = next(csv_reader) # skip the header 160 | for row in csv_reader: 161 | pid, vid = row[1:3] 162 | # start_frame, end_frame = int(row[6]), int(row[7]) 163 | # Deprecated: some videos might have fps mismatch issue 164 | start_timestamp, end_timestamp = datetime2sec(row[4]), datetime2sec(row[5]) 165 | narration = row[8] 166 | verb, noun = int(row[10]), int(row[12]) 167 | 168 | vid_path = '{}.mp4'.format(vid) 169 | self.samples.append((vid_path, start_timestamp, end_timestamp, narration, verb, noun)) 170 | 171 | self.labels, self.label_mapping = generate_label_map('ek100_cls', metadata) 172 | 173 | def __len__(self): 174 | return len(self.samples) 175 | 176 | def get_raw_item(self, i): 177 | # vid_path, start_frame, end_frame, narration, verb, noun = self.samples[i] 178 | # frames = video_loader_by_timestamp(self.root, vid_path, 179 | # start_timestamp=start_frame, end_timestamp=end_frame, 180 | # clip_length=self.clip_length, is_training=self.is_training, 181 | # threads=self.threads, fast_rcc=self.fast_rcc, rcc_params=self.rcc_params 182 | # ) 183 | vid_path, start_timestamp, end_timestamp, narration, verb, noun = self.samples[i] 184 | 185 | if self.is_training: 186 | frames = video_loader(self.root, vid_path.replace('.mp4',''), 'mp4', start_timestamp, end_timestamp, 187 | chunk_len=self.chunk_len, clip_length=self.clip_length, threads=self.threads, fps=self.fps, 188 | fast_rrc=self.fast_rrc, rrc_params=self.rrc_params, jitter=self.is_training) 189 | # frames_slow = video_loader(self.root, vid_path.replace('.mp4',''), 'mp4', start_timestamp, end_timestamp, 190 | # chunk_len=self.chunk_len, clip_length=4, threads=self.threads, fps=self.fps, 191 | # fast_rrc=self.fast_rrc, rrc_params=self.rrc_params, jitter=self.is_training) 192 | frames_slow = frames 193 | else: 194 | frames = video_loader(self.root, vid_path.replace('.mp4',''), 'mp4', start_timestamp, end_timestamp, 195 | chunk_len=self.chunk_len, clip_length=self.clip_length, threads=self.threads, fps=self.fps, 196 | fast_rcc=self.fast_rcc, rcc_params=self.rcc_params, jitter=self.is_training) 197 | frames_slow = video_loader(self.root, vid_path.replace('.mp4',''), 'mp4', start_timestamp, end_timestamp, 198 | chunk_len=self.chunk_len, clip_length=4, threads=self.threads, fps=self.fps, 199 | fast_rcc=self.fast_rcc, rcc_params=self.rcc_params, jitter=self.is_training) 200 | frames = frames.float() / 255.0 201 | frames = self.transform(frames.permute(0, 3, 1, 2)) 202 | 203 | return frames,frames_slow, f'{verb}:{noun}', narration 204 | 205 | def __getitem__(self, i): 206 | ### for record info only ### 207 | vid_path, start_frame, end_frame, narration, verb, noun = self.samples[i] 208 | 209 | frames, frames_slow, label, narration = self.get_raw_item(i) 210 | raw_caption = narration 211 | 212 | 213 | if isinstance(label, list): 214 | # multi-label case 215 | res_array = np.zeros(len(self.label_mapping)) 216 | for lbl in label: 217 | res_array[self.label_mapping[lbl]] = 1. 218 | label = res_array 219 | else: 220 | raw_label = label 221 | label = self.label_mapping[label] 222 | 223 | return frames, frames_slow, label 224 | 225 | def main(args): 226 | config = get_config(args) 227 | root = args.root 228 | crop_size = args.crop_size 229 | metadata = args.metadata 230 | dataset = EK100Dataset_CLS(root,metadata) 231 | 232 | 233 | model_name = config.model.name 234 | 235 | if model_name == 'CLIP_VITB16': 236 | model = CLIP_VITB16( 237 | config=config.model, 238 | freeze_temperature=config.model.freeze_temperature, 239 | use_grad_checkpointing=config.model.grad_checkpointing, 240 | context_length=config.data.context_length, 241 | vocab_size=config.data.vocab_size, 242 | patch_dropout=config.model.patch_dropout, 243 | num_frames=config.data.clip_length, 244 | drop_path_rate=config.model.drop_path_rate, 245 | use_fast_conv1=config.model.use_fast_conv1, 246 | use_flash_attn=config.model.use_flash_attn, 247 | use_quick_gelu=True, 248 | project_embed_dim=config.model.project_embed_dim, 249 | pretrain_zoo=config.model.pretrain_zoo, 250 | pretrain_path=config.model.pretrain_path, 251 | ) 252 | elif model_name == 'CLIP_VITL14_336PX': 253 | model = CLIP_VITL14_336PX( 254 | config=config.model, 255 | freeze_temperature=config.model.freeze_temperature, 256 | use_grad_checkpointing=config.model.grad_checkpointing, 257 | context_length=config.data.context_length, 258 | vocab_size=config.data.vocab_size, 259 | patch_dropout=config.model.patch_dropout, 260 | num_frames=config.data.clip_length, 261 | drop_path_rate=config.model.drop_path_rate, 262 | use_fast_conv1=config.model.use_fast_conv1, 263 | use_flash_attn=config.model.use_flash_attn, 264 | use_quick_gelu=True, 265 | project_embed_dim=config.model.project_embed_dim, 266 | pretrain_zoo=config.model.pretrain_zoo, 267 | pretrain_path=config.model.pretrain_path, 268 | ) 269 | elif model_name == 'CLIP_VITL14_336PX_Slowfast': 270 | model = CLIP_VITL14_336PX_Slowfast( 271 | config=config.model, 272 | freeze_temperature=config.model.freeze_temperature, 273 | use_grad_checkpointing=config.model.grad_checkpointing, 274 | context_length=config.data.context_length, 275 | vocab_size=config.data.vocab_size, 276 | patch_dropout=config.model.patch_dropout, 277 | num_frames=config.data.clip_length, 278 | drop_path_rate=config.model.drop_path_rate, 279 | use_fast_conv1=config.model.use_fast_conv1, 280 | use_flash_attn=config.model.use_flash_attn, 281 | use_quick_gelu=True, 282 | project_embed_dim=config.model.project_embed_dim, 283 | pretrain_zoo=config.model.pretrain_zoo, 284 | pretrain_path=config.model.pretrain_path, 285 | ) 286 | elif model_name == 'CLIP_VITB16_Slowfast': 287 | model = CLIP_VITB16_Slowfast( 288 | config=config.model, 289 | freeze_temperature=config.model.freeze_temperature, 290 | use_grad_checkpointing=config.model.grad_checkpointing, 291 | context_length=config.data.context_length, 292 | vocab_size=config.data.vocab_size, 293 | patch_dropout=config.model.patch_dropout, 294 | num_frames=config.data.clip_length, 295 | drop_path_rate=config.model.drop_path_rate, 296 | use_fast_conv1=config.model.use_fast_conv1, 297 | use_flash_attn=config.model.use_flash_attn, 298 | use_quick_gelu=True, 299 | project_embed_dim=config.model.project_embed_dim, 300 | pretrain_zoo=config.model.pretrain_zoo, 301 | pretrain_path=config.model.pretrain_path, 302 | ) 303 | if config.resume: 304 | print("=> loading resume checkpoint '{}'".format(config.resume)) 305 | curr_checkpoint = torch.load(config.resume, map_location='cpu') 306 | new_ckpt = {} 307 | 308 | for key,value in curr_checkpoint['state_dict'].items(): 309 | new_key = key.replace('module.','') 310 | new_ckpt[new_key] = value 311 | result = model.load_state_dict(new_ckpt) 312 | print(result) 313 | model = model.to('cuda') 314 | 315 | model = model.eval().cuda().half() 316 | ans = [] 317 | 318 | 319 | text_features = [] 320 | labels = dataset.labels 321 | num_clips = 16 322 | templates = ['{}'] 323 | with torch.no_grad(): 324 | for label in labels: 325 | if isinstance(label, list): 326 | texts = [tmpl.format(lbl) for tmpl in templates for lbl in label] 327 | else: 328 | texts = [tmpl.format(label) for tmpl in templates] 329 | texts = clip.tokenize(texts).to('cuda') 330 | 331 | class_embeddings = model.encode_text(texts) 332 | class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True) 333 | 334 | class_embeddings = class_embeddings.mean(dim=0) 335 | class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True) 336 | 337 | text_features.append(class_embeddings) 338 | text_features = torch.stack(text_features, dim=0) 339 | 340 | mean, std = [0.485* 255, 0.456* 255, 0.406* 255], [0.229* 255, 0.224* 255, 0.225* 255] 341 | mean = (0.48145466 * 255,0.4578275 * 255,0.40821073 * 255) 342 | std = (0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255) 343 | import kornia as K 344 | gpu_val_transform_ls = [K.enhance.Normalize(mean=mean, std=std)] 345 | transform_gpu = torch.nn.Sequential(*gpu_val_transform_ls) 346 | 347 | top1ac = 0 348 | top1total = 0 349 | 350 | top5ac = 0 351 | top5total = 0 352 | 353 | for i in range(len(dataset)): 354 | with torch.no_grad(): 355 | frames,frames_slow,label = dataset[i] 356 | frames = frames.to('cuda').unsqueeze(0).to(torch.float16) 357 | frames = frames.permute(0, 2, 1, 3, 4) 358 | image_embed = model.encode_image(frames)[0] 359 | image_embed = F.normalize(image_embed, dim=-1) 360 | similarities = F.cosine_similarity(image_embed, text_features, dim=1) 361 | top1_values, top1_indices = torch.topk(similarities, k=1, dim=-1) 362 | top5_values, top5_indices = torch.topk(similarities, k=5, dim=-1) 363 | #label2word = dataset.mapping_act2narration[label] 364 | 365 | top1total += 1 366 | top5total += 1 367 | 368 | if label in top1_indices: 369 | top1ac += 1 370 | if label in top5_indices: 371 | top5ac += 1 372 | 373 | print(f'top1acc: {top1ac / top1total}') 374 | print(f'large top5acc: {top5ac / top5total}') 375 | print('---------------------------------') 376 | 377 | if __name__ == '__main__': 378 | args = get_args_parser() 379 | args = args.parse_args() 380 | main(args) 381 | -------------------------------------------------------------------------------- /evaluation/eval_mir.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | import argparse 12 | import datetime 13 | import json 14 | import numpy as np 15 | import os 16 | import time 17 | from pathlib import Path 18 | 19 | import torch 20 | import torch.backends.cudnn as cudnn 21 | from torch.utils.tensorboard import SummaryWriter 22 | import torchvision.transforms as transforms 23 | import torchvision.datasets as datasets 24 | 25 | import timm 26 | 27 | import timm.optim.optim_factory as optim_factory 28 | 29 | import util.misc as misc 30 | from util.misc import NativeScalerWithGradNormCount as NativeScaler 31 | import util.dist_utils as dist_utils 32 | import models_mae 33 | 34 | from engine_pretrain import train_one_epoch,validate_ek100_mir_zeroshot,build_transform 35 | from util.config import get_config 36 | from dataset.egodataset import EgoExoDataset 37 | from dataset.ek100dataset import EK100Dataset 38 | from model.clip import * 39 | import torch 40 | import torch.cuda.amp as amp 41 | from model import loss 42 | def get_args_parser(): 43 | parser = argparse.ArgumentParser('EK-MIR eval', add_help=False) 44 | parser.add_argument('--accum_iter', default=1, type=int, 45 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') 46 | parser.add_argument('--config_file', default='configs/no_decoder/debug_clip_base.yml', type=str,help='config file') 47 | 48 | parser.add_argument('--output_dir', default='./output_dir', 49 | help='path where to save, empty for no saving') 50 | parser.add_argument('--log_dir', default='./output_dir', 51 | help='path where to tensorboard log') 52 | parser.add_argument('--device', default='cuda', 53 | help='device to use for training / testing') 54 | parser.add_argument('--seed', default=0, type=int) 55 | parser.add_argument('--resume', default='', 56 | help='resume from checkpoint') 57 | 58 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 59 | help='start epoch') 60 | parser.add_argument('--num_workers', default=10, type=int) 61 | parser.add_argument('--pin_mem', action='store_true', 62 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 63 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 64 | parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.') 65 | parser.add_argument('--world_size', default=1, type=int, 66 | help='number of nodes for distributed training') 67 | parser.add_argument('--rank', default=0, type=int, 68 | help='node rank for distributed training') 69 | parser.add_argument("--local_rank", type=int, default=0) 70 | parser.set_defaults(pin_mem=True) 71 | 72 | return parser 73 | 74 | 75 | def main(args): 76 | 77 | dist_utils.init_distributed_mode(args) 78 | config = get_config(args) 79 | 80 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 81 | print("{}".format(args).replace(', ', ',\n')) 82 | 83 | device = torch.device(args.device) 84 | dist_utils.random_seed(args.seed, dist_utils.get_rank()) 85 | 86 | transform_train = build_transform(config.model.name,mode='train') 87 | transform_val = build_transform(config.model.name,mode='val') 88 | 89 | crop_size = 336 if "_336PX" in config.model.name else 224 90 | tokenizer = None 91 | 92 | train_dataset = EgoExoDataset( 93 | config.data, transform=transform_train, is_training=True, tokenizer=tokenizer, crop_size=crop_size 94 | ) 95 | val_dataset = EK100Dataset(config.test.ek100_mir, transform=transform_val, is_training=False, tokenizer=None, crop_size=crop_size) 96 | 97 | if dist_utils.get_rank() == 0 and args.log_dir is not None: 98 | os.makedirs(args.log_dir, exist_ok=True) 99 | log_writer = SummaryWriter(log_dir=args.log_dir) 100 | else: 101 | log_writer = None 102 | 103 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 104 | data_loader_train = torch.utils.data.DataLoader( 105 | train_dataset, batch_size=config.train.batch_size, shuffle=(train_sampler is None), 106 | # collate_fn=collect if config.data.dataset == 'htego_feat' else None, 107 | collate_fn = None, 108 | num_workers=config.train.workers, pin_memory=False, sampler=train_sampler, drop_last=True, 109 | ) 110 | 111 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False) 112 | val_loader = torch.utils.data.DataLoader( 113 | val_dataset, batch_size=config.test.batch_size, shuffle=False, 114 | num_workers=config.train.workers, pin_memory=False, sampler=val_sampler, drop_last=False 115 | ) 116 | 117 | model_name = config.model.name 118 | if model_name == 'CLIP_VITB16': 119 | model = CLIP_VITB16( 120 | config=config.model, 121 | freeze_temperature=config.model.freeze_temperature, 122 | use_grad_checkpointing=config.model.grad_checkpointing, 123 | context_length=config.data.context_length, 124 | vocab_size=config.data.vocab_size, 125 | patch_dropout=config.model.patch_dropout, 126 | num_frames=config.data.clip_length, 127 | drop_path_rate=config.model.drop_path_rate, 128 | use_fast_conv1=config.model.use_fast_conv1, 129 | use_flash_attn=config.model.use_flash_attn, 130 | use_quick_gelu=True, 131 | project_embed_dim=config.model.project_embed_dim, 132 | pretrain_zoo=config.model.pretrain_zoo, 133 | pretrain_path=config.model.pretrain_path, 134 | ) 135 | elif model_name == 'CLIP_VITL14_336PX': 136 | model = CLIP_VITL14_336PX( 137 | config=config.model, 138 | freeze_temperature=config.model.freeze_temperature, 139 | use_grad_checkpointing=config.model.grad_checkpointing, 140 | context_length=config.data.context_length, 141 | vocab_size=config.data.vocab_size, 142 | patch_dropout=config.model.patch_dropout, 143 | num_frames=config.data.clip_length, 144 | drop_path_rate=config.model.drop_path_rate, 145 | use_fast_conv1=config.model.use_fast_conv1, 146 | use_flash_attn=config.model.use_flash_attn, 147 | use_quick_gelu=True, 148 | project_embed_dim=config.model.project_embed_dim, 149 | pretrain_zoo=config.model.pretrain_zoo, 150 | pretrain_path=config.model.pretrain_path, 151 | ) 152 | elif model_name == 'CLIP_VITL14_336PX_Slowfast': 153 | model = CLIP_VITL14_336PX_Slowfast( 154 | config=config.model, 155 | freeze_temperature=config.model.freeze_temperature, 156 | use_grad_checkpointing=config.model.grad_checkpointing, 157 | context_length=config.data.context_length, 158 | vocab_size=config.data.vocab_size, 159 | patch_dropout=config.model.patch_dropout, 160 | num_frames=config.data.clip_length, 161 | drop_path_rate=config.model.drop_path_rate, 162 | use_fast_conv1=config.model.use_fast_conv1, 163 | use_flash_attn=config.model.use_flash_attn, 164 | use_quick_gelu=True, 165 | project_embed_dim=config.model.project_embed_dim, 166 | pretrain_zoo=config.model.pretrain_zoo, 167 | pretrain_path=config.model.pretrain_path, 168 | ) 169 | elif model_name == 'CLIP_VITB16_Slowfast': 170 | model = CLIP_VITB16_Slowfast( 171 | config=config.model, 172 | freeze_temperature=config.model.freeze_temperature, 173 | use_grad_checkpointing=config.model.grad_checkpointing, 174 | context_length=config.data.context_length, 175 | vocab_size=config.data.vocab_size, 176 | patch_dropout=config.model.patch_dropout, 177 | num_frames=config.data.clip_length, 178 | drop_path_rate=config.model.drop_path_rate, 179 | use_fast_conv1=config.model.use_fast_conv1, 180 | use_flash_attn=config.model.use_flash_attn, 181 | use_quick_gelu=True, 182 | project_embed_dim=config.model.project_embed_dim, 183 | pretrain_zoo=config.model.pretrain_zoo, 184 | pretrain_path=config.model.pretrain_path, 185 | ) 186 | 187 | model.to(device) 188 | 189 | model_without_ddp = model 190 | 191 | if args.distributed: 192 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) 193 | model_without_ddp = model.module 194 | 195 | 196 | # following timm: set wd as 0 for bias and norm layers 197 | param_groups = optim_factory.add_weight_decay(model_without_ddp, config.train.optimizer.wd) 198 | optimizer = torch.optim.AdamW(param_groups, lr=config.train.lr, betas=(0.9, 0.999)) 199 | print(optimizer) 200 | scaler = amp.GradScaler(enabled=not config.train.disable_amp) 201 | 202 | print(f"Start training for {config.train.epochs} epochs") 203 | start_time = time.time() 204 | start_epoch = 0 205 | 206 | criterion = loss.ClipLoss( 207 | local_loss=config.train.local_loss, 208 | gather_with_grad=config.train.gather_with_grad, 209 | cache_labels=True, 210 | rank=args.rank, 211 | world_size=args.world_size, 212 | ).cuda(args.gpu) 213 | 214 | if config.resume: 215 | print("=> loading resume checkpoint '{}'".format(config.resume)) 216 | curr_checkpoint = torch.load(config.resume, map_location='cpu') 217 | 218 | result = model.load_state_dict(curr_checkpoint['state_dict']) 219 | val_stats = validate_ek100_mir_zeroshot(val_loader, model=model, criterion=criterion, args=args, config=config,split=0) 220 | 221 | else: 222 | raise NotImplementedError("Not implemented") 223 | 224 | if __name__ == '__main__': 225 | args = get_args_parser() 226 | args = args.parse_args() 227 | main(args) 228 | -------------------------------------------------------------------------------- /exps/eval_egomcq.sh: -------------------------------------------------------------------------------- 1 | export MASTER_PORT=$((12000 + $RANDOM % 20000)) 2 | export OMP_NUM_THREADS=1 3 | echo "PYTHONPATH: ${PYTHONPATH}" 4 | which_python=$(which python) 5 | echo "which python: ${which_python}" 6 | export PYTHONPATH=${PYTHONPATH}:${which_python} 7 | export PYTHONPATH=${PYTHONPATH}:. 8 | echo "PYTHONPATH: ${PYTHONPATH}" 9 | MASTER_NODE=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 10 | ALL_NODES=$(scontrol show hostnames "$SLURM_JOB_NODELIST") 11 | MASTER_PORT=$((10000 + $RANDOM % 100)) 12 | 13 | JOB_NAME='eval_egomcq' 14 | PARTITION='HOD' 15 | 16 | srun -p ${PARTITION} \ 17 | --job-name=${JOB_NAME} \ 18 | --gres=gpu:1 \ 19 | -u python evaluation/eval_egomcq.py \ 20 | --config_file configs/no_decoder/clip_base_eval.yml \ 21 | --root ego4d/videos_short320_chunked_15s/ \ 22 | --metadata annotations/egomcq.json \ 23 | --crop_size 224 -------------------------------------------------------------------------------- /exps/eval_egtea.sh: -------------------------------------------------------------------------------- 1 | export MASTER_PORT=$((12000 + $RANDOM % 20000)) 2 | export OMP_NUM_THREADS=1 3 | echo "PYTHONPATH: ${PYTHONPATH}" 4 | which_python=$(which python) 5 | echo "which python: ${which_python}" 6 | export PYTHONPATH=${PYTHONPATH}:${which_python} 7 | export PYTHONPATH=${PYTHONPATH}:. 8 | echo "PYTHONPATH: ${PYTHONPATH}" 9 | MASTER_NODE=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 10 | ALL_NODES=$(scontrol show hostnames "$SLURM_JOB_NODELIST") 11 | MASTER_PORT=$((10000 + $RANDOM % 100)) 12 | 13 | JOB_NAME='eval_egtea' 14 | PARTITION='HOD' 15 | 16 | srun -p ${PARTITION} \ 17 | --job-name=${JOB_NAME} \ 18 | --gres=gpu:1 \ 19 | -u python evaluation/eval_egtea.py \ 20 | --config_file configs/config/clip_base_eval.yml \ 21 | --root egtea_gaze/cropped_clips \ 22 | --metadata egtea \ 23 | --crop_size 224 -------------------------------------------------------------------------------- /exps/eval_ekcls.sh: -------------------------------------------------------------------------------- 1 | export MASTER_PORT=$((12000 + $RANDOM % 20000)) 2 | export OMP_NUM_THREADS=1 3 | echo "PYTHONPATH: ${PYTHONPATH}" 4 | which_python=$(which python) 5 | echo "which python: ${which_python}" 6 | export PYTHONPATH=${PYTHONPATH}:${which_python} 7 | export PYTHONPATH=${PYTHONPATH}:. 8 | echo "PYTHONPATH: ${PYTHONPATH}" 9 | MASTER_NODE=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 10 | ALL_NODES=$(scontrol show hostnames "$SLURM_JOB_NODELIST") 11 | MASTER_PORT=$((10000 + $RANDOM % 100)) 12 | 13 | JOB_NAME='eval_ekcls' 14 | PARTITION='HOD' 15 | 16 | srun -p ${PARTITION} \ 17 | --job-name=${JOB_NAME} \ 18 | --gres=gpu:1 \ 19 | -u python evaluation/eval_ekcls.py \ 20 | --config_file configs/no_decoder/clip_base_eval.yml \ 21 | --root epic/epic_video_320p/ \ 22 | --metadata epic_kitchen/ \ 23 | --crop_size 224 -------------------------------------------------------------------------------- /exps/eval_mir.sh: -------------------------------------------------------------------------------- 1 | export MASTER_PORT=$((12000 + $RANDOM % 20000)) 2 | export OMP_NUM_THREADS=1 3 | echo "PYTHONPATH: ${PYTHONPATH}" 4 | which_python=$(which python) 5 | echo "which python: ${which_python}" 6 | export PYTHONPATH=${PYTHONPATH}:${which_python} 7 | export PYTHONPATH=${PYTHONPATH}:. 8 | echo "PYTHONPATH: ${PYTHONPATH}" 9 | MASTER_NODE=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 10 | ALL_NODES=$(scontrol show hostnames "$SLURM_JOB_NODELIST") 11 | MASTER_PORT=$((10000 + $RANDOM % 100)) 12 | 13 | JOB_NAME='eval_mir' 14 | PARTITION='HOD' 15 | GPUS=16 16 | GPUS_PER_NODE=8 17 | CPUS_PER_TASK=12 18 | NNODE=1 19 | JOB_DIR='./log/' 20 | 21 | srun -p ${PARTITION} \ 22 | --job-name=${JOB_NAME} \ 23 | --gres=gpu:8 \ 24 | --ntasks=8 \ 25 | --ntasks-per-node=8 \ 26 | --cpus-per-task=${CPUS_PER_TASK} \ 27 | -u python evaluation/eval_mir.py \ 28 | --config_file configs/no_decoder/clip_base_eval.yml \ 29 | -------------------------------------------------------------------------------- /exps/pretrain.sh: -------------------------------------------------------------------------------- 1 | export MASTER_PORT=$((12000 + $RANDOM % 20000)) 2 | export OMP_NUM_THREADS=1 3 | echo "PYTHONPATH: ${PYTHONPATH}" 4 | which_python=$(which python) 5 | echo "which python: ${which_python}" 6 | export PYTHONPATH=${PYTHONPATH}:${which_python} 7 | export PYTHONPATH=${PYTHONPATH}:. 8 | echo "PYTHONPATH: ${PYTHONPATH}" 9 | MASTER_NODE=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 10 | ALL_NODES=$(scontrol show hostnames "$SLURM_JOB_NODELIST") 11 | MASTER_PORT=$((10000 + $RANDOM % 100)) 12 | 13 | JOB_NAME='pretrain' 14 | PARTITION='HOD' 15 | GPUS=16 16 | GPUS_PER_NODE=8 17 | CPUS_PER_TASK=12 18 | NNODE=1 19 | JOB_DIR='./log/' 20 | 21 | srun -p ${PARTITION} \ 22 | --job-name=${JOB_NAME} \ 23 | --gres=gpu:8 \ 24 | --ntasks=16 \ 25 | --ntasks-per-node=8 \ 26 | --cpus-per-task=${CPUS_PER_TASK} \ 27 | -u python main_pretrain.py \ 28 | --config_file configs/config/clip_base.yml \ 29 | -------------------------------------------------------------------------------- /exps/pretrain_large.sh: -------------------------------------------------------------------------------- 1 | export MASTER_PORT=$((12000 + $RANDOM % 20000)) 2 | export OMP_NUM_THREADS=1 3 | echo "PYTHONPATH: ${PYTHONPATH}" 4 | which_python=$(which python) 5 | echo "which python: ${which_python}" 6 | export PYTHONPATH=${PYTHONPATH}:${which_python} 7 | export PYTHONPATH=${PYTHONPATH}:. 8 | echo "PYTHONPATH: ${PYTHONPATH}" 9 | MASTER_NODE=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 10 | ALL_NODES=$(scontrol show hostnames "$SLURM_JOB_NODELIST") 11 | MASTER_PORT=$((10000 + $RANDOM % 100)) 12 | 13 | JOB_NAME='pretrain' 14 | PARTITION='HOD' 15 | GPUS=16 16 | GPUS_PER_NODE=8 17 | CPUS_PER_TASK=12 18 | NNODE=1 19 | JOB_DIR='./log/' 20 | 21 | srun -p ${PARTITION} \ 22 | --job-name=${JOB_NAME} \ 23 | --gres=gpu:8 \ 24 | --ntasks=16 \ 25 | --ntasks-per-node=8 \ 26 | --cpus-per-task=${CPUS_PER_TASK} \ 27 | -u python main_pretrain.py \ 28 | --config_file configs/no_decoder/debug_clip_large.yml \ -------------------------------------------------------------------------------- /main_pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | import argparse 12 | import datetime 13 | import json 14 | import numpy as np 15 | import os 16 | import time 17 | from pathlib import Path 18 | 19 | import torch 20 | import torch.backends.cudnn as cudnn 21 | from torch.utils.tensorboard import SummaryWriter 22 | import torchvision.transforms as transforms 23 | import torchvision.datasets as datasets 24 | 25 | import timm 26 | 27 | import timm.optim.optim_factory as optim_factory 28 | 29 | import util.misc as misc 30 | from util.misc import NativeScalerWithGradNormCount as NativeScaler 31 | import util.dist_utils as dist_utils 32 | 33 | from engine_pretrain import train_one_epoch,validate_ek100_mir_zeroshot,build_transform 34 | from util.config import get_config 35 | from dataset.egodataset import EgoExoDataset 36 | from dataset.ek100dataset import EK100Dataset 37 | from model.clip import * 38 | import torch 39 | import torch.cuda.amp as amp 40 | from model import loss 41 | def get_args_parser(): 42 | parser = argparse.ArgumentParser('HOD pre-training', add_help=False) 43 | parser.add_argument('--accum_iter', default=1, type=int, 44 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') 45 | parser.add_argument('--config_file', default='configs/no_decoder/debug_clip_base.yml', type=str, 46 | help='config file') 47 | 48 | # Dataset parameters 49 | parser.add_argument('--output_dir', default='./output_dir', 50 | help='path where to save, empty for no saving') 51 | parser.add_argument('--log_dir', default='./output_dir', 52 | help='path where to tensorboard log') 53 | parser.add_argument('--device', default='cuda', 54 | help='device to use for training / testing') 55 | parser.add_argument('--seed', default=0, type=int) 56 | parser.add_argument('--resume', default='', 57 | help='resume from checkpoint') 58 | 59 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 60 | help='start epoch') 61 | parser.add_argument('--num_workers', default=10, type=int) 62 | parser.add_argument('--pin_mem', action='store_true', 63 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 64 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 65 | parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.') 66 | parser.add_argument('--world_size', default=1, type=int, 67 | help='number of nodes for distributed training') 68 | parser.add_argument('--rank', default=0, type=int, 69 | help='node rank for distributed training') 70 | parser.add_argument("--local_rank", type=int, default=0) 71 | parser.set_defaults(pin_mem=True) 72 | 73 | return parser 74 | 75 | 76 | def main(args): 77 | 78 | dist_utils.init_distributed_mode(args) 79 | config = get_config(args) 80 | 81 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 82 | print("{}".format(args).replace(', ', ',\n')) 83 | 84 | device = torch.device(args.device) 85 | dist_utils.random_seed(args.seed, dist_utils.get_rank()) 86 | 87 | transform_train = build_transform(config.model.name,mode='train') 88 | transform_val = build_transform(config.model.name,mode='val') 89 | 90 | crop_size = 336 if "_336PX" in config.model.name else 224 91 | tokenizer = None 92 | 93 | train_dataset = EgoExoDataset( 94 | config.data, transform=transform_train, is_training=True, tokenizer=tokenizer, crop_size=crop_size 95 | ) 96 | val_dataset = EK100Dataset(config.test.ek100_mir, transform=transform_val, is_training=False, tokenizer=None, crop_size=crop_size) 97 | 98 | if dist_utils.get_rank() == 0 and args.log_dir is not None: 99 | os.makedirs(args.log_dir, exist_ok=True) 100 | log_writer = SummaryWriter(log_dir=args.log_dir) 101 | else: 102 | log_writer = None 103 | 104 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 105 | data_loader_train = torch.utils.data.DataLoader( 106 | train_dataset, batch_size=config.train.batch_size, shuffle=(train_sampler is None), 107 | # collate_fn=collect if config.data.dataset == 'htego_feat' else None, 108 | collate_fn = None, 109 | num_workers=config.train.workers, pin_memory=False, sampler=train_sampler, drop_last=True, 110 | ) 111 | 112 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False) 113 | val_loader = torch.utils.data.DataLoader( 114 | val_dataset, batch_size=config.test.batch_size, shuffle=False, 115 | num_workers=config.train.workers, pin_memory=False, sampler=val_sampler, drop_last=False 116 | ) 117 | 118 | model_name = config.model.name 119 | if model_name == 'CLIP_VITB16': 120 | model = CLIP_VITB16( 121 | config=config.model, 122 | freeze_temperature=config.model.freeze_temperature, 123 | use_grad_checkpointing=config.model.grad_checkpointing, 124 | context_length=config.data.context_length, 125 | vocab_size=config.data.vocab_size, 126 | patch_dropout=config.model.patch_dropout, 127 | num_frames=config.data.clip_length, 128 | drop_path_rate=config.model.drop_path_rate, 129 | use_fast_conv1=config.model.use_fast_conv1, 130 | use_flash_attn=config.model.use_flash_attn, 131 | use_quick_gelu=True, 132 | project_embed_dim=config.model.project_embed_dim, 133 | pretrain_zoo=config.model.pretrain_zoo, 134 | pretrain_path=config.model.pretrain_path, 135 | ) 136 | elif model_name == 'CLIP_VITL14_336PX': 137 | model = CLIP_VITL14_336PX( 138 | config=config.model, 139 | freeze_temperature=config.model.freeze_temperature, 140 | use_grad_checkpointing=config.model.grad_checkpointing, 141 | context_length=config.data.context_length, 142 | vocab_size=config.data.vocab_size, 143 | patch_dropout=config.model.patch_dropout, 144 | num_frames=config.data.clip_length, 145 | drop_path_rate=config.model.drop_path_rate, 146 | use_fast_conv1=config.model.use_fast_conv1, 147 | use_flash_attn=config.model.use_flash_attn, 148 | use_quick_gelu=True, 149 | project_embed_dim=config.model.project_embed_dim, 150 | pretrain_zoo=config.model.pretrain_zoo, 151 | pretrain_path=config.model.pretrain_path, 152 | ) 153 | elif model_name == 'CLIP_VITL14_336PX_Slowfast': 154 | model = CLIP_VITL14_336PX_Slowfast( 155 | config=config.model, 156 | freeze_temperature=config.model.freeze_temperature, 157 | use_grad_checkpointing=config.model.grad_checkpointing, 158 | context_length=config.data.context_length, 159 | vocab_size=config.data.vocab_size, 160 | patch_dropout=config.model.patch_dropout, 161 | num_frames=config.data.clip_length, 162 | drop_path_rate=config.model.drop_path_rate, 163 | use_fast_conv1=config.model.use_fast_conv1, 164 | use_flash_attn=config.model.use_flash_attn, 165 | use_quick_gelu=True, 166 | project_embed_dim=config.model.project_embed_dim, 167 | pretrain_zoo=config.model.pretrain_zoo, 168 | pretrain_path=config.model.pretrain_path, 169 | ) 170 | elif model_name == 'CLIP_VITB16_Slowfast': 171 | model = CLIP_VITB16_Slowfast( 172 | config=config.model, 173 | freeze_temperature=config.model.freeze_temperature, 174 | use_grad_checkpointing=config.model.grad_checkpointing, 175 | context_length=config.data.context_length, 176 | vocab_size=config.data.vocab_size, 177 | patch_dropout=config.model.patch_dropout, 178 | num_frames=config.data.clip_length, 179 | drop_path_rate=config.model.drop_path_rate, 180 | use_fast_conv1=config.model.use_fast_conv1, 181 | use_flash_attn=config.model.use_flash_attn, 182 | use_quick_gelu=True, 183 | project_embed_dim=config.model.project_embed_dim, 184 | pretrain_zoo=config.model.pretrain_zoo, 185 | pretrain_path=config.model.pretrain_path, 186 | ) 187 | 188 | model.to(device) 189 | 190 | model_without_ddp = model 191 | 192 | if args.distributed: 193 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) 194 | model_without_ddp = model.module 195 | 196 | 197 | # following timm: set wd as 0 for bias and norm layers 198 | param_groups = optim_factory.add_weight_decay(model_without_ddp, config.train.optimizer.wd) 199 | optimizer = torch.optim.AdamW(param_groups, lr=config.train.lr, betas=(0.9, 0.999)) 200 | print(optimizer) 201 | scaler = amp.GradScaler(enabled=not config.train.disable_amp) 202 | 203 | print(f"Start training for {config.train.epochs} epochs") 204 | start_time = time.time() 205 | start_epoch = 0 206 | 207 | criterion = loss.ClipLoss( 208 | local_loss=config.train.local_loss, 209 | gather_with_grad=config.train.gather_with_grad, 210 | cache_labels=True, 211 | rank=args.rank, 212 | world_size=args.world_size, 213 | ).cuda(args.gpu) 214 | 215 | if config.resume: 216 | print("=> loading resume checkpoint '{}'".format(config.resume)) 217 | curr_checkpoint = torch.load(config.resume, map_location='cpu') 218 | 219 | result = model.load_state_dict(curr_checkpoint['state_dict']) 220 | 221 | 222 | 223 | for epoch in range(start_epoch, config.train.epochs): 224 | train_stats = train_one_epoch( 225 | model, data_loader_train, 226 | optimizer, device, epoch, scaler, 227 | log_writer=log_writer, 228 | args=args,criterion=criterion 229 | ) 230 | 231 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 232 | 'epoch': epoch,} 233 | val_stats = validate_ek100_mir_zeroshot(val_loader, model=model, criterion=criterion, args=args, config=config,split=epoch) 234 | 235 | dist_utils.save_on_master({ 236 | 'epoch': epoch + 1, 237 | 'state_dict': model.state_dict(), 238 | 'val_state':val_stats, 239 | }, config.output_dir) 240 | 241 | 242 | if __name__ == '__main__': 243 | args = get_args_parser() 244 | args = args.parse_args() 245 | main(args) 246 | -------------------------------------------------------------------------------- /model/__pycache__/clip.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenRobotLab/EgoHOD/e874d04e42e716e3c9a85b2697450e61e696d48a/model/__pycache__/clip.cpython-310.pyc -------------------------------------------------------------------------------- /model/__pycache__/loss.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenRobotLab/EgoHOD/e874d04e42e716e3c9a85b2697450e61e696d48a/model/__pycache__/loss.cpython-310.pyc -------------------------------------------------------------------------------- /model/__pycache__/timesformer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenRobotLab/EgoHOD/e874d04e42e716e3c9a85b2697450e61e696d48a/model/__pycache__/timesformer.cpython-310.pyc -------------------------------------------------------------------------------- /model/__pycache__/transformer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenRobotLab/EgoHOD/e874d04e42e716e3c9a85b2697450e61e696d48a/model/__pycache__/transformer.cpython-310.pyc -------------------------------------------------------------------------------- /model/clip.py: -------------------------------------------------------------------------------- 1 | import clip 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from timm.models.layers import trunc_normal_ 7 | 8 | from .transformer import TextTransformer, VisionTransformer, VisionTransformer_Slowfast 9 | from .timesformer import SpaceTimeTransformer 10 | from ipdb import set_trace 11 | from einops import rearrange 12 | import torch.cuda.amp as amp 13 | 14 | from easydict import EasyDict 15 | import sys 16 | import torch 17 | from PIL import Image 18 | import pickle 19 | from collections import OrderedDict 20 | import torch 21 | import torch.nn as nn 22 | import torch.nn.functional as F 23 | from ipdb import set_trace 24 | import functools 25 | 26 | # util functions to convert OpenCLIP-style model keys to ViT-style 27 | def remap_keys_from_open_clip_to_vit( 28 | clip_state_dict, 29 | visual_transformer_layers=12, 30 | textual_transformer_layers=12, 31 | context_length=77, 32 | vocab_size=49408, 33 | use_fast_conv1=False, 34 | use_flash_attn=False, 35 | ): 36 | if 'state_dict' in clip_state_dict: 37 | clip_state_dict = clip_state_dict['state_dict'] 38 | if list(clip_state_dict.keys())[0].startswith('module.'): 39 | clip_state_dict = OrderedDict({ 40 | k.replace('module.', ''): v for k, v in clip_state_dict.items() 41 | }) 42 | remapped_state_dict = OrderedDict() 43 | key_mapping = { 44 | "logit_scale": "logit_scale", 45 | "visual.proj": "visual.image_projection", 46 | "positional_embedding": "textual.positional_embedding", 47 | "text_projection": "textual.text_projection", 48 | "token_embedding.weight": "textual.token_embedding.weight", 49 | "ln_final.weight": "textual.ln_final.weight", 50 | "ln_final.bias": "textual.ln_final.bias" 51 | } 52 | 53 | for layer in range(visual_transformer_layers): 54 | if use_flash_attn: 55 | for src_name, tgt_name in { 56 | 'attn.in_proj_weight': 'attn.Wqkv.weight', 'attn.in_proj_bias': 'attn.Wqkv.bias', 57 | 'attn.out_proj.weight': 'attn.out_proj.weight', 'attn.out_proj.bias': 'attn.out_proj.bias', 58 | 'mlp.c_fc.weight': 'mlp.fc1.weight', 'mlp.c_fc.bias': 'mlp.fc1.bias', 59 | 'mlp.c_proj.weight': 'mlp.fc2.weight', 'mlp.c_proj.bias': 'mlp.fc2.bias', 60 | }.items(): 61 | key_mapping[f"visual.transformer.resblocks.{layer}.{src_name}"] = f"visual.transformer.resblocks.{layer}.{tgt_name}" 62 | 63 | 64 | for layer in range(textual_transformer_layers): 65 | for name in [ 66 | 'attn.in_proj_weight', 'attn.in_proj_bias', 'attn.out_proj.weight', 'attn.out_proj.bias', 67 | 'ln_1.weight', 'ln_1.bias', 'ln_2.weight', 'ln_2.bias', 68 | 'mlp.c_fc.weight', 'mlp.c_fc.bias', 'mlp.c_proj.weight', 'mlp.c_proj.bias', 69 | ]: 70 | key_mapping[f"transformer.resblocks.{layer}.{name}"] = f"textual.transformer.resblocks.{layer}.{name}" 71 | 72 | for key in clip_state_dict: 73 | if key in ["visual.proj", "text_projection", "logit_scale"]: 74 | continue 75 | if use_fast_conv1 and key == 'visual.conv1.weight': 76 | remapped_state_dict['visual.conv1.weight'] = clip_state_dict[key].flatten(1) 77 | elif key not in key_mapping: 78 | remapped_state_dict[key] = clip_state_dict[key] 79 | else: 80 | if key == 'positional_embedding': 81 | old_context_length, dim = clip_state_dict[key].shape 82 | old_dtype = clip_state_dict[key].dtype 83 | if context_length <= old_context_length: 84 | remapped_state_dict[key_mapping[key]] = clip_state_dict[key][:context_length, :] 85 | else: 86 | remapped_state_dict[key_mapping[key]] = torch.cat( 87 | (clip_state_dict[key], torch.zeros((context_length - old_context_length, dim), dtype=old_dtype)), dim=0 88 | ) 89 | elif key == 'token_embedding.weight': 90 | old_vocab_size, dim = clip_state_dict[key].shape 91 | old_dtype = clip_state_dict[key].dtype 92 | assert vocab_size >= old_vocab_size 93 | remapped_state_dict[key_mapping[key]] = torch.cat( 94 | (clip_state_dict[key], torch.zeros((vocab_size - old_vocab_size, dim), dtype=old_dtype)), dim=0 95 | ) 96 | else: 97 | remapped_state_dict[key_mapping[key]] = clip_state_dict[key] 98 | 99 | return remapped_state_dict 100 | 101 | def CLIP_VITB16( 102 | config, 103 | freeze_temperature=False, 104 | use_grad_checkpointing=False, 105 | use_bidirectional_lm=False, 106 | context_length=77, 107 | patch_dropout=0., 108 | drop_path_rate=0., 109 | num_frames=1, 110 | use_fast_conv1=False, 111 | use_flash_attn=False, 112 | project_embed_dim=512, 113 | pretrain_zoo='openai', 114 | pretrain_path=None, 115 | **kwargs 116 | ): 117 | # vision_model = timm.create_model('vit_base_patch16_224', num_classes=0) 118 | vision_model = VisionTransformer( 119 | 224, 16, 768, 12, 12, 4, 120 | output_dim=project_embed_dim, patch_dropout=patch_dropout, 121 | drop_path_rate=drop_path_rate, 122 | num_frames=num_frames, 123 | use_fast_conv1=use_fast_conv1, 124 | use_flash_attn=use_flash_attn, 125 | ) 126 | 127 | text_model = TextTransformer(context_length=77, vocab_size=49408, width=512, heads=8, layers=12, output_dim=project_embed_dim, causal_mask=not use_bidirectional_lm) 128 | model = CLIP(embed_dim=project_embed_dim, vision_model=vision_model, text_model=text_model, freeze_temperature=freeze_temperature,ckpt_path=config.lavila_path) 129 | 130 | print("=> loading openai model") 131 | clip_model, preprocess = clip.load(config.ckpt_path, device='cpu') 132 | remapped_state_dict = remap_keys_from_open_clip_to_vit( 133 | clip_model.state_dict(), 134 | use_fast_conv1=use_fast_conv1, 135 | use_flash_attn=use_flash_attn, 136 | ) 137 | 138 | missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=False) 139 | print("missing_keys: ", missing_keys) 140 | print("unexpected_keys: ", unexpected_keys) 141 | 142 | return model 143 | 144 | def CLIP_VITL14_336PX( 145 | config, 146 | freeze_temperature=False, 147 | use_grad_checkpointing=False, 148 | use_bidirectional_lm=False, 149 | context_length=77, 150 | vocab_size=49408, 151 | patch_dropout=0., 152 | drop_path_rate=0., 153 | num_frames=1, 154 | use_fast_conv1=False, 155 | use_flash_attn=False, 156 | project_embed_dim=512, 157 | pretrain_zoo='openai', 158 | pretrain_path=None, 159 | **kwargs 160 | ): 161 | vision_model = VisionTransformer( 162 | 336, 14, 1024, 24, 16, 4, 163 | output_dim=project_embed_dim, patch_dropout=patch_dropout, 164 | drop_path_rate=drop_path_rate, 165 | num_frames=num_frames, 166 | use_fast_conv1=use_fast_conv1, 167 | use_flash_attn=use_flash_attn, 168 | ) 169 | text_model = TextTransformer(context_length=context_length, vocab_size=vocab_size, width=768, heads=12, layers=12, output_dim=project_embed_dim, causal_mask=not use_bidirectional_lm) 170 | model = CLIP(embed_dim=project_embed_dim, vision_model=vision_model, text_model=text_model, freeze_temperature=freeze_temperature,ckpt_path=config.lavila_path) 171 | 172 | print("=> loading openai model") 173 | clip_model, preprocess = clip.load(config.ckpt_path, device='cpu') 174 | remapped_state_dict = remap_keys_from_open_clip_to_vit( 175 | clip_model.state_dict(), 176 | visual_transformer_layers=24, 177 | use_fast_conv1=use_fast_conv1, 178 | use_flash_attn=use_flash_attn, 179 | ) 180 | missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=False) 181 | print("missing_keys: ", missing_keys) 182 | print("unexpected_keys: ", unexpected_keys) 183 | 184 | return model 185 | 186 | def CLIP_VITL14_336PX_Slowfast( 187 | config, 188 | freeze_temperature=False, 189 | use_grad_checkpointing=False, 190 | use_bidirectional_lm=False, 191 | context_length=77, 192 | vocab_size=49408, 193 | patch_dropout=0., 194 | drop_path_rate=0., 195 | num_frames=1, 196 | use_fast_conv1=False, 197 | use_flash_attn=False, 198 | project_embed_dim=512, 199 | pretrain_zoo='openai', 200 | pretrain_path=None, 201 | **kwargs 202 | ): 203 | # vision_model = timm.create_model('vit_base_patch16_224', num_classes=0) 204 | vision_model = VisionTransformer_Slowfast( 205 | 336, 14, 1024, 24, 16, 4, 206 | output_dim=project_embed_dim, patch_dropout=patch_dropout, 207 | drop_path_rate=drop_path_rate, 208 | num_frames=num_frames, 209 | use_fast_conv1=use_fast_conv1, 210 | use_flash_attn=use_flash_attn, 211 | ) 212 | text_model = TextTransformer(context_length=context_length, vocab_size=vocab_size, width=768, heads=12, layers=12, output_dim=project_embed_dim, causal_mask=not use_bidirectional_lm) 213 | model = CLIP_Slowfast(embed_dim=project_embed_dim, vision_model=vision_model, text_model=text_model, freeze_temperature=freeze_temperature,ckpt_path=config.lavila_path) 214 | 215 | print("=> loading openai model") 216 | clip_model, preprocess = clip.load(config.ckpt_path, device='cpu') 217 | remapped_state_dict = remap_keys_from_open_clip_to_vit( 218 | clip_model.state_dict(), 219 | visual_transformer_layers=24, 220 | use_fast_conv1=use_fast_conv1, 221 | use_flash_attn=use_flash_attn, 222 | ) 223 | missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=False) 224 | print("missing_keys: ", missing_keys) 225 | print("unexpected_keys: ", unexpected_keys) 226 | return model 227 | 228 | def CLIP_VITB16_Slowfast( 229 | config, 230 | freeze_temperature=False, 231 | use_grad_checkpointing=False, 232 | use_bidirectional_lm=False, 233 | context_length=77, 234 | patch_dropout=0., 235 | drop_path_rate=0., 236 | num_frames=1, 237 | use_fast_conv1=False, 238 | use_flash_attn=False, 239 | project_embed_dim=512, 240 | pretrain_zoo='openai', 241 | pretrain_path=None, 242 | **kwargs 243 | ): 244 | # vision_model = timm.create_model('vit_base_patch16_224', num_classes=0) 245 | vision_model = VisionTransformer_Slowfast( 246 | 224, 16, 768, 12, 12, 4, 247 | output_dim=project_embed_dim, patch_dropout=patch_dropout, 248 | drop_path_rate=drop_path_rate, 249 | num_frames=num_frames, 250 | use_fast_conv1=use_fast_conv1, 251 | use_flash_attn=use_flash_attn, 252 | ) 253 | text_model = TextTransformer(context_length=context_length, vocab_size=49408, width=512, heads=8, layers=12, output_dim=project_embed_dim, causal_mask=not use_bidirectional_lm) 254 | model = CLIP_Slowfast(embed_dim=project_embed_dim, vision_model=vision_model, text_model=text_model, freeze_temperature=freeze_temperature,ckpt_path=config.lavila_path) 255 | 256 | print("=> loading openai model") 257 | clip_model, preprocess = clip.load(config.ckpt_path, device='cpu') 258 | remapped_state_dict = remap_keys_from_open_clip_to_vit( 259 | clip_model.state_dict(), 260 | use_fast_conv1=use_fast_conv1, 261 | use_flash_attn=use_flash_attn, 262 | ) 263 | missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=False) 264 | print("missing_keys: ", missing_keys) 265 | print("unexpected_keys: ", unexpected_keys) 266 | 267 | return model 268 | 269 | class CLIP(nn.Module): 270 | def __init__(self, 271 | embed_dim: int, 272 | vision_model: nn.Module, 273 | text_model: nn.Module, 274 | vision_width: int = None, 275 | text_width: int = None, 276 | freeze_temperature=False, 277 | ckpt_path=None, 278 | **kwargs 279 | ): 280 | super().__init__() 281 | 282 | self.visual = vision_model 283 | self.textual = text_model 284 | 285 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 286 | if freeze_temperature: 287 | self.logit_scale.requires_grad_(False) 288 | 289 | if vision_width is not None: 290 | self.vision_width = vision_width 291 | self.image_projection = nn.Parameter(torch.empty(vision_width, embed_dim)) 292 | else: 293 | self.image_projection = None 294 | if text_width is not None: 295 | self.text_width = text_width 296 | self.text_projection = nn.Parameter(torch.empty(text_width, embed_dim)) 297 | else: 298 | self.text_projection = None 299 | self.init_parameters() 300 | 301 | def init_parameters(self): 302 | if self.image_projection is not None: 303 | trunc_normal_(self.image_projection, std=self.vision_width ** -0.5) 304 | if self.text_projection is not None: 305 | trunc_normal_(self.text_projection, std=self.text_width ** -0.5) 306 | 307 | def encode_visual(self, image): 308 | return self.encode_image(image) 309 | 310 | def encode_image(self, image): 311 | 312 | x_pooling,x = self.visual(image) 313 | if self.image_projection is not None: 314 | x = x @ self.image_projection.to(x.dtype) 315 | return x_pooling,x 316 | 317 | def encode_text(self, text, cast_dtype=None): 318 | if len(text.shape) > 2: 319 | text = text.squeeze() 320 | x = self.textual(text) 321 | if self.text_projection is not None: 322 | x = x @ self.text_projection.to(x.dtype) 323 | return x 324 | 325 | def forward(self,image,slow, text,eval_mode=False): 326 | 327 | image_embed,_ = self.encode_image(image) 328 | #print(image_embed.dtype) 329 | text_embed = self.encode_text(text, cast_dtype=image_embed.dtype) 330 | return F.normalize(image_embed, dim=-1), F.normalize(text_embed, dim=-1), self.logit_scale.exp() 331 | 332 | 333 | class CLIP_Slowfast(nn.Module): 334 | def __init__(self, 335 | embed_dim: int, 336 | vision_model: nn.Module, 337 | text_model: nn.Module, 338 | vision_width: int = None, 339 | text_width: int = None, 340 | freeze_temperature=False, 341 | ckpt_path=None, 342 | **kwargs 343 | ): 344 | super().__init__() 345 | 346 | self.visual = vision_model 347 | self.textual = text_model 348 | 349 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 350 | if freeze_temperature: 351 | self.logit_scale.requires_grad_(False) 352 | 353 | if vision_width is not None: 354 | self.vision_width = vision_width 355 | self.image_projection = nn.Parameter(torch.empty(vision_width, embed_dim)) 356 | else: 357 | self.image_projection = None 358 | if text_width is not None: 359 | self.text_width = text_width 360 | self.text_projection = nn.Parameter(torch.empty(text_width, embed_dim)) 361 | else: 362 | self.text_projection = None 363 | 364 | self.slowfast_projection = nn.Parameter(torch.empty(embed_dim*2, embed_dim)) 365 | self.init_parameters() 366 | 367 | for n, p in self.named_parameters(): 368 | if ('adapter' in n) or ('projection' in n) or ('visual.class_embedding' in n) or ('visual.temporal_embedding') in n: 369 | p.requires_grad_(True) 370 | print(n,p.requires_grad) 371 | else: 372 | p.requires_grad_(False) 373 | 374 | 375 | n_trainable_params = 0 376 | for n, p in self.named_parameters(): 377 | if p.requires_grad: 378 | n_trainable_params += p.numel() 379 | print('Total trainable params:', n_trainable_params, '(%.2f M)' % (n_trainable_params / 1000000)) 380 | 381 | 382 | def init_parameters(self): 383 | if self.image_projection is not None: 384 | trunc_normal_(self.image_projection, std=self.vision_width ** -0.5) 385 | if self.text_projection is not None: 386 | trunc_normal_(self.text_projection, std=self.text_width ** -0.5) 387 | 388 | trunc_normal_(self.slowfast_projection, std=(1024) ** -0.5) 389 | 390 | def encode_visual(self, image): 391 | return self.encode_image(image) 392 | 393 | def encode_image(self, image): 394 | 395 | x_pooling,x = self.visual(image) 396 | if self.image_projection is not None: 397 | x = x @ self.image_projection.to(x.dtype) 398 | return x_pooling,x 399 | 400 | def encode_text(self, text, cast_dtype=None): 401 | text = text.squeeze() 402 | if len(text.shape) == 1: 403 | text = text.unsqueeze(0) 404 | x = self.textual(text) 405 | if self.text_projection is not None: 406 | x = x @ self.text_projection.to(x.dtype) 407 | return x 408 | 409 | def encode_slowfast(self,image, image_slow): 410 | image_embed_slow,all_embed_slow = self.encode_image(image_slow) 411 | #image_temp = rearrange(image,'b c (t1 t2) h w->(b t1) c t2 h w',t1=4,t2=4) 412 | image_embed,all_embed= self.encode_image(image) 413 | # image_embed = rearrange(image_embed,'(b t1) c->b t1 c',t1=4) 414 | # image_embed = image_embed.mean(dim=1) 415 | 416 | image_embed = torch.cat((image_embed_slow,image_embed),dim=-1) 417 | image_embed = image_embed @ self.slowfast_projection.to(image_embed.dtype) 418 | 419 | return F.normalize(image_embed, dim=-1) 420 | 421 | def forward(self, image, image_slow,text,eval_mode=False): 422 | 423 | image_embed_slow,all_embed_slow = self.encode_image(image_slow) 424 | #image_temp = rearrange(image,'b c (t1 t2) h w->(b t1) c t2 h w',t1=4,t2=4) 425 | image_embed,all_embed= self.encode_image(image) 426 | # image_embed = rearrange(image_embed,'(b t1) c->b t1 c',t1=4) 427 | # image_embed = image_embed.mean(dim=1) 428 | 429 | image_embed = torch.cat((image_embed_slow,image_embed),dim=-1) 430 | image_embed = image_embed @ self.slowfast_projection.to(image_embed.dtype) 431 | 432 | 433 | text_embed = self.encode_text(text, cast_dtype=image_embed.dtype) 434 | 435 | return F.normalize(image_embed, dim=-1), F.normalize(text_embed, dim=-1), self.logit_scale.exp() 436 | -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | from einops import rearrange 5 | try: 6 | import torch.distributed.nn 7 | from torch import distributed as dist 8 | has_distributed = True 9 | except ImportError: 10 | has_distributed = False 11 | 12 | try: 13 | import horovod.torch as hvd 14 | except ImportError: 15 | hvd = None 16 | 17 | from ipdb import set_trace 18 | 19 | def gather_hand_feature(hand_box,l_valid,r_valid,left_data,right_data, 20 | local_loss=False, 21 | gather_with_grad=False, 22 | rank=0, 23 | world_size=1, 24 | use_horovod=False): 25 | if gather_with_grad: 26 | all_hand_box = torch.cat(torch.distributed.nn.all_gather(hand_box), dim=0) 27 | all_l_valid = torch.cat(torch.distributed.nn.all_gather(l_valid), dim=0) 28 | else: 29 | gathered_hand_box = [torch.zeros_like(hand_box) for _ in range(world_size)] 30 | gathered_l_valid = [torch.zeros_like(l_valid) for _ in range(world_size)] 31 | gathered_r_valid = [torch.zeros_like(r_valid) for _ in range(world_size)] 32 | gathered_left_data = [torch.zeros_like(left_data) for _ in range(world_size)] 33 | gathered_right_data = [torch.zeros_like(right_data) for _ in range(world_size)] 34 | dist.all_gather(gathered_hand_box, hand_box) 35 | dist.all_gather(gathered_l_valid, l_valid) 36 | dist.all_gather(gathered_r_valid, r_valid) 37 | dist.all_gather(gathered_left_data, left_data) 38 | dist.all_gather(gathered_right_data, right_data) 39 | if not local_loss: 40 | # ensure grads for local rank when all_* features don't have a gradient 41 | gathered_hand_box[rank] = hand_box 42 | gathered_l_valid[rank] = l_valid 43 | gathered_r_valid[rank] = r_valid 44 | gathered_left_data[rank] = left_data 45 | gathered_right_data[rank] = right_data 46 | all_hand_box = torch.cat(gathered_hand_box, dim=0) 47 | all_l_valid = torch.cat(gathered_l_valid, dim=0) 48 | all_r_valid = torch.cat(gathered_r_valid, dim=0) 49 | all_left_data = torch.cat(gathered_left_data, dim=0) 50 | all_right_data = torch.cat(gathered_right_data, dim=0) 51 | return all_hand_box,all_l_valid,all_r_valid,all_left_data,all_right_data 52 | 53 | def gather_features( 54 | image_features, 55 | text_features, 56 | local_loss=False, 57 | gather_with_grad=False, 58 | rank=0, 59 | world_size=1, 60 | use_horovod=False 61 | ): 62 | assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.' 63 | if use_horovod: 64 | assert hvd is not None, 'Please install horovod' 65 | if gather_with_grad: 66 | all_image_features = hvd.allgather(image_features) 67 | all_text_features = hvd.allgather(text_features) 68 | else: 69 | with torch.no_grad(): 70 | all_image_features = hvd.allgather(image_features) 71 | all_text_features = hvd.allgather(text_features) 72 | if not local_loss: 73 | # ensure grads for local rank when all_* features don't have a gradient 74 | gathered_image_features = list(all_image_features.chunk(world_size, dim=0)) 75 | gathered_text_features = list(all_text_features.chunk(world_size, dim=0)) 76 | gathered_image_features[rank] = image_features 77 | gathered_text_features[rank] = text_features 78 | all_image_features = torch.cat(gathered_image_features, dim=0) 79 | all_text_features = torch.cat(gathered_text_features, dim=0) 80 | else: 81 | # We gather tensors from all gpus 82 | if gather_with_grad: 83 | all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0) 84 | all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0) 85 | else: 86 | gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)] 87 | gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)] 88 | dist.all_gather(gathered_image_features, image_features) 89 | dist.all_gather(gathered_text_features, text_features) 90 | if not local_loss: 91 | # ensure grads for local rank when all_* features don't have a gradient 92 | gathered_image_features[rank] = image_features 93 | gathered_text_features[rank] = text_features 94 | all_image_features = torch.cat(gathered_image_features, dim=0) 95 | all_text_features = torch.cat(gathered_text_features, dim=0) 96 | 97 | return all_image_features, all_text_features 98 | 99 | def loss_boxes(outputs, targets, num_boxes): 100 | 101 | loss_bbox = F.l1_loss(outputs, targets, reduction='none') 102 | 103 | losses = F.l1_loss(outputs, targets, reduction='none').sum() / num_boxes 104 | 105 | # loss_giou = 1 - torch.diag(generalized_box_iou( 106 | # box_cxcywh_to_xyxy(outputs), 107 | # box_cxcywh_to_xyxy(targets))) 108 | # losses['loss_giou'] = loss_giou.sum() / num_boxes 109 | return losses 110 | 111 | class ClipLoss(nn.Module): 112 | 113 | def __init__( 114 | self, 115 | local_loss=False, 116 | gather_with_grad=False, 117 | cache_labels=False, 118 | rank=0, 119 | world_size=1, 120 | use_horovod=False, 121 | ): 122 | super().__init__() 123 | self.local_loss = local_loss 124 | self.gather_with_grad = gather_with_grad 125 | self.cache_labels = cache_labels 126 | self.rank = rank 127 | self.world_size = world_size 128 | self.use_horovod = use_horovod 129 | 130 | # cache state 131 | self.prev_num_logits = 0 132 | self.labels = {} 133 | 134 | def forward(self, image_features, text_features, logit_scale,hand_box=None,l_valid=None,r_valid=None, 135 | left_data=None,right_data=None): 136 | device = image_features.device 137 | 138 | if self.world_size > 1: 139 | all_image_features, all_text_features = gather_features( 140 | image_features, text_features, 141 | self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) 142 | if hand_box is not None: 143 | all_hand_box,all_l_valid,all_r_valid,all_left_data,all_right_data = gather_hand_feature( 144 | hand_box,l_valid,r_valid,left_data,right_data,self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) 145 | 146 | 147 | all_left_hand = all_hand_box[all_l_valid>0][:,:4,:] 148 | all_left_hand = rearrange(all_left_hand,'b t c->b (t c)') 149 | all_left_target = all_left_data[all_l_valid>0] 150 | all_right_hand = all_hand_box[all_r_valid>0][:,4:,:] #B,4,4 151 | all_right_hand = rearrange(all_right_hand,'b t c->b (t c)') 152 | all_right_target = all_right_data[all_r_valid>0] #[B,16] 153 | loss2 = loss_boxes(all_right_hand,all_right_target,all_right_hand.shape[0]*4) 154 | loss3 = loss_boxes(all_left_hand,all_left_target,all_left_hand.shape[0]*4) 155 | loss1 = (loss2 + loss3) / 2 156 | if self.local_loss: 157 | logits_per_image = logit_scale * image_features @ all_text_features.T 158 | logits_per_text = logit_scale * text_features @ all_image_features.T 159 | else: 160 | logits_per_image = logit_scale * all_image_features @ all_text_features.T 161 | logits_per_text = logits_per_image.T 162 | else: 163 | logits_per_image = logit_scale * image_features @ text_features.T 164 | logits_per_text = logit_scale * text_features @ image_features.T 165 | 166 | # calculated ground-truth and cache if enabled 167 | num_logits = logits_per_image.shape[0] 168 | if self.prev_num_logits != num_logits or device not in self.labels: 169 | labels = torch.arange(num_logits, device=device, dtype=torch.long) 170 | if self.world_size > 1 and self.local_loss: 171 | labels = labels + num_logits * self.rank 172 | if self.cache_labels: 173 | self.labels[device] = labels 174 | self.prev_num_logits = num_logits 175 | else: 176 | labels = self.labels[device] 177 | 178 | vlp_loss = ( 179 | F.cross_entropy(logits_per_image, labels) + 180 | F.cross_entropy(logits_per_text, labels) 181 | ) / 2 182 | 183 | with torch.no_grad(): 184 | pred = torch.argmax(logits_per_image, dim=-1) 185 | correct = pred.eq(labels).sum() 186 | acc = 100 * correct / logits_per_image.size(0) 187 | 188 | if hand_box is not None: 189 | total_loss = vlp_loss + loss1 190 | return {'vlp_loss': vlp_loss, 'clip_acc': acc,'box_loss':loss1,'loss':vlp_loss + loss1} 191 | else: 192 | return {'clip_acc': acc,'loss':vlp_loss} 193 | 194 | class Multiview_Cliploss(nn.Module): 195 | def __init__( 196 | self, 197 | local_loss=False, 198 | gather_with_grad=False, 199 | cache_labels=False, 200 | rank=0, 201 | world_size=1, 202 | use_horovod=False, 203 | ): 204 | super().__init__() 205 | self.local_loss = local_loss 206 | self.gather_with_grad = gather_with_grad 207 | self.cache_labels = cache_labels 208 | self.rank = rank 209 | self.world_size = world_size 210 | self.use_horovod = use_horovod 211 | 212 | self.clip_loss = ClipLoss(local_loss=self.local_loss, 213 | gather_with_grad=self.gather_with_grad, cache_labels=self.cache_labels, 214 | rank=self.rank, world_size=self.world_size, use_horovod=self.use_horovod) 215 | 216 | self.multiview_loss = 'global' 217 | assert self.multiview_loss in ['global'] 218 | 219 | def forward(self, image_features, text_features, logit_scale, ego_features, exo_features, multiview_logit_scale): 220 | vt_loss = self.clip_loss(image_features, text_features, logit_scale) 221 | # set_trace() 222 | if self.multiview_loss == 'global': 223 | ego_features = ego_features.mean(1) 224 | exo_features = exo_features.mean(1) 225 | vv_loss = self.clip_loss(ego_features, exo_features, multiview_logit_scale) 226 | else: 227 | vv_loss = {} 228 | 229 | print(vt_loss, vv_loss) 230 | set_trace() 231 | loss_dict = { 232 | 'loss': vt_loss['loss'] + vv_loss['loss'], 233 | 'v2t': vt_loss['loss'], 'vt_acc': vt_loss['clip_acc'], 234 | 'v2v': vv_loss['loss'], 'vv_acc': vv_loss['clip_acc'], 235 | } 236 | return loss_dict 237 | 238 | 239 | def sim_matrix(a, b, eps=1e-8): 240 | """ 241 | added eps for numerical stability 242 | """ 243 | a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None] 244 | a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n)) 245 | b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n)) 246 | sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1)) 247 | return sim_mt 248 | 249 | 250 | class MaxMarginRankingLoss(nn.Module): 251 | 252 | def __init__( 253 | self, 254 | margin=0.2, 255 | fix_norm=True, 256 | local_loss=False, 257 | gather_with_grad=False, 258 | rank=0, 259 | world_size=1, 260 | use_horovod=False, 261 | ): 262 | super().__init__() 263 | self.fix_norm = fix_norm 264 | self.margin = margin 265 | self.local_loss = local_loss 266 | self.gather_with_grad = gather_with_grad 267 | self.rank = rank 268 | self.world_size = world_size 269 | self.use_horovod = use_horovod 270 | 271 | def forward(self, image_features, text_features, weight=None): 272 | # TODO: try gather_from_all in 273 | # https://github.com/facebookresearch/LaViLa/blob/main/lavila/models/distributed_utils.py 274 | # all_image_features = gather_from_all(image_features) 275 | # all_text_features = gather_from_all(text_features) 276 | all_image_features, all_text_features = gather_features( 277 | image_features, text_features, 278 | self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) 279 | 280 | 281 | x = sim_matrix(all_text_features, all_image_features) 282 | 283 | n = x.size()[0] 284 | 285 | x1 = torch.diag(x) 286 | x1 = x1.unsqueeze(1) 287 | x1 = x1.expand(n, n) 288 | x1 = x1.contiguous().view(-1, 1) 289 | x1 = torch.cat((x1, x1), 0) 290 | 291 | x2 = x.view(-1, 1) 292 | x3 = x.transpose(0, 1).contiguous().view(-1, 1) 293 | 294 | x2 = torch.cat((x2, x3), 0) 295 | max_margin = F.relu(self.margin - (x1 - x2)) 296 | 297 | if self.fix_norm: 298 | # remove the elements from the diagonal 299 | keep = torch.ones(x.shape) - torch.eye(x.shape[0]) # 128 x 128 300 | keep1 = keep.view(-1, 1) 301 | keep2 = keep.transpose(0, 1).contiguous().view(-1, 1) 302 | keep_idx = torch.nonzero(torch.cat((keep1, keep2), 0).flatten()).flatten() 303 | if x1.is_cuda: 304 | keep_idx = keep_idx.cuda() 305 | x1_ = torch.index_select(x1, dim=0, index=keep_idx) 306 | x2_ = torch.index_select(x2, dim=0, index=keep_idx) 307 | max_margin = F.relu(self.margin - (x1_ - x2_)) 308 | 309 | return { 310 | 'loss': max_margin.mean() 311 | } 312 | 313 | 314 | 315 | class CaptionLoss(nn.Module): 316 | def __init__(self, pad_id=0, tokenizer=None): 317 | super().__init__() 318 | self.pad_id = pad_id 319 | self.tokenizer = tokenizer 320 | self.pad_id = tokenizer.pad_token_id 321 | 322 | def forward(self, outputs): 323 | logits = outputs['text_tokens_logits'] 324 | labels = outputs['labels'] 325 | # loss = F.cross_entropy(logits, labels, ignore_index=self.pad_id) 326 | loss = F.cross_entropy(logits, labels, ignore_index=self.pad_id, reduction='none') 327 | 328 | # compute accuracy 329 | with torch.no_grad(): 330 | correct = 0. 331 | total = 0. 332 | ppls = [] 333 | for i in range(logits.size(0)): 334 | pred = torch.argmax(logits[i], dim=0) 335 | nopad = labels[i].ne(self.pad_id) 336 | correct += (pred.eq(labels[i]) & nopad).sum() 337 | total += nopad.sum() 338 | ppl = torch.exp(loss[i].sum() / nopad.sum()) 339 | ppls.append(ppl) 340 | # TODO: for debug only 341 | # sep_pos = labels[i].tolist().index(self.tokenizer.tokenizer.sep_token_id) 342 | # if self.tokenizer is not None: 343 | # print('{} {} {}'.format( 344 | # i, self.tokenizer.tokenizer.convert_ids_to_tokens(pred[:sep_pos]), 345 | # self.tokenizer.tokenizer.convert_ids_to_tokens(labels[i, :sep_pos]), 346 | # )) 347 | acc = 100 * correct / (total + 1e-8) 348 | return {'loss': loss.mean(), 'caption_loss': loss.mean(), 'caption_acc': acc, 'ppl': torch.tensor(ppls).mean()} 349 | 350 | -------------------------------------------------------------------------------- /output_dir/events.out.tfevents.1740572442.SH-IDC1-10-140-37-2.156563.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenRobotLab/EgoHOD/e874d04e42e716e3c9a85b2697450e61e696d48a/output_dir/events.out.tfevents.1740572442.SH-IDC1-10-140-37-2.156563.0 -------------------------------------------------------------------------------- /output_dir/events.out.tfevents.1740573180.SH-IDC1-10-140-37-41.259551.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenRobotLab/EgoHOD/e874d04e42e716e3c9a85b2697450e61e696d48a/output_dir/events.out.tfevents.1740573180.SH-IDC1-10-140-37-41.259551.0 -------------------------------------------------------------------------------- /output_dir/events.out.tfevents.1740573310.SH-IDC1-10-140-37-41.261269.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenRobotLab/EgoHOD/e874d04e42e716e3c9a85b2697450e61e696d48a/output_dir/events.out.tfevents.1740573310.SH-IDC1-10-140-37-41.261269.0 -------------------------------------------------------------------------------- /output_dir/events.out.tfevents.1740573743.SH-IDC1-10-140-37-41.119535.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenRobotLab/EgoHOD/e874d04e42e716e3c9a85b2697450e61e696d48a/output_dir/events.out.tfevents.1740573743.SH-IDC1-10-140-37-41.119535.0 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | decord==0.6.0 2 | easydict==1.9 3 | einops==0.8.1 4 | flash_attn==2.7.4.post1 5 | func_timeout==4.3.5 6 | horovod==0.28.1 7 | ipdb==0.13.11 8 | kornia==0.6.10 9 | mmengine==0.10.3 10 | nltk==3.8.1 11 | numpy==1.25.2 12 | omegaconf==2.2.3 13 | opencv_python==4.5.4.58 14 | pandas==1.5.1 15 | petrel_oss_sdk==2.2.2.post10086 16 | Pillow==9.3.0 17 | Pillow==11.1.0 18 | PyYAML==6.0 19 | PyYAML==6.0.2 20 | scikit_learn==1.2.2 21 | submitit==1.5.2 22 | timm==0.5.4 23 | torch==1.13.1 24 | torchvision==0.14.1 25 | tqdm==4.64.1 26 | -------------------------------------------------------------------------------- /util/__pycache__/config.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenRobotLab/EgoHOD/e874d04e42e716e3c9a85b2697450e61e696d48a/util/__pycache__/config.cpython-310.pyc -------------------------------------------------------------------------------- /util/__pycache__/dist_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenRobotLab/EgoHOD/e874d04e42e716e3c9a85b2697450e61e696d48a/util/__pycache__/dist_utils.cpython-310.pyc -------------------------------------------------------------------------------- /util/__pycache__/lr_sched.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenRobotLab/EgoHOD/e874d04e42e716e3c9a85b2697450e61e696d48a/util/__pycache__/lr_sched.cpython-310.pyc -------------------------------------------------------------------------------- /util/__pycache__/meter.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenRobotLab/EgoHOD/e874d04e42e716e3c9a85b2697450e61e696d48a/util/__pycache__/meter.cpython-310.pyc -------------------------------------------------------------------------------- /util/__pycache__/misc.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenRobotLab/EgoHOD/e874d04e42e716e3c9a85b2697450e61e696d48a/util/__pycache__/misc.cpython-310.pyc -------------------------------------------------------------------------------- /util/__pycache__/pos_embed.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenRobotLab/EgoHOD/e874d04e42e716e3c9a85b2697450e61e696d48a/util/__pycache__/pos_embed.cpython-310.pyc -------------------------------------------------------------------------------- /util/config.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import os.path as osp 4 | 5 | from omegaconf import OmegaConf 6 | 7 | 8 | def load_config(cfg_file): 9 | cfg = OmegaConf.load(cfg_file) 10 | if '_base_' in cfg: 11 | if isinstance(cfg._base_, str): 12 | base_cfg = OmegaConf.load(osp.join(osp.dirname(cfg_file), cfg._base_)) 13 | else: 14 | base_cfg = OmegaConf.merge(OmegaConf.load(f) for f in cfg._base_) 15 | cfg = OmegaConf.merge(base_cfg, cfg) 16 | return cfg 17 | 18 | def get_config(args): 19 | cfg = load_config(args.config_file) 20 | OmegaConf.set_struct(cfg, True) 21 | OmegaConf.set_readonly(cfg, True) 22 | 23 | return cfg 24 | -------------------------------------------------------------------------------- /util/crop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | import torch 10 | 11 | from torchvision import transforms 12 | from torchvision.transforms import functional as F 13 | 14 | 15 | class RandomResizedCrop(transforms.RandomResizedCrop): 16 | """ 17 | RandomResizedCrop for matching TF/TPU implementation: no for-loop is used. 18 | This may lead to results different with torchvision's version. 19 | Following BYOL's TF code: 20 | https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206 21 | """ 22 | @staticmethod 23 | def get_params(img, scale, ratio): 24 | width, height = F._get_image_size(img) 25 | area = height * width 26 | 27 | target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() 28 | log_ratio = torch.log(torch.tensor(ratio)) 29 | aspect_ratio = torch.exp( 30 | torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) 31 | ).item() 32 | 33 | w = int(round(math.sqrt(target_area * aspect_ratio))) 34 | h = int(round(math.sqrt(target_area / aspect_ratio))) 35 | 36 | w = min(w, width) 37 | h = min(h, height) 38 | 39 | i = torch.randint(0, height - h + 1, size=(1,)).item() 40 | j = torch.randint(0, width - w + 1, size=(1,)).item() 41 | 42 | return i, j, h, w -------------------------------------------------------------------------------- /util/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # -------------------------------------------------------- 10 | 11 | import os 12 | import PIL 13 | 14 | from torchvision import datasets, transforms 15 | 16 | from timm.data import create_transform 17 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 18 | 19 | 20 | def build_dataset(is_train, args): 21 | transform = build_transform(is_train, args) 22 | 23 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 24 | dataset = datasets.ImageFolder(root, transform=transform) 25 | 26 | print(dataset) 27 | 28 | return dataset 29 | 30 | 31 | def build_transform(is_train, args): 32 | mean = IMAGENET_DEFAULT_MEAN 33 | std = IMAGENET_DEFAULT_STD 34 | # train transform 35 | if is_train: 36 | # this should always dispatch to transforms_imagenet_train 37 | transform = create_transform( 38 | input_size=args.input_size, 39 | is_training=True, 40 | color_jitter=args.color_jitter, 41 | auto_augment=args.aa, 42 | interpolation='bicubic', 43 | re_prob=args.reprob, 44 | re_mode=args.remode, 45 | re_count=args.recount, 46 | mean=mean, 47 | std=std, 48 | ) 49 | return transform 50 | 51 | # eval transform 52 | t = [] 53 | if args.input_size <= 224: 54 | crop_pct = 224 / 256 55 | else: 56 | crop_pct = 1.0 57 | size = int(args.input_size / crop_pct) 58 | t.append( 59 | transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images 60 | ) 61 | t.append(transforms.CenterCrop(args.input_size)) 62 | 63 | t.append(transforms.ToTensor()) 64 | t.append(transforms.Normalize(mean, std)) 65 | return transforms.Compose(t) 66 | -------------------------------------------------------------------------------- /util/dist_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import shutil 9 | import torch 10 | import torch.distributed as dist 11 | import subprocess 12 | import sys 13 | import random 14 | import numpy as np 15 | from ipdb import set_trace 16 | 17 | def random_seed(seed=42, rank=0): 18 | torch.manual_seed(seed + rank) 19 | np.random.seed(seed + rank) 20 | random.seed(seed + rank) 21 | 22 | 23 | def get_model(model): 24 | if isinstance(model, torch.nn.DataParallel) \ 25 | or isinstance(model, torch.nn.parallel.DistributedDataParallel): 26 | return model.module 27 | else: 28 | return model 29 | 30 | 31 | def setup_for_distributed(is_master): 32 | """ 33 | This function disables printing when not in master process 34 | """ 35 | import builtins as __builtin__ 36 | builtin_print = __builtin__.print 37 | 38 | def print(*args, **kwargs): 39 | force = kwargs.pop('force', False) 40 | if is_master or force: 41 | builtin_print(*args, **kwargs) 42 | 43 | __builtin__.print = print 44 | 45 | 46 | def is_dist_avail_and_initialized(): 47 | if not dist.is_available(): 48 | return False 49 | if not dist.is_initialized(): 50 | return False 51 | return True 52 | 53 | 54 | def get_world_size(): 55 | if not is_dist_avail_and_initialized(): 56 | return 1 57 | else: 58 | return dist.get_world_size() 59 | 60 | 61 | def get_rank(): 62 | if not is_dist_avail_and_initialized(): 63 | return 0 64 | return dist.get_rank() 65 | 66 | 67 | def is_main_process(): 68 | return get_rank() == 0 69 | 70 | 71 | def filter_checkpoint(model): 72 | new_state_dict = {} 73 | for k, v in state_dict.items(): 74 | if v.requires_grad: 75 | new_state_dict[k] = v 76 | 77 | return new_state_dict 78 | 79 | def save_on_master(state, output_dir, is_epoch=True): 80 | if is_main_process(): 81 | epoch = state['epoch'] 82 | ckpt_path = f'{output_dir}/checkpoint.pt' 83 | best_path = f'{output_dir}/checkpoint_epoch{epoch}.pt' 84 | torch.save(state, best_path) 85 | 86 | 87 | def init_distributed_mode(args): 88 | # launched with torch.distributed.launch 89 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 90 | print('Init with torch.distributed.launch') 91 | args.rank = int(os.environ["RANK"]) 92 | args.world_size = int(os.environ['WORLD_SIZE']) 93 | args.gpu = int(os.environ['LOCAL_RANK']) 94 | # launched with submitit on a slurm cluster 95 | elif 'SLURM_PROCID' in os.environ: 96 | #args.rank = int(os.environ['SLURM_PROCID']) 97 | #args.gpu = args.rank % torch.cuda.device_count() 98 | print('Init with slurm cluster') 99 | 100 | proc_id = int(os.environ['SLURM_PROCID']) 101 | ntasks = os.environ['SLURM_NTASKS'] 102 | if ntasks == 1: 103 | ntasks = 8 104 | node_list = os.environ['SLURM_NODELIST'] 105 | num_gpus = torch.cuda.device_count() 106 | addr = subprocess.getoutput( 107 | 'scontrol show hostname {} | head -n1'.format(node_list) 108 | ) 109 | master_port = os.environ.get('MASTER_PORT', '29486') 110 | os.environ['MASTER_PORT'] = master_port 111 | os.environ['MASTER_ADDR'] = addr 112 | os.environ['WORLD_SIZE'] = str(ntasks) 113 | os.environ['RANK'] = str(proc_id) 114 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 115 | os.environ['LOCAL_SIZE'] = str(num_gpus) 116 | args.dist_url = 'env://' 117 | args.world_size = int(ntasks) 118 | 119 | args.rank = int(proc_id) 120 | args.gpu = int(proc_id % num_gpus) 121 | print(f'SLURM MODE: proc_id: {proc_id}, ntasks: {ntasks}, node_list: {node_list}, num_gpus:{num_gpus}, addr:{addr}, master port:{master_port}' ) 122 | 123 | # launched naively with `python main_dino.py` 124 | # we manually add MASTER_ADDR and MASTER_PORT to env variables 125 | elif torch.cuda.is_available(): 126 | print('Will run the code on one GPU.') 127 | args.rank, args.gpu, args.world_size = 0, 0, 1 128 | os.environ['MASTER_ADDR'] = '127.0.0.1' 129 | os.environ['MASTER_PORT'] = '29501' 130 | else: 131 | # print('Does not support training without GPU.') 132 | # sys.exit(1) 133 | print('Training without GPU') 134 | return 135 | 136 | 137 | dist.init_process_group( 138 | backend="nccl", 139 | init_method=args.dist_url, 140 | world_size=args.world_size, 141 | rank=args.rank, 142 | ) 143 | 144 | args.distributed = True 145 | 146 | torch.cuda.set_device(args.gpu) 147 | print('| distributed init (rank {}): {}'.format( 148 | args.rank, args.dist_url), flush=True) 149 | dist.barrier() 150 | setup_for_distributed(args.rank == 0) 151 | 152 | 153 | def scaled_all_reduce(tensors, is_scale=True): 154 | """Performs the scaled all_reduce operation on the provided tensors. 155 | The input tensors are modified in-place. Currently supports only the sum 156 | reduction operator. The reduced values are scaled by the inverse size of the 157 | world size. 158 | """ 159 | world_size = get_world_size() 160 | # There is no need for reduction in the single-proc case 161 | if world_size == 1: 162 | return tensors 163 | # Queue the reductions 164 | reductions = [] 165 | for tensor in tensors: 166 | reduction = dist.all_reduce(tensor, async_op=True) 167 | reductions.append(reduction) 168 | # Wait for reductions to finish 169 | for reduction in reductions: 170 | reduction.wait() 171 | # Scale the results 172 | if is_scale: 173 | for tensor in tensors: 174 | tensor.mul_(1.0 / world_size) 175 | return tensors 176 | -------------------------------------------------------------------------------- /util/lars.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # LARS optimizer, implementation from MoCo v3: 8 | # https://github.com/facebookresearch/moco-v3 9 | # -------------------------------------------------------- 10 | 11 | import torch 12 | 13 | 14 | class LARS(torch.optim.Optimizer): 15 | """ 16 | LARS optimizer, no rate scaling or weight decay for parameters <= 1D. 17 | """ 18 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001): 19 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient) 20 | super().__init__(params, defaults) 21 | 22 | @torch.no_grad() 23 | def step(self): 24 | for g in self.param_groups: 25 | for p in g['params']: 26 | dp = p.grad 27 | 28 | if dp is None: 29 | continue 30 | 31 | if p.ndim > 1: # if not normalization gamma/beta or bias 32 | dp = dp.add(p, alpha=g['weight_decay']) 33 | param_norm = torch.norm(p) 34 | update_norm = torch.norm(dp) 35 | one = torch.ones_like(param_norm) 36 | q = torch.where(param_norm > 0., 37 | torch.where(update_norm > 0, 38 | (g['trust_coefficient'] * param_norm / update_norm), one), 39 | one) 40 | dp = dp.mul(q) 41 | 42 | param_state = self.state[p] 43 | if 'mu' not in param_state: 44 | param_state['mu'] = torch.zeros_like(p) 45 | mu = param_state['mu'] 46 | mu.mul_(g['momentum']).add_(dp) 47 | p.add_(mu, alpha=-g['lr']) -------------------------------------------------------------------------------- /util/lr_decay.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ELECTRA https://github.com/google-research/electra 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import json 13 | 14 | 15 | def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75): 16 | """ 17 | Parameter groups for layer-wise lr decay 18 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 19 | """ 20 | param_group_names = {} 21 | param_groups = {} 22 | 23 | num_layers = len(model.blocks) + 1 24 | 25 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) 26 | 27 | for n, p in model.named_parameters(): 28 | if not p.requires_grad: 29 | continue 30 | 31 | # no decay: all 1D parameters and model specific ones 32 | if p.ndim == 1 or n in no_weight_decay_list: 33 | g_decay = "no_decay" 34 | this_decay = 0. 35 | else: 36 | g_decay = "decay" 37 | this_decay = weight_decay 38 | 39 | layer_id = get_layer_id_for_vit(n, num_layers) 40 | group_name = "layer_%d_%s" % (layer_id, g_decay) 41 | 42 | if group_name not in param_group_names: 43 | this_scale = layer_scales[layer_id] 44 | 45 | param_group_names[group_name] = { 46 | "lr_scale": this_scale, 47 | "weight_decay": this_decay, 48 | "params": [], 49 | } 50 | param_groups[group_name] = { 51 | "lr_scale": this_scale, 52 | "weight_decay": this_decay, 53 | "params": [], 54 | } 55 | 56 | param_group_names[group_name]["params"].append(n) 57 | param_groups[group_name]["params"].append(p) 58 | 59 | # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) 60 | 61 | return list(param_groups.values()) 62 | 63 | 64 | def get_layer_id_for_vit(name, num_layers): 65 | """ 66 | Assign a parameter with its layer id 67 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 68 | """ 69 | if name in ['cls_token', 'pos_embed']: 70 | return 0 71 | elif name.startswith('patch_embed'): 72 | return 0 73 | elif name.startswith('blocks'): 74 | return int(name.split('.')[1]) + 1 75 | else: 76 | return num_layers -------------------------------------------------------------------------------- /util/lr_sched.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | def adjust_learning_rate(optimizer, epoch, args): 10 | """Decay the learning rate with half-cycle cosine after warmup""" 11 | if epoch < args.warmup_epochs: 12 | lr = args.lr * epoch / args.warmup_epochs 13 | else: 14 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ 15 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) 16 | for param_group in optimizer.param_groups: 17 | if "lr_scale" in param_group: 18 | param_group["lr"] = lr * param_group["lr_scale"] 19 | else: 20 | param_group["lr"] = lr 21 | return lr 22 | -------------------------------------------------------------------------------- /util/meter.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.distributed as dist 9 | import numpy as np 10 | from ipdb import set_trace 11 | from sklearn.metrics import recall_score, precision_score 12 | 13 | class AverageMeter(object): 14 | """Computes and stores the average and current value""" 15 | def __init__(self, name, fmt=':f'): 16 | self.name = name 17 | self.fmt = fmt 18 | self.reset() 19 | 20 | def reset(self): 21 | self.val = 0 22 | self.avg = 0 23 | self.sum = 0 24 | self.count = 0 25 | 26 | def update(self, val, n=1): 27 | self.val = val 28 | self.sum += val * n 29 | self.count += n 30 | self.avg = self.sum / self.count 31 | 32 | def synchronize(self): 33 | if not dist_utils.is_dist_avail_and_initialized(): 34 | return 35 | t = torch.tensor([self.sum, self.count], dtype=torch.float64, device='cuda') 36 | dist.barrier() 37 | dist.all_reduce(t) 38 | t = t.tolist() 39 | self.sum = int(t[0]) 40 | self.count = t[1] 41 | self.avg = self.sum / self.count 42 | 43 | def __str__(self): 44 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 45 | return fmtstr.format(**self.__dict__) 46 | 47 | 48 | class ProgressMeter(object): 49 | def __init__(self, num_batches, meters, prefix=""): 50 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 51 | self.meters = meters 52 | self.prefix = prefix 53 | 54 | def display(self, batch): 55 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 56 | entries += [str(meter) for meter in self.meters] 57 | print('\t'.join(entries)) 58 | 59 | def synchronize(self): 60 | for meter in self.meters: 61 | meter.synchronize() 62 | 63 | def _get_batch_fmtstr(self, num_batches): 64 | num_digits = len(str(num_batches // 1)) 65 | fmt = '{:' + str(num_digits) + 'd}' 66 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 67 | 68 | 69 | def egomcq_accuracy_metrics(preds, labels, types, group_list = ["Inter-video", "Intra-video"]): 70 | metrics = {} 71 | type_list = torch.unique(types) 72 | # group_list = ["Intra-video", "Inter-video"] 73 | for type_i, group_i in zip(type_list, group_list): 74 | correct = 0 75 | total = 0 76 | for pred, label, typer in zip(preds, labels, types): 77 | if typer == type_i: 78 | pred_ = torch.argmax(pred) 79 | if pred_.item() == label.item(): 80 | correct += 1 81 | total += 1 82 | accuracy = correct/total 83 | metrics[group_i] = accuracy * 100 84 | return metrics 85 | 86 | 87 | 88 | def get_marginal_indexes(actions, mode): 89 | """For each verb/noun retrieve the list of actions containing that verb/name 90 | Input: 91 | mode: "verb" or "noun" 92 | Output: 93 | a list of numpy array of indexes. If verb/noun 3 is contained in actions 2,8,19, 94 | then output[3] will be np.array([2,8,19]) 95 | """ 96 | vi = [] 97 | for v in range(actions[mode].max()+1): 98 | vals = actions[actions[mode] == v].index.values 99 | if len(vals) > 0: 100 | vi.append(vals) 101 | else: 102 | vi.append(np.array([0])) 103 | return vi 104 | 105 | 106 | def marginalize(probs, indexes): 107 | mprobs = [] 108 | for ilist in indexes: 109 | mprobs.append(probs[:, ilist].sum(1)) 110 | return np.array(mprobs).T 111 | 112 | def calculate_DCG(similarity_matrix, relevancy_matrix, k_counts): 113 | """ 114 | Calculates the Discounted Cumulative Gain (DCG) between two modalities for 115 | the first modality. 116 | DCG = \sum_{i=1}^k \frac{rel_i}{log_2(i + 1)} 117 | i.e. the sum of the k relevant retrievals which is calculated as the scaled 118 | relevancy for the ith item. The scale is designed such that early 119 | retrievals are more important than later retrievals. 120 | Params: 121 | - similarity_matrix: matrix of size n1 x n2 where n1 is the number of 122 | items in the first modality and n2 is the number of items in the 123 | second modality. The [ith,jth] element is the predicted similarity 124 | between the ith item from the first modality and the jth item from 125 | the second modality. 126 | - relevancy_matrix: matrix of size n1 x n2 (see similarity_matrix 127 | above). The [ith, jth] element is the semantic relevancy between the 128 | ith item from the first modality and the jth item from the second 129 | modality. 130 | - k_counts: matrix of size n1 x n2 (see similarity_matrix above) which 131 | includes information on which items to use to calculate the DCG for 132 | (see calculate_k_counts for more info on this matrix). 133 | Returns: 134 | - The DCG for each item in the first modality, a n1 length vector. 135 | """ 136 | x_sz, y_sz = similarity_matrix.shape 137 | ranks = np.argsort(similarity_matrix)[:, ::-1] 138 | # Create vector of size (n,) where n is the length of the last dimension in 139 | # similarity matrix 140 | # This vector is of the form log(i+1) 141 | logs = np.log2(np.arange(y_sz) + 2) 142 | # Convert logs into the divisor for the DCG calculation, of size similarity 143 | # matrix 144 | divisors = np.repeat(np.expand_dims(logs, axis=0), x_sz, axis=0) 145 | 146 | # mask out the sorted relevancy matrix to only use the first k relevant 147 | # retrievals for each item. 148 | columns = np.repeat(np.expand_dims(np.arange(x_sz), axis=1), y_sz, axis=1) 149 | numerators = relevancy_matrix[columns, ranks] * k_counts 150 | # Calculate the final DCG score (note that this isn't expected to sum to 1) 151 | return np.sum(numerators / divisors, axis=1) 152 | 153 | 154 | def calculate_k_counts(relevancy_matrix): 155 | """ 156 | Works out the maximum number of allowed retrievals when working out the 157 | Discounted Cumulative Gain. For each query the DCG only uses the first k 158 | items retrieved which constitute the k relevant items for that query 159 | (otherwise the nDCG scores can be deceptively high for bad rankings). 160 | Params: 161 | - relevancy_matrix: matrix of size n1 x n2 where n1 is the number of 162 | items in the first modality and n2 is the number of items in the 163 | second modality. The [ith, jth] element is the semantic relevancy 164 | between the ith item from the first modality and the jth item from 165 | the second modality. 166 | Returns: 167 | - Matrix of size n1 x n2 (see relevancy matrix for more info). This is 168 | created as a mask such that if the [ith, jth] element is 1 it 169 | represents a valid item to use for the calculation of DCG for the 170 | ith item after sorting. For example, if relevancy matrix of: 171 | [[1, 0.5, 0], 172 | [0, 0 , 1]] 173 | is given, then the k_counts matrix will be: 174 | [[1, 1, 0], 175 | [1, 0, 0]] 176 | i.e. the first row has 2 non-zero items, so the first two retrieved 177 | items should be used in the calculation. In the second row there is 178 | only 1 relevant item, therefore only the first retrieved item should 179 | be used for the DCG calculation. 180 | """ 181 | return (np.sort(relevancy_matrix)[:, ::-1] > 0).astype(int) 182 | 183 | 184 | def calculate_IDCG(relevancy_matrix, k_counts): 185 | """ 186 | Calculates the Ideal Discounted Cumulative Gain (IDCG) which is the value 187 | of the Discounted Cumulative Gain (DCG) for a perfect retrieval, i.e. the 188 | items in the second modality were retrieved in order of their descending 189 | relevancy. 190 | Params: 191 | - relevancy_matrix: matrix of size n1 x n2 where n1 is the number of 192 | items in the first modality and n2 is the number of items in the 193 | second modality. The [ith, jth] element is the semantic relevancy 194 | between the ith item from the first modality and the jth item from 195 | the second modality. 196 | - k_counts: matrix of size n1 x n2 (see similarity_matrix above) which 197 | includes information on which items to use to calculate the DCG for 198 | (see calculate_k_counts for more info on this matrix). 199 | """ 200 | return calculate_DCG(relevancy_matrix, relevancy_matrix, k_counts) 201 | 202 | 203 | def calculate_nDCG(similarity_matrix, relevancy_matrix, k_counts=None, IDCG=None, reduction='mean'): 204 | """ 205 | Calculates the normalised Discounted Cumulative Gain (nDCG) between two 206 | modalities for the first modality using the Discounted Cumulative Gain 207 | (DCG) and the Ideal Discounted Cumulative Gain (IDCG). 208 | nDCG = \frac{DCG}{IDCG} 209 | Params: 210 | - similarity_matrix: matrix of size n1 x n2 where n1 is the number of 211 | items in the first modality and n2 is the number of items in the second 212 | modality. The [ith,jth] element is the predicted similarity between 213 | the ith item from the first modality and the jth item from the second 214 | modality. 215 | - relevancy_matrix: matrix of size n1 x n2 (see similarity_matrix 216 | above). The [ith, jth] element is the semantic relevancy between the 217 | ith item from the first modality and the jth item from the second 218 | modality. 219 | - k_counts: optional parameter: matrix of size n1 x n2 (see 220 | similarity_matrix above) which includes information on which items to 221 | use to calculate the DCG for (see calculate_k_counts for more info on 222 | this matrix). This will be calculated using calculate_IDCG if not 223 | present, but should be pre-processed for efficiency. 224 | - IDCG: Optional parameter which includes the pre-processed Ideal 225 | Discounted Cumulative Gain (IDCG). This is a vector of size n1 (see 226 | similarity_matrix above) which contains the IDCG value for each item 227 | from the first modality. This will be calculated using calculate_IDCG 228 | if not present, but should be pre-processed for efficiency. 229 | - reduction: what to use to reduce the different nDCG scores. By 230 | default this applies np.mean across all different queries. 231 | Returns: 232 | - The nDCG values for the first modality. 233 | """ 234 | if k_counts is None: 235 | k_counts = calculate_k_counts(relevancy_matrix) 236 | DCG = calculate_DCG(similarity_matrix, relevancy_matrix, k_counts) 237 | if IDCG is None: 238 | IDCG = calculate_IDCG(relevancy_matrix, k_counts) 239 | if reduction == 'mean': 240 | return np.mean(DCG / IDCG) 241 | elif reduction is None: 242 | return DCG / IDCG 243 | 244 | 245 | def calculate_mAP(sim_mat, relevancy_matrix): 246 | """ 247 | Computes the mean average precision according to the following formula of 248 | average precision: 249 | \frac{\sum_{k=1}^n p(k) x rel(k)}{num_rel_docs} 250 | where p(k) is the precision at k, rel(k) is an indicator function 251 | determining whether the kth returned item is relevant or not and 252 | num_rel_docs is the number of relevant items to find within the search. 253 | The mean average precision is the mean of the average precision for each 254 | query item (i.e row in the matrix) 255 | This function takes in two parameters: 256 | - sim_mat: a NxM matrix which represents the similarity between two 257 | modalities (with modality 1 being of size N and modality 2 of size M). 258 | - relevancy_matrix: an NxM matrix which represents the relevancy between two 259 | modalities of items (with modality 1 being of size N and modality 2 of 260 | size M). 261 | """ 262 | # Find the order of the items in modality 2 according to modality 1 263 | ranked_order = (-sim_mat).argsort() 264 | ranked_sim_mat = sim_mat[np.arange(sim_mat.shape[0])[:, None], ranked_order] 265 | # re-order the relevancy matrix to accommodate the proposals 266 | ranked_rel_mat = relevancy_matrix[np.arange(relevancy_matrix.shape[0])[:, None], ranked_order] 267 | 268 | # find the number of relevant items found at each k 269 | cumulative_rel_mat = np.cumsum(ranked_rel_mat, axis=1) 270 | # Mask this ensuring that it is non zero if the kth term is 1 (rel(k) above) 271 | cumulative_rel_mat[ranked_rel_mat != 1] = 0 272 | # find the divisor for p(k) 273 | divisor = np.arange(ranked_rel_mat.shape[1]) + 1 274 | 275 | # find the number of relevant docs per query item 276 | number_rel_docs = np.sum(ranked_rel_mat == 1, axis=1) 277 | 278 | # find the average precision per query, within np.sum finds p(k) * rel(k) 279 | avg_precision = np.sum(cumulative_rel_mat / divisor, axis=1) / number_rel_docs 280 | mAP = np.mean(avg_precision) 281 | return mAP 282 | 283 | 284 | def get_mAP(similarity_matrix, rel_matrix): 285 | vis_map = calculate_mAP(similarity_matrix, rel_matrix) 286 | txt_map = calculate_mAP(similarity_matrix.T, rel_matrix.T) 287 | return vis_map, txt_map, (vis_map + txt_map) / 2 288 | 289 | 290 | def get_nDCG(similarity_matrix, rel_matrix): 291 | vis_k_counts = calculate_k_counts(rel_matrix) 292 | txt_k_counts = calculate_k_counts(rel_matrix.T) 293 | vis_IDCG = calculate_IDCG(rel_matrix, vis_k_counts) 294 | txt_IDCG = calculate_IDCG(rel_matrix.T, txt_k_counts) 295 | vis_nDCG = calculate_nDCG(similarity_matrix, rel_matrix, k_counts=vis_k_counts, IDCG=vis_IDCG) 296 | txt_nDCG = calculate_nDCG(similarity_matrix.T, rel_matrix.T, k_counts=txt_k_counts, IDCG=txt_IDCG) 297 | return vis_nDCG, txt_nDCG, (vis_nDCG + txt_nDCG) / 2 298 | 299 | 300 | def accuracy(output, target, topk=(1,)): 301 | """Computes the accuracy over the k top predictions for the specified values of k""" 302 | with torch.no_grad(): 303 | maxk = max(topk) 304 | batch_size = target.size(0) 305 | 306 | _, pred = output.topk(maxk, 1, True, True) 307 | pred = pred.t() 308 | correct = pred.eq(target.reshape(1, -1).expand_as(pred)) 309 | 310 | res = [] 311 | for k in topk: 312 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 313 | res.append(correct_k.mul_(100.0 / batch_size)) 314 | return res 315 | 316 | def get_recall(output, target): 317 | pred = output.max(1)[1].cpu().numpy() 318 | target = target.cpu().numpy() 319 | return recall_score(target, pred) 320 | 321 | def get_mean_accuracy(cm): 322 | list_acc = [] 323 | for i in range(len(cm)): 324 | acc = 0 325 | if cm[i, :].sum() > 0: 326 | acc = cm[i, i] / cm[i, :].sum() 327 | list_acc.append(acc) 328 | 329 | return 100 * np.mean(list_acc), 100 * np.trace(cm) / np.sum(cm) 330 | 331 | 332 | def compute_map(submission_array, gt_array): 333 | """ Returns mAP, weighted mAP, and AP array """ 334 | m_aps = [] 335 | n_classes = submission_array.shape[1] 336 | for oc_i in range(n_classes): 337 | sorted_idxs = np.argsort(-submission_array[:, oc_i]) 338 | tp = gt_array[:, oc_i][sorted_idxs] == 1 339 | fp = np.invert(tp) 340 | n_pos = tp.sum() 341 | if n_pos < 0.1: 342 | m_aps.append(float('nan')) 343 | continue 344 | fp.sum() 345 | f_pcs = np.cumsum(fp) 346 | t_pcs = np.cumsum(tp) 347 | prec = t_pcs / (f_pcs+t_pcs).astype(float) 348 | avg_prec = 0 349 | for i in range(submission_array.shape[0]): 350 | if tp[i]: 351 | avg_prec += prec[i] 352 | m_aps.append(avg_prec / n_pos.astype(float)) 353 | m_aps = np.array(m_aps) 354 | m_ap = np.mean(m_aps) 355 | w_ap = (m_aps * gt_array.sum(axis=0) / gt_array.sum().sum().astype(float)) 356 | return m_ap, w_ap, m_aps 357 | 358 | 359 | def charades_map(submission_array, gt_array): 360 | """ 361 | Approximate version of the charades evaluation function 362 | For precise numbers, use the submission file with the official matlab script 363 | """ 364 | fix = submission_array.copy() 365 | empty = np.sum(gt_array, axis=1) == 0 366 | fix[empty, :] = np.NINF 367 | return compute_map(fix, gt_array) 368 | 369 | 370 | def create_submission(video_list, predictions, out_file): 371 | assert len(video_list) == predictions.shape[0] 372 | with open(out_file, 'w') as f: 373 | for i, video_id in enumerate(video_list): 374 | pred_str = ' '.join(map(lambda x: str(x), predictions[i].tolist())) 375 | f.write('{} {}\n\n'.format(video_id, pred_str)) 376 | 377 | 378 | def compute_metrics(x): 379 | sx = np.sort(-x, axis=1) 380 | d = np.diag(-x) 381 | d = d[:, np.newaxis] 382 | ind = sx - d 383 | ind = np.where(ind == 0) 384 | ind = ind[1] 385 | metrics = {} 386 | metrics['R1'] = float(np.sum(ind == 0)) * 100 / len(ind) 387 | metrics['R5'] = float(np.sum(ind < 5)) * 100 / len(ind) 388 | metrics['R10'] = float(np.sum(ind < 10)) * 100 / len(ind) 389 | metrics['MR'] = np.median(ind) + 1 390 | metrics["MedianR"] = metrics['MR'] 391 | metrics["MeanR"] = np.mean(ind) + 1 392 | metrics["cols"] = [int(i) for i in list(ind)] 393 | return metrics 394 | -------------------------------------------------------------------------------- /util/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import builtins 13 | import datetime 14 | import os 15 | import time 16 | from collections import defaultdict, deque 17 | from pathlib import Path 18 | 19 | import torch 20 | import torch.distributed as dist 21 | from torch import inf 22 | 23 | class SmoothedValue(object): 24 | """Track a series of values and provide access to smoothed values over a 25 | window or the global series average. 26 | """ 27 | 28 | def __init__(self, window_size=20, fmt=None): 29 | if fmt is None: 30 | fmt = "{median:.4f} ({global_avg:.4f})" 31 | self.deque = deque(maxlen=window_size) 32 | self.total = 0.0 33 | self.count = 0 34 | self.fmt = fmt 35 | 36 | def update(self, value, n=1): 37 | self.deque.append(value) 38 | self.count += n 39 | self.total += value * n 40 | 41 | def synchronize_between_processes(self): 42 | """ 43 | Warning: does not synchronize the deque! 44 | """ 45 | if not is_dist_avail_and_initialized(): 46 | return 47 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 48 | dist.barrier() 49 | dist.all_reduce(t) 50 | t = t.tolist() 51 | self.count = int(t[0]) 52 | self.total = t[1] 53 | 54 | @property 55 | def median(self): 56 | d = torch.tensor(list(self.deque)) 57 | return d.median().item() 58 | 59 | @property 60 | def avg(self): 61 | d = torch.tensor(list(self.deque), dtype=torch.float32) 62 | return d.mean().item() 63 | 64 | @property 65 | def global_avg(self): 66 | return self.total / self.count 67 | 68 | @property 69 | def max(self): 70 | return max(self.deque) 71 | 72 | @property 73 | def value(self): 74 | return self.deque[-1] 75 | 76 | def __str__(self): 77 | return self.fmt.format( 78 | median=self.median, 79 | avg=self.avg, 80 | global_avg=self.global_avg, 81 | max=self.max, 82 | value=self.value) 83 | 84 | 85 | class MetricLogger(object): 86 | def __init__(self, delimiter="\t"): 87 | self.meters = defaultdict(SmoothedValue) 88 | self.delimiter = delimiter 89 | 90 | def update(self, **kwargs): 91 | for k, v in kwargs.items(): 92 | if v is None: 93 | continue 94 | if isinstance(v, torch.Tensor): 95 | v = v.item() 96 | assert isinstance(v, (float, int)) 97 | self.meters[k].update(v) 98 | 99 | def __getattr__(self, attr): 100 | if attr in self.meters: 101 | return self.meters[attr] 102 | if attr in self.__dict__: 103 | return self.__dict__[attr] 104 | raise AttributeError("'{}' object has no attribute '{}'".format( 105 | type(self).__name__, attr)) 106 | 107 | def __str__(self): 108 | loss_str = [] 109 | for name, meter in self.meters.items(): 110 | loss_str.append( 111 | "{}: {}".format(name, str(meter)) 112 | ) 113 | return self.delimiter.join(loss_str) 114 | 115 | def synchronize_between_processes(self): 116 | for meter in self.meters.values(): 117 | meter.synchronize_between_processes() 118 | 119 | def add_meter(self, name, meter): 120 | self.meters[name] = meter 121 | 122 | def log_every(self, iterable, print_freq, header=None): 123 | i = 0 124 | if not header: 125 | header = '' 126 | start_time = time.time() 127 | end = time.time() 128 | iter_time = SmoothedValue(fmt='{avg:.4f}') 129 | data_time = SmoothedValue(fmt='{avg:.4f}') 130 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 131 | log_msg = [ 132 | header, 133 | '[{0' + space_fmt + '}/{1}]', 134 | 'eta: {eta}', 135 | '{meters}', 136 | 'time: {time}', 137 | 'data: {data}' 138 | ] 139 | if torch.cuda.is_available(): 140 | log_msg.append('max mem: {memory:.0f}') 141 | log_msg = self.delimiter.join(log_msg) 142 | MB = 1024.0 * 1024.0 143 | for obj in iterable: 144 | data_time.update(time.time() - end) 145 | yield obj 146 | iter_time.update(time.time() - end) 147 | if i % print_freq == 0 or i == len(iterable) - 1: 148 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 149 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 150 | if torch.cuda.is_available(): 151 | print(log_msg.format( 152 | i, len(iterable), eta=eta_string, 153 | meters=str(self), 154 | time=str(iter_time), data=str(data_time), 155 | memory=torch.cuda.max_memory_allocated() / MB)) 156 | else: 157 | print(log_msg.format( 158 | i, len(iterable), eta=eta_string, 159 | meters=str(self), 160 | time=str(iter_time), data=str(data_time))) 161 | i += 1 162 | end = time.time() 163 | total_time = time.time() - start_time 164 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 165 | print('{} Total time: {} ({:.4f} s / it)'.format( 166 | header, total_time_str, total_time / len(iterable))) 167 | 168 | 169 | def setup_for_distributed(is_master): 170 | """ 171 | This function disables printing when not in master process 172 | """ 173 | builtin_print = builtins.print 174 | 175 | def print(*args, **kwargs): 176 | force = kwargs.pop('force', False) 177 | force = force or (get_world_size() > 8) 178 | if is_master or force: 179 | now = datetime.datetime.now().time() 180 | builtin_print('[{}] '.format(now), end='') # print with time stamp 181 | builtin_print(*args, **kwargs) 182 | 183 | builtins.print = print 184 | 185 | 186 | def is_dist_avail_and_initialized(): 187 | if not dist.is_available(): 188 | return False 189 | if not dist.is_initialized(): 190 | return False 191 | return True 192 | 193 | 194 | def get_world_size(): 195 | if not is_dist_avail_and_initialized(): 196 | return 1 197 | return dist.get_world_size() 198 | 199 | 200 | def get_rank(): 201 | if not is_dist_avail_and_initialized(): 202 | return 0 203 | return dist.get_rank() 204 | 205 | 206 | def is_main_process(): 207 | return get_rank() == 0 208 | 209 | 210 | def save_on_master(*args, **kwargs): 211 | if is_main_process(): 212 | torch.save(*args, **kwargs) 213 | 214 | 215 | def init_distributed_mode(args): 216 | if args.dist_on_itp: 217 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 218 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 219 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 220 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 221 | os.environ['LOCAL_RANK'] = str(args.gpu) 222 | os.environ['RANK'] = str(args.rank) 223 | os.environ['WORLD_SIZE'] = str(args.world_size) 224 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 225 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 226 | args.rank = int(os.environ["RANK"]) 227 | args.world_size = int(os.environ['WORLD_SIZE']) 228 | args.gpu = int(os.environ['LOCAL_RANK']) 229 | elif 'SLURM_PROCID' in os.environ: 230 | args.rank = int(os.environ['SLURM_PROCID']) 231 | args.gpu = args.rank % torch.cuda.device_count() 232 | else: 233 | print('Not using distributed mode') 234 | setup_for_distributed(is_master=True) # hack 235 | args.distributed = False 236 | return 237 | 238 | args.distributed = True 239 | 240 | torch.cuda.set_device(args.gpu) 241 | args.dist_backend = 'nccl' 242 | print('| distributed init (rank {}): {}, gpu {}'.format( 243 | args.rank, args.dist_url, args.gpu), flush=True) 244 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 245 | world_size=args.world_size, rank=args.rank) 246 | torch.distributed.barrier() 247 | setup_for_distributed(args.rank == 0) 248 | 249 | 250 | class NativeScalerWithGradNormCount: 251 | state_dict_key = "amp_scaler" 252 | 253 | def __init__(self): 254 | self._scaler = torch.cuda.amp.GradScaler() 255 | 256 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 257 | self._scaler.scale(loss).backward(create_graph=create_graph) 258 | if update_grad: 259 | if clip_grad is not None: 260 | assert parameters is not None 261 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 262 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 263 | else: 264 | self._scaler.unscale_(optimizer) 265 | norm = get_grad_norm_(parameters) 266 | self._scaler.step(optimizer) 267 | self._scaler.update() 268 | else: 269 | norm = None 270 | return norm 271 | 272 | def state_dict(self): 273 | return self._scaler.state_dict() 274 | 275 | def load_state_dict(self, state_dict): 276 | self._scaler.load_state_dict(state_dict) 277 | 278 | 279 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 280 | if isinstance(parameters, torch.Tensor): 281 | parameters = [parameters] 282 | parameters = [p for p in parameters if p.grad is not None] 283 | norm_type = float(norm_type) 284 | if len(parameters) == 0: 285 | return torch.tensor(0.) 286 | device = parameters[0].grad.device 287 | if norm_type == inf: 288 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 289 | else: 290 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 291 | return total_norm 292 | 293 | 294 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler): 295 | output_dir = Path(args.output_dir) 296 | epoch_name = str(epoch) 297 | if loss_scaler is not None: 298 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] 299 | for checkpoint_path in checkpoint_paths: 300 | to_save = { 301 | 'model': model_without_ddp.state_dict(), 302 | 'optimizer': optimizer.state_dict(), 303 | 'epoch': epoch, 304 | 'scaler': loss_scaler.state_dict(), 305 | 'args': args, 306 | } 307 | 308 | save_on_master(to_save, checkpoint_path) 309 | else: 310 | client_state = {'epoch': epoch} 311 | model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state) 312 | 313 | 314 | def load_model(args, model_without_ddp, optimizer, loss_scaler): 315 | if args.resume: 316 | if args.resume.startswith('https'): 317 | checkpoint = torch.hub.load_state_dict_from_url( 318 | args.resume, map_location='cpu', check_hash=True) 319 | else: 320 | checkpoint = torch.load(args.resume, map_location='cpu') 321 | model_without_ddp.load_state_dict(checkpoint['model']) 322 | print("Resume checkpoint %s" % args.resume) 323 | if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): 324 | optimizer.load_state_dict(checkpoint['optimizer']) 325 | args.start_epoch = checkpoint['epoch'] + 1 326 | if 'scaler' in checkpoint: 327 | loss_scaler.load_state_dict(checkpoint['scaler']) 328 | print("With optim & sched!") 329 | 330 | 331 | def all_reduce_mean(x): 332 | world_size = get_world_size() 333 | if world_size > 1: 334 | x_reduce = torch.tensor(x).cuda() 335 | dist.all_reduce(x_reduce) 336 | x_reduce /= world_size 337 | return x_reduce.item() 338 | else: 339 | return x -------------------------------------------------------------------------------- /util/pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | 12 | import torch 13 | 14 | # -------------------------------------------------------- 15 | # 2D sine-cosine position embedding 16 | # References: 17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 18 | # MoCo v3: https://github.com/facebookresearch/moco-v3 19 | # -------------------------------------------------------- 20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 21 | """ 22 | grid_size: int of the grid height and width 23 | return: 24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 25 | """ 26 | grid_h = np.arange(grid_size, dtype=np.float32) 27 | grid_w = np.arange(grid_size, dtype=np.float32) 28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 29 | grid = np.stack(grid, axis=0) 30 | 31 | grid = grid.reshape([2, 1, grid_size, grid_size]) 32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 33 | if cls_token: 34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 35 | return pos_embed 36 | 37 | 38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 39 | assert embed_dim % 2 == 0 40 | 41 | # use half of dimensions to encode grid_h 42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 44 | 45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 46 | return emb 47 | 48 | 49 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 50 | """ 51 | embed_dim: output dimension for each position 52 | pos: a list of positions to be encoded: size (M,) 53 | out: (M, D) 54 | """ 55 | assert embed_dim % 2 == 0 56 | omega = np.arange(embed_dim // 2, dtype=np.float) 57 | omega /= embed_dim / 2. 58 | omega = 1. / 10000**omega # (D/2,) 59 | 60 | pos = pos.reshape(-1) # (M,) 61 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 62 | 63 | emb_sin = np.sin(out) # (M, D/2) 64 | emb_cos = np.cos(out) # (M, D/2) 65 | 66 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 67 | return emb 68 | 69 | 70 | # -------------------------------------------------------- 71 | # Interpolate position embeddings for high-resolution 72 | # References: 73 | # DeiT: https://github.com/facebookresearch/deit 74 | # -------------------------------------------------------- 75 | def interpolate_pos_embed(model, checkpoint_model): 76 | if 'pos_embed' in checkpoint_model: 77 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 78 | embedding_size = pos_embed_checkpoint.shape[-1] 79 | num_patches = model.patch_embed.num_patches 80 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 81 | # height (== width) for the checkpoint position embedding 82 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 83 | # height (== width) for the new position embedding 84 | new_size = int(num_patches ** 0.5) 85 | # class_token and dist_token are kept unchanged 86 | if orig_size != new_size: 87 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 88 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 89 | # only the position tokens are interpolated 90 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 91 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 92 | pos_tokens = torch.nn.functional.interpolate( 93 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 94 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 95 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 96 | checkpoint_model['pos_embed'] = new_pos_embed 97 | --------------------------------------------------------------------------------