├── DATASET.md
├── README.md
├── calip.png
├── class_template.py
├── clip
├── __init__.py
├── bpe_simple_vocab_16e6.txt.gz
├── clip.py
├── model.py
└── simple_tokenizer.py
├── configs
└── imagenet.yaml
├── requirements.txt
├── run_imagenet.py
└── utils.py
/DATASET.md:
--------------------------------------------------------------------------------
1 | # How to install datasets
2 |
3 | We suggest putting all datasets under the same folder (say `$DATA`) to ease management and following the instructions below to organize datasets to avoid modifying the source code. The file structure looks like
4 |
5 | ```
6 | $DATA/
7 | |–– imagenet/
8 | |–– caltech-101/
9 | |–– oxford_pets/
10 | |–– stanford_cars/
11 | ```
12 |
13 | If you have some datasets already installed somewhere else, you can create symbolic links in `$DATA/dataset_name` that point to the original data to avoid duplicate download.
14 |
15 | Datasets list:
16 | - [ImageNet](#imagenet)
17 | - [Caltech101](#caltech101)
18 | - [OxfordPets](#oxfordpets)
19 | - [StanfordCars](#stanfordcars)
20 | - [Flowers102](#flowers102)
21 | - [Food101](#food101)
22 | - [FGVCAircraft](#fgvcaircraft)
23 | - [SUN397](#sun397)
24 | - [DTD](#dtd)
25 | - [EuroSAT](#eurosat)
26 | - [UCF101](#ucf101)
27 |
28 | The instructions to prepare each dataset are detailed below. To ensure reproducibility and fair comparison for future work, we utilize CoOp-style train/val/test splits for all datasets except ImageNet where the validation set is used as test set.
29 |
30 | ### ImageNet
31 | - Create a folder named `imagenet/` under `$DATA`.
32 | - Create `images/` under `imagenet/`.
33 | - Download the dataset from the [official website](https://image-net.org/index.php) and extract the training and validation sets to `$DATA/imagenet/images`. The directory structure should look like
34 | ```
35 | imagenet/
36 | |–– images/
37 | | |–– train/ # contains 1,000 folders like n01440764, n01443537, etc.
38 | | |–– val/
39 | ```
40 | - If you had downloaded the ImageNet dataset before, you can create symbolic links to map the training and validation sets to `$DATA/imagenet/images`.
41 | - Download the `classnames.txt` to `$DATA/imagenet/` from this [link](https://drive.google.com/file/d/1-61f_ol79pViBFDG_IDlUQSwoLcn2XXF/view?usp=sharing). The class names are copied from [CLIP](https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb).
42 |
43 | ### Caltech101
44 | - Create a folder named `caltech-101/` under `$DATA`.
45 | - Download `101_ObjectCategories.tar.gz` from http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz and extract the file under `$DATA/caltech-101`.
46 | - Download `split_zhou_Caltech101.json` from this [link](https://drive.google.com/file/d/1hyarUivQE36mY6jSomru6Fjd-JzwcCzN/view?usp=sharing) and put it under `$DATA/caltech-101`.
47 |
48 | The directory structure should look like
49 | ```
50 | caltech-101/
51 | |–– 101_ObjectCategories/
52 | |–– split_zhou_Caltech101.json
53 | ```
54 |
55 | ### OxfordPets
56 | - Create a folder named `oxford_pets/` under `$DATA`.
57 | - Download the images from https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz.
58 | - Download the annotations from https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz.
59 | - Download `split_zhou_OxfordPets.json` from this [link](https://drive.google.com/file/d/1501r8Ber4nNKvmlFVQZ8SeUHTcdTTEqs/view?usp=sharing).
60 |
61 | The directory structure should look like
62 | ```
63 | oxford_pets/
64 | |–– images/
65 | |–– annotations/
66 | |–– split_zhou_OxfordPets.json
67 | ```
68 |
69 | ### StanfordCars
70 | - Create a folder named `stanford_cars/` under `$DATA`.
71 | - Download the train images http://ai.stanford.edu/~jkrause/car196/cars_train.tgz.
72 | - Download the test images http://ai.stanford.edu/~jkrause/car196/cars_test.tgz.
73 | - Download the train labels https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz.
74 | - Download the test labels http://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat.
75 | - Download `split_zhou_StanfordCars.json` from this [link](https://drive.google.com/file/d/1ObCFbaAgVu0I-k_Au-gIUcefirdAuizT/view?usp=sharing).
76 |
77 | The directory structure should look like
78 | ```
79 | stanford_cars/
80 | |–– cars_test\
81 | |–– cars_test_annos_withlabels.mat
82 | |–– cars_train\
83 | |–– devkit\
84 | |–– split_zhou_StanfordCars.json
85 | ```
86 |
87 | ### Flowers102
88 | - Create a folder named `oxford_flowers/` under `$DATA`.
89 | - Download the images and labels from https://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz and https://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat respectively.
90 | - Download `cat_to_name.json` from [here](https://drive.google.com/file/d/1AkcxCXeK_RCGCEC_GvmWxjcjaNhu-at0/view?usp=sharing).
91 | - Download `split_zhou_OxfordFlowers.json` from [here](https://drive.google.com/file/d/1Pp0sRXzZFZq15zVOzKjKBu4A9i01nozT/view?usp=sharing).
92 |
93 | The directory structure should look like
94 | ```
95 | oxford_flowers/
96 | |–– cat_to_name.json
97 | |–– imagelabels.mat
98 | |–– jpg/
99 | |–– split_zhou_OxfordFlowers.json
100 | ```
101 |
102 | ### Food101
103 | - Download the dataset from https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/ and extract the file `food-101.tar.gz` under `$DATA`, resulting in a folder named `$DATA/food-101/`.
104 | - Download `split_zhou_Food101.json` from [here](https://drive.google.com/file/d/1QK0tGi096I0Ba6kggatX1ee6dJFIcEJl/view?usp=sharing).
105 |
106 | The directory structure should look like
107 | ```
108 | food-101/
109 | |–– images/
110 | |–– license_agreement.txt
111 | |–– meta/
112 | |–– README.txt
113 | |–– split_zhou_Food101.json
114 | ```
115 |
116 | ### FGVCAircraft
117 | - Download the data from https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz.
118 | - Extract `fgvc-aircraft-2013b.tar.gz` and keep only `data/`.
119 | - Move `data/` to `$DATA` and rename the folder to `fgvc_aircraft/`.
120 |
121 | The directory structure should look like
122 | ```
123 | fgvc_aircraft/
124 | |–– images/
125 | |–– ... # a bunch of .txt files
126 | ```
127 |
128 | ### SUN397
129 | - Create a folder named `sun397/` under `$DATA`.
130 | - Download the images http://vision.princeton.edu/projects/2010/SUN/SUN397.tar.gz.
131 | - Download the partitions https://vision.princeton.edu/projects/2010/SUN/download/Partitions.zip.
132 | - Extract these files under `$DATA/sun397/`.
133 | - Download `split_zhou_SUN397.json` from this [link](https://drive.google.com/file/d/1y2RD81BYuiyvebdN-JymPfyWYcd8_MUq/view?usp=sharing).
134 |
135 | The directory structure should look like
136 | ```
137 | sun397/
138 | |–– SUN397/
139 | |–– split_zhou_SUN397.json
140 | |–– ... # a bunch of .txt files
141 | ```
142 |
143 | ### DTD
144 | - Download the dataset from https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz and extract it to `$DATA`. This should lead to `$DATA/dtd/`.
145 | - Download `split_zhou_DescribableTextures.json` from this [link](https://drive.google.com/file/d/1u3_QfB467jqHgNXC00UIzbLZRQCg2S7x/view?usp=sharing).
146 |
147 | The directory structure should look like
148 | ```
149 | dtd/
150 | |–– images/
151 | |–– imdb/
152 | |–– labels/
153 | |–– split_zhou_DescribableTextures.json
154 | ```
155 |
156 | ### EuroSAT
157 | - Create a folder named `eurosat/` under `$DATA`.
158 | - Download the dataset from http://madm.dfki.de/files/sentinel/EuroSAT.zip and extract it to `$DATA/eurosat/`.
159 | - Download `split_zhou_EuroSAT.json` from [here](https://drive.google.com/file/d/1Ip7yaCWFi0eaOFUGga0lUdVi_DDQth1o/view?usp=sharing).
160 |
161 | The directory structure should look like
162 | ```
163 | eurosat/
164 | |–– 2750/
165 | |–– split_zhou_EuroSAT.json
166 | ```
167 |
168 | ### UCF101
169 | - Create a folder named `ucf101/` under `$DATA`.
170 | - Download the zip file `UCF-101-midframes.zip` from [here](https://drive.google.com/file/d/10Jqome3vtUA2keJkNanAiFpgbyC9Hc2O/view?usp=sharing) and extract it to `$DATA/ucf101/`. This zip file contains the extracted middle video frames.
171 | - Download `split_zhou_UCF101.json` from this [link](https://drive.google.com/file/d/1I0S0q91hJfsV9Gf4xDIjgDq4AqBNJb1y/view?usp=sharing).
172 |
173 | The directory structure should look like
174 | ```
175 | ucf101/
176 | |–– UCF-101-midframes/
177 | |–– split_zhou_UCF101.json
178 | ```
179 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CALIP: Zero-Shot Enhancement of CLIP with Parameter-free Attention
2 | Official implementation of ['CALIP: Zero-Shot Enhancement of CLIP with Parameter-free Attention'](https://arxiv.org/pdf/2209.14169.pdf).
3 |
4 | The paper has been accepted by **AAAI 2023**.
5 |
6 | ## Introduction
7 | CALIP is a free-lunch enhancement method to boost CLIP’s zero-shot performance via a parameter-free Attention module. Specifically, we guide visual and textual representations to interact with each other and explore cross-modal informative features via attention. As the pre-training has largely reduced the embedding distances between two modalities, we discard all learnable parameters in the attention and bidirectionally update the multi-modal features, enabling the whole process to be parameter-free and training-free. In this way, the images are blended with textual-aware signals and the text representations become visual-guided for better adaptive zeroshot alignment. We evaluate CALIP on various benchmarks of 14 datasets for both 2D image and 3D point cloud few-shot classification, showing consistent zero-shot performance improvement over CLIP. Based on that, we further insert a small number of linear layers in CALIP’s attention module and verify our robustness under the few-shot settings, which also achieves leading performance compared to existing methods.
8 |
9 |

10 |
11 |
12 | ## Requirements
13 | ### Installation
14 | Create a conda environment and install dependencies:
15 | ```bash
16 | git clone https://github.com/ZiyuGuo99/CALIP.git
17 | cd CALIP
18 |
19 | conda create -n calip python=3.7
20 | conda activate calip
21 |
22 | # Install the according versions of torch and torchvision
23 | conda install pytorch torchvision cudatoolkit
24 |
25 | pip install -r requirements.txt
26 | ```
27 |
28 | ### Dataset
29 | Follow [DATASET.md](https://github.com/ZiyuGuo99/CALIP/blob/main/DATASET.md) to install ImageNet and other 10 datasets according to CoOp.
30 |
31 | ## Get Started
32 | ### Configs
33 | The configuration for running on each dataset can be modified in `configs/*.yaml`. You need to fill in the `data_root` with your data path. Also, you can edit the settings of `backbone` and `search` as your need, and feel free to adjust `beta2` and `beta3` for a wider or finer search range.
34 |
35 | Note that the default `load_cache` is `False` for the first running, leading to storing the encoded features and labels. It can be set as `True` for faster hyperparamters tuning during later running.
36 |
37 | ### Running
38 | For ImageNet dataset:
39 | ```bash
40 | CUDA_VISIBLE_DEVICES=0 python run_imagenet.py --config configs/imagenet.yaml
41 | ```
42 |
43 | For other 10 datasets:
44 | TODO...
45 |
46 | ## Acknowledgement
47 | This repo benefits from [CLIP](https://github.com/openai/CLIP), [CoOp](https://github.com/KaiyangZhou/Dassl.pytorch), [CLIP-Adapter](https://github.com/gaopengcuhk/CLIP-Adapter) and [Tip-Adapter](https://github.com/gaopengcuhk/Tip-Adapter). Thanks for their wonderful works.
48 |
49 | ## Citation
50 | ```bash
51 | @article{guo2022calip,
52 | title={Calip: Zero-shot enhancement of clip with parameter-free attention},
53 | author={Guo, Ziyu and Zhang, Renrui and Qiu, Longtian and Ma, Xianzheng and Miao, Xupeng and He, Xuming and Cui, Bin},
54 | journal={arXiv preprint arXiv:2209.14169},
55 | year={2022}
56 | }
57 | ```
58 |
59 | ## Contact
60 | If you have any question about this project, please feel free to contact 2101210573@pku.edu.cn.
61 |
--------------------------------------------------------------------------------
/calip.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZiyuGuo99/CALIP/83e21b6b985934dce6e8414c26300671b4825841/calip.png
--------------------------------------------------------------------------------
/class_template.py:
--------------------------------------------------------------------------------
1 |
2 | imagenet_classes = ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray",
3 | "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco",
4 | "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper",
5 | "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander",
6 | "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog",
7 | "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin",
8 | "box turtle", "banded gecko", "green iguana", "Carolina anole",
9 | "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard",
10 | "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile",
11 | "American alligator", "triceratops", "worm snake", "ring-necked snake",
12 | "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake",
13 | "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra",
14 | "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake",
15 | "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider",
16 | "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider",
17 | "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl",
18 | "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet",
19 | "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck",
20 | "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby",
21 | "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch",
22 | "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab",
23 | "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab",
24 | "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron",
25 | "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot",
26 | "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher",
27 | "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion",
28 | "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel",
29 | "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle",
30 | "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound",
31 | "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound",
32 | "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound",
33 | "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier",
34 | "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier",
35 | "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier",
36 | "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier",
37 | "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer",
38 | "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier",
39 | "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier",
40 | "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever",
41 | "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla",
42 | "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel",
43 | "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel",
44 | "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard",
45 | "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie",
46 | "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann",
47 | "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog",
48 | "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff",
49 | "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky",
50 | "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog",
51 | "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon",
52 | "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle",
53 | "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf",
54 | "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox",
55 | "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat",
56 | "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger",
57 | "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose",
58 | "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle",
59 | "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper",
60 | "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper",
61 | "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly",
62 | "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly",
63 | "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit",
64 | "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse",
65 | "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison",
66 | "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)",
67 | "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat",
68 | "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan",
69 | "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque",
70 | "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin",
71 | "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey",
72 | "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda",
73 | "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish",
74 | "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown",
75 | "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance",
76 | "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle",
77 | "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo",
78 | "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel",
79 | "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel",
80 | "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)",
81 | "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini",
82 | "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet",
83 | "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra",
84 | "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest",
85 | "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe",
86 | "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton",
87 | "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran",
88 | "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw",
89 | "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking",
90 | "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker",
91 | "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard",
92 | "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot",
93 | "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed",
94 | "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer",
95 | "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table",
96 | "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig",
97 | "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar",
98 | "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder",
99 | "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute",
100 | "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed",
101 | "freight car", "French horn", "frying pan", "fur coat", "garbage truck",
102 | "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola",
103 | "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine",
104 | "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer",
105 | "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet",
106 | "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar",
107 | "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep",
108 | "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat",
109 | "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library",
110 | "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion",
111 | "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag",
112 | "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask",
113 | "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone",
114 | "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile",
115 | "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor",
116 | "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa",
117 | "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail",
118 | "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina",
119 | "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart",
120 | "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush",
121 | "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench",
122 | "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case",
123 | "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube",
124 | "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball",
125 | "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag",
126 | "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho",
127 | "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug",
128 | "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill",
129 | "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel",
130 | "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator",
131 | "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser",
132 | "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal",
133 | "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard",
134 | "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store",
135 | "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap",
136 | "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door",
137 | "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock",
138 | "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater",
139 | "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight",
140 | "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf",
141 | "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa",
142 | "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge",
143 | "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe",
144 | "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball",
145 | "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof",
146 | "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store",
147 | "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod",
148 | "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard",
149 | "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling",
150 | "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball",
151 | "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink",
152 | "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle",
153 | "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing",
154 | "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website",
155 | "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu",
156 | "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette",
157 | "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli",
158 | "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber",
159 | "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange",
160 | "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate",
161 | "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito",
162 | "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef",
163 | "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player",
164 | "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn",
165 | "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom",
166 | "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"]
167 | imagenet_templates = [
168 | "itap of a {}.",
169 | "a bad photo of the {}.",
170 | "a origami {}.",
171 | "a photo of the large {}.",
172 | "a {} in a video game.",
173 | "art of the {}.",
174 | "a photo of the small {}.",
175 | ]
176 | TEMPLATE = {
177 | 'imagenet' : imagenet_templates,
178 | }
179 | CLASS_NAME = {
180 | 'imagenet' : imagenet_classes,
181 | }
--------------------------------------------------------------------------------
/clip/__init__.py:
--------------------------------------------------------------------------------
1 | from .clip import *
2 |
--------------------------------------------------------------------------------
/clip/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZiyuGuo99/CALIP/83e21b6b985934dce6e8414c26300671b4825841/clip/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/clip/clip.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | import urllib
4 | import warnings
5 | from typing import Any, Union, List
6 |
7 | import torch
8 | from PIL import Image
9 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
10 | from tqdm import tqdm
11 |
12 | from .model import build_model
13 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer
14 |
15 | try:
16 | from torchvision.transforms import InterpolationMode
17 | BICUBIC = InterpolationMode.BICUBIC
18 | except ImportError:
19 | BICUBIC = Image.BICUBIC
20 |
21 |
22 | # if torch.__version__.split(".") < ["1", "7", "1"]:
23 | # warnings.warn("PyTorch version 1.7.1 or higher is recommended")
24 |
25 |
26 | __all__ = ["available_models", "load", "tokenize"]
27 | _tokenizer = _Tokenizer()
28 |
29 | _MODELS = {
30 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
31 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
32 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
33 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
34 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
35 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
36 | }
37 |
38 |
39 | def _download(url: str, root: str):
40 | os.makedirs(root, exist_ok=True)
41 | filename = os.path.basename(url)
42 |
43 | expected_sha256 = url.split("/")[-2]
44 | download_target = os.path.join(root, filename)
45 |
46 | if os.path.exists(download_target) and not os.path.isfile(download_target):
47 | raise RuntimeError(f"{download_target} exists and is not a regular file")
48 |
49 | if os.path.isfile(download_target):
50 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
51 | return download_target
52 | else:
53 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
54 |
55 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
56 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
57 | while True:
58 | buffer = source.read(8192)
59 | if not buffer:
60 | break
61 |
62 | output.write(buffer)
63 | loop.update(len(buffer))
64 |
65 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
66 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
67 |
68 | return download_target
69 |
70 |
71 | def _convert_image_to_rgb(image):
72 | return image.convert("RGB")
73 |
74 |
75 | def _transform(n_px):
76 | return Compose([
77 | Resize(n_px, interpolation=BICUBIC),
78 | CenterCrop(n_px),
79 | _convert_image_to_rgb,
80 | ToTensor(),
81 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
82 | ])
83 |
84 |
85 | def available_models() -> List[str]:
86 | """Returns the names of available CLIP models"""
87 | return list(_MODELS.keys())
88 |
89 |
90 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
91 | """Load a CLIP model
92 |
93 | Parameters
94 | ----------
95 | name : str
96 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
97 |
98 | device : Union[str, torch.device]
99 | The device to put the loaded model
100 |
101 | jit : bool
102 | Whether to load the optimized JIT model or more hackable non-JIT model (default).
103 |
104 | download_root: str
105 | path to download the model files; by default, it uses "~/.cache/clip"
106 |
107 | Returns
108 | -------
109 | model : torch.nn.Module
110 | The CLIP model
111 |
112 | preprocess : Callable[[PIL.Image], torch.Tensor]
113 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
114 | """
115 | if name in _MODELS:
116 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
117 | elif os.path.isfile(name):
118 | model_path = name
119 | else:
120 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
121 |
122 | try:
123 | # loading JIT archive
124 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
125 | state_dict = None
126 | except RuntimeError:
127 | # loading saved state dict
128 | if jit:
129 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
130 | jit = False
131 | state_dict = torch.load(model_path, map_location="cpu")
132 |
133 | if not jit:
134 | model = build_model(state_dict or model.state_dict()).to(device)
135 | if str(device) == "cpu":
136 | model.float()
137 | return model, _transform(model.visual.input_resolution)
138 |
139 | # patch the device names
140 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
141 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
142 |
143 | def patch_device(module):
144 | try:
145 | graphs = [module.graph] if hasattr(module, "graph") else []
146 | except RuntimeError:
147 | graphs = []
148 |
149 | if hasattr(module, "forward1"):
150 | graphs.append(module.forward1.graph)
151 |
152 | for graph in graphs:
153 | for node in graph.findAllNodes("prim::Constant"):
154 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
155 | node.copyAttributes(device_node)
156 |
157 | model.apply(patch_device)
158 | patch_device(model.encode_image)
159 | patch_device(model.encode_text)
160 |
161 | # patch dtype to float32 on CPU
162 | if str(device) == "cpu":
163 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
164 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
165 | float_node = float_input.node()
166 |
167 | def patch_float(module):
168 | try:
169 | graphs = [module.graph] if hasattr(module, "graph") else []
170 | except RuntimeError:
171 | graphs = []
172 |
173 | if hasattr(module, "forward1"):
174 | graphs.append(module.forward1.graph)
175 |
176 | for graph in graphs:
177 | for node in graph.findAllNodes("aten::to"):
178 | inputs = list(node.inputs())
179 | for i in [1, 2]: # dtype can be the second or third argument to aten::to()
180 | if inputs[i].node()["value"] == 5:
181 | inputs[i].node().copyAttributes(float_node)
182 |
183 | model.apply(patch_float)
184 | patch_float(model.encode_image)
185 | patch_float(model.encode_text)
186 |
187 | model.float()
188 |
189 | return model, _transform(model.input_resolution.item())
190 |
191 |
192 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:
193 | """
194 | Returns the tokenized representation of given input string(s)
195 |
196 | Parameters
197 | ----------
198 | texts : Union[str, List[str]]
199 | An input string or a list of input strings to tokenize
200 |
201 | context_length : int
202 | The context length to use; all CLIP models use 77 as the context length
203 |
204 | truncate: bool
205 | Whether to truncate the text in case its encoding is longer than the context length
206 |
207 | Returns
208 | -------
209 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
210 | """
211 | if isinstance(texts, str):
212 | texts = [texts]
213 |
214 | sot_token = _tokenizer.encoder["<|startoftext|>"]
215 | eot_token = _tokenizer.encoder["<|endoftext|>"]
216 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
217 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
218 |
219 | for i, tokens in enumerate(all_tokens):
220 | if len(tokens) > context_length:
221 | if truncate:
222 | tokens = tokens[:context_length]
223 | tokens[-1] = eot_token
224 | else:
225 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
226 | result[i, :len(tokens)] = torch.tensor(tokens)
227 |
228 | return result
229 |
--------------------------------------------------------------------------------
/clip/model.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from typing import Tuple, Union
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | from torch import nn
8 |
9 |
10 | class Bottleneck(nn.Module):
11 | expansion = 4
12 |
13 | def __init__(self, inplanes, planes, stride=1):
14 | super().__init__()
15 |
16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18 | self.bn1 = nn.BatchNorm2d(planes)
19 |
20 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
21 | self.bn2 = nn.BatchNorm2d(planes)
22 |
23 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
24 |
25 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
26 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
27 |
28 | self.relu = nn.ReLU(inplace=True)
29 | self.downsample = None
30 | self.stride = stride
31 |
32 | if stride > 1 or inplanes != planes * Bottleneck.expansion:
33 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
34 | self.downsample = nn.Sequential(OrderedDict([
35 | ("-1", nn.AvgPool2d(stride)),
36 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
37 | ("1", nn.BatchNorm2d(planes * self.expansion))
38 | ]))
39 |
40 | def forward(self, x: torch.Tensor):
41 | identity = x
42 |
43 | out = self.relu(self.bn1(self.conv1(x)))
44 | out = self.relu(self.bn2(self.conv2(out)))
45 | out = self.avgpool(out)
46 | out = self.bn3(self.conv3(out))
47 |
48 | if self.downsample is not None:
49 | identity = self.downsample(x)
50 |
51 | out += identity
52 | out = self.relu(out)
53 | return out
54 |
55 |
56 | class AttentionPool2d(nn.Module):
57 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
58 | super().__init__()
59 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
60 | self.k_proj = nn.Linear(embed_dim, embed_dim)
61 | self.q_proj = nn.Linear(embed_dim, embed_dim)
62 | self.v_proj = nn.Linear(embed_dim, embed_dim)
63 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
64 | self.num_heads = num_heads
65 |
66 | def forward(self, x):
67 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
68 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
69 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
70 | x, _ = F.multi_head_attention_forward(
71 | query=x, key=x, value=x,
72 | embed_dim_to_check=x.shape[-1],
73 | num_heads=self.num_heads,
74 | q_proj_weight=self.q_proj.weight,
75 | k_proj_weight=self.k_proj.weight,
76 | v_proj_weight=self.v_proj.weight,
77 | in_proj_weight=None,
78 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
79 | bias_k=None,
80 | bias_v=None,
81 | add_zero_attn=False,
82 | dropout_p=0,
83 | out_proj_weight=self.c_proj.weight,
84 | out_proj_bias=self.c_proj.bias,
85 | use_separate_proj_weight=True,
86 | training=self.training,
87 | need_weights=False
88 | )
89 | return x
90 |
91 |
92 | class ModifiedResNet(nn.Module):
93 | """
94 | A ResNet class that is similar to torchvision's but contains the following changes:
95 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
96 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
97 | - The final pooling layer is a QKV attention instead of an average pool
98 | """
99 |
100 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
101 | super().__init__()
102 | self.output_dim = output_dim
103 | self.input_resolution = input_resolution
104 |
105 | # the 3-layer stem
106 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
107 | self.bn1 = nn.BatchNorm2d(width // 2)
108 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
109 | self.bn2 = nn.BatchNorm2d(width // 2)
110 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
111 | self.bn3 = nn.BatchNorm2d(width)
112 | self.avgpool = nn.AvgPool2d(2)
113 | self.relu = nn.ReLU(inplace=True)
114 |
115 | # residual layers
116 | self._inplanes = width # this is a *mutable* variable used during construction
117 | self.layer1 = self._make_layer(width, layers[0])
118 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
119 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
120 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
121 |
122 | embed_dim = width * 32 # the ResNet feature dimension
123 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
124 |
125 | def _make_layer(self, planes, blocks, stride=1):
126 | layers = [Bottleneck(self._inplanes, planes, stride)]
127 |
128 | self._inplanes = planes * Bottleneck.expansion
129 | for _ in range(1, blocks):
130 | layers.append(Bottleneck(self._inplanes, planes))
131 |
132 | return nn.Sequential(*layers)
133 |
134 | def forward(self, x):
135 | def stem(x):
136 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
137 | x = self.relu(bn(conv(x)))
138 | x = self.avgpool(x)
139 | return x
140 |
141 | x = x.type(self.conv1.weight.dtype)
142 | x = stem(x)
143 | x = self.layer1(x)
144 | x = self.layer2(x)
145 | x = self.layer3(x)
146 | x = self.layer4(x)
147 | x = self.attnpool(x)
148 |
149 | return x
150 |
151 |
152 | class LayerNorm(nn.LayerNorm):
153 | """Subclass torch's LayerNorm to handle fp16."""
154 |
155 | def forward(self, x: torch.Tensor):
156 | orig_type = x.dtype
157 | ret = super().forward(x.type(torch.float32))
158 | return ret.type(orig_type)
159 |
160 |
161 | class QuickGELU(nn.Module):
162 | def forward(self, x: torch.Tensor):
163 | return x * torch.sigmoid(1.702 * x)
164 |
165 |
166 | class ResidualAttentionBlock(nn.Module):
167 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
168 | super().__init__()
169 |
170 | self.attn = nn.MultiheadAttention(d_model, n_head)
171 | self.ln_1 = LayerNorm(d_model)
172 | self.mlp = nn.Sequential(OrderedDict([
173 | ("c_fc", nn.Linear(d_model, d_model * 4)),
174 | ("gelu", QuickGELU()),
175 | ("c_proj", nn.Linear(d_model * 4, d_model))
176 | ]))
177 | self.ln_2 = LayerNorm(d_model)
178 | self.attn_mask = attn_mask
179 |
180 | def attention(self, x: torch.Tensor):
181 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
182 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
183 |
184 | def forward(self, x: torch.Tensor):
185 | x = x + self.attention(self.ln_1(x))
186 | x = x + self.mlp(self.ln_2(x))
187 | return x
188 |
189 |
190 | class Transformer(nn.Module):
191 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
192 | super().__init__()
193 | self.width = width
194 | self.layers = layers
195 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
196 |
197 | def forward(self, x: torch.Tensor):
198 | return self.resblocks(x)
199 |
200 |
201 | class VisionTransformer(nn.Module):
202 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
203 | super().__init__()
204 | self.input_resolution = input_resolution
205 | self.output_dim = output_dim
206 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
207 |
208 | scale = width ** -0.5
209 | self.class_embedding = nn.Parameter(scale * torch.randn(width))
210 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
211 | self.ln_pre = LayerNorm(width)
212 |
213 | self.transformer = Transformer(width, layers, heads)
214 |
215 | self.ln_post = LayerNorm(width)
216 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
217 |
218 | def forward(self, x: torch.Tensor):
219 | x = self.conv1(x) # shape = [*, width, grid, grid]
220 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
221 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
222 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
223 | x = x + self.positional_embedding.to(x.dtype)
224 | x = self.ln_pre(x)
225 |
226 | x = x.permute(1, 0, 2) # NLD -> LND
227 | x = self.transformer(x)
228 | x = x.permute(1, 0, 2) # LND -> NLD
229 | x = self.ln_post(x)
230 |
231 | if self.proj is not None:
232 | x = x @ self.proj
233 | return x.permute(1, 0, 2)
234 |
235 |
236 | class CLIP(nn.Module):
237 | def __init__(self,
238 | embed_dim: int,
239 | # vision
240 | image_resolution: int,
241 | vision_layers: Union[Tuple[int, int, int, int], int],
242 | vision_width: int,
243 | vision_patch_size: int,
244 | # text
245 | context_length: int,
246 | vocab_size: int,
247 | transformer_width: int,
248 | transformer_heads: int,
249 | transformer_layers: int
250 | ):
251 | super().__init__()
252 |
253 | self.context_length = context_length
254 |
255 | if isinstance(vision_layers, (tuple, list)):
256 | vision_heads = vision_width * 32 // 64
257 | self.visual = ModifiedResNet(
258 | layers=vision_layers,
259 | output_dim=embed_dim,
260 | heads=vision_heads,
261 | input_resolution=image_resolution,
262 | width=vision_width
263 | )
264 | else:
265 | vision_heads = vision_width // 64
266 | self.visual = VisionTransformer(
267 | input_resolution=image_resolution,
268 | patch_size=vision_patch_size,
269 | width=vision_width,
270 | layers=vision_layers,
271 | heads=vision_heads,
272 | output_dim=embed_dim
273 | )
274 |
275 | self.transformer = Transformer(
276 | width=transformer_width,
277 | layers=transformer_layers,
278 | heads=transformer_heads,
279 | attn_mask=self.build_attention_mask()
280 | )
281 |
282 | self.vocab_size = vocab_size
283 | self.token_embedding = nn.Embedding(vocab_size, transformer_width)
284 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
285 | self.ln_final = LayerNorm(transformer_width)
286 |
287 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
288 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
289 |
290 | self.initialize_parameters()
291 |
292 | def initialize_parameters(self):
293 | nn.init.normal_(self.token_embedding.weight, std=0.02)
294 | nn.init.normal_(self.positional_embedding, std=0.01)
295 |
296 | if isinstance(self.visual, ModifiedResNet):
297 | if self.visual.attnpool is not None:
298 | std = self.visual.attnpool.c_proj.in_features ** -0.5
299 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
300 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
301 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
302 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
303 |
304 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
305 | for name, param in resnet_block.named_parameters():
306 | if name.endswith("bn3.weight"):
307 | nn.init.zeros_(param)
308 |
309 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
310 | attn_std = self.transformer.width ** -0.5
311 | fc_std = (2 * self.transformer.width) ** -0.5
312 | for block in self.transformer.resblocks:
313 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
314 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
315 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
316 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
317 |
318 | if self.text_projection is not None:
319 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
320 |
321 | def build_attention_mask(self):
322 | # lazily create causal attention mask, with full attention between the vision tokens
323 | # pytorch uses additive attention mask; fill with -inf
324 | mask = torch.empty(self.context_length, self.context_length)
325 | mask.fill_(float("-inf"))
326 | mask.triu_(1) # zero out the lower diagonal
327 | return mask
328 |
329 | @property
330 | def dtype(self):
331 | return self.visual.conv1.weight.dtype
332 |
333 | def encode_image(self, image):
334 | return self.visual(image.type(self.dtype))
335 |
336 | def encode_text(self, text):
337 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
338 |
339 | x = x + self.positional_embedding.type(self.dtype)
340 | x = x.permute(1, 0, 2) # NLD -> LND
341 | x = self.transformer(x)
342 | x = x.permute(1, 0, 2) # LND -> NLD
343 | x = self.ln_final(x).type(self.dtype)
344 |
345 | # x.shape = [batch_size, n_ctx, transformer.width]
346 | # take features from the eot embedding (eot_token is the highest number in each sequence)
347 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
348 |
349 | return x
350 |
351 | def forward(self, image, text):
352 | image_features = self.encode_image(image)
353 | text_features = self.encode_text(text)
354 |
355 | # normalized features
356 | image_features = image_features / image_features.norm(dim=-1, keepdim=True)
357 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
358 |
359 | # cosine similarity as logits
360 | logit_scale = self.logit_scale.exp()
361 | logits_per_image = logit_scale * image_features @ text_features.t()
362 | logits_per_text = logits_per_image.t()
363 |
364 | # shape = [global_batch_size, global_batch_size]
365 | return logits_per_image, logits_per_text
366 |
367 |
368 | def convert_weights(model: nn.Module):
369 | """Convert applicable model parameters to fp16"""
370 |
371 | def _convert_weights_to_fp16(l):
372 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
373 | l.weight.data = l.weight.data.half()
374 | if l.bias is not None:
375 | l.bias.data = l.bias.data.half()
376 |
377 | if isinstance(l, nn.MultiheadAttention):
378 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
379 | tensor = getattr(l, attr)
380 | if tensor is not None:
381 | tensor.data = tensor.data.half()
382 |
383 | for name in ["text_projection", "proj"]:
384 | if hasattr(l, name):
385 | attr = getattr(l, name)
386 | if attr is not None:
387 | attr.data = attr.data.half()
388 |
389 | model.apply(_convert_weights_to_fp16)
390 |
391 |
392 | def build_model(state_dict: dict):
393 | vit = "visual.proj" in state_dict
394 |
395 | if vit:
396 | vision_width = state_dict["visual.conv1.weight"].shape[0]
397 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
398 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
399 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
400 | image_resolution = vision_patch_size * grid_size
401 | else:
402 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
403 | vision_layers = tuple(counts)
404 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
405 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
406 | vision_patch_size = None
407 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
408 | image_resolution = output_width * 32
409 |
410 | embed_dim = state_dict["text_projection"].shape[1]
411 | context_length = state_dict["positional_embedding"].shape[0]
412 | vocab_size = state_dict["token_embedding.weight"].shape[0]
413 | transformer_width = state_dict["ln_final.weight"].shape[0]
414 | transformer_heads = transformer_width // 64
415 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
416 |
417 | model = CLIP(
418 | embed_dim,
419 | image_resolution, vision_layers, vision_width, vision_patch_size,
420 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
421 | )
422 |
423 | for key in ["input_resolution", "context_length", "vocab_size"]:
424 | if key in state_dict:
425 | del state_dict[key]
426 |
427 | convert_weights(model)
428 | model.load_state_dict(state_dict)
429 | return model.eval()
430 |
--------------------------------------------------------------------------------
/clip/simple_tokenizer.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import html
3 | import os
4 | from functools import lru_cache
5 |
6 | import ftfy
7 | import regex as re
8 |
9 |
10 | @lru_cache()
11 | def default_bpe():
12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13 |
14 |
15 | @lru_cache()
16 | def bytes_to_unicode():
17 | """
18 | Returns list of utf-8 byte and a corresponding list of unicode strings.
19 | The reversible bpe codes work on unicode strings.
20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22 | This is a signficant percentage of your normal, say, 32K bpe vocab.
23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24 | And avoids mapping to whitespace/control characters the bpe code barfs on.
25 | """
26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27 | cs = bs[:]
28 | n = 0
29 | for b in range(2**8):
30 | if b not in bs:
31 | bs.append(b)
32 | cs.append(2**8+n)
33 | n += 1
34 | cs = [chr(n) for n in cs]
35 | return dict(zip(bs, cs))
36 |
37 |
38 | def get_pairs(word):
39 | """Return set of symbol pairs in a word.
40 | Word is represented as tuple of symbols (symbols being variable-length strings).
41 | """
42 | pairs = set()
43 | prev_char = word[0]
44 | for char in word[1:]:
45 | pairs.add((prev_char, char))
46 | prev_char = char
47 | return pairs
48 |
49 |
50 | def basic_clean(text):
51 | text = ftfy.fix_text(text)
52 | text = html.unescape(html.unescape(text))
53 | return text.strip()
54 |
55 |
56 | def whitespace_clean(text):
57 | text = re.sub(r'\s+', ' ', text)
58 | text = text.strip()
59 | return text
60 |
61 |
62 | class SimpleTokenizer(object):
63 | def __init__(self, bpe_path: str = default_bpe()):
64 | self.byte_encoder = bytes_to_unicode()
65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67 | merges = merges[1:49152-256-2+1]
68 | merges = [tuple(merge.split()) for merge in merges]
69 | vocab = list(bytes_to_unicode().values())
70 | vocab = vocab + [v+'' for v in vocab]
71 | for merge in merges:
72 | vocab.append(''.join(merge))
73 | vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74 | self.encoder = dict(zip(vocab, range(len(vocab))))
75 | self.decoder = {v: k for k, v in self.encoder.items()}
76 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79 |
80 | def bpe(self, token):
81 | if token in self.cache:
82 | return self.cache[token]
83 | word = tuple(token[:-1]) + ( token[-1] + '',)
84 | pairs = get_pairs(word)
85 |
86 | if not pairs:
87 | return token+''
88 |
89 | while True:
90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91 | if bigram not in self.bpe_ranks:
92 | break
93 | first, second = bigram
94 | new_word = []
95 | i = 0
96 | while i < len(word):
97 | try:
98 | j = word.index(first, i)
99 | new_word.extend(word[i:j])
100 | i = j
101 | except:
102 | new_word.extend(word[i:])
103 | break
104 |
105 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
106 | new_word.append(first+second)
107 | i += 2
108 | else:
109 | new_word.append(word[i])
110 | i += 1
111 | new_word = tuple(new_word)
112 | word = new_word
113 | if len(word) == 1:
114 | break
115 | else:
116 | pairs = get_pairs(word)
117 | word = ' '.join(word)
118 | self.cache[token] = word
119 | return word
120 |
121 | def encode(self, text):
122 | bpe_tokens = []
123 | text = whitespace_clean(basic_clean(text)).lower()
124 | for token in re.findall(self.pat, text):
125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127 | return bpe_tokens
128 |
129 | def decode(self, tokens):
130 | text = ''.join([self.decoder[token] for token in tokens])
131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
132 | return text
133 |
--------------------------------------------------------------------------------
/configs/imagenet.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path/dataset_name ------
2 | data_root: '/data0/pgao/coop/'
3 | dataset: 'imagenet'
4 |
5 | # ------ Basic Config ------
6 | backbone: 'RN101'
7 |
8 | # ------ Load Cache and Features ------
9 | load_cache: False
10 | # load_cache: True
11 |
12 | # ------ Hyperparamters ------
13 | search: True
14 | # search: False
15 |
16 | beta2: 2.0
17 | beta3: 0.1
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | flake8==3.7.9
2 | yapf==0.29.0
3 | isort==4.3.21
4 | yacs
5 | gdown
6 | tb-nightly
7 | future
8 | scipy
9 | scikit-learn
10 | tqdm
11 | ftfy
12 | regex
13 | wilds==1.2.2
14 | tabulate
15 | chardet
--------------------------------------------------------------------------------
/run_imagenet.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import clip
4 | import torchvision
5 | import torch.nn.functional as F
6 | import torch.nn as nn
7 | from tqdm import tqdm
8 | import argparse, os, yaml
9 | from class_template import TEMPLATE, CLASS_NAME
10 | from utils import accuracy, text_encode
11 |
12 | def get_arguments():
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument('--config', dest='config', help='settings of CALIP in yaml format')
15 | args = parser.parse_args()
16 | print(args)
17 | return args
18 |
19 | def main(cfg):
20 | backbone = cfg['backbone']
21 | total_feat_path = os.path.join('cache', 'total_features', backbone)
22 | label_path = os.path.join('cache', 'label', backbone)
23 | os.makedirs(total_feat_path, exist_ok=True)
24 | os.makedirs(label_path, exist_ok=True)
25 |
26 | clip.available_models()
27 | model, preprocess = clip.load(backbone)
28 | model.eval()
29 |
30 | print(f"Loading {cfg['dataset']} and templates for CALIP: {len(CLASS_NAME[cfg['dataset']])} classes, {len(TEMPLATE[cfg['dataset']])} templates")
31 | dataset = torchvision.datasets.ImageNet(cfg['data_root'] + cfg['dataset'], split='val', transform=preprocess)
32 | loader = torch.utils.data.DataLoader(dataset, batch_size=128, num_workers=8, shuffle=False)
33 |
34 | print('Encoding text features...')
35 | feat_t = text_encode(CLASS_NAME[cfg['dataset']], TEMPLATE[cfg['dataset']], model)
36 | print('Finish encoding text features.')
37 |
38 | if cfg['load_cache']:
39 | print('Loading cached image features and labels from ./cache/...')
40 | total_features = torch.load(total_feat_path + '/' + cfg['dataset'] + '.pt')
41 | labels = torch.load(label_path + '/' + cfg['dataset'] + '.pt')
42 | else:
43 | print('No cached features and labels, start encoding image features with clip...')
44 | total_features = []
45 | labels = []
46 | with torch.no_grad():
47 | for i, (images, label) in enumerate(tqdm(loader)):
48 | images = images.cuda()
49 | label = label.cuda()
50 | features = model.encode_image(images)
51 |
52 | features = features.permute(1, 0, 2)
53 | features /= features.norm(dim=-1, keepdim=True)
54 |
55 | total_features.append(features)
56 | labels.append(label)
57 |
58 | total_features = torch.cat(total_features, dim=0)
59 | labels = torch.cat(labels, dim=0)
60 | torch.save(total_features, total_feat_path + '/' + cfg['dataset'] + '.pt')
61 | torch.save(labels, label_path + '/' + cfg['dataset'] + '.pt')
62 |
63 | img_global_feat = total_features[:, 0, :]
64 | img_spatial_feat = total_features[:, 1: , :]
65 | img_spatial_feat = img_spatial_feat.permute(0, 2, 1)
66 |
67 | # ------------------------------------------ CLIP Zero-shot ------------------------------------------
68 | logits = 100. * img_global_feat @ feat_t
69 | acc, _ = accuracy(logits, labels, n=img_global_feat.size(0))
70 | print(f"CLIP zero-shot accuracy: {acc:.2f}")
71 |
72 | # ------------------------------------------ CALIP Zero-shot -----------------------------------------
73 | def get_logits():
74 | with torch.no_grad():
75 | logits1 = []
76 | logits2 = []
77 | for i, feat_v in enumerate(tqdm(img_spatial_feat)):
78 | A_weight = torch.matmul(feat_v.permute(1, 0), feat_t) * 2
79 | A_weight1 = F.softmax(A_weight, dim=0)
80 | A_weight2 = F.softmax(A_weight, dim=1)
81 |
82 | feat_t_a = torch.matmul(feat_v, A_weight1)
83 | feat_v_a = torch.matmul(A_weight2, feat_t.permute(1, 0))
84 | feat_v_a = feat_v_a.mean(0) + feat_v_a.max(0)[0]
85 |
86 | l1 = 100. * img_global_feat[i] @ feat_t_a
87 | l2 = 100. * feat_v_a @ feat_t
88 | logits1.append(l1.unsqueeze(0))
89 | logits2.append(l2.unsqueeze(0))
90 |
91 | logits1 = torch.cat(logits1, dim=0)
92 | logits2 = torch.cat(logits2, dim=0)
93 | return logits1, logits2
94 |
95 | if cfg['search']:
96 | logits1, logits2 = get_logits()
97 | beta2_list = [i * (cfg['beta2'] - 0.001) / 200 + 0.001 for i in range(200)]
98 | beta3_list = [i * (cfg['beta3'] - 0.001) / 200 + 0.001 for i in range(200)]
99 | print('-' * 20)
100 | print('Starting searching...')
101 | print(' beta1 = 1.0')
102 | print(' beta2 searching range: [0.001, ' + str(cfg['beta2']) + ']')
103 | print(' beta3 searching range: [0.001, ' + str(cfg['beta3']) + ']')
104 | print('-' * 20)
105 |
106 | best_acc = 0.
107 | best_beta2 = 0.
108 | best_beta3 = 0.
109 |
110 | for beta2 in beta2_list:
111 | for beta3 in beta3_list:
112 | logits = 100. * img_global_feat @ feat_t
113 | logits = logits + logits1 * beta2 + logits2 * beta3
114 | acc, _ = accuracy(logits, labels, n=img_global_feat.size(0))
115 |
116 | if acc > best_acc:
117 | print('New best setting, beta1: {:.4f}; beta2: {:.4f}; beta3: {:.4f}; Acc: {:.2f}'.format(1, beta2, beta3, acc))
118 | best_acc = acc
119 | best_beta2 = beta2
120 | best_beta3 = beta3
121 |
122 | print(f"Finish searching {cfg['dataset']} on backbone {cfg['backbone']}. Final Acc: {best_acc:.2f}")
123 |
124 | if __name__ == '__main__':
125 | args = get_arguments()
126 | assert (os.path.exists(args.config))
127 | cfg = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
128 | print(cfg)
129 | main(cfg)
130 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import clip
3 |
4 | def text_encode(classnames, templates, model):
5 | with torch.no_grad():
6 | text_feat = []
7 | for classname in classnames:
8 | texts = [template.format(classname) for template in templates] # format with class
9 | texts = clip.tokenize(texts).cuda() # tokenize
10 | class_embeddings = model.encode_text(texts) # embed with text encoder
11 | class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
12 | class_embedding = class_embeddings.mean(dim=0)
13 | class_embedding /= class_embedding.norm()
14 | text_feat.append(class_embedding)
15 | text_feat = torch.stack(text_feat, dim=1).cuda()
16 | return text_feat
17 |
18 | def accuracy(output, label, n, topk=(1, 5)):
19 | pred = output.topk(max(topk), 1, True, True)[1].t()
20 | correct = pred.eq(label.view(1, -1).expand_as(pred))
21 | return (100 * float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) / n for k in topk)
22 |
23 |
--------------------------------------------------------------------------------