├── data └── .gitkeep ├── pretrained └── .gitkeep ├── assets └── pipeline.png ├── datasets ├── files │ ├── NCars │ │ ├── 10shot-repeat=False.pkl │ │ ├── 20shot-repeat=False.pkl │ │ ├── 30shot-repeat=False.pkl │ │ ├── 50shot-repeat=False.pkl │ │ ├── 100shot-repeat=False.pkl │ │ ├── 200shot-repeat=False.pkl │ │ └── 500shot-repeat=False.pkl │ ├── NCaltech101 │ │ ├── 10shot-repeat=True.pkl │ │ ├── 1shot-repeat=False.pkl │ │ ├── 1shot-repeat=True.pkl │ │ ├── 20shot-repeat=True.pkl │ │ ├── 2shot-repeat=False.pkl │ │ ├── 2shot-repeat=True.pkl │ │ ├── 3shot-repeat=False.pkl │ │ ├── 3shot-repeat=True.pkl │ │ ├── 50shot-repeat=True.pkl │ │ ├── 5shot-repeat=False.pkl │ │ ├── 5shot-repeat=True.pkl │ │ ├── 10shot-repeat=False.pkl │ │ ├── 20shot-repeat=False.pkl │ │ └── 50shot-repeat=False.pkl │ ├── NImageNet │ │ ├── 10shot-repeat=False.pkl │ │ ├── 1shot-repeat=False.pkl │ │ ├── 20shot-repeat=False.pkl │ │ ├── 2shot-repeat=False.pkl │ │ ├── 3shot-repeat=False.pkl │ │ ├── 50shot-repeat=False.pkl │ │ └── 5shot-repeat=False.pkl │ └── NImageNetMini │ │ ├── 10shot-repeat=True.pkl │ │ ├── 1shot-repeat=True.pkl │ │ ├── 20shot-repeat=True.pkl │ │ ├── 2shot-repeat=True.pkl │ │ ├── 3shot-repeat=True.pkl │ │ ├── 50shot-repeat=True.pkl │ │ └── 5shot-repeat=True.pkl ├── __init__.py ├── cars.py ├── utils.py ├── imagenet.py ├── vis.py ├── imagenet_mini.py ├── event2img.py ├── augment.py └── caltech.py ├── scripts ├── test_all_subset.sh ├── train_all_shots.sh ├── test_all_arch.sh ├── dup_run_sbatch.sh ├── resubmit_failed_job.sh └── sbatch_run.sh ├── models ├── __init__.py ├── adapter.py ├── clip_cls_ft.py ├── clip_cls.py └── lora.py ├── configs ├── zsclip │ ├── zsclip_ncars_params.py │ ├── zsclip_nin_params.py │ ├── zsclip_ncaltech_params.py │ └── zsclip_nin_mini_params-vitb32.py ├── fsclip │ ├── joint_adapter │ │ ├── joint_fsclip_nin_params.py │ │ ├── joint_fsclip_ncaltech_params.py │ │ ├── joint_fsclip_ncars_params.py │ │ └── joint_fsclip_nin_mini_params-vitb32.py │ └── text_adapter │ │ ├── text_fsclip_nin_params.py │ │ ├── text_fsclip_ncaltech_params.py │ │ └── text_fsclip_ncars_params.py └── ftclip │ ├── ft_text_fsclip_nin_params.py │ ├── ft_text_fsclip_nin_params-lora16.py │ ├── ft_text_fsclip_ncaltech_params-vitb16.py │ └── ft_text_fsclip_nin_params-vitb16.py ├── LICENSE ├── docs ├── install.md ├── data.md └── benchmark.md ├── .gitignore ├── README.md ├── train.py ├── environment.yml ├── method.py ├── test.py └── gen_data.py /data/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pretrained/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/assets/pipeline.png -------------------------------------------------------------------------------- /datasets/files/NCars/10shot-repeat=False.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NCars/10shot-repeat=False.pkl -------------------------------------------------------------------------------- /datasets/files/NCars/20shot-repeat=False.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NCars/20shot-repeat=False.pkl -------------------------------------------------------------------------------- /datasets/files/NCars/30shot-repeat=False.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NCars/30shot-repeat=False.pkl -------------------------------------------------------------------------------- /datasets/files/NCars/50shot-repeat=False.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NCars/50shot-repeat=False.pkl -------------------------------------------------------------------------------- /datasets/files/NCars/100shot-repeat=False.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NCars/100shot-repeat=False.pkl -------------------------------------------------------------------------------- /datasets/files/NCars/200shot-repeat=False.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NCars/200shot-repeat=False.pkl -------------------------------------------------------------------------------- /datasets/files/NCars/500shot-repeat=False.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NCars/500shot-repeat=False.pkl -------------------------------------------------------------------------------- /datasets/files/NCaltech101/10shot-repeat=True.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NCaltech101/10shot-repeat=True.pkl -------------------------------------------------------------------------------- /datasets/files/NCaltech101/1shot-repeat=False.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NCaltech101/1shot-repeat=False.pkl -------------------------------------------------------------------------------- /datasets/files/NCaltech101/1shot-repeat=True.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NCaltech101/1shot-repeat=True.pkl -------------------------------------------------------------------------------- /datasets/files/NCaltech101/20shot-repeat=True.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NCaltech101/20shot-repeat=True.pkl -------------------------------------------------------------------------------- /datasets/files/NCaltech101/2shot-repeat=False.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NCaltech101/2shot-repeat=False.pkl -------------------------------------------------------------------------------- /datasets/files/NCaltech101/2shot-repeat=True.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NCaltech101/2shot-repeat=True.pkl -------------------------------------------------------------------------------- /datasets/files/NCaltech101/3shot-repeat=False.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NCaltech101/3shot-repeat=False.pkl -------------------------------------------------------------------------------- /datasets/files/NCaltech101/3shot-repeat=True.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NCaltech101/3shot-repeat=True.pkl -------------------------------------------------------------------------------- /datasets/files/NCaltech101/50shot-repeat=True.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NCaltech101/50shot-repeat=True.pkl -------------------------------------------------------------------------------- /datasets/files/NCaltech101/5shot-repeat=False.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NCaltech101/5shot-repeat=False.pkl -------------------------------------------------------------------------------- /datasets/files/NCaltech101/5shot-repeat=True.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NCaltech101/5shot-repeat=True.pkl -------------------------------------------------------------------------------- /datasets/files/NImageNet/10shot-repeat=False.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NImageNet/10shot-repeat=False.pkl -------------------------------------------------------------------------------- /datasets/files/NImageNet/1shot-repeat=False.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NImageNet/1shot-repeat=False.pkl -------------------------------------------------------------------------------- /datasets/files/NImageNet/20shot-repeat=False.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NImageNet/20shot-repeat=False.pkl -------------------------------------------------------------------------------- /datasets/files/NImageNet/2shot-repeat=False.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NImageNet/2shot-repeat=False.pkl -------------------------------------------------------------------------------- /datasets/files/NImageNet/3shot-repeat=False.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NImageNet/3shot-repeat=False.pkl -------------------------------------------------------------------------------- /datasets/files/NImageNet/50shot-repeat=False.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NImageNet/50shot-repeat=False.pkl -------------------------------------------------------------------------------- /datasets/files/NImageNet/5shot-repeat=False.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NImageNet/5shot-repeat=False.pkl -------------------------------------------------------------------------------- /datasets/files/NCaltech101/10shot-repeat=False.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NCaltech101/10shot-repeat=False.pkl -------------------------------------------------------------------------------- /datasets/files/NCaltech101/20shot-repeat=False.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NCaltech101/20shot-repeat=False.pkl -------------------------------------------------------------------------------- /datasets/files/NCaltech101/50shot-repeat=False.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NCaltech101/50shot-repeat=False.pkl -------------------------------------------------------------------------------- /datasets/files/NImageNetMini/10shot-repeat=True.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NImageNetMini/10shot-repeat=True.pkl -------------------------------------------------------------------------------- /datasets/files/NImageNetMini/1shot-repeat=True.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NImageNetMini/1shot-repeat=True.pkl -------------------------------------------------------------------------------- /datasets/files/NImageNetMini/20shot-repeat=True.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NImageNetMini/20shot-repeat=True.pkl -------------------------------------------------------------------------------- /datasets/files/NImageNetMini/2shot-repeat=True.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NImageNetMini/2shot-repeat=True.pkl -------------------------------------------------------------------------------- /datasets/files/NImageNetMini/3shot-repeat=True.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NImageNetMini/3shot-repeat=True.pkl -------------------------------------------------------------------------------- /datasets/files/NImageNetMini/50shot-repeat=True.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NImageNetMini/50shot-repeat=True.pkl -------------------------------------------------------------------------------- /datasets/files/NImageNetMini/5shot-repeat=True.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NImageNetMini/5shot-repeat=True.pkl -------------------------------------------------------------------------------- /scripts/test_all_subset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # run `test.py` with different N-IN robustness variants 4 | 5 | CMD=$1 6 | 7 | for subset in -1 1 2 3 4 5 6 7 8 9 8 | do 9 | cmd="$CMD --subset $subset" 10 | echo $cmd 11 | eval $cmd 12 | done 13 | -------------------------------------------------------------------------------- /scripts/train_all_shots.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # run `train.py` with different number of shots 4 | CMD=$1 5 | 6 | shot1=${2:-20} 7 | shot2=${3:-10} 8 | shot3=${4:-5} 9 | shot4=${5:-3} 10 | shot5=${6:-1} 11 | 12 | for shot in $shot1 $shot2 $shot3 $shot4 $shot5 13 | do 14 | cmd="$CMD --num_shots $shot" 15 | echo $cmd 16 | eval $cmd 17 | done 18 | -------------------------------------------------------------------------------- /scripts/test_all_arch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # run `test.py` on all CLIP arches 4 | CMD=$1 5 | 6 | for arch in 'RN50' 'RN101' 'RN50x4' 'RN50x16' 'RN50x64' 'ViT-B/32' 'ViT-B/16' 'ViT-L/14' 7 | do 8 | if [ "$arch" = "RN50x64" ]; then 9 | bs=32 10 | else 11 | bs=64 12 | fi 13 | 14 | echo "Testing $arch" 15 | cmd="$CMD --arch $arch --bs $bs" 16 | echo $cmd 17 | eval $cmd 18 | done 19 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip_cls import ZSCLIPClassifier, FSCLIPClassifier 2 | from .clip_cls_ft import FTCLIPClassifier 3 | 4 | 5 | def build_model(params): 6 | if params.model == 'ZSCLIP': 7 | return ZSCLIPClassifier(clip_dict=params.clip_dict, ) 8 | elif params.model == 'FSCLIP': 9 | return FSCLIPClassifier( 10 | adapter_dict=params.adapter_dict, 11 | clip_dict=params.clip_dict, 12 | loss_dict=params.loss_dict, 13 | ) 14 | elif params.model == 'FTCLIP': 15 | return FTCLIPClassifier( 16 | adapter_dict=params.adapter_dict, 17 | clip_dict=params.clip_dict, 18 | loss_dict=params.loss_dict, 19 | ) 20 | else: 21 | raise NotImplementedError(f'{params.model} is not implemented.') 22 | -------------------------------------------------------------------------------- /configs/zsclip/zsclip_ncars_params.py: -------------------------------------------------------------------------------- 1 | from nerv.training import BaseParams 2 | 3 | 4 | class EventCLIPParams(BaseParams): 5 | project = 'EventCLIP' 6 | 7 | # training settings 8 | gpus = 1 9 | 10 | # data settings 11 | dataset = 'n_cars' 12 | data_root = './data/N-Cars/' 13 | train_batch_size = 32 // gpus 14 | val_batch_size = train_batch_size * 2 15 | num_workers = 8 16 | 17 | # event2img conversion 18 | quantize_args = dict( 19 | max_imgs=2, 20 | N=30000, 21 | split_method='event_count', 22 | convert_method='event_histogram', 23 | grayscale=True, 24 | count_non_zero=True, 25 | background_mask=False, 26 | ) 27 | 28 | # model configs 29 | model = 'ZSCLIP' 30 | clip_dict = dict( 31 | # 'RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32' 32 | # 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px' 33 | arch='ViT-L/14', 34 | prompt='a point cloud image of a {}', 35 | agg_func='mean', # aggregate the logits over views 36 | ) 37 | -------------------------------------------------------------------------------- /configs/zsclip/zsclip_nin_params.py: -------------------------------------------------------------------------------- 1 | from nerv.training import BaseParams 2 | 3 | 4 | class EventCLIPParams(BaseParams): 5 | project = 'EventCLIP' 6 | 7 | # training settings 8 | gpus = 1 9 | 10 | # data settings 11 | dataset = 'n_imagenet' 12 | data_root = './data/N_Imagenet/' 13 | train_batch_size = 32 // gpus 14 | val_batch_size = train_batch_size * 2 15 | num_workers = 16 16 | 17 | # event2img conversion 18 | quantize_args = dict( 19 | max_imgs=2, 20 | N=70000, 21 | split_method='event_count', 22 | convert_method='event_histogram', 23 | grayscale=True, 24 | count_non_zero=False, 25 | background_mask=True, 26 | ) 27 | 28 | # model configs 29 | model = 'ZSCLIP' 30 | clip_dict = dict( 31 | # 'RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32' 32 | # 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px' 33 | arch='ViT-L/14', 34 | prompt='a point cloud image of a {}', 35 | agg_func='mean', # aggregate the logits over views 36 | ) 37 | -------------------------------------------------------------------------------- /configs/zsclip/zsclip_ncaltech_params.py: -------------------------------------------------------------------------------- 1 | from nerv.training import BaseParams 2 | 3 | 4 | class EventCLIPParams(BaseParams): 5 | project = 'EventCLIP' 6 | 7 | # training settings 8 | gpus = 1 9 | 10 | # data settings 11 | dataset = 'n_caltech' 12 | data_root = './data/N-Caltech101/' 13 | train_batch_size = 32 // gpus 14 | val_batch_size = train_batch_size * 2 15 | num_workers = 8 16 | 17 | # event2img conversion 18 | quantize_args = dict( 19 | max_imgs=2, 20 | N=20000, 21 | split_method='event_count', 22 | convert_method='event_histogram', 23 | grayscale=True, 24 | count_non_zero=False, 25 | background_mask=True, 26 | ) 27 | 28 | # model configs 29 | model = 'ZSCLIP' 30 | clip_dict = dict( 31 | # 'RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32' 32 | # 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px' 33 | arch='ViT-L/14', 34 | prompt='a point cloud image of a {}', 35 | agg_func='mean', # aggregate the logits over views 36 | ) 37 | -------------------------------------------------------------------------------- /configs/zsclip/zsclip_nin_mini_params-vitb32.py: -------------------------------------------------------------------------------- 1 | from nerv.training import BaseParams 2 | 3 | 4 | class EventCLIPParams(BaseParams): 5 | project = 'EventCLIP' 6 | 7 | # training settings 8 | gpus = 1 9 | 10 | # data settings 11 | dataset = 'n_imagenet_mini' 12 | data_root = './data/N_Imagenet/' 13 | train_batch_size = 32 // gpus 14 | val_batch_size = train_batch_size * 2 15 | num_workers = 16 16 | 17 | # event2img conversion 18 | quantize_args = dict( 19 | max_imgs=2, 20 | N=70000, 21 | split_method='event_count', 22 | convert_method='event_histogram', 23 | grayscale=True, 24 | count_non_zero=False, 25 | background_mask=True, 26 | ) 27 | 28 | # model configs 29 | model = 'ZSCLIP' 30 | clip_dict = dict( 31 | # 'RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32' 32 | # 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px' 33 | arch='ViT-B/32', 34 | prompt='a sketch image of a {}', 35 | agg_func='mean', # aggregate the logits over views 36 | ) 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Ziyi Wu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | from .caltech import build_n_caltech_dataset, NCaltech101 4 | from .cars import build_n_cars_dataset, NCars 5 | from .imagenet import build_n_imagenet_dataset, NImageNet 6 | from .imagenet_mini import build_n_imagenet_mini_dataset, NImageNetMini 7 | from .event2img import build_event2img_dataset, Event2ImageDataset 8 | from .vis import events2frames 9 | 10 | 11 | def build_dataset(params, val_only=False, **kwargs): 12 | # `gen_data` means doing pseudo label generation for self-training 13 | gen_data = kwargs.pop('gen_data', False) 14 | tta = kwargs.pop('tta', False) 15 | 16 | dst = params.dataset 17 | ev_dst = eval(f'build_{dst}_dataset')( 18 | params, val_only=val_only, gen_data=gen_data, **kwargs) 19 | 20 | # adjust max-views for event2img conversion 21 | train_params = copy.deepcopy(params) 22 | val_params = copy.deepcopy(params) 23 | val_params.quantize_args['max_imgs'] = 10 # load all views for testing 24 | 25 | # only build one dataset in these cases 26 | if val_only or gen_data: 27 | return build_event2img_dataset(val_params, ev_dst, tta=tta) 28 | 29 | # build both train and val datasets 30 | return build_event2img_dataset( 31 | train_params, ev_dst[0], augment=params.get('img_aug', False)), \ 32 | build_event2img_dataset(val_params, ev_dst[1], augment=False) 33 | -------------------------------------------------------------------------------- /datasets/cars.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .caltech import NCaltech101 4 | 5 | NEW_CNAMES = { 6 | "cars": "car", 7 | "background": "background", 8 | } 9 | 10 | 11 | class NCars(NCaltech101): 12 | """Dataset class for N-Cars dataset.""" 13 | 14 | def __init__( 15 | self, 16 | root, 17 | augmentation=False, 18 | num_shots=None, 19 | new_cnames=None, 20 | ): 21 | super().__init__( 22 | root=root, 23 | augmentation=augmentation, 24 | num_shots=num_shots, 25 | repeat=False, 26 | new_cnames=new_cnames, 27 | ) 28 | 29 | # data stats 30 | self.resolution = (100, 120) 31 | self.max_t = 0.1 # max 32 | self.max_n = 12500 # 95th percentile 33 | 34 | # data augmentation 35 | self.max_shift = 10 # resolution is ~half as N-Caltech101 36 | 37 | 38 | def build_n_cars_dataset(params, val_only=False, gen_data=False): 39 | """Build the N-Cars dataset.""" 40 | # only build the test set 41 | test_set = NCars( 42 | root=os.path.join(params.data_root, 'test'), 43 | augmentation=False, 44 | new_cnames=NEW_CNAMES, 45 | ) 46 | if val_only: 47 | assert not gen_data 48 | return test_set 49 | # build the training set for pseudo label generation 50 | if gen_data: 51 | return NCars( 52 | root=os.path.join(params.data_root, 'train'), 53 | augmentation=False, 54 | new_cnames=NEW_CNAMES, 55 | ) 56 | 57 | # build the training set 58 | train_set = NCars( 59 | root=os.path.join(params.data_root, 'train'), 60 | augmentation=True, 61 | num_shots=params.get('num_shots', None), 62 | new_cnames=NEW_CNAMES, 63 | ) 64 | return train_set, test_set 65 | -------------------------------------------------------------------------------- /docs/install.md: -------------------------------------------------------------------------------- 1 | # Install 2 | 3 | We recommend using [conda](https://docs.conda.io/projects/conda/en/latest/user-guide/install/index.html) for environment setup: 4 | 5 | ``` 6 | conda create -n eventclip python=3.9.17 7 | conda activate eventclip 8 | ``` 9 | 10 | Then install PyTorch which is compatible with your CUDA setting. 11 | In our experiments, we use PyTorch 1.12.1 + CUDA 11.3 (the CUDA version is fine as long as it meets the requirement [here](https://pytorch.org/get-started/previous-versions/)). 12 | PyTorch 2.0.0 is also compatible and can speed up model training: 13 | 14 | ``` 15 | conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch 16 | pip install pytorch-lightning==1.8.6 torchmetrics==0.11.4 17 | ``` 18 | 19 | The codebase heavily relies on [nerv](https://github.com/Wuziyi616/nerv) for project template and Trainer. 20 | You can easily install it by: 21 | 22 | ``` 23 | git clone git@github.com:Wuziyi616/nerv.git 24 | cd nerv 25 | git checkout v0.3.1 # tested with v0.3.1 release 26 | pip install -e . 27 | ``` 28 | 29 | This will automatically install packages necessary for the project. 30 | Additional packages are listed as follows: 31 | 32 | ``` 33 | pip install ftfy regex tqdm # packages required by CLIP 34 | pip install git+https://github.com/openai/CLIP.git # install CLIP from OpenAI 35 | ``` 36 | 37 | Finally, clone this project by: 38 | 39 | ``` 40 | cd .. # move out from nerv/ 41 | git clone git@github.com:Wuziyi616/EventCLIP.git 42 | cd EventCLIP 43 | ``` 44 | 45 | We use [wandb](https://wandb.ai/) for logging, please run `wandb login` to log in. 46 | 47 | ## Possible Issues 48 | 49 | - In case you encounter any environmental issues, you can refer to the conda env file exported from my server [environment.yml](../environment.yml). 50 | You can install the same environment by `conda env create -f environment.yml`. 51 | -------------------------------------------------------------------------------- /scripts/dup_run_sbatch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This is a wrapper for `sbatch_run.sh` to run repeated experiments 4 | # It will duplicate the same params file for several times and run them all 5 | 6 | ####################################################################### 7 | # An example usage: 8 | # GPUS=1 CPUS_PER_GPU=8 MEM_PER_CPU=5 QOS=scavenger REPEAT=3 ./scripts/dup_run_sbatch.sh \ 9 | # rtx6000 test-sbatch ./train.py ddp params.py --fp16 --ddp --cudnn 10 | ####################################################################### 11 | 12 | # read args from command line 13 | REPEAT=${REPEAT:-3} 14 | GPUS=${GPUS:-1} 15 | CPUS_PER_GPU=${CPUS_PER_GPU:-8} 16 | MEM_PER_CPU=${MEM_PER_CPU:-5} 17 | QOS=${QOS:-scavenger} 18 | TIME=${TIME:-96:00:00} 19 | 20 | PY_ARGS=${@:6} 21 | PARTITION=$1 22 | JOB_NAME=$2 23 | PY_FILE=$3 24 | DDP=$4 25 | PARAMS=$5 26 | 27 | for repeat_idx in $(seq 1 $REPEAT) 28 | do 29 | params="${PARAMS:0:(-3)}-dup${repeat_idx}.py" 30 | cp $PARAMS $params 31 | job_name="${JOB_NAME}-dup${repeat_idx}" 32 | # if `$PY_ARGS` contains "--N X", then append "-N_X" to `job_name` 33 | if [[ $PY_ARGS == *"--N"* ]]; then 34 | N=$(echo $PY_ARGS | grep -oP "(?<=--N )\d+") 35 | # only modify when `X` is positive 36 | if [[ $N -gt 0 ]]; then 37 | job_name="${job_name}-N_${N}" 38 | fi 39 | fi 40 | # if `$PY_ARGS` contains "--num_shots X", then append "-Xshot" to `job_name` 41 | if [[ $PY_ARGS == *"--num_shots"* ]]; then 42 | num_shots=$(echo $PY_ARGS | grep -oP "(?<=--num_shots )\d+") 43 | # only modify when `X` is positive 44 | if [[ $num_shots -gt 0 ]]; then 45 | job_name="${job_name}-${num_shots}shot" 46 | fi 47 | fi 48 | cmd="./scripts/sbatch_run.sh $PARTITION $job_name $PY_FILE $DDP --params $params $PY_ARGS" 49 | echo $cmd 50 | eval $cmd 51 | done 52 | -------------------------------------------------------------------------------- /datasets/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def random_shift_events(events, max_shift=20, resolution=(180, 240)): 5 | """Spatially shift events by a random offset.""" 6 | H, W = resolution 7 | x_shift, y_shift = np.random.randint(-max_shift, max_shift + 1, size=(2, )) 8 | events[:, 0] += x_shift 9 | events[:, 1] += y_shift 10 | 11 | valid_events = (events[:, 0] >= 0) & (events[:, 0] < W) & \ 12 | (events[:, 1] >= 0) & (events[:, 1] < H) 13 | events = events[valid_events] 14 | 15 | return events 16 | 17 | 18 | def random_flip_events_along_x(events, resolution=(180, 240), p=0.5): 19 | """Flip events along horizontally with probability p.""" 20 | H, W = resolution 21 | if np.random.random() < p: 22 | events[:, 0] = W - 1 - events[:, 0] 23 | return events 24 | 25 | 26 | def random_time_flip_events(events, p=0.5): 27 | """Flip events over time with probability p.""" 28 | if np.random.random() < p: 29 | events = np.flip(events, axis=0) 30 | events = np.ascontiguousarray(events) 31 | # reverse the time 32 | events[:, 2] = events[0, 2] - events[:, 2] 33 | # reverse the polarity 34 | events[:, 3] = -events[:, 3] 35 | return events 36 | 37 | 38 | def center_events(events, resolution=(180, 240)): 39 | """Center the temporal & spatial coordinates of events. 40 | Make min_t == 0. 41 | Make (max_x + min_x + 1) / 2 == W / 2 and (max_y + min_y + 1) / 2 == H / 2. 42 | 43 | Args: 44 | events: [N, 4 (x,y,t,p)] 45 | resolution: (H, W) 46 | """ 47 | # temporal 48 | events[:, 2] -= events[:, 2].min() 49 | # spatial 50 | H, W = resolution 51 | x_min, x_max = events[:, 0].min(), events[:, 0].max() 52 | y_min, y_max = events[:, 1].min(), events[:, 1].max() 53 | x_shift = ((x_max + x_min + 1.) - W) // 2. 54 | y_shift = ((y_max + y_min + 1.) - H) // 2. 55 | events[:, 0] -= x_shift 56 | events[:, 1] -= y_shift 57 | return events 58 | -------------------------------------------------------------------------------- /configs/fsclip/joint_adapter/joint_fsclip_nin_params.py: -------------------------------------------------------------------------------- 1 | from nerv.training import BaseParams 2 | 3 | 4 | class EventCLIPParams(BaseParams): 5 | project = 'EventCLIP' 6 | 7 | # training settings 8 | gpus = 1 9 | max_epochs = 100 10 | save_interval = 1 11 | eval_interval = 5 12 | save_epoch_end = False 13 | n_samples = 10 14 | 15 | # optimizer settings 16 | # Adam optimizer, Cosine decay with Warmup 17 | optimizer = 'Adam' 18 | lr = 2e-5 19 | warmup_steps_pct = 0.05 20 | 21 | # data settings 22 | dataset = 'n_imagenet' 23 | data_root = './data/N_Imagenet/' 24 | num_shots = None 25 | img_aug = True 26 | train_batch_size = 128 // gpus 27 | val_batch_size = train_batch_size * 2 28 | num_workers = 16 29 | 30 | # event2img conversion 31 | quantize_args = dict( 32 | max_imgs=2, 33 | N=70000, 34 | split_method='event_count', 35 | convert_method='event_histogram', 36 | grayscale=True, 37 | count_non_zero=False, 38 | background_mask=True, 39 | ) 40 | 41 | # model configs 42 | model = 'FSCLIP' 43 | clip_dict = dict( 44 | # 'RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32' 45 | # 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px' 46 | arch='ViT-L/14', 47 | prompt='a point cloud image of a {}', 48 | agg_func='mean', # aggregate the logits over views 49 | ) 50 | 51 | # adapter configs 52 | d_model = 256 53 | adapter_dict = dict( 54 | adapter_type='text-trans', 55 | in_dim=512, 56 | d_model=d_model, 57 | num_heads=d_model // 64, 58 | ffn_dim=d_model * 4, 59 | norm_first=True, 60 | num_layers=2, 61 | residual=0.95, 62 | ) 63 | 64 | # loss configs 65 | loss_dict = dict( 66 | use_logits_loss=True, # CE over mean logits 67 | use_probs_loss=False, # CE over mean probs 68 | ) 69 | 70 | ce_loss_w = 1. 71 | 72 | # save the model with the highest acc 73 | ckp_monitor = 'val/probs_acc' 74 | ckp_monitor_type = 'max' # 'max' or 'min' 75 | -------------------------------------------------------------------------------- /configs/fsclip/text_adapter/text_fsclip_nin_params.py: -------------------------------------------------------------------------------- 1 | from nerv.training import BaseParams 2 | 3 | 4 | class EventCLIPParams(BaseParams): 5 | project = 'EventCLIP' 6 | 7 | # training settings 8 | gpus = 1 9 | max_epochs = 100 10 | save_interval = 1 11 | eval_interval = 5 12 | save_epoch_end = False 13 | n_samples = 10 14 | 15 | # optimizer settings 16 | # Adam optimizer, Cosine decay with Warmup 17 | optimizer = 'Adam' 18 | lr = 2e-5 19 | warmup_steps_pct = 0.05 20 | 21 | # data settings 22 | dataset = 'n_imagenet' 23 | data_root = './data/N_Imagenet/' 24 | num_shots = None 25 | img_aug = True 26 | train_batch_size = 128 // gpus 27 | val_batch_size = train_batch_size * 2 28 | num_workers = 16 29 | 30 | # event2img conversion 31 | quantize_args = dict( 32 | max_imgs=2, 33 | N=70000, 34 | split_method='event_count', 35 | convert_method='event_histogram', 36 | grayscale=True, 37 | count_non_zero=False, 38 | background_mask=True, 39 | ) 40 | 41 | # model configs 42 | model = 'FSCLIP' 43 | clip_dict = dict( 44 | # 'RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32' 45 | # 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px' 46 | arch='ViT-L/14', 47 | prompt='a point cloud image of a {}', 48 | agg_func='mean', # aggregate the logits over views 49 | ) 50 | 51 | # adapter configs 52 | d_model = 256 53 | adapter_dict = dict( 54 | adapter_type='text-identity', 55 | in_dim=512, 56 | d_model=d_model, 57 | num_heads=d_model // 64, 58 | ffn_dim=d_model * 4, 59 | norm_first=True, 60 | num_layers=2, 61 | residual=0.95, 62 | ) 63 | 64 | # loss configs 65 | loss_dict = dict( 66 | use_logits_loss=True, # CE over mean logits 67 | use_probs_loss=False, # CE over mean probs 68 | ) 69 | 70 | ce_loss_w = 1. 71 | 72 | # save the model with the highest acc 73 | ckp_monitor = 'val/probs_acc' 74 | ckp_monitor_type = 'max' # 'max' or 'min' 75 | -------------------------------------------------------------------------------- /configs/fsclip/joint_adapter/joint_fsclip_ncaltech_params.py: -------------------------------------------------------------------------------- 1 | from nerv.training import BaseParams 2 | 3 | 4 | class EventCLIPParams(BaseParams): 5 | project = 'EventCLIP' 6 | 7 | # training settings 8 | gpus = 1 9 | max_epochs = 100 10 | save_interval = 1 11 | eval_interval = 5 12 | save_epoch_end = False 13 | n_samples = 5 14 | 15 | # optimizer settings 16 | # Adam optimizer, Cosine decay with Warmup 17 | optimizer = 'Adam' 18 | lr = 1e-4 19 | warmup_steps_pct = 0.05 20 | 21 | # data settings 22 | dataset = 'n_caltech' 23 | data_root = './data/N-Caltech101/' 24 | num_shots = None 25 | repeat_data = True 26 | img_aug = True 27 | train_batch_size = 32 // gpus 28 | val_batch_size = train_batch_size * 2 29 | num_workers = 8 30 | 31 | # event2img conversion 32 | quantize_args = dict( 33 | max_imgs=2, 34 | N=20000, 35 | split_method='event_count', 36 | convert_method='event_histogram', 37 | grayscale=True, 38 | count_non_zero=False, 39 | background_mask=True, 40 | ) 41 | 42 | # model configs 43 | model = 'FSCLIP' 44 | clip_dict = dict( 45 | # 'RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32' 46 | # 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px' 47 | arch='ViT-L/14', 48 | prompt='a point cloud image of a {}', 49 | agg_func='mean', # aggregate the logits over views 50 | ) 51 | 52 | # adapter configs 53 | d_model = 256 54 | adapter_dict = dict( 55 | adapter_type='text-trans', 56 | in_dim=512, 57 | d_model=d_model, 58 | num_heads=d_model // 64, 59 | ffn_dim=d_model * 4, 60 | norm_first=True, 61 | num_layers=2, 62 | residual=0.8, 63 | ) 64 | 65 | # loss configs 66 | loss_dict = dict( 67 | use_logits_loss=True, # CE over mean logits 68 | use_probs_loss=False, # CE over mean probs 69 | ) 70 | 71 | ce_loss_w = 1. 72 | 73 | # save the model with the highest acc 74 | ckp_monitor = 'val/probs_acc' 75 | ckp_monitor_type = 'max' # 'max' or 'min' 76 | -------------------------------------------------------------------------------- /configs/fsclip/text_adapter/text_fsclip_ncaltech_params.py: -------------------------------------------------------------------------------- 1 | from nerv.training import BaseParams 2 | 3 | 4 | class EventCLIPParams(BaseParams): 5 | project = 'EventCLIP' 6 | 7 | # training settings 8 | gpus = 1 9 | max_epochs = 100 10 | save_interval = 1 11 | eval_interval = 5 12 | save_epoch_end = False 13 | n_samples = 5 14 | 15 | # optimizer settings 16 | # Adam optimizer, Cosine decay with Warmup 17 | optimizer = 'Adam' 18 | lr = 1e-4 19 | warmup_steps_pct = 0.05 20 | 21 | # data settings 22 | dataset = 'n_caltech' 23 | data_root = './data/N-Caltech101/' 24 | num_shots = None 25 | repeat_data = True 26 | img_aug = True 27 | train_batch_size = 32 // gpus 28 | val_batch_size = train_batch_size * 2 29 | num_workers = 8 30 | 31 | # event2img conversion 32 | quantize_args = dict( 33 | max_imgs=2, 34 | N=20000, 35 | split_method='event_count', 36 | convert_method='event_histogram', 37 | grayscale=True, 38 | count_non_zero=False, 39 | background_mask=True, 40 | ) 41 | 42 | # model configs 43 | model = 'FSCLIP' 44 | clip_dict = dict( 45 | # 'RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32' 46 | # 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px' 47 | arch='ViT-L/14', 48 | prompt='a point cloud image of a {}', 49 | agg_func='mean', # aggregate the logits over views 50 | ) 51 | 52 | # adapter configs 53 | d_model = 256 54 | adapter_dict = dict( 55 | adapter_type='text-identity', 56 | in_dim=512, 57 | d_model=d_model, 58 | num_heads=d_model // 64, 59 | ffn_dim=d_model * 4, 60 | norm_first=True, 61 | num_layers=2, 62 | residual=0.8, 63 | ) 64 | 65 | # loss configs 66 | loss_dict = dict( 67 | use_logits_loss=True, # CE over mean logits 68 | use_probs_loss=False, # CE over mean probs 69 | ) 70 | 71 | ce_loss_w = 1. 72 | 73 | # save the model with the highest acc 74 | ckp_monitor = 'val/probs_acc' 75 | ckp_monitor_type = 'max' # 'max' or 'min' 76 | -------------------------------------------------------------------------------- /configs/fsclip/joint_adapter/joint_fsclip_ncars_params.py: -------------------------------------------------------------------------------- 1 | from nerv.training import BaseParams 2 | 3 | 4 | class EventCLIPParams(BaseParams): 5 | project = 'EventCLIP' 6 | 7 | # training settings 8 | gpus = 1 9 | max_epochs = 50 10 | save_interval = 1 11 | eval_interval = 5 12 | save_epoch_end = False 13 | n_samples = 5 14 | 15 | # optimizer settings 16 | # Adam optimizer, Cosine decay with Warmup 17 | optimizer = 'Adam' 18 | lr = 2e-4 19 | warmup_steps_pct = 0.05 20 | 21 | # data settings 22 | dataset = 'n_cars' 23 | data_root = './data/N-Cars/' 24 | num_shots = None 25 | img_aug = False 26 | train_batch_size = 32 // gpus if \ 27 | num_shots is None else min(num_shots * 2, 32) // gpus 28 | val_batch_size = max(train_batch_size, 32 // gpus) * 2 29 | num_workers = 8 30 | 31 | # event2img conversion 32 | quantize_args = dict( 33 | max_imgs=2, 34 | N=30000, 35 | split_method='event_count', 36 | convert_method='event_histogram', 37 | grayscale=True, 38 | count_non_zero=True, 39 | background_mask=False, 40 | ) 41 | 42 | # model configs 43 | model = 'FSCLIP' 44 | clip_dict = dict( 45 | # 'RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32' 46 | # 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px' 47 | arch='ViT-L/14', 48 | prompt='a point cloud image of a {}', 49 | agg_func='mean', # aggregate the logits over views 50 | ) 51 | 52 | # adapter configs 53 | d_model = 256 54 | adapter_dict = dict( 55 | adapter_type='text-trans', 56 | in_dim=512, 57 | d_model=d_model, 58 | num_heads=d_model // 64, 59 | ffn_dim=d_model * 4, 60 | norm_first=True, 61 | num_layers=2, 62 | residual=0.8, 63 | ) 64 | 65 | # loss configs 66 | loss_dict = dict( 67 | use_logits_loss=True, # CE over mean logits 68 | use_probs_loss=False, # CE over mean probs 69 | ) 70 | 71 | ce_loss_w = 1. 72 | 73 | # save the model with the highest acc 74 | ckp_monitor = 'val/probs_acc' 75 | ckp_monitor_type = 'max' # 'max' or 'min' 76 | -------------------------------------------------------------------------------- /configs/fsclip/text_adapter/text_fsclip_ncars_params.py: -------------------------------------------------------------------------------- 1 | from nerv.training import BaseParams 2 | 3 | 4 | class EventCLIPParams(BaseParams): 5 | project = 'EventCLIP' 6 | 7 | # training settings 8 | gpus = 1 9 | max_epochs = 50 10 | save_interval = 1 11 | eval_interval = 5 12 | save_epoch_end = False 13 | n_samples = 5 14 | 15 | # optimizer settings 16 | # Adam optimizer, Cosine decay with Warmup 17 | optimizer = 'Adam' 18 | lr = 2e-4 19 | warmup_steps_pct = 0.05 20 | 21 | # data settings 22 | dataset = 'n_cars' 23 | data_root = './data/N-Cars/' 24 | num_shots = None 25 | img_aug = False 26 | train_batch_size = 32 // gpus if \ 27 | num_shots is None else min(num_shots * 2, 32) // gpus 28 | val_batch_size = max(train_batch_size, 32 // gpus) * 2 29 | num_workers = 8 30 | 31 | # event2img conversion 32 | quantize_args = dict( 33 | max_imgs=2, 34 | N=30000, 35 | split_method='event_count', 36 | convert_method='event_histogram', 37 | grayscale=True, 38 | count_non_zero=True, 39 | background_mask=False, 40 | ) 41 | 42 | # model configs 43 | model = 'FSCLIP' 44 | clip_dict = dict( 45 | # 'RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32' 46 | # 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px' 47 | arch='ViT-L/14', 48 | prompt='a point cloud image of a {}', 49 | agg_func='mean', # aggregate the logits over views 50 | ) 51 | 52 | # adapter configs 53 | d_model = 256 54 | adapter_dict = dict( 55 | adapter_type='text-identity', 56 | in_dim=512, 57 | d_model=d_model, 58 | num_heads=d_model // 64, 59 | ffn_dim=d_model * 4, 60 | norm_first=True, 61 | num_layers=2, 62 | residual=0.8, 63 | ) 64 | 65 | # loss configs 66 | loss_dict = dict( 67 | use_logits_loss=True, # CE over mean logits 68 | use_probs_loss=False, # CE over mean probs 69 | ) 70 | 71 | ce_loss_w = 1. 72 | 73 | # save the model with the highest acc 74 | ckp_monitor = 'val/probs_acc' 75 | ckp_monitor_type = 'max' # 'max' or 'min' 76 | -------------------------------------------------------------------------------- /configs/fsclip/joint_adapter/joint_fsclip_nin_mini_params-vitb32.py: -------------------------------------------------------------------------------- 1 | from nerv.training import BaseParams 2 | 3 | 4 | class EventCLIPParams(BaseParams): 5 | project = 'EventCLIP' 6 | 7 | # training settings 8 | gpus = 1 9 | max_epochs = 100 10 | save_interval = 1 11 | eval_interval = 5 12 | save_epoch_end = False 13 | n_samples = 5 14 | 15 | # optimizer settings 16 | # Adam optimizer, Cosine decay with Warmup 17 | optimizer = 'Adam' 18 | lr = 2e-5 19 | warmup_steps_pct = 0.05 20 | 21 | # data settings 22 | dataset = 'n_imagenet_mini' 23 | data_root = './data/N_Imagenet/' # change to 'data/pseudo-N_Imagenet/xxx' 24 | num_shots = None # set this to train with highest confident pseudo labels 25 | repeat_data = True 26 | img_aug = True 27 | train_batch_size = 32 // gpus 28 | val_batch_size = train_batch_size * 2 29 | num_workers = 8 30 | 31 | # event2img conversion 32 | quantize_args = dict( 33 | max_imgs=2, 34 | N=70000, 35 | split_method='event_count', 36 | convert_method='event_histogram', 37 | grayscale=True, 38 | count_non_zero=False, 39 | background_mask=True, 40 | ) 41 | 42 | # model configs 43 | model = 'FSCLIP' 44 | clip_dict = dict( 45 | # 'RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32' 46 | # 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px' 47 | arch='ViT-B/32', 48 | prompt='a sketch image of a {}', 49 | agg_func='mean', # aggregate the logits over views 50 | ) 51 | 52 | # adapter configs 53 | d_model = 256 54 | adapter_dict = dict( 55 | adapter_type='text-trans', 56 | in_dim=512, 57 | d_model=d_model, 58 | num_heads=d_model // 64, 59 | ffn_dim=d_model * 4, 60 | norm_first=True, 61 | num_layers=2, 62 | residual=0.95, 63 | ) 64 | 65 | # loss configs 66 | loss_dict = dict( 67 | use_logits_loss=True, # CE over mean logits 68 | use_probs_loss=False, # CE over mean probs 69 | ) 70 | 71 | ce_loss_w = 1. 72 | 73 | # save the model with the highest acc 74 | ckp_monitor = 'val/probs_acc' 75 | ckp_monitor_type = 'max' # 'max' or 'min' 76 | -------------------------------------------------------------------------------- /docs/data.md: -------------------------------------------------------------------------------- 1 | # Dataset Preparation 2 | 3 | All datasets should be downloaded or soft-linked to `./data/`. 4 | Or you can modify the `data_root` value in the config files. 5 | 6 | ## N-Caltech101 7 | 8 | We adopt the N-Caltech101 dataset from [EST repo](https://github.com/uzh-rpg/rpg_event_representation_learning#training). 9 | Please download and unzip the data and put it under `./data/N-Caltech101`. 10 | 11 | ## N-Cars 12 | 13 | We also use the N-Cars dataset processed by the EST repo. 14 | Please use [this link](http://rpg.ifi.uzh.ch/datasets/gehrig_et_al_iccv19/N-Cars.zip) to download it and unzip it to `./data/N-Cars`. 15 | 16 | ## N-ImageNet 17 | 18 | Please follow the instructions on the [n_imagenet repo](https://github.com/82magnolia/n_imagenet#n-imagenet-towards-robust-fine-grained-object-recognition-with-event-cameras) to download N-ImageNet and put it under `./data/N_Imagenet`. 19 | 20 | The `N-ImageNet Variants (~150GB)` are not required for training. 21 | But if you want to test the robustness of our method, you can download them as we do provide evaluation code on them. 22 | 23 | The `Mini N-ImageNet (~45 GB)` subset is not used in this project. 24 | But you can modify the dataloader if you want to experiment at a smaller scale. 25 | 26 | ## Summary 27 | 28 | **The final `data` directory should look like this:** 29 | 30 | ``` 31 | data/ 32 | ├── N-Caltech101/ 33 | │ ├── training/ 34 | │ │ ├── accordion/ # folder with events in '.npy' format 35 | │ │ ├── airplanes/ 36 | • • • 37 | • • • 38 | │ │ └── yin_yang/ 39 | │ ├── validation/ # same as 'training' 40 | │ ├── testing/ # same as 'training' 41 | ├── N-Cars/ 42 | │ ├── train/ 43 | │ │ ├── background/ # folder with events in '.npy' format 44 | │ │ ├── cars/ 45 | │ ├── test/ # same as 'train' 46 | ├── N_Imagenet/ 47 | │ ├── extracted_train/ 48 | │ │ ├── n01440764/ # folder with events in '.npy' format 49 | │ │ ├── n01443537/ 50 | • • • 51 | • • • 52 | │ │ └── n15075141/ 53 | │ ├── extracted_val/ # same as 'extracted_train' 54 | │ ├── extracted_val_brightness_4/ # these are robustness variants 55 | │ ├── extracted_val_brightness_5/ # brightness change 56 | │ ├── extracted_val_brightness_6/ 57 | │ ├── extracted_val_brightness_7/ 58 | │ ├── extracted_val_mode_1/ # trajectory change 59 | │ ├── extracted_val_mode_3/ 60 | │ ├── extracted_val_mode_5/ 61 | │ ├── extracted_val_mode_6/ 62 | └ └── extracted_val_mode_7/ 63 | ``` 64 | -------------------------------------------------------------------------------- /scripts/resubmit_failed_job.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # this is a script called by `sbatch_run.sh` internally 4 | # the goal is to re-submit the job if it doesn't end normally, i.e., `COMPLETED` or `CANCELLED` 5 | # currently we consider `FAILED`, `OUT_OF_MEMORY`, and `TIMEOUT` as abnormal ends 6 | # for other states like `PENDING`, `NONE`, we will do nothing 7 | 8 | # read args from command line 9 | JOB_ID=$1 10 | SLRM_NAME=$2 11 | LOG_FILE=$3 12 | 13 | # we first copy the sbatch file to "./sbatch/" 14 | slrm_file="run-${SLRM_NAME}.slrm" 15 | mkdir -p "./sbatch/" 16 | mv "./${slrm_file}" "./sbatch/${slrm_file}" 17 | 18 | # util function to check if string1 contains string2 19 | check_contain() { 20 | local string1="$1" 21 | local string2="$2" 22 | 23 | if [[ $string1 == *"$string2"* ]]; then 24 | return 0 # true 25 | else 26 | return 1 # false 27 | fi 28 | } 29 | 30 | # periodically check the job status 31 | while true; do 32 | read -ra arr <<< "$(sacct -j "$JOB_ID" --format State --noheader)" 33 | status="${arr[0]}" 34 | # re-submit it if it failed or OOM 35 | if check_contain "$status" "FAIL" || check_contain "$status" "OUT_OF_M" || check_contain "$status" "TIMEOUT"; then 36 | # the sbatch file is saved under "./sbatch/run-${SLRM_NAME}.slrm" 37 | # we copy it to "./", run it, and delete it 38 | cp "./sbatch/${slrm_file}" "./${slrm_file}" 39 | # should also update the JOB_ID! 40 | JOB_ID=$(sbatch --parsable $slrm_file) 41 | rm -f $slrm_file 42 | echo "Job $SLRM_NAME failed with status $status, resubmitted with JOB_ID $JOB_ID" >> $LOG_FILE 43 | # exit the loop/this script if it's 1) completed 2) cancelled 44 | # also delete the sbatch file 45 | elif check_contain "$status" "COMPLE" || check_contain "$status" "CANCEL"; then 46 | echo "Job $SLRM_NAME finished with status $status" >> $LOG_FILE 47 | rm -f "./sbatch/${slrm_file}" 48 | exit 0 49 | # do nothing if it's 1) running 2) waiting 50 | else 51 | echo "Job $SLRM_NAME, ID $JOB_ID is good with status $status" >> $LOG_FILE 52 | fi 53 | sleep 600 # check every 10 minutes 54 | done & # run in background 55 | 56 | # detach the background process with the current shell 57 | disown 58 | 59 | # ways to check if there are duplicated runs 60 | # names = str(subprocess.check_output("squeue -u jiaqixi -o '%.100j' --noheader", shell=True))[2:-1] 61 | # names = [n.strip() for n in names.split('\\n')][:-1] 62 | # [n for n in names if names.count(n) > 1] 63 | -------------------------------------------------------------------------------- /configs/ftclip/ft_text_fsclip_nin_params.py: -------------------------------------------------------------------------------- 1 | from nerv.training import BaseParams 2 | 3 | 4 | class EventCLIPParams(BaseParams): 5 | project = 'EventCLIP' 6 | 7 | # training settings 8 | gpus = 4 9 | max_epochs = 100 10 | save_interval = 1 11 | eval_interval = 5 12 | save_epoch_end = False 13 | n_samples = 10 14 | 15 | # optimizer settings 16 | # Adam optimizer, Cosine decay with Warmup 17 | optimizer = 'Adam' 18 | lr = 2e-5 19 | clip_lr = lr / 10. 20 | warmup_steps_pct = 0.05 21 | 22 | # data settings 23 | dataset = 'n_imagenet' 24 | data_root = './data/N_Imagenet/' 25 | num_shots = None 26 | img_aug = True 27 | train_batch_size = 128 // gpus 28 | val_batch_size = train_batch_size * 2 29 | num_workers = 8 30 | 31 | # event2img conversion 32 | quantize_args = dict( 33 | max_imgs=2, 34 | N=70000, 35 | split_method='event_count', 36 | convert_method='event_histogram', 37 | grayscale=True, 38 | count_non_zero=False, 39 | background_mask=True, 40 | ) 41 | 42 | # model configs 43 | model = 'FTCLIP' 44 | clip_dict = dict( 45 | # 'RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32' 46 | # 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px' 47 | arch='ViT-L/14', 48 | prompt='a point cloud image of a {}', 49 | agg_func='mean', # aggregate the logits over views 50 | lora=-1, # use LoRA fine-tuning, typically r = 4, 16 51 | only_conv1=False, # only tune the first conv layer 52 | only_bias=False, # only tune the bias terms 53 | only_ln=False, # only tune the LayerNorm layers 54 | only_cls_fc=False, # only tune the embedding projection head 55 | only_cls_token=False, # only tune the CLS token 56 | ) 57 | 58 | # adapter configs 59 | d_model = 256 60 | adapter_dict = dict( 61 | adapter_type='text-identity', 62 | in_dim=512, 63 | d_model=d_model, 64 | num_heads=d_model // 64, 65 | ffn_dim=d_model * 4, 66 | norm_first=True, 67 | num_layers=2, 68 | residual=0.95, 69 | ) 70 | 71 | # loss configs 72 | loss_dict = dict( 73 | use_logits_loss=True, # CE over mean logits 74 | use_probs_loss=False, # CE over mean probs 75 | ) 76 | 77 | ce_loss_w = 1. 78 | 79 | # save the model with the highest acc 80 | ckp_monitor = 'val/probs_acc' 81 | ckp_monitor_type = 'max' # 'max' or 'min' 82 | -------------------------------------------------------------------------------- /configs/ftclip/ft_text_fsclip_nin_params-lora16.py: -------------------------------------------------------------------------------- 1 | from nerv.training import BaseParams 2 | 3 | 4 | class EventCLIPParams(BaseParams): 5 | project = 'EventCLIP' 6 | 7 | # training settings 8 | gpus = 4 9 | max_epochs = 100 10 | save_interval = 1 11 | eval_interval = 5 12 | save_epoch_end = False 13 | n_samples = 10 14 | 15 | # optimizer settings 16 | # Adam optimizer, Cosine decay with Warmup 17 | optimizer = 'Adam' 18 | lr = 2e-5 19 | clip_lr = lr 20 | warmup_steps_pct = 0.05 21 | 22 | # data settings 23 | dataset = 'n_imagenet' 24 | data_root = './data/N_Imagenet/' 25 | num_shots = None 26 | img_aug = True 27 | train_batch_size = 128 // gpus 28 | val_batch_size = train_batch_size * 2 29 | num_workers = 8 30 | 31 | # event2img conversion 32 | quantize_args = dict( 33 | max_imgs=2, 34 | N=70000, 35 | split_method='event_count', 36 | convert_method='event_histogram', 37 | grayscale=True, 38 | count_non_zero=False, 39 | background_mask=True, 40 | ) 41 | 42 | # model configs 43 | model = 'FTCLIP' 44 | clip_dict = dict( 45 | # 'RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32' 46 | # 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px' 47 | arch='ViT-L/14', 48 | prompt='a point cloud image of a {}', 49 | agg_func='mean', # aggregate the logits over views 50 | lora='qkvo-16', # LoRA fine-tuning, 'qv-16', 'qkv-16' (int), 'qkvo-16' 51 | only_conv1=False, # only tune the first conv layer 52 | only_bias=False, # only tune the bias terms 53 | only_ln=False, # only tune the LayerNorm layers 54 | only_cls_fc=False, # only tune the embedding projection head 55 | only_cls_token=False, # only tune the CLS token 56 | ) 57 | 58 | # adapter configs 59 | d_model = 256 60 | adapter_dict = dict( 61 | adapter_type='text-identity', 62 | in_dim=512, 63 | d_model=d_model, 64 | num_heads=d_model // 64, 65 | ffn_dim=d_model * 4, 66 | norm_first=True, 67 | num_layers=2, 68 | residual=0.95, 69 | ) 70 | 71 | # loss configs 72 | loss_dict = dict( 73 | use_logits_loss=True, # CE over mean logits 74 | use_probs_loss=False, # CE over mean probs 75 | ) 76 | 77 | ce_loss_w = 1. 78 | 79 | # save the model with the highest acc 80 | ckp_monitor = 'val/probs_acc' 81 | ckp_monitor_type = 'max' # 'max' or 'min' 82 | -------------------------------------------------------------------------------- /configs/ftclip/ft_text_fsclip_ncaltech_params-vitb16.py: -------------------------------------------------------------------------------- 1 | from nerv.training import BaseParams 2 | 3 | 4 | class EventCLIPParams(BaseParams): 5 | project = 'EventCLIP' 6 | 7 | # training settings 8 | gpus = 1 9 | max_epochs = 50 10 | save_interval = 1 11 | eval_interval = 5 12 | save_epoch_end = False 13 | n_samples = 5 14 | 15 | # optimizer settings 16 | # Adam optimizer, Cosine decay with Warmup 17 | optimizer = 'Adam' 18 | lr = 1e-4 19 | clip_lr = lr / 10. 20 | warmup_steps_pct = 0.05 21 | 22 | # data settings 23 | dataset = 'n_caltech' 24 | data_root = './data/N-Caltech101/' 25 | num_shots = None 26 | repeat_data = True 27 | img_aug = True 28 | train_batch_size = 32 // gpus 29 | val_batch_size = train_batch_size * 2 30 | num_workers = 8 31 | 32 | # event2img conversion 33 | quantize_args = dict( 34 | max_imgs=2, 35 | N=20000, 36 | split_method='event_count', 37 | convert_method='event_histogram', 38 | grayscale=True, 39 | count_non_zero=False, 40 | background_mask=True, 41 | ) 42 | 43 | # model configs 44 | model = 'FTCLIP' 45 | clip_dict = dict( 46 | # 'RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32' 47 | # 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px' 48 | arch='ViT-B/16', # to compare with E-CLIP 49 | prompt='a point cloud image of a {}', 50 | agg_func='mean', # aggregate the logits over views 51 | lora=-1, # use LoRA fine-tuning, typically r = 4, 16 52 | only_conv1=False, # only tune the first conv layer 53 | only_bias=False, # only tune the bias terms 54 | only_ln=False, # only tune the LayerNorm layers 55 | only_cls_fc=False, # only tune the embedding projection head 56 | only_cls_token=False, # only tune the CLS token 57 | # lora >> bias > conv > fc > ln > CLS 58 | ) 59 | 60 | # adapter configs 61 | d_model = 256 62 | adapter_dict = dict( 63 | adapter_type='text-identity', 64 | in_dim=512, 65 | d_model=d_model, 66 | num_heads=d_model // 64, 67 | ffn_dim=d_model * 4, 68 | norm_first=True, 69 | num_layers=2, 70 | residual=0.8, 71 | ) 72 | 73 | # loss configs 74 | loss_dict = dict( 75 | use_logits_loss=True, # CE over mean logits 76 | use_probs_loss=False, # CE over mean probs 77 | ) 78 | 79 | ce_loss_w = 1. 80 | 81 | # save the model with the highest acc 82 | ckp_monitor = 'val/probs_acc' 83 | ckp_monitor_type = 'max' # 'max' or 'min' 84 | -------------------------------------------------------------------------------- /configs/ftclip/ft_text_fsclip_nin_params-vitb16.py: -------------------------------------------------------------------------------- 1 | from nerv.training import BaseParams 2 | 3 | 4 | class EventCLIPParams(BaseParams): 5 | project = 'EventCLIP' 6 | 7 | # training settings 8 | gpus = 1 9 | max_epochs = 50 10 | # max_epochs = 3 # following E-CLIP, only train 3 epochs on full dataset 11 | # also change `save_interval=0.05`, `eval_interval = 1` 12 | save_interval = 1 13 | eval_interval = 5 14 | save_epoch_end = False 15 | n_samples = 10 16 | 17 | # optimizer settings 18 | # Adam optimizer, Cosine decay with Warmup 19 | optimizer = 'Adam' 20 | lr = 2e-5 21 | clip_lr = lr / 10. 22 | warmup_steps_pct = 0.05 23 | 24 | # data settings 25 | dataset = 'n_imagenet' 26 | data_root = './data/N_Imagenet/' 27 | num_shots = None 28 | img_aug = True 29 | train_batch_size = 128 // gpus 30 | val_batch_size = train_batch_size * 2 31 | num_workers = 8 32 | 33 | # event2img conversion 34 | quantize_args = dict( 35 | max_imgs=2, 36 | N=70000, 37 | split_method='event_count', 38 | convert_method='event_histogram', 39 | grayscale=True, 40 | count_non_zero=False, 41 | background_mask=True, 42 | ) 43 | 44 | # model configs 45 | model = 'FTCLIP' 46 | clip_dict = dict( 47 | # 'RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32' 48 | # 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px' 49 | arch='ViT-B/16', # to compare with E-CLIP 50 | prompt='a point cloud image of a {}', 51 | agg_func='mean', # aggregate the logits over views 52 | lora=-1, # use LoRA fine-tuning, typically r = 4, 16 53 | only_conv1=False, # only tune the first conv layer 54 | only_bias=False, # only tune the bias terms 55 | only_ln=False, # only tune the LayerNorm layers 56 | only_cls_fc=False, # only tune the embedding projection head 57 | only_cls_token=False, # only tune the CLS token 58 | ) 59 | 60 | # adapter configs 61 | d_model = 256 62 | adapter_dict = dict( 63 | adapter_type='text-identity', 64 | in_dim=512, 65 | d_model=d_model, 66 | num_heads=d_model // 64, 67 | ffn_dim=d_model * 4, 68 | norm_first=True, 69 | num_layers=2, 70 | residual=0.95, 71 | ) 72 | 73 | # loss configs 74 | loss_dict = dict( 75 | use_logits_loss=True, # CE over mean logits 76 | use_probs_loss=False, # CE over mean probs 77 | ) 78 | 79 | ce_loss_w = 1. 80 | 81 | # save the model with the highest acc 82 | ckp_monitor = 'val/probs_acc' 83 | ckp_monitor_type = 'max' # 'max' or 'min' 84 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | vis/ 2 | wandb/ 3 | sbatch/ 4 | pretrained/ 5 | checkpoint/ 6 | checkpoints/ 7 | data/*/ 8 | data/N-Caltech101 9 | data/N-Cars 10 | data/N_Imagenet 11 | data/pseudo-N_Imagenet 12 | .idea/ 13 | 14 | # Byte-compiled / optimized / DLL files 15 | __pycache__/ 16 | *.py[cod] 17 | *$py.class 18 | 19 | # C extensions 20 | *.so 21 | 22 | # Distribution / packaging 23 | .Python 24 | build/ 25 | develop-eggs/ 26 | dist/ 27 | downloads/ 28 | eggs/ 29 | .eggs/ 30 | lib/ 31 | lib64/ 32 | parts/ 33 | sdist/ 34 | var/ 35 | wheels/ 36 | pip-wheel-metadata/ 37 | share/python-wheels/ 38 | *.egg-info/ 39 | .installed.cfg 40 | *.egg 41 | MANIFEST 42 | 43 | # PyInstaller 44 | # Usually these files are written by a python script from a template 45 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 46 | *.manifest 47 | *.spec 48 | 49 | # Installer logs 50 | pip-log.txt 51 | pip-delete-this-directory.txt 52 | 53 | # Unit test / coverage reports 54 | htmlcov/ 55 | .tox/ 56 | .nox/ 57 | .coverage 58 | .coverage.* 59 | .cache 60 | nosetests.xml 61 | coverage.xml 62 | *.cover 63 | *.py,cover 64 | .hypothesis/ 65 | .pytest_cache/ 66 | 67 | # Translations 68 | *.mo 69 | *.pot 70 | 71 | # Django stuff: 72 | *.log 73 | local_settings.py 74 | db.sqlite3 75 | db.sqlite3-journal 76 | 77 | # Flask stuff: 78 | instance/ 79 | .webassets-cache 80 | 81 | # Scrapy stuff: 82 | .scrapy 83 | 84 | # Sphinx documentation 85 | docs/_build/ 86 | 87 | # PyBuilder 88 | target/ 89 | 90 | # Jupyter Notebook 91 | .ipynb_checkpoints 92 | 93 | # IPython 94 | profile_default/ 95 | ipython_config.py 96 | 97 | # pyenv 98 | .python-version 99 | 100 | # pipenv 101 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 102 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 103 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 104 | # install all needed dependencies. 105 | #Pipfile.lock 106 | 107 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 108 | __pypackages__/ 109 | 110 | # Celery stuff 111 | celerybeat-schedule 112 | celerybeat.pid 113 | 114 | # SageMath parsed files 115 | *.sage.py 116 | 117 | # Environments 118 | .env 119 | .venv 120 | env/ 121 | venv/ 122 | ENV/ 123 | env.bak/ 124 | venv.bak/ 125 | 126 | # Spyder project settings 127 | .spyderproject 128 | .spyproject 129 | 130 | # Rope project settings 131 | .ropeproject 132 | 133 | # mkdocs documentation 134 | /site 135 | 136 | # mypy 137 | .mypy_cache/ 138 | .dmypy.json 139 | dmypy.json 140 | 141 | # Pyre type checker 142 | .pyre/ 143 | -------------------------------------------------------------------------------- /models/adapter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Adapter(nn.Module): 6 | """Base Adapter class. 7 | 8 | Handle common operations such as residual connection. 9 | """ 10 | 11 | def __init__(self, residual=True): 12 | super().__init__() 13 | 14 | assert isinstance(residual, (bool, float)) 15 | if isinstance(residual, bool): 16 | residual = 0.5 if residual else 0. 17 | if isinstance(residual, float): 18 | assert 0. <= residual <= 1. 19 | 20 | self.residual = residual 21 | 22 | def residual_add(self, in_feats, new_feats): 23 | """Perform residual connection.""" 24 | assert isinstance(self.residual, float) 25 | return in_feats * self.residual + new_feats * (1. - self.residual) 26 | 27 | def forward(self, *args, **kwargs): 28 | raise NotImplementedError 29 | 30 | @property 31 | def dtype(self): 32 | raise NotImplementedError 33 | 34 | 35 | class IdentityAdapter(Adapter): 36 | """Trivial Adapter that does nothing.""" 37 | 38 | def __init__(self, *args, **kwargs): 39 | super().__init__(residual=False) 40 | 41 | # dummy parameter to record the dtype & device 42 | self.dummy = nn.Parameter(torch.zeros(1), requires_grad=False) 43 | 44 | def forward(self, feats, valid_masks): 45 | """feats: [B, num_views, C].""" 46 | return feats 47 | 48 | @property 49 | def dtype(self): 50 | return self.dummy.dtype 51 | 52 | 53 | class TransformerAdapter(Adapter): 54 | """Transformer Adapter which is order-invariant.""" 55 | 56 | def __init__( 57 | self, 58 | in_dim, 59 | d_model=256, 60 | num_heads=4, 61 | ffn_dim=256 * 4, 62 | norm_first=True, 63 | num_layers=2, 64 | residual=False, 65 | ): 66 | super().__init__(residual=residual) 67 | 68 | self.d_model = d_model 69 | enc_layer = nn.TransformerEncoderLayer( 70 | d_model=d_model, 71 | nhead=num_heads, 72 | dim_feedforward=ffn_dim, 73 | norm_first=norm_first, 74 | batch_first=True, 75 | ) 76 | self.transformer_encoder = nn.TransformerEncoder( 77 | encoder_layer=enc_layer, num_layers=num_layers) 78 | 79 | self.in_proj = nn.Linear(in_dim, d_model) 80 | self.out_proj = nn.Linear(d_model, in_dim) 81 | 82 | def forward(self, feats, valid_masks): 83 | """Inter-view interaction via Attention. 84 | 85 | Args: 86 | feats: [B, num_views, C] 87 | valid_masks: [B, num_views], True for valid views. 88 | Should mask the Attention in Transformer accordingly. 89 | """ 90 | in_feats = feats 91 | 92 | # [B, num_views, d_model] 93 | feats = self.in_proj(feats) 94 | 95 | # [B, num_views, d_model] 96 | pad_masks = (~valid_masks) # True --> padded 97 | feats = self.transformer_encoder(feats, src_key_padding_mask=pad_masks) 98 | 99 | # [B, num_views, C] 100 | feats = self.out_proj(feats) 101 | 102 | # residual connection 103 | feats = self.residual_add(in_feats, feats) 104 | 105 | return feats 106 | 107 | @property 108 | def dtype(self): 109 | return self.in_proj.weight.dtype 110 | -------------------------------------------------------------------------------- /datasets/imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | 5 | from .caltech import NCaltech101 6 | 7 | 8 | def load_event(event_path): 9 | """Load event data from npz file.""" 10 | event = np.load(event_path)['event_data'] 11 | event = np.stack([ 12 | event['x'], 13 | event['y'], 14 | event['t'], 15 | event['p'].astype(np.uint8), 16 | ], 1) # [N, 4] 17 | 18 | event = event.astype(float) 19 | 20 | # Account for int-type timestamp 21 | event[:, 2] /= 1e6 22 | 23 | # Account for zero polarity 24 | if event[:, 3].min() >= -0.5: 25 | event[:, 3][event[:, 3] <= 0.5] = -1 26 | 27 | return event 28 | 29 | 30 | class NImageNet(NCaltech101): 31 | """Dataset class for N-ImageNet dataset.""" 32 | 33 | def __init__( 34 | self, 35 | root, 36 | augmentation=False, 37 | num_shots=None, 38 | ): 39 | super().__init__( 40 | root=root, 41 | augmentation=augmentation, 42 | num_shots=num_shots, 43 | repeat=False, 44 | new_cnames=None, 45 | ) 46 | 47 | # data stats 48 | self.resolution = (480, 640) 49 | self.max_t = 0.055 # max 50 | self.max_n = 135000 # 95th percentile 51 | 52 | # data augmentation 53 | self.flip_time = True 54 | 55 | # load folder name to class name mapping 56 | cur_dir = os.path.dirname(os.path.abspath(__file__)) 57 | label_map = os.path.join(cur_dir, 'files/CLIP-IN_ClassNames.txt') 58 | with open(label_map, 'r') as f: 59 | lines = f.readlines()[:1000] 60 | lines = [line.strip() for line in lines] 61 | """ 62 | n01440764 tench 63 | n01443537 goldfish 64 | n01484850 great white shark 65 | n01491361 tiger shark 66 | n01494475 hammerhead shark 67 | """ 68 | folder2name = { 69 | s.split(' ')[0]: ' '.join(s.split(' ')[1:]) 70 | for s in lines 71 | } 72 | self.folder2name = folder2name 73 | self.name2folder = {v: k for k, v in folder2name.items()} 74 | self.classes = [folder2name[c] for c in self.classes] 75 | 76 | @staticmethod 77 | def _load_events(event_path): 78 | """Load events from a file.""" 79 | return load_event(event_path).astype(np.float32) 80 | 81 | 82 | def build_n_imagenet_dataset( 83 | params, 84 | val_only=False, 85 | gen_data=False, 86 | subset=-1, 87 | ): 88 | """Build the N-ImageNet dataset.""" 89 | val_names = { 90 | 1: 'val_mode_1', 91 | 2: 'val_mode_5', 92 | 3: 'val_mode_6', 93 | 4: 'val_mode_7', 94 | 5: 'val_mode_3', 95 | 6: 'val_brightness_4', 96 | 7: 'val_brightness_5', 97 | 8: 'val_brightness_6', 98 | 9: 'val_brightness_7', 99 | } 100 | if subset > 0: 101 | val_set = val_names[subset] 102 | val_root = os.path.join(params.data_root, f'extracted_{val_set}') 103 | print('Using N-ImageNet subset:', val_set) 104 | else: 105 | val_root = os.path.join(params.data_root, 'extracted_val') 106 | print('Using normal N-ImageNet val set') 107 | 108 | # only build the test set 109 | test_set = NImageNet( 110 | root=val_root, 111 | augmentation=False, 112 | ) 113 | if val_only: 114 | assert not gen_data 115 | return test_set 116 | # build the training set for pseudo label generation 117 | if gen_data: 118 | return NImageNet( 119 | root=os.path.join(params.data_root, 'extracted_train'), 120 | augmentation=False, 121 | ) 122 | 123 | # build the training set 124 | train_set = NImageNet( 125 | root=os.path.join(params.data_root, 'extracted_train'), 126 | augmentation=True, 127 | num_shots=params.get('num_shots', None), 128 | ) 129 | return train_set, test_set 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EventCLIP 2 | 3 | [**EventCLIP: Adapting CLIP for Event-based Object Recognition**](https://github.com/Wuziyi616/EventCLIP)
4 | [Ziyi Wu](https://wuziyi616.github.io/), 5 | [Xudong Liu](https://www.linkedin.com/in/xudong-frank-liu-566513198/), 6 | [Igor Gilitschenski](https://tisl.cs.utoronto.ca/author/igor-gilitschenski/)
7 | _[arXiv'23](https://arxiv.org/abs/2306.06354) | 8 | [GitHub](https://github.com/Wuziyi616/EventCLIP) | 9 | [arXiv](https://arxiv.org/abs/2306.06354)_ 10 | 11 | ## Introduction 12 | 13 | This is the official PyTorch implementation for paper: [EventCLIP: Adapting CLIP for Event-based Object Recognition](https://arxiv.org/abs/2306.06354). 14 | The code contains: 15 | 16 | - Zero-shot EventCLIP inference on N-Caltech, N-Cars, N-ImageNet datasets 17 | - Few-shot adaptation of EventCLIP on the three datasets, with SOTA results in the low-data regime 18 | - Data-efficient fine-tuning of EventCLIP on N-Caltech & N-ImageNet, achieving superior accuracy over fully-trained baselines 19 | - Learning from unlabeled data on N-Caltech & N-ImageNet, including both fully unsupervised and semi-supervised settings 20 | 21 | ### Motivation 22 | 23 | [Event cameras](https://tub-rip.github.io/eventvision2023/#null) are bio-inspired low-latency and energy-efficient sensors, which have gained significant interest recently. 24 | However, due to the lack of large-scale datasets, the event-based vision community cannot enjoy the recent success of foundation models in RGB vision. 25 | This paper thus seeks to adapt one of the most impactful VLM, [CLIP](https://openai.com/research/clip), to recognize event data. 26 | We study common practice in data-efficient model adaptation, and propose a general framework named EventCLIP. 27 | The overall pipeline is shown below: 28 | 29 |

EventCLIP pipeline

30 | 31 | ## Update 32 | 33 | - 2023.9.14: Release code for learning with unlabeled data 34 | - 2023.7.17: Release fine-tuning code 35 | - 2023.5.17: Initial code release! 36 | 37 | ## Installation 38 | 39 | Please refer to [install.md](docs/install.md) for step-by-step guidance on how to install the packages. 40 | 41 | ## Experiments 42 | 43 | **This codebase is tailored to [Slurm](https://slurm.schedmd.com/documentation.html) GPU clusters with preemption mechanism.** 44 | For the configs, we mainly use A40 with 40GB memory (though many experiments don't require so much memory). 45 | Please modify the code accordingly if you are using other hardware settings: 46 | 47 | - Please go through `train.py` and change the fields marked by `TODO:` 48 | - Please read the config file for the model you want to train. 49 | We use DDP with multiple GPUs to accelerate training. 50 | You can use less GPUs to achieve a better memory-speed trade-off 51 | 52 | ### Dataset Preparation 53 | 54 | Please refer to [data.md](docs/data.md) for dataset downloading and pre-processing. 55 | 56 | ### Reproduce Results 57 | 58 | Please see [benchmark.md](docs/benchmark.md) for detailed instructions on how to reproduce our results in the paper. 59 | 60 | ## Possible Issues 61 | 62 | See the troubleshooting section of [nerv](https://github.com/Wuziyi616/nerv#possible-issues) for potential issues. 63 | 64 | Please open an issue if you encounter any errors running the code. 65 | 66 | ## Citation 67 | 68 | Please cite our paper if you find it useful in your research: 69 | 70 | ``` 71 | @article{wu2023eventclip, 72 | title={{EventCLIP}: Adapting CLIP for Event-based Object Recognition}, 73 | author={Wu, Ziyi and Liu, Xudong and Gilitschenski, Igor}, 74 | journal={arXiv preprint arXiv:2306.06354}, 75 | year={2023} 76 | } 77 | ``` 78 | 79 | ## Acknowledgement 80 | 81 | We thank the authors of [CLIP](https://github.com/openai/CLIP), [EST](https://github.com/uzh-rpg/rpg_event_representation_learning), [n_imagenet](https://github.com/82magnolia/n_imagenet), [PointCLIP](https://github.com/ZrrSkywalker/PointCLIP), [LoRA](https://github.com/microsoft/LoRA) for opening source their wonderful works. 82 | 83 | ## License 84 | 85 | EventCLIP is released under the MIT License. See the LICENSE file for more details. 86 | 87 | ## Contact 88 | 89 | If you have any questions about the code, please contact Ziyi Wu dazitu616@gmail.com 90 | -------------------------------------------------------------------------------- /datasets/vis.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import numpy as np 4 | 5 | 6 | def make_event_histogram(x, y, p, red, blue, shape, thresh=10., **kwargs): 7 | """Event polarity histogram.""" 8 | # count the number of positive and negative events per pixel 9 | H, W = shape 10 | pos_x, pos_y = x[p > 0].astype(np.int32), y[p > 0].astype(np.int32) 11 | pos_count = np.bincount(pos_x + pos_y * W, minlength=H * W).reshape(H, W) 12 | neg_x, neg_y = x[p < 0].astype(np.int32), y[p < 0].astype(np.int32) 13 | neg_count = np.bincount(neg_x + neg_y * W, minlength=H * W).reshape(H, W) 14 | hist = np.stack([pos_count, neg_count], axis=-1) # [H, W, 2] 15 | 16 | # remove hotpixels, i.e. pixels with event num > thresh * std + mean 17 | if thresh > 0: 18 | if kwargs.get('count_non_zero', False): 19 | mean = hist[hist > 0].mean() 20 | std = hist[hist > 0].std() 21 | else: 22 | mean = hist.mean() 23 | std = hist.std() 24 | hist[hist > thresh * std + mean] = 0 25 | 26 | # normalize 27 | hist = hist.astype(np.float32) / hist.max() # [H, W, 2] 28 | 29 | # colorize 30 | cmap = np.stack([red, blue], axis=0).astype(np.float32) # [2, 3] 31 | img = hist @ cmap # [H, W, 3] 32 | 33 | # alpha-masking with pure white background 34 | if kwargs.get('background_mask', True): 35 | weights = np.clip(hist.sum(-1, keepdims=True), a_min=0, a_max=1) 36 | background = np.ones_like(img) * 255. 37 | img = img * weights + background * (1. - weights) 38 | 39 | img = np.round(img).astype(np.uint8) # [H, W, 3], np.uint8 in (0, 255) 40 | 41 | return img 42 | 43 | 44 | def parse_events(events): 45 | """Read (x,y,t,p) from input events (can be np.array or dict).""" 46 | if isinstance(events, dict): 47 | x, y, t, p = events['x'], events['y'], events['t'], events['p'] 48 | else: 49 | x, y, t, p = events[:, 0], events[:, 1], events[:, 2], events[:, 3] 50 | x, y, p = x.astype(np.int32), y.astype(np.int32), p.astype(np.int32) 51 | t_us = t * 1e6 # convert to us unit 52 | return x, y, t_us, p 53 | 54 | 55 | def split_event_count(t, N=30000): 56 | """Split the events according to event count.""" 57 | tot_cnt = len(t) 58 | 59 | # if the event count is too small, just return the whole chunk 60 | if tot_cnt < N: 61 | return [0], [tot_cnt], [t[0]], [t[-1]] 62 | 63 | # find the start and end time of each event chunk w.r.t event index 64 | idx = np.arange(0, tot_cnt, N).tolist() 65 | idx1, idx0 = idx[1:], idx[:-1] 66 | # add the last index if the last chunk of events is not so small 67 | if tot_cnt - idx[-1] > N * 0.5: 68 | idx0.append(tot_cnt - N) 69 | idx1.append(tot_cnt) 70 | t0, t1 = t[idx0], t[np.array(idx1) - 1] 71 | 72 | return idx0, idx1, t0, t1 73 | 74 | 75 | def events2frames( 76 | events, # [N, 4 (x,y,t,p)] 77 | split_method, # 'event_count' 78 | convert_method, # 'event_histogram' 79 | shape=(180, 240), 80 | **kwargs, 81 | ): 82 | """Convert events to 2D frames.""" 83 | # some additional arguments 84 | grayscale = kwargs.pop('grayscale', True) # True, False 85 | 86 | # parse different input formats 87 | x, y, t, p = parse_events(events) 88 | 89 | # split the events into different chunks 90 | assert split_method == 'event_count' 91 | N = int(kwargs['N']) 92 | idx0, idx1, t0, t1 = split_event_count(t, N) 93 | 94 | # color map for pos and neg events 95 | if grayscale: 96 | if isinstance(grayscale, bool): 97 | v = 127 98 | else: 99 | v = np.array(grayscale) # values in addition to 127 100 | red = np.round(np.ones(3) * v).astype(np.uint8) 101 | blue = np.round(np.ones(3) * v).astype(np.uint8) 102 | else: 103 | red = np.array([255, 0, 0], dtype=np.uint8) 104 | blue = np.array([0, 0, 255], dtype=np.uint8) 105 | 106 | frames = [] 107 | for t_idx, (i0, i1) in enumerate(zip(idx0, idx1)): 108 | xx, yy, pp, tt = x[i0:i1], y[i0:i1], p[i0:i1], t[i0:i1] 109 | if convert_method == 'event_histogram': 110 | frame = make_event_histogram(xx, yy, pp, red, blue, shape, 111 | **kwargs) 112 | else: 113 | raise NotImplementedError(f'{convert_method} not implemented!') 114 | frames.append(copy.deepcopy(frame)) 115 | frames = np.stack(frames) # [N, H, W, 3] 116 | 117 | return frames # [N, H, W, 3] 118 | -------------------------------------------------------------------------------- /scripts/sbatch_run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # SBATCH file can't directly take command args 4 | # as a workaround, I first use a sh script to read in args 5 | # and then create a new .slrm file for SBATCH execution 6 | 7 | ####################################################################### 8 | # An example usage: 9 | # GPUS=1 CPUS_PER_GPU=8 MEM_PER_CPU=5 QOS=scavenger ./scripts/sbatch_run.sh rtx6000 \ 10 | # test-sbatch ./train.py ddp --params params.py --fp16 --ddp --cudnn 11 | ####################################################################### 12 | 13 | # read args from command line 14 | GPUS=${GPUS:-1} 15 | CPUS_PER_GPU=${CPUS_PER_GPU:-8} 16 | MEM_PER_CPU=${MEM_PER_CPU:-5} 17 | QOS=${QOS:-scavenger} 18 | TIME=${TIME:-96:00:00} 19 | if [[ $QOS == "cpu" ]]; then 20 | QOS="cpu_qos" 21 | GPUS=0 22 | CPUS_PER_TASK=$CPUS_PER_GPU 23 | else 24 | CPUS_PER_TASK=$((GPUS * CPUS_PER_GPU)) 25 | fi 26 | 27 | # python args 28 | PY_ARGS=${@:5} 29 | PARTITION=$1 30 | JOB_NAME=$2 31 | PY_FILE=$3 32 | DDP=$4 33 | 34 | # create log files 35 | SLRM_NAME="${JOB_NAME/\//"_"}" 36 | LOG_DIR=checkpoint/"$(basename -- $JOB_NAME)" 37 | DATETIME=$(date "+%Y-%m-%d_%H:%M:%S") 38 | LOG_FILE=$LOG_DIR/${DATETIME}.log 39 | SLRM_LOG="${LOG_DIR}/slurm.log" 40 | 41 | # set up log output folder 42 | mkdir -p $LOG_DIR 43 | 44 | # create new .slrm file 45 | slrm_file="run-${SLRM_NAME}.slrm" 46 | 47 | # python runner for DDP 48 | if [[ $DDP == "ddp" ]]; then 49 | PORT=$((29501 + $RANDOM % 100)) # randomly select a port 50 | PYTHON="python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT" 51 | else 52 | PYTHON="python" 53 | fi 54 | 55 | # get the max possible time limit from QOS 56 | # please refer to https://support.vectorinstitute.ai/Vaughan_slurm_changes 57 | get_time() { 58 | local req_time="$1" 59 | if [[ $req_time == "0" ]]; then 60 | req_time="96:00:00" 61 | fi 62 | local qos="$2" # make it lower case 63 | qos="${qos,,}" 64 | if [[ $qos == "cpu_qos" ]]; then 65 | max_time="96:00:00" 66 | elif [[ $qos == "normal" ]]; then 67 | max_time="16:00:00" 68 | elif [[ $qos == "m" ]]; then 69 | max_time="12:00:00" 70 | elif [[ $qos == "m2" ]]; then 71 | max_time="08:00:00" 72 | elif [[ $qos == "m3" ]]; then 73 | max_time="04:00:00" 74 | elif [[ $qos == "m4" ]]; then 75 | max_time="02:00:00" 76 | elif [[ $qos == "m5" ]]; then 77 | max_time="01:00:00" 78 | elif [[ $qos == "long" ]]; then 79 | max_time="48:00:00" 80 | elif [[ $qos == "deadline" ]]; then 81 | max_time="00:00:00" 82 | elif [[ $qos == "high" ]]; then 83 | max_time="08:00:00" 84 | elif [[ $qos == "scavenger" ]]; then 85 | max_time="96:00:00" 86 | else 87 | echo "Invalid QOS $qos" 88 | return # this will trigger `Invalid --time specification` and fail the job 89 | fi 90 | # return the smaller one 91 | # Remove colons and compare as numbers 92 | num_req_time=$(echo "${req_time//:/}" | sed 's/^0*//') 93 | num_max_time=$(echo "${max_time//:/}" | sed 's/^0*//') 94 | if [[ $num_req_time -lt $num_max_time ]]; then 95 | echo $req_time 96 | else 97 | echo $max_time 98 | fi 99 | } 100 | TIME=$(get_time $TIME $QOS) 101 | echo "Run with QOS=$QOS, TIME=$TIME" 102 | 103 | # write to new file 104 | echo "#!/bin/bash 105 | 106 | # set up SBATCH args 107 | #SBATCH --job-name=$SLRM_NAME 108 | #SBATCH --output=$LOG_FILE 109 | #SBATCH --error=$LOG_FILE 110 | #SBATCH --open-mode=append 111 | #SBATCH --partition=$PARTITION # self-explanatory, set to your preference (e.g. gpu or cpu on MaRS, p100, t4, or cpu on Vaughan) 112 | #SBATCH --cpus-per-task=$CPUS_PER_TASK # self-explanatory, set to your preference 113 | #SBATCH --ntasks=1 114 | #SBATCH --ntasks-per-node=1 115 | #SBATCH --mem-per-cpu=${MEM_PER_CPU}G # self-explanatory, set to your preference 116 | #SBATCH --gres=gpu:$GPUS # NOTE: you need a GPU for CUDA support; self-explanatory, set to your preference 117 | #SBATCH --nodes=1 118 | #SBATCH --qos=$QOS # for 'high' and 'deadline' QoS, refer to https://support.vectorinstitute.ai/AboutVaughan2 119 | #SBATCH --time=$TIME # running time limit, 0 as unlimited 120 | 121 | # log some necessary environment params 122 | echo \$SLURM_JOB_ID >> $LOG_FILE # log the job id 123 | echo \$SLURM_JOB_PARTITION >> $LOG_FILE # log the job partition 124 | 125 | echo $CONDA_PREFIX >> $LOG_FILE # log the active conda environment 126 | 127 | python --version >> $LOG_FILE # log Python version 128 | gcc --version >> $LOG_FILE # log GCC version 129 | nvcc --version >> $LOG_FILE # log NVCC version 130 | 131 | # run python file 132 | $PYTHON $PY_FILE $PY_ARGS >> $LOG_FILE # the script above, with its standard output appended log file 133 | 134 | " >> ./$slrm_file 135 | 136 | # run the created file 137 | job_id=$(sbatch --parsable $slrm_file) 138 | echo "Submitted batch job $job_id" 139 | 140 | sleep 0.5 141 | if [[ $job_id ]]; then # successfully submitted the job 142 | ./scripts/resubmit_failed_job.sh $job_id $SLRM_NAME $SLRM_LOG 143 | else # failed to submit the job 144 | rm -f run-${SLRM_NAME}.slrm 145 | echo "Failed to submit job $SLRM_NAME" 146 | fi 147 | -------------------------------------------------------------------------------- /datasets/imagenet_mini.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .caltech import get_real_path 4 | from .imagenet import NImageNet 5 | 6 | # N-ImageNet (Mini) subset, taken from https://arxiv.org/pdf/2308.09383.pdf 7 | # Note that this is slightly different from the Mini-ImageNet used in e.g. MAML 8 | MINI_NAMES = [ 9 | "hamster", "academic gown", "airship", "jackfruit", "barbershop", 10 | "cocktail shaker", "Komodo dragon", "sunglasses", "grey fox", "cello", 11 | "comic book", "goldfish", "Bloodhound", "porcupine", "jaguar", "kingsnake", 12 | "altar", "water buffalo", "chiton", "scarf", "storage chest", "tool kit", 13 | "sea anemone", "Border Terrier", "menu", "picket fence", "forklift", 14 | "yellow lady's slipper", "chameleon", "dragonfly", "Pomeranian", 15 | "European garden spider", "Airedale Terrier", "frilled-necked lizard", 16 | "black stork", "valley", "radio telescope", "leopard", "crossword", 17 | "Australian Terrier", "Shih Tzu", "husky", "can opener", "artichoke", 18 | "assault rifle", "fountain pen", "harvestman", "parallel bars", 19 | "harmonica", "half-track", "snoek fish", "pencil sharpener", "submarine", 20 | "muzzle", "eastern diamondback rattlesnake", "Miniature Schnauzer", 21 | "missile", "Komondor", "grand piano", "website", "king penguin", "canoe", 22 | "red-breasted merganser", "trolleybus", "quail", "poke bonnet", 23 | "King Charles Spaniel", "race car", "Malinois", "solar thermal collector", 24 | "slug", "bucket", "dung beetle", "Asian elephant", "window screen", 25 | "Flat-Coated Retriever", "steel drum", "snowplow", "handkerchief", 26 | "tailed frog", "church", "Chesapeake Bay Retriever", "Christmas stocking", 27 | "hatchet", "hair clip", "vulture", "sidewinder rattlesnake", 28 | "oscilloscope", "worm snake", "eel", "wok", "planetarium", 29 | "Old English Sheepdog", "platypus", "Pembroke Welsh Corgi", 30 | "alligator lizard", "consomme", "African rock python", "hot tub", 31 | "Tibetan Mastiff" 32 | ] 33 | 34 | 35 | class NImageNetMini(NImageNet): 36 | """Dataset class for N-ImageNet (Mini) dataset.""" 37 | 38 | def __init__( 39 | self, 40 | root, 41 | augmentation=False, 42 | num_shots=None, 43 | repeat=True, 44 | ): 45 | root = get_real_path(root) 46 | self.root = root 47 | # TODO: a hack for identifying generated pseudo labeled datasets 48 | self.is_pseudo = 'pseudo' in root 49 | if self.is_pseudo: 50 | print('Using pseudo labeled dataset!') 51 | 52 | # data stats 53 | self.resolution = (480, 640) 54 | self.max_t = 0.055 # max 55 | self.max_n = 135000 # 95th percentile 56 | 57 | # data augmentation 58 | self.augmentation = augmentation 59 | self.flip_time = True 60 | self.max_shift = 20 61 | 62 | # load folder name to class name mapping 63 | cur_dir = os.path.dirname(os.path.abspath(__file__)) 64 | label_map = os.path.join(cur_dir, 'files/CLIP-IN_ClassNames.txt') 65 | with open(label_map, 'r') as f: 66 | lines = f.readlines()[:1000] 67 | lines = [line.strip() for line in lines] 68 | """ 69 | n01440764 tench 70 | n01443537 goldfish 71 | n01484850 great white shark 72 | n01491361 tiger shark 73 | n01494475 hammerhead shark 74 | """ 75 | 76 | folder2name = { 77 | s.split(' ')[0]: ' '.join(s.split(' ')[1:]) 78 | for s in lines 79 | } 80 | # only take a subset of 100 classes 81 | folder2name = {k: v for k, v in folder2name.items() if v in MINI_NAMES} 82 | assert len(folder2name) == 100 == len(MINI_NAMES) 83 | self.classes = list(folder2name.keys()) 84 | self.folder2name = folder2name 85 | self.name2folder = {v: k for k, v in folder2name.items()} 86 | 87 | # few-shot cls 88 | self.num_shots = num_shots # number of labeled data per class 89 | self.few_shot = (num_shots is not None and num_shots > 0) 90 | if self.few_shot: 91 | assert 'train' in root.lower(), 'Only sample data in training set' 92 | self.repeat = repeat 93 | 94 | self.labeled_files, self.labels = self._get_sample_idx() 95 | assert len(self.labeled_files) == len(self.labels) 96 | 97 | # finally, get semantically meaningful class names 98 | self.classes = [folder2name[c] for c in self.classes] 99 | assert all(c in self.classes for c in MINI_NAMES) and \ 100 | len(self.classes) == 100 101 | self.new_cnames = None 102 | 103 | 104 | def build_n_imagenet_mini_dataset(params, val_only=False, gen_data=False): 105 | """Build the N-ImageNet (Mini) dataset.""" 106 | # only build the test set 107 | test_set = NImageNetMini( 108 | root=os.path.join(params.data_root, 'extracted_val'), 109 | augmentation=False, 110 | ) 111 | if val_only: 112 | assert not gen_data, 'Only generate pseudo labels on the training set' 113 | return test_set 114 | # build the training set for pseudo label generation 115 | if gen_data: 116 | return NImageNetMini( 117 | root=os.path.join(params.data_root, 'extracted_train'), 118 | augmentation=False, 119 | ) 120 | 121 | # build the training set 122 | train_set = NImageNetMini( 123 | root=os.path.join(params.data_root, 'extracted_train'), 124 | augmentation=True, 125 | num_shots=params.get('num_shots', None), 126 | repeat=params.get('repeat_data', True), 127 | ) 128 | return train_set, test_set 129 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """EventCLIP training script""" 2 | 3 | import os 4 | import sys 5 | import pwd 6 | import importlib 7 | import argparse 8 | import wandb 9 | 10 | import torch 11 | 12 | import clip 13 | 14 | from nerv.utils import mkdir_or_exist 15 | from nerv.training import BaseDataModule, find_old_slurm_id 16 | 17 | from models import build_model 18 | from method import build_method 19 | from datasets import build_dataset 20 | 21 | 22 | def main(params): 23 | # have to load CLIP model first 24 | arch = params.clip_dict.pop('arch') 25 | device = 'cuda' 26 | model, preprocess = clip.load(arch, device=device) 27 | # cast weights to FP32 28 | for p in model.parameters(): 29 | p.data = p.data.float() 30 | 31 | # build dataset 32 | params.data_transforms = preprocess 33 | train_set, val_set = build_dataset(params) 34 | datamodule = BaseDataModule( 35 | params, train_set=train_set, val_set=val_set, use_ddp=params.ddp) 36 | 37 | # build model 38 | params.clip_dict['clip_model'] = model 39 | params.clip_dict['class_names'] = train_set.classes 40 | params.resolution = train_set.resolution 41 | params.class_names = train_set.classes 42 | params.adapter_dict['in_dim'] = model.visual.output_dim 43 | model = build_model(params) 44 | 45 | # create checkpoint dir 46 | exp_name = os.path.basename(args.params) 47 | ckp_path = os.path.join('checkpoint', exp_name, 'models') 48 | if args.local_rank == 0: 49 | mkdir_or_exist(os.path.dirname(ckp_path)) 50 | 51 | # on clusters, quota under user dir is usually limited 52 | # soft link to save the weights in temp space for checkpointing 53 | # e.g. on our cluster, the temp dir is /checkpoint/$USR/$SLURM_JOB_ID/ 54 | # TODO: modify this if you are not running on clusters 55 | SLURM_JOB_ID = os.environ.get('SLURM_JOB_ID') 56 | if os.path.exists(ckp_path): 57 | SLURM_JOB_ID = find_old_slurm_id(ckp_path) 58 | else: 59 | if SLURM_JOB_ID: 60 | os.system(r'ln -s /checkpoint/{}/{}/ {}'.format( 61 | pwd.getpwuid(os.getuid())[0], SLURM_JOB_ID, ckp_path)) 62 | else: 63 | os.makedirs(ckp_path, exist_ok=True) 64 | 65 | # it's not good to hard-code the wandb id 66 | # but on preemption clusters, we want the job to resume the same wandb 67 | # process after resuming training (i.e. drawing the same graph) 68 | # so we have to keep the same wandb id 69 | # TODO: modify this if you are not running on preemption clusters 70 | preemption = True 71 | if SLURM_JOB_ID and preemption: 72 | logger_id = logger_name = f'{exp_name}-{SLURM_JOB_ID}' 73 | else: 74 | logger_name = exp_name 75 | logger_id = None 76 | 77 | wandb.init( 78 | project=params.project, 79 | name=logger_name, 80 | id=logger_id, 81 | dir=ckp_path, 82 | ) 83 | 84 | method = build_method( 85 | model=model, 86 | datamodule=datamodule, 87 | params=params, 88 | ckp_path=ckp_path, 89 | local_rank=args.local_rank, 90 | use_ddp=args.ddp, 91 | use_fp16=args.fp16, 92 | ) 93 | 94 | method.fit( 95 | resume_from=args.weight, san_check_val_step=params.san_check_val_step) 96 | 97 | 98 | if __name__ == "__main__": 99 | parser = argparse.ArgumentParser(description='EventCLIP') 100 | parser.add_argument('--params', type=str, required=True) 101 | parser.add_argument('--num_shots', type=int, default=-1) 102 | parser.add_argument('--N', type=int, default=-1) 103 | parser.add_argument('--weight', type=str, default='', help='load weight') 104 | parser.add_argument('--fp16', action='store_true') 105 | parser.add_argument('--ddp', action='store_true') 106 | parser.add_argument('--cudnn', action='store_true') 107 | parser.add_argument('--local_rank', type=int, default=0) 108 | parser.add_argument('--local-rank', type=int, default=0) 109 | args = parser.parse_args() 110 | 111 | if args.params.endswith('.py'): 112 | args.params = args.params[:-3] 113 | sys.path.append(os.path.dirname(args.params)) 114 | params = importlib.import_module(os.path.basename(args.params)) 115 | params = params.EventCLIPParams() 116 | params.ddp = args.ddp 117 | 118 | assert params.model != 'ZSCLIP', \ 119 | 'zero-shot EventCLIP does not require training' 120 | 121 | if args.N > 0: 122 | params.quantize_args['N'] = int(args.N * 1000) 123 | args.params = args.params + f'-N_{args.N}' 124 | 125 | if args.num_shots > 0: 126 | params.num_shots = args.num_shots 127 | args.params = args.params + f'-{args.num_shots}shot' 128 | 129 | # adjust the batch size since N-Cars only have 2 classes 130 | if params.dataset == 'n_cars': 131 | params.train_batch_size = min( 132 | params.num_shots * 2, # 2 classes 133 | params.train_batch_size) 134 | print(f'Set {params.train_batch_size=} for N-Cars') 135 | if params.dataset == 'n_imagenet_mini': 136 | params.train_batch_size = min( 137 | params.num_shots * 100, # 100 classes 138 | params.train_batch_size) 139 | print(f'Set {params.train_batch_size=} for N-ImageNet (Mini)') 140 | 141 | if args.fp16: 142 | print('INFO: using FP16 training!') 143 | if args.ddp: 144 | print('INFO: using DDP training!') 145 | if args.cudnn: 146 | torch.backends.cudnn.benchmark = True 147 | print('INFO: using cudnn benchmark!') 148 | 149 | main(params) 150 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: eventclip 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _openmp_mutex=5.1=1_gnu 8 | - blas=1.0=mkl 9 | - brotlipy=0.7.0=py39h27cfd23_1003 10 | - bzip2=1.0.8=h7b6447c_0 11 | - ca-certificates=2023.08.22=h06a4308_0 12 | - certifi=2023.7.22=py39h06a4308_0 13 | - cffi=1.15.1=py39h5eee18b_3 14 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 15 | - cryptography=41.0.3=py39hdda0065_0 16 | - cudatoolkit=11.3.1=h2bc3f7f_2 17 | - ffmpeg=4.3=hf484d3e_0 18 | - freetype=2.12.1=h4a9f257_0 19 | - giflib=5.2.1=h5eee18b_3 20 | - gmp=6.2.1=h295c915_3 21 | - gnutls=3.6.15=he1e5248_0 22 | - idna=3.4=py39h06a4308_0 23 | - intel-openmp=2023.1.0=hdb19cb5_46305 24 | - jpeg=9e=h5eee18b_1 25 | - lame=3.100=h7b6447c_0 26 | - lcms2=2.12=h3be6417_0 27 | - ld_impl_linux-64=2.38=h1181459_1 28 | - lerc=3.0=h295c915_0 29 | - libdeflate=1.17=h5eee18b_0 30 | - libffi=3.4.4=h6a678d5_0 31 | - libgcc-ng=11.2.0=h1234567_1 32 | - libgomp=11.2.0=h1234567_1 33 | - libiconv=1.16=h7f8727e_2 34 | - libidn2=2.3.4=h5eee18b_0 35 | - libpng=1.6.39=h5eee18b_0 36 | - libstdcxx-ng=11.2.0=h1234567_1 37 | - libtasn1=4.19.0=h5eee18b_0 38 | - libtiff=4.5.1=h6a678d5_0 39 | - libunistring=0.9.10=h27cfd23_0 40 | - libwebp=1.3.2=h11a3e52_0 41 | - libwebp-base=1.3.2=h5eee18b_0 42 | - lz4-c=1.9.4=h6a678d5_0 43 | - mkl=2023.1.0=h213fc3f_46343 44 | - mkl-service=2.4.0=py39h5eee18b_1 45 | - mkl_fft=1.3.8=py39h5eee18b_0 46 | - mkl_random=1.2.4=py39hdb19cb5_0 47 | - ncurses=6.4=h6a678d5_0 48 | - nettle=3.7.3=hbbd107a_1 49 | - numpy=1.25.2=py39h5f9d8c6_0 50 | - numpy-base=1.25.2=py39hb5e798b_0 51 | - openh264=2.1.1=h4ff587b_0 52 | - openssl=3.0.11=h7f8727e_2 53 | - pillow=9.4.0=py39h6a678d5_1 54 | - pip=23.2.1=py39h06a4308_0 55 | - pycparser=2.21=pyhd3eb1b0_0 56 | - pyopenssl=23.2.0=py39h06a4308_0 57 | - pysocks=1.7.1=py39h06a4308_0 58 | - python=3.9.17=h955ad1f_0 59 | - pytorch=1.12.1=py3.9_cuda11.3_cudnn8.3.2_0 60 | - pytorch-mutex=1.0=cuda 61 | - readline=8.2=h5eee18b_0 62 | - requests=2.31.0=py39h06a4308_0 63 | - setuptools=68.0.0=py39h06a4308_0 64 | - sqlite=3.41.2=h5eee18b_0 65 | - tbb=2021.8.0=hdb19cb5_0 66 | - tk=8.6.12=h1ccaba5_0 67 | - torchaudio=0.12.1=py39_cu113 68 | - torchvision=0.13.1=py39_cu113 69 | - typing_extensions=4.7.1=py39h06a4308_0 70 | - urllib3=1.26.16=py39h06a4308_0 71 | - wheel=0.38.4=py39h06a4308_0 72 | - xz=5.4.2=h5eee18b_0 73 | - zlib=1.2.13=h5eee18b_0 74 | - zstd=1.5.5=hc292b87_0 75 | - pip: 76 | - addict==2.4.0 77 | - aiohttp==3.8.5 78 | - aiosignal==1.3.1 79 | - ansi2html==1.8.0 80 | - appdirs==1.4.4 81 | - asttokens==2.4.0 82 | - async-timeout==4.0.3 83 | - attrs==23.1.0 84 | - backcall==0.2.0 85 | - beautifulsoup4==4.12.2 86 | - click==8.1.7 87 | - clip==1.0 88 | - comm==0.1.4 89 | - configargparse==1.7 90 | - contourpy==1.1.1 91 | - cycler==0.11.0 92 | - dash==2.13.0 93 | - dash-core-components==2.0.0 94 | - dash-html-components==2.0.0 95 | - dash-table==5.0.0 96 | - decorator==4.4.2 97 | - docker-pycreds==0.4.0 98 | - exceptiongroup==1.1.3 99 | - executing==1.2.0 100 | - fastjsonschema==2.18.0 101 | - filelock==3.12.4 102 | - flask==2.2.5 103 | - fonttools==4.42.1 104 | - frozenlist==1.4.0 105 | - fsspec==2023.9.2 106 | - ftfy==6.1.1 107 | - gdown==4.7.1 108 | - gitdb==4.0.10 109 | - gitpython==3.1.37 110 | - imageio==2.31.4 111 | - imageio-ffmpeg==0.4.9 112 | - importlib-metadata==6.8.0 113 | - importlib-resources==6.1.0 114 | - ipython==8.15.0 115 | - ipywidgets==8.1.1 116 | - itsdangerous==2.1.2 117 | - jedi==0.19.0 118 | - jinja2==3.1.2 119 | - joblib==1.3.2 120 | - jsonschema==4.19.1 121 | - jsonschema-specifications==2023.7.1 122 | - jupyter-core==5.3.1 123 | - jupyterlab-widgets==3.0.9 124 | - kiwisolver==1.4.5 125 | - lightning-utilities==0.9.0 126 | - markupsafe==2.1.3 127 | - matplotlib==3.8.0 128 | - matplotlib-inline==0.1.6 129 | - moviepy==1.0.3 130 | - multidict==6.0.4 131 | - nbformat==5.7.0 132 | - nest-asyncio==1.5.8 133 | - open3d==0.17.0 134 | - opencv-python==4.8.0.76 135 | - packaging==23.1 136 | - pandas==2.1.1 137 | - parso==0.8.3 138 | - pathtools==0.1.2 139 | - pexpect==4.8.0 140 | - pickleshare==0.7.5 141 | - platformdirs==3.10.0 142 | - plotly==5.17.0 143 | - proglog==0.1.10 144 | - prompt-toolkit==3.0.39 145 | - protobuf==4.24.3 146 | - psutil==5.9.5 147 | - ptyprocess==0.7.0 148 | - pure-eval==0.2.2 149 | - pygments==2.16.1 150 | - pyparsing==3.1.1 151 | - pyquaternion==0.9.9 152 | - python-dateutil==2.8.2 153 | - pytorch-lightning==1.8.6 154 | - pytz==2023.3.post1 155 | - pyyaml==6.0.1 156 | - referencing==0.30.2 157 | - regex==2023.8.8 158 | - retrying==1.3.4 159 | - rpds-py==0.10.3 160 | - scikit-learn==1.3.1 161 | - scipy==1.11.2 162 | - sentry-sdk==1.31.0 163 | - setproctitle==1.3.2 164 | - six==1.16.0 165 | - smmap==5.0.1 166 | - soupsieve==2.5 167 | - stack-data==0.6.2 168 | - tenacity==8.2.3 169 | - tensorboardx==2.6.2.2 170 | - threadpoolctl==3.2.0 171 | - torchmetrics==0.11.4 172 | - tqdm==4.66.1 173 | - traitlets==5.10.0 174 | - tzdata==2023.3 175 | - wandb==0.15.11 176 | - wcwidth==0.2.6 177 | - werkzeug==2.2.3 178 | - widgetsnbextension==4.0.9 179 | - yarl==1.9.2 180 | - zipp==3.17.0 181 | -------------------------------------------------------------------------------- /datasets/event2img.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | from PIL import Image 4 | 5 | import torch 6 | from torch.utils.data import Dataset 7 | 8 | from .vis import events2frames 9 | from .augment import RandAugment, InterpolationMode 10 | from .utils import random_time_flip_events as tflip_events 11 | from .utils import random_flip_events_along_x as hflip_events 12 | 13 | 14 | class Event2ImageDataset(Dataset): 15 | """A wrapper for EventDataset that converts events to 2D images.""" 16 | 17 | def __init__( 18 | self, 19 | transforms, 20 | event_dataset, 21 | quantize_args=dict( 22 | max_imgs=2, 23 | split_method='event_count', 24 | convert_method='event_histogram', 25 | N=30000, 26 | grayscale=True, 27 | count_non_zero=False, # hotpixel statistics 28 | background_mask=True, # apply white background via alpha-masking 29 | ), 30 | augment=False, 31 | tta=False, 32 | ): 33 | 34 | # data augmentation 35 | self.augment = augment 36 | if augment: 37 | self.augmentation = RandAugment( 38 | num_ops=2, # follow common practice 39 | interpolation=InterpolationMode.BICUBIC, # CLIP uses bicubic 40 | fill=[255, 255, 255] # pad with white pixels 41 | if quantize_args['background_mask'] else [0, 0, 0], 42 | ) 43 | 44 | # transforms to apply to the 2D images 45 | self.transforms = transforms 46 | 47 | # dataset that loads raw events in shape [N, 4 (x, y, t, p)] 48 | self.event_dataset = event_dataset 49 | self.classes = event_dataset.classes 50 | self.resolution = event_dataset.resolution 51 | self.max_t = event_dataset.max_t # timestamp 52 | self.max_n = event_dataset.max_n # number of events 53 | self.tta = tta 54 | if tta: 55 | assert not event_dataset.augmentation, \ 56 | 'Do not augment events in pseudo label generation' 57 | assert not augment, 'Do not augment twice' 58 | assert event_dataset.num_shots is None, 'Should sample all data' 59 | assert 'train' in event_dataset.root, \ 60 | 'Generate pseudo labels only on training set' 61 | print('Apply h- and t-flip TTA in pseudo label generation') 62 | 63 | # arguments for mapping events to 2D images 64 | self.quantize_args = copy.deepcopy(quantize_args) 65 | self.quantize_args['shape'] = self.resolution 66 | 67 | self.split_method = quantize_args['split_method'] 68 | self.event_rep = quantize_args['convert_method'] 69 | assert self.split_method == 'event_count' 70 | max_imgs = round(self.max_n / quantize_args['N']) 71 | max_max_imgs = quantize_args.pop('max_imgs', 10) # hard limit 72 | self.max_imgs = max(min(max_imgs, max_max_imgs), 1) 73 | 74 | # a hack in visualization to also load the raw events data 75 | self.keep_events = False 76 | 77 | def __len__(self): 78 | return len(self.event_dataset) 79 | 80 | def _subsample_imgs(self, imgs): 81 | """Randomly select a subset of images or pad with zeros.""" 82 | valid_mask = torch.zeros(self.max_imgs).bool() 83 | if len(imgs) > self.max_imgs: 84 | valid_mask[:] = True 85 | idxs = torch.randperm(len(imgs))[:self.max_imgs] 86 | imgs = imgs[idxs] 87 | else: 88 | valid_mask[:len(imgs)] = True 89 | pad = torch.zeros( 90 | (self.max_imgs - len(imgs), *imgs.shape[1:])).type_as(imgs) 91 | imgs = torch.cat([imgs, pad], dim=0) 92 | return imgs, valid_mask 93 | 94 | def _load_tta_data(self, idx): 95 | """Apply h- and t-flip to the loaded events, then convert to images.""" 96 | data_dict = self.event_dataset[idx] 97 | events = data_dict.pop('events') 98 | assert not self.keep_events, 'val dataset should not be TTA' 99 | h_events = hflip_events( 100 | copy.deepcopy(events), resolution=self.resolution, p=1.) 101 | t_events = tflip_events(copy.deepcopy(events), p=1.) 102 | h_t_events = tflip_events(copy.deepcopy(h_events), p=1.) 103 | tta_events = [events, h_events, t_events, h_t_events] 104 | tta_imgs, tta_valid_mask = [], [] 105 | for events in tta_events: 106 | imgs, valid_mask = self._event2img(events) 107 | tta_imgs.append(imgs) 108 | tta_valid_mask.append(valid_mask) 109 | data_dict['img'] = torch.stack(tta_imgs, dim=0) # [4, N, 3, H, W] 110 | data_dict['valid_mask'] = torch.stack(tta_valid_mask, dim=0) # [4, N] 111 | # `label` is still just an integer 112 | return data_dict 113 | 114 | def _event2img(self, events): 115 | """Convert events to 2D images.""" 116 | # events: [N, 4 (x, y, t, p)] 117 | # get [N, H, W, 3] images with dtype np.uint8 118 | imgs = events2frames(events, **self.quantize_args) 119 | imgs = [Image.fromarray(img) for img in imgs] 120 | if self.augment: 121 | imgs = self.augmentation(imgs) 122 | imgs = torch.stack([self.transforms(img) for img in imgs]) 123 | # to [N, 3, H, W] torch.Tensor as model inputs 124 | 125 | # randomly select a subset of images or pad with zeros 126 | imgs, valid_mask = self._subsample_imgs(imgs) 127 | 128 | return imgs, valid_mask 129 | 130 | def __getitem__(self, idx): 131 | if self.tta: 132 | return self._load_tta_data(idx) 133 | 134 | data_dict = self.event_dataset[idx] 135 | events = data_dict.pop('events') 136 | 137 | if self.keep_events: 138 | data_dict['events'] = copy.deepcopy(events) 139 | 140 | imgs, valid_mask = self._event2img(events) 141 | 142 | data_dict['img'] = imgs 143 | data_dict['valid_mask'] = valid_mask 144 | 145 | return data_dict 146 | 147 | 148 | def build_event2img_dataset(params, event_dataset, augment=False, tta=False): 149 | """Wrap an event dataset with a Event2Image processing pipeline.""" 150 | return Event2ImageDataset( 151 | transforms=params.data_transforms, 152 | event_dataset=event_dataset, 153 | quantize_args=params.quantize_args, 154 | augment=augment, 155 | tta=tta, 156 | ) 157 | -------------------------------------------------------------------------------- /method.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import wandb 3 | import numpy as np 4 | 5 | import torch 6 | import torch.optim as optim 7 | import torch.cuda.amp as amp 8 | 9 | from nerv.utils import AverageMeter, MeanMetric 10 | from nerv.training import BaseMethod, CosineAnnealingWarmupRestarts 11 | 12 | from datasets import events2frames 13 | 14 | 15 | def denormalize(x): 16 | """Reverse the input image normalization.""" 17 | mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).type_as(x) 18 | std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).type_as(x) 19 | return x * std[None, :, None, None] + mean[None, :, None, None] 20 | 21 | 22 | def build_method(**kwargs): 23 | params = kwargs['params'] 24 | if params.model in ['ZSCLIP', 'FSCLIP', 'FTCLIP']: 25 | return EventCLIPMethod(**kwargs) 26 | else: 27 | raise NotImplementedError(f'{params.model} method is not implemented.') 28 | 29 | 30 | class EventBaseMethod(BaseMethod): 31 | """Base method in this project.""" 32 | 33 | def _round(self, v, n): 34 | if isinstance(v, (float, int)): 35 | return float(round(v, n)) 36 | return type(v)([self._round(i, n) for i in v]) 37 | 38 | @staticmethod 39 | def _convert_video(video, caption=None): 40 | """Convert torch.FloatTensor video to wandb.Video.""" 41 | assert isinstance(video, torch.Tensor) 42 | video = denormalize(video) 43 | video = (video * 255.).cpu().numpy() 44 | video = np.round(video).clip(0, 255).astype(np.uint8) 45 | return wandb.Video(video, fps=2, caption=caption) 46 | 47 | @staticmethod 48 | def _get_sample_idx(N, dst): 49 | """Load data uniformly from the dataset.""" 50 | dst_len = len(dst) 51 | N = N - 1 if dst_len % N != 0 else N 52 | sampled_idx = torch.arange(0, dst_len, dst_len // N) 53 | return sampled_idx.numpy() 54 | 55 | @torch.no_grad() 56 | @amp.autocast() 57 | def validation_epoch(self, model, san_check_step=-1, sample_events=True): 58 | """Validate one epoch. 59 | We aggregate the avg of all statistics and only log once. 60 | """ 61 | out_dict = super().validation_epoch( 62 | model, san_check_step=san_check_step) 63 | if self.local_rank != 0: 64 | return 65 | # visualization after every epoch 66 | if sample_events: 67 | self._sample_events(model) 68 | 69 | return out_dict 70 | 71 | @staticmethod 72 | def event2video(events, caption=None, **quantize_args): 73 | """Convert events to wandb loggable videos.""" 74 | imgs = events2frames(events, **quantize_args).astype(np.uint8) 75 | imgs = np.ascontiguousarray(imgs.transpose(0, 3, 1, 2)) 76 | # add a black border to the video 77 | T, C, H, W = imgs.shape 78 | video = np.zeros((T, C, H + 8, W + 8), dtype=np.uint8) 79 | video[:, :, 4:-4, 4:-4] = imgs 80 | return wandb.Video(video, fps=2, caption=caption) 81 | 82 | def _configure_optimizers(self): 83 | """Returns an optimizer, a scheduler and its frequency (step/epoch).""" 84 | optimizer = super()._configure_optimizers()[0] 85 | 86 | lr = self.params.lr 87 | total_steps = self.params.max_epochs * len(self.train_loader) 88 | warmup_steps = self.params.warmup_steps_pct * total_steps 89 | 90 | scheduler = CosineAnnealingWarmupRestarts( 91 | optimizer, 92 | total_steps, 93 | max_lr=lr, 94 | min_lr=lr / 100., 95 | warmup_steps=warmup_steps, 96 | ) 97 | 98 | return optimizer, (scheduler, 'step') 99 | 100 | 101 | class EventCLIPMethod(EventBaseMethod): 102 | 103 | @torch.no_grad() 104 | def _sample_events(self, model): 105 | """model is a simple nn.Module, not warpped in e.g. DataParallel.""" 106 | model.eval() 107 | dst = self.val_loader.dataset 108 | dst.keep_events = True 109 | classes = dst.classes 110 | quantize_args = copy.deepcopy(dst.quantize_args) 111 | quantize_args['background_mask'] = True 112 | 113 | sampled_idx = self._get_sample_idx(self.params.n_samples, dst) 114 | log_dict = {} 115 | for i, idx in enumerate(sampled_idx): 116 | data_dict = dst[idx] 117 | events, label = data_dict.pop('events'), data_dict.pop('label') 118 | img, valid_mask = data_dict['img'], data_dict['valid_mask'] 119 | # [N, 3, H, W], [N] 120 | in_dict = { 121 | 'img': img[None].to(model.device), 122 | 'valid_mask': valid_mask[None].to(model.device), 123 | } 124 | probs = model(in_dict)['probs'][0] # [n_cls] 125 | 126 | # keep the topk predictions 127 | k = min(3, probs.shape[-1]) 128 | topk = probs.topk(k, dim=-1) 129 | idxs, probs = \ 130 | topk.indices.cpu().numpy(), topk.values.cpu().numpy() 131 | caption = f'GT: {classes[label]}\n' + '\t'.join([ 132 | f'{classes[idx]}: {prob:.4f}' 133 | for idx, prob in zip(idxs, probs) 134 | ]) 135 | 136 | # visualize the events 137 | # raw events 138 | raw_events = self.event2video( 139 | events, caption=caption, **quantize_args) 140 | log_dict[f'val/raw_events_{i}'] = raw_events 141 | # model inputs 142 | img = img[valid_mask] 143 | video = self._convert_video(img, caption=caption) 144 | log_dict[f'val/video_{i}'] = video 145 | 146 | wandb.log(log_dict, step=self.it) 147 | torch.cuda.empty_cache() 148 | dst.keep_events = False 149 | 150 | def _configure_optimizers(self): 151 | """Returns an optimizer, a scheduler and its frequency (step/epoch).""" 152 | if self.params.model != 'FTCLIP': 153 | return super()._configure_optimizers() 154 | 155 | # use smaller lr for finetuning CLIP 156 | if self.params.optimizer.lower() == 'adam': 157 | opt = optim.Adam 158 | elif self.params.optimizer.lower() == 'adamw': 159 | opt = optim.AdamW 160 | else: 161 | raise ValueError('Should use Adam or AdamW optimizer!') 162 | assert self.params.weight_decay == 0. 163 | lr = self.params.lr 164 | clip_lr = self.params.clip_lr 165 | name = 'model.visual' 166 | 167 | adapter_params = list( 168 | filter(lambda kv: name not in kv[0] and kv[1].requires_grad, 169 | self.model.named_parameters())) 170 | clip_params = list( 171 | filter(lambda kv: name in kv[0] and kv[1].requires_grad, 172 | self.model.named_parameters())) 173 | # assert len(adapter_params) > 0 and len(clip_params) > 0 174 | params_list = [{ 175 | 'params': [kv[1] for kv in adapter_params], 176 | 'lr': lr, 177 | }, { 178 | 'params': [kv[1] for kv in clip_params], 179 | 'lr': clip_lr, 180 | }] 181 | 182 | optimizer = opt(params_list) 183 | total_steps = self.params.max_epochs * len(self.train_loader) 184 | warmup_steps = self.params.warmup_steps_pct * total_steps 185 | scheduler = CosineAnnealingWarmupRestarts( 186 | optimizer, 187 | total_steps, 188 | max_lr=(lr, clip_lr), 189 | min_lr=(lr / 100., clip_lr / 100.), 190 | warmup_steps=warmup_steps, 191 | ) 192 | 193 | return optimizer, (scheduler, 'step') 194 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | """EventCLIP testing script""" 2 | 3 | import os 4 | import sys 5 | import importlib 6 | import argparse 7 | 8 | from tqdm import tqdm 9 | 10 | import torch 11 | 12 | import clip 13 | 14 | from nerv.training import BaseDataModule 15 | from nerv.utils import AverageMeter 16 | 17 | from models import build_model 18 | from datasets import build_dataset 19 | 20 | 21 | @torch.no_grad() 22 | def main(params, printing=True): 23 | # have to load CLIP model first 24 | arch = params.clip_dict['arch'] 25 | device = 'cuda' 26 | model, preprocess = clip.load(arch, device=device) 27 | 28 | # build dataset 29 | params.data_transforms = preprocess 30 | if args.subset > 0: 31 | test_set = build_dataset(params, val_only=True, subset=args.subset) 32 | else: 33 | test_set = build_dataset(params, val_only=True) 34 | is_nin = (params.dataset == 'n_imagenet') 35 | 36 | datamodule = BaseDataModule( 37 | params, train_set=None, val_set=test_set, use_ddp=False) 38 | test_loader = datamodule.val_loader 39 | 40 | # build model 41 | params.clip_dict['clip_model'] = model 42 | params.clip_dict['class_names'] = test_set.classes 43 | if not is_zs: 44 | params.adapter_dict['in_dim'] = model.visual.output_dim 45 | model = build_model(params) 46 | 47 | # load weight 48 | # don't load for zero-shot models 49 | if args.weight and not is_zs: 50 | model.load_weight(args.weight) 51 | print(f'Loading weight: {args.weight}') 52 | model = model.cuda().eval() 53 | 54 | # test 55 | probs_acc_meter, logits_acc_meter = AverageMeter(), AverageMeter() 56 | if is_nin: 57 | probs_acc5_meter, logits_acc5_meter = AverageMeter(), AverageMeter() 58 | 59 | for data_dict in tqdm(test_loader): 60 | data_dict = {k: v.cuda() for k, v in data_dict.items()} 61 | out_dict = model(data_dict) 62 | labels = data_dict['label'] 63 | 64 | # based on aggregated probs 65 | probs = out_dict['probs'] 66 | probs_acc = (probs.argmax(dim=-1) == labels).float().mean().item() 67 | probs_acc_meter.update(probs_acc, labels.shape[0]) 68 | 69 | # based on aggregated logits 70 | logits = out_dict['logits'] 71 | logits_acc = (logits.argmax(dim=-1) == labels).float().mean().item() 72 | logits_acc_meter.update(logits_acc, labels.shape[0]) 73 | 74 | # top5 accuracy 75 | if is_nin: 76 | probs_acc5 = (probs.topk(5, dim=-1).indices == labels[:, None]).\ 77 | float().sum(dim=-1).mean().item() 78 | probs_acc5_meter.update(probs_acc5, labels.shape[0]) 79 | logits_acc5 = (logits.topk(5, dim=-1).indices == labels[:, None]).\ 80 | float().sum(dim=-1).mean().item() 81 | logits_acc5_meter.update(logits_acc5, labels.shape[0]) 82 | 83 | if not printing: 84 | return probs_acc_meter.avg, logits_acc_meter.avg 85 | 86 | print(f'\n\nTesting {args.params}') 87 | print(f'Model weight: {args.weight}') 88 | print(f'\tProbs-based accuracy@1: {probs_acc_meter.avg * 100.:.2f}%') 89 | print(f'\tLogits-based accuracy@1: {logits_acc_meter.avg * 100.:.2f}%\n') 90 | if not is_nin: 91 | return 92 | print(f'\tProbs-based accuracy@5: {probs_acc5_meter.avg * 100.:.2f}%') 93 | print(f'\tLogits-based accuracy@5: {logits_acc5_meter.avg * 100.:.2f}%\n') 94 | 95 | 96 | if __name__ == "__main__": 97 | parser = argparse.ArgumentParser(description='EventCLIP') 98 | parser.add_argument('--params', type=str, required=True) 99 | parser.add_argument('--weight', type=str, default='', help='load weight') 100 | parser.add_argument('--N', type=int, default=-1) 101 | parser.add_argument('--arch', type=str, default='') 102 | parser.add_argument('--prompt', type=str, default='') 103 | parser.add_argument('--bs', type=int, default=-1) 104 | parser.add_argument('--subset', type=int, default=-1) 105 | parser.add_argument('--train_shots', nargs='+', default=[-1], type=int) 106 | args = parser.parse_args() 107 | 108 | if args.params.endswith('.py'): 109 | args.params = args.params[:-3] 110 | sys.path.append(os.path.dirname(args.params)) 111 | params = importlib.import_module(os.path.basename(args.params)) 112 | params = params.EventCLIPParams() 113 | 114 | # adjust params 115 | is_zs = (params.model == 'ZSCLIP') 116 | if args.N > 0: 117 | params.quantize_args['N'] = int(args.N * 1e3) 118 | assert is_zs, 'can only change N in zero-shot testing' 119 | if args.arch: 120 | params.clip_dict['arch'] = args.arch 121 | assert is_zs, 'can only change ViT arch in zero-shot testing' 122 | if args.prompt: 123 | params.clip_dict['prompt'] = args.prompt 124 | assert is_zs, 'can only change text prompt in zero-shot testing' 125 | if args.bs > 0: 126 | params.val_batch_size = args.bs 127 | if args.subset > 0: 128 | assert params.dataset == 'n_imagenet', 'only N-ImageNet has subsets' 129 | 130 | # automatically find the model weight if `train_shots` is provided 131 | # we will ignore the provided `args.weight` 132 | # instead search for 'checkpoint/$PARAMS-${NUM}shot/models/model_xxx.pth' 133 | if args.train_shots[0] <= 0: 134 | main(params) 135 | exit(-1) 136 | 137 | all_probs_acc, all_logits_acc = [], [] 138 | for num_shot in args.train_shots: 139 | # first, find all dup-run dirs 140 | dup_weight_dir = os.path.join('checkpoint', 141 | os.path.basename(args.params)) 142 | all_weight_dirs = [f'{dup_weight_dir}-{num_shot}shot'] 143 | for i in range(1, 11, 1): # at most dup 10 times 144 | weight_dir = f'{dup_weight_dir}-dup{i}-{num_shot}shot' 145 | if os.path.exists(weight_dir): 146 | all_weight_dirs.append(weight_dir) 147 | 148 | # average the accuracy over all weights 149 | probs_acc_avg, logits_acc_avg = AverageMeter(), AverageMeter() 150 | 151 | # now, for each weight_dir, find a weight to test 152 | for weight_dir in all_weight_dirs: 153 | if not os.path.exists(weight_dir): 154 | continue 155 | weight_dir = os.path.join(weight_dir, 'models') 156 | 157 | # load the best weight if it is saved 158 | if os.path.exists(os.path.join(weight_dir, 'best.pth')): 159 | args.weight = os.path.join(weight_dir, 'best.pth') 160 | # find the latest one 161 | else: 162 | all_weights = [ 163 | w for w in os.listdir(weight_dir) if w.endswith('.pth') 164 | ] 165 | all_weights = sorted( 166 | all_weights, key=lambda x: int(x[:-4].split('_')[1])) 167 | args.weight = os.path.join(weight_dir, all_weights[-1]) 168 | 169 | probs_acc, logits_acc = main(params, printing=False) 170 | probs_acc_avg.update(probs_acc, 1) 171 | logits_acc_avg.update(logits_acc, 1) 172 | 173 | # print the results for this `num_shot` 174 | print(f'\n\nTesting {args.params}-{num_shot}shot') 175 | print(f'Average accuracy over {probs_acc_avg.count} runs:') 176 | print(f'\tProbs-based accuracy@1: {probs_acc_avg.avg * 100.:.2f}%') 177 | print(f'\tLogits-based accuracy@1: {logits_acc_avg.avg * 100.:.2f}%\n') 178 | all_probs_acc.append(round(probs_acc_avg.avg * 100., 2)) 179 | all_logits_acc.append(round(logits_acc_avg.avg * 100., 2)) 180 | 181 | # print the results for recording & LaTeX 182 | print('\n\n') 183 | print(f'Probs-based accuracy@1: {all_probs_acc}') 184 | print('\t', ' & '.join([str(acc) for acc in all_probs_acc])) 185 | print(f'Logits-based accuracy@1: {all_logits_acc}') 186 | print('\t', ' & '.join([str(acc) for acc in all_logits_acc])) 187 | -------------------------------------------------------------------------------- /datasets/augment.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Dict, List, Optional, Tuple 3 | 4 | import torch 5 | from torch import Tensor 6 | from torchvision.transforms import functional as F 7 | from torchvision.transforms import InterpolationMode 8 | 9 | 10 | def _apply_op( 11 | img: Tensor, op_name: str, magnitude: float, interpolation: InterpolationMode, fill: Optional[List[float]] 12 | ): 13 | if op_name == "ShearX": 14 | # magnitude should be arctan(magnitude) 15 | # official autoaug: (1, level, 0, 0, 1, 0) 16 | # https://github.com/tensorflow/models/blob/dd02069717128186b88afa8d857ce57d17957f03/research/autoaugment/augmentation_transforms.py#L290 17 | # compared to 18 | # torchvision: (1, tan(level), 0, 0, 1, 0) 19 | # https://github.com/pytorch/vision/blob/0c2373d0bba3499e95776e7936e207d8a1676e65/torchvision/transforms/functional.py#L976 20 | img = F.affine( 21 | img, 22 | angle=0.0, 23 | translate=[0, 0], 24 | scale=1.0, 25 | shear=[math.degrees(math.atan(magnitude)), 0.0], 26 | interpolation=interpolation, 27 | fill=fill, 28 | center=[0, 0], 29 | ) 30 | elif op_name == "ShearY": 31 | # magnitude should be arctan(magnitude) 32 | # See above 33 | img = F.affine( 34 | img, 35 | angle=0.0, 36 | translate=[0, 0], 37 | scale=1.0, 38 | shear=[0.0, math.degrees(math.atan(magnitude))], 39 | interpolation=interpolation, 40 | fill=fill, 41 | center=[0, 0], 42 | ) 43 | elif op_name == "TranslateX": 44 | img = F.affine( 45 | img, 46 | angle=0.0, 47 | translate=[int(magnitude), 0], 48 | scale=1.0, 49 | interpolation=interpolation, 50 | shear=[0.0, 0.0], 51 | fill=fill, 52 | ) 53 | elif op_name == "TranslateY": 54 | img = F.affine( 55 | img, 56 | angle=0.0, 57 | translate=[0, int(magnitude)], 58 | scale=1.0, 59 | interpolation=interpolation, 60 | shear=[0.0, 0.0], 61 | fill=fill, 62 | ) 63 | elif op_name == "Rotate": 64 | img = F.rotate(img, magnitude, interpolation=interpolation, fill=fill) 65 | elif op_name == "Brightness": 66 | img = F.adjust_brightness(img, 1.0 + magnitude) 67 | elif op_name == "Color": 68 | img = F.adjust_saturation(img, 1.0 + magnitude) 69 | elif op_name == "Contrast": 70 | img = F.adjust_contrast(img, 1.0 + magnitude) 71 | elif op_name == "Sharpness": 72 | img = F.adjust_sharpness(img, 1.0 + magnitude) 73 | elif op_name == "Posterize": 74 | img = F.posterize(img, int(magnitude)) 75 | elif op_name == "Solarize": 76 | img = F.solarize(img, magnitude) 77 | elif op_name == "AutoContrast": 78 | img = F.autocontrast(img) 79 | elif op_name == "Equalize": 80 | img = F.equalize(img) 81 | elif op_name == "Invert": 82 | img = F.invert(img) 83 | elif op_name == "Identity": 84 | pass 85 | else: 86 | raise ValueError(f"The provided operator {op_name} is not recognized.") 87 | return img 88 | 89 | 90 | class RandAugment(torch.nn.Module): 91 | r"""RandAugment data augmentation method based on 92 | `"RandAugment: Practical automated data augmentation with a reduced search space" 93 | `_. 94 | If the image is torch Tensor, it should be of type torch.uint8, and it is expected 95 | to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. 96 | If img is PIL Image, it is expected to be in mode "L" or "RGB". 97 | 98 | Args: 99 | num_ops (int): Number of augmentation transformations to apply sequentially. 100 | magnitude (int): Magnitude for all the transformations. 101 | num_magnitude_bins (int): The number of different magnitude values. 102 | interpolation (InterpolationMode): Desired interpolation enum defined by 103 | :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. 104 | If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. 105 | fill (sequence or number, optional): Pixel fill value for the area outside the transformed 106 | image. If given a number, the value is used for all bands respectively. 107 | """ 108 | 109 | def __init__( 110 | self, 111 | num_ops: int = 2, 112 | interpolation: InterpolationMode = InterpolationMode.NEAREST, 113 | fill: Optional[List[float]] = None, 114 | ) -> None: 115 | super().__init__() 116 | 117 | self.num_ops = num_ops 118 | self.interpolation = interpolation 119 | self.fill = fill 120 | 121 | self.cur_ops = None 122 | 123 | def _augmentation_space(self, num_bins: int, image_size: Tuple[int, int]) -> Dict[str, Tuple[Tensor, bool]]: 124 | return { 125 | # op_name: (magnitudes, signed) 126 | "Identity": (torch.tensor(0.0), False), 127 | "ShearX": (torch.linspace(0.0, 0.3, num_bins), True), 128 | "ShearY": (torch.linspace(0.0, 0.3, num_bins), True), 129 | "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), 130 | "TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), 131 | "Rotate": (torch.linspace(0.0, 30.0, num_bins), True), 132 | "Brightness": (torch.linspace(0.0, 0.9, num_bins), True), 133 | "Color": (torch.linspace(0.0, 0.9, num_bins), True), 134 | "Contrast": (torch.linspace(0.0, 0.9, num_bins), True), 135 | "Sharpness": (torch.linspace(0.0, 0.9, num_bins), True), 136 | "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False), 137 | "Solarize": (torch.linspace(255.0, 0.0, num_bins), False), 138 | "AutoContrast": (torch.tensor(0.0), False), 139 | "Equalize": (torch.tensor(0.0), False), 140 | } 141 | 142 | def randomize_ops(self, resolution: Tuple[int, int]) -> None: 143 | """Randomly select `self.num_ops` augmentations to apply.""" 144 | assert self.cur_ops is None, 'Unused RandAugment ops' 145 | self.cur_ops = [] 146 | num_magnitude_bins = 30 147 | cur_magnitude = int(torch.randint(num_magnitude_bins, (1,)).item()) 148 | op_meta = self._augmentation_space(num_magnitude_bins, resolution) 149 | for _ in range(self.num_ops): 150 | op_index = int(torch.randint(len(op_meta), (1,)).item()) 151 | op_name = list(op_meta.keys())[op_index] 152 | magnitudes, signed = op_meta[op_name] 153 | magnitude = float(magnitudes[cur_magnitude].item()) if \ 154 | magnitudes.ndim > 0 else 0.0 155 | if signed and torch.randint(2, (1,)): 156 | magnitude *= -1.0 157 | self.cur_ops.append((op_name, magnitude)) 158 | 159 | def forward(self, imgs: List[Tensor]) -> Tensor: 160 | """Apply the same RandAugment ops to all images. 161 | 162 | Args: 163 | imgs (List[PIL Image or Tensor]): Image to be transformed. 164 | 165 | Returns: 166 | List[PIL Image or Tensor]: Transformed image. 167 | """ 168 | channels, height, width = F.get_dimensions(imgs[0]) 169 | self.randomize_ops((height, width)) 170 | fill = self.fill 171 | if isinstance(imgs[0], Tensor): 172 | if isinstance(fill, (int, float)): 173 | fill = [float(fill)] * channels 174 | elif fill is not None: 175 | fill = [float(f) for f in fill] 176 | imgs = [self._forward(img, fill) for img in imgs] 177 | self.cur_ops = None 178 | return imgs 179 | 180 | def _forward(self, img: Tensor, fill: Optional[List[float]]) -> Tensor: 181 | """Apply augmentation to one image. 182 | 183 | Args: 184 | img (PIL Image or Tensor): Image to be transformed. 185 | 186 | Returns: 187 | PIL Image or Tensor: Transformed image. 188 | """ 189 | assert len(self.cur_ops) == self.num_ops, 'Wrong number of RandAugment ops' 190 | for i in range(self.num_ops): 191 | op_name, magnitude = self.cur_ops[i] 192 | img = _apply_op(img, op_name, magnitude, 193 | interpolation=self.interpolation, fill=fill) 194 | return img 195 | -------------------------------------------------------------------------------- /datasets/caltech.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os import listdir 3 | from os.path import join 4 | import random 5 | 6 | import numpy as np 7 | 8 | from torch.utils.data import Dataset 9 | 10 | from nerv.utils import load_obj, dump_obj 11 | 12 | from .utils import random_time_flip_events, random_shift_events, \ 13 | random_flip_events_along_x, center_events 14 | 15 | # from https://github.com/KaiyangZhou/CoOp/blob/main/datasets/caltech101.py 16 | NEW_CNAMES = { 17 | "airplanes": "airplane", 18 | "Faces": "face", # actually doesn't exist 19 | "Faces_easy": "face", 20 | "Leopards": "leopard", 21 | "Motorbikes": "motorbike", 22 | "BACKGROUND_Google": "background", # random images, hard to categorize 23 | } 24 | 25 | 26 | def get_real_path(path): 27 | while os.path.islink(path): 28 | path = os.readlink(path) 29 | return path 30 | 31 | 32 | class NCaltech101(Dataset): 33 | """Dataset class for N-Caltech101 dataset.""" 34 | 35 | def __init__( 36 | self, 37 | root, 38 | augmentation=False, 39 | num_shots=None, 40 | repeat=True, 41 | new_cnames=None, 42 | ): 43 | root = get_real_path(root) 44 | self.root = root 45 | self.classes = sorted(listdir(root)) 46 | # TODO: a hack for identifying generated pseudo labeled datasets 47 | self.is_pseudo = 'pseudo' in root 48 | if self.is_pseudo: 49 | print('Using pseudo labeled dataset!') 50 | 51 | # data stats (computed from the test set) 52 | self.resolution = (180, 240) 53 | # t is very uniform, i.e. different samples have similar max_t 54 | # so just take the max (unit: second) 55 | self.max_t = 0.325 56 | # the number of events are VERY unbalanced, so instead of taking 57 | # the max, we take the 95th percentile 58 | self.max_n = 225000 59 | 60 | # data augmentation 61 | self.augmentation = augmentation 62 | self.flip_time = False 63 | self.max_shift = 20 64 | 65 | # few-shot cls 66 | self.num_shots = num_shots # number of labeled data per class 67 | self.few_shot = (num_shots is not None and num_shots > 0) 68 | if self.few_shot: 69 | assert 'train' in root.lower(), 'Only sample data in training set' 70 | self.repeat = repeat 71 | 72 | self.labeled_files, self.labels = self._get_sample_idx() 73 | assert len(self.labeled_files) == len(self.labels) 74 | 75 | # change some class names 76 | self.new_cnames = new_cnames 77 | if new_cnames is None: 78 | return 79 | for i in range(len(self.classes)): 80 | if self.classes[i] in new_cnames: 81 | new_name = new_cnames[self.classes[i]] 82 | print(f'Rename {self.classes[i]} to {new_name}') 83 | self.classes[i] = new_name 84 | 85 | def _get_sample_idx(self): 86 | """Load event file_name and label pairs.""" 87 | # load pre-generated splits if available 88 | if self.few_shot and not self.is_pseudo: 89 | cur_dir = os.path.dirname(os.path.realpath(__file__)) 90 | split_fn = os.path.join( 91 | cur_dir, 'files', self.__class__.__name__, 92 | f'{self.num_shots}shot-repeat={self.repeat}.pkl') 93 | if os.path.exists(split_fn): 94 | print(f'Loading pre-generated split from {split_fn}') 95 | splits = load_obj(split_fn) # Dict[event_fn: label] 96 | labeled_files = np.array(list(splits.keys())) 97 | labels = np.array(list(splits.values())) 98 | return labeled_files, labels 99 | 100 | labeled_files, labels = [], [] 101 | 102 | # fix the random seed since we'll sample data 103 | random.seed(0) 104 | for i, c in enumerate(self.classes): 105 | cls_files = [ 106 | get_real_path(join(self.root, c, f)) 107 | for f in sorted(listdir(join(self.root, c))) 108 | ] 109 | if len(cls_files) == 0: 110 | print(f'Warning: class {c} has no data!') 111 | continue 112 | 113 | # randomly sample `num_shots` labeled data for each class 114 | if self.few_shot: 115 | if self.num_shots <= len(cls_files): 116 | lbl_files = random.sample(cls_files, k=self.num_shots) 117 | else: 118 | if self.repeat: 119 | lbl_files = random.choices(cls_files, k=self.num_shots) 120 | else: 121 | lbl_files = cls_files 122 | elif self.num_shots is None: 123 | lbl_files = cls_files 124 | else: 125 | raise ValueError(f'Invalid num_shots: {self.num_shots}') 126 | labeled_files += lbl_files 127 | labels += [i] * len(lbl_files) 128 | 129 | # save the splits for future use 130 | if self.few_shot and not self.is_pseudo: 131 | splits = {fn: lbl for fn, lbl in zip(labeled_files, labels)} 132 | os.makedirs(os.path.dirname(split_fn), exist_ok=True) 133 | dump_obj(splits, split_fn) 134 | print(f'Saving split file to {split_fn}') 135 | 136 | labeled_files = np.array(labeled_files) 137 | labels = np.array(labels) 138 | return labeled_files, labels 139 | 140 | def __len__(self): 141 | return len(self.labeled_files) 142 | 143 | def _rand_another(self): 144 | """Randomly sample another data.""" 145 | idx = np.random.randint(0, len(self)) 146 | return self.__getitem__(idx) 147 | 148 | @staticmethod 149 | def _load_events(event_path): 150 | """Load events from a file.""" 151 | return np.load(event_path).astype(np.float32) 152 | 153 | def _augment_events(self, events): 154 | """Data augmentation on events.""" 155 | if self.flip_time: 156 | events = random_time_flip_events(events) 157 | events = random_shift_events( 158 | events, max_shift=self.max_shift, resolution=self.resolution) 159 | events = random_flip_events_along_x(events, resolution=self.resolution) 160 | # not using time flip on N-Caltech and N-Cars dataset 161 | return events 162 | 163 | def __getitem__(self, idx): 164 | """ 165 | returns events and label, potentially with augmentation 166 | :param idx: data_idx 167 | :return: [N, (x,y,t,p)], label, data_idx 168 | """ 169 | f = str(self.labeled_files[idx]) 170 | label = int(self.labels[idx]) 171 | events = self._load_events(f) 172 | # the spatial resolution of N-Caltech events is 180x240 173 | # we should center the spatial coordinates of events 174 | # some events only reside in e.g. [0, 0] x [100, 160] 175 | # which will be largely removed after center crop! 176 | events = center_events(events, resolution=self.resolution) 177 | 178 | if self.augmentation: 179 | events = self._augment_events(events) 180 | 181 | if events.shape[0] == 0: 182 | return self._rand_another() 183 | 184 | # events: [N, 4 (x, y, t, p)], label: int 185 | # N is usually 1e5 ~ 1e6 186 | return { 187 | 'events': events, 188 | # 't': events[:, 2], 189 | 'label': label, 190 | 'data_idx': idx, 191 | } 192 | 193 | 194 | def build_n_caltech_dataset(params, val_only=False, gen_data=False): 195 | """Build the N-Caltech101 dataset.""" 196 | # only build the test set 197 | if val_only: 198 | assert not gen_data, 'Only generate pseudo labels on the training set' 199 | return NCaltech101( 200 | root=os.path.join(params.data_root, 'testing'), 201 | augmentation=False, 202 | new_cnames=NEW_CNAMES, 203 | ) 204 | # build the training set for pseudo label generation 205 | if gen_data: 206 | return NCaltech101( 207 | root=os.path.join(params.data_root, 'training'), 208 | augmentation=False, 209 | new_cnames=NEW_CNAMES, 210 | ) 211 | 212 | # build the training set 213 | train_set = NCaltech101( 214 | root=os.path.join(params.data_root, 'training'), 215 | augmentation=True, 216 | num_shots=params.get('num_shots', None), 217 | repeat=params.get('repeat_data', True), 218 | new_cnames=NEW_CNAMES, 219 | ) 220 | val_set = NCaltech101( 221 | root=os.path.join(params.data_root, 'testing'), 222 | augmentation=False, 223 | new_cnames=NEW_CNAMES, 224 | ) 225 | return train_set, val_set 226 | -------------------------------------------------------------------------------- /docs/benchmark.md: -------------------------------------------------------------------------------- 1 | # Benchmark 2 | 3 | ## Overview 4 | 5 | We provide instructions on reproducing the results reported in the paper, including: 6 | 7 | - Zero-shot EventCLIP on N-Caltech, N-Cars, and N-ImageNet datasets 8 | - Few-shot EventCLIP with text adapter (most stable), or joint adapter (best performing) on 3 datasets 9 | - Fine-tuning EventCLIP on N-Caltech and N-ImageNet to achieve SOTA performance 10 | - Learning with unlabeled data by self-training on generated pseudo labels on the N-ImageNet (Mini) dataset 11 | 12 | In the following instructions, we will mostly use **EventCLIP with joint adapter on N-Caltech under the 5-shot setting** as example. 13 | Other settings are easily replicable by changing the config file or other flags. 14 | 15 | ### Pre-trained Weights 16 | 17 | Since most of the experiments in the paper can be trained within 1-2 hours, we only provide pre-trained weights for long-running experiments, or those involving multi-step training. 18 | Please download the pre-trained weights from [Google Drive](https://drive.google.com/file/d/1QW7sn5BYjRdUe6xD_jQUQgSa9oIveq0s/view?usp=sharing) and unzip them under [pretrained/](../pretrained/). 19 | 20 | ## Training EventCLIP Feature Adapter 21 | 22 | **We provide a unified script [train.py](../train.py) to train all models used in this project.** 23 | You should always call it in the **root directory** of this repo (i.e. calling `python train.py xxx`). 24 | 25 | **All of the model training can be done by specifying a config file (called `params` here), and adding other args (e.g. `--num_shots`).** 26 | Please check the config file for the number of GPUs and other resources (e.g. `num_workers` CPUs) before launching a training. 27 | 28 | Here is one example: 29 | 30 | ``` 31 | python train.py --params configs/fsclip/joint_adapter/joint_fsclip_ncaltech_params.py --num_shots 5 --fp16 --cudnn 32 | ``` 33 | 34 | Other arguments include: 35 | 36 | - `--weight`: resume training from this weight 37 | - `--ddp`: use DDP multi-GPU training (needed when using `>=2` GPUs) 38 | - `--fp16`: enable half-precision training (highly recommended) 39 | - `--cudnn`: enable cudnn benchmark (highly recommended) 40 | - `--local_rank`/`--local-rank`: required by DDP, don't change it 41 | 42 | During training, model checkpoints and visualizations will be saved under `./checkpoint/$PARAMS/models/`. 43 | 44 | We provide config files for training EventCLIP under the few-shot classification setting: 45 | 46 | - [Text Adapter configs](../configs/fsclip/text_adapter/) 47 | - [Joint Adapter configs](../configs/fsclip/joint_adapter/) 48 | 49 | ## Testing 50 | 51 | Testing can be done with [test.py](../test.py). 52 | To test the above trained model, simply run: 53 | 54 | ``` 55 | python test.py --params configs/fsclip/joint_adapter/joint_fsclip_ncaltech_params.py --weight $WEIGHT 56 | ``` 57 | 58 | Or you can use these configs to test the zero-shot classification performance: 59 | 60 | - [Zero-shot configs](../configs/zsclip/) 61 | 62 | Other arguments in `test.py` include: 63 | 64 | - `--bs`: testing batch size 65 | - `--subset`: used to specify the N-ImageNet robustness variants to test. See their paper Appendix for a conversion between subset ID and the actual data variation 66 | - `--arch`: change CLIP's image encoder backbone in zero-shot testing 67 | - `--prompt`: change the text prompt in zero-shot testing 68 | 69 | We also provide a `--train_shots` argument to automatically gather results over different shots. 70 | If you train the same model with different `--num_shots` values in `train.py`, you can put all numbers of shots here to test them together. 71 | For example, you can run: 72 | 73 | ``` 74 | python test.py --params configs/fsclip/joint_adapter/joint_fsclip_ncaltech_params.py --train_shots 20 10 5 3 1 75 | ``` 76 | 77 | **Note that testing is always conducted over the entire test set without few-shot filtering.** 78 | 79 | ## Fine-tuning EventCLIP Full Model 80 | 81 | Fine-tuning EventCLIP is similar to training an adapter model. 82 | But they require more GPU memory and training time. 83 | We provide config files for fine-tuning EventCLIP under the few-shot and full data setting. 84 | Please refer to them for detailed training requirement: 85 | 86 | - [Fine-tuning configs](../configs/ftclip/) 87 | 88 | We provide the weight for our fine-tuned EventCLIP (with ViT-B/16 backbone) on the N-ImageNet dataset. 89 | 90 | ## Learning with Unlabeled Data 91 | 92 | To generate pseudo labels on unlabeled data, please use [gen_data.py](../gen_data.py). 93 | For example, you can use the zero-shot EventCLIP to generate pseudo labels on the N-ImageNet (Mini) dataset by: 94 | 95 | ``` 96 | python gen_data.py --params configs/zsclip/zsclip_nin_mini_params-vitb32.py --weight '' \ 97 | --conf_thresh 0.999 --tta --tta_min_prob --tta_consistent --topk 30 \ 98 | --save_path data/pseudo-N_Imagenet/vitb32_zs-tta-thresh_0999-top30 \ 99 | --gt_shots -1 100 | ``` 101 | 102 | Here, we use a very high confidence threshold of `0.999` to filter predictions. 103 | This is because the pre-trained CLIP model always makes over-confident predictions (likely due to the learned temperature parameter `\tau`). 104 | Other arguments include: 105 | 106 | - `--tta`, `--tta_min_prob`, and `--tta_consistent` are the techniques introduced in the paper to further improve the label quality 107 | - `--topk 30` means we only select the top-30 most confident predictions for each class 108 | - `--save_path` indicates the path to save the generated dataset 109 | - `--gt_shots` specifies the number of labeled data used to train the model as `--weight` 110 | Since we are using the zero-shot model here, we set it to `-1` and `--weight` is empty 111 | 112 | If you want to study the semi-supervised setting, where we have `X` labeled data and all the remaining unlabeled data, you can first pre-train an EventCLIP with joint adapter using the [provided config file](../configs/fsclip/joint_adapter/joint_fsclip_nin_mini_params-vitb32.py). 113 | We provide the 1-, 3-, 5-, 10-, and 20-shot pre-trained weights in this setting. 114 | Then, run `gen_data.py` again, but use the joint adapter's config file as `--params`, `--gt_shots X`, and `--weight` pointing to the pre-trained model's weight. 115 | Also, please use a lower confidence threshold `--conf_thresh 0.5` as the few-shot EventCLIP is now calibrated. 116 | An example command is: 117 | 118 | ``` 119 | python gen_data.py --params configs/fsclip/joint_adapter/joint_fsclip_nin_mini_params-vitb32.py \ 120 | --weight pretrained/joint_fsclip_nin_mini_params-vitb32-1shot-pretrain.pth \ 121 | --conf_thresh 0.5 --tta --tta_min_prob --tta_consistent --topk 30 \ 122 | --save_path data/pseudo-N_Imagenet/vitb32_1shot-tta-thresh_05-top30 \ 123 | --gt_shots 1 124 | ``` 125 | 126 | Finally, to train on this generated dataset (i.e. self-training), please modify the [config file](../configs/fsclip/joint_adapter/joint_fsclip_nin_mini_params-vitb32.py)'s `data_root` field to the `save_path` above and run `train.py`. 127 | **Note that** you should set `--num_shots` to `X + topk`. 128 | This is because we select the `topk` most confident predictions per class as pseudo labels, plus the `X` GT labels per class to train the model. 129 | 130 | We provide the weight for our EventCLIP (with ViT-B/32 backbone) trained on zero-shot generated pseudo labels on the N-ImageNet (Mini) dataset. 131 | 132 | ## Useful Scripts 133 | 134 | We provide helper scripts for Slurm cluster job submission, and train/test over multiple settings. 135 | 136 | - You can use [sbatch_run.sh](../scripts/sbatch_run.sh) to automatically generate a sbatch file and submit the job to slurm. 137 | Simply run: 138 | 139 | ``` 140 | GPUS=$NUM_GPU CPUS_PER_GPU=8 MEM_PER_CPU=5 QOS=$QOS \ 141 | ./scripts/sbatch_run.sh $PARTITION $JOB_NAME \ 142 | train.py none (if DDP then change `none` to `ddp`) --py_args... 143 | ``` 144 | 145 | Again using the same example, we can set `--py_args...` as (see the config file for the number of GPU/CPU to use) 146 | 147 | ``` 148 | --params configs/fsclip/joint_adapter/joint_fsclip_ncaltech_params.py \ 149 | --num_shots 5 --fp16 --cudnn 150 | ``` 151 | 152 | Then this will be equivalent to running the above `python train.py xxx` command in CLI. 153 | 154 | Note that `sbatch_run.sh` calls a `resubmit_failed_job.sh` script inside, which will monitor the job status and resubmit the job if it fails. 155 | 156 | - We provide a script to submit multiple runs of the same experiment with different random seeds to slurm. 157 | To use the script [dup_run_sbatch.sh](../scripts/dup_run_sbatch.sh), simply run: 158 | 159 | ``` 160 | GPUS=$NUM_GPU CPUS_PER_GPU=8 MEM_PER_CPU=5 QOS=$QOS REPEAT=$NUM_REPEAT \ 161 | ./scripts/dup_run_sbatch.sh $PARTITION $JOB_NAME \ 162 | train.py none $PARAMS --py_args... 163 | ``` 164 | 165 | The other parts are really the same as `sbatch_run.sh`. 166 | The only difference is that we need to input the config file `$PARAMS` separately, so that the script will make several copies to it, and submit different jobs. 167 | 168 | Again training the same model, duplicating `3` times, on `rtx6000` partition, and in the name of `joint_fsclip_ncaltech_params`, simply run: 169 | 170 | ``` 171 | GPUS=1 CPUS_PER_GPU=8 MEM_PER_CPU=5 QOS=normal REPEAT=3 \ 172 | ./scripts/dup_run_sbatch.sh rtx6000 joint_fsclip_ncaltech_params \ 173 | train.py none \ 174 | configs/fsclip/joint_adapter/joint_fsclip_ncaltech_params.py \ 175 | --num_shots 5 --fp16 --cudnn 176 | ``` 177 | 178 | The model weights will be saved under `./checkpoint/joint_fsclip_ncaltech_params-dup$X-5shot/`. 179 | 180 | - We provide scripts to train one EventCLIP under all five numbers of shots used in the paper [train_all_shots.sh](../scripts/train_all_shots.sh). 181 | For example, if you want to run the above `dup_run_sbatch.sh` command with `20, 10, 5, 3, 1` shots, simply wrap that command with this script by: 182 | 183 | ``` 184 | ./scripts/train_all_shots.sh "GPUS=1 CPUS_PER_GPU=8 MEM_PER_CPU=5 QOS=normal REPEAT=3 ./scripts/dup_run_sbatch.sh rtx6000 joint_fsclip_ncaltech_params train.py none configs/fsclip/joint_adapter/joint_fsclip_ncaltech_params.py --fp16 --cudnn" 20 10 5 3 1 185 | ``` 186 | 187 | The model weights will be saved under `./checkpoint/joint_fsclip_ncaltech_params-dup$X-$Yshot/`. 188 | See the `Testing` section above on how to efficiently test all these models with the `--train_shots` flag. 189 | 190 | - In zero-shot testing, we provide script to test EventCLIP with different ViT's image encoder architectures [test_all_arch.sh](../scripts/test_all_arch.sh). 191 | Simply wrap your testing command with this script, and add the arches you want to try 192 | 193 | - On N-ImageNet testing, we provide script to test EventCLIP over all robustness variants [test_all_subset.sh](../scripts/test_all_subset.sh). 194 | Simply wrap your testing command with this script, and add the subsets you want to try 195 | -------------------------------------------------------------------------------- /models/clip_cls_ft.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | 7 | import clip 8 | 9 | from nerv.training import BaseModel 10 | 11 | from .adapter import IdentityAdapter, TransformerAdapter 12 | from .lora import inject_trainable_lora 13 | 14 | 15 | class FTCLIPClassifier(BaseModel): 16 | """Finetune CLIP model for **few-shot** classification.""" 17 | 18 | def __init__( 19 | self, 20 | adapter_dict=dict( 21 | adapter_type='text-identity', 22 | residual=True, 23 | ), 24 | clip_dict=dict( 25 | clip_model=None, 26 | prompt='a point cloud image of a {}', 27 | class_names=None, 28 | agg_func='sum', 29 | ), 30 | loss_dict=dict( 31 | use_logits_loss=True, 32 | use_probs_loss=False, 33 | ), 34 | ): 35 | super().__init__() 36 | 37 | self.clip_dict = clip_dict 38 | self.loss_dict = loss_dict 39 | self.adapter_dict = copy.deepcopy(adapter_dict) 40 | 41 | self._build_clip() 42 | self._build_loss() 43 | self._build_adapter() 44 | 45 | def _build_clip(self): 46 | # freeze the CLIP model 47 | model = self.clip_dict['clip_model'] 48 | for p in model.parameters(): 49 | p.requires_grad = False 50 | # LoRA fine-tuning 51 | lora = self.clip_dict.get('lora', -1) 52 | if isinstance(lora, str) or lora > 0: 53 | model.visual = inject_trainable_lora(model.visual, r=lora) 54 | # finetune CLIP.visual or its sub-layers 55 | conv1 = self.clip_dict['only_conv1'] 56 | bias = self.clip_dict['only_bias'] 57 | ln = self.clip_dict['only_ln'] 58 | cls_fc = self.clip_dict.get('only_cls_fc', False) 59 | cls_token = self.clip_dict.get('only_cls_token', False) 60 | if conv1: # only tune the first conv layer 61 | for p in model.visual.conv1.parameters(): 62 | p.requires_grad = True 63 | if bias: # only tune the bias terms 64 | for name, p in model.visual.named_parameters(): 65 | if 'bias' in name and p is not None: 66 | p.requires_grad = True 67 | if ln: # only tune the LayerNorm layers 68 | for m in model.visual.modules(): 69 | if isinstance(m, nn.LayerNorm): 70 | for p in m.parameters(): 71 | p.requires_grad = True 72 | if cls_fc: # only tune the final projection head 73 | model.visual.proj.requires_grad = True 74 | if cls_token: # only tune the CLS token 75 | model.visual.class_embedding.requires_grad = True 76 | # tune all 77 | if (isinstance(lora, int) and lora <= 0) and \ 78 | not (conv1 or bias or ln or cls_fc or cls_token): 79 | for p in model.visual.parameters(): 80 | p.requires_grad = True 81 | # set as eval 82 | self.model = model.eval() 83 | self.logit_scale = model.logit_scale.exp().item() 84 | 85 | # text prompt for zero-shot cls 86 | self.prompt = self.clip_dict['prompt'] 87 | self.class_names = self.clip_dict['class_names'] 88 | self.text_feats = None 89 | 90 | # aggregation function 91 | self.agg_func = self.clip_dict['agg_func'] 92 | assert self.agg_func in ['sum', 'mean', 'max'] 93 | 94 | def _build_loss(self): 95 | self.use_logits_loss = self.loss_dict['use_logits_loss'] 96 | self.use_probs_loss = self.loss_dict['use_probs_loss'] 97 | assert int(self.use_logits_loss) + int(self.use_probs_loss) == 1 98 | 99 | def _build_prompts(self, adapter_type): 100 | """Build the text features for prompt tuning.""" 101 | with torch.no_grad(): 102 | text_feats = self._get_text_feats().float() # [n_classes, C] 103 | self.text_feats = nn.Parameter(text_feats, requires_grad=True) 104 | adapter_type = adapter_type[5:] 105 | return adapter_type 106 | 107 | def _build_adapter(self): 108 | # whether to tune the text features as well 109 | adapter_type = self.adapter_dict.pop('adapter_type').lower() 110 | if adapter_type.startswith('text-'): 111 | print('Tune text features as well!') 112 | self.prompt_tuning = True 113 | adapter_type = self._build_prompts(adapter_type) 114 | else: 115 | self.prompt_tuning = False 116 | 117 | # image feature adapter 118 | self.adapter_type = adapter_type 119 | assert adapter_type == 'identity' 120 | if adapter_type == 'identity': # not tuning image features 121 | model = IdentityAdapter 122 | elif adapter_type == 'trans': # Transformer to fuse image features 123 | model = TransformerAdapter 124 | else: 125 | raise NotImplementedError(f'adapter {adapter_type} not supported!') 126 | self.adapter = model(**self.adapter_dict) 127 | 128 | def _same_class_names(self, class_names): 129 | """Check if the input `class_names` matches `self.class_names`.""" 130 | return all([c1 == c2 for c1, c2 in zip(class_names, self.class_names)]) 131 | 132 | def _get_text_feats(self, class_names=None): 133 | """Compute the text prompt features using CLIP text encoder.""" 134 | # no `class_names` provided 135 | if class_names is None: 136 | no_cls_flag = True 137 | class_names = self.class_names 138 | # with cached `text_feats` 139 | if self.text_feats is not None: 140 | return self.text_feats 141 | # `class_names` matches 142 | elif self._same_class_names(class_names): 143 | # with cached `text_feats` 144 | if self.text_feats is not None: 145 | return self.text_feats 146 | 147 | # compute the text prompt features 148 | class_names = [c.lower().replace('_', ' ') for c in class_names] 149 | prompts = torch.cat([ 150 | clip.tokenize(self.prompt.format(c)) for c in class_names 151 | ]).to(self.device) 152 | text_feats = self.model.encode_text(prompts) 153 | text_feats = F.normalize(text_feats, p=2, dim=-1) 154 | 155 | # cache the `text_feats` if 156 | # 1) the `class_names` matches 157 | # 2) the `class_names` is not provided 158 | if no_cls_flag or self._same_class_names(class_names): 159 | self.text_feats = text_feats 160 | 161 | return text_feats # [n_classes, C] 162 | 163 | def get_text_feats(self, class_names=None): 164 | # finetune the text features (i.e. prompt tuning) 165 | if self.prompt_tuning: 166 | assert self.text_feats.requires_grad, 'prompt should be trainable!' 167 | text_feats = F.normalize(self.text_feats, p=2, dim=-1) 168 | # otherwise, we use fixed text features 169 | else: 170 | with torch.no_grad(): 171 | text_feats = self._get_text_feats(class_names) 172 | return self._adjust_dtype(text_feats) 173 | 174 | def _get_img_feats(self, imgs): 175 | """Compute the image features using CLIP image encoder. 176 | 177 | Args: 178 | imgs (torch.Tensor): [B, C, H, W] 179 | """ 180 | img_feats = self.model.encode_image(imgs) 181 | return img_feats # [B, C] 182 | 183 | def get_img_feats(self, imgs): 184 | img_feats = self._get_img_feats(imgs) 185 | return self._adjust_dtype(img_feats) 186 | 187 | def _aggregate_logits(self, logits, valid_masks): 188 | """Aggregate logits for each data. 189 | 190 | Args: 191 | logits (torch.Tensor): [B, T, n_classes] 192 | valid_masks (torch.Tensor): [B, T] 193 | """ 194 | if self.agg_func == 'sum': 195 | logits = logits.sum(1) 196 | elif self.agg_func == 'mean': 197 | logits = logits.sum(1) / valid_masks.float().sum(1, keepdim=True) 198 | elif self.agg_func == 'max': 199 | # make invalid logits very small 200 | logits = logits - (1. - valid_masks.float()) * 1e6 201 | logits = logits.max(1)[0] 202 | else: 203 | raise NotImplementedError 204 | return logits 205 | 206 | def _aggregate_probs(self, logits, valid_masks): 207 | """This one always take the mean.""" 208 | valid_masks = valid_masks.detach().float() 209 | probs = logits.softmax(dim=-1) 210 | probs = probs * valid_masks[..., None] 211 | probs = probs.sum(1) / valid_masks.sum(1, keepdim=True) 212 | return probs 213 | 214 | def forward(self, data_dict): 215 | """Forward function.""" 216 | imgs = data_dict['img'] # [B, T, C, H, W], `T` is number of views 217 | valid_masks = data_dict['valid_mask'] # [B, T] 218 | B, T = valid_masks.shape 219 | 220 | # compute image features 221 | valid_imgs = imgs[valid_masks] # [N, C, H, W] 222 | img_feats = self.get_img_feats(valid_imgs) # [N, C] 223 | 224 | # update image features using adapter 225 | C = img_feats.shape[-1] 226 | full_img_feats = torch.zeros(B, T, C).type_as(img_feats) 227 | full_img_feats[valid_masks] = img_feats 228 | # full_img_feats = self.adapter(full_img_feats, valid_masks) 229 | # [B, T, C], multi-view image features 230 | # normalize the output features 231 | # all zeros vector will still be zeros after F.normalize() 232 | full_img_feats = F.normalize( 233 | full_img_feats, p=2, dim=-1).type_as(full_img_feats) 234 | # make invalid features zeros 235 | full_img_feats = full_img_feats * valid_masks.float().unsqueeze(-1) 236 | 237 | # compute text features 238 | # we may need to compute gradients w.r.t. text features 239 | # so we can't use torch.no_grad() here 240 | text_feats = self.get_text_feats() # [n_classes, C] 241 | 242 | # compute logits 243 | full_logits = (self.logit_scale * full_img_feats @ text_feats.T) 244 | # [B, T, n_cls], multi-view logits 245 | 246 | # convert to [B, n_cls] for loss computation! 247 | logits = self._aggregate_logits(full_logits, valid_masks) 248 | probs = self._aggregate_probs(full_logits, valid_masks) 249 | 250 | out_dict = { 251 | 'full_logits': full_logits, # [B, T, n_classes] 252 | 'valid_masks': valid_masks, # [B, T] 253 | 'logits': logits, # [B, n_classes] 254 | 'probs': probs, # [B, n_classes] 255 | } 256 | return out_dict 257 | 258 | def calc_train_loss(self, data_dict, out_dict): 259 | """Compute training loss.""" 260 | labels = data_dict['label'] # [B] 261 | logits = out_dict['logits'] # [B, n_classes] 262 | probs = out_dict['probs'] # [B, n_classes] 263 | loss_dict = {} 264 | if self.use_logits_loss: 265 | loss_dict['ce_loss'] = F.cross_entropy(logits, labels) 266 | if self.use_probs_loss: 267 | probs = probs + 1e-6 # avoid nan 268 | loss_dict['ce_loss'] = F.nll_loss(probs.log(), labels) 269 | return loss_dict 270 | 271 | @torch.no_grad() 272 | def calc_eval_loss(self, data_dict, out_dict): 273 | """Loss computation in eval.""" 274 | loss_dict = self.calc_train_loss(data_dict, out_dict) 275 | 276 | # also compute the cls accuracy 277 | labels = data_dict['label'] # [B] 278 | # based on aggregated probs 279 | probs = out_dict['probs'] # [B, n_classes] 280 | probs_acc = (probs.argmax(dim=-1) == labels).float().mean() 281 | loss_dict['probs_acc'] = probs_acc 282 | # based on aggregated logits 283 | logits = out_dict['logits'] # [B, n_classes] 284 | logits_acc = (logits.argmax(dim=-1) == labels).float().mean() 285 | loss_dict['logits_acc'] = logits_acc 286 | return loss_dict 287 | 288 | def _adjust_dtype(self, x): 289 | """CLIP model returns features in FP16. 290 | During training, torch.amp will help us handle this. 291 | However, during inference, we need to manually convert them to FP32. 292 | """ 293 | if self.training: 294 | return x 295 | return x.type(self.dtype) 296 | 297 | @property 298 | def dtype(self): 299 | return self.adapter.dtype 300 | 301 | @property 302 | def device(self): 303 | return self.model.logit_scale.device 304 | 305 | def train(self, mode=True): 306 | nn.Module.train(self, mode) 307 | # keep CLIP in eval mode 308 | self.model.eval() 309 | # but adjust CLIP.visual 310 | self.model.visual.train(mode) 311 | return self 312 | 313 | def state_dict(self): 314 | """Remove CLIP weight (keep `model.visual`) from the state dict.""" 315 | w = super().state_dict() 316 | w = { 317 | k: v 318 | for k, v in w.items() 319 | if ((not k.startswith('model.')) or k.startswith('model.visual.')) 320 | } 321 | return w 322 | 323 | def load_state_dict(self, state_dict, strict=True): 324 | """Don't load CLIP weight (load `model.visual`) from the state dict.""" 325 | # load CLIP weight from the state dict 326 | clip_w = { 327 | f'model.{k}': v 328 | for k, v in self.model.state_dict().items() 329 | if not k.startswith('visual.') 330 | } 331 | assert all(k not in state_dict for k in clip_w) 332 | state_dict = {**clip_w, **state_dict} 333 | super().load_state_dict(state_dict, strict=strict) 334 | -------------------------------------------------------------------------------- /gen_data.py: -------------------------------------------------------------------------------- 1 | """EventCLIP testing script""" 2 | 3 | import os 4 | import os.path as osp 5 | import sys 6 | import importlib 7 | import argparse 8 | 9 | from tqdm import tqdm 10 | 11 | import torch 12 | 13 | import clip 14 | 15 | from nerv.training import BaseDataModule 16 | from nerv.utils import AverageMeter, load_obj 17 | 18 | from models import build_model 19 | from datasets import build_dataset 20 | 21 | 22 | def get_real_path(path): 23 | while osp.islink(path): 24 | path = os.readlink(path) 25 | return path 26 | 27 | 28 | def get_folder_and_fn(path): 29 | return osp.join(osp.basename(osp.dirname(path)), osp.basename(path)) 30 | 31 | 32 | def find_key_from_value(d, v): 33 | for k, v_ in d.items(): 34 | if v_ == v: 35 | return k 36 | return None 37 | 38 | 39 | def print_stats(class_names, gt_class_cnt, sel_class_cnt, 40 | sel_correct_class_cnt): 41 | print('\nClass stats:') 42 | for k in class_names: 43 | gt_num, sel_num, correct_num = \ 44 | gt_class_cnt[k], sel_class_cnt[k], sel_correct_class_cnt[k] 45 | print(f'\t{k}: GT {gt_num}, select {sel_num}, {correct_num} correct') 46 | print('Not accurate classes') 47 | less_accurate_cnt = 0 48 | for k in class_names: 49 | gt_num, sel_num, correct_num = \ 50 | gt_class_cnt[k], sel_class_cnt[k], sel_correct_class_cnt[k] 51 | ratio = correct_num / sel_num if sel_num > 0 else 0. 52 | if ratio < 0.5: 53 | print(f'\t{k}: GT {gt_num}, select {correct_num}/{sel_num} -- {ratio:.2f}') 54 | less_accurate_cnt += 1 55 | print(f'Not accurate classes: {less_accurate_cnt}/{len(class_names)}') 56 | 57 | total_num = sum(gt_class_cnt.values()) 58 | select_num = sum(sel_class_cnt.values()) 59 | select_correct_num = sum(sel_correct_class_cnt.values()) 60 | sel_acc = select_correct_num / select_num * 100. if select_num > 0 else 0. 61 | print(f'\nUsing {args.conf_thresh=}') 62 | if args.topk > 0: 63 | print(f'Using {args.topk=}') 64 | print(f'\tSelect {select_num} from {total_num}, Acc={sel_acc:.2f}%') 65 | if args.tta: 66 | print(f'Using TTA with {args.tta_consistent=} + {args.tta_min_prob=}') 67 | 68 | 69 | @torch.no_grad() 70 | def main(params): 71 | # have to load CLIP model first 72 | arch = params.clip_dict['arch'] 73 | device = 'cuda' 74 | model, preprocess = clip.load(arch, device=device) 75 | 76 | # build training dataset for generating pseudo labels 77 | params.data_transforms = preprocess 78 | tta = args.tta 79 | is_nin = ('n_imagenet' in params.dataset) 80 | if not is_nin: 81 | assert params.dataset == 'n_caltech', f'{params.dataset} not supported' 82 | print(f'Generate pseudo labels for {params.dataset}') 83 | test_set = build_dataset(params, val_only=False, gen_data=True, tta=tta) 84 | ev_dst = test_set.event_dataset 85 | class_names, labels = test_set.classes, ev_dst.labels 86 | 87 | datamodule = BaseDataModule( 88 | params, train_set=None, val_set=test_set, use_ddp=False) 89 | test_loader = datamodule.val_loader 90 | 91 | # build model 92 | params.clip_dict['clip_model'] = model 93 | params.clip_dict['class_names'] = test_set.classes 94 | if not is_zs: 95 | params.adapter_dict['in_dim'] = model.visual.output_dim 96 | model = build_model(params) 97 | 98 | # load weight 99 | gt_data = {} # we might have some labeled data 100 | if args.weight: 101 | assert not is_zs, 'Zero-shot models should not have pre-trained weight' 102 | model.load_weight(args.weight) 103 | print(f'Loading weight: {args.weight}') 104 | # load labeled data, we won't generate pseudo labels for them 105 | assert args.gt_shots > 0, \ 106 | 'Should specify the num_shots used to pre-train the model' 107 | assert f'{args.gt_shots}shot' in args.weight or \ 108 | f'{args.gt_shots}-shot' in args.weight, \ 109 | f'Weight {args.weight} does not match `{args.gt_shots}-shot`' 110 | assert f'{args.gt_shots}shot' in save_path or \ 111 | f'{args.gt_shots}-shot' in save_path, \ 112 | f'Should put `{args.gt_shots}shot` in `save_path`' 113 | split_fn = osp.join('./datasets/files', ev_dst.__class__.__name__, 114 | f'{args.gt_shots}shot-repeat=True.pkl') 115 | gt_split = load_obj(split_fn) # Dict[event_fn (str): label (int)] 116 | # convert to Dict[event_fn (str): class_name (str)] 117 | gt_data = {k: class_names[v] for k, v in gt_split.items()} 118 | gt_data_paths = [get_folder_and_fn(k) for k in gt_data.keys()] 119 | model = model.cuda().eval() 120 | 121 | # test 122 | all_acc_meter = AverageMeter() 123 | gt_class_cnt = {k: (labels == i).sum() for i, k in enumerate(class_names)} 124 | sel_class_cnt = {k: 0 for k in class_names} 125 | sel_correct_class_cnt = {k: 0 for k in class_names} 126 | pred_path2cls = {} 127 | 128 | conf_thresh, topk = args.conf_thresh, args.topk 129 | for data_dict in tqdm(test_loader): 130 | data_idx = data_dict.pop('data_idx').numpy() # [B] 131 | data_dict = {k: v.cuda() for k, v in data_dict.items()} 132 | if tta: # loaded data in shape [B, 4, N, ...] 133 | data_dict['img'] = data_dict['img'].flatten(0, 1) 134 | data_dict['valid_mask'] = data_dict['valid_mask'].flatten(0, 1) 135 | out_dict = model(data_dict) 136 | labels = data_dict['label'] # [B] 137 | 138 | # based on aggregated probs 139 | pred_probs = out_dict['probs'] 140 | # aggregate probs from multi-view TTA predictions 141 | if tta: 142 | probs = pred_probs.unflatten(0, (-1, 4)) # [B, 4, n_cls] 143 | tta_mask = torch.ones_like(labels).bool() # [B] 144 | # predictions over 4 views should be consistent 145 | if args.tta_consistent: 146 | pred_cls = probs.argmax(dim=-1) # [B, 4] 147 | tta_mask &= (pred_cls[:, 0] == pred_cls[:, 1]) & \ 148 | (pred_cls[:, 0] == pred_cls[:, 2]) & \ 149 | (pred_cls[:, 0] == pred_cls[:, 3]) 150 | # the minimum confidence should be larger than conf_thresh 151 | if args.tta_min_prob: 152 | min_probs = probs.max(-1).values.min(-1).values 153 | tta_mask &= (min_probs > conf_thresh) 154 | probs = probs.mean(dim=1) # [B, n_cls] 155 | else: 156 | probs = pred_probs 157 | probs_acc = (probs.argmax(dim=-1) == labels).float().mean().item() 158 | all_acc_meter.update(probs_acc, labels.shape[0]) 159 | 160 | # only trust probs > conf_thresh 161 | max_probs, pred_labels = probs.max(dim=-1) 162 | sel_mask = (max_probs > conf_thresh) 163 | if tta: 164 | sel_mask &= tta_mask 165 | # update class cnt 166 | for i, (lbl, pred_lbl) in enumerate(zip(labels, pred_labels)): 167 | ev_path = str(ev_dst.labeled_files[data_idx[i]]) 168 | # skip labeled data as they will be included later anyway 169 | if get_folder_and_fn(ev_path) in gt_data_paths: 170 | continue 171 | pred_cls_name = class_names[pred_lbl.item()] 172 | if sel_mask[i].item(): 173 | sel_class_cnt[pred_cls_name] += 1 174 | if pred_lbl.item() == lbl.item(): 175 | sel_correct_class_cnt[pred_cls_name] += 1 176 | if sel_mask[i].item(): 177 | if topk > 0: # also record the probs, take top-k later 178 | pred_path2cls[ev_path] = { 179 | 'cls': pred_cls_name, 180 | 'prob': max_probs[i].item(), 181 | } 182 | else: 183 | pred_path2cls[ev_path] = pred_cls_name 184 | 185 | print_stats(class_names, gt_class_cnt, sel_class_cnt, 186 | sel_correct_class_cnt) 187 | print(f'\n\nTesting {args.params}') 188 | if args.weight: 189 | print(f'Model weight: {args.weight}') 190 | print(f'\tProbs-based accuracy@1: {all_acc_meter.avg * 100.:.2f}%') 191 | 192 | if not save_path: 193 | return 194 | # save pseudo labels to a new dataset 195 | train_path = osp.join(save_path, 'extracted_train') if \ 196 | is_nin else osp.join(save_path, 'training') 197 | assert not osp.exists(save_path), f'{save_path} already exists!' 198 | os.makedirs(train_path, exist_ok=True) 199 | # some classes might be renamed 200 | new_cnames = ev_dst.new_cnames 201 | # only take `topk` predictions for each class 202 | if topk > 0: 203 | topk_pred_path2cls, sel_class_cnt, sel_correct_class_cnt = {}, {}, {} 204 | for cls_name in class_names: 205 | sel_correct_class_cnt[cls_name] = 0 206 | cls_pred_paths, cls_pred_probs = [], [] 207 | # find data that are classified as this class 208 | for path, pred in pred_path2cls.items(): 209 | if pred['cls'] == cls_name: 210 | cls_pred_paths.append(path) 211 | cls_pred_probs.append(pred['prob']) 212 | cls_pred_probs = torch.tensor(cls_pred_probs) 213 | k = min(topk, cls_pred_probs.shape[0]) 214 | _, topk_idx = cls_pred_probs.topk(k) 215 | for i in topk_idx: 216 | path = cls_pred_paths[i] # get the GT label from path 217 | gt_cls_name = osp.basename(osp.dirname(path)) 218 | if is_nin: 219 | gt_cls_name = ev_dst.folder2name[gt_cls_name] 220 | if new_cnames is not None: 221 | gt_cls_name = new_cnames.get(gt_cls_name, gt_cls_name) 222 | if gt_cls_name == cls_name: 223 | sel_correct_class_cnt[cls_name] += 1 224 | topk_pred_path2cls[path] = cls_name 225 | sel_class_cnt[cls_name] = k 226 | pred_path2cls = topk_pred_path2cls 227 | print_stats(class_names, gt_class_cnt, sel_class_cnt, 228 | sel_correct_class_cnt) 229 | # update data with GT labels 230 | pred_path2cls.update(gt_data) 231 | # save pseudo labels 232 | for path, pred_cls in pred_path2cls.items(): 233 | path = get_real_path(path) 234 | # path: xxx/N-Caltech101/training/airplanes/airplanes_150.npy 235 | # xxx/N_Imagenet/extracted_train/n02114855/n02114855_15515.npz 236 | # pred_cls is a semantic class name 237 | # some class names might have been altered 238 | if new_cnames is not None: 239 | ori_cls = find_key_from_value(new_cnames, pred_cls) 240 | if ori_cls is not None: 241 | pred_cls = ori_cls 242 | if is_nin: 243 | folder_name = ev_dst.name2folder[pred_cls] 244 | else: 245 | folder_name = pred_cls 246 | ev_name = osp.basename(path) 247 | # save to train_path/folder_name/ev_name 248 | # use soft link to save disk space 249 | new_path = osp.join(train_path, folder_name, ev_name) 250 | os.makedirs(osp.dirname(new_path), exist_ok=True) 251 | os.symlink(path, new_path) 252 | # also soft-link val/test set if they exist 253 | if is_nin: 254 | val_path = osp.join(save_path, 'extracted_val') 255 | ori_val_path = osp.join(osp.dirname(ev_dst.root), 'extracted_val') 256 | ori_val_path = get_real_path(ori_val_path) 257 | os.symlink(ori_val_path, val_path) 258 | else: 259 | val_path = osp.join(save_path, 'validation') 260 | test_path = osp.join(save_path, 'testing') 261 | ori_val_path = osp.join(osp.dirname(ev_dst.root), 'validation') 262 | ori_test_path = osp.join(osp.dirname(ev_dst.root), 'testing') 263 | ori_val_path = get_real_path(ori_val_path) 264 | ori_test_path = get_real_path(ori_test_path) 265 | os.symlink(ori_val_path, val_path) 266 | os.symlink(ori_test_path, test_path) 267 | print(f'\nSaved pseudo labels to {save_path}') 268 | # some classes don't have any pseudo labels 269 | # we still create the folder for consistency 270 | for k in class_names: 271 | if new_cnames is not None: 272 | ori_cls = find_key_from_value(new_cnames, k) 273 | if ori_cls is not None: 274 | k = ori_cls 275 | if is_nin: 276 | folder_name = ev_dst.name2folder[k] 277 | else: 278 | folder_name = k 279 | folder_path = osp.join(train_path, folder_name) 280 | os.makedirs(folder_path, exist_ok=True) 281 | 282 | 283 | if __name__ == "__main__": 284 | parser = argparse.ArgumentParser(description='EventCLIP') 285 | parser.add_argument('--params', type=str, required=True) 286 | parser.add_argument('--save_path', type=str, default='') 287 | parser.add_argument('--weight', type=str, default='', help='load weight') 288 | parser.add_argument('--conf_thresh', type=float, default=-1.) 289 | parser.add_argument('--tta', action='store_true') 290 | parser.add_argument('--tta_consistent', action='store_true') 291 | parser.add_argument('--tta_min_prob', action='store_true') 292 | parser.add_argument('--topk', type=int, default=-1) 293 | parser.add_argument('--gt_shots', type=int, default=-1) # labeled data 294 | args = parser.parse_args() 295 | 296 | if args.params.endswith('.py'): 297 | args.params = args.params[:-3] 298 | sys.path.append(osp.dirname(args.params)) 299 | params = importlib.import_module(osp.basename(args.params)) 300 | params = params.EventCLIPParams() 301 | 302 | # adjust params 303 | is_zs = (params.model == 'ZSCLIP') 304 | save_path = args.save_path 305 | if save_path: 306 | assert not osp.exists(save_path), f'{save_path} already exists!' 307 | 308 | main(params) 309 | exit(-1) 310 | -------------------------------------------------------------------------------- /models/clip_cls.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | 7 | import clip 8 | 9 | from nerv.training import BaseModel 10 | 11 | from .adapter import IdentityAdapter, TransformerAdapter 12 | 13 | 14 | class ZSCLIPClassifier(BaseModel): 15 | """CLIP model for **zero-shot** classification.""" 16 | 17 | def __init__( 18 | self, 19 | clip_dict=dict( 20 | clip_model=None, 21 | prompt='a point cloud image of a {}', 22 | class_names=None, 23 | agg_func='sum', 24 | ), 25 | loss_dict=dict( 26 | use_logits_loss=True, 27 | use_probs_loss=False, 28 | ), 29 | ): 30 | super().__init__() 31 | 32 | self.clip_dict = clip_dict 33 | self.loss_dict = loss_dict 34 | 35 | self._build_clip() 36 | self._build_loss() 37 | 38 | def _build_clip(self): 39 | # freeze the CLIP model 40 | model = self.clip_dict['clip_model'] 41 | for p in model.parameters(): 42 | p.requires_grad = False 43 | self.model = model.eval() 44 | self.logit_scale = model.logit_scale.exp().item() 45 | 46 | # text prompt for zero-shot cls 47 | self.prompt = self.clip_dict['prompt'] 48 | self.class_names = self.clip_dict['class_names'] 49 | self.text_feats = None 50 | 51 | # aggregation function 52 | self.agg_func = self.clip_dict['agg_func'] 53 | assert self.agg_func in ['sum', 'mean', 'max'] 54 | 55 | def _build_loss(self): 56 | self.use_logits_loss = self.loss_dict['use_logits_loss'] 57 | self.use_probs_loss = self.loss_dict['use_probs_loss'] 58 | assert int(self.use_logits_loss) + int(self.use_probs_loss) == 1 59 | 60 | def _same_class_names(self, class_names): 61 | """Check if the input `class_names` matches `self.class_names`.""" 62 | return all([c1 == c2 for c1, c2 in zip(class_names, self.class_names)]) 63 | 64 | def get_text_feats(self, class_names=None): 65 | """Compute the text prompt features using CLIP text encoder.""" 66 | # no `class_names` provided 67 | if class_names is None: 68 | no_cls_flag = True 69 | class_names = self.class_names 70 | # with cached `text_feats` 71 | if self.text_feats is not None: 72 | return self.text_feats 73 | # `class_names` matches 74 | elif self._same_class_names(class_names): 75 | # with cached `text_feats` 76 | if self.text_feats is not None: 77 | return self.text_feats 78 | 79 | # compute the text prompt features 80 | class_names = [c.lower().replace('_', ' ') for c in class_names] 81 | prompts = torch.cat([ 82 | clip.tokenize(self.prompt.format(c)) for c in class_names 83 | ]).to(self.device) 84 | text_feats = self.model.encode_text(prompts) 85 | text_feats = F.normalize(text_feats, p=2, dim=-1) 86 | 87 | # cache the `text_feats` if 88 | # 1) the `class_names` matches 89 | # 2) the `class_names` is not provided 90 | if no_cls_flag or self._same_class_names(class_names): 91 | self.text_feats = text_feats 92 | 93 | return text_feats # [n_classes, C] 94 | 95 | def get_img_feats(self, imgs): 96 | """Compute the image features using CLIP image encoder. 97 | 98 | Args: 99 | imgs (torch.Tensor): [B, C, H, W] 100 | """ 101 | img_feats = self.model.encode_image(imgs) 102 | return img_feats # [B, C] 103 | 104 | def _aggregate_logits(self, logits, valid_masks): 105 | """Aggregate logits for each data. 106 | 107 | Args: 108 | logits (torch.Tensor): [B, T, n_classes] 109 | valid_masks (torch.Tensor): [B, T] 110 | """ 111 | if self.agg_func == 'sum': 112 | logits = logits.sum(1) 113 | elif self.agg_func == 'mean': 114 | logits = logits.sum(1) / valid_masks.float().sum(1, keepdim=True) 115 | elif self.agg_func == 'max': 116 | # make invalid logits very small 117 | logits = logits - (1. - valid_masks.float()) * 1e6 118 | logits = logits.max(1)[0] 119 | else: 120 | raise NotImplementedError 121 | return logits 122 | 123 | def _aggregate_probs(self, logits, valid_masks): 124 | """This one always take the mean.""" 125 | valid_masks = valid_masks.detach().float() 126 | probs = logits.softmax(dim=-1) 127 | probs = probs * valid_masks[..., None] 128 | probs = probs.sum(1) / valid_masks.sum(1, keepdim=True) 129 | return probs 130 | 131 | def forward(self, data_dict): 132 | """Forward function.""" 133 | imgs = data_dict['img'] # [B, T, C, H, W] 134 | valid_masks = data_dict['valid_mask'] # [B, T] 135 | B, T = valid_masks.shape 136 | 137 | # compute image features 138 | with torch.no_grad(): 139 | valid_imgs = imgs[valid_masks] # [N, C, H, W] 140 | img_feats = self.get_img_feats(valid_imgs) # [N, C] 141 | 142 | # compute text features 143 | with torch.no_grad(): 144 | text_feats = self.get_text_feats() # [n_classes, C] 145 | n_cls = text_feats.shape[0] 146 | 147 | # compute logits 148 | logits = (self.logit_scale * img_feats @ text_feats.T) # [N, n_cls] 149 | 150 | # map logits back to [B, T, n_cls] 151 | full_logits = torch.zeros(B, T, n_cls).type_as(logits) 152 | full_logits[valid_masks] = logits 153 | logits = self._aggregate_logits(full_logits, valid_masks) 154 | probs = self._aggregate_probs(full_logits, valid_masks) 155 | 156 | out_dict = { 157 | 'full_logits': full_logits, # [B, T, n_classes] 158 | 'valid_masks': valid_masks, # [B, T] 159 | 'logits': logits, # [B, n_classes] 160 | 'probs': probs, # [B, n_classes] 161 | } 162 | return out_dict 163 | 164 | def calc_train_loss(self, data_dict, out_dict): 165 | """Compute training loss.""" 166 | labels = data_dict['label'] # [B] 167 | logits = out_dict['logits'] # [B, n_classes] 168 | probs = out_dict['probs'] # [B, n_classes] 169 | loss_dict = {} 170 | if self.use_logits_loss: 171 | loss_dict['ce_loss'] = F.cross_entropy(logits, labels) 172 | if self.use_probs_loss: 173 | probs = probs + 1e-6 # avoid nan 174 | loss_dict['ce_loss'] = F.nll_loss(probs.log(), labels) 175 | return loss_dict 176 | 177 | @torch.no_grad() 178 | def calc_eval_loss(self, data_dict, out_dict): 179 | """Loss computation in eval.""" 180 | loss_dict = self.calc_train_loss(data_dict, out_dict) 181 | 182 | # also compute the cls accuracy 183 | labels = data_dict['label'] # [B] 184 | # based on aggregated probs 185 | probs = out_dict['probs'] # [B, n_classes] 186 | probs_acc = (probs.argmax(dim=-1) == labels).float().mean() 187 | loss_dict['probs_acc'] = probs_acc 188 | # based on aggregated logits 189 | logits = out_dict['logits'] # [B, n_classes] 190 | logits_acc = (logits.argmax(dim=-1) == labels).float().mean() 191 | loss_dict['logits_acc'] = logits_acc 192 | return loss_dict 193 | 194 | @property 195 | def dtype(self): 196 | return self.model.logit_scale.dtype 197 | 198 | @property 199 | def device(self): 200 | return self.model.logit_scale.device 201 | 202 | def train(self, mode=True): 203 | nn.Module.train(self, mode) 204 | # keep CLIP in eval mode 205 | self.model.eval() 206 | return self 207 | 208 | def state_dict(self): 209 | """Remove CLIP weight from the state dict.""" 210 | w = super().state_dict() 211 | w = {k: v for k, v in w.items() if not k.startswith('model.')} 212 | return w 213 | 214 | def load_state_dict(self, state_dict, strict=True): 215 | """Don't load CLIP weight from the state dict.""" 216 | # load CLIP weight from the state dict 217 | clip_w = {f'model.{k}': v for k, v in self.model.state_dict().items()} 218 | state_dict = {**clip_w, **state_dict} 219 | super().load_state_dict(state_dict, strict=strict) 220 | 221 | 222 | class FSCLIPClassifier(ZSCLIPClassifier): 223 | """CLIP model for **few-shot** classification.""" 224 | 225 | def __init__( 226 | self, 227 | adapter_dict=dict( 228 | # 'trans', 'identity' 229 | # 'text-{}' with the above: tune text features as FC weight 230 | adapter_type='trans', 231 | residual=True, 232 | ), 233 | clip_dict=dict( 234 | clip_model=None, 235 | prompt='a point cloud image of a {}', 236 | class_names=None, 237 | agg_func='sum', 238 | ), 239 | loss_dict=dict( 240 | use_logits_loss=False, 241 | use_probs_loss=True, 242 | ), 243 | ): 244 | super().__init__( 245 | clip_dict=clip_dict, 246 | loss_dict=loss_dict, 247 | ) 248 | 249 | self.adapter_dict = copy.deepcopy(adapter_dict) 250 | 251 | self._build_adapter() 252 | 253 | def _build_prompts(self, adapter_type): 254 | """Build the text features for prompt tuning.""" 255 | with torch.no_grad(): 256 | text_feats = super().get_text_feats().float() # [n_classes, C] 257 | self.text_feats = nn.Parameter(text_feats, requires_grad=True) 258 | adapter_type = adapter_type[5:] 259 | return adapter_type 260 | 261 | def _build_adapter(self): 262 | # whether to tune the text features as well 263 | adapter_type = self.adapter_dict.pop('adapter_type').lower() 264 | if adapter_type.startswith('text-'): 265 | print('Tune text features as well!') 266 | self.prompt_tuning = True 267 | adapter_type = self._build_prompts(adapter_type) 268 | else: 269 | self.prompt_tuning = False 270 | 271 | # image feature adapter 272 | self.adapter_type = adapter_type 273 | if adapter_type == 'identity': # not tuning image features 274 | model = IdentityAdapter 275 | elif adapter_type == 'trans': # Transformer to fuse image features 276 | model = TransformerAdapter 277 | else: 278 | raise NotImplementedError(f'adapter {adapter_type} not supported!') 279 | self.adapter = model(**self.adapter_dict) 280 | 281 | def _adjust_dtype(self, x): 282 | """CLIP model returns features in FP16. 283 | During training, torch.amp will help us handle this. 284 | However, during inference, we need to manually convert them to FP32. 285 | """ 286 | if self.training: 287 | return x 288 | return x.type(self.dtype) 289 | 290 | def get_text_feats(self, class_names=None): 291 | # finetune the text features (i.e. prompt tuning) 292 | if self.prompt_tuning: 293 | assert self.text_feats.requires_grad or not self.training, \ 294 | 'prompt should be trainable!' 295 | text_feats = F.normalize(self.text_feats, p=2, dim=-1) 296 | # otherwise, we use fixed text features 297 | else: 298 | with torch.no_grad(): 299 | text_feats = super().get_text_feats(class_names) 300 | return self._adjust_dtype(text_feats) 301 | 302 | def get_img_feats(self, imgs): 303 | # fixed CLIP image features 304 | with torch.no_grad(): 305 | img_feats = super().get_img_feats(imgs) 306 | return self._adjust_dtype(img_feats) 307 | 308 | def forward(self, data_dict): 309 | """Forward function.""" 310 | imgs = data_dict['img'] # [B, T, C, H, W], `T` is number of views 311 | valid_masks = data_dict['valid_mask'] # [B, T] 312 | B, T = valid_masks.shape 313 | 314 | # compute image features 315 | valid_imgs = imgs[valid_masks] # [N, C, H, W] 316 | img_feats = self.get_img_feats(valid_imgs) # [N, C] 317 | 318 | # update image features using adapter 319 | C = img_feats.shape[-1] 320 | full_img_feats = torch.zeros(B, T, C).type_as(img_feats) 321 | full_img_feats[valid_masks] = img_feats 322 | full_img_feats = self.adapter(full_img_feats, valid_masks) 323 | # [B, T, C], multi-view image features 324 | # normalize the output features 325 | # all zeros vector will still be zeros after F.normalize() 326 | full_img_feats = F.normalize( 327 | full_img_feats, p=2, dim=-1).type_as(full_img_feats) 328 | # make invalid features zeros 329 | full_img_feats = full_img_feats * valid_masks.float().unsqueeze(-1) 330 | 331 | # compute text features 332 | # we may need to compute gradients w.r.t. text features 333 | # so we can't use torch.no_grad() here 334 | text_feats = self.get_text_feats() # [n_classes, C] 335 | 336 | # compute logits 337 | full_logits = (self.logit_scale * full_img_feats @ text_feats.T) 338 | # [B, T, n_cls], multi-view logits 339 | 340 | # convert to [B, n_cls] for loss computation! 341 | logits = self._aggregate_logits(full_logits, valid_masks) 342 | probs = self._aggregate_probs(full_logits, valid_masks) 343 | 344 | out_dict = { 345 | 'full_logits': full_logits, # [B, T, n_classes] 346 | 'valid_masks': valid_masks, # [B, T] 347 | 'logits': logits, # [B, n_classes] 348 | 'probs': probs, # [B, n_classes] 349 | } 350 | return out_dict 351 | 352 | @property 353 | def dtype(self): 354 | return self.adapter.dtype 355 | -------------------------------------------------------------------------------- /models/lora.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torch.nn import MultiheadAttention 6 | 7 | 8 | def lora_w_init_(lora_down, lora_up, r): 9 | """Initialize the LoRA weights.""" 10 | nn.init.normal_(lora_down, std=1. / r) 11 | nn.init.zeros_(lora_up) 12 | 13 | 14 | class LoraInjectedLinear(nn.Module): 15 | """Apply LoRA to a standard Linear layer.""" 16 | 17 | def __init__(self, linear, r=4): 18 | super().__init__() 19 | 20 | in_features = linear.in_features 21 | out_features = linear.out_features 22 | if r > min(in_features, out_features): 23 | raise ValueError( 24 | f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}" 25 | ) 26 | 27 | self.r = r 28 | self.linear = linear 29 | for p in self.linear.parameters(): 30 | p.requires_grad = False 31 | self.lora_down = nn.Linear(in_features, r, bias=False) 32 | self.lora_up = nn.Linear(r, out_features, bias=False) 33 | 34 | lora_w_init_(self.lora_down.weight, self.lora_up.weight, r) 35 | 36 | def forward(self, input): 37 | return (self.linear(input) + self.lora_up(self.lora_down(input))) 38 | 39 | @property 40 | def dtype(self): 41 | """Return the dtype of the projection weight.""" 42 | return self.linear.weight.dtype 43 | 44 | @property 45 | def device(self): 46 | """Return the device of the projection weight.""" 47 | return self.linear.weight.device 48 | 49 | @property 50 | def weight(self): 51 | """Return the LoRA adjusted weight.""" 52 | return self.linear.weight + self.lora_up.weight @ self.lora_down.weight 53 | 54 | @property 55 | def bias(self): 56 | """Return the bias of the projection layer.""" 57 | return self.linear.bias 58 | 59 | 60 | class LoraInjectedProj(nn.Module): 61 | """Apply LoRA on the projection head in MultiHeadAttention. 62 | 63 | The Q/K/V projection weight W is nn.Parameter(d_model, in_dim). 64 | We learn a lora_down [r, in_dim] and a lora_up [d_model, r]. 65 | """ 66 | 67 | def __init__(self, proj, r=4): 68 | super().__init__() 69 | 70 | d_model, in_dim = proj.shape 71 | if r > min(d_model, in_dim): 72 | raise ValueError( 73 | f"LoRA rank {r} must be less or equal than {min(d_model, in_dim)}" 74 | ) 75 | 76 | self.d_model = d_model 77 | self.r = r 78 | self.proj = proj # original projection weight, nn.Parameter 79 | self.proj.requires_grad = False # freeze the original weight 80 | 81 | self.lora_down = nn.Parameter(torch.empty(r, in_dim)) 82 | self.lora_up = nn.Parameter(torch.empty(d_model, r)) 83 | 84 | lora_w_init_(self.lora_down, self.lora_up, r) 85 | 86 | def forward(self): 87 | """Return the LoRA updated weight.""" 88 | return self.proj + self.lora_up @ self.lora_down 89 | 90 | @property 91 | def dtype(self): 92 | """Return the dtype of the projection weight.""" 93 | return self.proj.dtype 94 | 95 | @property 96 | def device(self): 97 | """Return the device of the projection weight.""" 98 | return self.proj.device 99 | 100 | 101 | class LoraInjectedMergedProj(nn.Module): 102 | """Apply LoRA on the **merged** projection head in MultiHeadAttention. 103 | 104 | The merged projection weight W is nn.Parameter(3*d_model, in_dim). 105 | We learn three (lora_down [r, in_dim] and lora_up [d_model, r]). 106 | """ 107 | 108 | def __init__(self, merged_proj, r=4, lora_k=True): 109 | super().__init__() 110 | 111 | d_model_3, in_dim = merged_proj.shape 112 | assert d_model_3 % 3 == 0, "MergedProj's dim must be divisible by 3" 113 | d_model = d_model_3 // 3 114 | if r > min(d_model, in_dim): 115 | raise ValueError( 116 | f"LoRA rank {r} must be less or equal than {min(d_model, in_dim)}" 117 | ) 118 | 119 | self.d_model = d_model 120 | self.r = r 121 | self.merged_proj = merged_proj # original projection weight 122 | self.merged_proj.requires_grad = False # freeze the original weight 123 | 124 | self.lora_down_q = nn.Parameter(torch.empty(r, in_dim)) 125 | self.lora_up_q = nn.Parameter(torch.empty(d_model, r)) 126 | self.lora_down_v = nn.Parameter(torch.empty(r, in_dim)) 127 | self.lora_up_v = nn.Parameter(torch.empty(d_model, r)) 128 | 129 | lora_w_init_(self.lora_down_q, self.lora_up_q, r) 130 | lora_w_init_(self.lora_down_v, self.lora_up_v, r) 131 | 132 | self.lora_k = lora_k 133 | if lora_k: 134 | self.lora_down_k = nn.Parameter(torch.empty(r, in_dim)) 135 | self.lora_up_k = nn.Parameter(torch.empty(d_model, r)) 136 | lora_w_init_(self.lora_down_k, self.lora_up_k, r) 137 | 138 | def forward(self): 139 | """Return the LoRA updated weight.""" 140 | return torch.cat([ 141 | self.merged_proj[:self.d_model] + 142 | self.lora_up_q @ self.lora_down_q, 143 | self.merged_proj[self.d_model:2 * self.d_model] + 144 | self.lora_up_k @ self.lora_down_k if self.lora_k else 145 | self.merged_proj[self.d_model:2 * self.d_model], 146 | self.merged_proj[2 * self.d_model:] + 147 | self.lora_up_v @ self.lora_down_v, 148 | ], 149 | dim=0) 150 | 151 | @property 152 | def dtype(self): 153 | """Return the dtype of the projection weight.""" 154 | return self.merged_proj.dtype 155 | 156 | @property 157 | def device(self): 158 | """Return the device of the projection weight.""" 159 | return self.merged_proj.device 160 | 161 | 162 | class LoraInjectedMHA(MultiheadAttention): 163 | """MultiHeadAttention with LoRA fine-tuning.""" 164 | 165 | def forward( 166 | self, 167 | query, 168 | key, 169 | value, 170 | key_padding_mask=None, 171 | need_weights=True, 172 | attn_mask=None, 173 | average_attn_weights=True, 174 | ): 175 | is_batched = query.dim() == 3 176 | why_not_fast_path = '' 177 | if not is_batched: 178 | why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}" 179 | elif query is not key or key is not value: 180 | # When lifting this restriction, don't forget to either 181 | # enforce that the dtypes all match or test cases where 182 | # they don't! 183 | why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)" 184 | elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype: 185 | why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match" 186 | elif self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype: 187 | # this case will fail anyway, but at least they'll get a useful error message. 188 | why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match" 189 | elif self.training: 190 | why_not_fast_path = "training is enabled" 191 | elif not self.batch_first: 192 | why_not_fast_path = "batch_first was not True" 193 | elif self.bias_k is not None: 194 | why_not_fast_path = "self.bias_k was not None" 195 | elif self.bias_v is not None: 196 | why_not_fast_path = "self.bias_v was not None" 197 | elif self.dropout: 198 | why_not_fast_path = f"dropout was {self.dropout}, required zero" 199 | elif self.add_zero_attn: 200 | why_not_fast_path = "add_zero_attn was enabled" 201 | elif not self._qkv_same_embed_dim: 202 | why_not_fast_path = "_qkv_same_embed_dim was not True" 203 | elif query.is_nested and (key_padding_mask is not None 204 | or attn_mask is not None): 205 | why_not_fast_path = "key_padding_mask and attn_mask are not supported with NestedTensor input" 206 | elif not query.is_nested and key_padding_mask is not None and attn_mask is not None: 207 | why_not_fast_path = "key_padding_mask and attn_mask were both supplied" 208 | 209 | if not why_not_fast_path: 210 | tensor_args = ( 211 | query, 212 | key, 213 | value, 214 | self.in_proj_weight(), 215 | self.in_proj_bias, 216 | self.out_proj.weight, 217 | self.out_proj.bias, 218 | ) 219 | # We have to use list comprehensions below because TorchScript does not support 220 | # generator expressions. 221 | if torch.overrides.has_torch_function(tensor_args): 222 | why_not_fast_path = "some Tensor argument has_torch_function" 223 | elif not all([(x.is_cuda or 'cpu' in str(x.device)) 224 | for x in tensor_args]): 225 | why_not_fast_path = "some Tensor argument is neither CUDA nor CPU" 226 | elif torch.is_grad_enabled() and any( 227 | [x.requires_grad for x in tensor_args]): 228 | why_not_fast_path = ( 229 | "grad is enabled and at least one of query or the " 230 | "input/output projection weights or biases requires_grad") 231 | if not why_not_fast_path: 232 | return torch._native_multi_head_attention( 233 | query, key, value, self.embed_dim, self.num_heads, 234 | self.in_proj_weight(), self.in_proj_bias, 235 | self.out_proj.weight, self.out_proj.bias, key_padding_mask 236 | if key_padding_mask is not None else attn_mask, 237 | need_weights, average_attn_weights) 238 | any_nested = query.is_nested or key.is_nested or value.is_nested 239 | assert not any_nested, ( 240 | "MultiheadAttention does not support NestedTensor outside of its fast path. " 241 | + f"The fast path was not hit because {why_not_fast_path}") 242 | 243 | if self.batch_first and is_batched: 244 | # make sure that the transpose op does not affect the "is" property 245 | if key is value: 246 | if query is key: 247 | query = key = value = query.transpose(1, 0) 248 | else: 249 | query, key = [x.transpose(1, 0) for x in (query, key)] 250 | value = key 251 | else: 252 | query, key, value = [ 253 | x.transpose(1, 0) for x in (query, key, value) 254 | ] 255 | 256 | if not self._qkv_same_embed_dim: 257 | attn_output, attn_output_weights = F.multi_head_attention_forward( 258 | query, 259 | key, 260 | value, 261 | self.embed_dim, 262 | self.num_heads, 263 | self.in_proj_weight, 264 | self.in_proj_bias, 265 | self.bias_k, 266 | self.bias_v, 267 | self.add_zero_attn, 268 | self.dropout, 269 | self.out_proj.weight, 270 | self.out_proj.bias, 271 | training=self.training, 272 | key_padding_mask=key_padding_mask, 273 | need_weights=need_weights, 274 | attn_mask=attn_mask, 275 | use_separate_proj_weight=True, 276 | q_proj_weight=self.q_proj_weight(), 277 | k_proj_weight=self.k_proj_weight(), 278 | v_proj_weight=self.v_proj_weight(), 279 | average_attn_weights=average_attn_weights) 280 | else: 281 | attn_output, attn_output_weights = F.multi_head_attention_forward( 282 | query, 283 | key, 284 | value, 285 | self.embed_dim, 286 | self.num_heads, 287 | self.in_proj_weight(), 288 | self.in_proj_bias, 289 | self.bias_k, 290 | self.bias_v, 291 | self.add_zero_attn, 292 | self.dropout, 293 | self.out_proj.weight, 294 | self.out_proj.bias, 295 | training=self.training, 296 | key_padding_mask=key_padding_mask, 297 | need_weights=need_weights, 298 | attn_mask=attn_mask, 299 | average_attn_weights=average_attn_weights) 300 | if self.batch_first and is_batched: 301 | return attn_output.transpose(1, 0), attn_output_weights 302 | else: 303 | return attn_output, attn_output_weights 304 | 305 | def __setattr__(self, name, value): 306 | # special case hack for LoRA 307 | def remove_from(*dicts_or_sets): 308 | for d in dicts_or_sets: 309 | if name in d: 310 | if isinstance(d, dict): 311 | del d[name] 312 | else: 313 | d.discard(name) 314 | 315 | try: 316 | super().__setattr__(name, value) 317 | return 318 | except TypeError: 319 | pass 320 | 321 | assert isinstance(getattr(self, name), nn.Parameter) and isinstance( 322 | value, nn.Module) 323 | remove_from(self.__dict__, self._buffers, self._parameters, 324 | self._modules, self._non_persistent_buffers_set) 325 | modules = self.__dict__.get('_modules') 326 | modules[name] = value 327 | 328 | 329 | def build_lora_proj(proj, r): 330 | """Given an old projection, build a LoRA projection.""" 331 | lora_proj = LoraInjectedProj(proj=proj, r=r) 332 | return lora_proj 333 | 334 | 335 | def build_lora_merged_proj(merged_proj, r, lora_k=True): 336 | """Given an old merged projection, build a LoRA merged projection.""" 337 | lora_merged_proj = LoraInjectedMergedProj( 338 | merged_proj=merged_proj, r=r, lora_k=lora_k) 339 | return lora_merged_proj 340 | 341 | 342 | def build_lora_mha(mha, r): 343 | """Given an old MHA, build a LoRA-MHA.""" 344 | lora_mha = LoraInjectedMHA( 345 | embed_dim=mha.embed_dim, 346 | num_heads=mha.num_heads, 347 | dropout=mha.dropout, 348 | kdim=mha.kdim, 349 | vdim=mha.vdim, 350 | batch_first=mha.batch_first, 351 | device=mha.out_proj.weight.device, 352 | dtype=mha.out_proj.weight.dtype, 353 | ) 354 | lora_mha.load_state_dict(mha.state_dict()) 355 | for p in lora_mha.parameters(): 356 | p.requires_grad = False 357 | # parse LoRA arguments 358 | # if it's an int, then it's the dim 359 | # otherwise can be 'qv-$DIM', 'qkv-$DIM', 'qkvo-$DIM' 360 | # by default we apply LoRA to q,k,v projections, no output projection 361 | if not isinstance(r, int): 362 | assert 'q' in r and 'v' in r 363 | lora_k = ('k' in r) 364 | lora_o = ('o' in r) 365 | r = int(r.split('-')[-1]) 366 | else: 367 | lora_k = True 368 | lora_o = False 369 | assert r > 0 370 | # inject LoRA to projection head weights 371 | if mha._qkv_same_embed_dim: # replace `in_proj_weight` 372 | lora_mha.in_proj_weight = build_lora_merged_proj( 373 | mha.in_proj_weight, r=r, lora_k=lora_k) 374 | else: # replace `q_proj_weight`, `k_proj_weight`, `v_proj_weight` 375 | lora_mha.q_proj_weight = build_lora_proj(mha.q_proj_weight, r) 376 | lora_mha.v_proj_weight = build_lora_proj(mha.v_proj_weight, r) 377 | if lora_k: 378 | lora_mha.k_proj_weight = build_lora_proj(mha.k_proj_weight, r) 379 | # inject LoRA to output projection head weights 380 | if lora_o: 381 | lora_mha.out_proj = LoraInjectedLinear(mha.out_proj, r=r) 382 | return lora_mha 383 | 384 | 385 | def inject_trainable_lora(model, r=4): 386 | """Replace all the MHA in `model` with LoRA-MHA.""" 387 | to_replace = [] 388 | for name, module in model.named_modules(): 389 | if isinstance(module, nn.MultiheadAttention): 390 | lora_mha = build_lora_mha(module, r) 391 | to_replace.append((name, lora_mha)) 392 | for name, lora_mha in to_replace: 393 | # cannot directly do `setattr(model, name, lora_mha)` 394 | # name like `transformer.resblocks.23.attn`, containing `.` and numbers 395 | # workaround: make it `transformer.resblocks[23].attn` 396 | # then use `eval()` 397 | for s in name.split('.'): 398 | if s.isdigit(): 399 | name = name.replace(f'.{s}', f'[{s}]') 400 | assert name[-4:] == 'attn' 401 | m = eval(f'model.{name[:-5]}') 402 | m.attn = lora_mha 403 | return model 404 | --------------------------------------------------------------------------------