├── data
└── .gitkeep
├── pretrained
└── .gitkeep
├── assets
└── pipeline.png
├── datasets
├── files
│ ├── NCars
│ │ ├── 10shot-repeat=False.pkl
│ │ ├── 20shot-repeat=False.pkl
│ │ ├── 30shot-repeat=False.pkl
│ │ ├── 50shot-repeat=False.pkl
│ │ ├── 100shot-repeat=False.pkl
│ │ ├── 200shot-repeat=False.pkl
│ │ └── 500shot-repeat=False.pkl
│ ├── NCaltech101
│ │ ├── 10shot-repeat=True.pkl
│ │ ├── 1shot-repeat=False.pkl
│ │ ├── 1shot-repeat=True.pkl
│ │ ├── 20shot-repeat=True.pkl
│ │ ├── 2shot-repeat=False.pkl
│ │ ├── 2shot-repeat=True.pkl
│ │ ├── 3shot-repeat=False.pkl
│ │ ├── 3shot-repeat=True.pkl
│ │ ├── 50shot-repeat=True.pkl
│ │ ├── 5shot-repeat=False.pkl
│ │ ├── 5shot-repeat=True.pkl
│ │ ├── 10shot-repeat=False.pkl
│ │ ├── 20shot-repeat=False.pkl
│ │ └── 50shot-repeat=False.pkl
│ ├── NImageNet
│ │ ├── 10shot-repeat=False.pkl
│ │ ├── 1shot-repeat=False.pkl
│ │ ├── 20shot-repeat=False.pkl
│ │ ├── 2shot-repeat=False.pkl
│ │ ├── 3shot-repeat=False.pkl
│ │ ├── 50shot-repeat=False.pkl
│ │ └── 5shot-repeat=False.pkl
│ └── NImageNetMini
│ │ ├── 10shot-repeat=True.pkl
│ │ ├── 1shot-repeat=True.pkl
│ │ ├── 20shot-repeat=True.pkl
│ │ ├── 2shot-repeat=True.pkl
│ │ ├── 3shot-repeat=True.pkl
│ │ ├── 50shot-repeat=True.pkl
│ │ └── 5shot-repeat=True.pkl
├── __init__.py
├── cars.py
├── utils.py
├── imagenet.py
├── vis.py
├── imagenet_mini.py
├── event2img.py
├── augment.py
└── caltech.py
├── scripts
├── test_all_subset.sh
├── train_all_shots.sh
├── test_all_arch.sh
├── dup_run_sbatch.sh
├── resubmit_failed_job.sh
└── sbatch_run.sh
├── models
├── __init__.py
├── adapter.py
├── clip_cls_ft.py
├── clip_cls.py
└── lora.py
├── configs
├── zsclip
│ ├── zsclip_ncars_params.py
│ ├── zsclip_nin_params.py
│ ├── zsclip_ncaltech_params.py
│ └── zsclip_nin_mini_params-vitb32.py
├── fsclip
│ ├── joint_adapter
│ │ ├── joint_fsclip_nin_params.py
│ │ ├── joint_fsclip_ncaltech_params.py
│ │ ├── joint_fsclip_ncars_params.py
│ │ └── joint_fsclip_nin_mini_params-vitb32.py
│ └── text_adapter
│ │ ├── text_fsclip_nin_params.py
│ │ ├── text_fsclip_ncaltech_params.py
│ │ └── text_fsclip_ncars_params.py
└── ftclip
│ ├── ft_text_fsclip_nin_params.py
│ ├── ft_text_fsclip_nin_params-lora16.py
│ ├── ft_text_fsclip_ncaltech_params-vitb16.py
│ └── ft_text_fsclip_nin_params-vitb16.py
├── LICENSE
├── docs
├── install.md
├── data.md
└── benchmark.md
├── .gitignore
├── README.md
├── train.py
├── environment.yml
├── method.py
├── test.py
└── gen_data.py
/data/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/pretrained/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/assets/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/assets/pipeline.png
--------------------------------------------------------------------------------
/datasets/files/NCars/10shot-repeat=False.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NCars/10shot-repeat=False.pkl
--------------------------------------------------------------------------------
/datasets/files/NCars/20shot-repeat=False.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NCars/20shot-repeat=False.pkl
--------------------------------------------------------------------------------
/datasets/files/NCars/30shot-repeat=False.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NCars/30shot-repeat=False.pkl
--------------------------------------------------------------------------------
/datasets/files/NCars/50shot-repeat=False.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NCars/50shot-repeat=False.pkl
--------------------------------------------------------------------------------
/datasets/files/NCars/100shot-repeat=False.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NCars/100shot-repeat=False.pkl
--------------------------------------------------------------------------------
/datasets/files/NCars/200shot-repeat=False.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NCars/200shot-repeat=False.pkl
--------------------------------------------------------------------------------
/datasets/files/NCars/500shot-repeat=False.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NCars/500shot-repeat=False.pkl
--------------------------------------------------------------------------------
/datasets/files/NCaltech101/10shot-repeat=True.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NCaltech101/10shot-repeat=True.pkl
--------------------------------------------------------------------------------
/datasets/files/NCaltech101/1shot-repeat=False.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NCaltech101/1shot-repeat=False.pkl
--------------------------------------------------------------------------------
/datasets/files/NCaltech101/1shot-repeat=True.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NCaltech101/1shot-repeat=True.pkl
--------------------------------------------------------------------------------
/datasets/files/NCaltech101/20shot-repeat=True.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NCaltech101/20shot-repeat=True.pkl
--------------------------------------------------------------------------------
/datasets/files/NCaltech101/2shot-repeat=False.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NCaltech101/2shot-repeat=False.pkl
--------------------------------------------------------------------------------
/datasets/files/NCaltech101/2shot-repeat=True.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NCaltech101/2shot-repeat=True.pkl
--------------------------------------------------------------------------------
/datasets/files/NCaltech101/3shot-repeat=False.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NCaltech101/3shot-repeat=False.pkl
--------------------------------------------------------------------------------
/datasets/files/NCaltech101/3shot-repeat=True.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NCaltech101/3shot-repeat=True.pkl
--------------------------------------------------------------------------------
/datasets/files/NCaltech101/50shot-repeat=True.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NCaltech101/50shot-repeat=True.pkl
--------------------------------------------------------------------------------
/datasets/files/NCaltech101/5shot-repeat=False.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NCaltech101/5shot-repeat=False.pkl
--------------------------------------------------------------------------------
/datasets/files/NCaltech101/5shot-repeat=True.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NCaltech101/5shot-repeat=True.pkl
--------------------------------------------------------------------------------
/datasets/files/NImageNet/10shot-repeat=False.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NImageNet/10shot-repeat=False.pkl
--------------------------------------------------------------------------------
/datasets/files/NImageNet/1shot-repeat=False.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NImageNet/1shot-repeat=False.pkl
--------------------------------------------------------------------------------
/datasets/files/NImageNet/20shot-repeat=False.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NImageNet/20shot-repeat=False.pkl
--------------------------------------------------------------------------------
/datasets/files/NImageNet/2shot-repeat=False.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NImageNet/2shot-repeat=False.pkl
--------------------------------------------------------------------------------
/datasets/files/NImageNet/3shot-repeat=False.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NImageNet/3shot-repeat=False.pkl
--------------------------------------------------------------------------------
/datasets/files/NImageNet/50shot-repeat=False.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NImageNet/50shot-repeat=False.pkl
--------------------------------------------------------------------------------
/datasets/files/NImageNet/5shot-repeat=False.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NImageNet/5shot-repeat=False.pkl
--------------------------------------------------------------------------------
/datasets/files/NCaltech101/10shot-repeat=False.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NCaltech101/10shot-repeat=False.pkl
--------------------------------------------------------------------------------
/datasets/files/NCaltech101/20shot-repeat=False.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NCaltech101/20shot-repeat=False.pkl
--------------------------------------------------------------------------------
/datasets/files/NCaltech101/50shot-repeat=False.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NCaltech101/50shot-repeat=False.pkl
--------------------------------------------------------------------------------
/datasets/files/NImageNetMini/10shot-repeat=True.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NImageNetMini/10shot-repeat=True.pkl
--------------------------------------------------------------------------------
/datasets/files/NImageNetMini/1shot-repeat=True.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NImageNetMini/1shot-repeat=True.pkl
--------------------------------------------------------------------------------
/datasets/files/NImageNetMini/20shot-repeat=True.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NImageNetMini/20shot-repeat=True.pkl
--------------------------------------------------------------------------------
/datasets/files/NImageNetMini/2shot-repeat=True.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NImageNetMini/2shot-repeat=True.pkl
--------------------------------------------------------------------------------
/datasets/files/NImageNetMini/3shot-repeat=True.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NImageNetMini/3shot-repeat=True.pkl
--------------------------------------------------------------------------------
/datasets/files/NImageNetMini/50shot-repeat=True.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NImageNetMini/50shot-repeat=True.pkl
--------------------------------------------------------------------------------
/datasets/files/NImageNetMini/5shot-repeat=True.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/EventCLIP/HEAD/datasets/files/NImageNetMini/5shot-repeat=True.pkl
--------------------------------------------------------------------------------
/scripts/test_all_subset.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # run `test.py` with different N-IN robustness variants
4 |
5 | CMD=$1
6 |
7 | for subset in -1 1 2 3 4 5 6 7 8 9
8 | do
9 | cmd="$CMD --subset $subset"
10 | echo $cmd
11 | eval $cmd
12 | done
13 |
--------------------------------------------------------------------------------
/scripts/train_all_shots.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # run `train.py` with different number of shots
4 | CMD=$1
5 |
6 | shot1=${2:-20}
7 | shot2=${3:-10}
8 | shot3=${4:-5}
9 | shot4=${5:-3}
10 | shot5=${6:-1}
11 |
12 | for shot in $shot1 $shot2 $shot3 $shot4 $shot5
13 | do
14 | cmd="$CMD --num_shots $shot"
15 | echo $cmd
16 | eval $cmd
17 | done
18 |
--------------------------------------------------------------------------------
/scripts/test_all_arch.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # run `test.py` on all CLIP arches
4 | CMD=$1
5 |
6 | for arch in 'RN50' 'RN101' 'RN50x4' 'RN50x16' 'RN50x64' 'ViT-B/32' 'ViT-B/16' 'ViT-L/14'
7 | do
8 | if [ "$arch" = "RN50x64" ]; then
9 | bs=32
10 | else
11 | bs=64
12 | fi
13 |
14 | echo "Testing $arch"
15 | cmd="$CMD --arch $arch --bs $bs"
16 | echo $cmd
17 | eval $cmd
18 | done
19 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .clip_cls import ZSCLIPClassifier, FSCLIPClassifier
2 | from .clip_cls_ft import FTCLIPClassifier
3 |
4 |
5 | def build_model(params):
6 | if params.model == 'ZSCLIP':
7 | return ZSCLIPClassifier(clip_dict=params.clip_dict, )
8 | elif params.model == 'FSCLIP':
9 | return FSCLIPClassifier(
10 | adapter_dict=params.adapter_dict,
11 | clip_dict=params.clip_dict,
12 | loss_dict=params.loss_dict,
13 | )
14 | elif params.model == 'FTCLIP':
15 | return FTCLIPClassifier(
16 | adapter_dict=params.adapter_dict,
17 | clip_dict=params.clip_dict,
18 | loss_dict=params.loss_dict,
19 | )
20 | else:
21 | raise NotImplementedError(f'{params.model} is not implemented.')
22 |
--------------------------------------------------------------------------------
/configs/zsclip/zsclip_ncars_params.py:
--------------------------------------------------------------------------------
1 | from nerv.training import BaseParams
2 |
3 |
4 | class EventCLIPParams(BaseParams):
5 | project = 'EventCLIP'
6 |
7 | # training settings
8 | gpus = 1
9 |
10 | # data settings
11 | dataset = 'n_cars'
12 | data_root = './data/N-Cars/'
13 | train_batch_size = 32 // gpus
14 | val_batch_size = train_batch_size * 2
15 | num_workers = 8
16 |
17 | # event2img conversion
18 | quantize_args = dict(
19 | max_imgs=2,
20 | N=30000,
21 | split_method='event_count',
22 | convert_method='event_histogram',
23 | grayscale=True,
24 | count_non_zero=True,
25 | background_mask=False,
26 | )
27 |
28 | # model configs
29 | model = 'ZSCLIP'
30 | clip_dict = dict(
31 | # 'RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32'
32 | # 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'
33 | arch='ViT-L/14',
34 | prompt='a point cloud image of a {}',
35 | agg_func='mean', # aggregate the logits over views
36 | )
37 |
--------------------------------------------------------------------------------
/configs/zsclip/zsclip_nin_params.py:
--------------------------------------------------------------------------------
1 | from nerv.training import BaseParams
2 |
3 |
4 | class EventCLIPParams(BaseParams):
5 | project = 'EventCLIP'
6 |
7 | # training settings
8 | gpus = 1
9 |
10 | # data settings
11 | dataset = 'n_imagenet'
12 | data_root = './data/N_Imagenet/'
13 | train_batch_size = 32 // gpus
14 | val_batch_size = train_batch_size * 2
15 | num_workers = 16
16 |
17 | # event2img conversion
18 | quantize_args = dict(
19 | max_imgs=2,
20 | N=70000,
21 | split_method='event_count',
22 | convert_method='event_histogram',
23 | grayscale=True,
24 | count_non_zero=False,
25 | background_mask=True,
26 | )
27 |
28 | # model configs
29 | model = 'ZSCLIP'
30 | clip_dict = dict(
31 | # 'RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32'
32 | # 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'
33 | arch='ViT-L/14',
34 | prompt='a point cloud image of a {}',
35 | agg_func='mean', # aggregate the logits over views
36 | )
37 |
--------------------------------------------------------------------------------
/configs/zsclip/zsclip_ncaltech_params.py:
--------------------------------------------------------------------------------
1 | from nerv.training import BaseParams
2 |
3 |
4 | class EventCLIPParams(BaseParams):
5 | project = 'EventCLIP'
6 |
7 | # training settings
8 | gpus = 1
9 |
10 | # data settings
11 | dataset = 'n_caltech'
12 | data_root = './data/N-Caltech101/'
13 | train_batch_size = 32 // gpus
14 | val_batch_size = train_batch_size * 2
15 | num_workers = 8
16 |
17 | # event2img conversion
18 | quantize_args = dict(
19 | max_imgs=2,
20 | N=20000,
21 | split_method='event_count',
22 | convert_method='event_histogram',
23 | grayscale=True,
24 | count_non_zero=False,
25 | background_mask=True,
26 | )
27 |
28 | # model configs
29 | model = 'ZSCLIP'
30 | clip_dict = dict(
31 | # 'RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32'
32 | # 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'
33 | arch='ViT-L/14',
34 | prompt='a point cloud image of a {}',
35 | agg_func='mean', # aggregate the logits over views
36 | )
37 |
--------------------------------------------------------------------------------
/configs/zsclip/zsclip_nin_mini_params-vitb32.py:
--------------------------------------------------------------------------------
1 | from nerv.training import BaseParams
2 |
3 |
4 | class EventCLIPParams(BaseParams):
5 | project = 'EventCLIP'
6 |
7 | # training settings
8 | gpus = 1
9 |
10 | # data settings
11 | dataset = 'n_imagenet_mini'
12 | data_root = './data/N_Imagenet/'
13 | train_batch_size = 32 // gpus
14 | val_batch_size = train_batch_size * 2
15 | num_workers = 16
16 |
17 | # event2img conversion
18 | quantize_args = dict(
19 | max_imgs=2,
20 | N=70000,
21 | split_method='event_count',
22 | convert_method='event_histogram',
23 | grayscale=True,
24 | count_non_zero=False,
25 | background_mask=True,
26 | )
27 |
28 | # model configs
29 | model = 'ZSCLIP'
30 | clip_dict = dict(
31 | # 'RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32'
32 | # 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'
33 | arch='ViT-B/32',
34 | prompt='a sketch image of a {}',
35 | agg_func='mean', # aggregate the logits over views
36 | )
37 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Ziyi Wu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | from .caltech import build_n_caltech_dataset, NCaltech101
4 | from .cars import build_n_cars_dataset, NCars
5 | from .imagenet import build_n_imagenet_dataset, NImageNet
6 | from .imagenet_mini import build_n_imagenet_mini_dataset, NImageNetMini
7 | from .event2img import build_event2img_dataset, Event2ImageDataset
8 | from .vis import events2frames
9 |
10 |
11 | def build_dataset(params, val_only=False, **kwargs):
12 | # `gen_data` means doing pseudo label generation for self-training
13 | gen_data = kwargs.pop('gen_data', False)
14 | tta = kwargs.pop('tta', False)
15 |
16 | dst = params.dataset
17 | ev_dst = eval(f'build_{dst}_dataset')(
18 | params, val_only=val_only, gen_data=gen_data, **kwargs)
19 |
20 | # adjust max-views for event2img conversion
21 | train_params = copy.deepcopy(params)
22 | val_params = copy.deepcopy(params)
23 | val_params.quantize_args['max_imgs'] = 10 # load all views for testing
24 |
25 | # only build one dataset in these cases
26 | if val_only or gen_data:
27 | return build_event2img_dataset(val_params, ev_dst, tta=tta)
28 |
29 | # build both train and val datasets
30 | return build_event2img_dataset(
31 | train_params, ev_dst[0], augment=params.get('img_aug', False)), \
32 | build_event2img_dataset(val_params, ev_dst[1], augment=False)
33 |
--------------------------------------------------------------------------------
/datasets/cars.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from .caltech import NCaltech101
4 |
5 | NEW_CNAMES = {
6 | "cars": "car",
7 | "background": "background",
8 | }
9 |
10 |
11 | class NCars(NCaltech101):
12 | """Dataset class for N-Cars dataset."""
13 |
14 | def __init__(
15 | self,
16 | root,
17 | augmentation=False,
18 | num_shots=None,
19 | new_cnames=None,
20 | ):
21 | super().__init__(
22 | root=root,
23 | augmentation=augmentation,
24 | num_shots=num_shots,
25 | repeat=False,
26 | new_cnames=new_cnames,
27 | )
28 |
29 | # data stats
30 | self.resolution = (100, 120)
31 | self.max_t = 0.1 # max
32 | self.max_n = 12500 # 95th percentile
33 |
34 | # data augmentation
35 | self.max_shift = 10 # resolution is ~half as N-Caltech101
36 |
37 |
38 | def build_n_cars_dataset(params, val_only=False, gen_data=False):
39 | """Build the N-Cars dataset."""
40 | # only build the test set
41 | test_set = NCars(
42 | root=os.path.join(params.data_root, 'test'),
43 | augmentation=False,
44 | new_cnames=NEW_CNAMES,
45 | )
46 | if val_only:
47 | assert not gen_data
48 | return test_set
49 | # build the training set for pseudo label generation
50 | if gen_data:
51 | return NCars(
52 | root=os.path.join(params.data_root, 'train'),
53 | augmentation=False,
54 | new_cnames=NEW_CNAMES,
55 | )
56 |
57 | # build the training set
58 | train_set = NCars(
59 | root=os.path.join(params.data_root, 'train'),
60 | augmentation=True,
61 | num_shots=params.get('num_shots', None),
62 | new_cnames=NEW_CNAMES,
63 | )
64 | return train_set, test_set
65 |
--------------------------------------------------------------------------------
/docs/install.md:
--------------------------------------------------------------------------------
1 | # Install
2 |
3 | We recommend using [conda](https://docs.conda.io/projects/conda/en/latest/user-guide/install/index.html) for environment setup:
4 |
5 | ```
6 | conda create -n eventclip python=3.9.17
7 | conda activate eventclip
8 | ```
9 |
10 | Then install PyTorch which is compatible with your CUDA setting.
11 | In our experiments, we use PyTorch 1.12.1 + CUDA 11.3 (the CUDA version is fine as long as it meets the requirement [here](https://pytorch.org/get-started/previous-versions/)).
12 | PyTorch 2.0.0 is also compatible and can speed up model training:
13 |
14 | ```
15 | conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
16 | pip install pytorch-lightning==1.8.6 torchmetrics==0.11.4
17 | ```
18 |
19 | The codebase heavily relies on [nerv](https://github.com/Wuziyi616/nerv) for project template and Trainer.
20 | You can easily install it by:
21 |
22 | ```
23 | git clone git@github.com:Wuziyi616/nerv.git
24 | cd nerv
25 | git checkout v0.3.1 # tested with v0.3.1 release
26 | pip install -e .
27 | ```
28 |
29 | This will automatically install packages necessary for the project.
30 | Additional packages are listed as follows:
31 |
32 | ```
33 | pip install ftfy regex tqdm # packages required by CLIP
34 | pip install git+https://github.com/openai/CLIP.git # install CLIP from OpenAI
35 | ```
36 |
37 | Finally, clone this project by:
38 |
39 | ```
40 | cd .. # move out from nerv/
41 | git clone git@github.com:Wuziyi616/EventCLIP.git
42 | cd EventCLIP
43 | ```
44 |
45 | We use [wandb](https://wandb.ai/) for logging, please run `wandb login` to log in.
46 |
47 | ## Possible Issues
48 |
49 | - In case you encounter any environmental issues, you can refer to the conda env file exported from my server [environment.yml](../environment.yml).
50 | You can install the same environment by `conda env create -f environment.yml`.
51 |
--------------------------------------------------------------------------------
/scripts/dup_run_sbatch.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # This is a wrapper for `sbatch_run.sh` to run repeated experiments
4 | # It will duplicate the same params file for several times and run them all
5 |
6 | #######################################################################
7 | # An example usage:
8 | # GPUS=1 CPUS_PER_GPU=8 MEM_PER_CPU=5 QOS=scavenger REPEAT=3 ./scripts/dup_run_sbatch.sh \
9 | # rtx6000 test-sbatch ./train.py ddp params.py --fp16 --ddp --cudnn
10 | #######################################################################
11 |
12 | # read args from command line
13 | REPEAT=${REPEAT:-3}
14 | GPUS=${GPUS:-1}
15 | CPUS_PER_GPU=${CPUS_PER_GPU:-8}
16 | MEM_PER_CPU=${MEM_PER_CPU:-5}
17 | QOS=${QOS:-scavenger}
18 | TIME=${TIME:-96:00:00}
19 |
20 | PY_ARGS=${@:6}
21 | PARTITION=$1
22 | JOB_NAME=$2
23 | PY_FILE=$3
24 | DDP=$4
25 | PARAMS=$5
26 |
27 | for repeat_idx in $(seq 1 $REPEAT)
28 | do
29 | params="${PARAMS:0:(-3)}-dup${repeat_idx}.py"
30 | cp $PARAMS $params
31 | job_name="${JOB_NAME}-dup${repeat_idx}"
32 | # if `$PY_ARGS` contains "--N X", then append "-N_X" to `job_name`
33 | if [[ $PY_ARGS == *"--N"* ]]; then
34 | N=$(echo $PY_ARGS | grep -oP "(?<=--N )\d+")
35 | # only modify when `X` is positive
36 | if [[ $N -gt 0 ]]; then
37 | job_name="${job_name}-N_${N}"
38 | fi
39 | fi
40 | # if `$PY_ARGS` contains "--num_shots X", then append "-Xshot" to `job_name`
41 | if [[ $PY_ARGS == *"--num_shots"* ]]; then
42 | num_shots=$(echo $PY_ARGS | grep -oP "(?<=--num_shots )\d+")
43 | # only modify when `X` is positive
44 | if [[ $num_shots -gt 0 ]]; then
45 | job_name="${job_name}-${num_shots}shot"
46 | fi
47 | fi
48 | cmd="./scripts/sbatch_run.sh $PARTITION $job_name $PY_FILE $DDP --params $params $PY_ARGS"
49 | echo $cmd
50 | eval $cmd
51 | done
52 |
--------------------------------------------------------------------------------
/datasets/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def random_shift_events(events, max_shift=20, resolution=(180, 240)):
5 | """Spatially shift events by a random offset."""
6 | H, W = resolution
7 | x_shift, y_shift = np.random.randint(-max_shift, max_shift + 1, size=(2, ))
8 | events[:, 0] += x_shift
9 | events[:, 1] += y_shift
10 |
11 | valid_events = (events[:, 0] >= 0) & (events[:, 0] < W) & \
12 | (events[:, 1] >= 0) & (events[:, 1] < H)
13 | events = events[valid_events]
14 |
15 | return events
16 |
17 |
18 | def random_flip_events_along_x(events, resolution=(180, 240), p=0.5):
19 | """Flip events along horizontally with probability p."""
20 | H, W = resolution
21 | if np.random.random() < p:
22 | events[:, 0] = W - 1 - events[:, 0]
23 | return events
24 |
25 |
26 | def random_time_flip_events(events, p=0.5):
27 | """Flip events over time with probability p."""
28 | if np.random.random() < p:
29 | events = np.flip(events, axis=0)
30 | events = np.ascontiguousarray(events)
31 | # reverse the time
32 | events[:, 2] = events[0, 2] - events[:, 2]
33 | # reverse the polarity
34 | events[:, 3] = -events[:, 3]
35 | return events
36 |
37 |
38 | def center_events(events, resolution=(180, 240)):
39 | """Center the temporal & spatial coordinates of events.
40 | Make min_t == 0.
41 | Make (max_x + min_x + 1) / 2 == W / 2 and (max_y + min_y + 1) / 2 == H / 2.
42 |
43 | Args:
44 | events: [N, 4 (x,y,t,p)]
45 | resolution: (H, W)
46 | """
47 | # temporal
48 | events[:, 2] -= events[:, 2].min()
49 | # spatial
50 | H, W = resolution
51 | x_min, x_max = events[:, 0].min(), events[:, 0].max()
52 | y_min, y_max = events[:, 1].min(), events[:, 1].max()
53 | x_shift = ((x_max + x_min + 1.) - W) // 2.
54 | y_shift = ((y_max + y_min + 1.) - H) // 2.
55 | events[:, 0] -= x_shift
56 | events[:, 1] -= y_shift
57 | return events
58 |
--------------------------------------------------------------------------------
/configs/fsclip/joint_adapter/joint_fsclip_nin_params.py:
--------------------------------------------------------------------------------
1 | from nerv.training import BaseParams
2 |
3 |
4 | class EventCLIPParams(BaseParams):
5 | project = 'EventCLIP'
6 |
7 | # training settings
8 | gpus = 1
9 | max_epochs = 100
10 | save_interval = 1
11 | eval_interval = 5
12 | save_epoch_end = False
13 | n_samples = 10
14 |
15 | # optimizer settings
16 | # Adam optimizer, Cosine decay with Warmup
17 | optimizer = 'Adam'
18 | lr = 2e-5
19 | warmup_steps_pct = 0.05
20 |
21 | # data settings
22 | dataset = 'n_imagenet'
23 | data_root = './data/N_Imagenet/'
24 | num_shots = None
25 | img_aug = True
26 | train_batch_size = 128 // gpus
27 | val_batch_size = train_batch_size * 2
28 | num_workers = 16
29 |
30 | # event2img conversion
31 | quantize_args = dict(
32 | max_imgs=2,
33 | N=70000,
34 | split_method='event_count',
35 | convert_method='event_histogram',
36 | grayscale=True,
37 | count_non_zero=False,
38 | background_mask=True,
39 | )
40 |
41 | # model configs
42 | model = 'FSCLIP'
43 | clip_dict = dict(
44 | # 'RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32'
45 | # 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'
46 | arch='ViT-L/14',
47 | prompt='a point cloud image of a {}',
48 | agg_func='mean', # aggregate the logits over views
49 | )
50 |
51 | # adapter configs
52 | d_model = 256
53 | adapter_dict = dict(
54 | adapter_type='text-trans',
55 | in_dim=512,
56 | d_model=d_model,
57 | num_heads=d_model // 64,
58 | ffn_dim=d_model * 4,
59 | norm_first=True,
60 | num_layers=2,
61 | residual=0.95,
62 | )
63 |
64 | # loss configs
65 | loss_dict = dict(
66 | use_logits_loss=True, # CE over mean logits
67 | use_probs_loss=False, # CE over mean probs
68 | )
69 |
70 | ce_loss_w = 1.
71 |
72 | # save the model with the highest acc
73 | ckp_monitor = 'val/probs_acc'
74 | ckp_monitor_type = 'max' # 'max' or 'min'
75 |
--------------------------------------------------------------------------------
/configs/fsclip/text_adapter/text_fsclip_nin_params.py:
--------------------------------------------------------------------------------
1 | from nerv.training import BaseParams
2 |
3 |
4 | class EventCLIPParams(BaseParams):
5 | project = 'EventCLIP'
6 |
7 | # training settings
8 | gpus = 1
9 | max_epochs = 100
10 | save_interval = 1
11 | eval_interval = 5
12 | save_epoch_end = False
13 | n_samples = 10
14 |
15 | # optimizer settings
16 | # Adam optimizer, Cosine decay with Warmup
17 | optimizer = 'Adam'
18 | lr = 2e-5
19 | warmup_steps_pct = 0.05
20 |
21 | # data settings
22 | dataset = 'n_imagenet'
23 | data_root = './data/N_Imagenet/'
24 | num_shots = None
25 | img_aug = True
26 | train_batch_size = 128 // gpus
27 | val_batch_size = train_batch_size * 2
28 | num_workers = 16
29 |
30 | # event2img conversion
31 | quantize_args = dict(
32 | max_imgs=2,
33 | N=70000,
34 | split_method='event_count',
35 | convert_method='event_histogram',
36 | grayscale=True,
37 | count_non_zero=False,
38 | background_mask=True,
39 | )
40 |
41 | # model configs
42 | model = 'FSCLIP'
43 | clip_dict = dict(
44 | # 'RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32'
45 | # 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'
46 | arch='ViT-L/14',
47 | prompt='a point cloud image of a {}',
48 | agg_func='mean', # aggregate the logits over views
49 | )
50 |
51 | # adapter configs
52 | d_model = 256
53 | adapter_dict = dict(
54 | adapter_type='text-identity',
55 | in_dim=512,
56 | d_model=d_model,
57 | num_heads=d_model // 64,
58 | ffn_dim=d_model * 4,
59 | norm_first=True,
60 | num_layers=2,
61 | residual=0.95,
62 | )
63 |
64 | # loss configs
65 | loss_dict = dict(
66 | use_logits_loss=True, # CE over mean logits
67 | use_probs_loss=False, # CE over mean probs
68 | )
69 |
70 | ce_loss_w = 1.
71 |
72 | # save the model with the highest acc
73 | ckp_monitor = 'val/probs_acc'
74 | ckp_monitor_type = 'max' # 'max' or 'min'
75 |
--------------------------------------------------------------------------------
/configs/fsclip/joint_adapter/joint_fsclip_ncaltech_params.py:
--------------------------------------------------------------------------------
1 | from nerv.training import BaseParams
2 |
3 |
4 | class EventCLIPParams(BaseParams):
5 | project = 'EventCLIP'
6 |
7 | # training settings
8 | gpus = 1
9 | max_epochs = 100
10 | save_interval = 1
11 | eval_interval = 5
12 | save_epoch_end = False
13 | n_samples = 5
14 |
15 | # optimizer settings
16 | # Adam optimizer, Cosine decay with Warmup
17 | optimizer = 'Adam'
18 | lr = 1e-4
19 | warmup_steps_pct = 0.05
20 |
21 | # data settings
22 | dataset = 'n_caltech'
23 | data_root = './data/N-Caltech101/'
24 | num_shots = None
25 | repeat_data = True
26 | img_aug = True
27 | train_batch_size = 32 // gpus
28 | val_batch_size = train_batch_size * 2
29 | num_workers = 8
30 |
31 | # event2img conversion
32 | quantize_args = dict(
33 | max_imgs=2,
34 | N=20000,
35 | split_method='event_count',
36 | convert_method='event_histogram',
37 | grayscale=True,
38 | count_non_zero=False,
39 | background_mask=True,
40 | )
41 |
42 | # model configs
43 | model = 'FSCLIP'
44 | clip_dict = dict(
45 | # 'RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32'
46 | # 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'
47 | arch='ViT-L/14',
48 | prompt='a point cloud image of a {}',
49 | agg_func='mean', # aggregate the logits over views
50 | )
51 |
52 | # adapter configs
53 | d_model = 256
54 | adapter_dict = dict(
55 | adapter_type='text-trans',
56 | in_dim=512,
57 | d_model=d_model,
58 | num_heads=d_model // 64,
59 | ffn_dim=d_model * 4,
60 | norm_first=True,
61 | num_layers=2,
62 | residual=0.8,
63 | )
64 |
65 | # loss configs
66 | loss_dict = dict(
67 | use_logits_loss=True, # CE over mean logits
68 | use_probs_loss=False, # CE over mean probs
69 | )
70 |
71 | ce_loss_w = 1.
72 |
73 | # save the model with the highest acc
74 | ckp_monitor = 'val/probs_acc'
75 | ckp_monitor_type = 'max' # 'max' or 'min'
76 |
--------------------------------------------------------------------------------
/configs/fsclip/text_adapter/text_fsclip_ncaltech_params.py:
--------------------------------------------------------------------------------
1 | from nerv.training import BaseParams
2 |
3 |
4 | class EventCLIPParams(BaseParams):
5 | project = 'EventCLIP'
6 |
7 | # training settings
8 | gpus = 1
9 | max_epochs = 100
10 | save_interval = 1
11 | eval_interval = 5
12 | save_epoch_end = False
13 | n_samples = 5
14 |
15 | # optimizer settings
16 | # Adam optimizer, Cosine decay with Warmup
17 | optimizer = 'Adam'
18 | lr = 1e-4
19 | warmup_steps_pct = 0.05
20 |
21 | # data settings
22 | dataset = 'n_caltech'
23 | data_root = './data/N-Caltech101/'
24 | num_shots = None
25 | repeat_data = True
26 | img_aug = True
27 | train_batch_size = 32 // gpus
28 | val_batch_size = train_batch_size * 2
29 | num_workers = 8
30 |
31 | # event2img conversion
32 | quantize_args = dict(
33 | max_imgs=2,
34 | N=20000,
35 | split_method='event_count',
36 | convert_method='event_histogram',
37 | grayscale=True,
38 | count_non_zero=False,
39 | background_mask=True,
40 | )
41 |
42 | # model configs
43 | model = 'FSCLIP'
44 | clip_dict = dict(
45 | # 'RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32'
46 | # 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'
47 | arch='ViT-L/14',
48 | prompt='a point cloud image of a {}',
49 | agg_func='mean', # aggregate the logits over views
50 | )
51 |
52 | # adapter configs
53 | d_model = 256
54 | adapter_dict = dict(
55 | adapter_type='text-identity',
56 | in_dim=512,
57 | d_model=d_model,
58 | num_heads=d_model // 64,
59 | ffn_dim=d_model * 4,
60 | norm_first=True,
61 | num_layers=2,
62 | residual=0.8,
63 | )
64 |
65 | # loss configs
66 | loss_dict = dict(
67 | use_logits_loss=True, # CE over mean logits
68 | use_probs_loss=False, # CE over mean probs
69 | )
70 |
71 | ce_loss_w = 1.
72 |
73 | # save the model with the highest acc
74 | ckp_monitor = 'val/probs_acc'
75 | ckp_monitor_type = 'max' # 'max' or 'min'
76 |
--------------------------------------------------------------------------------
/configs/fsclip/joint_adapter/joint_fsclip_ncars_params.py:
--------------------------------------------------------------------------------
1 | from nerv.training import BaseParams
2 |
3 |
4 | class EventCLIPParams(BaseParams):
5 | project = 'EventCLIP'
6 |
7 | # training settings
8 | gpus = 1
9 | max_epochs = 50
10 | save_interval = 1
11 | eval_interval = 5
12 | save_epoch_end = False
13 | n_samples = 5
14 |
15 | # optimizer settings
16 | # Adam optimizer, Cosine decay with Warmup
17 | optimizer = 'Adam'
18 | lr = 2e-4
19 | warmup_steps_pct = 0.05
20 |
21 | # data settings
22 | dataset = 'n_cars'
23 | data_root = './data/N-Cars/'
24 | num_shots = None
25 | img_aug = False
26 | train_batch_size = 32 // gpus if \
27 | num_shots is None else min(num_shots * 2, 32) // gpus
28 | val_batch_size = max(train_batch_size, 32 // gpus) * 2
29 | num_workers = 8
30 |
31 | # event2img conversion
32 | quantize_args = dict(
33 | max_imgs=2,
34 | N=30000,
35 | split_method='event_count',
36 | convert_method='event_histogram',
37 | grayscale=True,
38 | count_non_zero=True,
39 | background_mask=False,
40 | )
41 |
42 | # model configs
43 | model = 'FSCLIP'
44 | clip_dict = dict(
45 | # 'RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32'
46 | # 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'
47 | arch='ViT-L/14',
48 | prompt='a point cloud image of a {}',
49 | agg_func='mean', # aggregate the logits over views
50 | )
51 |
52 | # adapter configs
53 | d_model = 256
54 | adapter_dict = dict(
55 | adapter_type='text-trans',
56 | in_dim=512,
57 | d_model=d_model,
58 | num_heads=d_model // 64,
59 | ffn_dim=d_model * 4,
60 | norm_first=True,
61 | num_layers=2,
62 | residual=0.8,
63 | )
64 |
65 | # loss configs
66 | loss_dict = dict(
67 | use_logits_loss=True, # CE over mean logits
68 | use_probs_loss=False, # CE over mean probs
69 | )
70 |
71 | ce_loss_w = 1.
72 |
73 | # save the model with the highest acc
74 | ckp_monitor = 'val/probs_acc'
75 | ckp_monitor_type = 'max' # 'max' or 'min'
76 |
--------------------------------------------------------------------------------
/configs/fsclip/text_adapter/text_fsclip_ncars_params.py:
--------------------------------------------------------------------------------
1 | from nerv.training import BaseParams
2 |
3 |
4 | class EventCLIPParams(BaseParams):
5 | project = 'EventCLIP'
6 |
7 | # training settings
8 | gpus = 1
9 | max_epochs = 50
10 | save_interval = 1
11 | eval_interval = 5
12 | save_epoch_end = False
13 | n_samples = 5
14 |
15 | # optimizer settings
16 | # Adam optimizer, Cosine decay with Warmup
17 | optimizer = 'Adam'
18 | lr = 2e-4
19 | warmup_steps_pct = 0.05
20 |
21 | # data settings
22 | dataset = 'n_cars'
23 | data_root = './data/N-Cars/'
24 | num_shots = None
25 | img_aug = False
26 | train_batch_size = 32 // gpus if \
27 | num_shots is None else min(num_shots * 2, 32) // gpus
28 | val_batch_size = max(train_batch_size, 32 // gpus) * 2
29 | num_workers = 8
30 |
31 | # event2img conversion
32 | quantize_args = dict(
33 | max_imgs=2,
34 | N=30000,
35 | split_method='event_count',
36 | convert_method='event_histogram',
37 | grayscale=True,
38 | count_non_zero=True,
39 | background_mask=False,
40 | )
41 |
42 | # model configs
43 | model = 'FSCLIP'
44 | clip_dict = dict(
45 | # 'RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32'
46 | # 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'
47 | arch='ViT-L/14',
48 | prompt='a point cloud image of a {}',
49 | agg_func='mean', # aggregate the logits over views
50 | )
51 |
52 | # adapter configs
53 | d_model = 256
54 | adapter_dict = dict(
55 | adapter_type='text-identity',
56 | in_dim=512,
57 | d_model=d_model,
58 | num_heads=d_model // 64,
59 | ffn_dim=d_model * 4,
60 | norm_first=True,
61 | num_layers=2,
62 | residual=0.8,
63 | )
64 |
65 | # loss configs
66 | loss_dict = dict(
67 | use_logits_loss=True, # CE over mean logits
68 | use_probs_loss=False, # CE over mean probs
69 | )
70 |
71 | ce_loss_w = 1.
72 |
73 | # save the model with the highest acc
74 | ckp_monitor = 'val/probs_acc'
75 | ckp_monitor_type = 'max' # 'max' or 'min'
76 |
--------------------------------------------------------------------------------
/configs/fsclip/joint_adapter/joint_fsclip_nin_mini_params-vitb32.py:
--------------------------------------------------------------------------------
1 | from nerv.training import BaseParams
2 |
3 |
4 | class EventCLIPParams(BaseParams):
5 | project = 'EventCLIP'
6 |
7 | # training settings
8 | gpus = 1
9 | max_epochs = 100
10 | save_interval = 1
11 | eval_interval = 5
12 | save_epoch_end = False
13 | n_samples = 5
14 |
15 | # optimizer settings
16 | # Adam optimizer, Cosine decay with Warmup
17 | optimizer = 'Adam'
18 | lr = 2e-5
19 | warmup_steps_pct = 0.05
20 |
21 | # data settings
22 | dataset = 'n_imagenet_mini'
23 | data_root = './data/N_Imagenet/' # change to 'data/pseudo-N_Imagenet/xxx'
24 | num_shots = None # set this to train with highest confident pseudo labels
25 | repeat_data = True
26 | img_aug = True
27 | train_batch_size = 32 // gpus
28 | val_batch_size = train_batch_size * 2
29 | num_workers = 8
30 |
31 | # event2img conversion
32 | quantize_args = dict(
33 | max_imgs=2,
34 | N=70000,
35 | split_method='event_count',
36 | convert_method='event_histogram',
37 | grayscale=True,
38 | count_non_zero=False,
39 | background_mask=True,
40 | )
41 |
42 | # model configs
43 | model = 'FSCLIP'
44 | clip_dict = dict(
45 | # 'RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32'
46 | # 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'
47 | arch='ViT-B/32',
48 | prompt='a sketch image of a {}',
49 | agg_func='mean', # aggregate the logits over views
50 | )
51 |
52 | # adapter configs
53 | d_model = 256
54 | adapter_dict = dict(
55 | adapter_type='text-trans',
56 | in_dim=512,
57 | d_model=d_model,
58 | num_heads=d_model // 64,
59 | ffn_dim=d_model * 4,
60 | norm_first=True,
61 | num_layers=2,
62 | residual=0.95,
63 | )
64 |
65 | # loss configs
66 | loss_dict = dict(
67 | use_logits_loss=True, # CE over mean logits
68 | use_probs_loss=False, # CE over mean probs
69 | )
70 |
71 | ce_loss_w = 1.
72 |
73 | # save the model with the highest acc
74 | ckp_monitor = 'val/probs_acc'
75 | ckp_monitor_type = 'max' # 'max' or 'min'
76 |
--------------------------------------------------------------------------------
/docs/data.md:
--------------------------------------------------------------------------------
1 | # Dataset Preparation
2 |
3 | All datasets should be downloaded or soft-linked to `./data/`.
4 | Or you can modify the `data_root` value in the config files.
5 |
6 | ## N-Caltech101
7 |
8 | We adopt the N-Caltech101 dataset from [EST repo](https://github.com/uzh-rpg/rpg_event_representation_learning#training).
9 | Please download and unzip the data and put it under `./data/N-Caltech101`.
10 |
11 | ## N-Cars
12 |
13 | We also use the N-Cars dataset processed by the EST repo.
14 | Please use [this link](http://rpg.ifi.uzh.ch/datasets/gehrig_et_al_iccv19/N-Cars.zip) to download it and unzip it to `./data/N-Cars`.
15 |
16 | ## N-ImageNet
17 |
18 | Please follow the instructions on the [n_imagenet repo](https://github.com/82magnolia/n_imagenet#n-imagenet-towards-robust-fine-grained-object-recognition-with-event-cameras) to download N-ImageNet and put it under `./data/N_Imagenet`.
19 |
20 | The `N-ImageNet Variants (~150GB)` are not required for training.
21 | But if you want to test the robustness of our method, you can download them as we do provide evaluation code on them.
22 |
23 | The `Mini N-ImageNet (~45 GB)` subset is not used in this project.
24 | But you can modify the dataloader if you want to experiment at a smaller scale.
25 |
26 | ## Summary
27 |
28 | **The final `data` directory should look like this:**
29 |
30 | ```
31 | data/
32 | ├── N-Caltech101/
33 | │ ├── training/
34 | │ │ ├── accordion/ # folder with events in '.npy' format
35 | │ │ ├── airplanes/
36 | • • •
37 | • • •
38 | │ │ └── yin_yang/
39 | │ ├── validation/ # same as 'training'
40 | │ ├── testing/ # same as 'training'
41 | ├── N-Cars/
42 | │ ├── train/
43 | │ │ ├── background/ # folder with events in '.npy' format
44 | │ │ ├── cars/
45 | │ ├── test/ # same as 'train'
46 | ├── N_Imagenet/
47 | │ ├── extracted_train/
48 | │ │ ├── n01440764/ # folder with events in '.npy' format
49 | │ │ ├── n01443537/
50 | • • •
51 | • • •
52 | │ │ └── n15075141/
53 | │ ├── extracted_val/ # same as 'extracted_train'
54 | │ ├── extracted_val_brightness_4/ # these are robustness variants
55 | │ ├── extracted_val_brightness_5/ # brightness change
56 | │ ├── extracted_val_brightness_6/
57 | │ ├── extracted_val_brightness_7/
58 | │ ├── extracted_val_mode_1/ # trajectory change
59 | │ ├── extracted_val_mode_3/
60 | │ ├── extracted_val_mode_5/
61 | │ ├── extracted_val_mode_6/
62 | └ └── extracted_val_mode_7/
63 | ```
64 |
--------------------------------------------------------------------------------
/scripts/resubmit_failed_job.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # this is a script called by `sbatch_run.sh` internally
4 | # the goal is to re-submit the job if it doesn't end normally, i.e., `COMPLETED` or `CANCELLED`
5 | # currently we consider `FAILED`, `OUT_OF_MEMORY`, and `TIMEOUT` as abnormal ends
6 | # for other states like `PENDING`, `NONE`, we will do nothing
7 |
8 | # read args from command line
9 | JOB_ID=$1
10 | SLRM_NAME=$2
11 | LOG_FILE=$3
12 |
13 | # we first copy the sbatch file to "./sbatch/"
14 | slrm_file="run-${SLRM_NAME}.slrm"
15 | mkdir -p "./sbatch/"
16 | mv "./${slrm_file}" "./sbatch/${slrm_file}"
17 |
18 | # util function to check if string1 contains string2
19 | check_contain() {
20 | local string1="$1"
21 | local string2="$2"
22 |
23 | if [[ $string1 == *"$string2"* ]]; then
24 | return 0 # true
25 | else
26 | return 1 # false
27 | fi
28 | }
29 |
30 | # periodically check the job status
31 | while true; do
32 | read -ra arr <<< "$(sacct -j "$JOB_ID" --format State --noheader)"
33 | status="${arr[0]}"
34 | # re-submit it if it failed or OOM
35 | if check_contain "$status" "FAIL" || check_contain "$status" "OUT_OF_M" || check_contain "$status" "TIMEOUT"; then
36 | # the sbatch file is saved under "./sbatch/run-${SLRM_NAME}.slrm"
37 | # we copy it to "./", run it, and delete it
38 | cp "./sbatch/${slrm_file}" "./${slrm_file}"
39 | # should also update the JOB_ID!
40 | JOB_ID=$(sbatch --parsable $slrm_file)
41 | rm -f $slrm_file
42 | echo "Job $SLRM_NAME failed with status $status, resubmitted with JOB_ID $JOB_ID" >> $LOG_FILE
43 | # exit the loop/this script if it's 1) completed 2) cancelled
44 | # also delete the sbatch file
45 | elif check_contain "$status" "COMPLE" || check_contain "$status" "CANCEL"; then
46 | echo "Job $SLRM_NAME finished with status $status" >> $LOG_FILE
47 | rm -f "./sbatch/${slrm_file}"
48 | exit 0
49 | # do nothing if it's 1) running 2) waiting
50 | else
51 | echo "Job $SLRM_NAME, ID $JOB_ID is good with status $status" >> $LOG_FILE
52 | fi
53 | sleep 600 # check every 10 minutes
54 | done & # run in background
55 |
56 | # detach the background process with the current shell
57 | disown
58 |
59 | # ways to check if there are duplicated runs
60 | # names = str(subprocess.check_output("squeue -u jiaqixi -o '%.100j' --noheader", shell=True))[2:-1]
61 | # names = [n.strip() for n in names.split('\\n')][:-1]
62 | # [n for n in names if names.count(n) > 1]
63 |
--------------------------------------------------------------------------------
/configs/ftclip/ft_text_fsclip_nin_params.py:
--------------------------------------------------------------------------------
1 | from nerv.training import BaseParams
2 |
3 |
4 | class EventCLIPParams(BaseParams):
5 | project = 'EventCLIP'
6 |
7 | # training settings
8 | gpus = 4
9 | max_epochs = 100
10 | save_interval = 1
11 | eval_interval = 5
12 | save_epoch_end = False
13 | n_samples = 10
14 |
15 | # optimizer settings
16 | # Adam optimizer, Cosine decay with Warmup
17 | optimizer = 'Adam'
18 | lr = 2e-5
19 | clip_lr = lr / 10.
20 | warmup_steps_pct = 0.05
21 |
22 | # data settings
23 | dataset = 'n_imagenet'
24 | data_root = './data/N_Imagenet/'
25 | num_shots = None
26 | img_aug = True
27 | train_batch_size = 128 // gpus
28 | val_batch_size = train_batch_size * 2
29 | num_workers = 8
30 |
31 | # event2img conversion
32 | quantize_args = dict(
33 | max_imgs=2,
34 | N=70000,
35 | split_method='event_count',
36 | convert_method='event_histogram',
37 | grayscale=True,
38 | count_non_zero=False,
39 | background_mask=True,
40 | )
41 |
42 | # model configs
43 | model = 'FTCLIP'
44 | clip_dict = dict(
45 | # 'RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32'
46 | # 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'
47 | arch='ViT-L/14',
48 | prompt='a point cloud image of a {}',
49 | agg_func='mean', # aggregate the logits over views
50 | lora=-1, # use LoRA fine-tuning, typically r = 4, 16
51 | only_conv1=False, # only tune the first conv layer
52 | only_bias=False, # only tune the bias terms
53 | only_ln=False, # only tune the LayerNorm layers
54 | only_cls_fc=False, # only tune the embedding projection head
55 | only_cls_token=False, # only tune the CLS token
56 | )
57 |
58 | # adapter configs
59 | d_model = 256
60 | adapter_dict = dict(
61 | adapter_type='text-identity',
62 | in_dim=512,
63 | d_model=d_model,
64 | num_heads=d_model // 64,
65 | ffn_dim=d_model * 4,
66 | norm_first=True,
67 | num_layers=2,
68 | residual=0.95,
69 | )
70 |
71 | # loss configs
72 | loss_dict = dict(
73 | use_logits_loss=True, # CE over mean logits
74 | use_probs_loss=False, # CE over mean probs
75 | )
76 |
77 | ce_loss_w = 1.
78 |
79 | # save the model with the highest acc
80 | ckp_monitor = 'val/probs_acc'
81 | ckp_monitor_type = 'max' # 'max' or 'min'
82 |
--------------------------------------------------------------------------------
/configs/ftclip/ft_text_fsclip_nin_params-lora16.py:
--------------------------------------------------------------------------------
1 | from nerv.training import BaseParams
2 |
3 |
4 | class EventCLIPParams(BaseParams):
5 | project = 'EventCLIP'
6 |
7 | # training settings
8 | gpus = 4
9 | max_epochs = 100
10 | save_interval = 1
11 | eval_interval = 5
12 | save_epoch_end = False
13 | n_samples = 10
14 |
15 | # optimizer settings
16 | # Adam optimizer, Cosine decay with Warmup
17 | optimizer = 'Adam'
18 | lr = 2e-5
19 | clip_lr = lr
20 | warmup_steps_pct = 0.05
21 |
22 | # data settings
23 | dataset = 'n_imagenet'
24 | data_root = './data/N_Imagenet/'
25 | num_shots = None
26 | img_aug = True
27 | train_batch_size = 128 // gpus
28 | val_batch_size = train_batch_size * 2
29 | num_workers = 8
30 |
31 | # event2img conversion
32 | quantize_args = dict(
33 | max_imgs=2,
34 | N=70000,
35 | split_method='event_count',
36 | convert_method='event_histogram',
37 | grayscale=True,
38 | count_non_zero=False,
39 | background_mask=True,
40 | )
41 |
42 | # model configs
43 | model = 'FTCLIP'
44 | clip_dict = dict(
45 | # 'RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32'
46 | # 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'
47 | arch='ViT-L/14',
48 | prompt='a point cloud image of a {}',
49 | agg_func='mean', # aggregate the logits over views
50 | lora='qkvo-16', # LoRA fine-tuning, 'qv-16', 'qkv-16' (int), 'qkvo-16'
51 | only_conv1=False, # only tune the first conv layer
52 | only_bias=False, # only tune the bias terms
53 | only_ln=False, # only tune the LayerNorm layers
54 | only_cls_fc=False, # only tune the embedding projection head
55 | only_cls_token=False, # only tune the CLS token
56 | )
57 |
58 | # adapter configs
59 | d_model = 256
60 | adapter_dict = dict(
61 | adapter_type='text-identity',
62 | in_dim=512,
63 | d_model=d_model,
64 | num_heads=d_model // 64,
65 | ffn_dim=d_model * 4,
66 | norm_first=True,
67 | num_layers=2,
68 | residual=0.95,
69 | )
70 |
71 | # loss configs
72 | loss_dict = dict(
73 | use_logits_loss=True, # CE over mean logits
74 | use_probs_loss=False, # CE over mean probs
75 | )
76 |
77 | ce_loss_w = 1.
78 |
79 | # save the model with the highest acc
80 | ckp_monitor = 'val/probs_acc'
81 | ckp_monitor_type = 'max' # 'max' or 'min'
82 |
--------------------------------------------------------------------------------
/configs/ftclip/ft_text_fsclip_ncaltech_params-vitb16.py:
--------------------------------------------------------------------------------
1 | from nerv.training import BaseParams
2 |
3 |
4 | class EventCLIPParams(BaseParams):
5 | project = 'EventCLIP'
6 |
7 | # training settings
8 | gpus = 1
9 | max_epochs = 50
10 | save_interval = 1
11 | eval_interval = 5
12 | save_epoch_end = False
13 | n_samples = 5
14 |
15 | # optimizer settings
16 | # Adam optimizer, Cosine decay with Warmup
17 | optimizer = 'Adam'
18 | lr = 1e-4
19 | clip_lr = lr / 10.
20 | warmup_steps_pct = 0.05
21 |
22 | # data settings
23 | dataset = 'n_caltech'
24 | data_root = './data/N-Caltech101/'
25 | num_shots = None
26 | repeat_data = True
27 | img_aug = True
28 | train_batch_size = 32 // gpus
29 | val_batch_size = train_batch_size * 2
30 | num_workers = 8
31 |
32 | # event2img conversion
33 | quantize_args = dict(
34 | max_imgs=2,
35 | N=20000,
36 | split_method='event_count',
37 | convert_method='event_histogram',
38 | grayscale=True,
39 | count_non_zero=False,
40 | background_mask=True,
41 | )
42 |
43 | # model configs
44 | model = 'FTCLIP'
45 | clip_dict = dict(
46 | # 'RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32'
47 | # 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'
48 | arch='ViT-B/16', # to compare with E-CLIP
49 | prompt='a point cloud image of a {}',
50 | agg_func='mean', # aggregate the logits over views
51 | lora=-1, # use LoRA fine-tuning, typically r = 4, 16
52 | only_conv1=False, # only tune the first conv layer
53 | only_bias=False, # only tune the bias terms
54 | only_ln=False, # only tune the LayerNorm layers
55 | only_cls_fc=False, # only tune the embedding projection head
56 | only_cls_token=False, # only tune the CLS token
57 | # lora >> bias > conv > fc > ln > CLS
58 | )
59 |
60 | # adapter configs
61 | d_model = 256
62 | adapter_dict = dict(
63 | adapter_type='text-identity',
64 | in_dim=512,
65 | d_model=d_model,
66 | num_heads=d_model // 64,
67 | ffn_dim=d_model * 4,
68 | norm_first=True,
69 | num_layers=2,
70 | residual=0.8,
71 | )
72 |
73 | # loss configs
74 | loss_dict = dict(
75 | use_logits_loss=True, # CE over mean logits
76 | use_probs_loss=False, # CE over mean probs
77 | )
78 |
79 | ce_loss_w = 1.
80 |
81 | # save the model with the highest acc
82 | ckp_monitor = 'val/probs_acc'
83 | ckp_monitor_type = 'max' # 'max' or 'min'
84 |
--------------------------------------------------------------------------------
/configs/ftclip/ft_text_fsclip_nin_params-vitb16.py:
--------------------------------------------------------------------------------
1 | from nerv.training import BaseParams
2 |
3 |
4 | class EventCLIPParams(BaseParams):
5 | project = 'EventCLIP'
6 |
7 | # training settings
8 | gpus = 1
9 | max_epochs = 50
10 | # max_epochs = 3 # following E-CLIP, only train 3 epochs on full dataset
11 | # also change `save_interval=0.05`, `eval_interval = 1`
12 | save_interval = 1
13 | eval_interval = 5
14 | save_epoch_end = False
15 | n_samples = 10
16 |
17 | # optimizer settings
18 | # Adam optimizer, Cosine decay with Warmup
19 | optimizer = 'Adam'
20 | lr = 2e-5
21 | clip_lr = lr / 10.
22 | warmup_steps_pct = 0.05
23 |
24 | # data settings
25 | dataset = 'n_imagenet'
26 | data_root = './data/N_Imagenet/'
27 | num_shots = None
28 | img_aug = True
29 | train_batch_size = 128 // gpus
30 | val_batch_size = train_batch_size * 2
31 | num_workers = 8
32 |
33 | # event2img conversion
34 | quantize_args = dict(
35 | max_imgs=2,
36 | N=70000,
37 | split_method='event_count',
38 | convert_method='event_histogram',
39 | grayscale=True,
40 | count_non_zero=False,
41 | background_mask=True,
42 | )
43 |
44 | # model configs
45 | model = 'FTCLIP'
46 | clip_dict = dict(
47 | # 'RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32'
48 | # 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'
49 | arch='ViT-B/16', # to compare with E-CLIP
50 | prompt='a point cloud image of a {}',
51 | agg_func='mean', # aggregate the logits over views
52 | lora=-1, # use LoRA fine-tuning, typically r = 4, 16
53 | only_conv1=False, # only tune the first conv layer
54 | only_bias=False, # only tune the bias terms
55 | only_ln=False, # only tune the LayerNorm layers
56 | only_cls_fc=False, # only tune the embedding projection head
57 | only_cls_token=False, # only tune the CLS token
58 | )
59 |
60 | # adapter configs
61 | d_model = 256
62 | adapter_dict = dict(
63 | adapter_type='text-identity',
64 | in_dim=512,
65 | d_model=d_model,
66 | num_heads=d_model // 64,
67 | ffn_dim=d_model * 4,
68 | norm_first=True,
69 | num_layers=2,
70 | residual=0.95,
71 | )
72 |
73 | # loss configs
74 | loss_dict = dict(
75 | use_logits_loss=True, # CE over mean logits
76 | use_probs_loss=False, # CE over mean probs
77 | )
78 |
79 | ce_loss_w = 1.
80 |
81 | # save the model with the highest acc
82 | ckp_monitor = 'val/probs_acc'
83 | ckp_monitor_type = 'max' # 'max' or 'min'
84 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | vis/
2 | wandb/
3 | sbatch/
4 | pretrained/
5 | checkpoint/
6 | checkpoints/
7 | data/*/
8 | data/N-Caltech101
9 | data/N-Cars
10 | data/N_Imagenet
11 | data/pseudo-N_Imagenet
12 | .idea/
13 |
14 | # Byte-compiled / optimized / DLL files
15 | __pycache__/
16 | *.py[cod]
17 | *$py.class
18 |
19 | # C extensions
20 | *.so
21 |
22 | # Distribution / packaging
23 | .Python
24 | build/
25 | develop-eggs/
26 | dist/
27 | downloads/
28 | eggs/
29 | .eggs/
30 | lib/
31 | lib64/
32 | parts/
33 | sdist/
34 | var/
35 | wheels/
36 | pip-wheel-metadata/
37 | share/python-wheels/
38 | *.egg-info/
39 | .installed.cfg
40 | *.egg
41 | MANIFEST
42 |
43 | # PyInstaller
44 | # Usually these files are written by a python script from a template
45 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
46 | *.manifest
47 | *.spec
48 |
49 | # Installer logs
50 | pip-log.txt
51 | pip-delete-this-directory.txt
52 |
53 | # Unit test / coverage reports
54 | htmlcov/
55 | .tox/
56 | .nox/
57 | .coverage
58 | .coverage.*
59 | .cache
60 | nosetests.xml
61 | coverage.xml
62 | *.cover
63 | *.py,cover
64 | .hypothesis/
65 | .pytest_cache/
66 |
67 | # Translations
68 | *.mo
69 | *.pot
70 |
71 | # Django stuff:
72 | *.log
73 | local_settings.py
74 | db.sqlite3
75 | db.sqlite3-journal
76 |
77 | # Flask stuff:
78 | instance/
79 | .webassets-cache
80 |
81 | # Scrapy stuff:
82 | .scrapy
83 |
84 | # Sphinx documentation
85 | docs/_build/
86 |
87 | # PyBuilder
88 | target/
89 |
90 | # Jupyter Notebook
91 | .ipynb_checkpoints
92 |
93 | # IPython
94 | profile_default/
95 | ipython_config.py
96 |
97 | # pyenv
98 | .python-version
99 |
100 | # pipenv
101 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
102 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
103 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
104 | # install all needed dependencies.
105 | #Pipfile.lock
106 |
107 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
108 | __pypackages__/
109 |
110 | # Celery stuff
111 | celerybeat-schedule
112 | celerybeat.pid
113 |
114 | # SageMath parsed files
115 | *.sage.py
116 |
117 | # Environments
118 | .env
119 | .venv
120 | env/
121 | venv/
122 | ENV/
123 | env.bak/
124 | venv.bak/
125 |
126 | # Spyder project settings
127 | .spyderproject
128 | .spyproject
129 |
130 | # Rope project settings
131 | .ropeproject
132 |
133 | # mkdocs documentation
134 | /site
135 |
136 | # mypy
137 | .mypy_cache/
138 | .dmypy.json
139 | dmypy.json
140 |
141 | # Pyre type checker
142 | .pyre/
143 |
--------------------------------------------------------------------------------
/models/adapter.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class Adapter(nn.Module):
6 | """Base Adapter class.
7 |
8 | Handle common operations such as residual connection.
9 | """
10 |
11 | def __init__(self, residual=True):
12 | super().__init__()
13 |
14 | assert isinstance(residual, (bool, float))
15 | if isinstance(residual, bool):
16 | residual = 0.5 if residual else 0.
17 | if isinstance(residual, float):
18 | assert 0. <= residual <= 1.
19 |
20 | self.residual = residual
21 |
22 | def residual_add(self, in_feats, new_feats):
23 | """Perform residual connection."""
24 | assert isinstance(self.residual, float)
25 | return in_feats * self.residual + new_feats * (1. - self.residual)
26 |
27 | def forward(self, *args, **kwargs):
28 | raise NotImplementedError
29 |
30 | @property
31 | def dtype(self):
32 | raise NotImplementedError
33 |
34 |
35 | class IdentityAdapter(Adapter):
36 | """Trivial Adapter that does nothing."""
37 |
38 | def __init__(self, *args, **kwargs):
39 | super().__init__(residual=False)
40 |
41 | # dummy parameter to record the dtype & device
42 | self.dummy = nn.Parameter(torch.zeros(1), requires_grad=False)
43 |
44 | def forward(self, feats, valid_masks):
45 | """feats: [B, num_views, C]."""
46 | return feats
47 |
48 | @property
49 | def dtype(self):
50 | return self.dummy.dtype
51 |
52 |
53 | class TransformerAdapter(Adapter):
54 | """Transformer Adapter which is order-invariant."""
55 |
56 | def __init__(
57 | self,
58 | in_dim,
59 | d_model=256,
60 | num_heads=4,
61 | ffn_dim=256 * 4,
62 | norm_first=True,
63 | num_layers=2,
64 | residual=False,
65 | ):
66 | super().__init__(residual=residual)
67 |
68 | self.d_model = d_model
69 | enc_layer = nn.TransformerEncoderLayer(
70 | d_model=d_model,
71 | nhead=num_heads,
72 | dim_feedforward=ffn_dim,
73 | norm_first=norm_first,
74 | batch_first=True,
75 | )
76 | self.transformer_encoder = nn.TransformerEncoder(
77 | encoder_layer=enc_layer, num_layers=num_layers)
78 |
79 | self.in_proj = nn.Linear(in_dim, d_model)
80 | self.out_proj = nn.Linear(d_model, in_dim)
81 |
82 | def forward(self, feats, valid_masks):
83 | """Inter-view interaction via Attention.
84 |
85 | Args:
86 | feats: [B, num_views, C]
87 | valid_masks: [B, num_views], True for valid views.
88 | Should mask the Attention in Transformer accordingly.
89 | """
90 | in_feats = feats
91 |
92 | # [B, num_views, d_model]
93 | feats = self.in_proj(feats)
94 |
95 | # [B, num_views, d_model]
96 | pad_masks = (~valid_masks) # True --> padded
97 | feats = self.transformer_encoder(feats, src_key_padding_mask=pad_masks)
98 |
99 | # [B, num_views, C]
100 | feats = self.out_proj(feats)
101 |
102 | # residual connection
103 | feats = self.residual_add(in_feats, feats)
104 |
105 | return feats
106 |
107 | @property
108 | def dtype(self):
109 | return self.in_proj.weight.dtype
110 |
--------------------------------------------------------------------------------
/datasets/imagenet.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 |
5 | from .caltech import NCaltech101
6 |
7 |
8 | def load_event(event_path):
9 | """Load event data from npz file."""
10 | event = np.load(event_path)['event_data']
11 | event = np.stack([
12 | event['x'],
13 | event['y'],
14 | event['t'],
15 | event['p'].astype(np.uint8),
16 | ], 1) # [N, 4]
17 |
18 | event = event.astype(float)
19 |
20 | # Account for int-type timestamp
21 | event[:, 2] /= 1e6
22 |
23 | # Account for zero polarity
24 | if event[:, 3].min() >= -0.5:
25 | event[:, 3][event[:, 3] <= 0.5] = -1
26 |
27 | return event
28 |
29 |
30 | class NImageNet(NCaltech101):
31 | """Dataset class for N-ImageNet dataset."""
32 |
33 | def __init__(
34 | self,
35 | root,
36 | augmentation=False,
37 | num_shots=None,
38 | ):
39 | super().__init__(
40 | root=root,
41 | augmentation=augmentation,
42 | num_shots=num_shots,
43 | repeat=False,
44 | new_cnames=None,
45 | )
46 |
47 | # data stats
48 | self.resolution = (480, 640)
49 | self.max_t = 0.055 # max
50 | self.max_n = 135000 # 95th percentile
51 |
52 | # data augmentation
53 | self.flip_time = True
54 |
55 | # load folder name to class name mapping
56 | cur_dir = os.path.dirname(os.path.abspath(__file__))
57 | label_map = os.path.join(cur_dir, 'files/CLIP-IN_ClassNames.txt')
58 | with open(label_map, 'r') as f:
59 | lines = f.readlines()[:1000]
60 | lines = [line.strip() for line in lines]
61 | """
62 | n01440764 tench
63 | n01443537 goldfish
64 | n01484850 great white shark
65 | n01491361 tiger shark
66 | n01494475 hammerhead shark
67 | """
68 | folder2name = {
69 | s.split(' ')[0]: ' '.join(s.split(' ')[1:])
70 | for s in lines
71 | }
72 | self.folder2name = folder2name
73 | self.name2folder = {v: k for k, v in folder2name.items()}
74 | self.classes = [folder2name[c] for c in self.classes]
75 |
76 | @staticmethod
77 | def _load_events(event_path):
78 | """Load events from a file."""
79 | return load_event(event_path).astype(np.float32)
80 |
81 |
82 | def build_n_imagenet_dataset(
83 | params,
84 | val_only=False,
85 | gen_data=False,
86 | subset=-1,
87 | ):
88 | """Build the N-ImageNet dataset."""
89 | val_names = {
90 | 1: 'val_mode_1',
91 | 2: 'val_mode_5',
92 | 3: 'val_mode_6',
93 | 4: 'val_mode_7',
94 | 5: 'val_mode_3',
95 | 6: 'val_brightness_4',
96 | 7: 'val_brightness_5',
97 | 8: 'val_brightness_6',
98 | 9: 'val_brightness_7',
99 | }
100 | if subset > 0:
101 | val_set = val_names[subset]
102 | val_root = os.path.join(params.data_root, f'extracted_{val_set}')
103 | print('Using N-ImageNet subset:', val_set)
104 | else:
105 | val_root = os.path.join(params.data_root, 'extracted_val')
106 | print('Using normal N-ImageNet val set')
107 |
108 | # only build the test set
109 | test_set = NImageNet(
110 | root=val_root,
111 | augmentation=False,
112 | )
113 | if val_only:
114 | assert not gen_data
115 | return test_set
116 | # build the training set for pseudo label generation
117 | if gen_data:
118 | return NImageNet(
119 | root=os.path.join(params.data_root, 'extracted_train'),
120 | augmentation=False,
121 | )
122 |
123 | # build the training set
124 | train_set = NImageNet(
125 | root=os.path.join(params.data_root, 'extracted_train'),
126 | augmentation=True,
127 | num_shots=params.get('num_shots', None),
128 | )
129 | return train_set, test_set
130 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # EventCLIP
2 |
3 | [**EventCLIP: Adapting CLIP for Event-based Object Recognition**](https://github.com/Wuziyi616/EventCLIP)
4 | [Ziyi Wu](https://wuziyi616.github.io/),
5 | [Xudong Liu](https://www.linkedin.com/in/xudong-frank-liu-566513198/),
6 | [Igor Gilitschenski](https://tisl.cs.utoronto.ca/author/igor-gilitschenski/)
7 | _[arXiv'23](https://arxiv.org/abs/2306.06354) |
8 | [GitHub](https://github.com/Wuziyi616/EventCLIP) |
9 | [arXiv](https://arxiv.org/abs/2306.06354)_
10 |
11 | ## Introduction
12 |
13 | This is the official PyTorch implementation for paper: [EventCLIP: Adapting CLIP for Event-based Object Recognition](https://arxiv.org/abs/2306.06354).
14 | The code contains:
15 |
16 | - Zero-shot EventCLIP inference on N-Caltech, N-Cars, N-ImageNet datasets
17 | - Few-shot adaptation of EventCLIP on the three datasets, with SOTA results in the low-data regime
18 | - Data-efficient fine-tuning of EventCLIP on N-Caltech & N-ImageNet, achieving superior accuracy over fully-trained baselines
19 | - Learning from unlabeled data on N-Caltech & N-ImageNet, including both fully unsupervised and semi-supervised settings
20 |
21 | ### Motivation
22 |
23 | [Event cameras](https://tub-rip.github.io/eventvision2023/#null) are bio-inspired low-latency and energy-efficient sensors, which have gained significant interest recently.
24 | However, due to the lack of large-scale datasets, the event-based vision community cannot enjoy the recent success of foundation models in RGB vision.
25 | This paper thus seeks to adapt one of the most impactful VLM, [CLIP](https://openai.com/research/clip), to recognize event data.
26 | We study common practice in data-efficient model adaptation, and propose a general framework named EventCLIP.
27 | The overall pipeline is shown below:
28 |
29 |

30 |
31 | ## Update
32 |
33 | - 2023.9.14: Release code for learning with unlabeled data
34 | - 2023.7.17: Release fine-tuning code
35 | - 2023.5.17: Initial code release!
36 |
37 | ## Installation
38 |
39 | Please refer to [install.md](docs/install.md) for step-by-step guidance on how to install the packages.
40 |
41 | ## Experiments
42 |
43 | **This codebase is tailored to [Slurm](https://slurm.schedmd.com/documentation.html) GPU clusters with preemption mechanism.**
44 | For the configs, we mainly use A40 with 40GB memory (though many experiments don't require so much memory).
45 | Please modify the code accordingly if you are using other hardware settings:
46 |
47 | - Please go through `train.py` and change the fields marked by `TODO:`
48 | - Please read the config file for the model you want to train.
49 | We use DDP with multiple GPUs to accelerate training.
50 | You can use less GPUs to achieve a better memory-speed trade-off
51 |
52 | ### Dataset Preparation
53 |
54 | Please refer to [data.md](docs/data.md) for dataset downloading and pre-processing.
55 |
56 | ### Reproduce Results
57 |
58 | Please see [benchmark.md](docs/benchmark.md) for detailed instructions on how to reproduce our results in the paper.
59 |
60 | ## Possible Issues
61 |
62 | See the troubleshooting section of [nerv](https://github.com/Wuziyi616/nerv#possible-issues) for potential issues.
63 |
64 | Please open an issue if you encounter any errors running the code.
65 |
66 | ## Citation
67 |
68 | Please cite our paper if you find it useful in your research:
69 |
70 | ```
71 | @article{wu2023eventclip,
72 | title={{EventCLIP}: Adapting CLIP for Event-based Object Recognition},
73 | author={Wu, Ziyi and Liu, Xudong and Gilitschenski, Igor},
74 | journal={arXiv preprint arXiv:2306.06354},
75 | year={2023}
76 | }
77 | ```
78 |
79 | ## Acknowledgement
80 |
81 | We thank the authors of [CLIP](https://github.com/openai/CLIP), [EST](https://github.com/uzh-rpg/rpg_event_representation_learning), [n_imagenet](https://github.com/82magnolia/n_imagenet), [PointCLIP](https://github.com/ZrrSkywalker/PointCLIP), [LoRA](https://github.com/microsoft/LoRA) for opening source their wonderful works.
82 |
83 | ## License
84 |
85 | EventCLIP is released under the MIT License. See the LICENSE file for more details.
86 |
87 | ## Contact
88 |
89 | If you have any questions about the code, please contact Ziyi Wu dazitu616@gmail.com
90 |
--------------------------------------------------------------------------------
/datasets/vis.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import numpy as np
4 |
5 |
6 | def make_event_histogram(x, y, p, red, blue, shape, thresh=10., **kwargs):
7 | """Event polarity histogram."""
8 | # count the number of positive and negative events per pixel
9 | H, W = shape
10 | pos_x, pos_y = x[p > 0].astype(np.int32), y[p > 0].astype(np.int32)
11 | pos_count = np.bincount(pos_x + pos_y * W, minlength=H * W).reshape(H, W)
12 | neg_x, neg_y = x[p < 0].astype(np.int32), y[p < 0].astype(np.int32)
13 | neg_count = np.bincount(neg_x + neg_y * W, minlength=H * W).reshape(H, W)
14 | hist = np.stack([pos_count, neg_count], axis=-1) # [H, W, 2]
15 |
16 | # remove hotpixels, i.e. pixels with event num > thresh * std + mean
17 | if thresh > 0:
18 | if kwargs.get('count_non_zero', False):
19 | mean = hist[hist > 0].mean()
20 | std = hist[hist > 0].std()
21 | else:
22 | mean = hist.mean()
23 | std = hist.std()
24 | hist[hist > thresh * std + mean] = 0
25 |
26 | # normalize
27 | hist = hist.astype(np.float32) / hist.max() # [H, W, 2]
28 |
29 | # colorize
30 | cmap = np.stack([red, blue], axis=0).astype(np.float32) # [2, 3]
31 | img = hist @ cmap # [H, W, 3]
32 |
33 | # alpha-masking with pure white background
34 | if kwargs.get('background_mask', True):
35 | weights = np.clip(hist.sum(-1, keepdims=True), a_min=0, a_max=1)
36 | background = np.ones_like(img) * 255.
37 | img = img * weights + background * (1. - weights)
38 |
39 | img = np.round(img).astype(np.uint8) # [H, W, 3], np.uint8 in (0, 255)
40 |
41 | return img
42 |
43 |
44 | def parse_events(events):
45 | """Read (x,y,t,p) from input events (can be np.array or dict)."""
46 | if isinstance(events, dict):
47 | x, y, t, p = events['x'], events['y'], events['t'], events['p']
48 | else:
49 | x, y, t, p = events[:, 0], events[:, 1], events[:, 2], events[:, 3]
50 | x, y, p = x.astype(np.int32), y.astype(np.int32), p.astype(np.int32)
51 | t_us = t * 1e6 # convert to us unit
52 | return x, y, t_us, p
53 |
54 |
55 | def split_event_count(t, N=30000):
56 | """Split the events according to event count."""
57 | tot_cnt = len(t)
58 |
59 | # if the event count is too small, just return the whole chunk
60 | if tot_cnt < N:
61 | return [0], [tot_cnt], [t[0]], [t[-1]]
62 |
63 | # find the start and end time of each event chunk w.r.t event index
64 | idx = np.arange(0, tot_cnt, N).tolist()
65 | idx1, idx0 = idx[1:], idx[:-1]
66 | # add the last index if the last chunk of events is not so small
67 | if tot_cnt - idx[-1] > N * 0.5:
68 | idx0.append(tot_cnt - N)
69 | idx1.append(tot_cnt)
70 | t0, t1 = t[idx0], t[np.array(idx1) - 1]
71 |
72 | return idx0, idx1, t0, t1
73 |
74 |
75 | def events2frames(
76 | events, # [N, 4 (x,y,t,p)]
77 | split_method, # 'event_count'
78 | convert_method, # 'event_histogram'
79 | shape=(180, 240),
80 | **kwargs,
81 | ):
82 | """Convert events to 2D frames."""
83 | # some additional arguments
84 | grayscale = kwargs.pop('grayscale', True) # True, False
85 |
86 | # parse different input formats
87 | x, y, t, p = parse_events(events)
88 |
89 | # split the events into different chunks
90 | assert split_method == 'event_count'
91 | N = int(kwargs['N'])
92 | idx0, idx1, t0, t1 = split_event_count(t, N)
93 |
94 | # color map for pos and neg events
95 | if grayscale:
96 | if isinstance(grayscale, bool):
97 | v = 127
98 | else:
99 | v = np.array(grayscale) # values in addition to 127
100 | red = np.round(np.ones(3) * v).astype(np.uint8)
101 | blue = np.round(np.ones(3) * v).astype(np.uint8)
102 | else:
103 | red = np.array([255, 0, 0], dtype=np.uint8)
104 | blue = np.array([0, 0, 255], dtype=np.uint8)
105 |
106 | frames = []
107 | for t_idx, (i0, i1) in enumerate(zip(idx0, idx1)):
108 | xx, yy, pp, tt = x[i0:i1], y[i0:i1], p[i0:i1], t[i0:i1]
109 | if convert_method == 'event_histogram':
110 | frame = make_event_histogram(xx, yy, pp, red, blue, shape,
111 | **kwargs)
112 | else:
113 | raise NotImplementedError(f'{convert_method} not implemented!')
114 | frames.append(copy.deepcopy(frame))
115 | frames = np.stack(frames) # [N, H, W, 3]
116 |
117 | return frames # [N, H, W, 3]
118 |
--------------------------------------------------------------------------------
/scripts/sbatch_run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # SBATCH file can't directly take command args
4 | # as a workaround, I first use a sh script to read in args
5 | # and then create a new .slrm file for SBATCH execution
6 |
7 | #######################################################################
8 | # An example usage:
9 | # GPUS=1 CPUS_PER_GPU=8 MEM_PER_CPU=5 QOS=scavenger ./scripts/sbatch_run.sh rtx6000 \
10 | # test-sbatch ./train.py ddp --params params.py --fp16 --ddp --cudnn
11 | #######################################################################
12 |
13 | # read args from command line
14 | GPUS=${GPUS:-1}
15 | CPUS_PER_GPU=${CPUS_PER_GPU:-8}
16 | MEM_PER_CPU=${MEM_PER_CPU:-5}
17 | QOS=${QOS:-scavenger}
18 | TIME=${TIME:-96:00:00}
19 | if [[ $QOS == "cpu" ]]; then
20 | QOS="cpu_qos"
21 | GPUS=0
22 | CPUS_PER_TASK=$CPUS_PER_GPU
23 | else
24 | CPUS_PER_TASK=$((GPUS * CPUS_PER_GPU))
25 | fi
26 |
27 | # python args
28 | PY_ARGS=${@:5}
29 | PARTITION=$1
30 | JOB_NAME=$2
31 | PY_FILE=$3
32 | DDP=$4
33 |
34 | # create log files
35 | SLRM_NAME="${JOB_NAME/\//"_"}"
36 | LOG_DIR=checkpoint/"$(basename -- $JOB_NAME)"
37 | DATETIME=$(date "+%Y-%m-%d_%H:%M:%S")
38 | LOG_FILE=$LOG_DIR/${DATETIME}.log
39 | SLRM_LOG="${LOG_DIR}/slurm.log"
40 |
41 | # set up log output folder
42 | mkdir -p $LOG_DIR
43 |
44 | # create new .slrm file
45 | slrm_file="run-${SLRM_NAME}.slrm"
46 |
47 | # python runner for DDP
48 | if [[ $DDP == "ddp" ]]; then
49 | PORT=$((29501 + $RANDOM % 100)) # randomly select a port
50 | PYTHON="python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT"
51 | else
52 | PYTHON="python"
53 | fi
54 |
55 | # get the max possible time limit from QOS
56 | # please refer to https://support.vectorinstitute.ai/Vaughan_slurm_changes
57 | get_time() {
58 | local req_time="$1"
59 | if [[ $req_time == "0" ]]; then
60 | req_time="96:00:00"
61 | fi
62 | local qos="$2" # make it lower case
63 | qos="${qos,,}"
64 | if [[ $qos == "cpu_qos" ]]; then
65 | max_time="96:00:00"
66 | elif [[ $qos == "normal" ]]; then
67 | max_time="16:00:00"
68 | elif [[ $qos == "m" ]]; then
69 | max_time="12:00:00"
70 | elif [[ $qos == "m2" ]]; then
71 | max_time="08:00:00"
72 | elif [[ $qos == "m3" ]]; then
73 | max_time="04:00:00"
74 | elif [[ $qos == "m4" ]]; then
75 | max_time="02:00:00"
76 | elif [[ $qos == "m5" ]]; then
77 | max_time="01:00:00"
78 | elif [[ $qos == "long" ]]; then
79 | max_time="48:00:00"
80 | elif [[ $qos == "deadline" ]]; then
81 | max_time="00:00:00"
82 | elif [[ $qos == "high" ]]; then
83 | max_time="08:00:00"
84 | elif [[ $qos == "scavenger" ]]; then
85 | max_time="96:00:00"
86 | else
87 | echo "Invalid QOS $qos"
88 | return # this will trigger `Invalid --time specification` and fail the job
89 | fi
90 | # return the smaller one
91 | # Remove colons and compare as numbers
92 | num_req_time=$(echo "${req_time//:/}" | sed 's/^0*//')
93 | num_max_time=$(echo "${max_time//:/}" | sed 's/^0*//')
94 | if [[ $num_req_time -lt $num_max_time ]]; then
95 | echo $req_time
96 | else
97 | echo $max_time
98 | fi
99 | }
100 | TIME=$(get_time $TIME $QOS)
101 | echo "Run with QOS=$QOS, TIME=$TIME"
102 |
103 | # write to new file
104 | echo "#!/bin/bash
105 |
106 | # set up SBATCH args
107 | #SBATCH --job-name=$SLRM_NAME
108 | #SBATCH --output=$LOG_FILE
109 | #SBATCH --error=$LOG_FILE
110 | #SBATCH --open-mode=append
111 | #SBATCH --partition=$PARTITION # self-explanatory, set to your preference (e.g. gpu or cpu on MaRS, p100, t4, or cpu on Vaughan)
112 | #SBATCH --cpus-per-task=$CPUS_PER_TASK # self-explanatory, set to your preference
113 | #SBATCH --ntasks=1
114 | #SBATCH --ntasks-per-node=1
115 | #SBATCH --mem-per-cpu=${MEM_PER_CPU}G # self-explanatory, set to your preference
116 | #SBATCH --gres=gpu:$GPUS # NOTE: you need a GPU for CUDA support; self-explanatory, set to your preference
117 | #SBATCH --nodes=1
118 | #SBATCH --qos=$QOS # for 'high' and 'deadline' QoS, refer to https://support.vectorinstitute.ai/AboutVaughan2
119 | #SBATCH --time=$TIME # running time limit, 0 as unlimited
120 |
121 | # log some necessary environment params
122 | echo \$SLURM_JOB_ID >> $LOG_FILE # log the job id
123 | echo \$SLURM_JOB_PARTITION >> $LOG_FILE # log the job partition
124 |
125 | echo $CONDA_PREFIX >> $LOG_FILE # log the active conda environment
126 |
127 | python --version >> $LOG_FILE # log Python version
128 | gcc --version >> $LOG_FILE # log GCC version
129 | nvcc --version >> $LOG_FILE # log NVCC version
130 |
131 | # run python file
132 | $PYTHON $PY_FILE $PY_ARGS >> $LOG_FILE # the script above, with its standard output appended log file
133 |
134 | " >> ./$slrm_file
135 |
136 | # run the created file
137 | job_id=$(sbatch --parsable $slrm_file)
138 | echo "Submitted batch job $job_id"
139 |
140 | sleep 0.5
141 | if [[ $job_id ]]; then # successfully submitted the job
142 | ./scripts/resubmit_failed_job.sh $job_id $SLRM_NAME $SLRM_LOG
143 | else # failed to submit the job
144 | rm -f run-${SLRM_NAME}.slrm
145 | echo "Failed to submit job $SLRM_NAME"
146 | fi
147 |
--------------------------------------------------------------------------------
/datasets/imagenet_mini.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from .caltech import get_real_path
4 | from .imagenet import NImageNet
5 |
6 | # N-ImageNet (Mini) subset, taken from https://arxiv.org/pdf/2308.09383.pdf
7 | # Note that this is slightly different from the Mini-ImageNet used in e.g. MAML
8 | MINI_NAMES = [
9 | "hamster", "academic gown", "airship", "jackfruit", "barbershop",
10 | "cocktail shaker", "Komodo dragon", "sunglasses", "grey fox", "cello",
11 | "comic book", "goldfish", "Bloodhound", "porcupine", "jaguar", "kingsnake",
12 | "altar", "water buffalo", "chiton", "scarf", "storage chest", "tool kit",
13 | "sea anemone", "Border Terrier", "menu", "picket fence", "forklift",
14 | "yellow lady's slipper", "chameleon", "dragonfly", "Pomeranian",
15 | "European garden spider", "Airedale Terrier", "frilled-necked lizard",
16 | "black stork", "valley", "radio telescope", "leopard", "crossword",
17 | "Australian Terrier", "Shih Tzu", "husky", "can opener", "artichoke",
18 | "assault rifle", "fountain pen", "harvestman", "parallel bars",
19 | "harmonica", "half-track", "snoek fish", "pencil sharpener", "submarine",
20 | "muzzle", "eastern diamondback rattlesnake", "Miniature Schnauzer",
21 | "missile", "Komondor", "grand piano", "website", "king penguin", "canoe",
22 | "red-breasted merganser", "trolleybus", "quail", "poke bonnet",
23 | "King Charles Spaniel", "race car", "Malinois", "solar thermal collector",
24 | "slug", "bucket", "dung beetle", "Asian elephant", "window screen",
25 | "Flat-Coated Retriever", "steel drum", "snowplow", "handkerchief",
26 | "tailed frog", "church", "Chesapeake Bay Retriever", "Christmas stocking",
27 | "hatchet", "hair clip", "vulture", "sidewinder rattlesnake",
28 | "oscilloscope", "worm snake", "eel", "wok", "planetarium",
29 | "Old English Sheepdog", "platypus", "Pembroke Welsh Corgi",
30 | "alligator lizard", "consomme", "African rock python", "hot tub",
31 | "Tibetan Mastiff"
32 | ]
33 |
34 |
35 | class NImageNetMini(NImageNet):
36 | """Dataset class for N-ImageNet (Mini) dataset."""
37 |
38 | def __init__(
39 | self,
40 | root,
41 | augmentation=False,
42 | num_shots=None,
43 | repeat=True,
44 | ):
45 | root = get_real_path(root)
46 | self.root = root
47 | # TODO: a hack for identifying generated pseudo labeled datasets
48 | self.is_pseudo = 'pseudo' in root
49 | if self.is_pseudo:
50 | print('Using pseudo labeled dataset!')
51 |
52 | # data stats
53 | self.resolution = (480, 640)
54 | self.max_t = 0.055 # max
55 | self.max_n = 135000 # 95th percentile
56 |
57 | # data augmentation
58 | self.augmentation = augmentation
59 | self.flip_time = True
60 | self.max_shift = 20
61 |
62 | # load folder name to class name mapping
63 | cur_dir = os.path.dirname(os.path.abspath(__file__))
64 | label_map = os.path.join(cur_dir, 'files/CLIP-IN_ClassNames.txt')
65 | with open(label_map, 'r') as f:
66 | lines = f.readlines()[:1000]
67 | lines = [line.strip() for line in lines]
68 | """
69 | n01440764 tench
70 | n01443537 goldfish
71 | n01484850 great white shark
72 | n01491361 tiger shark
73 | n01494475 hammerhead shark
74 | """
75 |
76 | folder2name = {
77 | s.split(' ')[0]: ' '.join(s.split(' ')[1:])
78 | for s in lines
79 | }
80 | # only take a subset of 100 classes
81 | folder2name = {k: v for k, v in folder2name.items() if v in MINI_NAMES}
82 | assert len(folder2name) == 100 == len(MINI_NAMES)
83 | self.classes = list(folder2name.keys())
84 | self.folder2name = folder2name
85 | self.name2folder = {v: k for k, v in folder2name.items()}
86 |
87 | # few-shot cls
88 | self.num_shots = num_shots # number of labeled data per class
89 | self.few_shot = (num_shots is not None and num_shots > 0)
90 | if self.few_shot:
91 | assert 'train' in root.lower(), 'Only sample data in training set'
92 | self.repeat = repeat
93 |
94 | self.labeled_files, self.labels = self._get_sample_idx()
95 | assert len(self.labeled_files) == len(self.labels)
96 |
97 | # finally, get semantically meaningful class names
98 | self.classes = [folder2name[c] for c in self.classes]
99 | assert all(c in self.classes for c in MINI_NAMES) and \
100 | len(self.classes) == 100
101 | self.new_cnames = None
102 |
103 |
104 | def build_n_imagenet_mini_dataset(params, val_only=False, gen_data=False):
105 | """Build the N-ImageNet (Mini) dataset."""
106 | # only build the test set
107 | test_set = NImageNetMini(
108 | root=os.path.join(params.data_root, 'extracted_val'),
109 | augmentation=False,
110 | )
111 | if val_only:
112 | assert not gen_data, 'Only generate pseudo labels on the training set'
113 | return test_set
114 | # build the training set for pseudo label generation
115 | if gen_data:
116 | return NImageNetMini(
117 | root=os.path.join(params.data_root, 'extracted_train'),
118 | augmentation=False,
119 | )
120 |
121 | # build the training set
122 | train_set = NImageNetMini(
123 | root=os.path.join(params.data_root, 'extracted_train'),
124 | augmentation=True,
125 | num_shots=params.get('num_shots', None),
126 | repeat=params.get('repeat_data', True),
127 | )
128 | return train_set, test_set
129 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | """EventCLIP training script"""
2 |
3 | import os
4 | import sys
5 | import pwd
6 | import importlib
7 | import argparse
8 | import wandb
9 |
10 | import torch
11 |
12 | import clip
13 |
14 | from nerv.utils import mkdir_or_exist
15 | from nerv.training import BaseDataModule, find_old_slurm_id
16 |
17 | from models import build_model
18 | from method import build_method
19 | from datasets import build_dataset
20 |
21 |
22 | def main(params):
23 | # have to load CLIP model first
24 | arch = params.clip_dict.pop('arch')
25 | device = 'cuda'
26 | model, preprocess = clip.load(arch, device=device)
27 | # cast weights to FP32
28 | for p in model.parameters():
29 | p.data = p.data.float()
30 |
31 | # build dataset
32 | params.data_transforms = preprocess
33 | train_set, val_set = build_dataset(params)
34 | datamodule = BaseDataModule(
35 | params, train_set=train_set, val_set=val_set, use_ddp=params.ddp)
36 |
37 | # build model
38 | params.clip_dict['clip_model'] = model
39 | params.clip_dict['class_names'] = train_set.classes
40 | params.resolution = train_set.resolution
41 | params.class_names = train_set.classes
42 | params.adapter_dict['in_dim'] = model.visual.output_dim
43 | model = build_model(params)
44 |
45 | # create checkpoint dir
46 | exp_name = os.path.basename(args.params)
47 | ckp_path = os.path.join('checkpoint', exp_name, 'models')
48 | if args.local_rank == 0:
49 | mkdir_or_exist(os.path.dirname(ckp_path))
50 |
51 | # on clusters, quota under user dir is usually limited
52 | # soft link to save the weights in temp space for checkpointing
53 | # e.g. on our cluster, the temp dir is /checkpoint/$USR/$SLURM_JOB_ID/
54 | # TODO: modify this if you are not running on clusters
55 | SLURM_JOB_ID = os.environ.get('SLURM_JOB_ID')
56 | if os.path.exists(ckp_path):
57 | SLURM_JOB_ID = find_old_slurm_id(ckp_path)
58 | else:
59 | if SLURM_JOB_ID:
60 | os.system(r'ln -s /checkpoint/{}/{}/ {}'.format(
61 | pwd.getpwuid(os.getuid())[0], SLURM_JOB_ID, ckp_path))
62 | else:
63 | os.makedirs(ckp_path, exist_ok=True)
64 |
65 | # it's not good to hard-code the wandb id
66 | # but on preemption clusters, we want the job to resume the same wandb
67 | # process after resuming training (i.e. drawing the same graph)
68 | # so we have to keep the same wandb id
69 | # TODO: modify this if you are not running on preemption clusters
70 | preemption = True
71 | if SLURM_JOB_ID and preemption:
72 | logger_id = logger_name = f'{exp_name}-{SLURM_JOB_ID}'
73 | else:
74 | logger_name = exp_name
75 | logger_id = None
76 |
77 | wandb.init(
78 | project=params.project,
79 | name=logger_name,
80 | id=logger_id,
81 | dir=ckp_path,
82 | )
83 |
84 | method = build_method(
85 | model=model,
86 | datamodule=datamodule,
87 | params=params,
88 | ckp_path=ckp_path,
89 | local_rank=args.local_rank,
90 | use_ddp=args.ddp,
91 | use_fp16=args.fp16,
92 | )
93 |
94 | method.fit(
95 | resume_from=args.weight, san_check_val_step=params.san_check_val_step)
96 |
97 |
98 | if __name__ == "__main__":
99 | parser = argparse.ArgumentParser(description='EventCLIP')
100 | parser.add_argument('--params', type=str, required=True)
101 | parser.add_argument('--num_shots', type=int, default=-1)
102 | parser.add_argument('--N', type=int, default=-1)
103 | parser.add_argument('--weight', type=str, default='', help='load weight')
104 | parser.add_argument('--fp16', action='store_true')
105 | parser.add_argument('--ddp', action='store_true')
106 | parser.add_argument('--cudnn', action='store_true')
107 | parser.add_argument('--local_rank', type=int, default=0)
108 | parser.add_argument('--local-rank', type=int, default=0)
109 | args = parser.parse_args()
110 |
111 | if args.params.endswith('.py'):
112 | args.params = args.params[:-3]
113 | sys.path.append(os.path.dirname(args.params))
114 | params = importlib.import_module(os.path.basename(args.params))
115 | params = params.EventCLIPParams()
116 | params.ddp = args.ddp
117 |
118 | assert params.model != 'ZSCLIP', \
119 | 'zero-shot EventCLIP does not require training'
120 |
121 | if args.N > 0:
122 | params.quantize_args['N'] = int(args.N * 1000)
123 | args.params = args.params + f'-N_{args.N}'
124 |
125 | if args.num_shots > 0:
126 | params.num_shots = args.num_shots
127 | args.params = args.params + f'-{args.num_shots}shot'
128 |
129 | # adjust the batch size since N-Cars only have 2 classes
130 | if params.dataset == 'n_cars':
131 | params.train_batch_size = min(
132 | params.num_shots * 2, # 2 classes
133 | params.train_batch_size)
134 | print(f'Set {params.train_batch_size=} for N-Cars')
135 | if params.dataset == 'n_imagenet_mini':
136 | params.train_batch_size = min(
137 | params.num_shots * 100, # 100 classes
138 | params.train_batch_size)
139 | print(f'Set {params.train_batch_size=} for N-ImageNet (Mini)')
140 |
141 | if args.fp16:
142 | print('INFO: using FP16 training!')
143 | if args.ddp:
144 | print('INFO: using DDP training!')
145 | if args.cudnn:
146 | torch.backends.cudnn.benchmark = True
147 | print('INFO: using cudnn benchmark!')
148 |
149 | main(params)
150 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: eventclip
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - _libgcc_mutex=0.1=main
7 | - _openmp_mutex=5.1=1_gnu
8 | - blas=1.0=mkl
9 | - brotlipy=0.7.0=py39h27cfd23_1003
10 | - bzip2=1.0.8=h7b6447c_0
11 | - ca-certificates=2023.08.22=h06a4308_0
12 | - certifi=2023.7.22=py39h06a4308_0
13 | - cffi=1.15.1=py39h5eee18b_3
14 | - charset-normalizer=2.0.4=pyhd3eb1b0_0
15 | - cryptography=41.0.3=py39hdda0065_0
16 | - cudatoolkit=11.3.1=h2bc3f7f_2
17 | - ffmpeg=4.3=hf484d3e_0
18 | - freetype=2.12.1=h4a9f257_0
19 | - giflib=5.2.1=h5eee18b_3
20 | - gmp=6.2.1=h295c915_3
21 | - gnutls=3.6.15=he1e5248_0
22 | - idna=3.4=py39h06a4308_0
23 | - intel-openmp=2023.1.0=hdb19cb5_46305
24 | - jpeg=9e=h5eee18b_1
25 | - lame=3.100=h7b6447c_0
26 | - lcms2=2.12=h3be6417_0
27 | - ld_impl_linux-64=2.38=h1181459_1
28 | - lerc=3.0=h295c915_0
29 | - libdeflate=1.17=h5eee18b_0
30 | - libffi=3.4.4=h6a678d5_0
31 | - libgcc-ng=11.2.0=h1234567_1
32 | - libgomp=11.2.0=h1234567_1
33 | - libiconv=1.16=h7f8727e_2
34 | - libidn2=2.3.4=h5eee18b_0
35 | - libpng=1.6.39=h5eee18b_0
36 | - libstdcxx-ng=11.2.0=h1234567_1
37 | - libtasn1=4.19.0=h5eee18b_0
38 | - libtiff=4.5.1=h6a678d5_0
39 | - libunistring=0.9.10=h27cfd23_0
40 | - libwebp=1.3.2=h11a3e52_0
41 | - libwebp-base=1.3.2=h5eee18b_0
42 | - lz4-c=1.9.4=h6a678d5_0
43 | - mkl=2023.1.0=h213fc3f_46343
44 | - mkl-service=2.4.0=py39h5eee18b_1
45 | - mkl_fft=1.3.8=py39h5eee18b_0
46 | - mkl_random=1.2.4=py39hdb19cb5_0
47 | - ncurses=6.4=h6a678d5_0
48 | - nettle=3.7.3=hbbd107a_1
49 | - numpy=1.25.2=py39h5f9d8c6_0
50 | - numpy-base=1.25.2=py39hb5e798b_0
51 | - openh264=2.1.1=h4ff587b_0
52 | - openssl=3.0.11=h7f8727e_2
53 | - pillow=9.4.0=py39h6a678d5_1
54 | - pip=23.2.1=py39h06a4308_0
55 | - pycparser=2.21=pyhd3eb1b0_0
56 | - pyopenssl=23.2.0=py39h06a4308_0
57 | - pysocks=1.7.1=py39h06a4308_0
58 | - python=3.9.17=h955ad1f_0
59 | - pytorch=1.12.1=py3.9_cuda11.3_cudnn8.3.2_0
60 | - pytorch-mutex=1.0=cuda
61 | - readline=8.2=h5eee18b_0
62 | - requests=2.31.0=py39h06a4308_0
63 | - setuptools=68.0.0=py39h06a4308_0
64 | - sqlite=3.41.2=h5eee18b_0
65 | - tbb=2021.8.0=hdb19cb5_0
66 | - tk=8.6.12=h1ccaba5_0
67 | - torchaudio=0.12.1=py39_cu113
68 | - torchvision=0.13.1=py39_cu113
69 | - typing_extensions=4.7.1=py39h06a4308_0
70 | - urllib3=1.26.16=py39h06a4308_0
71 | - wheel=0.38.4=py39h06a4308_0
72 | - xz=5.4.2=h5eee18b_0
73 | - zlib=1.2.13=h5eee18b_0
74 | - zstd=1.5.5=hc292b87_0
75 | - pip:
76 | - addict==2.4.0
77 | - aiohttp==3.8.5
78 | - aiosignal==1.3.1
79 | - ansi2html==1.8.0
80 | - appdirs==1.4.4
81 | - asttokens==2.4.0
82 | - async-timeout==4.0.3
83 | - attrs==23.1.0
84 | - backcall==0.2.0
85 | - beautifulsoup4==4.12.2
86 | - click==8.1.7
87 | - clip==1.0
88 | - comm==0.1.4
89 | - configargparse==1.7
90 | - contourpy==1.1.1
91 | - cycler==0.11.0
92 | - dash==2.13.0
93 | - dash-core-components==2.0.0
94 | - dash-html-components==2.0.0
95 | - dash-table==5.0.0
96 | - decorator==4.4.2
97 | - docker-pycreds==0.4.0
98 | - exceptiongroup==1.1.3
99 | - executing==1.2.0
100 | - fastjsonschema==2.18.0
101 | - filelock==3.12.4
102 | - flask==2.2.5
103 | - fonttools==4.42.1
104 | - frozenlist==1.4.0
105 | - fsspec==2023.9.2
106 | - ftfy==6.1.1
107 | - gdown==4.7.1
108 | - gitdb==4.0.10
109 | - gitpython==3.1.37
110 | - imageio==2.31.4
111 | - imageio-ffmpeg==0.4.9
112 | - importlib-metadata==6.8.0
113 | - importlib-resources==6.1.0
114 | - ipython==8.15.0
115 | - ipywidgets==8.1.1
116 | - itsdangerous==2.1.2
117 | - jedi==0.19.0
118 | - jinja2==3.1.2
119 | - joblib==1.3.2
120 | - jsonschema==4.19.1
121 | - jsonschema-specifications==2023.7.1
122 | - jupyter-core==5.3.1
123 | - jupyterlab-widgets==3.0.9
124 | - kiwisolver==1.4.5
125 | - lightning-utilities==0.9.0
126 | - markupsafe==2.1.3
127 | - matplotlib==3.8.0
128 | - matplotlib-inline==0.1.6
129 | - moviepy==1.0.3
130 | - multidict==6.0.4
131 | - nbformat==5.7.0
132 | - nest-asyncio==1.5.8
133 | - open3d==0.17.0
134 | - opencv-python==4.8.0.76
135 | - packaging==23.1
136 | - pandas==2.1.1
137 | - parso==0.8.3
138 | - pathtools==0.1.2
139 | - pexpect==4.8.0
140 | - pickleshare==0.7.5
141 | - platformdirs==3.10.0
142 | - plotly==5.17.0
143 | - proglog==0.1.10
144 | - prompt-toolkit==3.0.39
145 | - protobuf==4.24.3
146 | - psutil==5.9.5
147 | - ptyprocess==0.7.0
148 | - pure-eval==0.2.2
149 | - pygments==2.16.1
150 | - pyparsing==3.1.1
151 | - pyquaternion==0.9.9
152 | - python-dateutil==2.8.2
153 | - pytorch-lightning==1.8.6
154 | - pytz==2023.3.post1
155 | - pyyaml==6.0.1
156 | - referencing==0.30.2
157 | - regex==2023.8.8
158 | - retrying==1.3.4
159 | - rpds-py==0.10.3
160 | - scikit-learn==1.3.1
161 | - scipy==1.11.2
162 | - sentry-sdk==1.31.0
163 | - setproctitle==1.3.2
164 | - six==1.16.0
165 | - smmap==5.0.1
166 | - soupsieve==2.5
167 | - stack-data==0.6.2
168 | - tenacity==8.2.3
169 | - tensorboardx==2.6.2.2
170 | - threadpoolctl==3.2.0
171 | - torchmetrics==0.11.4
172 | - tqdm==4.66.1
173 | - traitlets==5.10.0
174 | - tzdata==2023.3
175 | - wandb==0.15.11
176 | - wcwidth==0.2.6
177 | - werkzeug==2.2.3
178 | - widgetsnbextension==4.0.9
179 | - yarl==1.9.2
180 | - zipp==3.17.0
181 |
--------------------------------------------------------------------------------
/datasets/event2img.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | from PIL import Image
4 |
5 | import torch
6 | from torch.utils.data import Dataset
7 |
8 | from .vis import events2frames
9 | from .augment import RandAugment, InterpolationMode
10 | from .utils import random_time_flip_events as tflip_events
11 | from .utils import random_flip_events_along_x as hflip_events
12 |
13 |
14 | class Event2ImageDataset(Dataset):
15 | """A wrapper for EventDataset that converts events to 2D images."""
16 |
17 | def __init__(
18 | self,
19 | transforms,
20 | event_dataset,
21 | quantize_args=dict(
22 | max_imgs=2,
23 | split_method='event_count',
24 | convert_method='event_histogram',
25 | N=30000,
26 | grayscale=True,
27 | count_non_zero=False, # hotpixel statistics
28 | background_mask=True, # apply white background via alpha-masking
29 | ),
30 | augment=False,
31 | tta=False,
32 | ):
33 |
34 | # data augmentation
35 | self.augment = augment
36 | if augment:
37 | self.augmentation = RandAugment(
38 | num_ops=2, # follow common practice
39 | interpolation=InterpolationMode.BICUBIC, # CLIP uses bicubic
40 | fill=[255, 255, 255] # pad with white pixels
41 | if quantize_args['background_mask'] else [0, 0, 0],
42 | )
43 |
44 | # transforms to apply to the 2D images
45 | self.transforms = transforms
46 |
47 | # dataset that loads raw events in shape [N, 4 (x, y, t, p)]
48 | self.event_dataset = event_dataset
49 | self.classes = event_dataset.classes
50 | self.resolution = event_dataset.resolution
51 | self.max_t = event_dataset.max_t # timestamp
52 | self.max_n = event_dataset.max_n # number of events
53 | self.tta = tta
54 | if tta:
55 | assert not event_dataset.augmentation, \
56 | 'Do not augment events in pseudo label generation'
57 | assert not augment, 'Do not augment twice'
58 | assert event_dataset.num_shots is None, 'Should sample all data'
59 | assert 'train' in event_dataset.root, \
60 | 'Generate pseudo labels only on training set'
61 | print('Apply h- and t-flip TTA in pseudo label generation')
62 |
63 | # arguments for mapping events to 2D images
64 | self.quantize_args = copy.deepcopy(quantize_args)
65 | self.quantize_args['shape'] = self.resolution
66 |
67 | self.split_method = quantize_args['split_method']
68 | self.event_rep = quantize_args['convert_method']
69 | assert self.split_method == 'event_count'
70 | max_imgs = round(self.max_n / quantize_args['N'])
71 | max_max_imgs = quantize_args.pop('max_imgs', 10) # hard limit
72 | self.max_imgs = max(min(max_imgs, max_max_imgs), 1)
73 |
74 | # a hack in visualization to also load the raw events data
75 | self.keep_events = False
76 |
77 | def __len__(self):
78 | return len(self.event_dataset)
79 |
80 | def _subsample_imgs(self, imgs):
81 | """Randomly select a subset of images or pad with zeros."""
82 | valid_mask = torch.zeros(self.max_imgs).bool()
83 | if len(imgs) > self.max_imgs:
84 | valid_mask[:] = True
85 | idxs = torch.randperm(len(imgs))[:self.max_imgs]
86 | imgs = imgs[idxs]
87 | else:
88 | valid_mask[:len(imgs)] = True
89 | pad = torch.zeros(
90 | (self.max_imgs - len(imgs), *imgs.shape[1:])).type_as(imgs)
91 | imgs = torch.cat([imgs, pad], dim=0)
92 | return imgs, valid_mask
93 |
94 | def _load_tta_data(self, idx):
95 | """Apply h- and t-flip to the loaded events, then convert to images."""
96 | data_dict = self.event_dataset[idx]
97 | events = data_dict.pop('events')
98 | assert not self.keep_events, 'val dataset should not be TTA'
99 | h_events = hflip_events(
100 | copy.deepcopy(events), resolution=self.resolution, p=1.)
101 | t_events = tflip_events(copy.deepcopy(events), p=1.)
102 | h_t_events = tflip_events(copy.deepcopy(h_events), p=1.)
103 | tta_events = [events, h_events, t_events, h_t_events]
104 | tta_imgs, tta_valid_mask = [], []
105 | for events in tta_events:
106 | imgs, valid_mask = self._event2img(events)
107 | tta_imgs.append(imgs)
108 | tta_valid_mask.append(valid_mask)
109 | data_dict['img'] = torch.stack(tta_imgs, dim=0) # [4, N, 3, H, W]
110 | data_dict['valid_mask'] = torch.stack(tta_valid_mask, dim=0) # [4, N]
111 | # `label` is still just an integer
112 | return data_dict
113 |
114 | def _event2img(self, events):
115 | """Convert events to 2D images."""
116 | # events: [N, 4 (x, y, t, p)]
117 | # get [N, H, W, 3] images with dtype np.uint8
118 | imgs = events2frames(events, **self.quantize_args)
119 | imgs = [Image.fromarray(img) for img in imgs]
120 | if self.augment:
121 | imgs = self.augmentation(imgs)
122 | imgs = torch.stack([self.transforms(img) for img in imgs])
123 | # to [N, 3, H, W] torch.Tensor as model inputs
124 |
125 | # randomly select a subset of images or pad with zeros
126 | imgs, valid_mask = self._subsample_imgs(imgs)
127 |
128 | return imgs, valid_mask
129 |
130 | def __getitem__(self, idx):
131 | if self.tta:
132 | return self._load_tta_data(idx)
133 |
134 | data_dict = self.event_dataset[idx]
135 | events = data_dict.pop('events')
136 |
137 | if self.keep_events:
138 | data_dict['events'] = copy.deepcopy(events)
139 |
140 | imgs, valid_mask = self._event2img(events)
141 |
142 | data_dict['img'] = imgs
143 | data_dict['valid_mask'] = valid_mask
144 |
145 | return data_dict
146 |
147 |
148 | def build_event2img_dataset(params, event_dataset, augment=False, tta=False):
149 | """Wrap an event dataset with a Event2Image processing pipeline."""
150 | return Event2ImageDataset(
151 | transforms=params.data_transforms,
152 | event_dataset=event_dataset,
153 | quantize_args=params.quantize_args,
154 | augment=augment,
155 | tta=tta,
156 | )
157 |
--------------------------------------------------------------------------------
/method.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import wandb
3 | import numpy as np
4 |
5 | import torch
6 | import torch.optim as optim
7 | import torch.cuda.amp as amp
8 |
9 | from nerv.utils import AverageMeter, MeanMetric
10 | from nerv.training import BaseMethod, CosineAnnealingWarmupRestarts
11 |
12 | from datasets import events2frames
13 |
14 |
15 | def denormalize(x):
16 | """Reverse the input image normalization."""
17 | mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).type_as(x)
18 | std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).type_as(x)
19 | return x * std[None, :, None, None] + mean[None, :, None, None]
20 |
21 |
22 | def build_method(**kwargs):
23 | params = kwargs['params']
24 | if params.model in ['ZSCLIP', 'FSCLIP', 'FTCLIP']:
25 | return EventCLIPMethod(**kwargs)
26 | else:
27 | raise NotImplementedError(f'{params.model} method is not implemented.')
28 |
29 |
30 | class EventBaseMethod(BaseMethod):
31 | """Base method in this project."""
32 |
33 | def _round(self, v, n):
34 | if isinstance(v, (float, int)):
35 | return float(round(v, n))
36 | return type(v)([self._round(i, n) for i in v])
37 |
38 | @staticmethod
39 | def _convert_video(video, caption=None):
40 | """Convert torch.FloatTensor video to wandb.Video."""
41 | assert isinstance(video, torch.Tensor)
42 | video = denormalize(video)
43 | video = (video * 255.).cpu().numpy()
44 | video = np.round(video).clip(0, 255).astype(np.uint8)
45 | return wandb.Video(video, fps=2, caption=caption)
46 |
47 | @staticmethod
48 | def _get_sample_idx(N, dst):
49 | """Load data uniformly from the dataset."""
50 | dst_len = len(dst)
51 | N = N - 1 if dst_len % N != 0 else N
52 | sampled_idx = torch.arange(0, dst_len, dst_len // N)
53 | return sampled_idx.numpy()
54 |
55 | @torch.no_grad()
56 | @amp.autocast()
57 | def validation_epoch(self, model, san_check_step=-1, sample_events=True):
58 | """Validate one epoch.
59 | We aggregate the avg of all statistics and only log once.
60 | """
61 | out_dict = super().validation_epoch(
62 | model, san_check_step=san_check_step)
63 | if self.local_rank != 0:
64 | return
65 | # visualization after every epoch
66 | if sample_events:
67 | self._sample_events(model)
68 |
69 | return out_dict
70 |
71 | @staticmethod
72 | def event2video(events, caption=None, **quantize_args):
73 | """Convert events to wandb loggable videos."""
74 | imgs = events2frames(events, **quantize_args).astype(np.uint8)
75 | imgs = np.ascontiguousarray(imgs.transpose(0, 3, 1, 2))
76 | # add a black border to the video
77 | T, C, H, W = imgs.shape
78 | video = np.zeros((T, C, H + 8, W + 8), dtype=np.uint8)
79 | video[:, :, 4:-4, 4:-4] = imgs
80 | return wandb.Video(video, fps=2, caption=caption)
81 |
82 | def _configure_optimizers(self):
83 | """Returns an optimizer, a scheduler and its frequency (step/epoch)."""
84 | optimizer = super()._configure_optimizers()[0]
85 |
86 | lr = self.params.lr
87 | total_steps = self.params.max_epochs * len(self.train_loader)
88 | warmup_steps = self.params.warmup_steps_pct * total_steps
89 |
90 | scheduler = CosineAnnealingWarmupRestarts(
91 | optimizer,
92 | total_steps,
93 | max_lr=lr,
94 | min_lr=lr / 100.,
95 | warmup_steps=warmup_steps,
96 | )
97 |
98 | return optimizer, (scheduler, 'step')
99 |
100 |
101 | class EventCLIPMethod(EventBaseMethod):
102 |
103 | @torch.no_grad()
104 | def _sample_events(self, model):
105 | """model is a simple nn.Module, not warpped in e.g. DataParallel."""
106 | model.eval()
107 | dst = self.val_loader.dataset
108 | dst.keep_events = True
109 | classes = dst.classes
110 | quantize_args = copy.deepcopy(dst.quantize_args)
111 | quantize_args['background_mask'] = True
112 |
113 | sampled_idx = self._get_sample_idx(self.params.n_samples, dst)
114 | log_dict = {}
115 | for i, idx in enumerate(sampled_idx):
116 | data_dict = dst[idx]
117 | events, label = data_dict.pop('events'), data_dict.pop('label')
118 | img, valid_mask = data_dict['img'], data_dict['valid_mask']
119 | # [N, 3, H, W], [N]
120 | in_dict = {
121 | 'img': img[None].to(model.device),
122 | 'valid_mask': valid_mask[None].to(model.device),
123 | }
124 | probs = model(in_dict)['probs'][0] # [n_cls]
125 |
126 | # keep the topk predictions
127 | k = min(3, probs.shape[-1])
128 | topk = probs.topk(k, dim=-1)
129 | idxs, probs = \
130 | topk.indices.cpu().numpy(), topk.values.cpu().numpy()
131 | caption = f'GT: {classes[label]}\n' + '\t'.join([
132 | f'{classes[idx]}: {prob:.4f}'
133 | for idx, prob in zip(idxs, probs)
134 | ])
135 |
136 | # visualize the events
137 | # raw events
138 | raw_events = self.event2video(
139 | events, caption=caption, **quantize_args)
140 | log_dict[f'val/raw_events_{i}'] = raw_events
141 | # model inputs
142 | img = img[valid_mask]
143 | video = self._convert_video(img, caption=caption)
144 | log_dict[f'val/video_{i}'] = video
145 |
146 | wandb.log(log_dict, step=self.it)
147 | torch.cuda.empty_cache()
148 | dst.keep_events = False
149 |
150 | def _configure_optimizers(self):
151 | """Returns an optimizer, a scheduler and its frequency (step/epoch)."""
152 | if self.params.model != 'FTCLIP':
153 | return super()._configure_optimizers()
154 |
155 | # use smaller lr for finetuning CLIP
156 | if self.params.optimizer.lower() == 'adam':
157 | opt = optim.Adam
158 | elif self.params.optimizer.lower() == 'adamw':
159 | opt = optim.AdamW
160 | else:
161 | raise ValueError('Should use Adam or AdamW optimizer!')
162 | assert self.params.weight_decay == 0.
163 | lr = self.params.lr
164 | clip_lr = self.params.clip_lr
165 | name = 'model.visual'
166 |
167 | adapter_params = list(
168 | filter(lambda kv: name not in kv[0] and kv[1].requires_grad,
169 | self.model.named_parameters()))
170 | clip_params = list(
171 | filter(lambda kv: name in kv[0] and kv[1].requires_grad,
172 | self.model.named_parameters()))
173 | # assert len(adapter_params) > 0 and len(clip_params) > 0
174 | params_list = [{
175 | 'params': [kv[1] for kv in adapter_params],
176 | 'lr': lr,
177 | }, {
178 | 'params': [kv[1] for kv in clip_params],
179 | 'lr': clip_lr,
180 | }]
181 |
182 | optimizer = opt(params_list)
183 | total_steps = self.params.max_epochs * len(self.train_loader)
184 | warmup_steps = self.params.warmup_steps_pct * total_steps
185 | scheduler = CosineAnnealingWarmupRestarts(
186 | optimizer,
187 | total_steps,
188 | max_lr=(lr, clip_lr),
189 | min_lr=(lr / 100., clip_lr / 100.),
190 | warmup_steps=warmup_steps,
191 | )
192 |
193 | return optimizer, (scheduler, 'step')
194 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | """EventCLIP testing script"""
2 |
3 | import os
4 | import sys
5 | import importlib
6 | import argparse
7 |
8 | from tqdm import tqdm
9 |
10 | import torch
11 |
12 | import clip
13 |
14 | from nerv.training import BaseDataModule
15 | from nerv.utils import AverageMeter
16 |
17 | from models import build_model
18 | from datasets import build_dataset
19 |
20 |
21 | @torch.no_grad()
22 | def main(params, printing=True):
23 | # have to load CLIP model first
24 | arch = params.clip_dict['arch']
25 | device = 'cuda'
26 | model, preprocess = clip.load(arch, device=device)
27 |
28 | # build dataset
29 | params.data_transforms = preprocess
30 | if args.subset > 0:
31 | test_set = build_dataset(params, val_only=True, subset=args.subset)
32 | else:
33 | test_set = build_dataset(params, val_only=True)
34 | is_nin = (params.dataset == 'n_imagenet')
35 |
36 | datamodule = BaseDataModule(
37 | params, train_set=None, val_set=test_set, use_ddp=False)
38 | test_loader = datamodule.val_loader
39 |
40 | # build model
41 | params.clip_dict['clip_model'] = model
42 | params.clip_dict['class_names'] = test_set.classes
43 | if not is_zs:
44 | params.adapter_dict['in_dim'] = model.visual.output_dim
45 | model = build_model(params)
46 |
47 | # load weight
48 | # don't load for zero-shot models
49 | if args.weight and not is_zs:
50 | model.load_weight(args.weight)
51 | print(f'Loading weight: {args.weight}')
52 | model = model.cuda().eval()
53 |
54 | # test
55 | probs_acc_meter, logits_acc_meter = AverageMeter(), AverageMeter()
56 | if is_nin:
57 | probs_acc5_meter, logits_acc5_meter = AverageMeter(), AverageMeter()
58 |
59 | for data_dict in tqdm(test_loader):
60 | data_dict = {k: v.cuda() for k, v in data_dict.items()}
61 | out_dict = model(data_dict)
62 | labels = data_dict['label']
63 |
64 | # based on aggregated probs
65 | probs = out_dict['probs']
66 | probs_acc = (probs.argmax(dim=-1) == labels).float().mean().item()
67 | probs_acc_meter.update(probs_acc, labels.shape[0])
68 |
69 | # based on aggregated logits
70 | logits = out_dict['logits']
71 | logits_acc = (logits.argmax(dim=-1) == labels).float().mean().item()
72 | logits_acc_meter.update(logits_acc, labels.shape[0])
73 |
74 | # top5 accuracy
75 | if is_nin:
76 | probs_acc5 = (probs.topk(5, dim=-1).indices == labels[:, None]).\
77 | float().sum(dim=-1).mean().item()
78 | probs_acc5_meter.update(probs_acc5, labels.shape[0])
79 | logits_acc5 = (logits.topk(5, dim=-1).indices == labels[:, None]).\
80 | float().sum(dim=-1).mean().item()
81 | logits_acc5_meter.update(logits_acc5, labels.shape[0])
82 |
83 | if not printing:
84 | return probs_acc_meter.avg, logits_acc_meter.avg
85 |
86 | print(f'\n\nTesting {args.params}')
87 | print(f'Model weight: {args.weight}')
88 | print(f'\tProbs-based accuracy@1: {probs_acc_meter.avg * 100.:.2f}%')
89 | print(f'\tLogits-based accuracy@1: {logits_acc_meter.avg * 100.:.2f}%\n')
90 | if not is_nin:
91 | return
92 | print(f'\tProbs-based accuracy@5: {probs_acc5_meter.avg * 100.:.2f}%')
93 | print(f'\tLogits-based accuracy@5: {logits_acc5_meter.avg * 100.:.2f}%\n')
94 |
95 |
96 | if __name__ == "__main__":
97 | parser = argparse.ArgumentParser(description='EventCLIP')
98 | parser.add_argument('--params', type=str, required=True)
99 | parser.add_argument('--weight', type=str, default='', help='load weight')
100 | parser.add_argument('--N', type=int, default=-1)
101 | parser.add_argument('--arch', type=str, default='')
102 | parser.add_argument('--prompt', type=str, default='')
103 | parser.add_argument('--bs', type=int, default=-1)
104 | parser.add_argument('--subset', type=int, default=-1)
105 | parser.add_argument('--train_shots', nargs='+', default=[-1], type=int)
106 | args = parser.parse_args()
107 |
108 | if args.params.endswith('.py'):
109 | args.params = args.params[:-3]
110 | sys.path.append(os.path.dirname(args.params))
111 | params = importlib.import_module(os.path.basename(args.params))
112 | params = params.EventCLIPParams()
113 |
114 | # adjust params
115 | is_zs = (params.model == 'ZSCLIP')
116 | if args.N > 0:
117 | params.quantize_args['N'] = int(args.N * 1e3)
118 | assert is_zs, 'can only change N in zero-shot testing'
119 | if args.arch:
120 | params.clip_dict['arch'] = args.arch
121 | assert is_zs, 'can only change ViT arch in zero-shot testing'
122 | if args.prompt:
123 | params.clip_dict['prompt'] = args.prompt
124 | assert is_zs, 'can only change text prompt in zero-shot testing'
125 | if args.bs > 0:
126 | params.val_batch_size = args.bs
127 | if args.subset > 0:
128 | assert params.dataset == 'n_imagenet', 'only N-ImageNet has subsets'
129 |
130 | # automatically find the model weight if `train_shots` is provided
131 | # we will ignore the provided `args.weight`
132 | # instead search for 'checkpoint/$PARAMS-${NUM}shot/models/model_xxx.pth'
133 | if args.train_shots[0] <= 0:
134 | main(params)
135 | exit(-1)
136 |
137 | all_probs_acc, all_logits_acc = [], []
138 | for num_shot in args.train_shots:
139 | # first, find all dup-run dirs
140 | dup_weight_dir = os.path.join('checkpoint',
141 | os.path.basename(args.params))
142 | all_weight_dirs = [f'{dup_weight_dir}-{num_shot}shot']
143 | for i in range(1, 11, 1): # at most dup 10 times
144 | weight_dir = f'{dup_weight_dir}-dup{i}-{num_shot}shot'
145 | if os.path.exists(weight_dir):
146 | all_weight_dirs.append(weight_dir)
147 |
148 | # average the accuracy over all weights
149 | probs_acc_avg, logits_acc_avg = AverageMeter(), AverageMeter()
150 |
151 | # now, for each weight_dir, find a weight to test
152 | for weight_dir in all_weight_dirs:
153 | if not os.path.exists(weight_dir):
154 | continue
155 | weight_dir = os.path.join(weight_dir, 'models')
156 |
157 | # load the best weight if it is saved
158 | if os.path.exists(os.path.join(weight_dir, 'best.pth')):
159 | args.weight = os.path.join(weight_dir, 'best.pth')
160 | # find the latest one
161 | else:
162 | all_weights = [
163 | w for w in os.listdir(weight_dir) if w.endswith('.pth')
164 | ]
165 | all_weights = sorted(
166 | all_weights, key=lambda x: int(x[:-4].split('_')[1]))
167 | args.weight = os.path.join(weight_dir, all_weights[-1])
168 |
169 | probs_acc, logits_acc = main(params, printing=False)
170 | probs_acc_avg.update(probs_acc, 1)
171 | logits_acc_avg.update(logits_acc, 1)
172 |
173 | # print the results for this `num_shot`
174 | print(f'\n\nTesting {args.params}-{num_shot}shot')
175 | print(f'Average accuracy over {probs_acc_avg.count} runs:')
176 | print(f'\tProbs-based accuracy@1: {probs_acc_avg.avg * 100.:.2f}%')
177 | print(f'\tLogits-based accuracy@1: {logits_acc_avg.avg * 100.:.2f}%\n')
178 | all_probs_acc.append(round(probs_acc_avg.avg * 100., 2))
179 | all_logits_acc.append(round(logits_acc_avg.avg * 100., 2))
180 |
181 | # print the results for recording & LaTeX
182 | print('\n\n')
183 | print(f'Probs-based accuracy@1: {all_probs_acc}')
184 | print('\t', ' & '.join([str(acc) for acc in all_probs_acc]))
185 | print(f'Logits-based accuracy@1: {all_logits_acc}')
186 | print('\t', ' & '.join([str(acc) for acc in all_logits_acc]))
187 |
--------------------------------------------------------------------------------
/datasets/augment.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Dict, List, Optional, Tuple
3 |
4 | import torch
5 | from torch import Tensor
6 | from torchvision.transforms import functional as F
7 | from torchvision.transforms import InterpolationMode
8 |
9 |
10 | def _apply_op(
11 | img: Tensor, op_name: str, magnitude: float, interpolation: InterpolationMode, fill: Optional[List[float]]
12 | ):
13 | if op_name == "ShearX":
14 | # magnitude should be arctan(magnitude)
15 | # official autoaug: (1, level, 0, 0, 1, 0)
16 | # https://github.com/tensorflow/models/blob/dd02069717128186b88afa8d857ce57d17957f03/research/autoaugment/augmentation_transforms.py#L290
17 | # compared to
18 | # torchvision: (1, tan(level), 0, 0, 1, 0)
19 | # https://github.com/pytorch/vision/blob/0c2373d0bba3499e95776e7936e207d8a1676e65/torchvision/transforms/functional.py#L976
20 | img = F.affine(
21 | img,
22 | angle=0.0,
23 | translate=[0, 0],
24 | scale=1.0,
25 | shear=[math.degrees(math.atan(magnitude)), 0.0],
26 | interpolation=interpolation,
27 | fill=fill,
28 | center=[0, 0],
29 | )
30 | elif op_name == "ShearY":
31 | # magnitude should be arctan(magnitude)
32 | # See above
33 | img = F.affine(
34 | img,
35 | angle=0.0,
36 | translate=[0, 0],
37 | scale=1.0,
38 | shear=[0.0, math.degrees(math.atan(magnitude))],
39 | interpolation=interpolation,
40 | fill=fill,
41 | center=[0, 0],
42 | )
43 | elif op_name == "TranslateX":
44 | img = F.affine(
45 | img,
46 | angle=0.0,
47 | translate=[int(magnitude), 0],
48 | scale=1.0,
49 | interpolation=interpolation,
50 | shear=[0.0, 0.0],
51 | fill=fill,
52 | )
53 | elif op_name == "TranslateY":
54 | img = F.affine(
55 | img,
56 | angle=0.0,
57 | translate=[0, int(magnitude)],
58 | scale=1.0,
59 | interpolation=interpolation,
60 | shear=[0.0, 0.0],
61 | fill=fill,
62 | )
63 | elif op_name == "Rotate":
64 | img = F.rotate(img, magnitude, interpolation=interpolation, fill=fill)
65 | elif op_name == "Brightness":
66 | img = F.adjust_brightness(img, 1.0 + magnitude)
67 | elif op_name == "Color":
68 | img = F.adjust_saturation(img, 1.0 + magnitude)
69 | elif op_name == "Contrast":
70 | img = F.adjust_contrast(img, 1.0 + magnitude)
71 | elif op_name == "Sharpness":
72 | img = F.adjust_sharpness(img, 1.0 + magnitude)
73 | elif op_name == "Posterize":
74 | img = F.posterize(img, int(magnitude))
75 | elif op_name == "Solarize":
76 | img = F.solarize(img, magnitude)
77 | elif op_name == "AutoContrast":
78 | img = F.autocontrast(img)
79 | elif op_name == "Equalize":
80 | img = F.equalize(img)
81 | elif op_name == "Invert":
82 | img = F.invert(img)
83 | elif op_name == "Identity":
84 | pass
85 | else:
86 | raise ValueError(f"The provided operator {op_name} is not recognized.")
87 | return img
88 |
89 |
90 | class RandAugment(torch.nn.Module):
91 | r"""RandAugment data augmentation method based on
92 | `"RandAugment: Practical automated data augmentation with a reduced search space"
93 | `_.
94 | If the image is torch Tensor, it should be of type torch.uint8, and it is expected
95 | to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
96 | If img is PIL Image, it is expected to be in mode "L" or "RGB".
97 |
98 | Args:
99 | num_ops (int): Number of augmentation transformations to apply sequentially.
100 | magnitude (int): Magnitude for all the transformations.
101 | num_magnitude_bins (int): The number of different magnitude values.
102 | interpolation (InterpolationMode): Desired interpolation enum defined by
103 | :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
104 | If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
105 | fill (sequence or number, optional): Pixel fill value for the area outside the transformed
106 | image. If given a number, the value is used for all bands respectively.
107 | """
108 |
109 | def __init__(
110 | self,
111 | num_ops: int = 2,
112 | interpolation: InterpolationMode = InterpolationMode.NEAREST,
113 | fill: Optional[List[float]] = None,
114 | ) -> None:
115 | super().__init__()
116 |
117 | self.num_ops = num_ops
118 | self.interpolation = interpolation
119 | self.fill = fill
120 |
121 | self.cur_ops = None
122 |
123 | def _augmentation_space(self, num_bins: int, image_size: Tuple[int, int]) -> Dict[str, Tuple[Tensor, bool]]:
124 | return {
125 | # op_name: (magnitudes, signed)
126 | "Identity": (torch.tensor(0.0), False),
127 | "ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
128 | "ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
129 | "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
130 | "TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
131 | "Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
132 | "Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
133 | "Color": (torch.linspace(0.0, 0.9, num_bins), True),
134 | "Contrast": (torch.linspace(0.0, 0.9, num_bins), True),
135 | "Sharpness": (torch.linspace(0.0, 0.9, num_bins), True),
136 | "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),
137 | "Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
138 | "AutoContrast": (torch.tensor(0.0), False),
139 | "Equalize": (torch.tensor(0.0), False),
140 | }
141 |
142 | def randomize_ops(self, resolution: Tuple[int, int]) -> None:
143 | """Randomly select `self.num_ops` augmentations to apply."""
144 | assert self.cur_ops is None, 'Unused RandAugment ops'
145 | self.cur_ops = []
146 | num_magnitude_bins = 30
147 | cur_magnitude = int(torch.randint(num_magnitude_bins, (1,)).item())
148 | op_meta = self._augmentation_space(num_magnitude_bins, resolution)
149 | for _ in range(self.num_ops):
150 | op_index = int(torch.randint(len(op_meta), (1,)).item())
151 | op_name = list(op_meta.keys())[op_index]
152 | magnitudes, signed = op_meta[op_name]
153 | magnitude = float(magnitudes[cur_magnitude].item()) if \
154 | magnitudes.ndim > 0 else 0.0
155 | if signed and torch.randint(2, (1,)):
156 | magnitude *= -1.0
157 | self.cur_ops.append((op_name, magnitude))
158 |
159 | def forward(self, imgs: List[Tensor]) -> Tensor:
160 | """Apply the same RandAugment ops to all images.
161 |
162 | Args:
163 | imgs (List[PIL Image or Tensor]): Image to be transformed.
164 |
165 | Returns:
166 | List[PIL Image or Tensor]: Transformed image.
167 | """
168 | channels, height, width = F.get_dimensions(imgs[0])
169 | self.randomize_ops((height, width))
170 | fill = self.fill
171 | if isinstance(imgs[0], Tensor):
172 | if isinstance(fill, (int, float)):
173 | fill = [float(fill)] * channels
174 | elif fill is not None:
175 | fill = [float(f) for f in fill]
176 | imgs = [self._forward(img, fill) for img in imgs]
177 | self.cur_ops = None
178 | return imgs
179 |
180 | def _forward(self, img: Tensor, fill: Optional[List[float]]) -> Tensor:
181 | """Apply augmentation to one image.
182 |
183 | Args:
184 | img (PIL Image or Tensor): Image to be transformed.
185 |
186 | Returns:
187 | PIL Image or Tensor: Transformed image.
188 | """
189 | assert len(self.cur_ops) == self.num_ops, 'Wrong number of RandAugment ops'
190 | for i in range(self.num_ops):
191 | op_name, magnitude = self.cur_ops[i]
192 | img = _apply_op(img, op_name, magnitude,
193 | interpolation=self.interpolation, fill=fill)
194 | return img
195 |
--------------------------------------------------------------------------------
/datasets/caltech.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os import listdir
3 | from os.path import join
4 | import random
5 |
6 | import numpy as np
7 |
8 | from torch.utils.data import Dataset
9 |
10 | from nerv.utils import load_obj, dump_obj
11 |
12 | from .utils import random_time_flip_events, random_shift_events, \
13 | random_flip_events_along_x, center_events
14 |
15 | # from https://github.com/KaiyangZhou/CoOp/blob/main/datasets/caltech101.py
16 | NEW_CNAMES = {
17 | "airplanes": "airplane",
18 | "Faces": "face", # actually doesn't exist
19 | "Faces_easy": "face",
20 | "Leopards": "leopard",
21 | "Motorbikes": "motorbike",
22 | "BACKGROUND_Google": "background", # random images, hard to categorize
23 | }
24 |
25 |
26 | def get_real_path(path):
27 | while os.path.islink(path):
28 | path = os.readlink(path)
29 | return path
30 |
31 |
32 | class NCaltech101(Dataset):
33 | """Dataset class for N-Caltech101 dataset."""
34 |
35 | def __init__(
36 | self,
37 | root,
38 | augmentation=False,
39 | num_shots=None,
40 | repeat=True,
41 | new_cnames=None,
42 | ):
43 | root = get_real_path(root)
44 | self.root = root
45 | self.classes = sorted(listdir(root))
46 | # TODO: a hack for identifying generated pseudo labeled datasets
47 | self.is_pseudo = 'pseudo' in root
48 | if self.is_pseudo:
49 | print('Using pseudo labeled dataset!')
50 |
51 | # data stats (computed from the test set)
52 | self.resolution = (180, 240)
53 | # t is very uniform, i.e. different samples have similar max_t
54 | # so just take the max (unit: second)
55 | self.max_t = 0.325
56 | # the number of events are VERY unbalanced, so instead of taking
57 | # the max, we take the 95th percentile
58 | self.max_n = 225000
59 |
60 | # data augmentation
61 | self.augmentation = augmentation
62 | self.flip_time = False
63 | self.max_shift = 20
64 |
65 | # few-shot cls
66 | self.num_shots = num_shots # number of labeled data per class
67 | self.few_shot = (num_shots is not None and num_shots > 0)
68 | if self.few_shot:
69 | assert 'train' in root.lower(), 'Only sample data in training set'
70 | self.repeat = repeat
71 |
72 | self.labeled_files, self.labels = self._get_sample_idx()
73 | assert len(self.labeled_files) == len(self.labels)
74 |
75 | # change some class names
76 | self.new_cnames = new_cnames
77 | if new_cnames is None:
78 | return
79 | for i in range(len(self.classes)):
80 | if self.classes[i] in new_cnames:
81 | new_name = new_cnames[self.classes[i]]
82 | print(f'Rename {self.classes[i]} to {new_name}')
83 | self.classes[i] = new_name
84 |
85 | def _get_sample_idx(self):
86 | """Load event file_name and label pairs."""
87 | # load pre-generated splits if available
88 | if self.few_shot and not self.is_pseudo:
89 | cur_dir = os.path.dirname(os.path.realpath(__file__))
90 | split_fn = os.path.join(
91 | cur_dir, 'files', self.__class__.__name__,
92 | f'{self.num_shots}shot-repeat={self.repeat}.pkl')
93 | if os.path.exists(split_fn):
94 | print(f'Loading pre-generated split from {split_fn}')
95 | splits = load_obj(split_fn) # Dict[event_fn: label]
96 | labeled_files = np.array(list(splits.keys()))
97 | labels = np.array(list(splits.values()))
98 | return labeled_files, labels
99 |
100 | labeled_files, labels = [], []
101 |
102 | # fix the random seed since we'll sample data
103 | random.seed(0)
104 | for i, c in enumerate(self.classes):
105 | cls_files = [
106 | get_real_path(join(self.root, c, f))
107 | for f in sorted(listdir(join(self.root, c)))
108 | ]
109 | if len(cls_files) == 0:
110 | print(f'Warning: class {c} has no data!')
111 | continue
112 |
113 | # randomly sample `num_shots` labeled data for each class
114 | if self.few_shot:
115 | if self.num_shots <= len(cls_files):
116 | lbl_files = random.sample(cls_files, k=self.num_shots)
117 | else:
118 | if self.repeat:
119 | lbl_files = random.choices(cls_files, k=self.num_shots)
120 | else:
121 | lbl_files = cls_files
122 | elif self.num_shots is None:
123 | lbl_files = cls_files
124 | else:
125 | raise ValueError(f'Invalid num_shots: {self.num_shots}')
126 | labeled_files += lbl_files
127 | labels += [i] * len(lbl_files)
128 |
129 | # save the splits for future use
130 | if self.few_shot and not self.is_pseudo:
131 | splits = {fn: lbl for fn, lbl in zip(labeled_files, labels)}
132 | os.makedirs(os.path.dirname(split_fn), exist_ok=True)
133 | dump_obj(splits, split_fn)
134 | print(f'Saving split file to {split_fn}')
135 |
136 | labeled_files = np.array(labeled_files)
137 | labels = np.array(labels)
138 | return labeled_files, labels
139 |
140 | def __len__(self):
141 | return len(self.labeled_files)
142 |
143 | def _rand_another(self):
144 | """Randomly sample another data."""
145 | idx = np.random.randint(0, len(self))
146 | return self.__getitem__(idx)
147 |
148 | @staticmethod
149 | def _load_events(event_path):
150 | """Load events from a file."""
151 | return np.load(event_path).astype(np.float32)
152 |
153 | def _augment_events(self, events):
154 | """Data augmentation on events."""
155 | if self.flip_time:
156 | events = random_time_flip_events(events)
157 | events = random_shift_events(
158 | events, max_shift=self.max_shift, resolution=self.resolution)
159 | events = random_flip_events_along_x(events, resolution=self.resolution)
160 | # not using time flip on N-Caltech and N-Cars dataset
161 | return events
162 |
163 | def __getitem__(self, idx):
164 | """
165 | returns events and label, potentially with augmentation
166 | :param idx: data_idx
167 | :return: [N, (x,y,t,p)], label, data_idx
168 | """
169 | f = str(self.labeled_files[idx])
170 | label = int(self.labels[idx])
171 | events = self._load_events(f)
172 | # the spatial resolution of N-Caltech events is 180x240
173 | # we should center the spatial coordinates of events
174 | # some events only reside in e.g. [0, 0] x [100, 160]
175 | # which will be largely removed after center crop!
176 | events = center_events(events, resolution=self.resolution)
177 |
178 | if self.augmentation:
179 | events = self._augment_events(events)
180 |
181 | if events.shape[0] == 0:
182 | return self._rand_another()
183 |
184 | # events: [N, 4 (x, y, t, p)], label: int
185 | # N is usually 1e5 ~ 1e6
186 | return {
187 | 'events': events,
188 | # 't': events[:, 2],
189 | 'label': label,
190 | 'data_idx': idx,
191 | }
192 |
193 |
194 | def build_n_caltech_dataset(params, val_only=False, gen_data=False):
195 | """Build the N-Caltech101 dataset."""
196 | # only build the test set
197 | if val_only:
198 | assert not gen_data, 'Only generate pseudo labels on the training set'
199 | return NCaltech101(
200 | root=os.path.join(params.data_root, 'testing'),
201 | augmentation=False,
202 | new_cnames=NEW_CNAMES,
203 | )
204 | # build the training set for pseudo label generation
205 | if gen_data:
206 | return NCaltech101(
207 | root=os.path.join(params.data_root, 'training'),
208 | augmentation=False,
209 | new_cnames=NEW_CNAMES,
210 | )
211 |
212 | # build the training set
213 | train_set = NCaltech101(
214 | root=os.path.join(params.data_root, 'training'),
215 | augmentation=True,
216 | num_shots=params.get('num_shots', None),
217 | repeat=params.get('repeat_data', True),
218 | new_cnames=NEW_CNAMES,
219 | )
220 | val_set = NCaltech101(
221 | root=os.path.join(params.data_root, 'testing'),
222 | augmentation=False,
223 | new_cnames=NEW_CNAMES,
224 | )
225 | return train_set, val_set
226 |
--------------------------------------------------------------------------------
/docs/benchmark.md:
--------------------------------------------------------------------------------
1 | # Benchmark
2 |
3 | ## Overview
4 |
5 | We provide instructions on reproducing the results reported in the paper, including:
6 |
7 | - Zero-shot EventCLIP on N-Caltech, N-Cars, and N-ImageNet datasets
8 | - Few-shot EventCLIP with text adapter (most stable), or joint adapter (best performing) on 3 datasets
9 | - Fine-tuning EventCLIP on N-Caltech and N-ImageNet to achieve SOTA performance
10 | - Learning with unlabeled data by self-training on generated pseudo labels on the N-ImageNet (Mini) dataset
11 |
12 | In the following instructions, we will mostly use **EventCLIP with joint adapter on N-Caltech under the 5-shot setting** as example.
13 | Other settings are easily replicable by changing the config file or other flags.
14 |
15 | ### Pre-trained Weights
16 |
17 | Since most of the experiments in the paper can be trained within 1-2 hours, we only provide pre-trained weights for long-running experiments, or those involving multi-step training.
18 | Please download the pre-trained weights from [Google Drive](https://drive.google.com/file/d/1QW7sn5BYjRdUe6xD_jQUQgSa9oIveq0s/view?usp=sharing) and unzip them under [pretrained/](../pretrained/).
19 |
20 | ## Training EventCLIP Feature Adapter
21 |
22 | **We provide a unified script [train.py](../train.py) to train all models used in this project.**
23 | You should always call it in the **root directory** of this repo (i.e. calling `python train.py xxx`).
24 |
25 | **All of the model training can be done by specifying a config file (called `params` here), and adding other args (e.g. `--num_shots`).**
26 | Please check the config file for the number of GPUs and other resources (e.g. `num_workers` CPUs) before launching a training.
27 |
28 | Here is one example:
29 |
30 | ```
31 | python train.py --params configs/fsclip/joint_adapter/joint_fsclip_ncaltech_params.py --num_shots 5 --fp16 --cudnn
32 | ```
33 |
34 | Other arguments include:
35 |
36 | - `--weight`: resume training from this weight
37 | - `--ddp`: use DDP multi-GPU training (needed when using `>=2` GPUs)
38 | - `--fp16`: enable half-precision training (highly recommended)
39 | - `--cudnn`: enable cudnn benchmark (highly recommended)
40 | - `--local_rank`/`--local-rank`: required by DDP, don't change it
41 |
42 | During training, model checkpoints and visualizations will be saved under `./checkpoint/$PARAMS/models/`.
43 |
44 | We provide config files for training EventCLIP under the few-shot classification setting:
45 |
46 | - [Text Adapter configs](../configs/fsclip/text_adapter/)
47 | - [Joint Adapter configs](../configs/fsclip/joint_adapter/)
48 |
49 | ## Testing
50 |
51 | Testing can be done with [test.py](../test.py).
52 | To test the above trained model, simply run:
53 |
54 | ```
55 | python test.py --params configs/fsclip/joint_adapter/joint_fsclip_ncaltech_params.py --weight $WEIGHT
56 | ```
57 |
58 | Or you can use these configs to test the zero-shot classification performance:
59 |
60 | - [Zero-shot configs](../configs/zsclip/)
61 |
62 | Other arguments in `test.py` include:
63 |
64 | - `--bs`: testing batch size
65 | - `--subset`: used to specify the N-ImageNet robustness variants to test. See their paper Appendix for a conversion between subset ID and the actual data variation
66 | - `--arch`: change CLIP's image encoder backbone in zero-shot testing
67 | - `--prompt`: change the text prompt in zero-shot testing
68 |
69 | We also provide a `--train_shots` argument to automatically gather results over different shots.
70 | If you train the same model with different `--num_shots` values in `train.py`, you can put all numbers of shots here to test them together.
71 | For example, you can run:
72 |
73 | ```
74 | python test.py --params configs/fsclip/joint_adapter/joint_fsclip_ncaltech_params.py --train_shots 20 10 5 3 1
75 | ```
76 |
77 | **Note that testing is always conducted over the entire test set without few-shot filtering.**
78 |
79 | ## Fine-tuning EventCLIP Full Model
80 |
81 | Fine-tuning EventCLIP is similar to training an adapter model.
82 | But they require more GPU memory and training time.
83 | We provide config files for fine-tuning EventCLIP under the few-shot and full data setting.
84 | Please refer to them for detailed training requirement:
85 |
86 | - [Fine-tuning configs](../configs/ftclip/)
87 |
88 | We provide the weight for our fine-tuned EventCLIP (with ViT-B/16 backbone) on the N-ImageNet dataset.
89 |
90 | ## Learning with Unlabeled Data
91 |
92 | To generate pseudo labels on unlabeled data, please use [gen_data.py](../gen_data.py).
93 | For example, you can use the zero-shot EventCLIP to generate pseudo labels on the N-ImageNet (Mini) dataset by:
94 |
95 | ```
96 | python gen_data.py --params configs/zsclip/zsclip_nin_mini_params-vitb32.py --weight '' \
97 | --conf_thresh 0.999 --tta --tta_min_prob --tta_consistent --topk 30 \
98 | --save_path data/pseudo-N_Imagenet/vitb32_zs-tta-thresh_0999-top30 \
99 | --gt_shots -1
100 | ```
101 |
102 | Here, we use a very high confidence threshold of `0.999` to filter predictions.
103 | This is because the pre-trained CLIP model always makes over-confident predictions (likely due to the learned temperature parameter `\tau`).
104 | Other arguments include:
105 |
106 | - `--tta`, `--tta_min_prob`, and `--tta_consistent` are the techniques introduced in the paper to further improve the label quality
107 | - `--topk 30` means we only select the top-30 most confident predictions for each class
108 | - `--save_path` indicates the path to save the generated dataset
109 | - `--gt_shots` specifies the number of labeled data used to train the model as `--weight`
110 | Since we are using the zero-shot model here, we set it to `-1` and `--weight` is empty
111 |
112 | If you want to study the semi-supervised setting, where we have `X` labeled data and all the remaining unlabeled data, you can first pre-train an EventCLIP with joint adapter using the [provided config file](../configs/fsclip/joint_adapter/joint_fsclip_nin_mini_params-vitb32.py).
113 | We provide the 1-, 3-, 5-, 10-, and 20-shot pre-trained weights in this setting.
114 | Then, run `gen_data.py` again, but use the joint adapter's config file as `--params`, `--gt_shots X`, and `--weight` pointing to the pre-trained model's weight.
115 | Also, please use a lower confidence threshold `--conf_thresh 0.5` as the few-shot EventCLIP is now calibrated.
116 | An example command is:
117 |
118 | ```
119 | python gen_data.py --params configs/fsclip/joint_adapter/joint_fsclip_nin_mini_params-vitb32.py \
120 | --weight pretrained/joint_fsclip_nin_mini_params-vitb32-1shot-pretrain.pth \
121 | --conf_thresh 0.5 --tta --tta_min_prob --tta_consistent --topk 30 \
122 | --save_path data/pseudo-N_Imagenet/vitb32_1shot-tta-thresh_05-top30 \
123 | --gt_shots 1
124 | ```
125 |
126 | Finally, to train on this generated dataset (i.e. self-training), please modify the [config file](../configs/fsclip/joint_adapter/joint_fsclip_nin_mini_params-vitb32.py)'s `data_root` field to the `save_path` above and run `train.py`.
127 | **Note that** you should set `--num_shots` to `X + topk`.
128 | This is because we select the `topk` most confident predictions per class as pseudo labels, plus the `X` GT labels per class to train the model.
129 |
130 | We provide the weight for our EventCLIP (with ViT-B/32 backbone) trained on zero-shot generated pseudo labels on the N-ImageNet (Mini) dataset.
131 |
132 | ## Useful Scripts
133 |
134 | We provide helper scripts for Slurm cluster job submission, and train/test over multiple settings.
135 |
136 | - You can use [sbatch_run.sh](../scripts/sbatch_run.sh) to automatically generate a sbatch file and submit the job to slurm.
137 | Simply run:
138 |
139 | ```
140 | GPUS=$NUM_GPU CPUS_PER_GPU=8 MEM_PER_CPU=5 QOS=$QOS \
141 | ./scripts/sbatch_run.sh $PARTITION $JOB_NAME \
142 | train.py none (if DDP then change `none` to `ddp`) --py_args...
143 | ```
144 |
145 | Again using the same example, we can set `--py_args...` as (see the config file for the number of GPU/CPU to use)
146 |
147 | ```
148 | --params configs/fsclip/joint_adapter/joint_fsclip_ncaltech_params.py \
149 | --num_shots 5 --fp16 --cudnn
150 | ```
151 |
152 | Then this will be equivalent to running the above `python train.py xxx` command in CLI.
153 |
154 | Note that `sbatch_run.sh` calls a `resubmit_failed_job.sh` script inside, which will monitor the job status and resubmit the job if it fails.
155 |
156 | - We provide a script to submit multiple runs of the same experiment with different random seeds to slurm.
157 | To use the script [dup_run_sbatch.sh](../scripts/dup_run_sbatch.sh), simply run:
158 |
159 | ```
160 | GPUS=$NUM_GPU CPUS_PER_GPU=8 MEM_PER_CPU=5 QOS=$QOS REPEAT=$NUM_REPEAT \
161 | ./scripts/dup_run_sbatch.sh $PARTITION $JOB_NAME \
162 | train.py none $PARAMS --py_args...
163 | ```
164 |
165 | The other parts are really the same as `sbatch_run.sh`.
166 | The only difference is that we need to input the config file `$PARAMS` separately, so that the script will make several copies to it, and submit different jobs.
167 |
168 | Again training the same model, duplicating `3` times, on `rtx6000` partition, and in the name of `joint_fsclip_ncaltech_params`, simply run:
169 |
170 | ```
171 | GPUS=1 CPUS_PER_GPU=8 MEM_PER_CPU=5 QOS=normal REPEAT=3 \
172 | ./scripts/dup_run_sbatch.sh rtx6000 joint_fsclip_ncaltech_params \
173 | train.py none \
174 | configs/fsclip/joint_adapter/joint_fsclip_ncaltech_params.py \
175 | --num_shots 5 --fp16 --cudnn
176 | ```
177 |
178 | The model weights will be saved under `./checkpoint/joint_fsclip_ncaltech_params-dup$X-5shot/`.
179 |
180 | - We provide scripts to train one EventCLIP under all five numbers of shots used in the paper [train_all_shots.sh](../scripts/train_all_shots.sh).
181 | For example, if you want to run the above `dup_run_sbatch.sh` command with `20, 10, 5, 3, 1` shots, simply wrap that command with this script by:
182 |
183 | ```
184 | ./scripts/train_all_shots.sh "GPUS=1 CPUS_PER_GPU=8 MEM_PER_CPU=5 QOS=normal REPEAT=3 ./scripts/dup_run_sbatch.sh rtx6000 joint_fsclip_ncaltech_params train.py none configs/fsclip/joint_adapter/joint_fsclip_ncaltech_params.py --fp16 --cudnn" 20 10 5 3 1
185 | ```
186 |
187 | The model weights will be saved under `./checkpoint/joint_fsclip_ncaltech_params-dup$X-$Yshot/`.
188 | See the `Testing` section above on how to efficiently test all these models with the `--train_shots` flag.
189 |
190 | - In zero-shot testing, we provide script to test EventCLIP with different ViT's image encoder architectures [test_all_arch.sh](../scripts/test_all_arch.sh).
191 | Simply wrap your testing command with this script, and add the arches you want to try
192 |
193 | - On N-ImageNet testing, we provide script to test EventCLIP over all robustness variants [test_all_subset.sh](../scripts/test_all_subset.sh).
194 | Simply wrap your testing command with this script, and add the subsets you want to try
195 |
--------------------------------------------------------------------------------
/models/clip_cls_ft.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn import functional as F
6 |
7 | import clip
8 |
9 | from nerv.training import BaseModel
10 |
11 | from .adapter import IdentityAdapter, TransformerAdapter
12 | from .lora import inject_trainable_lora
13 |
14 |
15 | class FTCLIPClassifier(BaseModel):
16 | """Finetune CLIP model for **few-shot** classification."""
17 |
18 | def __init__(
19 | self,
20 | adapter_dict=dict(
21 | adapter_type='text-identity',
22 | residual=True,
23 | ),
24 | clip_dict=dict(
25 | clip_model=None,
26 | prompt='a point cloud image of a {}',
27 | class_names=None,
28 | agg_func='sum',
29 | ),
30 | loss_dict=dict(
31 | use_logits_loss=True,
32 | use_probs_loss=False,
33 | ),
34 | ):
35 | super().__init__()
36 |
37 | self.clip_dict = clip_dict
38 | self.loss_dict = loss_dict
39 | self.adapter_dict = copy.deepcopy(adapter_dict)
40 |
41 | self._build_clip()
42 | self._build_loss()
43 | self._build_adapter()
44 |
45 | def _build_clip(self):
46 | # freeze the CLIP model
47 | model = self.clip_dict['clip_model']
48 | for p in model.parameters():
49 | p.requires_grad = False
50 | # LoRA fine-tuning
51 | lora = self.clip_dict.get('lora', -1)
52 | if isinstance(lora, str) or lora > 0:
53 | model.visual = inject_trainable_lora(model.visual, r=lora)
54 | # finetune CLIP.visual or its sub-layers
55 | conv1 = self.clip_dict['only_conv1']
56 | bias = self.clip_dict['only_bias']
57 | ln = self.clip_dict['only_ln']
58 | cls_fc = self.clip_dict.get('only_cls_fc', False)
59 | cls_token = self.clip_dict.get('only_cls_token', False)
60 | if conv1: # only tune the first conv layer
61 | for p in model.visual.conv1.parameters():
62 | p.requires_grad = True
63 | if bias: # only tune the bias terms
64 | for name, p in model.visual.named_parameters():
65 | if 'bias' in name and p is not None:
66 | p.requires_grad = True
67 | if ln: # only tune the LayerNorm layers
68 | for m in model.visual.modules():
69 | if isinstance(m, nn.LayerNorm):
70 | for p in m.parameters():
71 | p.requires_grad = True
72 | if cls_fc: # only tune the final projection head
73 | model.visual.proj.requires_grad = True
74 | if cls_token: # only tune the CLS token
75 | model.visual.class_embedding.requires_grad = True
76 | # tune all
77 | if (isinstance(lora, int) and lora <= 0) and \
78 | not (conv1 or bias or ln or cls_fc or cls_token):
79 | for p in model.visual.parameters():
80 | p.requires_grad = True
81 | # set as eval
82 | self.model = model.eval()
83 | self.logit_scale = model.logit_scale.exp().item()
84 |
85 | # text prompt for zero-shot cls
86 | self.prompt = self.clip_dict['prompt']
87 | self.class_names = self.clip_dict['class_names']
88 | self.text_feats = None
89 |
90 | # aggregation function
91 | self.agg_func = self.clip_dict['agg_func']
92 | assert self.agg_func in ['sum', 'mean', 'max']
93 |
94 | def _build_loss(self):
95 | self.use_logits_loss = self.loss_dict['use_logits_loss']
96 | self.use_probs_loss = self.loss_dict['use_probs_loss']
97 | assert int(self.use_logits_loss) + int(self.use_probs_loss) == 1
98 |
99 | def _build_prompts(self, adapter_type):
100 | """Build the text features for prompt tuning."""
101 | with torch.no_grad():
102 | text_feats = self._get_text_feats().float() # [n_classes, C]
103 | self.text_feats = nn.Parameter(text_feats, requires_grad=True)
104 | adapter_type = adapter_type[5:]
105 | return adapter_type
106 |
107 | def _build_adapter(self):
108 | # whether to tune the text features as well
109 | adapter_type = self.adapter_dict.pop('adapter_type').lower()
110 | if adapter_type.startswith('text-'):
111 | print('Tune text features as well!')
112 | self.prompt_tuning = True
113 | adapter_type = self._build_prompts(adapter_type)
114 | else:
115 | self.prompt_tuning = False
116 |
117 | # image feature adapter
118 | self.adapter_type = adapter_type
119 | assert adapter_type == 'identity'
120 | if adapter_type == 'identity': # not tuning image features
121 | model = IdentityAdapter
122 | elif adapter_type == 'trans': # Transformer to fuse image features
123 | model = TransformerAdapter
124 | else:
125 | raise NotImplementedError(f'adapter {adapter_type} not supported!')
126 | self.adapter = model(**self.adapter_dict)
127 |
128 | def _same_class_names(self, class_names):
129 | """Check if the input `class_names` matches `self.class_names`."""
130 | return all([c1 == c2 for c1, c2 in zip(class_names, self.class_names)])
131 |
132 | def _get_text_feats(self, class_names=None):
133 | """Compute the text prompt features using CLIP text encoder."""
134 | # no `class_names` provided
135 | if class_names is None:
136 | no_cls_flag = True
137 | class_names = self.class_names
138 | # with cached `text_feats`
139 | if self.text_feats is not None:
140 | return self.text_feats
141 | # `class_names` matches
142 | elif self._same_class_names(class_names):
143 | # with cached `text_feats`
144 | if self.text_feats is not None:
145 | return self.text_feats
146 |
147 | # compute the text prompt features
148 | class_names = [c.lower().replace('_', ' ') for c in class_names]
149 | prompts = torch.cat([
150 | clip.tokenize(self.prompt.format(c)) for c in class_names
151 | ]).to(self.device)
152 | text_feats = self.model.encode_text(prompts)
153 | text_feats = F.normalize(text_feats, p=2, dim=-1)
154 |
155 | # cache the `text_feats` if
156 | # 1) the `class_names` matches
157 | # 2) the `class_names` is not provided
158 | if no_cls_flag or self._same_class_names(class_names):
159 | self.text_feats = text_feats
160 |
161 | return text_feats # [n_classes, C]
162 |
163 | def get_text_feats(self, class_names=None):
164 | # finetune the text features (i.e. prompt tuning)
165 | if self.prompt_tuning:
166 | assert self.text_feats.requires_grad, 'prompt should be trainable!'
167 | text_feats = F.normalize(self.text_feats, p=2, dim=-1)
168 | # otherwise, we use fixed text features
169 | else:
170 | with torch.no_grad():
171 | text_feats = self._get_text_feats(class_names)
172 | return self._adjust_dtype(text_feats)
173 |
174 | def _get_img_feats(self, imgs):
175 | """Compute the image features using CLIP image encoder.
176 |
177 | Args:
178 | imgs (torch.Tensor): [B, C, H, W]
179 | """
180 | img_feats = self.model.encode_image(imgs)
181 | return img_feats # [B, C]
182 |
183 | def get_img_feats(self, imgs):
184 | img_feats = self._get_img_feats(imgs)
185 | return self._adjust_dtype(img_feats)
186 |
187 | def _aggregate_logits(self, logits, valid_masks):
188 | """Aggregate logits for each data.
189 |
190 | Args:
191 | logits (torch.Tensor): [B, T, n_classes]
192 | valid_masks (torch.Tensor): [B, T]
193 | """
194 | if self.agg_func == 'sum':
195 | logits = logits.sum(1)
196 | elif self.agg_func == 'mean':
197 | logits = logits.sum(1) / valid_masks.float().sum(1, keepdim=True)
198 | elif self.agg_func == 'max':
199 | # make invalid logits very small
200 | logits = logits - (1. - valid_masks.float()) * 1e6
201 | logits = logits.max(1)[0]
202 | else:
203 | raise NotImplementedError
204 | return logits
205 |
206 | def _aggregate_probs(self, logits, valid_masks):
207 | """This one always take the mean."""
208 | valid_masks = valid_masks.detach().float()
209 | probs = logits.softmax(dim=-1)
210 | probs = probs * valid_masks[..., None]
211 | probs = probs.sum(1) / valid_masks.sum(1, keepdim=True)
212 | return probs
213 |
214 | def forward(self, data_dict):
215 | """Forward function."""
216 | imgs = data_dict['img'] # [B, T, C, H, W], `T` is number of views
217 | valid_masks = data_dict['valid_mask'] # [B, T]
218 | B, T = valid_masks.shape
219 |
220 | # compute image features
221 | valid_imgs = imgs[valid_masks] # [N, C, H, W]
222 | img_feats = self.get_img_feats(valid_imgs) # [N, C]
223 |
224 | # update image features using adapter
225 | C = img_feats.shape[-1]
226 | full_img_feats = torch.zeros(B, T, C).type_as(img_feats)
227 | full_img_feats[valid_masks] = img_feats
228 | # full_img_feats = self.adapter(full_img_feats, valid_masks)
229 | # [B, T, C], multi-view image features
230 | # normalize the output features
231 | # all zeros vector will still be zeros after F.normalize()
232 | full_img_feats = F.normalize(
233 | full_img_feats, p=2, dim=-1).type_as(full_img_feats)
234 | # make invalid features zeros
235 | full_img_feats = full_img_feats * valid_masks.float().unsqueeze(-1)
236 |
237 | # compute text features
238 | # we may need to compute gradients w.r.t. text features
239 | # so we can't use torch.no_grad() here
240 | text_feats = self.get_text_feats() # [n_classes, C]
241 |
242 | # compute logits
243 | full_logits = (self.logit_scale * full_img_feats @ text_feats.T)
244 | # [B, T, n_cls], multi-view logits
245 |
246 | # convert to [B, n_cls] for loss computation!
247 | logits = self._aggregate_logits(full_logits, valid_masks)
248 | probs = self._aggregate_probs(full_logits, valid_masks)
249 |
250 | out_dict = {
251 | 'full_logits': full_logits, # [B, T, n_classes]
252 | 'valid_masks': valid_masks, # [B, T]
253 | 'logits': logits, # [B, n_classes]
254 | 'probs': probs, # [B, n_classes]
255 | }
256 | return out_dict
257 |
258 | def calc_train_loss(self, data_dict, out_dict):
259 | """Compute training loss."""
260 | labels = data_dict['label'] # [B]
261 | logits = out_dict['logits'] # [B, n_classes]
262 | probs = out_dict['probs'] # [B, n_classes]
263 | loss_dict = {}
264 | if self.use_logits_loss:
265 | loss_dict['ce_loss'] = F.cross_entropy(logits, labels)
266 | if self.use_probs_loss:
267 | probs = probs + 1e-6 # avoid nan
268 | loss_dict['ce_loss'] = F.nll_loss(probs.log(), labels)
269 | return loss_dict
270 |
271 | @torch.no_grad()
272 | def calc_eval_loss(self, data_dict, out_dict):
273 | """Loss computation in eval."""
274 | loss_dict = self.calc_train_loss(data_dict, out_dict)
275 |
276 | # also compute the cls accuracy
277 | labels = data_dict['label'] # [B]
278 | # based on aggregated probs
279 | probs = out_dict['probs'] # [B, n_classes]
280 | probs_acc = (probs.argmax(dim=-1) == labels).float().mean()
281 | loss_dict['probs_acc'] = probs_acc
282 | # based on aggregated logits
283 | logits = out_dict['logits'] # [B, n_classes]
284 | logits_acc = (logits.argmax(dim=-1) == labels).float().mean()
285 | loss_dict['logits_acc'] = logits_acc
286 | return loss_dict
287 |
288 | def _adjust_dtype(self, x):
289 | """CLIP model returns features in FP16.
290 | During training, torch.amp will help us handle this.
291 | However, during inference, we need to manually convert them to FP32.
292 | """
293 | if self.training:
294 | return x
295 | return x.type(self.dtype)
296 |
297 | @property
298 | def dtype(self):
299 | return self.adapter.dtype
300 |
301 | @property
302 | def device(self):
303 | return self.model.logit_scale.device
304 |
305 | def train(self, mode=True):
306 | nn.Module.train(self, mode)
307 | # keep CLIP in eval mode
308 | self.model.eval()
309 | # but adjust CLIP.visual
310 | self.model.visual.train(mode)
311 | return self
312 |
313 | def state_dict(self):
314 | """Remove CLIP weight (keep `model.visual`) from the state dict."""
315 | w = super().state_dict()
316 | w = {
317 | k: v
318 | for k, v in w.items()
319 | if ((not k.startswith('model.')) or k.startswith('model.visual.'))
320 | }
321 | return w
322 |
323 | def load_state_dict(self, state_dict, strict=True):
324 | """Don't load CLIP weight (load `model.visual`) from the state dict."""
325 | # load CLIP weight from the state dict
326 | clip_w = {
327 | f'model.{k}': v
328 | for k, v in self.model.state_dict().items()
329 | if not k.startswith('visual.')
330 | }
331 | assert all(k not in state_dict for k in clip_w)
332 | state_dict = {**clip_w, **state_dict}
333 | super().load_state_dict(state_dict, strict=strict)
334 |
--------------------------------------------------------------------------------
/gen_data.py:
--------------------------------------------------------------------------------
1 | """EventCLIP testing script"""
2 |
3 | import os
4 | import os.path as osp
5 | import sys
6 | import importlib
7 | import argparse
8 |
9 | from tqdm import tqdm
10 |
11 | import torch
12 |
13 | import clip
14 |
15 | from nerv.training import BaseDataModule
16 | from nerv.utils import AverageMeter, load_obj
17 |
18 | from models import build_model
19 | from datasets import build_dataset
20 |
21 |
22 | def get_real_path(path):
23 | while osp.islink(path):
24 | path = os.readlink(path)
25 | return path
26 |
27 |
28 | def get_folder_and_fn(path):
29 | return osp.join(osp.basename(osp.dirname(path)), osp.basename(path))
30 |
31 |
32 | def find_key_from_value(d, v):
33 | for k, v_ in d.items():
34 | if v_ == v:
35 | return k
36 | return None
37 |
38 |
39 | def print_stats(class_names, gt_class_cnt, sel_class_cnt,
40 | sel_correct_class_cnt):
41 | print('\nClass stats:')
42 | for k in class_names:
43 | gt_num, sel_num, correct_num = \
44 | gt_class_cnt[k], sel_class_cnt[k], sel_correct_class_cnt[k]
45 | print(f'\t{k}: GT {gt_num}, select {sel_num}, {correct_num} correct')
46 | print('Not accurate classes')
47 | less_accurate_cnt = 0
48 | for k in class_names:
49 | gt_num, sel_num, correct_num = \
50 | gt_class_cnt[k], sel_class_cnt[k], sel_correct_class_cnt[k]
51 | ratio = correct_num / sel_num if sel_num > 0 else 0.
52 | if ratio < 0.5:
53 | print(f'\t{k}: GT {gt_num}, select {correct_num}/{sel_num} -- {ratio:.2f}')
54 | less_accurate_cnt += 1
55 | print(f'Not accurate classes: {less_accurate_cnt}/{len(class_names)}')
56 |
57 | total_num = sum(gt_class_cnt.values())
58 | select_num = sum(sel_class_cnt.values())
59 | select_correct_num = sum(sel_correct_class_cnt.values())
60 | sel_acc = select_correct_num / select_num * 100. if select_num > 0 else 0.
61 | print(f'\nUsing {args.conf_thresh=}')
62 | if args.topk > 0:
63 | print(f'Using {args.topk=}')
64 | print(f'\tSelect {select_num} from {total_num}, Acc={sel_acc:.2f}%')
65 | if args.tta:
66 | print(f'Using TTA with {args.tta_consistent=} + {args.tta_min_prob=}')
67 |
68 |
69 | @torch.no_grad()
70 | def main(params):
71 | # have to load CLIP model first
72 | arch = params.clip_dict['arch']
73 | device = 'cuda'
74 | model, preprocess = clip.load(arch, device=device)
75 |
76 | # build training dataset for generating pseudo labels
77 | params.data_transforms = preprocess
78 | tta = args.tta
79 | is_nin = ('n_imagenet' in params.dataset)
80 | if not is_nin:
81 | assert params.dataset == 'n_caltech', f'{params.dataset} not supported'
82 | print(f'Generate pseudo labels for {params.dataset}')
83 | test_set = build_dataset(params, val_only=False, gen_data=True, tta=tta)
84 | ev_dst = test_set.event_dataset
85 | class_names, labels = test_set.classes, ev_dst.labels
86 |
87 | datamodule = BaseDataModule(
88 | params, train_set=None, val_set=test_set, use_ddp=False)
89 | test_loader = datamodule.val_loader
90 |
91 | # build model
92 | params.clip_dict['clip_model'] = model
93 | params.clip_dict['class_names'] = test_set.classes
94 | if not is_zs:
95 | params.adapter_dict['in_dim'] = model.visual.output_dim
96 | model = build_model(params)
97 |
98 | # load weight
99 | gt_data = {} # we might have some labeled data
100 | if args.weight:
101 | assert not is_zs, 'Zero-shot models should not have pre-trained weight'
102 | model.load_weight(args.weight)
103 | print(f'Loading weight: {args.weight}')
104 | # load labeled data, we won't generate pseudo labels for them
105 | assert args.gt_shots > 0, \
106 | 'Should specify the num_shots used to pre-train the model'
107 | assert f'{args.gt_shots}shot' in args.weight or \
108 | f'{args.gt_shots}-shot' in args.weight, \
109 | f'Weight {args.weight} does not match `{args.gt_shots}-shot`'
110 | assert f'{args.gt_shots}shot' in save_path or \
111 | f'{args.gt_shots}-shot' in save_path, \
112 | f'Should put `{args.gt_shots}shot` in `save_path`'
113 | split_fn = osp.join('./datasets/files', ev_dst.__class__.__name__,
114 | f'{args.gt_shots}shot-repeat=True.pkl')
115 | gt_split = load_obj(split_fn) # Dict[event_fn (str): label (int)]
116 | # convert to Dict[event_fn (str): class_name (str)]
117 | gt_data = {k: class_names[v] for k, v in gt_split.items()}
118 | gt_data_paths = [get_folder_and_fn(k) for k in gt_data.keys()]
119 | model = model.cuda().eval()
120 |
121 | # test
122 | all_acc_meter = AverageMeter()
123 | gt_class_cnt = {k: (labels == i).sum() for i, k in enumerate(class_names)}
124 | sel_class_cnt = {k: 0 for k in class_names}
125 | sel_correct_class_cnt = {k: 0 for k in class_names}
126 | pred_path2cls = {}
127 |
128 | conf_thresh, topk = args.conf_thresh, args.topk
129 | for data_dict in tqdm(test_loader):
130 | data_idx = data_dict.pop('data_idx').numpy() # [B]
131 | data_dict = {k: v.cuda() for k, v in data_dict.items()}
132 | if tta: # loaded data in shape [B, 4, N, ...]
133 | data_dict['img'] = data_dict['img'].flatten(0, 1)
134 | data_dict['valid_mask'] = data_dict['valid_mask'].flatten(0, 1)
135 | out_dict = model(data_dict)
136 | labels = data_dict['label'] # [B]
137 |
138 | # based on aggregated probs
139 | pred_probs = out_dict['probs']
140 | # aggregate probs from multi-view TTA predictions
141 | if tta:
142 | probs = pred_probs.unflatten(0, (-1, 4)) # [B, 4, n_cls]
143 | tta_mask = torch.ones_like(labels).bool() # [B]
144 | # predictions over 4 views should be consistent
145 | if args.tta_consistent:
146 | pred_cls = probs.argmax(dim=-1) # [B, 4]
147 | tta_mask &= (pred_cls[:, 0] == pred_cls[:, 1]) & \
148 | (pred_cls[:, 0] == pred_cls[:, 2]) & \
149 | (pred_cls[:, 0] == pred_cls[:, 3])
150 | # the minimum confidence should be larger than conf_thresh
151 | if args.tta_min_prob:
152 | min_probs = probs.max(-1).values.min(-1).values
153 | tta_mask &= (min_probs > conf_thresh)
154 | probs = probs.mean(dim=1) # [B, n_cls]
155 | else:
156 | probs = pred_probs
157 | probs_acc = (probs.argmax(dim=-1) == labels).float().mean().item()
158 | all_acc_meter.update(probs_acc, labels.shape[0])
159 |
160 | # only trust probs > conf_thresh
161 | max_probs, pred_labels = probs.max(dim=-1)
162 | sel_mask = (max_probs > conf_thresh)
163 | if tta:
164 | sel_mask &= tta_mask
165 | # update class cnt
166 | for i, (lbl, pred_lbl) in enumerate(zip(labels, pred_labels)):
167 | ev_path = str(ev_dst.labeled_files[data_idx[i]])
168 | # skip labeled data as they will be included later anyway
169 | if get_folder_and_fn(ev_path) in gt_data_paths:
170 | continue
171 | pred_cls_name = class_names[pred_lbl.item()]
172 | if sel_mask[i].item():
173 | sel_class_cnt[pred_cls_name] += 1
174 | if pred_lbl.item() == lbl.item():
175 | sel_correct_class_cnt[pred_cls_name] += 1
176 | if sel_mask[i].item():
177 | if topk > 0: # also record the probs, take top-k later
178 | pred_path2cls[ev_path] = {
179 | 'cls': pred_cls_name,
180 | 'prob': max_probs[i].item(),
181 | }
182 | else:
183 | pred_path2cls[ev_path] = pred_cls_name
184 |
185 | print_stats(class_names, gt_class_cnt, sel_class_cnt,
186 | sel_correct_class_cnt)
187 | print(f'\n\nTesting {args.params}')
188 | if args.weight:
189 | print(f'Model weight: {args.weight}')
190 | print(f'\tProbs-based accuracy@1: {all_acc_meter.avg * 100.:.2f}%')
191 |
192 | if not save_path:
193 | return
194 | # save pseudo labels to a new dataset
195 | train_path = osp.join(save_path, 'extracted_train') if \
196 | is_nin else osp.join(save_path, 'training')
197 | assert not osp.exists(save_path), f'{save_path} already exists!'
198 | os.makedirs(train_path, exist_ok=True)
199 | # some classes might be renamed
200 | new_cnames = ev_dst.new_cnames
201 | # only take `topk` predictions for each class
202 | if topk > 0:
203 | topk_pred_path2cls, sel_class_cnt, sel_correct_class_cnt = {}, {}, {}
204 | for cls_name in class_names:
205 | sel_correct_class_cnt[cls_name] = 0
206 | cls_pred_paths, cls_pred_probs = [], []
207 | # find data that are classified as this class
208 | for path, pred in pred_path2cls.items():
209 | if pred['cls'] == cls_name:
210 | cls_pred_paths.append(path)
211 | cls_pred_probs.append(pred['prob'])
212 | cls_pred_probs = torch.tensor(cls_pred_probs)
213 | k = min(topk, cls_pred_probs.shape[0])
214 | _, topk_idx = cls_pred_probs.topk(k)
215 | for i in topk_idx:
216 | path = cls_pred_paths[i] # get the GT label from path
217 | gt_cls_name = osp.basename(osp.dirname(path))
218 | if is_nin:
219 | gt_cls_name = ev_dst.folder2name[gt_cls_name]
220 | if new_cnames is not None:
221 | gt_cls_name = new_cnames.get(gt_cls_name, gt_cls_name)
222 | if gt_cls_name == cls_name:
223 | sel_correct_class_cnt[cls_name] += 1
224 | topk_pred_path2cls[path] = cls_name
225 | sel_class_cnt[cls_name] = k
226 | pred_path2cls = topk_pred_path2cls
227 | print_stats(class_names, gt_class_cnt, sel_class_cnt,
228 | sel_correct_class_cnt)
229 | # update data with GT labels
230 | pred_path2cls.update(gt_data)
231 | # save pseudo labels
232 | for path, pred_cls in pred_path2cls.items():
233 | path = get_real_path(path)
234 | # path: xxx/N-Caltech101/training/airplanes/airplanes_150.npy
235 | # xxx/N_Imagenet/extracted_train/n02114855/n02114855_15515.npz
236 | # pred_cls is a semantic class name
237 | # some class names might have been altered
238 | if new_cnames is not None:
239 | ori_cls = find_key_from_value(new_cnames, pred_cls)
240 | if ori_cls is not None:
241 | pred_cls = ori_cls
242 | if is_nin:
243 | folder_name = ev_dst.name2folder[pred_cls]
244 | else:
245 | folder_name = pred_cls
246 | ev_name = osp.basename(path)
247 | # save to train_path/folder_name/ev_name
248 | # use soft link to save disk space
249 | new_path = osp.join(train_path, folder_name, ev_name)
250 | os.makedirs(osp.dirname(new_path), exist_ok=True)
251 | os.symlink(path, new_path)
252 | # also soft-link val/test set if they exist
253 | if is_nin:
254 | val_path = osp.join(save_path, 'extracted_val')
255 | ori_val_path = osp.join(osp.dirname(ev_dst.root), 'extracted_val')
256 | ori_val_path = get_real_path(ori_val_path)
257 | os.symlink(ori_val_path, val_path)
258 | else:
259 | val_path = osp.join(save_path, 'validation')
260 | test_path = osp.join(save_path, 'testing')
261 | ori_val_path = osp.join(osp.dirname(ev_dst.root), 'validation')
262 | ori_test_path = osp.join(osp.dirname(ev_dst.root), 'testing')
263 | ori_val_path = get_real_path(ori_val_path)
264 | ori_test_path = get_real_path(ori_test_path)
265 | os.symlink(ori_val_path, val_path)
266 | os.symlink(ori_test_path, test_path)
267 | print(f'\nSaved pseudo labels to {save_path}')
268 | # some classes don't have any pseudo labels
269 | # we still create the folder for consistency
270 | for k in class_names:
271 | if new_cnames is not None:
272 | ori_cls = find_key_from_value(new_cnames, k)
273 | if ori_cls is not None:
274 | k = ori_cls
275 | if is_nin:
276 | folder_name = ev_dst.name2folder[k]
277 | else:
278 | folder_name = k
279 | folder_path = osp.join(train_path, folder_name)
280 | os.makedirs(folder_path, exist_ok=True)
281 |
282 |
283 | if __name__ == "__main__":
284 | parser = argparse.ArgumentParser(description='EventCLIP')
285 | parser.add_argument('--params', type=str, required=True)
286 | parser.add_argument('--save_path', type=str, default='')
287 | parser.add_argument('--weight', type=str, default='', help='load weight')
288 | parser.add_argument('--conf_thresh', type=float, default=-1.)
289 | parser.add_argument('--tta', action='store_true')
290 | parser.add_argument('--tta_consistent', action='store_true')
291 | parser.add_argument('--tta_min_prob', action='store_true')
292 | parser.add_argument('--topk', type=int, default=-1)
293 | parser.add_argument('--gt_shots', type=int, default=-1) # labeled data
294 | args = parser.parse_args()
295 |
296 | if args.params.endswith('.py'):
297 | args.params = args.params[:-3]
298 | sys.path.append(osp.dirname(args.params))
299 | params = importlib.import_module(osp.basename(args.params))
300 | params = params.EventCLIPParams()
301 |
302 | # adjust params
303 | is_zs = (params.model == 'ZSCLIP')
304 | save_path = args.save_path
305 | if save_path:
306 | assert not osp.exists(save_path), f'{save_path} already exists!'
307 |
308 | main(params)
309 | exit(-1)
310 |
--------------------------------------------------------------------------------
/models/clip_cls.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn import functional as F
6 |
7 | import clip
8 |
9 | from nerv.training import BaseModel
10 |
11 | from .adapter import IdentityAdapter, TransformerAdapter
12 |
13 |
14 | class ZSCLIPClassifier(BaseModel):
15 | """CLIP model for **zero-shot** classification."""
16 |
17 | def __init__(
18 | self,
19 | clip_dict=dict(
20 | clip_model=None,
21 | prompt='a point cloud image of a {}',
22 | class_names=None,
23 | agg_func='sum',
24 | ),
25 | loss_dict=dict(
26 | use_logits_loss=True,
27 | use_probs_loss=False,
28 | ),
29 | ):
30 | super().__init__()
31 |
32 | self.clip_dict = clip_dict
33 | self.loss_dict = loss_dict
34 |
35 | self._build_clip()
36 | self._build_loss()
37 |
38 | def _build_clip(self):
39 | # freeze the CLIP model
40 | model = self.clip_dict['clip_model']
41 | for p in model.parameters():
42 | p.requires_grad = False
43 | self.model = model.eval()
44 | self.logit_scale = model.logit_scale.exp().item()
45 |
46 | # text prompt for zero-shot cls
47 | self.prompt = self.clip_dict['prompt']
48 | self.class_names = self.clip_dict['class_names']
49 | self.text_feats = None
50 |
51 | # aggregation function
52 | self.agg_func = self.clip_dict['agg_func']
53 | assert self.agg_func in ['sum', 'mean', 'max']
54 |
55 | def _build_loss(self):
56 | self.use_logits_loss = self.loss_dict['use_logits_loss']
57 | self.use_probs_loss = self.loss_dict['use_probs_loss']
58 | assert int(self.use_logits_loss) + int(self.use_probs_loss) == 1
59 |
60 | def _same_class_names(self, class_names):
61 | """Check if the input `class_names` matches `self.class_names`."""
62 | return all([c1 == c2 for c1, c2 in zip(class_names, self.class_names)])
63 |
64 | def get_text_feats(self, class_names=None):
65 | """Compute the text prompt features using CLIP text encoder."""
66 | # no `class_names` provided
67 | if class_names is None:
68 | no_cls_flag = True
69 | class_names = self.class_names
70 | # with cached `text_feats`
71 | if self.text_feats is not None:
72 | return self.text_feats
73 | # `class_names` matches
74 | elif self._same_class_names(class_names):
75 | # with cached `text_feats`
76 | if self.text_feats is not None:
77 | return self.text_feats
78 |
79 | # compute the text prompt features
80 | class_names = [c.lower().replace('_', ' ') for c in class_names]
81 | prompts = torch.cat([
82 | clip.tokenize(self.prompt.format(c)) for c in class_names
83 | ]).to(self.device)
84 | text_feats = self.model.encode_text(prompts)
85 | text_feats = F.normalize(text_feats, p=2, dim=-1)
86 |
87 | # cache the `text_feats` if
88 | # 1) the `class_names` matches
89 | # 2) the `class_names` is not provided
90 | if no_cls_flag or self._same_class_names(class_names):
91 | self.text_feats = text_feats
92 |
93 | return text_feats # [n_classes, C]
94 |
95 | def get_img_feats(self, imgs):
96 | """Compute the image features using CLIP image encoder.
97 |
98 | Args:
99 | imgs (torch.Tensor): [B, C, H, W]
100 | """
101 | img_feats = self.model.encode_image(imgs)
102 | return img_feats # [B, C]
103 |
104 | def _aggregate_logits(self, logits, valid_masks):
105 | """Aggregate logits for each data.
106 |
107 | Args:
108 | logits (torch.Tensor): [B, T, n_classes]
109 | valid_masks (torch.Tensor): [B, T]
110 | """
111 | if self.agg_func == 'sum':
112 | logits = logits.sum(1)
113 | elif self.agg_func == 'mean':
114 | logits = logits.sum(1) / valid_masks.float().sum(1, keepdim=True)
115 | elif self.agg_func == 'max':
116 | # make invalid logits very small
117 | logits = logits - (1. - valid_masks.float()) * 1e6
118 | logits = logits.max(1)[0]
119 | else:
120 | raise NotImplementedError
121 | return logits
122 |
123 | def _aggregate_probs(self, logits, valid_masks):
124 | """This one always take the mean."""
125 | valid_masks = valid_masks.detach().float()
126 | probs = logits.softmax(dim=-1)
127 | probs = probs * valid_masks[..., None]
128 | probs = probs.sum(1) / valid_masks.sum(1, keepdim=True)
129 | return probs
130 |
131 | def forward(self, data_dict):
132 | """Forward function."""
133 | imgs = data_dict['img'] # [B, T, C, H, W]
134 | valid_masks = data_dict['valid_mask'] # [B, T]
135 | B, T = valid_masks.shape
136 |
137 | # compute image features
138 | with torch.no_grad():
139 | valid_imgs = imgs[valid_masks] # [N, C, H, W]
140 | img_feats = self.get_img_feats(valid_imgs) # [N, C]
141 |
142 | # compute text features
143 | with torch.no_grad():
144 | text_feats = self.get_text_feats() # [n_classes, C]
145 | n_cls = text_feats.shape[0]
146 |
147 | # compute logits
148 | logits = (self.logit_scale * img_feats @ text_feats.T) # [N, n_cls]
149 |
150 | # map logits back to [B, T, n_cls]
151 | full_logits = torch.zeros(B, T, n_cls).type_as(logits)
152 | full_logits[valid_masks] = logits
153 | logits = self._aggregate_logits(full_logits, valid_masks)
154 | probs = self._aggregate_probs(full_logits, valid_masks)
155 |
156 | out_dict = {
157 | 'full_logits': full_logits, # [B, T, n_classes]
158 | 'valid_masks': valid_masks, # [B, T]
159 | 'logits': logits, # [B, n_classes]
160 | 'probs': probs, # [B, n_classes]
161 | }
162 | return out_dict
163 |
164 | def calc_train_loss(self, data_dict, out_dict):
165 | """Compute training loss."""
166 | labels = data_dict['label'] # [B]
167 | logits = out_dict['logits'] # [B, n_classes]
168 | probs = out_dict['probs'] # [B, n_classes]
169 | loss_dict = {}
170 | if self.use_logits_loss:
171 | loss_dict['ce_loss'] = F.cross_entropy(logits, labels)
172 | if self.use_probs_loss:
173 | probs = probs + 1e-6 # avoid nan
174 | loss_dict['ce_loss'] = F.nll_loss(probs.log(), labels)
175 | return loss_dict
176 |
177 | @torch.no_grad()
178 | def calc_eval_loss(self, data_dict, out_dict):
179 | """Loss computation in eval."""
180 | loss_dict = self.calc_train_loss(data_dict, out_dict)
181 |
182 | # also compute the cls accuracy
183 | labels = data_dict['label'] # [B]
184 | # based on aggregated probs
185 | probs = out_dict['probs'] # [B, n_classes]
186 | probs_acc = (probs.argmax(dim=-1) == labels).float().mean()
187 | loss_dict['probs_acc'] = probs_acc
188 | # based on aggregated logits
189 | logits = out_dict['logits'] # [B, n_classes]
190 | logits_acc = (logits.argmax(dim=-1) == labels).float().mean()
191 | loss_dict['logits_acc'] = logits_acc
192 | return loss_dict
193 |
194 | @property
195 | def dtype(self):
196 | return self.model.logit_scale.dtype
197 |
198 | @property
199 | def device(self):
200 | return self.model.logit_scale.device
201 |
202 | def train(self, mode=True):
203 | nn.Module.train(self, mode)
204 | # keep CLIP in eval mode
205 | self.model.eval()
206 | return self
207 |
208 | def state_dict(self):
209 | """Remove CLIP weight from the state dict."""
210 | w = super().state_dict()
211 | w = {k: v for k, v in w.items() if not k.startswith('model.')}
212 | return w
213 |
214 | def load_state_dict(self, state_dict, strict=True):
215 | """Don't load CLIP weight from the state dict."""
216 | # load CLIP weight from the state dict
217 | clip_w = {f'model.{k}': v for k, v in self.model.state_dict().items()}
218 | state_dict = {**clip_w, **state_dict}
219 | super().load_state_dict(state_dict, strict=strict)
220 |
221 |
222 | class FSCLIPClassifier(ZSCLIPClassifier):
223 | """CLIP model for **few-shot** classification."""
224 |
225 | def __init__(
226 | self,
227 | adapter_dict=dict(
228 | # 'trans', 'identity'
229 | # 'text-{}' with the above: tune text features as FC weight
230 | adapter_type='trans',
231 | residual=True,
232 | ),
233 | clip_dict=dict(
234 | clip_model=None,
235 | prompt='a point cloud image of a {}',
236 | class_names=None,
237 | agg_func='sum',
238 | ),
239 | loss_dict=dict(
240 | use_logits_loss=False,
241 | use_probs_loss=True,
242 | ),
243 | ):
244 | super().__init__(
245 | clip_dict=clip_dict,
246 | loss_dict=loss_dict,
247 | )
248 |
249 | self.adapter_dict = copy.deepcopy(adapter_dict)
250 |
251 | self._build_adapter()
252 |
253 | def _build_prompts(self, adapter_type):
254 | """Build the text features for prompt tuning."""
255 | with torch.no_grad():
256 | text_feats = super().get_text_feats().float() # [n_classes, C]
257 | self.text_feats = nn.Parameter(text_feats, requires_grad=True)
258 | adapter_type = adapter_type[5:]
259 | return adapter_type
260 |
261 | def _build_adapter(self):
262 | # whether to tune the text features as well
263 | adapter_type = self.adapter_dict.pop('adapter_type').lower()
264 | if adapter_type.startswith('text-'):
265 | print('Tune text features as well!')
266 | self.prompt_tuning = True
267 | adapter_type = self._build_prompts(adapter_type)
268 | else:
269 | self.prompt_tuning = False
270 |
271 | # image feature adapter
272 | self.adapter_type = adapter_type
273 | if adapter_type == 'identity': # not tuning image features
274 | model = IdentityAdapter
275 | elif adapter_type == 'trans': # Transformer to fuse image features
276 | model = TransformerAdapter
277 | else:
278 | raise NotImplementedError(f'adapter {adapter_type} not supported!')
279 | self.adapter = model(**self.adapter_dict)
280 |
281 | def _adjust_dtype(self, x):
282 | """CLIP model returns features in FP16.
283 | During training, torch.amp will help us handle this.
284 | However, during inference, we need to manually convert them to FP32.
285 | """
286 | if self.training:
287 | return x
288 | return x.type(self.dtype)
289 |
290 | def get_text_feats(self, class_names=None):
291 | # finetune the text features (i.e. prompt tuning)
292 | if self.prompt_tuning:
293 | assert self.text_feats.requires_grad or not self.training, \
294 | 'prompt should be trainable!'
295 | text_feats = F.normalize(self.text_feats, p=2, dim=-1)
296 | # otherwise, we use fixed text features
297 | else:
298 | with torch.no_grad():
299 | text_feats = super().get_text_feats(class_names)
300 | return self._adjust_dtype(text_feats)
301 |
302 | def get_img_feats(self, imgs):
303 | # fixed CLIP image features
304 | with torch.no_grad():
305 | img_feats = super().get_img_feats(imgs)
306 | return self._adjust_dtype(img_feats)
307 |
308 | def forward(self, data_dict):
309 | """Forward function."""
310 | imgs = data_dict['img'] # [B, T, C, H, W], `T` is number of views
311 | valid_masks = data_dict['valid_mask'] # [B, T]
312 | B, T = valid_masks.shape
313 |
314 | # compute image features
315 | valid_imgs = imgs[valid_masks] # [N, C, H, W]
316 | img_feats = self.get_img_feats(valid_imgs) # [N, C]
317 |
318 | # update image features using adapter
319 | C = img_feats.shape[-1]
320 | full_img_feats = torch.zeros(B, T, C).type_as(img_feats)
321 | full_img_feats[valid_masks] = img_feats
322 | full_img_feats = self.adapter(full_img_feats, valid_masks)
323 | # [B, T, C], multi-view image features
324 | # normalize the output features
325 | # all zeros vector will still be zeros after F.normalize()
326 | full_img_feats = F.normalize(
327 | full_img_feats, p=2, dim=-1).type_as(full_img_feats)
328 | # make invalid features zeros
329 | full_img_feats = full_img_feats * valid_masks.float().unsqueeze(-1)
330 |
331 | # compute text features
332 | # we may need to compute gradients w.r.t. text features
333 | # so we can't use torch.no_grad() here
334 | text_feats = self.get_text_feats() # [n_classes, C]
335 |
336 | # compute logits
337 | full_logits = (self.logit_scale * full_img_feats @ text_feats.T)
338 | # [B, T, n_cls], multi-view logits
339 |
340 | # convert to [B, n_cls] for loss computation!
341 | logits = self._aggregate_logits(full_logits, valid_masks)
342 | probs = self._aggregate_probs(full_logits, valid_masks)
343 |
344 | out_dict = {
345 | 'full_logits': full_logits, # [B, T, n_classes]
346 | 'valid_masks': valid_masks, # [B, T]
347 | 'logits': logits, # [B, n_classes]
348 | 'probs': probs, # [B, n_classes]
349 | }
350 | return out_dict
351 |
352 | @property
353 | def dtype(self):
354 | return self.adapter.dtype
355 |
--------------------------------------------------------------------------------
/models/lora.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from torch.nn import MultiheadAttention
6 |
7 |
8 | def lora_w_init_(lora_down, lora_up, r):
9 | """Initialize the LoRA weights."""
10 | nn.init.normal_(lora_down, std=1. / r)
11 | nn.init.zeros_(lora_up)
12 |
13 |
14 | class LoraInjectedLinear(nn.Module):
15 | """Apply LoRA to a standard Linear layer."""
16 |
17 | def __init__(self, linear, r=4):
18 | super().__init__()
19 |
20 | in_features = linear.in_features
21 | out_features = linear.out_features
22 | if r > min(in_features, out_features):
23 | raise ValueError(
24 | f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}"
25 | )
26 |
27 | self.r = r
28 | self.linear = linear
29 | for p in self.linear.parameters():
30 | p.requires_grad = False
31 | self.lora_down = nn.Linear(in_features, r, bias=False)
32 | self.lora_up = nn.Linear(r, out_features, bias=False)
33 |
34 | lora_w_init_(self.lora_down.weight, self.lora_up.weight, r)
35 |
36 | def forward(self, input):
37 | return (self.linear(input) + self.lora_up(self.lora_down(input)))
38 |
39 | @property
40 | def dtype(self):
41 | """Return the dtype of the projection weight."""
42 | return self.linear.weight.dtype
43 |
44 | @property
45 | def device(self):
46 | """Return the device of the projection weight."""
47 | return self.linear.weight.device
48 |
49 | @property
50 | def weight(self):
51 | """Return the LoRA adjusted weight."""
52 | return self.linear.weight + self.lora_up.weight @ self.lora_down.weight
53 |
54 | @property
55 | def bias(self):
56 | """Return the bias of the projection layer."""
57 | return self.linear.bias
58 |
59 |
60 | class LoraInjectedProj(nn.Module):
61 | """Apply LoRA on the projection head in MultiHeadAttention.
62 |
63 | The Q/K/V projection weight W is nn.Parameter(d_model, in_dim).
64 | We learn a lora_down [r, in_dim] and a lora_up [d_model, r].
65 | """
66 |
67 | def __init__(self, proj, r=4):
68 | super().__init__()
69 |
70 | d_model, in_dim = proj.shape
71 | if r > min(d_model, in_dim):
72 | raise ValueError(
73 | f"LoRA rank {r} must be less or equal than {min(d_model, in_dim)}"
74 | )
75 |
76 | self.d_model = d_model
77 | self.r = r
78 | self.proj = proj # original projection weight, nn.Parameter
79 | self.proj.requires_grad = False # freeze the original weight
80 |
81 | self.lora_down = nn.Parameter(torch.empty(r, in_dim))
82 | self.lora_up = nn.Parameter(torch.empty(d_model, r))
83 |
84 | lora_w_init_(self.lora_down, self.lora_up, r)
85 |
86 | def forward(self):
87 | """Return the LoRA updated weight."""
88 | return self.proj + self.lora_up @ self.lora_down
89 |
90 | @property
91 | def dtype(self):
92 | """Return the dtype of the projection weight."""
93 | return self.proj.dtype
94 |
95 | @property
96 | def device(self):
97 | """Return the device of the projection weight."""
98 | return self.proj.device
99 |
100 |
101 | class LoraInjectedMergedProj(nn.Module):
102 | """Apply LoRA on the **merged** projection head in MultiHeadAttention.
103 |
104 | The merged projection weight W is nn.Parameter(3*d_model, in_dim).
105 | We learn three (lora_down [r, in_dim] and lora_up [d_model, r]).
106 | """
107 |
108 | def __init__(self, merged_proj, r=4, lora_k=True):
109 | super().__init__()
110 |
111 | d_model_3, in_dim = merged_proj.shape
112 | assert d_model_3 % 3 == 0, "MergedProj's dim must be divisible by 3"
113 | d_model = d_model_3 // 3
114 | if r > min(d_model, in_dim):
115 | raise ValueError(
116 | f"LoRA rank {r} must be less or equal than {min(d_model, in_dim)}"
117 | )
118 |
119 | self.d_model = d_model
120 | self.r = r
121 | self.merged_proj = merged_proj # original projection weight
122 | self.merged_proj.requires_grad = False # freeze the original weight
123 |
124 | self.lora_down_q = nn.Parameter(torch.empty(r, in_dim))
125 | self.lora_up_q = nn.Parameter(torch.empty(d_model, r))
126 | self.lora_down_v = nn.Parameter(torch.empty(r, in_dim))
127 | self.lora_up_v = nn.Parameter(torch.empty(d_model, r))
128 |
129 | lora_w_init_(self.lora_down_q, self.lora_up_q, r)
130 | lora_w_init_(self.lora_down_v, self.lora_up_v, r)
131 |
132 | self.lora_k = lora_k
133 | if lora_k:
134 | self.lora_down_k = nn.Parameter(torch.empty(r, in_dim))
135 | self.lora_up_k = nn.Parameter(torch.empty(d_model, r))
136 | lora_w_init_(self.lora_down_k, self.lora_up_k, r)
137 |
138 | def forward(self):
139 | """Return the LoRA updated weight."""
140 | return torch.cat([
141 | self.merged_proj[:self.d_model] +
142 | self.lora_up_q @ self.lora_down_q,
143 | self.merged_proj[self.d_model:2 * self.d_model] +
144 | self.lora_up_k @ self.lora_down_k if self.lora_k else
145 | self.merged_proj[self.d_model:2 * self.d_model],
146 | self.merged_proj[2 * self.d_model:] +
147 | self.lora_up_v @ self.lora_down_v,
148 | ],
149 | dim=0)
150 |
151 | @property
152 | def dtype(self):
153 | """Return the dtype of the projection weight."""
154 | return self.merged_proj.dtype
155 |
156 | @property
157 | def device(self):
158 | """Return the device of the projection weight."""
159 | return self.merged_proj.device
160 |
161 |
162 | class LoraInjectedMHA(MultiheadAttention):
163 | """MultiHeadAttention with LoRA fine-tuning."""
164 |
165 | def forward(
166 | self,
167 | query,
168 | key,
169 | value,
170 | key_padding_mask=None,
171 | need_weights=True,
172 | attn_mask=None,
173 | average_attn_weights=True,
174 | ):
175 | is_batched = query.dim() == 3
176 | why_not_fast_path = ''
177 | if not is_batched:
178 | why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
179 | elif query is not key or key is not value:
180 | # When lifting this restriction, don't forget to either
181 | # enforce that the dtypes all match or test cases where
182 | # they don't!
183 | why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
184 | elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
185 | why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
186 | elif self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype:
187 | # this case will fail anyway, but at least they'll get a useful error message.
188 | why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
189 | elif self.training:
190 | why_not_fast_path = "training is enabled"
191 | elif not self.batch_first:
192 | why_not_fast_path = "batch_first was not True"
193 | elif self.bias_k is not None:
194 | why_not_fast_path = "self.bias_k was not None"
195 | elif self.bias_v is not None:
196 | why_not_fast_path = "self.bias_v was not None"
197 | elif self.dropout:
198 | why_not_fast_path = f"dropout was {self.dropout}, required zero"
199 | elif self.add_zero_attn:
200 | why_not_fast_path = "add_zero_attn was enabled"
201 | elif not self._qkv_same_embed_dim:
202 | why_not_fast_path = "_qkv_same_embed_dim was not True"
203 | elif query.is_nested and (key_padding_mask is not None
204 | or attn_mask is not None):
205 | why_not_fast_path = "key_padding_mask and attn_mask are not supported with NestedTensor input"
206 | elif not query.is_nested and key_padding_mask is not None and attn_mask is not None:
207 | why_not_fast_path = "key_padding_mask and attn_mask were both supplied"
208 |
209 | if not why_not_fast_path:
210 | tensor_args = (
211 | query,
212 | key,
213 | value,
214 | self.in_proj_weight(),
215 | self.in_proj_bias,
216 | self.out_proj.weight,
217 | self.out_proj.bias,
218 | )
219 | # We have to use list comprehensions below because TorchScript does not support
220 | # generator expressions.
221 | if torch.overrides.has_torch_function(tensor_args):
222 | why_not_fast_path = "some Tensor argument has_torch_function"
223 | elif not all([(x.is_cuda or 'cpu' in str(x.device))
224 | for x in tensor_args]):
225 | why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
226 | elif torch.is_grad_enabled() and any(
227 | [x.requires_grad for x in tensor_args]):
228 | why_not_fast_path = (
229 | "grad is enabled and at least one of query or the "
230 | "input/output projection weights or biases requires_grad")
231 | if not why_not_fast_path:
232 | return torch._native_multi_head_attention(
233 | query, key, value, self.embed_dim, self.num_heads,
234 | self.in_proj_weight(), self.in_proj_bias,
235 | self.out_proj.weight, self.out_proj.bias, key_padding_mask
236 | if key_padding_mask is not None else attn_mask,
237 | need_weights, average_attn_weights)
238 | any_nested = query.is_nested or key.is_nested or value.is_nested
239 | assert not any_nested, (
240 | "MultiheadAttention does not support NestedTensor outside of its fast path. "
241 | + f"The fast path was not hit because {why_not_fast_path}")
242 |
243 | if self.batch_first and is_batched:
244 | # make sure that the transpose op does not affect the "is" property
245 | if key is value:
246 | if query is key:
247 | query = key = value = query.transpose(1, 0)
248 | else:
249 | query, key = [x.transpose(1, 0) for x in (query, key)]
250 | value = key
251 | else:
252 | query, key, value = [
253 | x.transpose(1, 0) for x in (query, key, value)
254 | ]
255 |
256 | if not self._qkv_same_embed_dim:
257 | attn_output, attn_output_weights = F.multi_head_attention_forward(
258 | query,
259 | key,
260 | value,
261 | self.embed_dim,
262 | self.num_heads,
263 | self.in_proj_weight,
264 | self.in_proj_bias,
265 | self.bias_k,
266 | self.bias_v,
267 | self.add_zero_attn,
268 | self.dropout,
269 | self.out_proj.weight,
270 | self.out_proj.bias,
271 | training=self.training,
272 | key_padding_mask=key_padding_mask,
273 | need_weights=need_weights,
274 | attn_mask=attn_mask,
275 | use_separate_proj_weight=True,
276 | q_proj_weight=self.q_proj_weight(),
277 | k_proj_weight=self.k_proj_weight(),
278 | v_proj_weight=self.v_proj_weight(),
279 | average_attn_weights=average_attn_weights)
280 | else:
281 | attn_output, attn_output_weights = F.multi_head_attention_forward(
282 | query,
283 | key,
284 | value,
285 | self.embed_dim,
286 | self.num_heads,
287 | self.in_proj_weight(),
288 | self.in_proj_bias,
289 | self.bias_k,
290 | self.bias_v,
291 | self.add_zero_attn,
292 | self.dropout,
293 | self.out_proj.weight,
294 | self.out_proj.bias,
295 | training=self.training,
296 | key_padding_mask=key_padding_mask,
297 | need_weights=need_weights,
298 | attn_mask=attn_mask,
299 | average_attn_weights=average_attn_weights)
300 | if self.batch_first and is_batched:
301 | return attn_output.transpose(1, 0), attn_output_weights
302 | else:
303 | return attn_output, attn_output_weights
304 |
305 | def __setattr__(self, name, value):
306 | # special case hack for LoRA
307 | def remove_from(*dicts_or_sets):
308 | for d in dicts_or_sets:
309 | if name in d:
310 | if isinstance(d, dict):
311 | del d[name]
312 | else:
313 | d.discard(name)
314 |
315 | try:
316 | super().__setattr__(name, value)
317 | return
318 | except TypeError:
319 | pass
320 |
321 | assert isinstance(getattr(self, name), nn.Parameter) and isinstance(
322 | value, nn.Module)
323 | remove_from(self.__dict__, self._buffers, self._parameters,
324 | self._modules, self._non_persistent_buffers_set)
325 | modules = self.__dict__.get('_modules')
326 | modules[name] = value
327 |
328 |
329 | def build_lora_proj(proj, r):
330 | """Given an old projection, build a LoRA projection."""
331 | lora_proj = LoraInjectedProj(proj=proj, r=r)
332 | return lora_proj
333 |
334 |
335 | def build_lora_merged_proj(merged_proj, r, lora_k=True):
336 | """Given an old merged projection, build a LoRA merged projection."""
337 | lora_merged_proj = LoraInjectedMergedProj(
338 | merged_proj=merged_proj, r=r, lora_k=lora_k)
339 | return lora_merged_proj
340 |
341 |
342 | def build_lora_mha(mha, r):
343 | """Given an old MHA, build a LoRA-MHA."""
344 | lora_mha = LoraInjectedMHA(
345 | embed_dim=mha.embed_dim,
346 | num_heads=mha.num_heads,
347 | dropout=mha.dropout,
348 | kdim=mha.kdim,
349 | vdim=mha.vdim,
350 | batch_first=mha.batch_first,
351 | device=mha.out_proj.weight.device,
352 | dtype=mha.out_proj.weight.dtype,
353 | )
354 | lora_mha.load_state_dict(mha.state_dict())
355 | for p in lora_mha.parameters():
356 | p.requires_grad = False
357 | # parse LoRA arguments
358 | # if it's an int, then it's the dim
359 | # otherwise can be 'qv-$DIM', 'qkv-$DIM', 'qkvo-$DIM'
360 | # by default we apply LoRA to q,k,v projections, no output projection
361 | if not isinstance(r, int):
362 | assert 'q' in r and 'v' in r
363 | lora_k = ('k' in r)
364 | lora_o = ('o' in r)
365 | r = int(r.split('-')[-1])
366 | else:
367 | lora_k = True
368 | lora_o = False
369 | assert r > 0
370 | # inject LoRA to projection head weights
371 | if mha._qkv_same_embed_dim: # replace `in_proj_weight`
372 | lora_mha.in_proj_weight = build_lora_merged_proj(
373 | mha.in_proj_weight, r=r, lora_k=lora_k)
374 | else: # replace `q_proj_weight`, `k_proj_weight`, `v_proj_weight`
375 | lora_mha.q_proj_weight = build_lora_proj(mha.q_proj_weight, r)
376 | lora_mha.v_proj_weight = build_lora_proj(mha.v_proj_weight, r)
377 | if lora_k:
378 | lora_mha.k_proj_weight = build_lora_proj(mha.k_proj_weight, r)
379 | # inject LoRA to output projection head weights
380 | if lora_o:
381 | lora_mha.out_proj = LoraInjectedLinear(mha.out_proj, r=r)
382 | return lora_mha
383 |
384 |
385 | def inject_trainable_lora(model, r=4):
386 | """Replace all the MHA in `model` with LoRA-MHA."""
387 | to_replace = []
388 | for name, module in model.named_modules():
389 | if isinstance(module, nn.MultiheadAttention):
390 | lora_mha = build_lora_mha(module, r)
391 | to_replace.append((name, lora_mha))
392 | for name, lora_mha in to_replace:
393 | # cannot directly do `setattr(model, name, lora_mha)`
394 | # name like `transformer.resblocks.23.attn`, containing `.` and numbers
395 | # workaround: make it `transformer.resblocks[23].attn`
396 | # then use `eval()`
397 | for s in name.split('.'):
398 | if s.isdigit():
399 | name = name.replace(f'.{s}', f'[{s}]')
400 | assert name[-4:] == 'attn'
401 | m = eval(f'model.{name[:-5]}')
402 | m.attn = lora_mha
403 | return model
404 |
--------------------------------------------------------------------------------