├── train ├── __init__.py ├── base.py └── hash_train.py ├── dataset ├── __init__.py ├── used_label.txt ├── make_mirflickr25k.py ├── dataloader.py ├── base.py ├── make_nuswide.py └── make_coco.py ├── result └── init ├── model ├── __init__.py ├── bpe_simple_vocab_16e6.txt.gz ├── hash_model.py ├── simple_tokenizer.py ├── optimization.py ├── clip.py └── model.py ├── utils ├── codetable.xlsx ├── __init__.py ├── logger.py ├── get_args.py ├── calc_utils.py └── utils.py ├── requirements.txt ├── main.py └── README.md /train/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /result/init: -------------------------------------------------------------------------------- 1 | save result 2 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /utils/codetable.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QinLab-WFU/DSPH/HEAD/utils/codetable.xlsx -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | scipy 3 | numpy 4 | pillow 5 | matplotlib 6 | sklearn 7 | pytorch==1.12.1 8 | 9 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from train.hash_train import Trainer 2 | 3 | 4 | if __name__ == "__main__": 5 | 6 | Trainer() 7 | 8 | 9 | -------------------------------------------------------------------------------- /model/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QinLab-WFU/DSPH/HEAD/model/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | from .logger import get_logger, get_summary_writer 3 | from .get_args import get_args -------------------------------------------------------------------------------- /dataset/used_label.txt: -------------------------------------------------------------------------------- 1 | Labels_animal.txt 2 | Labels_beach.txt 3 | Labels_buildings.txt 4 | Labels_clouds.txt 5 | Labels_flowers.txt 6 | Labels_grass.txt 7 | Labels_lake.txt 8 | Labels_mountain.txt 9 | Labels_ocean.txt 10 | Labels_person.txt 11 | Labels_plants.txt 12 | Labels_reflection.txt 13 | Labels_road.txt 14 | Labels_rocks.txt 15 | Labels_sky.txt 16 | Labels_snow.txt 17 | Labels_sunset.txt 18 | Labels_tree.txt 19 | Labels_vehicle.txt 20 | Labels_water.txt 21 | Labels_window.txt 22 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | from torch.utils.tensorboard import SummaryWriter 5 | 6 | def get_logger(filename=None): 7 | logger = logging.getLogger('logger') 8 | logger.setLevel(logging.DEBUG) 9 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', 10 | datefmt='%m/%d/%Y %H:%M:%S', 11 | level=logging.INFO) 12 | if filename is not None: 13 | handler = logging.FileHandler(filename) 14 | handler.setLevel(logging.DEBUG) 15 | handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s')) 16 | logging.getLogger().addHandler(handler) 17 | return logger 18 | 19 | def get_summary_writer(dirname: str): 20 | 21 | os.makedirs(dirname, exist_ok=True) 22 | return SummaryWriter(log_dir=dirname) -------------------------------------------------------------------------------- /utils/get_args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import xlrd 4 | 5 | 6 | def get_args(): 7 | 8 | parser = argparse.ArgumentParser() 9 | 10 | parser.add_argument("--save-dir", type=str, default="./result/64-bit") 11 | parser.add_argument("--clip-path", type=str, default="./ViT-B-32.pt", help="pretrained clip path.") 12 | parser.add_argument("--pretrained", type=str, default="") 13 | parser.add_argument("--dataset", type=str, default="flickr25k", help="choise from [coco, mirflckr25k, nuswide]") 14 | parser.add_argument("--index-file", type=str, default="index.mat") 15 | parser.add_argument("--caption-file", type=str, default="caption.mat") 16 | parser.add_argument("--label-file", type=str, default="label.mat") 17 | 18 | parser.add_argument("--output-dim", type=int, default=64) 19 | parser.add_argument("--numclass", type=int, default=24) 20 | parser.add_argument("--epochs", type=int, default=100) 21 | parser.add_argument("--max-words", type=int, default=32) 22 | parser.add_argument("--resolution", type=int, default=224) 23 | parser.add_argument("--batch-size", type=int, default=64) 24 | parser.add_argument("--num-workers", type=int, default=4) 25 | parser.add_argument("--query-num", type=int, default=5000) 26 | parser.add_argument("--train-num", type=int, default=10000) 27 | parser.add_argument("--lr-decay-freq", type=int, default=5) 28 | parser.add_argument("--display-step", type=int, default=50) 29 | parser.add_argument("--seed", type=int, default=1814) 30 | parser.add_argument("--hypseed", type=int, default=0) 31 | 32 | parser.add_argument("--lr", type=float, default=0.001) 33 | parser.add_argument("--alpha", type=float, default=0.8) 34 | parser.add_argument("--lr-decay", type=float, default=0.9) 35 | parser.add_argument("--clip-lr", type=float, default=0.00001) 36 | parser.add_argument("--weight-decay", type=float, default=0.2) 37 | parser.add_argument("--warmup-proportion", type=float, default=0.1, 38 | help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% of training.") 39 | 40 | parser.add_argument("--is-train", action="store_true") 41 | 42 | args = parser.parse_args() 43 | 44 | return args 45 | 46 | 47 | args = get_args() 48 | 49 | sheet = xlrd.open_workbook('./utils/codetable.xlsx').sheet_by_index(0) 50 | threshold = sheet.row(args.output_dim)[math.ceil(math.log(args.numclass, 2))].value 51 | 52 | -------------------------------------------------------------------------------- /dataset/make_mirflickr25k.py: -------------------------------------------------------------------------------- 1 | import os 2 | import scipy.io as scio 3 | import numpy as np 4 | 5 | # mirflickr25k_annotations_v080 and mirflickr 6 | # mkdir mat 7 | # mv make_mirflickr25k.py mat 8 | # python make_mirflickr25k.py 9 | root_dir = "/home/admin00/dataset/Cleared-Set/archive/" 10 | 11 | file_path = os.path.join(root_dir, "mirflickr25k_annotations_v080") 12 | 13 | file_list = os.listdir(file_path) 14 | 15 | file_list = [item for item in file_list if "_r1" not in item and "README" not in item] 16 | 17 | print("class num:", len(file_list)) 18 | 19 | class_index = {} 20 | for i, item in enumerate(file_list): 21 | class_index.update({item: i}) 22 | 23 | label_dict = {} 24 | for path_id in file_list: 25 | path = os.path.join(file_path, path_id) 26 | with open(path, "r") as f: 27 | for item in f: 28 | item = item.strip() 29 | if item not in label_dict: 30 | label = np.zeros(len(file_list)) 31 | label[class_index[path_id]] = 1 32 | label_dict.update({item: label}) 33 | else: 34 | # print() 35 | label_dict[item][class_index[path_id]] = 1 36 | 37 | # print(label_dict) 38 | print("create label:", len(label_dict)) 39 | keys = list(label_dict.keys()) 40 | keys.sort() 41 | 42 | labels = [] 43 | for key in keys: 44 | labels.append(label_dict[key]) 45 | print("labels created:", len(labels)) 46 | labels = {"category": labels} 47 | 48 | 49 | PATH = os.path.join(root_dir, "mirflickr25k", "mirflickr") 50 | index = [os.path.join(PATH, "im" + item + ".jpg") for item in keys] 51 | print("index created:", len(index)) 52 | index= {"index": index} 53 | 54 | 55 | captions_path = os.path.join(root_dir, "mirflickr25k", "mirflickr/meta/tags") 56 | captions_list = os.listdir(captions_path) 57 | captions_dict = {} 58 | for item in captions_list: 59 | id_ = item.split(".")[0].replace("tags", "") 60 | caption = "" 61 | with open(os.path.join(captions_path, item), "r") as f: 62 | for word in f.readlines(): 63 | caption += word.strip() + " " 64 | caption = caption.strip() 65 | captions_dict.update({id_: caption}) 66 | 67 | captions = [] 68 | 69 | for item in keys: 70 | captions.append([captions_dict[item]]) 71 | 72 | print("captions created:", len(captions)) 73 | captions = {"caption": captions} 74 | 75 | scio.savemat("/home/admin00/dataset/Cleared-Set/archive/index.mat", index) 76 | scio.savemat("/home/admin00/dataset/Cleared-Set/archive/caption.mat", captions) 77 | scio.savemat("/home/admin00/dataset/Cleared-Set/archive/label.mat", labels) 78 | 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Semantic-aware Proxy Hashing for Multi-label Cross-modal Retrieval [Paper](https://ieeexplore.ieee.org/document/10149001) 2 | This paper is accepted for IEEE Transactions on Circuits and Systems for Video Technology (TCSVT). 3 | If you have any questions please contact hyd199810@163.com. 4 | 5 | ## Dependencies 6 | We use python to build our code, you need to install those package to run 7 | 8 | - pytorch 1.12.1 9 | - sklearn 10 | - tqdm 11 | - pillow 12 | 13 | ## Training 14 | 15 | ### Processing dataset 16 | Before training, you need to download the oringal data from [coco](https://www.kaggle.com/datasets/awsaf49/coco-2017-dataset)(include 2017 train,val and annotations), nuswide [Google drive](https://drive.google.com/file/d/11w3J98uL_KHWn9j22GeKWc5K_AYM5U3V/view?usp=drive_link), mirflickr25k [Baidu, code: u9e1](https://pan.baidu.com/s/1upgnBNNVfBzMiIET9zPfZQ) or [Google drive](https://drive.google.com/file/d/18oGgziSwhRzKlAjbqNZfj-HuYzbxWYTh/view?usp=sharing) (include mirflickr25k and mirflickr25k_annotations_v080), then use the "data/make_XXX.py" to generate .mat file. The generated data is available from [Baidu, code: wyb8](https://pan.baidu.com/s/17QeFeZgTAOiY9qe0wo8OxQ). 17 | 18 | After all mat file generated, the dir of `dataset` will like this: 19 | ~~~ 20 | dataset 21 | ├── base.py 22 | ├── __init__.py 23 | ├── dataloader.py 24 | ├── coco 25 | │   ├── caption.mat 26 | │   ├── index.mat 27 | │   └── label.mat 28 | ├── flickr25k 29 | │   ├── caption.mat 30 | │   ├── index.mat 31 | │   └── label.mat 32 | └── nuswide 33 |     ├── caption.txt # Notice! It is a txt file! 34 |     ├── index.mat 35 |     └── label.mat 36 | ~~~ 37 | 38 | ### Download CLIP pretrained model 39 | Pretrained model will be found in the 30 lines of [CLIP/clip/clip.py](https://github.com/openai/CLIP/blob/main/clip/clip.py). This code is based on the "ViT-B/32". 40 | 41 | You should copy ViT-B-32.pt to this dir. 42 | 43 | ### Start 44 | 45 | After the dataset has been prepared, we could run the follow command to train. 46 | > python main.py --is-train --dataset coco --caption-file caption.mat --index-file index.mat --label-file label.mat --lr 0.001 --output-dim 64 --save-dir ./result/coco/64 --clip-path ./ViT-B-32.pt --batch-size 128 --numclass 80 47 | 48 | 49 | ## Citation 50 | @ARTICLE{10149001, 51 | author={Huo, Yadong and Qin, Qibing and Dai, Jiangyan and Wang, Lei and Zhang, Wenfeng and Huang, Lei and Wang, Chengduan}, 52 | journal={IEEE Transactions on Circuits and Systems for Video Technology}, 53 | title={Deep Semantic-Aware Proxy Hashing for Multi-Label Cross-Modal Retrieval}, 54 | year={2024}, 55 | volume={34}, 56 | number={1}, 57 | pages={576-589}, 58 | doi={[10.1109/TCSVT.2023.3285266](https://ieeexplore.ieee.org/document/10149001)}} 59 | 60 | 61 | ## Acknowledegements 62 | [DCHMT](https://github.com/kalenforn/DCHMT) 63 | -------------------------------------------------------------------------------- /dataset/dataloader.py: -------------------------------------------------------------------------------- 1 | from .base import BaseDataset 2 | import os 3 | import numpy as np 4 | import scipy.io as scio 5 | 6 | 7 | def split_data(captions, indexs, labels, query_num=5000, train_num=10000, seed=None): 8 | np.random.seed(seed=seed) 9 | random_index = np.random.permutation(range(len(indexs))) 10 | query_index = random_index[: query_num] 11 | train_index = random_index[query_num: query_num + train_num] 12 | retrieval_index = random_index[query_num:] 13 | 14 | query_indexs = indexs[query_index] 15 | query_captions = captions[query_index] 16 | query_labels = labels[query_index] 17 | 18 | train_indexs = indexs[train_index] 19 | train_captions = captions[train_index] 20 | train_labels = labels[train_index] 21 | 22 | retrieval_indexs = indexs[retrieval_index] 23 | retrieval_captions = captions[retrieval_index] 24 | retrieval_labels = labels[retrieval_index] 25 | 26 | split_indexs = (query_indexs, train_indexs, retrieval_indexs) 27 | split_captions = (query_captions, train_captions, retrieval_captions) 28 | split_labels = (query_labels, train_labels, retrieval_labels) 29 | return split_indexs, split_captions, split_labels 30 | 31 | def dataloader(captionFile: str, 32 | indexFile: str, 33 | labelFile: str, 34 | maxWords=32, 35 | imageResolution=224, 36 | query_num=5000, 37 | train_num=10000, 38 | seed=None, 39 | npy=False): 40 | if captionFile.endswith("mat"): 41 | captions = scio.loadmat(captionFile)["caption"] 42 | captions = captions[0] if captions.shape[0] == 1 else captions 43 | elif captionFile.endswith("txt"): 44 | with open(captionFile, "r") as f: 45 | captions = f.readlines() 46 | captions = np.asarray([[item.strip()] for item in captions]) 47 | else: 48 | raise ValueError("the format of 'captionFile' doesn't support, only support [txt, mat] format.") 49 | if not npy: 50 | indexs = scio.loadmat(indexFile)["index"] 51 | else: 52 | indexs = np.load(indexFile, allow_pickle=True) 53 | labels = scio.loadmat(labelFile)["category"] 54 | 55 | 56 | split_indexs, split_captions, split_labels = split_data(captions, indexs, labels, query_num=query_num, train_num=train_num, seed=seed) 57 | 58 | train_data = BaseDataset(captions=split_captions[1], indexs=split_indexs[1], labels=split_labels[1], maxWords=maxWords, imageResolution=imageResolution, npy=npy) 59 | query_data = BaseDataset(captions=split_captions[0], indexs=split_indexs[0], labels=split_labels[0], maxWords=maxWords, imageResolution=imageResolution, is_train=False, npy=npy) 60 | retrieval_data = BaseDataset(captions=split_captions[2], indexs=split_indexs[2], labels=split_labels[2], maxWords=maxWords, imageResolution=imageResolution, is_train=False, npy=npy) 61 | 62 | return train_data, query_data, retrieval_data 63 | 64 | 65 | -------------------------------------------------------------------------------- /utils/calc_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from tqdm import tqdm 4 | 5 | 6 | def calc_hammingDist(B1, B2): 7 | q = B2.shape[1] 8 | if len(B1.shape) < 2: 9 | B1 = B1.unsqueeze(0) 10 | distH = 0.5 * (q - B1.mm(B2.transpose(0, 1))) 11 | return distH 12 | 13 | 14 | def calc_map_k_matrix(qB, rB, query_L, retrieval_L, k=None, rank=0): 15 | 16 | num_query = query_L.shape[0] 17 | if qB.is_cuda: 18 | qB = qB.cpu() 19 | rB = rB.cpu() 20 | map = 0 21 | if k is None: 22 | k = retrieval_L.shape[0] 23 | for iter in range(num_query): 24 | gnd = (query_L[iter].unsqueeze(0).mm(retrieval_L.t()) > 0).type(torch.float).squeeze() 25 | tsum = torch.sum(gnd) 26 | if tsum == 0: 27 | continue 28 | hamm = calc_hammingDist(qB[iter, :], rB) 29 | _, ind = torch.sort(hamm) 30 | ind.squeeze_() 31 | gnd = gnd[ind] 32 | total = min(k, int(tsum)) 33 | count = torch.arange(1, total + 1).type(torch.float).to(gnd.device) 34 | tindex = torch.nonzero(gnd)[:total].squeeze().type(torch.float) + 1.0 35 | map += torch.mean(count / tindex) 36 | map = map / num_query 37 | return map 38 | 39 | 40 | def calc_neighbor(label1, label2): 41 | # calculate the similar matrix 42 | Sim = label1.matmul(label2.transpose(0, 1)) > 0 43 | return Sim.float() 44 | 45 | 46 | def norm_max_min(x: torch.Tensor, dim=None): 47 | if dim is None: 48 | max = torch.max(x) 49 | min = torch.min(x) 50 | if dim is not None: 51 | max = torch.max(x, dim=dim)[0] 52 | min = torch.min(x, dim=dim)[0] 53 | if dim > 0: 54 | max = max.unsqueeze(len(x.shape) - 1) 55 | min = min.unsqueeze(len(x.shape) - 1) 56 | norm = (x - min) / (max - min) 57 | return norm 58 | 59 | 60 | def norm_mean(x: torch.Tensor, dim=None): 61 | if dim is None: 62 | mean = torch.mean(x) 63 | std = torch.std(x) 64 | if dim is not None: 65 | mean = torch.mean(x, dim=dim) 66 | std = torch.std(x, dim=dim) 67 | if dim > 0: 68 | mean = mean.unsqueeze(len(x.shape) - 1) 69 | std = std.unsqueeze(len(x.shape) - 1) 70 | norm = (x - mean) / std 71 | return norm 72 | 73 | 74 | def norm_abs_mean(x: torch.Tensor, dim=None): 75 | if dim is None: 76 | mean = torch.mean(x) 77 | std = torch.std(x) 78 | if dim is not None: 79 | mean = torch.mean(x, dim=dim) 80 | std = torch.std(x, dim=dim) 81 | if dim > 0: 82 | mean = mean.unsqueeze(len(x.shape) - 1) 83 | std = std.unsqueeze(len(x.shape) - 1) 84 | norm = torch.abs(x - mean) / std 85 | return norm 86 | 87 | 88 | def factorial(n): 89 | if n == 0: 90 | return 1 91 | else: 92 | return n * factorial(n - 1) 93 | 94 | 95 | def calc_IF(all_bow): 96 | word_num = torch.sum(all_bow, dim=0) 97 | total_num = torch.sum(word_num) 98 | IF = word_num / total_num 99 | return IF 100 | -------------------------------------------------------------------------------- /train/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import torch 4 | 5 | from torch import distributed as dist 6 | from utils import get_logger, get_summary_writer 7 | 8 | 9 | class TrainBase(object): 10 | 11 | def __init__(self, 12 | args, 13 | rank=1): 14 | 15 | self.args = args 16 | os.makedirs(args.save_dir, exist_ok=True) 17 | self._init_writer() 18 | self.logger.info(self.args) 19 | self.rank = rank 20 | 21 | self._init_dataset() 22 | self._init_model() 23 | 24 | self.global_step = 0 25 | # self.global_step_t = 0 26 | self.max_mapi2t = 0 27 | self.max_mapt2i = 0 28 | self.best_epoch_i = 0 29 | self.best_epoch_t = 0 30 | 31 | def _init_dataset(self): 32 | self.train_loader = None 33 | self.query_loader = None 34 | self.retrieval_loader = None 35 | 36 | def _init_model(self): 37 | self.model = None 38 | self.model_ddp = None 39 | 40 | def _init_writer(self): 41 | self.logger = get_logger(os.path.join(self.args.save_dir, "train.log" if self.args.is_train else "test.log")) 42 | self.writer = get_summary_writer(os.path.join(self.args.save_dir, "tensorboard")) 43 | 44 | def run(self): 45 | if self.args.is_train: 46 | self.train() 47 | else: 48 | self.test() 49 | 50 | def change_state(self, mode): 51 | 52 | if mode == "train": 53 | self.model.train() 54 | elif mode == "valid": 55 | self.model.eval() 56 | 57 | def get_code(self, data_loader, length: int, feature_map): 58 | 59 | img_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank) 60 | text_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank) 61 | 62 | for image, text, label, index in tqdm(data_loader): 63 | image = image.to(self.rank, non_blocking=True) 64 | text = text.to(self.rank, non_blocking=True) 65 | index = index.numpy() 66 | image_hash = self.model.encode_image(image, feature_map) 67 | image_hash = torch.sign(image_hash) 68 | text_hash = self.model.encode_text(text, feature_map) 69 | text_hash = torch.sign(text_hash) 70 | 71 | img_buffer[index, :] = image_hash.data 72 | text_buffer[index, :] = text_hash.data 73 | 74 | return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank) 75 | 76 | def save_model(self, epoch): 77 | torch.save(self.model.state_dict(), os.path.join(self.args.save_dir, "model-" + str(epoch) + ".pth")) 78 | self.logger.info("save mode to {}".format(os.path.join(self.args.save_dir, "model-" + str(epoch) + ".pth"))) 79 | 80 | def train(self): 81 | raise NotImplementedError("Function of 'train' doesn't implement.") 82 | 83 | def valid(self): 84 | raise NotImplementedError("Function of 'valid' doesn't implement.") 85 | 86 | def test(self): 87 | raise NotImplementedError("Function of 'test' doesn't implement.") 88 | 89 | def compute_loss(self): 90 | raise NotImplementedError("Function of 'compute_loss' doesn't implement.") 91 | -------------------------------------------------------------------------------- /model/hash_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import logging 4 | import torch.nn as nn 5 | import numpy as np 6 | from typing import Union 7 | 8 | from model.model import build_model 9 | from utils import get_logger, get_summary_writer 10 | 11 | 12 | class LinearHash(nn.Module): 13 | 14 | def __init__(self, inputDim=2048, outputDim=64): 15 | super(LinearHash, self).__init__() 16 | self.fc = nn.Linear(inputDim, outputDim) 17 | self.drop_out = nn.Dropout(p=0.2) 18 | 19 | def forward(self, data): 20 | result = self.fc(data) 21 | return torch.tanh(self.drop_out(result)) 22 | 23 | 24 | class DSPH(nn.Module): 25 | 26 | def __init__(self, 27 | outputDim=64, 28 | clipPath="./ViT-B-32.pt", 29 | writer=None, 30 | saveDir="./result/log", 31 | logger: logging.Logger=None, 32 | is_train=True): 33 | super(DSPH, self).__init__() 34 | os.makedirs(saveDir, exist_ok=True) 35 | self.logger = logger if logger is not None else get_logger(os.path.join(saveDir, "train.log" if is_train else "test.log")) 36 | self.writer = writer if writer is not None and is_train else get_summary_writer(os.path.join(saveDir, "tensorboard")) 37 | embedDim, self.clip = self.load_clip(clipPath) 38 | self.image_hash = LinearHash(inputDim=embedDim, outputDim=outputDim) 39 | self.text_hash = LinearHash(inputDim=embedDim, outputDim=outputDim) 40 | 41 | def freezen(self): 42 | for name, param in self.clip.named_parameters(): 43 | # print(name) 44 | if name.find("ln_final.") == 0 or name.find("text_projection") == 0 or name.find("logit_scale") == 0 \ 45 | or name.find("visual.ln_post.") == 0 or name.find("visual.proj") == 0: 46 | # print("1") 47 | continue 48 | elif name.find("visual.transformer.resblocks.") == 0 or name.find("transformer.resblocks.") == 0: 49 | layer_num = int(name.split(".resblocks.")[1].split(".")[0]) 50 | if layer_num >= 12: 51 | # print("2") 52 | continue 53 | if name.find("conv2.") == 0: 54 | # print("3") 55 | continue 56 | else: 57 | # paramenters which < freeze_layer_num will be freezed 58 | param.requires_grad = False 59 | 60 | def load_clip(self, clipPath: str) -> tuple: 61 | try: 62 | model = torch.jit.load(clipPath, map_location="cpu").eval() 63 | state_dict = model.state_dict() 64 | except RuntimeError: 65 | state_dict = torch.load(clipPath, map_location="cpu") 66 | 67 | return state_dict["text_projection"].shape[1], build_model(state_dict) 68 | 69 | def encode_image(self, image): 70 | 71 | image_embed = self.clip.encode_image(image) #512 72 | 73 | image_embed = self.image_hash(image_embed) 74 | 75 | return image_embed 76 | 77 | def eval(self): 78 | self.image_hash.eval() 79 | self.text_hash.eval() 80 | # self.clip.eval() 81 | 82 | def train(self): 83 | self.image_hash.train() 84 | self.text_hash.train() 85 | 86 | def encode_text(self, text): 87 | 88 | text_embed = self.clip.encode_text(text) 89 | 90 | text_embed = self.text_hash(text_embed) 91 | 92 | return text_embed 93 | 94 | def forward(self, image, text): 95 | image_embed = self.encode_image(image) 96 | text_embed = self.encode_text(text) 97 | return image_embed, text_embed 98 | 99 | -------------------------------------------------------------------------------- /dataset/base.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import unicode_literals 4 | from __future__ import print_function 5 | 6 | from torch.utils.data import Dataset 7 | import torch 8 | import random 9 | from PIL import Image 10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 11 | from model.simple_tokenizer import SimpleTokenizer as Tokenizer 12 | 13 | 14 | class BaseDataset(Dataset): 15 | 16 | def __init__(self, 17 | 18 | captions: dict, 19 | indexs: dict, 20 | labels: dict, 21 | is_train=True, 22 | tokenizer=Tokenizer(), 23 | maxWords=32, 24 | imageResolution=224, 25 | npy=False): 26 | 27 | self.captions = captions 28 | self.indexs = indexs 29 | self.labels = labels 30 | self.npy = npy 31 | 32 | self.maxWords = maxWords 33 | self.tokenizer = tokenizer 34 | 35 | self.transform = Compose([ 36 | Resize(imageResolution, interpolation=Image.BICUBIC), 37 | CenterCrop(imageResolution), 38 | ToTensor(), 39 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 40 | ]) if is_train else Compose([ 41 | Resize((imageResolution, imageResolution), interpolation=Image.BICUBIC), 42 | ToTensor(), 43 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 44 | ]) 45 | self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>", 46 | "MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"} 47 | 48 | self.__length = len(self.indexs) 49 | 50 | def __len__(self): 51 | return self.__length 52 | 53 | def _load_image(self, index: int) -> torch.Tensor: 54 | if not self.npy: 55 | image_path = self.indexs[index].strip() 56 | # print(image_path) 57 | image = Image.open(image_path).convert("RGB") 58 | else: 59 | image = Image.fromarray(self.indexs[index]).convert("RGB") 60 | image = self.transform(image) 61 | 62 | return image 63 | 64 | def _load_text(self, index: int): 65 | captions = self.captions[index] 66 | use_cap = captions[random.randint(0, len(captions) - 1)] 67 | 68 | words = self.tokenizer.tokenize(use_cap) 69 | words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words 70 | total_length_with_CLS = self.maxWords - 1 71 | if len(words) > total_length_with_CLS: 72 | words = words[:total_length_with_CLS] 73 | 74 | words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]] 75 | caption = self.tokenizer.convert_tokens_to_ids(words) 76 | 77 | while len(caption) < self.maxWords: 78 | caption.append(0) 79 | caption = torch.tensor(caption) 80 | 81 | return caption 82 | 83 | def _load_label(self, index: int) -> torch.Tensor: 84 | label = self.labels[index] 85 | label = torch.from_numpy(label) 86 | 87 | return label 88 | 89 | def get_all_label(self): 90 | labels = torch.zeros([self.__length, len(self.labels[0])], dtype=torch.float32) 91 | for i, item in enumerate(self.labels): 92 | 93 | labels[i] = torch.from_numpy(item) 94 | return labels 95 | 96 | def __getitem__(self, index): 97 | image = self._load_image(index) 98 | caption = self._load_text(index) 99 | label = self._load_label(index) 100 | 101 | return image, caption, label, index -------------------------------------------------------------------------------- /dataset/make_nuswide.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | 4 | import scipy.io as scio 5 | import numpy as np 6 | 7 | # mkdir mat 8 | # mv make_nuswide.py mat 9 | # python make_nuswide.py 10 | root_dir = "/home/admin00/dataset/NUS-WIDE" 11 | 12 | 13 | imageListFile = os.path.join(root_dir, "ImageList", "Imagelist.txt") 14 | labelPath = os.path.join(root_dir, "Groundtruth", "AllLabels") 15 | textFile = os.path.join(root_dir, "NUS_WID_Tags", "All_Tags.txt") 16 | classIndexFile = os.path.join(root_dir, "ConceptsList", "Concepts81.txt") 17 | 18 | # you can use the image urls to download images 19 | imagePath = os.path.join("/home/admin00/dataset/nuswide/Flickr") 20 | 21 | with open(imageListFile, "r") as f: 22 | indexs = f.readlines() 23 | 24 | indexs = [os.path.join(imagePath, item.strip().replace("\\", "/")) for item in indexs] 25 | print("indexs length:", len(indexs)) 26 | 27 | #class_index = {} 28 | #with open(classIndexFile, "r") as f: 29 | # data = f.readlines() 30 | # 31 | #for i, item in enumerate(data): 32 | # class_index.update({item.strip(): i}) 33 | 34 | captions = [] 35 | with open(textFile, "r", encoding='utf-8') as f: 36 | for line in f: 37 | if len(line.strip()) == 0: 38 | print("some line empty!") 39 | continue 40 | caption = line.split()[1:] 41 | caption = " ".join(caption).strip() 42 | # caption = re.sub(r'[^a-zA-Z]+', "", str(caption)) 43 | if len(caption) == 0: 44 | caption = "123456" 45 | captions.append(caption) 46 | 47 | print("captions length:", len(captions)) 48 | 49 | #labels = np.zeros([len(indexs), len(class_index)], dtype=np.int8) 50 | # label_lists = os.listdir(labelPath) 51 | with open("/home/admin00/dataset/NUS-WIDE/Groundtruth/used_label.txt", encoding='utf-8') as f: 52 | label_lists = f.readlines() 53 | label_lists = [item.strip() for item in label_lists] 54 | 55 | class_index = {} 56 | for i, item in enumerate(label_lists): 57 | class_index.update({item: i}) 58 | 59 | labels = np.zeros([len(indexs), len(class_index)], dtype=np.int8) 60 | 61 | for item in label_lists: 62 | path = os.path.join(labelPath, item) 63 | class_label = item# .split(".")[0].split("_")[-1] 64 | 65 | with open(path, "r") as f: 66 | data = f.readlines() 67 | for i, val in enumerate(data): 68 | labels[i][class_index[class_label]] = 1 if val.strip() == "1" else 0 69 | print("labels sum:", labels.sum()) 70 | 71 | not_used_id = [] 72 | with open("/home/admin00/dataset/NUS-WIDE/Groundtruth/not_used_id.txt", encoding='utf-8') as f: 73 | not_used_id = f.readlines() 74 | not_used_id = [int(int(item.strip())-2) for item in not_used_id] 75 | 76 | # for item in not_used_id: 77 | # indexs.pop(item) 78 | # captions.pop(item) 79 | # labels = np.delete(labels, item, 0) 80 | ind = list(range(len(indexs))) 81 | for item in not_used_id: 82 | ind.remove(item) 83 | indexs[item] = "" 84 | captions[item] = "" 85 | indexs = [item for item in indexs if item != ""] 86 | captions = [item for item in captions if item != ""] 87 | ind = np.asarray(ind) 88 | labels = labels[ind] 89 | # ind = range(len(indexs)) 90 | 91 | print("indexs length:", len(indexs)) 92 | print("captions length:", len(captions)) 93 | print("labels shape:", labels.shape) 94 | 95 | indexs = {"index": indexs} 96 | captions = {"caption": captions} 97 | labels = {"category": labels} 98 | 99 | scio.savemat('/home/admin00/DSPH-main/dataset/nuswide/index.mat', indexs) 100 | # scio.savemat("caption.mat", captions) 101 | scio.savemat('/home/admin00/DSPH-main/dataset/nuswide/label.mat', labels) 102 | 103 | 104 | captions = [item + "\n" for item in captions["caption"]] 105 | 106 | with open('/home/admin00/DSPH-main/dataset/nuswide/caption.txt', "w", encoding='utf-8') as f: 107 | f.writelines(captions) 108 | 109 | print("finished!") 110 | 111 | -------------------------------------------------------------------------------- /model/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | 134 | def tokenize(self, text): 135 | tokens = [] 136 | text = whitespace_clean(basic_clean(text)).lower() 137 | for token in re.findall(self.pat, text): 138 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 139 | tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' ')) 140 | return tokens 141 | 142 | def convert_tokens_to_ids(self, tokens): 143 | return [self.encoder[bpe_token] for bpe_token in tokens] 144 | -------------------------------------------------------------------------------- /dataset/make_coco.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | 5 | def make_index(jsonData: dict, indexDict: dict): 6 | """ 7 | use coco dict data as orignial data. 8 | indexDict: {jsonData's key: [index_key, index_value]} 9 | """ 10 | result = [] 11 | for name in indexDict: 12 | data = jsonData[name] 13 | middle_dict = {} 14 | for item in data: 15 | if item[indexDict[name][0]] not in middle_dict: 16 | middle_dict.update({item[indexDict[name][0]]: [item[indexDict[name][1]]]}) 17 | else: 18 | middle_dict[item[indexDict[name][0]]].append(item[indexDict[name][1]]) 19 | result.append(middle_dict) 20 | 21 | return result 22 | 23 | def check_file_exist(indexDict: dict, file_path: str): 24 | keys = list(indexDict.keys()) 25 | for item in keys: 26 | # print(indexDict[item]) 27 | if not os.path.exists(os.path.join(file_path, indexDict[item][0])): 28 | print(item, indexDict[item]) 29 | indexDict.pop(item) 30 | indexDict[item] = os.path.join(file_path, indexDict[item][0]) 31 | return indexDict 32 | 33 | def chage_categories2numpy(category_ids: dict, data: dict): 34 | 35 | for item in data: 36 | class_item = [0] * len(category_ids) 37 | for class_id in data[item]: 38 | class_item[category_ids[class_id]] = 1 39 | data[item] = np.asarray(class_item) 40 | 41 | return data 42 | 43 | def get_all_use_key(categoryDict: dict): 44 | return list(categoryDict.keys()) 45 | 46 | def remove_not_use(data: dict, used_key: list): 47 | 48 | keys = list(data.keys()) 49 | for item in keys: 50 | if item not in used_key: 51 | # print("remove:", item, indexDict[item]) 52 | data.pop(item) 53 | # print(len(category_list)) 54 | return data 55 | 56 | def merge_to_list(data: dict): 57 | 58 | result = [] 59 | key_sort = list(data.keys()) 60 | key_sort.sort() 61 | # print(key_sort) 62 | # print(key_sort.index(91654)) 63 | 64 | for item in key_sort: 65 | result.append(data[item]) 66 | 67 | return result 68 | 69 | 70 | if __name__ == "__main__": 71 | import json 72 | import scipy.io as scio 73 | import argparse 74 | 75 | parser = argparse.ArgumentParser() 76 | parser.add_argument("--coco-dir", default="/home/admin00/dataset/coco/archive (1)/coco2017", type=str, help="the coco dataset dir") 77 | parser.add_argument("--save-dir", default="./dataset/coco", type=str, help="mat file saved dir") 78 | args = parser.parse_args() 79 | 80 | 81 | PATH = args.coco_dir 82 | jsonFile = os.path.join(PATH, "annotations", "captions_train2017.json") 83 | with open(jsonFile, "r") as f: 84 | jsonData = json.load(f) 85 | indexDict = {"images": ["id", "file_name"], "annotations": ["image_id", "caption"]} 86 | result = make_index(jsonData, indexDict) 87 | indexDict_, captionDict = result 88 | indexDict_ = check_file_exist(indexDict_, os.path.join(PATH, "train2017")) 89 | print("caption:", len(indexDict_), len(captionDict)) 90 | 91 | jsonFile = os.path.join(PATH, "annotations", "instances_train2017.json") 92 | with open(jsonFile, "r") as f: 93 | jsonData = json.load(f) 94 | categroy_ids = {} 95 | for i, item in enumerate(jsonData['categories']): 96 | categroy_ids.update({item['id']: i}) 97 | indexDict = {"annotations": ["image_id", "category_id"], "images": ["id", "file_name"]} 98 | result = make_index(jsonData, indexDict) 99 | categoryDict = result[0] 100 | cateIndexDict = result[1] 101 | 102 | categoryDict = chage_categories2numpy(categroy_ids, categoryDict) 103 | 104 | used_key = get_all_use_key(categoryDict) 105 | # 统一index 106 | indexDict_ = remove_not_use(indexDict_, used_key) 107 | captionDict = remove_not_use(captionDict, used_key) 108 | categoryIndexDict = remove_not_use(cateIndexDict, used_key) 109 | categoryDict = remove_not_use(categoryDict, used_key) 110 | # 转变为list 111 | indexList = merge_to_list(indexDict_) 112 | captionList = merge_to_list(captionDict) 113 | categoryIndexList = merge_to_list(categoryIndexDict) 114 | categoryList = merge_to_list(categoryDict) 115 | print("result", len(indexDict_), len(categoryDict)) 116 | print("category:", len(categoryDict), len(categoryIndexList)) 117 | for i in range(len(indexList)): 118 | if indexList[i] != categoryIndexList[i]: 119 | print("Not the same:", i, indexList[i], categoryIndexList[i]) 120 | 121 | val_jsonFile = os.path.join(PATH, "annotations", "captions_val2017.json") 122 | with open(val_jsonFile, "r") as f: 123 | jsonData = json.load(f) 124 | indexDict = {"images": ["id", "file_name"], "annotations": ["image_id", "caption"]} 125 | result = make_index(jsonData, indexDict) 126 | val_indexDict = result[0] 127 | val_captionDict = result[1] 128 | val_indexDict = check_file_exist(val_indexDict, os.path.join(PATH, "val2017")) 129 | jsonFile = os.path.join(PATH, "annotations", "instances_val2017.json") 130 | with open(jsonFile, "r") as f: 131 | jsonData = json.load(f) 132 | categroy_ids = {} 133 | for i, item in enumerate(jsonData['categories']): 134 | categroy_ids.update({item['id']: i}) 135 | indexDict = {"annotations": ["image_id", "category_id"], "images": ["id", "file_name"]} 136 | result = make_index(jsonData, indexDict) 137 | val_categoryDict = result[0] 138 | val_categoryIndexDict = result[1] 139 | val_categoryDict = chage_categories2numpy(categroy_ids, val_categoryDict) 140 | used_key = get_all_use_key(val_categoryDict) 141 | val_indexDict = remove_not_use(val_indexDict, used_key) 142 | val_captionDict = remove_not_use(val_captionDict, used_key) 143 | val_categoryIndexDict = remove_not_use(val_categoryIndexDict, used_key) 144 | val_categoryDict = remove_not_use(val_categoryDict, used_key) 145 | 146 | val_indexList = merge_to_list(val_indexDict) 147 | val_captionList = merge_to_list(val_captionDict) 148 | val_categoryIndexList = merge_to_list(val_categoryIndexDict) 149 | val_categoryList = merge_to_list(val_categoryDict) 150 | 151 | indexList.extend(val_indexList) 152 | captionList.extend(val_captionList) 153 | categoryIndexList.extend(val_categoryIndexList) 154 | categoryList.extend(val_categoryList) 155 | 156 | print(len(indexList), len(captionList), len(categoryIndexList)) 157 | indexs = {"index": indexList} 158 | captions = {"caption": captionList} 159 | categorys = {"category": categoryList} 160 | 161 | scio.savemat("/home/admin00/DSPH-main/dataset/coco/index.mat", indexs) 162 | scio.savemat("/home/admin00/DSPH-main/dataset/coco/caption.mat", captions) 163 | scio.savemat("/home/admin00/DSPH-main/dataset/coco/label.mat", categorys) 164 | 165 | 166 | 167 | -------------------------------------------------------------------------------- /model/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.optim.optimizer import required 21 | from torch.nn.utils import clip_grad_norm_ 22 | import logging 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | def warmup_cosine(x, warmup=0.002): 27 | if x < warmup: 28 | return x/warmup 29 | return 0.5 * (1.0 + math.cos(math.pi * x)) 30 | 31 | def warmup_constant(x, warmup=0.002): 32 | """ Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps. 33 | Learning rate is 1. afterwards. """ 34 | if x < warmup: 35 | return x/warmup 36 | return 1.0 37 | 38 | def warmup_linear(x, warmup=0.002): 39 | """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step. 40 | After `t_total`-th training step, learning rate is zero. """ 41 | if x < warmup: 42 | return x/warmup 43 | return max((x-1.)/(warmup-1.), 0) 44 | 45 | SCHEDULES = { 46 | 'warmup_cosine': warmup_cosine, 47 | 'warmup_constant': warmup_constant, 48 | 'warmup_linear': warmup_linear, 49 | } 50 | 51 | 52 | class BertAdam(Optimizer): 53 | """Implements BERT version of Adam algorithm with weight decay fix. 54 | Params: 55 | lr: learning rate 56 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 57 | t_total: total number of training steps for the learning 58 | rate schedule, -1 means constant learning rate. Default: -1 59 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear' 60 | b1: Adams b1. Default: 0.9 61 | b2: Adams b2. Default: 0.999 62 | e: Adams epsilon. Default: 1e-6 63 | weight_decay: Weight decay. Default: 0.01 64 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 65 | """ 66 | def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear', 67 | b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, 68 | max_grad_norm=1.0): 69 | if lr is not required and lr < 0.0: 70 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 71 | if schedule not in SCHEDULES: 72 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 73 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 74 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 75 | if not 0.0 <= b1 < 1.0: 76 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 77 | if not 0.0 <= b2 < 1.0: 78 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 79 | if not e >= 0.0: 80 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 81 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, 82 | b1=b1, b2=b2, e=e, weight_decay=weight_decay, 83 | max_grad_norm=max_grad_norm) 84 | super(BertAdam, self).__init__(params, defaults) 85 | 86 | def get_lr(self): 87 | lr = [] 88 | for group in self.param_groups: 89 | for p in group['params']: 90 | if p.grad is None: 91 | continue 92 | state = self.state[p] 93 | if len(state) == 0: 94 | return [0] 95 | if group['t_total'] != -1: 96 | schedule_fct = SCHEDULES[group['schedule']] 97 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 98 | else: 99 | lr_scheduled = group['lr'] 100 | lr.append(lr_scheduled) 101 | return lr 102 | 103 | def step(self, closure=None): 104 | """Performs a single optimization step. 105 | Arguments: 106 | closure (callable, optional): A closure that reevaluates the model 107 | and returns the loss. 108 | """ 109 | loss = None 110 | if closure is not None: 111 | loss = closure() 112 | 113 | for group in self.param_groups: 114 | for p in group['params']: 115 | if p.grad is None: 116 | continue 117 | grad = p.grad.data 118 | if grad.is_sparse: 119 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 120 | 121 | state = self.state[p] 122 | 123 | # State initialization 124 | if len(state) == 0: 125 | state['step'] = 0 126 | # Exponential moving average of gradient values 127 | state['next_m'] = torch.zeros_like(p.data) 128 | # Exponential moving average of squared gradient values 129 | state['next_v'] = torch.zeros_like(p.data) 130 | 131 | next_m, next_v = state['next_m'], state['next_v'] 132 | beta1, beta2 = group['b1'], group['b2'] 133 | 134 | # Add grad clipping 135 | if group['max_grad_norm'] > 0: 136 | clip_grad_norm_(p, group['max_grad_norm']) 137 | 138 | # Decay the first and second moment running average coefficient 139 | # In-place operations to update the averages at the same time 140 | # next_m.mul_(beta1).add_(1 - beta1, grad) --> pytorch 1.7 141 | next_m.mul_(beta1).add_(grad, alpha=1 - beta1) 142 | # next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) --> pytorch 1.7 143 | next_v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 144 | update = next_m / (next_v.sqrt() + group['e']) 145 | 146 | # Just adding the square of the weights to the loss function is *not* 147 | # the correct way of using L2 regularization/weight decay with Adam, 148 | # since that will interact with the m and v parameters in strange ways. 149 | # 150 | # Instead we want to decay the weights in a manner that doesn't interact 151 | # with the m/v parameters. This is equivalent to adding the square 152 | # of the weights to the loss with plain (non-momentum) SGD. 153 | if group['weight_decay'] > 0.0: 154 | update += group['weight_decay'] * p.data 155 | 156 | if group['t_total'] != -1: 157 | schedule_fct = SCHEDULES[group['schedule']] 158 | progress = state['step']/group['t_total'] 159 | lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup']) 160 | else: 161 | lr_scheduled = group['lr'] 162 | 163 | update_with_lr = lr_scheduled * update 164 | p.data.add_(-update_with_lr) 165 | 166 | state['step'] += 1 167 | 168 | return loss -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from typing import Union 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | from utils.get_args import threshold 7 | from sklearn.metrics.pairwise import euclidean_distances 8 | from utils.get_args import get_args 9 | 10 | 11 | class HyP(torch.nn.Module): 12 | def __init__(self): 13 | torch.nn.Module.__init__(self) 14 | self.args = get_args() 15 | torch.manual_seed(self.args.hypseed) 16 | # Initialization 17 | self.proxies = torch.nn.Parameter(torch.randn(self.args.numclass, self.args.output_dim).to(1)) 18 | nn.init.kaiming_normal_(self.proxies, mode = 'fan_out') 19 | 20 | def forward(self, x=None, y=None, label=None): 21 | P_one_hot = label 22 | 23 | cos = F.normalize(x, p = 2, dim = 1).mm(F.normalize(self.proxies, p = 2, dim = 1).T) 24 | pos = 1 - cos 25 | neg = F.relu(cos - threshold) 26 | 27 | cos_t = F.normalize(y, p = 2, dim = 1).mm(F.normalize(self.proxies, p = 2, dim = 1).T) 28 | pos_t = 1 - cos_t 29 | neg_t = F.relu(cos_t - threshold) 30 | 31 | P_num = len(P_one_hot.nonzero()) 32 | N_num = len((P_one_hot == 0).nonzero()) 33 | 34 | pos_term = torch.where(P_one_hot == 1, pos.to(torch.float32), torch.zeros_like(cos).to(torch.float32)).sum() / P_num 35 | neg_term = torch.where(P_one_hot == 0, neg.to(torch.float32), torch.zeros_like(cos).to(torch.float32)).sum() / N_num 36 | 37 | pos_term_t = torch.where(P_one_hot == 1, pos_t.to(torch.float32), torch.zeros_like(cos_t).to(torch.float32)).sum() / P_num 38 | neg_term_t = torch.where(P_one_hot == 0, neg_t.to(torch.float32), torch.zeros_like(cos_t).to(torch.float32)).sum() / N_num 39 | 40 | if self.args.alpha > 0: 41 | index = label.sum(dim = 1) > 1 42 | label_ = label[index].float() 43 | 44 | x_ = x[index] 45 | t_ = y[index] 46 | 47 | cos_sim = label_.mm(label_.T) 48 | 49 | if len((cos_sim == 0).nonzero()) == 0: 50 | reg_term = 0 51 | reg_term_t = 0 52 | reg_term_xt = 0 53 | else: 54 | x_sim = F.normalize(x_, p = 2, dim = 1).mm(F.normalize(x_, p = 2, dim = 1).T) 55 | t_sim = F.normalize(t_, p = 2, dim = 1).mm(F.normalize(t_, p = 2, dim = 1).T) 56 | xt_sim = F.normalize(x_, p = 2, dim = 1).mm(F.normalize(t_, p = 2, dim = 1).T) 57 | 58 | neg = self.args.alpha * F.relu(x_sim - threshold) 59 | neg_t = self.args.alpha * F.relu(t_sim - threshold) 60 | neg_xt = self.args.alpha * F.relu(xt_sim - threshold) 61 | 62 | reg_term = torch.where(cos_sim == 0, neg, torch.zeros_like(x_sim)).sum() / len((cos_sim == 0).nonzero()) 63 | reg_term_t = torch.where(cos_sim == 0, neg_t, torch.zeros_like(t_sim)).sum() / len((cos_sim == 0).nonzero()) 64 | reg_term_xt = torch.where(cos_sim == 0, neg_xt, torch.zeros_like(xt_sim)).sum() / len((cos_sim == 0).nonzero()) 65 | else: 66 | reg_term = 0 67 | reg_term_t = 0 68 | reg_term_xt = 0 69 | 70 | return pos_term + neg_term + pos_term_t + neg_term_t + reg_term + reg_term_t + reg_term_xt 71 | 72 | 73 | def compute_metrics(x): 74 | # 取复值的原因在于cosine的值越大说明越相似,但是需要取的是前N个值,所以取符号变为增函数s 75 | sx = np.sort(-x, axis=1) 76 | d = np.diag(-x) 77 | d = d[:, np.newaxis] 78 | ind = sx - d 79 | ind = np.where(ind == 0) 80 | ind = ind[1] 81 | metrics = {} 82 | metrics['R1'] = float(np.sum(ind == 0)) * 100 / len(ind) 83 | metrics['R5'] = float(np.sum(ind < 5)) * 100 / len(ind) 84 | metrics['R10'] = float(np.sum(ind < 10)) * 100 / len(ind) 85 | metrics['MR'] = np.median(ind) + 1 86 | metrics["MedianR"] = metrics['MR'] 87 | metrics["MeanR"] = np.mean(ind) + 1 88 | metrics["cols"] = [int(i) for i in list(ind)] 89 | return metrics 90 | 91 | 92 | def calc_neighbor(a: torch.Tensor, b: torch.Tensor): 93 | # print(a.dtype, b.dtype) 94 | return (a.matmul(b.transpose(0, 1)) > 0).float() 95 | 96 | 97 | def euclidean_similarity(a: Union[torch.Tensor, np.ndarray], b: Union[torch.Tensor, np.ndarray]): 98 | 99 | if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor): 100 | similarity = torch.cdist(a, b, p=2.0) 101 | elif isinstance(a, np.ndarray) and isinstance(b, np.ndarray): 102 | similarity = euclidean_distances(a, b) 103 | else: 104 | raise ValueError("input value must in [torch.Tensor, numpy.ndarray], but it is %s, %s"%(type(a), type(b))) 105 | return similarity 106 | 107 | 108 | def euclidean_dist_matrix(tensor1: torch.Tensor, tensor2: torch.Tensor): 109 | """ 110 | calculate euclidean distance as inner product 111 | :param tensor1: a tensor with shape (a, c) 112 | :param tensor2: a tensor with shape (b, c) 113 | :return: the euclidean distance matrix which each point is the distance between a row in tensor1 and a row in tensor2. 114 | """ 115 | dim1 = tensor1.shape[0] 116 | dim2 = tensor2.shape[0] 117 | multi = torch.matmul(tensor1, tensor2.t()) 118 | a2 = torch.sum(torch.pow(tensor1, 2), dim=1, keepdim=True).expand(dim1, dim2) 119 | b2 = torch.sum(torch.pow(tensor2, 2), dim=1, keepdim=True).t().expand(dim1, dim2) 120 | dist = torch.sqrt(a2 + b2 - 2 * multi) 121 | return dist 122 | 123 | 124 | def cosine_similarity(a: Union[torch.Tensor, np.ndarray], b: Union[torch.Tensor, np.ndarray]): 125 | 126 | if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor): 127 | a = a / a.norm(dim=-1, keepdim=True) if len(torch.where(a != 0)[0]) > 0 else a 128 | b = b / b.norm(dim=-1, keepdim=True) if len(torch.where(b != 0)[0]) > 0 else b 129 | return torch.matmul(a, b.t()) 130 | elif isinstance(a, np.ndarray) and isinstance(b, np.ndarray): 131 | a = a / np.linalg.norm(a, axis=-1, keepdims=True) if len(np.where(a != 0)[0]) > 0 else a 132 | b = b / np.linalg.norm(b, axis=-1, keepdims=True) if len(np.where(b != 0)[0]) > 0 else b 133 | return np.matmul(a, b.T) 134 | else: 135 | raise ValueError("input value must in [torch.Tensor, numpy.ndarray], but it is %s, %s"%(type(a), type(b))) 136 | 137 | def calc_map_k(qB, rB, query_L, retrieval_L, k=None, rank=0): 138 | # qB: {-1,+1}^{mxq} 139 | # rB: {-1,+1}^{nxq} 140 | # query_L: {0,1}^{mxl} 141 | # retrieval_L: {0,1}^{nxl} 142 | num_query = query_L.shape[0] 143 | qB = torch.sign(qB) 144 | rB = torch.sign(rB) 145 | map = 0 146 | if k is None: 147 | k = retrieval_L.shape[0] 148 | # print("query num:", num_query) 149 | for iter in range(num_query): 150 | q_L = query_L[iter] 151 | if len(q_L.shape) < 2: 152 | q_L = q_L.unsqueeze(0) # [1, hash length] 153 | gnd = (q_L.mm(retrieval_L.transpose(0, 1)) > 0).squeeze().type(torch.float32) 154 | tsum = torch.sum(gnd) 155 | if tsum == 0: 156 | continue 157 | hamm = calcHammingDist(qB[iter, :], rB) 158 | _, ind = torch.sort(hamm) 159 | ind.squeeze_() 160 | gnd = gnd[ind] 161 | total = min(k, int(tsum)) 162 | count = torch.arange(1, total + 1).type(torch.float32) 163 | tindex = torch.nonzero(gnd)[:total].squeeze().type(torch.float32) + 1.0 164 | if tindex.is_cuda: 165 | count = count.to(rank) 166 | map = map + torch.mean(count / tindex) 167 | map = map / num_query 168 | return map 169 | 170 | 171 | def calcHammingDist(B1, B2): 172 | 173 | if len(B1.shape) < 2: 174 | B1.view(1, -1) 175 | if len(B2.shape) < 2: 176 | B2.view(1, -1) 177 | q = B2.shape[1] 178 | if isinstance(B1, torch.Tensor): 179 | distH = 0.5 * (q - torch.matmul(B1, B2.t())) 180 | elif isinstance(B1, np.ndarray): 181 | distH = 0.5 * (q - np.matmul(B1, B2.transpose())) 182 | else: 183 | raise ValueError("B1, B2 must in [torch.Tensor, np.ndarray]") 184 | return distH 185 | -------------------------------------------------------------------------------- /model/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Any, Union, List 6 | 7 | import torch 8 | from PIL import Image 9 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 10 | from tqdm import tqdm 11 | 12 | from .model import build_model 13 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 14 | 15 | try: 16 | from torchvision.transforms import InterpolationMode 17 | BICUBIC = InterpolationMode.BICUBIC 18 | except ImportError: 19 | BICUBIC = Image.BICUBIC 20 | 21 | 22 | if torch.__version__.split(".") < ["1", "7", "1"]: 23 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 24 | 25 | 26 | __all__ = ["available_models", "load", "tokenize"] 27 | _tokenizer = _Tokenizer() 28 | 29 | _MODELS = { 30 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 31 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 32 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 33 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 34 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 35 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 36 | } 37 | 38 | 39 | def _download(url: str, root: str): 40 | os.makedirs(root, exist_ok=True) 41 | filename = os.path.basename(url) 42 | 43 | expected_sha256 = url.split("/")[-2] 44 | download_target = os.path.join(root, filename) 45 | 46 | if os.path.exists(download_target) and not os.path.isfile(download_target): 47 | raise RuntimeError(f"{download_target} exists and is not a regular file") 48 | 49 | if os.path.isfile(download_target): 50 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 51 | return download_target 52 | else: 53 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 54 | 55 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 56 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 57 | while True: 58 | buffer = source.read(8192) 59 | if not buffer: 60 | break 61 | 62 | output.write(buffer) 63 | loop.update(len(buffer)) 64 | 65 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 66 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 67 | 68 | return download_target 69 | 70 | 71 | def _transform(n_px): 72 | return Compose([ 73 | Resize(n_px, interpolation=BICUBIC), 74 | CenterCrop(n_px), 75 | lambda image: image.convert("RGB"), 76 | ToTensor(), 77 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 78 | ]) 79 | 80 | 81 | def available_models() -> List[str]: 82 | """Returns the names of available CLIP models""" 83 | return list(_MODELS.keys()) 84 | 85 | 86 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): 87 | """Load a CLIP model 88 | 89 | Parameters 90 | ---------- 91 | name : str 92 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 93 | 94 | device : Union[str, torch.device] 95 | The device to put the loaded model 96 | 97 | jit : bool 98 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 99 | 100 | download_root: str 101 | path to download the model files; by default, it uses "~/.cache/clip" 102 | 103 | Returns 104 | ------- 105 | model : torch.nn.Module 106 | The CLIP model 107 | 108 | preprocess : Callable[[PIL.Image], torch.Tensor] 109 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 110 | """ 111 | if name in _MODELS: 112 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) 113 | elif os.path.isfile(name): 114 | model_path = name 115 | else: 116 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 117 | 118 | try: 119 | # loading JIT archive 120 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 121 | state_dict = None 122 | except RuntimeError: 123 | # loading saved state dict 124 | if jit: 125 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 126 | jit = False 127 | state_dict = torch.load(model_path, map_location="cpu") 128 | 129 | if not jit: 130 | model = build_model(state_dict or model.state_dict()).to(device) 131 | if str(device) == "cpu": 132 | model.float() 133 | return model, _transform(model.visual.input_resolution) 134 | 135 | # patch the device names 136 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 137 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 138 | 139 | def patch_device(module): 140 | try: 141 | graphs = [module.graph] if hasattr(module, "graph") else [] 142 | except RuntimeError: 143 | graphs = [] 144 | 145 | if hasattr(module, "forward1"): 146 | graphs.append(module.forward1.graph) 147 | 148 | for graph in graphs: 149 | for node in graph.findAllNodes("prim::Constant"): 150 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 151 | node.copyAttributes(device_node) 152 | 153 | model.apply(patch_device) 154 | patch_device(model.encode_image) 155 | patch_device(model.encode_text) 156 | 157 | # patch dtype to float32 on CPU 158 | if str(device) == "cpu": 159 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 160 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 161 | float_node = float_input.node() 162 | 163 | def patch_float(module): 164 | try: 165 | graphs = [module.graph] if hasattr(module, "graph") else [] 166 | except RuntimeError: 167 | graphs = [] 168 | 169 | if hasattr(module, "forward1"): 170 | graphs.append(module.forward1.graph) 171 | 172 | for graph in graphs: 173 | for node in graph.findAllNodes("aten::to"): 174 | inputs = list(node.inputs()) 175 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 176 | if inputs[i].node()["value"] == 5: 177 | inputs[i].node().copyAttributes(float_node) 178 | 179 | model.apply(patch_float) 180 | patch_float(model.encode_image) 181 | patch_float(model.encode_text) 182 | 183 | model.float() 184 | 185 | return model, _transform(model.input_resolution.item()) 186 | 187 | 188 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: 189 | """ 190 | Returns the tokenized representation of given input string(s) 191 | 192 | Parameters 193 | ---------- 194 | texts : Union[str, List[str]] 195 | An input string or a list of input strings to tokenize 196 | 197 | context_length : int 198 | The context length to use; all CLIP models use 77 as the context length 199 | 200 | truncate: bool 201 | Whether to truncate the text in case its encoding is longer than the context length 202 | 203 | Returns 204 | ------- 205 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 206 | """ 207 | if isinstance(texts, str): 208 | texts = [texts] 209 | 210 | sot_token = _tokenizer.encoder["<|startoftext|>"] 211 | eot_token = _tokenizer.encoder["<|endoftext|>"] 212 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 213 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 214 | 215 | for i, tokens in enumerate(all_tokens): 216 | if len(tokens) > context_length: 217 | if truncate: 218 | tokens = tokens[:context_length] 219 | tokens[-1] = eot_token 220 | else: 221 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 222 | result[i, :len(tokens)] = torch.tensor(tokens) 223 | 224 | return result 225 | -------------------------------------------------------------------------------- /train/hash_train.py: -------------------------------------------------------------------------------- 1 | from torch.nn.modules import loss 2 | from model.hash_model import DSPH as DSPH 3 | import os 4 | from tqdm import tqdm 5 | import torch 6 | import torch.nn as nn 7 | from torch.utils.data import DataLoader 8 | import scipy.io as scio 9 | 10 | from .base import TrainBase 11 | from model.optimization import BertAdam 12 | from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similarity 13 | from utils.calc_utils import calc_map_k_matrix as calc_map_k 14 | from utils.utils import HyP 15 | from dataset.dataloader import dataloader 16 | import time 17 | 18 | 19 | class Trainer(TrainBase): 20 | 21 | def __init__(self, 22 | rank=1): 23 | args = get_args() 24 | super(Trainer, self).__init__(args, rank) 25 | self.logger.info("dataset len: {}".format(len(self.train_loader.dataset))) 26 | self.run() 27 | 28 | def _init_model(self): 29 | self.logger.info("init model.") 30 | 31 | self.logger.info("ViT+GPT!") 32 | HashModel = DSPH 33 | self.model = HashModel(outputDim=self.args.output_dim, clipPath=self.args.clip_path, 34 | writer=self.writer, logger=self.logger, is_train=self.args.is_train).to(self.rank) 35 | if self.args.pretrained != "" and os.path.exists(self.args.pretrained): 36 | self.logger.info("load pretrained model.") 37 | self.model.load_state_dict(torch.load(self.args.pretrained, map_location=f"cuda:{self.rank}")) 38 | 39 | self.model.float() 40 | self.optimizer = BertAdam([ 41 | {'params': self.model.clip.parameters(), 'lr': self.args.clip_lr}, 42 | {'params': self.model.image_hash.parameters(), 'lr': self.args.lr}, 43 | {'params': self.model.text_hash.parameters(), 'lr': self.args.lr} 44 | ], lr=self.args.lr, warmup=self.args.warmup_proportion, schedule='warmup_cosine', 45 | b1=0.9, b2=0.98, e=1e-6, t_total=len(self.train_loader) * self.args.epochs, 46 | weight_decay=self.args.weight_decay, max_grad_norm=1.0) 47 | 48 | self.hyp = HyP().to(1) 49 | self.optimizer_loss = torch.optim.SGD(params=self.hyp.parameters(), lr=0.02, momentum=0.9, weight_decay=0.0005) 50 | self.total_time = 0 51 | print(self.model) 52 | 53 | def _init_dataset(self): 54 | self.logger.info("init dataset.") 55 | self.logger.info(f"Using {self.args.dataset} dataset.") 56 | self.args.index_file = os.path.join("./dataset", self.args.dataset, self.args.index_file) 57 | self.args.caption_file = os.path.join("./dataset", self.args.dataset, self.args.caption_file) 58 | self.args.label_file = os.path.join("./dataset", self.args.dataset, self.args.label_file) 59 | train_data, query_data, retrieval_data = dataloader(captionFile=self.args.caption_file, 60 | indexFile=self.args.index_file, 61 | labelFile=self.args.label_file, 62 | maxWords=self.args.max_words, 63 | imageResolution=self.args.resolution, 64 | query_num=self.args.query_num, 65 | train_num=self.args.train_num, 66 | seed=self.args.seed) 67 | self.train_labels = train_data.get_all_label().to(1) 68 | self.query_labels = query_data.get_all_label() 69 | self.retrieval_labels = retrieval_data.get_all_label() 70 | self.args.retrieval_num = len(self.retrieval_labels) 71 | self.logger.info(f"query shape: {self.query_labels.shape}") 72 | self.logger.info(f"retrieval shape: {self.retrieval_labels.shape}") 73 | self.train_loader = DataLoader( 74 | dataset=train_data, 75 | batch_size=self.args.batch_size, 76 | num_workers=self.args.num_workers, 77 | pin_memory=True, 78 | shuffle=True 79 | ) 80 | self.query_loader = DataLoader( 81 | dataset=query_data, 82 | batch_size=self.args.batch_size, 83 | num_workers=self.args.num_workers, 84 | pin_memory=True, 85 | shuffle=True 86 | ) 87 | self.retrieval_loader = DataLoader( 88 | dataset=retrieval_data, 89 | batch_size=self.args.batch_size, 90 | num_workers=self.args.num_workers, 91 | pin_memory=True, 92 | shuffle=True 93 | ) 94 | 95 | 96 | def train_epoch(self, epoch): 97 | self.change_state(mode="train") 98 | self.logger.info(">>>>>> epochs: %d/%d"%(epoch, self.args.epochs)) 99 | all_loss = 0 100 | for image, text, label, index in self.train_loader: 101 | start_time = time.time() 102 | self.global_step += 1 103 | image.float() 104 | image = image.to(self.rank, non_blocking=True) 105 | text = text.to(self.rank, non_blocking=True) 106 | label = label.to(self.rank, non_blocking=True) 107 | 108 | index = index.numpy() 109 | 110 | hash_img, hash_text = self.model(image, text) 111 | 112 | loss = self.hyp(hash_img, hash_text, label) 113 | 114 | all_loss += loss 115 | 116 | self.optimizer.zero_grad() 117 | self.optimizer_loss.zero_grad() 118 | loss.backward() 119 | self.optimizer.step() 120 | self.optimizer_loss.step() 121 | self.total_time += time.time() - start_time 122 | 123 | 124 | self.logger.info(f">>>>>> [{epoch}/{self.args.epochs}] loss: {all_loss.data / (len(self.train_loader))}, lr: {'-'.join([str('%.9f'%itm) for itm in sorted(list(set(self.optimizer.get_lr())))])}, time: {self.total_time}") 125 | 126 | def train(self): 127 | self.logger.info("Start train.") 128 | 129 | for epoch in range(self.args.epochs): 130 | self.train_epoch(epoch) 131 | self.valid(epoch) 132 | # self.save_model(epoch) 133 | 134 | self.logger.info(f">>>>>>> FINISHED >>>>>> Best epoch, I-T: {self.best_epoch_i}, mAP: {self.max_mapi2t}, T-I: {self.best_epoch_t}, mAP: {self.max_mapt2i}") 135 | 136 | 137 | def get_code(self, data_loader, length: int): 138 | 139 | img_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank) 140 | text_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank) 141 | encoder_time = 0 142 | for image, text, label, index in tqdm(data_loader): 143 | start_encoder_time = time.time() 144 | image = image.to(self.rank, non_blocking=True) 145 | text = text.to(self.rank, non_blocking=True) 146 | index = index.numpy() 147 | image_hash = self.model.encode_image(image) 148 | image_hash = torch.sign(image_hash) 149 | text_hash = self.model.encode_text(text) 150 | text_hash = torch.sign(text_hash) 151 | encoder_time = time.time() - start_encoder_time 152 | img_buffer[index, :] = image_hash.data 153 | text_buffer[index, :] = text_hash.data 154 | 155 | return img_buffer, text_buffer, encoder_time 156 | 157 | 158 | def test(self, mode_name="i2t"): 159 | if self.args.pretrained == "": 160 | raise RuntimeError("test step must load a model! please set the --pretrained argument.") 161 | self.change_state(mode="valid") 162 | save_dir = os.path.join(self.args.save_dir, "PR_cruve") 163 | os.makedirs(save_dir, exist_ok=True) 164 | query_img, query_txt, q_encoder_time = self.get_code(self.query_loader, self.args.query_num) 165 | retrieval_img, retrieval_txt, r_encoder_time = self.get_code(self.retrieval_loader, self.args.retrieval_num) 166 | mAPi2t = calc_map_k(query_img, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank) 167 | mAPt2i = calc_map_k(query_txt, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank) 168 | mAPi2i = calc_map_k(query_img, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank) 169 | mAPt2t = calc_map_k(query_txt, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank) 170 | self.max_mapt2i = max(self.max_mapt2i, mAPt2i) 171 | self.logger.info(f">>>>>> MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}, MAP(t->t): {mAPt2t}, MAP(i->i): {mAPi2i}") 172 | 173 | query_img = query_img.cpu().detach().numpy() 174 | query_txt = query_txt.cpu().detach().numpy() 175 | retrieval_img = retrieval_img.cpu().detach().numpy() 176 | retrieval_txt = retrieval_txt.cpu().detach().numpy() 177 | query_labels = self.query_labels.numpy() 178 | retrieval_labels = self.retrieval_labels.numpy() 179 | 180 | result_dict = { 181 | 'q_img': query_img, 182 | 'q_txt': query_txt, 183 | 'r_img': retrieval_img, 184 | 'r_txt': retrieval_txt, 185 | 'q_l': query_labels, 186 | 'r_l': retrieval_labels 187 | } 188 | scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + "-" + mode_name + ".mat"), result_dict) 189 | self.logger.info(">>>>>> save all data!") 190 | 191 | 192 | def valid(self, epoch): 193 | self.logger.info("Valid.") 194 | self.change_state(mode="valid") 195 | query_img, query_txt, q_encoder_time = self.get_code(self.query_loader, self.args.query_num) 196 | retrieval_img, retrieval_txt, r_encoder_time = self.get_code(self.retrieval_loader, self.args.retrieval_num) 197 | mAPi2t = calc_map_k(query_img, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank) 198 | mAPt2i = calc_map_k(query_txt, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank) 199 | mAPi2i = calc_map_k(query_img, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank) 200 | mAPt2t = calc_map_k(query_txt, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank) 201 | if self.max_mapi2t < mAPi2t: 202 | self.best_epoch_i = epoch 203 | self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t") 204 | self.max_mapi2t = max(self.max_mapi2t, mAPi2t) 205 | if self.max_mapt2i < mAPt2i: 206 | self.best_epoch_t = epoch 207 | self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="t2i") 208 | self.max_mapt2i = max(self.max_mapt2i, mAPt2i) 209 | self.logger.info(f">>>>>> [{epoch}/{self.args.epochs}], MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}, MAP(t->t): {mAPt2t}, MAP(i->i): {mAPi2i}, \ 210 | MAX MAP(i->t): {self.max_mapi2t}, MAX MAP(t->i): {self.max_mapt2i}, query_encoder_time: {q_encoder_time}, retrieval_encoder_time: {r_encoder_time}") 211 | 212 | def save_mat(self, query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t"): 213 | 214 | save_dir = os.path.join(self.args.save_dir, "PR_cruve") 215 | os.makedirs(save_dir, exist_ok=True) 216 | 217 | query_img = query_img.cpu().detach().numpy() 218 | query_txt = query_txt.cpu().detach().numpy() 219 | retrieval_img = retrieval_img.cpu().detach().numpy() 220 | retrieval_txt = retrieval_txt.cpu().detach().numpy() 221 | query_labels = self.query_labels.numpy() 222 | retrieval_labels = self.retrieval_labels.numpy() 223 | 224 | result_dict = { 225 | 'q_img': query_img, 226 | 'q_txt': query_txt, 227 | 'r_img': retrieval_img, 228 | 'r_txt': retrieval_txt, 229 | 'q_l': query_labels, 230 | 'r_l': retrieval_labels 231 | } 232 | scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + "-" + mode_name + ".mat"), result_dict) 233 | self.logger.info(f">>>>>> save best {mode_name} data!") 234 | 235 | 236 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, inplanes, planes, stride=1): 14 | super().__init__() 15 | 16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | 20 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | 23 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 24 | 25 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 26 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 27 | 28 | self.relu = nn.ReLU(inplace=True) 29 | self.downsample = None 30 | self.stride = stride 31 | 32 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 33 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 34 | self.downsample = nn.Sequential(OrderedDict([ 35 | ("-1", nn.AvgPool2d(stride)), 36 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 37 | ("1", nn.BatchNorm2d(planes * self.expansion)) 38 | ])) 39 | 40 | def forward(self, x: torch.Tensor): 41 | identity = x 42 | 43 | out = self.relu(self.bn1(self.conv1(x))) 44 | out = self.relu(self.bn2(self.conv2(out))) 45 | out = self.avgpool(out) 46 | out = self.bn3(self.conv3(out)) 47 | 48 | if self.downsample is not None: 49 | identity = self.downsample(x) 50 | 51 | out += identity 52 | out = self.relu(out) 53 | return out 54 | 55 | 56 | class AttentionPool2d(nn.Module): 57 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 58 | super().__init__() 59 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 60 | self.k_proj = nn.Linear(embed_dim, embed_dim) 61 | self.q_proj = nn.Linear(embed_dim, embed_dim) 62 | self.v_proj = nn.Linear(embed_dim, embed_dim) 63 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 64 | self.num_heads = num_heads 65 | 66 | def forward(self, x): 67 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 68 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 69 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 70 | x, _ = F.multi_head_attention_forward( 71 | query=x, key=x, value=x, 72 | embed_dim_to_check=x.shape[-1], 73 | num_heads=self.num_heads, 74 | q_proj_weight=self.q_proj.weight, 75 | k_proj_weight=self.k_proj.weight, 76 | v_proj_weight=self.v_proj.weight, 77 | in_proj_weight=None, 78 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 79 | bias_k=None, 80 | bias_v=None, 81 | add_zero_attn=False, 82 | dropout_p=0, 83 | out_proj_weight=self.c_proj.weight, 84 | out_proj_bias=self.c_proj.bias, 85 | use_separate_proj_weight=True, 86 | training=self.training, 87 | need_weights=False 88 | ) 89 | 90 | return x[0] 91 | 92 | 93 | class ModifiedResNet(nn.Module): 94 | """ 95 | A ResNet class that is similar to torchvision's but contains the following changes: 96 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 97 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 98 | - The final pooling layer is a QKV attention instead of an average pool 99 | """ 100 | 101 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 102 | super().__init__() 103 | self.output_dim = output_dim 104 | self.input_resolution = input_resolution 105 | 106 | # the 3-layer stem 107 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 108 | self.bn1 = nn.BatchNorm2d(width // 2) 109 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 110 | self.bn2 = nn.BatchNorm2d(width // 2) 111 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 112 | self.bn3 = nn.BatchNorm2d(width) 113 | self.avgpool = nn.AvgPool2d(2) 114 | self.relu = nn.ReLU(inplace=True) 115 | 116 | # residual layers 117 | self._inplanes = width # this is a *mutable* variable used during construction 118 | self.layer1 = self._make_layer(width, layers[0]) 119 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 120 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 121 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 122 | 123 | embed_dim = width * 32 # the ResNet feature dimension 124 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 125 | 126 | def _make_layer(self, planes, blocks, stride=1): 127 | layers = [Bottleneck(self._inplanes, planes, stride)] 128 | 129 | self._inplanes = planes * Bottleneck.expansion 130 | for _ in range(1, blocks): 131 | layers.append(Bottleneck(self._inplanes, planes)) 132 | 133 | return nn.Sequential(*layers) 134 | 135 | def forward(self, x): 136 | def stem(x): 137 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: 138 | x = self.relu(bn(conv(x))) 139 | x = self.avgpool(x) 140 | return x 141 | 142 | x = x.type(self.conv1.weight.dtype) 143 | x = stem(x) 144 | x = self.layer1(x) 145 | x = self.layer2(x) 146 | x = self.layer3(x) 147 | x = self.layer4(x) 148 | x = self.attnpool(x) 149 | 150 | return x 151 | 152 | 153 | class LayerNorm(nn.LayerNorm): 154 | """Subclass torch's LayerNorm to handle fp16.""" 155 | 156 | def forward(self, x: torch.Tensor): 157 | orig_type = x.dtype 158 | ret = super().forward(x.type(torch.float32)) 159 | return ret.type(orig_type) 160 | 161 | 162 | class QuickGELU(nn.Module): 163 | def forward(self, x: torch.Tensor): 164 | return x * torch.sigmoid(1.702 * x) 165 | 166 | 167 | class ResidualAttentionBlock(nn.Module): 168 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 169 | super().__init__() 170 | 171 | self.attn = nn.MultiheadAttention(d_model, n_head) 172 | self.ln_1 = LayerNorm(d_model) 173 | self.mlp = nn.Sequential(OrderedDict([ 174 | ("c_fc", nn.Linear(d_model, d_model * 4)), 175 | ("gelu", QuickGELU()), 176 | ("c_proj", nn.Linear(d_model * 4, d_model)) 177 | ])) 178 | self.ln_2 = LayerNorm(d_model) 179 | self.attn_mask = attn_mask 180 | 181 | def attention(self, x: torch.Tensor): 182 | # self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 183 | # return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 184 | attn_mask_ = self.attn_mask 185 | if self.attn_mask is not None and hasattr(self.attn_mask, '__call__'): 186 | attn_mask_ = self.attn_mask(x.size(0)) # LND 187 | 188 | attn_mask_ = attn_mask_.to(dtype=x.dtype, device=x.device) if attn_mask_ is not None else None 189 | return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask_)[0] 190 | 191 | def forward(self, x: torch.Tensor): 192 | # x, video_frame = x_tuple 193 | # print(x.shape) 194 | x = x + self.attention(self.ln_1(x)) 195 | x = x + self.mlp(self.ln_2(x)) 196 | return x 197 | 198 | 199 | class Transformer(nn.Module): 200 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 201 | super().__init__() 202 | self.width = width 203 | self.layers = layers 204 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 205 | 206 | def forward(self, x: torch.Tensor): 207 | return self.resblocks(x) 208 | 209 | 210 | class VisionTransformer(nn.Module): 211 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): 212 | super().__init__() 213 | self.input_resolution = input_resolution 214 | self.output_dim = output_dim 215 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 216 | 217 | 218 | scale = width ** -0.5 219 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 220 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 221 | self.ln_pre = LayerNorm(width) 222 | 223 | self.transformer = Transformer(width, layers, heads) 224 | 225 | self.ln_post = LayerNorm(width) 226 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 227 | 228 | def forward(self, x: torch.Tensor): 229 | # print(x.shape) 230 | # print(x.shape) 231 | x = self.conv1(x) # shape = [*, width, grid, grid] 232 | # print("image feature map:", x.shape) 233 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 234 | # print(x.shape) 235 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 236 | # print(x.shape) 237 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 238 | x = x + self.positional_embedding.to(x.dtype) 239 | x = self.ln_pre(x) 240 | # print(x.shape) 241 | 242 | x = x.permute(1, 0, 2) # NLD -> LND 243 | x = self.transformer(x) 244 | # print(x.shape) 245 | x = x.permute(1, 0, 2) # LND -> NLD 246 | 247 | x = self.ln_post(x[:, 0, :]) 248 | 249 | if self.proj is not None: 250 | x = x @ self.proj 251 | 252 | return x 253 | 254 | 255 | class CLIP(nn.Module): 256 | def __init__(self, 257 | embed_dim: int, 258 | # vision 259 | image_resolution: int, 260 | vision_layers: Union[Tuple[int, int, int, int], int], 261 | vision_width: int, 262 | vision_patch_size: int, 263 | # text 264 | context_length: int, 265 | vocab_size: int, 266 | transformer_width: int, 267 | transformer_heads: int, 268 | transformer_layers: int 269 | ): 270 | super().__init__() 271 | 272 | self.context_length = context_length 273 | 274 | if isinstance(vision_layers, (tuple, list)): 275 | vision_heads = vision_width * 32 // 64 276 | self.visual = ModifiedResNet( 277 | layers=vision_layers, 278 | output_dim=embed_dim, 279 | heads=vision_heads, 280 | input_resolution=image_resolution, 281 | width=vision_width 282 | ) 283 | else: 284 | vision_heads = vision_width // 64 285 | self.visual = VisionTransformer( 286 | input_resolution=image_resolution, 287 | patch_size=vision_patch_size, 288 | width=vision_width, 289 | layers=vision_layers, 290 | heads=vision_heads, 291 | output_dim=embed_dim 292 | ) 293 | 294 | self.transformer = Transformer( 295 | width=transformer_width, 296 | layers=transformer_layers, 297 | heads=transformer_heads, 298 | attn_mask=self.build_attention_mask 299 | ) 300 | 301 | self.vocab_size = vocab_size 302 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 303 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 304 | self.ln_final = LayerNorm(transformer_width) 305 | 306 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 307 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 308 | 309 | self.initialize_parameters() 310 | 311 | def initialize_parameters(self): 312 | nn.init.normal_(self.token_embedding.weight, std=0.02) 313 | nn.init.normal_(self.positional_embedding, std=0.01) 314 | 315 | if isinstance(self.visual, ModifiedResNet): 316 | if self.visual.attnpool is not None: 317 | std = self.visual.attnpool.c_proj.in_features ** -0.5 318 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 319 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 320 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 321 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 322 | 323 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 324 | for name, param in resnet_block.named_parameters(): 325 | if name.endswith("bn3.weight"): 326 | nn.init.zeros_(param) 327 | 328 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 329 | attn_std = self.transformer.width ** -0.5 330 | fc_std = (2 * self.transformer.width) ** -0.5 331 | for block in self.transformer.resblocks: 332 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 333 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 334 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 335 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 336 | 337 | if self.text_projection is not None: 338 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 339 | 340 | def build_attention_mask(self, context_length): 341 | # lazily create causal attention mask, with full attention between the vision tokens 342 | # pytorch uses additive attention mask; fill with -inf 343 | mask = torch.empty(context_length, context_length) 344 | mask.fill_(float("-inf")) 345 | mask.triu_(1) # zero out the lower diagonal 346 | return mask 347 | 348 | @property 349 | def dtype(self): 350 | return self.visual.conv1.weight.dtype 351 | 352 | def encode_image(self, image): 353 | return self.visual(image.type(self.dtype)) 354 | 355 | def encode_text(self, text): 356 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 357 | 358 | x = x + self.positional_embedding[:x.size(1), :].type(self.dtype) 359 | x = x.permute(1, 0, 2) # NLD -> LND 360 | x = self.transformer(x) 361 | x = x.permute(1, 0, 2) # LND -> NLD 362 | x = self.ln_final(x).type(self.dtype) 363 | 364 | # x.shape = [batch_size, n_ctx, transformer.width] 365 | # take features from the eot embedding (eot_token is the highest number in each sequence) 366 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 367 | 368 | return x 369 | 370 | def forward(self, image, text): 371 | image_features = self.encode_image(image) 372 | text_features = self.encode_text(text) 373 | 374 | # normalized features 375 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 376 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 377 | 378 | # cosine similarity as logits 379 | logit_scale = self.logit_scale.exp() 380 | logits_per_image = logit_scale * image_features @ text_features.t() 381 | logits_per_text = logits_per_image.t() 382 | 383 | # shape = [global_batch_size, global_batch_size] 384 | return logits_per_image, logits_per_text 385 | 386 | 387 | def convert_weights(model: nn.Module): 388 | """Convert applicable model parameters to fp16""" 389 | 390 | def _convert_weights_to_fp16(l): 391 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 392 | l.weight.data = l.weight.data.half() 393 | if l.bias is not None: 394 | l.bias.data = l.bias.data.half() 395 | 396 | if isinstance(l, nn.MultiheadAttention): 397 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 398 | tensor = getattr(l, attr) 399 | if tensor is not None: 400 | tensor.data = tensor.data.half() 401 | 402 | for name in ["text_projection", "proj"]: 403 | if hasattr(l, name): 404 | attr = getattr(l, name) 405 | if attr is not None: 406 | attr.data = attr.data.half() 407 | 408 | model.apply(_convert_weights_to_fp16) 409 | 410 | 411 | def build_model(state_dict: dict): 412 | vit = "visual.proj" in state_dict 413 | 414 | if vit: 415 | vision_width = state_dict["visual.conv1.weight"].shape[0] 416 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 417 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 418 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 419 | image_resolution = vision_patch_size * grid_size 420 | else: 421 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 422 | vision_layers = tuple(counts) 423 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 424 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 425 | vision_patch_size = None 426 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 427 | image_resolution = output_width * 32 428 | 429 | embed_dim = state_dict["text_projection"].shape[1] 430 | context_length = state_dict["positional_embedding"].shape[0] 431 | vocab_size = state_dict["token_embedding.weight"].shape[0] 432 | transformer_width = state_dict["ln_final.weight"].shape[0] 433 | transformer_heads = transformer_width // 64 434 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) 435 | 436 | model = CLIP( 437 | embed_dim, 438 | image_resolution, vision_layers, vision_width, vision_patch_size, 439 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 440 | ) 441 | # print("vision width:", vision_width) 442 | # print("vision patch size", vision_patch_size) 443 | 444 | for key in ["input_resolution", "context_length", "vocab_size"]: 445 | if key in state_dict: 446 | del state_dict[key] 447 | 448 | convert_weights(model) 449 | model.load_state_dict(state_dict) 450 | # return model.eval() 451 | return model 452 | --------------------------------------------------------------------------------