├── .gitignore ├── LICENSE ├── README.md ├── assets ├── banner.png ├── brain.png ├── omnimodal_pretraining.png ├── paradigm.png └── scaling_laws.png ├── data ├── README.md ├── caption_config │ ├── caption-generation-audio.json │ ├── caption-generation-vision.json │ ├── default_model_cfg.json │ └── default_run_cfg.json ├── config.yaml ├── data │ ├── IndexAnno.py │ ├── IndexSrc.py │ ├── __init__.py │ ├── audio_mapper.py │ ├── loader.py │ └── vision_mapper.py ├── download_hdvila.sh ├── makeparquet.py ├── model │ ├── __init__.py │ ├── audio_encoders │ │ ├── ast │ │ │ └── ast.py │ │ └── beats │ │ │ └── beats.py │ ├── general_module.py │ ├── text_encoders │ │ └── bert │ │ │ └── bert.py │ ├── vast.py │ └── vision_encoders │ │ ├── clip │ │ ├── clip.py │ │ └── clip_tokenizer.py │ │ ├── evaclip │ │ ├── __init__.py │ │ ├── bpe_simple_vocab_16e6.txt.gz │ │ ├── constants.py │ │ ├── eva_vit_model.py │ │ ├── factory.py │ │ ├── hf_configs.py │ │ ├── hf_model.py │ │ ├── loss.py │ │ ├── model.py │ │ ├── model_configs │ │ │ ├── EVA01-CLIP-B-16.json │ │ │ ├── EVA01-CLIP-g-14-plus.json │ │ │ ├── EVA01-CLIP-g-14.json │ │ │ ├── EVA02-CLIP-B-16.json │ │ │ ├── EVA02-CLIP-L-14-336.json │ │ │ ├── EVA02-CLIP-L-14.json │ │ │ ├── EVA02-CLIP-bigE-14-plus.json │ │ │ └── EVA02-CLIP-bigE-14.json │ │ ├── modified_resnet.py │ │ ├── openai.py │ │ ├── pretrained.py │ │ ├── rope.py │ │ ├── timm_model.py │ │ ├── tokenizer.py │ │ ├── transform.py │ │ ├── transformer.py │ │ └── utils.py │ │ ├── swin │ │ ├── swin.py │ │ └── swin_config.py │ │ └── videoswin │ │ └── videoswin.py ├── run.py ├── scripts │ ├── run_audio_captioner.sh │ └── run_vision_captioner.sh ├── setup_env.sh └── utils │ ├── __init__.py │ ├── args.py │ ├── build_dataloader.py │ ├── build_model.py │ ├── build_optimizer.py │ ├── distributed.py │ ├── initialize.py │ ├── logger.py │ ├── offline_process_data.py │ ├── pipeline.py │ ├── save.py │ ├── sched.py │ └── tool.py ├── example ├── test.flac ├── test.jpeg └── test.mp4 ├── inference_demo.py ├── model ├── .DS_Store ├── audioprocessor.py ├── bert-base-uncased-crossattn │ ├── config.json │ └── generation_config.json ├── bert.py ├── clip │ ├── clip.py │ └── clip_tokenizer.py ├── evaclip │ ├── __init__.py │ ├── bpe_simple_vocab_16e6.txt.gz │ ├── constants.py │ ├── eva_vit_model.py │ ├── factory.py │ ├── hf_configs.py │ ├── hf_model.py │ ├── loss.py │ ├── model.py │ ├── model_configs │ │ ├── EVA01-CLIP-B-16.json │ │ ├── EVA01-CLIP-g-14-plus.json │ │ ├── EVA01-CLIP-g-14.json │ │ ├── EVA02-CLIP-B-16.json │ │ ├── EVA02-CLIP-L-14-336.json │ │ ├── EVA02-CLIP-L-14.json │ │ ├── EVA02-CLIP-bigE-14-plus.json │ │ └── EVA02-CLIP-bigE-14.json │ ├── modified_resnet.py │ ├── openai.py │ ├── pretrained.py │ ├── rope.py │ ├── timm_model.py │ ├── tokenizer.py │ ├── transform.py │ ├── transformer.py │ └── utils.py ├── imageprocessor.py ├── mico.py ├── swin.py ├── swin_base_patch4_window7_224_22k.yaml ├── swin_config.py ├── tokenizer │ ├── special_tokens_map.json │ ├── tokenizer_config.json │ └── vocab.txt ├── transformer.py └── videoprocessor.py └── set_env.sh /.gitignore: -------------------------------------------------------------------------------- 1 | /.history 2 | examples/ 3 | *.pyc 4 | *.pt 5 | wandb/ 6 | 7 | .history 8 | /cococaption 9 | 10 | Mico-g/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

