├── .DS_Store
├── LICENSE
├── README.md
├── pic
├── Figure1.png
├── Table1.png
└── Table8-11.png
├── requirements.txt
└── src
├── .DS_Store
├── LE.py
├── classifier-tuning
├── .DS_Store
├── clip
│ ├── README.md
│ ├── bpe_simple_vocab_16e6.txt.gz
│ ├── clip.py
│ ├── model.py
│ └── tokenizer.py
└── src
│ ├── .DS_Store
│ ├── args.py
│ ├── ct_fsl.py
│ ├── ct_zsl.py
│ ├── datasets
│ ├── .DS_Store
│ ├── __init__.py
│ ├── cifar10.py
│ ├── cifar100.py
│ ├── common.py
│ ├── fmow.py
│ ├── imagenet.py
│ ├── imagenet_a.py
│ ├── imagenet_classnames.py
│ ├── imagenet_r.py
│ ├── imagenet_sketch.py
│ ├── imagenet_vid_robust.py
│ ├── imagenetv2.py
│ ├── iwildcam.py
│ ├── iwildcam_metadata
│ │ └── labels.csv
│ ├── objectnet.py
│ ├── objectnet_metadata
│ │ ├── folder_to_objectnet_label.json
│ │ ├── imagenet_to_label_2012_v2
│ │ ├── objectnet_to_imagenet_1k.json
│ │ └── pytorch_to_imagenet_2012_id.json
│ ├── transfer_ds
│ │ ├── Randaug.py
│ │ ├── __init__.py
│ │ ├── aircraft.py
│ │ ├── cal_mean_std.py
│ │ ├── caltech.py
│ │ ├── constants.py
│ │ ├── cub.py
│ │ ├── cub_for_robust_codebase.py
│ │ ├── cub_transform.py
│ │ ├── dtd.py
│ │ ├── fine_tunify.py
│ │ ├── food_101.py
│ │ ├── imbalance_cifar.py
│ │ ├── process_dataset
│ │ │ ├── pro_aircraft.py
│ │ │ ├── pro_caltech101.py
│ │ │ ├── pro_cars.py
│ │ │ ├── pro_flowers.py
│ │ │ ├── pro_imgnet.py
│ │ │ ├── pro_pool15.py
│ │ │ └── process_pets.py
│ │ ├── transfer_datasets.py
│ │ ├── transform_ckpt.py
│ │ └── utils.py
│ ├── ytbb-robust_metadata
│ │ ├── anchor_labels.json
│ │ ├── class_idx_map.json
│ │ ├── pmk_labels.json
│ │ ├── rev_class_idx_map.json
│ │ ├── ytbb_class_index.json
│ │ └── ytbb_robustness_test_anchors_full.csv
│ └── ytbb_robust.py
│ ├── get_classifier_weights.py
│ ├── models
│ ├── .DS_Store
│ ├── __init__.py
│ ├── eval.py
│ ├── finetune.py
│ ├── modeling.py
│ ├── utils.py
│ └── zeroshot.py
│ ├── select_glide_ims_by_clip.py
│ └── templates
│ ├── __init__.py
│ ├── fmow_template.py
│ ├── iwildcam_template.py
│ ├── openai_imagenet_template.py
│ ├── simple_template.py
│ ├── transfer_ds_template.py
│ └── utils.py
├── glide
├── .DS_Store
├── gen_fsl.sh
├── gen_zsl.sh
├── glide_fsl.py
└── glide_zsl.py
└── glide_text2im
├── .gitignore
├── __init__.py
├── clip
├── __init__.py
├── attention.py
├── config.yaml
├── encoders.py
├── model_creation.py
└── utils.py
├── download.py
├── fp16_util.py
├── gaussian_diffusion.py
├── model_creation.py
├── nn.py
├── respace.py
├── text2im_model.py
├── tokenizer
├── __init__.py
├── bpe.py
├── bpe_simple_vocab_16e6.txt.gz
├── encoder.json.gz
├── simple_tokenizer.py
└── vocab.bpe.gz
├── unet.py
└── xf.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CVMI-Lab/SyntheticData/062ac23b64aaa89020e1e41e5ca9d7719688d589/.DS_Store
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Is synthetic data from generative models ready for image recognition?
2 |
3 |
4 |
5 |
6 |
7 |
8 | Is synthetic data from generative models ready for image recognition? (ICLR 2023, Spotlight)
9 | By
10 | Ruifei He,
11 | Shuyang Sun,
12 | Xin Yu,
13 | Chuhui Xue,
14 | Wenqing Zhang,
15 | Philip Torr,
16 | Song Bai,
17 | Xiaojuan Qi.
18 |
19 |
20 |
21 |
22 | ## Abstract
23 |
24 | Recent text-to-image generation models have shown promising results in generating high-fidelity photo-realistic images. Though the results are astonishing to human eyes, how applicable these generated images are for recognition tasks remains under-explored. In this work, we extensively study whether and how synthetic images generated from state-of-the-art text-to-image generation models can be used for image recognition tasks, and focus on two perspectives: synthetic data for improving classification models in data-scarce settings ({\ie} zero-shot and few-shot), and synthetic data for large-scale model pre-training for transfer learning. We showcase the powerfulness and shortcomings of synthetic data from existing generative models, and propose strategies for better applying synthetic data for recognition tasks.
25 |
26 | 
27 |
28 | 
29 |
30 | 
31 |
32 |
33 |
34 | ## Getting started
35 |
36 | 1. Clone our repo: `git clone https://github.com/CVMI-Lab/SyntheticData.git`
37 |
38 | 2. Install dependencies:
39 | ```sh
40 | conda create -n SyntheticData python=3.7
41 | conda activate SyntheticData
42 | pip install -r requirements.txt
43 | ```
44 |
45 |
46 |
47 | ## Zero-shot settings
48 |
49 | ### Synthetic data generation
50 |
51 | #### Language Enhancement
52 |
53 | We generate sentences from label names of a specific dataset and save the generated sentences offline.
54 |
55 | Input the targeted label space in variable `labels` in file `src/LE.py` and run it like:
56 |
57 | ```sh
58 | python3.7 src/LE.py 200 /path/to/save/dataset.pkl
59 | ```
60 |
61 | where 200 is the number of sentence for each label, and the latter is the save path for the generated sentences.
62 |
63 | #### Text-to-Image generation
64 |
65 | We use [GLIDE](https://github.com/openai/glide-text2im) for text-to-image generation, and follow the [official instructions](https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb) for the generation process.
66 |
67 | We use text generated from language enhancement as prompts for the text-to-image generation.
68 |
69 | We provide a multi-gpu generation code example in `src/glide/glide_zsl.py` and run it like:
70 |
71 | ```sh
72 | sh glide/gen_zsl.sh /path/to/save/dataset.pkl /path/to/save/dataset
73 | ```
74 |
75 | #### CLIP Filter
76 |
77 | We use CLIP to help filter out unreliable images:
78 |
79 | ```sh
80 | # under dir: classifier-tuning
81 | python3.7 src/select_glide_ims_by_clip.py /path/to/synthetic/dataset 10 # 10 is the number of class for a given task
82 | ```
83 |
84 | ### Synthetic data for ZSL: Classifier-Tuning with CLIP
85 |
86 | We revise from the [Wise-ft](https://github.com/mlfoundations/wise-ft) codebase. Here, we provide a example for the Eurosat dataset.
87 |
88 | "model" could choose "RN50"/"ViT-B/16".
89 |
90 | Note that you should download the validation/test data for each dataset and revise the path in `src/classifier-tuning/src/dataset/transfer_datasets.py`.
91 |
92 | ```sh
93 | python3.7 src/ct_zsl.py \
94 | --freeze-encoder \
95 | --sl=0.5 \
96 | --sl_T=2 \
97 | --train-dataset=Eurosat \
98 | --save=/path/to/save/results \
99 | --epochs=30 \
100 | --lr=2e-3 \
101 | --wd=0.1 \
102 | --batch-size=512 \
103 | --warmup_length=0 \
104 | --cache-dir=cache \
105 | --model=RN50 \
106 | --eval-datasets=Eurosat \
107 | --template=eurosat_template \
108 | --results-db=results.jsonl \
109 | --data-location=/path/to/synthetic/data | tee results/${exp_name}/train-$now.log
110 | ```
111 |
112 |
113 |
114 | ## Few-shot settings
115 |
116 | ### Synthetic data generation-RG
117 |
118 | We provide the code for our proposed Real Guidance strategy. We would first obtain a set of few-shot images for a given task. You may need to revise the function `get_few_shot_images_path_prompt_pairs()` that returns a list of (im_path, prompt) in file `src/glide/glide_fsl.py`.
119 |
120 | Also, you should set the variable `refer_img_iters` to 15, 20, 35, 40, and 50 for shot 16, 8, 4, 2, and 1, respectively, and make the result of `batch_size * batch_size_time * shot =800`.
121 |
122 | We provide a multi-gpu generation code example in `src/glide/glide_fsl.py` and run it like:
123 |
124 | ```sh
125 | sh glide/gen_fsl.sh /path/to/few-shot/images /path/to/save/dataset
126 | ```
127 |
128 | ### Synthetic data for FSL: Classifier-Tuning with CLIP
129 |
130 | Again, we revise from the [Wise-ft](https://github.com/mlfoundations/wise-ft) codebase. Following is a example:
131 |
132 | ```sh
133 | python3.7 src/ct_fsl.py \
134 | --freeze-encoder \
135 | --sl=0.5 \
136 | --sl_T=2 \
137 | --train-dataset=Eurosat \
138 | --save=/path/to/save/results \
139 | --epochs=30 \
140 | --lr=1e-3 \
141 | --wd=0.1 \
142 | --batch-size-real=32 \
143 | --batch-size-syn=512 \
144 | --loss-weight=1.0 \
145 | --loss-weight-real=1.0 \
146 | --warmup_length=0 \
147 | --cache-dir=cache \
148 | --model=RN50 \
149 | --eval-datasets=Eurosat \
150 | --template=eurosat_template \
151 | --results-db=results.jsonl \
152 | --data-location=/path/to/synthetic/data \
153 | --data-location-real=/path/to/few-shot/data | tee results/${exp_name}/train-$now.log
154 |
155 | ```
156 |
157 |
158 |
159 | ## Pre-training settings
160 |
161 | ### Synthetic data generation
162 |
163 | We adopt language enhancement strategy only for pre-training setting. Please modify the files (`src/LE.py`, `src/glide/glide_zsl.py`) in zero-shot settings for generating synthetic pre-training data.
164 |
165 | ### Pre-training with synthetic data
166 |
167 | We recommend using [timm](https://github.com/rwightman/pytorch-image-models) codebase for its wonderful implementation for pre-training. For concrete hyper-parameters, please refer to Sec. C.5.3 in our Appendix.
168 |
169 |
170 |
171 | ## Citing this work
172 |
173 | If you find this repo useful for your research, please consider citing our paper:
174 |
175 | ```
176 | @article{he2022synthetic,
177 | title={Is synthetic data from generative models ready for image recognition?},
178 | author={He, Ruifei and Sun, Shuyang and Yu, Xin and Xue, Chuhui and Zhang, Wenqing and Torr, Philip and Bai, Song and Qi, Xiaojuan},
179 | journal={arXiv preprint arXiv:2210.07574},
180 | year={2022}
181 | }
182 | ```
183 |
184 |
185 |
186 |
187 | ## Acknowledgement
188 |
189 | We thank the open source code from [GLIDE](https://github.com/openai/glide-text2im), [CLIP](https://github.com/openai/CLIP), [keytotext](https://github.com/gagan3012/keytotext), [Wise-ft](https://github.com/mlfoundations/wise-ft), [timm](https://github.com/rwightman/pytorch-image-models), [Detectron2](https://github.com/facebookresearch/Detectron2), [DeiT](https://github.com/facebookresearch/deit), [MoCo](https://github.com/facebookresearch/moco).
190 |
191 |
--------------------------------------------------------------------------------
/pic/Figure1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CVMI-Lab/SyntheticData/062ac23b64aaa89020e1e41e5ca9d7719688d589/pic/Figure1.png
--------------------------------------------------------------------------------
/pic/Table1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CVMI-Lab/SyntheticData/062ac23b64aaa89020e1e41e5ca9d7719688d589/pic/Table1.png
--------------------------------------------------------------------------------
/pic/Table8-11.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CVMI-Lab/SyntheticData/062ac23b64aaa89020e1e41e5ca9d7719688d589/pic/Table8-11.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | torchvision
3 | robustness
4 | numpy
5 | scipy
6 | GPUtil
7 | dill
8 | tensorboardX
9 | tables
10 | tqdm
11 | seaborn
12 | jupyter
13 | cox
14 | sklearn
15 | pillow
16 | timm
17 | ftfy
18 | regex
19 | transformers
20 | wilds
21 | keytotext
--------------------------------------------------------------------------------
/src/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CVMI-Lab/SyntheticData/062ac23b64aaa89020e1e41e5ca9d7719688d589/src/.DS_Store
--------------------------------------------------------------------------------
/src/LE.py:
--------------------------------------------------------------------------------
1 | import sys, random
2 | from keytotext import pipeline
3 | import pickle
4 |
5 | labels = \
6 | ['a Annual Crop Land', 'a Forest', 'a Herbaceous Vegetation Land', 'a Highway or Road', 'a Industrial Building', 'a Pasture Land', 'a Permanent Crop Land', 'a Residential Building', 'a River', 'a Sea or Lake']
7 |
8 |
9 | nlp = pipeline("mrm8488/t5-base-finetuned-common_gen")
10 |
11 | def word2sentence(classnames, num=200, save_path=''):
12 | sentence_dict = {}
13 | for n in classnames:
14 | sentence_dict[n] = []
15 | for n in classnames:
16 | for i in range(num+50):
17 | sentence = nlp([n], num_return_sequences=1, do_sample=True)
18 | sentence_dict[n].append(sentence)
19 |
20 | # remove duplicate
21 | sampled_dict = {}
22 | for k, v in sentence_dict.items():
23 | v_unique = list(set(v))
24 | sampled_v = random.sample(v_unique, num)
25 | sampled_dict[k] = sampled_v
26 |
27 | r = open(save_path,"wb")
28 | pickle.dump(sampled_dict, r)
29 | r.close()
30 |
31 | if __name__ == "__main__":
32 | num = sys.argv[1]
33 | save_path = sys.argv[2]
34 | word2sentence(labels, int(num), save_path)
35 |
36 | '''
37 | python3.7 src/LE.py 200 /path/to/save/dataset.pkl
38 | '''
--------------------------------------------------------------------------------
/src/classifier-tuning/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CVMI-Lab/SyntheticData/062ac23b64aaa89020e1e41e5ca9d7719688d589/src/classifier-tuning/.DS_Store
--------------------------------------------------------------------------------
/src/classifier-tuning/clip/README.md:
--------------------------------------------------------------------------------
1 | This folder is a lightly modified version of https://github.com/openai/CLIP.
2 |
--------------------------------------------------------------------------------
/src/classifier-tuning/clip/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CVMI-Lab/SyntheticData/062ac23b64aaa89020e1e41e5ca9d7719688d589/src/classifier-tuning/clip/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/src/classifier-tuning/clip/tokenizer.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import html
3 | import os
4 | from functools import lru_cache
5 |
6 | import ftfy
7 | import regex as re
8 |
9 |
10 | @lru_cache()
11 | def default_bpe():
12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13 |
14 |
15 | @lru_cache()
16 | def bytes_to_unicode():
17 | """
18 | Returns list of utf-8 byte and a corresponding list of unicode strings.
19 | The reversible bpe codes work on unicode strings.
20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22 | This is a signficant percentage of your normal, say, 32K bpe vocab.
23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24 | And avoids mapping to whitespace/control characters the bpe code barfs on.
25 | """
26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27 | cs = bs[:]
28 | n = 0
29 | for b in range(2**8):
30 | if b not in bs:
31 | bs.append(b)
32 | cs.append(2**8+n)
33 | n += 1
34 | cs = [chr(n) for n in cs]
35 | return dict(zip(bs, cs))
36 |
37 |
38 | def get_pairs(word):
39 | """Return set of symbol pairs in a word.
40 | Word is represented as tuple of symbols (symbols being variable-length strings).
41 | """
42 | pairs = set()
43 | prev_char = word[0]
44 | for char in word[1:]:
45 | pairs.add((prev_char, char))
46 | prev_char = char
47 | return pairs
48 |
49 |
50 | def basic_clean(text):
51 | text = ftfy.fix_text(text)
52 | text = html.unescape(html.unescape(text))
53 | return text.strip()
54 |
55 |
56 | def whitespace_clean(text):
57 | text = re.sub(r'\s+', ' ', text)
58 | text = text.strip()
59 | return text
60 |
61 |
62 | class SimpleTokenizer(object):
63 | def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
64 | self.byte_encoder = bytes_to_unicode()
65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67 | merges = merges[1:49152-256-2+1]
68 | merges = [tuple(merge.split()) for merge in merges]
69 | vocab = list(bytes_to_unicode().values())
70 | vocab = vocab + [v+'' for v in vocab]
71 | for merge in merges:
72 | vocab.append(''.join(merge))
73 | if not special_tokens:
74 | special_tokens = ['', '']
75 | else:
76 | special_tokens = ['', ''] + special_tokens
77 | vocab.extend(special_tokens)
78 | self.encoder = dict(zip(vocab, range(len(vocab))))
79 | self.decoder = {v: k for k, v in self.encoder.items()}
80 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
81 | self.cache = {t:t for t in special_tokens}
82 | special = "|".join(special_tokens)
83 | self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
84 |
85 | self.vocab_size = len(self.encoder)
86 | self.all_special_ids = [self.encoder[t] for t in special_tokens]
87 |
88 | def bpe(self, token):
89 | if token in self.cache:
90 | return self.cache[token]
91 | word = tuple(token[:-1]) + ( token[-1] + '',)
92 | pairs = get_pairs(word)
93 |
94 | if not pairs:
95 | return token+''
96 |
97 | while True:
98 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
99 | if bigram not in self.bpe_ranks:
100 | break
101 | first, second = bigram
102 | new_word = []
103 | i = 0
104 | while i < len(word):
105 | try:
106 | j = word.index(first, i)
107 | new_word.extend(word[i:j])
108 | i = j
109 | except:
110 | new_word.extend(word[i:])
111 | break
112 |
113 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
114 | new_word.append(first+second)
115 | i += 2
116 | else:
117 | new_word.append(word[i])
118 | i += 1
119 | new_word = tuple(new_word)
120 | word = new_word
121 | if len(word) == 1:
122 | break
123 | else:
124 | pairs = get_pairs(word)
125 | word = ' '.join(word)
126 | self.cache[token] = word
127 | return word
128 |
129 | def encode(self, text):
130 | bpe_tokens = []
131 | text = whitespace_clean(basic_clean(text)).lower()
132 | for token in re.findall(self.pat, text):
133 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
134 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
135 | return bpe_tokens
136 |
137 | def decode(self, tokens):
138 | text = ''.join([self.decoder[token] for token in tokens])
139 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
140 | return text
--------------------------------------------------------------------------------
/src/classifier-tuning/src/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CVMI-Lab/SyntheticData/062ac23b64aaa89020e1e41e5ca9d7719688d589/src/classifier-tuning/src/.DS_Store
--------------------------------------------------------------------------------
/src/classifier-tuning/src/args.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 |
4 | import torch
5 |
6 | def parse_arguments():
7 | parser = argparse.ArgumentParser()
8 |
9 | parser.add_argument('--data_aug_lp', action='store_true',
10 | help='train data aug for linear probe, not feature datasets')
11 | parser.add_argument('--hard_pseudo_label', action='store_true',
12 | help='train data aug for linear probe, not feature datasets')
13 |
14 | parser.add_argument(
15 | "--data-location",
16 | type=str,
17 | default=os.path.expanduser('~/data'),
18 | help="The root directory for the datasets.",
19 | )
20 | parser.add_argument(
21 | "--data-location-real",
22 | type=str,
23 | default=os.path.expanduser('~/data'),
24 | help="The root directory for the datasets.",
25 | )
26 | parser.add_argument(
27 | "--data-location-syn",
28 | type=str,
29 | default=os.path.expanduser('~/data'),
30 | help="The root directory for the datasets.",
31 | )
32 | parser.add_argument(
33 | "--eval-datasets",
34 | default=None,
35 | type=lambda x: x.split(","),
36 | help="Which datasets to use for evaluation. Split by comma, e.g. CIFAR101,CIFAR102."
37 | " Note that same model used for all datasets, so much have same classnames"
38 | "for zero shot.",
39 | )
40 | parser.add_argument(
41 | "--train-dataset",
42 | default=None,
43 | help="For fine tuning or linear probe, which dataset to train on",
44 | )
45 | # parser.add_argument(
46 | # "--ds-real",
47 | # default=None,
48 | # help="real data dataset to train on",
49 | # )
50 | # parser.add_argument(
51 | # "--ds-syn",
52 | # default=None,
53 | # help="syn data dataset to train on",
54 | # )
55 | parser.add_argument(
56 | "--template",
57 | type=str,
58 | default=None,
59 | help="Which prompt template is used. Leave as None for linear probe, etc.",
60 | )
61 | parser.add_argument(
62 | "--classnames",
63 | type=str,
64 | default="openai",
65 | help="Which class names to use.",
66 | )
67 | parser.add_argument(
68 | "--alpha",
69 | default=[0.5],
70 | nargs='*',
71 | type=float,
72 | help=(
73 | 'Interpolation coefficient for ensembling. '
74 | 'Users should specify N-1 values, where N is the number of '
75 | 'models being ensembled. The specified numbers should sum to '
76 | 'less than 1. Note that the order of these values matter, and '
77 | 'should be the same as the order of the classifiers being ensembled.'
78 | )
79 | )
80 | parser.add_argument(
81 | "--exp_name",
82 | type=str,
83 | default=None,
84 | help="Name of the experiment, for organization purposes only."
85 | )
86 | parser.add_argument(
87 | "--results-db",
88 | type=str,
89 | default=None,
90 | help="Where to store the results, else does not store",
91 | )
92 | parser.add_argument(
93 | "--model",
94 | type=str,
95 | default=None,
96 | help="The type of model (e.g. RN50, ViT-B/32).",
97 | )
98 | parser.add_argument(
99 | "--batch-size",
100 | type=int,
101 | default=128,
102 | )
103 | parser.add_argument(
104 | "--batch-size-real",
105 | type=int,
106 | default=32,
107 | )
108 | parser.add_argument(
109 | "--batch-size-syn",
110 | type=int,
111 | default=512,
112 | )
113 | parser.add_argument(
114 | "--lr",
115 | type=float,
116 | default=0.001,
117 | help="Learning rate."
118 | )
119 | parser.add_argument(
120 | "--wd",
121 | type=float,
122 | default=0.1,
123 | help="Weight decay"
124 | )
125 | parser.add_argument(
126 | "--loss-weight",
127 | type=float,
128 | default=1.0,
129 | help="Loss weight balancing real and syn loss."
130 | )
131 | parser.add_argument(
132 | "--loss-weight-real",
133 | type=float,
134 | default=1.0,
135 | help="Loss weight balancing real and syn loss."
136 | )
137 | parser.add_argument(
138 | "--ls",
139 | type=float,
140 | default=0.0,
141 | help="Label smoothing."
142 | )
143 | parser.add_argument(
144 | "--sl",
145 | type=float,
146 | default=0.0,
147 | help="soft label."
148 | )
149 | parser.add_argument(
150 | "--sl_T",
151 | type=float,
152 | default=4.0,
153 | help="soft label Temperature."
154 | )
155 | parser.add_argument(
156 | "--warmup_length",
157 | type=int,
158 | default=500,
159 | )
160 | parser.add_argument(
161 | "--epochs",
162 | type=int,
163 | default=10,
164 | )
165 | parser.add_argument(
166 | "--load",
167 | type=lambda x: x.split(","),
168 | default=None,
169 | help="Optionally load _classifiers_, e.g. a zero shot classifier or probe or ensemble both.",
170 | )
171 | parser.add_argument(
172 | "--save",
173 | type=str,
174 | default=None,
175 | help="Optionally save a _classifier_, e.g. a zero shot classifier or probe.",
176 | )
177 | parser.add_argument(
178 | "--freeze-encoder",
179 | default=False,
180 | action="store_true",
181 | help="Whether or not to freeze the image encoder. Only relevant for fine-tuning."
182 | )
183 | parser.add_argument(
184 | "--cache-dir",
185 | type=str,
186 | default=None,
187 | help="Directory for caching features and encoder",
188 | )
189 | parser.add_argument(
190 | "--coopCSC",
191 | type=bool,
192 | default=False,
193 | help="coopCSC",
194 | )
195 | parser.add_argument(
196 | "--coopN_CTX",
197 | type=int,
198 | default=16,
199 | )
200 | parser.add_argument(
201 | "--coopPOS",
202 | type=str,
203 | default="end",
204 | help="end or middle or front",
205 | )
206 | parsed_args = parser.parse_args()
207 | parsed_args.device = "cuda" if torch.cuda.is_available() else "cpu"
208 |
209 | if parsed_args.load is not None and len(parsed_args.load) == 1:
210 | parsed_args.load = parsed_args.load[0]
211 | return parsed_args
212 |
--------------------------------------------------------------------------------
/src/classifier-tuning/src/ct_fsl.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from src.models.finetune import finetune_fsl
4 | from src.models.modeling import ClassificationHead, ImageEncoder, ImageClassifier
5 | from src.models.zeroshot import get_zeroshot_classifier
6 | from src.args import parse_arguments
7 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, RandomResizedCrop, RandomHorizontalFlip
8 | from PIL import Image
9 |
10 | def _convert_to_rgb(image):
11 | return image.convert('RGB')
12 |
13 | normalize = Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
14 |
15 | def classifier_tuning(args):
16 | assert args.save is not None, 'Please provide a path to store models'
17 | print('import success')
18 |
19 | # Build and save zero-shot model
20 | image_encoder = ImageEncoder(args, keep_lang=True)
21 | classification_head = get_zeroshot_classifier(args, image_encoder.model)
22 | delattr(image_encoder.model, 'transformer')
23 | classifier = ImageClassifier(image_encoder, classification_head, process_images=False)
24 |
25 | zeroshot_checkpoint = os.path.join(args.save, 'zeroshot'+args.train_dataset+'.pt')
26 | classifier.save(zeroshot_checkpoint)
27 |
28 | # Standard fine-tuning
29 | args.load = zeroshot_checkpoint
30 | args.save = os.path.join(args.save, 'finetuned')
31 |
32 | # Mimic eurosat low-res images, val data aug
33 | train_data_aug = Compose([
34 | # Resize(64), # resize to 32/64 for Cifar / Eurosat
35 | Resize(224, interpolation=Image.BICUBIC),
36 | CenterCrop(224),
37 | _convert_to_rgb,
38 | ToTensor(),
39 | normalize,
40 | ])
41 |
42 | finetuned_checkpoint = finetune_fsl(args, train_data_aug)
43 |
44 |
45 |
46 | if __name__ == '__main__':
47 | args = parse_arguments()
48 | classifier_tuning(args)
49 |
50 |
51 |
52 |
--------------------------------------------------------------------------------
/src/classifier-tuning/src/ct_zsl.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from src.models.finetune import finetune_zsl
4 | from src.models.modeling import ClassificationHead, ImageEncoder, ImageClassifier
5 | from src.models.zeroshot import get_zeroshot_classifier
6 | from src.args import parse_arguments
7 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, RandomResizedCrop, RandomHorizontalFlip
8 | from PIL import Image
9 |
10 | def _convert_to_rgb(image):
11 | return image.convert('RGB')
12 |
13 | normalize = Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
14 |
15 | def classifier_tuning(args):
16 | assert args.save is not None, 'Please provide a path to store models'
17 | print('import success')
18 |
19 | # Build and save zero-shot model
20 | image_encoder = ImageEncoder(args, keep_lang=True)
21 | classification_head = get_zeroshot_classifier(args, image_encoder.model)
22 | delattr(image_encoder.model, 'transformer')
23 | classifier = ImageClassifier(image_encoder, classification_head, process_images=False)
24 |
25 | zeroshot_checkpoint = os.path.join(args.save, 'zeroshot'+args.train_dataset+'.pt')
26 | classifier.save(zeroshot_checkpoint)
27 |
28 | # Standard fine-tuning
29 | args.load = zeroshot_checkpoint
30 | args.save = os.path.join(args.save, 'finetuned')
31 |
32 | # Mimic eurosat low-res images, val data aug
33 | train_data_aug = Compose([
34 | # Resize(64), # resize to 32/64 for Cifar / Eurosat
35 | Resize(224, interpolation=Image.BICUBIC),
36 | CenterCrop(224),
37 | _convert_to_rgb,
38 | ToTensor(),
39 | normalize,
40 | ])
41 |
42 | finetuned_checkpoint = finetune_zsl(args, train_data_aug)
43 |
44 |
45 |
46 | if __name__ == '__main__':
47 | args = parse_arguments()
48 | classifier_tuning(args)
49 |
50 |
51 |
52 |
--------------------------------------------------------------------------------
/src/classifier-tuning/src/datasets/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CVMI-Lab/SyntheticData/062ac23b64aaa89020e1e41e5ca9d7719688d589/src/classifier-tuning/src/datasets/.DS_Store
--------------------------------------------------------------------------------
/src/classifier-tuning/src/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .cifar10 import *
2 | from .cifar100 import *
3 | from .fmow import FMOWID, FMOWOOD, FMOW
4 | from .imagenet import ImageNet as ImageNet_ft
5 | # from .imagenetv2 import ImageNetV2
6 | from .imagenet_a import ImageNetAValClasses, ImageNetA
7 | from .imagenet_r import ImageNetRValClasses, ImageNetR
8 | from .imagenet_sketch import ImageNetSketch
9 | from .imagenet_vid_robust import ImageNetVidRobustValClasses, ImageNetVidRobust
10 | from .iwildcam import IWildCamID, IWildCamOOD, IWildCamIDNonEmpty, IWildCamOODNonEmpty, IWildCam
11 | from .objectnet import ObjectNetValClasses, ObjectNet
12 | from .ytbb_robust import YTBBRobustValClasses, YTBBRobust
13 | from .transfer_ds.transfer_datasets import *
14 |
--------------------------------------------------------------------------------
/src/classifier-tuning/src/datasets/cifar10.py:
--------------------------------------------------------------------------------
1 | import os
2 | import PIL
3 | import torch
4 | import numpy as np
5 | import torchvision
6 | from torchvision import transforms
7 | from torchvision.datasets import CIFAR10 as PyTorchCIFAR10
8 | from torchvision.datasets import VisionDataset
9 |
10 | cifar_classnames = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
11 |
12 | class CIFAR10_theirs:
13 | def __init__(self, preprocess,
14 | location=os.path.expanduser('~/data'),
15 | batch_size=128,
16 | num_workers=16,
17 | classnames=None):
18 |
19 |
20 | self.train_dataset = PyTorchCIFAR10(
21 | root=location, download=True, train=True, transform=preprocess
22 | )
23 |
24 | self.train_loader = torch.utils.data.DataLoader(
25 | self.train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers
26 | )
27 |
28 | self.test_dataset = PyTorchCIFAR10(
29 | root=location, download=True, train=False, transform=preprocess
30 | )
31 |
32 | self.test_loader = torch.utils.data.DataLoader(
33 | self.test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
34 | )
35 |
36 | self.classnames = self.test_dataset.classes
37 |
38 | def convert(x):
39 | if isinstance(x, np.ndarray):
40 | return torchvision.transforms.functional.to_pil_image(x)
41 | return x
42 |
43 | class BasicVisionDataset(VisionDataset):
44 | def __init__(self, images, targets, transform=None, target_transform=None):
45 | if transform is not None:
46 | transform.transforms.insert(0, convert)
47 | super(BasicVisionDataset, self).__init__(root=None, transform=transform, target_transform=target_transform)
48 | assert len(images) == len(targets)
49 |
50 | self.images = images
51 | self.targets = targets
52 |
53 | def __getitem__(self, index):
54 | return self.transform(self.images[index]), self.targets[index]
55 |
56 | def __len__(self):
57 | return len(self.targets)
58 |
59 | class CIFAR101:
60 | def __init__(self,
61 | preprocess,
62 | location=os.path.expanduser('~/data'),
63 | batch_size=128,
64 | num_workers=16,
65 | classnames=None):
66 |
67 | data_root = os.path.join(location, "CIFAR-10.1")
68 | data = np.load(os.path.join(data_root, 'cifar10.1_v6_data.npy'), allow_pickle=True)
69 | labels = np.load(os.path.join(data_root, 'cifar10.1_v6_labels.npy'), allow_pickle=True)
70 |
71 | use_cuda = torch.cuda.is_available()
72 |
73 | # Data loading code
74 | kwargs = {"num_workers": num_workers, "pin_memory": True} if use_cuda else {}
75 |
76 | self.train_loader = None
77 |
78 | self.test_dataset = BasicVisionDataset(
79 | images=data, targets=torch.Tensor(labels).long(),
80 | transform=preprocess,
81 | )
82 |
83 | self.test_loader = torch.utils.data.DataLoader(
84 | self.test_dataset, batch_size=batch_size, shuffle=False, **kwargs
85 | )
86 |
87 | self.classnames = cifar_classnames
88 |
89 |
90 | class CIFAR102:
91 | def __init__(self,
92 | preprocess,
93 | location=os.path.expanduser('~/data'),
94 | batch_size=128,
95 | num_workers=16,
96 | classnames=None):
97 |
98 | train_data = np.load(os.path.join(location, "CIFAR-10.2", 'cifar102_train.npy'), allow_pickle=True).item()
99 | test_data = np.load(os.path.join(location, "CIFAR-10.2", 'cifar102_test.npy'), allow_pickle=True).item()
100 |
101 |
102 | train_data_images = train_data['images']
103 | train_data_labels = train_data['labels']
104 |
105 | test_data_images = test_data['images']
106 | test_data_labels = test_data['labels']
107 |
108 | use_cuda = torch.cuda.is_available()
109 |
110 | # Data loading code
111 | kwargs = {"num_workers": num_workers, "pin_memory": True} if use_cuda else {}
112 |
113 | self.test_dataset = BasicVisionDataset(
114 | images=test_data_images, targets=torch.Tensor(test_data_labels).long(),
115 | transform=preprocess,
116 | )
117 |
118 | self.test_loader = torch.utils.data.DataLoader(
119 | self.test_dataset, batch_size=batch_size, shuffle=False, **kwargs
120 | )
121 |
122 | self.classnames = cifar_classnames
123 |
--------------------------------------------------------------------------------
/src/classifier-tuning/src/datasets/cifar100.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torchvision.datasets import CIFAR100 as PyTorchCIFAR100
4 |
5 | class CIFAR100_theirs:
6 | def __init__(self,
7 | preprocess,
8 | location=os.path.expanduser('~/data'),
9 | batch_size=128,
10 | num_workers=16,
11 | classnames=None):
12 |
13 | self.train_dataset = PyTorchCIFAR100(
14 | root=location, download=True, train=True, transform=preprocess
15 | )
16 |
17 | self.train_loader = torch.utils.data.DataLoader(
18 | self.train_dataset, batch_size=batch_size, num_workers=num_workers
19 | )
20 |
21 | self.test_dataset = PyTorchCIFAR100(
22 | root=location, download=True, train=False, transform=preprocess
23 | )
24 |
25 | self.test_loader = torch.utils.data.DataLoader(
26 | self.test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
27 | )
28 |
29 | self.classnames = self.test_dataset.classes
30 |
31 |
32 |
--------------------------------------------------------------------------------
/src/classifier-tuning/src/datasets/common.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import json
4 | import glob
5 | import collections
6 | import random
7 |
8 | import numpy as np
9 |
10 | from tqdm import tqdm
11 |
12 | import torchvision.datasets as datasets
13 | from torch.utils.data import Dataset, DataLoader, Sampler
14 |
15 |
16 | class SubsetSampler(Sampler):
17 | def __init__(self, indices):
18 | self.indices = indices
19 |
20 | def __iter__(self):
21 | return (i for i in self.indices)
22 |
23 | def __len__(self):
24 | return len(self.indices)
25 |
26 | class ImageFolderWithPaths(datasets.ImageFolder):
27 | def __init__(self, path, transform, flip_label_prob=0.0):
28 | super().__init__(path, transform)
29 | self.flip_label_prob = flip_label_prob
30 | if self.flip_label_prob > 0:
31 | print(f'Flipping labels with probability {self.flip_label_prob}')
32 | num_classes = len(self.classes)
33 | for i in range(len(self.samples)):
34 | if random.random() < self.flip_label_prob:
35 | new_label = random.randint(0, num_classes-1)
36 | self.samples[i] = (
37 | self.samples[i][0],
38 | new_label
39 | )
40 |
41 | def __getitem__(self, index):
42 | image, label = super(ImageFolderWithPaths, self).__getitem__(index)
43 | return {
44 | 'images': image,
45 | 'labels': label,
46 | 'image_paths': self.samples[index][0]
47 | }
48 |
49 |
50 | def maybe_dictionarize(batch):
51 | if isinstance(batch, dict):
52 | return batch
53 |
54 | if len(batch) == 2:
55 | batch = {'images': batch[0], 'labels': batch[1]}
56 | elif len(batch) == 3:
57 | batch = {'images': batch[0], 'labels': batch[1], 'metadata': batch[2]}
58 | else:
59 | raise ValueError(f'Unexpected number of elements: {len(batch)}')
60 |
61 | return batch
62 |
63 |
64 | def get_features_helper(image_encoder, dataloader, device):
65 | all_data = collections.defaultdict(list)
66 |
67 | image_encoder = image_encoder.to(device)
68 | image_encoder = torch.nn.DataParallel(image_encoder, device_ids=[x for x in range(torch.cuda.device_count())])
69 | image_encoder.eval()
70 |
71 | with torch.no_grad():
72 | for batch in tqdm(dataloader):
73 | batch = maybe_dictionarize(batch)
74 | features = image_encoder(batch['images'].cuda())
75 |
76 | all_data['features'].append(features.cpu())
77 |
78 | for key, val in batch.items():
79 | if key == 'images':
80 | continue
81 | if hasattr(val, 'cpu'):
82 | val = val.cpu()
83 | all_data[key].append(val)
84 | else:
85 | all_data[key].extend(val)
86 |
87 | for key, val in all_data.items():
88 | if torch.is_tensor(val[0]):
89 | all_data[key] = torch.cat(val).numpy()
90 |
91 | return all_data
92 |
93 |
94 | def get_features(is_train, image_encoder, dataset, device, is_real=True):
95 | split = 'train' if is_train else 'val'
96 | if is_real is False:
97 | split= 'trainsyn'
98 | dname = type(dataset).__name__
99 | cache_dir = f'{image_encoder.cache_dir}/{dname}/{split}'
100 | if image_encoder.cache_dir is not None:
101 | cache_dir = f'{image_encoder.cache_dir}/{dname}/{split}'
102 | cached_files = glob.glob(f'{cache_dir}/*')
103 | if image_encoder.cache_dir is not None and len(cached_files) > 0:
104 | print(f'Getting features from {cache_dir}')
105 | data = {}
106 | for cached_file in cached_files:
107 | name = os.path.splitext(os.path.basename(cached_file))[0]
108 | data[name] = torch.load(cached_file)
109 | else:
110 | # import ipdb
111 | # ipdb.set_trace(context=20)
112 | print(f'Did not find cached features at {cache_dir}. Building from scratch.')
113 | loader = dataset.train_loader if is_train else dataset.test_loader
114 | data = get_features_helper(image_encoder, loader, device)
115 | if image_encoder.cache_dir is None:
116 | print('Not caching because no cache directory was passed.')
117 | else:
118 | os.makedirs(cache_dir, exist_ok=True)
119 | print(f'Caching data at {cache_dir}')
120 | for name, val in data.items():
121 | torch.save(val, f'{cache_dir}/{name}.pt', pickle_protocol=4)
122 | return data
123 |
124 |
125 | class FeatureDataset(Dataset):
126 | def __init__(self, is_train, image_encoder, dataset, device, is_real=True):
127 | self.data = get_features(is_train, image_encoder, dataset, device, is_real)
128 |
129 | def __len__(self):
130 | return len(self.data['features'])
131 |
132 | def __getitem__(self, idx):
133 | data = {k: v[idx] for k, v in self.data.items()}
134 | data['features'] = torch.from_numpy(data['features']).float()
135 | return data
136 |
137 |
138 | def get_dataloader(dataset, is_train, args, image_encoder=None, is_real=True):
139 | if image_encoder is not None:
140 | feature_dataset = FeatureDataset(is_train, image_encoder, dataset, args.device, is_real)
141 | dataloader = DataLoader(feature_dataset, batch_size=args.batch_size, shuffle=is_train)
142 | else:
143 | dataloader = dataset.train_loader if is_train else dataset.test_loader
144 | return dataloader
--------------------------------------------------------------------------------
/src/classifier-tuning/src/datasets/fmow.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import wilds
4 |
5 | from torchvision.datasets import CIFAR10 as PyTorchCIFAR10
6 | from wilds.common.data_loaders import get_train_loader, get_eval_loader
7 |
8 | class FMOW:
9 | test_subset = None
10 |
11 | def __init__(self,
12 | preprocess,
13 | location=os.path.expanduser('~/data'),
14 | batch_size=128,
15 | num_workers=16,
16 | subset='test',
17 | classnames=None,
18 | **kwargs):
19 |
20 | self.dataset = wilds.get_dataset(dataset='fmow', root_dir=location)
21 |
22 | self.train_dataset = self.dataset.get_subset('train', transform=preprocess)
23 | self.train_loader = get_train_loader("standard", self.train_dataset, num_workers=num_workers, batch_size=batch_size)
24 |
25 | self.test_dataset = self.dataset.get_subset(self.test_subset, transform=preprocess)
26 | self.test_loader = get_eval_loader("standard", self.test_dataset, num_workers=num_workers, batch_size=batch_size)
27 |
28 | self.classnames = [
29 | "airport", "airport_hangar", "airport_terminal", "amusement_park", "aquaculture",
30 | "archaeological_site", "barn", "border_checkpoint", "burial_site", "car_dealership",
31 | "construction_site", "crop_field", "dam", "debris_or_rubble", "educational_institution",
32 | "electric_substation", "factory_or_powerplant", "fire_station", "flooded_road", "fountain",
33 | "gas_station", "golf_course", "ground_transportation_station", "helipad", "hospital",
34 | "impoverished_settlement", "interchange", "lake_or_pond", "lighthouse", "military_facility",
35 | "multi-unit_residential", "nuclear_powerplant", "office_building", "oil_or_gas_facility", "park",
36 | "parking_lot_or_garage", "place_of_worship", "police_station", "port", "prison", "race_track",
37 | "railway_bridge", "recreational_facility", "road_bridge", "runway", "shipyard", "shopping_mall",
38 | "single-unit_residential", "smokestack", "solar_farm", "space_facility", "stadium", "storage_tank",
39 | "surface_mine", "swimming_pool", "toll_booth", "tower", "tunnel_opening", "waste_disposal",
40 | "water_treatment_facility", "wind_farm", "zoo"
41 | ]
42 |
43 | def post_loop_metrics(self, labels, preds, metadata, args):
44 | metadata = torch.stack(metadata)
45 | preds = preds.argmax(dim=1, keepdim=True).view_as(labels)
46 | results = self.dataset.eval(preds, labels, metadata)
47 | return results[0]
48 |
49 | class FMOWID(FMOW):
50 | def __init__(self, *args, **kwargs):
51 | self.test_subset = 'id_test'
52 | super().__init__(*args, **kwargs)
53 |
54 | class FMOWOOD(FMOW):
55 | def __init__(self, *args, **kwargs):
56 | self.test_subset = 'test'
57 | super().__init__(*args, **kwargs)
--------------------------------------------------------------------------------
/src/classifier-tuning/src/datasets/imagenet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 | from .common import ImageFolderWithPaths, SubsetSampler
5 | from .imagenet_classnames import get_classnames
6 | import numpy as np
7 |
8 | class ImageNet:
9 | def __init__(self,
10 | preprocess,
11 | location=os.path.expanduser('~/data'),
12 | batch_size=32,
13 | num_workers=32,
14 | classnames='openai'):
15 | self.preprocess = preprocess
16 | self.location = location
17 | self.batch_size = batch_size
18 | self.num_workers = num_workers
19 | self.classnames = get_classnames(classnames)
20 |
21 | self.populate_train()
22 | self.populate_test()
23 |
24 | def populate_train(self):
25 | traindir = os.path.join(self.location, self.name(), 'train')
26 | self.train_dataset = ImageFolderWithPaths(
27 | traindir,
28 | transform=self.preprocess)
29 | sampler = self.get_train_sampler()
30 | kwargs = {'shuffle' : True} if sampler is None else {}
31 | self.train_loader = torch.utils.data.DataLoader(
32 | self.train_dataset,
33 | sampler=sampler,
34 | batch_size=self.batch_size,
35 | num_workers=self.num_workers,
36 | **kwargs,
37 | )
38 |
39 | def populate_test(self):
40 | self.test_dataset = self.get_test_dataset()
41 | self.test_loader = torch.utils.data.DataLoader(
42 | self.test_dataset,
43 | batch_size=self.batch_size,
44 | num_workers=self.num_workers,
45 | sampler=self.get_test_sampler()
46 | )
47 |
48 | def get_test_path(self):
49 | test_path = os.path.join(self.location, self.name(), 'val_in_folder')
50 | if not os.path.exists(test_path):
51 | test_path = os.path.join(self.location, self.name(), 'val')
52 | return test_path
53 |
54 | def get_train_sampler(self):
55 | return None
56 |
57 | def get_test_sampler(self):
58 | return None
59 |
60 | def get_test_dataset(self):
61 | return ImageFolderWithPaths(self.get_test_path(), transform=self.preprocess)
62 |
63 | def name(self):
64 | # return 'imagenet'
65 | return 'ILSVRC2012_img_train'
66 |
67 | class ImageNetTrain(ImageNet):
68 |
69 | def get_test_dataset(self):
70 | pass
71 |
72 | class ImageNetK(ImageNet):
73 |
74 | def get_train_sampler(self):
75 | idxs = np.zeros(len(self.train_dataset.targets))
76 | target_array = np.array(self.train_dataset.targets)
77 | for c in range(1000):
78 | m = target_array == c
79 | n = len(idxs[m])
80 | arr = np.zeros(n)
81 | arr[:self.k()] = 1
82 | np.random.shuffle(arr)
83 | idxs[m] = arr
84 |
85 | idxs = idxs.astype('int')
86 | sampler = SubsetSampler(np.where(idxs)[0])
87 | return sampler
88 |
89 |
90 | def project_logits(logits, class_sublist_mask, device):
91 | if isinstance(logits, list):
92 | return [project_logits(l, class_sublist_mask, device) for l in logits]
93 | if logits.size(1) > sum(class_sublist_mask):
94 | return logits[:, class_sublist_mask].to(device)
95 | else:
96 | return logits.to(device)
97 |
98 | class ImageNetSubsample(ImageNet):
99 | def __init__(self, *args, **kwargs):
100 | super().__init__(*args, **kwargs)
101 | class_sublist, self.class_sublist_mask = self.get_class_sublist_and_mask()
102 | self.classnames = [self.classnames[i] for i in class_sublist]
103 |
104 | def get_class_sublist_and_mask(self):
105 | raise NotImplementedError()
106 |
107 | def populate_train(self):
108 | pass
109 |
110 | def project_logits(self, logits, device):
111 | return project_logits(logits, self.class_sublist_mask, device)
112 |
113 | class ImageNetSubsampleValClasses(ImageNet):
114 | def get_class_sublist_and_mask(self):
115 | raise NotImplementedError()
116 |
117 | def populate_train(self):
118 | pass
119 |
120 | def get_test_sampler(self):
121 | self.class_sublist, self.class_sublist_mask = self.get_class_sublist_and_mask()
122 | idx_subsample_list = [range(x * 50, (x + 1) * 50) for x in self.class_sublist]
123 | idx_subsample_list = sorted([item for sublist in idx_subsample_list for item in sublist])
124 |
125 | sampler = SubsetSampler(idx_subsample_list)
126 | return sampler
127 |
128 | def project_labels(self, labels, device):
129 | projected_labels = [self.class_sublist.index(int(label)) for label in labels]
130 | return torch.LongTensor(projected_labels).to(device)
131 |
132 | def project_logits(self, logits, device):
133 | return project_logits(logits, self.class_sublist_mask, device)
134 |
135 | ks = [1, 2, 4, 8, 16, 25, 32, 50, 64, 128, 600]
136 |
137 | for k in ks:
138 | cls_name = f"ImageNet{k}"
139 | dyn_cls = type(cls_name, (ImageNetK, ), {
140 | "k": lambda self, num_samples=k: num_samples,
141 | })
142 | globals()[cls_name] = dyn_cls
--------------------------------------------------------------------------------
/src/classifier-tuning/src/datasets/imagenet_a.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torchvision.datasets as datasets
4 |
5 | from .imagenet import ImageNetSubsample, ImageNetSubsampleValClasses
6 | import numpy as np
7 |
8 |
9 | CLASS_SUBLIST = [
10 | 6, 11, 13, 15, 17, 22, 23, 27, 30, 37, 39, 42, 47, 50, 57, 70, 71, 76, 79, 89, 90, 94, 96, 97, 99, 105, 107,
11 | 108, 110,
12 | 113, 124, 125, 130, 132, 143, 144, 150, 151, 207, 234, 235, 254, 277, 283, 287, 291, 295, 298, 301, 306, 307,
13 | 308, 309,
14 | 310, 311, 313, 314, 315, 317, 319, 323, 324, 326, 327, 330, 334, 335, 336, 347, 361, 363, 372, 378, 386, 397,
15 | 400, 401,
16 | 402, 404, 407, 411, 416, 417, 420, 425, 428, 430, 437, 438, 445, 456, 457, 461, 462, 470, 472, 483, 486, 488,
17 | 492, 496,
18 | 514, 516, 528, 530, 539, 542, 543, 549, 552, 557, 561, 562, 569, 572, 573, 575, 579, 589, 606, 607, 609, 614,
19 | 626, 627,
20 | 640, 641, 642, 643, 658, 668, 677, 682, 684, 687, 701, 704, 719, 736, 746, 749, 752, 758, 763, 765, 768, 773,
21 | 774, 776,
22 | 779, 780, 786, 792, 797, 802, 803, 804, 813, 815, 820, 823, 831, 833, 835, 839, 845, 847, 850, 859, 862, 870,
23 | 879, 880,
24 | 888, 890, 897, 900, 907, 913, 924, 932, 933, 934, 937, 943, 945, 947, 951, 954, 956, 957, 959, 971, 972, 980,
25 | 981, 984,
26 | 986, 987, 988]
27 | CLASS_SUBLIST_MASK = [(i in CLASS_SUBLIST) for i in range(1000)]
28 |
29 |
30 | class ImageNetAValClasses(ImageNetSubsampleValClasses):
31 | def get_class_sublist_and_mask(self):
32 | return CLASS_SUBLIST, CLASS_SUBLIST_MASK
33 |
34 |
35 | class ImageNetA(ImageNetSubsample):
36 | def get_class_sublist_and_mask(self):
37 | return CLASS_SUBLIST, CLASS_SUBLIST_MASK
38 |
39 | def get_test_path(self):
40 | return os.path.join(self.location, 'imagenet-a')
41 |
--------------------------------------------------------------------------------
/src/classifier-tuning/src/datasets/imagenet_r.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torchvision.datasets as datasets
4 |
5 | from .imagenet import ImageNetSubsample, ImageNetSubsampleValClasses
6 | import numpy as np
7 |
8 |
9 | CLASS_SUBLIST = [
10 | 1, 2, 4, 6, 8, 9, 11, 13, 22, 23, 26, 29, 31, 39, 47, 63, 71, 76, 79, 84, 90, 94, 96, 97, 99, 100, 105, 107,
11 | 113, 122,
12 | 125, 130, 132, 144, 145, 147, 148, 150, 151, 155, 160, 161, 162, 163, 171, 172, 178, 187, 195, 199, 203,
13 | 207, 208, 219,
14 | 231, 232, 234, 235, 242, 245, 247, 250, 251, 254, 259, 260, 263, 265, 267, 269, 276, 277, 281, 288, 289,
15 | 291, 292, 293,
16 | 296, 299, 301, 308, 309, 310, 311, 314, 315, 319, 323, 327, 330, 334, 335, 337, 338, 340, 341, 344, 347,
17 | 353, 355, 361,
18 | 362, 365, 366, 367, 368, 372, 388, 390, 393, 397, 401, 407, 413, 414, 425, 428, 430, 435, 437, 441, 447,
19 | 448, 457, 462,
20 | 463, 469, 470, 471, 472, 476, 483, 487, 515, 546, 555, 558, 570, 579, 583, 587, 593, 594, 596, 609, 613,
21 | 617, 621, 629,
22 | 637, 657, 658, 701, 717, 724, 763, 768, 774, 776, 779, 780, 787, 805, 812, 815, 820, 824, 833, 847, 852,
23 | 866, 875, 883,
24 | 889, 895, 907, 928, 931, 932, 933, 934, 936, 937, 943, 945, 947, 948, 949, 951, 953, 954, 957, 963, 965,
25 | 967, 980, 981,
26 | 983, 988]
27 | CLASS_SUBLIST_MASK = [(i in CLASS_SUBLIST) for i in range(1000)]
28 |
29 |
30 | class ImageNetRValClasses(ImageNetSubsampleValClasses):
31 | def get_class_sublist_and_mask(self):
32 | return CLASS_SUBLIST, CLASS_SUBLIST_MASK
33 |
34 | class ImageNetR(ImageNetSubsample):
35 | def get_class_sublist_and_mask(self):
36 | return CLASS_SUBLIST, CLASS_SUBLIST_MASK
37 |
38 | def get_test_path(self):
39 | return os.path.join(self.location, 'imagenet-r')
--------------------------------------------------------------------------------
/src/classifier-tuning/src/datasets/imagenet_sketch.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .imagenet import ImageNet
3 |
4 |
5 | class ImageNetSketch(ImageNet):
6 |
7 | def populate_train(self):
8 | pass
9 |
10 | def get_test_path(self):
11 | return os.path.join(self.location, 'sketch')
12 |
--------------------------------------------------------------------------------
/src/classifier-tuning/src/datasets/imagenet_vid_robust.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torchvision.datasets as datasets
4 |
5 | from .imagenet import ImageNet
6 | import numpy as np
7 | import pathlib
8 | import json
9 |
10 | from .common import ImageFolderWithPaths, SubsetSampler
11 |
12 |
13 | class VidRobustDataset(ImageFolderWithPaths):
14 | def __init__(self, label_map, path, transform):
15 | self.label_map = label_map
16 | super().__init__(path, transform=transform)
17 |
18 | def __getitem__(self, index):
19 | data = super().__getitem__(index)
20 | label_key = '/'.join(data['image_paths'].split('/')[-3:])
21 | data['labels'] = self.label_map[label_key][0]
22 | return data
23 |
24 | class ImageNetVidRobustBase(ImageNet):
25 | def __init__(self, *args, **kwargs):
26 | data_loc = pathlib.Path(kwargs.get('location', '~')) / 'imagenet_vid_ytbb_robust/imagenet-vid-robust'
27 | with open((data_loc / 'misc/wnid_map.json').resolve()) as f:
28 | self.wnid_map = json.load(f)
29 | with open((data_loc / 'misc/rev_wnid_map.json').resolve()) as f:
30 | self.rev_wnid_map = json.load(f)
31 | with open((data_loc / 'misc/imagenet_class_index.json').resolve()) as f:
32 | self.imagenet_class_index = json.load(f)
33 | with open((data_loc / 'misc/imagenet_vid_class_index.json').resolve()) as f:
34 | self.imagenet_vid_class_index = json.load(f)
35 | with open((data_loc / 'metadata/labels.json').resolve()) as f:
36 | self.label_map = json.load(f)
37 | with open((data_loc / 'metadata/pmsets.json').resolve()) as f:
38 | self.pmsets = json.load(f)
39 |
40 | rev_imagenet = {v[0] : k for k, v in self.imagenet_class_index.items()}
41 | rev_vid = {v[0] : k for k,v in self.imagenet_vid_class_index.items()}
42 | self.CLASS_IDX_LIST = sorted([int(rev_imagenet[k]) for k in self.wnid_map])
43 | self.CLASS_IDX_MAP = {int(rev_imagenet[k]) : int(rev_vid[v]) for k, v in self.wnid_map.items()}
44 | self.rev_class_idx_map = {int(rev_vid[k]): [int(rev_imagenet[elt]) for elt in v] for k, v in self.rev_wnid_map.items()}
45 | self.merge_op = 'max'
46 |
47 | super().__init__(*args, **kwargs)
48 |
49 | self.classnames = [self.imagenet_vid_class_index[str(i)][1] for i in range(30)]
50 |
51 | def populate_train(self):
52 | pass
53 |
54 | def project_logits(self, logits, device):
55 | if isinstance(logits, list) or isinstance(logits, tuple):
56 | return [self.project_logits(l, device) for l in logits]
57 | if logits.shape[1] == 30:
58 | return logits
59 | if torch.is_tensor(logits):
60 | logits = logits.cpu().numpy()
61 | logits_projected = np.zeros((logits.shape[0], 30))
62 | for k, v in self.rev_class_idx_map.items():
63 | if self.merge_op == 'mean':
64 | logits_projected[:, k] = np.mean(logits[:, v], axis=1).squeeze()
65 | elif self.merge_op == 'median':
66 | logits_projected[:, k] = np.median(logits[:, v], axis=1).squeeze()
67 | elif self.merge_op == 'max':
68 | logits_projected[:, k] = np.max(logits[:, v], axis=1).squeeze()
69 | elif self.merge_op == 'sum':
70 | logits_projected[:, k] = np.sum(logits[:, v], axis=1)
71 | else:
72 | raise Exception(f'unsupported merge operation {merge_op} not allowed')
73 | return torch.tensor(logits_projected).to(device)
74 |
75 | def scatter_weights(self, weights):
76 | if weights.size(1) == 1000:
77 | return weights
78 | new_weights = torch.ones((weights.size(0), 1000)).to(weights.device) * -10e10
79 | for k, v in self.rev_class_idx_map.items():
80 | for vv in v:
81 | new_weights[:, vv] = weights[:, k]
82 | return new_weights
83 |
84 |
85 | class ImageNetVidRobustValClasses(ImageNetVidRobustBase):
86 |
87 | def post_loop_metrics(self, targets, logits, image_paths, args):
88 | logits = logits.numpy()
89 | targets = targets.numpy()
90 | return {'acc' : self.score_predictions(logits, targets)}
91 |
92 | def score_predictions(self, logits_projected, targets):
93 | preds = logits_projected.argmax(axis=1)
94 | acc = np.sum(np.equal(preds, targets))
95 | n = len(preds)
96 | return acc/n
97 |
98 | def get_test_sampler(self):
99 | idx_subsample_list = [range(x * 50, (x + 1) * 50) for x in self.CLASS_IDX_LIST]
100 | idx_subsample_list = sorted([item for sublist in idx_subsample_list for item in sublist])
101 |
102 | sampler = SubsetSampler(idx_subsample_list)
103 | return sampler
104 |
105 | def project_labels(self, labels, device):
106 | labels = labels.cpu().numpy()
107 | labels_projected = torch.tensor([self.CLASS_IDX_MAP[label] for label in labels]).to(device)
108 | return labels_projected
109 |
110 |
111 | class ImageNetVidRobust(ImageNetVidRobustBase):
112 |
113 | def score_predictions(self, preds, pmsets):
114 | correct_anchor = 0
115 | correct_pmk = 0
116 | N = len(pmsets)
117 | wrong_map = {}
118 | for anchor, pmset in pmsets.items():
119 | pmset_correct = 0
120 | wrongs = []
121 | for elem in pmset:
122 | if np.argmax(preds[elem]) in self.label_map[elem]:
123 | pmset_correct += 1
124 | else:
125 | wrongs.append(elem)
126 |
127 | if np.argmax(preds[anchor]) in self.label_map[anchor]:
128 | correct_anchor += 1
129 | pmset_correct += 1
130 | if len(wrongs) > 0:
131 | wrong_map[anchor] = wrongs[-1]
132 |
133 | if pmset_correct == len(pmset) + 1:
134 | correct_pmk += 1
135 |
136 | return correct_anchor/N, correct_pmk/N
137 |
138 | def post_loop_metrics(self, labels, logits, image_paths, args):
139 | logits = logits.numpy()
140 | labels = labels.numpy()
141 |
142 | preds_dict = {}
143 | for i, img_name in enumerate(image_paths):
144 | preds_dict['val/' + img_name.split('val/')[1]] = logits[i]
145 |
146 | benign,pmk = self.score_predictions(preds_dict, self.pmsets)
147 | metrics_dict = {}
148 | metrics_dict['pm0'] = benign
149 | metrics_dict['pm10'] = pmk
150 | metrics_dict['merge_op'] = self.merge_op
151 | return metrics_dict
152 |
153 | def get_test_dataset(self):
154 | valdir = os.path.join(self.location, 'imagenet_vid_ytbb_robust/imagenet-vid-robust/val')
155 | return VidRobustDataset(self.label_map, valdir, transform=self.preprocess)
156 |
157 |
--------------------------------------------------------------------------------
/src/classifier-tuning/src/datasets/imagenetv2.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 |
3 | from imagenetv2_pytorch import ImageNetV2Dataset
4 |
5 | from .imagenet import ImageNet
6 |
7 | class ImageNetV2DatasetWithPaths(ImageNetV2Dataset):
8 | def __getitem__(self, i):
9 | img, label = Image.open(self.fnames[i]), int(self.fnames[i].parent.name)
10 | if self.transform is not None:
11 | img = self.transform(img)
12 | return {
13 | 'images': img,
14 | 'labels': label,
15 | 'image_paths': str(self.fnames[i])
16 | }
17 |
18 | class ImageNetV2(ImageNet):
19 | def get_test_dataset(self):
20 | return ImageNetV2DatasetWithPaths(transform=self.preprocess, location=self.location)
21 |
--------------------------------------------------------------------------------
/src/classifier-tuning/src/datasets/iwildcam.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pandas as pd
3 | import json
4 | import numpy as np
5 | import pathlib
6 |
7 | import wilds
8 | from wilds.common.data_loaders import get_train_loader, get_eval_loader
9 | from wilds.datasets.wilds_dataset import WILDSSubset
10 |
11 |
12 | def get_mask_non_empty(dataset):
13 | metadf = pd.read_csv(dataset._data_dir / 'metadata.csv')
14 | filename = os.path.expanduser(dataset._data_dir / 'iwildcam2020_megadetector_results.json')
15 | with open(filename, 'r') as f:
16 | md_data = json.load(f)
17 | id_to_maxdet = {x['id']: x['max_detection_conf'] for x in md_data['images']}
18 | threshold = 0.95
19 | mask_non_empty = [id_to_maxdet[x] >= threshold for x in metadf['image_id']]
20 | return mask_non_empty
21 |
22 |
23 | def get_nonempty_subset(dataset, split, frac=1.0, transform=None):
24 | if split not in dataset.split_dict:
25 | raise ValueError(f"Split {split} not found in dataset's split_dict.")
26 | split_mask = dataset.split_array == dataset.split_dict[split]
27 |
28 | # intersect split mask with non_empty. here is the only place this fn differs
29 | # from https://github.com/p-lambda/wilds/blob/main/wilds/datasets/wilds_dataset.py#L56
30 | mask_non_empty = get_mask_non_empty(dataset)
31 | split_mask = split_mask & mask_non_empty
32 |
33 | split_idx = np.where(split_mask)[0]
34 | if frac < 1.0:
35 | num_to_retain = int(np.round(float(len(split_idx)) * frac))
36 | split_idx = np.sort(np.random.permutation(split_idx)[:num_to_retain])
37 | subset = WILDSSubset(dataset, split_idx, transform)
38 | return subset
39 |
40 |
41 | class IWildCam:
42 | def __init__(self,
43 | preprocess,
44 | location=os.path.expanduser('~/data'),
45 | remove_non_empty=False,
46 | batch_size=128,
47 | num_workers=16,
48 | classnames=None,
49 | subset='train'):
50 | self.dataset = wilds.get_dataset(dataset='iwildcam', root_dir=location)
51 | self.train_dataset = self.dataset.get_subset('train', transform=preprocess)
52 | self.train_loader = get_train_loader("standard", self.train_dataset, num_workers=num_workers, batch_size=batch_size)
53 |
54 | if remove_non_empty:
55 | self.train_dataset = get_nonempty_subset(self.dataset, 'train', transform=preprocess)
56 | else:
57 | self.train_dataset = self.dataset.get_subset('train', transform=preprocess)
58 |
59 | if remove_non_empty:
60 | self.test_dataset = get_nonempty_subset(self.dataset, subset, transform=preprocess)
61 | else:
62 | self.test_dataset = self.dataset.get_subset(subset, transform=preprocess)
63 |
64 | self.test_loader = get_eval_loader(
65 | "standard", self.test_dataset,
66 | num_workers=num_workers,
67 | batch_size=batch_size)
68 |
69 | labels_csv = pathlib.Path(__file__).parent / 'iwildcam_metadata' / 'labels.csv'
70 | df = pd.read_csv(labels_csv)
71 | df = df[df['y'] < 99999]
72 |
73 | self.classnames = [s.lower() for s in list(df['english'])]
74 |
75 | def post_loop_metrics(self, labels, preds, metadata, args):
76 | preds = preds.argmax(dim=1, keepdim=True).view_as(labels)
77 | results = self.dataset.eval(preds, labels, metadata)
78 | return results[0]
79 |
80 |
81 | class IWildCamID(IWildCam):
82 | def __init__(self, *args, **kwargs):
83 | kwargs['subset'] = 'id_test'
84 | super().__init__(*args, **kwargs)
85 |
86 |
87 | class IWildCamOOD(IWildCam):
88 | def __init__(self, *args, **kwargs):
89 | kwargs['subset'] = 'test'
90 | super().__init__(*args, **kwargs)
91 |
92 |
93 | class IWildCamNonEmpty(IWildCam):
94 | def __init__(self, *args, **kwargs):
95 | kwargs['subset'] = 'train'
96 | super().__init__(*args, **kwargs)
97 |
98 |
99 | class IWildCamIDNonEmpty(IWildCam):
100 | def __init__(self, *args, **kwargs):
101 | kwargs['subset'] = 'id_test'
102 | super().__init__(*args, **kwargs)
103 |
104 |
105 | class IWildCamOODNonEmpty(IWildCam):
106 | def __init__(self, *args, **kwargs):
107 | kwargs['subset'] = 'test'
108 | super().__init__(*args, **kwargs)
109 |
--------------------------------------------------------------------------------
/src/classifier-tuning/src/datasets/objectnet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from pathlib import Path
4 | import PIL
5 |
6 | import numpy as np
7 |
8 | import torch
9 | from torchvision import datasets
10 | from torchvision.transforms import Compose
11 |
12 | from .common import ImageFolderWithPaths, SubsetSampler
13 | from .imagenet import ImageNet, ImageNetSubsampleValClasses
14 |
15 |
16 | def get_metadata():
17 | metadata = Path(__file__).parent / 'objectnet_metadata'
18 |
19 | with open(metadata / 'folder_to_objectnet_label.json', 'r') as f:
20 | folder_map = json.load(f)
21 | folder_map = {v: k for k, v in folder_map.items()}
22 | with open(metadata / 'objectnet_to_imagenet_1k.json', 'r') as f:
23 | objectnet_map = json.load(f)
24 |
25 | with open(metadata / 'pytorch_to_imagenet_2012_id.json', 'r') as f:
26 | pytorch_map = json.load(f)
27 | pytorch_map = {v: k for k, v in pytorch_map.items()}
28 |
29 | with open(metadata / 'imagenet_to_label_2012_v2', 'r') as f:
30 | imagenet_map = {v.strip(): str(pytorch_map[i]) for i, v in enumerate(f)}
31 |
32 | folder_to_ids, class_sublist = {}, []
33 | classnames = []
34 | for objectnet_name, imagenet_names in objectnet_map.items():
35 | imagenet_names = imagenet_names.split('; ')
36 | imagenet_ids = [int(imagenet_map[imagenet_name]) for imagenet_name in imagenet_names]
37 | class_sublist.extend(imagenet_ids)
38 | folder_to_ids[folder_map[objectnet_name]] = imagenet_ids
39 |
40 | class_sublist = sorted(class_sublist)
41 | class_sublist_mask = [(i in class_sublist) for i in range(1000)]
42 | classname_map = {v: k for k, v in folder_map.items()}
43 | return class_sublist, class_sublist_mask, folder_to_ids, classname_map
44 |
45 |
46 | def crop(img):
47 | width, height = img.size
48 | cropArea = (2, 2, width - 2, height - 2)
49 | img = img.crop(cropArea)
50 | return img
51 |
52 |
53 | class ObjectNetDataset(datasets.ImageFolder):
54 |
55 | def __init__(self, label_map, path, transform):
56 | self.label_map = label_map
57 | super().__init__(path, transform=transform)
58 | self.samples = [
59 | d for d in self.samples
60 | if os.path.basename(os.path.dirname(d[0])) in self.label_map
61 | ]
62 | self.imgs = self.samples
63 |
64 | def __len__(self):
65 | return len(self.samples)
66 |
67 | def __getitem__(self, index):
68 | path, target = self.samples[index]
69 | sample = self.loader(path)
70 | if self.transform is not None:
71 | sample = self.transform(sample)
72 | label = os.path.basename(os.path.dirname(path))
73 | return {
74 | 'images': sample,
75 | 'labels': self.label_map[label],
76 | 'image_paths': path
77 | }
78 |
79 |
80 | class ObjectNetBase(ImageNet):
81 | def __init__(self, *args, **kwargs):
82 | (self._class_sublist,
83 | self.class_sublist_mask,
84 | self.folders_to_ids,
85 | self.classname_map) = get_metadata()
86 |
87 | super().__init__(*args, **kwargs)
88 |
89 | self.classnames = sorted(list(self.folders_to_ids.keys()))
90 | self.rev_class_idx_map = {}
91 | self.class_idx_map = {}
92 | for idx, name in enumerate(self.classnames):
93 | self.rev_class_idx_map[idx] = self.folders_to_ids[name]
94 | for imagenet_idx in self.rev_class_idx_map[idx]:
95 | self.class_idx_map[imagenet_idx] = idx
96 |
97 | self.crop = crop
98 | self.preprocess = Compose([crop, self.preprocess])
99 | self.classnames = [self.classname_map[c].lower() for c in self.classnames]
100 |
101 | def populate_train(self):
102 | pass
103 |
104 | def get_test_dataset(self):
105 | subdir = 'objectnet-1.0/images'
106 | valdir = os.path.join(self.location, subdir)
107 | label_map = {name: idx for idx, name in enumerate(sorted(list(self.folders_to_ids.keys())))}
108 | return ObjectNetDataset(label_map, valdir, transform=self.preprocess)
109 |
110 | def project_logits(self, logits, device):
111 | if isinstance(logits, list) or isinstance(logits, tuple):
112 | return [self.project_logits(l, device) for l in logits]
113 | if logits.shape[1] == 113:
114 | return logits
115 | if torch.is_tensor(logits):
116 | logits = logits.cpu().numpy()
117 | logits_projected = np.zeros((logits.shape[0], 113))
118 | for k, v in self.rev_class_idx_map.items():
119 | logits_projected[:, k] = np.max(logits[:, v], axis=1).squeeze()
120 | return torch.tensor(logits_projected).to(device)
121 |
122 | def scatter_weights(self, weights):
123 | if weights.size(1) == 1000:
124 | return weights
125 | new_weights = torch.ones((weights.size(0), 1000)).to(weights.device) * -10e8
126 | for k, v in self.rev_class_idx_map.items():
127 | for vv in v:
128 | new_weights[:, vv] = weights[:, k]
129 | return new_weights
130 |
131 |
132 |
133 | def accuracy(logits, targets, img_paths, args):
134 | assert logits.shape[1] == 113
135 | preds = logits.argmax(dim=1)
136 | if torch.is_tensor(preds):
137 | preds = preds.cpu().numpy()
138 | if torch.is_tensor(targets):
139 | targets = targets.cpu().numpy()
140 | return np.sum(preds == targets), len(preds)
141 |
142 |
143 | class ObjectNetValClasses(ObjectNetBase):
144 |
145 | def get_test_sampler(self):
146 | idx_subsample_list = [range(x * 50, (x + 1) * 50) for x in self._class_sublist]
147 | idx_subsample_list = sorted([item for sublist in idx_subsample_list for item in sublist])
148 |
149 | sampler = SubsetSampler(idx_subsample_list)
150 | return sampler
151 |
152 | def get_test_dataset(self):
153 | return ImageFolderWithPaths(self.get_test_path(), transform=self.preprocess)
154 |
155 | def project_labels(self, labels, device):
156 | projected_labels = [self.class_idx_map[int(label)] for label in labels]
157 | return torch.LongTensor(projected_labels).to(device)
158 |
159 |
160 | class ObjectNet(ObjectNetBase):
161 |
162 | def accuracy(self, logits, targets, img_paths, args):
163 | return accuracy(logits, targets, img_paths, args)
164 |
--------------------------------------------------------------------------------
/src/classifier-tuning/src/datasets/objectnet_metadata/objectnet_to_imagenet_1k.json:
--------------------------------------------------------------------------------
1 | {
2 | "Alarm clock": "analog clock; digital clock",
3 | "Backpack": "backpack, back pack, knapsack, packsack, rucksack, haversack",
4 | "Banana": "banana",
5 | "Band Aid": "Band Aid",
6 | "Basket": "shopping basket",
7 | "Bath towel": "bath towel",
8 | "Beer bottle": "beer bottle",
9 | "Bench": "park bench",
10 | "Bicycle": "mountain bike, all-terrain bike, off-roader; bicycle-built-for-two, tandem bicycle, tandem",
11 | "Binder (closed)": "binder, ring-binder",
12 | "Bottle cap": "bottlecap",
13 | "Bread loaf": "French loaf",
14 | "Broom": "broom",
15 | "Bucket": "bucket, pail",
16 | "Butcher's knife": "cleaver, meat cleaver, chopper",
17 | "Can opener": "can opener, tin opener",
18 | "Candle": "candle, taper, wax light",
19 | "Cellphone": "cellular telephone, cellular phone, cellphone, cell, mobile phone",
20 | "Chair": "barber chair; folding chair; rocking chair, rocker",
21 | "Clothes hamper": "hamper",
22 | "Coffee/French press": "espresso maker",
23 | "Combination lock": "combination lock",
24 | "Computer mouse": "mouse, computer mouse",
25 | "Desk lamp": "table lamp",
26 | "Dishrag or hand towel": "dishrag, dishcloth",
27 | "Doormat": "doormat, welcome mat",
28 | "Dress shoe (men)": "Loafer",
29 | "Drill": "power drill",
30 | "Drinking Cup": "cup",
31 | "Drying rack for plates": "plate rack",
32 | "Envelope": "envelope",
33 | "Fan": "electric fan, blower",
34 | "Frying pan": "frying pan, frypan, skillet",
35 | "Dress": "gown",
36 | "Hair dryer": "hand blower, blow dryer, blow drier, hair dryer, hair drier",
37 | "Hammer": "hammer",
38 | "Helmet": "football helmet; crash helmet",
39 | "Iron (for clothes)": "iron, smoothing iron",
40 | "Jeans": "jean, blue jean, denim",
41 | "Keyboard": "computer keyboard, keypad",
42 | "Ladle": "ladle",
43 | "Lampshade": "lampshade, lamp shade",
44 | "Laptop (open)": "laptop, laptop computer",
45 | "Lemon": "lemon",
46 | "Letter opener": "letter opener, paper knife, paperknife",
47 | "Lighter": "lighter, light, igniter, ignitor",
48 | "Lipstick": "lipstick, lip rouge",
49 | "Match": "matchstick",
50 | "Measuring cup": "measuring cup",
51 | "Microwave": "microwave, microwave oven",
52 | "Mixing / Salad Bowl": "mixing bowl",
53 | "Monitor": "monitor",
54 | "Mug": "coffee mug",
55 | "Nail (fastener)": "nail",
56 | "Necklace": "necklace",
57 | "Orange": "orange",
58 | "Padlock": "padlock",
59 | "Paintbrush": "paintbrush",
60 | "Paper towel": "paper towel",
61 | "Pen": "ballpoint, ballpoint pen, ballpen, Biro; quill, quill pen; fountain pen",
62 | "Pill bottle": "pill bottle",
63 | "Pillow": "pillow",
64 | "Pitcher": "pitcher, ewer",
65 | "Plastic bag": "plastic bag",
66 | "Plate": "plate",
67 | "Plunger": "plunger, plumber's helper",
68 | "Pop can": "pop bottle, soda bottle",
69 | "Portable heater": "space heater",
70 | "Printer": "printer",
71 | "Remote control": "remote control, remote",
72 | "Ruler": "rule, ruler",
73 | "Running shoe": "running shoe",
74 | "Safety pin": "safety pin",
75 | "Salt shaker": "saltshaker, salt shaker",
76 | "Sandal": "sandal",
77 | "Screw": "screw",
78 | "Shovel": "shovel",
79 | "Skirt": "hoopskirt, crinoline; miniskirt, mini; overskirt",
80 | "Sleeping bag": "sleeping bag",
81 | "Soap dispenser": "soap dispenser",
82 | "Sock": "sock",
83 | "Soup Bowl": "soup bowl",
84 | "Spatula": "spatula",
85 | "Speaker": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system",
86 | "Still Camera": "Polaroid camera, Polaroid Land camera; reflex camera",
87 | "Strainer": "strainer",
88 | "Stuffed animal": "teddy, teddy bear",
89 | "Suit jacket": "suit, suit of clothes",
90 | "Sunglasses": "sunglasses, dark glasses, shades",
91 | "Sweater": "sweatshirt",
92 | "Swimming trunks": "swimming trunks, bathing trunks",
93 | "T-shirt": "jersey, T-shirt, tee shirt",
94 | "TV": "television, television system",
95 | "Teapot": "teapot",
96 | "Tennis racket": "racket, racquet",
97 | "Tie": "bow tie, bow-tie, bowtie; Windsor tie",
98 | "Toaster": "toaster",
99 | "Toilet paper roll": "toilet tissue, toilet paper, bathroom tissue",
100 | "Trash bin": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin",
101 | "Tray": "tray",
102 | "Umbrella": "umbrella",
103 | "Vacuum cleaner": "vacuum, vacuum cleaner",
104 | "Vase": "vase",
105 | "Wallet": "wallet, billfold, notecase, pocketbook",
106 | "Watch": "digital watch",
107 | "Water bottle": "water bottle",
108 | "Weight (exercise)": "dumbbell",
109 | "Weight scale": "scale, weighing machine",
110 | "Wheel": "car wheel; paddlewheel, paddle wheel",
111 | "Whistle": "whistle",
112 | "Wine bottle": "wine bottle",
113 | "Winter glove": "mitten",
114 | "Wok": "wok"
115 | }
116 |
--------------------------------------------------------------------------------
/src/classifier-tuning/src/datasets/transfer_ds/Randaug.py:
--------------------------------------------------------------------------------
1 | # code in this file is adpated from rpmcruz/autoaugment
2 | # https://github.com/rpmcruz/autoaugment/blob/master/transformations.py
3 | import random
4 |
5 | import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw
6 | import numpy as np
7 | import torch
8 | from PIL import Image
9 |
10 |
11 | def ShearX(img, v): # [-0.3, 0.3]
12 | assert -0.3 <= v <= 0.3
13 | if random.random() > 0.5:
14 | v = -v
15 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))
16 |
17 |
18 | def ShearY(img, v): # [-0.3, 0.3]
19 | assert -0.3 <= v <= 0.3
20 | if random.random() > 0.5:
21 | v = -v
22 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))
23 |
24 |
25 | def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
26 | assert -0.45 <= v <= 0.45
27 | if random.random() > 0.5:
28 | v = -v
29 | v = v * img.size[0]
30 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
31 |
32 |
33 | def TranslateXabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
34 | assert 0 <= v
35 | if random.random() > 0.5:
36 | v = -v
37 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
38 |
39 |
40 | def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
41 | assert -0.45 <= v <= 0.45
42 | if random.random() > 0.5:
43 | v = -v
44 | v = v * img.size[1]
45 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
46 |
47 |
48 | def TranslateYabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
49 | assert 0 <= v
50 | if random.random() > 0.5:
51 | v = -v
52 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
53 |
54 |
55 | def Rotate(img, v): # [-30, 30]
56 | assert -30 <= v <= 30
57 | if random.random() > 0.5:
58 | v = -v
59 | return img.rotate(v)
60 |
61 |
62 | def AutoContrast(img, _):
63 | return PIL.ImageOps.autocontrast(img)
64 |
65 |
66 | def Invert(img, _):
67 | return PIL.ImageOps.invert(img)
68 |
69 |
70 | def Equalize(img, _):
71 | return PIL.ImageOps.equalize(img)
72 |
73 |
74 | def Flip(img, _): # not from the paper
75 | return PIL.ImageOps.mirror(img)
76 |
77 |
78 | def Solarize(img, v): # [0, 256]
79 | assert 0 <= v <= 256
80 | return PIL.ImageOps.solarize(img, v)
81 |
82 |
83 | def SolarizeAdd(img, addition=0, threshold=128):
84 | img_np = np.array(img).astype(np.int)
85 | img_np = img_np + addition
86 | img_np = np.clip(img_np, 0, 255)
87 | img_np = img_np.astype(np.uint8)
88 | img = Image.fromarray(img_np)
89 | return PIL.ImageOps.solarize(img, threshold)
90 |
91 |
92 | def Posterize(img, v): # [4, 8]
93 | v = int(v)
94 | v = max(1, v)
95 | return PIL.ImageOps.posterize(img, v)
96 |
97 |
98 | def Contrast(img, v): # [0.1,1.9]
99 | assert 0.1 <= v <= 1.9
100 | return PIL.ImageEnhance.Contrast(img).enhance(v)
101 |
102 |
103 | def Color(img, v): # [0.1,1.9]
104 | assert 0.1 <= v <= 1.9
105 | return PIL.ImageEnhance.Color(img).enhance(v)
106 |
107 |
108 | def Brightness(img, v): # [0.1,1.9]
109 | assert 0.1 <= v <= 1.9
110 | return PIL.ImageEnhance.Brightness(img).enhance(v)
111 |
112 |
113 | def Sharpness(img, v): # [0.1,1.9]
114 | assert 0.1 <= v <= 1.9
115 | return PIL.ImageEnhance.Sharpness(img).enhance(v)
116 |
117 |
118 | def Cutout(img, v): # [0, 60] => percentage: [0, 0.2]
119 | assert 0.0 <= v <= 0.2
120 | if v <= 0.:
121 | return img
122 |
123 | v = v * img.size[0]
124 | return CutoutAbs(img, v)
125 |
126 |
127 | def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2]
128 | # assert 0 <= v <= 20
129 | if v < 0:
130 | return img
131 | w, h = img.size
132 | x0 = np.random.uniform(w)
133 | y0 = np.random.uniform(h)
134 |
135 | x0 = int(max(0, x0 - v / 2.))
136 | y0 = int(max(0, y0 - v / 2.))
137 | x1 = min(w, x0 + v)
138 | y1 = min(h, y0 + v)
139 |
140 | xy = (x0, y0, x1, y1)
141 | color = (125, 123, 114)
142 | if len(img.size)==2:
143 | color = (125)
144 | # color = (0, 0, 0)
145 | img = img.copy()
146 | PIL.ImageDraw.Draw(img).rectangle(xy, color)
147 |
148 | return img
149 |
150 |
151 | def SamplePairing(imgs): # [0, 0.4]
152 | def f(img1, v):
153 | i = np.random.choice(len(imgs))
154 | img2 = PIL.Image.fromarray(imgs[i])
155 | return PIL.Image.blend(img1, img2, v)
156 |
157 | return f
158 |
159 |
160 | def Identity(img, v):
161 | return img
162 |
163 |
164 | def augment_list(): # 16 oeprations and their ranges
165 | # https://github.com/google-research/uda/blob/master/image/randaugment/policies.py#L57
166 | # l = [
167 | # (Identity, 0., 1.0),
168 | # (ShearX, 0., 0.3), # 0
169 | # (ShearY, 0., 0.3), # 1
170 | # (TranslateX, 0., 0.33), # 2
171 | # (TranslateY, 0., 0.33), # 3
172 | # (Rotate, 0, 30), # 4
173 | # (AutoContrast, 0, 1), # 5
174 | # (Invert, 0, 1), # 6
175 | # (Equalize, 0, 1), # 7
176 | # (Solarize, 0, 110), # 8
177 | # (Posterize, 4, 8), # 9
178 | # # (Contrast, 0.1, 1.9), # 10
179 | # (Color, 0.1, 1.9), # 11
180 | # (Brightness, 0.1, 1.9), # 12
181 | # (Sharpness, 0.1, 1.9), # 13
182 | # # (Cutout, 0, 0.2), # 14
183 | # # (SamplePairing(imgs), 0, 0.4), # 15
184 | # ]
185 |
186 | # https://github.com/tensorflow/tpu/blob/8462d083dd89489a79e3200bcc8d4063bf362186/models/official/efficientnet/autoaugment.py#L505
187 | l = [
188 | (AutoContrast, 0, 1),
189 | (Equalize, 0, 1),
190 | (Invert, 0, 1),
191 | (Rotate, 0, 30),
192 | (Posterize, 0, 4),
193 | (Solarize, 0, 256),
194 | (SolarizeAdd, 0, 110),
195 | (Color, 0.1, 1.9),
196 | (Contrast, 0.1, 1.9),
197 | (Brightness, 0.1, 1.9),
198 | (Sharpness, 0.1, 1.9),
199 | (ShearX, 0., 0.3),
200 | (ShearY, 0., 0.3),
201 | (CutoutAbs, 0, 40),
202 | (TranslateXabs, 0., 100),
203 | (TranslateYabs, 0., 100),
204 | ]
205 |
206 | return l
207 |
208 |
209 | class Lighting(object):
210 | """Lighting noise(AlexNet - style PCA - based noise)"""
211 |
212 | def __init__(self, alphastd, eigval, eigvec):
213 | self.alphastd = alphastd
214 | self.eigval = torch.Tensor(eigval)
215 | self.eigvec = torch.Tensor(eigvec)
216 |
217 | def __call__(self, img):
218 | if self.alphastd == 0:
219 | return img
220 |
221 | alpha = img.new().resize_(3).normal_(0, self.alphastd)
222 | rgb = self.eigvec.type_as(img).clone() \
223 | .mul(alpha.view(1, 3).expand(3, 3)) \
224 | .mul(self.eigval.view(1, 3).expand(3, 3)) \
225 | .sum(1).squeeze()
226 |
227 | return img.add(rgb.view(3, 1, 1).expand_as(img))
228 |
229 |
230 | class CutoutDefault(object):
231 | """
232 | Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py
233 | """
234 | def __init__(self, length):
235 | self.length = length
236 |
237 | def __call__(self, img):
238 | h, w = img.size(1), img.size(2)
239 | mask = np.ones((h, w), np.float32)
240 | y = np.random.randint(h)
241 | x = np.random.randint(w)
242 |
243 | y1 = np.clip(y - self.length // 2, 0, h)
244 | y2 = np.clip(y + self.length // 2, 0, h)
245 | x1 = np.clip(x - self.length // 2, 0, w)
246 | x2 = np.clip(x + self.length // 2, 0, w)
247 |
248 | mask[y1: y2, x1: x2] = 0.
249 | mask = torch.from_numpy(mask)
250 | mask = mask.expand_as(img)
251 | img *= mask
252 | return img
253 |
254 |
255 | class RandAugment:
256 | def __init__(self, n, m):
257 | self.n = n
258 | self.m = m # [0, 30]
259 | self.augment_list = augment_list()
260 |
261 | def __call__(self, img):
262 | ops = random.choices(self.augment_list, k=self.n)
263 | for op, minval, maxval in ops:
264 | val = (float(self.m) / 30) * float(maxval - minval) + minval
265 | img = op(img, val)
266 |
267 | return img
--------------------------------------------------------------------------------
/src/classifier-tuning/src/datasets/transfer_ds/cal_mean_std.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision import datasets, transforms
3 |
4 | # dataset = datasets.ImageFolder('/opt/tiger/filter_transfer/data/PASS_dataset/train',transform=transforms.ToTensor())
5 | # dataset = datasets.ImageFolder('/opt/tiger/filter_transfer/data/imagenet/ILSVRC2012_img_train/train',transform=transforms.ToTensor())
6 | dataset = datasets.ImageFolder('/opt/tiger/filter_transfer/data/PASS_dataset/train',
7 | transform=transforms.Compose([transforms.Resize(256),
8 | transforms.CenterCrop(224),
9 | transforms.ToTensor()]))
10 |
11 | # --------- PASS
12 | # mean: tensor([0.4646, 0.4484, 0.4129])
13 | # std: tensor([0.2750, 0.2689, 0.2885])
14 |
15 | # dataset = datasets.ImageFolder('/opt/tiger/filter_transfer/data/imagenet/ILSVRC2012_img_train/train',
16 | # transform=transforms.Compose([transforms.Resize(256),
17 | # transforms.CenterCrop(224),
18 | # transforms.ToTensor()]))
19 |
20 | loader = torch.utils.data.DataLoader(dataset,
21 | batch_size=1000,
22 | num_workers=8,
23 | shuffle=False)
24 |
25 | # mean = 0.
26 | # meansq = 0.
27 | # i=0
28 | # for data,_ in loader:
29 | # print('{}/{}'.format(i,len(loader)))
30 | # i+=1
31 | # mean = data.mean()
32 | # meansq = (data ** 2).mean()
33 | #
34 | # std = torch.sqrt(meansq - mean ** 2)
35 | # print("mean: " + str(mean))
36 | # print("std: " + str(std))
37 | # print()
38 |
39 | # mean = 0.0
40 | # i=0
41 | # for images, _ in loader:
42 | # batch_samples = images.size(0)
43 | # images = images.view(batch_samples, images.size(1), -1)
44 | # mean += images.mean(2).sum(0)
45 | # print('{}/{}'.format(i, len(loader)))
46 | # i+=1
47 | # print(mean / i / 1000)
48 | # mean = mean / len(loader.dataset) / 1000
49 |
50 | # import ipdb
51 | # ipdb.set_trace(context=20)
52 | # mean = torch.FloatTensor([0.485, 0.456, 0.406])
53 | mean = torch.FloatTensor([0.4646, 0.4484, 0.4129])
54 | var = 0.0
55 | i=0
56 | for images, _ in loader:
57 | batch_samples = images.size(0)
58 | images = images.view(batch_samples, images.size(1), -1)
59 | var += ((images - mean.unsqueeze(1))**2).sum([0,2])
60 | print('{}/{}'.format(i, len(loader)))
61 | i += 1
62 | print(torch.sqrt(var / (i*224*224)))
63 | print(torch.sqrt(var / (i*1000*224*224)))
64 | std = torch.sqrt(var / (len(loader.dataset)*224*224))
65 |
66 | import ipdb
67 | ipdb.set_trace(context=20)
68 |
69 | a=1
70 |
--------------------------------------------------------------------------------
/src/classifier-tuning/src/datasets/transfer_ds/constants.py:
--------------------------------------------------------------------------------
1 | from torchvision import transforms
2 | from .Randaug import RandAugment
3 |
4 | prefix = "data"
5 |
6 | # IMGNET_PATH = prefix + "/imagenet/ILSVRC2012_img_train_100k/"
7 | # IMGNET_PATH = prefix + "/imagenet/ILSVRC2012_img_train_128k/"
8 | # IMGNET_PATH = prefix + "/imagenet/ILSVRC2012_img_train_200k/"
9 | IMGNET_PATH = prefix + "/imagenet/ILSVRC2012_img_train/"
10 | # IMGNET_PATH = prefix + "/imagenet/ILSVRC2012_img_train_plus15tasks/"
11 | # IMGNET_PATH = prefix + "/imagenet/imagenet.10_1000/"
12 | # IMGNET_PATH = prefix + "/imagenet/ILSVRC2012_img_traintrain_200cls_640shot_128.0k/"
13 | # IMGNET_PATH = prefix + "/imagenet/ILSVRC2012_img_traintrain_500cls_256shot_128.0k/"
14 | # IMGNET_PATH = prefix + "/place_128k"
15 | # IMGNET_PATH = prefix + "/pass/PASS_128k"
16 |
17 | # Planes dataset
18 | FGVC_PATH = prefix + "/fgvc-aircraft-2013b/"
19 |
20 | # Oxford Flowers dataset
21 | # FLOWERS_PATH = prefix + "/oxford_flowers_pytorch/"
22 | FLOWERS_PATH = prefix + "/flowers_new/"
23 |
24 | # DTD dataset
25 | DTD_PATH = prefix + "/dtd/"
26 |
27 | # Stanford Cars dataset
28 | CARS_PATH = prefix + "/cars_new"
29 |
30 | # SUN397 dataset
31 | SUN_PATH = prefix + "/SUN397/splits_01/"
32 |
33 | # FOOD dataset
34 | FOOD_PATH = prefix + "/food-101"
35 |
36 | # BIRDS dataset
37 | BIRDS_PATH = prefix + "/birdsnap"
38 |
39 | # CUB-200-2011 birds
40 | CUB_PATH = prefix + "/CUB_200_2011"
41 |
42 | # COCO
43 | COCO_PATH = prefix + "/coco_cls"
44 |
45 | # ade20k
46 | ADE20K_PATH = prefix + "/ade20k_cls"
47 |
48 | # Mix seg: cs voc ade
49 | MIX_SEG_PATH = prefix + "/mix_seg"
50 |
51 | # PETS dataset
52 | PETS_PATH = prefix + ""
53 |
54 | # Caltech datasets
55 | CALTECH101_PATH = prefix + ""
56 | CALTECH256_PATH = prefix + ""
57 |
58 | value_scale = 255
59 | mean = [0.485, 0.456, 0.406]
60 | mean = [0.48145466, 0.4578275, 0.40821073]
61 | mean = [item * value_scale for item in mean]
62 | std = [0.229, 0.224, 0.225]
63 | std = [0.26862954, 0.26130258, 0.27577711]
64 | std = [item * value_scale for item in std]
65 |
66 | # Data Augmentation defaults
67 | TRAIN_TRANSFORMS = transforms.Compose([
68 | # transforms.Resize(32),
69 | transforms.RandomResizedCrop(224),
70 | # transforms.RandomResizedCrop(224, scale=(0.08,1.0), ratio=(0.75,1.333333)),
71 | # transforms.RandomResizedCrop(224, scale=(0.08,1.0), ratio=(0.5,2.0)),
72 | transforms.RandomHorizontalFlip(),
73 | transforms.ToTensor(),
74 | # transforms.Normalize(mean=mean, std=std),
75 | ])
76 | # TRAIN_TRANSFORMS = transforms.Compose([
77 | # # transforms.Resize(32),
78 | # transforms.Resize(256),
79 | # transforms.CenterCrop(224),
80 | # transforms.ToTensor(),
81 | # # transforms.Normalize(mean=mean, std=std),
82 | # ])
83 |
84 | TEST_TRANSFORMS = transforms.Compose([
85 | # transforms.Resize(32),
86 | transforms.Resize(256),
87 | transforms.CenterCrop(224),
88 | transforms.ToTensor(),
89 | # transforms.Normalize(mean=mean, std=std),
90 | ])
91 |
92 | # from PIL import Image
93 | # BICUBIC = Image.BICUBIC
94 | # TEST_TRANSFORMS = transforms.Compose([
95 | # # transforms.Resize(32),
96 | # transforms.Resize(224,interpolation=BICUBIC),
97 | # transforms.CenterCrop(224),
98 | # transforms.ToTensor(),
99 | # # transforms.Normalize(mean=mean, std=std),
100 | # ])
101 |
102 | # Add RandAugment with N, M(hyperparameter)
103 | # N=3
104 | # M=9
105 | # TRAIN_TRANSFORMS.transforms.insert(0, RandAugment(N, M))
--------------------------------------------------------------------------------
/src/classifier-tuning/src/datasets/transfer_ds/cub.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import numpy as np
4 | from torch.utils.data import Dataset
5 |
6 |
7 | class CUB(Dataset):
8 |
9 | def __init__(self, path, train=True, transform=None, target_transform=None):
10 |
11 | self.root = path
12 | self.is_train = train
13 | self.transform = transform
14 | self.target_transform = target_transform
15 | self.images_path = {}
16 | with open(os.path.join(self.root, 'images.txt')) as f:
17 | for line in f:
18 | image_id, path = line.split()
19 | self.images_path[image_id] = path
20 |
21 | self.class_ids = {}
22 | with open(os.path.join(self.root, 'image_class_labels.txt')) as f:
23 | for line in f:
24 | image_id, class_id = line.split()
25 | self.class_ids[image_id] = class_id
26 |
27 | self.data_id = []
28 | if self.is_train:
29 | with open(os.path.join(self.root, 'train_test_split.txt')) as f:
30 | for line in f:
31 | image_id, is_train = line.split()
32 | if int(is_train):
33 | self.data_id.append(image_id)
34 | if not self.is_train:
35 | with open(os.path.join(self.root, 'train_test_split.txt')) as f:
36 | for line in f:
37 | image_id, is_train = line.split()
38 | if not int(is_train):
39 | self.data_id.append(image_id)
40 |
41 | def __len__(self):
42 | return len(self.data_id)
43 |
44 | def __getitem__(self, index):
45 | """
46 | Args:
47 | index: index of training dataset
48 | Returns:
49 | image and its corresponding label
50 | """
51 | image_id = self.data_id[index]
52 | class_id = int(self._get_class_by_id(image_id)) - 1
53 | path = self._get_path_by_id(image_id)
54 | image = cv2.imread(os.path.join(self.root, 'images', path))
55 |
56 | if self.transform:
57 | image = self.transform(image)
58 |
59 | if self.target_transform:
60 | class_id = self.target_transform(class_id)
61 | return image, class_id
62 |
63 | def _get_path_by_id(self, image_id):
64 |
65 | return self.images_path[image_id]
66 |
67 | def _get_class_by_id(self, image_id):
68 |
69 | return self.class_ids[image_id]
--------------------------------------------------------------------------------
/src/classifier-tuning/src/datasets/transfer_ds/cub_for_robust_codebase.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import numpy as np
4 | from torch.utils.data import Dataset
5 |
6 | def main():
7 | path= "/opt/tiger/filter_transfer/data/CUB_200_2011"
8 | root = path
9 | images_path = {}
10 |
11 | os.system('mkdir train')
12 | os.system('cp -r images/* train/')
13 | os.system('mkdir val')
14 | os.system('cp -r images/* val/')
15 |
16 | with open(os.path.join(root, 'images.txt')) as f:
17 | for line in f:
18 | image_id, path = line.split()
19 | images_path[image_id] = path
20 |
21 | class_ids = {}
22 | with open(os.path.join(root, 'image_class_labels.txt')) as f:
23 | for line in f:
24 | image_id, class_id = line.split()
25 | class_ids[image_id] = class_id
26 |
27 | train_id = [] # train not val
28 | with open(os.path.join(root, 'train_test_split.txt')) as f:
29 | for line in f:
30 | image_id, is_train = line.split()
31 | if int(is_train):
32 | train_id.append(image_id)
33 |
34 | with open(os.path.join(root, 'images.txt')) as f:
35 | for line in f:
36 | image_id, path = line.split()
37 | if image_id in train_id:
38 | os.system('rm val/{}'.format(path))
39 | else:
40 | # import ipdb
41 | # ipdb.set_trace(context=20)
42 | os.system('rm train/{}'.format(path))
43 |
44 |
45 |
46 |
47 |
48 |
49 | if __name__ == "__main__":
50 | main()
--------------------------------------------------------------------------------
/src/classifier-tuning/src/datasets/transfer_ds/cub_transform.py:
--------------------------------------------------------------------------------
1 | import random
2 | import math
3 | import numbers
4 |
5 | import cv2
6 | import numpy as np
7 |
8 | import torch
9 |
10 |
11 | class Compose:
12 | """Composes several transforms together.
13 |
14 | Args:
15 | transforms(list of 'Transform' object): list of transforms to compose
16 |
17 | """
18 |
19 | def __init__(self, transforms):
20 | self.transforms = transforms
21 |
22 | def __call__(self, img):
23 |
24 | for trans in self.transforms:
25 | img = trans(img)
26 |
27 | return img
28 |
29 | def __repr__(self):
30 | format_string = self.__class__.__name__ + '('
31 | for t in self.transforms:
32 | format_string += '\n'
33 | format_string += ' {0}'.format(t)
34 | format_string += '\n)'
35 | return format_string
36 |
37 |
38 | class ToCVImage:
39 | """Convert an Opencv image to a 3 channel uint8 image
40 | """
41 |
42 | def __call__(self, image):
43 | """
44 | Args:
45 | image (numpy array): Image to be converted to 32-bit floating point
46 |
47 | Returns:
48 | image (numpy array): Converted Image
49 | """
50 | if len(image.shape) == 2:
51 | image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
52 |
53 | image = image.astype('uint8')
54 |
55 | return image
56 |
57 |
58 | class RandomResizedCrop:
59 | """Randomly crop a rectangle region whose aspect ratio is randomly sampled
60 | in [3/4, 4/3] and area randomly sampled in [8%, 100%], then resize the cropped
61 | region into a 224-by-224 square image.
62 |
63 | Args:
64 | size: expected output size of each edge
65 | scale: range of size of the origin size cropped
66 | ratio: range of aspect ratio of the origin aspect ratio cropped (w / h)
67 | interpolation: Default: cv2.INTER_LINEAR:
68 | """
69 |
70 | def __init__(self, size, scale=(0.08, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0), interpolation='linear'):
71 |
72 | self.methods = {
73 | "area": cv2.INTER_AREA,
74 | "nearest": cv2.INTER_NEAREST,
75 | "linear": cv2.INTER_LINEAR,
76 | "cubic": cv2.INTER_CUBIC,
77 | "lanczos4": cv2.INTER_LANCZOS4
78 | }
79 |
80 | self.size = (size, size)
81 | self.interpolation = self.methods[interpolation]
82 | self.scale = scale
83 | self.ratio = ratio
84 |
85 | def __call__(self, img):
86 | h, w, _ = img.shape
87 |
88 | area = w * h
89 |
90 | for attempt in range(10):
91 | target_area = random.uniform(*self.scale) * area
92 | target_ratio = random.uniform(*self.ratio)
93 |
94 | output_h = int(round(math.sqrt(target_area * target_ratio)))
95 | output_w = int(round(math.sqrt(target_area / target_ratio)))
96 |
97 | if random.random() < 0.5:
98 | output_w, output_h = output_h, output_w
99 |
100 | if output_w <= w and output_h <= h:
101 | topleft_x = random.randint(0, w - output_w)
102 | topleft_y = random.randint(0, h - output_h)
103 | break
104 |
105 | if output_w > w or output_h > h:
106 | output_w = min(w, h)
107 | output_h = output_w
108 | topleft_x = random.randint(0, w - output_w)
109 | topleft_y = random.randint(0, h - output_w)
110 |
111 | cropped = img[topleft_y: topleft_y + output_h, topleft_x: topleft_x + output_w]
112 |
113 | resized = cv2.resize(cropped, self.size, interpolation=self.interpolation)
114 |
115 | return resized
116 |
117 | def __repr__(self):
118 | for name, inter in self.methods.items():
119 | if inter == self.interpolation:
120 | inter_name = name
121 |
122 | interpolate_str = inter_name
123 | format_str = self.__class__.__name__ + '(size={0}'.format(self.size)
124 | format_str += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
125 | format_str += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
126 | format_str += ', interpolation={0})'.format(interpolate_str)
127 |
128 | return format_str
129 |
130 |
131 | class RandomHorizontalFlip:
132 | """Horizontally flip the given opencv image with given probability p.
133 |
134 | Args:
135 | p: probability of the image being flipped
136 | """
137 |
138 | def __init__(self, p=0.5):
139 | self.p = p
140 |
141 | def __call__(self, img):
142 | """
143 | Args:
144 | the image to be flipped
145 | Returns:
146 | flipped image
147 | """
148 | if random.random() < self.p:
149 | img = cv2.flip(img, 1)
150 |
151 | return img
152 |
153 |
154 | class ToTensor:
155 | """convert an opencv image (h, w, c) ndarray range from 0 to 255 to a pytorch
156 | float tensor (c, h, w) ranged from 0 to 1
157 | """
158 |
159 | def __call__(self, img):
160 | """
161 | Args:
162 | a numpy array (h, w, c) range from [0, 255]
163 |
164 | Returns:
165 | a pytorch tensor
166 | """
167 | # convert format H W C to C H W
168 | img = img.transpose(2, 0, 1)
169 | img = torch.from_numpy(img)
170 | img = img.float() / 255.0
171 |
172 | return img
173 |
174 |
175 | class Normalize:
176 | """Normalize a torch tensor (H, W, BGR order) with mean and standard deviation
177 |
178 | for each channel in torch tensor:
179 | ``input[channel] = (input[channel] - mean[channel]) / std[channel]``
180 |
181 | Args:
182 | mean: sequence of means for each channel
183 | std: sequence of stds for each channel
184 | """
185 |
186 | def __init__(self, mean, std, inplace=False):
187 | self.mean = mean
188 | self.std = std
189 | self.inplace = inplace
190 |
191 | def __call__(self, img):
192 | """
193 | Args:
194 | (H W C) format numpy array range from [0, 255]
195 | Returns:
196 | (H W C) format numpy array in float32 range from [0, 1]
197 | """
198 | assert torch.is_tensor(img) and img.ndimension() == 3, 'not an image tensor'
199 |
200 | if not self.inplace:
201 | img = img.clone()
202 |
203 | mean = torch.tensor(self.mean, dtype=torch.float32)
204 | std = torch.tensor(self.std, dtype=torch.float32)
205 | img.sub_(mean[:, None, None]).div_(std[:, None, None])
206 |
207 | return img
208 |
209 |
210 | class Resize:
211 |
212 | def __init__(self, resized=256, interpolation='linear'):
213 | methods = {
214 | "area": cv2.INTER_AREA,
215 | "nearest": cv2.INTER_NEAREST,
216 | "linear": cv2.INTER_LINEAR,
217 | "cubic": cv2.INTER_CUBIC,
218 | "lanczos4": cv2.INTER_LANCZOS4
219 | }
220 | self.interpolation = methods[interpolation]
221 |
222 | if isinstance(resized, numbers.Number):
223 | resized = (resized, resized)
224 |
225 | self.resized = resized
226 |
227 | def __call__(self, img):
228 | img = cv2.resize(img, self.resized, interpolation=self.interpolation)
229 |
230 | return img
231 |
--------------------------------------------------------------------------------
/src/classifier-tuning/src/datasets/transfer_ds/dtd.py:
--------------------------------------------------------------------------------
1 | from glob import glob
2 | # from . import constants as cs
3 | from torch.utils.data.dataset import Dataset
4 | from torch.utils.data import DataLoader
5 | from os.path import join as osj
6 | from PIL import Image
7 | from torchvision import transforms
8 | import os
9 |
10 | TRAIN_TRANSFORMS = transforms.Compose([
11 | # transforms.Resize(32),
12 | transforms.RandomResizedCrop(224),
13 | transforms.RandomHorizontalFlip(),
14 | transforms.ToTensor(),
15 | # transforms.Normalize(mean=mean, std=std),
16 | ])
17 |
18 | TEST_TRANSFORMS = transforms.Compose([
19 | # transforms.Resize(32),
20 | transforms.Resize(256),
21 | transforms.CenterCrop(224),
22 | transforms.ToTensor(),
23 | # transforms.Normalize(mean=mean, std=std),
24 | ])
25 |
26 | class DTD(Dataset):
27 | def __init__(self, split="1", train=False, transform=TRAIN_TRANSFORMS):
28 | super().__init__()
29 | DTD_PATH='/opt/tiger/filter_transfer/data/dtd'
30 | train_path = osj(DTD_PATH, f"labels/train{split}.txt")
31 | val_path = osj(DTD_PATH, f"labels/val{split}.txt")
32 | test_path = osj(DTD_PATH, f"labels/test{split}.txt")
33 | if train:
34 | print(DTD_PATH)
35 | self.ims = open(train_path).readlines() + \
36 | open(val_path).readlines()
37 | else:
38 | self.ims = open(test_path).readlines()
39 |
40 | self.full_ims = [osj(DTD_PATH, "images", x) for x in self.ims]
41 |
42 | pth = osj(DTD_PATH, f"labels/classes.txt")
43 | self.c_to_t = {x.strip(): i for i, x in enumerate(open(pth).readlines())}
44 |
45 | # self.transform = TRAIN_TRANSFORMS if train else TEST_TRANSFORMS
46 | self.transform = transform
47 | self.labels = [self.c_to_t[x.split("/")[0]] for x in self.ims]
48 |
49 | def __getitem__(self, index):
50 | im = Image.open(self.full_ims[index].strip())
51 | im = self.transform(im)
52 | return im, self.labels[index]
53 |
54 | def __len__(self):
55 | return len(self.ims)
56 |
57 | if __name__ == "__main__":
58 | dtd = DTD(train=True)
59 | # import ipdb
60 | # ipdb.set_trace(context=20)
61 | target_folder = "/opt/tiger/filter_transfer/data/dtd/mix_dtd/"
62 | for im in dtd.full_ims:
63 | img = im[:-1]
64 | category = img.split('/')[-2]
65 | if not os.path.exists(target_folder+category):os.makedirs(target_folder+category)
66 | os.system('cp {} {}'.format(img, target_folder+category))
67 | a=1
--------------------------------------------------------------------------------
/src/classifier-tuning/src/datasets/transfer_ds/fine_tunify.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from robustness.tools.custom_modules import SequentialWithArgs
3 |
4 | def ft(model_name, model_ft, num_classes, additional_hidden=0):
5 | if model_name in ['clip_resnest50d','resnet50_feat_pca_pre_relu_multi_pool','resnet18_feat_pre_relu_multi_pool','resnet50_feat_pca_pre_relu_multi','resnet18_feat_pre_relu_multi','resnet50_feat_interpolate_multi','resnet18_multi','resnet50_feat_pre_relu_multi','resnet50_overhaul','resnet18_feat_pre_relu_regressor','resnet18_custom','resnet18_feat_pre_relu',"resnet50_feat_interpolate","resnet50_feat_pca","resnet50_feat_nmf","resnet50_feat_lda","resnet50_feat_mag","resnet18_feat","resnet152_feat","resnet50_feat","resnet","resnet20_as_gift","resnet50_clean", "resnet18", "resnet34","resnet50", "wide_resnet50_2", "wide_resnet50_4", "resnext50_32x4d", 'shufflenet']:
6 | num_ftrs = model_ft.fc.in_features
7 | # The two cases are split just to allow loading
8 | # models trained prior to adding the additional_hidden argument
9 | # without errors
10 | if additional_hidden == 0:
11 | model_ft.fc = nn.Linear(num_ftrs, num_classes)
12 | else:
13 | model_ft.fc = SequentialWithArgs(
14 | *list(sum([[nn.Linear(num_ftrs, num_ftrs), nn.ReLU()] for i in range(additional_hidden)], [])),
15 | nn.Linear(num_ftrs, num_classes)
16 | )
17 | input_size = 224
18 |
19 | elif model_name == 'RN50':
20 | num_ftrs = 1024
21 | # model_ft.fc = nn.Linear(num_ftrs, num_classes)
22 | input_size = 224
23 | elif model_name == "alexnet":
24 | num_ftrs = model_ft.classifier[6].in_features
25 | model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
26 | input_size = 224
27 | elif "vgg" in model_name:
28 | num_ftrs = model_ft.classifier[6].in_features
29 | model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
30 | input_size = 224
31 | elif model_name == "squeezenet":
32 | model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1,1), stride=(1,1))
33 | model_ft.num_classes = num_classes
34 | input_size = 224
35 | elif model_name == "densenet":
36 | num_ftrs = model_ft.classifier.in_features
37 | model_ft.classifier = nn.Linear(num_ftrs, num_classes)
38 | input_size = 224
39 | elif model_name in ["mnasnet", "mobilenet"]:
40 | num_ftrs = model_ft.classifier.in_features
41 | model_ft.classifier = nn.Linear(num_ftrs, num_classes)
42 | input_size = 224
43 | else:
44 | pass
45 | # raise ValueError("Invalid model type, exiting...")
46 |
47 | return model_ft
48 |
--------------------------------------------------------------------------------
/src/classifier-tuning/src/datasets/transfer_ds/food_101.py:
--------------------------------------------------------------------------------
1 | # pytorch imports
2 | import torch
3 | from torchvision import models, transforms, datasets
4 | from torch.utils.data import DataLoader
5 | from robustness import data_augmentation as da
6 | from . import constants as cs
7 |
8 | class FOOD101():
9 | def __init__(self, transform=None):
10 | # self.TRAIN_PATH = cs.FOOD_PATH+"/train"
11 | self.TRAIN_PATH = "/opt/tiger/filter_transfer/data/food-101/train"
12 | # self.VALID_PATH = cs.FOOD_PATH+"/valid"
13 | self.VALID_PATH = "/opt/tiger/filter_transfer/data/food-101/valid"
14 |
15 | self.train_ds, self.valid_ds, self.train_cls, self.valid_cls = [None]*4
16 | self.transform = transform
17 |
18 | def _get_tfms(self):
19 | train_tfms = cs.TRAIN_TRANSFORMS
20 | valid_tfms = cs.TEST_TRANSFORMS
21 | return train_tfms, valid_tfms
22 |
23 | def get_dataset(self):
24 | # train_tfms, valid_tfms = self._get_tfms() # transformations
25 | train_tfms, valid_tfms = self.transform, self.transform
26 | self.train_ds = datasets.ImageFolder(root=self.TRAIN_PATH,
27 | transform=train_tfms)
28 | self.valid_ds = datasets.ImageFolder(root=self.VALID_PATH,
29 | transform=valid_tfms)
30 | self.train_classes = self.train_ds.classes
31 | self.valid_classes = self.valid_ds.classes
32 |
33 | # print(self.train_classes)
34 |
35 | assert self.train_classes==self.valid_classes
36 | return self.train_ds, self.valid_ds, self.train_classes
37 |
38 | def get_dls(self, train_ds, valid_ds, bs, **kwargs):
39 | return (DataLoader(train_ds, batch_size=bs, shuffle=True, **kwargs),
40 | DataLoader(valid_ds, batch_size=bs, shuffle=True, **kwargs))
41 |
42 |
43 |
44 |
--------------------------------------------------------------------------------
/src/classifier-tuning/src/datasets/transfer_ds/imbalance_cifar.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | import torchvision.transforms as transforms
4 | import numpy as np
5 |
6 |
7 | class IMBALANCECIFAR10(torchvision.datasets.CIFAR10):
8 | cls_num = 10
9 |
10 | def __init__(self, root, imb_type='exp', imb_factor=0.01, rand_number=0, train=True,
11 | transform=None, target_transform=None,
12 | download=False):
13 | super(IMBALANCECIFAR10, self).__init__(root, train, transform, target_transform, download)
14 | np.random.seed(rand_number)
15 | img_num_list = self.get_img_num_per_cls(self.cls_num, imb_type, imb_factor)
16 | self.gen_imbalanced_data(img_num_list)
17 |
18 | def get_img_num_per_cls(self, cls_num, imb_type, imb_factor):
19 | img_max = len(self.data) / cls_num
20 | img_num_per_cls = []
21 | if imb_type == 'exp':
22 | for cls_idx in range(cls_num):
23 | num = img_max * (imb_factor ** (cls_idx / (cls_num - 1.0)))
24 | img_num_per_cls.append(int(num))
25 | elif imb_type == 'step':
26 | for cls_idx in range(cls_num // 2):
27 | img_num_per_cls.append(int(img_max))
28 | for cls_idx in range(cls_num // 2):
29 | img_num_per_cls.append(int(img_max * imb_factor))
30 | else:
31 | img_num_per_cls.extend([int(img_max)] * cls_num)
32 | return img_num_per_cls
33 |
34 | def gen_imbalanced_data(self, img_num_per_cls):
35 | new_data = []
36 | new_targets = []
37 | targets_np = np.array(self.targets, dtype=np.int64)
38 | classes = np.unique(targets_np)
39 | # np.random.shuffle(classes)
40 | self.num_per_cls_dict = dict()
41 | for the_class, the_img_num in zip(classes, img_num_per_cls):
42 | self.num_per_cls_dict[the_class] = the_img_num
43 | idx = np.where(targets_np == the_class)[0]
44 | np.random.shuffle(idx)
45 | selec_idx = idx[:the_img_num]
46 | new_data.append(self.data[selec_idx, ...])
47 | new_targets.extend([the_class, ] * the_img_num)
48 | new_data = np.vstack(new_data)
49 | self.data = new_data
50 | self.targets = new_targets
51 |
52 | def get_cls_num_list(self):
53 | cls_num_list = []
54 | for i in range(self.cls_num):
55 | cls_num_list.append(self.num_per_cls_dict[i])
56 | return cls_num_list
57 |
58 |
59 | class IMBALANCECIFAR100(IMBALANCECIFAR10):
60 | """`CIFAR100 `_ Dataset.
61 | This is a subclass of the `CIFAR10` Dataset.
62 | """
63 | base_folder = 'cifar-100-python'
64 | url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
65 | filename = "cifar-100-python.tar.gz"
66 | tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
67 | train_list = [
68 | ['train', '16019d7e3df5f24257cddd939b257f8d'],
69 | ]
70 |
71 | test_list = [
72 | ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
73 | ]
74 | meta = {
75 | 'filename': 'meta',
76 | 'key': 'fine_label_names',
77 | 'md5': '7973b15100ade9c7d40fb424638fde48',
78 | }
79 | cls_num = 100
80 |
81 |
82 | if __name__ == '__main__':
83 | transform = transforms.Compose(
84 | [transforms.ToTensor(),
85 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
86 | trainset = IMBALANCECIFAR100(root='./data', train=True,
87 | download=True, transform=transform)
88 | trainloader = iter(trainset)
89 | data, label = next(trainloader)
90 | import pdb;
91 |
92 | pdb.set_trace()
--------------------------------------------------------------------------------
/src/classifier-tuning/src/datasets/transfer_ds/process_dataset/pro_aircraft.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import os
4 | import numpy as np
5 |
6 | # --- get class names
7 | # read_filepath = 'data/variants.txt'
8 | # names=[]
9 | # with open(read_filepath, 'r') as f:
10 | # for line in f.readlines():
11 | # print(line.strip())
12 | # names.append(line.strip())
13 | #
14 | # names = ["a "+n for n in names]
15 | # print(names)
16 |
17 |
18 | # --- find train images
19 | read_filepath = 'data/images_variant_trainval.txt'
20 | names=[]
21 | with open(read_filepath, 'r') as f:
22 | for line in f.readlines():
23 | line=line.strip().split(' ')
24 | img_name = line[0] + '.jpg'
25 | print(img_name)
26 | names.append(img_name)
27 |
28 | # names = ["a "+n for n in names]
29 | print(names)
30 |
31 | for name in names:
32 | os.system('cp data/images/{} ../orig_pool15/aircraft/'.format(name))
--------------------------------------------------------------------------------
/src/classifier-tuning/src/datasets/transfer_ds/process_dataset/pro_caltech101.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import os
4 | import numpy as np
5 |
6 |
7 | # --- find train images
8 | read_filepath = 'data/images_variant_trainval.txt'
9 | names=[]
10 | with open(read_filepath, 'r') as f:
11 | for line in f.readlines():
12 | line=line.strip().split(' ')
13 | img_name = line[0] + '.jpg'
14 | print(img_name)
15 | names.append(img_name)
16 |
17 | # names = ["a "+n for n in names]
18 | print(names)
19 |
20 | for name in names:
21 | os.system('cp data/images/{} ../orig_pool15/aircraft/'.format(name))
22 | os.system('pwd')
--------------------------------------------------------------------------------
/src/classifier-tuning/src/datasets/transfer_ds/process_dataset/pro_cars.py:
--------------------------------------------------------------------------------
1 | # from scipy.io import loadmat
2 | # import pandas as pd
3 | #
4 | # names=[]
5 | #
6 | # label_dir = 'devkit'
7 | # car_annos = loadmat(label_dir+'/cars_meta.mat')
8 | # # import ipdb
9 | # # ipdb.set_trace(context=20)
10 | # labels = [c for c in car_annos['class_names'][0]]
11 | #
12 | # names = ['a '+label.item() for label in labels]
13 | # # names.append(label[i])
14 | # print(names)
15 | # # label = pd.DataFrame(label, columns = ['Model'])
16 | # # print('{} classes '.format(len(label)))
17 | # # label.head(196)
18 | #
19 | # a=1
20 | #
21 | # car_annos['class_names'].keys()
22 | #
23 | # names = [s.split('/[[')[1] for s in all]
24 | # names = [s.split(']]')[0] for s in names]
25 |
26 | car_dict={}
27 | cars = \
28 | ['a AM General Hummer SUV 2000', 'a Acura RL Sedan 2012', 'a Acura TL Sedan 2012', 'a Acura TL Type-S 2008', 'a Acura TSX Sedan 2012', 'a Acura Integra Type R 2001', 'a Acura ZDX Hatchback 2012', 'a Aston Martin V8 Vantage Convertible 2012', 'a Aston Martin V8 Vantage Coupe 2012', 'a Aston Martin Virage Convertible 2012', 'a Aston Martin Virage Coupe 2012', 'a Audi RS 4 Convertible 2008', 'a Audi A5 Coupe 2012', 'a Audi TTS Coupe 2012', 'a Audi R8 Coupe 2012', 'a Audi V8 Sedan 1994', 'a Audi 100 Sedan 1994', 'a Audi 100 Wagon 1994', 'a Audi TT Hatchback 2011', 'a Audi S6 Sedan 2011', 'a Audi S5 Convertible 2012', 'a Audi S5 Coupe 2012', 'a Audi S4 Sedan 2012', 'a Audi S4 Sedan 2007', 'a Audi TT RS Coupe 2012', 'a BMW ActiveHybrid 5 Sedan 2012', 'a BMW 1 Series Convertible 2012', 'a BMW 1 Series Coupe 2012', 'a BMW 3 Series Sedan 2012', 'a BMW 3 Series Wagon 2012', 'a BMW 6 Series Convertible 2007', 'a BMW X5 SUV 2007', 'a BMW X6 SUV 2012', 'a BMW M3 Coupe 2012', 'a BMW M5 Sedan 2010', 'a BMW M6 Convertible 2010', 'a BMW X3 SUV 2012', 'a BMW Z4 Convertible 2012', 'a Bentley Continental Supersports Conv. Convertible 2012', 'a Bentley Arnage Sedan 2009', 'a Bentley Mulsanne Sedan 2011', 'a Bentley Continental GT Coupe 2012', 'a Bentley Continental GT Coupe 2007', 'a Bentley Continental Flying Spur Sedan 2007', 'a Bugatti Veyron 16.4 Convertible 2009', 'a Bugatti Veyron 16.4 Coupe 2009', 'a Buick Regal GS 2012', 'a Buick Rainier SUV 2007', 'a Buick Verano Sedan 2012', 'a Buick Enclave SUV 2012', 'a Cadillac CTS-V Sedan 2012', 'a Cadillac SRX SUV 2012', 'a Cadillac Escalade EXT Crew Cab 2007', 'a Chevrolet Silverado 1500 Hybrid Crew Cab 2012', 'a Chevrolet Corvette Convertible 2012', 'a Chevrolet Corvette ZR1 2012', 'a Chevrolet Corvette Ron Fellows Edition Z06 2007', 'a Chevrolet Traverse SUV 2012', 'a Chevrolet Camaro Convertible 2012', 'a Chevrolet HHR SS 2010', 'a Chevrolet Impala Sedan 2007', 'a Chevrolet Tahoe Hybrid SUV 2012', 'a Chevrolet Sonic Sedan 2012', 'a Chevrolet Express Cargo Van 2007', 'a Chevrolet Avalanche Crew Cab 2012', 'a Chevrolet Cobalt SS 2010', 'a Chevrolet Malibu Hybrid Sedan 2010', 'a Chevrolet TrailBlazer SS 2009', 'a Chevrolet Silverado 2500HD Regular Cab 2012', 'a Chevrolet Silverado 1500 Classic Extended Cab 2007', 'a Chevrolet Express Van 2007', 'a Chevrolet Monte Carlo Coupe 2007', 'a Chevrolet Malibu Sedan 2007', 'a Chevrolet Silverado 1500 Extended Cab 2012', 'a Chevrolet Silverado 1500 Regular Cab 2012', 'a Chrysler Aspen SUV 2009', 'a Chrysler Sebring Convertible 2010', 'a Chrysler Town and Country Minivan 2012', 'a Chrysler 300 SRT-8 2010', 'a Chrysler Crossfire Convertible 2008', 'a Chrysler PT Cruiser Convertible 2008', 'a Daewoo Nubira Wagon 2002', 'a Dodge Caliber Wagon 2012', 'a Dodge Caliber Wagon 2007', 'a Dodge Caravan Minivan 1997', 'a Dodge Ram Pickup 3500 Crew Cab 2010', 'a Dodge Ram Pickup 3500 Quad Cab 2009', 'a Dodge Sprinter Cargo Van 2009', 'a Dodge Journey SUV 2012', 'a Dodge Dakota Crew Cab 2010', 'a Dodge Dakota Club Cab 2007', 'a Dodge Magnum Wagon 2008', 'a Dodge Challenger SRT8 2011', 'a Dodge Durango SUV 2012', 'a Dodge Durango SUV 2007', 'a Dodge Charger Sedan 2012', 'a Dodge Charger SRT-8 2009', 'a Eagle Talon Hatchback 1998', 'a FIAT 500 Abarth 2012', 'a FIAT 500 Convertible 2012', 'a Ferrari FF Coupe 2012', 'a Ferrari California Convertible 2012', 'a Ferrari 458 Italia Convertible 2012', 'a Ferrari 458 Italia Coupe 2012', 'a Fisker Karma Sedan 2012', 'a Ford F-450 Super Duty Crew Cab 2012', 'a Ford Mustang Convertible 2007', 'a Ford Freestar Minivan 2007', 'a Ford Expedition EL SUV 2009', 'a Ford Edge SUV 2012', 'a Ford Ranger SuperCab 2011', 'a Ford GT Coupe 2006', 'a Ford F-150 Regular Cab 2012', 'a Ford F-150 Regular Cab 2007', 'a Ford Focus Sedan 2007', 'a Ford E-Series Wagon Van 2012', 'a Ford Fiesta Sedan 2012', 'a GMC Terrain SUV 2012', 'a GMC Savana Van 2012', 'a GMC Yukon Hybrid SUV 2012', 'a GMC Acadia SUV 2012', 'a GMC Canyon Extended Cab 2012', 'a Geo Metro Convertible 1993', 'a HUMMER H3T Crew Cab 2010', 'a HUMMER H2 SUT Crew Cab 2009', 'a Honda Odyssey Minivan 2012', 'a Honda Odyssey Minivan 2007', 'a Honda Accord Coupe 2012', 'a Honda Accord Sedan 2012', 'a Hyundai Veloster Hatchback 2012', 'a Hyundai Santa Fe SUV 2012', 'a Hyundai Tucson SUV 2012', 'a Hyundai Veracruz SUV 2012', 'a Hyundai Sonata Hybrid Sedan 2012', 'a Hyundai Elantra Sedan 2007', 'a Hyundai Accent Sedan 2012', 'a Hyundai Genesis Sedan 2012', 'a Hyundai Sonata Sedan 2012', 'a Hyundai Elantra Touring Hatchback 2012', 'a Hyundai Azera Sedan 2012', 'a Infiniti G Coupe IPL 2012', 'a Infiniti QX56 SUV 2011', 'a Isuzu Ascender SUV 2008', 'a Jaguar XK XKR 2012', 'a Jeep Patriot SUV 2012', 'a Jeep Wrangler SUV 2012', 'a Jeep Liberty SUV 2012', 'a Jeep Grand Cherokee SUV 2012', 'a Jeep Compass SUV 2012', 'a Lamborghini Reventon Coupe 2008', 'a Lamborghini Aventador Coupe 2012', 'a Lamborghini Gallardo LP 570-4 Superleggera 2012', 'a Lamborghini Diablo Coupe 2001', 'a Land Rover Range Rover SUV 2012', 'a Land Rover LR2 SUV 2012', 'a Lincoln Town Car Sedan 2011', 'a MINI Cooper Roadster Convertible 2012', 'a Maybach Landaulet Convertible 2012', 'a Mazda Tribute SUV 2011', 'a McLaren MP4-12C Coupe 2012', 'a Mercedes-Benz 300-Class Convertible 1993', 'a Mercedes-Benz C-Class Sedan 2012', 'a Mercedes-Benz SL-Class Coupe 2009', 'a Mercedes-Benz E-Class Sedan 2012', 'a Mercedes-Benz S-Class Sedan 2012', 'a Mercedes-Benz Sprinter Van 2012', 'a Mitsubishi Lancer Sedan 2012', 'a Nissan Leaf Hatchback 2012', 'a Nissan NV Passenger Van 2012', 'a Nissan Juke Hatchback 2012', 'a Nissan 240SX Coupe 1998', 'a Plymouth Neon Coupe 1999', 'a Porsche Panamera Sedan 2012', 'a Ram C/V Cargo Van Minivan 2012', 'a Rolls-Royce Phantom Drophead Coupe Convertible 2012', 'a Rolls-Royce Ghost Sedan 2012', 'a Rolls-Royce Phantom Sedan 2012', 'a Scion xD Hatchback 2012', 'a Spyker C8 Convertible 2009', 'a Spyker C8 Coupe 2009', 'a Suzuki Aerio Sedan 2007', 'a Suzuki Kizashi Sedan 2012', 'a Suzuki SX4 Hatchback 2012', 'a Suzuki SX4 Sedan 2012', 'a Tesla Model S Sedan 2012', 'a Toyota Sequoia SUV 2012', 'a Toyota Camry Sedan 2012', 'a Toyota Corolla Sedan 2012', 'a Toyota 4Runner SUV 2012', 'a Volkswagen Golf Hatchback 2012', 'a Volkswagen Golf Hatchback 1991', 'a Volkswagen Beetle Hatchback 2012', 'a Volvo C30 Hatchback 2012', 'a Volvo 240 Sedan 1993', 'a Volvo XC90 SUV 2007', 'a smart fortwo Convertible 2012']
29 |
30 | for j in range(196):
31 | i=j+1
32 | car_dict[str(i)]=cars[j]
33 |
34 | order = ['100', '101', '102', '103', '104', '105', '106', '107', '108', '109', '10', '110', '111', '112', '113', '114', '115', '116', '117', '118', '119', '11', '120', '121', '122', '123', '124', '125', '126', '127', '128', '129', '12', '130', '131', '132', '133', '134', '135', '136', '137', '138', '139', '13', '140', '141', '142', '143', '144', '145', '146', '147', '148', '149', '14', '150', '151', '152', '153', '154', '155', '156', '157', '158', '159', '15', '160', '161', '162', '163', '164', '165', '166', '167', '168', '169', '16', '170', '171', '172', '173', '174', '175', '176', '177', '178', '179', '17', '180', '181', '182', '183', '184', '185', '186', '187', '188', '189', '18', '190', '191', '192', '193', '194', '195', '196', '19', '1', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '2', '30', '31', '32', '33', '34', '35', '36', '37', '38', '39', '3', '40', '41', '42', '43', '44', '45', '46', '47', '48', '49', '4', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '5', '60', '61', '62', '63', '64', '65', '66', '67', '68', '69', '6', '70', '71', '72', '73', '74', '75', '76', '77', '78', '79', '7', '80', '81', '82', '83', '84', '85', '86', '87', '88', '89', '8', '90', '91', '92', '93', '94', '95', '96', '97', '98', '99', '9']
35 |
36 | final_names = [car_dict[v] for v in order]
37 | print(final_names)
--------------------------------------------------------------------------------
/src/classifier-tuning/src/datasets/transfer_ds/process_dataset/pro_flowers.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 |
4 | cat2name = {"21": "fire lily", "3": "canterbury bells", "45": "bolero deep blue", "1": "pink primrose", "34": "mexican aster", "27": "prince of wales feathers", "7": "moon orchid", "16": "globe-flower", "25": "grape hyacinth", "26": "corn poppy", "79": "toad lily", "39": "siam tulip", "24": "red ginger", "67": "spring crocus", "35": "alpine sea holly", "32": "garden phlox", "10": "globe thistle", "6": "tiger lily", "93": "ball moss", "33": "love in the mist", "9": "monkshood", "102": "blackberry lily", "14": "spear thistle", "19": "balloon flower", "100": "blanket flower", "13": "king protea", "49": "oxeye daisy", "15": "yellow iris", "61": "cautleya spicata", "31": "carnation", "64": "silverbush", "68": "bearded iris", "63": "black-eyed susan", "69": "windflower", "62": "japanese anemone", "20": "giant white arum lily", "38": "great masterwort", "4": "sweet pea", "86": "tree mallow", "101": "trumpet creeper", "42": "daffodil", "22": "pincushion flower", "2": "hard-leaved pocket orchid", "54": "sunflower", "66": "osteospermum", "70": "tree poppy", "85": "desert-rose", "99": "bromelia", "87": "magnolia", "5": "english marigold", "92": "bee balm", "28": "stemless gentian", "97": "mallow", "57": "gaura", "40": "lenten rose", "47": "marigold", "59": "orange dahlia", "48": "buttercup", "55": "pelargonium", "36": "ruby-lipped cattleya", "91": "hippeastrum", "29": "artichoke", "71": "gazania", "90": "canna lily", "18": "peruvian lily", "98": "mexican petunia", "8": "bird of paradise", "30": "sweet william", "17": "purple coneflower", "52": "wild pansy", "84": "columbine", "12": "colt's foot", "11": "snapdragon", "96": "camellia", "23": "fritillary", "50": "common dandelion", "44": "poinsettia", "53": "primula", "72": "azalea", "65": "californian poppy", "80": "anthurium", "76": "morning glory", "37": "cape flower", "56": "bishop of llandaff", "60": "pink-yellow dahlia", "82": "clematis", "58": "geranium", "75": "thorn apple", "41": "barbeton daisy", "95": "bougainvillea", "43": "sword lily", "83": "hibiscus", "78": "lotus", "88": "cyclamen", "94": "foxglove", "81": "frangipani", "74": "rose", "89": "watercress", "73": "water lily", "46": "wallflower", "77": "passion flower", "51": "petunia"}
5 |
6 | cates = sorted(glob.glob('*'))[:-1]
7 |
8 |
9 |
10 | names = ["a "+ cat2name[cat] for cat in cates]
11 | import ipdb
12 |
13 | ipdb.set_trace(context=20)
14 | a=1
15 |
16 |
17 | b=[a[3:] for a in al]
18 | al = ["a "+ v for v in al]
--------------------------------------------------------------------------------
/src/classifier-tuning/src/datasets/transfer_ds/process_dataset/pro_imgnet.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 |
4 | src_classes = sorted(glob.glob('ILSVRC2012_img_train/train/*'))
5 | os.system('mkdir -p ILSVRC2012_img_train_plus15tasks/train/img')
6 |
7 | i=0
8 | for src_cls in src_classes:
9 | print(i)
10 | i+=1
11 | os.system('cp {}/* ILSVRC2012_img_train_plus15tasks/train/img'.format(src_cls))
12 |
13 |
--------------------------------------------------------------------------------
/src/classifier-tuning/src/datasets/transfer_ds/process_dataset/pro_pool15.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | # pwd = /opt/tiger/filter_transfer/data
4 | new_ds = 'ds15img'
5 | all_ds = sorted(glob.glob('pool15/*'))
6 | os.system('mkdir {}'.format(new_ds))
7 |
8 | for ds in all_ds:
9 | os.system('mv {}/train/*/* {}'.format(ds, new_ds))
10 |
11 |
--------------------------------------------------------------------------------
/src/classifier-tuning/src/datasets/transfer_ds/process_dataset/process_pets.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 |
4 | names=[]
5 |
6 | all = sorted(glob.glob('*'))
7 | for i in all:
8 | # os.system('mv {} train/'.format(i))
9 | # os.system('mkdir val/{}'.format(i))
10 |
11 | name = sorted(glob.glob('{}/*.jpg'.format(i)))[0]
12 |
13 | # name = 'a '+name.split('/')[-1].split('_')[0]
14 | name = 'a '+name.split('/')[-1].split('_1')[0]
15 | # name = sorted(glob.glob('train/{}/*.jpg'.format(i)))[0].split('_')[0]
16 |
17 | print(name)
18 | names.append(name)
19 | print(names)
20 | # os.system('cp {} val/{}/'.format(to_cp_name,i))
21 | import ipdb
22 | ipdb.set_trace(context=20)
23 | print(names)
24 |
25 |
--------------------------------------------------------------------------------
/src/classifier-tuning/src/datasets/transfer_ds/utils.py:
--------------------------------------------------------------------------------
1 | from . import generic_dataset
2 | from . import food_101
3 |
--------------------------------------------------------------------------------
/src/classifier-tuning/src/datasets/ytbb-robust_metadata/class_idx_map.json:
--------------------------------------------------------------------------------
1 | {
2 | "7": 1,
3 | "8": 1,
4 | "9": 1,
5 | "10": 1,
6 | "11": 1,
7 | "12": 1,
8 | "13": 1,
9 | "14": 1,
10 | "15": 1,
11 | "16": 1,
12 | "17": 1,
13 | "18": 1,
14 | "19": 1,
15 | "20": 1,
16 | "21": 1,
17 | "22": 1,
18 | "23": 1,
19 | "24": 1,
20 | "80": 1,
21 | "81": 1,
22 | "82": 1,
23 | "83": 1,
24 | "84": 1,
25 | "85": 1,
26 | "86": 1,
27 | "87": 1,
28 | "88": 1,
29 | "89": 1,
30 | "90": 1,
31 | "91": 1,
32 | "92": 1,
33 | "93": 1,
34 | "94": 1,
35 | "95": 1,
36 | "96": 1,
37 | "97": 1,
38 | "98": 1,
39 | "99": 1,
40 | "100": 1,
41 | "127": 1,
42 | "128": 1,
43 | "129": 1,
44 | "130": 1,
45 | "131": 1,
46 | "132": 1,
47 | "133": 1,
48 | "134": 1,
49 | "135": 1,
50 | "136": 1,
51 | "137": 1,
52 | "138": 1,
53 | "139": 1,
54 | "140": 1,
55 | "141": 1,
56 | "142": 1,
57 | "143": 1,
58 | "144": 1,
59 | "145": 1,
60 | "146": 1,
61 | "151": 19,
62 | "152": 19,
63 | "153": 19,
64 | "154": 19,
65 | "155": 19,
66 | "156": 19,
67 | "157": 19,
68 | "158": 19,
69 | "159": 19,
70 | "160": 19,
71 | "161": 19,
72 | "162": 19,
73 | "163": 19,
74 | "164": 19,
75 | "165": 19,
76 | "166": 19,
77 | "167": 19,
78 | "168": 19,
79 | "169": 19,
80 | "170": 19,
81 | "171": 19,
82 | "172": 19,
83 | "173": 19,
84 | "174": 19,
85 | "175": 19,
86 | "176": 19,
87 | "177": 19,
88 | "178": 19,
89 | "179": 19,
90 | "180": 19,
91 | "181": 19,
92 | "182": 19,
93 | "183": 19,
94 | "184": 19,
95 | "185": 19,
96 | "186": 19,
97 | "187": 19,
98 | "188": 19,
99 | "189": 19,
100 | "190": 19,
101 | "191": 19,
102 | "192": 19,
103 | "193": 19,
104 | "194": 19,
105 | "195": 19,
106 | "196": 19,
107 | "197": 19,
108 | "198": 19,
109 | "199": 19,
110 | "200": 19,
111 | "201": 19,
112 | "202": 19,
113 | "203": 19,
114 | "204": 19,
115 | "205": 19,
116 | "206": 19,
117 | "207": 19,
118 | "208": 19,
119 | "209": 19,
120 | "210": 19,
121 | "211": 19,
122 | "212": 19,
123 | "213": 19,
124 | "214": 19,
125 | "215": 19,
126 | "216": 19,
127 | "217": 19,
128 | "218": 19,
129 | "219": 19,
130 | "220": 19,
131 | "221": 19,
132 | "222": 19,
133 | "223": 19,
134 | "224": 19,
135 | "225": 19,
136 | "226": 19,
137 | "227": 19,
138 | "228": 19,
139 | "229": 19,
140 | "230": 19,
141 | "231": 19,
142 | "232": 19,
143 | "233": 19,
144 | "234": 19,
145 | "235": 19,
146 | "236": 19,
147 | "237": 19,
148 | "238": 19,
149 | "239": 19,
150 | "240": 19,
151 | "241": 19,
152 | "242": 19,
153 | "243": 19,
154 | "244": 19,
155 | "245": 19,
156 | "246": 19,
157 | "247": 19,
158 | "248": 19,
159 | "249": 19,
160 | "250": 19,
161 | "251": 19,
162 | "252": 19,
163 | "253": 19,
164 | "254": 19,
165 | "255": 19,
166 | "256": 19,
167 | "257": 19,
168 | "258": 19,
169 | "259": 19,
170 | "260": 19,
171 | "261": 19,
172 | "262": 19,
173 | "263": 19,
174 | "264": 19,
175 | "265": 19,
176 | "266": 19,
177 | "267": 19,
178 | "268": 19,
179 | "281": 7,
180 | "282": 7,
181 | "283": 7,
182 | "284": 7,
183 | "285": 7,
184 | "286": 7,
185 | "287": 7,
186 | "294": 5,
187 | "295": 5,
188 | "296": 5,
189 | "297": 5,
190 | "339": 10,
191 | "340": 17,
192 | "385": 20,
193 | "386": 20,
194 | "404": 13,
195 | "407": 23,
196 | "436": 23,
197 | "444": 2,
198 | "466": 15,
199 | "468": 23,
200 | "472": 3,
201 | "499": 12,
202 | "511": 23,
203 | "554": 3,
204 | "555": 16,
205 | "569": 16,
206 | "576": 3,
207 | "609": 23,
208 | "623": 12,
209 | "625": 3,
210 | "627": 23,
211 | "654": 4,
212 | "656": 23,
213 | "661": 23,
214 | "665": 11,
215 | "671": 2,
216 | "675": 16,
217 | "717": 16,
218 | "734": 16,
219 | "751": 23,
220 | "779": 4,
221 | "814": 3,
222 | "817": 23,
223 | "864": 16,
224 | "867": 16,
225 | "874": 4,
226 | "879": 21,
227 | "914": 3,
228 | "981": 0,
229 | "982": 0,
230 | "983": 0,
231 | "985": 9,
232 | "986": 9
233 | }
--------------------------------------------------------------------------------
/src/classifier-tuning/src/datasets/ytbb-robust_metadata/rev_class_idx_map.json:
--------------------------------------------------------------------------------
1 | {
2 | "0": [
3 | 981,
4 | 982,
5 | 983
6 | ],
7 | "1": [
8 | 7,
9 | 8,
10 | 9,
11 | 10,
12 | 11,
13 | 12,
14 | 13,
15 | 14,
16 | 15,
17 | 16,
18 | 17,
19 | 18,
20 | 19,
21 | 20,
22 | 21,
23 | 22,
24 | 23,
25 | 24,
26 | 80,
27 | 81,
28 | 82,
29 | 83,
30 | 84,
31 | 85,
32 | 86,
33 | 87,
34 | 88,
35 | 89,
36 | 90,
37 | 91,
38 | 92,
39 | 93,
40 | 94,
41 | 95,
42 | 96,
43 | 97,
44 | 98,
45 | 99,
46 | 100,
47 | 127,
48 | 128,
49 | 129,
50 | 130,
51 | 131,
52 | 132,
53 | 133,
54 | 134,
55 | 135,
56 | 136,
57 | 137,
58 | 138,
59 | 139,
60 | 140,
61 | 141,
62 | 142,
63 | 143,
64 | 144,
65 | 145,
66 | 146
67 | ],
68 | "2": [
69 | 444,
70 | 671
71 | ],
72 | "3": [
73 | 472,
74 | 554,
75 | 576,
76 | 625,
77 | 814,
78 | 914
79 | ],
80 | "4": [
81 | 654,
82 | 779,
83 | 874
84 | ],
85 | "5": [
86 | 294,
87 | 295,
88 | 296,
89 | 297
90 | ],
91 | "7": [
92 | 281,
93 | 282,
94 | 283,
95 | 284,
96 | 285,
97 | 286,
98 | 287
99 | ],
100 | "9": [
101 | 985,
102 | 986
103 | ],
104 | "10": [
105 | 339
106 | ],
107 | "11": [
108 | 665
109 | ],
110 | "12": [
111 | 499,
112 | 623
113 | ],
114 | "13": [
115 | 404
116 | ],
117 | "15": [
118 | 466
119 | ],
120 | "16": [
121 | 555,
122 | 569,
123 | 675,
124 | 717,
125 | 734,
126 | 864,
127 | 867
128 | ],
129 | "17": [
130 | 340
131 | ],
132 | "19": [
133 | 151,
134 | 152,
135 | 153,
136 | 154,
137 | 155,
138 | 156,
139 | 157,
140 | 158,
141 | 159,
142 | 160,
143 | 161,
144 | 162,
145 | 163,
146 | 164,
147 | 165,
148 | 166,
149 | 167,
150 | 168,
151 | 169,
152 | 170,
153 | 171,
154 | 172,
155 | 173,
156 | 174,
157 | 175,
158 | 176,
159 | 177,
160 | 178,
161 | 179,
162 | 180,
163 | 181,
164 | 182,
165 | 183,
166 | 184,
167 | 185,
168 | 186,
169 | 187,
170 | 188,
171 | 189,
172 | 190,
173 | 191,
174 | 192,
175 | 193,
176 | 194,
177 | 195,
178 | 196,
179 | 197,
180 | 198,
181 | 199,
182 | 200,
183 | 201,
184 | 202,
185 | 203,
186 | 204,
187 | 205,
188 | 206,
189 | 207,
190 | 208,
191 | 209,
192 | 210,
193 | 211,
194 | 212,
195 | 213,
196 | 214,
197 | 215,
198 | 216,
199 | 217,
200 | 218,
201 | 219,
202 | 220,
203 | 221,
204 | 222,
205 | 223,
206 | 224,
207 | 225,
208 | 226,
209 | 227,
210 | 228,
211 | 229,
212 | 230,
213 | 231,
214 | 232,
215 | 233,
216 | 234,
217 | 235,
218 | 236,
219 | 237,
220 | 238,
221 | 239,
222 | 240,
223 | 241,
224 | 242,
225 | 243,
226 | 244,
227 | 245,
228 | 246,
229 | 247,
230 | 248,
231 | 249,
232 | 250,
233 | 251,
234 | 252,
235 | 253,
236 | 254,
237 | 255,
238 | 256,
239 | 257,
240 | 258,
241 | 259,
242 | 260,
243 | 261,
244 | 262,
245 | 263,
246 | 264,
247 | 265,
248 | 266,
249 | 267,
250 | 268
251 | ],
252 | "20": [
253 | 385,
254 | 386
255 | ],
256 | "21": [
257 | 879
258 | ],
259 | "23": [
260 | 407,
261 | 436,
262 | 468,
263 | 511,
264 | 609,
265 | 627,
266 | 656,
267 | 661,
268 | 751,
269 | 817
270 | ]
271 | }
--------------------------------------------------------------------------------
/src/classifier-tuning/src/datasets/ytbb-robust_metadata/ytbb_class_index.json:
--------------------------------------------------------------------------------
1 | {
2 | "0": "person",
3 | "1": "bird",
4 | "2": "bicycle",
5 | "3": "boat",
6 | "4": "bus",
7 | "5": "bear",
8 | "6": "cow",
9 | "7": "cat",
10 | "8": "giraffe",
11 | "9": "potted plant",
12 | "10": "horse",
13 | "11": "motorcycle",
14 | "12": "knife",
15 | "13": "airplane",
16 | "14": "skateboard",
17 | "15": "train",
18 | "16": "truck",
19 | "17": "zebra",
20 | "18": "toilet",
21 | "19": "dog",
22 | "20": "elephant",
23 | "21": "umbrella",
24 | "22": "none",
25 | "23": "car"
26 | }
27 |
--------------------------------------------------------------------------------
/src/classifier-tuning/src/get_classifier_weights.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 |
5 | import torch
6 |
7 |
8 | from src.models.modeling import ClassificationHead, ImageEncoder, ImageClassifier
9 | # from ..src.models.modeling import ClassificationHead, ImageEncoder, ImageClassifier
10 |
11 |
12 | # load_path = '/opt/tiger/filter_transfer/src/wise-ft/results/extest.r1/zeroshotEurosat.pt'
13 | # load_path = '/opt/tiger/filter_transfer/pretrained-models/clip_few_shot/eurosat/syn_init_tfda_55.68.pt'
14 | # load_path = '/opt/tiger/filter_transfer/pretrained-models/clip_few_shot/eurosat/syn_ref16.1_iters50_71.72.pt'
15 | # load_path = '/opt/tiger/filter_transfer/pretrained-models/clip_few_shot/eurosat/real_init_16.1_tfda_86.85.pt'
16 | # load_path = '/opt/tiger/filter_transfer/pretrained-models/clip_few_shot/eurosat/mix_16.1_iters50_88.21.pt'
17 | # load_path = '/opt/tiger/filter_transfer/pretrained-models/clip_few_shot/eurosat/mix_16.1_iters20_88.86.pt'
18 | # load_path = '/opt/tiger/filter_transfer/pretrained-models/clip_few_shot/eurosat/mix_16.1_v4_20k_87.86.pt'
19 | load_path = '/opt/tiger/filter_transfer/pretrained-models/clip_few_shot/eurosat/mix_4.1_iter50_81.72.pt'
20 | image_classifier = ImageClassifier.load(load_path)
21 |
22 | # import ipdb
23 | # ipdb.set_trace(context=20)
24 |
25 | head = image_classifier.classification_head
26 | weights = head.weight.detach().numpy()
27 |
28 | torch.save(weights, '/opt/tiger/filter_transfer/src/wise-ft/cache/Eurosat/mix_4.1_iter50_81.72_weights.pt')
29 | a=0
30 |
--------------------------------------------------------------------------------
/src/classifier-tuning/src/models/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CVMI-Lab/SyntheticData/062ac23b64aaa89020e1e41e5ca9d7719688d589/src/classifier-tuning/src/models/.DS_Store
--------------------------------------------------------------------------------
/src/classifier-tuning/src/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CVMI-Lab/SyntheticData/062ac23b64aaa89020e1e41e5ca9d7719688d589/src/classifier-tuning/src/models/__init__.py
--------------------------------------------------------------------------------
/src/classifier-tuning/src/models/eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 |
4 | import torch
5 | import numpy as np
6 |
7 | from src.models import utils
8 | from src.datasets.common import get_dataloader, maybe_dictionarize
9 |
10 | import src.datasets as datasets
11 |
12 |
13 | def eval_single_dataset(image_classifier, dataset, args):
14 | if args.freeze_encoder and not args.data_aug_lp:
15 | model = image_classifier.classification_head
16 | input_key = 'features'
17 | image_enc = image_classifier.image_encoder
18 | else:
19 | model = image_classifier
20 | input_key = 'images'
21 | image_enc = None
22 |
23 | model.eval()
24 | dataloader = get_dataloader(
25 | dataset, is_train=False, args=args, image_encoder=image_enc)
26 | batched_data = enumerate(dataloader)
27 | device = args.device
28 |
29 | if hasattr(dataset, 'post_loop_metrics'):
30 | # keep track of labels, predictions and metadata
31 | all_labels, all_preds, all_metadata = [], [], []
32 |
33 | with torch.no_grad():
34 | top1, correct, n = 0., 0., 0.
35 | for i, data in batched_data:
36 | data = maybe_dictionarize(data)
37 | x = data[input_key].to(device)
38 | y = data['labels'].to(device)
39 |
40 | if 'image_paths' in data:
41 | image_paths = data['image_paths']
42 |
43 | logits = utils.get_logits(x, model)
44 | projection_fn = getattr(dataset, 'project_logits', None)
45 | if projection_fn is not None:
46 | logits = projection_fn(logits, device)
47 |
48 | if hasattr(dataset, 'project_labels'):
49 | y = dataset.project_labels(y, device)
50 | pred = logits.argmax(dim=1, keepdim=True).to(device)
51 | if hasattr(dataset, 'accuracy'):
52 | acc1, num_total = dataset.accuracy(logits, y, image_paths, args)
53 | correct += acc1
54 | n += num_total
55 | else:
56 | correct += pred.eq(y.view_as(pred)).sum().item()
57 | n += y.size(0)
58 |
59 | if hasattr(dataset, 'post_loop_metrics'):
60 | all_labels.append(y.cpu().clone().detach())
61 | all_preds.append(logits.cpu().clone().detach())
62 | metadata = data['metadata'] if 'metadata' in data else image_paths
63 | all_metadata.extend(metadata)
64 |
65 | top1 = correct / n
66 |
67 | if hasattr(dataset, 'post_loop_metrics'):
68 | all_labels = torch.cat(all_labels)
69 | all_preds = torch.cat(all_preds)
70 | metrics = dataset.post_loop_metrics(all_labels, all_preds, all_metadata, args)
71 | if 'acc' in metrics:
72 | metrics['top1'] = metrics['acc']
73 | else:
74 | metrics = {}
75 | if 'top1' not in metrics:
76 | metrics['top1'] = top1
77 |
78 | return metrics
79 |
80 | def evaluate(image_classifier, args):
81 | if args.eval_datasets is None:
82 | return
83 | info = vars(args)
84 | for i, dataset_name in enumerate(args.eval_datasets):
85 | print('Evaluating on', dataset_name)
86 | dataset_class = getattr(datasets, dataset_name)
87 | dataset = dataset_class(
88 | image_classifier.val_preprocess,
89 | location=args.data_location,
90 | batch_size=args.batch_size
91 | )
92 |
93 | results = eval_single_dataset(image_classifier, dataset, args)
94 |
95 | if 'top1' in results:
96 | print(f"{dataset_name} Top-1 accuracy: {results['top1']:.4f}")
97 | for key, val in results.items():
98 | if 'worst' in key or 'f1' in key.lower() or 'pm0' in key:
99 | print(f"{dataset_name} {key}: {val:.4f}")
100 | info[dataset_name + ':' + key] = val
101 |
102 | if args.results_db is not None:
103 | dirname = os.path.dirname(args.results_db)
104 | if dirname:
105 | os.makedirs(dirname, exist_ok=True)
106 | with open(args.results_db, 'a+') as f:
107 | f.write(json.dumps(info) + '\n')
108 | print(f'Results saved to {args.results_db}.')
109 | else:
110 | print('Results not saved (to do so, use --results_db to specify a path).')
111 |
112 | return info
--------------------------------------------------------------------------------
/src/classifier-tuning/src/models/modeling.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import copy
3 |
4 | import clip.clip as clip
5 |
6 | from src.models import utils
7 |
8 |
9 | class ImageEncoder(torch.nn.Module):
10 | def __init__(self, args, keep_lang=False):
11 | super().__init__()
12 |
13 | self.model, self.train_preprocess, self.val_preprocess = clip.load(
14 | args.model, args.device, jit=False)
15 |
16 | self.cache_dir = args.cache_dir
17 |
18 | if not keep_lang and hasattr(self.model, 'transformer'):
19 | delattr(self.model, 'transformer')
20 |
21 | def forward(self, images):
22 | assert self.model is not None
23 | return self.model.encode_image(images)
24 |
25 | def save(self, filename):
26 | print(f'Saving image encoder to {filename}')
27 | utils.torch_save(self, filename)
28 |
29 | @classmethod
30 | def load(cls, filename):
31 | print(f'Loading image encoder from {filename}')
32 | return utils.torch_load(filename)
33 |
34 |
35 | class ClassificationHead(torch.nn.Linear):
36 | def __init__(self, normalize, weights, biases=None):
37 | output_size, input_size = weights.shape
38 | super().__init__(input_size, output_size)
39 | self.normalize = normalize
40 | if weights is not None:
41 | self.weight = torch.nn.Parameter(weights.clone())
42 | if biases is not None:
43 | self.bias = torch.nn.Parameter(biases.clone())
44 | else:
45 | self.bias = torch.nn.Parameter(torch.zeros_like(self.bias))
46 |
47 | def forward(self, inputs):
48 | if self.normalize:
49 | inputs = inputs / inputs.norm(dim=-1, keepdim=True)
50 | return super().forward(inputs)
51 |
52 | def save(self, filename):
53 | print(f'Saving classification head to {filename}')
54 | utils.torch_save(self, filename)
55 |
56 | @classmethod
57 | def load(cls, filename):
58 | print(f'Loading classification head from {filename}')
59 | return utils.torch_load(filename)
60 |
61 |
62 | class ImageClassifier(torch.nn.Module):
63 | def __init__(self, image_encoder, classification_head, process_images=True):
64 | super().__init__()
65 | self.image_encoder = image_encoder
66 | self.classification_head = classification_head
67 | self.process_images = process_images
68 | if self.image_encoder is not None:
69 | self.train_preprocess = self.image_encoder.train_preprocess
70 | self.val_preprocess = self.image_encoder.val_preprocess
71 |
72 | def forward(self, inputs):
73 | if self.process_images:
74 | inputs = self.image_encoder(inputs)
75 | outputs = self.classification_head(inputs)
76 | return outputs
77 |
78 | def save(self, filename):
79 | print(f'Saving image classifier to {filename}')
80 | utils.torch_save(self, filename)
81 |
82 | @classmethod
83 | def load(cls, filename):
84 | print(f'Loading image classifier from {filename}')
85 | return utils.torch_load(filename)
86 |
--------------------------------------------------------------------------------
/src/classifier-tuning/src/models/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | import pickle
5 | from tqdm import tqdm
6 | import math
7 | import torch.nn.functional as F
8 | import numpy as np
9 |
10 |
11 | def assign_learning_rate(param_group, new_lr):
12 | param_group["lr"] = new_lr
13 |
14 |
15 | def _warmup_lr(base_lr, warmup_length, step):
16 | return base_lr * (step + 1) / warmup_length
17 |
18 |
19 | def cosine_lr(optimizer, base_lrs, warmup_length, steps):
20 | if not isinstance(base_lrs, list):
21 | base_lrs = [base_lrs for _ in optimizer.param_groups]
22 | assert len(base_lrs) == len(optimizer.param_groups)
23 | def _lr_adjuster(step):
24 | for param_group, base_lr in zip(optimizer.param_groups, base_lrs):
25 | if step < warmup_length:
26 | lr = _warmup_lr(base_lr, warmup_length, step)
27 | else:
28 | e = step - warmup_length
29 | es = steps - warmup_length
30 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
31 | assign_learning_rate(param_group, lr)
32 | return _lr_adjuster
33 |
34 | def cosine_lr_const_warmup(optimizer, base_lrs, warmup_length, steps, warmup_const=1e-5):
35 | if not isinstance(base_lrs, list):
36 | base_lrs = [base_lrs for _ in optimizer.param_groups]
37 | assert len(base_lrs) == len(optimizer.param_groups)
38 | def _lr_adjuster(step):
39 | for param_group, base_lr in zip(optimizer.param_groups, base_lrs):
40 | if step < warmup_length:
41 | lr = warmup_const
42 | else:
43 | e = step - warmup_length
44 | es = steps - warmup_length
45 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
46 | assign_learning_rate(param_group, lr)
47 | return _lr_adjuster
48 |
49 | def get_cosine_lr_value(optimizer, base_lrs, warmup_length, steps):
50 | if not isinstance(base_lrs, list):
51 | base_lrs = [base_lrs for _ in optimizer.param_groups]
52 | assert len(base_lrs) == len(optimizer.param_groups)
53 | def _lr_adjuster(step):
54 | for param_group, base_lr in zip(optimizer.param_groups, base_lrs):
55 | if step < warmup_length:
56 | lr = _warmup_lr(base_lr, warmup_length, step)
57 | else:
58 | e = step - warmup_length
59 | es = steps - warmup_length
60 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
61 | assign_learning_rate(param_group, lr)
62 | return _lr_adjuster
63 |
64 | def accuracy(output, target, topk=(1,)):
65 | pred = output.topk(max(topk), 1, True, True)[1].t()
66 | correct = pred.eq(target.view(1, -1).expand_as(pred))
67 | return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]
68 |
69 |
70 | def torch_save(classifier, save_path):
71 | if os.path.dirname(save_path) != '':
72 | os.makedirs(os.path.dirname(save_path), exist_ok=True)
73 | with open(save_path, 'wb') as f:
74 | pickle.dump(classifier.cpu(), f)
75 |
76 |
77 | def torch_load(save_path, device=None):
78 | with open(save_path, 'rb') as f:
79 | classifier = pickle.load(f)
80 | if device is not None:
81 | classifier = classifier.to(device)
82 | return classifier
83 |
84 |
85 | def get_logits(inputs, classifier):
86 | assert callable(classifier)
87 | if hasattr(classifier, 'to'):
88 | classifier = classifier.to(inputs.device)
89 | return classifier(inputs)
90 |
91 |
92 | def get_probs(inputs, classifier):
93 | if hasattr(classifier, 'predict_proba'):
94 | probs = classifier.predict_proba(inputs.detach().cpu().numpy())
95 | return torch.from_numpy(probs)
96 | logits = get_logits(inputs, classifier)
97 | return logits.softmax(dim=1)
98 |
99 |
100 | class LabelSmoothing(torch.nn.Module):
101 | def __init__(self, smoothing=0.0):
102 | super(LabelSmoothing, self).__init__()
103 | self.confidence = 1.0 - smoothing
104 | self.smoothing = smoothing
105 |
106 | def forward(self, x, target):
107 | # import ipdb
108 | # ipdb.set_trace(context=20)
109 | logprobs = F.log_softmax(x, dim=-1) # B C
110 |
111 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) # B 1
112 | nll_loss = nll_loss.squeeze(1) # B
113 | smooth_loss = -logprobs.mean(dim=-1) # B
114 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss
115 | return loss.mean()
116 |
117 | class SoftTargetCrossEntropy_T(torch.nn.Module):
118 | '''
119 | from timm, abandon
120 | '''
121 |
122 | def __init__(self, T):
123 | super(SoftTargetCrossEntropy_T, self).__init__()
124 | self.T = T
125 |
126 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
127 | '''
128 | Args:
129 | x: student logit
130 | target: teacher logit
131 |
132 | Returns:
133 |
134 | '''
135 | soft_labels = torch.softmax(target/self.T, dim=1)
136 | loss = torch.sum(-soft_labels * F.log_softmax(x, dim=-1), dim=-1)
137 | # loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1)
138 | return loss.mean()
139 |
140 | class SoftTargetCrossEntropy(torch.nn.Module):
141 | '''
142 | from timm
143 | '''
144 |
145 | def __init__(self):
146 | super(SoftTargetCrossEntropy, self).__init__()
147 |
148 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
149 | '''
150 | Args:
151 | x: student logit
152 | target: teacher logit
153 |
154 | Returns:
155 |
156 | '''
157 | # soft_labels = torch.softmax(target/self.T, dim=1)
158 | # loss = torch.sum(-soft_labels * F.log_softmax(x, dim=-1), dim=-1)
159 | loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1)
160 | return loss.mean()
161 |
162 | class KD_loss(torch.nn.Module):
163 | """Distilling the Knowledge in a Neural Network"""
164 | def __init__(self, T):
165 | super(KD_loss, self).__init__()
166 | self.T = T
167 |
168 | def forward(self, y_s, y_t):
169 | p_s = F.log_softmax(y_s/self.T, dim=1)
170 | p_t = F.softmax(y_t/self.T, dim=1)
171 | loss = F.kl_div(p_s, p_t, size_average=False) * (self.T**2) / y_s.shape[0]
172 | return loss
173 |
174 |
175 | class Hard_pseudo_label_CE(torch.nn.Module):
176 | '''
177 | from timm
178 | '''
179 |
180 | def __init__(self):
181 | super(Hard_pseudo_label_CE, self).__init__()
182 | self.loss_fn = torch.nn.CrossEntropyLoss()
183 |
184 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
185 | '''
186 | Args:
187 | x: student logit
188 | target: teacher logit
189 |
190 | Returns:
191 |
192 | '''
193 | hard_pseudo_labels = torch.argmax(torch.softmax(target, dim=1), dim=1)
194 | loss = self.loss_fn(x, hard_pseudo_labels)
195 | # loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1)
196 | return loss
197 |
198 | def onehot(labels: torch.Tensor, label_num):
199 | return torch.zeros(labels.shape[0], label_num, device=labels.device).scatter_(1, labels.view(-1, 1), 1)
200 |
201 | class Poly1_cross_entropy(torch.nn.Module):
202 | '''
203 | https://arxiv.org/pdf/2204.12511.pdf
204 | '''
205 |
206 | def __init__(self, epsilon=1.0):
207 | super(Poly1_cross_entropy, self).__init__()
208 | self.loss_fn = torch.nn.CrossEntropyLoss(reduction='none')
209 | self.epsilon=epsilon
210 |
211 | def forward(self, logits, labels):
212 | # pt, CE, and Poly1 have shape [batch].
213 | # import ipdb
214 | # ipdb.set_trace(context=20)
215 | onehot_labels = onehot(labels, logits.shape[-1])
216 | pt = torch.sum(onehot_labels * F.softmax(logits, dim=1), dim=-1)
217 | # CE = tf.nn.softmax_cross_entropy_with_logits(labels, logits)
218 | CE = self.loss_fn(logits, labels)
219 | Poly1 = CE + self.epsilon * (1 - pt)
220 | return Poly1.mean()
--------------------------------------------------------------------------------
/src/classifier-tuning/src/models/zeroshot.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from tqdm import tqdm
5 |
6 | import numpy as np
7 |
8 | import clip.clip as clip
9 |
10 | import src.templates as templates
11 | import src.datasets as datasets
12 |
13 | from src.args import parse_arguments
14 | from src.models.modeling import ClassificationHead, ImageEncoder, ImageClassifier
15 | from src.models.eval import evaluate
16 |
17 |
18 | def get_classnames_zs(args, clip_model):
19 | assert args.template is not None
20 | assert args.train_dataset is not None
21 | template = getattr(templates, args.template)
22 | logit_scale = clip_model.logit_scale
23 | dataset_class = getattr(datasets, args.train_dataset)
24 | dataset = dataset_class(
25 | None,
26 | location=args.data_location,
27 | batch_size=args.batch_size,
28 | classnames=args.classnames
29 | )
30 | classnames = dataset.classnames
31 | return classnames
32 |
33 |
34 |
35 | def get_zeroshot_classifier(args, clip_model):
36 | assert args.template is not None
37 | assert args.train_dataset is not None
38 | template = getattr(templates, args.template)
39 | logit_scale = clip_model.logit_scale
40 | dataset_class = getattr(datasets, args.train_dataset)
41 | dataset = dataset_class(
42 | None,
43 | location=args.data_location,
44 | batch_size=args.batch_size,
45 | classnames=args.classnames
46 | )
47 | device = args.device
48 | clip_model.eval()
49 | clip_model.to(device)
50 |
51 | print('Getting zeroshot weights.')
52 | with torch.no_grad():
53 | zeroshot_weights = []
54 |
55 | for classname in tqdm(dataset.classnames):
56 | texts = []
57 | for t in template:
58 | texts.append(t(classname))
59 | texts = clip.tokenize(texts).to(device) # tokenize
60 | embeddings = clip_model.encode_text(texts) # embed with text encoder
61 | embeddings /= embeddings.norm(dim=-1, keepdim=True)
62 |
63 | embeddings = embeddings.mean(dim=0, keepdim=True)
64 | embeddings /= embeddings.norm()
65 |
66 | zeroshot_weights.append(embeddings)
67 |
68 | zeroshot_weights = torch.stack(zeroshot_weights, dim=0).to(device)
69 | zeroshot_weights = torch.transpose(zeroshot_weights, 0, 2)
70 |
71 | zeroshot_weights *= logit_scale.exp()
72 |
73 | zeroshot_weights = zeroshot_weights.squeeze().float()
74 | zeroshot_weights = torch.transpose(zeroshot_weights, 0, 1)
75 | # import ipdb
76 | # ipdb.set_trace(context=20)
77 | classification_head = ClassificationHead(normalize=True, weights=zeroshot_weights)
78 |
79 | return classification_head
80 |
81 |
82 | def eval(args):
83 | args.freeze_encoder = True
84 | if args.load is not None:
85 | classifier = ImageClassifier.load(args.load)
86 | else:
87 | image_encoder = ImageEncoder(args, keep_lang=True)
88 | classification_head = get_zeroshot_classifier(args, image_encoder.model)
89 | delattr(image_encoder.model, 'transformer')
90 | classifier = ImageClassifier(image_encoder, classification_head, process_images=False)
91 |
92 | evaluate(classifier, args)
93 |
94 | if args.save is not None:
95 | classifier.save(args.save)
96 |
97 |
98 | if __name__ == '__main__':
99 | args = parse_arguments()
100 | eval(args)
--------------------------------------------------------------------------------
/src/classifier-tuning/src/templates/__init__.py:
--------------------------------------------------------------------------------
1 | from .openai_imagenet_template import openai_imagenet_template
2 | from .simple_template import simple_template
3 | from .fmow_template import fmow_template
4 | from .iwildcam_template import iwildcam_template
5 | from .transfer_ds_template import *
--------------------------------------------------------------------------------
/src/classifier-tuning/src/templates/fmow_template.py:
--------------------------------------------------------------------------------
1 | from .utils import append_proper_article, get_plural
2 |
3 | fmow_template = [
4 | lambda c : f"satellite photo of a {c}.",
5 | lambda c : f"aerial photo of a {c}.",
6 | lambda c : f"satellite photo of {append_proper_article(c)}.",
7 | lambda c : f"aerial photo of {append_proper_article(c)}.",
8 | lambda c : f"satellite photo of a {c} in asia.",
9 | lambda c : f"aerial photo of a {c} in asia.",
10 | lambda c : f"satellite photo of a {c} in africa.",
11 | lambda c : f"aerial photo of a {c} in africa.",
12 | lambda c : f"satellite photo of a {c} in the americas.",
13 | lambda c : f"aerial photo of a {c} in the americas.",
14 | lambda c : f"satellite photo of a {c} in europe.",
15 | lambda c : f"aerial photo of a {c} in europe.",
16 | lambda c : f"satellite photo of a {c} in oceania.",
17 | lambda c : f"aerial photo of a {c} in oceania.",
18 | lambda c: f"a photo of a {c}.",
19 | lambda c: f"{c}.",
20 | ]
21 |
--------------------------------------------------------------------------------
/src/classifier-tuning/src/templates/iwildcam_template.py:
--------------------------------------------------------------------------------
1 | from .utils import append_proper_article, get_plural
2 |
3 | iwildcam_template = [
4 | lambda c: f"a photo of {c}.",
5 | lambda c: f"{c} in the wild.",
6 | ]
--------------------------------------------------------------------------------
/src/classifier-tuning/src/templates/openai_imagenet_template.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | openai_imagenet_template = [
5 | lambda c: f'a bad photo of a {c}.',
6 | lambda c: f'a photo of many {c}.',
7 | lambda c: f'a sculpture of a {c}.',
8 | lambda c: f'a photo of the hard to see {c}.',
9 | lambda c: f'a low resolution photo of the {c}.',
10 | lambda c: f'a rendering of a {c}.',
11 | lambda c: f'graffiti of a {c}.',
12 | lambda c: f'a bad photo of the {c}.',
13 | lambda c: f'a cropped photo of the {c}.',
14 | lambda c: f'a tattoo of a {c}.',
15 | lambda c: f'the embroidered {c}.',
16 | lambda c: f'a photo of a hard to see {c}.',
17 | lambda c: f'a bright photo of a {c}.',
18 | lambda c: f'a photo of a clean {c}.',
19 | lambda c: f'a photo of a dirty {c}.',
20 | lambda c: f'a dark photo of the {c}.',
21 | lambda c: f'a drawing of a {c}.',
22 | lambda c: f'a photo of my {c}.',
23 | lambda c: f'the plastic {c}.',
24 | lambda c: f'a photo of the cool {c}.',
25 | lambda c: f'a close-up photo of a {c}.',
26 | lambda c: f'a black and white photo of the {c}.',
27 | lambda c: f'a painting of the {c}.',
28 | lambda c: f'a painting of a {c}.',
29 | lambda c: f'a pixelated photo of the {c}.',
30 | lambda c: f'a sculpture of the {c}.',
31 | lambda c: f'a bright photo of the {c}.',
32 | lambda c: f'a cropped photo of a {c}.',
33 | lambda c: f'a plastic {c}.',
34 | lambda c: f'a photo of the dirty {c}.',
35 | lambda c: f'a jpeg corrupted photo of a {c}.',
36 | lambda c: f'a blurry photo of the {c}.',
37 | lambda c: f'a photo of the {c}.',
38 | lambda c: f'a good photo of the {c}.',
39 | lambda c: f'a rendering of the {c}.',
40 | lambda c: f'a {c} in a video game.',
41 | lambda c: f'a photo of one {c}.',
42 | lambda c: f'a doodle of a {c}.',
43 | lambda c: f'a close-up photo of the {c}.',
44 | lambda c: f'a photo of a {c}.',
45 | lambda c: f'the origami {c}.',
46 | lambda c: f'the {c} in a video game.',
47 | lambda c: f'a sketch of a {c}.',
48 | lambda c: f'a doodle of the {c}.',
49 | lambda c: f'a origami {c}.',
50 | lambda c: f'a low resolution photo of a {c}.',
51 | lambda c: f'the toy {c}.',
52 | lambda c: f'a rendition of the {c}.',
53 | lambda c: f'a photo of the clean {c}.',
54 | lambda c: f'a photo of a large {c}.',
55 | lambda c: f'a rendition of a {c}.',
56 | lambda c: f'a photo of a nice {c}.',
57 | lambda c: f'a photo of a weird {c}.',
58 | lambda c: f'a blurry photo of a {c}.',
59 | lambda c: f'a cartoon {c}.',
60 | lambda c: f'art of a {c}.',
61 | lambda c: f'a sketch of the {c}.',
62 | lambda c: f'a embroidered {c}.',
63 | lambda c: f'a pixelated photo of a {c}.',
64 | lambda c: f'itap of the {c}.',
65 | lambda c: f'a jpeg corrupted photo of the {c}.',
66 | lambda c: f'a good photo of a {c}.',
67 | lambda c: f'a plushie {c}.',
68 | lambda c: f'a photo of the nice {c}.',
69 | lambda c: f'a photo of the small {c}.',
70 | lambda c: f'a photo of the weird {c}.',
71 | lambda c: f'the cartoon {c}.',
72 | lambda c: f'art of the {c}.',
73 | lambda c: f'a drawing of the {c}.',
74 | lambda c: f'a photo of the large {c}.',
75 | lambda c: f'a black and white photo of a {c}.',
76 | lambda c: f'the plushie {c}.',
77 | lambda c: f'a dark photo of a {c}.',
78 | lambda c: f'itap of a {c}.',
79 | lambda c: f'graffiti of the {c}.',
80 | lambda c: f'a toy {c}.',
81 | lambda c: f'itap of my {c}.',
82 | lambda c: f'a photo of a cool {c}.',
83 | lambda c: f'a photo of a small {c}.',
84 | lambda c: f'a tattoo of the {c}.',
85 | ]
--------------------------------------------------------------------------------
/src/classifier-tuning/src/templates/simple_template.py:
--------------------------------------------------------------------------------
1 | from src.templates.utils import append_proper_article
2 |
3 | simple_template = [
4 | lambda c: f"a photo of a {c}."
5 | # lambda c: f"a sketch of a {c}."
6 | ]
--------------------------------------------------------------------------------
/src/classifier-tuning/src/templates/transfer_ds_template.py:
--------------------------------------------------------------------------------
1 | from src.templates.utils import append_proper_article
2 |
3 | aircraft_template = [
4 | lambda c: f"a photo of a {c}, a type of aircraft.",
5 | # lambda c: f"a photo of the {c}, a type of aircraft."
6 | ]
7 |
8 | birds_template = [
9 | lambda c: f"a photo of a {c}, a type of bird."
10 | ]
11 |
12 | eurosat_template = [
13 | lambda c: f"a centered satellite photo of {c}."
14 | ]
15 |
16 | # eurosat_template = [
17 | # lambda c: f"a centered satellite photo of {c}.",
18 | # lambda c: f'a centered satellite photo of a {c}.',
19 | # lambda c: f'a centered satellite photo of the {c}.',
20 | # ]
21 |
22 |
23 | flowers_template = [
24 | lambda c: f"a photo of a {c}, a type of flower."
25 | ]
26 |
27 | food_template = [
28 | lambda c: f"a photo of a {c}, a type of food."
29 | ]
30 |
31 | pets_template = [
32 | lambda c: f"a photo of a {c}, a type of pet."
33 | ]
34 |
35 | imagenet_template = [
36 | lambda c: f"itap of a {c}.",
37 | lambda c: f"a bad photo of the {c}.",
38 | lambda c: f"a origami {c}.",
39 | lambda c: f"a photo of the large {c}.",
40 | lambda c: f"a {c} in a video game.",
41 | lambda c: f"art of the {c}.",
42 | lambda c: f"a photo of the small {c}."]
43 |
44 | cifar100_template = [
45 | lambda c: f'a photo of a {c}.',
46 | lambda c: f'a blurry photo of a {c}.',
47 | lambda c: f'a black and white photo of a {c}.',
48 | lambda c: f'a low contrast photo of a {c}.',
49 | lambda c: f'a high contrast photo of a {c}.',
50 | lambda c: f'a bad photo of a {c}.',
51 | lambda c: f'a good photo of a {c}.',
52 | lambda c: f'a photo of a small {c}.',
53 | lambda c: f'a photo of a big {c}.',
54 | lambda c: f'a photo of the {c}.',
55 | lambda c: f'a blurry photo of the {c}.',
56 | lambda c: f'a black and white photo of the {c}.',
57 | lambda c: f'a low contrast photo of the {c}.',
58 | lambda c: f'a high contrast photo of the {c}.',
59 | lambda c: f'a bad photo of the {c}.',
60 | lambda c: f'a good photo of the {c}.',
61 | lambda c: f'a photo of the small {c}.',
62 | lambda c: f'a photo of the big {c}.',
63 | ]
64 |
65 | cifar10_templates = [
66 | lambda c: f'a photo of a {c}.',
67 | lambda c: f'a blurry photo of a {c}.',
68 | lambda c: f'a black and white photo of a {c}.',
69 | lambda c: f'a low contrast photo of a {c}.',
70 | lambda c: f'a high contrast photo of a {c}.',
71 | lambda c: f'a bad photo of a {c}.',
72 | lambda c: f'a good photo of a {c}.',
73 | lambda c: f'a photo of a small {c}.',
74 | lambda c: f'a photo of a big {c}.',
75 | lambda c: f'a photo of the {c}.',
76 | lambda c: f'a blurry photo of the {c}.',
77 | lambda c: f'a black and white photo of the {c}.',
78 | lambda c: f'a low contrast photo of the {c}.',
79 | lambda c: f'a high contrast photo of the {c}.',
80 | lambda c: f'a bad photo of the {c}.',
81 | lambda c: f'a good photo of the {c}.',
82 | lambda c: f'a photo of the small {c}.',
83 | lambda c: f'a photo of the big {c}.',
84 | ]
85 |
86 | sun_template = [
87 | lambda c: f'a photo of a {c}.',
88 | lambda c: f'a photo of the {c}.',
89 | ]
90 |
91 | cars_template = [
92 | lambda c: f'a photo of a {c}.',
93 | lambda c: f'a photo of the {c}.',
94 | lambda c: f'a photo of my {c}.',
95 | lambda c: f'i love my {c}!',
96 | lambda c: f'a photo of my dirty {c}.',
97 | lambda c: f'a photo of my clean {c}.',
98 | lambda c: f'a photo of my new {c}.',
99 | lambda c: f'a photo of my old {c}.',
100 | ]
101 |
102 | dtd_template = [
103 | lambda c: f"{c} texture."
104 | ]
105 |
106 | # dtd_template = [
107 | # lambda c: f'a photo of a {c} texture.',
108 | # lambda c: f'a photo of a {c} pattern.',
109 | # lambda c: f'a photo of a {c} thing.',
110 | # lambda c: f'a photo of a {c} object.',
111 | # lambda c: f'a photo of the {c} texture.',
112 | # lambda c: f'a photo of the {c} pattern.',
113 | # lambda c: f'a photo of the {c} thing.',
114 | # lambda c: f'a photo of the {c} object.',
115 | # ]
--------------------------------------------------------------------------------
/src/classifier-tuning/src/templates/utils.py:
--------------------------------------------------------------------------------
1 |
2 | def get_plural(name):
3 | name = name.replace('_', ' ')
4 | if name[-2:] == 'sh':
5 | name = name + 'es'
6 | elif name[-2:] == 'ch':
7 | name = name + 'es'
8 | elif name[-1:] == 'y':
9 | name = name[:-1] + 'ies'
10 | elif name[-1:] == 's':
11 | name = name + 'es'
12 | elif name[-1:] == 'x':
13 | name = name + 'es'
14 | elif name[-3:] == 'man':
15 | name = name[:-3] + 'men'
16 | elif name == 'mouse':
17 | name = 'mice'
18 | elif name[-1:] == 'f':
19 | name = name[:-1] + 'ves'
20 | else:
21 | name = name + 's'
22 | return name
23 |
24 |
25 | def append_proper_article(name):
26 | name = name.replace('_', ' ')
27 | if name[0] in 'aeiou':
28 | return 'an ' + name
29 | return 'a ' + name
30 |
--------------------------------------------------------------------------------
/src/glide/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CVMI-Lab/SyntheticData/062ac23b64aaa89020e1e41e5ca9d7719688d589/src/glide/.DS_Store
--------------------------------------------------------------------------------
/src/glide/gen_fsl.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | # sh glide/gen_fsl.sh /path/to/few-shot/images /path/to/save/dataset
3 |
4 | ref_img_path=$1
5 | save_path=$2
6 |
7 | for i in $(seq 0 1 7)
8 | do
9 | CUDA_VISIBLE_DEVICES=${i} python3.7 glide/glide_fsl.py ${i} ${ref_img_path} ${save_path} &
10 | done
11 |
12 |
13 |
14 |
15 |
16 |
17 |
--------------------------------------------------------------------------------
/src/glide/gen_zsl.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | # sh glide/gen_zsl.sh /path/to/save/dataset.pkl /path/to/save/dataset
3 | # sh glide/gen_zsl.sh ../dataset_eurosat.pkl syn_eurosat
4 |
5 | pkl_path=$1
6 | save_path=$2
7 |
8 | for i in $(seq 0 1 7)
9 | do
10 | CUDA_VISIBLE_DEVICES=${i} python3.7 glide/glide_zsl.py ${i} ${pkl_path} ${save_path} &
11 | done
12 |
13 |
14 |
15 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/src/glide_text2im/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | *.egg-info/
3 | .DS_Store
4 |
--------------------------------------------------------------------------------
/src/glide_text2im/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | A codebase for performing model inference with a text-conditional diffusion model.
3 | """
4 |
--------------------------------------------------------------------------------
/src/glide_text2im/clip/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CVMI-Lab/SyntheticData/062ac23b64aaa89020e1e41e5ca9d7719688d589/src/glide_text2im/clip/__init__.py
--------------------------------------------------------------------------------
/src/glide_text2im/clip/attention.py:
--------------------------------------------------------------------------------
1 | import math
2 | from abc import ABC, abstractmethod
3 | from itertools import product
4 | from typing import Any, Optional
5 |
6 | import attr
7 | import numpy as np
8 | import torch
9 |
10 |
11 | @attr.s
12 | class AttentionMask(ABC):
13 | query_context_size: int = attr.ib(validator=lambda i, a, x: x >= 1) # type: ignore
14 | key_context_size: int = attr.ib(validator=lambda i, a, x: x >= 1) # type: ignore
15 | block_size: int = attr.ib(validator=lambda i, a, x: x >= 1) # type: ignore
16 | n_head: int = attr.ib(validator=lambda i, a, x: x >= 1) # type: ignore
17 | is_head_specific: bool = attr.ib(default=False)
18 | n_query_pad: int = attr.ib(default=0)
19 | n_key_pad: int = attr.ib(default=0)
20 |
21 | def __attrs_post_init__(self) -> None:
22 | if self.query_context_size % self.block_size != 0:
23 | raise ValueError()
24 | if self.key_context_size % self.block_size != 0:
25 | raise ValueError()
26 | if self.n_query_pad >= self.query_context_size:
27 | raise ValueError()
28 | if self.n_key_pad >= self.key_context_size:
29 | raise ValueError()
30 |
31 | self.n_query_block = self.query_context_size // self.block_size
32 | self.n_key_block = self.key_context_size // self.block_size
33 | self.first_pad_query_block_idx = self.n_query_block - int(
34 | math.ceil(self.n_query_pad / self.block_size)
35 | )
36 | self.first_pad_key_block_idx = self.n_key_block - int(
37 | math.ceil(self.n_key_pad / self.block_size)
38 | )
39 |
40 | def _make_global_layout(self) -> None:
41 | if not self.is_head_specific:
42 | m = np.ones([self.n_query_block, self.n_key_block], dtype=np.bool)
43 | r = product(*[range(n) for n in m.shape])
44 |
45 | for qb, kb in r:
46 | m[qb, kb] = np.any(self.block_layout(None, 0, qb, kb, 0))
47 | else:
48 | m = np.ones([self.n_head, self.n_query_block, self.n_key_block], dtype=np.bool)
49 | r = product(*[range(n) for n in m.shape])
50 |
51 | for h, qb, kb in r:
52 | m[h, qb, kb] = np.any(self.block_layout(None, h, qb, kb, 0))
53 |
54 | self.global_layout = m
55 |
56 | @abstractmethod
57 | def _block_layout(
58 | self, blk_shape: Any, head_idx: int, query_idx: int, key_idx: int, blk_idx: int
59 | ) -> np.ndarray:
60 | raise NotImplementedError()
61 |
62 | def block_layout(
63 | self, blk_shape: Any, head_idx: int, query_idx: int, key_idx: int, blk_idx: int
64 | ) -> np.ndarray:
65 | """
66 | `query_idx`, `key_idx` are block-level, zero-based indices.
67 | """
68 |
69 | m = np.ones([self.block_size, self.block_size], dtype=np.bool)
70 |
71 | if query_idx >= self.first_pad_query_block_idx:
72 | n_pad = min(
73 | self.block_size,
74 | (query_idx + 1) * self.block_size - (self.query_context_size - self.n_query_pad),
75 | )
76 | assert n_pad > 0
77 | m[self.block_size - n_pad :] = False
78 | if key_idx >= self.first_pad_key_block_idx:
79 | n_pad = min(
80 | self.block_size,
81 | (key_idx + 1) * self.block_size - (self.key_context_size - self.n_key_pad),
82 | )
83 | assert n_pad > 0
84 | m[:, self.block_size - n_pad :] = False
85 |
86 | return m & self._block_layout(blk_shape, head_idx, query_idx, key_idx, blk_idx)
87 |
88 |
89 | @attr.s
90 | class DenseAttentionMask(AttentionMask):
91 | def __attrs_post_init__(self) -> None:
92 | super().__attrs_post_init__()
93 |
94 | self.global_layout = np.ones([self.n_query_block, self.n_key_block], dtype=np.bool)
95 | n_zero_query_blocks = self.n_query_pad // self.block_size
96 | n_zero_key_blocks = self.n_key_pad // self.block_size
97 | self.global_layout[self.n_query_block - n_zero_query_blocks :] = False
98 | self.global_layout[:, self.n_key_block - n_zero_key_blocks :] = False
99 |
100 | def _block_layout(
101 | self, blk_shape: Any, head_idx: int, query_idx: int, key_idx: int, blk_idx: int
102 | ) -> np.ndarray:
103 | return np.ones([self.block_size, self.block_size], dtype=np.bool)
104 |
105 |
106 | @attr.s
107 | class DenseCausalAttentionMask(AttentionMask):
108 | def __attrs_post_init__(self) -> None:
109 | super().__attrs_post_init__()
110 |
111 | self.global_layout = np.tril(np.ones([self.n_query_block, self.n_key_block], dtype=np.bool))
112 | n_zero_query_blocks = self.n_query_pad // self.block_size
113 | n_zero_key_blocks = self.n_key_pad // self.block_size
114 | self.global_layout[self.n_query_block - n_zero_query_blocks :] = False
115 | self.global_layout[:, self.n_key_block - n_zero_key_blocks :] = False
116 |
117 | def _block_layout(
118 | self, blk_shape: Any, head_idx: int, query_idx: int, key_idx: int, blk_idx: int
119 | ) -> np.ndarray:
120 | if query_idx > key_idx:
121 | return np.ones(2 * [self.block_size], dtype=np.bool)
122 | elif query_idx < key_idx:
123 | return np.zeros(2 * [self.block_size], dtype=np.bool)
124 | else:
125 | return np.tril(np.ones(2 * [self.block_size], dtype=np.bool))
126 |
127 |
128 | @attr.s(eq=False, repr=False)
129 | class AttentionInfo:
130 | n_heads: int = attr.ib()
131 | ctx_blks_q: int = attr.ib()
132 | ctx_blks_k: int = attr.ib()
133 | block_size: int = attr.ib()
134 | pytorch_attn_bias: Optional[torch.Tensor] = attr.ib()
135 |
136 |
137 | def to_attention_info(d: AttentionMask) -> AttentionInfo:
138 | return AttentionInfo(
139 | n_heads=d.n_head,
140 | ctx_blks_q=d.n_query_block,
141 | ctx_blks_k=d.n_key_block,
142 | block_size=d.block_size,
143 | pytorch_attn_bias=None,
144 | )
145 |
146 |
147 | def make_full_layout(d: AttentionMask) -> np.ndarray:
148 | """
149 | Returns the `context_size x context_size` layout matrix described by `d`. If the layout is dependent on the index of
150 | the attention head, a `attention_head x context_size x context_size` layout matrix is returned instead.
151 | """
152 |
153 | if not d.is_head_specific:
154 | u = np.reshape(d.global_layout, [d.n_query_block, d.n_key_block, 1, 1])
155 | r = product(range(d.n_query_block), range(d.n_key_block))
156 | v = np.array([d.block_layout(None, 0, i, j, 0) for i, j in r])
157 | v = np.reshape(v, [d.n_query_block, d.n_key_block, d.block_size, d.block_size])
158 |
159 | w = u * v
160 | w = np.transpose(w, [0, 2, 1, 3])
161 | w = np.reshape(w, [d.query_context_size, d.key_context_size])
162 | return w
163 | else:
164 | if len(d.global_layout.shape) == 2:
165 | u = np.reshape(d.global_layout, [1, d.n_query_block, d.n_key_block, 1, 1])
166 | u = np.tile(u, [d.n_head, 1, 1, 1, 1])
167 | elif len(d.global_layout.shape) == 3:
168 | u = np.reshape(d.global_layout, [d.n_head, d.n_query_block, d.n_key_block, 1, 1])
169 | else:
170 | raise RuntimeError()
171 |
172 | s = product(range(d.n_head), range(d.n_query_block), range(d.n_key_block))
173 | v = np.array([d.block_layout(None, i, j, k, 0) for i, j, k in s])
174 | v = np.reshape(v, [d.n_head, d.n_query_block, d.n_key_block, d.block_size, d.block_size])
175 |
176 | w = u * v
177 | w = np.transpose(w, [0, 1, 3, 2, 4])
178 | w = np.reshape(w, [d.n_head, d.query_context_size, d.key_context_size])
179 | return w
180 |
--------------------------------------------------------------------------------
/src/glide_text2im/clip/config.yaml:
--------------------------------------------------------------------------------
1 | logit_scale: 100.0
2 |
3 | # Diffusion settings
4 | beta_schedule: "squaredcos_cap_v2"
5 | n_timesteps: 1000
6 |
7 | # Architecture settings
8 | image_size: 64
9 | patch_size: 4
10 | n_vocab: 65536
11 | max_text_len: 77
12 | n_embd: 512
13 | n_head_state_text: 64
14 | n_head_text: 8
15 | n_xf_blocks_text: 12
16 | n_head_state_image: 64
17 | n_head_image: 12
18 | n_xf_blocks_image: 12
19 |
--------------------------------------------------------------------------------
/src/glide_text2im/clip/model_creation.py:
--------------------------------------------------------------------------------
1 | import os
2 | from functools import lru_cache
3 | from typing import Any, Callable, Dict, List, Optional, Tuple
4 |
5 | import attr
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | import yaml
10 | from glide_text2im.tokenizer.simple_tokenizer import SimpleTokenizer
11 |
12 | from .encoders import ImageEncoder, TextEncoder
13 |
14 |
15 | @lru_cache()
16 | def default_config_path() -> str:
17 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "config.yaml")
18 |
19 |
20 | @attr.s
21 | class CLIPModel:
22 | config: Dict[str, Any] = attr.ib()
23 | text_encoder: nn.Module = attr.ib()
24 | image_encoder: nn.Module = attr.ib()
25 | logit_scale: torch.Tensor = attr.ib()
26 | device: torch.device = attr.ib()
27 | tokenizer: SimpleTokenizer = attr.ib()
28 |
29 | def encode_prompts(self, prompts: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
30 | tokens = []
31 | lens = []
32 | for prompt in prompts:
33 | sub_tokens, sub_len = self.tokenizer.padded_tokens_and_len(
34 | self.tokenizer.encode(prompt), self.text_encoder.max_text_len
35 | )
36 | tokens.append(sub_tokens)
37 | lens.append(sub_len)
38 | return (
39 | torch.tensor(tokens).to(dtype=torch.long, device=self.device),
40 | torch.tensor(lens).to(dtype=torch.long, device=self.device),
41 | )
42 |
43 | def text_embeddings(self, prompts: List[str]) -> torch.Tensor:
44 | tokens, lens = self.encode_prompts(prompts)
45 | z_t = self.text_encoder(tokens, lens)
46 | return z_t / (torch.linalg.norm(z_t, dim=-1, keepdim=True) + 1e-12)
47 |
48 | def image_embeddings(self, images: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
49 | z_i = self.image_encoder((images + 1) * 127.5, t)
50 | return z_i / (torch.linalg.norm(z_i, dim=-1, keepdim=True) + 1e-12)
51 |
52 | def cond_fn(self, prompts: List[str], grad_scale: float) -> Callable[..., torch.Tensor]:
53 | with torch.no_grad():
54 | z_t = self.text_embeddings(prompts)
55 |
56 | def cond_fn(x, t, grad_scale=grad_scale, **kwargs):
57 | with torch.enable_grad():
58 | x_var = x.detach().requires_grad_(True)
59 | z_i = self.image_embeddings(x_var, t)
60 | loss = torch.exp(self.logit_scale) * (z_t * z_i).sum()
61 | grad = torch.autograd.grad(loss, x_var)[0].detach()
62 | return grad * grad_scale
63 |
64 | return cond_fn
65 |
66 |
67 | def create_clip_model(
68 | config_path: Optional[str] = None,
69 | device: Optional[torch.device] = None,
70 | tokenizer: Optional[SimpleTokenizer] = None,
71 | ) -> CLIPModel:
72 | if config_path is None:
73 | config_path = default_config_path()
74 | if device is None:
75 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
76 | if tokenizer is None:
77 | tokenizer = SimpleTokenizer()
78 |
79 | with open(config_path, "r") as f:
80 | config = yaml.load(f, Loader=yaml.SafeLoader)
81 |
82 | text_encoder = TextEncoder(
83 | n_bpe_vocab=config["n_vocab"],
84 | max_text_len=config["max_text_len"],
85 | n_embd=config["n_embd"],
86 | n_head=config["n_head_text"],
87 | n_xf_blocks=config["n_xf_blocks_text"],
88 | n_head_state=config["n_head_state_text"],
89 | device=device,
90 | )
91 |
92 | image_encoder = ImageEncoder(
93 | image_size=config["image_size"],
94 | patch_size=config["patch_size"],
95 | n_embd=config["n_embd"],
96 | n_head=config["n_head_image"],
97 | n_xf_blocks=config["n_xf_blocks_image"],
98 | n_head_state=config["n_head_state_image"],
99 | n_timestep=config["n_timesteps"],
100 | device=device,
101 | )
102 |
103 | logit_scale = torch.tensor(
104 | np.log(config["logit_scale"]),
105 | dtype=torch.float32,
106 | device=device,
107 | requires_grad=False,
108 | )
109 |
110 | return CLIPModel(
111 | config=config,
112 | text_encoder=text_encoder,
113 | image_encoder=image_encoder,
114 | logit_scale=logit_scale,
115 | device=device,
116 | tokenizer=tokenizer,
117 | )
118 |
--------------------------------------------------------------------------------
/src/glide_text2im/clip/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Callable, Optional
3 |
4 | import attr
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | FilterFn = Callable[[torch.Tensor], torch.Tensor]
10 |
11 |
12 | class ZeroKeyBiasGrad(torch.autograd.Function):
13 | @staticmethod
14 | def forward(ctx, x):
15 | return x
16 |
17 | @staticmethod
18 | def backward(ctx, output_grad):
19 | output_grad = output_grad.clone()
20 | output_grad.chunk(3)[1].zero_()
21 | return output_grad
22 |
23 |
24 | def zero_key_bias_grad(x: torch.Tensor) -> torch.Tensor:
25 | return ZeroKeyBiasGrad.apply(x)
26 |
27 |
28 | @attr.s(eq=False, repr=False)
29 | class LayerNorm(nn.Module):
30 | n_state: int = attr.ib()
31 | eps: float = attr.ib(default=1e-6)
32 | device: torch.device = attr.ib(default=torch.device("cuda"))
33 |
34 | def __attrs_post_init__(self) -> None:
35 | super().__init__()
36 | self.g = nn.Parameter(torch.ones((self.n_state,), dtype=torch.float32, device=self.device))
37 | self.b = nn.Parameter(torch.zeros((self.n_state,), dtype=torch.float32, device=self.device))
38 | self.g.weight_decay_level = "disable" # type: ignore
39 | self.b.weight_decay_level = "disable" # type: ignore
40 |
41 | def forward(self, x: torch.Tensor) -> torch.Tensor:
42 | return F.layer_norm(
43 | x.type(torch.float32), torch.Size((self.n_state,)), self.g, self.b, self.eps
44 | )
45 |
46 |
47 | @attr.s(eq=False, repr=False)
48 | class Affine(nn.Module):
49 | n_in: int = attr.ib()
50 | n_out: int = attr.ib()
51 | use_bias: bool = attr.ib(default=True)
52 | use_admnet_init: bool = attr.ib(default=False)
53 | std: Optional[float] = attr.ib(default=None)
54 | extra_init_scale: Optional[float] = attr.ib(default=None)
55 | bias_filter_fn: FilterFn = attr.ib(default=lambda x: x)
56 | device: torch.device = attr.ib(default=torch.device("cuda"))
57 |
58 | def __attrs_post_init__(self) -> None:
59 | super().__init__()
60 |
61 | if not self.use_admnet_init:
62 | self.std = self.std if self.std is not None else math.sqrt(2 / (self.n_in + self.n_out))
63 | self.std = (
64 | self.std if self.extra_init_scale is None else self.std * self.extra_init_scale
65 | )
66 |
67 | w = torch.empty((self.n_out, self.n_in), dtype=torch.float32, device=self.device)
68 | self.w = nn.Parameter(w)
69 |
70 | if self.use_bias:
71 | self.b = nn.Parameter(
72 | torch.zeros((self.n_out,), dtype=torch.float32, device=self.device)
73 | )
74 | self.b.weight_decay_level = "disable" # type: ignore
75 | else:
76 | if self.extra_init_scale is not None:
77 | raise ValueError("extra_init_scale incompatible with admnet init")
78 |
79 | w = torch.empty((self.n_out, self.n_in), dtype=torch.float32, device=self.device)
80 |
81 | if self.use_bias:
82 | b = torch.empty((self.n_out,), dtype=torch.float32, device=self.device)
83 |
84 | self.w = nn.Parameter(w)
85 |
86 | if self.use_bias:
87 | self.b = nn.Parameter(b)
88 | self.b.weight_decay_level = "disable" # type: ignore
89 |
90 | def forward(self, x: torch.Tensor) -> torch.Tensor:
91 | w = self.w if self.w.dtype == x.dtype else self.w.to(x.dtype)
92 | b = (
93 | self.bias_filter_fn(self.b if self.b.dtype == x.dtype else self.b.to(x.dtype))
94 | if self.use_bias
95 | else None
96 | )
97 | return F.linear(x, w, b)
98 |
--------------------------------------------------------------------------------
/src/glide_text2im/download.py:
--------------------------------------------------------------------------------
1 | import os
2 | from functools import lru_cache
3 | from typing import Dict, Optional
4 |
5 | import requests
6 | import torch as th
7 | from filelock import FileLock
8 | from tqdm.auto import tqdm
9 |
10 | MODEL_PATHS = {
11 | "base": "https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base.pt",
12 | "upsample": "https://openaipublic.blob.core.windows.net/diffusion/dec-2021/upsample.pt",
13 | "base-inpaint": "https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base_inpaint.pt",
14 | "upsample-inpaint": "https://openaipublic.blob.core.windows.net/diffusion/dec-2021/upsample_inpaint.pt",
15 | "clip/image-enc": "https://openaipublic.blob.core.windows.net/diffusion/dec-2021/clip_image_enc.pt",
16 | "clip/text-enc": "https://openaipublic.blob.core.windows.net/diffusion/dec-2021/clip_text_enc.pt",
17 | }
18 |
19 |
20 | @lru_cache()
21 | def default_cache_dir() -> str:
22 | return os.path.join(os.path.abspath(os.getcwd()), "glide_model_cache")
23 |
24 |
25 | def fetch_file_cached(
26 | url: str, progress: bool = True, cache_dir: Optional[str] = None, chunk_size: int = 4096
27 | ) -> str:
28 | """
29 | Download the file at the given URL into a local file and return the path.
30 |
31 | If cache_dir is specified, it will be used to download the files.
32 | Otherwise, default_cache_dir() is used.
33 | """
34 | if cache_dir is None:
35 | cache_dir = default_cache_dir()
36 | os.makedirs(cache_dir, exist_ok=True)
37 | local_path = os.path.join(cache_dir, url.split("/")[-1])
38 | if os.path.exists(local_path):
39 | return local_path
40 | response = requests.get(url, stream=True)
41 | size = int(response.headers.get("content-length", "0"))
42 | with FileLock(local_path + ".lock"):
43 | if progress:
44 | pbar = tqdm(total=size, unit="iB", unit_scale=True)
45 | tmp_path = local_path + ".tmp"
46 | with open(tmp_path, "wb") as f:
47 | for chunk in response.iter_content(chunk_size):
48 | if progress:
49 | pbar.update(len(chunk))
50 | f.write(chunk)
51 | os.rename(tmp_path, local_path)
52 | if progress:
53 | pbar.close()
54 | return local_path
55 |
56 |
57 | def load_checkpoint(
58 | checkpoint_name: str,
59 | device: th.device,
60 | progress: bool = True,
61 | cache_dir: Optional[str] = None,
62 | chunk_size: int = 4096,
63 | ) -> Dict[str, th.Tensor]:
64 | if checkpoint_name not in MODEL_PATHS:
65 | raise ValueError(
66 | f"Unknown checkpoint name {checkpoint_name}. Known names are: {MODEL_PATHS.keys()}."
67 | )
68 | path = fetch_file_cached(
69 | MODEL_PATHS[checkpoint_name], progress=progress, cache_dir=cache_dir, chunk_size=chunk_size
70 | )
71 | return th.load(path, map_location=device)
72 |
--------------------------------------------------------------------------------
/src/glide_text2im/fp16_util.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers to inference with 16-bit precision.
3 | """
4 |
5 | import torch.nn as nn
6 |
7 |
8 | def convert_module_to_f16(l):
9 | """
10 | Convert primitive modules to float16.
11 | """
12 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
13 | l.weight.data = l.weight.data.half()
14 | if l.bias is not None:
15 | l.bias.data = l.bias.data.half()
16 |
17 |
18 | def convert_module_to_f32(l):
19 | """
20 | Convert primitive modules to float32, undoing convert_module_to_f16().
21 | """
22 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
23 | l.weight.data = l.weight.data.float()
24 | if l.bias is not None:
25 | l.bias.data = l.bias.data.float()
26 |
--------------------------------------------------------------------------------
/src/glide_text2im/model_creation.py:
--------------------------------------------------------------------------------
1 | from glide_text2im.gaussian_diffusion import get_named_beta_schedule
2 | from glide_text2im.respace import SpacedDiffusion, space_timesteps
3 | from glide_text2im.text2im_model import (
4 | InpaintText2ImUNet,
5 | SuperResInpaintText2ImUnet,
6 | SuperResText2ImUNet,
7 | Text2ImUNet,
8 | )
9 | from glide_text2im.tokenizer.bpe import get_encoder
10 |
11 |
12 | def model_and_diffusion_defaults():
13 | return dict(
14 | image_size=64,
15 | num_channels=192,
16 | num_res_blocks=3,
17 | channel_mult="",
18 | num_heads=1,
19 | num_head_channels=64,
20 | num_heads_upsample=-1,
21 | attention_resolutions="32,16,8",
22 | dropout=0.1,
23 | text_ctx=128,
24 | xf_width=512,
25 | xf_layers=16,
26 | xf_heads=8,
27 | xf_final_ln=True,
28 | xf_padding=True,
29 | diffusion_steps=1000,
30 | noise_schedule="squaredcos_cap_v2",
31 | timestep_respacing="",
32 | use_scale_shift_norm=True,
33 | resblock_updown=True,
34 | use_fp16=True,
35 | cache_text_emb=False,
36 | inpaint=False,
37 | super_res=False,
38 | )
39 |
40 |
41 | def model_and_diffusion_defaults_upsampler():
42 | result = model_and_diffusion_defaults()
43 | result.update(
44 | dict(
45 | image_size=256,
46 | num_res_blocks=2,
47 | noise_schedule="linear",
48 | super_res=True,
49 | )
50 | )
51 | return result
52 |
53 |
54 | def create_model_and_diffusion(
55 | image_size,
56 | num_channels,
57 | num_res_blocks,
58 | channel_mult,
59 | num_heads,
60 | num_head_channels,
61 | num_heads_upsample,
62 | attention_resolutions,
63 | dropout,
64 | text_ctx,
65 | xf_width,
66 | xf_layers,
67 | xf_heads,
68 | xf_final_ln,
69 | xf_padding,
70 | diffusion_steps,
71 | noise_schedule,
72 | timestep_respacing,
73 | use_scale_shift_norm,
74 | resblock_updown,
75 | use_fp16,
76 | cache_text_emb,
77 | inpaint,
78 | super_res,
79 | ):
80 | model = create_model(
81 | image_size,
82 | num_channels,
83 | num_res_blocks,
84 | channel_mult=channel_mult,
85 | attention_resolutions=attention_resolutions,
86 | num_heads=num_heads,
87 | num_head_channels=num_head_channels,
88 | num_heads_upsample=num_heads_upsample,
89 | use_scale_shift_norm=use_scale_shift_norm,
90 | dropout=dropout,
91 | text_ctx=text_ctx,
92 | xf_width=xf_width,
93 | xf_layers=xf_layers,
94 | xf_heads=xf_heads,
95 | xf_final_ln=xf_final_ln,
96 | xf_padding=xf_padding,
97 | resblock_updown=resblock_updown,
98 | use_fp16=use_fp16,
99 | cache_text_emb=cache_text_emb,
100 | inpaint=inpaint,
101 | super_res=super_res,
102 | )
103 | diffusion = create_gaussian_diffusion(
104 | steps=diffusion_steps,
105 | noise_schedule=noise_schedule,
106 | timestep_respacing=timestep_respacing,
107 | )
108 | return model, diffusion
109 |
110 |
111 | def create_model(
112 | image_size,
113 | num_channels,
114 | num_res_blocks,
115 | channel_mult,
116 | attention_resolutions,
117 | num_heads,
118 | num_head_channels,
119 | num_heads_upsample,
120 | use_scale_shift_norm,
121 | dropout,
122 | text_ctx,
123 | xf_width,
124 | xf_layers,
125 | xf_heads,
126 | xf_final_ln,
127 | xf_padding,
128 | resblock_updown,
129 | use_fp16,
130 | cache_text_emb,
131 | inpaint,
132 | super_res,
133 | ):
134 | if channel_mult == "":
135 | if image_size == 256:
136 | channel_mult = (1, 1, 2, 2, 4, 4)
137 | elif image_size == 128:
138 | channel_mult = (1, 1, 2, 3, 4)
139 | elif image_size == 64:
140 | channel_mult = (1, 2, 3, 4)
141 | else:
142 | raise ValueError(f"unsupported image size: {image_size}")
143 | else:
144 | channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(","))
145 | assert 2 ** (len(channel_mult) + 2) == image_size
146 |
147 | attention_ds = []
148 | for res in attention_resolutions.split(","):
149 | attention_ds.append(image_size // int(res))
150 |
151 | if inpaint and super_res:
152 | model_cls = SuperResInpaintText2ImUnet
153 | elif inpaint:
154 | model_cls = InpaintText2ImUNet
155 | elif super_res:
156 | model_cls = SuperResText2ImUNet
157 | else:
158 | model_cls = Text2ImUNet
159 | return model_cls(
160 | text_ctx=text_ctx,
161 | xf_width=xf_width,
162 | xf_layers=xf_layers,
163 | xf_heads=xf_heads,
164 | xf_final_ln=xf_final_ln,
165 | tokenizer=get_encoder(),
166 | xf_padding=xf_padding,
167 | in_channels=3,
168 | model_channels=num_channels,
169 | out_channels=6,
170 | num_res_blocks=num_res_blocks,
171 | attention_resolutions=tuple(attention_ds),
172 | dropout=dropout,
173 | channel_mult=channel_mult,
174 | use_fp16=use_fp16,
175 | num_heads=num_heads,
176 | num_head_channels=num_head_channels,
177 | num_heads_upsample=num_heads_upsample,
178 | use_scale_shift_norm=use_scale_shift_norm,
179 | resblock_updown=resblock_updown,
180 | cache_text_emb=cache_text_emb,
181 | )
182 |
183 |
184 | def create_gaussian_diffusion(
185 | steps,
186 | noise_schedule,
187 | timestep_respacing,
188 | ):
189 | betas = get_named_beta_schedule(noise_schedule, steps)
190 | if not timestep_respacing:
191 | timestep_respacing = [steps]
192 | return SpacedDiffusion(
193 | use_timesteps=space_timesteps(steps, timestep_respacing),
194 | betas=betas,
195 | )
196 |
--------------------------------------------------------------------------------
/src/glide_text2im/nn.py:
--------------------------------------------------------------------------------
1 | """
2 | Various utilities for neural networks.
3 | """
4 |
5 | import math
6 |
7 | import torch as th
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 |
12 | class GroupNorm32(nn.GroupNorm):
13 | def __init__(self, num_groups, num_channels, swish, eps=1e-5):
14 | super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps)
15 | self.swish = swish
16 |
17 | def forward(self, x):
18 | y = super().forward(x.float()).to(x.dtype)
19 | if self.swish == 1.0:
20 | y = F.silu(y)
21 | elif self.swish:
22 | y = y * F.sigmoid(y * float(self.swish))
23 | return y
24 |
25 |
26 | def conv_nd(dims, *args, **kwargs):
27 | """
28 | Create a 1D, 2D, or 3D convolution module.
29 | """
30 | if dims == 1:
31 | return nn.Conv1d(*args, **kwargs)
32 | elif dims == 2:
33 | return nn.Conv2d(*args, **kwargs)
34 | elif dims == 3:
35 | return nn.Conv3d(*args, **kwargs)
36 | raise ValueError(f"unsupported dimensions: {dims}")
37 |
38 |
39 | def linear(*args, **kwargs):
40 | """
41 | Create a linear module.
42 | """
43 | return nn.Linear(*args, **kwargs)
44 |
45 |
46 | def avg_pool_nd(dims, *args, **kwargs):
47 | """
48 | Create a 1D, 2D, or 3D average pooling module.
49 | """
50 | if dims == 1:
51 | return nn.AvgPool1d(*args, **kwargs)
52 | elif dims == 2:
53 | return nn.AvgPool2d(*args, **kwargs)
54 | elif dims == 3:
55 | return nn.AvgPool3d(*args, **kwargs)
56 | raise ValueError(f"unsupported dimensions: {dims}")
57 |
58 |
59 | def zero_module(module):
60 | """
61 | Zero out the parameters of a module and return it.
62 | """
63 | for p in module.parameters():
64 | p.detach().zero_()
65 | return module
66 |
67 |
68 | def scale_module(module, scale):
69 | """
70 | Scale the parameters of a module and return it.
71 | """
72 | for p in module.parameters():
73 | p.detach().mul_(scale)
74 | return module
75 |
76 |
77 | def normalization(channels, swish=0.0):
78 | """
79 | Make a standard normalization layer, with an optional swish activation.
80 |
81 | :param channels: number of input channels.
82 | :return: an nn.Module for normalization.
83 | """
84 | return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
85 |
86 |
87 | def timestep_embedding(timesteps, dim, max_period=10000):
88 | """
89 | Create sinusoidal timestep embeddings.
90 |
91 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
92 | These may be fractional.
93 | :param dim: the dimension of the output.
94 | :param max_period: controls the minimum frequency of the embeddings.
95 | :return: an [N x dim] Tensor of positional embeddings.
96 | """
97 | half = dim // 2
98 | freqs = th.exp(
99 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
100 | ).to(device=timesteps.device)
101 | args = timesteps[:, None].float() * freqs[None]
102 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
103 | if dim % 2:
104 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
105 | return embedding
106 |
--------------------------------------------------------------------------------
/src/glide_text2im/respace.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilities for changing sampling schedules of a trained model.
3 |
4 | Simplified from: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/respace.py
5 | """
6 |
7 | import numpy as np
8 | import torch as th
9 |
10 | from .gaussian_diffusion import GaussianDiffusion
11 |
12 |
13 | def space_timesteps(num_timesteps, section_counts):
14 | """
15 | Create a list of timesteps to use from an original diffusion process,
16 | given the number of timesteps we want to take from equally-sized portions
17 | of the original process.
18 |
19 | For example, if there's 300 timesteps and the section counts are [10,15,20]
20 | then the first 100 timesteps are strided to be 10 timesteps, the second 100
21 | are strided to be 15 timesteps, and the final 100 are strided to be 20.
22 |
23 | :param num_timesteps: the number of diffusion steps in the original
24 | process to divide up.
25 | :param section_counts: either a list of numbers, or a string containing
26 | comma-separated numbers, indicating the step count
27 | per section. As a special case, use "ddimN" where N
28 | is a number of steps to use the striding from the
29 | DDIM paper.
30 | :return: a set of diffusion steps from the original process to use.
31 | """
32 | if isinstance(section_counts, str):
33 | if section_counts.startswith("ddim"):
34 | desired_count = int(section_counts[len("ddim") :])
35 | for i in range(1, num_timesteps):
36 | if len(range(0, num_timesteps, i)) == desired_count:
37 | return set(range(0, num_timesteps, i))
38 | raise ValueError(f"cannot create exactly {num_timesteps} steps with an integer stride")
39 | elif section_counts == "fast27":
40 | steps = space_timesteps(num_timesteps, "10,10,3,2,2")
41 | # Help reduce DDIM artifacts from noisiest timesteps.
42 | steps.remove(num_timesteps - 1)
43 | steps.add(num_timesteps - 3)
44 | return steps
45 | section_counts = [int(x) for x in section_counts.split(",")]
46 | size_per = num_timesteps // len(section_counts)
47 | extra = num_timesteps % len(section_counts)
48 | start_idx = 0
49 | all_steps = []
50 | for i, section_count in enumerate(section_counts):
51 | size = size_per + (1 if i < extra else 0)
52 | if size < section_count:
53 | raise ValueError(f"cannot divide section of {size} steps into {section_count}")
54 | if section_count <= 1:
55 | frac_stride = 1
56 | else:
57 | frac_stride = (size - 1) / (section_count - 1)
58 | cur_idx = 0.0
59 | taken_steps = []
60 | for _ in range(section_count):
61 | taken_steps.append(start_idx + round(cur_idx))
62 | cur_idx += frac_stride
63 | all_steps += taken_steps
64 | start_idx += size
65 | return set(all_steps)
66 |
67 |
68 | class SpacedDiffusion(GaussianDiffusion):
69 | """
70 | A diffusion process which can skip steps in a base diffusion process.
71 |
72 | :param use_timesteps: a collection (sequence or set) of timesteps from the
73 | original diffusion process to retain.
74 | :param kwargs: the kwargs to create the base diffusion process.
75 | """
76 |
77 | def __init__(self, use_timesteps, **kwargs):
78 | self.use_timesteps = set(use_timesteps)
79 | self.timestep_map = []
80 | self.original_num_steps = len(kwargs["betas"])
81 |
82 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
83 | last_alpha_cumprod = 1.0
84 | new_betas = []
85 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
86 | if i in self.use_timesteps:
87 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
88 | last_alpha_cumprod = alpha_cumprod
89 | self.timestep_map.append(i)
90 | kwargs["betas"] = np.array(new_betas)
91 | super().__init__(**kwargs)
92 |
93 | def p_mean_variance(self, model, *args, **kwargs):
94 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
95 |
96 | def condition_mean(self, cond_fn, *args, **kwargs):
97 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
98 |
99 | def condition_score(self, cond_fn, *args, **kwargs):
100 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
101 |
102 | def _wrap_model(self, model):
103 | if isinstance(model, _WrappedModel):
104 | return model
105 | return _WrappedModel(model, self.timestep_map, self.original_num_steps)
106 |
107 |
108 | class _WrappedModel:
109 | def __init__(self, model, timestep_map, original_num_steps):
110 | self.model = model
111 | self.timestep_map = timestep_map
112 | self.original_num_steps = original_num_steps
113 |
114 | def __call__(self, x, ts, **kwargs):
115 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
116 | new_ts = map_tensor[ts]
117 | return self.model(x, new_ts, **kwargs)
118 |
--------------------------------------------------------------------------------
/src/glide_text2im/tokenizer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CVMI-Lab/SyntheticData/062ac23b64aaa89020e1e41e5ca9d7719688d589/src/glide_text2im/tokenizer/__init__.py
--------------------------------------------------------------------------------
/src/glide_text2im/tokenizer/bpe.py:
--------------------------------------------------------------------------------
1 | """
2 | Byte pair encoding utilities adapted from:
3 | https://github.com/openai/gpt-2/blob/master/src/encoder.py
4 | """
5 |
6 | import gzip
7 | import json
8 | import os
9 | from functools import lru_cache
10 | from typing import List, Tuple
11 |
12 | import regex as re
13 |
14 |
15 | @lru_cache()
16 | def bytes_to_unicode():
17 | """
18 | Returns list of utf-8 byte and a corresponding list of unicode strings.
19 | The reversible bpe codes work on unicode strings.
20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22 | This is a signficant percentage of your normal, say, 32K bpe vocab.
23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24 | And avoids mapping to whitespace/control characters the bpe code barfs on.
25 | """
26 | bs = (
27 | list(range(ord("!"), ord("~") + 1))
28 | + list(range(ord("¡"), ord("¬") + 1))
29 | + list(range(ord("®"), ord("ÿ") + 1))
30 | )
31 | cs = bs[:]
32 | n = 0
33 | for b in range(2 ** 8):
34 | if b not in bs:
35 | bs.append(b)
36 | cs.append(2 ** 8 + n)
37 | n += 1
38 | cs = [chr(n) for n in cs]
39 | return dict(zip(bs, cs))
40 |
41 |
42 | def get_pairs(word):
43 | """Return set of symbol pairs in a word.
44 | Word is represented as tuple of symbols (symbols being variable-length strings).
45 | """
46 | pairs = set()
47 | prev_char = word[0]
48 | for char in word[1:]:
49 | pairs.add((prev_char, char))
50 | prev_char = char
51 | return pairs
52 |
53 |
54 | class Encoder:
55 | def __init__(self, encoder, bpe_merges, errors="replace"):
56 | self.encoder = encoder
57 | self.decoder = {v: k for k, v in self.encoder.items()}
58 | self.errors = errors # how to handle errors in decoding
59 | self.byte_encoder = bytes_to_unicode()
60 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
61 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
62 | self.cache = {}
63 |
64 | # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
65 | self.pat = re.compile(
66 | r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
67 | )
68 |
69 | @property
70 | def n_vocab(self) -> int:
71 | return len(self.encoder)
72 |
73 | @property
74 | def end_token(self) -> int:
75 | return self.n_vocab - 1
76 |
77 | def padded_tokens_and_mask(
78 | self, tokens: List[int], text_ctx: int
79 | ) -> Tuple[List[int], List[bool]]:
80 | tokens = tokens[:text_ctx]
81 | padding = text_ctx - len(tokens)
82 | padded_tokens = tokens + [self.end_token] * padding
83 | mask = [True] * len(tokens) + [False] * padding
84 | return padded_tokens, mask
85 |
86 | def bpe(self, token):
87 | if token in self.cache:
88 | return self.cache[token]
89 | word = tuple(token)
90 | pairs = get_pairs(word)
91 |
92 | if not pairs:
93 | return token
94 |
95 | while True:
96 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
97 | if bigram not in self.bpe_ranks:
98 | break
99 | first, second = bigram
100 | new_word = []
101 | i = 0
102 | while i < len(word):
103 | try:
104 | j = word.index(first, i)
105 | new_word.extend(word[i:j])
106 | i = j
107 | except: # pylint: disable=bare-except
108 | new_word.extend(word[i:])
109 | break
110 |
111 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
112 | new_word.append(first + second)
113 | i += 2
114 | else:
115 | new_word.append(word[i])
116 | i += 1
117 | new_word = tuple(new_word)
118 | word = new_word
119 | if len(word) == 1:
120 | break
121 | else:
122 | pairs = get_pairs(word)
123 | word = " ".join(word)
124 | self.cache[token] = word
125 | return word
126 |
127 | def encode(self, text):
128 | text = text.lower()
129 | bpe_tokens = []
130 | for token in re.findall(self.pat, text):
131 | token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
132 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" "))
133 | return bpe_tokens
134 |
135 | def decode(self, tokens):
136 | text = "".join([self.decoder[token] for token in tokens])
137 | text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
138 | return text
139 |
140 |
141 | def get_encoder():
142 | root_dir = os.path.dirname(os.path.abspath(__file__))
143 | with gzip.open(os.path.join(root_dir, "encoder.json.gz"), "r") as f:
144 | encoder = json.load(f)
145 | with gzip.open(os.path.join(root_dir, "vocab.bpe.gz"), "r") as f:
146 | bpe_data = str(f.read(), "utf-8")
147 | bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]]
148 | return Encoder(
149 | encoder=encoder,
150 | bpe_merges=bpe_merges,
151 | )
152 |
--------------------------------------------------------------------------------
/src/glide_text2im/tokenizer/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CVMI-Lab/SyntheticData/062ac23b64aaa89020e1e41e5ca9d7719688d589/src/glide_text2im/tokenizer/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/src/glide_text2im/tokenizer/encoder.json.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CVMI-Lab/SyntheticData/062ac23b64aaa89020e1e41e5ca9d7719688d589/src/glide_text2im/tokenizer/encoder.json.gz
--------------------------------------------------------------------------------
/src/glide_text2im/tokenizer/simple_tokenizer.py:
--------------------------------------------------------------------------------
1 | """
2 | Copied from: https://github.com/openai/CLIP/blob/573315e83f07b53a61ff5098757e8fc885f1703e/clip/simple_tokenizer.py
3 | """
4 |
5 | import gzip
6 | import html
7 | import os
8 | from functools import lru_cache
9 | from typing import List, Tuple
10 |
11 | import ftfy
12 | import regex as re
13 |
14 |
15 | @lru_cache()
16 | def default_bpe():
17 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
18 |
19 |
20 | @lru_cache()
21 | def bytes_to_unicode():
22 | """
23 | Returns list of utf-8 byte and a corresponding list of unicode strings.
24 | The reversible bpe codes work on unicode strings.
25 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
26 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
27 | This is a signficant percentage of your normal, say, 32K bpe vocab.
28 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
29 | And avoids mapping to whitespace/control characters the bpe code barfs on.
30 | """
31 | bs = (
32 | list(range(ord("!"), ord("~") + 1))
33 | + list(range(ord("¡"), ord("¬") + 1))
34 | + list(range(ord("®"), ord("ÿ") + 1))
35 | )
36 | cs = bs[:]
37 | n = 0
38 | for b in range(2 ** 8):
39 | if b not in bs:
40 | bs.append(b)
41 | cs.append(2 ** 8 + n)
42 | n += 1
43 | cs = [chr(n) for n in cs]
44 | return dict(zip(bs, cs))
45 |
46 |
47 | def get_pairs(word):
48 | """Return set of symbol pairs in a word.
49 | Word is represented as tuple of symbols (symbols being variable-length strings).
50 | """
51 | pairs = set()
52 | prev_char = word[0]
53 | for char in word[1:]:
54 | pairs.add((prev_char, char))
55 | prev_char = char
56 | return pairs
57 |
58 |
59 | def basic_clean(text):
60 | text = ftfy.fix_text(text)
61 | text = html.unescape(html.unescape(text))
62 | return text.strip()
63 |
64 |
65 | def whitespace_clean(text):
66 | text = re.sub(r"\s+", " ", text)
67 | text = text.strip()
68 | return text
69 |
70 |
71 | class SimpleTokenizer(object):
72 | def __init__(self, bpe_path: str = default_bpe()):
73 | self.byte_encoder = bytes_to_unicode()
74 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
75 | merges = gzip.open(bpe_path).read().decode("utf-8").split("\n")
76 | merges = merges[1 : 49152 - 256 - 2 + 1]
77 | merges = [tuple(merge.split()) for merge in merges]
78 | vocab = list(bytes_to_unicode().values())
79 | vocab = vocab + [v + "" for v in vocab]
80 | for merge in merges:
81 | vocab.append("".join(merge))
82 | vocab.extend(["<|startoftext|>", "<|endoftext|>"])
83 | self.encoder = dict(zip(vocab, range(len(vocab))))
84 | self.decoder = {v: k for k, v in self.encoder.items()}
85 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
86 | self.cache = {"<|startoftext|>": "<|startoftext|>", "<|endoftext|>": "<|endoftext|>"}
87 | self.pat = re.compile(
88 | r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
89 | re.IGNORECASE,
90 | )
91 |
92 | @property
93 | def start_token(self):
94 | return self.encoder["<|startoftext|>"]
95 |
96 | @property
97 | def end_token(self):
98 | return self.encoder["<|endoftext|>"]
99 |
100 | def padded_tokens_and_len(self, tokens: List[int], text_ctx: int) -> Tuple[List[int], int]:
101 | tokens = [self.start_token] + tokens[: text_ctx - 2] + [self.end_token]
102 | text_len = len(tokens)
103 | padding = text_ctx - len(tokens)
104 | padded_tokens = tokens + [0] * padding
105 | return padded_tokens, text_len
106 |
107 | def bpe(self, token):
108 | if token in self.cache:
109 | return self.cache[token]
110 | word = tuple(token[:-1]) + (token[-1] + "",)
111 | pairs = get_pairs(word)
112 |
113 | if not pairs:
114 | return token + ""
115 |
116 | while True:
117 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
118 | if bigram not in self.bpe_ranks:
119 | break
120 | first, second = bigram
121 | new_word = []
122 | i = 0
123 | while i < len(word):
124 | try:
125 | j = word.index(first, i)
126 | new_word.extend(word[i:j])
127 | i = j
128 | except: # pylint: disable=bare-except
129 | new_word.extend(word[i:])
130 | break
131 |
132 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
133 | new_word.append(first + second)
134 | i += 2
135 | else:
136 | new_word.append(word[i])
137 | i += 1
138 | new_word = tuple(new_word)
139 | word = new_word
140 | if len(word) == 1:
141 | break
142 | else:
143 | pairs = get_pairs(word)
144 | word = " ".join(word)
145 | self.cache[token] = word
146 | return word
147 |
148 | def encode(self, text):
149 | bpe_tokens = []
150 | text = whitespace_clean(basic_clean(text)).lower()
151 | for token in re.findall(self.pat, text):
152 | token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
153 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" "))
154 | return bpe_tokens
155 |
156 | def decode(self, tokens):
157 | text = "".join([self.decoder[token] for token in tokens])
158 | text = (
159 | bytearray([self.byte_decoder[c] for c in text])
160 | .decode("utf-8", errors="replace")
161 | .replace("", " ")
162 | )
163 | return text
164 |
--------------------------------------------------------------------------------
/src/glide_text2im/tokenizer/vocab.bpe.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CVMI-Lab/SyntheticData/062ac23b64aaa89020e1e41e5ca9d7719688d589/src/glide_text2im/tokenizer/vocab.bpe.gz
--------------------------------------------------------------------------------
/src/glide_text2im/xf.py:
--------------------------------------------------------------------------------
1 | """
2 | Transformer implementation adapted from CLIP ViT:
3 | https://github.com/openai/CLIP/blob/4c0275784d6d9da97ca1f47eaaee31de1867da91/clip/model.py
4 | """
5 |
6 | import math
7 |
8 | import torch as th
9 | import torch.nn as nn
10 |
11 |
12 | def convert_module_to_f16(l):
13 | """
14 | Convert primitive modules to float16.
15 | """
16 | if isinstance(l, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
17 | l.weight.data = l.weight.data.half()
18 | if l.bias is not None:
19 | l.bias.data = l.bias.data.half()
20 |
21 |
22 | class LayerNorm(nn.LayerNorm):
23 | """
24 | Implementation that supports fp16 inputs but fp32 gains/biases.
25 | """
26 |
27 | def forward(self, x: th.Tensor):
28 | return super().forward(x.float()).to(x.dtype)
29 |
30 |
31 | class MultiheadAttention(nn.Module):
32 | def __init__(self, n_ctx, width, heads):
33 | super().__init__()
34 | self.n_ctx = n_ctx
35 | self.width = width
36 | self.heads = heads
37 | self.c_qkv = nn.Linear(width, width * 3)
38 | self.c_proj = nn.Linear(width, width)
39 | self.attention = QKVMultiheadAttention(heads, n_ctx)
40 |
41 | def forward(self, x):
42 | x = self.c_qkv(x)
43 | x = self.attention(x)
44 | x = self.c_proj(x)
45 | return x
46 |
47 |
48 | class MLP(nn.Module):
49 | def __init__(self, width):
50 | super().__init__()
51 | self.width = width
52 | self.c_fc = nn.Linear(width, width * 4)
53 | self.c_proj = nn.Linear(width * 4, width)
54 | self.gelu = nn.GELU()
55 |
56 | def forward(self, x):
57 | return self.c_proj(self.gelu(self.c_fc(x)))
58 |
59 |
60 | class QKVMultiheadAttention(nn.Module):
61 | def __init__(self, n_heads: int, n_ctx: int):
62 | super().__init__()
63 | self.n_heads = n_heads
64 | self.n_ctx = n_ctx
65 |
66 | def forward(self, qkv):
67 | bs, n_ctx, width = qkv.shape
68 | attn_ch = width // self.n_heads // 3
69 | scale = 1 / math.sqrt(math.sqrt(attn_ch))
70 | qkv = qkv.view(bs, n_ctx, self.n_heads, -1)
71 | q, k, v = th.split(qkv, attn_ch, dim=-1)
72 | weight = th.einsum(
73 | "bthc,bshc->bhts", q * scale, k * scale
74 | ) # More stable with f16 than dividing afterwards
75 | wdtype = weight.dtype
76 | weight = th.softmax(weight.float(), dim=-1).type(wdtype)
77 | return th.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
78 |
79 |
80 | class ResidualAttentionBlock(nn.Module):
81 | def __init__(
82 | self,
83 | n_ctx: int,
84 | width: int,
85 | heads: int,
86 | ):
87 | super().__init__()
88 |
89 | self.attn = MultiheadAttention(
90 | n_ctx,
91 | width,
92 | heads,
93 | )
94 | self.ln_1 = LayerNorm(width)
95 | self.mlp = MLP(width)
96 | self.ln_2 = LayerNorm(width)
97 |
98 | def forward(self, x: th.Tensor):
99 | x = x + self.attn(self.ln_1(x))
100 | x = x + self.mlp(self.ln_2(x))
101 | return x
102 |
103 |
104 | class Transformer(nn.Module):
105 | def __init__(
106 | self,
107 | n_ctx: int,
108 | width: int,
109 | layers: int,
110 | heads: int,
111 | ):
112 | super().__init__()
113 | self.n_ctx = n_ctx
114 | self.width = width
115 | self.layers = layers
116 | self.resblocks = nn.ModuleList(
117 | [
118 | ResidualAttentionBlock(
119 | n_ctx,
120 | width,
121 | heads,
122 | )
123 | for _ in range(layers)
124 | ]
125 | )
126 |
127 | def forward(self, x: th.Tensor):
128 | for block in self.resblocks:
129 | x = block(x)
130 | return x
131 |
--------------------------------------------------------------------------------