├── .gitignore
├── LICENSE
├── README.md
├── assets
├── banner.png
├── brain.png
├── omnimodal_pretraining.png
├── paradigm.png
└── scaling_laws.png
├── data
├── README.md
├── caption_config
│ ├── caption-generation-audio.json
│ ├── caption-generation-vision.json
│ ├── default_model_cfg.json
│ └── default_run_cfg.json
├── config.yaml
├── data
│ ├── IndexAnno.py
│ ├── IndexSrc.py
│ ├── __init__.py
│ ├── audio_mapper.py
│ ├── loader.py
│ └── vision_mapper.py
├── download_hdvila.sh
├── makeparquet.py
├── model
│ ├── __init__.py
│ ├── audio_encoders
│ │ ├── ast
│ │ │ └── ast.py
│ │ └── beats
│ │ │ └── beats.py
│ ├── general_module.py
│ ├── text_encoders
│ │ └── bert
│ │ │ └── bert.py
│ ├── vast.py
│ └── vision_encoders
│ │ ├── clip
│ │ ├── clip.py
│ │ └── clip_tokenizer.py
│ │ ├── evaclip
│ │ ├── __init__.py
│ │ ├── bpe_simple_vocab_16e6.txt.gz
│ │ ├── constants.py
│ │ ├── eva_vit_model.py
│ │ ├── factory.py
│ │ ├── hf_configs.py
│ │ ├── hf_model.py
│ │ ├── loss.py
│ │ ├── model.py
│ │ ├── model_configs
│ │ │ ├── EVA01-CLIP-B-16.json
│ │ │ ├── EVA01-CLIP-g-14-plus.json
│ │ │ ├── EVA01-CLIP-g-14.json
│ │ │ ├── EVA02-CLIP-B-16.json
│ │ │ ├── EVA02-CLIP-L-14-336.json
│ │ │ ├── EVA02-CLIP-L-14.json
│ │ │ ├── EVA02-CLIP-bigE-14-plus.json
│ │ │ └── EVA02-CLIP-bigE-14.json
│ │ ├── modified_resnet.py
│ │ ├── openai.py
│ │ ├── pretrained.py
│ │ ├── rope.py
│ │ ├── timm_model.py
│ │ ├── tokenizer.py
│ │ ├── transform.py
│ │ ├── transformer.py
│ │ └── utils.py
│ │ ├── swin
│ │ ├── swin.py
│ │ └── swin_config.py
│ │ └── videoswin
│ │ └── videoswin.py
├── run.py
├── scripts
│ ├── run_audio_captioner.sh
│ └── run_vision_captioner.sh
├── setup_env.sh
└── utils
│ ├── __init__.py
│ ├── args.py
│ ├── build_dataloader.py
│ ├── build_model.py
│ ├── build_optimizer.py
│ ├── distributed.py
│ ├── initialize.py
│ ├── logger.py
│ ├── offline_process_data.py
│ ├── pipeline.py
│ ├── save.py
│ ├── sched.py
│ └── tool.py
├── example
├── test.flac
├── test.jpeg
└── test.mp4
├── inference_demo.py
├── model
├── .DS_Store
├── audioprocessor.py
├── bert-base-uncased-crossattn
│ ├── config.json
│ └── generation_config.json
├── bert.py
├── clip
│ ├── clip.py
│ └── clip_tokenizer.py
├── evaclip
│ ├── __init__.py
│ ├── bpe_simple_vocab_16e6.txt.gz
│ ├── constants.py
│ ├── eva_vit_model.py
│ ├── factory.py
│ ├── hf_configs.py
│ ├── hf_model.py
│ ├── loss.py
│ ├── model.py
│ ├── model_configs
│ │ ├── EVA01-CLIP-B-16.json
│ │ ├── EVA01-CLIP-g-14-plus.json
│ │ ├── EVA01-CLIP-g-14.json
│ │ ├── EVA02-CLIP-B-16.json
│ │ ├── EVA02-CLIP-L-14-336.json
│ │ ├── EVA02-CLIP-L-14.json
│ │ ├── EVA02-CLIP-bigE-14-plus.json
│ │ └── EVA02-CLIP-bigE-14.json
│ ├── modified_resnet.py
│ ├── openai.py
│ ├── pretrained.py
│ ├── rope.py
│ ├── timm_model.py
│ ├── tokenizer.py
│ ├── transform.py
│ ├── transformer.py
│ └── utils.py
├── imageprocessor.py
├── mico.py
├── swin.py
├── swin_base_patch4_window7_224_22k.yaml
├── swin_config.py
├── tokenizer
│ ├── special_tokens_map.json
│ ├── tokenizer_config.json
│ └── vocab.txt
├── transformer.py
└── videoprocessor.py
└── set_env.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | /.history
2 | examples/
3 | *.pyc
4 | *.pt
5 | wandb/
6 |
7 | .history
8 | /cococaption
9 |
10 | Mico-g/
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | [](https://arxiv.org/abs/2406.09412)
6 | [](https://invictus717.github.io/MiCo/)
7 | [](https://huggingface.co/Yiyuan/MiCo-ViT-g-14-omnimodal-300k-b64K)
8 |
9 |
10 |
11 |
13 |
14 |
15 |
16 | ### ✨ Inspiration of Multimodal Context: Multimedia Brain Cognition
17 |
18 |
19 |
20 |
21 |
22 | ***How the human brain performs coherent multimodal cognition?***
23 |
24 | As outlined in Richard Mayer's Cognitive Theory of Multimedia Learning,our brain processes multimedia signals through two distinct channels—auditory and visual—in sensory memory, as depicted in Figure(a). The sensory memory integrates these signals with prior knowledge through words, transforming new multimedia information into long-term memory. Notably, **1**) multimedia signals in the brain share channels, and **2**) words function as the reasoning interface in our brain.
25 |
26 | Inspired by these insights, we categorize diverse modalities into two types: ``knowledge modality`` and ``interface modality``. *Knowledge modalities*, primarily derived from raw sensors, contribute knowledge in diverse formats. For example, images and depth maps offer visual knowledge, while audio and video provide auditory and spatiotemporal knowledge. The language modality, developed by humans, is inherently more abstract and naturally functions as the *interface modality*, facilitating learning, reasoning, and the coordination of knowledge. To this end, we design an omni-modal learning architecture, illustrated in Figure (b), with two distinct branches: one for knowledge modalities and one for the interface modality, *i.e.* natural language. The knowledge and interface modalities are aligned through a novel generative reasoning method.
27 |
28 | ### 🚀 MiCo, An omni-modal and scalable pretraining paradigm
29 |
30 |
31 |
32 |
33 |
34 | We propose collecting large-scale omni-modal paired data, including text,
35 | image, video, depth, and normal maps, to learn universal representations.
36 |
37 |
38 |
39 |
40 |
41 | **🚀 Evolution of Pretraining Paradigms**. Masked modeling (a) has shown great success in single modality, general-purpose understanding. Contrastive learning (b) distinguishes transferable features with modality tuples (such as text-image, text-video, text-audio, etc).
42 |
43 | *🚀🚀🚀 We aim to achieve general-purpose omni-modal understanding and learn transferable, universal representations in (c).*
44 |
45 | ### 🌟🌟🌟 The Multimodal Scaling Laws with MiCo: Modalities Help Modalies!
46 |
47 |
48 |
49 |
50 |
51 | ### 🔓 Pretrained Omni-Modal Models
52 |
53 | **We will continue to update this model zoo including all scales of ViTs and highly-efficient ConvNets with the MiCo pretraining paradigm**
54 |
55 | Current Checkpoints
56 |
57 |
58 |
59 | | Model | Pretraining | Scale | Modality | #Param | Google Drive | Hugging Face
60 | | :------------: | :----------: | :----------------------: | :----: | :---------------------------------------------------------------------------------------------------: |:----: | :----: |
61 | | MiCo | 300k steps | ViT-g | Omni-modal | 1.3B | [ckpt](https://drive.google.com/drive/folders/1AIQjV1KU8K4OXiO-4gFirxkoxt3twWIq?usp=sharing) | [ckpt](https://huggingface.co/Yiyuan/MiCo-ViT-g-14-omnimodal-300k-b64K)
62 |
63 |
64 |
65 |
66 | ### 🔓 Omni-Modal Dataset Collection
67 |
68 | We provdie a detailed [doc](data/README.md) for preparing the omni-modal dataset step-by-step
69 |
70 | ### ⚡ Quick Start
71 | ```bash
72 | pip install gdown
73 | gdown 1AIQjV1KU8K4OXiO-4gFirxkoxt3twWIq --folder
74 | python inference_demo.py
75 | ```
76 | # Citation
77 | If the code and paper help your research, please kindly cite:
78 | ```
79 | @article{zhang2024explore,
80 | title={Explore the Limits of Omni-modal Pretraining at Scale},
81 | author={Zhang, Yiyuan and Li, Handong and Liu, Jing and Yue, Xiangyu},
82 | journal={arXiv preprint arXiv:2406.09412},
83 | year={2024}
84 | }
85 | ```
86 | # License
87 | This project is released under the [Apache 2.0 license](LICENSE).
88 | # Acknowledgement
89 | We appreciate [Dr. Xiaohan Ding](https://dingxiaohan.xyz/) for the valuable discussion and suggestions.This code is developed based [Meta-Transformer](https://github.com/invictus717/MetaTransformer), [VAST](https://github.com/TXH-mercury/VAST), [DPT](https://github.com/EPFL-VILAB/omnidata), and [GeoWizard](https://github.com/fuxiao0719/GeoWizard).
--------------------------------------------------------------------------------
/assets/banner.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/invictus717/MiCo/831847ff066296d185c7f304563326224b9dbc8b/assets/banner.png
--------------------------------------------------------------------------------
/assets/brain.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/invictus717/MiCo/831847ff066296d185c7f304563326224b9dbc8b/assets/brain.png
--------------------------------------------------------------------------------
/assets/omnimodal_pretraining.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/invictus717/MiCo/831847ff066296d185c7f304563326224b9dbc8b/assets/omnimodal_pretraining.png
--------------------------------------------------------------------------------
/assets/paradigm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/invictus717/MiCo/831847ff066296d185c7f304563326224b9dbc8b/assets/paradigm.png
--------------------------------------------------------------------------------
/assets/scaling_laws.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/invictus717/MiCo/831847ff066296d185c7f304563326224b9dbc8b/assets/scaling_laws.png
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | ### Data Preparation
2 |
3 | #### Install Download Tools
4 | ```bash
5 | pip install video2dataset
6 | ```
7 | or
8 | ```bash
9 | git clone https://github.com/iejMac/video2dataset
10 | cd video2dataset
11 | pip install -e .
12 | ```
13 | #### Download Metadata
14 | ```bash
15 | wget -O hdvila100m.zip https://hdvila.blob.core.windows.net/dataset/hdvila100m.zip?sp=r&st=2022-06-28T03:33:11Z&se=2026-01-01T11:33:11Z&spr=https&sv=2021-06-08&sr=b&sig=VaqQkLFDqKinfkaPNs1jJ1EQIYCB%2FUPYiqFqmjWye6Y%3D
16 | ```
17 | Then unzip the metadata zip file.
18 | ```bash
19 | unzip hdvilla100m.zip
20 | ```
21 | With the metadata, we will deal with these data into parquet files by running this code:
22 | ```bash
23 | python makeparquet.py
24 | ```
25 | Once you run this, you should have a file `hd_vila.parquet` with all the relevant metadata. The files are organized as:
26 | ```bash
27 | data
28 | ├── caption_config
29 | ├── model
30 | ├── scripts
31 | ├── utils
32 | ├── makeparquet.py
33 | ├── config.yaml
34 | ├── download_hdvila.sh
35 | ├── hdvila
36 | │ ├── hdvila_part0.jsonl
37 | │ ├── hdvila_part1.jsonl
38 | │ ├── hdvila_part2.jsonl
39 | │ ├── hdvila_part3.jsonl
40 | │ ├── hdvila_part4.jsonl
41 | │ ├── hdvila_part5.jsonl
42 | │ ├── hdvila_part6.jsonl
43 | │ ├── hdvila_part7.jsonl
44 | │ ├── hdvila_part8.jsonl
45 | │ ├── hdvila_part9.jsonl
46 | │ ├── hdvila_part10.jsonl
47 | │ ├── hd_vila.parquet
48 | ```
49 | #### Download HDVILA-100M Source Data
50 | Please check your path in `download_hdvila.sh` before running the script for downloading the dataset:
51 | ```bash
52 | bash download_hdvila.sh
53 | ```
54 | #### Annotate Your Videos
55 | 1. Download Pretrained Captioners for Videos (Images) and Audio.
56 | ```bash
57 | pip install gdown
58 | gdown https://drive.google.com/file/d/1vYqb0Lb_3sQ5bo6XV-FQ4n7k_0M9UMU3/view?usp=sharing
59 | tar -xvf audio_captioner.tar.gz
60 | gdown https://drive.google.com/file/d/1ZFCWZ8csMWLYsg9CWt71PJmKYpSn-FMt/view?usp=sharing
61 | tar -xvf vision_captioner.tar.gz
62 | ```
63 | 2. Deploy captioners for data annotation
64 | Set up the python environment for captioner.
65 | ```bash
66 | bash setup_env.sh
67 | ```
68 |
69 | Video Annotation with Captions
70 | ```bash
71 | bash scripts/run_vision_captioner.sh
72 | ```
73 |
74 | Audio Annotation with Captions
75 | ```bash
76 | bash scripts/run_audio_captioner.sh
77 | ```
78 | 3. (Optional) Deploy Depth Estimator to annotate 3D contents
79 | *We highly recommend you to use [GeoWizard](https://github.com/fuxiao0719/GeoWizard) to generate high-quality 3D contents*.
80 | while the shortage of *GeoWizard* is the inference speed of generative models. Therefore, in our practice, we use the [DPT](https://github.com/EPFL-VILAB/omnidata) to annotate major data.
--------------------------------------------------------------------------------
/data/caption_config/caption-generation-audio.json:
--------------------------------------------------------------------------------
1 | {"run_cfg":
2 | {"default":"default_run_cfg.json",
3 | "mode":"testing"},
4 |
5 | "model_cfg":
6 | {"default":"default_model_cfg.json"},
7 |
8 | "data_cfg":
9 |
10 | {"train":{},
11 |
12 | "val":
13 | [
14 | {
15 | "type":"annoindexed",
16 | "training":false,
17 | "name": "yourdata",
18 | "txt": "datasets/annotations/yourdata/meta.json",
19 | "audio": "datasets/srcdata/yourdata/audios",
20 | "audio_sample_num": 3,
21 | "task" : "cap%ta",
22 | "n_workers": 8,
23 | "batch_size": 64 }
24 | ]}}
25 |
26 |
27 |
--------------------------------------------------------------------------------
/data/caption_config/caption-generation-vision.json:
--------------------------------------------------------------------------------
1 |
2 | {"run_cfg":
3 | {"default":"default_run_cfg.json",
4 | "mode":"testing"},
5 |
6 | "model_cfg":
7 | {"default":"default_model_cfg.json"},
8 |
9 | "data_cfg":
10 |
11 | {"train":{},
12 |
13 | "val":
14 | [{
15 | "type":"annoindexed",
16 | "training":false,
17 | "name": "yourdata",
18 | "txt": "datasets/annotations/yourdata/meta.json",
19 | "vision": "datasets/srcdata/yourdata/videos",
20 | "vision_format": "video_rawvideo",
21 | "vision_sample_num": 8,
22 | "task" : "cap%tv",
23 | "n_workers": 8,
24 | "batch_size": 64
25 | }]}}
--------------------------------------------------------------------------------
/data/caption_config/default_model_cfg.json:
--------------------------------------------------------------------------------
1 | {"model_type": "vast",
2 | "itm_ratio":0.1,
3 | "frozen_vision":false,
4 | "frozen_audio":false,
5 | "checkpointing":false,
6 | "max_caption_len":40,
7 | "max_omni_caption_len":70,
8 | "max_subtitle_len":70,
9 | "contra_dim":512,
10 | "inherit_keys":["vision_encoder_type","audio_encoder_type","audio_melbins","audio_target_length"],
11 | "frame_embedding_type":"adaptive",
12 | "vision_resolution":224,
13 | "vision_encoder_type":"evaclip01_giant",
14 | "audio_encoder_type":"beats",
15 | "audio_melbins":64,
16 | "audio_target_length": 1024,
17 | "beam_size":3,
18 | "captioner_mode":false,
19 | "generate_nums":1,
20 | "ret_bidirection_evaluation":false,
21 | "itm_rerank_num":50,
22 | "evaluation_type":"evaluation_mm"}
23 |
24 |
--------------------------------------------------------------------------------
/data/caption_config/default_run_cfg.json:
--------------------------------------------------------------------------------
1 | {"checkpoint":"",
2 | "output_dir":"none",
3 | "gradient_accumulation_steps":1,
4 | "clip_lr":5e-7,
5 | "optim":"adamw",
6 | "learning_rate":1e-4,
7 | "betas":[0.9, 0.98],
8 | "weight_decay":0.01,
9 | "grad_norm":2.0,
10 | "warmup_ratio":0.1,
11 | "resume":false,
12 | "seed":50,
13 | "fp16":true,
14 | "bf16":false,
15 | "zero_shot":false,
16 | "scheduler":"warmup_linear",
17 | "new_lr":0,
18 | "new_params_name":[],
19 | "valid_freq":10,
20 | "dataset_mix_type":"random",
21 | "remove_before_ckpt":true,
22 | "first_eval":true,
23 | "pretrain_dir":"",
24 | "num_train_steps":0,
25 | "save_best":false,
26 | "pin_mem":true,
27 | "vision_resolution":224,
28 | "use_ddp":true,
29 | "mode":"training",
30 | "log_steps":100
31 | }
--------------------------------------------------------------------------------
/data/config.yaml:
--------------------------------------------------------------------------------
1 | subsampling:
2 | CutDetectionSubsampler:
3 | args:
4 | cut_detection_mode: "all"
5 | framerates: null
6 | threshold: 11.5
7 | min_scene_len: 15
8 | reading:
9 | yt_args:
10 | download_size: 360
11 | download_audio_rate: 44100
12 | yt_metadata_args:
13 | writesubtitles: 'all'
14 | subtitleslangs: ['en']
15 | writeautomaticsub: True
16 | get_info: True
17 | timeout: 180
18 | sampler: null
19 |
20 | storage:
21 | number_sample_per_shard: 100
22 | captions_are_subtitles: False
23 | oom_shard_count: 5
24 |
25 | distribution:
26 | processes_count: 2
27 | thread_count: 8
28 | subjob_size: 1000
29 | distributor: "multiprocessing"
--------------------------------------------------------------------------------
/data/data/IndexAnno.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import json
4 | import random
5 | import torch
6 | import numpy as np
7 | import torch.nn.functional as F
8 | from toolz.sandbox import unzip
9 | from torch.utils.data import Dataset
10 | from utils.logger import LOGGER
11 | from .vision_mapper import VisionMapper
12 | from .audio_mapper import AudioMapper
13 |
14 | from torch.utils.data import ConcatDataset
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 | class AnnoIndexedDataset(Dataset):
24 | def __init__(self, d_cfg, args):
25 | self.vision_mapper = VisionMapper(d_cfg, args) if 'vision' in d_cfg else None
26 | self.audio_mapper = AudioMapper(d_cfg, args) if 'audio' in d_cfg else None
27 | self.annos = json.load(open(d_cfg['txt']))
28 | self.idx = list(range(len(self.annos)))
29 | self.dataset_name = d_cfg['name']
30 | self.training = d_cfg.training
31 |
32 | self.worker_init_fn = None
33 | self.use_sampler = True
34 | self.collate_fn = annoindexedcollate
35 |
36 | self.annfile = getattr(d_cfg,'annfile',None)
37 | self.make_submission = getattr(d_cfg,'make_submission',False)
38 | self.multi_evaluation = getattr(d_cfg,'multi_evaluation',False)
39 | self.vqa_anno_file = getattr(d_cfg,'vqa_anno_file',None)
40 | self.vqa_question_file = getattr(d_cfg,'vqa_question_file',None)
41 |
42 |
43 | def __len__(self):
44 | return len(self.annos)
45 |
46 | def __getitem__(self, i):
47 | anno = self.annos[i]
48 |
49 | for key in ['video_id','image_id','image','id']:
50 | if key in anno:
51 | id_ = anno[key]
52 | break
53 |
54 | raw_captions = None
55 | raw_subtitles = None
56 | question_id = None
57 | question = None
58 | answer = None
59 | id_txt = None
60 | vision_pixels = None
61 | audio_spectrograms = None
62 |
63 |
64 |
65 |
66 |
67 |
68 | raw_captions = anno['desc'] if 'desc' in anno else anno['caption']
69 | num_samples = len(raw_captions) if isinstance(raw_captions, list) else 1
70 | id_txt = [id_] * num_samples
71 |
72 |
73 | if 'subtitle' in anno:
74 | raw_subtitles = anno['subtitle']
75 |
76 | if 'question' in anno:
77 |
78 | if self.training:
79 | question = anno['question']
80 | if isinstance(anno['answer'],list): #### vqav2
81 | answer = random.choice(anno['answer'])
82 | else:
83 | answer = anno['answer']
84 |
85 | else:
86 | question = anno['question']
87 | answer = anno['answer']
88 | if 'question_id' in anno:
89 | question_id = anno['question_id']
90 |
91 |
92 | if self.vision_mapper:
93 | if self.vision_mapper.vision_format == 'video_feats':
94 | vision_feats = self.vision_mapper.read(id_)
95 |
96 | else:
97 | vision_pixels = self.vision_mapper.read(id_)
98 | if vision_pixels is None: ###wrong img/video, resample when training and raise error when testing
99 | if self.training:
100 | resample_idx = random.choice(self.idx)
101 | LOGGER.info(f'current idx {id_} from {self.dataset_name} returns wrong image/video, use {resample_idx} instead.')
102 | return self.__getitem__(resample_idx)
103 | else:
104 | resample_idx = random.choice(self.idx)
105 | LOGGER.info(f'current idx {id_} from {self.dataset_name} returns wrong image/video,!!!!!!!!!!!!!!!!!!!!!!!! use {resample_idx} instead.')
106 | return self.__getitem__(resample_idx)
107 | # raise ValueError
108 |
109 | if self.audio_mapper:
110 | audio_spectrograms = self.audio_mapper.read(id_)
111 | if audio_spectrograms is None: ### wrong audio, resample when training and raise error when testing
112 | if self.training:
113 | resample_idx = random.choice(self.idx)
114 | LOGGER.info(f'current idx {id_} from {self.dataset_name} returns wrong audio, use {resample_idx} instead.')
115 | return self.__getitem__(resample_idx)
116 | else:
117 | raise ValueError
118 |
119 | return id_, raw_captions, vision_pixels, id_txt, question, answer, question_id, \
120 | audio_spectrograms, raw_subtitles
121 |
122 |
123 |
124 | def annoindexedcollate(inputs):
125 |
126 | batch = {}
127 | all_data = map(list, unzip(inputs))
128 | keys = ['ids',
129 | 'raw_captions',
130 | 'vision_pixels',
131 | 'ids_txt',
132 | 'raw_questions',
133 | 'raw_answers',
134 | 'question_ids',
135 | 'audio_spectrograms',
136 | 'raw_subtitles']
137 |
138 | for key, data in zip(keys, all_data):
139 |
140 | if data[0] is None:
141 | continue
142 | elif isinstance(data[0], torch.Tensor):
143 | batch[key] = torch.stack(data, dim=0).float()
144 |
145 | else:
146 | batch[key] = data
147 |
148 |
149 |
150 | return batch
151 |
152 |
153 |
--------------------------------------------------------------------------------
/data/data/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from .IndexAnno import AnnoIndexedDataset
3 | from .IndexSrc import SrcIndexedDataset
4 |
5 | data_registry={
6 | 'annoindexed':AnnoIndexedDataset,
7 | 'srcindexed':SrcIndexedDataset,
8 |
9 | }
10 |
--------------------------------------------------------------------------------
/data/data/audio_mapper.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import torch
4 | import torchaudio
5 | from utils.logger import LOGGER
6 | from utils.tool import split
7 |
8 |
9 | class AudioMapper(object):
10 | # def __init__(self, audio_dir, opts, sample_num, check_exists=True):
11 | def __init__(self, d_cfg, args):
12 | self.audio_dir = d_cfg.audio
13 | self.melbins = args.model_cfg.audio_melbins
14 | self.target_length = args.model_cfg.audio_target_length
15 | self.training = d_cfg.training
16 | self.frame_shift = 10
17 | self.sample_num = d_cfg.audio_sample_num
18 | self.audio_encoder_type = args.model_cfg.audio_encoder_type
19 | if self.audio_encoder_type == 'ast':
20 | self.mean = -4.2677393
21 | self.std = 4.5689974
22 | elif self.audio_encoder_type == 'beats':
23 | self.mean = 15.41663
24 | self.std = 6.55582
25 | else:
26 | raise NotImplementedError
27 |
28 |
29 |
30 | def read(self, id_):
31 |
32 | wav_file = os.path.join(self.audio_dir, id_)
33 |
34 | if not os.path.exists(wav_file):
35 | wav_file = os.path.join(self.audio_dir, id_+'.wav')
36 | if not os.path.exists(wav_file):
37 | wav_file = wav_file.replace('wav','mp3')
38 | if not os.path.exists(wav_file):
39 | wav_file = wav_file.replace('mp3','mkv')
40 | if not os.path.exists(wav_file):
41 | print('not have audios', id_)
42 | return torch.zeros(self.sample_num, self.target_length, self.melbins)
43 | try:
44 | if self.audio_encoder_type == 'ast':
45 |
46 | waveform, sr = torchaudio.load(wav_file)
47 |
48 | waveform = waveform - waveform.mean()
49 | fbank = torchaudio.compliance.kaldi.fbank(waveform, htk_compat=True, sample_frequency=sr, use_energy=False,
50 | window_type='hanning', num_mel_bins=self.melbins, dither=0.0, frame_shift=self.frame_shift)
51 |
52 |
53 |
54 | elif self.audio_encoder_type == 'beats':
55 |
56 | waveform, sr = torchaudio.load(wav_file)
57 | if sr != 16000:
58 | trans = torchaudio.transforms.Resample(sr, 16000)
59 | waveform = trans(waveform)
60 |
61 | waveform = waveform * 2 ** 15
62 | fbank = torchaudio.compliance.kaldi.fbank(waveform, num_mel_bins=self.melbins, sample_frequency=16000, frame_length=25, frame_shift=10)
63 |
64 | else:
65 | raise NotImplementedError
66 |
67 | # ### normalization
68 | fbank = (fbank - self.mean) / (self.std * 2)
69 | src_length = fbank.shape[0]
70 | # #### sample
71 | output_slices = []
72 | pad_len = max(self.target_length * self.sample_num -src_length, self.target_length - src_length%self.target_length)
73 | fbank = torch.nn.ZeroPad2d((0, 0, 0, pad_len))(fbank)
74 | total_slice_num = fbank.shape[0] // self.target_length
75 | total_slice_num = list(range(total_slice_num))
76 | total_slice_num = split(total_slice_num, self.sample_num)
77 |
78 | if self.training:
79 | sample_idx = [random.choice(i) for i in total_slice_num]
80 | else:
81 | sample_idx = [i[(len(i)+1)//2-1] for i in total_slice_num]
82 |
83 |
84 | for i in sample_idx:
85 | cur_bank = fbank[i*self.target_length : (i+1)*self.target_length]
86 | output_slices.append(cur_bank)
87 |
88 | fbank = torch.stack(output_slices,dim=0) ### n, 1024, 128
89 | return fbank
90 |
91 | except Exception as e:
92 | print(e)
93 | return
94 |
95 |
--------------------------------------------------------------------------------
/data/data/loader.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 | from torch.utils.data import DataLoader
4 | from utils.distributed import any_broadcast
5 | from torch.utils.data.distributed import DistributedSampler
6 |
7 |
8 | class MetaLoader(object):
9 | """ wraps multiple data loaders """
10 | def __init__(self, loaders, accum_steps=1, distributed=False):
11 | assert isinstance(loaders, dict)
12 | self.name2loader = {}
13 | self.name2iter = {}
14 | self.sampling_pools = []
15 | self.name2labelname={}
16 | for idx, (n, l) in enumerate(loaders.items()):
17 | if isinstance(l, tuple):
18 | l, r = l
19 | elif isinstance(l, DataLoader):
20 | r = 1
21 | else:
22 | raise ValueError()
23 | self.name2loader[n] = l
24 | self.name2iter[n] = iter(l)
25 | self.sampling_pools.extend([n]*r)
26 | # import ipdb
27 | # ipdb.set_trace()
28 |
29 |
30 | self.accum_steps = accum_steps
31 | self.distributed = distributed
32 | self.step = 0
33 | self.epoch = 0
34 |
35 |
36 | def __iter__(self):
37 | """ this iterator will run indefinitely """
38 | task = self.sampling_pools[0]
39 | while True:
40 | if self.step % self.accum_steps == 0:
41 | task = random.choice(self.sampling_pools)
42 | if self.distributed:
43 | # make sure all process is training same task
44 | task = any_broadcast(task, 0)
45 | self.step += 1
46 | iter_ = self.name2iter[task]
47 | try:
48 | batch = next(iter_)
49 | except StopIteration:
50 | self.epoch = self.epoch + 1
51 | if isinstance(self.name2loader[task].sampler, DistributedSampler):
52 | self.name2loader[task].sampler.set_epoch(self.epoch)
53 | else:
54 | pass
55 | iter_ = iter(self.name2loader[task])
56 | batch = next(iter_)
57 | self.name2iter[task] = iter_
58 |
59 |
60 | yield task, batch
61 |
62 |
63 | def move_to_cuda(batch):
64 | if isinstance(batch, torch.Tensor):
65 | return batch.cuda(non_blocking=True)
66 | elif isinstance(batch, list):
67 | new_batch = [move_to_cuda(t) for t in batch]
68 | elif isinstance(batch, tuple):
69 | new_batch = tuple(move_to_cuda(t) for t in batch)
70 | elif isinstance(batch, dict):
71 | new_batch = {n: move_to_cuda(t) for n, t in batch.items()}
72 | else:
73 | return batch
74 | return new_batch
75 |
76 |
77 | def record_cuda_stream(batch):
78 | if isinstance(batch, torch.Tensor):
79 | batch.record_stream(torch.cuda.current_stream())
80 | elif isinstance(batch, list) or isinstance(batch, tuple):
81 | for t in batch:
82 | record_cuda_stream(t)
83 | elif isinstance(batch, dict):
84 | for t in batch.values():
85 | record_cuda_stream(t)
86 | else:
87 | pass
88 |
89 |
90 | class PrefetchLoader(object):
91 | """
92 | overlap compute and cuda data transfer
93 | (copied and then modified from nvidia apex)
94 | """
95 | def __init__(self, loader):
96 | self.loader = loader
97 | self.stream = torch.cuda.Stream()
98 |
99 | def __iter__(self):
100 | loader_it = iter(self.loader)
101 | self.preload(loader_it)
102 | batch = self.next(loader_it)
103 | while batch is not None:
104 | yield batch
105 | batch = self.next(loader_it)
106 |
107 | def __len__(self):
108 | return len(self.loader)
109 |
110 | def preload(self, it):
111 | try:
112 | self.batch = next(it)
113 | except StopIteration:
114 | self.batch = None
115 | return
116 | # if record_stream() doesn't work, another option is to make sure
117 | # device inputs are created on the main stream.
118 | # self.next_input_gpu = torch.empty_like(self.next_input,
119 | # device='cuda')
120 | # self.next_target_gpu = torch.empty_like(self.next_target,
121 | # device='cuda')
122 | # Need to make sure the memory allocated for next_* is not still in use
123 | # by the main stream at the time we start copying to next_*:
124 | # self.stream.wait_stream(torch.cuda.current_stream())
125 | with torch.cuda.stream(self.stream):
126 | self.batch = move_to_cuda(self.batch)
127 | # more code for the alternative if record_stream() doesn't work:
128 | # copy_ will record the use of the pinned source tensor in this
129 | # side stream.
130 | # self.next_input_gpu.copy_(self.next_input, non_blocking=True)
131 | # self.next_target_gpu.copy_(self.next_target, non_blocking=True)
132 | # self.next_input = self.next_input_gpu
133 | # self.next_target = self.next_target_gpu
134 |
135 | def next(self, it):
136 |
137 | torch.cuda.current_stream().wait_stream(self.stream)
138 | batch = self.batch
139 | if batch is not None:
140 | record_cuda_stream(batch)
141 | self.preload(it)
142 | return batch
143 |
144 |
145 |
146 | def __getattr__(self, name):
147 | method = self.loader.__getattribute__(name)
148 | return method
149 |
--------------------------------------------------------------------------------
/data/download_hdvila.sh:
--------------------------------------------------------------------------------
1 | video2dataset \
2 | --url_list='hd_vila.parquet' \
3 | --input_format='parquet' \
4 | --output_format='files' \
5 | --output_folder="./hdvila" \
6 | --url_col="url" \
7 | --enable_wandb=False \
8 | --encode_formats="{'video': 'mp4', 'audio': 'mp3'}" \
9 | --config="config.yaml" \
10 | --max_shard_retry=3
--------------------------------------------------------------------------------
/data/makeparquet.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import glob
3 | import json
4 | import os
5 | import time
6 | from datetime import datetime
7 |
8 | def time_string_to_seconds(timestamp):
9 | hh,mm,s = timestamp.split(':')
10 | ss,ms = s.split('.')
11 | time = 3600*int(hh) + 60*int(mm) + int(ss) + int(ms)/1000
12 | return time
13 |
14 | def convert_clip_list(clip_list):
15 | return [[time_string_to_seconds(x) for x in clip] for clip in clip_list]
16 |
17 | ###### Change your path
18 | parquet_dir = "/path/to/my/metadata/dir/"
19 |
20 | data = []
21 | for jsonl in sorted(glob.glob(f"{parquet_dir}*.jsonl")):
22 | path = os.path.join(parquet_dir, jsonl)
23 | with open(path, "r") as f:
24 | for line in f:
25 | json_obj = json.loads(line)
26 | clips = [
27 | json_obj['clip'][i]['span']
28 | for i in range(len(json_obj['clip']))
29 | ]
30 |
31 | out = {
32 | 'video_id': json_obj['video_id'],
33 | 'url': json_obj['url'],
34 | 'clips': clips
35 | }
36 | data.append(out)
37 |
38 | df = pd.DataFrame(data)
39 | df['clips'] = df['clips'].map(lambda x: convert_clip_list(x))
40 | df.to_parquet("hd_vila.parquet")
--------------------------------------------------------------------------------
/data/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .vast import VAST
2 | model_registry = {
3 | 'vast':VAST
4 | }
--------------------------------------------------------------------------------
/data/model/audio_encoders/ast/ast.py:
--------------------------------------------------------------------------------
1 | """
2 | BERT layers from the huggingface implementation
3 | (https://github.com/huggingface/transformers)
4 | """
5 | # coding=utf-8
6 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
7 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
8 | #
9 | # Licensed under the Apache License, Version 2.0 (the "License");
10 | # you may not use this file except in compliance with the License.
11 | # You may obtain a copy of the License at
12 | #
13 | # http://www.apache.org/licenses/LICENSE-2.0
14 | #
15 | # Unless required by applicable law or agreed to in writing, software
16 | # distributed under the License is distributed on an "AS IS" BASIS,
17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18 | # See the License for the specific language governing permissions and
19 | # limitations under the License.
20 | import logging
21 | import math
22 | import copy
23 | import torch
24 | from torch import nn
25 | from torch.nn import LayerNorm as LayerNorm
26 | import torch.nn.functional as F
27 | import ipdb
28 |
29 | logger = logging.getLogger(__name__)
30 |
31 |
32 | def gelu(x):
33 | """Implementation of the gelu activation function.
34 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
35 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
36 | Also see https://arxiv.org/abs/1606.08415
37 | """
38 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
39 |
40 |
41 | def swish(x):
42 | return x * torch.sigmoid(x)
43 |
44 |
45 | ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
46 |
47 |
48 | class GELU(nn.Module):
49 | def forward(self, input_):
50 | output = gelu(input_)
51 | return output
52 |
53 |
54 |
55 |
56 | class TransformerLayer(nn.Module):
57 | def __init__(self, config, mode):
58 | super().__init__()
59 | self.attention = MultiHeadAttention(config)
60 | self.ff_layer = FeedForward(config)
61 | self.dropout = nn.Dropout(config.hidden_dropout)
62 | self.layernorm1 = LayerNorm(config.hidden_size, eps=1e-12)
63 | self.layernorm2 = LayerNorm(config.hidden_size, eps=1e-12)
64 | self.mode = mode
65 |
66 | def forward(self, hidden_states, attention_mask):
67 | if self.mode == 'prenorm':
68 | return self.forward_prenorm(hidden_states, attention_mask)
69 | elif self.mode == 'postnorm':
70 | return self.forward_postnorm(hidden_states, attention_mask)
71 | else:
72 | raise NotImplementedError
73 |
74 | def forward_prenorm(self, hidden_states, attention_mask):
75 | residual = hidden_states
76 | hidden_states = self.layernorm1(hidden_states)
77 | attention_output = self.attention(hidden_states, hidden_states, hidden_states, attention_mask)
78 | hidden_states = residual + self.dropout(attention_output)
79 |
80 | residual = hidden_states
81 | hidden_states = self.layernorm2(hidden_states)
82 | ff_output = self.ff_layer(hidden_states)
83 | hidden_states = residual + self.dropout(ff_output)
84 |
85 | return hidden_states
86 |
87 | def forward_postnorm(self, hidden_states, attention_mask):
88 | residual = hidden_states
89 | attention_output = self.attention(hidden_states, hidden_states, hidden_states, attention_mask)
90 | hidden_states = residual + self.dropout(attention_output)
91 | hidden_states = self.layernorm1(hidden_states)
92 |
93 | residual = hidden_states
94 | ff_output = self.ff_layer(hidden_states)
95 | hidden_states = residual + self.dropout(ff_output)
96 | hidden_states = self.layernorm2(hidden_states)
97 |
98 | return hidden_states
99 |
100 |
101 | def clones(x,times):
102 | return nn.ModuleList([copy.deepcopy(x) for i in range(times)])
103 |
104 |
105 |
106 | class MultiHeadAttention(nn.Module):
107 | def __init__(self, config):
108 | super().__init__()
109 | self.linears = clones(nn.Linear(config.hidden_size, config.hidden_size), 4)
110 | self.head_num = config.num_attention_heads
111 | self.hidden_size = config.hidden_size
112 | self.dropout=nn.Dropout(config.attention_dropout)
113 |
114 |
115 | def forward(self,q,k,v,mask=None):
116 | batch_size=q.shape[0]
117 | q,k,v=[layer(x).view(batch_size,-1,self.head_num, self.hidden_size//self.head_num).transpose(1,2) \
118 | for layer,x in zip(self.linears,(q,k,v))]
119 | norm_d=q.shape[-1]
120 | att_map=torch.matmul(q,k.transpose(-2,-1)) / math.sqrt(norm_d)
121 | if mask is not None:
122 | att_map=att_map + mask
123 | att_map=F.softmax(att_map,dim=-1)
124 | # import ipdb
125 | # if att_map.shape[-1] == 45:
126 | # ipdb.set_trace()
127 |
128 | att_map=self.dropout(att_map)
129 | attn_output = self.linears[-1](torch.matmul(att_map,v).transpose(1,2).contiguous().view(batch_size,-1,self.hidden_size))
130 | return attn_output
131 |
132 |
133 | class FeedForward(nn.Module):
134 | def __init__(self, config):
135 | super().__init__()
136 | self.linear1=nn.Linear(config.hidden_size, config.intermediate_size)
137 | self.linear2=nn.Linear(config.intermediate_size, config.hidden_size)
138 | self.activation = GELU()
139 |
140 |
141 | def forward(self,x):
142 | return self.linear2((self.activation(self.linear1(x))))
143 |
144 |
145 |
146 | class TransformerEncoder(nn.Module):
147 | def __init__(self, config, mode = 'prenorm'):
148 | super().__init__()
149 | layer = TransformerLayer(config, mode)
150 | self.mode = mode
151 | self.layer = nn.ModuleList([copy.deepcopy(layer)
152 | for _ in range(config.num_hidden_layers)])
153 | if self.mode == 'prenorm':
154 | self.last_layernorm = LayerNorm(config.hidden_size, eps=1e-12)
155 | self.checkpointing = config.checkpointing
156 | def forward(self, input_, attention_mask=None, cross_hidden_states=None,
157 | use_cache=False,
158 | cache=None,
159 | cache_first=False,
160 | cache_type=None):
161 | hidden_states = input_
162 | for layer_module in self.layer:
163 | if self.checkpointing:
164 | hidden_states = torch.utils.checkpoint.checkpoint(layer_module, hidden_states, attention_mask)
165 | else:
166 | hidden_states = layer_module(hidden_states, attention_mask)
167 |
168 | if self.mode == 'prenorm':
169 | hidden_states = self.last_layernorm(hidden_states)
170 | return hidden_states, cache
171 |
172 |
173 |
174 |
175 | class AudioEmbeddings(nn.Module):
176 | def __init__(self, model_cfg_audio):
177 | super().__init__()
178 |
179 | self.patch_size = 16
180 |
181 |
182 | self.token_length_per_frame = (model_cfg_audio.audio_melbins // self.patch_size) * (model_cfg_audio.audio_target_length // self.patch_size)
183 | self.first_conv = nn.Conv2d(1, model_cfg_audio.hidden_size, kernel_size = self.patch_size,
184 | stride = self.patch_size, padding=0)
185 | self.position_embeddings = nn.Embedding(self.token_length_per_frame + 1, model_cfg_audio.hidden_size)
186 | self.dropout = nn.Dropout(model_cfg_audio.hidden_dropout)
187 | self.cls_token = nn.Parameter(0.02 * torch.randn(1, 1, model_cfg_audio.hidden_size))
188 |
189 | def forward(self, audio_spectrograms): ### shape Bxn_sample_128x512
190 |
191 | audio_spectrograms = self.first_conv(audio_spectrograms.unsqueeze(1))
192 | b,c,_,_=audio_spectrograms.shape
193 | audio_tokens = audio_spectrograms.permute(0,2,3,1).reshape(b,-1,c)
194 | cls_token = self.cls_token.expand(b,-1,-1)
195 | audio_tokens = torch.cat((cls_token,audio_tokens),dim=1)
196 | audio_pos_ids = list(range(self.token_length_per_frame + 1))
197 | audio_pos_ids = torch.tensor(audio_pos_ids, dtype=torch.long, device=audio_spectrograms.device).unsqueeze(0)
198 | position_embeddings = self.position_embeddings(audio_pos_ids)
199 | embeddings = audio_tokens + position_embeddings
200 | embeddings = self.dropout(embeddings)
201 | return embeddings
202 |
--------------------------------------------------------------------------------
/data/model/vision_encoders/clip/clip_tokenizer.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import html
3 | import os
4 | from functools import lru_cache
5 | from typing import Any, Union, List
6 | from pkg_resources import packaging
7 | import ftfy
8 | import regex as re
9 | import torch
10 |
11 |
12 | @lru_cache()
13 | def default_bpe():
14 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
15 |
16 |
17 | @lru_cache()
18 | def bytes_to_unicode():
19 | """
20 | Returns list of utf-8 byte and a corresponding list of unicode strings.
21 | The reversible bpe codes work on unicode strings.
22 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
23 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
24 | This is a signficant percentage of your normal, say, 32K bpe vocab.
25 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
26 | And avoids mapping to whitespace/control characters the bpe code barfs on.
27 | """
28 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
29 | cs = bs[:]
30 | n = 0
31 | for b in range(2**8):
32 | if b not in bs:
33 | bs.append(b)
34 | cs.append(2**8+n)
35 | n += 1
36 | cs = [chr(n) for n in cs]
37 | return dict(zip(bs, cs))
38 |
39 |
40 | def get_pairs(word):
41 | """Return set of symbol pairs in a word.
42 | Word is represented as tuple of symbols (symbols being variable-length strings).
43 | """
44 | pairs = set()
45 | prev_char = word[0]
46 | for char in word[1:]:
47 | pairs.add((prev_char, char))
48 | prev_char = char
49 | return pairs
50 |
51 |
52 | def basic_clean(text):
53 | text = ftfy.fix_text(text)
54 | text = html.unescape(html.unescape(text))
55 | return text.strip()
56 |
57 |
58 | def whitespace_clean(text):
59 | text = re.sub(r'\s+', ' ', text)
60 | text = text.strip()
61 | return text
62 |
63 |
64 | class SimpleTokenizer(object):
65 | def __init__(self, bpe_path: str = default_bpe()):
66 | self.byte_encoder = bytes_to_unicode()
67 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
68 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
69 | merges = merges[1:49152-256-2+1]
70 | merges = [tuple(merge.split()) for merge in merges]
71 | vocab = list(bytes_to_unicode().values())
72 | vocab = vocab + [v+'' for v in vocab]
73 | for merge in merges:
74 | vocab.append(''.join(merge))
75 | vocab.extend(['<|startoftext|>', '<|endoftext|>'])
76 | self.encoder = dict(zip(vocab, range(len(vocab))))
77 | self.decoder = {v: k for k, v in self.encoder.items()}
78 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
79 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
80 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
81 |
82 | def bpe(self, token):
83 | if token in self.cache:
84 | return self.cache[token]
85 | word = tuple(token[:-1]) + ( token[-1] + '',)
86 | pairs = get_pairs(word)
87 |
88 | if not pairs:
89 | return token+''
90 |
91 | while True:
92 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
93 | if bigram not in self.bpe_ranks:
94 | break
95 | first, second = bigram
96 | new_word = []
97 | i = 0
98 | while i < len(word):
99 | try:
100 | j = word.index(first, i)
101 | new_word.extend(word[i:j])
102 | i = j
103 | except:
104 | new_word.extend(word[i:])
105 | break
106 |
107 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
108 | new_word.append(first+second)
109 | i += 2
110 | else:
111 | new_word.append(word[i])
112 | i += 1
113 | new_word = tuple(new_word)
114 | word = new_word
115 | if len(word) == 1:
116 | break
117 | else:
118 | pairs = get_pairs(word)
119 | word = ' '.join(word)
120 | self.cache[token] = word
121 | return word
122 |
123 | def encode(self, text):
124 | bpe_tokens = []
125 | text = whitespace_clean(basic_clean(text)).lower()
126 | for token in re.findall(self.pat, text):
127 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
128 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
129 | return bpe_tokens
130 |
131 | def decode(self, tokens):
132 | text = ''.join([self.decoder[token] for token in tokens])
133 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
134 | return text
135 |
136 |
137 |
138 | _tokenizer = SimpleTokenizer()
139 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
140 | """
141 | Returns the tokenized representation of given input string(s)
142 |
143 | Parameters
144 | ----------
145 | texts : Union[str, List[str]]
146 | An input string or a list of input strings to tokenize
147 |
148 | context_length : int
149 | The context length to use; all CLIP models use 77 as the context length
150 |
151 | truncate: bool
152 | Whether to truncate the text in case its encoding is longer than the context length
153 |
154 | Returns
155 | -------
156 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
157 | We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
158 | """
159 | if isinstance(texts, str):
160 | texts = [texts]
161 |
162 | sot_token = _tokenizer.encoder["<|startoftext|>"]
163 | eot_token = _tokenizer.encoder["<|endoftext|>"]
164 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
165 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
166 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
167 | else:
168 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
169 |
170 | for i, tokens in enumerate(all_tokens):
171 | if len(tokens) > context_length:
172 | if truncate:
173 | tokens = tokens[:context_length]
174 | tokens[-1] = eot_token
175 | else:
176 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
177 | result[i, :len(tokens)] = torch.tensor(tokens)
178 |
179 | return result
180 |
--------------------------------------------------------------------------------
/data/model/vision_encoders/evaclip/__init__.py:
--------------------------------------------------------------------------------
1 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
2 | from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer
3 | from .factory import list_models, add_model_config, get_model_config, load_checkpoint
4 | # from .loss import ClipLoss
5 | from .model import CLIP, CustomCLIP, CLIPTextCfg, CLIPVisionCfg,\
6 | convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
7 | from .openai import load_openai_model, list_openai_models
8 | from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model,\
9 | get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
10 | from .tokenizer import SimpleTokenizer, tokenize
11 | from .transform import image_transform
--------------------------------------------------------------------------------
/data/model/vision_encoders/evaclip/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/invictus717/MiCo/831847ff066296d185c7f304563326224b9dbc8b/data/model/vision_encoders/evaclip/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/data/model/vision_encoders/evaclip/constants.py:
--------------------------------------------------------------------------------
1 | OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
2 | OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
3 |
--------------------------------------------------------------------------------
/data/model/vision_encoders/evaclip/hf_configs.py:
--------------------------------------------------------------------------------
1 | # HF architecture dict:
2 | arch_dict = {
3 | # https://huggingface.co/docs/transformers/model_doc/roberta#roberta
4 | "roberta": {
5 | "config_names": {
6 | "context_length": "max_position_embeddings",
7 | "vocab_size": "vocab_size",
8 | "width": "hidden_size",
9 | "heads": "num_attention_heads",
10 | "layers": "num_hidden_layers",
11 | "layer_attr": "layer",
12 | "token_embeddings_attr": "embeddings"
13 | },
14 | "pooler": "mean_pooler",
15 | },
16 | # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
17 | "xlm-roberta": {
18 | "config_names": {
19 | "context_length": "max_position_embeddings",
20 | "vocab_size": "vocab_size",
21 | "width": "hidden_size",
22 | "heads": "num_attention_heads",
23 | "layers": "num_hidden_layers",
24 | "layer_attr": "layer",
25 | "token_embeddings_attr": "embeddings"
26 | },
27 | "pooler": "mean_pooler",
28 | },
29 | # https://huggingface.co/docs/transformers/model_doc/mt5#mt5
30 | "mt5": {
31 | "config_names": {
32 | # unlimited seqlen
33 | # https://github.com/google-research/text-to-text-transfer-transformer/issues/273
34 | # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
35 | "context_length": "",
36 | "vocab_size": "vocab_size",
37 | "width": "d_model",
38 | "heads": "num_heads",
39 | "layers": "num_layers",
40 | "layer_attr": "block",
41 | "token_embeddings_attr": "embed_tokens"
42 | },
43 | "pooler": "mean_pooler",
44 | },
45 | "bert": {
46 | "config_names": {
47 | "context_length": "max_position_embeddings",
48 | "vocab_size": "vocab_size",
49 | "width": "hidden_size",
50 | "heads": "num_attention_heads",
51 | "layers": "num_hidden_layers",
52 | "layer_attr": "layer",
53 | "token_embeddings_attr": "embeddings"
54 | },
55 | "pooler": "mean_pooler",
56 | }
57 | }
58 |
--------------------------------------------------------------------------------
/data/model/vision_encoders/evaclip/loss.py:
--------------------------------------------------------------------------------
1 | # import math
2 | # import torch
3 | # import torch.nn as nn
4 | # from torch.nn import functional as F
5 |
6 | # try:
7 | # import torch.distributed.nn
8 | # from torch import distributed as dist
9 | # has_distributed = True
10 | # except ImportError:
11 | # has_distributed = False
12 |
13 | # try:
14 | # import horovod.torch as hvd
15 | # except ImportError:
16 | # hvd = None
17 |
18 | # from timm.loss import LabelSmoothingCrossEntropy
19 |
20 |
21 | # def gather_features(
22 | # image_features,
23 | # text_features,
24 | # local_loss=False,
25 | # gather_with_grad=False,
26 | # rank=0,
27 | # world_size=1,
28 | # use_horovod=False
29 | # ):
30 | # assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
31 | # if use_horovod:
32 | # assert hvd is not None, 'Please install horovod'
33 | # if gather_with_grad:
34 | # all_image_features = hvd.allgather(image_features)
35 | # all_text_features = hvd.allgather(text_features)
36 | # else:
37 | # with torch.no_grad():
38 | # all_image_features = hvd.allgather(image_features)
39 | # all_text_features = hvd.allgather(text_features)
40 | # if not local_loss:
41 | # # ensure grads for local rank when all_* features don't have a gradient
42 | # gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
43 | # gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
44 | # gathered_image_features[rank] = image_features
45 | # gathered_text_features[rank] = text_features
46 | # all_image_features = torch.cat(gathered_image_features, dim=0)
47 | # all_text_features = torch.cat(gathered_text_features, dim=0)
48 | # else:
49 | # # We gather tensors from all gpus
50 | # if gather_with_grad:
51 | # all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
52 | # all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
53 | # # all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features, async_op=True), dim=0)
54 | # # all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features, async_op=True), dim=0)
55 | # else:
56 | # gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
57 | # gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
58 | # dist.all_gather(gathered_image_features, image_features)
59 | # dist.all_gather(gathered_text_features, text_features)
60 | # if not local_loss:
61 | # # ensure grads for local rank when all_* features don't have a gradient
62 | # gathered_image_features[rank] = image_features
63 | # gathered_text_features[rank] = text_features
64 | # all_image_features = torch.cat(gathered_image_features, dim=0)
65 | # all_text_features = torch.cat(gathered_text_features, dim=0)
66 |
67 | # return all_image_features, all_text_features
68 |
69 |
70 | # class ClipLoss(nn.Module):
71 |
72 | # def __init__(
73 | # self,
74 | # local_loss=False,
75 | # gather_with_grad=False,
76 | # cache_labels=False,
77 | # rank=0,
78 | # world_size=1,
79 | # use_horovod=False,
80 | # smoothing=0.,
81 | # ):
82 | # super().__init__()
83 | # self.local_loss = local_loss
84 | # self.gather_with_grad = gather_with_grad
85 | # self.cache_labels = cache_labels
86 | # self.rank = rank
87 | # self.world_size = world_size
88 | # self.use_horovod = use_horovod
89 | # self.label_smoothing_cross_entropy = LabelSmoothingCrossEntropy(smoothing=smoothing) if smoothing > 0 else None
90 |
91 | # # cache state
92 | # self.prev_num_logits = 0
93 | # self.labels = {}
94 |
95 | # def forward(self, image_features, text_features, logit_scale=1.):
96 | # device = image_features.device
97 | # if self.world_size > 1:
98 | # all_image_features, all_text_features = gather_features(
99 | # image_features, text_features,
100 | # self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
101 |
102 | # if self.local_loss:
103 | # logits_per_image = logit_scale * image_features @ all_text_features.T
104 | # logits_per_text = logit_scale * text_features @ all_image_features.T
105 | # else:
106 | # logits_per_image = logit_scale * all_image_features @ all_text_features.T
107 | # logits_per_text = logits_per_image.T
108 | # else:
109 | # logits_per_image = logit_scale * image_features @ text_features.T
110 | # logits_per_text = logit_scale * text_features @ image_features.T
111 | # # calculated ground-truth and cache if enabled
112 | # num_logits = logits_per_image.shape[0]
113 | # if self.prev_num_logits != num_logits or device not in self.labels:
114 | # labels = torch.arange(num_logits, device=device, dtype=torch.long)
115 | # if self.world_size > 1 and self.local_loss:
116 | # labels = labels + num_logits * self.rank
117 | # if self.cache_labels:
118 | # self.labels[device] = labels
119 | # self.prev_num_logits = num_logits
120 | # else:
121 | # labels = self.labels[device]
122 |
123 | # if self.label_smoothing_cross_entropy:
124 | # total_loss = (
125 | # self.label_smoothing_cross_entropy(logits_per_image, labels) +
126 | # self.label_smoothing_cross_entropy(logits_per_text, labels)
127 | # ) / 2
128 | # else:
129 | # total_loss = (
130 | # F.cross_entropy(logits_per_image, labels) +
131 | # F.cross_entropy(logits_per_text, labels)
132 | # ) / 2
133 |
134 | # acc = None
135 | # i2t_acc = (logits_per_image.argmax(-1) == labels).sum() / len(logits_per_image)
136 | # t2i_acc = (logits_per_text.argmax(-1) == labels).sum() / len(logits_per_text)
137 | # acc = {"i2t": i2t_acc, "t2i": t2i_acc}
138 | # return total_loss, acc
--------------------------------------------------------------------------------
/data/model/vision_encoders/evaclip/model_configs/EVA01-CLIP-B-16.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 768,
7 | "patch_size": 16,
8 | "eva_model_name": "eva-clip-b-16",
9 | "ls_init_value": 0.1,
10 | "drop_path_rate": 0.0
11 | },
12 | "text_cfg": {
13 | "context_length": 77,
14 | "vocab_size": 49408,
15 | "width": 512,
16 | "heads": 8,
17 | "layers": 12
18 | }
19 | }
--------------------------------------------------------------------------------
/data/model/vision_encoders/evaclip/model_configs/EVA01-CLIP-g-14-plus.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 40,
6 | "width": 1408,
7 | "head_width": 88,
8 | "mlp_ratio": 4.3637,
9 | "patch_size": 14,
10 | "eva_model_name": "eva-clip-g-14-x",
11 | "drop_path_rate": 0,
12 | "xattn": true,
13 | "fusedLN": true
14 | },
15 | "text_cfg": {
16 | "context_length": 77,
17 | "vocab_size": 49408,
18 | "width": 1024,
19 | "heads": 16,
20 | "layers": 24,
21 | "xattn": false,
22 | "fusedLN": true
23 | }
24 | }
--------------------------------------------------------------------------------
/data/model/vision_encoders/evaclip/model_configs/EVA01-CLIP-g-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 40,
6 | "width": 1408,
7 | "head_width": 88,
8 | "mlp_ratio": 4.3637,
9 | "patch_size": 14,
10 | "eva_model_name": "eva-clip-g-14-x",
11 | "drop_path_rate": 0.4,
12 | "xattn": true,
13 | "fusedLN": true
14 | },
15 | "text_cfg": {
16 | "context_length": 77,
17 | "vocab_size": 49408,
18 | "width": 768,
19 | "heads": 12,
20 | "layers": 12,
21 | "xattn": false,
22 | "fusedLN": true
23 | }
24 | }
--------------------------------------------------------------------------------
/data/model/vision_encoders/evaclip/model_configs/EVA02-CLIP-B-16.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 768,
7 | "head_width": 64,
8 | "patch_size": 16,
9 | "mlp_ratio": 2.6667,
10 | "eva_model_name": "eva-clip-b-16-X",
11 | "drop_path_rate": 0.0,
12 | "xattn": true,
13 | "fusedLN": true,
14 | "rope": true,
15 | "pt_hw_seq_len": 16,
16 | "intp_freq": true,
17 | "naiveswiglu": true,
18 | "subln": true
19 | },
20 | "text_cfg": {
21 | "context_length": 77,
22 | "vocab_size": 49408,
23 | "width": 512,
24 | "heads": 8,
25 | "layers": 12,
26 | "xattn": true,
27 | "fusedLN": true
28 | }
29 | }
--------------------------------------------------------------------------------
/data/model/vision_encoders/evaclip/model_configs/EVA02-CLIP-L-14-336.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 336,
5 | "layers": 24,
6 | "width": 1024,
7 | "drop_path_rate": 0,
8 | "head_width": 64,
9 | "mlp_ratio": 2.6667,
10 | "patch_size": 14,
11 | "eva_model_name": "eva-clip-l-14-336",
12 | "xattn": true,
13 | "fusedLN": true,
14 | "rope": true,
15 | "pt_hw_seq_len": 16,
16 | "intp_freq": true,
17 | "naiveswiglu": true,
18 | "subln": true
19 | },
20 | "text_cfg": {
21 | "context_length": 77,
22 | "vocab_size": 49408,
23 | "width": 768,
24 | "heads": 12,
25 | "layers": 12,
26 | "xattn": false,
27 | "fusedLN": true
28 | }
29 | }
--------------------------------------------------------------------------------
/data/model/vision_encoders/evaclip/model_configs/EVA02-CLIP-L-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 24,
6 | "width": 1024,
7 | "drop_path_rate": 0,
8 | "head_width": 64,
9 | "mlp_ratio": 2.6667,
10 | "patch_size": 14,
11 | "eva_model_name": "eva-clip-l-14",
12 | "xattn": true,
13 | "fusedLN": true,
14 | "rope": true,
15 | "pt_hw_seq_len": 16,
16 | "intp_freq": true,
17 | "naiveswiglu": true,
18 | "subln": true
19 | },
20 | "text_cfg": {
21 | "context_length": 77,
22 | "vocab_size": 49408,
23 | "width": 768,
24 | "heads": 12,
25 | "layers": 12,
26 | "xattn": false,
27 | "fusedLN": true
28 | }
29 | }
--------------------------------------------------------------------------------
/data/model/vision_encoders/evaclip/model_configs/EVA02-CLIP-bigE-14-plus.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 64,
6 | "width": 1792,
7 | "head_width": 112,
8 | "mlp_ratio": 8.571428571428571,
9 | "patch_size": 14,
10 | "eva_model_name": "eva-clip-4b-14-x",
11 | "drop_path_rate": 0,
12 | "xattn": true,
13 | "postnorm": true,
14 | "fusedLN": true
15 | },
16 | "text_cfg": {
17 | "context_length": 77,
18 | "vocab_size": 49408,
19 | "width": 1280,
20 | "heads": 20,
21 | "layers": 32,
22 | "xattn": false,
23 | "fusedLN": true
24 | }
25 | }
26 |
--------------------------------------------------------------------------------
/data/model/vision_encoders/evaclip/model_configs/EVA02-CLIP-bigE-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 64,
6 | "width": 1792,
7 | "head_width": 112,
8 | "mlp_ratio": 8.571428571428571,
9 | "patch_size": 14,
10 | "eva_model_name": "eva-clip-4b-14-x",
11 | "drop_path_rate": 0,
12 | "xattn": true,
13 | "postnorm": true,
14 | "fusedLN": true
15 | },
16 | "text_cfg": {
17 | "context_length": 77,
18 | "vocab_size": 49408,
19 | "width": 1024,
20 | "heads": 16,
21 | "layers": 24,
22 | "xattn": false,
23 | "fusedLN": true
24 | }
25 | }
--------------------------------------------------------------------------------
/data/model/vision_encoders/evaclip/modified_resnet.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 | from torch import nn
5 | from torch.nn import functional as F
6 |
7 | from .utils import freeze_batch_norm_2d
8 |
9 |
10 | class Bottleneck(nn.Module):
11 | expansion = 4
12 |
13 | def __init__(self, inplanes, planes, stride=1):
14 | super().__init__()
15 |
16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18 | self.bn1 = nn.BatchNorm2d(planes)
19 | self.act1 = nn.ReLU(inplace=True)
20 |
21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
22 | self.bn2 = nn.BatchNorm2d(planes)
23 | self.act2 = nn.ReLU(inplace=True)
24 |
25 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
26 |
27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
29 | self.act3 = nn.ReLU(inplace=True)
30 |
31 | self.downsample = None
32 | self.stride = stride
33 |
34 | if stride > 1 or inplanes != planes * Bottleneck.expansion:
35 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
36 | self.downsample = nn.Sequential(OrderedDict([
37 | ("-1", nn.AvgPool2d(stride)),
38 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
39 | ("1", nn.BatchNorm2d(planes * self.expansion))
40 | ]))
41 |
42 | def forward(self, x: torch.Tensor):
43 | identity = x
44 |
45 | out = self.act1(self.bn1(self.conv1(x)))
46 | out = self.act2(self.bn2(self.conv2(out)))
47 | out = self.avgpool(out)
48 | out = self.bn3(self.conv3(out))
49 |
50 | if self.downsample is not None:
51 | identity = self.downsample(x)
52 |
53 | out += identity
54 | out = self.act3(out)
55 | return out
56 |
57 |
58 | class AttentionPool2d(nn.Module):
59 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
60 | super().__init__()
61 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
62 | self.k_proj = nn.Linear(embed_dim, embed_dim)
63 | self.q_proj = nn.Linear(embed_dim, embed_dim)
64 | self.v_proj = nn.Linear(embed_dim, embed_dim)
65 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
66 | self.num_heads = num_heads
67 |
68 | def forward(self, x):
69 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
70 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
71 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
72 | x, _ = F.multi_head_attention_forward(
73 | query=x, key=x, value=x,
74 | embed_dim_to_check=x.shape[-1],
75 | num_heads=self.num_heads,
76 | q_proj_weight=self.q_proj.weight,
77 | k_proj_weight=self.k_proj.weight,
78 | v_proj_weight=self.v_proj.weight,
79 | in_proj_weight=None,
80 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
81 | bias_k=None,
82 | bias_v=None,
83 | add_zero_attn=False,
84 | dropout_p=0.,
85 | out_proj_weight=self.c_proj.weight,
86 | out_proj_bias=self.c_proj.bias,
87 | use_separate_proj_weight=True,
88 | training=self.training,
89 | need_weights=False
90 | )
91 |
92 | return x[0]
93 |
94 |
95 | class ModifiedResNet(nn.Module):
96 | """
97 | A ResNet class that is similar to torchvision's but contains the following changes:
98 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
99 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
100 | - The final pooling layer is a QKV attention instead of an average pool
101 | """
102 |
103 | def __init__(self, layers, output_dim, heads, image_size=224, width=64):
104 | super().__init__()
105 | self.output_dim = output_dim
106 | self.image_size = image_size
107 |
108 | # the 3-layer stem
109 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
110 | self.bn1 = nn.BatchNorm2d(width // 2)
111 | self.act1 = nn.ReLU(inplace=True)
112 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
113 | self.bn2 = nn.BatchNorm2d(width // 2)
114 | self.act2 = nn.ReLU(inplace=True)
115 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
116 | self.bn3 = nn.BatchNorm2d(width)
117 | self.act3 = nn.ReLU(inplace=True)
118 | self.avgpool = nn.AvgPool2d(2)
119 |
120 | # residual layers
121 | self._inplanes = width # this is a *mutable* variable used during construction
122 | self.layer1 = self._make_layer(width, layers[0])
123 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
124 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
125 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
126 |
127 | embed_dim = width * 32 # the ResNet feature dimension
128 | self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
129 |
130 | self.init_parameters()
131 |
132 | def _make_layer(self, planes, blocks, stride=1):
133 | layers = [Bottleneck(self._inplanes, planes, stride)]
134 |
135 | self._inplanes = planes * Bottleneck.expansion
136 | for _ in range(1, blocks):
137 | layers.append(Bottleneck(self._inplanes, planes))
138 |
139 | return nn.Sequential(*layers)
140 |
141 | def init_parameters(self):
142 | if self.attnpool is not None:
143 | std = self.attnpool.c_proj.in_features ** -0.5
144 | nn.init.normal_(self.attnpool.q_proj.weight, std=std)
145 | nn.init.normal_(self.attnpool.k_proj.weight, std=std)
146 | nn.init.normal_(self.attnpool.v_proj.weight, std=std)
147 | nn.init.normal_(self.attnpool.c_proj.weight, std=std)
148 |
149 | for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
150 | for name, param in resnet_block.named_parameters():
151 | if name.endswith("bn3.weight"):
152 | nn.init.zeros_(param)
153 |
154 | def lock(self, unlocked_groups=0, freeze_bn_stats=False):
155 | assert unlocked_groups == 0, 'partial locking not currently supported for this model'
156 | for param in self.parameters():
157 | param.requires_grad = False
158 | if freeze_bn_stats:
159 | freeze_batch_norm_2d(self)
160 |
161 | @torch.jit.ignore
162 | def set_grad_checkpointing(self, enable=True):
163 | # FIXME support for non-transformer
164 | pass
165 |
166 | def stem(self, x):
167 | x = self.act1(self.bn1(self.conv1(x)))
168 | x = self.act2(self.bn2(self.conv2(x)))
169 | x = self.act3(self.bn3(self.conv3(x)))
170 | x = self.avgpool(x)
171 | return x
172 |
173 | def forward(self, x):
174 | x = self.stem(x)
175 | x = self.layer1(x)
176 | x = self.layer2(x)
177 | x = self.layer3(x)
178 | x = self.layer4(x)
179 | x = self.attnpool(x)
180 |
181 | return x
182 |
--------------------------------------------------------------------------------
/data/model/vision_encoders/evaclip/openai.py:
--------------------------------------------------------------------------------
1 | """ OpenAI pretrained model functions
2 |
3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4 | """
5 |
6 | import os
7 | import warnings
8 | from typing import List, Optional, Union
9 |
10 | import torch
11 |
12 | from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype
13 | from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url
14 |
15 | __all__ = ["list_openai_models", "load_openai_model"]
16 |
17 |
18 | def list_openai_models() -> List[str]:
19 | """Returns the names of available CLIP models"""
20 | return list_pretrained_models_by_tag('openai')
21 |
22 |
23 | def load_openai_model(
24 | name: str,
25 | precision: Optional[str] = None,
26 | device: Optional[Union[str, torch.device]] = None,
27 | jit: bool = True,
28 | cache_dir: Optional[str] = None,
29 | ):
30 | """Load a CLIP model
31 |
32 | Parameters
33 | ----------
34 | name : str
35 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
36 | precision: str
37 | Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.
38 | device : Union[str, torch.device]
39 | The device to put the loaded model
40 | jit : bool
41 | Whether to load the optimized JIT model (default) or more hackable non-JIT model.
42 | cache_dir : Optional[str]
43 | The directory to cache the downloaded model weights
44 |
45 | Returns
46 | -------
47 | model : torch.nn.Module
48 | The CLIP model
49 | preprocess : Callable[[PIL.Image], torch.Tensor]
50 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
51 | """
52 | if device is None:
53 | device = "cuda" if torch.cuda.is_available() else "cpu"
54 | if precision is None:
55 | precision = 'fp32' if device == 'cpu' else 'fp16'
56 |
57 | if get_pretrained_url(name, 'openai'):
58 | model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir)
59 | elif os.path.isfile(name):
60 | model_path = name
61 | else:
62 | raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}")
63 |
64 | try:
65 | # loading JIT archive
66 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
67 | state_dict = None
68 | except RuntimeError:
69 | # loading saved state dict
70 | if jit:
71 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
72 | jit = False
73 | state_dict = torch.load(model_path, map_location="cpu")
74 |
75 | if not jit:
76 | # Build a non-jit model from the OpenAI jitted model state dict
77 | cast_dtype = get_cast_dtype(precision)
78 | try:
79 | model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype)
80 | except KeyError:
81 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
82 | model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype)
83 |
84 | # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use
85 | model = model.to(device)
86 | if precision.startswith('amp') or precision == 'fp32':
87 | model.float()
88 | elif precision == 'bf16':
89 | convert_weights_to_lp(model, dtype=torch.bfloat16)
90 |
91 | return model
92 |
93 | # patch the device names
94 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
95 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
96 |
97 | def patch_device(module):
98 | try:
99 | graphs = [module.graph] if hasattr(module, "graph") else []
100 | except RuntimeError:
101 | graphs = []
102 |
103 | if hasattr(module, "forward1"):
104 | graphs.append(module.forward1.graph)
105 |
106 | for graph in graphs:
107 | for node in graph.findAllNodes("prim::Constant"):
108 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
109 | node.copyAttributes(device_node)
110 |
111 | model.apply(patch_device)
112 | patch_device(model.encode_image)
113 | patch_device(model.encode_text)
114 |
115 | # patch dtype to float32 (typically for CPU)
116 | if precision == 'fp32':
117 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
118 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
119 | float_node = float_input.node()
120 |
121 | def patch_float(module):
122 | try:
123 | graphs = [module.graph] if hasattr(module, "graph") else []
124 | except RuntimeError:
125 | graphs = []
126 |
127 | if hasattr(module, "forward1"):
128 | graphs.append(module.forward1.graph)
129 |
130 | for graph in graphs:
131 | for node in graph.findAllNodes("aten::to"):
132 | inputs = list(node.inputs())
133 | for i in [1, 2]: # dtype can be the second or third argument to aten::to()
134 | if inputs[i].node()["value"] == 5:
135 | inputs[i].node().copyAttributes(float_node)
136 |
137 | model.apply(patch_float)
138 | patch_float(model.encode_image)
139 | patch_float(model.encode_text)
140 | model.float()
141 |
142 | # ensure image_size attr available at consistent location for both jit and non-jit
143 | model.visual.image_size = model.input_resolution.item()
144 | return model
145 |
--------------------------------------------------------------------------------
/data/model/vision_encoders/evaclip/rope.py:
--------------------------------------------------------------------------------
1 | from math import pi
2 | import torch
3 | from torch import nn
4 | from einops import rearrange, repeat
5 | import logging
6 |
7 | def broadcat(tensors, dim = -1):
8 | num_tensors = len(tensors)
9 | shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
10 | assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
11 | shape_len = list(shape_lens)[0]
12 | dim = (dim + shape_len) if dim < 0 else dim
13 | dims = list(zip(*map(lambda t: list(t.shape), tensors)))
14 | expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
15 | assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation'
16 | max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
17 | expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
18 | expanded_dims.insert(dim, (dim, dims[dim]))
19 | expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
20 | tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
21 | return torch.cat(tensors, dim = dim)
22 |
23 | def rotate_half(x):
24 | x = rearrange(x, '... (d r) -> ... d r', r = 2)
25 | x1, x2 = x.unbind(dim = -1)
26 | x = torch.stack((-x2, x1), dim = -1)
27 | return rearrange(x, '... d r -> ... (d r)')
28 |
29 |
30 | class VisionRotaryEmbedding(nn.Module):
31 | def __init__(
32 | self,
33 | dim,
34 | pt_seq_len,
35 | ft_seq_len=None,
36 | custom_freqs = None,
37 | freqs_for = 'lang',
38 | theta = 10000,
39 | max_freq = 10,
40 | num_freqs = 1,
41 | ):
42 | super().__init__()
43 | if custom_freqs:
44 | freqs = custom_freqs
45 | elif freqs_for == 'lang':
46 | freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
47 | elif freqs_for == 'pixel':
48 | freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
49 | elif freqs_for == 'constant':
50 | freqs = torch.ones(num_freqs).float()
51 | else:
52 | raise ValueError(f'unknown modality {freqs_for}')
53 |
54 | if ft_seq_len is None: ft_seq_len = pt_seq_len
55 | t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
56 |
57 | freqs_h = torch.einsum('..., f -> ... f', t, freqs)
58 | freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2)
59 |
60 | freqs_w = torch.einsum('..., f -> ... f', t, freqs)
61 | freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2)
62 |
63 | freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1)
64 |
65 | self.register_buffer("freqs_cos", freqs.cos())
66 | self.register_buffer("freqs_sin", freqs.sin())
67 |
68 | logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
69 |
70 | def forward(self, t, start_index = 0):
71 | rot_dim = self.freqs_cos.shape[-1]
72 | end_index = start_index + rot_dim
73 | assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
74 | t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
75 | t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
76 |
77 | return torch.cat((t_left, t, t_right), dim = -1)
78 |
79 | class VisionRotaryEmbeddingFast(nn.Module):
80 | def __init__(
81 | self,
82 | dim,
83 | pt_seq_len,
84 | ft_seq_len=None,
85 | custom_freqs = None,
86 | freqs_for = 'lang',
87 | theta = 10000,
88 | max_freq = 10,
89 | num_freqs = 1,
90 | patch_dropout = 0.
91 | ):
92 | super().__init__()
93 | if custom_freqs:
94 | freqs = custom_freqs
95 | elif freqs_for == 'lang':
96 | freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
97 | elif freqs_for == 'pixel':
98 | freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
99 | elif freqs_for == 'constant':
100 | freqs = torch.ones(num_freqs).float()
101 | else:
102 | raise ValueError(f'unknown modality {freqs_for}')
103 |
104 | if ft_seq_len is None: ft_seq_len = pt_seq_len
105 | t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
106 |
107 | freqs = torch.einsum('..., f -> ... f', t, freqs)
108 | freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
109 | freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1)
110 |
111 | freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
112 | freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
113 |
114 | self.patch_dropout = patch_dropout
115 |
116 | self.register_buffer("freqs_cos", freqs_cos)
117 | self.register_buffer("freqs_sin", freqs_sin)
118 |
119 | logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
120 |
121 | def forward(self, t, patch_indices_keep=None):
122 | if patch_indices_keep is not None:
123 | batch = t.size()[0]
124 | batch_indices = torch.arange(batch)
125 | batch_indices = batch_indices[..., None]
126 |
127 | freqs_cos = repeat(self.freqs_cos, 'i j -> n i m j', n=t.shape[0], m=t.shape[1])
128 | freqs_sin = repeat(self.freqs_sin, 'i j -> n i m j', n=t.shape[0], m=t.shape[1])
129 |
130 | freqs_cos = freqs_cos[batch_indices, patch_indices_keep]
131 | freqs_cos = rearrange(freqs_cos, 'n i m j -> n m i j')
132 | freqs_sin = freqs_sin[batch_indices, patch_indices_keep]
133 | freqs_sin = rearrange(freqs_sin, 'n i m j -> n m i j')
134 |
135 | return t * freqs_cos + rotate_half(t) * freqs_sin
136 |
137 | return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
--------------------------------------------------------------------------------
/data/model/vision_encoders/evaclip/timm_model.py:
--------------------------------------------------------------------------------
1 | """ timm model adapter
2 |
3 | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
4 | """
5 | import logging
6 | from collections import OrderedDict
7 |
8 | import torch
9 | import torch.nn as nn
10 |
11 | try:
12 | import timm
13 | from timm.models.layers import Mlp, to_2tuple
14 | try:
15 | # old timm imports < 0.8.1
16 | from timm.models.layers.attention_pool2d import RotAttentionPool2d
17 | from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d
18 | except ImportError:
19 | # new timm imports >= 0.8.1
20 | from timm.layers import RotAttentionPool2d
21 | from timm.layers import AttentionPool2d as AbsAttentionPool2d
22 | except ImportError:
23 | timm = None
24 |
25 | from .utils import freeze_batch_norm_2d
26 |
27 |
28 | class TimmModel(nn.Module):
29 | """ timm model adapter
30 | # FIXME this adapter is a work in progress, may change in ways that break weight compat
31 | """
32 |
33 | def __init__(
34 | self,
35 | model_name,
36 | embed_dim,
37 | image_size=224,
38 | pool='avg',
39 | proj='linear',
40 | proj_bias=False,
41 | drop=0.,
42 | pretrained=False):
43 | super().__init__()
44 | if timm is None:
45 | raise RuntimeError("Please `pip install timm` to use timm models.")
46 |
47 | self.image_size = to_2tuple(image_size)
48 | self.trunk = timm.create_model(model_name, pretrained=pretrained)
49 | feat_size = self.trunk.default_cfg.get('pool_size', None)
50 | feature_ndim = 1 if not feat_size else 2
51 | if pool in ('abs_attn', 'rot_attn'):
52 | assert feature_ndim == 2
53 | # if attn pooling used, remove both classifier and default pool
54 | self.trunk.reset_classifier(0, global_pool='')
55 | else:
56 | # reset global pool if pool config set, otherwise leave as network default
57 | reset_kwargs = dict(global_pool=pool) if pool else {}
58 | self.trunk.reset_classifier(0, **reset_kwargs)
59 | prev_chs = self.trunk.num_features
60 |
61 | head_layers = OrderedDict()
62 | if pool == 'abs_attn':
63 | head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)
64 | prev_chs = embed_dim
65 | elif pool == 'rot_attn':
66 | head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
67 | prev_chs = embed_dim
68 | else:
69 | assert proj, 'projection layer needed if non-attention pooling is used.'
70 |
71 | # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
72 | if proj == 'linear':
73 | head_layers['drop'] = nn.Dropout(drop)
74 | head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias)
75 | elif proj == 'mlp':
76 | head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop, bias=(True, proj_bias))
77 |
78 | self.head = nn.Sequential(head_layers)
79 |
80 | def lock(self, unlocked_groups=0, freeze_bn_stats=False):
81 | """ lock modules
82 | Args:
83 | unlocked_groups (int): leave last n layer groups unlocked (default: 0)
84 | """
85 | if not unlocked_groups:
86 | # lock full model
87 | for param in self.trunk.parameters():
88 | param.requires_grad = False
89 | if freeze_bn_stats:
90 | freeze_batch_norm_2d(self.trunk)
91 | else:
92 | # NOTE: partial freeze requires latest timm (master) branch and is subject to change
93 | try:
94 | # FIXME import here until API stable and in an official release
95 | from timm.models.helpers import group_parameters, group_modules
96 | except ImportError:
97 | raise RuntimeError(
98 | 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`')
99 | matcher = self.trunk.group_matcher()
100 | gparams = group_parameters(self.trunk, matcher)
101 | max_layer_id = max(gparams.keys())
102 | max_layer_id = max_layer_id - unlocked_groups
103 | for group_idx in range(max_layer_id + 1):
104 | group = gparams[group_idx]
105 | for param in group:
106 | self.trunk.get_parameter(param).requires_grad = False
107 | if freeze_bn_stats:
108 | gmodules = group_modules(self.trunk, matcher, reverse=True)
109 | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
110 | freeze_batch_norm_2d(self.trunk, gmodules)
111 |
112 | @torch.jit.ignore
113 | def set_grad_checkpointing(self, enable=True):
114 | try:
115 | self.trunk.set_grad_checkpointing(enable)
116 | except Exception as e:
117 | logging.warning('grad checkpointing not supported for this timm image tower, continuing without...')
118 |
119 | def forward(self, x):
120 | x = self.trunk(x)
121 | x = self.head(x)
122 | return x
123 |
--------------------------------------------------------------------------------
/data/model/vision_encoders/evaclip/tokenizer.py:
--------------------------------------------------------------------------------
1 | """ CLIP tokenizer
2 |
3 | Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4 | """
5 | import gzip
6 | import html
7 | import os
8 | from functools import lru_cache
9 | from typing import Union, List
10 |
11 | # import ftfy
12 | import regex as re
13 | import torch
14 |
15 | # https://stackoverflow.com/q/62691279
16 | import os
17 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
18 |
19 |
20 | @lru_cache()
21 | def default_bpe():
22 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
23 |
24 |
25 | @lru_cache()
26 | def bytes_to_unicode():
27 | """
28 | Returns list of utf-8 byte and a corresponding list of unicode strings.
29 | The reversible bpe codes work on unicode strings.
30 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
31 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
32 | This is a signficant percentage of your normal, say, 32K bpe vocab.
33 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
34 | And avoids mapping to whitespace/control characters the bpe code barfs on.
35 | """
36 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
37 | cs = bs[:]
38 | n = 0
39 | for b in range(2**8):
40 | if b not in bs:
41 | bs.append(b)
42 | cs.append(2**8+n)
43 | n += 1
44 | cs = [chr(n) for n in cs]
45 | return dict(zip(bs, cs))
46 |
47 |
48 | def get_pairs(word):
49 | """Return set of symbol pairs in a word.
50 | Word is represented as tuple of symbols (symbols being variable-length strings).
51 | """
52 | pairs = set()
53 | prev_char = word[0]
54 | for char in word[1:]:
55 | pairs.add((prev_char, char))
56 | prev_char = char
57 | return pairs
58 |
59 |
60 | def basic_clean(text):
61 | text = ftfy.fix_text(text)
62 | text = html.unescape(html.unescape(text))
63 | return text.strip()
64 |
65 |
66 | def whitespace_clean(text):
67 | text = re.sub(r'\s+', ' ', text)
68 | text = text.strip()
69 | return text
70 |
71 |
72 | class SimpleTokenizer(object):
73 | def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
74 | self.byte_encoder = bytes_to_unicode()
75 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
76 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
77 | merges = merges[1:49152-256-2+1]
78 | merges = [tuple(merge.split()) for merge in merges]
79 | vocab = list(bytes_to_unicode().values())
80 | vocab = vocab + [v+'' for v in vocab]
81 | for merge in merges:
82 | vocab.append(''.join(merge))
83 | if not special_tokens:
84 | special_tokens = ['', '']
85 | else:
86 | special_tokens = ['', ''] + special_tokens
87 | vocab.extend(special_tokens)
88 | self.encoder = dict(zip(vocab, range(len(vocab))))
89 | self.decoder = {v: k for k, v in self.encoder.items()}
90 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
91 | self.cache = {t:t for t in special_tokens}
92 | special = "|".join(special_tokens)
93 | self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
94 |
95 | self.vocab_size = len(self.encoder)
96 | self.all_special_ids = [self.encoder[t] for t in special_tokens]
97 |
98 | def bpe(self, token):
99 | if token in self.cache:
100 | return self.cache[token]
101 | word = tuple(token[:-1]) + ( token[-1] + '',)
102 | pairs = get_pairs(word)
103 |
104 | if not pairs:
105 | return token+''
106 |
107 | while True:
108 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
109 | if bigram not in self.bpe_ranks:
110 | break
111 | first, second = bigram
112 | new_word = []
113 | i = 0
114 | while i < len(word):
115 | try:
116 | j = word.index(first, i)
117 | new_word.extend(word[i:j])
118 | i = j
119 | except:
120 | new_word.extend(word[i:])
121 | break
122 |
123 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
124 | new_word.append(first+second)
125 | i += 2
126 | else:
127 | new_word.append(word[i])
128 | i += 1
129 | new_word = tuple(new_word)
130 | word = new_word
131 | if len(word) == 1:
132 | break
133 | else:
134 | pairs = get_pairs(word)
135 | word = ' '.join(word)
136 | self.cache[token] = word
137 | return word
138 |
139 | def encode(self, text):
140 | bpe_tokens = []
141 | text = whitespace_clean(basic_clean(text)).lower()
142 | for token in re.findall(self.pat, text):
143 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
144 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
145 | return bpe_tokens
146 |
147 | def decode(self, tokens):
148 | text = ''.join([self.decoder[token] for token in tokens])
149 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
150 | return text
151 |
152 |
153 | _tokenizer = SimpleTokenizer()
154 |
155 |
156 | def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
157 | """
158 | Returns the tokenized representation of given input string(s)
159 |
160 | Parameters
161 | ----------
162 | texts : Union[str, List[str]]
163 | An input string or a list of input strings to tokenize
164 | context_length : int
165 | The context length to use; all CLIP models use 77 as the context length
166 |
167 | Returns
168 | -------
169 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
170 | """
171 | if isinstance(texts, str):
172 | texts = [texts]
173 |
174 | sot_token = _tokenizer.encoder[""]
175 | eot_token = _tokenizer.encoder[""]
176 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
177 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
178 |
179 | for i, tokens in enumerate(all_tokens):
180 | if len(tokens) > context_length:
181 | tokens = tokens[:context_length] # Truncate
182 | tokens[-1] = eot_token
183 | result[i, :len(tokens)] = torch.tensor(tokens)
184 |
185 | return result
186 |
187 |
188 | class HFTokenizer:
189 | "HuggingFace tokenizer wrapper"
190 | def __init__(self, tokenizer_name:str):
191 | from transformers import AutoTokenizer
192 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
193 |
194 | def __call__(self, texts:Union[str, List[str]], context_length:int=77) -> torch.Tensor:
195 | # same cleaning as for default tokenizer, except lowercasing
196 | # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
197 | if isinstance(texts, str):
198 | texts = [texts]
199 | texts = [whitespace_clean(basic_clean(text)) for text in texts]
200 | input_ids = self.tokenizer(texts, return_tensors='pt', max_length=context_length, padding='max_length', truncation=True).input_ids
201 | return input_ids
202 |
--------------------------------------------------------------------------------
/data/model/vision_encoders/evaclip/transform.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Sequence, Tuple
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torchvision.transforms.functional as F
6 |
7 | from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
8 | CenterCrop
9 |
10 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
11 |
12 |
13 | class ResizeMaxSize(nn.Module):
14 |
15 | def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0):
16 | super().__init__()
17 | if not isinstance(max_size, int):
18 | raise TypeError(f"Size should be int. Got {type(max_size)}")
19 | self.max_size = max_size
20 | self.interpolation = interpolation
21 | self.fn = min if fn == 'min' else min
22 | self.fill = fill
23 |
24 | def forward(self, img):
25 | if isinstance(img, torch.Tensor):
26 | height, width = img.shape[:2]
27 | else:
28 | width, height = img.size
29 | scale = self.max_size / float(max(height, width))
30 | if scale != 1.0:
31 | new_size = tuple(round(dim * scale) for dim in (height, width))
32 | img = F.resize(img, new_size, self.interpolation)
33 | pad_h = self.max_size - new_size[0]
34 | pad_w = self.max_size - new_size[1]
35 | img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill)
36 | return img
37 |
38 |
39 | def _convert_to_rgb(image):
40 | return image.convert('RGB')
41 |
42 |
43 | # class CatGen(nn.Module):
44 | # def __init__(self, num=4):
45 | # self.num = num
46 | # def mixgen_batch(image, text):
47 | # batch_size = image.shape[0]
48 | # index = np.random.permutation(batch_size)
49 |
50 | # cat_images = []
51 | # for i in range(batch_size):
52 | # # image mixup
53 | # image[i,:] = lam * image[i,:] + (1 - lam) * image[index[i],:]
54 | # # text concat
55 | # text[i] = tokenizer((str(text[i]) + " " + str(text[index[i]])))[0]
56 | # text = torch.stack(text)
57 | # return image, text
58 |
59 |
60 | def image_transform(
61 | image_size: int,
62 | is_train: bool,
63 | mean: Optional[Tuple[float, ...]] = None,
64 | std: Optional[Tuple[float, ...]] = None,
65 | resize_longest_max: bool = False,
66 | fill_color: int = 0,
67 | ):
68 | mean = mean or OPENAI_DATASET_MEAN
69 | if not isinstance(mean, (list, tuple)):
70 | mean = (mean,) * 3
71 |
72 | std = std or OPENAI_DATASET_STD
73 | if not isinstance(std, (list, tuple)):
74 | std = (std,) * 3
75 |
76 | if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
77 | # for square size, pass size as int so that Resize() uses aspect preserving shortest edge
78 | image_size = image_size[0]
79 |
80 | normalize = Normalize(mean=mean, std=std)
81 | if is_train:
82 | return Compose([
83 | RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC),
84 | _convert_to_rgb,
85 | ToTensor(),
86 | normalize,
87 | ])
88 | else:
89 | if resize_longest_max:
90 | transforms = [
91 | ResizeMaxSize(image_size, fill=fill_color)
92 | ]
93 | else:
94 | transforms = [
95 | Resize(image_size, interpolation=InterpolationMode.BICUBIC),
96 | CenterCrop(image_size),
97 | ]
98 | transforms.extend([
99 | _convert_to_rgb,
100 | ToTensor(),
101 | normalize,
102 | ])
103 | return Compose(transforms)
104 |
--------------------------------------------------------------------------------
/data/run.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import torch
4 | import torch.distributed as dist
5 | from utils.args import get_args,logging_cfgs
6 | from utils.initialize import initialize
7 | from utils.build_model import build_model
8 | from utils.build_optimizer import build_optimizer
9 | from utils.build_dataloader import create_train_dataloaders, create_val_dataloaders
10 | from utils.pipeline import train, test
11 |
12 |
13 | def main():
14 |
15 | ### init
16 |
17 | args = get_args()
18 | initialize(args)
19 |
20 | ### logging cfgs
21 | logging_cfgs(args)
22 |
23 |
24 | if args.run_cfg.mode == 'training':
25 |
26 | ### create datasets and dataloader
27 | train_loader = create_train_dataloaders(args)
28 | val_loaders = create_val_dataloaders(args)
29 |
30 | ### build model and optimizer
31 |
32 | model, optimizer_ckpt, start_step = build_model(args)
33 |
34 | optimizer = build_optimizer(model, args, optimizer_ckpt)
35 |
36 |
37 | ### start evaluation
38 | if args.run_cfg.first_eval or args.run_cfg.zero_shot:
39 | test(model, val_loaders, args.run_cfg)
40 | if args.run_cfg.zero_shot:
41 | return
42 |
43 | ### start training
44 |
45 |
46 | train(model, optimizer, train_loader, val_loaders, args.run_cfg, start_step = start_step, verbose_time=False)
47 |
48 | elif args.run_cfg.mode == 'testing':
49 | ### build model
50 | model,_,_ = build_model(args)
51 |
52 | ### create datasets and dataloader
53 | val_loaders = create_val_dataloaders(args)
54 |
55 | ### start evaluation
56 | test(model, val_loaders, args.run_cfg)
57 |
58 | else:
59 | raise NotImplementedError
60 |
61 |
62 | if __name__ == "__main__":
63 | main()
64 |
--------------------------------------------------------------------------------
/data/scripts/run_audio_captioner.sh:
--------------------------------------------------------------------------------
1 | torchrun --nnodes 1 \
2 | --node_rank 0 \
3 | --nproc_per_node 8 \
4 | --master_port 9814 \
5 | run.py \
6 | --config ./caption_config/caption-generation-audio.json \
7 | --pretrain_dir './audio_captioner' \
8 | --output_dir './output/audio_caption' \
9 | --test_batch_size 128 \
10 | --generate_nums 3 \
11 | --captioner_mode true \
--------------------------------------------------------------------------------
/data/scripts/run_vision_captioner.sh:
--------------------------------------------------------------------------------
1 | torchrun --nnodes 1 \
2 | --node_rank 0 \
3 | --nproc_per_node 8 \
4 | --master_port 9814 \
5 | run.py \
6 | --config ./caption_config/caption-generation-vision.json \
7 | --pretrain_dir './vision_captioner' \
8 | --output_dir './output/vision_caption' \
9 | --test_batch_size 64 \
10 | --test_vision_sample_num 8 \
11 | --generate_nums 3 \
12 | --captioner_mode true \
13 |
--------------------------------------------------------------------------------
/data/setup_env.sh:
--------------------------------------------------------------------------------
1 | pip install torch==2.0.1 -i https://pypi.tuna.tsinghua.edu.cn/simple
2 | pip install torchvision==0.15.2 -i https://pypi.tuna.tsinghua.edu.cn/simple
3 | pip install torchaudio==2.0.2 -i https://pypi.tuna.tsinghua.edu.cn/simple
4 | pip install decord -i https://pypi.tuna.tsinghua.edu.cn/simple
5 | pip install h5py -i https://pypi.tuna.tsinghua.edu.cn/simple
6 | pip install ffmpeg-python -i https://pypi.tuna.tsinghua.edu.cn/simple
7 | pip install yacs -i https://pypi.tuna.tsinghua.edu.cn/simple
8 | pip install toolz -i https://pypi.tuna.tsinghua.edu.cn/simple
9 | pip install ipdb -i https://pypi.tuna.tsinghua.edu.cn/simple
10 | pip install einops -i https://pypi.tuna.tsinghua.edu.cn/simple
11 | pip install easydict -i https://pypi.tuna.tsinghua.edu.cn/simple
12 | pip install transformers==4.31.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
13 | pip install webdataset -i https://pypi.tuna.tsinghua.edu.cn/simple
14 | pip install SentencePiece -i https://pypi.tuna.tsinghua.edu.cn/simple
15 |
--------------------------------------------------------------------------------
/data/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/invictus717/MiCo/831847ff066296d185c7f304563326224b9dbc8b/data/utils/__init__.py
--------------------------------------------------------------------------------
/data/utils/build_dataloader.py:
--------------------------------------------------------------------------------
1 | import torch.distributed as dist
2 | from torch.utils.data.distributed import DistributedSampler
3 | from torch.utils.data import DataLoader
4 | from data.loader import MetaLoader, PrefetchLoader
5 | from data import data_registry
6 | from utils.distributed import DistributedSampler_wopadding
7 | from .logger import LOGGER
8 |
9 |
10 |
11 | def create_train_dataloaders(args):
12 | data_cfg = args.data_cfg.train
13 | dataloaders = []
14 | dataloaders_dict={}
15 | train_steps = []
16 | loader_names = []
17 |
18 | if len(data_cfg) == 0:
19 | return
20 |
21 | for d_cfg in data_cfg:
22 |
23 | dataset_ls = []
24 |
25 | name = d_cfg['name']
26 | dataset = data_registry[d_cfg.type](d_cfg, args)
27 |
28 | # print(dataset[0])
29 |
30 | collate_fn = dataset.collate_fn
31 | worker_init_fn = dataset.worker_init_fn
32 | use_sampler = dataset.use_sampler
33 |
34 |
35 | LOGGER.info("Create Dataset {} Success".format(name))
36 | task = d_cfg['task']
37 | batch_size = d_cfg['batch_size']
38 | n_workers = d_cfg['n_workers']
39 |
40 | if 'steps' in d_cfg:
41 | train_steps.append(d_cfg['steps'])
42 | elif 'epoch' in d_cfg:
43 | epoch = d_cfg['epoch']
44 | train_steps.append(int((len(dataset) // batch_size) * epoch))
45 |
46 | loader = build_dataloader(dataset, collate_fn, True, batch_size // args.run_cfg.gradient_accumulation_steps , n_workers, worker_init_fn, use_sampler)
47 |
48 | dataloaders.append(loader)
49 | loader_names.append(f'{task}--{name}')
50 |
51 |
52 | for i in range(len(dataloaders)):
53 | ratio = train_steps[i]
54 | dataloaders_dict[loader_names[i]] = (dataloaders[i], ratio)
55 |
56 | n_gpu = dist.get_world_size()
57 | for name, (loader, ratio) in dataloaders_dict.items():
58 | # epoch = (ratio * loader.batch_size * n_gpu ) // len(loader.dataset)
59 | LOGGER.info(f" loader {name} , ratio {ratio} , bs_pergpu {loader.batch_size}, n_workers {loader.num_workers}" )
60 |
61 |
62 | meta_loader = MetaLoader(dataloaders_dict,
63 | accum_steps=args.run_cfg.gradient_accumulation_steps,
64 | distributed=n_gpu > 1)
65 |
66 | if args.run_cfg.num_train_steps == 0:
67 | total_train_steps = sum(train_steps)
68 | args.run_cfg.num_train_steps = total_train_steps
69 |
70 |
71 |
72 | meta_loader = PrefetchLoader(meta_loader)
73 | meta_loader.ndata = len(dataloaders_dict)
74 | args.run_cfg.valid_steps = args.run_cfg.num_train_steps // args.run_cfg.valid_freq -1
75 |
76 |
77 |
78 | return meta_loader
79 |
80 |
81 | def create_val_dataloaders(args):
82 | data_cfg = args.data_cfg.val
83 | dataloaders = {}
84 | for d_cfg in data_cfg:
85 | name = d_cfg['name']
86 | dataset = data_registry[d_cfg.type](d_cfg, args)
87 | collate_fn = dataset.collate_fn
88 | worker_init_fn = dataset.worker_init_fn
89 | use_sampler = dataset.use_sampler
90 | # task = d_cfg['task'].split('_')
91 | # if 'qa' in task:
92 | # dataset.make_submission = d_cfg.get('make_submission', False)
93 |
94 | # if 'cap' in task:
95 | # dataset.annfile = d_cfg['annfile']
96 |
97 | # dataset.data_type = data_type
98 | dataset.name = name
99 | LOGGER.info("Create Dataset {} Success".format(name))
100 | task = d_cfg['task']
101 | batch_size = d_cfg['batch_size']
102 | n_workers = d_cfg['n_workers']
103 | loader = build_dataloader(dataset, collate_fn, False, batch_size, n_workers, worker_init_fn, use_sampler)
104 | task_name = f'{task}--{name}'
105 | dataloaders[task_name] = PrefetchLoader(loader)
106 | return dataloaders
107 |
108 |
109 | def build_dataloader(dataset, collate_fn, is_train, batch_size, n_workers=None, worker_init_fn=None, use_sampler=True):
110 | batch_size = batch_size // dist.get_world_size()
111 | if use_sampler:
112 | if is_train:
113 | sampler = DistributedSampler(dataset)
114 | else:
115 | sampler = DistributedSampler_wopadding(dataset)
116 | loader = DataLoader(dataset, sampler = sampler, batch_size = batch_size,
117 | num_workers=n_workers, pin_memory=True,
118 | collate_fn=collate_fn, drop_last=is_train,worker_init_fn=worker_init_fn)
119 | else:
120 |
121 | loader = DataLoader(dataset, batch_size = batch_size,
122 | num_workers=n_workers, pin_memory=True,
123 | collate_fn=collate_fn, drop_last=is_train,worker_init_fn=worker_init_fn)
124 |
125 | return loader
126 |
127 |
--------------------------------------------------------------------------------
/data/utils/build_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.distributed as dist
4 |
5 | from model import model_registry
6 | from torch.nn.parallel import DistributedDataParallel as DDP
7 | from .logger import LOGGER
8 | from .build_optimizer import build_optimizer
9 |
10 |
11 | class DDP_modify(DDP):
12 | def __getattr__(self, name):
13 | try:
14 | return super().__getattr__(name)
15 | except:
16 | return getattr(self.module,name)
17 |
18 |
19 | def build_model(args):
20 |
21 | model = model_registry[args.model_cfg.model_type](args.model_cfg)
22 | checkpoint = {}
23 |
24 | ### load ckpt from a pretrained_dir
25 | if args.run_cfg.pretrain_dir:
26 | checkpoint = load_from_pretrained_dir(args)
27 | LOGGER.info("Load from pretrained dir {}".format(args.run_cfg.pretrain_dir))
28 |
29 | ### load ckpt from specific path
30 | if args.run_cfg.checkpoint:
31 | checkpoint = torch.load(args.run_cfg.checkpoint, map_location = 'cpu')
32 |
33 | ### resume training
34 | if args.run_cfg.resume:
35 | checkpoint, checkpoint_optim, start_step = load_from_resume(args.run_cfg)
36 | else:
37 | checkpoint_optim, start_step = None , 0
38 |
39 |
40 | checkpoint = {k.replace('module.',''):v for k,v in checkpoint.items()}
41 |
42 | if checkpoint != {}:
43 |
44 | checkpoint = model.modify_checkpoint(checkpoint)
45 | if "model" in checkpoint.keys():
46 | checkpoint = checkpoint["model"]
47 |
48 | missing_keys,unexpected_keys = model.load_state_dict(checkpoint,strict=False)
49 | LOGGER.info(f"Unexpected keys {unexpected_keys}")
50 | LOGGER.info(f"missing_keys {missing_keys}")
51 |
52 |
53 | local_rank = args.local_rank
54 | device = torch.device("cuda", local_rank)
55 | model.to(device)
56 | if args.run_cfg.use_ddp:
57 | model = DDP_modify(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
58 | else:
59 | pass
60 |
61 | return model, checkpoint_optim, start_step
62 |
63 |
64 |
65 | def load_from_pretrained_dir(args):
66 |
67 |
68 | try: ### huggingface trainer
69 | checkpoint_dir = args.run_cfg.pretrain_dir
70 | checkpoint_ls = [ i for i in os.listdir(checkpoint_dir) if i.startswith('checkpoint')]
71 | checkpoint_ls = [int(i.split('-')[1]) for i in checkpoint_ls]
72 | checkpoint_ls.sort()
73 | step = checkpoint_ls[-1]
74 |
75 | try:
76 | checkpoint_name = f'checkpoint-{step}/pytorch_model.bin'
77 | ckpt_file = os.path.join(checkpoint_dir, checkpoint_name)
78 | checkpoint = torch.load(ckpt_file, map_location = 'cpu')
79 | except:
80 | checkpoint_name1 = f'checkpoint-{step}/pytorch_model-00001-of-00002.bin'
81 | ckpt_file1 = torch.load(os.path.join(checkpoint_dir, checkpoint_name1), map_location = 'cpu')
82 | checkpoint_name2 = f'checkpoint-{step}/pytorch_model-00002-of-00002.bin'
83 | ckpt_file2 = torch.load(os.path.join(checkpoint_dir, checkpoint_name2), map_location = 'cpu')
84 | ckpt_file1.update(ckpt_file2)
85 | checkpoint = ckpt_file1
86 | # checkpoint = {k.replace('module.',''):v for k,v in checkpoint.items()}
87 | LOGGER.info(f'load_from_pretrained: {ckpt_file}')
88 |
89 | except:
90 | checkpoint_dir = os.path.join(args.run_cfg.pretrain_dir,'ckpt')
91 | checkpoint_ls = [ i for i in os.listdir(checkpoint_dir) if i.startswith('model_step')]
92 | checkpoint_ls = [int(i.split('_')[2].split('.')[0]) for i in checkpoint_ls]
93 | checkpoint_ls.sort()
94 | step = checkpoint_ls[-1]
95 |
96 | checkpoint_name = 'model_step_'+str(step)+'.pt'
97 | ckpt_file = os.path.join(checkpoint_dir, checkpoint_name)
98 | checkpoint = torch.load(ckpt_file, map_location = 'cpu')
99 | # checkpoint = {k.replace('module.',''):v for k,v in checkpoint.items()}
100 | LOGGER.info(f'load_from_pretrained: {ckpt_file}')
101 |
102 |
103 | return checkpoint
104 |
105 |
106 | def load_from_resume(run_cfg):
107 | ckpt_dir = os.path.join(run_cfg.output_dir,'ckpt')
108 | previous_optimizer_state = [i for i in os.listdir(ckpt_dir) if i.startswith('optimizer')]
109 | steps = [i.split('.pt')[0].split('_')[-1] for i in previous_optimizer_state]
110 | steps = [ int(i) for i in steps]
111 | steps.sort()
112 | previous_step = steps[-1]
113 | previous_optimizer_state = f'optimizer_step_{previous_step}.pt'
114 | previous_model_state = f'model_step_{previous_step}.pt'
115 | previous_step = int(previous_model_state.split('.')[0].split('_')[-1])
116 | previous_optimizer_state = os.path.join(ckpt_dir, previous_optimizer_state)
117 | previous_model_state = os.path.join(ckpt_dir, previous_model_state)
118 |
119 | assert os.path.exists(previous_optimizer_state) and os.path.exists(previous_model_state)
120 | LOGGER.info("choose previous model: {}".format(previous_model_state))
121 | LOGGER.info("choose previous optimizer: {}".format(previous_optimizer_state))
122 | previous_model_state = torch.load(previous_model_state,map_location='cpu')
123 | previous_optimizer_state = torch.load(previous_optimizer_state,map_location='cpu')
124 | return previous_model_state, previous_optimizer_state, previous_step
125 |
126 |
127 |
128 |
129 |
130 |
131 |
--------------------------------------------------------------------------------
/data/utils/distributed.py:
--------------------------------------------------------------------------------
1 | import math
2 | import pickle
3 | import torch
4 | import torch.distributed as dist
5 | from torch.autograd import Function
6 | from torch.utils.data.distributed import DistributedSampler
7 |
8 |
9 |
10 |
11 |
12 | class GatherLayer(torch.autograd.Function):
13 | """
14 | Gather tensors from all workers with support for backward propagation:
15 | This implementation does not cut the gradients as torch.distributed.all_gather does.
16 | """
17 |
18 | @staticmethod
19 | def forward(ctx, x):
20 | output = [
21 | torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())
22 | ]
23 | torch.distributed.all_gather(output, x)
24 | return tuple(output)
25 |
26 | @staticmethod
27 | def backward(ctx, *grads):
28 | all_gradients = torch.stack(grads)
29 | torch.distributed.all_reduce(all_gradients)
30 | return all_gradients[torch.distributed.get_rank()]
31 |
32 |
33 | def all_gather_with_grad(tensors):
34 | """
35 | Performs all_gather operation on the provided tensors.
36 | Graph remains connected for backward grad computation.
37 | """
38 | # Queue the gathered tensors
39 | world_size = torch.distributed.get_world_size()
40 | # There is no need for reduction in the single-proc case
41 | if world_size == 1:
42 | return tensors
43 |
44 | # tensor_all = GatherLayer.apply(tensors)
45 | tensor_all = GatherLayer.apply(tensors)
46 |
47 | return torch.cat(tensor_all, dim=0)
48 |
49 |
50 | @torch.no_grad()
51 | def concat_all_gather(tensor):
52 | """
53 | Performs all_gather operation on the provided tensors.
54 | *** Warning ***: torch.distributed.all_gather has no gradient.
55 | """
56 | # if use distributed training
57 | # if not is_dist_avail_and_initialized():
58 | # return tensor
59 |
60 | tensors_gather = [
61 | torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
62 | ]
63 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
64 |
65 | output = torch.cat(tensors_gather, dim=0)
66 | return output
67 |
68 |
69 |
70 | def _encode(enc, max_size, use_max_size=False):
71 | enc_size = len(enc)
72 | enc_byte = max(math.floor(math.log(max_size, 256)+1), 1)
73 | if use_max_size:
74 | # this is used for broadcasting
75 | buffer_ = torch.cuda.ByteTensor(max_size+enc_byte)
76 | else:
77 | buffer_ = torch.cuda.ByteTensor(enc_size+enc_byte)
78 | remainder = enc_size
79 | for i in range(enc_byte):
80 | base = 256 ** (enc_byte-i-1)
81 | buffer_[i] = remainder // base
82 | remainder %= base
83 | buffer_[enc_byte:enc_byte+enc_size] = torch.ByteTensor(list(enc))
84 | return buffer_, enc_byte
85 |
86 |
87 | def _decode(buffer_, enc_byte):
88 | size = sum(256 ** (enc_byte-i-1) * buffer_[i].item()
89 | for i in range(enc_byte))
90 | bytes_list = bytes(buffer_[enc_byte:enc_byte+size].tolist())
91 | shift = size + enc_byte
92 | return bytes_list, shift
93 |
94 |
95 |
96 |
97 |
98 | def all_gather_list(data):
99 | """Gathers arbitrary data from all nodes into a list."""
100 | enc = pickle.dumps(data)
101 |
102 | enc_size = len(enc)
103 | max_size = ddp_allgather(torch.tensor([enc_size]).cuda()).max().item()
104 | in_buffer, enc_byte = _encode(enc, max_size)
105 |
106 | out_buffer = ddp_allgather(in_buffer[:enc_byte+enc_size])
107 |
108 | results = []
109 | for _ in range(dist.get_world_size()):
110 | bytes_list, shift = _decode(out_buffer, enc_byte)
111 | out_buffer = out_buffer[shift:]
112 | result = pickle.loads(bytes_list)
113 | results.append(result)
114 | return results
115 |
116 |
117 | def any_broadcast(data, root_rank):
118 | """broadcast arbitrary data from root_rank to all nodes."""
119 | enc = pickle.dumps(data)
120 |
121 | max_size = ddp_allgather(torch.tensor([len(enc)]).cuda()).max().item()
122 | buffer_, enc_byte = _encode(enc, max_size, use_max_size=True)
123 |
124 | dist.broadcast(buffer_, root_rank)
125 |
126 | bytes_list, _ = _decode(buffer_, enc_byte)
127 | result = pickle.loads(bytes_list)
128 | return result
129 |
130 |
131 |
132 | ###### with different batch_size ~
133 | def ddp_allgather(input):
134 | tmp_input = input.cuda()
135 | size = torch.tensor(tmp_input.shape[0]).cuda()
136 | size_list = [torch.zeros_like(size) for i in range(dist.get_world_size())]
137 | dist.all_gather(size_list, size)
138 | max_size = max(size_list).item()
139 | padding_size = max_size - size
140 | if padding_size > 0 :
141 | padding_tensor = torch.zeros(padding_size,*tmp_input.shape[1:]).to(tmp_input)
142 | tmp_input = torch.cat((tmp_input, padding_tensor), dim = 0)
143 | tmp_list = [torch.zeros_like(tmp_input) for i in range(dist.get_world_size())]
144 | dist.all_gather(tmp_list, tmp_input)
145 | output = []
146 | for t, s in zip(tmp_list, size_list):
147 | output.append(t[:s])
148 | output = torch.cat(output,dim=0)
149 | return output
150 |
151 |
152 |
153 | class DistributedSampler_wopadding(DistributedSampler):
154 |
155 | def __iter__(self):
156 | if self.shuffle:
157 | # deterministically shuffle based on epoch and seed
158 | g = torch.Generator()
159 | g.manual_seed(self.seed + self.epoch)
160 | indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
161 | else:
162 | indices = list(range(len(self.dataset))) # type: ignore[arg-type]
163 |
164 | # if not self.drop_last:
165 | # # add extra samples to make it evenly divisible
166 | # padding_size = self.total_size - len(indices)
167 | # if padding_size <= len(indices):
168 | # indices += indices[:padding_size]
169 | # else:
170 | # indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
171 | # else:
172 | # remove tail of data to make it evenly divisible.
173 | if self.drop_last:
174 | indices = indices[:self.total_size]
175 | #assert len(indices) == self.total_size
176 |
177 | # subsample
178 | indices = indices[self.rank:len(indices):self.num_replicas]
179 | # assert len(indices) == self.num_samples
180 |
181 | return iter(indices)
182 |
183 |
184 |
185 |
186 |
--------------------------------------------------------------------------------
/data/utils/initialize.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import torch
4 | import numpy as np
5 | import torch.distributed as dist
6 | from .logger import LOGGER,add_log_to_file
7 |
8 | def initialize(opts):
9 |
10 | # if not os.path.exists(opts.run_cfg.output_dir):
11 | os.makedirs(os.path.join(opts.run_cfg.output_dir, 'log'), exist_ok=True)
12 | os.makedirs(os.path.join(opts.run_cfg.output_dir, 'ckpt'), exist_ok=True)
13 |
14 | local_rank = opts.local_rank
15 | torch.cuda.set_device(local_rank)
16 | dist.init_process_group(backend='nccl')
17 | if opts.run_cfg.gradient_accumulation_steps < 1:
18 | raise ValueError("Invalid gradient_accumulation_steps parameter: {}, "
19 | "should be >= 1".format(
20 | opts.run_cfg.gradient_accumulation_steps))
21 | set_random_seed(opts.run_cfg.seed)
22 | torch.backends.cudnn.benchmark = True
23 | torch.backends.cudnn.enabled = True
24 | if dist.get_rank() == 0:
25 | # TB_LOGGER.create(os.path.join(opts.output_dir, 'log'))
26 | add_log_to_file(os.path.join(opts.run_cfg.output_dir, 'log', 'log.txt'))
27 | else:
28 | LOGGER.disabled = True
29 |
30 |
31 | def set_random_seed(seed):
32 | random.seed(seed)
33 | np.random.seed(seed)
34 | torch.manual_seed(seed)
35 | torch.cuda.manual_seed_all(seed)
36 |
37 |
--------------------------------------------------------------------------------
/data/utils/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import math
3 |
4 |
5 | _LOG_FMT = '%(asctime)s - %(levelname)s - %(name)s - %(message)s'
6 | _DATE_FMT = '%m/%d/%Y %H:%M:%S'
7 | logging.basicConfig(format=_LOG_FMT, datefmt=_DATE_FMT, level=logging.INFO)
8 | LOGGER = logging.getLogger('__main__') # this is the global logger
9 |
10 |
11 | def add_log_to_file(log_path):
12 | fh = logging.FileHandler(log_path)
13 | formatter = logging.Formatter(_LOG_FMT, datefmt=_DATE_FMT)
14 | fh.setFormatter(formatter)
15 | LOGGER.addHandler(fh)
16 |
17 |
18 | class RunningMeter(object):
19 | """ running meteor of a scalar value
20 | (useful for monitoring training loss)
21 | """
22 | def __init__(self, name=None, val=None, smooth=0.99):
23 | self._name = name
24 | self._sm = smooth
25 | self._val = val
26 |
27 | def __call__(self, value):
28 | val = (value if self._val is None
29 | else value*(1-self._sm) + self._val*self._sm)
30 | if not math.isnan(val):
31 | self._val = val
32 |
33 | def __str__(self):
34 | return f'{self._name}: {self._val:.4f}'
35 |
36 | @property
37 | def val(self):
38 | if self._val is None:
39 | return 0
40 | return self._val
41 |
42 | @property
43 | def name(self):
44 | return self._name
45 |
46 |
--------------------------------------------------------------------------------
/data/utils/offline_process_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tqdm
3 | import ffmpeg
4 | import subprocess
5 | import multiprocessing
6 | import numpy as np
7 |
8 | from multiprocessing import Pool
9 |
10 |
11 | input_path = '/public/chensihan/datasets/tgif/gifs_used'
12 | output_path = '/public/chensihan/datasets/tgif/'
13 |
14 | data_list = os.listdir(input_path)
15 |
16 | def execCmd(cmd):
17 | r = os.popen(cmd)
18 | text = r.read()
19 | r.close()
20 | return text
21 |
22 | def pipline(video_path, video_probe, output_dir, fps, sr, duration_target):
23 | video_name = os.path.basename(video_path)
24 |
25 | video_name = video_name.replace(".mp4", "")
26 |
27 |
28 | # extract video frames fps
29 | fps_frame_dir = os.path.join(output_dir, f"frames_fps{fps}", video_name)
30 | os.makedirs(fps_frame_dir, exist_ok=True)
31 | cmd = "ffmpeg -loglevel error -i {} -vsync 0 -f image2 -vf fps=fps={:.02f} -qscale:v 2 {}/frame_%04d.jpg".format(
32 | video_path, fps, fps_frame_dir)
33 |
34 | ## extract fixed number frames
35 | # fps_frame_dir = os.path.join(output_dir, f"frames_32", video_name)
36 | # os.makedirs(fps_frame_dir, exist_ok=True)
37 | # cmd = "ffmpeg -loglevel error -i {} -vsync 0 -f image2 -vframes 32 -qscale:v 2 {}/frame_%04d.jpg".format(
38 | # video_path, fps_frame_dir)
39 |
40 |
41 | # ## extract audios
42 | # sr_audio_dir = os.path.join(output_dir,f"audios")
43 | # os.makedirs(sr_audio_dir, exist_ok=True)
44 | # # print(sr_audio_dir)
45 | # audio_name = video_name+'.wav'
46 | # audio_file_path = os.path.join(sr_audio_dir, audio_name)
47 |
48 |
49 | cmd = "ffmpeg -i {} -loglevel error -f wav -vn -ac 1 -ab 16k -ar {} -y {}".format(
50 | video_path, sr, audio_file_path)
51 |
52 |
53 | subprocess.call(cmd, shell=True)
54 |
55 |
56 |
57 | def extract_thread(video_id):
58 |
59 | video_name = os.path.join(input_path, video_id)
60 |
61 | if not os.path.exists(video_name):
62 |
63 | return
64 | try:
65 | # print(1)
66 | probe = ffmpeg.probe(video_name)
67 | # print(1)
68 | pipline(video_name, probe, output_path, fps=1, sr=22050, duration_target=10)
69 | except Exception as e:
70 | print(e)
71 | return
72 |
73 |
74 | def extract_all(video_ids, thread_num, start):
75 | length = len(video_ids)
76 | print(length)
77 | with Pool(thread_num) as p:
78 | list(tqdm.tqdm(p.imap(extract_thread, video_ids), total=length))
79 |
80 | if __name__=='__main__':
81 | thread_num = 20
82 | start = 0
83 |
84 | print(len(data_list))
85 | extract_all(data_list, thread_num, start)
86 |
87 |
--------------------------------------------------------------------------------
/data/utils/pipeline.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.distributed as dist
4 | import torch.nn.functional as F
5 | from tqdm import tqdm
6 | # from apex import amp
7 | from collections import defaultdict
8 | from torch.nn.utils import clip_grad_norm_
9 | from evaluation import evaluation_registry
10 | from .save import ModelSaver
11 | from .tool import NoOp
12 | from .logger import LOGGER, RunningMeter
13 | from .sched import get_lr_sched
14 | from torch.cuda.amp import autocast, GradScaler
15 |
16 |
17 | def train(model, optimizer, train_loader, val_loaders, run_cfg, start_step=0, verbose_time=False):
18 |
19 | if dist.get_rank() == 0:
20 | pbar = tqdm(total=run_cfg.num_train_steps, initial=start_step)
21 | model_saver = ModelSaver(os.path.join(run_cfg.output_dir, 'ckpt'),remove_before_ckpt=run_cfg.remove_before_ckpt)
22 | else:
23 | pbar = NoOp()
24 | model_saver = NoOp()
25 |
26 | loss_moving_averagetors ={}
27 | metric_logger_dict = defaultdict(dict)
28 | global_step = start_step
29 |
30 | scaler = GradScaler()
31 |
32 | best_indicator = {}
33 | evaluate_fn = evaluation_registry[model.config.evaluation_type]
34 |
35 | for step, (name, batch) in enumerate(train_loader):
36 |
37 | ndata = train_loader.ndata
38 | task = name.split('--')[0]
39 |
40 |
41 |
42 | if run_cfg.fp16:
43 | with autocast():
44 | loss_dict = model(batch, task=task, compute_loss=True)
45 | loss = sum(list(loss_dict.values()))
46 | loss_dict['total_loss'] = loss
47 | loss_dict = {k:v.item() for k,v in loss_dict.items()}
48 |
49 | else:
50 | loss_dict = model(batch, task=task, compute_loss=True)
51 | loss = sum(list(loss_dict.values()))
52 | loss_dict['total_loss'] = loss
53 | loss_dict = {k:v.item() for k,v in loss_dict.items()}
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 | if not name in loss_moving_averagetors:
64 | ### first time initialize
65 | for k in loss_dict.keys():
66 | loss_moving_averagetors[f'loss_{name}/{k}'] = RunningMeter()
67 | ####accumulate loss
68 |
69 | for k,v in loss_dict.items():
70 | loss_moving_averagetors[f'loss_{name}/{k}'](v)
71 |
72 |
73 | global_step += 1
74 | # learning rate scheduling
75 | lr_ratio = get_lr_sched(global_step, run_cfg)
76 |
77 | for param_group in optimizer.param_groups:
78 | param_group['lr'] = param_group['init_lr'] * lr_ratio
79 |
80 | if global_step % 50 == 0:
81 | LOGGER.info({name : averagetor.val for name, averagetor in loss_moving_averagetors.items()})
82 |
83 | # update model params
84 |
85 |
86 | if run_cfg.fp16:
87 | optimizer.zero_grad()
88 | scaler.scale(loss).backward()
89 | else:
90 | loss.backward()
91 |
92 | if not run_cfg.use_ddp:
93 | works = []
94 | for p in model.parameters():
95 | # to speed it up, you can also organize grads to larger buckets to make allreduce more efficient
96 | if p.grad is not None:
97 | works.append(dist.all_reduce(p.grad, async_op=True))
98 | for work in works:
99 | work.wait()
100 |
101 |
102 | # if run_cfg.grad_norm != -1:
103 | # grad_norm = clip_grad_norm_(model.parameters(), run_cfg.grad_norm)
104 |
105 | if run_cfg.fp16:
106 | scaler.step(optimizer)
107 | scaler.update()
108 | else:
109 | optimizer.step()
110 | optimizer.zero_grad()
111 | pbar.update(1)
112 |
113 |
114 |
115 | if (global_step+1) % run_cfg.valid_steps == 0:
116 | eval_log = evaluate_fn(model, val_loaders, run_cfg, global_step)
117 |
118 | if dist.get_rank() == 0:
119 | for task_name, val_log in eval_log.items():
120 | for eval_name, metric in val_log.items():
121 | eval_name = task_name +'_' +eval_name
122 | metric_logger_dict[eval_name][str(global_step)] = metric
123 | LOGGER.info(f"====-evaluation--{eval_name}=====step {global_step}--===========\n")
124 | LOGGER.info(metric)
125 | best_name = get_best_name(eval_name, metric)
126 | if best_name is not None:
127 | if ('best_step' not in metric_logger_dict[eval_name]) or \
128 | (metric[best_name] >= metric_logger_dict[eval_name]['best_value']):
129 | metric_logger_dict[eval_name]['best_step'] = global_step
130 | metric_logger_dict[eval_name]['best_value'] = metric[best_name]
131 | best_indicator[eval_name] = True
132 | else:
133 | best_indicator[eval_name] = False
134 | best_step = metric_logger_dict[eval_name]['best_step']
135 | LOGGER.info(f"======evaluation--{eval_name}====history best step: {best_step}=======\n")
136 | LOGGER.info(metric_logger_dict[eval_name][str(best_step)])
137 |
138 | model_saver.save(model, global_step, optimizer,best_indicator, run_cfg.save_best)
139 |
140 |
141 | if global_step >= run_cfg.num_train_steps:
142 | break
143 | pbar.close()
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 | def test(model, test_loader, run_cfg):
153 |
154 | evaluate_fn = evaluation_registry[model.config.evaluation_type]
155 | eval_log = evaluate_fn(model, test_loader, run_cfg, global_step=0)
156 | if dist.get_rank()==0:
157 | for task_name, val_log in eval_log.items():
158 | for eval_name, metric in val_log.items():
159 | eval_name = task_name +'_' +eval_name
160 | # TB_LOGGER.log_scaler_dict({f"eval/{eval_name}/test_{k}": v
161 | # for k, v in metric.items() if not isinstance(v,str)})
162 | LOGGER.info(f"==== evaluation--{eval_name}========\n")
163 | LOGGER.info(metric)
164 |
165 |
166 |
167 |
168 | def get_best_name(eval_name, metric):
169 | if eval_name.startswith('cap'):
170 | return 'CIDEr'
171 | elif eval_name.startswith('qa'):
172 | return 'accuracy'
173 | elif eval_name.startswith('ret'):
174 | if 'video_r1' in metric:
175 | return 'video_r1'
176 | elif eval_name.startswith('pt'):
177 | return None
178 | else:
179 | raise NotImplementedError
180 |
181 |
--------------------------------------------------------------------------------
/data/utils/save.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from os.path import join
4 | from utils.logger import LOGGER
5 |
6 |
7 |
8 |
9 | class ModelSaver(object):
10 | def __init__(self, output_dir, prefix='model_step', suffix='pt',remove_before_ckpt=True):
11 | self.output_dir = output_dir
12 | self.prefix = prefix
13 | self.suffix = suffix
14 | self.remove_before_ckpt = remove_before_ckpt
15 | def save(self, model, step, optimizer=None, best_indicator=None, save_best=False):
16 | ###remove previous model
17 | previous_state = [i for i in os.listdir(self.output_dir) if i.startswith('model')]
18 | # if not self.pretraining:
19 | if self.remove_before_ckpt:
20 | for p in previous_state:
21 | os.remove(os.path.join(self.output_dir,p))
22 | output_model_file = join(self.output_dir,
23 | f"{self.prefix}_{step}.{self.suffix}")
24 | state_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v
25 | for k, v in model.state_dict().items()}
26 | torch.save(state_dict, output_model_file)
27 |
28 | if save_best:
29 | for k in best_indicator:
30 | if best_indicator[k]:
31 | torch.save(state_dict, join(self.output_dir,
32 | f"best_{k}.{self.suffix}"))
33 |
34 | if optimizer is not None:
35 | if hasattr(optimizer, '_amp_stash'):
36 | pass # TODO fp16 optimizer
37 | previous_state = [i for i in os.listdir(self.output_dir) if i.startswith('optimizer')]
38 | if self.remove_before_ckpt:
39 | for p in previous_state:
40 | os.remove(os.path.join(self.output_dir,p))
41 | torch.save(optimizer.state_dict(), f'{self.output_dir}/optimizer_step_{step}.pt')
42 |
--------------------------------------------------------------------------------
/data/utils/sched.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | def warmup_cosine(x, warmup_ratio):
4 | if x < warmup_ratio:
5 | return x/warmup_ratio
6 | return 0.5 * (1.0 + math.cos(math.pi * x))
7 |
8 | def warmup_constant(x, warmup_ratio):
9 | """ Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps.
10 | Learning rate is 1. afterwards. """
11 | if x < warmup_ratio:
12 | return x/warmup_ratio
13 | return 1.0
14 |
15 | def warmup_linear(x, warmup_ratio):
16 | """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step.
17 | After `t_total`-th training step, learning rate is zero. """
18 | if x < warmup_ratio:
19 | return x/warmup_ratio
20 | return max((x-1.)/(warmup_ratio-1.), 0)
21 |
22 | scheduler_dict = {'warmup_linear' : warmup_linear,
23 | 'warmup_cosine' : warmup_cosine}
24 |
25 | def get_lr_sched(global_step, opts):
26 | warmup_ratio = opts.warmup_ratio
27 | current_ratio = global_step / opts.num_train_steps
28 | lr_ratio = scheduler_dict[opts.scheduler](current_ratio, warmup_ratio)
29 | return lr_ratio
30 |
31 |
32 |
--------------------------------------------------------------------------------
/data/utils/tool.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class NoOp(object):
4 | """ useful for distributed training No-Ops """
5 | def __getattr__(self, name):
6 | return self.noop
7 |
8 | def noop(self, *args, **kwargs):
9 | return
10 |
11 |
12 |
13 |
14 | def split(frame_name_lists, sample_num):
15 | if len(frame_name_lists) < sample_num: ###padding with the last frame
16 | frame_name_lists += [frame_name_lists[-1]]*(sample_num - len(frame_name_lists))
17 | k, m = divmod(len(frame_name_lists), sample_num)
18 | return [frame_name_lists[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in list(range(sample_num))]
19 |
20 |
21 |
22 |
--------------------------------------------------------------------------------
/example/test.flac:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/invictus717/MiCo/831847ff066296d185c7f304563326224b9dbc8b/example/test.flac
--------------------------------------------------------------------------------
/example/test.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/invictus717/MiCo/831847ff066296d185c7f304563326224b9dbc8b/example/test.jpeg
--------------------------------------------------------------------------------
/example/test.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/invictus717/MiCo/831847ff066296d185c7f304563326224b9dbc8b/example/test.mp4
--------------------------------------------------------------------------------
/model/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/invictus717/MiCo/831847ff066296d185c7f304563326224b9dbc8b/model/.DS_Store
--------------------------------------------------------------------------------
/model/audioprocessor.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 |
4 | import torch
5 | import torchaudio
6 |
7 |
8 | def split(frame_name_lists, sample_num):
9 | if len(frame_name_lists) < sample_num: ###padding with the last frame
10 | frame_name_lists += [frame_name_lists[-1]]*(sample_num - len(frame_name_lists))
11 | k, m = divmod(len(frame_name_lists), sample_num)
12 | return [frame_name_lists[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in list(range(sample_num))]
13 |
14 |
15 | class AudioProcessor(object):
16 | def __init__(self, melbins, target_length, sample_num, frame_shift=10, resize_melbin_num=224, mean=15.41663, std=6.55582, training=True):
17 | self.melbins = melbins
18 | self.target_length = target_length
19 | self.training = training
20 | self.frame_shift = frame_shift
21 | self.sample_num = sample_num
22 | self.resize_melbin_num = resize_melbin_num
23 |
24 | self.mean = mean
25 | self.std = std
26 |
27 | def __call__(self, wav_file):
28 |
29 | if not os.path.exists(wav_file):
30 | print('not have audios', wav_file)
31 | return torch.zeros(self.sample_num, self.target_length, self.melbins)
32 |
33 | try:
34 | waveform, sr = torchaudio.load(wav_file)
35 | if sr != 16000:
36 | trans = torchaudio.transforms.Resample(sr, 16000)
37 | waveform = trans(waveform)
38 |
39 | waveform = waveform * 2 ** 15
40 | fbank = torchaudio.compliance.kaldi.fbank(waveform, num_mel_bins=self.melbins, sample_frequency=16000, frame_length=25, frame_shift=10)
41 |
42 | if fbank.size(1) != self.resize_melbin_num:
43 | fbank = torch.nn.functional.interpolate(fbank.reshape(1, 1, *fbank.shape[-2:]), size=(fbank.size(0), self.resize_melbin_num), mode='bilinear').reshape(fbank.size(0), self.resize_melbin_num)
44 |
45 | # ### normalization
46 | fbank = (fbank - self.mean) / (self.std * 2)
47 | src_length = fbank.shape[0]
48 | # #### sample
49 |
50 | output_slices = []
51 | pad_len = max(self.target_length * self.sample_num -src_length, self.target_length - src_length%self.target_length)
52 | fbank = torch.nn.ZeroPad2d((0, 0, 0, pad_len))(fbank)
53 |
54 | total_slice_num = fbank.shape[0] // self.target_length
55 | total_slice_num = list(range(total_slice_num))
56 | total_slice_num = split(total_slice_num, self.sample_num)
57 |
58 | if self.training:
59 | sample_idx = [random.choice(i) for i in total_slice_num]
60 | else:
61 | sample_idx = [i[(len(i)+1)//2-1] for i in total_slice_num]
62 |
63 |
64 | for i in sample_idx:
65 | cur_bank = fbank[i*self.target_length : (i+1)*self.target_length]
66 | output_slices.append(cur_bank)
67 |
68 |
69 |
70 | fbank = torch.stack(output_slices,dim=0) ### n, 1024, 128
71 |
72 |
73 | return fbank
74 |
75 |
76 | except Exception as e:
77 | print(e)
78 | return
79 |
80 |
81 | if __name__ == "__main__":
82 | wav_file = "./data/test.flac"
83 | proc = AudioProcessor(melbins=224, target_length=224, sample_num=4, training=True)
84 | audio_input = proc(wav_file)
85 | print(audio_input.size())
--------------------------------------------------------------------------------
/model/bert-base-uncased-crossattn/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "_name_or_path": "bert-base-uncased",
3 | "add_cross_attention": true,
4 | "architectures": [
5 | "BertForMaskedLM"
6 | ],
7 | "attention_probs_dropout_prob": 0.1,
8 | "classifier_dropout": null,
9 | "gradient_checkpointing": false,
10 | "hidden_act": "gelu",
11 | "hidden_dropout_prob": 0.1,
12 | "hidden_size": 768,
13 | "initializer_range": 0.02,
14 | "intermediate_size": 3072,
15 | "is_decoder": true,
16 | "layer_norm_eps": 1e-12,
17 | "max_position_embeddings": 512,
18 | "model_type": "bert",
19 | "num_attention_heads": 12,
20 | "num_hidden_layers": 12,
21 | "pad_token_id": 0,
22 | "position_embedding_type": "absolute",
23 | "torch_dtype": "float32",
24 | "transformers_version": "4.26.1",
25 | "type_vocab_size": 2,
26 | "use_cache": true,
27 | "vocab_size": 30522
28 | }
29 |
--------------------------------------------------------------------------------
/model/bert-base-uncased-crossattn/generation_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "_from_model_config": true,
3 | "pad_token_id": 0,
4 | "transformers_version": "4.26.1"
5 | }
6 |
--------------------------------------------------------------------------------
/model/clip/clip_tokenizer.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import html
3 | import os
4 | from functools import lru_cache
5 | from typing import Any, Union, List
6 | from pkg_resources import packaging
7 | import ftfy
8 | import regex as re
9 | import torch
10 |
11 |
12 | @lru_cache()
13 | def default_bpe():
14 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
15 |
16 |
17 | @lru_cache()
18 | def bytes_to_unicode():
19 | """
20 | Returns list of utf-8 byte and a corresponding list of unicode strings.
21 | The reversible bpe codes work on unicode strings.
22 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
23 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
24 | This is a signficant percentage of your normal, say, 32K bpe vocab.
25 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
26 | And avoids mapping to whitespace/control characters the bpe code barfs on.
27 | """
28 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
29 | cs = bs[:]
30 | n = 0
31 | for b in range(2**8):
32 | if b not in bs:
33 | bs.append(b)
34 | cs.append(2**8+n)
35 | n += 1
36 | cs = [chr(n) for n in cs]
37 | return dict(zip(bs, cs))
38 |
39 |
40 | def get_pairs(word):
41 | """Return set of symbol pairs in a word.
42 | Word is represented as tuple of symbols (symbols being variable-length strings).
43 | """
44 | pairs = set()
45 | prev_char = word[0]
46 | for char in word[1:]:
47 | pairs.add((prev_char, char))
48 | prev_char = char
49 | return pairs
50 |
51 |
52 | def basic_clean(text):
53 | text = ftfy.fix_text(text)
54 | text = html.unescape(html.unescape(text))
55 | return text.strip()
56 |
57 |
58 | def whitespace_clean(text):
59 | text = re.sub(r'\s+', ' ', text)
60 | text = text.strip()
61 | return text
62 |
63 |
64 | class SimpleTokenizer(object):
65 | def __init__(self, bpe_path: str = default_bpe()):
66 | self.byte_encoder = bytes_to_unicode()
67 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
68 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
69 | merges = merges[1:49152-256-2+1]
70 | merges = [tuple(merge.split()) for merge in merges]
71 | vocab = list(bytes_to_unicode().values())
72 | vocab = vocab + [v+'' for v in vocab]
73 | for merge in merges:
74 | vocab.append(''.join(merge))
75 | vocab.extend(['<|startoftext|>', '<|endoftext|>'])
76 | self.encoder = dict(zip(vocab, range(len(vocab))))
77 | self.decoder = {v: k for k, v in self.encoder.items()}
78 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
79 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
80 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
81 |
82 | def bpe(self, token):
83 | if token in self.cache:
84 | return self.cache[token]
85 | word = tuple(token[:-1]) + ( token[-1] + '',)
86 | pairs = get_pairs(word)
87 |
88 | if not pairs:
89 | return token+''
90 |
91 | while True:
92 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
93 | if bigram not in self.bpe_ranks:
94 | break
95 | first, second = bigram
96 | new_word = []
97 | i = 0
98 | while i < len(word):
99 | try:
100 | j = word.index(first, i)
101 | new_word.extend(word[i:j])
102 | i = j
103 | except:
104 | new_word.extend(word[i:])
105 | break
106 |
107 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
108 | new_word.append(first+second)
109 | i += 2
110 | else:
111 | new_word.append(word[i])
112 | i += 1
113 | new_word = tuple(new_word)
114 | word = new_word
115 | if len(word) == 1:
116 | break
117 | else:
118 | pairs = get_pairs(word)
119 | word = ' '.join(word)
120 | self.cache[token] = word
121 | return word
122 |
123 | def encode(self, text):
124 | bpe_tokens = []
125 | text = whitespace_clean(basic_clean(text)).lower()
126 | for token in re.findall(self.pat, text):
127 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
128 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
129 | return bpe_tokens
130 |
131 | def decode(self, tokens):
132 | text = ''.join([self.decoder[token] for token in tokens])
133 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
134 | return text
135 |
136 |
137 |
138 | _tokenizer = SimpleTokenizer()
139 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
140 | """
141 | Returns the tokenized representation of given input string(s)
142 |
143 | Parameters
144 | ----------
145 | texts : Union[str, List[str]]
146 | An input string or a list of input strings to tokenize
147 |
148 | context_length : int
149 | The context length to use; all CLIP models use 77 as the context length
150 |
151 | truncate: bool
152 | Whether to truncate the text in case its encoding is longer than the context length
153 |
154 | Returns
155 | -------
156 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
157 | We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
158 | """
159 | if isinstance(texts, str):
160 | texts = [texts]
161 |
162 | sot_token = _tokenizer.encoder["<|startoftext|>"]
163 | eot_token = _tokenizer.encoder["<|endoftext|>"]
164 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
165 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
166 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
167 | else:
168 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
169 |
170 | for i, tokens in enumerate(all_tokens):
171 | if len(tokens) > context_length:
172 | if truncate:
173 | tokens = tokens[:context_length]
174 | tokens[-1] = eot_token
175 | else:
176 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
177 | result[i, :len(tokens)] = torch.tensor(tokens)
178 |
179 | return result
180 |
--------------------------------------------------------------------------------
/model/evaclip/__init__.py:
--------------------------------------------------------------------------------
1 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
2 | from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer
3 | from .factory import list_models, add_model_config, get_model_config, load_checkpoint
4 | # from .loss import ClipLoss
5 | from .model import CLIP, CustomCLIP, CLIPTextCfg, CLIPVisionCfg,\
6 | convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
7 | from .openai import load_openai_model, list_openai_models
8 | from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model,\
9 | get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
10 | from .tokenizer import SimpleTokenizer, tokenize
11 | from .transform import image_transform
--------------------------------------------------------------------------------
/model/evaclip/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/invictus717/MiCo/831847ff066296d185c7f304563326224b9dbc8b/model/evaclip/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/model/evaclip/constants.py:
--------------------------------------------------------------------------------
1 | OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
2 | OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
3 |
--------------------------------------------------------------------------------
/model/evaclip/hf_configs.py:
--------------------------------------------------------------------------------
1 | # HF architecture dict:
2 | arch_dict = {
3 | # https://huggingface.co/docs/transformers/model_doc/roberta#roberta
4 | "roberta": {
5 | "config_names": {
6 | "context_length": "max_position_embeddings",
7 | "vocab_size": "vocab_size",
8 | "width": "hidden_size",
9 | "heads": "num_attention_heads",
10 | "layers": "num_hidden_layers",
11 | "layer_attr": "layer",
12 | "token_embeddings_attr": "embeddings"
13 | },
14 | "pooler": "mean_pooler",
15 | },
16 | # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
17 | "xlm-roberta": {
18 | "config_names": {
19 | "context_length": "max_position_embeddings",
20 | "vocab_size": "vocab_size",
21 | "width": "hidden_size",
22 | "heads": "num_attention_heads",
23 | "layers": "num_hidden_layers",
24 | "layer_attr": "layer",
25 | "token_embeddings_attr": "embeddings"
26 | },
27 | "pooler": "mean_pooler",
28 | },
29 | # https://huggingface.co/docs/transformers/model_doc/mt5#mt5
30 | "mt5": {
31 | "config_names": {
32 | # unlimited seqlen
33 | # https://github.com/google-research/text-to-text-transfer-transformer/issues/273
34 | # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
35 | "context_length": "",
36 | "vocab_size": "vocab_size",
37 | "width": "d_model",
38 | "heads": "num_heads",
39 | "layers": "num_layers",
40 | "layer_attr": "block",
41 | "token_embeddings_attr": "embed_tokens"
42 | },
43 | "pooler": "mean_pooler",
44 | },
45 | "bert": {
46 | "config_names": {
47 | "context_length": "max_position_embeddings",
48 | "vocab_size": "vocab_size",
49 | "width": "hidden_size",
50 | "heads": "num_attention_heads",
51 | "layers": "num_hidden_layers",
52 | "layer_attr": "layer",
53 | "token_embeddings_attr": "embeddings"
54 | },
55 | "pooler": "mean_pooler",
56 | }
57 | }
58 |
--------------------------------------------------------------------------------
/model/evaclip/loss.py:
--------------------------------------------------------------------------------
1 | # import math
2 | # import torch
3 | # import torch.nn as nn
4 | # from torch.nn import functional as F
5 |
6 | # try:
7 | # import torch.distributed.nn
8 | # from torch import distributed as dist
9 | # has_distributed = True
10 | # except ImportError:
11 | # has_distributed = False
12 |
13 | # try:
14 | # import horovod.torch as hvd
15 | # except ImportError:
16 | # hvd = None
17 |
18 | # from timm.loss import LabelSmoothingCrossEntropy
19 |
20 |
21 | # def gather_features(
22 | # image_features,
23 | # text_features,
24 | # local_loss=False,
25 | # gather_with_grad=False,
26 | # rank=0,
27 | # world_size=1,
28 | # use_horovod=False
29 | # ):
30 | # assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
31 | # if use_horovod:
32 | # assert hvd is not None, 'Please install horovod'
33 | # if gather_with_grad:
34 | # all_image_features = hvd.allgather(image_features)
35 | # all_text_features = hvd.allgather(text_features)
36 | # else:
37 | # with torch.no_grad():
38 | # all_image_features = hvd.allgather(image_features)
39 | # all_text_features = hvd.allgather(text_features)
40 | # if not local_loss:
41 | # # ensure grads for local rank when all_* features don't have a gradient
42 | # gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
43 | # gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
44 | # gathered_image_features[rank] = image_features
45 | # gathered_text_features[rank] = text_features
46 | # all_image_features = torch.cat(gathered_image_features, dim=0)
47 | # all_text_features = torch.cat(gathered_text_features, dim=0)
48 | # else:
49 | # # We gather tensors from all gpus
50 | # if gather_with_grad:
51 | # all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
52 | # all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
53 | # # all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features, async_op=True), dim=0)
54 | # # all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features, async_op=True), dim=0)
55 | # else:
56 | # gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
57 | # gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
58 | # dist.all_gather(gathered_image_features, image_features)
59 | # dist.all_gather(gathered_text_features, text_features)
60 | # if not local_loss:
61 | # # ensure grads for local rank when all_* features don't have a gradient
62 | # gathered_image_features[rank] = image_features
63 | # gathered_text_features[rank] = text_features
64 | # all_image_features = torch.cat(gathered_image_features, dim=0)
65 | # all_text_features = torch.cat(gathered_text_features, dim=0)
66 |
67 | # return all_image_features, all_text_features
68 |
69 |
70 | # class ClipLoss(nn.Module):
71 |
72 | # def __init__(
73 | # self,
74 | # local_loss=False,
75 | # gather_with_grad=False,
76 | # cache_labels=False,
77 | # rank=0,
78 | # world_size=1,
79 | # use_horovod=False,
80 | # smoothing=0.,
81 | # ):
82 | # super().__init__()
83 | # self.local_loss = local_loss
84 | # self.gather_with_grad = gather_with_grad
85 | # self.cache_labels = cache_labels
86 | # self.rank = rank
87 | # self.world_size = world_size
88 | # self.use_horovod = use_horovod
89 | # self.label_smoothing_cross_entropy = LabelSmoothingCrossEntropy(smoothing=smoothing) if smoothing > 0 else None
90 |
91 | # # cache state
92 | # self.prev_num_logits = 0
93 | # self.labels = {}
94 |
95 | # def forward(self, image_features, text_features, logit_scale=1.):
96 | # device = image_features.device
97 | # if self.world_size > 1:
98 | # all_image_features, all_text_features = gather_features(
99 | # image_features, text_features,
100 | # self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
101 |
102 | # if self.local_loss:
103 | # logits_per_image = logit_scale * image_features @ all_text_features.T
104 | # logits_per_text = logit_scale * text_features @ all_image_features.T
105 | # else:
106 | # logits_per_image = logit_scale * all_image_features @ all_text_features.T
107 | # logits_per_text = logits_per_image.T
108 | # else:
109 | # logits_per_image = logit_scale * image_features @ text_features.T
110 | # logits_per_text = logit_scale * text_features @ image_features.T
111 | # # calculated ground-truth and cache if enabled
112 | # num_logits = logits_per_image.shape[0]
113 | # if self.prev_num_logits != num_logits or device not in self.labels:
114 | # labels = torch.arange(num_logits, device=device, dtype=torch.long)
115 | # if self.world_size > 1 and self.local_loss:
116 | # labels = labels + num_logits * self.rank
117 | # if self.cache_labels:
118 | # self.labels[device] = labels
119 | # self.prev_num_logits = num_logits
120 | # else:
121 | # labels = self.labels[device]
122 |
123 | # if self.label_smoothing_cross_entropy:
124 | # total_loss = (
125 | # self.label_smoothing_cross_entropy(logits_per_image, labels) +
126 | # self.label_smoothing_cross_entropy(logits_per_text, labels)
127 | # ) / 2
128 | # else:
129 | # total_loss = (
130 | # F.cross_entropy(logits_per_image, labels) +
131 | # F.cross_entropy(logits_per_text, labels)
132 | # ) / 2
133 |
134 | # acc = None
135 | # i2t_acc = (logits_per_image.argmax(-1) == labels).sum() / len(logits_per_image)
136 | # t2i_acc = (logits_per_text.argmax(-1) == labels).sum() / len(logits_per_text)
137 | # acc = {"i2t": i2t_acc, "t2i": t2i_acc}
138 | # return total_loss, acc
--------------------------------------------------------------------------------
/model/evaclip/model_configs/EVA01-CLIP-B-16.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 768,
7 | "patch_size": 16,
8 | "eva_model_name": "eva-clip-b-16",
9 | "ls_init_value": 0.1,
10 | "drop_path_rate": 0.0
11 | },
12 | "text_cfg": {
13 | "context_length": 77,
14 | "vocab_size": 49408,
15 | "width": 512,
16 | "heads": 8,
17 | "layers": 12
18 | }
19 | }
--------------------------------------------------------------------------------
/model/evaclip/model_configs/EVA01-CLIP-g-14-plus.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 40,
6 | "width": 1408,
7 | "head_width": 88,
8 | "mlp_ratio": 4.3637,
9 | "patch_size": 14,
10 | "eva_model_name": "eva-clip-g-14-x",
11 | "drop_path_rate": 0,
12 | "xattn": true,
13 | "fusedLN": true
14 | },
15 | "text_cfg": {
16 | "context_length": 77,
17 | "vocab_size": 49408,
18 | "width": 1024,
19 | "heads": 16,
20 | "layers": 24,
21 | "xattn": false,
22 | "fusedLN": true
23 | }
24 | }
--------------------------------------------------------------------------------
/model/evaclip/model_configs/EVA01-CLIP-g-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 40,
6 | "width": 1408,
7 | "head_width": 88,
8 | "mlp_ratio": 4.3637,
9 | "patch_size": 14,
10 | "eva_model_name": "eva-clip-g-14-x",
11 | "drop_path_rate": 0.4,
12 | "xattn": true,
13 | "fusedLN": true
14 | },
15 | "text_cfg": {
16 | "context_length": 77,
17 | "vocab_size": 49408,
18 | "width": 768,
19 | "heads": 12,
20 | "layers": 12,
21 | "xattn": false,
22 | "fusedLN": true
23 | }
24 | }
--------------------------------------------------------------------------------
/model/evaclip/model_configs/EVA02-CLIP-B-16.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 768,
7 | "head_width": 64,
8 | "patch_size": 16,
9 | "mlp_ratio": 2.6667,
10 | "eva_model_name": "eva-clip-b-16-X",
11 | "drop_path_rate": 0.0,
12 | "xattn": true,
13 | "fusedLN": true,
14 | "rope": true,
15 | "pt_hw_seq_len": 16,
16 | "intp_freq": true,
17 | "naiveswiglu": true,
18 | "subln": true
19 | },
20 | "text_cfg": {
21 | "context_length": 77,
22 | "vocab_size": 49408,
23 | "width": 512,
24 | "heads": 8,
25 | "layers": 12,
26 | "xattn": true,
27 | "fusedLN": true
28 | }
29 | }
--------------------------------------------------------------------------------
/model/evaclip/model_configs/EVA02-CLIP-L-14-336.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 336,
5 | "layers": 24,
6 | "width": 1024,
7 | "drop_path_rate": 0,
8 | "head_width": 64,
9 | "mlp_ratio": 2.6667,
10 | "patch_size": 14,
11 | "eva_model_name": "eva-clip-l-14-336",
12 | "xattn": true,
13 | "fusedLN": true,
14 | "rope": true,
15 | "pt_hw_seq_len": 16,
16 | "intp_freq": true,
17 | "naiveswiglu": true,
18 | "subln": true
19 | },
20 | "text_cfg": {
21 | "context_length": 77,
22 | "vocab_size": 49408,
23 | "width": 768,
24 | "heads": 12,
25 | "layers": 12,
26 | "xattn": false,
27 | "fusedLN": true
28 | }
29 | }
--------------------------------------------------------------------------------
/model/evaclip/model_configs/EVA02-CLIP-L-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 24,
6 | "width": 1024,
7 | "drop_path_rate": 0,
8 | "head_width": 64,
9 | "mlp_ratio": 2.6667,
10 | "patch_size": 14,
11 | "eva_model_name": "eva-clip-l-14",
12 | "xattn": true,
13 | "fusedLN": true,
14 | "rope": true,
15 | "pt_hw_seq_len": 16,
16 | "intp_freq": true,
17 | "naiveswiglu": true,
18 | "subln": true
19 | },
20 | "text_cfg": {
21 | "context_length": 77,
22 | "vocab_size": 49408,
23 | "width": 768,
24 | "heads": 12,
25 | "layers": 12,
26 | "xattn": false,
27 | "fusedLN": true
28 | }
29 | }
--------------------------------------------------------------------------------
/model/evaclip/model_configs/EVA02-CLIP-bigE-14-plus.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 64,
6 | "width": 1792,
7 | "head_width": 112,
8 | "mlp_ratio": 8.571428571428571,
9 | "patch_size": 14,
10 | "eva_model_name": "eva-clip-4b-14-x",
11 | "drop_path_rate": 0,
12 | "xattn": true,
13 | "postnorm": true,
14 | "fusedLN": true
15 | },
16 | "text_cfg": {
17 | "context_length": 77,
18 | "vocab_size": 49408,
19 | "width": 1280,
20 | "heads": 20,
21 | "layers": 32,
22 | "xattn": false,
23 | "fusedLN": true
24 | }
25 | }
26 |
--------------------------------------------------------------------------------
/model/evaclip/model_configs/EVA02-CLIP-bigE-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 64,
6 | "width": 1792,
7 | "head_width": 112,
8 | "mlp_ratio": 8.571428571428571,
9 | "patch_size": 14,
10 | "eva_model_name": "eva-clip-4b-14-x",
11 | "drop_path_rate": 0,
12 | "xattn": true,
13 | "postnorm": true,
14 | "fusedLN": true
15 | },
16 | "text_cfg": {
17 | "context_length": 77,
18 | "vocab_size": 49408,
19 | "width": 1024,
20 | "heads": 16,
21 | "layers": 24,
22 | "xattn": false,
23 | "fusedLN": true
24 | }
25 | }
--------------------------------------------------------------------------------
/model/evaclip/modified_resnet.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 | from torch import nn
5 | from torch.nn import functional as F
6 |
7 | from .utils import freeze_batch_norm_2d
8 |
9 |
10 | class Bottleneck(nn.Module):
11 | expansion = 4
12 |
13 | def __init__(self, inplanes, planes, stride=1):
14 | super().__init__()
15 |
16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18 | self.bn1 = nn.BatchNorm2d(planes)
19 | self.act1 = nn.ReLU(inplace=True)
20 |
21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
22 | self.bn2 = nn.BatchNorm2d(planes)
23 | self.act2 = nn.ReLU(inplace=True)
24 |
25 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
26 |
27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
29 | self.act3 = nn.ReLU(inplace=True)
30 |
31 | self.downsample = None
32 | self.stride = stride
33 |
34 | if stride > 1 or inplanes != planes * Bottleneck.expansion:
35 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
36 | self.downsample = nn.Sequential(OrderedDict([
37 | ("-1", nn.AvgPool2d(stride)),
38 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
39 | ("1", nn.BatchNorm2d(planes * self.expansion))
40 | ]))
41 |
42 | def forward(self, x: torch.Tensor):
43 | identity = x
44 |
45 | out = self.act1(self.bn1(self.conv1(x)))
46 | out = self.act2(self.bn2(self.conv2(out)))
47 | out = self.avgpool(out)
48 | out = self.bn3(self.conv3(out))
49 |
50 | if self.downsample is not None:
51 | identity = self.downsample(x)
52 |
53 | out += identity
54 | out = self.act3(out)
55 | return out
56 |
57 |
58 | class AttentionPool2d(nn.Module):
59 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
60 | super().__init__()
61 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
62 | self.k_proj = nn.Linear(embed_dim, embed_dim)
63 | self.q_proj = nn.Linear(embed_dim, embed_dim)
64 | self.v_proj = nn.Linear(embed_dim, embed_dim)
65 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
66 | self.num_heads = num_heads
67 |
68 | def forward(self, x):
69 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
70 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
71 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
72 | x, _ = F.multi_head_attention_forward(
73 | query=x, key=x, value=x,
74 | embed_dim_to_check=x.shape[-1],
75 | num_heads=self.num_heads,
76 | q_proj_weight=self.q_proj.weight,
77 | k_proj_weight=self.k_proj.weight,
78 | v_proj_weight=self.v_proj.weight,
79 | in_proj_weight=None,
80 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
81 | bias_k=None,
82 | bias_v=None,
83 | add_zero_attn=False,
84 | dropout_p=0.,
85 | out_proj_weight=self.c_proj.weight,
86 | out_proj_bias=self.c_proj.bias,
87 | use_separate_proj_weight=True,
88 | training=self.training,
89 | need_weights=False
90 | )
91 |
92 | return x[0]
93 |
94 |
95 | class ModifiedResNet(nn.Module):
96 | """
97 | A ResNet class that is similar to torchvision's but contains the following changes:
98 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
99 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
100 | - The final pooling layer is a QKV attention instead of an average pool
101 | """
102 |
103 | def __init__(self, layers, output_dim, heads, image_size=224, width=64):
104 | super().__init__()
105 | self.output_dim = output_dim
106 | self.image_size = image_size
107 |
108 | # the 3-layer stem
109 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
110 | self.bn1 = nn.BatchNorm2d(width // 2)
111 | self.act1 = nn.ReLU(inplace=True)
112 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
113 | self.bn2 = nn.BatchNorm2d(width // 2)
114 | self.act2 = nn.ReLU(inplace=True)
115 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
116 | self.bn3 = nn.BatchNorm2d(width)
117 | self.act3 = nn.ReLU(inplace=True)
118 | self.avgpool = nn.AvgPool2d(2)
119 |
120 | # residual layers
121 | self._inplanes = width # this is a *mutable* variable used during construction
122 | self.layer1 = self._make_layer(width, layers[0])
123 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
124 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
125 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
126 |
127 | embed_dim = width * 32 # the ResNet feature dimension
128 | self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
129 |
130 | self.init_parameters()
131 |
132 | def _make_layer(self, planes, blocks, stride=1):
133 | layers = [Bottleneck(self._inplanes, planes, stride)]
134 |
135 | self._inplanes = planes * Bottleneck.expansion
136 | for _ in range(1, blocks):
137 | layers.append(Bottleneck(self._inplanes, planes))
138 |
139 | return nn.Sequential(*layers)
140 |
141 | def init_parameters(self):
142 | if self.attnpool is not None:
143 | std = self.attnpool.c_proj.in_features ** -0.5
144 | nn.init.normal_(self.attnpool.q_proj.weight, std=std)
145 | nn.init.normal_(self.attnpool.k_proj.weight, std=std)
146 | nn.init.normal_(self.attnpool.v_proj.weight, std=std)
147 | nn.init.normal_(self.attnpool.c_proj.weight, std=std)
148 |
149 | for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
150 | for name, param in resnet_block.named_parameters():
151 | if name.endswith("bn3.weight"):
152 | nn.init.zeros_(param)
153 |
154 | def lock(self, unlocked_groups=0, freeze_bn_stats=False):
155 | assert unlocked_groups == 0, 'partial locking not currently supported for this model'
156 | for param in self.parameters():
157 | param.requires_grad = False
158 | if freeze_bn_stats:
159 | freeze_batch_norm_2d(self)
160 |
161 | @torch.jit.ignore
162 | def set_grad_checkpointing(self, enable=True):
163 | # FIXME support for non-transformer
164 | pass
165 |
166 | def stem(self, x):
167 | x = self.act1(self.bn1(self.conv1(x)))
168 | x = self.act2(self.bn2(self.conv2(x)))
169 | x = self.act3(self.bn3(self.conv3(x)))
170 | x = self.avgpool(x)
171 | return x
172 |
173 | def forward(self, x):
174 | x = self.stem(x)
175 | x = self.layer1(x)
176 | x = self.layer2(x)
177 | x = self.layer3(x)
178 | x = self.layer4(x)
179 | x = self.attnpool(x)
180 |
181 | return x
182 |
--------------------------------------------------------------------------------
/model/evaclip/openai.py:
--------------------------------------------------------------------------------
1 | """ OpenAI pretrained model functions
2 |
3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4 | """
5 |
6 | import os
7 | import warnings
8 | from typing import List, Optional, Union
9 |
10 | import torch
11 |
12 | from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype
13 | from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url
14 |
15 | __all__ = ["list_openai_models", "load_openai_model"]
16 |
17 |
18 | def list_openai_models() -> List[str]:
19 | """Returns the names of available CLIP models"""
20 | return list_pretrained_models_by_tag('openai')
21 |
22 |
23 | def load_openai_model(
24 | name: str,
25 | precision: Optional[str] = None,
26 | device: Optional[Union[str, torch.device]] = None,
27 | jit: bool = True,
28 | cache_dir: Optional[str] = None,
29 | ):
30 | """Load a CLIP model
31 |
32 | Parameters
33 | ----------
34 | name : str
35 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
36 | precision: str
37 | Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.
38 | device : Union[str, torch.device]
39 | The device to put the loaded model
40 | jit : bool
41 | Whether to load the optimized JIT model (default) or more hackable non-JIT model.
42 | cache_dir : Optional[str]
43 | The directory to cache the downloaded model weights
44 |
45 | Returns
46 | -------
47 | model : torch.nn.Module
48 | The CLIP model
49 | preprocess : Callable[[PIL.Image], torch.Tensor]
50 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
51 | """
52 | if device is None:
53 | device = "cuda" if torch.cuda.is_available() else "cpu"
54 | if precision is None:
55 | precision = 'fp32' if device == 'cpu' else 'fp16'
56 |
57 | if get_pretrained_url(name, 'openai'):
58 | model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir)
59 | elif os.path.isfile(name):
60 | model_path = name
61 | else:
62 | raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}")
63 |
64 | try:
65 | # loading JIT archive
66 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
67 | state_dict = None
68 | except RuntimeError:
69 | # loading saved state dict
70 | if jit:
71 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
72 | jit = False
73 | state_dict = torch.load(model_path, map_location="cpu")
74 |
75 | if not jit:
76 | # Build a non-jit model from the OpenAI jitted model state dict
77 | cast_dtype = get_cast_dtype(precision)
78 | try:
79 | model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype)
80 | except KeyError:
81 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
82 | model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype)
83 |
84 | # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use
85 | model = model.to(device)
86 | if precision.startswith('amp') or precision == 'fp32':
87 | model.float()
88 | elif precision == 'bf16':
89 | convert_weights_to_lp(model, dtype=torch.bfloat16)
90 |
91 | return model
92 |
93 | # patch the device names
94 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
95 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
96 |
97 | def patch_device(module):
98 | try:
99 | graphs = [module.graph] if hasattr(module, "graph") else []
100 | except RuntimeError:
101 | graphs = []
102 |
103 | if hasattr(module, "forward1"):
104 | graphs.append(module.forward1.graph)
105 |
106 | for graph in graphs:
107 | for node in graph.findAllNodes("prim::Constant"):
108 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
109 | node.copyAttributes(device_node)
110 |
111 | model.apply(patch_device)
112 | patch_device(model.encode_image)
113 | patch_device(model.encode_text)
114 |
115 | # patch dtype to float32 (typically for CPU)
116 | if precision == 'fp32':
117 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
118 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
119 | float_node = float_input.node()
120 |
121 | def patch_float(module):
122 | try:
123 | graphs = [module.graph] if hasattr(module, "graph") else []
124 | except RuntimeError:
125 | graphs = []
126 |
127 | if hasattr(module, "forward1"):
128 | graphs.append(module.forward1.graph)
129 |
130 | for graph in graphs:
131 | for node in graph.findAllNodes("aten::to"):
132 | inputs = list(node.inputs())
133 | for i in [1, 2]: # dtype can be the second or third argument to aten::to()
134 | if inputs[i].node()["value"] == 5:
135 | inputs[i].node().copyAttributes(float_node)
136 |
137 | model.apply(patch_float)
138 | patch_float(model.encode_image)
139 | patch_float(model.encode_text)
140 | model.float()
141 |
142 | # ensure image_size attr available at consistent location for both jit and non-jit
143 | model.visual.image_size = model.input_resolution.item()
144 | return model
145 |
--------------------------------------------------------------------------------
/model/evaclip/rope.py:
--------------------------------------------------------------------------------
1 | from math import pi
2 | import torch
3 | from torch import nn
4 | from einops import rearrange, repeat
5 | import logging
6 |
7 | def broadcat(tensors, dim = -1):
8 | num_tensors = len(tensors)
9 | shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
10 | assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
11 | shape_len = list(shape_lens)[0]
12 | dim = (dim + shape_len) if dim < 0 else dim
13 | dims = list(zip(*map(lambda t: list(t.shape), tensors)))
14 | expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
15 | assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation'
16 | max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
17 | expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
18 | expanded_dims.insert(dim, (dim, dims[dim]))
19 | expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
20 | tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
21 | return torch.cat(tensors, dim = dim)
22 |
23 | def rotate_half(x):
24 | x = rearrange(x, '... (d r) -> ... d r', r = 2)
25 | x1, x2 = x.unbind(dim = -1)
26 | x = torch.stack((-x2, x1), dim = -1)
27 | return rearrange(x, '... d r -> ... (d r)')
28 |
29 |
30 | class VisionRotaryEmbedding(nn.Module):
31 | def __init__(
32 | self,
33 | dim,
34 | pt_seq_len,
35 | ft_seq_len=None,
36 | custom_freqs = None,
37 | freqs_for = 'lang',
38 | theta = 10000,
39 | max_freq = 10,
40 | num_freqs = 1,
41 | ):
42 | super().__init__()
43 | if custom_freqs:
44 | freqs = custom_freqs
45 | elif freqs_for == 'lang':
46 | freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
47 | elif freqs_for == 'pixel':
48 | freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
49 | elif freqs_for == 'constant':
50 | freqs = torch.ones(num_freqs).float()
51 | else:
52 | raise ValueError(f'unknown modality {freqs_for}')
53 |
54 | if ft_seq_len is None: ft_seq_len = pt_seq_len
55 | t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
56 |
57 | freqs_h = torch.einsum('..., f -> ... f', t, freqs)
58 | freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2)
59 |
60 | freqs_w = torch.einsum('..., f -> ... f', t, freqs)
61 | freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2)
62 |
63 | freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1)
64 |
65 | self.register_buffer("freqs_cos", freqs.cos())
66 | self.register_buffer("freqs_sin", freqs.sin())
67 |
68 | logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
69 |
70 | def forward(self, t, start_index = 0):
71 | rot_dim = self.freqs_cos.shape[-1]
72 | end_index = start_index + rot_dim
73 | assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
74 | t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
75 | t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
76 |
77 | return torch.cat((t_left, t, t_right), dim = -1)
78 |
79 | class VisionRotaryEmbeddingFast(nn.Module):
80 | def __init__(
81 | self,
82 | dim,
83 | pt_seq_len,
84 | ft_seq_len=None,
85 | custom_freqs = None,
86 | freqs_for = 'lang',
87 | theta = 10000,
88 | max_freq = 10,
89 | num_freqs = 1,
90 | patch_dropout = 0.
91 | ):
92 | super().__init__()
93 | if custom_freqs:
94 | freqs = custom_freqs
95 | elif freqs_for == 'lang':
96 | freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
97 | elif freqs_for == 'pixel':
98 | freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
99 | elif freqs_for == 'constant':
100 | freqs = torch.ones(num_freqs).float()
101 | else:
102 | raise ValueError(f'unknown modality {freqs_for}')
103 |
104 | if ft_seq_len is None: ft_seq_len = pt_seq_len
105 | t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
106 |
107 | freqs = torch.einsum('..., f -> ... f', t, freqs)
108 | freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
109 | freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1)
110 |
111 | freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
112 | freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
113 |
114 | self.patch_dropout = patch_dropout
115 |
116 | self.register_buffer("freqs_cos", freqs_cos)
117 | self.register_buffer("freqs_sin", freqs_sin)
118 |
119 | logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
120 |
121 | def forward(self, t, patch_indices_keep=None):
122 | if patch_indices_keep is not None:
123 | batch = t.size()[0]
124 | batch_indices = torch.arange(batch)
125 | batch_indices = batch_indices[..., None]
126 |
127 | freqs_cos = repeat(self.freqs_cos, 'i j -> n i m j', n=t.shape[0], m=t.shape[1])
128 | freqs_sin = repeat(self.freqs_sin, 'i j -> n i m j', n=t.shape[0], m=t.shape[1])
129 |
130 | freqs_cos = freqs_cos[batch_indices, patch_indices_keep]
131 | freqs_cos = rearrange(freqs_cos, 'n i m j -> n m i j')
132 | freqs_sin = freqs_sin[batch_indices, patch_indices_keep]
133 | freqs_sin = rearrange(freqs_sin, 'n i m j -> n m i j')
134 |
135 | return t * freqs_cos + rotate_half(t) * freqs_sin
136 |
137 | return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
--------------------------------------------------------------------------------
/model/evaclip/timm_model.py:
--------------------------------------------------------------------------------
1 | """ timm model adapter
2 |
3 | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
4 | """
5 | import logging
6 | from collections import OrderedDict
7 |
8 | import torch
9 | import torch.nn as nn
10 |
11 | try:
12 | import timm
13 | from timm.models.layers import Mlp, to_2tuple
14 | try:
15 | # old timm imports < 0.8.1
16 | from timm.models.layers.attention_pool2d import RotAttentionPool2d
17 | from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d
18 | except ImportError:
19 | # new timm imports >= 0.8.1
20 | from timm.layers import RotAttentionPool2d
21 | from timm.layers import AttentionPool2d as AbsAttentionPool2d
22 | except ImportError:
23 | timm = None
24 |
25 | from .utils import freeze_batch_norm_2d
26 |
27 |
28 | class TimmModel(nn.Module):
29 | """ timm model adapter
30 | # FIXME this adapter is a work in progress, may change in ways that break weight compat
31 | """
32 |
33 | def __init__(
34 | self,
35 | model_name,
36 | embed_dim,
37 | image_size=224,
38 | pool='avg',
39 | proj='linear',
40 | proj_bias=False,
41 | drop=0.,
42 | pretrained=False):
43 | super().__init__()
44 | if timm is None:
45 | raise RuntimeError("Please `pip install timm` to use timm models.")
46 |
47 | self.image_size = to_2tuple(image_size)
48 | self.trunk = timm.create_model(model_name, pretrained=pretrained)
49 | feat_size = self.trunk.default_cfg.get('pool_size', None)
50 | feature_ndim = 1 if not feat_size else 2
51 | if pool in ('abs_attn', 'rot_attn'):
52 | assert feature_ndim == 2
53 | # if attn pooling used, remove both classifier and default pool
54 | self.trunk.reset_classifier(0, global_pool='')
55 | else:
56 | # reset global pool if pool config set, otherwise leave as network default
57 | reset_kwargs = dict(global_pool=pool) if pool else {}
58 | self.trunk.reset_classifier(0, **reset_kwargs)
59 | prev_chs = self.trunk.num_features
60 |
61 | head_layers = OrderedDict()
62 | if pool == 'abs_attn':
63 | head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)
64 | prev_chs = embed_dim
65 | elif pool == 'rot_attn':
66 | head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
67 | prev_chs = embed_dim
68 | else:
69 | assert proj, 'projection layer needed if non-attention pooling is used.'
70 |
71 | # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
72 | if proj == 'linear':
73 | head_layers['drop'] = nn.Dropout(drop)
74 | head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias)
75 | elif proj == 'mlp':
76 | head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop, bias=(True, proj_bias))
77 |
78 | self.head = nn.Sequential(head_layers)
79 |
80 | def lock(self, unlocked_groups=0, freeze_bn_stats=False):
81 | """ lock modules
82 | Args:
83 | unlocked_groups (int): leave last n layer groups unlocked (default: 0)
84 | """
85 | if not unlocked_groups:
86 | # lock full model
87 | for param in self.trunk.parameters():
88 | param.requires_grad = False
89 | if freeze_bn_stats:
90 | freeze_batch_norm_2d(self.trunk)
91 | else:
92 | # NOTE: partial freeze requires latest timm (master) branch and is subject to change
93 | try:
94 | # FIXME import here until API stable and in an official release
95 | from timm.models.helpers import group_parameters, group_modules
96 | except ImportError:
97 | raise RuntimeError(
98 | 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`')
99 | matcher = self.trunk.group_matcher()
100 | gparams = group_parameters(self.trunk, matcher)
101 | max_layer_id = max(gparams.keys())
102 | max_layer_id = max_layer_id - unlocked_groups
103 | for group_idx in range(max_layer_id + 1):
104 | group = gparams[group_idx]
105 | for param in group:
106 | self.trunk.get_parameter(param).requires_grad = False
107 | if freeze_bn_stats:
108 | gmodules = group_modules(self.trunk, matcher, reverse=True)
109 | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
110 | freeze_batch_norm_2d(self.trunk, gmodules)
111 |
112 | @torch.jit.ignore
113 | def set_grad_checkpointing(self, enable=True):
114 | try:
115 | self.trunk.set_grad_checkpointing(enable)
116 | except Exception as e:
117 | logging.warning('grad checkpointing not supported for this timm image tower, continuing without...')
118 |
119 | def forward(self, x):
120 | x = self.trunk(x)
121 | x = self.head(x)
122 | return x
123 |
--------------------------------------------------------------------------------
/model/evaclip/tokenizer.py:
--------------------------------------------------------------------------------
1 | """ CLIP tokenizer
2 |
3 | Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4 | """
5 | import gzip
6 | import html
7 | import os
8 | from functools import lru_cache
9 | from typing import Union, List
10 |
11 | # import ftfy
12 | import regex as re
13 | import torch
14 |
15 | # https://stackoverflow.com/q/62691279
16 | import os
17 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
18 |
19 |
20 | @lru_cache()
21 | def default_bpe():
22 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
23 |
24 |
25 | @lru_cache()
26 | def bytes_to_unicode():
27 | """
28 | Returns list of utf-8 byte and a corresponding list of unicode strings.
29 | The reversible bpe codes work on unicode strings.
30 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
31 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
32 | This is a signficant percentage of your normal, say, 32K bpe vocab.
33 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
34 | And avoids mapping to whitespace/control characters the bpe code barfs on.
35 | """
36 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
37 | cs = bs[:]
38 | n = 0
39 | for b in range(2**8):
40 | if b not in bs:
41 | bs.append(b)
42 | cs.append(2**8+n)
43 | n += 1
44 | cs = [chr(n) for n in cs]
45 | return dict(zip(bs, cs))
46 |
47 |
48 | def get_pairs(word):
49 | """Return set of symbol pairs in a word.
50 | Word is represented as tuple of symbols (symbols being variable-length strings).
51 | """
52 | pairs = set()
53 | prev_char = word[0]
54 | for char in word[1:]:
55 | pairs.add((prev_char, char))
56 | prev_char = char
57 | return pairs
58 |
59 |
60 | def basic_clean(text):
61 | text = ftfy.fix_text(text)
62 | text = html.unescape(html.unescape(text))
63 | return text.strip()
64 |
65 |
66 | def whitespace_clean(text):
67 | text = re.sub(r'\s+', ' ', text)
68 | text = text.strip()
69 | return text
70 |
71 |
72 | class SimpleTokenizer(object):
73 | def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
74 | self.byte_encoder = bytes_to_unicode()
75 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
76 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
77 | merges = merges[1:49152-256-2+1]
78 | merges = [tuple(merge.split()) for merge in merges]
79 | vocab = list(bytes_to_unicode().values())
80 | vocab = vocab + [v+'' for v in vocab]
81 | for merge in merges:
82 | vocab.append(''.join(merge))
83 | if not special_tokens:
84 | special_tokens = ['', '']
85 | else:
86 | special_tokens = ['', ''] + special_tokens
87 | vocab.extend(special_tokens)
88 | self.encoder = dict(zip(vocab, range(len(vocab))))
89 | self.decoder = {v: k for k, v in self.encoder.items()}
90 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
91 | self.cache = {t:t for t in special_tokens}
92 | special = "|".join(special_tokens)
93 | self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
94 |
95 | self.vocab_size = len(self.encoder)
96 | self.all_special_ids = [self.encoder[t] for t in special_tokens]
97 |
98 | def bpe(self, token):
99 | if token in self.cache:
100 | return self.cache[token]
101 | word = tuple(token[:-1]) + ( token[-1] + '',)
102 | pairs = get_pairs(word)
103 |
104 | if not pairs:
105 | return token+''
106 |
107 | while True:
108 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
109 | if bigram not in self.bpe_ranks:
110 | break
111 | first, second = bigram
112 | new_word = []
113 | i = 0
114 | while i < len(word):
115 | try:
116 | j = word.index(first, i)
117 | new_word.extend(word[i:j])
118 | i = j
119 | except:
120 | new_word.extend(word[i:])
121 | break
122 |
123 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
124 | new_word.append(first+second)
125 | i += 2
126 | else:
127 | new_word.append(word[i])
128 | i += 1
129 | new_word = tuple(new_word)
130 | word = new_word
131 | if len(word) == 1:
132 | break
133 | else:
134 | pairs = get_pairs(word)
135 | word = ' '.join(word)
136 | self.cache[token] = word
137 | return word
138 |
139 | def encode(self, text):
140 | bpe_tokens = []
141 | text = whitespace_clean(basic_clean(text)).lower()
142 | for token in re.findall(self.pat, text):
143 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
144 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
145 | return bpe_tokens
146 |
147 | def decode(self, tokens):
148 | text = ''.join([self.decoder[token] for token in tokens])
149 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
150 | return text
151 |
152 |
153 | _tokenizer = SimpleTokenizer()
154 |
155 |
156 | def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
157 | """
158 | Returns the tokenized representation of given input string(s)
159 |
160 | Parameters
161 | ----------
162 | texts : Union[str, List[str]]
163 | An input string or a list of input strings to tokenize
164 | context_length : int
165 | The context length to use; all CLIP models use 77 as the context length
166 |
167 | Returns
168 | -------
169 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
170 | """
171 | if isinstance(texts, str):
172 | texts = [texts]
173 |
174 | sot_token = _tokenizer.encoder[""]
175 | eot_token = _tokenizer.encoder[""]
176 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
177 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
178 |
179 | for i, tokens in enumerate(all_tokens):
180 | if len(tokens) > context_length:
181 | tokens = tokens[:context_length] # Truncate
182 | tokens[-1] = eot_token
183 | result[i, :len(tokens)] = torch.tensor(tokens)
184 |
185 | return result
186 |
187 |
188 | class HFTokenizer:
189 | "HuggingFace tokenizer wrapper"
190 | def __init__(self, tokenizer_name:str):
191 | from transformers import AutoTokenizer
192 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
193 |
194 | def __call__(self, texts:Union[str, List[str]], context_length:int=77) -> torch.Tensor:
195 | # same cleaning as for default tokenizer, except lowercasing
196 | # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
197 | if isinstance(texts, str):
198 | texts = [texts]
199 | texts = [whitespace_clean(basic_clean(text)) for text in texts]
200 | input_ids = self.tokenizer(texts, return_tensors='pt', max_length=context_length, padding='max_length', truncation=True).input_ids
201 | return input_ids
202 |
--------------------------------------------------------------------------------
/model/evaclip/transform.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Sequence, Tuple
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torchvision.transforms.functional as F
6 |
7 | from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
8 | CenterCrop
9 |
10 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
11 |
12 |
13 | class ResizeMaxSize(nn.Module):
14 |
15 | def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0):
16 | super().__init__()
17 | if not isinstance(max_size, int):
18 | raise TypeError(f"Size should be int. Got {type(max_size)}")
19 | self.max_size = max_size
20 | self.interpolation = interpolation
21 | self.fn = min if fn == 'min' else min
22 | self.fill = fill
23 |
24 | def forward(self, img):
25 | if isinstance(img, torch.Tensor):
26 | height, width = img.shape[:2]
27 | else:
28 | width, height = img.size
29 | scale = self.max_size / float(max(height, width))
30 | if scale != 1.0:
31 | new_size = tuple(round(dim * scale) for dim in (height, width))
32 | img = F.resize(img, new_size, self.interpolation)
33 | pad_h = self.max_size - new_size[0]
34 | pad_w = self.max_size - new_size[1]
35 | img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill)
36 | return img
37 |
38 |
39 | def _convert_to_rgb(image):
40 | return image.convert('RGB')
41 |
42 |
43 | # class CatGen(nn.Module):
44 | # def __init__(self, num=4):
45 | # self.num = num
46 | # def mixgen_batch(image, text):
47 | # batch_size = image.shape[0]
48 | # index = np.random.permutation(batch_size)
49 |
50 | # cat_images = []
51 | # for i in range(batch_size):
52 | # # image mixup
53 | # image[i,:] = lam * image[i,:] + (1 - lam) * image[index[i],:]
54 | # # text concat
55 | # text[i] = tokenizer((str(text[i]) + " " + str(text[index[i]])))[0]
56 | # text = torch.stack(text)
57 | # return image, text
58 |
59 |
60 | def image_transform(
61 | image_size: int,
62 | is_train: bool,
63 | mean: Optional[Tuple[float, ...]] = None,
64 | std: Optional[Tuple[float, ...]] = None,
65 | resize_longest_max: bool = False,
66 | fill_color: int = 0,
67 | ):
68 | mean = mean or OPENAI_DATASET_MEAN
69 | if not isinstance(mean, (list, tuple)):
70 | mean = (mean,) * 3
71 |
72 | std = std or OPENAI_DATASET_STD
73 | if not isinstance(std, (list, tuple)):
74 | std = (std,) * 3
75 |
76 | if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
77 | # for square size, pass size as int so that Resize() uses aspect preserving shortest edge
78 | image_size = image_size[0]
79 |
80 | normalize = Normalize(mean=mean, std=std)
81 | if is_train:
82 | return Compose([
83 | RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC),
84 | _convert_to_rgb,
85 | ToTensor(),
86 | normalize,
87 | ])
88 | else:
89 | if resize_longest_max:
90 | transforms = [
91 | ResizeMaxSize(image_size, fill=fill_color)
92 | ]
93 | else:
94 | transforms = [
95 | Resize(image_size, interpolation=InterpolationMode.BICUBIC),
96 | CenterCrop(image_size),
97 | ]
98 | transforms.extend([
99 | _convert_to_rgb,
100 | ToTensor(),
101 | normalize,
102 | ])
103 | return Compose(transforms)
104 |
--------------------------------------------------------------------------------
/model/imageprocessor.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | from PIL import Image
4 |
5 | import torch
6 | from torchvision.transforms.transforms import *
7 | from torchvision import transforms
8 |
9 |
10 | class ImageProcessor(object):
11 | def __init__(self, image_resolution, image_encoder_type, image_transforms='none', training=True):
12 | self.training = training
13 |
14 | self.resolution = image_resolution
15 | self.image_encoder_type = image_encoder_type
16 |
17 | if image_encoder_type.startswith('clip') or image_encoder_type.startswith('evaclip'):
18 | self.mean = [0.48145466, 0.4578275, 0.40821073]
19 | self.std = [0.26862954, 0.26130258, 0.27577711]
20 | else:
21 | self.mean = [0.485, 0.456, 0.406]
22 | self.std = [0.229, 0.224, 0.225]
23 |
24 | self.image_transforms = image_transforms
25 | if image_transforms == 'none':
26 | self.train_transforms = transforms.Compose([Resize((self.resolution,self.resolution)),
27 | Normalize(self.mean,self.std)])
28 |
29 | self.test_transforms = transforms.Compose([Resize((self.resolution,self.resolution)),
30 | Normalize(self.mean,self.std)])
31 | elif image_transforms == 'crop_flip':
32 | self.train_transforms = transforms.Compose([RandomResizedCrop(self.resolution, [0.8,1.0],[1.0,1.0]),
33 | RandomHorizontalFlip(),
34 | Normalize(self.mean,self.std)])
35 |
36 | self.test_transforms = transforms.Compose([Resize(self.resolution),
37 | CenterCrop(self.resolution),
38 | Normalize(self.mean,self.std)])
39 | else:
40 | raise NotImplementedError
41 |
42 | def __call__(self, image_file):
43 |
44 | try:
45 | img_path = image_file
46 | if not os.path.exists(img_path):
47 | print('not have image', image_file)
48 | return None
49 | img = Image.open(img_path)
50 | img = img.convert('RGB') #### convert 1-channel gray image and 4-channel CMYK image to RGB image
51 | img = transforms.ToTensor()(img)
52 | if self.training:
53 | img = self.train_transforms(img)
54 | else:
55 | img = self.test_transforms(img)
56 |
57 | img = img.unsqueeze(0)
58 |
59 | return img
60 |
61 | except Exception as e:
62 | print(e)
63 | return None
64 |
65 | if __name__ == "__main__":
66 | image_file = "./data/test.jpeg"
67 | proc = ImageProcessor(image_resolution=224, image_encoder_type="swin", training=True)
68 | image_input = proc(image_file)
69 | print(image_input.size())
--------------------------------------------------------------------------------
/model/swin_base_patch4_window7_224_22k.yaml:
--------------------------------------------------------------------------------
1 | DATA:
2 | DATASET: imagenet22K
3 | MODEL:
4 | TYPE: swin
5 | NAME: swin_base_patch4_window7_224_22k
6 | DROP_PATH_RATE: 0.2
7 | SWIN:
8 | EMBED_DIM: 128
9 | DEPTHS: [ 2, 2, 18, 2 ]
10 | NUM_HEADS: [ 4, 8, 16, 32 ]
11 | WINDOW_SIZE: 7
12 | TRAIN:
13 | EPOCHS: 90
14 | WARMUP_EPOCHS: 5
15 | WEIGHT_DECAY: 0.05
16 | BASE_LR: 1.25e-4 # 4096 batch-size
17 | WARMUP_LR: 1.25e-7
18 | MIN_LR: 1.25e-6
--------------------------------------------------------------------------------
/model/tokenizer/special_tokens_map.json:
--------------------------------------------------------------------------------
1 | {
2 | "cls_token": "[CLS]",
3 | "mask_token": "[MASK]",
4 | "pad_token": "[PAD]",
5 | "sep_token": "[SEP]",
6 | "unk_token": "[UNK]"
7 | }
8 |
--------------------------------------------------------------------------------
/model/tokenizer/tokenizer_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "cls_token": "[CLS]",
3 | "do_basic_tokenize": true,
4 | "do_lower_case": true,
5 | "mask_token": "[MASK]",
6 | "model_max_length": 512,
7 | "name_or_path": "bert-base-uncased",
8 | "never_split": null,
9 | "pad_token": "[PAD]",
10 | "sep_token": "[SEP]",
11 | "special_tokens_map_file": null,
12 | "strip_accents": null,
13 | "tokenize_chinese_chars": true,
14 | "tokenizer_class": "BertTokenizer",
15 | "unk_token": "[UNK]"
16 | }
17 |
--------------------------------------------------------------------------------
/model/transformer.py:
--------------------------------------------------------------------------------
1 | """
2 | BERT layers from the huggingface implementation
3 | (https://github.com/huggingface/transformers)
4 | """
5 | # coding=utf-8
6 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
7 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
8 | #
9 | # Licensed under the Apache License, Version 2.0 (the "License");
10 | # you may not use this file except in compliance with the License.
11 | # You may obtain a copy of the License at
12 | #
13 | # http://www.apache.org/licenses/LICENSE-2.0
14 | #
15 | # Unless required by applicable law or agreed to in writing, software
16 | # distributed under the License is distributed on an "AS IS" BASIS,
17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18 | # See the License for the specific language governing permissions and
19 | # limitations under the License.
20 | import logging
21 | import math
22 | import copy
23 | import torch
24 | from torch import nn
25 | # from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm
26 | from torch.nn import LayerNorm
27 | import torch.nn.functional as F
28 | import ipdb
29 |
30 | logger = logging.getLogger(__name__)
31 |
32 |
33 | def gelu(x):
34 | """Implementation of the gelu activation function.
35 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
36 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
37 | Also see https://arxiv.org/abs/1606.08415
38 | """
39 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
40 |
41 |
42 | def swish(x):
43 | return x * torch.sigmoid(x)
44 |
45 |
46 | ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
47 |
48 |
49 | class GELU(nn.Module):
50 | def forward(self, input_):
51 | output = gelu(input_)
52 | return output
53 |
54 |
55 |
56 |
57 | class TransformerLayer(nn.Module):
58 | def __init__(self, config, mode):
59 | super().__init__()
60 | self.attention = MultiHeadAttention(config)
61 | self.ff_layer = FeedForward(config)
62 | self.dropout = nn.Dropout(config.hidden_dropout)
63 | self.layernorm1 = LayerNorm(config.hidden_size, eps=1e-12)
64 | self.layernorm2 = LayerNorm(config.hidden_size, eps=1e-12)
65 | self.mode = mode
66 |
67 | def forward(self, hidden_states, attention_mask):
68 | if self.mode == 'prenorm':
69 | return self.forward_prenorm(hidden_states, attention_mask)
70 | elif self.mode == 'postnorm':
71 | return self.forward_postnorm(hidden_states, attention_mask)
72 | else:
73 | raise NotImplementedError
74 |
75 | def forward_prenorm(self, hidden_states, attention_mask):
76 | residual = hidden_states
77 | hidden_states = self.layernorm1(hidden_states)
78 | attention_output = self.attention(hidden_states, hidden_states, hidden_states, attention_mask)
79 | hidden_states = residual + self.dropout(attention_output)
80 |
81 | residual = hidden_states
82 | hidden_states = self.layernorm2(hidden_states)
83 | ff_output = self.ff_layer(hidden_states)
84 | hidden_states = residual + self.dropout(ff_output)
85 |
86 | return hidden_states
87 |
88 | def forward_postnorm(self, hidden_states, attention_mask):
89 | residual = hidden_states
90 | attention_output = self.attention(hidden_states, hidden_states, hidden_states, attention_mask)
91 | hidden_states = residual + self.dropout(attention_output)
92 | hidden_states = self.layernorm1(hidden_states)
93 |
94 | residual = hidden_states
95 | ff_output = self.ff_layer(hidden_states)
96 | hidden_states = residual + self.dropout(ff_output)
97 | hidden_states = self.layernorm2(hidden_states)
98 |
99 | return hidden_states
100 |
101 |
102 | def clones(x,times):
103 | return nn.ModuleList([copy.deepcopy(x) for i in range(times)])
104 |
105 |
106 |
107 | class MultiHeadAttention(nn.Module):
108 | def __init__(self, config):
109 | super().__init__()
110 | self.linears = clones(nn.Linear(config.hidden_size, config.hidden_size), 4)
111 | self.head_num = config.num_attention_heads
112 | self.hidden_size = config.hidden_size
113 | self.dropout=nn.Dropout(config.attention_dropout)
114 |
115 |
116 | def forward(self,q,k,v,mask=None):
117 | batch_size=q.shape[0]
118 | q,k,v=[layer(x).view(batch_size,-1,self.head_num, self.hidden_size//self.head_num).transpose(1,2) \
119 | for layer,x in zip(self.linears,(q,k,v))]
120 | norm_d=q.shape[-1]
121 | att_map=torch.matmul(q,k.transpose(-2,-1)) / math.sqrt(norm_d)
122 | if mask is not None:
123 | att_map=att_map + mask
124 | att_map=F.softmax(att_map,dim=-1)
125 | # import ipdb
126 | # if att_map.shape[-1] == 45:
127 | # ipdb.set_trace()
128 |
129 | att_map=self.dropout(att_map)
130 | attn_output = self.linears[-1](torch.matmul(att_map,v).transpose(1,2).contiguous().view(batch_size,-1,self.hidden_size))
131 | return attn_output
132 |
133 |
134 | class FeedForward(nn.Module):
135 | def __init__(self, config):
136 | super().__init__()
137 | self.linear1=nn.Linear(config.hidden_size, config.intermediate_size)
138 | self.linear2=nn.Linear(config.intermediate_size, config.hidden_size)
139 | self.activation = GELU()
140 |
141 |
142 | def forward(self,x):
143 | return self.linear2((self.activation(self.linear1(x))))
144 |
145 |
146 |
147 | class TransformerEncoder(nn.Module):
148 | def __init__(self, config, mode = 'prenorm'):
149 | super().__init__()
150 | layer = TransformerLayer(config, mode)
151 | self.mode = mode
152 | self.layer = nn.ModuleList([copy.deepcopy(layer)
153 | for _ in range(config.num_hidden_layers)])
154 | if self.mode == 'prenorm':
155 | self.last_layernorm = LayerNorm(config.hidden_size, eps=1e-12)
156 | self.checkpointing = config.checkpointing
157 | def forward(self, input_, attention_mask=None, cross_hidden_states=None,
158 | use_cache=False,
159 | cache=None,
160 | cache_first=False,
161 | cache_type=None):
162 | hidden_states = input_
163 | for layer_module in self.layer:
164 | if self.checkpointing:
165 | hidden_states = torch.utils.checkpoint.checkpoint(layer_module, hidden_states, attention_mask)
166 | else:
167 | hidden_states = layer_module(hidden_states, attention_mask)
168 |
169 | if self.mode == 'prenorm':
170 | hidden_states = self.last_layernorm(hidden_states)
171 | return hidden_states, cache
172 |
173 |
--------------------------------------------------------------------------------
/model/videoprocessor.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | from decord import VideoReader
4 | from PIL import Image
5 |
6 | import torch
7 | from torchvision.transforms.transforms import *
8 | from torchvision import transforms
9 |
10 |
11 | def split(frame_name_lists, sample_num):
12 | if len(frame_name_lists) < sample_num: ###padding with the last frame
13 | frame_name_lists += [frame_name_lists[-1]]*(sample_num - len(frame_name_lists))
14 | k, m = divmod(len(frame_name_lists), sample_num)
15 | return [frame_name_lists[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in list(range(sample_num))]
16 |
17 | class VideoProcessor(object):
18 | def __init__(self, video_resolution, video_encoder_type, sample_num = 4, video_transforms='none', data_format="frame", training=True):
19 | self.frame_syncaug = True
20 | self.training = training
21 | self.sample_num = sample_num
22 | self.data_format = data_format
23 |
24 | self.resolution = video_resolution
25 | self.video_encoder_type = video_encoder_type
26 |
27 | if video_encoder_type.startswith('clip') or video_encoder_type.startswith('evaclip'):
28 | self.mean = [0.48145466, 0.4578275, 0.40821073]
29 | self.std = [0.26862954, 0.26130258, 0.27577711]
30 | else:
31 | self.mean = [0.485, 0.456, 0.406]
32 | self.std = [0.229, 0.224, 0.225]
33 |
34 | self.video_transforms = video_transforms
35 | if video_transforms == 'none':
36 | self.train_transforms = transforms.Compose([Resize((self.resolution,self.resolution)),
37 | Normalize(self.mean,self.std)])
38 |
39 | self.test_transforms = transforms.Compose([Resize((self.resolution,self.resolution)),
40 | Normalize(self.mean,self.std)])
41 | elif video_transforms == 'crop_flip':
42 | self.train_transforms = transforms.Compose([RandomResizedCrop(self.resolution, [0.8,1.0],[1.0,1.0]),
43 | RandomHorizontalFlip(),
44 | Normalize(self.mean,self.std)])
45 |
46 | self.test_transforms = transforms.Compose([Resize(self.resolution),
47 | CenterCrop(self.resolution),
48 | Normalize(self.mean,self.std)])
49 | else:
50 | raise NotImplementedError
51 |
52 | def __call__(self, video_file):
53 |
54 | video_pixels = []
55 | sample_num = self.sample_num
56 | try:
57 | if self.data_format == 'frame':
58 | frame_path = video_file
59 | if not os.path.exists(video_path):
60 | print('not have videos', video_file)
61 | return None
62 | frames = os.listdir(frame_path)
63 | frames.sort() ### ['img_0001.jpg','img_0002.jpg',...]
64 | sample_num = self.sample_num
65 | frames_splited = split(frames,sample_num)
66 | if self.training:
67 | sample_idx = [random.choice(i) for i in frames_splited]
68 | else:
69 | sample_idx = [i[(len(i)+1)//2-1] for i in frames_splited]
70 | for i in range(sample_num):
71 | frame = Image.open(os.path.join(frame_path,sample_idx[i]))
72 | frame = transforms.ToTensor()(frame) ## frame: 3XhXw
73 | video_pixels.append(frame.unsqueeze(0))
74 |
75 | elif self.data_format == 'raw':
76 | video_path = video_file
77 | if not os.path.exists(video_path):
78 | print('not have videos', video_file)
79 | return None
80 | container = decord.VideoReader(uri=video_path)
81 | frames_ids = list(range(len(container)))
82 |
83 | frames_splited = split(frames_ids, sample_num)
84 | if self.training:
85 | sample_idx = [random.choice(i) for i in frames_splited]
86 | else:
87 | sample_idx = [i[(len(i)+1)//2-1] for i in frames_splited]
88 |
89 | frames = container.get_batch(sample_idx).asnumpy()
90 | # print(len(frames), type(frames),sample_idx)
91 |
92 | for i in frames:
93 | frame = Image.fromarray(i)
94 | frame = transforms.ToTensor()(frame) ## frame: 3XhXw
95 | video_pixels.append(frame.unsqueeze(0))
96 |
97 |
98 | video_pixels = torch.cat(video_pixels,dim=0) ### nX3xHxW
99 | if self.training:
100 | video_pixels = self.train_transforms(video_pixels)
101 | else:
102 | video_pixels = self.test_transforms(video_pixels)
103 | return video_pixels
104 |
105 | except Exception as e:
106 | print(e)
107 | print(video_file)
108 | return None
109 |
110 | if __name__ == "__main__":
111 | video_file = "./data/test.mp4"
112 | proc = VideoProcessor(video_resolution=224, video_encoder_type="swin", sample_num=4, data_format="raw", training=True)
113 | video_input = proc(video_file)
114 | print(video_input.size())
--------------------------------------------------------------------------------
/set_env.sh:
--------------------------------------------------------------------------------
1 | pip install torch==2.0.1 -i https://pypi.tuna.tsinghua.edu.cn/simple
2 | pip install torchvision==0.15.2 -i https://pypi.tuna.tsinghua.edu.cn/simple
3 | pip install torchaudio==2.0.2 -i https://pypi.tuna.tsinghua.edu.cn/simple
4 | pip install decord -i https://pypi.tuna.tsinghua.edu.cn/simple
5 | pip install h5py -i https://pypi.tuna.tsinghua.edu.cn/simple
6 | pip install ffmpeg-python -i https://pypi.tuna.tsinghua.edu.cn/simple
7 | pip install yacs -i https://pypi.tuna.tsinghua.edu.cn/simple
8 | pip install toolz -i https://pypi.tuna.tsinghua.edu.cn/simple
9 | pip install ipdb -i https://pypi.tuna.tsinghua.edu.cn/simple
10 | pip install einops -i https://pypi.tuna.tsinghua.edu.cn/simple
11 | pip install easydict -i https://pypi.tuna.tsinghua.edu.cn/simple
12 | pip install transformers==4.31.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
13 | pip install webdataset -i https://pypi.tuna.tsinghua.edu.cn/simple
14 | pip install SentencePiece -i https://pypi.tuna.tsinghua.edu.cn/simple
--------------------------------------------------------------------------------