4 | 5 | [![arXiv](https://img.shields.io/badge/arxiv-2406.09412-b31b1b?style=plastic&color=b31b1b&link=https%3A%2F%2Farxiv.org%2Fabs%2F2406.09412)](https://arxiv.org/abs/2406.09412) 6 | [![website](https://img.shields.io/badge/Project-Website-purple)](https://invictus717.github.io/MiCo/) 7 | [![Hugging Face Models](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-blue)](https://huggingface.co/Yiyuan/MiCo-ViT-g-14-omnimodal-300k-b64K) 8 | 9 | License: Apache2.0 10 | 11 | 13 | 14 | 15 | 16 | ### ✨ Inspiration of Multimodal Context: Multimedia Brain Cognition 17 | 18 |

19 | 20 |

21 | 22 | ***How the human brain performs coherent multimodal cognition?*** 23 | 24 | As outlined in Richard Mayer's Cognitive Theory of Multimedia Learning,our brain processes multimedia signals through two distinct channels—auditory and visual—in sensory memory, as depicted in Figure(a). The sensory memory integrates these signals with prior knowledge through words, transforming new multimedia information into long-term memory. Notably, **1**) multimedia signals in the brain share channels, and **2**) words function as the reasoning interface in our brain. 25 | 26 | Inspired by these insights, we categorize diverse modalities into two types: ``knowledge modality`` and ``interface modality``. *Knowledge modalities*, primarily derived from raw sensors, contribute knowledge in diverse formats. For example, images and depth maps offer visual knowledge, while audio and video provide auditory and spatiotemporal knowledge. The language modality, developed by humans, is inherently more abstract and naturally functions as the *interface modality*, facilitating learning, reasoning, and the coordination of knowledge. To this end, we design an omni-modal learning architecture, illustrated in Figure (b), with two distinct branches: one for knowledge modalities and one for the interface modality, *i.e.* natural language. The knowledge and interface modalities are aligned through a novel generative reasoning method. 27 | 28 | ### 🚀 MiCo, An omni-modal and scalable pretraining paradigm 29 | 30 |

31 | 32 |

33 | 34 | We propose collecting large-scale omni-modal paired data, including text, 35 | image, video, depth, and normal maps, to learn universal representations. 36 | 37 |

38 | 39 |

40 | 41 | **🚀 Evolution of Pretraining Paradigms**. Masked modeling (a) has shown great success in single modality, general-purpose understanding. Contrastive learning (b) distinguishes transferable features with modality tuples (such as text-image, text-video, text-audio, etc). 42 | 43 | *🚀🚀🚀 We aim to achieve general-purpose omni-modal understanding and learn transferable, universal representations in (c).* 44 | 45 | ### 🌟🌟🌟 The Multimodal Scaling Laws with MiCo: Modalities Help Modalies! 46 | 47 |

48 | 49 |

50 | 51 | ### 🔓 Pretrained Omni-Modal Models 52 | 53 | **We will continue to update this model zoo including all scales of ViTs and highly-efficient ConvNets with the MiCo pretraining paradigm** 54 | 55 | Current Checkpoints 56 |
57 |
58 | 59 | | Model | Pretraining | Scale | Modality | #Param | Google Drive | Hugging Face 60 | | :------------: | :----------: | :----------------------: | :----: | :---------------------------------------------------------------------------------------------------: |:----: | :----: | 61 | | MiCo | 300k steps | ViT-g | Omni-modal | 1.3B | [ckpt](https://drive.google.com/drive/folders/1AIQjV1KU8K4OXiO-4gFirxkoxt3twWIq?usp=sharing) | [ckpt](https://huggingface.co/Yiyuan/MiCo-ViT-g-14-omnimodal-300k-b64K) 62 | 63 | 64 |
65 | 66 | ### 🔓 Omni-Modal Dataset Collection 67 | 68 | We provdie a detailed [doc](data/README.md) for preparing the omni-modal dataset step-by-step 69 | 70 | ### ⚡ Quick Start 71 | ```bash 72 | pip install gdown 73 | gdown 1AIQjV1KU8K4OXiO-4gFirxkoxt3twWIq --folder 74 | python inference_demo.py 75 | ``` 76 | # Citation 77 | If the code and paper help your research, please kindly cite: 78 | ``` 79 | @article{zhang2024explore, 80 | title={Explore the Limits of Omni-modal Pretraining at Scale}, 81 | author={Zhang, Yiyuan and Li, Handong and Liu, Jing and Yue, Xiangyu}, 82 | journal={arXiv preprint arXiv:2406.09412}, 83 | year={2024} 84 | } 85 | ``` 86 | # License 87 | This project is released under the [Apache 2.0 license](LICENSE). 88 | # Acknowledgement 89 | We appreciate [Dr. Xiaohan Ding](https://dingxiaohan.xyz/) for the valuable discussion and suggestions.This code is developed based [Meta-Transformer](https://github.com/invictus717/MetaTransformer), [VAST](https://github.com/TXH-mercury/VAST), [DPT](https://github.com/EPFL-VILAB/omnidata), and [GeoWizard](https://github.com/fuxiao0719/GeoWizard). -------------------------------------------------------------------------------- /assets/banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/invictus717/MiCo/831847ff066296d185c7f304563326224b9dbc8b/assets/banner.png -------------------------------------------------------------------------------- /assets/brain.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/invictus717/MiCo/831847ff066296d185c7f304563326224b9dbc8b/assets/brain.png -------------------------------------------------------------------------------- /assets/omnimodal_pretraining.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/invictus717/MiCo/831847ff066296d185c7f304563326224b9dbc8b/assets/omnimodal_pretraining.png -------------------------------------------------------------------------------- /assets/paradigm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/invictus717/MiCo/831847ff066296d185c7f304563326224b9dbc8b/assets/paradigm.png -------------------------------------------------------------------------------- /assets/scaling_laws.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/invictus717/MiCo/831847ff066296d185c7f304563326224b9dbc8b/assets/scaling_laws.png -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | ### Data Preparation 2 | 3 | #### Install Download Tools 4 | ```bash 5 | pip install video2dataset 6 | ``` 7 | or 8 | ```bash 9 | git clone https://github.com/iejMac/video2dataset 10 | cd video2dataset 11 | pip install -e . 12 | ``` 13 | #### Download Metadata 14 | ```bash 15 | wget -O hdvila100m.zip https://hdvila.blob.core.windows.net/dataset/hdvila100m.zip?sp=r&st=2022-06-28T03:33:11Z&se=2026-01-01T11:33:11Z&spr=https&sv=2021-06-08&sr=b&sig=VaqQkLFDqKinfkaPNs1jJ1EQIYCB%2FUPYiqFqmjWye6Y%3D 16 | ``` 17 | Then unzip the metadata zip file. 18 | ```bash 19 | unzip hdvilla100m.zip 20 | ``` 21 | With the metadata, we will deal with these data into parquet files by running this code: 22 | ```bash 23 | python makeparquet.py 24 | ``` 25 | Once you run this, you should have a file `hd_vila.parquet` with all the relevant metadata. The files are organized as: 26 | ```bash 27 | data 28 | ├── caption_config 29 | ├── model 30 | ├── scripts 31 | ├── utils 32 | ├── makeparquet.py 33 | ├── config.yaml 34 | ├── download_hdvila.sh 35 | ├── hdvila 36 | │ ├── hdvila_part0.jsonl 37 | │ ├── hdvila_part1.jsonl 38 | │ ├── hdvila_part2.jsonl 39 | │ ├── hdvila_part3.jsonl 40 | │ ├── hdvila_part4.jsonl 41 | │ ├── hdvila_part5.jsonl 42 | │ ├── hdvila_part6.jsonl 43 | │ ├── hdvila_part7.jsonl 44 | │ ├── hdvila_part8.jsonl 45 | │ ├── hdvila_part9.jsonl 46 | │ ├── hdvila_part10.jsonl 47 | │ ├── hd_vila.parquet 48 | ``` 49 | #### Download HDVILA-100M Source Data 50 | Please check your path in `download_hdvila.sh` before running the script for downloading the dataset: 51 | ```bash 52 | bash download_hdvila.sh 53 | ``` 54 | #### Annotate Your Videos 55 | 1. Download Pretrained Captioners for Videos (Images) and Audio. 56 | ```bash 57 | pip install gdown 58 | gdown https://drive.google.com/file/d/1vYqb0Lb_3sQ5bo6XV-FQ4n7k_0M9UMU3/view?usp=sharing 59 | tar -xvf audio_captioner.tar.gz 60 | gdown https://drive.google.com/file/d/1ZFCWZ8csMWLYsg9CWt71PJmKYpSn-FMt/view?usp=sharing 61 | tar -xvf vision_captioner.tar.gz 62 | ``` 63 | 2. Deploy captioners for data annotation 64 | Set up the python environment for captioner. 65 | ```bash 66 | bash setup_env.sh 67 | ``` 68 | 69 | Video Annotation with Captions 70 | ```bash 71 | bash scripts/run_vision_captioner.sh 72 | ``` 73 | 74 | Audio Annotation with Captions 75 | ```bash 76 | bash scripts/run_audio_captioner.sh 77 | ``` 78 | 3. (Optional) Deploy Depth Estimator to annotate 3D contents 79 | *We highly recommend you to use [GeoWizard](https://github.com/fuxiao0719/GeoWizard) to generate high-quality 3D contents*. 80 | while the shortage of *GeoWizard* is the inference speed of generative models. Therefore, in our practice, we use the [DPT](https://github.com/EPFL-VILAB/omnidata) to annotate major data. -------------------------------------------------------------------------------- /data/caption_config/caption-generation-audio.json: -------------------------------------------------------------------------------- 1 | {"run_cfg": 2 | {"default":"default_run_cfg.json", 3 | "mode":"testing"}, 4 | 5 | "model_cfg": 6 | {"default":"default_model_cfg.json"}, 7 | 8 | "data_cfg": 9 | 10 | {"train":{}, 11 | 12 | "val": 13 | [ 14 | { 15 | "type":"annoindexed", 16 | "training":false, 17 | "name": "yourdata", 18 | "txt": "datasets/annotations/yourdata/meta.json", 19 | "audio": "datasets/srcdata/yourdata/audios", 20 | "audio_sample_num": 3, 21 | "task" : "cap%ta", 22 | "n_workers": 8, 23 | "batch_size": 64 } 24 | ]}} 25 | 26 | 27 | -------------------------------------------------------------------------------- /data/caption_config/caption-generation-vision.json: -------------------------------------------------------------------------------- 1 | 2 | {"run_cfg": 3 | {"default":"default_run_cfg.json", 4 | "mode":"testing"}, 5 | 6 | "model_cfg": 7 | {"default":"default_model_cfg.json"}, 8 | 9 | "data_cfg": 10 | 11 | {"train":{}, 12 | 13 | "val": 14 | [{ 15 | "type":"annoindexed", 16 | "training":false, 17 | "name": "yourdata", 18 | "txt": "datasets/annotations/yourdata/meta.json", 19 | "vision": "datasets/srcdata/yourdata/videos", 20 | "vision_format": "video_rawvideo", 21 | "vision_sample_num": 8, 22 | "task" : "cap%tv", 23 | "n_workers": 8, 24 | "batch_size": 64 25 | }]}} -------------------------------------------------------------------------------- /data/caption_config/default_model_cfg.json: -------------------------------------------------------------------------------- 1 | {"model_type": "vast", 2 | "itm_ratio":0.1, 3 | "frozen_vision":false, 4 | "frozen_audio":false, 5 | "checkpointing":false, 6 | "max_caption_len":40, 7 | "max_omni_caption_len":70, 8 | "max_subtitle_len":70, 9 | "contra_dim":512, 10 | "inherit_keys":["vision_encoder_type","audio_encoder_type","audio_melbins","audio_target_length"], 11 | "frame_embedding_type":"adaptive", 12 | "vision_resolution":224, 13 | "vision_encoder_type":"evaclip01_giant", 14 | "audio_encoder_type":"beats", 15 | "audio_melbins":64, 16 | "audio_target_length": 1024, 17 | "beam_size":3, 18 | "captioner_mode":false, 19 | "generate_nums":1, 20 | "ret_bidirection_evaluation":false, 21 | "itm_rerank_num":50, 22 | "evaluation_type":"evaluation_mm"} 23 | 24 | -------------------------------------------------------------------------------- /data/caption_config/default_run_cfg.json: -------------------------------------------------------------------------------- 1 | {"checkpoint":"", 2 | "output_dir":"none", 3 | "gradient_accumulation_steps":1, 4 | "clip_lr":5e-7, 5 | "optim":"adamw", 6 | "learning_rate":1e-4, 7 | "betas":[0.9, 0.98], 8 | "weight_decay":0.01, 9 | "grad_norm":2.0, 10 | "warmup_ratio":0.1, 11 | "resume":false, 12 | "seed":50, 13 | "fp16":true, 14 | "bf16":false, 15 | "zero_shot":false, 16 | "scheduler":"warmup_linear", 17 | "new_lr":0, 18 | "new_params_name":[], 19 | "valid_freq":10, 20 | "dataset_mix_type":"random", 21 | "remove_before_ckpt":true, 22 | "first_eval":true, 23 | "pretrain_dir":"", 24 | "num_train_steps":0, 25 | "save_best":false, 26 | "pin_mem":true, 27 | "vision_resolution":224, 28 | "use_ddp":true, 29 | "mode":"training", 30 | "log_steps":100 31 | } -------------------------------------------------------------------------------- /data/config.yaml: -------------------------------------------------------------------------------- 1 | subsampling: 2 | CutDetectionSubsampler: 3 | args: 4 | cut_detection_mode: "all" 5 | framerates: null 6 | threshold: 11.5 7 | min_scene_len: 15 8 | reading: 9 | yt_args: 10 | download_size: 360 11 | download_audio_rate: 44100 12 | yt_metadata_args: 13 | writesubtitles: 'all' 14 | subtitleslangs: ['en'] 15 | writeautomaticsub: True 16 | get_info: True 17 | timeout: 180 18 | sampler: null 19 | 20 | storage: 21 | number_sample_per_shard: 100 22 | captions_are_subtitles: False 23 | oom_shard_count: 5 24 | 25 | distribution: 26 | processes_count: 2 27 | thread_count: 8 28 | subjob_size: 1000 29 | distributor: "multiprocessing" -------------------------------------------------------------------------------- /data/data/IndexAnno.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import json 4 | import random 5 | import torch 6 | import numpy as np 7 | import torch.nn.functional as F 8 | from toolz.sandbox import unzip 9 | from torch.utils.data import Dataset 10 | from utils.logger import LOGGER 11 | from .vision_mapper import VisionMapper 12 | from .audio_mapper import AudioMapper 13 | 14 | from torch.utils.data import ConcatDataset 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | class AnnoIndexedDataset(Dataset): 24 | def __init__(self, d_cfg, args): 25 | self.vision_mapper = VisionMapper(d_cfg, args) if 'vision' in d_cfg else None 26 | self.audio_mapper = AudioMapper(d_cfg, args) if 'audio' in d_cfg else None 27 | self.annos = json.load(open(d_cfg['txt'])) 28 | self.idx = list(range(len(self.annos))) 29 | self.dataset_name = d_cfg['name'] 30 | self.training = d_cfg.training 31 | 32 | self.worker_init_fn = None 33 | self.use_sampler = True 34 | self.collate_fn = annoindexedcollate 35 | 36 | self.annfile = getattr(d_cfg,'annfile',None) 37 | self.make_submission = getattr(d_cfg,'make_submission',False) 38 | self.multi_evaluation = getattr(d_cfg,'multi_evaluation',False) 39 | self.vqa_anno_file = getattr(d_cfg,'vqa_anno_file',None) 40 | self.vqa_question_file = getattr(d_cfg,'vqa_question_file',None) 41 | 42 | 43 | def __len__(self): 44 | return len(self.annos) 45 | 46 | def __getitem__(self, i): 47 | anno = self.annos[i] 48 | 49 | for key in ['video_id','image_id','image','id']: 50 | if key in anno: 51 | id_ = anno[key] 52 | break 53 | 54 | raw_captions = None 55 | raw_subtitles = None 56 | question_id = None 57 | question = None 58 | answer = None 59 | id_txt = None 60 | vision_pixels = None 61 | audio_spectrograms = None 62 | 63 | 64 | 65 | 66 | 67 | 68 | raw_captions = anno['desc'] if 'desc' in anno else anno['caption'] 69 | num_samples = len(raw_captions) if isinstance(raw_captions, list) else 1 70 | id_txt = [id_] * num_samples 71 | 72 | 73 | if 'subtitle' in anno: 74 | raw_subtitles = anno['subtitle'] 75 | 76 | if 'question' in anno: 77 | 78 | if self.training: 79 | question = anno['question'] 80 | if isinstance(anno['answer'],list): #### vqav2 81 | answer = random.choice(anno['answer']) 82 | else: 83 | answer = anno['answer'] 84 | 85 | else: 86 | question = anno['question'] 87 | answer = anno['answer'] 88 | if 'question_id' in anno: 89 | question_id = anno['question_id'] 90 | 91 | 92 | if self.vision_mapper: 93 | if self.vision_mapper.vision_format == 'video_feats': 94 | vision_feats = self.vision_mapper.read(id_) 95 | 96 | else: 97 | vision_pixels = self.vision_mapper.read(id_) 98 | if vision_pixels is None: ###wrong img/video, resample when training and raise error when testing 99 | if self.training: 100 | resample_idx = random.choice(self.idx) 101 | LOGGER.info(f'current idx {id_} from {self.dataset_name} returns wrong image/video, use {resample_idx} instead.') 102 | return self.__getitem__(resample_idx) 103 | else: 104 | resample_idx = random.choice(self.idx) 105 | LOGGER.info(f'current idx {id_} from {self.dataset_name} returns wrong image/video,!!!!!!!!!!!!!!!!!!!!!!!! use {resample_idx} instead.') 106 | return self.__getitem__(resample_idx) 107 | # raise ValueError 108 | 109 | if self.audio_mapper: 110 | audio_spectrograms = self.audio_mapper.read(id_) 111 | if audio_spectrograms is None: ### wrong audio, resample when training and raise error when testing 112 | if self.training: 113 | resample_idx = random.choice(self.idx) 114 | LOGGER.info(f'current idx {id_} from {self.dataset_name} returns wrong audio, use {resample_idx} instead.') 115 | return self.__getitem__(resample_idx) 116 | else: 117 | raise ValueError 118 | 119 | return id_, raw_captions, vision_pixels, id_txt, question, answer, question_id, \ 120 | audio_spectrograms, raw_subtitles 121 | 122 | 123 | 124 | def annoindexedcollate(inputs): 125 | 126 | batch = {} 127 | all_data = map(list, unzip(inputs)) 128 | keys = ['ids', 129 | 'raw_captions', 130 | 'vision_pixels', 131 | 'ids_txt', 132 | 'raw_questions', 133 | 'raw_answers', 134 | 'question_ids', 135 | 'audio_spectrograms', 136 | 'raw_subtitles'] 137 | 138 | for key, data in zip(keys, all_data): 139 | 140 | if data[0] is None: 141 | continue 142 | elif isinstance(data[0], torch.Tensor): 143 | batch[key] = torch.stack(data, dim=0).float() 144 | 145 | else: 146 | batch[key] = data 147 | 148 | 149 | 150 | return batch 151 | 152 | 153 | -------------------------------------------------------------------------------- /data/data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .IndexAnno import AnnoIndexedDataset 3 | from .IndexSrc import SrcIndexedDataset 4 | 5 | data_registry={ 6 | 'annoindexed':AnnoIndexedDataset, 7 | 'srcindexed':SrcIndexedDataset, 8 | 9 | } 10 | -------------------------------------------------------------------------------- /data/data/audio_mapper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import torch 4 | import torchaudio 5 | from utils.logger import LOGGER 6 | from utils.tool import split 7 | 8 | 9 | class AudioMapper(object): 10 | # def __init__(self, audio_dir, opts, sample_num, check_exists=True): 11 | def __init__(self, d_cfg, args): 12 | self.audio_dir = d_cfg.audio 13 | self.melbins = args.model_cfg.audio_melbins 14 | self.target_length = args.model_cfg.audio_target_length 15 | self.training = d_cfg.training 16 | self.frame_shift = 10 17 | self.sample_num = d_cfg.audio_sample_num 18 | self.audio_encoder_type = args.model_cfg.audio_encoder_type 19 | if self.audio_encoder_type == 'ast': 20 | self.mean = -4.2677393 21 | self.std = 4.5689974 22 | elif self.audio_encoder_type == 'beats': 23 | self.mean = 15.41663 24 | self.std = 6.55582 25 | else: 26 | raise NotImplementedError 27 | 28 | 29 | 30 | def read(self, id_): 31 | 32 | wav_file = os.path.join(self.audio_dir, id_) 33 | 34 | if not os.path.exists(wav_file): 35 | wav_file = os.path.join(self.audio_dir, id_+'.wav') 36 | if not os.path.exists(wav_file): 37 | wav_file = wav_file.replace('wav','mp3') 38 | if not os.path.exists(wav_file): 39 | wav_file = wav_file.replace('mp3','mkv') 40 | if not os.path.exists(wav_file): 41 | print('not have audios', id_) 42 | return torch.zeros(self.sample_num, self.target_length, self.melbins) 43 | try: 44 | if self.audio_encoder_type == 'ast': 45 | 46 | waveform, sr = torchaudio.load(wav_file) 47 | 48 | waveform = waveform - waveform.mean() 49 | fbank = torchaudio.compliance.kaldi.fbank(waveform, htk_compat=True, sample_frequency=sr, use_energy=False, 50 | window_type='hanning', num_mel_bins=self.melbins, dither=0.0, frame_shift=self.frame_shift) 51 | 52 | 53 | 54 | elif self.audio_encoder_type == 'beats': 55 | 56 | waveform, sr = torchaudio.load(wav_file) 57 | if sr != 16000: 58 | trans = torchaudio.transforms.Resample(sr, 16000) 59 | waveform = trans(waveform) 60 | 61 | waveform = waveform * 2 ** 15 62 | fbank = torchaudio.compliance.kaldi.fbank(waveform, num_mel_bins=self.melbins, sample_frequency=16000, frame_length=25, frame_shift=10) 63 | 64 | else: 65 | raise NotImplementedError 66 | 67 | # ### normalization 68 | fbank = (fbank - self.mean) / (self.std * 2) 69 | src_length = fbank.shape[0] 70 | # #### sample 71 | output_slices = [] 72 | pad_len = max(self.target_length * self.sample_num -src_length, self.target_length - src_length%self.target_length) 73 | fbank = torch.nn.ZeroPad2d((0, 0, 0, pad_len))(fbank) 74 | total_slice_num = fbank.shape[0] // self.target_length 75 | total_slice_num = list(range(total_slice_num)) 76 | total_slice_num = split(total_slice_num, self.sample_num) 77 | 78 | if self.training: 79 | sample_idx = [random.choice(i) for i in total_slice_num] 80 | else: 81 | sample_idx = [i[(len(i)+1)//2-1] for i in total_slice_num] 82 | 83 | 84 | for i in sample_idx: 85 | cur_bank = fbank[i*self.target_length : (i+1)*self.target_length] 86 | output_slices.append(cur_bank) 87 | 88 | fbank = torch.stack(output_slices,dim=0) ### n, 1024, 128 89 | return fbank 90 | 91 | except Exception as e: 92 | print(e) 93 | return 94 | 95 | -------------------------------------------------------------------------------- /data/data/loader.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | from torch.utils.data import DataLoader 4 | from utils.distributed import any_broadcast 5 | from torch.utils.data.distributed import DistributedSampler 6 | 7 | 8 | class MetaLoader(object): 9 | """ wraps multiple data loaders """ 10 | def __init__(self, loaders, accum_steps=1, distributed=False): 11 | assert isinstance(loaders, dict) 12 | self.name2loader = {} 13 | self.name2iter = {} 14 | self.sampling_pools = [] 15 | self.name2labelname={} 16 | for idx, (n, l) in enumerate(loaders.items()): 17 | if isinstance(l, tuple): 18 | l, r = l 19 | elif isinstance(l, DataLoader): 20 | r = 1 21 | else: 22 | raise ValueError() 23 | self.name2loader[n] = l 24 | self.name2iter[n] = iter(l) 25 | self.sampling_pools.extend([n]*r) 26 | # import ipdb 27 | # ipdb.set_trace() 28 | 29 | 30 | self.accum_steps = accum_steps 31 | self.distributed = distributed 32 | self.step = 0 33 | self.epoch = 0 34 | 35 | 36 | def __iter__(self): 37 | """ this iterator will run indefinitely """ 38 | task = self.sampling_pools[0] 39 | while True: 40 | if self.step % self.accum_steps == 0: 41 | task = random.choice(self.sampling_pools) 42 | if self.distributed: 43 | # make sure all process is training same task 44 | task = any_broadcast(task, 0) 45 | self.step += 1 46 | iter_ = self.name2iter[task] 47 | try: 48 | batch = next(iter_) 49 | except StopIteration: 50 | self.epoch = self.epoch + 1 51 | if isinstance(self.name2loader[task].sampler, DistributedSampler): 52 | self.name2loader[task].sampler.set_epoch(self.epoch) 53 | else: 54 | pass 55 | iter_ = iter(self.name2loader[task]) 56 | batch = next(iter_) 57 | self.name2iter[task] = iter_ 58 | 59 | 60 | yield task, batch 61 | 62 | 63 | def move_to_cuda(batch): 64 | if isinstance(batch, torch.Tensor): 65 | return batch.cuda(non_blocking=True) 66 | elif isinstance(batch, list): 67 | new_batch = [move_to_cuda(t) for t in batch] 68 | elif isinstance(batch, tuple): 69 | new_batch = tuple(move_to_cuda(t) for t in batch) 70 | elif isinstance(batch, dict): 71 | new_batch = {n: move_to_cuda(t) for n, t in batch.items()} 72 | else: 73 | return batch 74 | return new_batch 75 | 76 | 77 | def record_cuda_stream(batch): 78 | if isinstance(batch, torch.Tensor): 79 | batch.record_stream(torch.cuda.current_stream()) 80 | elif isinstance(batch, list) or isinstance(batch, tuple): 81 | for t in batch: 82 | record_cuda_stream(t) 83 | elif isinstance(batch, dict): 84 | for t in batch.values(): 85 | record_cuda_stream(t) 86 | else: 87 | pass 88 | 89 | 90 | class PrefetchLoader(object): 91 | """ 92 | overlap compute and cuda data transfer 93 | (copied and then modified from nvidia apex) 94 | """ 95 | def __init__(self, loader): 96 | self.loader = loader 97 | self.stream = torch.cuda.Stream() 98 | 99 | def __iter__(self): 100 | loader_it = iter(self.loader) 101 | self.preload(loader_it) 102 | batch = self.next(loader_it) 103 | while batch is not None: 104 | yield batch 105 | batch = self.next(loader_it) 106 | 107 | def __len__(self): 108 | return len(self.loader) 109 | 110 | def preload(self, it): 111 | try: 112 | self.batch = next(it) 113 | except StopIteration: 114 | self.batch = None 115 | return 116 | # if record_stream() doesn't work, another option is to make sure 117 | # device inputs are created on the main stream. 118 | # self.next_input_gpu = torch.empty_like(self.next_input, 119 | # device='cuda') 120 | # self.next_target_gpu = torch.empty_like(self.next_target, 121 | # device='cuda') 122 | # Need to make sure the memory allocated for next_* is not still in use 123 | # by the main stream at the time we start copying to next_*: 124 | # self.stream.wait_stream(torch.cuda.current_stream()) 125 | with torch.cuda.stream(self.stream): 126 | self.batch = move_to_cuda(self.batch) 127 | # more code for the alternative if record_stream() doesn't work: 128 | # copy_ will record the use of the pinned source tensor in this 129 | # side stream. 130 | # self.next_input_gpu.copy_(self.next_input, non_blocking=True) 131 | # self.next_target_gpu.copy_(self.next_target, non_blocking=True) 132 | # self.next_input = self.next_input_gpu 133 | # self.next_target = self.next_target_gpu 134 | 135 | def next(self, it): 136 | 137 | torch.cuda.current_stream().wait_stream(self.stream) 138 | batch = self.batch 139 | if batch is not None: 140 | record_cuda_stream(batch) 141 | self.preload(it) 142 | return batch 143 | 144 | 145 | 146 | def __getattr__(self, name): 147 | method = self.loader.__getattribute__(name) 148 | return method 149 | -------------------------------------------------------------------------------- /data/download_hdvila.sh: -------------------------------------------------------------------------------- 1 | video2dataset \ 2 | --url_list='hd_vila.parquet' \ 3 | --input_format='parquet' \ 4 | --output_format='files' \ 5 | --output_folder="./hdvila" \ 6 | --url_col="url" \ 7 | --enable_wandb=False \ 8 | --encode_formats="{'video': 'mp4', 'audio': 'mp3'}" \ 9 | --config="config.yaml" \ 10 | --max_shard_retry=3 -------------------------------------------------------------------------------- /data/makeparquet.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import glob 3 | import json 4 | import os 5 | import time 6 | from datetime import datetime 7 | 8 | def time_string_to_seconds(timestamp): 9 | hh,mm,s = timestamp.split(':') 10 | ss,ms = s.split('.') 11 | time = 3600*int(hh) + 60*int(mm) + int(ss) + int(ms)/1000 12 | return time 13 | 14 | def convert_clip_list(clip_list): 15 | return [[time_string_to_seconds(x) for x in clip] for clip in clip_list] 16 | 17 | ###### Change your path 18 | parquet_dir = "/path/to/my/metadata/dir/" 19 | 20 | data = [] 21 | for jsonl in sorted(glob.glob(f"{parquet_dir}*.jsonl")): 22 | path = os.path.join(parquet_dir, jsonl) 23 | with open(path, "r") as f: 24 | for line in f: 25 | json_obj = json.loads(line) 26 | clips = [ 27 | json_obj['clip'][i]['span'] 28 | for i in range(len(json_obj['clip'])) 29 | ] 30 | 31 | out = { 32 | 'video_id': json_obj['video_id'], 33 | 'url': json_obj['url'], 34 | 'clips': clips 35 | } 36 | data.append(out) 37 | 38 | df = pd.DataFrame(data) 39 | df['clips'] = df['clips'].map(lambda x: convert_clip_list(x)) 40 | df.to_parquet("hd_vila.parquet") -------------------------------------------------------------------------------- /data/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .vast import VAST 2 | model_registry = { 3 | 'vast':VAST 4 | } -------------------------------------------------------------------------------- /data/model/audio_encoders/ast/ast.py: -------------------------------------------------------------------------------- 1 | """ 2 | BERT layers from the huggingface implementation 3 | (https://github.com/huggingface/transformers) 4 | """ 5 | # coding=utf-8 6 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 7 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | import logging 21 | import math 22 | import copy 23 | import torch 24 | from torch import nn 25 | from torch.nn import LayerNorm as LayerNorm 26 | import torch.nn.functional as F 27 | import ipdb 28 | 29 | logger = logging.getLogger(__name__) 30 | 31 | 32 | def gelu(x): 33 | """Implementation of the gelu activation function. 34 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 35 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 36 | Also see https://arxiv.org/abs/1606.08415 37 | """ 38 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 39 | 40 | 41 | def swish(x): 42 | return x * torch.sigmoid(x) 43 | 44 | 45 | ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} 46 | 47 | 48 | class GELU(nn.Module): 49 | def forward(self, input_): 50 | output = gelu(input_) 51 | return output 52 | 53 | 54 | 55 | 56 | class TransformerLayer(nn.Module): 57 | def __init__(self, config, mode): 58 | super().__init__() 59 | self.attention = MultiHeadAttention(config) 60 | self.ff_layer = FeedForward(config) 61 | self.dropout = nn.Dropout(config.hidden_dropout) 62 | self.layernorm1 = LayerNorm(config.hidden_size, eps=1e-12) 63 | self.layernorm2 = LayerNorm(config.hidden_size, eps=1e-12) 64 | self.mode = mode 65 | 66 | def forward(self, hidden_states, attention_mask): 67 | if self.mode == 'prenorm': 68 | return self.forward_prenorm(hidden_states, attention_mask) 69 | elif self.mode == 'postnorm': 70 | return self.forward_postnorm(hidden_states, attention_mask) 71 | else: 72 | raise NotImplementedError 73 | 74 | def forward_prenorm(self, hidden_states, attention_mask): 75 | residual = hidden_states 76 | hidden_states = self.layernorm1(hidden_states) 77 | attention_output = self.attention(hidden_states, hidden_states, hidden_states, attention_mask) 78 | hidden_states = residual + self.dropout(attention_output) 79 | 80 | residual = hidden_states 81 | hidden_states = self.layernorm2(hidden_states) 82 | ff_output = self.ff_layer(hidden_states) 83 | hidden_states = residual + self.dropout(ff_output) 84 | 85 | return hidden_states 86 | 87 | def forward_postnorm(self, hidden_states, attention_mask): 88 | residual = hidden_states 89 | attention_output = self.attention(hidden_states, hidden_states, hidden_states, attention_mask) 90 | hidden_states = residual + self.dropout(attention_output) 91 | hidden_states = self.layernorm1(hidden_states) 92 | 93 | residual = hidden_states 94 | ff_output = self.ff_layer(hidden_states) 95 | hidden_states = residual + self.dropout(ff_output) 96 | hidden_states = self.layernorm2(hidden_states) 97 | 98 | return hidden_states 99 | 100 | 101 | def clones(x,times): 102 | return nn.ModuleList([copy.deepcopy(x) for i in range(times)]) 103 | 104 | 105 | 106 | class MultiHeadAttention(nn.Module): 107 | def __init__(self, config): 108 | super().__init__() 109 | self.linears = clones(nn.Linear(config.hidden_size, config.hidden_size), 4) 110 | self.head_num = config.num_attention_heads 111 | self.hidden_size = config.hidden_size 112 | self.dropout=nn.Dropout(config.attention_dropout) 113 | 114 | 115 | def forward(self,q,k,v,mask=None): 116 | batch_size=q.shape[0] 117 | q,k,v=[layer(x).view(batch_size,-1,self.head_num, self.hidden_size//self.head_num).transpose(1,2) \ 118 | for layer,x in zip(self.linears,(q,k,v))] 119 | norm_d=q.shape[-1] 120 | att_map=torch.matmul(q,k.transpose(-2,-1)) / math.sqrt(norm_d) 121 | if mask is not None: 122 | att_map=att_map + mask 123 | att_map=F.softmax(att_map,dim=-1) 124 | # import ipdb 125 | # if att_map.shape[-1] == 45: 126 | # ipdb.set_trace() 127 | 128 | att_map=self.dropout(att_map) 129 | attn_output = self.linears[-1](torch.matmul(att_map,v).transpose(1,2).contiguous().view(batch_size,-1,self.hidden_size)) 130 | return attn_output 131 | 132 | 133 | class FeedForward(nn.Module): 134 | def __init__(self, config): 135 | super().__init__() 136 | self.linear1=nn.Linear(config.hidden_size, config.intermediate_size) 137 | self.linear2=nn.Linear(config.intermediate_size, config.hidden_size) 138 | self.activation = GELU() 139 | 140 | 141 | def forward(self,x): 142 | return self.linear2((self.activation(self.linear1(x)))) 143 | 144 | 145 | 146 | class TransformerEncoder(nn.Module): 147 | def __init__(self, config, mode = 'prenorm'): 148 | super().__init__() 149 | layer = TransformerLayer(config, mode) 150 | self.mode = mode 151 | self.layer = nn.ModuleList([copy.deepcopy(layer) 152 | for _ in range(config.num_hidden_layers)]) 153 | if self.mode == 'prenorm': 154 | self.last_layernorm = LayerNorm(config.hidden_size, eps=1e-12) 155 | self.checkpointing = config.checkpointing 156 | def forward(self, input_, attention_mask=None, cross_hidden_states=None, 157 | use_cache=False, 158 | cache=None, 159 | cache_first=False, 160 | cache_type=None): 161 | hidden_states = input_ 162 | for layer_module in self.layer: 163 | if self.checkpointing: 164 | hidden_states = torch.utils.checkpoint.checkpoint(layer_module, hidden_states, attention_mask) 165 | else: 166 | hidden_states = layer_module(hidden_states, attention_mask) 167 | 168 | if self.mode == 'prenorm': 169 | hidden_states = self.last_layernorm(hidden_states) 170 | return hidden_states, cache 171 | 172 | 173 | 174 | 175 | class AudioEmbeddings(nn.Module): 176 | def __init__(self, model_cfg_audio): 177 | super().__init__() 178 | 179 | self.patch_size = 16 180 | 181 | 182 | self.token_length_per_frame = (model_cfg_audio.audio_melbins // self.patch_size) * (model_cfg_audio.audio_target_length // self.patch_size) 183 | self.first_conv = nn.Conv2d(1, model_cfg_audio.hidden_size, kernel_size = self.patch_size, 184 | stride = self.patch_size, padding=0) 185 | self.position_embeddings = nn.Embedding(self.token_length_per_frame + 1, model_cfg_audio.hidden_size) 186 | self.dropout = nn.Dropout(model_cfg_audio.hidden_dropout) 187 | self.cls_token = nn.Parameter(0.02 * torch.randn(1, 1, model_cfg_audio.hidden_size)) 188 | 189 | def forward(self, audio_spectrograms): ### shape Bxn_sample_128x512 190 | 191 | audio_spectrograms = self.first_conv(audio_spectrograms.unsqueeze(1)) 192 | b,c,_,_=audio_spectrograms.shape 193 | audio_tokens = audio_spectrograms.permute(0,2,3,1).reshape(b,-1,c) 194 | cls_token = self.cls_token.expand(b,-1,-1) 195 | audio_tokens = torch.cat((cls_token,audio_tokens),dim=1) 196 | audio_pos_ids = list(range(self.token_length_per_frame + 1)) 197 | audio_pos_ids = torch.tensor(audio_pos_ids, dtype=torch.long, device=audio_spectrograms.device).unsqueeze(0) 198 | position_embeddings = self.position_embeddings(audio_pos_ids) 199 | embeddings = audio_tokens + position_embeddings 200 | embeddings = self.dropout(embeddings) 201 | return embeddings 202 | -------------------------------------------------------------------------------- /data/model/vision_encoders/clip/clip_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | from typing import Any, Union, List 6 | from pkg_resources import packaging 7 | import ftfy 8 | import regex as re 9 | import torch 10 | 11 | 12 | @lru_cache() 13 | def default_bpe(): 14 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 15 | 16 | 17 | @lru_cache() 18 | def bytes_to_unicode(): 19 | """ 20 | Returns list of utf-8 byte and a corresponding list of unicode strings. 21 | The reversible bpe codes work on unicode strings. 22 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 23 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 24 | This is a signficant percentage of your normal, say, 32K bpe vocab. 25 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 26 | And avoids mapping to whitespace/control characters the bpe code barfs on. 27 | """ 28 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 29 | cs = bs[:] 30 | n = 0 31 | for b in range(2**8): 32 | if b not in bs: 33 | bs.append(b) 34 | cs.append(2**8+n) 35 | n += 1 36 | cs = [chr(n) for n in cs] 37 | return dict(zip(bs, cs)) 38 | 39 | 40 | def get_pairs(word): 41 | """Return set of symbol pairs in a word. 42 | Word is represented as tuple of symbols (symbols being variable-length strings). 43 | """ 44 | pairs = set() 45 | prev_char = word[0] 46 | for char in word[1:]: 47 | pairs.add((prev_char, char)) 48 | prev_char = char 49 | return pairs 50 | 51 | 52 | def basic_clean(text): 53 | text = ftfy.fix_text(text) 54 | text = html.unescape(html.unescape(text)) 55 | return text.strip() 56 | 57 | 58 | def whitespace_clean(text): 59 | text = re.sub(r'\s+', ' ', text) 60 | text = text.strip() 61 | return text 62 | 63 | 64 | class SimpleTokenizer(object): 65 | def __init__(self, bpe_path: str = default_bpe()): 66 | self.byte_encoder = bytes_to_unicode() 67 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 68 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 69 | merges = merges[1:49152-256-2+1] 70 | merges = [tuple(merge.split()) for merge in merges] 71 | vocab = list(bytes_to_unicode().values()) 72 | vocab = vocab + [v+'' for v in vocab] 73 | for merge in merges: 74 | vocab.append(''.join(merge)) 75 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 76 | self.encoder = dict(zip(vocab, range(len(vocab)))) 77 | self.decoder = {v: k for k, v in self.encoder.items()} 78 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 79 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 80 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 81 | 82 | def bpe(self, token): 83 | if token in self.cache: 84 | return self.cache[token] 85 | word = tuple(token[:-1]) + ( token[-1] + '',) 86 | pairs = get_pairs(word) 87 | 88 | if not pairs: 89 | return token+'' 90 | 91 | while True: 92 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 93 | if bigram not in self.bpe_ranks: 94 | break 95 | first, second = bigram 96 | new_word = [] 97 | i = 0 98 | while i < len(word): 99 | try: 100 | j = word.index(first, i) 101 | new_word.extend(word[i:j]) 102 | i = j 103 | except: 104 | new_word.extend(word[i:]) 105 | break 106 | 107 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 108 | new_word.append(first+second) 109 | i += 2 110 | else: 111 | new_word.append(word[i]) 112 | i += 1 113 | new_word = tuple(new_word) 114 | word = new_word 115 | if len(word) == 1: 116 | break 117 | else: 118 | pairs = get_pairs(word) 119 | word = ' '.join(word) 120 | self.cache[token] = word 121 | return word 122 | 123 | def encode(self, text): 124 | bpe_tokens = [] 125 | text = whitespace_clean(basic_clean(text)).lower() 126 | for token in re.findall(self.pat, text): 127 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 128 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 129 | return bpe_tokens 130 | 131 | def decode(self, tokens): 132 | text = ''.join([self.decoder[token] for token in tokens]) 133 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 134 | return text 135 | 136 | 137 | 138 | _tokenizer = SimpleTokenizer() 139 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: 140 | """ 141 | Returns the tokenized representation of given input string(s) 142 | 143 | Parameters 144 | ---------- 145 | texts : Union[str, List[str]] 146 | An input string or a list of input strings to tokenize 147 | 148 | context_length : int 149 | The context length to use; all CLIP models use 77 as the context length 150 | 151 | truncate: bool 152 | Whether to truncate the text in case its encoding is longer than the context length 153 | 154 | Returns 155 | ------- 156 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. 157 | We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. 158 | """ 159 | if isinstance(texts, str): 160 | texts = [texts] 161 | 162 | sot_token = _tokenizer.encoder["<|startoftext|>"] 163 | eot_token = _tokenizer.encoder["<|endoftext|>"] 164 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 165 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): 166 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 167 | else: 168 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) 169 | 170 | for i, tokens in enumerate(all_tokens): 171 | if len(tokens) > context_length: 172 | if truncate: 173 | tokens = tokens[:context_length] 174 | tokens[-1] = eot_token 175 | else: 176 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 177 | result[i, :len(tokens)] = torch.tensor(tokens) 178 | 179 | return result 180 | -------------------------------------------------------------------------------- /data/model/vision_encoders/evaclip/__init__.py: -------------------------------------------------------------------------------- 1 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 2 | from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer 3 | from .factory import list_models, add_model_config, get_model_config, load_checkpoint 4 | # from .loss import ClipLoss 5 | from .model import CLIP, CustomCLIP, CLIPTextCfg, CLIPVisionCfg,\ 6 | convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype 7 | from .openai import load_openai_model, list_openai_models 8 | from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model,\ 9 | get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained 10 | from .tokenizer import SimpleTokenizer, tokenize 11 | from .transform import image_transform -------------------------------------------------------------------------------- /data/model/vision_encoders/evaclip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/invictus717/MiCo/831847ff066296d185c7f304563326224b9dbc8b/data/model/vision_encoders/evaclip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /data/model/vision_encoders/evaclip/constants.py: -------------------------------------------------------------------------------- 1 | OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) 2 | OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) 3 | -------------------------------------------------------------------------------- /data/model/vision_encoders/evaclip/hf_configs.py: -------------------------------------------------------------------------------- 1 | # HF architecture dict: 2 | arch_dict = { 3 | # https://huggingface.co/docs/transformers/model_doc/roberta#roberta 4 | "roberta": { 5 | "config_names": { 6 | "context_length": "max_position_embeddings", 7 | "vocab_size": "vocab_size", 8 | "width": "hidden_size", 9 | "heads": "num_attention_heads", 10 | "layers": "num_hidden_layers", 11 | "layer_attr": "layer", 12 | "token_embeddings_attr": "embeddings" 13 | }, 14 | "pooler": "mean_pooler", 15 | }, 16 | # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig 17 | "xlm-roberta": { 18 | "config_names": { 19 | "context_length": "max_position_embeddings", 20 | "vocab_size": "vocab_size", 21 | "width": "hidden_size", 22 | "heads": "num_attention_heads", 23 | "layers": "num_hidden_layers", 24 | "layer_attr": "layer", 25 | "token_embeddings_attr": "embeddings" 26 | }, 27 | "pooler": "mean_pooler", 28 | }, 29 | # https://huggingface.co/docs/transformers/model_doc/mt5#mt5 30 | "mt5": { 31 | "config_names": { 32 | # unlimited seqlen 33 | # https://github.com/google-research/text-to-text-transfer-transformer/issues/273 34 | # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374 35 | "context_length": "", 36 | "vocab_size": "vocab_size", 37 | "width": "d_model", 38 | "heads": "num_heads", 39 | "layers": "num_layers", 40 | "layer_attr": "block", 41 | "token_embeddings_attr": "embed_tokens" 42 | }, 43 | "pooler": "mean_pooler", 44 | }, 45 | "bert": { 46 | "config_names": { 47 | "context_length": "max_position_embeddings", 48 | "vocab_size": "vocab_size", 49 | "width": "hidden_size", 50 | "heads": "num_attention_heads", 51 | "layers": "num_hidden_layers", 52 | "layer_attr": "layer", 53 | "token_embeddings_attr": "embeddings" 54 | }, 55 | "pooler": "mean_pooler", 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /data/model/vision_encoders/evaclip/loss.py: -------------------------------------------------------------------------------- 1 | # import math 2 | # import torch 3 | # import torch.nn as nn 4 | # from torch.nn import functional as F 5 | 6 | # try: 7 | # import torch.distributed.nn 8 | # from torch import distributed as dist 9 | # has_distributed = True 10 | # except ImportError: 11 | # has_distributed = False 12 | 13 | # try: 14 | # import horovod.torch as hvd 15 | # except ImportError: 16 | # hvd = None 17 | 18 | # from timm.loss import LabelSmoothingCrossEntropy 19 | 20 | 21 | # def gather_features( 22 | # image_features, 23 | # text_features, 24 | # local_loss=False, 25 | # gather_with_grad=False, 26 | # rank=0, 27 | # world_size=1, 28 | # use_horovod=False 29 | # ): 30 | # assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.' 31 | # if use_horovod: 32 | # assert hvd is not None, 'Please install horovod' 33 | # if gather_with_grad: 34 | # all_image_features = hvd.allgather(image_features) 35 | # all_text_features = hvd.allgather(text_features) 36 | # else: 37 | # with torch.no_grad(): 38 | # all_image_features = hvd.allgather(image_features) 39 | # all_text_features = hvd.allgather(text_features) 40 | # if not local_loss: 41 | # # ensure grads for local rank when all_* features don't have a gradient 42 | # gathered_image_features = list(all_image_features.chunk(world_size, dim=0)) 43 | # gathered_text_features = list(all_text_features.chunk(world_size, dim=0)) 44 | # gathered_image_features[rank] = image_features 45 | # gathered_text_features[rank] = text_features 46 | # all_image_features = torch.cat(gathered_image_features, dim=0) 47 | # all_text_features = torch.cat(gathered_text_features, dim=0) 48 | # else: 49 | # # We gather tensors from all gpus 50 | # if gather_with_grad: 51 | # all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0) 52 | # all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0) 53 | # # all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features, async_op=True), dim=0) 54 | # # all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features, async_op=True), dim=0) 55 | # else: 56 | # gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)] 57 | # gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)] 58 | # dist.all_gather(gathered_image_features, image_features) 59 | # dist.all_gather(gathered_text_features, text_features) 60 | # if not local_loss: 61 | # # ensure grads for local rank when all_* features don't have a gradient 62 | # gathered_image_features[rank] = image_features 63 | # gathered_text_features[rank] = text_features 64 | # all_image_features = torch.cat(gathered_image_features, dim=0) 65 | # all_text_features = torch.cat(gathered_text_features, dim=0) 66 | 67 | # return all_image_features, all_text_features 68 | 69 | 70 | # class ClipLoss(nn.Module): 71 | 72 | # def __init__( 73 | # self, 74 | # local_loss=False, 75 | # gather_with_grad=False, 76 | # cache_labels=False, 77 | # rank=0, 78 | # world_size=1, 79 | # use_horovod=False, 80 | # smoothing=0., 81 | # ): 82 | # super().__init__() 83 | # self.local_loss = local_loss 84 | # self.gather_with_grad = gather_with_grad 85 | # self.cache_labels = cache_labels 86 | # self.rank = rank 87 | # self.world_size = world_size 88 | # self.use_horovod = use_horovod 89 | # self.label_smoothing_cross_entropy = LabelSmoothingCrossEntropy(smoothing=smoothing) if smoothing > 0 else None 90 | 91 | # # cache state 92 | # self.prev_num_logits = 0 93 | # self.labels = {} 94 | 95 | # def forward(self, image_features, text_features, logit_scale=1.): 96 | # device = image_features.device 97 | # if self.world_size > 1: 98 | # all_image_features, all_text_features = gather_features( 99 | # image_features, text_features, 100 | # self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) 101 | 102 | # if self.local_loss: 103 | # logits_per_image = logit_scale * image_features @ all_text_features.T 104 | # logits_per_text = logit_scale * text_features @ all_image_features.T 105 | # else: 106 | # logits_per_image = logit_scale * all_image_features @ all_text_features.T 107 | # logits_per_text = logits_per_image.T 108 | # else: 109 | # logits_per_image = logit_scale * image_features @ text_features.T 110 | # logits_per_text = logit_scale * text_features @ image_features.T 111 | # # calculated ground-truth and cache if enabled 112 | # num_logits = logits_per_image.shape[0] 113 | # if self.prev_num_logits != num_logits or device not in self.labels: 114 | # labels = torch.arange(num_logits, device=device, dtype=torch.long) 115 | # if self.world_size > 1 and self.local_loss: 116 | # labels = labels + num_logits * self.rank 117 | # if self.cache_labels: 118 | # self.labels[device] = labels 119 | # self.prev_num_logits = num_logits 120 | # else: 121 | # labels = self.labels[device] 122 | 123 | # if self.label_smoothing_cross_entropy: 124 | # total_loss = ( 125 | # self.label_smoothing_cross_entropy(logits_per_image, labels) + 126 | # self.label_smoothing_cross_entropy(logits_per_text, labels) 127 | # ) / 2 128 | # else: 129 | # total_loss = ( 130 | # F.cross_entropy(logits_per_image, labels) + 131 | # F.cross_entropy(logits_per_text, labels) 132 | # ) / 2 133 | 134 | # acc = None 135 | # i2t_acc = (logits_per_image.argmax(-1) == labels).sum() / len(logits_per_image) 136 | # t2i_acc = (logits_per_text.argmax(-1) == labels).sum() / len(logits_per_text) 137 | # acc = {"i2t": i2t_acc, "t2i": t2i_acc} 138 | # return total_loss, acc -------------------------------------------------------------------------------- /data/model/vision_encoders/evaclip/model_configs/EVA01-CLIP-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 16, 8 | "eva_model_name": "eva-clip-b-16", 9 | "ls_init_value": 0.1, 10 | "drop_path_rate": 0.0 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /data/model/vision_encoders/evaclip/model_configs/EVA01-CLIP-g-14-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-g-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "fusedLN": true 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 1024, 19 | "heads": 16, 20 | "layers": 24, 21 | "xattn": false, 22 | "fusedLN": true 23 | } 24 | } -------------------------------------------------------------------------------- /data/model/vision_encoders/evaclip/model_configs/EVA01-CLIP-g-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-g-14-x", 11 | "drop_path_rate": 0.4, 12 | "xattn": true, 13 | "fusedLN": true 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 768, 19 | "heads": 12, 20 | "layers": 12, 21 | "xattn": false, 22 | "fusedLN": true 23 | } 24 | } -------------------------------------------------------------------------------- /data/model/vision_encoders/evaclip/model_configs/EVA02-CLIP-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "head_width": 64, 8 | "patch_size": 16, 9 | "mlp_ratio": 2.6667, 10 | "eva_model_name": "eva-clip-b-16-X", 11 | "drop_path_rate": 0.0, 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 512, 24 | "heads": 8, 25 | "layers": 12, 26 | "xattn": true, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /data/model/vision_encoders/evaclip/model_configs/EVA02-CLIP-L-14-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "layers": 24, 6 | "width": 1024, 7 | "drop_path_rate": 0, 8 | "head_width": 64, 9 | "mlp_ratio": 2.6667, 10 | "patch_size": 14, 11 | "eva_model_name": "eva-clip-l-14-336", 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 768, 24 | "heads": 12, 25 | "layers": 12, 26 | "xattn": false, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /data/model/vision_encoders/evaclip/model_configs/EVA02-CLIP-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "drop_path_rate": 0, 8 | "head_width": 64, 9 | "mlp_ratio": 2.6667, 10 | "patch_size": 14, 11 | "eva_model_name": "eva-clip-l-14", 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 768, 24 | "heads": 12, 25 | "layers": 12, 26 | "xattn": false, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /data/model/vision_encoders/evaclip/model_configs/EVA02-CLIP-bigE-14-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 64, 6 | "width": 1792, 7 | "head_width": 112, 8 | "mlp_ratio": 8.571428571428571, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-4b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": true, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1280, 20 | "heads": 20, 21 | "layers": 32, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /data/model/vision_encoders/evaclip/model_configs/EVA02-CLIP-bigE-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 64, 6 | "width": 1792, 7 | "head_width": 112, 8 | "mlp_ratio": 8.571428571428571, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-4b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": true, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1024, 20 | "heads": 16, 21 | "layers": 24, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } -------------------------------------------------------------------------------- /data/model/vision_encoders/evaclip/modified_resnet.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from .utils import freeze_batch_norm_2d 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, inplanes, planes, stride=1): 14 | super().__init__() 15 | 16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.act1 = nn.ReLU(inplace=True) 20 | 21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.act2 = nn.ReLU(inplace=True) 24 | 25 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 26 | 27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 29 | self.act3 = nn.ReLU(inplace=True) 30 | 31 | self.downsample = None 32 | self.stride = stride 33 | 34 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 35 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 36 | self.downsample = nn.Sequential(OrderedDict([ 37 | ("-1", nn.AvgPool2d(stride)), 38 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 39 | ("1", nn.BatchNorm2d(planes * self.expansion)) 40 | ])) 41 | 42 | def forward(self, x: torch.Tensor): 43 | identity = x 44 | 45 | out = self.act1(self.bn1(self.conv1(x))) 46 | out = self.act2(self.bn2(self.conv2(out))) 47 | out = self.avgpool(out) 48 | out = self.bn3(self.conv3(out)) 49 | 50 | if self.downsample is not None: 51 | identity = self.downsample(x) 52 | 53 | out += identity 54 | out = self.act3(out) 55 | return out 56 | 57 | 58 | class AttentionPool2d(nn.Module): 59 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 60 | super().__init__() 61 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 62 | self.k_proj = nn.Linear(embed_dim, embed_dim) 63 | self.q_proj = nn.Linear(embed_dim, embed_dim) 64 | self.v_proj = nn.Linear(embed_dim, embed_dim) 65 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 66 | self.num_heads = num_heads 67 | 68 | def forward(self, x): 69 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 70 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 71 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 72 | x, _ = F.multi_head_attention_forward( 73 | query=x, key=x, value=x, 74 | embed_dim_to_check=x.shape[-1], 75 | num_heads=self.num_heads, 76 | q_proj_weight=self.q_proj.weight, 77 | k_proj_weight=self.k_proj.weight, 78 | v_proj_weight=self.v_proj.weight, 79 | in_proj_weight=None, 80 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 81 | bias_k=None, 82 | bias_v=None, 83 | add_zero_attn=False, 84 | dropout_p=0., 85 | out_proj_weight=self.c_proj.weight, 86 | out_proj_bias=self.c_proj.bias, 87 | use_separate_proj_weight=True, 88 | training=self.training, 89 | need_weights=False 90 | ) 91 | 92 | return x[0] 93 | 94 | 95 | class ModifiedResNet(nn.Module): 96 | """ 97 | A ResNet class that is similar to torchvision's but contains the following changes: 98 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 99 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 100 | - The final pooling layer is a QKV attention instead of an average pool 101 | """ 102 | 103 | def __init__(self, layers, output_dim, heads, image_size=224, width=64): 104 | super().__init__() 105 | self.output_dim = output_dim 106 | self.image_size = image_size 107 | 108 | # the 3-layer stem 109 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 110 | self.bn1 = nn.BatchNorm2d(width // 2) 111 | self.act1 = nn.ReLU(inplace=True) 112 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 113 | self.bn2 = nn.BatchNorm2d(width // 2) 114 | self.act2 = nn.ReLU(inplace=True) 115 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 116 | self.bn3 = nn.BatchNorm2d(width) 117 | self.act3 = nn.ReLU(inplace=True) 118 | self.avgpool = nn.AvgPool2d(2) 119 | 120 | # residual layers 121 | self._inplanes = width # this is a *mutable* variable used during construction 122 | self.layer1 = self._make_layer(width, layers[0]) 123 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 124 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 125 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 126 | 127 | embed_dim = width * 32 # the ResNet feature dimension 128 | self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) 129 | 130 | self.init_parameters() 131 | 132 | def _make_layer(self, planes, blocks, stride=1): 133 | layers = [Bottleneck(self._inplanes, planes, stride)] 134 | 135 | self._inplanes = planes * Bottleneck.expansion 136 | for _ in range(1, blocks): 137 | layers.append(Bottleneck(self._inplanes, planes)) 138 | 139 | return nn.Sequential(*layers) 140 | 141 | def init_parameters(self): 142 | if self.attnpool is not None: 143 | std = self.attnpool.c_proj.in_features ** -0.5 144 | nn.init.normal_(self.attnpool.q_proj.weight, std=std) 145 | nn.init.normal_(self.attnpool.k_proj.weight, std=std) 146 | nn.init.normal_(self.attnpool.v_proj.weight, std=std) 147 | nn.init.normal_(self.attnpool.c_proj.weight, std=std) 148 | 149 | for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: 150 | for name, param in resnet_block.named_parameters(): 151 | if name.endswith("bn3.weight"): 152 | nn.init.zeros_(param) 153 | 154 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 155 | assert unlocked_groups == 0, 'partial locking not currently supported for this model' 156 | for param in self.parameters(): 157 | param.requires_grad = False 158 | if freeze_bn_stats: 159 | freeze_batch_norm_2d(self) 160 | 161 | @torch.jit.ignore 162 | def set_grad_checkpointing(self, enable=True): 163 | # FIXME support for non-transformer 164 | pass 165 | 166 | def stem(self, x): 167 | x = self.act1(self.bn1(self.conv1(x))) 168 | x = self.act2(self.bn2(self.conv2(x))) 169 | x = self.act3(self.bn3(self.conv3(x))) 170 | x = self.avgpool(x) 171 | return x 172 | 173 | def forward(self, x): 174 | x = self.stem(x) 175 | x = self.layer1(x) 176 | x = self.layer2(x) 177 | x = self.layer3(x) 178 | x = self.layer4(x) 179 | x = self.attnpool(x) 180 | 181 | return x 182 | -------------------------------------------------------------------------------- /data/model/vision_encoders/evaclip/openai.py: -------------------------------------------------------------------------------- 1 | """ OpenAI pretrained model functions 2 | 3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | 6 | import os 7 | import warnings 8 | from typing import List, Optional, Union 9 | 10 | import torch 11 | 12 | from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype 13 | from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url 14 | 15 | __all__ = ["list_openai_models", "load_openai_model"] 16 | 17 | 18 | def list_openai_models() -> List[str]: 19 | """Returns the names of available CLIP models""" 20 | return list_pretrained_models_by_tag('openai') 21 | 22 | 23 | def load_openai_model( 24 | name: str, 25 | precision: Optional[str] = None, 26 | device: Optional[Union[str, torch.device]] = None, 27 | jit: bool = True, 28 | cache_dir: Optional[str] = None, 29 | ): 30 | """Load a CLIP model 31 | 32 | Parameters 33 | ---------- 34 | name : str 35 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 36 | precision: str 37 | Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. 38 | device : Union[str, torch.device] 39 | The device to put the loaded model 40 | jit : bool 41 | Whether to load the optimized JIT model (default) or more hackable non-JIT model. 42 | cache_dir : Optional[str] 43 | The directory to cache the downloaded model weights 44 | 45 | Returns 46 | ------- 47 | model : torch.nn.Module 48 | The CLIP model 49 | preprocess : Callable[[PIL.Image], torch.Tensor] 50 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 51 | """ 52 | if device is None: 53 | device = "cuda" if torch.cuda.is_available() else "cpu" 54 | if precision is None: 55 | precision = 'fp32' if device == 'cpu' else 'fp16' 56 | 57 | if get_pretrained_url(name, 'openai'): 58 | model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir) 59 | elif os.path.isfile(name): 60 | model_path = name 61 | else: 62 | raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") 63 | 64 | try: 65 | # loading JIT archive 66 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 67 | state_dict = None 68 | except RuntimeError: 69 | # loading saved state dict 70 | if jit: 71 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 72 | jit = False 73 | state_dict = torch.load(model_path, map_location="cpu") 74 | 75 | if not jit: 76 | # Build a non-jit model from the OpenAI jitted model state dict 77 | cast_dtype = get_cast_dtype(precision) 78 | try: 79 | model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype) 80 | except KeyError: 81 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} 82 | model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype) 83 | 84 | # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use 85 | model = model.to(device) 86 | if precision.startswith('amp') or precision == 'fp32': 87 | model.float() 88 | elif precision == 'bf16': 89 | convert_weights_to_lp(model, dtype=torch.bfloat16) 90 | 91 | return model 92 | 93 | # patch the device names 94 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 95 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 96 | 97 | def patch_device(module): 98 | try: 99 | graphs = [module.graph] if hasattr(module, "graph") else [] 100 | except RuntimeError: 101 | graphs = [] 102 | 103 | if hasattr(module, "forward1"): 104 | graphs.append(module.forward1.graph) 105 | 106 | for graph in graphs: 107 | for node in graph.findAllNodes("prim::Constant"): 108 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 109 | node.copyAttributes(device_node) 110 | 111 | model.apply(patch_device) 112 | patch_device(model.encode_image) 113 | patch_device(model.encode_text) 114 | 115 | # patch dtype to float32 (typically for CPU) 116 | if precision == 'fp32': 117 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 118 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 119 | float_node = float_input.node() 120 | 121 | def patch_float(module): 122 | try: 123 | graphs = [module.graph] if hasattr(module, "graph") else [] 124 | except RuntimeError: 125 | graphs = [] 126 | 127 | if hasattr(module, "forward1"): 128 | graphs.append(module.forward1.graph) 129 | 130 | for graph in graphs: 131 | for node in graph.findAllNodes("aten::to"): 132 | inputs = list(node.inputs()) 133 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 134 | if inputs[i].node()["value"] == 5: 135 | inputs[i].node().copyAttributes(float_node) 136 | 137 | model.apply(patch_float) 138 | patch_float(model.encode_image) 139 | patch_float(model.encode_text) 140 | model.float() 141 | 142 | # ensure image_size attr available at consistent location for both jit and non-jit 143 | model.visual.image_size = model.input_resolution.item() 144 | return model 145 | -------------------------------------------------------------------------------- /data/model/vision_encoders/evaclip/rope.py: -------------------------------------------------------------------------------- 1 | from math import pi 2 | import torch 3 | from torch import nn 4 | from einops import rearrange, repeat 5 | import logging 6 | 7 | def broadcat(tensors, dim = -1): 8 | num_tensors = len(tensors) 9 | shape_lens = set(list(map(lambda t: len(t.shape), tensors))) 10 | assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions' 11 | shape_len = list(shape_lens)[0] 12 | dim = (dim + shape_len) if dim < 0 else dim 13 | dims = list(zip(*map(lambda t: list(t.shape), tensors))) 14 | expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] 15 | assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation' 16 | max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) 17 | expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) 18 | expanded_dims.insert(dim, (dim, dims[dim])) 19 | expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) 20 | tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) 21 | return torch.cat(tensors, dim = dim) 22 | 23 | def rotate_half(x): 24 | x = rearrange(x, '... (d r) -> ... d r', r = 2) 25 | x1, x2 = x.unbind(dim = -1) 26 | x = torch.stack((-x2, x1), dim = -1) 27 | return rearrange(x, '... d r -> ... (d r)') 28 | 29 | 30 | class VisionRotaryEmbedding(nn.Module): 31 | def __init__( 32 | self, 33 | dim, 34 | pt_seq_len, 35 | ft_seq_len=None, 36 | custom_freqs = None, 37 | freqs_for = 'lang', 38 | theta = 10000, 39 | max_freq = 10, 40 | num_freqs = 1, 41 | ): 42 | super().__init__() 43 | if custom_freqs: 44 | freqs = custom_freqs 45 | elif freqs_for == 'lang': 46 | freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) 47 | elif freqs_for == 'pixel': 48 | freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi 49 | elif freqs_for == 'constant': 50 | freqs = torch.ones(num_freqs).float() 51 | else: 52 | raise ValueError(f'unknown modality {freqs_for}') 53 | 54 | if ft_seq_len is None: ft_seq_len = pt_seq_len 55 | t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len 56 | 57 | freqs_h = torch.einsum('..., f -> ... f', t, freqs) 58 | freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2) 59 | 60 | freqs_w = torch.einsum('..., f -> ... f', t, freqs) 61 | freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2) 62 | 63 | freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1) 64 | 65 | self.register_buffer("freqs_cos", freqs.cos()) 66 | self.register_buffer("freqs_sin", freqs.sin()) 67 | 68 | logging.info(f'Shape of rope freq: {self.freqs_cos.shape}') 69 | 70 | def forward(self, t, start_index = 0): 71 | rot_dim = self.freqs_cos.shape[-1] 72 | end_index = start_index + rot_dim 73 | assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}' 74 | t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:] 75 | t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin) 76 | 77 | return torch.cat((t_left, t, t_right), dim = -1) 78 | 79 | class VisionRotaryEmbeddingFast(nn.Module): 80 | def __init__( 81 | self, 82 | dim, 83 | pt_seq_len, 84 | ft_seq_len=None, 85 | custom_freqs = None, 86 | freqs_for = 'lang', 87 | theta = 10000, 88 | max_freq = 10, 89 | num_freqs = 1, 90 | patch_dropout = 0. 91 | ): 92 | super().__init__() 93 | if custom_freqs: 94 | freqs = custom_freqs 95 | elif freqs_for == 'lang': 96 | freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) 97 | elif freqs_for == 'pixel': 98 | freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi 99 | elif freqs_for == 'constant': 100 | freqs = torch.ones(num_freqs).float() 101 | else: 102 | raise ValueError(f'unknown modality {freqs_for}') 103 | 104 | if ft_seq_len is None: ft_seq_len = pt_seq_len 105 | t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len 106 | 107 | freqs = torch.einsum('..., f -> ... f', t, freqs) 108 | freqs = repeat(freqs, '... n -> ... (n r)', r = 2) 109 | freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1) 110 | 111 | freqs_cos = freqs.cos().view(-1, freqs.shape[-1]) 112 | freqs_sin = freqs.sin().view(-1, freqs.shape[-1]) 113 | 114 | self.patch_dropout = patch_dropout 115 | 116 | self.register_buffer("freqs_cos", freqs_cos) 117 | self.register_buffer("freqs_sin", freqs_sin) 118 | 119 | logging.info(f'Shape of rope freq: {self.freqs_cos.shape}') 120 | 121 | def forward(self, t, patch_indices_keep=None): 122 | if patch_indices_keep is not None: 123 | batch = t.size()[0] 124 | batch_indices = torch.arange(batch) 125 | batch_indices = batch_indices[..., None] 126 | 127 | freqs_cos = repeat(self.freqs_cos, 'i j -> n i m j', n=t.shape[0], m=t.shape[1]) 128 | freqs_sin = repeat(self.freqs_sin, 'i j -> n i m j', n=t.shape[0], m=t.shape[1]) 129 | 130 | freqs_cos = freqs_cos[batch_indices, patch_indices_keep] 131 | freqs_cos = rearrange(freqs_cos, 'n i m j -> n m i j') 132 | freqs_sin = freqs_sin[batch_indices, patch_indices_keep] 133 | freqs_sin = rearrange(freqs_sin, 'n i m j -> n m i j') 134 | 135 | return t * freqs_cos + rotate_half(t) * freqs_sin 136 | 137 | return t * self.freqs_cos + rotate_half(t) * self.freqs_sin -------------------------------------------------------------------------------- /data/model/vision_encoders/evaclip/timm_model.py: -------------------------------------------------------------------------------- 1 | """ timm model adapter 2 | 3 | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. 4 | """ 5 | import logging 6 | from collections import OrderedDict 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | try: 12 | import timm 13 | from timm.models.layers import Mlp, to_2tuple 14 | try: 15 | # old timm imports < 0.8.1 16 | from timm.models.layers.attention_pool2d import RotAttentionPool2d 17 | from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d 18 | except ImportError: 19 | # new timm imports >= 0.8.1 20 | from timm.layers import RotAttentionPool2d 21 | from timm.layers import AttentionPool2d as AbsAttentionPool2d 22 | except ImportError: 23 | timm = None 24 | 25 | from .utils import freeze_batch_norm_2d 26 | 27 | 28 | class TimmModel(nn.Module): 29 | """ timm model adapter 30 | # FIXME this adapter is a work in progress, may change in ways that break weight compat 31 | """ 32 | 33 | def __init__( 34 | self, 35 | model_name, 36 | embed_dim, 37 | image_size=224, 38 | pool='avg', 39 | proj='linear', 40 | proj_bias=False, 41 | drop=0., 42 | pretrained=False): 43 | super().__init__() 44 | if timm is None: 45 | raise RuntimeError("Please `pip install timm` to use timm models.") 46 | 47 | self.image_size = to_2tuple(image_size) 48 | self.trunk = timm.create_model(model_name, pretrained=pretrained) 49 | feat_size = self.trunk.default_cfg.get('pool_size', None) 50 | feature_ndim = 1 if not feat_size else 2 51 | if pool in ('abs_attn', 'rot_attn'): 52 | assert feature_ndim == 2 53 | # if attn pooling used, remove both classifier and default pool 54 | self.trunk.reset_classifier(0, global_pool='') 55 | else: 56 | # reset global pool if pool config set, otherwise leave as network default 57 | reset_kwargs = dict(global_pool=pool) if pool else {} 58 | self.trunk.reset_classifier(0, **reset_kwargs) 59 | prev_chs = self.trunk.num_features 60 | 61 | head_layers = OrderedDict() 62 | if pool == 'abs_attn': 63 | head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) 64 | prev_chs = embed_dim 65 | elif pool == 'rot_attn': 66 | head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim) 67 | prev_chs = embed_dim 68 | else: 69 | assert proj, 'projection layer needed if non-attention pooling is used.' 70 | 71 | # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used 72 | if proj == 'linear': 73 | head_layers['drop'] = nn.Dropout(drop) 74 | head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) 75 | elif proj == 'mlp': 76 | head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop, bias=(True, proj_bias)) 77 | 78 | self.head = nn.Sequential(head_layers) 79 | 80 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 81 | """ lock modules 82 | Args: 83 | unlocked_groups (int): leave last n layer groups unlocked (default: 0) 84 | """ 85 | if not unlocked_groups: 86 | # lock full model 87 | for param in self.trunk.parameters(): 88 | param.requires_grad = False 89 | if freeze_bn_stats: 90 | freeze_batch_norm_2d(self.trunk) 91 | else: 92 | # NOTE: partial freeze requires latest timm (master) branch and is subject to change 93 | try: 94 | # FIXME import here until API stable and in an official release 95 | from timm.models.helpers import group_parameters, group_modules 96 | except ImportError: 97 | raise RuntimeError( 98 | 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`') 99 | matcher = self.trunk.group_matcher() 100 | gparams = group_parameters(self.trunk, matcher) 101 | max_layer_id = max(gparams.keys()) 102 | max_layer_id = max_layer_id - unlocked_groups 103 | for group_idx in range(max_layer_id + 1): 104 | group = gparams[group_idx] 105 | for param in group: 106 | self.trunk.get_parameter(param).requires_grad = False 107 | if freeze_bn_stats: 108 | gmodules = group_modules(self.trunk, matcher, reverse=True) 109 | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} 110 | freeze_batch_norm_2d(self.trunk, gmodules) 111 | 112 | @torch.jit.ignore 113 | def set_grad_checkpointing(self, enable=True): 114 | try: 115 | self.trunk.set_grad_checkpointing(enable) 116 | except Exception as e: 117 | logging.warning('grad checkpointing not supported for this timm image tower, continuing without...') 118 | 119 | def forward(self, x): 120 | x = self.trunk(x) 121 | x = self.head(x) 122 | return x 123 | -------------------------------------------------------------------------------- /data/model/vision_encoders/evaclip/tokenizer.py: -------------------------------------------------------------------------------- 1 | """ CLIP tokenizer 2 | 3 | Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | import gzip 6 | import html 7 | import os 8 | from functools import lru_cache 9 | from typing import Union, List 10 | 11 | # import ftfy 12 | import regex as re 13 | import torch 14 | 15 | # https://stackoverflow.com/q/62691279 16 | import os 17 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 18 | 19 | 20 | @lru_cache() 21 | def default_bpe(): 22 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 23 | 24 | 25 | @lru_cache() 26 | def bytes_to_unicode(): 27 | """ 28 | Returns list of utf-8 byte and a corresponding list of unicode strings. 29 | The reversible bpe codes work on unicode strings. 30 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 31 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 32 | This is a signficant percentage of your normal, say, 32K bpe vocab. 33 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 34 | And avoids mapping to whitespace/control characters the bpe code barfs on. 35 | """ 36 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 37 | cs = bs[:] 38 | n = 0 39 | for b in range(2**8): 40 | if b not in bs: 41 | bs.append(b) 42 | cs.append(2**8+n) 43 | n += 1 44 | cs = [chr(n) for n in cs] 45 | return dict(zip(bs, cs)) 46 | 47 | 48 | def get_pairs(word): 49 | """Return set of symbol pairs in a word. 50 | Word is represented as tuple of symbols (symbols being variable-length strings). 51 | """ 52 | pairs = set() 53 | prev_char = word[0] 54 | for char in word[1:]: 55 | pairs.add((prev_char, char)) 56 | prev_char = char 57 | return pairs 58 | 59 | 60 | def basic_clean(text): 61 | text = ftfy.fix_text(text) 62 | text = html.unescape(html.unescape(text)) 63 | return text.strip() 64 | 65 | 66 | def whitespace_clean(text): 67 | text = re.sub(r'\s+', ' ', text) 68 | text = text.strip() 69 | return text 70 | 71 | 72 | class SimpleTokenizer(object): 73 | def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): 74 | self.byte_encoder = bytes_to_unicode() 75 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 76 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 77 | merges = merges[1:49152-256-2+1] 78 | merges = [tuple(merge.split()) for merge in merges] 79 | vocab = list(bytes_to_unicode().values()) 80 | vocab = vocab + [v+'' for v in vocab] 81 | for merge in merges: 82 | vocab.append(''.join(merge)) 83 | if not special_tokens: 84 | special_tokens = ['', ''] 85 | else: 86 | special_tokens = ['', ''] + special_tokens 87 | vocab.extend(special_tokens) 88 | self.encoder = dict(zip(vocab, range(len(vocab)))) 89 | self.decoder = {v: k for k, v in self.encoder.items()} 90 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 91 | self.cache = {t:t for t in special_tokens} 92 | special = "|".join(special_tokens) 93 | self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 94 | 95 | self.vocab_size = len(self.encoder) 96 | self.all_special_ids = [self.encoder[t] for t in special_tokens] 97 | 98 | def bpe(self, token): 99 | if token in self.cache: 100 | return self.cache[token] 101 | word = tuple(token[:-1]) + ( token[-1] + '',) 102 | pairs = get_pairs(word) 103 | 104 | if not pairs: 105 | return token+'' 106 | 107 | while True: 108 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 109 | if bigram not in self.bpe_ranks: 110 | break 111 | first, second = bigram 112 | new_word = [] 113 | i = 0 114 | while i < len(word): 115 | try: 116 | j = word.index(first, i) 117 | new_word.extend(word[i:j]) 118 | i = j 119 | except: 120 | new_word.extend(word[i:]) 121 | break 122 | 123 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 124 | new_word.append(first+second) 125 | i += 2 126 | else: 127 | new_word.append(word[i]) 128 | i += 1 129 | new_word = tuple(new_word) 130 | word = new_word 131 | if len(word) == 1: 132 | break 133 | else: 134 | pairs = get_pairs(word) 135 | word = ' '.join(word) 136 | self.cache[token] = word 137 | return word 138 | 139 | def encode(self, text): 140 | bpe_tokens = [] 141 | text = whitespace_clean(basic_clean(text)).lower() 142 | for token in re.findall(self.pat, text): 143 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 144 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 145 | return bpe_tokens 146 | 147 | def decode(self, tokens): 148 | text = ''.join([self.decoder[token] for token in tokens]) 149 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 150 | return text 151 | 152 | 153 | _tokenizer = SimpleTokenizer() 154 | 155 | 156 | def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: 157 | """ 158 | Returns the tokenized representation of given input string(s) 159 | 160 | Parameters 161 | ---------- 162 | texts : Union[str, List[str]] 163 | An input string or a list of input strings to tokenize 164 | context_length : int 165 | The context length to use; all CLIP models use 77 as the context length 166 | 167 | Returns 168 | ------- 169 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 170 | """ 171 | if isinstance(texts, str): 172 | texts = [texts] 173 | 174 | sot_token = _tokenizer.encoder[""] 175 | eot_token = _tokenizer.encoder[""] 176 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 177 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 178 | 179 | for i, tokens in enumerate(all_tokens): 180 | if len(tokens) > context_length: 181 | tokens = tokens[:context_length] # Truncate 182 | tokens[-1] = eot_token 183 | result[i, :len(tokens)] = torch.tensor(tokens) 184 | 185 | return result 186 | 187 | 188 | class HFTokenizer: 189 | "HuggingFace tokenizer wrapper" 190 | def __init__(self, tokenizer_name:str): 191 | from transformers import AutoTokenizer 192 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 193 | 194 | def __call__(self, texts:Union[str, List[str]], context_length:int=77) -> torch.Tensor: 195 | # same cleaning as for default tokenizer, except lowercasing 196 | # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance 197 | if isinstance(texts, str): 198 | texts = [texts] 199 | texts = [whitespace_clean(basic_clean(text)) for text in texts] 200 | input_ids = self.tokenizer(texts, return_tensors='pt', max_length=context_length, padding='max_length', truncation=True).input_ids 201 | return input_ids 202 | -------------------------------------------------------------------------------- /data/model/vision_encoders/evaclip/transform.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Sequence, Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torchvision.transforms.functional as F 6 | 7 | from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ 8 | CenterCrop 9 | 10 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 11 | 12 | 13 | class ResizeMaxSize(nn.Module): 14 | 15 | def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0): 16 | super().__init__() 17 | if not isinstance(max_size, int): 18 | raise TypeError(f"Size should be int. Got {type(max_size)}") 19 | self.max_size = max_size 20 | self.interpolation = interpolation 21 | self.fn = min if fn == 'min' else min 22 | self.fill = fill 23 | 24 | def forward(self, img): 25 | if isinstance(img, torch.Tensor): 26 | height, width = img.shape[:2] 27 | else: 28 | width, height = img.size 29 | scale = self.max_size / float(max(height, width)) 30 | if scale != 1.0: 31 | new_size = tuple(round(dim * scale) for dim in (height, width)) 32 | img = F.resize(img, new_size, self.interpolation) 33 | pad_h = self.max_size - new_size[0] 34 | pad_w = self.max_size - new_size[1] 35 | img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill) 36 | return img 37 | 38 | 39 | def _convert_to_rgb(image): 40 | return image.convert('RGB') 41 | 42 | 43 | # class CatGen(nn.Module): 44 | # def __init__(self, num=4): 45 | # self.num = num 46 | # def mixgen_batch(image, text): 47 | # batch_size = image.shape[0] 48 | # index = np.random.permutation(batch_size) 49 | 50 | # cat_images = [] 51 | # for i in range(batch_size): 52 | # # image mixup 53 | # image[i,:] = lam * image[i,:] + (1 - lam) * image[index[i],:] 54 | # # text concat 55 | # text[i] = tokenizer((str(text[i]) + " " + str(text[index[i]])))[0] 56 | # text = torch.stack(text) 57 | # return image, text 58 | 59 | 60 | def image_transform( 61 | image_size: int, 62 | is_train: bool, 63 | mean: Optional[Tuple[float, ...]] = None, 64 | std: Optional[Tuple[float, ...]] = None, 65 | resize_longest_max: bool = False, 66 | fill_color: int = 0, 67 | ): 68 | mean = mean or OPENAI_DATASET_MEAN 69 | if not isinstance(mean, (list, tuple)): 70 | mean = (mean,) * 3 71 | 72 | std = std or OPENAI_DATASET_STD 73 | if not isinstance(std, (list, tuple)): 74 | std = (std,) * 3 75 | 76 | if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]: 77 | # for square size, pass size as int so that Resize() uses aspect preserving shortest edge 78 | image_size = image_size[0] 79 | 80 | normalize = Normalize(mean=mean, std=std) 81 | if is_train: 82 | return Compose([ 83 | RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC), 84 | _convert_to_rgb, 85 | ToTensor(), 86 | normalize, 87 | ]) 88 | else: 89 | if resize_longest_max: 90 | transforms = [ 91 | ResizeMaxSize(image_size, fill=fill_color) 92 | ] 93 | else: 94 | transforms = [ 95 | Resize(image_size, interpolation=InterpolationMode.BICUBIC), 96 | CenterCrop(image_size), 97 | ] 98 | transforms.extend([ 99 | _convert_to_rgb, 100 | ToTensor(), 101 | normalize, 102 | ]) 103 | return Compose(transforms) 104 | -------------------------------------------------------------------------------- /data/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import torch.distributed as dist 5 | from utils.args import get_args,logging_cfgs 6 | from utils.initialize import initialize 7 | from utils.build_model import build_model 8 | from utils.build_optimizer import build_optimizer 9 | from utils.build_dataloader import create_train_dataloaders, create_val_dataloaders 10 | from utils.pipeline import train, test 11 | 12 | 13 | def main(): 14 | 15 | ### init 16 | 17 | args = get_args() 18 | initialize(args) 19 | 20 | ### logging cfgs 21 | logging_cfgs(args) 22 | 23 | 24 | if args.run_cfg.mode == 'training': 25 | 26 | ### create datasets and dataloader 27 | train_loader = create_train_dataloaders(args) 28 | val_loaders = create_val_dataloaders(args) 29 | 30 | ### build model and optimizer 31 | 32 | model, optimizer_ckpt, start_step = build_model(args) 33 | 34 | optimizer = build_optimizer(model, args, optimizer_ckpt) 35 | 36 | 37 | ### start evaluation 38 | if args.run_cfg.first_eval or args.run_cfg.zero_shot: 39 | test(model, val_loaders, args.run_cfg) 40 | if args.run_cfg.zero_shot: 41 | return 42 | 43 | ### start training 44 | 45 | 46 | train(model, optimizer, train_loader, val_loaders, args.run_cfg, start_step = start_step, verbose_time=False) 47 | 48 | elif args.run_cfg.mode == 'testing': 49 | ### build model 50 | model,_,_ = build_model(args) 51 | 52 | ### create datasets and dataloader 53 | val_loaders = create_val_dataloaders(args) 54 | 55 | ### start evaluation 56 | test(model, val_loaders, args.run_cfg) 57 | 58 | else: 59 | raise NotImplementedError 60 | 61 | 62 | if __name__ == "__main__": 63 | main() 64 | -------------------------------------------------------------------------------- /data/scripts/run_audio_captioner.sh: -------------------------------------------------------------------------------- 1 | torchrun --nnodes 1 \ 2 | --node_rank 0 \ 3 | --nproc_per_node 8 \ 4 | --master_port 9814 \ 5 | run.py \ 6 | --config ./caption_config/caption-generation-audio.json \ 7 | --pretrain_dir './audio_captioner' \ 8 | --output_dir './output/audio_caption' \ 9 | --test_batch_size 128 \ 10 | --generate_nums 3 \ 11 | --captioner_mode true \ -------------------------------------------------------------------------------- /data/scripts/run_vision_captioner.sh: -------------------------------------------------------------------------------- 1 | torchrun --nnodes 1 \ 2 | --node_rank 0 \ 3 | --nproc_per_node 8 \ 4 | --master_port 9814 \ 5 | run.py \ 6 | --config ./caption_config/caption-generation-vision.json \ 7 | --pretrain_dir './vision_captioner' \ 8 | --output_dir './output/vision_caption' \ 9 | --test_batch_size 64 \ 10 | --test_vision_sample_num 8 \ 11 | --generate_nums 3 \ 12 | --captioner_mode true \ 13 | -------------------------------------------------------------------------------- /data/setup_env.sh: -------------------------------------------------------------------------------- 1 | pip install torch==2.0.1 -i https://pypi.tuna.tsinghua.edu.cn/simple 2 | pip install torchvision==0.15.2 -i https://pypi.tuna.tsinghua.edu.cn/simple 3 | pip install torchaudio==2.0.2 -i https://pypi.tuna.tsinghua.edu.cn/simple 4 | pip install decord -i https://pypi.tuna.tsinghua.edu.cn/simple 5 | pip install h5py -i https://pypi.tuna.tsinghua.edu.cn/simple 6 | pip install ffmpeg-python -i https://pypi.tuna.tsinghua.edu.cn/simple 7 | pip install yacs -i https://pypi.tuna.tsinghua.edu.cn/simple 8 | pip install toolz -i https://pypi.tuna.tsinghua.edu.cn/simple 9 | pip install ipdb -i https://pypi.tuna.tsinghua.edu.cn/simple 10 | pip install einops -i https://pypi.tuna.tsinghua.edu.cn/simple 11 | pip install easydict -i https://pypi.tuna.tsinghua.edu.cn/simple 12 | pip install transformers==4.31.0 -i https://pypi.tuna.tsinghua.edu.cn/simple 13 | pip install webdataset -i https://pypi.tuna.tsinghua.edu.cn/simple 14 | pip install SentencePiece -i https://pypi.tuna.tsinghua.edu.cn/simple 15 | -------------------------------------------------------------------------------- /data/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/invictus717/MiCo/831847ff066296d185c7f304563326224b9dbc8b/data/utils/__init__.py -------------------------------------------------------------------------------- /data/utils/build_dataloader.py: -------------------------------------------------------------------------------- 1 | import torch.distributed as dist 2 | from torch.utils.data.distributed import DistributedSampler 3 | from torch.utils.data import DataLoader 4 | from data.loader import MetaLoader, PrefetchLoader 5 | from data import data_registry 6 | from utils.distributed import DistributedSampler_wopadding 7 | from .logger import LOGGER 8 | 9 | 10 | 11 | def create_train_dataloaders(args): 12 | data_cfg = args.data_cfg.train 13 | dataloaders = [] 14 | dataloaders_dict={} 15 | train_steps = [] 16 | loader_names = [] 17 | 18 | if len(data_cfg) == 0: 19 | return 20 | 21 | for d_cfg in data_cfg: 22 | 23 | dataset_ls = [] 24 | 25 | name = d_cfg['name'] 26 | dataset = data_registry[d_cfg.type](d_cfg, args) 27 | 28 | # print(dataset[0]) 29 | 30 | collate_fn = dataset.collate_fn 31 | worker_init_fn = dataset.worker_init_fn 32 | use_sampler = dataset.use_sampler 33 | 34 | 35 | LOGGER.info("Create Dataset {} Success".format(name)) 36 | task = d_cfg['task'] 37 | batch_size = d_cfg['batch_size'] 38 | n_workers = d_cfg['n_workers'] 39 | 40 | if 'steps' in d_cfg: 41 | train_steps.append(d_cfg['steps']) 42 | elif 'epoch' in d_cfg: 43 | epoch = d_cfg['epoch'] 44 | train_steps.append(int((len(dataset) // batch_size) * epoch)) 45 | 46 | loader = build_dataloader(dataset, collate_fn, True, batch_size // args.run_cfg.gradient_accumulation_steps , n_workers, worker_init_fn, use_sampler) 47 | 48 | dataloaders.append(loader) 49 | loader_names.append(f'{task}--{name}') 50 | 51 | 52 | for i in range(len(dataloaders)): 53 | ratio = train_steps[i] 54 | dataloaders_dict[loader_names[i]] = (dataloaders[i], ratio) 55 | 56 | n_gpu = dist.get_world_size() 57 | for name, (loader, ratio) in dataloaders_dict.items(): 58 | # epoch = (ratio * loader.batch_size * n_gpu ) // len(loader.dataset) 59 | LOGGER.info(f" loader {name} , ratio {ratio} , bs_pergpu {loader.batch_size}, n_workers {loader.num_workers}" ) 60 | 61 | 62 | meta_loader = MetaLoader(dataloaders_dict, 63 | accum_steps=args.run_cfg.gradient_accumulation_steps, 64 | distributed=n_gpu > 1) 65 | 66 | if args.run_cfg.num_train_steps == 0: 67 | total_train_steps = sum(train_steps) 68 | args.run_cfg.num_train_steps = total_train_steps 69 | 70 | 71 | 72 | meta_loader = PrefetchLoader(meta_loader) 73 | meta_loader.ndata = len(dataloaders_dict) 74 | args.run_cfg.valid_steps = args.run_cfg.num_train_steps // args.run_cfg.valid_freq -1 75 | 76 | 77 | 78 | return meta_loader 79 | 80 | 81 | def create_val_dataloaders(args): 82 | data_cfg = args.data_cfg.val 83 | dataloaders = {} 84 | for d_cfg in data_cfg: 85 | name = d_cfg['name'] 86 | dataset = data_registry[d_cfg.type](d_cfg, args) 87 | collate_fn = dataset.collate_fn 88 | worker_init_fn = dataset.worker_init_fn 89 | use_sampler = dataset.use_sampler 90 | # task = d_cfg['task'].split('_') 91 | # if 'qa' in task: 92 | # dataset.make_submission = d_cfg.get('make_submission', False) 93 | 94 | # if 'cap' in task: 95 | # dataset.annfile = d_cfg['annfile'] 96 | 97 | # dataset.data_type = data_type 98 | dataset.name = name 99 | LOGGER.info("Create Dataset {} Success".format(name)) 100 | task = d_cfg['task'] 101 | batch_size = d_cfg['batch_size'] 102 | n_workers = d_cfg['n_workers'] 103 | loader = build_dataloader(dataset, collate_fn, False, batch_size, n_workers, worker_init_fn, use_sampler) 104 | task_name = f'{task}--{name}' 105 | dataloaders[task_name] = PrefetchLoader(loader) 106 | return dataloaders 107 | 108 | 109 | def build_dataloader(dataset, collate_fn, is_train, batch_size, n_workers=None, worker_init_fn=None, use_sampler=True): 110 | batch_size = batch_size // dist.get_world_size() 111 | if use_sampler: 112 | if is_train: 113 | sampler = DistributedSampler(dataset) 114 | else: 115 | sampler = DistributedSampler_wopadding(dataset) 116 | loader = DataLoader(dataset, sampler = sampler, batch_size = batch_size, 117 | num_workers=n_workers, pin_memory=True, 118 | collate_fn=collate_fn, drop_last=is_train,worker_init_fn=worker_init_fn) 119 | else: 120 | 121 | loader = DataLoader(dataset, batch_size = batch_size, 122 | num_workers=n_workers, pin_memory=True, 123 | collate_fn=collate_fn, drop_last=is_train,worker_init_fn=worker_init_fn) 124 | 125 | return loader 126 | 127 | -------------------------------------------------------------------------------- /data/utils/build_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | 5 | from model import model_registry 6 | from torch.nn.parallel import DistributedDataParallel as DDP 7 | from .logger import LOGGER 8 | from .build_optimizer import build_optimizer 9 | 10 | 11 | class DDP_modify(DDP): 12 | def __getattr__(self, name): 13 | try: 14 | return super().__getattr__(name) 15 | except: 16 | return getattr(self.module,name) 17 | 18 | 19 | def build_model(args): 20 | 21 | model = model_registry[args.model_cfg.model_type](args.model_cfg) 22 | checkpoint = {} 23 | 24 | ### load ckpt from a pretrained_dir 25 | if args.run_cfg.pretrain_dir: 26 | checkpoint = load_from_pretrained_dir(args) 27 | LOGGER.info("Load from pretrained dir {}".format(args.run_cfg.pretrain_dir)) 28 | 29 | ### load ckpt from specific path 30 | if args.run_cfg.checkpoint: 31 | checkpoint = torch.load(args.run_cfg.checkpoint, map_location = 'cpu') 32 | 33 | ### resume training 34 | if args.run_cfg.resume: 35 | checkpoint, checkpoint_optim, start_step = load_from_resume(args.run_cfg) 36 | else: 37 | checkpoint_optim, start_step = None , 0 38 | 39 | 40 | checkpoint = {k.replace('module.',''):v for k,v in checkpoint.items()} 41 | 42 | if checkpoint != {}: 43 | 44 | checkpoint = model.modify_checkpoint(checkpoint) 45 | if "model" in checkpoint.keys(): 46 | checkpoint = checkpoint["model"] 47 | 48 | missing_keys,unexpected_keys = model.load_state_dict(checkpoint,strict=False) 49 | LOGGER.info(f"Unexpected keys {unexpected_keys}") 50 | LOGGER.info(f"missing_keys {missing_keys}") 51 | 52 | 53 | local_rank = args.local_rank 54 | device = torch.device("cuda", local_rank) 55 | model.to(device) 56 | if args.run_cfg.use_ddp: 57 | model = DDP_modify(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True) 58 | else: 59 | pass 60 | 61 | return model, checkpoint_optim, start_step 62 | 63 | 64 | 65 | def load_from_pretrained_dir(args): 66 | 67 | 68 | try: ### huggingface trainer 69 | checkpoint_dir = args.run_cfg.pretrain_dir 70 | checkpoint_ls = [ i for i in os.listdir(checkpoint_dir) if i.startswith('checkpoint')] 71 | checkpoint_ls = [int(i.split('-')[1]) for i in checkpoint_ls] 72 | checkpoint_ls.sort() 73 | step = checkpoint_ls[-1] 74 | 75 | try: 76 | checkpoint_name = f'checkpoint-{step}/pytorch_model.bin' 77 | ckpt_file = os.path.join(checkpoint_dir, checkpoint_name) 78 | checkpoint = torch.load(ckpt_file, map_location = 'cpu') 79 | except: 80 | checkpoint_name1 = f'checkpoint-{step}/pytorch_model-00001-of-00002.bin' 81 | ckpt_file1 = torch.load(os.path.join(checkpoint_dir, checkpoint_name1), map_location = 'cpu') 82 | checkpoint_name2 = f'checkpoint-{step}/pytorch_model-00002-of-00002.bin' 83 | ckpt_file2 = torch.load(os.path.join(checkpoint_dir, checkpoint_name2), map_location = 'cpu') 84 | ckpt_file1.update(ckpt_file2) 85 | checkpoint = ckpt_file1 86 | # checkpoint = {k.replace('module.',''):v for k,v in checkpoint.items()} 87 | LOGGER.info(f'load_from_pretrained: {ckpt_file}') 88 | 89 | except: 90 | checkpoint_dir = os.path.join(args.run_cfg.pretrain_dir,'ckpt') 91 | checkpoint_ls = [ i for i in os.listdir(checkpoint_dir) if i.startswith('model_step')] 92 | checkpoint_ls = [int(i.split('_')[2].split('.')[0]) for i in checkpoint_ls] 93 | checkpoint_ls.sort() 94 | step = checkpoint_ls[-1] 95 | 96 | checkpoint_name = 'model_step_'+str(step)+'.pt' 97 | ckpt_file = os.path.join(checkpoint_dir, checkpoint_name) 98 | checkpoint = torch.load(ckpt_file, map_location = 'cpu') 99 | # checkpoint = {k.replace('module.',''):v for k,v in checkpoint.items()} 100 | LOGGER.info(f'load_from_pretrained: {ckpt_file}') 101 | 102 | 103 | return checkpoint 104 | 105 | 106 | def load_from_resume(run_cfg): 107 | ckpt_dir = os.path.join(run_cfg.output_dir,'ckpt') 108 | previous_optimizer_state = [i for i in os.listdir(ckpt_dir) if i.startswith('optimizer')] 109 | steps = [i.split('.pt')[0].split('_')[-1] for i in previous_optimizer_state] 110 | steps = [ int(i) for i in steps] 111 | steps.sort() 112 | previous_step = steps[-1] 113 | previous_optimizer_state = f'optimizer_step_{previous_step}.pt' 114 | previous_model_state = f'model_step_{previous_step}.pt' 115 | previous_step = int(previous_model_state.split('.')[0].split('_')[-1]) 116 | previous_optimizer_state = os.path.join(ckpt_dir, previous_optimizer_state) 117 | previous_model_state = os.path.join(ckpt_dir, previous_model_state) 118 | 119 | assert os.path.exists(previous_optimizer_state) and os.path.exists(previous_model_state) 120 | LOGGER.info("choose previous model: {}".format(previous_model_state)) 121 | LOGGER.info("choose previous optimizer: {}".format(previous_optimizer_state)) 122 | previous_model_state = torch.load(previous_model_state,map_location='cpu') 123 | previous_optimizer_state = torch.load(previous_optimizer_state,map_location='cpu') 124 | return previous_model_state, previous_optimizer_state, previous_step 125 | 126 | 127 | 128 | 129 | 130 | 131 | -------------------------------------------------------------------------------- /data/utils/distributed.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pickle 3 | import torch 4 | import torch.distributed as dist 5 | from torch.autograd import Function 6 | from torch.utils.data.distributed import DistributedSampler 7 | 8 | 9 | 10 | 11 | 12 | class GatherLayer(torch.autograd.Function): 13 | """ 14 | Gather tensors from all workers with support for backward propagation: 15 | This implementation does not cut the gradients as torch.distributed.all_gather does. 16 | """ 17 | 18 | @staticmethod 19 | def forward(ctx, x): 20 | output = [ 21 | torch.zeros_like(x) for _ in range(torch.distributed.get_world_size()) 22 | ] 23 | torch.distributed.all_gather(output, x) 24 | return tuple(output) 25 | 26 | @staticmethod 27 | def backward(ctx, *grads): 28 | all_gradients = torch.stack(grads) 29 | torch.distributed.all_reduce(all_gradients) 30 | return all_gradients[torch.distributed.get_rank()] 31 | 32 | 33 | def all_gather_with_grad(tensors): 34 | """ 35 | Performs all_gather operation on the provided tensors. 36 | Graph remains connected for backward grad computation. 37 | """ 38 | # Queue the gathered tensors 39 | world_size = torch.distributed.get_world_size() 40 | # There is no need for reduction in the single-proc case 41 | if world_size == 1: 42 | return tensors 43 | 44 | # tensor_all = GatherLayer.apply(tensors) 45 | tensor_all = GatherLayer.apply(tensors) 46 | 47 | return torch.cat(tensor_all, dim=0) 48 | 49 | 50 | @torch.no_grad() 51 | def concat_all_gather(tensor): 52 | """ 53 | Performs all_gather operation on the provided tensors. 54 | *** Warning ***: torch.distributed.all_gather has no gradient. 55 | """ 56 | # if use distributed training 57 | # if not is_dist_avail_and_initialized(): 58 | # return tensor 59 | 60 | tensors_gather = [ 61 | torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) 62 | ] 63 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 64 | 65 | output = torch.cat(tensors_gather, dim=0) 66 | return output 67 | 68 | 69 | 70 | def _encode(enc, max_size, use_max_size=False): 71 | enc_size = len(enc) 72 | enc_byte = max(math.floor(math.log(max_size, 256)+1), 1) 73 | if use_max_size: 74 | # this is used for broadcasting 75 | buffer_ = torch.cuda.ByteTensor(max_size+enc_byte) 76 | else: 77 | buffer_ = torch.cuda.ByteTensor(enc_size+enc_byte) 78 | remainder = enc_size 79 | for i in range(enc_byte): 80 | base = 256 ** (enc_byte-i-1) 81 | buffer_[i] = remainder // base 82 | remainder %= base 83 | buffer_[enc_byte:enc_byte+enc_size] = torch.ByteTensor(list(enc)) 84 | return buffer_, enc_byte 85 | 86 | 87 | def _decode(buffer_, enc_byte): 88 | size = sum(256 ** (enc_byte-i-1) * buffer_[i].item() 89 | for i in range(enc_byte)) 90 | bytes_list = bytes(buffer_[enc_byte:enc_byte+size].tolist()) 91 | shift = size + enc_byte 92 | return bytes_list, shift 93 | 94 | 95 | 96 | 97 | 98 | def all_gather_list(data): 99 | """Gathers arbitrary data from all nodes into a list.""" 100 | enc = pickle.dumps(data) 101 | 102 | enc_size = len(enc) 103 | max_size = ddp_allgather(torch.tensor([enc_size]).cuda()).max().item() 104 | in_buffer, enc_byte = _encode(enc, max_size) 105 | 106 | out_buffer = ddp_allgather(in_buffer[:enc_byte+enc_size]) 107 | 108 | results = [] 109 | for _ in range(dist.get_world_size()): 110 | bytes_list, shift = _decode(out_buffer, enc_byte) 111 | out_buffer = out_buffer[shift:] 112 | result = pickle.loads(bytes_list) 113 | results.append(result) 114 | return results 115 | 116 | 117 | def any_broadcast(data, root_rank): 118 | """broadcast arbitrary data from root_rank to all nodes.""" 119 | enc = pickle.dumps(data) 120 | 121 | max_size = ddp_allgather(torch.tensor([len(enc)]).cuda()).max().item() 122 | buffer_, enc_byte = _encode(enc, max_size, use_max_size=True) 123 | 124 | dist.broadcast(buffer_, root_rank) 125 | 126 | bytes_list, _ = _decode(buffer_, enc_byte) 127 | result = pickle.loads(bytes_list) 128 | return result 129 | 130 | 131 | 132 | ###### with different batch_size ~ 133 | def ddp_allgather(input): 134 | tmp_input = input.cuda() 135 | size = torch.tensor(tmp_input.shape[0]).cuda() 136 | size_list = [torch.zeros_like(size) for i in range(dist.get_world_size())] 137 | dist.all_gather(size_list, size) 138 | max_size = max(size_list).item() 139 | padding_size = max_size - size 140 | if padding_size > 0 : 141 | padding_tensor = torch.zeros(padding_size,*tmp_input.shape[1:]).to(tmp_input) 142 | tmp_input = torch.cat((tmp_input, padding_tensor), dim = 0) 143 | tmp_list = [torch.zeros_like(tmp_input) for i in range(dist.get_world_size())] 144 | dist.all_gather(tmp_list, tmp_input) 145 | output = [] 146 | for t, s in zip(tmp_list, size_list): 147 | output.append(t[:s]) 148 | output = torch.cat(output,dim=0) 149 | return output 150 | 151 | 152 | 153 | class DistributedSampler_wopadding(DistributedSampler): 154 | 155 | def __iter__(self): 156 | if self.shuffle: 157 | # deterministically shuffle based on epoch and seed 158 | g = torch.Generator() 159 | g.manual_seed(self.seed + self.epoch) 160 | indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] 161 | else: 162 | indices = list(range(len(self.dataset))) # type: ignore[arg-type] 163 | 164 | # if not self.drop_last: 165 | # # add extra samples to make it evenly divisible 166 | # padding_size = self.total_size - len(indices) 167 | # if padding_size <= len(indices): 168 | # indices += indices[:padding_size] 169 | # else: 170 | # indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] 171 | # else: 172 | # remove tail of data to make it evenly divisible. 173 | if self.drop_last: 174 | indices = indices[:self.total_size] 175 | #assert len(indices) == self.total_size 176 | 177 | # subsample 178 | indices = indices[self.rank:len(indices):self.num_replicas] 179 | # assert len(indices) == self.num_samples 180 | 181 | return iter(indices) 182 | 183 | 184 | 185 | 186 | -------------------------------------------------------------------------------- /data/utils/initialize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import torch 4 | import numpy as np 5 | import torch.distributed as dist 6 | from .logger import LOGGER,add_log_to_file 7 | 8 | def initialize(opts): 9 | 10 | # if not os.path.exists(opts.run_cfg.output_dir): 11 | os.makedirs(os.path.join(opts.run_cfg.output_dir, 'log'), exist_ok=True) 12 | os.makedirs(os.path.join(opts.run_cfg.output_dir, 'ckpt'), exist_ok=True) 13 | 14 | local_rank = opts.local_rank 15 | torch.cuda.set_device(local_rank) 16 | dist.init_process_group(backend='nccl') 17 | if opts.run_cfg.gradient_accumulation_steps < 1: 18 | raise ValueError("Invalid gradient_accumulation_steps parameter: {}, " 19 | "should be >= 1".format( 20 | opts.run_cfg.gradient_accumulation_steps)) 21 | set_random_seed(opts.run_cfg.seed) 22 | torch.backends.cudnn.benchmark = True 23 | torch.backends.cudnn.enabled = True 24 | if dist.get_rank() == 0: 25 | # TB_LOGGER.create(os.path.join(opts.output_dir, 'log')) 26 | add_log_to_file(os.path.join(opts.run_cfg.output_dir, 'log', 'log.txt')) 27 | else: 28 | LOGGER.disabled = True 29 | 30 | 31 | def set_random_seed(seed): 32 | random.seed(seed) 33 | np.random.seed(seed) 34 | torch.manual_seed(seed) 35 | torch.cuda.manual_seed_all(seed) 36 | 37 | -------------------------------------------------------------------------------- /data/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | 4 | 5 | _LOG_FMT = '%(asctime)s - %(levelname)s - %(name)s - %(message)s' 6 | _DATE_FMT = '%m/%d/%Y %H:%M:%S' 7 | logging.basicConfig(format=_LOG_FMT, datefmt=_DATE_FMT, level=logging.INFO) 8 | LOGGER = logging.getLogger('__main__') # this is the global logger 9 | 10 | 11 | def add_log_to_file(log_path): 12 | fh = logging.FileHandler(log_path) 13 | formatter = logging.Formatter(_LOG_FMT, datefmt=_DATE_FMT) 14 | fh.setFormatter(formatter) 15 | LOGGER.addHandler(fh) 16 | 17 | 18 | class RunningMeter(object): 19 | """ running meteor of a scalar value 20 | (useful for monitoring training loss) 21 | """ 22 | def __init__(self, name=None, val=None, smooth=0.99): 23 | self._name = name 24 | self._sm = smooth 25 | self._val = val 26 | 27 | def __call__(self, value): 28 | val = (value if self._val is None 29 | else value*(1-self._sm) + self._val*self._sm) 30 | if not math.isnan(val): 31 | self._val = val 32 | 33 | def __str__(self): 34 | return f'{self._name}: {self._val:.4f}' 35 | 36 | @property 37 | def val(self): 38 | if self._val is None: 39 | return 0 40 | return self._val 41 | 42 | @property 43 | def name(self): 44 | return self._name 45 | 46 | -------------------------------------------------------------------------------- /data/utils/offline_process_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tqdm 3 | import ffmpeg 4 | import subprocess 5 | import multiprocessing 6 | import numpy as np 7 | 8 | from multiprocessing import Pool 9 | 10 | 11 | input_path = '/public/chensihan/datasets/tgif/gifs_used' 12 | output_path = '/public/chensihan/datasets/tgif/' 13 | 14 | data_list = os.listdir(input_path) 15 | 16 | def execCmd(cmd): 17 | r = os.popen(cmd) 18 | text = r.read() 19 | r.close() 20 | return text 21 | 22 | def pipline(video_path, video_probe, output_dir, fps, sr, duration_target): 23 | video_name = os.path.basename(video_path) 24 | 25 | video_name = video_name.replace(".mp4", "") 26 | 27 | 28 | # extract video frames fps 29 | fps_frame_dir = os.path.join(output_dir, f"frames_fps{fps}", video_name) 30 | os.makedirs(fps_frame_dir, exist_ok=True) 31 | cmd = "ffmpeg -loglevel error -i {} -vsync 0 -f image2 -vf fps=fps={:.02f} -qscale:v 2 {}/frame_%04d.jpg".format( 32 | video_path, fps, fps_frame_dir) 33 | 34 | ## extract fixed number frames 35 | # fps_frame_dir = os.path.join(output_dir, f"frames_32", video_name) 36 | # os.makedirs(fps_frame_dir, exist_ok=True) 37 | # cmd = "ffmpeg -loglevel error -i {} -vsync 0 -f image2 -vframes 32 -qscale:v 2 {}/frame_%04d.jpg".format( 38 | # video_path, fps_frame_dir) 39 | 40 | 41 | # ## extract audios 42 | # sr_audio_dir = os.path.join(output_dir,f"audios") 43 | # os.makedirs(sr_audio_dir, exist_ok=True) 44 | # # print(sr_audio_dir) 45 | # audio_name = video_name+'.wav' 46 | # audio_file_path = os.path.join(sr_audio_dir, audio_name) 47 | 48 | 49 | cmd = "ffmpeg -i {} -loglevel error -f wav -vn -ac 1 -ab 16k -ar {} -y {}".format( 50 | video_path, sr, audio_file_path) 51 | 52 | 53 | subprocess.call(cmd, shell=True) 54 | 55 | 56 | 57 | def extract_thread(video_id): 58 | 59 | video_name = os.path.join(input_path, video_id) 60 | 61 | if not os.path.exists(video_name): 62 | 63 | return 64 | try: 65 | # print(1) 66 | probe = ffmpeg.probe(video_name) 67 | # print(1) 68 | pipline(video_name, probe, output_path, fps=1, sr=22050, duration_target=10) 69 | except Exception as e: 70 | print(e) 71 | return 72 | 73 | 74 | def extract_all(video_ids, thread_num, start): 75 | length = len(video_ids) 76 | print(length) 77 | with Pool(thread_num) as p: 78 | list(tqdm.tqdm(p.imap(extract_thread, video_ids), total=length)) 79 | 80 | if __name__=='__main__': 81 | thread_num = 20 82 | start = 0 83 | 84 | print(len(data_list)) 85 | extract_all(data_list, thread_num, start) 86 | 87 | -------------------------------------------------------------------------------- /data/utils/pipeline.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | import torch.nn.functional as F 5 | from tqdm import tqdm 6 | # from apex import amp 7 | from collections import defaultdict 8 | from torch.nn.utils import clip_grad_norm_ 9 | from evaluation import evaluation_registry 10 | from .save import ModelSaver 11 | from .tool import NoOp 12 | from .logger import LOGGER, RunningMeter 13 | from .sched import get_lr_sched 14 | from torch.cuda.amp import autocast, GradScaler 15 | 16 | 17 | def train(model, optimizer, train_loader, val_loaders, run_cfg, start_step=0, verbose_time=False): 18 | 19 | if dist.get_rank() == 0: 20 | pbar = tqdm(total=run_cfg.num_train_steps, initial=start_step) 21 | model_saver = ModelSaver(os.path.join(run_cfg.output_dir, 'ckpt'),remove_before_ckpt=run_cfg.remove_before_ckpt) 22 | else: 23 | pbar = NoOp() 24 | model_saver = NoOp() 25 | 26 | loss_moving_averagetors ={} 27 | metric_logger_dict = defaultdict(dict) 28 | global_step = start_step 29 | 30 | scaler = GradScaler() 31 | 32 | best_indicator = {} 33 | evaluate_fn = evaluation_registry[model.config.evaluation_type] 34 | 35 | for step, (name, batch) in enumerate(train_loader): 36 | 37 | ndata = train_loader.ndata 38 | task = name.split('--')[0] 39 | 40 | 41 | 42 | if run_cfg.fp16: 43 | with autocast(): 44 | loss_dict = model(batch, task=task, compute_loss=True) 45 | loss = sum(list(loss_dict.values())) 46 | loss_dict['total_loss'] = loss 47 | loss_dict = {k:v.item() for k,v in loss_dict.items()} 48 | 49 | else: 50 | loss_dict = model(batch, task=task, compute_loss=True) 51 | loss = sum(list(loss_dict.values())) 52 | loss_dict['total_loss'] = loss 53 | loss_dict = {k:v.item() for k,v in loss_dict.items()} 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | if not name in loss_moving_averagetors: 64 | ### first time initialize 65 | for k in loss_dict.keys(): 66 | loss_moving_averagetors[f'loss_{name}/{k}'] = RunningMeter() 67 | ####accumulate loss 68 | 69 | for k,v in loss_dict.items(): 70 | loss_moving_averagetors[f'loss_{name}/{k}'](v) 71 | 72 | 73 | global_step += 1 74 | # learning rate scheduling 75 | lr_ratio = get_lr_sched(global_step, run_cfg) 76 | 77 | for param_group in optimizer.param_groups: 78 | param_group['lr'] = param_group['init_lr'] * lr_ratio 79 | 80 | if global_step % 50 == 0: 81 | LOGGER.info({name : averagetor.val for name, averagetor in loss_moving_averagetors.items()}) 82 | 83 | # update model params 84 | 85 | 86 | if run_cfg.fp16: 87 | optimizer.zero_grad() 88 | scaler.scale(loss).backward() 89 | else: 90 | loss.backward() 91 | 92 | if not run_cfg.use_ddp: 93 | works = [] 94 | for p in model.parameters(): 95 | # to speed it up, you can also organize grads to larger buckets to make allreduce more efficient 96 | if p.grad is not None: 97 | works.append(dist.all_reduce(p.grad, async_op=True)) 98 | for work in works: 99 | work.wait() 100 | 101 | 102 | # if run_cfg.grad_norm != -1: 103 | # grad_norm = clip_grad_norm_(model.parameters(), run_cfg.grad_norm) 104 | 105 | if run_cfg.fp16: 106 | scaler.step(optimizer) 107 | scaler.update() 108 | else: 109 | optimizer.step() 110 | optimizer.zero_grad() 111 | pbar.update(1) 112 | 113 | 114 | 115 | if (global_step+1) % run_cfg.valid_steps == 0: 116 | eval_log = evaluate_fn(model, val_loaders, run_cfg, global_step) 117 | 118 | if dist.get_rank() == 0: 119 | for task_name, val_log in eval_log.items(): 120 | for eval_name, metric in val_log.items(): 121 | eval_name = task_name +'_' +eval_name 122 | metric_logger_dict[eval_name][str(global_step)] = metric 123 | LOGGER.info(f"====-evaluation--{eval_name}=====step {global_step}--===========\n") 124 | LOGGER.info(metric) 125 | best_name = get_best_name(eval_name, metric) 126 | if best_name is not None: 127 | if ('best_step' not in metric_logger_dict[eval_name]) or \ 128 | (metric[best_name] >= metric_logger_dict[eval_name]['best_value']): 129 | metric_logger_dict[eval_name]['best_step'] = global_step 130 | metric_logger_dict[eval_name]['best_value'] = metric[best_name] 131 | best_indicator[eval_name] = True 132 | else: 133 | best_indicator[eval_name] = False 134 | best_step = metric_logger_dict[eval_name]['best_step'] 135 | LOGGER.info(f"======evaluation--{eval_name}====history best step: {best_step}=======\n") 136 | LOGGER.info(metric_logger_dict[eval_name][str(best_step)]) 137 | 138 | model_saver.save(model, global_step, optimizer,best_indicator, run_cfg.save_best) 139 | 140 | 141 | if global_step >= run_cfg.num_train_steps: 142 | break 143 | pbar.close() 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | def test(model, test_loader, run_cfg): 153 | 154 | evaluate_fn = evaluation_registry[model.config.evaluation_type] 155 | eval_log = evaluate_fn(model, test_loader, run_cfg, global_step=0) 156 | if dist.get_rank()==0: 157 | for task_name, val_log in eval_log.items(): 158 | for eval_name, metric in val_log.items(): 159 | eval_name = task_name +'_' +eval_name 160 | # TB_LOGGER.log_scaler_dict({f"eval/{eval_name}/test_{k}": v 161 | # for k, v in metric.items() if not isinstance(v,str)}) 162 | LOGGER.info(f"==== evaluation--{eval_name}========\n") 163 | LOGGER.info(metric) 164 | 165 | 166 | 167 | 168 | def get_best_name(eval_name, metric): 169 | if eval_name.startswith('cap'): 170 | return 'CIDEr' 171 | elif eval_name.startswith('qa'): 172 | return 'accuracy' 173 | elif eval_name.startswith('ret'): 174 | if 'video_r1' in metric: 175 | return 'video_r1' 176 | elif eval_name.startswith('pt'): 177 | return None 178 | else: 179 | raise NotImplementedError 180 | 181 | -------------------------------------------------------------------------------- /data/utils/save.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from os.path import join 4 | from utils.logger import LOGGER 5 | 6 | 7 | 8 | 9 | class ModelSaver(object): 10 | def __init__(self, output_dir, prefix='model_step', suffix='pt',remove_before_ckpt=True): 11 | self.output_dir = output_dir 12 | self.prefix = prefix 13 | self.suffix = suffix 14 | self.remove_before_ckpt = remove_before_ckpt 15 | def save(self, model, step, optimizer=None, best_indicator=None, save_best=False): 16 | ###remove previous model 17 | previous_state = [i for i in os.listdir(self.output_dir) if i.startswith('model')] 18 | # if not self.pretraining: 19 | if self.remove_before_ckpt: 20 | for p in previous_state: 21 | os.remove(os.path.join(self.output_dir,p)) 22 | output_model_file = join(self.output_dir, 23 | f"{self.prefix}_{step}.{self.suffix}") 24 | state_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v 25 | for k, v in model.state_dict().items()} 26 | torch.save(state_dict, output_model_file) 27 | 28 | if save_best: 29 | for k in best_indicator: 30 | if best_indicator[k]: 31 | torch.save(state_dict, join(self.output_dir, 32 | f"best_{k}.{self.suffix}")) 33 | 34 | if optimizer is not None: 35 | if hasattr(optimizer, '_amp_stash'): 36 | pass # TODO fp16 optimizer 37 | previous_state = [i for i in os.listdir(self.output_dir) if i.startswith('optimizer')] 38 | if self.remove_before_ckpt: 39 | for p in previous_state: 40 | os.remove(os.path.join(self.output_dir,p)) 41 | torch.save(optimizer.state_dict(), f'{self.output_dir}/optimizer_step_{step}.pt') 42 | -------------------------------------------------------------------------------- /data/utils/sched.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | def warmup_cosine(x, warmup_ratio): 4 | if x < warmup_ratio: 5 | return x/warmup_ratio 6 | return 0.5 * (1.0 + math.cos(math.pi * x)) 7 | 8 | def warmup_constant(x, warmup_ratio): 9 | """ Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps. 10 | Learning rate is 1. afterwards. """ 11 | if x < warmup_ratio: 12 | return x/warmup_ratio 13 | return 1.0 14 | 15 | def warmup_linear(x, warmup_ratio): 16 | """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step. 17 | After `t_total`-th training step, learning rate is zero. """ 18 | if x < warmup_ratio: 19 | return x/warmup_ratio 20 | return max((x-1.)/(warmup_ratio-1.), 0) 21 | 22 | scheduler_dict = {'warmup_linear' : warmup_linear, 23 | 'warmup_cosine' : warmup_cosine} 24 | 25 | def get_lr_sched(global_step, opts): 26 | warmup_ratio = opts.warmup_ratio 27 | current_ratio = global_step / opts.num_train_steps 28 | lr_ratio = scheduler_dict[opts.scheduler](current_ratio, warmup_ratio) 29 | return lr_ratio 30 | 31 | 32 | -------------------------------------------------------------------------------- /data/utils/tool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class NoOp(object): 4 | """ useful for distributed training No-Ops """ 5 | def __getattr__(self, name): 6 | return self.noop 7 | 8 | def noop(self, *args, **kwargs): 9 | return 10 | 11 | 12 | 13 | 14 | def split(frame_name_lists, sample_num): 15 | if len(frame_name_lists) < sample_num: ###padding with the last frame 16 | frame_name_lists += [frame_name_lists[-1]]*(sample_num - len(frame_name_lists)) 17 | k, m = divmod(len(frame_name_lists), sample_num) 18 | return [frame_name_lists[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in list(range(sample_num))] 19 | 20 | 21 | 22 | -------------------------------------------------------------------------------- /example/test.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/invictus717/MiCo/831847ff066296d185c7f304563326224b9dbc8b/example/test.flac -------------------------------------------------------------------------------- /example/test.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/invictus717/MiCo/831847ff066296d185c7f304563326224b9dbc8b/example/test.jpeg -------------------------------------------------------------------------------- /example/test.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/invictus717/MiCo/831847ff066296d185c7f304563326224b9dbc8b/example/test.mp4 -------------------------------------------------------------------------------- /model/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/invictus717/MiCo/831847ff066296d185c7f304563326224b9dbc8b/model/.DS_Store -------------------------------------------------------------------------------- /model/audioprocessor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import torch 5 | import torchaudio 6 | 7 | 8 | def split(frame_name_lists, sample_num): 9 | if len(frame_name_lists) < sample_num: ###padding with the last frame 10 | frame_name_lists += [frame_name_lists[-1]]*(sample_num - len(frame_name_lists)) 11 | k, m = divmod(len(frame_name_lists), sample_num) 12 | return [frame_name_lists[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in list(range(sample_num))] 13 | 14 | 15 | class AudioProcessor(object): 16 | def __init__(self, melbins, target_length, sample_num, frame_shift=10, resize_melbin_num=224, mean=15.41663, std=6.55582, training=True): 17 | self.melbins = melbins 18 | self.target_length = target_length 19 | self.training = training 20 | self.frame_shift = frame_shift 21 | self.sample_num = sample_num 22 | self.resize_melbin_num = resize_melbin_num 23 | 24 | self.mean = mean 25 | self.std = std 26 | 27 | def __call__(self, wav_file): 28 | 29 | if not os.path.exists(wav_file): 30 | print('not have audios', wav_file) 31 | return torch.zeros(self.sample_num, self.target_length, self.melbins) 32 | 33 | try: 34 | waveform, sr = torchaudio.load(wav_file) 35 | if sr != 16000: 36 | trans = torchaudio.transforms.Resample(sr, 16000) 37 | waveform = trans(waveform) 38 | 39 | waveform = waveform * 2 ** 15 40 | fbank = torchaudio.compliance.kaldi.fbank(waveform, num_mel_bins=self.melbins, sample_frequency=16000, frame_length=25, frame_shift=10) 41 | 42 | if fbank.size(1) != self.resize_melbin_num: 43 | fbank = torch.nn.functional.interpolate(fbank.reshape(1, 1, *fbank.shape[-2:]), size=(fbank.size(0), self.resize_melbin_num), mode='bilinear').reshape(fbank.size(0), self.resize_melbin_num) 44 | 45 | # ### normalization 46 | fbank = (fbank - self.mean) / (self.std * 2) 47 | src_length = fbank.shape[0] 48 | # #### sample 49 | 50 | output_slices = [] 51 | pad_len = max(self.target_length * self.sample_num -src_length, self.target_length - src_length%self.target_length) 52 | fbank = torch.nn.ZeroPad2d((0, 0, 0, pad_len))(fbank) 53 | 54 | total_slice_num = fbank.shape[0] // self.target_length 55 | total_slice_num = list(range(total_slice_num)) 56 | total_slice_num = split(total_slice_num, self.sample_num) 57 | 58 | if self.training: 59 | sample_idx = [random.choice(i) for i in total_slice_num] 60 | else: 61 | sample_idx = [i[(len(i)+1)//2-1] for i in total_slice_num] 62 | 63 | 64 | for i in sample_idx: 65 | cur_bank = fbank[i*self.target_length : (i+1)*self.target_length] 66 | output_slices.append(cur_bank) 67 | 68 | 69 | 70 | fbank = torch.stack(output_slices,dim=0) ### n, 1024, 128 71 | 72 | 73 | return fbank 74 | 75 | 76 | except Exception as e: 77 | print(e) 78 | return 79 | 80 | 81 | if __name__ == "__main__": 82 | wav_file = "./data/test.flac" 83 | proc = AudioProcessor(melbins=224, target_length=224, sample_num=4, training=True) 84 | audio_input = proc(wav_file) 85 | print(audio_input.size()) -------------------------------------------------------------------------------- /model/bert-base-uncased-crossattn/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "bert-base-uncased", 3 | "add_cross_attention": true, 4 | "architectures": [ 5 | "BertForMaskedLM" 6 | ], 7 | "attention_probs_dropout_prob": 0.1, 8 | "classifier_dropout": null, 9 | "gradient_checkpointing": false, 10 | "hidden_act": "gelu", 11 | "hidden_dropout_prob": 0.1, 12 | "hidden_size": 768, 13 | "initializer_range": 0.02, 14 | "intermediate_size": 3072, 15 | "is_decoder": true, 16 | "layer_norm_eps": 1e-12, 17 | "max_position_embeddings": 512, 18 | "model_type": "bert", 19 | "num_attention_heads": 12, 20 | "num_hidden_layers": 12, 21 | "pad_token_id": 0, 22 | "position_embedding_type": "absolute", 23 | "torch_dtype": "float32", 24 | "transformers_version": "4.26.1", 25 | "type_vocab_size": 2, 26 | "use_cache": true, 27 | "vocab_size": 30522 28 | } 29 | -------------------------------------------------------------------------------- /model/bert-base-uncased-crossattn/generation_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_from_model_config": true, 3 | "pad_token_id": 0, 4 | "transformers_version": "4.26.1" 5 | } 6 | -------------------------------------------------------------------------------- /model/clip/clip_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | from typing import Any, Union, List 6 | from pkg_resources import packaging 7 | import ftfy 8 | import regex as re 9 | import torch 10 | 11 | 12 | @lru_cache() 13 | def default_bpe(): 14 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 15 | 16 | 17 | @lru_cache() 18 | def bytes_to_unicode(): 19 | """ 20 | Returns list of utf-8 byte and a corresponding list of unicode strings. 21 | The reversible bpe codes work on unicode strings. 22 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 23 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 24 | This is a signficant percentage of your normal, say, 32K bpe vocab. 25 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 26 | And avoids mapping to whitespace/control characters the bpe code barfs on. 27 | """ 28 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 29 | cs = bs[:] 30 | n = 0 31 | for b in range(2**8): 32 | if b not in bs: 33 | bs.append(b) 34 | cs.append(2**8+n) 35 | n += 1 36 | cs = [chr(n) for n in cs] 37 | return dict(zip(bs, cs)) 38 | 39 | 40 | def get_pairs(word): 41 | """Return set of symbol pairs in a word. 42 | Word is represented as tuple of symbols (symbols being variable-length strings). 43 | """ 44 | pairs = set() 45 | prev_char = word[0] 46 | for char in word[1:]: 47 | pairs.add((prev_char, char)) 48 | prev_char = char 49 | return pairs 50 | 51 | 52 | def basic_clean(text): 53 | text = ftfy.fix_text(text) 54 | text = html.unescape(html.unescape(text)) 55 | return text.strip() 56 | 57 | 58 | def whitespace_clean(text): 59 | text = re.sub(r'\s+', ' ', text) 60 | text = text.strip() 61 | return text 62 | 63 | 64 | class SimpleTokenizer(object): 65 | def __init__(self, bpe_path: str = default_bpe()): 66 | self.byte_encoder = bytes_to_unicode() 67 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 68 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 69 | merges = merges[1:49152-256-2+1] 70 | merges = [tuple(merge.split()) for merge in merges] 71 | vocab = list(bytes_to_unicode().values()) 72 | vocab = vocab + [v+'' for v in vocab] 73 | for merge in merges: 74 | vocab.append(''.join(merge)) 75 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 76 | self.encoder = dict(zip(vocab, range(len(vocab)))) 77 | self.decoder = {v: k for k, v in self.encoder.items()} 78 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 79 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 80 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 81 | 82 | def bpe(self, token): 83 | if token in self.cache: 84 | return self.cache[token] 85 | word = tuple(token[:-1]) + ( token[-1] + '',) 86 | pairs = get_pairs(word) 87 | 88 | if not pairs: 89 | return token+'' 90 | 91 | while True: 92 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 93 | if bigram not in self.bpe_ranks: 94 | break 95 | first, second = bigram 96 | new_word = [] 97 | i = 0 98 | while i < len(word): 99 | try: 100 | j = word.index(first, i) 101 | new_word.extend(word[i:j]) 102 | i = j 103 | except: 104 | new_word.extend(word[i:]) 105 | break 106 | 107 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 108 | new_word.append(first+second) 109 | i += 2 110 | else: 111 | new_word.append(word[i]) 112 | i += 1 113 | new_word = tuple(new_word) 114 | word = new_word 115 | if len(word) == 1: 116 | break 117 | else: 118 | pairs = get_pairs(word) 119 | word = ' '.join(word) 120 | self.cache[token] = word 121 | return word 122 | 123 | def encode(self, text): 124 | bpe_tokens = [] 125 | text = whitespace_clean(basic_clean(text)).lower() 126 | for token in re.findall(self.pat, text): 127 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 128 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 129 | return bpe_tokens 130 | 131 | def decode(self, tokens): 132 | text = ''.join([self.decoder[token] for token in tokens]) 133 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 134 | return text 135 | 136 | 137 | 138 | _tokenizer = SimpleTokenizer() 139 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: 140 | """ 141 | Returns the tokenized representation of given input string(s) 142 | 143 | Parameters 144 | ---------- 145 | texts : Union[str, List[str]] 146 | An input string or a list of input strings to tokenize 147 | 148 | context_length : int 149 | The context length to use; all CLIP models use 77 as the context length 150 | 151 | truncate: bool 152 | Whether to truncate the text in case its encoding is longer than the context length 153 | 154 | Returns 155 | ------- 156 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. 157 | We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. 158 | """ 159 | if isinstance(texts, str): 160 | texts = [texts] 161 | 162 | sot_token = _tokenizer.encoder["<|startoftext|>"] 163 | eot_token = _tokenizer.encoder["<|endoftext|>"] 164 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 165 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): 166 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 167 | else: 168 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) 169 | 170 | for i, tokens in enumerate(all_tokens): 171 | if len(tokens) > context_length: 172 | if truncate: 173 | tokens = tokens[:context_length] 174 | tokens[-1] = eot_token 175 | else: 176 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 177 | result[i, :len(tokens)] = torch.tensor(tokens) 178 | 179 | return result 180 | -------------------------------------------------------------------------------- /model/evaclip/__init__.py: -------------------------------------------------------------------------------- 1 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 2 | from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer 3 | from .factory import list_models, add_model_config, get_model_config, load_checkpoint 4 | # from .loss import ClipLoss 5 | from .model import CLIP, CustomCLIP, CLIPTextCfg, CLIPVisionCfg,\ 6 | convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype 7 | from .openai import load_openai_model, list_openai_models 8 | from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model,\ 9 | get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained 10 | from .tokenizer import SimpleTokenizer, tokenize 11 | from .transform import image_transform -------------------------------------------------------------------------------- /model/evaclip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/invictus717/MiCo/831847ff066296d185c7f304563326224b9dbc8b/model/evaclip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /model/evaclip/constants.py: -------------------------------------------------------------------------------- 1 | OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) 2 | OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) 3 | -------------------------------------------------------------------------------- /model/evaclip/hf_configs.py: -------------------------------------------------------------------------------- 1 | # HF architecture dict: 2 | arch_dict = { 3 | # https://huggingface.co/docs/transformers/model_doc/roberta#roberta 4 | "roberta": { 5 | "config_names": { 6 | "context_length": "max_position_embeddings", 7 | "vocab_size": "vocab_size", 8 | "width": "hidden_size", 9 | "heads": "num_attention_heads", 10 | "layers": "num_hidden_layers", 11 | "layer_attr": "layer", 12 | "token_embeddings_attr": "embeddings" 13 | }, 14 | "pooler": "mean_pooler", 15 | }, 16 | # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig 17 | "xlm-roberta": { 18 | "config_names": { 19 | "context_length": "max_position_embeddings", 20 | "vocab_size": "vocab_size", 21 | "width": "hidden_size", 22 | "heads": "num_attention_heads", 23 | "layers": "num_hidden_layers", 24 | "layer_attr": "layer", 25 | "token_embeddings_attr": "embeddings" 26 | }, 27 | "pooler": "mean_pooler", 28 | }, 29 | # https://huggingface.co/docs/transformers/model_doc/mt5#mt5 30 | "mt5": { 31 | "config_names": { 32 | # unlimited seqlen 33 | # https://github.com/google-research/text-to-text-transfer-transformer/issues/273 34 | # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374 35 | "context_length": "", 36 | "vocab_size": "vocab_size", 37 | "width": "d_model", 38 | "heads": "num_heads", 39 | "layers": "num_layers", 40 | "layer_attr": "block", 41 | "token_embeddings_attr": "embed_tokens" 42 | }, 43 | "pooler": "mean_pooler", 44 | }, 45 | "bert": { 46 | "config_names": { 47 | "context_length": "max_position_embeddings", 48 | "vocab_size": "vocab_size", 49 | "width": "hidden_size", 50 | "heads": "num_attention_heads", 51 | "layers": "num_hidden_layers", 52 | "layer_attr": "layer", 53 | "token_embeddings_attr": "embeddings" 54 | }, 55 | "pooler": "mean_pooler", 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /model/evaclip/loss.py: -------------------------------------------------------------------------------- 1 | # import math 2 | # import torch 3 | # import torch.nn as nn 4 | # from torch.nn import functional as F 5 | 6 | # try: 7 | # import torch.distributed.nn 8 | # from torch import distributed as dist 9 | # has_distributed = True 10 | # except ImportError: 11 | # has_distributed = False 12 | 13 | # try: 14 | # import horovod.torch as hvd 15 | # except ImportError: 16 | # hvd = None 17 | 18 | # from timm.loss import LabelSmoothingCrossEntropy 19 | 20 | 21 | # def gather_features( 22 | # image_features, 23 | # text_features, 24 | # local_loss=False, 25 | # gather_with_grad=False, 26 | # rank=0, 27 | # world_size=1, 28 | # use_horovod=False 29 | # ): 30 | # assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.' 31 | # if use_horovod: 32 | # assert hvd is not None, 'Please install horovod' 33 | # if gather_with_grad: 34 | # all_image_features = hvd.allgather(image_features) 35 | # all_text_features = hvd.allgather(text_features) 36 | # else: 37 | # with torch.no_grad(): 38 | # all_image_features = hvd.allgather(image_features) 39 | # all_text_features = hvd.allgather(text_features) 40 | # if not local_loss: 41 | # # ensure grads for local rank when all_* features don't have a gradient 42 | # gathered_image_features = list(all_image_features.chunk(world_size, dim=0)) 43 | # gathered_text_features = list(all_text_features.chunk(world_size, dim=0)) 44 | # gathered_image_features[rank] = image_features 45 | # gathered_text_features[rank] = text_features 46 | # all_image_features = torch.cat(gathered_image_features, dim=0) 47 | # all_text_features = torch.cat(gathered_text_features, dim=0) 48 | # else: 49 | # # We gather tensors from all gpus 50 | # if gather_with_grad: 51 | # all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0) 52 | # all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0) 53 | # # all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features, async_op=True), dim=0) 54 | # # all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features, async_op=True), dim=0) 55 | # else: 56 | # gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)] 57 | # gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)] 58 | # dist.all_gather(gathered_image_features, image_features) 59 | # dist.all_gather(gathered_text_features, text_features) 60 | # if not local_loss: 61 | # # ensure grads for local rank when all_* features don't have a gradient 62 | # gathered_image_features[rank] = image_features 63 | # gathered_text_features[rank] = text_features 64 | # all_image_features = torch.cat(gathered_image_features, dim=0) 65 | # all_text_features = torch.cat(gathered_text_features, dim=0) 66 | 67 | # return all_image_features, all_text_features 68 | 69 | 70 | # class ClipLoss(nn.Module): 71 | 72 | # def __init__( 73 | # self, 74 | # local_loss=False, 75 | # gather_with_grad=False, 76 | # cache_labels=False, 77 | # rank=0, 78 | # world_size=1, 79 | # use_horovod=False, 80 | # smoothing=0., 81 | # ): 82 | # super().__init__() 83 | # self.local_loss = local_loss 84 | # self.gather_with_grad = gather_with_grad 85 | # self.cache_labels = cache_labels 86 | # self.rank = rank 87 | # self.world_size = world_size 88 | # self.use_horovod = use_horovod 89 | # self.label_smoothing_cross_entropy = LabelSmoothingCrossEntropy(smoothing=smoothing) if smoothing > 0 else None 90 | 91 | # # cache state 92 | # self.prev_num_logits = 0 93 | # self.labels = {} 94 | 95 | # def forward(self, image_features, text_features, logit_scale=1.): 96 | # device = image_features.device 97 | # if self.world_size > 1: 98 | # all_image_features, all_text_features = gather_features( 99 | # image_features, text_features, 100 | # self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) 101 | 102 | # if self.local_loss: 103 | # logits_per_image = logit_scale * image_features @ all_text_features.T 104 | # logits_per_text = logit_scale * text_features @ all_image_features.T 105 | # else: 106 | # logits_per_image = logit_scale * all_image_features @ all_text_features.T 107 | # logits_per_text = logits_per_image.T 108 | # else: 109 | # logits_per_image = logit_scale * image_features @ text_features.T 110 | # logits_per_text = logit_scale * text_features @ image_features.T 111 | # # calculated ground-truth and cache if enabled 112 | # num_logits = logits_per_image.shape[0] 113 | # if self.prev_num_logits != num_logits or device not in self.labels: 114 | # labels = torch.arange(num_logits, device=device, dtype=torch.long) 115 | # if self.world_size > 1 and self.local_loss: 116 | # labels = labels + num_logits * self.rank 117 | # if self.cache_labels: 118 | # self.labels[device] = labels 119 | # self.prev_num_logits = num_logits 120 | # else: 121 | # labels = self.labels[device] 122 | 123 | # if self.label_smoothing_cross_entropy: 124 | # total_loss = ( 125 | # self.label_smoothing_cross_entropy(logits_per_image, labels) + 126 | # self.label_smoothing_cross_entropy(logits_per_text, labels) 127 | # ) / 2 128 | # else: 129 | # total_loss = ( 130 | # F.cross_entropy(logits_per_image, labels) + 131 | # F.cross_entropy(logits_per_text, labels) 132 | # ) / 2 133 | 134 | # acc = None 135 | # i2t_acc = (logits_per_image.argmax(-1) == labels).sum() / len(logits_per_image) 136 | # t2i_acc = (logits_per_text.argmax(-1) == labels).sum() / len(logits_per_text) 137 | # acc = {"i2t": i2t_acc, "t2i": t2i_acc} 138 | # return total_loss, acc -------------------------------------------------------------------------------- /model/evaclip/model_configs/EVA01-CLIP-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 16, 8 | "eva_model_name": "eva-clip-b-16", 9 | "ls_init_value": 0.1, 10 | "drop_path_rate": 0.0 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /model/evaclip/model_configs/EVA01-CLIP-g-14-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-g-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "fusedLN": true 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 1024, 19 | "heads": 16, 20 | "layers": 24, 21 | "xattn": false, 22 | "fusedLN": true 23 | } 24 | } -------------------------------------------------------------------------------- /model/evaclip/model_configs/EVA01-CLIP-g-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-g-14-x", 11 | "drop_path_rate": 0.4, 12 | "xattn": true, 13 | "fusedLN": true 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 768, 19 | "heads": 12, 20 | "layers": 12, 21 | "xattn": false, 22 | "fusedLN": true 23 | } 24 | } -------------------------------------------------------------------------------- /model/evaclip/model_configs/EVA02-CLIP-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "head_width": 64, 8 | "patch_size": 16, 9 | "mlp_ratio": 2.6667, 10 | "eva_model_name": "eva-clip-b-16-X", 11 | "drop_path_rate": 0.0, 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 512, 24 | "heads": 8, 25 | "layers": 12, 26 | "xattn": true, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /model/evaclip/model_configs/EVA02-CLIP-L-14-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "layers": 24, 6 | "width": 1024, 7 | "drop_path_rate": 0, 8 | "head_width": 64, 9 | "mlp_ratio": 2.6667, 10 | "patch_size": 14, 11 | "eva_model_name": "eva-clip-l-14-336", 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 768, 24 | "heads": 12, 25 | "layers": 12, 26 | "xattn": false, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /model/evaclip/model_configs/EVA02-CLIP-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "drop_path_rate": 0, 8 | "head_width": 64, 9 | "mlp_ratio": 2.6667, 10 | "patch_size": 14, 11 | "eva_model_name": "eva-clip-l-14", 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 768, 24 | "heads": 12, 25 | "layers": 12, 26 | "xattn": false, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /model/evaclip/model_configs/EVA02-CLIP-bigE-14-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 64, 6 | "width": 1792, 7 | "head_width": 112, 8 | "mlp_ratio": 8.571428571428571, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-4b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": true, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1280, 20 | "heads": 20, 21 | "layers": 32, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /model/evaclip/model_configs/EVA02-CLIP-bigE-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 64, 6 | "width": 1792, 7 | "head_width": 112, 8 | "mlp_ratio": 8.571428571428571, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-4b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": true, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1024, 20 | "heads": 16, 21 | "layers": 24, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } -------------------------------------------------------------------------------- /model/evaclip/modified_resnet.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from .utils import freeze_batch_norm_2d 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, inplanes, planes, stride=1): 14 | super().__init__() 15 | 16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.act1 = nn.ReLU(inplace=True) 20 | 21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.act2 = nn.ReLU(inplace=True) 24 | 25 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 26 | 27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 29 | self.act3 = nn.ReLU(inplace=True) 30 | 31 | self.downsample = None 32 | self.stride = stride 33 | 34 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 35 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 36 | self.downsample = nn.Sequential(OrderedDict([ 37 | ("-1", nn.AvgPool2d(stride)), 38 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 39 | ("1", nn.BatchNorm2d(planes * self.expansion)) 40 | ])) 41 | 42 | def forward(self, x: torch.Tensor): 43 | identity = x 44 | 45 | out = self.act1(self.bn1(self.conv1(x))) 46 | out = self.act2(self.bn2(self.conv2(out))) 47 | out = self.avgpool(out) 48 | out = self.bn3(self.conv3(out)) 49 | 50 | if self.downsample is not None: 51 | identity = self.downsample(x) 52 | 53 | out += identity 54 | out = self.act3(out) 55 | return out 56 | 57 | 58 | class AttentionPool2d(nn.Module): 59 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 60 | super().__init__() 61 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 62 | self.k_proj = nn.Linear(embed_dim, embed_dim) 63 | self.q_proj = nn.Linear(embed_dim, embed_dim) 64 | self.v_proj = nn.Linear(embed_dim, embed_dim) 65 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 66 | self.num_heads = num_heads 67 | 68 | def forward(self, x): 69 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 70 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 71 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 72 | x, _ = F.multi_head_attention_forward( 73 | query=x, key=x, value=x, 74 | embed_dim_to_check=x.shape[-1], 75 | num_heads=self.num_heads, 76 | q_proj_weight=self.q_proj.weight, 77 | k_proj_weight=self.k_proj.weight, 78 | v_proj_weight=self.v_proj.weight, 79 | in_proj_weight=None, 80 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 81 | bias_k=None, 82 | bias_v=None, 83 | add_zero_attn=False, 84 | dropout_p=0., 85 | out_proj_weight=self.c_proj.weight, 86 | out_proj_bias=self.c_proj.bias, 87 | use_separate_proj_weight=True, 88 | training=self.training, 89 | need_weights=False 90 | ) 91 | 92 | return x[0] 93 | 94 | 95 | class ModifiedResNet(nn.Module): 96 | """ 97 | A ResNet class that is similar to torchvision's but contains the following changes: 98 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 99 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 100 | - The final pooling layer is a QKV attention instead of an average pool 101 | """ 102 | 103 | def __init__(self, layers, output_dim, heads, image_size=224, width=64): 104 | super().__init__() 105 | self.output_dim = output_dim 106 | self.image_size = image_size 107 | 108 | # the 3-layer stem 109 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 110 | self.bn1 = nn.BatchNorm2d(width // 2) 111 | self.act1 = nn.ReLU(inplace=True) 112 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 113 | self.bn2 = nn.BatchNorm2d(width // 2) 114 | self.act2 = nn.ReLU(inplace=True) 115 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 116 | self.bn3 = nn.BatchNorm2d(width) 117 | self.act3 = nn.ReLU(inplace=True) 118 | self.avgpool = nn.AvgPool2d(2) 119 | 120 | # residual layers 121 | self._inplanes = width # this is a *mutable* variable used during construction 122 | self.layer1 = self._make_layer(width, layers[0]) 123 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 124 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 125 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 126 | 127 | embed_dim = width * 32 # the ResNet feature dimension 128 | self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) 129 | 130 | self.init_parameters() 131 | 132 | def _make_layer(self, planes, blocks, stride=1): 133 | layers = [Bottleneck(self._inplanes, planes, stride)] 134 | 135 | self._inplanes = planes * Bottleneck.expansion 136 | for _ in range(1, blocks): 137 | layers.append(Bottleneck(self._inplanes, planes)) 138 | 139 | return nn.Sequential(*layers) 140 | 141 | def init_parameters(self): 142 | if self.attnpool is not None: 143 | std = self.attnpool.c_proj.in_features ** -0.5 144 | nn.init.normal_(self.attnpool.q_proj.weight, std=std) 145 | nn.init.normal_(self.attnpool.k_proj.weight, std=std) 146 | nn.init.normal_(self.attnpool.v_proj.weight, std=std) 147 | nn.init.normal_(self.attnpool.c_proj.weight, std=std) 148 | 149 | for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: 150 | for name, param in resnet_block.named_parameters(): 151 | if name.endswith("bn3.weight"): 152 | nn.init.zeros_(param) 153 | 154 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 155 | assert unlocked_groups == 0, 'partial locking not currently supported for this model' 156 | for param in self.parameters(): 157 | param.requires_grad = False 158 | if freeze_bn_stats: 159 | freeze_batch_norm_2d(self) 160 | 161 | @torch.jit.ignore 162 | def set_grad_checkpointing(self, enable=True): 163 | # FIXME support for non-transformer 164 | pass 165 | 166 | def stem(self, x): 167 | x = self.act1(self.bn1(self.conv1(x))) 168 | x = self.act2(self.bn2(self.conv2(x))) 169 | x = self.act3(self.bn3(self.conv3(x))) 170 | x = self.avgpool(x) 171 | return x 172 | 173 | def forward(self, x): 174 | x = self.stem(x) 175 | x = self.layer1(x) 176 | x = self.layer2(x) 177 | x = self.layer3(x) 178 | x = self.layer4(x) 179 | x = self.attnpool(x) 180 | 181 | return x 182 | -------------------------------------------------------------------------------- /model/evaclip/openai.py: -------------------------------------------------------------------------------- 1 | """ OpenAI pretrained model functions 2 | 3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | 6 | import os 7 | import warnings 8 | from typing import List, Optional, Union 9 | 10 | import torch 11 | 12 | from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype 13 | from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url 14 | 15 | __all__ = ["list_openai_models", "load_openai_model"] 16 | 17 | 18 | def list_openai_models() -> List[str]: 19 | """Returns the names of available CLIP models""" 20 | return list_pretrained_models_by_tag('openai') 21 | 22 | 23 | def load_openai_model( 24 | name: str, 25 | precision: Optional[str] = None, 26 | device: Optional[Union[str, torch.device]] = None, 27 | jit: bool = True, 28 | cache_dir: Optional[str] = None, 29 | ): 30 | """Load a CLIP model 31 | 32 | Parameters 33 | ---------- 34 | name : str 35 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 36 | precision: str 37 | Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. 38 | device : Union[str, torch.device] 39 | The device to put the loaded model 40 | jit : bool 41 | Whether to load the optimized JIT model (default) or more hackable non-JIT model. 42 | cache_dir : Optional[str] 43 | The directory to cache the downloaded model weights 44 | 45 | Returns 46 | ------- 47 | model : torch.nn.Module 48 | The CLIP model 49 | preprocess : Callable[[PIL.Image], torch.Tensor] 50 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 51 | """ 52 | if device is None: 53 | device = "cuda" if torch.cuda.is_available() else "cpu" 54 | if precision is None: 55 | precision = 'fp32' if device == 'cpu' else 'fp16' 56 | 57 | if get_pretrained_url(name, 'openai'): 58 | model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir) 59 | elif os.path.isfile(name): 60 | model_path = name 61 | else: 62 | raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") 63 | 64 | try: 65 | # loading JIT archive 66 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 67 | state_dict = None 68 | except RuntimeError: 69 | # loading saved state dict 70 | if jit: 71 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 72 | jit = False 73 | state_dict = torch.load(model_path, map_location="cpu") 74 | 75 | if not jit: 76 | # Build a non-jit model from the OpenAI jitted model state dict 77 | cast_dtype = get_cast_dtype(precision) 78 | try: 79 | model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype) 80 | except KeyError: 81 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} 82 | model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype) 83 | 84 | # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use 85 | model = model.to(device) 86 | if precision.startswith('amp') or precision == 'fp32': 87 | model.float() 88 | elif precision == 'bf16': 89 | convert_weights_to_lp(model, dtype=torch.bfloat16) 90 | 91 | return model 92 | 93 | # patch the device names 94 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 95 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 96 | 97 | def patch_device(module): 98 | try: 99 | graphs = [module.graph] if hasattr(module, "graph") else [] 100 | except RuntimeError: 101 | graphs = [] 102 | 103 | if hasattr(module, "forward1"): 104 | graphs.append(module.forward1.graph) 105 | 106 | for graph in graphs: 107 | for node in graph.findAllNodes("prim::Constant"): 108 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 109 | node.copyAttributes(device_node) 110 | 111 | model.apply(patch_device) 112 | patch_device(model.encode_image) 113 | patch_device(model.encode_text) 114 | 115 | # patch dtype to float32 (typically for CPU) 116 | if precision == 'fp32': 117 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 118 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 119 | float_node = float_input.node() 120 | 121 | def patch_float(module): 122 | try: 123 | graphs = [module.graph] if hasattr(module, "graph") else [] 124 | except RuntimeError: 125 | graphs = [] 126 | 127 | if hasattr(module, "forward1"): 128 | graphs.append(module.forward1.graph) 129 | 130 | for graph in graphs: 131 | for node in graph.findAllNodes("aten::to"): 132 | inputs = list(node.inputs()) 133 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 134 | if inputs[i].node()["value"] == 5: 135 | inputs[i].node().copyAttributes(float_node) 136 | 137 | model.apply(patch_float) 138 | patch_float(model.encode_image) 139 | patch_float(model.encode_text) 140 | model.float() 141 | 142 | # ensure image_size attr available at consistent location for both jit and non-jit 143 | model.visual.image_size = model.input_resolution.item() 144 | return model 145 | -------------------------------------------------------------------------------- /model/evaclip/rope.py: -------------------------------------------------------------------------------- 1 | from math import pi 2 | import torch 3 | from torch import nn 4 | from einops import rearrange, repeat 5 | import logging 6 | 7 | def broadcat(tensors, dim = -1): 8 | num_tensors = len(tensors) 9 | shape_lens = set(list(map(lambda t: len(t.shape), tensors))) 10 | assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions' 11 | shape_len = list(shape_lens)[0] 12 | dim = (dim + shape_len) if dim < 0 else dim 13 | dims = list(zip(*map(lambda t: list(t.shape), tensors))) 14 | expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] 15 | assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation' 16 | max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) 17 | expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) 18 | expanded_dims.insert(dim, (dim, dims[dim])) 19 | expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) 20 | tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) 21 | return torch.cat(tensors, dim = dim) 22 | 23 | def rotate_half(x): 24 | x = rearrange(x, '... (d r) -> ... d r', r = 2) 25 | x1, x2 = x.unbind(dim = -1) 26 | x = torch.stack((-x2, x1), dim = -1) 27 | return rearrange(x, '... d r -> ... (d r)') 28 | 29 | 30 | class VisionRotaryEmbedding(nn.Module): 31 | def __init__( 32 | self, 33 | dim, 34 | pt_seq_len, 35 | ft_seq_len=None, 36 | custom_freqs = None, 37 | freqs_for = 'lang', 38 | theta = 10000, 39 | max_freq = 10, 40 | num_freqs = 1, 41 | ): 42 | super().__init__() 43 | if custom_freqs: 44 | freqs = custom_freqs 45 | elif freqs_for == 'lang': 46 | freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) 47 | elif freqs_for == 'pixel': 48 | freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi 49 | elif freqs_for == 'constant': 50 | freqs = torch.ones(num_freqs).float() 51 | else: 52 | raise ValueError(f'unknown modality {freqs_for}') 53 | 54 | if ft_seq_len is None: ft_seq_len = pt_seq_len 55 | t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len 56 | 57 | freqs_h = torch.einsum('..., f -> ... f', t, freqs) 58 | freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2) 59 | 60 | freqs_w = torch.einsum('..., f -> ... f', t, freqs) 61 | freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2) 62 | 63 | freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1) 64 | 65 | self.register_buffer("freqs_cos", freqs.cos()) 66 | self.register_buffer("freqs_sin", freqs.sin()) 67 | 68 | logging.info(f'Shape of rope freq: {self.freqs_cos.shape}') 69 | 70 | def forward(self, t, start_index = 0): 71 | rot_dim = self.freqs_cos.shape[-1] 72 | end_index = start_index + rot_dim 73 | assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}' 74 | t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:] 75 | t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin) 76 | 77 | return torch.cat((t_left, t, t_right), dim = -1) 78 | 79 | class VisionRotaryEmbeddingFast(nn.Module): 80 | def __init__( 81 | self, 82 | dim, 83 | pt_seq_len, 84 | ft_seq_len=None, 85 | custom_freqs = None, 86 | freqs_for = 'lang', 87 | theta = 10000, 88 | max_freq = 10, 89 | num_freqs = 1, 90 | patch_dropout = 0. 91 | ): 92 | super().__init__() 93 | if custom_freqs: 94 | freqs = custom_freqs 95 | elif freqs_for == 'lang': 96 | freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) 97 | elif freqs_for == 'pixel': 98 | freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi 99 | elif freqs_for == 'constant': 100 | freqs = torch.ones(num_freqs).float() 101 | else: 102 | raise ValueError(f'unknown modality {freqs_for}') 103 | 104 | if ft_seq_len is None: ft_seq_len = pt_seq_len 105 | t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len 106 | 107 | freqs = torch.einsum('..., f -> ... f', t, freqs) 108 | freqs = repeat(freqs, '... n -> ... (n r)', r = 2) 109 | freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1) 110 | 111 | freqs_cos = freqs.cos().view(-1, freqs.shape[-1]) 112 | freqs_sin = freqs.sin().view(-1, freqs.shape[-1]) 113 | 114 | self.patch_dropout = patch_dropout 115 | 116 | self.register_buffer("freqs_cos", freqs_cos) 117 | self.register_buffer("freqs_sin", freqs_sin) 118 | 119 | logging.info(f'Shape of rope freq: {self.freqs_cos.shape}') 120 | 121 | def forward(self, t, patch_indices_keep=None): 122 | if patch_indices_keep is not None: 123 | batch = t.size()[0] 124 | batch_indices = torch.arange(batch) 125 | batch_indices = batch_indices[..., None] 126 | 127 | freqs_cos = repeat(self.freqs_cos, 'i j -> n i m j', n=t.shape[0], m=t.shape[1]) 128 | freqs_sin = repeat(self.freqs_sin, 'i j -> n i m j', n=t.shape[0], m=t.shape[1]) 129 | 130 | freqs_cos = freqs_cos[batch_indices, patch_indices_keep] 131 | freqs_cos = rearrange(freqs_cos, 'n i m j -> n m i j') 132 | freqs_sin = freqs_sin[batch_indices, patch_indices_keep] 133 | freqs_sin = rearrange(freqs_sin, 'n i m j -> n m i j') 134 | 135 | return t * freqs_cos + rotate_half(t) * freqs_sin 136 | 137 | return t * self.freqs_cos + rotate_half(t) * self.freqs_sin -------------------------------------------------------------------------------- /model/evaclip/timm_model.py: -------------------------------------------------------------------------------- 1 | """ timm model adapter 2 | 3 | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. 4 | """ 5 | import logging 6 | from collections import OrderedDict 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | try: 12 | import timm 13 | from timm.models.layers import Mlp, to_2tuple 14 | try: 15 | # old timm imports < 0.8.1 16 | from timm.models.layers.attention_pool2d import RotAttentionPool2d 17 | from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d 18 | except ImportError: 19 | # new timm imports >= 0.8.1 20 | from timm.layers import RotAttentionPool2d 21 | from timm.layers import AttentionPool2d as AbsAttentionPool2d 22 | except ImportError: 23 | timm = None 24 | 25 | from .utils import freeze_batch_norm_2d 26 | 27 | 28 | class TimmModel(nn.Module): 29 | """ timm model adapter 30 | # FIXME this adapter is a work in progress, may change in ways that break weight compat 31 | """ 32 | 33 | def __init__( 34 | self, 35 | model_name, 36 | embed_dim, 37 | image_size=224, 38 | pool='avg', 39 | proj='linear', 40 | proj_bias=False, 41 | drop=0., 42 | pretrained=False): 43 | super().__init__() 44 | if timm is None: 45 | raise RuntimeError("Please `pip install timm` to use timm models.") 46 | 47 | self.image_size = to_2tuple(image_size) 48 | self.trunk = timm.create_model(model_name, pretrained=pretrained) 49 | feat_size = self.trunk.default_cfg.get('pool_size', None) 50 | feature_ndim = 1 if not feat_size else 2 51 | if pool in ('abs_attn', 'rot_attn'): 52 | assert feature_ndim == 2 53 | # if attn pooling used, remove both classifier and default pool 54 | self.trunk.reset_classifier(0, global_pool='') 55 | else: 56 | # reset global pool if pool config set, otherwise leave as network default 57 | reset_kwargs = dict(global_pool=pool) if pool else {} 58 | self.trunk.reset_classifier(0, **reset_kwargs) 59 | prev_chs = self.trunk.num_features 60 | 61 | head_layers = OrderedDict() 62 | if pool == 'abs_attn': 63 | head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) 64 | prev_chs = embed_dim 65 | elif pool == 'rot_attn': 66 | head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim) 67 | prev_chs = embed_dim 68 | else: 69 | assert proj, 'projection layer needed if non-attention pooling is used.' 70 | 71 | # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used 72 | if proj == 'linear': 73 | head_layers['drop'] = nn.Dropout(drop) 74 | head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) 75 | elif proj == 'mlp': 76 | head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop, bias=(True, proj_bias)) 77 | 78 | self.head = nn.Sequential(head_layers) 79 | 80 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 81 | """ lock modules 82 | Args: 83 | unlocked_groups (int): leave last n layer groups unlocked (default: 0) 84 | """ 85 | if not unlocked_groups: 86 | # lock full model 87 | for param in self.trunk.parameters(): 88 | param.requires_grad = False 89 | if freeze_bn_stats: 90 | freeze_batch_norm_2d(self.trunk) 91 | else: 92 | # NOTE: partial freeze requires latest timm (master) branch and is subject to change 93 | try: 94 | # FIXME import here until API stable and in an official release 95 | from timm.models.helpers import group_parameters, group_modules 96 | except ImportError: 97 | raise RuntimeError( 98 | 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`') 99 | matcher = self.trunk.group_matcher() 100 | gparams = group_parameters(self.trunk, matcher) 101 | max_layer_id = max(gparams.keys()) 102 | max_layer_id = max_layer_id - unlocked_groups 103 | for group_idx in range(max_layer_id + 1): 104 | group = gparams[group_idx] 105 | for param in group: 106 | self.trunk.get_parameter(param).requires_grad = False 107 | if freeze_bn_stats: 108 | gmodules = group_modules(self.trunk, matcher, reverse=True) 109 | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} 110 | freeze_batch_norm_2d(self.trunk, gmodules) 111 | 112 | @torch.jit.ignore 113 | def set_grad_checkpointing(self, enable=True): 114 | try: 115 | self.trunk.set_grad_checkpointing(enable) 116 | except Exception as e: 117 | logging.warning('grad checkpointing not supported for this timm image tower, continuing without...') 118 | 119 | def forward(self, x): 120 | x = self.trunk(x) 121 | x = self.head(x) 122 | return x 123 | -------------------------------------------------------------------------------- /model/evaclip/tokenizer.py: -------------------------------------------------------------------------------- 1 | """ CLIP tokenizer 2 | 3 | Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | import gzip 6 | import html 7 | import os 8 | from functools import lru_cache 9 | from typing import Union, List 10 | 11 | # import ftfy 12 | import regex as re 13 | import torch 14 | 15 | # https://stackoverflow.com/q/62691279 16 | import os 17 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 18 | 19 | 20 | @lru_cache() 21 | def default_bpe(): 22 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 23 | 24 | 25 | @lru_cache() 26 | def bytes_to_unicode(): 27 | """ 28 | Returns list of utf-8 byte and a corresponding list of unicode strings. 29 | The reversible bpe codes work on unicode strings. 30 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 31 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 32 | This is a signficant percentage of your normal, say, 32K bpe vocab. 33 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 34 | And avoids mapping to whitespace/control characters the bpe code barfs on. 35 | """ 36 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 37 | cs = bs[:] 38 | n = 0 39 | for b in range(2**8): 40 | if b not in bs: 41 | bs.append(b) 42 | cs.append(2**8+n) 43 | n += 1 44 | cs = [chr(n) for n in cs] 45 | return dict(zip(bs, cs)) 46 | 47 | 48 | def get_pairs(word): 49 | """Return set of symbol pairs in a word. 50 | Word is represented as tuple of symbols (symbols being variable-length strings). 51 | """ 52 | pairs = set() 53 | prev_char = word[0] 54 | for char in word[1:]: 55 | pairs.add((prev_char, char)) 56 | prev_char = char 57 | return pairs 58 | 59 | 60 | def basic_clean(text): 61 | text = ftfy.fix_text(text) 62 | text = html.unescape(html.unescape(text)) 63 | return text.strip() 64 | 65 | 66 | def whitespace_clean(text): 67 | text = re.sub(r'\s+', ' ', text) 68 | text = text.strip() 69 | return text 70 | 71 | 72 | class SimpleTokenizer(object): 73 | def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): 74 | self.byte_encoder = bytes_to_unicode() 75 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 76 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 77 | merges = merges[1:49152-256-2+1] 78 | merges = [tuple(merge.split()) for merge in merges] 79 | vocab = list(bytes_to_unicode().values()) 80 | vocab = vocab + [v+'' for v in vocab] 81 | for merge in merges: 82 | vocab.append(''.join(merge)) 83 | if not special_tokens: 84 | special_tokens = ['', ''] 85 | else: 86 | special_tokens = ['', ''] + special_tokens 87 | vocab.extend(special_tokens) 88 | self.encoder = dict(zip(vocab, range(len(vocab)))) 89 | self.decoder = {v: k for k, v in self.encoder.items()} 90 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 91 | self.cache = {t:t for t in special_tokens} 92 | special = "|".join(special_tokens) 93 | self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 94 | 95 | self.vocab_size = len(self.encoder) 96 | self.all_special_ids = [self.encoder[t] for t in special_tokens] 97 | 98 | def bpe(self, token): 99 | if token in self.cache: 100 | return self.cache[token] 101 | word = tuple(token[:-1]) + ( token[-1] + '',) 102 | pairs = get_pairs(word) 103 | 104 | if not pairs: 105 | return token+'' 106 | 107 | while True: 108 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 109 | if bigram not in self.bpe_ranks: 110 | break 111 | first, second = bigram 112 | new_word = [] 113 | i = 0 114 | while i < len(word): 115 | try: 116 | j = word.index(first, i) 117 | new_word.extend(word[i:j]) 118 | i = j 119 | except: 120 | new_word.extend(word[i:]) 121 | break 122 | 123 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 124 | new_word.append(first+second) 125 | i += 2 126 | else: 127 | new_word.append(word[i]) 128 | i += 1 129 | new_word = tuple(new_word) 130 | word = new_word 131 | if len(word) == 1: 132 | break 133 | else: 134 | pairs = get_pairs(word) 135 | word = ' '.join(word) 136 | self.cache[token] = word 137 | return word 138 | 139 | def encode(self, text): 140 | bpe_tokens = [] 141 | text = whitespace_clean(basic_clean(text)).lower() 142 | for token in re.findall(self.pat, text): 143 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 144 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 145 | return bpe_tokens 146 | 147 | def decode(self, tokens): 148 | text = ''.join([self.decoder[token] for token in tokens]) 149 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 150 | return text 151 | 152 | 153 | _tokenizer = SimpleTokenizer() 154 | 155 | 156 | def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: 157 | """ 158 | Returns the tokenized representation of given input string(s) 159 | 160 | Parameters 161 | ---------- 162 | texts : Union[str, List[str]] 163 | An input string or a list of input strings to tokenize 164 | context_length : int 165 | The context length to use; all CLIP models use 77 as the context length 166 | 167 | Returns 168 | ------- 169 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 170 | """ 171 | if isinstance(texts, str): 172 | texts = [texts] 173 | 174 | sot_token = _tokenizer.encoder[""] 175 | eot_token = _tokenizer.encoder[""] 176 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 177 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 178 | 179 | for i, tokens in enumerate(all_tokens): 180 | if len(tokens) > context_length: 181 | tokens = tokens[:context_length] # Truncate 182 | tokens[-1] = eot_token 183 | result[i, :len(tokens)] = torch.tensor(tokens) 184 | 185 | return result 186 | 187 | 188 | class HFTokenizer: 189 | "HuggingFace tokenizer wrapper" 190 | def __init__(self, tokenizer_name:str): 191 | from transformers import AutoTokenizer 192 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 193 | 194 | def __call__(self, texts:Union[str, List[str]], context_length:int=77) -> torch.Tensor: 195 | # same cleaning as for default tokenizer, except lowercasing 196 | # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance 197 | if isinstance(texts, str): 198 | texts = [texts] 199 | texts = [whitespace_clean(basic_clean(text)) for text in texts] 200 | input_ids = self.tokenizer(texts, return_tensors='pt', max_length=context_length, padding='max_length', truncation=True).input_ids 201 | return input_ids 202 | -------------------------------------------------------------------------------- /model/evaclip/transform.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Sequence, Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torchvision.transforms.functional as F 6 | 7 | from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ 8 | CenterCrop 9 | 10 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 11 | 12 | 13 | class ResizeMaxSize(nn.Module): 14 | 15 | def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0): 16 | super().__init__() 17 | if not isinstance(max_size, int): 18 | raise TypeError(f"Size should be int. Got {type(max_size)}") 19 | self.max_size = max_size 20 | self.interpolation = interpolation 21 | self.fn = min if fn == 'min' else min 22 | self.fill = fill 23 | 24 | def forward(self, img): 25 | if isinstance(img, torch.Tensor): 26 | height, width = img.shape[:2] 27 | else: 28 | width, height = img.size 29 | scale = self.max_size / float(max(height, width)) 30 | if scale != 1.0: 31 | new_size = tuple(round(dim * scale) for dim in (height, width)) 32 | img = F.resize(img, new_size, self.interpolation) 33 | pad_h = self.max_size - new_size[0] 34 | pad_w = self.max_size - new_size[1] 35 | img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill) 36 | return img 37 | 38 | 39 | def _convert_to_rgb(image): 40 | return image.convert('RGB') 41 | 42 | 43 | # class CatGen(nn.Module): 44 | # def __init__(self, num=4): 45 | # self.num = num 46 | # def mixgen_batch(image, text): 47 | # batch_size = image.shape[0] 48 | # index = np.random.permutation(batch_size) 49 | 50 | # cat_images = [] 51 | # for i in range(batch_size): 52 | # # image mixup 53 | # image[i,:] = lam * image[i,:] + (1 - lam) * image[index[i],:] 54 | # # text concat 55 | # text[i] = tokenizer((str(text[i]) + " " + str(text[index[i]])))[0] 56 | # text = torch.stack(text) 57 | # return image, text 58 | 59 | 60 | def image_transform( 61 | image_size: int, 62 | is_train: bool, 63 | mean: Optional[Tuple[float, ...]] = None, 64 | std: Optional[Tuple[float, ...]] = None, 65 | resize_longest_max: bool = False, 66 | fill_color: int = 0, 67 | ): 68 | mean = mean or OPENAI_DATASET_MEAN 69 | if not isinstance(mean, (list, tuple)): 70 | mean = (mean,) * 3 71 | 72 | std = std or OPENAI_DATASET_STD 73 | if not isinstance(std, (list, tuple)): 74 | std = (std,) * 3 75 | 76 | if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]: 77 | # for square size, pass size as int so that Resize() uses aspect preserving shortest edge 78 | image_size = image_size[0] 79 | 80 | normalize = Normalize(mean=mean, std=std) 81 | if is_train: 82 | return Compose([ 83 | RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC), 84 | _convert_to_rgb, 85 | ToTensor(), 86 | normalize, 87 | ]) 88 | else: 89 | if resize_longest_max: 90 | transforms = [ 91 | ResizeMaxSize(image_size, fill=fill_color) 92 | ] 93 | else: 94 | transforms = [ 95 | Resize(image_size, interpolation=InterpolationMode.BICUBIC), 96 | CenterCrop(image_size), 97 | ] 98 | transforms.extend([ 99 | _convert_to_rgb, 100 | ToTensor(), 101 | normalize, 102 | ]) 103 | return Compose(transforms) 104 | -------------------------------------------------------------------------------- /model/imageprocessor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from PIL import Image 4 | 5 | import torch 6 | from torchvision.transforms.transforms import * 7 | from torchvision import transforms 8 | 9 | 10 | class ImageProcessor(object): 11 | def __init__(self, image_resolution, image_encoder_type, image_transforms='none', training=True): 12 | self.training = training 13 | 14 | self.resolution = image_resolution 15 | self.image_encoder_type = image_encoder_type 16 | 17 | if image_encoder_type.startswith('clip') or image_encoder_type.startswith('evaclip'): 18 | self.mean = [0.48145466, 0.4578275, 0.40821073] 19 | self.std = [0.26862954, 0.26130258, 0.27577711] 20 | else: 21 | self.mean = [0.485, 0.456, 0.406] 22 | self.std = [0.229, 0.224, 0.225] 23 | 24 | self.image_transforms = image_transforms 25 | if image_transforms == 'none': 26 | self.train_transforms = transforms.Compose([Resize((self.resolution,self.resolution)), 27 | Normalize(self.mean,self.std)]) 28 | 29 | self.test_transforms = transforms.Compose([Resize((self.resolution,self.resolution)), 30 | Normalize(self.mean,self.std)]) 31 | elif image_transforms == 'crop_flip': 32 | self.train_transforms = transforms.Compose([RandomResizedCrop(self.resolution, [0.8,1.0],[1.0,1.0]), 33 | RandomHorizontalFlip(), 34 | Normalize(self.mean,self.std)]) 35 | 36 | self.test_transforms = transforms.Compose([Resize(self.resolution), 37 | CenterCrop(self.resolution), 38 | Normalize(self.mean,self.std)]) 39 | else: 40 | raise NotImplementedError 41 | 42 | def __call__(self, image_file): 43 | 44 | try: 45 | img_path = image_file 46 | if not os.path.exists(img_path): 47 | print('not have image', image_file) 48 | return None 49 | img = Image.open(img_path) 50 | img = img.convert('RGB') #### convert 1-channel gray image and 4-channel CMYK image to RGB image 51 | img = transforms.ToTensor()(img) 52 | if self.training: 53 | img = self.train_transforms(img) 54 | else: 55 | img = self.test_transforms(img) 56 | 57 | img = img.unsqueeze(0) 58 | 59 | return img 60 | 61 | except Exception as e: 62 | print(e) 63 | return None 64 | 65 | if __name__ == "__main__": 66 | image_file = "./data/test.jpeg" 67 | proc = ImageProcessor(image_resolution=224, image_encoder_type="swin", training=True) 68 | image_input = proc(image_file) 69 | print(image_input.size()) -------------------------------------------------------------------------------- /model/swin_base_patch4_window7_224_22k.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | DATASET: imagenet22K 3 | MODEL: 4 | TYPE: swin 5 | NAME: swin_base_patch4_window7_224_22k 6 | DROP_PATH_RATE: 0.2 7 | SWIN: 8 | EMBED_DIM: 128 9 | DEPTHS: [ 2, 2, 18, 2 ] 10 | NUM_HEADS: [ 4, 8, 16, 32 ] 11 | WINDOW_SIZE: 7 12 | TRAIN: 13 | EPOCHS: 90 14 | WARMUP_EPOCHS: 5 15 | WEIGHT_DECAY: 0.05 16 | BASE_LR: 1.25e-4 # 4096 batch-size 17 | WARMUP_LR: 1.25e-7 18 | MIN_LR: 1.25e-6 -------------------------------------------------------------------------------- /model/tokenizer/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "cls_token": "[CLS]", 3 | "mask_token": "[MASK]", 4 | "pad_token": "[PAD]", 5 | "sep_token": "[SEP]", 6 | "unk_token": "[UNK]" 7 | } 8 | -------------------------------------------------------------------------------- /model/tokenizer/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "cls_token": "[CLS]", 3 | "do_basic_tokenize": true, 4 | "do_lower_case": true, 5 | "mask_token": "[MASK]", 6 | "model_max_length": 512, 7 | "name_or_path": "bert-base-uncased", 8 | "never_split": null, 9 | "pad_token": "[PAD]", 10 | "sep_token": "[SEP]", 11 | "special_tokens_map_file": null, 12 | "strip_accents": null, 13 | "tokenize_chinese_chars": true, 14 | "tokenizer_class": "BertTokenizer", 15 | "unk_token": "[UNK]" 16 | } 17 | -------------------------------------------------------------------------------- /model/transformer.py: -------------------------------------------------------------------------------- 1 | """ 2 | BERT layers from the huggingface implementation 3 | (https://github.com/huggingface/transformers) 4 | """ 5 | # coding=utf-8 6 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 7 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | import logging 21 | import math 22 | import copy 23 | import torch 24 | from torch import nn 25 | # from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm 26 | from torch.nn import LayerNorm 27 | import torch.nn.functional as F 28 | import ipdb 29 | 30 | logger = logging.getLogger(__name__) 31 | 32 | 33 | def gelu(x): 34 | """Implementation of the gelu activation function. 35 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 36 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 37 | Also see https://arxiv.org/abs/1606.08415 38 | """ 39 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 40 | 41 | 42 | def swish(x): 43 | return x * torch.sigmoid(x) 44 | 45 | 46 | ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} 47 | 48 | 49 | class GELU(nn.Module): 50 | def forward(self, input_): 51 | output = gelu(input_) 52 | return output 53 | 54 | 55 | 56 | 57 | class TransformerLayer(nn.Module): 58 | def __init__(self, config, mode): 59 | super().__init__() 60 | self.attention = MultiHeadAttention(config) 61 | self.ff_layer = FeedForward(config) 62 | self.dropout = nn.Dropout(config.hidden_dropout) 63 | self.layernorm1 = LayerNorm(config.hidden_size, eps=1e-12) 64 | self.layernorm2 = LayerNorm(config.hidden_size, eps=1e-12) 65 | self.mode = mode 66 | 67 | def forward(self, hidden_states, attention_mask): 68 | if self.mode == 'prenorm': 69 | return self.forward_prenorm(hidden_states, attention_mask) 70 | elif self.mode == 'postnorm': 71 | return self.forward_postnorm(hidden_states, attention_mask) 72 | else: 73 | raise NotImplementedError 74 | 75 | def forward_prenorm(self, hidden_states, attention_mask): 76 | residual = hidden_states 77 | hidden_states = self.layernorm1(hidden_states) 78 | attention_output = self.attention(hidden_states, hidden_states, hidden_states, attention_mask) 79 | hidden_states = residual + self.dropout(attention_output) 80 | 81 | residual = hidden_states 82 | hidden_states = self.layernorm2(hidden_states) 83 | ff_output = self.ff_layer(hidden_states) 84 | hidden_states = residual + self.dropout(ff_output) 85 | 86 | return hidden_states 87 | 88 | def forward_postnorm(self, hidden_states, attention_mask): 89 | residual = hidden_states 90 | attention_output = self.attention(hidden_states, hidden_states, hidden_states, attention_mask) 91 | hidden_states = residual + self.dropout(attention_output) 92 | hidden_states = self.layernorm1(hidden_states) 93 | 94 | residual = hidden_states 95 | ff_output = self.ff_layer(hidden_states) 96 | hidden_states = residual + self.dropout(ff_output) 97 | hidden_states = self.layernorm2(hidden_states) 98 | 99 | return hidden_states 100 | 101 | 102 | def clones(x,times): 103 | return nn.ModuleList([copy.deepcopy(x) for i in range(times)]) 104 | 105 | 106 | 107 | class MultiHeadAttention(nn.Module): 108 | def __init__(self, config): 109 | super().__init__() 110 | self.linears = clones(nn.Linear(config.hidden_size, config.hidden_size), 4) 111 | self.head_num = config.num_attention_heads 112 | self.hidden_size = config.hidden_size 113 | self.dropout=nn.Dropout(config.attention_dropout) 114 | 115 | 116 | def forward(self,q,k,v,mask=None): 117 | batch_size=q.shape[0] 118 | q,k,v=[layer(x).view(batch_size,-1,self.head_num, self.hidden_size//self.head_num).transpose(1,2) \ 119 | for layer,x in zip(self.linears,(q,k,v))] 120 | norm_d=q.shape[-1] 121 | att_map=torch.matmul(q,k.transpose(-2,-1)) / math.sqrt(norm_d) 122 | if mask is not None: 123 | att_map=att_map + mask 124 | att_map=F.softmax(att_map,dim=-1) 125 | # import ipdb 126 | # if att_map.shape[-1] == 45: 127 | # ipdb.set_trace() 128 | 129 | att_map=self.dropout(att_map) 130 | attn_output = self.linears[-1](torch.matmul(att_map,v).transpose(1,2).contiguous().view(batch_size,-1,self.hidden_size)) 131 | return attn_output 132 | 133 | 134 | class FeedForward(nn.Module): 135 | def __init__(self, config): 136 | super().__init__() 137 | self.linear1=nn.Linear(config.hidden_size, config.intermediate_size) 138 | self.linear2=nn.Linear(config.intermediate_size, config.hidden_size) 139 | self.activation = GELU() 140 | 141 | 142 | def forward(self,x): 143 | return self.linear2((self.activation(self.linear1(x)))) 144 | 145 | 146 | 147 | class TransformerEncoder(nn.Module): 148 | def __init__(self, config, mode = 'prenorm'): 149 | super().__init__() 150 | layer = TransformerLayer(config, mode) 151 | self.mode = mode 152 | self.layer = nn.ModuleList([copy.deepcopy(layer) 153 | for _ in range(config.num_hidden_layers)]) 154 | if self.mode == 'prenorm': 155 | self.last_layernorm = LayerNorm(config.hidden_size, eps=1e-12) 156 | self.checkpointing = config.checkpointing 157 | def forward(self, input_, attention_mask=None, cross_hidden_states=None, 158 | use_cache=False, 159 | cache=None, 160 | cache_first=False, 161 | cache_type=None): 162 | hidden_states = input_ 163 | for layer_module in self.layer: 164 | if self.checkpointing: 165 | hidden_states = torch.utils.checkpoint.checkpoint(layer_module, hidden_states, attention_mask) 166 | else: 167 | hidden_states = layer_module(hidden_states, attention_mask) 168 | 169 | if self.mode == 'prenorm': 170 | hidden_states = self.last_layernorm(hidden_states) 171 | return hidden_states, cache 172 | 173 | -------------------------------------------------------------------------------- /model/videoprocessor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from decord import VideoReader 4 | from PIL import Image 5 | 6 | import torch 7 | from torchvision.transforms.transforms import * 8 | from torchvision import transforms 9 | 10 | 11 | def split(frame_name_lists, sample_num): 12 | if len(frame_name_lists) < sample_num: ###padding with the last frame 13 | frame_name_lists += [frame_name_lists[-1]]*(sample_num - len(frame_name_lists)) 14 | k, m = divmod(len(frame_name_lists), sample_num) 15 | return [frame_name_lists[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in list(range(sample_num))] 16 | 17 | class VideoProcessor(object): 18 | def __init__(self, video_resolution, video_encoder_type, sample_num = 4, video_transforms='none', data_format="frame", training=True): 19 | self.frame_syncaug = True 20 | self.training = training 21 | self.sample_num = sample_num 22 | self.data_format = data_format 23 | 24 | self.resolution = video_resolution 25 | self.video_encoder_type = video_encoder_type 26 | 27 | if video_encoder_type.startswith('clip') or video_encoder_type.startswith('evaclip'): 28 | self.mean = [0.48145466, 0.4578275, 0.40821073] 29 | self.std = [0.26862954, 0.26130258, 0.27577711] 30 | else: 31 | self.mean = [0.485, 0.456, 0.406] 32 | self.std = [0.229, 0.224, 0.225] 33 | 34 | self.video_transforms = video_transforms 35 | if video_transforms == 'none': 36 | self.train_transforms = transforms.Compose([Resize((self.resolution,self.resolution)), 37 | Normalize(self.mean,self.std)]) 38 | 39 | self.test_transforms = transforms.Compose([Resize((self.resolution,self.resolution)), 40 | Normalize(self.mean,self.std)]) 41 | elif video_transforms == 'crop_flip': 42 | self.train_transforms = transforms.Compose([RandomResizedCrop(self.resolution, [0.8,1.0],[1.0,1.0]), 43 | RandomHorizontalFlip(), 44 | Normalize(self.mean,self.std)]) 45 | 46 | self.test_transforms = transforms.Compose([Resize(self.resolution), 47 | CenterCrop(self.resolution), 48 | Normalize(self.mean,self.std)]) 49 | else: 50 | raise NotImplementedError 51 | 52 | def __call__(self, video_file): 53 | 54 | video_pixels = [] 55 | sample_num = self.sample_num 56 | try: 57 | if self.data_format == 'frame': 58 | frame_path = video_file 59 | if not os.path.exists(video_path): 60 | print('not have videos', video_file) 61 | return None 62 | frames = os.listdir(frame_path) 63 | frames.sort() ### ['img_0001.jpg','img_0002.jpg',...] 64 | sample_num = self.sample_num 65 | frames_splited = split(frames,sample_num) 66 | if self.training: 67 | sample_idx = [random.choice(i) for i in frames_splited] 68 | else: 69 | sample_idx = [i[(len(i)+1)//2-1] for i in frames_splited] 70 | for i in range(sample_num): 71 | frame = Image.open(os.path.join(frame_path,sample_idx[i])) 72 | frame = transforms.ToTensor()(frame) ## frame: 3XhXw 73 | video_pixels.append(frame.unsqueeze(0)) 74 | 75 | elif self.data_format == 'raw': 76 | video_path = video_file 77 | if not os.path.exists(video_path): 78 | print('not have videos', video_file) 79 | return None 80 | container = decord.VideoReader(uri=video_path) 81 | frames_ids = list(range(len(container))) 82 | 83 | frames_splited = split(frames_ids, sample_num) 84 | if self.training: 85 | sample_idx = [random.choice(i) for i in frames_splited] 86 | else: 87 | sample_idx = [i[(len(i)+1)//2-1] for i in frames_splited] 88 | 89 | frames = container.get_batch(sample_idx).asnumpy() 90 | # print(len(frames), type(frames),sample_idx) 91 | 92 | for i in frames: 93 | frame = Image.fromarray(i) 94 | frame = transforms.ToTensor()(frame) ## frame: 3XhXw 95 | video_pixels.append(frame.unsqueeze(0)) 96 | 97 | 98 | video_pixels = torch.cat(video_pixels,dim=0) ### nX3xHxW 99 | if self.training: 100 | video_pixels = self.train_transforms(video_pixels) 101 | else: 102 | video_pixels = self.test_transforms(video_pixels) 103 | return video_pixels 104 | 105 | except Exception as e: 106 | print(e) 107 | print(video_file) 108 | return None 109 | 110 | if __name__ == "__main__": 111 | video_file = "./data/test.mp4" 112 | proc = VideoProcessor(video_resolution=224, video_encoder_type="swin", sample_num=4, data_format="raw", training=True) 113 | video_input = proc(video_file) 114 | print(video_input.size()) -------------------------------------------------------------------------------- /set_env.sh: -------------------------------------------------------------------------------- 1 | pip install torch==2.0.1 -i https://pypi.tuna.tsinghua.edu.cn/simple 2 | pip install torchvision==0.15.2 -i https://pypi.tuna.tsinghua.edu.cn/simple 3 | pip install torchaudio==2.0.2 -i https://pypi.tuna.tsinghua.edu.cn/simple 4 | pip install decord -i https://pypi.tuna.tsinghua.edu.cn/simple 5 | pip install h5py -i https://pypi.tuna.tsinghua.edu.cn/simple 6 | pip install ffmpeg-python -i https://pypi.tuna.tsinghua.edu.cn/simple 7 | pip install yacs -i https://pypi.tuna.tsinghua.edu.cn/simple 8 | pip install toolz -i https://pypi.tuna.tsinghua.edu.cn/simple 9 | pip install ipdb -i https://pypi.tuna.tsinghua.edu.cn/simple 10 | pip install einops -i https://pypi.tuna.tsinghua.edu.cn/simple 11 | pip install easydict -i https://pypi.tuna.tsinghua.edu.cn/simple 12 | pip install transformers==4.31.0 -i https://pypi.tuna.tsinghua.edu.cn/simple 13 | pip install webdataset -i https://pypi.tuna.tsinghua.edu.cn/simple 14 | pip install SentencePiece -i https://pypi.tuna.tsinghua.edu.cn/simple --------------------------------------------------------------------------------