├── DATASET.md ├── README.md ├── calip.png ├── class_template.py ├── clip ├── __init__.py ├── bpe_simple_vocab_16e6.txt.gz ├── clip.py ├── model.py └── simple_tokenizer.py ├── configs └── imagenet.yaml ├── requirements.txt ├── run_imagenet.py └── utils.py /DATASET.md: -------------------------------------------------------------------------------- 1 | # How to install datasets 2 | 3 | We suggest putting all datasets under the same folder (say `$DATA`) to ease management and following the instructions below to organize datasets to avoid modifying the source code. The file structure looks like 4 | 5 | ``` 6 | $DATA/ 7 | |–– imagenet/ 8 | |–– caltech-101/ 9 | |–– oxford_pets/ 10 | |–– stanford_cars/ 11 | ``` 12 | 13 | If you have some datasets already installed somewhere else, you can create symbolic links in `$DATA/dataset_name` that point to the original data to avoid duplicate download. 14 | 15 | Datasets list: 16 | - [ImageNet](#imagenet) 17 | - [Caltech101](#caltech101) 18 | - [OxfordPets](#oxfordpets) 19 | - [StanfordCars](#stanfordcars) 20 | - [Flowers102](#flowers102) 21 | - [Food101](#food101) 22 | - [FGVCAircraft](#fgvcaircraft) 23 | - [SUN397](#sun397) 24 | - [DTD](#dtd) 25 | - [EuroSAT](#eurosat) 26 | - [UCF101](#ucf101) 27 | 28 | The instructions to prepare each dataset are detailed below. To ensure reproducibility and fair comparison for future work, we utilize CoOp-style train/val/test splits for all datasets except ImageNet where the validation set is used as test set. 29 | 30 | ### ImageNet 31 | - Create a folder named `imagenet/` under `$DATA`. 32 | - Create `images/` under `imagenet/`. 33 | - Download the dataset from the [official website](https://image-net.org/index.php) and extract the training and validation sets to `$DATA/imagenet/images`. The directory structure should look like 34 | ``` 35 | imagenet/ 36 | |–– images/ 37 | | |–– train/ # contains 1,000 folders like n01440764, n01443537, etc. 38 | | |–– val/ 39 | ``` 40 | - If you had downloaded the ImageNet dataset before, you can create symbolic links to map the training and validation sets to `$DATA/imagenet/images`. 41 | - Download the `classnames.txt` to `$DATA/imagenet/` from this [link](https://drive.google.com/file/d/1-61f_ol79pViBFDG_IDlUQSwoLcn2XXF/view?usp=sharing). The class names are copied from [CLIP](https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb). 42 | 43 | ### Caltech101 44 | - Create a folder named `caltech-101/` under `$DATA`. 45 | - Download `101_ObjectCategories.tar.gz` from http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz and extract the file under `$DATA/caltech-101`. 46 | - Download `split_zhou_Caltech101.json` from this [link](https://drive.google.com/file/d/1hyarUivQE36mY6jSomru6Fjd-JzwcCzN/view?usp=sharing) and put it under `$DATA/caltech-101`. 47 | 48 | The directory structure should look like 49 | ``` 50 | caltech-101/ 51 | |–– 101_ObjectCategories/ 52 | |–– split_zhou_Caltech101.json 53 | ``` 54 | 55 | ### OxfordPets 56 | - Create a folder named `oxford_pets/` under `$DATA`. 57 | - Download the images from https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz. 58 | - Download the annotations from https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz. 59 | - Download `split_zhou_OxfordPets.json` from this [link](https://drive.google.com/file/d/1501r8Ber4nNKvmlFVQZ8SeUHTcdTTEqs/view?usp=sharing). 60 | 61 | The directory structure should look like 62 | ``` 63 | oxford_pets/ 64 | |–– images/ 65 | |–– annotations/ 66 | |–– split_zhou_OxfordPets.json 67 | ``` 68 | 69 | ### StanfordCars 70 | - Create a folder named `stanford_cars/` under `$DATA`. 71 | - Download the train images http://ai.stanford.edu/~jkrause/car196/cars_train.tgz. 72 | - Download the test images http://ai.stanford.edu/~jkrause/car196/cars_test.tgz. 73 | - Download the train labels https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz. 74 | - Download the test labels http://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat. 75 | - Download `split_zhou_StanfordCars.json` from this [link](https://drive.google.com/file/d/1ObCFbaAgVu0I-k_Au-gIUcefirdAuizT/view?usp=sharing). 76 | 77 | The directory structure should look like 78 | ``` 79 | stanford_cars/ 80 | |–– cars_test\ 81 | |–– cars_test_annos_withlabels.mat 82 | |–– cars_train\ 83 | |–– devkit\ 84 | |–– split_zhou_StanfordCars.json 85 | ``` 86 | 87 | ### Flowers102 88 | - Create a folder named `oxford_flowers/` under `$DATA`. 89 | - Download the images and labels from https://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz and https://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat respectively. 90 | - Download `cat_to_name.json` from [here](https://drive.google.com/file/d/1AkcxCXeK_RCGCEC_GvmWxjcjaNhu-at0/view?usp=sharing). 91 | - Download `split_zhou_OxfordFlowers.json` from [here](https://drive.google.com/file/d/1Pp0sRXzZFZq15zVOzKjKBu4A9i01nozT/view?usp=sharing). 92 | 93 | The directory structure should look like 94 | ``` 95 | oxford_flowers/ 96 | |–– cat_to_name.json 97 | |–– imagelabels.mat 98 | |–– jpg/ 99 | |–– split_zhou_OxfordFlowers.json 100 | ``` 101 | 102 | ### Food101 103 | - Download the dataset from https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/ and extract the file `food-101.tar.gz` under `$DATA`, resulting in a folder named `$DATA/food-101/`. 104 | - Download `split_zhou_Food101.json` from [here](https://drive.google.com/file/d/1QK0tGi096I0Ba6kggatX1ee6dJFIcEJl/view?usp=sharing). 105 | 106 | The directory structure should look like 107 | ``` 108 | food-101/ 109 | |–– images/ 110 | |–– license_agreement.txt 111 | |–– meta/ 112 | |–– README.txt 113 | |–– split_zhou_Food101.json 114 | ``` 115 | 116 | ### FGVCAircraft 117 | - Download the data from https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz. 118 | - Extract `fgvc-aircraft-2013b.tar.gz` and keep only `data/`. 119 | - Move `data/` to `$DATA` and rename the folder to `fgvc_aircraft/`. 120 | 121 | The directory structure should look like 122 | ``` 123 | fgvc_aircraft/ 124 | |–– images/ 125 | |–– ... # a bunch of .txt files 126 | ``` 127 | 128 | ### SUN397 129 | - Create a folder named `sun397/` under `$DATA`. 130 | - Download the images http://vision.princeton.edu/projects/2010/SUN/SUN397.tar.gz. 131 | - Download the partitions https://vision.princeton.edu/projects/2010/SUN/download/Partitions.zip. 132 | - Extract these files under `$DATA/sun397/`. 133 | - Download `split_zhou_SUN397.json` from this [link](https://drive.google.com/file/d/1y2RD81BYuiyvebdN-JymPfyWYcd8_MUq/view?usp=sharing). 134 | 135 | The directory structure should look like 136 | ``` 137 | sun397/ 138 | |–– SUN397/ 139 | |–– split_zhou_SUN397.json 140 | |–– ... # a bunch of .txt files 141 | ``` 142 | 143 | ### DTD 144 | - Download the dataset from https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz and extract it to `$DATA`. This should lead to `$DATA/dtd/`. 145 | - Download `split_zhou_DescribableTextures.json` from this [link](https://drive.google.com/file/d/1u3_QfB467jqHgNXC00UIzbLZRQCg2S7x/view?usp=sharing). 146 | 147 | The directory structure should look like 148 | ``` 149 | dtd/ 150 | |–– images/ 151 | |–– imdb/ 152 | |–– labels/ 153 | |–– split_zhou_DescribableTextures.json 154 | ``` 155 | 156 | ### EuroSAT 157 | - Create a folder named `eurosat/` under `$DATA`. 158 | - Download the dataset from http://madm.dfki.de/files/sentinel/EuroSAT.zip and extract it to `$DATA/eurosat/`. 159 | - Download `split_zhou_EuroSAT.json` from [here](https://drive.google.com/file/d/1Ip7yaCWFi0eaOFUGga0lUdVi_DDQth1o/view?usp=sharing). 160 | 161 | The directory structure should look like 162 | ``` 163 | eurosat/ 164 | |–– 2750/ 165 | |–– split_zhou_EuroSAT.json 166 | ``` 167 | 168 | ### UCF101 169 | - Create a folder named `ucf101/` under `$DATA`. 170 | - Download the zip file `UCF-101-midframes.zip` from [here](https://drive.google.com/file/d/10Jqome3vtUA2keJkNanAiFpgbyC9Hc2O/view?usp=sharing) and extract it to `$DATA/ucf101/`. This zip file contains the extracted middle video frames. 171 | - Download `split_zhou_UCF101.json` from this [link](https://drive.google.com/file/d/1I0S0q91hJfsV9Gf4xDIjgDq4AqBNJb1y/view?usp=sharing). 172 | 173 | The directory structure should look like 174 | ``` 175 | ucf101/ 176 | |–– UCF-101-midframes/ 177 | |–– split_zhou_UCF101.json 178 | ``` 179 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CALIP: Zero-Shot Enhancement of CLIP with Parameter-free Attention 2 | Official implementation of ['CALIP: Zero-Shot Enhancement of CLIP with Parameter-free Attention'](https://arxiv.org/pdf/2209.14169.pdf). 3 | 4 | The paper has been accepted by **AAAI 2023**. 5 | 6 | ## Introduction 7 | CALIP is a free-lunch enhancement method to boost CLIP’s zero-shot performance via a parameter-free Attention module. Specifically, we guide visual and textual representations to interact with each other and explore cross-modal informative features via attention. As the pre-training has largely reduced the embedding distances between two modalities, we discard all learnable parameters in the attention and bidirectionally update the multi-modal features, enabling the whole process to be parameter-free and training-free. In this way, the images are blended with textual-aware signals and the text representations become visual-guided for better adaptive zeroshot alignment. We evaluate CALIP on various benchmarks of 14 datasets for both 2D image and 3D point cloud few-shot classification, showing consistent zero-shot performance improvement over CLIP. Based on that, we further insert a small number of linear layers in CALIP’s attention module and verify our robustness under the few-shot settings, which also achieves leading performance compared to existing methods. 8 |
9 | 10 |
11 | 12 | ## Requirements 13 | ### Installation 14 | Create a conda environment and install dependencies: 15 | ```bash 16 | git clone https://github.com/ZiyuGuo99/CALIP.git 17 | cd CALIP 18 | 19 | conda create -n calip python=3.7 20 | conda activate calip 21 | 22 | # Install the according versions of torch and torchvision 23 | conda install pytorch torchvision cudatoolkit 24 | 25 | pip install -r requirements.txt 26 | ``` 27 | 28 | ### Dataset 29 | Follow [DATASET.md](https://github.com/ZiyuGuo99/CALIP/blob/main/DATASET.md) to install ImageNet and other 10 datasets according to CoOp. 30 | 31 | ## Get Started 32 | ### Configs 33 | The configuration for running on each dataset can be modified in `configs/*.yaml`. You need to fill in the `data_root` with your data path. Also, you can edit the settings of `backbone` and `search` as your need, and feel free to adjust `beta2` and `beta3` for a wider or finer search range. 34 | 35 | Note that the default `load_cache` is `False` for the first running, leading to storing the encoded features and labels. It can be set as `True` for faster hyperparamters tuning during later running. 36 | 37 | ### Running 38 | For ImageNet dataset: 39 | ```bash 40 | CUDA_VISIBLE_DEVICES=0 python run_imagenet.py --config configs/imagenet.yaml 41 | ``` 42 | 43 | For other 10 datasets: 44 | TODO... 45 | 46 | ## Acknowledgement 47 | This repo benefits from [CLIP](https://github.com/openai/CLIP), [CoOp](https://github.com/KaiyangZhou/Dassl.pytorch), [CLIP-Adapter](https://github.com/gaopengcuhk/CLIP-Adapter) and [Tip-Adapter](https://github.com/gaopengcuhk/Tip-Adapter). Thanks for their wonderful works. 48 | 49 | ## Citation 50 | ```bash 51 | @article{guo2022calip, 52 | title={Calip: Zero-shot enhancement of clip with parameter-free attention}, 53 | author={Guo, Ziyu and Zhang, Renrui and Qiu, Longtian and Ma, Xianzheng and Miao, Xupeng and He, Xuming and Cui, Bin}, 54 | journal={arXiv preprint arXiv:2209.14169}, 55 | year={2022} 56 | } 57 | ``` 58 | 59 | ## Contact 60 | If you have any question about this project, please feel free to contact 2101210573@pku.edu.cn. 61 | -------------------------------------------------------------------------------- /calip.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZiyuGuo99/CALIP/83e21b6b985934dce6e8414c26300671b4825841/calip.png -------------------------------------------------------------------------------- /class_template.py: -------------------------------------------------------------------------------- 1 | 2 | imagenet_classes = ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", 3 | "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", 4 | "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", 5 | "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", 6 | "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", 7 | "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", 8 | "box turtle", "banded gecko", "green iguana", "Carolina anole", 9 | "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", 10 | "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", 11 | "American alligator", "triceratops", "worm snake", "ring-necked snake", 12 | "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", 13 | "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", 14 | "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", 15 | "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", 16 | "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", 17 | "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", 18 | "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", 19 | "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", 20 | "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", 21 | "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", 22 | "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", 23 | "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", 24 | "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", 25 | "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", 26 | "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", 27 | "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", 28 | "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", 29 | "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", 30 | "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", 31 | "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", 32 | "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", 33 | "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", 34 | "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", 35 | "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", 36 | "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", 37 | "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", 38 | "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", 39 | "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", 40 | "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", 41 | "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", 42 | "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", 43 | "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", 44 | "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", 45 | "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", 46 | "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", 47 | "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", 48 | "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", 49 | "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", 50 | "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", 51 | "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", 52 | "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", 53 | "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", 54 | "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", 55 | "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", 56 | "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", 57 | "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", 58 | "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", 59 | "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", 60 | "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", 61 | "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", 62 | "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", 63 | "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", 64 | "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", 65 | "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", 66 | "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", 67 | "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", 68 | "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", 69 | "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", 70 | "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", 71 | "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", 72 | "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", 73 | "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", 74 | "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", 75 | "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", 76 | "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", 77 | "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", 78 | "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", 79 | "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", 80 | "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", 81 | "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", 82 | "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", 83 | "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", 84 | "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", 85 | "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", 86 | "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton", 87 | "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", 88 | "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", 89 | "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", 90 | "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", 91 | "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", 92 | "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", 93 | "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", 94 | "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", 95 | "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", 96 | "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", 97 | "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", 98 | "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", 99 | "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", 100 | "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", 101 | "freight car", "French horn", "frying pan", "fur coat", "garbage truck", 102 | "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", 103 | "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", 104 | "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", 105 | "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", 106 | "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", 107 | "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", 108 | "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", 109 | "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", 110 | "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", 111 | "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", 112 | "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", 113 | "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", 114 | "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", 115 | "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", 116 | "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", 117 | "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", 118 | "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", 119 | "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", 120 | "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush", 121 | "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", 122 | "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case", 123 | "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", 124 | "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", 125 | "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", 126 | "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", 127 | "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", 128 | "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill", 129 | "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", 130 | "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator", 131 | "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", 132 | "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal", 133 | "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", 134 | "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store", 135 | "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap", 136 | "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door", 137 | "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", 138 | "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", 139 | "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight", 140 | "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", 141 | "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", 142 | "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge", 143 | "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe", 144 | "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", 145 | "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", 146 | "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", 147 | "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", 148 | "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", 149 | "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", 150 | "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", 151 | "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", 152 | "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", 153 | "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", 154 | "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", 155 | "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", 156 | "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", 157 | "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", 158 | "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", 159 | "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", 160 | "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", 161 | "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", 162 | "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", 163 | "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", 164 | "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", 165 | "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", 166 | "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"] 167 | imagenet_templates = [ 168 | "itap of a {}.", 169 | "a bad photo of the {}.", 170 | "a origami {}.", 171 | "a photo of the large {}.", 172 | "a {} in a video game.", 173 | "art of the {}.", 174 | "a photo of the small {}.", 175 | ] 176 | TEMPLATE = { 177 | 'imagenet' : imagenet_templates, 178 | } 179 | CLASS_NAME = { 180 | 'imagenet' : imagenet_classes, 181 | } -------------------------------------------------------------------------------- /clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZiyuGuo99/CALIP/83e21b6b985934dce6e8414c26300671b4825841/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Any, Union, List 6 | 7 | import torch 8 | from PIL import Image 9 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 10 | from tqdm import tqdm 11 | 12 | from .model import build_model 13 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 14 | 15 | try: 16 | from torchvision.transforms import InterpolationMode 17 | BICUBIC = InterpolationMode.BICUBIC 18 | except ImportError: 19 | BICUBIC = Image.BICUBIC 20 | 21 | 22 | # if torch.__version__.split(".") < ["1", "7", "1"]: 23 | # warnings.warn("PyTorch version 1.7.1 or higher is recommended") 24 | 25 | 26 | __all__ = ["available_models", "load", "tokenize"] 27 | _tokenizer = _Tokenizer() 28 | 29 | _MODELS = { 30 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 31 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 32 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 33 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 34 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 35 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 36 | } 37 | 38 | 39 | def _download(url: str, root: str): 40 | os.makedirs(root, exist_ok=True) 41 | filename = os.path.basename(url) 42 | 43 | expected_sha256 = url.split("/")[-2] 44 | download_target = os.path.join(root, filename) 45 | 46 | if os.path.exists(download_target) and not os.path.isfile(download_target): 47 | raise RuntimeError(f"{download_target} exists and is not a regular file") 48 | 49 | if os.path.isfile(download_target): 50 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 51 | return download_target 52 | else: 53 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 54 | 55 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 56 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 57 | while True: 58 | buffer = source.read(8192) 59 | if not buffer: 60 | break 61 | 62 | output.write(buffer) 63 | loop.update(len(buffer)) 64 | 65 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 66 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 67 | 68 | return download_target 69 | 70 | 71 | def _convert_image_to_rgb(image): 72 | return image.convert("RGB") 73 | 74 | 75 | def _transform(n_px): 76 | return Compose([ 77 | Resize(n_px, interpolation=BICUBIC), 78 | CenterCrop(n_px), 79 | _convert_image_to_rgb, 80 | ToTensor(), 81 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 82 | ]) 83 | 84 | 85 | def available_models() -> List[str]: 86 | """Returns the names of available CLIP models""" 87 | return list(_MODELS.keys()) 88 | 89 | 90 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): 91 | """Load a CLIP model 92 | 93 | Parameters 94 | ---------- 95 | name : str 96 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 97 | 98 | device : Union[str, torch.device] 99 | The device to put the loaded model 100 | 101 | jit : bool 102 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 103 | 104 | download_root: str 105 | path to download the model files; by default, it uses "~/.cache/clip" 106 | 107 | Returns 108 | ------- 109 | model : torch.nn.Module 110 | The CLIP model 111 | 112 | preprocess : Callable[[PIL.Image], torch.Tensor] 113 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 114 | """ 115 | if name in _MODELS: 116 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) 117 | elif os.path.isfile(name): 118 | model_path = name 119 | else: 120 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 121 | 122 | try: 123 | # loading JIT archive 124 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 125 | state_dict = None 126 | except RuntimeError: 127 | # loading saved state dict 128 | if jit: 129 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 130 | jit = False 131 | state_dict = torch.load(model_path, map_location="cpu") 132 | 133 | if not jit: 134 | model = build_model(state_dict or model.state_dict()).to(device) 135 | if str(device) == "cpu": 136 | model.float() 137 | return model, _transform(model.visual.input_resolution) 138 | 139 | # patch the device names 140 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 141 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 142 | 143 | def patch_device(module): 144 | try: 145 | graphs = [module.graph] if hasattr(module, "graph") else [] 146 | except RuntimeError: 147 | graphs = [] 148 | 149 | if hasattr(module, "forward1"): 150 | graphs.append(module.forward1.graph) 151 | 152 | for graph in graphs: 153 | for node in graph.findAllNodes("prim::Constant"): 154 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 155 | node.copyAttributes(device_node) 156 | 157 | model.apply(patch_device) 158 | patch_device(model.encode_image) 159 | patch_device(model.encode_text) 160 | 161 | # patch dtype to float32 on CPU 162 | if str(device) == "cpu": 163 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 164 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 165 | float_node = float_input.node() 166 | 167 | def patch_float(module): 168 | try: 169 | graphs = [module.graph] if hasattr(module, "graph") else [] 170 | except RuntimeError: 171 | graphs = [] 172 | 173 | if hasattr(module, "forward1"): 174 | graphs.append(module.forward1.graph) 175 | 176 | for graph in graphs: 177 | for node in graph.findAllNodes("aten::to"): 178 | inputs = list(node.inputs()) 179 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 180 | if inputs[i].node()["value"] == 5: 181 | inputs[i].node().copyAttributes(float_node) 182 | 183 | model.apply(patch_float) 184 | patch_float(model.encode_image) 185 | patch_float(model.encode_text) 186 | 187 | model.float() 188 | 189 | return model, _transform(model.input_resolution.item()) 190 | 191 | 192 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: 193 | """ 194 | Returns the tokenized representation of given input string(s) 195 | 196 | Parameters 197 | ---------- 198 | texts : Union[str, List[str]] 199 | An input string or a list of input strings to tokenize 200 | 201 | context_length : int 202 | The context length to use; all CLIP models use 77 as the context length 203 | 204 | truncate: bool 205 | Whether to truncate the text in case its encoding is longer than the context length 206 | 207 | Returns 208 | ------- 209 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 210 | """ 211 | if isinstance(texts, str): 212 | texts = [texts] 213 | 214 | sot_token = _tokenizer.encoder["<|startoftext|>"] 215 | eot_token = _tokenizer.encoder["<|endoftext|>"] 216 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 217 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 218 | 219 | for i, tokens in enumerate(all_tokens): 220 | if len(tokens) > context_length: 221 | if truncate: 222 | tokens = tokens[:context_length] 223 | tokens[-1] = eot_token 224 | else: 225 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 226 | result[i, :len(tokens)] = torch.tensor(tokens) 227 | 228 | return result 229 | -------------------------------------------------------------------------------- /clip/model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, inplanes, planes, stride=1): 14 | super().__init__() 15 | 16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | 20 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | 23 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 24 | 25 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 26 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 27 | 28 | self.relu = nn.ReLU(inplace=True) 29 | self.downsample = None 30 | self.stride = stride 31 | 32 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 33 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 34 | self.downsample = nn.Sequential(OrderedDict([ 35 | ("-1", nn.AvgPool2d(stride)), 36 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 37 | ("1", nn.BatchNorm2d(planes * self.expansion)) 38 | ])) 39 | 40 | def forward(self, x: torch.Tensor): 41 | identity = x 42 | 43 | out = self.relu(self.bn1(self.conv1(x))) 44 | out = self.relu(self.bn2(self.conv2(out))) 45 | out = self.avgpool(out) 46 | out = self.bn3(self.conv3(out)) 47 | 48 | if self.downsample is not None: 49 | identity = self.downsample(x) 50 | 51 | out += identity 52 | out = self.relu(out) 53 | return out 54 | 55 | 56 | class AttentionPool2d(nn.Module): 57 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 58 | super().__init__() 59 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 60 | self.k_proj = nn.Linear(embed_dim, embed_dim) 61 | self.q_proj = nn.Linear(embed_dim, embed_dim) 62 | self.v_proj = nn.Linear(embed_dim, embed_dim) 63 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 64 | self.num_heads = num_heads 65 | 66 | def forward(self, x): 67 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 68 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 69 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 70 | x, _ = F.multi_head_attention_forward( 71 | query=x, key=x, value=x, 72 | embed_dim_to_check=x.shape[-1], 73 | num_heads=self.num_heads, 74 | q_proj_weight=self.q_proj.weight, 75 | k_proj_weight=self.k_proj.weight, 76 | v_proj_weight=self.v_proj.weight, 77 | in_proj_weight=None, 78 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 79 | bias_k=None, 80 | bias_v=None, 81 | add_zero_attn=False, 82 | dropout_p=0, 83 | out_proj_weight=self.c_proj.weight, 84 | out_proj_bias=self.c_proj.bias, 85 | use_separate_proj_weight=True, 86 | training=self.training, 87 | need_weights=False 88 | ) 89 | return x 90 | 91 | 92 | class ModifiedResNet(nn.Module): 93 | """ 94 | A ResNet class that is similar to torchvision's but contains the following changes: 95 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 96 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 97 | - The final pooling layer is a QKV attention instead of an average pool 98 | """ 99 | 100 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 101 | super().__init__() 102 | self.output_dim = output_dim 103 | self.input_resolution = input_resolution 104 | 105 | # the 3-layer stem 106 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 107 | self.bn1 = nn.BatchNorm2d(width // 2) 108 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 109 | self.bn2 = nn.BatchNorm2d(width // 2) 110 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 111 | self.bn3 = nn.BatchNorm2d(width) 112 | self.avgpool = nn.AvgPool2d(2) 113 | self.relu = nn.ReLU(inplace=True) 114 | 115 | # residual layers 116 | self._inplanes = width # this is a *mutable* variable used during construction 117 | self.layer1 = self._make_layer(width, layers[0]) 118 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 119 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 120 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 121 | 122 | embed_dim = width * 32 # the ResNet feature dimension 123 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 124 | 125 | def _make_layer(self, planes, blocks, stride=1): 126 | layers = [Bottleneck(self._inplanes, planes, stride)] 127 | 128 | self._inplanes = planes * Bottleneck.expansion 129 | for _ in range(1, blocks): 130 | layers.append(Bottleneck(self._inplanes, planes)) 131 | 132 | return nn.Sequential(*layers) 133 | 134 | def forward(self, x): 135 | def stem(x): 136 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: 137 | x = self.relu(bn(conv(x))) 138 | x = self.avgpool(x) 139 | return x 140 | 141 | x = x.type(self.conv1.weight.dtype) 142 | x = stem(x) 143 | x = self.layer1(x) 144 | x = self.layer2(x) 145 | x = self.layer3(x) 146 | x = self.layer4(x) 147 | x = self.attnpool(x) 148 | 149 | return x 150 | 151 | 152 | class LayerNorm(nn.LayerNorm): 153 | """Subclass torch's LayerNorm to handle fp16.""" 154 | 155 | def forward(self, x: torch.Tensor): 156 | orig_type = x.dtype 157 | ret = super().forward(x.type(torch.float32)) 158 | return ret.type(orig_type) 159 | 160 | 161 | class QuickGELU(nn.Module): 162 | def forward(self, x: torch.Tensor): 163 | return x * torch.sigmoid(1.702 * x) 164 | 165 | 166 | class ResidualAttentionBlock(nn.Module): 167 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 168 | super().__init__() 169 | 170 | self.attn = nn.MultiheadAttention(d_model, n_head) 171 | self.ln_1 = LayerNorm(d_model) 172 | self.mlp = nn.Sequential(OrderedDict([ 173 | ("c_fc", nn.Linear(d_model, d_model * 4)), 174 | ("gelu", QuickGELU()), 175 | ("c_proj", nn.Linear(d_model * 4, d_model)) 176 | ])) 177 | self.ln_2 = LayerNorm(d_model) 178 | self.attn_mask = attn_mask 179 | 180 | def attention(self, x: torch.Tensor): 181 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 182 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 183 | 184 | def forward(self, x: torch.Tensor): 185 | x = x + self.attention(self.ln_1(x)) 186 | x = x + self.mlp(self.ln_2(x)) 187 | return x 188 | 189 | 190 | class Transformer(nn.Module): 191 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 192 | super().__init__() 193 | self.width = width 194 | self.layers = layers 195 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 196 | 197 | def forward(self, x: torch.Tensor): 198 | return self.resblocks(x) 199 | 200 | 201 | class VisionTransformer(nn.Module): 202 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): 203 | super().__init__() 204 | self.input_resolution = input_resolution 205 | self.output_dim = output_dim 206 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 207 | 208 | scale = width ** -0.5 209 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 210 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 211 | self.ln_pre = LayerNorm(width) 212 | 213 | self.transformer = Transformer(width, layers, heads) 214 | 215 | self.ln_post = LayerNorm(width) 216 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 217 | 218 | def forward(self, x: torch.Tensor): 219 | x = self.conv1(x) # shape = [*, width, grid, grid] 220 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 221 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 222 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 223 | x = x + self.positional_embedding.to(x.dtype) 224 | x = self.ln_pre(x) 225 | 226 | x = x.permute(1, 0, 2) # NLD -> LND 227 | x = self.transformer(x) 228 | x = x.permute(1, 0, 2) # LND -> NLD 229 | x = self.ln_post(x) 230 | 231 | if self.proj is not None: 232 | x = x @ self.proj 233 | return x.permute(1, 0, 2) 234 | 235 | 236 | class CLIP(nn.Module): 237 | def __init__(self, 238 | embed_dim: int, 239 | # vision 240 | image_resolution: int, 241 | vision_layers: Union[Tuple[int, int, int, int], int], 242 | vision_width: int, 243 | vision_patch_size: int, 244 | # text 245 | context_length: int, 246 | vocab_size: int, 247 | transformer_width: int, 248 | transformer_heads: int, 249 | transformer_layers: int 250 | ): 251 | super().__init__() 252 | 253 | self.context_length = context_length 254 | 255 | if isinstance(vision_layers, (tuple, list)): 256 | vision_heads = vision_width * 32 // 64 257 | self.visual = ModifiedResNet( 258 | layers=vision_layers, 259 | output_dim=embed_dim, 260 | heads=vision_heads, 261 | input_resolution=image_resolution, 262 | width=vision_width 263 | ) 264 | else: 265 | vision_heads = vision_width // 64 266 | self.visual = VisionTransformer( 267 | input_resolution=image_resolution, 268 | patch_size=vision_patch_size, 269 | width=vision_width, 270 | layers=vision_layers, 271 | heads=vision_heads, 272 | output_dim=embed_dim 273 | ) 274 | 275 | self.transformer = Transformer( 276 | width=transformer_width, 277 | layers=transformer_layers, 278 | heads=transformer_heads, 279 | attn_mask=self.build_attention_mask() 280 | ) 281 | 282 | self.vocab_size = vocab_size 283 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 284 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 285 | self.ln_final = LayerNorm(transformer_width) 286 | 287 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 288 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 289 | 290 | self.initialize_parameters() 291 | 292 | def initialize_parameters(self): 293 | nn.init.normal_(self.token_embedding.weight, std=0.02) 294 | nn.init.normal_(self.positional_embedding, std=0.01) 295 | 296 | if isinstance(self.visual, ModifiedResNet): 297 | if self.visual.attnpool is not None: 298 | std = self.visual.attnpool.c_proj.in_features ** -0.5 299 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 300 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 301 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 302 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 303 | 304 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 305 | for name, param in resnet_block.named_parameters(): 306 | if name.endswith("bn3.weight"): 307 | nn.init.zeros_(param) 308 | 309 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 310 | attn_std = self.transformer.width ** -0.5 311 | fc_std = (2 * self.transformer.width) ** -0.5 312 | for block in self.transformer.resblocks: 313 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 314 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 315 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 316 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 317 | 318 | if self.text_projection is not None: 319 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 320 | 321 | def build_attention_mask(self): 322 | # lazily create causal attention mask, with full attention between the vision tokens 323 | # pytorch uses additive attention mask; fill with -inf 324 | mask = torch.empty(self.context_length, self.context_length) 325 | mask.fill_(float("-inf")) 326 | mask.triu_(1) # zero out the lower diagonal 327 | return mask 328 | 329 | @property 330 | def dtype(self): 331 | return self.visual.conv1.weight.dtype 332 | 333 | def encode_image(self, image): 334 | return self.visual(image.type(self.dtype)) 335 | 336 | def encode_text(self, text): 337 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 338 | 339 | x = x + self.positional_embedding.type(self.dtype) 340 | x = x.permute(1, 0, 2) # NLD -> LND 341 | x = self.transformer(x) 342 | x = x.permute(1, 0, 2) # LND -> NLD 343 | x = self.ln_final(x).type(self.dtype) 344 | 345 | # x.shape = [batch_size, n_ctx, transformer.width] 346 | # take features from the eot embedding (eot_token is the highest number in each sequence) 347 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 348 | 349 | return x 350 | 351 | def forward(self, image, text): 352 | image_features = self.encode_image(image) 353 | text_features = self.encode_text(text) 354 | 355 | # normalized features 356 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 357 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 358 | 359 | # cosine similarity as logits 360 | logit_scale = self.logit_scale.exp() 361 | logits_per_image = logit_scale * image_features @ text_features.t() 362 | logits_per_text = logits_per_image.t() 363 | 364 | # shape = [global_batch_size, global_batch_size] 365 | return logits_per_image, logits_per_text 366 | 367 | 368 | def convert_weights(model: nn.Module): 369 | """Convert applicable model parameters to fp16""" 370 | 371 | def _convert_weights_to_fp16(l): 372 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 373 | l.weight.data = l.weight.data.half() 374 | if l.bias is not None: 375 | l.bias.data = l.bias.data.half() 376 | 377 | if isinstance(l, nn.MultiheadAttention): 378 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 379 | tensor = getattr(l, attr) 380 | if tensor is not None: 381 | tensor.data = tensor.data.half() 382 | 383 | for name in ["text_projection", "proj"]: 384 | if hasattr(l, name): 385 | attr = getattr(l, name) 386 | if attr is not None: 387 | attr.data = attr.data.half() 388 | 389 | model.apply(_convert_weights_to_fp16) 390 | 391 | 392 | def build_model(state_dict: dict): 393 | vit = "visual.proj" in state_dict 394 | 395 | if vit: 396 | vision_width = state_dict["visual.conv1.weight"].shape[0] 397 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 398 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 399 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 400 | image_resolution = vision_patch_size * grid_size 401 | else: 402 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 403 | vision_layers = tuple(counts) 404 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 405 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 406 | vision_patch_size = None 407 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 408 | image_resolution = output_width * 32 409 | 410 | embed_dim = state_dict["text_projection"].shape[1] 411 | context_length = state_dict["positional_embedding"].shape[0] 412 | vocab_size = state_dict["token_embedding.weight"].shape[0] 413 | transformer_width = state_dict["ln_final.weight"].shape[0] 414 | transformer_heads = transformer_width // 64 415 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) 416 | 417 | model = CLIP( 418 | embed_dim, 419 | image_resolution, vision_layers, vision_width, vision_patch_size, 420 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 421 | ) 422 | 423 | for key in ["input_resolution", "context_length", "vocab_size"]: 424 | if key in state_dict: 425 | del state_dict[key] 426 | 427 | convert_weights(model) 428 | model.load_state_dict(state_dict) 429 | return model.eval() 430 | -------------------------------------------------------------------------------- /clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /configs/imagenet.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path/dataset_name ------ 2 | data_root: '/data0/pgao/coop/' 3 | dataset: 'imagenet' 4 | 5 | # ------ Basic Config ------ 6 | backbone: 'RN101' 7 | 8 | # ------ Load Cache and Features ------ 9 | load_cache: False 10 | # load_cache: True 11 | 12 | # ------ Hyperparamters ------ 13 | search: True 14 | # search: False 15 | 16 | beta2: 2.0 17 | beta3: 0.1 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | flake8==3.7.9 2 | yapf==0.29.0 3 | isort==4.3.21 4 | yacs 5 | gdown 6 | tb-nightly 7 | future 8 | scipy 9 | scikit-learn 10 | tqdm 11 | ftfy 12 | regex 13 | wilds==1.2.2 14 | tabulate 15 | chardet -------------------------------------------------------------------------------- /run_imagenet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import clip 4 | import torchvision 5 | import torch.nn.functional as F 6 | import torch.nn as nn 7 | from tqdm import tqdm 8 | import argparse, os, yaml 9 | from class_template import TEMPLATE, CLASS_NAME 10 | from utils import accuracy, text_encode 11 | 12 | def get_arguments(): 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--config', dest='config', help='settings of CALIP in yaml format') 15 | args = parser.parse_args() 16 | print(args) 17 | return args 18 | 19 | def main(cfg): 20 | backbone = cfg['backbone'] 21 | total_feat_path = os.path.join('cache', 'total_features', backbone) 22 | label_path = os.path.join('cache', 'label', backbone) 23 | os.makedirs(total_feat_path, exist_ok=True) 24 | os.makedirs(label_path, exist_ok=True) 25 | 26 | clip.available_models() 27 | model, preprocess = clip.load(backbone) 28 | model.eval() 29 | 30 | print(f"Loading {cfg['dataset']} and templates for CALIP: {len(CLASS_NAME[cfg['dataset']])} classes, {len(TEMPLATE[cfg['dataset']])} templates") 31 | dataset = torchvision.datasets.ImageNet(cfg['data_root'] + cfg['dataset'], split='val', transform=preprocess) 32 | loader = torch.utils.data.DataLoader(dataset, batch_size=128, num_workers=8, shuffle=False) 33 | 34 | print('Encoding text features...') 35 | feat_t = text_encode(CLASS_NAME[cfg['dataset']], TEMPLATE[cfg['dataset']], model) 36 | print('Finish encoding text features.') 37 | 38 | if cfg['load_cache']: 39 | print('Loading cached image features and labels from ./cache/...') 40 | total_features = torch.load(total_feat_path + '/' + cfg['dataset'] + '.pt') 41 | labels = torch.load(label_path + '/' + cfg['dataset'] + '.pt') 42 | else: 43 | print('No cached features and labels, start encoding image features with clip...') 44 | total_features = [] 45 | labels = [] 46 | with torch.no_grad(): 47 | for i, (images, label) in enumerate(tqdm(loader)): 48 | images = images.cuda() 49 | label = label.cuda() 50 | features = model.encode_image(images) 51 | 52 | features = features.permute(1, 0, 2) 53 | features /= features.norm(dim=-1, keepdim=True) 54 | 55 | total_features.append(features) 56 | labels.append(label) 57 | 58 | total_features = torch.cat(total_features, dim=0) 59 | labels = torch.cat(labels, dim=0) 60 | torch.save(total_features, total_feat_path + '/' + cfg['dataset'] + '.pt') 61 | torch.save(labels, label_path + '/' + cfg['dataset'] + '.pt') 62 | 63 | img_global_feat = total_features[:, 0, :] 64 | img_spatial_feat = total_features[:, 1: , :] 65 | img_spatial_feat = img_spatial_feat.permute(0, 2, 1) 66 | 67 | # ------------------------------------------ CLIP Zero-shot ------------------------------------------ 68 | logits = 100. * img_global_feat @ feat_t 69 | acc, _ = accuracy(logits, labels, n=img_global_feat.size(0)) 70 | print(f"CLIP zero-shot accuracy: {acc:.2f}") 71 | 72 | # ------------------------------------------ CALIP Zero-shot ----------------------------------------- 73 | def get_logits(): 74 | with torch.no_grad(): 75 | logits1 = [] 76 | logits2 = [] 77 | for i, feat_v in enumerate(tqdm(img_spatial_feat)): 78 | A_weight = torch.matmul(feat_v.permute(1, 0), feat_t) * 2 79 | A_weight1 = F.softmax(A_weight, dim=0) 80 | A_weight2 = F.softmax(A_weight, dim=1) 81 | 82 | feat_t_a = torch.matmul(feat_v, A_weight1) 83 | feat_v_a = torch.matmul(A_weight2, feat_t.permute(1, 0)) 84 | feat_v_a = feat_v_a.mean(0) + feat_v_a.max(0)[0] 85 | 86 | l1 = 100. * img_global_feat[i] @ feat_t_a 87 | l2 = 100. * feat_v_a @ feat_t 88 | logits1.append(l1.unsqueeze(0)) 89 | logits2.append(l2.unsqueeze(0)) 90 | 91 | logits1 = torch.cat(logits1, dim=0) 92 | logits2 = torch.cat(logits2, dim=0) 93 | return logits1, logits2 94 | 95 | if cfg['search']: 96 | logits1, logits2 = get_logits() 97 | beta2_list = [i * (cfg['beta2'] - 0.001) / 200 + 0.001 for i in range(200)] 98 | beta3_list = [i * (cfg['beta3'] - 0.001) / 200 + 0.001 for i in range(200)] 99 | print('-' * 20) 100 | print('Starting searching...') 101 | print(' beta1 = 1.0') 102 | print(' beta2 searching range: [0.001, ' + str(cfg['beta2']) + ']') 103 | print(' beta3 searching range: [0.001, ' + str(cfg['beta3']) + ']') 104 | print('-' * 20) 105 | 106 | best_acc = 0. 107 | best_beta2 = 0. 108 | best_beta3 = 0. 109 | 110 | for beta2 in beta2_list: 111 | for beta3 in beta3_list: 112 | logits = 100. * img_global_feat @ feat_t 113 | logits = logits + logits1 * beta2 + logits2 * beta3 114 | acc, _ = accuracy(logits, labels, n=img_global_feat.size(0)) 115 | 116 | if acc > best_acc: 117 | print('New best setting, beta1: {:.4f}; beta2: {:.4f}; beta3: {:.4f}; Acc: {:.2f}'.format(1, beta2, beta3, acc)) 118 | best_acc = acc 119 | best_beta2 = beta2 120 | best_beta3 = beta3 121 | 122 | print(f"Finish searching {cfg['dataset']} on backbone {cfg['backbone']}. Final Acc: {best_acc:.2f}") 123 | 124 | if __name__ == '__main__': 125 | args = get_arguments() 126 | assert (os.path.exists(args.config)) 127 | cfg = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) 128 | print(cfg) 129 | main(cfg) 130 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import clip 3 | 4 | def text_encode(classnames, templates, model): 5 | with torch.no_grad(): 6 | text_feat = [] 7 | for classname in classnames: 8 | texts = [template.format(classname) for template in templates] # format with class 9 | texts = clip.tokenize(texts).cuda() # tokenize 10 | class_embeddings = model.encode_text(texts) # embed with text encoder 11 | class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) 12 | class_embedding = class_embeddings.mean(dim=0) 13 | class_embedding /= class_embedding.norm() 14 | text_feat.append(class_embedding) 15 | text_feat = torch.stack(text_feat, dim=1).cuda() 16 | return text_feat 17 | 18 | def accuracy(output, label, n, topk=(1, 5)): 19 | pred = output.topk(max(topk), 1, True, True)[1].t() 20 | correct = pred.eq(label.view(1, -1).expand_as(pred)) 21 | return (100 * float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) / n for k in topk) 22 | 23 | --------------------------------------------------------------------------------