├── data └── .gitkeep ├── images ├── model.png ├── test │ ├── bus.jpeg │ ├── bus2.jpeg │ ├── cat.jpeg │ ├── cat2.jpeg │ ├── cock.jpeg │ ├── dog.jpeg │ ├── dog2.jpeg │ ├── tiger.jpeg │ ├── lake_tree.jpeg │ ├── autumn_car.jpeg │ ├── cute_chick.jpeg │ └── tiger_river.jpeg └── train_loss.png ├── requirements.txt ├── save_bert_checkpoint.py ├── component ├── argument.py ├── datacollator.py ├── dataset.py ├── configuration.py └── model.py ├── train_args └── train_clip.json ├── filter_data.py ├── download_image.py ├── train_clip.py ├── predict_similarity.py └── README.md /data/.gitkeep: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file !.gitkeep -------------------------------------------------------------------------------- /images/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangjianxin1/CLIP-Chinese/HEAD/images/model.png -------------------------------------------------------------------------------- /images/test/bus.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangjianxin1/CLIP-Chinese/HEAD/images/test/bus.jpeg -------------------------------------------------------------------------------- /images/test/bus2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangjianxin1/CLIP-Chinese/HEAD/images/test/bus2.jpeg -------------------------------------------------------------------------------- /images/test/cat.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangjianxin1/CLIP-Chinese/HEAD/images/test/cat.jpeg -------------------------------------------------------------------------------- /images/test/cat2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangjianxin1/CLIP-Chinese/HEAD/images/test/cat2.jpeg -------------------------------------------------------------------------------- /images/test/cock.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangjianxin1/CLIP-Chinese/HEAD/images/test/cock.jpeg -------------------------------------------------------------------------------- /images/test/dog.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangjianxin1/CLIP-Chinese/HEAD/images/test/dog.jpeg -------------------------------------------------------------------------------- /images/test/dog2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangjianxin1/CLIP-Chinese/HEAD/images/test/dog2.jpeg -------------------------------------------------------------------------------- /images/train_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangjianxin1/CLIP-Chinese/HEAD/images/train_loss.png -------------------------------------------------------------------------------- /images/test/tiger.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangjianxin1/CLIP-Chinese/HEAD/images/test/tiger.jpeg -------------------------------------------------------------------------------- /images/test/lake_tree.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangjianxin1/CLIP-Chinese/HEAD/images/test/lake_tree.jpeg -------------------------------------------------------------------------------- /images/test/autumn_car.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangjianxin1/CLIP-Chinese/HEAD/images/test/autumn_car.jpeg -------------------------------------------------------------------------------- /images/test/cute_chick.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangjianxin1/CLIP-Chinese/HEAD/images/test/cute_chick.jpeg -------------------------------------------------------------------------------- /images/test/tiger_river.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangjianxin1/CLIP-Chinese/HEAD/images/test/tiger_river.jpeg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | huggingface_hub==0.11.0 2 | loguru==0.6.0 3 | numpy==1.22.4 4 | opencv_python==4.6.0.66 5 | pandas==1.4.2 6 | Pillow==9.3.0 7 | requests==2.27.1 8 | torch==1.12.0 9 | tqdm==4.64.0 10 | transformers==4.18.0 11 | -------------------------------------------------------------------------------- /save_bert_checkpoint.py: -------------------------------------------------------------------------------- 1 | """ 2 | 从BertCLIP模型中,拎出bert的权重,单独保存 3 | """ 4 | from component.model import BertCLIPTextModel 5 | from transformers import BertTokenizerFast 6 | 7 | if __name__ == '__main__': 8 | model_name_or_path = 'output/clip/checkpoint-final' 9 | save_path = 'output/bert' 10 | text_model = BertCLIPTextModel.from_pretrained(model_name_or_path) 11 | tokenizer = BertTokenizerFast.from_pretrained(model_name_or_path) 12 | text_model.text_model.save_pretrained(save_path) 13 | tokenizer.save_pretrained(save_path) 14 | -------------------------------------------------------------------------------- /component/argument.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | @dataclass 4 | class CLIPArguments: 5 | """ 6 | 自定义的一些参数 7 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 8 | """ 9 | max_seq_length: int = field(metadata={"help": "输入最大长度"}) 10 | train_file: str = field(metadata={"help": "训练集"}) 11 | test_file: str = field(metadata={"help": "测试集"}) 12 | clip_pretrain_path: str = field(metadata={"help": "clip的预训练权重路径"}) 13 | bert_pretrain_path: str = field(default=False, metadata={"help": "bert的预训练权重路径"}) 14 | image_path: str = field(default=False, metadata={"help": "图片存储路径"}) 15 | load_from_bert_clip: bool = field(default=False, metadata={"help": "是否加载BertCLIPModel的预训练权重"}) 16 | 17 | -------------------------------------------------------------------------------- /train_args/train_clip.json: -------------------------------------------------------------------------------- 1 | { 2 | "output_dir": "output/clip", 3 | "clip_pretrain_path": "openai/clip-vit-base-patch32", 4 | "bert_pretrain_path": "Langboat/mengzi-bert-base", 5 | "load_from_bert_clip": false, 6 | "image_path": "./data/images", 7 | "train_file": "./data/train.csv", 8 | "test_file": null, 9 | "num_train_epochs": 50, 10 | "max_steps": -1, 11 | "per_device_train_batch_size": 768, 12 | "per_device_eval_batch_size": 256, 13 | "learning_rate": 5e-5, 14 | "max_seq_length": 100, 15 | "logging_steps": 500, 16 | "save_steps": 500, 17 | "save_total_limit": 3, 18 | "lr_scheduler_type": "cosine", 19 | "warmup_steps": 1000, 20 | "warmup_ratio": 0, 21 | 22 | "gradient_accumulation_steps": 1, 23 | "optim": "adamw_torch", 24 | "seed": 42, 25 | "fp16": true, 26 | "no_cuda": false, 27 | "dataloader_num_workers": 30, 28 | "save_strategy": "steps", 29 | "weight_decay": 0, 30 | "max_grad_norm": 1.0 31 | } 32 | -------------------------------------------------------------------------------- /component/datacollator.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List 2 | import torch 3 | 4 | 5 | class CLIPCollator(object): 6 | def __init__(self, clip_processor, max_seq_length): 7 | self.clip_processor = clip_processor 8 | self.max_seq_length = max_seq_length 9 | 10 | def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: 11 | texts, pixel_values_list = [], [] 12 | for data in features: 13 | # 如果图片预处理失败,则跳过该图片 14 | if data['pixel_values'] is None: 15 | continue 16 | texts.append(data['text']) 17 | pixel_values_list.append(data['pixel_values']) 18 | # 进行tokenize 19 | inputs = self.clip_processor( 20 | text=texts, return_tensors="pt", max_length=self.max_seq_length, truncation=True, padding=True 21 | ) 22 | pixel_values_list = torch.concat(pixel_values_list, dim=0) 23 | inputs['return_loss'] = True 24 | inputs['pixel_values'] = pixel_values_list 25 | inputs.pop('token_type_ids') 26 | return inputs 27 | -------------------------------------------------------------------------------- /filter_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | 对训练数据进行过滤 3 | 下载的图片中,存在gif图片,对此类图片进行删除,并且把该条数据从训练集中删除,得到过滤后的训练集 4 | """ 5 | import pandas as pd 6 | import os 7 | from os.path import join 8 | from tqdm import tqdm 9 | import imghdr 10 | 11 | 12 | def main(): 13 | train_file = './data/train.csv' 14 | image_path = './data/images' 15 | out_file = './data/train-filter.csv' 16 | 17 | result = [] 18 | df = pd.read_csv(train_file) 19 | for _, row in tqdm(df.iterrows()): 20 | filename = row['filename'] 21 | file = join(image_path, filename) 22 | 23 | # 如果存在该图片 24 | if os.path.exists(file): 25 | # 判断图片是否为gif图或者损坏 26 | img_type = imghdr.what(file) 27 | # 图片损坏,或者为gif图,则跳过 28 | if img_type is None or img_type == 'gif': 29 | print('remove file:{}'.format(file)) 30 | os.remove(file) 31 | else: 32 | result.append(row) 33 | print('len of filter data:{}'.format(len(result))) 34 | df = pd.DataFrame(result) 35 | df.to_csv(out_file, index=False) 36 | 37 | 38 | if __name__ == '__main__': 39 | main() 40 | -------------------------------------------------------------------------------- /component/dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import pandas as pd 3 | from PIL import Image 4 | import requests 5 | from loguru import logger 6 | from glob import glob 7 | from os.path import join 8 | from tqdm import tqdm 9 | 10 | 11 | class CLIPDataset(Dataset): 12 | 13 | def __init__(self, file, clip_processor, image_path): 14 | df = pd.read_csv(file, usecols=['text', 'filename']) 15 | data_list = df.to_dict('records') 16 | print('len of data:{}'.format(len(data_list))) 17 | self.data_list = data_list 18 | self.clip_processor = clip_processor 19 | self.image_path = image_path 20 | 21 | def __len__(self): 22 | return len(self.data_list) 23 | 24 | def __getitem__(self, index): 25 | row = self.data_list[index] 26 | text = row['text'].strip() 27 | filename = row['filename'] 28 | file = join(self.image_path, filename) 29 | try: 30 | image = Image.open(file).convert('RGB') 31 | except Exception as e: 32 | # 下载图片失败 33 | logger.info('open image error, text: {}, filename:{}'.format(text, filename)) 34 | logger.info(e) 35 | image = None 36 | 37 | if image is None: 38 | pixel_values = None 39 | else: 40 | pixel_values = self.clip_processor(images=image, return_tensors='pt')['pixel_values'] 41 | data = {'pixel_values': pixel_values, 'text': text} 42 | return data 43 | -------------------------------------------------------------------------------- /download_image.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import pandas as pd 3 | import os 4 | from os.path import join 5 | from tqdm import tqdm 6 | import multiprocessing 7 | from multiprocessing import cpu_count 8 | from loguru import logger 9 | 10 | """ 11 | 下载图片 12 | """ 13 | 14 | 15 | def download(file, url): 16 | try: 17 | headers = {"User-Agent": "Chrome/68.0.3440.106"} 18 | response = requests.get(url, headers=headers) 19 | status_code = response.status_code 20 | content_type = response.headers.get('Content-Type') 21 | # 请求成功并且不是gif图 22 | if status_code == 200 and content_type != 'image/gif': 23 | image = response.content 24 | with open(file, 'wb') as f: 25 | f.write(image) 26 | except Exception as e: 27 | # 下载图片失败 28 | logger.info('downloading image error, url:{}'.format(url)) 29 | logger.info(e) 30 | 31 | 32 | def main(): 33 | thread_num = 20 # 线程数量 34 | in_file = './data/train.csv' 35 | out_path = './data/images' 36 | if not os.path.exists(out_path): 37 | os.makedirs(out_path) 38 | 39 | df = pd.read_csv(in_file) 40 | print(len(df)) 41 | 42 | # 初始化线程池 43 | pool = multiprocessing.Pool(processes=thread_num) 44 | for _, row in tqdm(df.iterrows()): 45 | filename = row['filename'] 46 | url = row['url'] 47 | file = join(out_path, filename) 48 | # 如果已经存在,则跳过 49 | if os.path.exists(file): 50 | continue 51 | pool.apply_async(download, (file, url)) # 异步并行计算 52 | pool.close() 53 | pool.join() 54 | 55 | 56 | if __name__ == '__main__': 57 | main() 58 | 59 | -------------------------------------------------------------------------------- /component/configuration.py: -------------------------------------------------------------------------------- 1 | from transformers import ( 2 | CLIPTextConfig, 3 | CLIPVisionConfig, 4 | BertConfig, 5 | CLIPConfig 6 | 7 | ) 8 | from transformers.utils import logging 9 | 10 | logger = logging.get_logger(__name__) 11 | 12 | 13 | class BertCLIPConfig(CLIPConfig): 14 | r""" 15 | [`CLIPConfig`] is the configuration class to store the configuration of a [`CLIPModel`]. It is used to instantiate 16 | CLIP model according to the specified arguments, defining the text model and vision model configs. Instantiating a 17 | configuration with the defaults will yield a similar configuration to that of the CLIP 18 | [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture. 19 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 20 | documentation from [`PretrainedConfig`] for more information. 21 | Args: 22 | text_config (`dict`, *optional*): 23 | Dictionary of configuration options used to initialize [`CLIPTextConfig`]. 24 | vision_config (`dict`, *optional*): 25 | Dictionary of configuration options used to initialize [`CLIPVisionConfig`]. 26 | projection_dim (`int`, *optional*, defaults to 512): 27 | Dimentionality of text and vision projection layers. 28 | logit_scale_init_value (`float`, *optional*, defaults to 2.6592): 29 | The inital value of the *logit_scale* paramter. Default is used as per the original CLIP implementation. 30 | kwargs (*optional*): 31 | Dictionary of keyword arguments. 32 | Example: 33 | ```python 34 | >>> from transformers import CLIPConfig, CLIPModel 35 | >>> # Initializing a CLIPConfig with openai/clip-vit-base-patch32 style configuration 36 | >>> configuration = CLIPConfig() 37 | >>> # Initializing a CLIPModel (with random weights) from the openai/clip-vit-base-patch32 style configuration 38 | >>> model = CLIPModel(configuration) 39 | >>> # Accessing the model configuration 40 | >>> configuration = model.config 41 | >>> # We can also initialize a CLIPConfig from a CLIPTextConfig and a CLIPVisionConfig 42 | >>> # Initializing a CLIPText and CLIPVision configuration 43 | >>> config_text = CLIPTextConfig() 44 | >>> config_vision = CLIPVisionConfig() 45 | >>> config = CLIPConfig.from_text_vision_configs(config_text, config_vision) 46 | ```""" 47 | 48 | model_type = "clip" 49 | is_composition = True 50 | 51 | def __init__( 52 | self, text_config=None, vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs 53 | ): 54 | super().__init__(**kwargs) 55 | 56 | # If `_config_dict` exist, we use them for the backward compatibility. 57 | text_config_dict = kwargs.pop("text_config_dict", None) 58 | vision_config_dict = kwargs.pop("vision_config_dict", None) 59 | if text_config_dict is not None: 60 | text_config = text_config_dict 61 | if vision_config_dict is not None: 62 | vision_config = vision_config_dict 63 | 64 | if text_config is None: 65 | text_config = {} 66 | logger.info("text_config is None. Initializing the CLIPTextConfig with default values.") 67 | 68 | if vision_config is None: 69 | vision_config = {} 70 | logger.info("vision_config is None. initializing the CLIPVisionConfig with default values.") 71 | 72 | self.text_config = BertConfig(**text_config) 73 | self.vision_config = CLIPVisionConfig(**vision_config) 74 | 75 | self.projection_dim = projection_dim 76 | self.logit_scale_init_value = logit_scale_init_value 77 | self.initializer_factor = 1.0 78 | 79 | -------------------------------------------------------------------------------- /train_clip.py: -------------------------------------------------------------------------------- 1 | from component.model import BertCLIPModel 2 | from transformers import ( 3 | CLIPConfig, 4 | BertModel, 5 | CLIPFeatureExtractor, 6 | CLIPProcessor, 7 | BertTokenizerFast, 8 | HfArgumentParser, 9 | TrainingArguments, 10 | set_seed, 11 | Trainer, 12 | ) 13 | from loguru import logger 14 | from component.dataset import CLIPDataset 15 | from component.argument import CLIPArguments 16 | import argparse 17 | import os 18 | import json 19 | from os.path import join 20 | from component.datacollator import CLIPCollator 21 | 22 | 23 | def load_model_and_processor(clip_pretrain_path, bert_pretrain_path): 24 | """ 25 | 加载模型和输入的processor,文本编码器与图像编码器分别来自不同的预训练模型,适用于初次做域内预训练 26 | :param clip_pretrain_path: 27 | :param bert_pretrain_path: 28 | :return: 29 | """ 30 | # 加载bert模型 31 | bert_model = BertModel.from_pretrained(bert_pretrain_path) 32 | bert_config = bert_model.config 33 | # 加载clip模型 34 | clip_config = CLIPConfig.from_pretrained(clip_pretrain_path) 35 | clip_config.text_config = bert_config # CLIPConfig中的text_config默认是CLIPTextConfig,将其修改为BertConfig 36 | # 忽略不匹配的预训练权重,主要是由于text_encoder更换为了bert 37 | bert_clip_model = BertCLIPModel.from_pretrained(clip_pretrain_path, config=clip_config, ignore_mismatched_sizes=True) 38 | # 更新clip的text encoder更新为bert的模型权重 39 | setattr(bert_clip_model, 'text_model', bert_model) 40 | # 将vision_model的权重冻结 41 | for name, param in bert_clip_model.vision_model.named_parameters(): 42 | param.requires_grad = False 43 | 44 | # 查看clip中的bert是否设置正确 45 | logger.info( 46 | 'bert_clip_model data_ptr:{}'.format(bert_clip_model.text_model.embeddings.word_embeddings.weight.data_ptr())) 47 | logger.info('bert data_ptr:{}'.format(bert_model.embeddings.word_embeddings.weight.data_ptr())) 48 | 49 | # 加载feature_extractor和tokenizer 50 | feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_pretrain_path) 51 | tokenizer = BertTokenizerFast.from_pretrained(bert_pretrain_path) 52 | # note: 代码库默认使用CLIPTokenizer, 这里需要设置自己需要的tokenizer的名称 53 | CLIPProcessor.tokenizer_class = 'BertTokenizerFast' 54 | clip_processor = CLIPProcessor(feature_extractor=feature_extractor, tokenizer=tokenizer) 55 | 56 | return bert_clip_model, clip_processor 57 | 58 | 59 | def load_model_and_processor_from_bert_clip(clip_pretrain_path): 60 | """ 61 | 加载模型和输入的processor。整个模型权重均加载自BertModel的checkpoint,适用于已经使用域内数据做预训练后,加载checkpoint继续预训练。 62 | :param clip_pretrain_path: 63 | """ 64 | # 加载模型 65 | model = BertCLIPModel.from_pretrained(clip_pretrain_path) 66 | # 将vision_model的权重冻结 67 | for name, param in model.vision_model.named_parameters(): 68 | param.requires_grad = False 69 | 70 | # feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_pretrain_path) 71 | # tokenizer = BertTokenizerFast.from_pretrained(clip_pretrain_path) 72 | # note: 代码库默认使用CLIPTokenizer, 这里需要设置自己需要的tokenizer的名称 73 | CLIPProcessor.tokenizer_class = 'BertTokenizerFast' 74 | # clip_processor = CLIPProcessor(feature_extractor=feature_extractor, tokenizer=tokenizer) 75 | clip_processor = CLIPProcessor.from_pretrained(clip_pretrain_path) 76 | return model, clip_processor 77 | 78 | 79 | def main(): 80 | parser = argparse.ArgumentParser() 81 | parser.add_argument("--train_args_file", type=str, default='train_args/train_clip.json', help="") 82 | args = parser.parse_args() 83 | train_args_file = args.train_args_file 84 | # 读取参数配置 85 | parser = HfArgumentParser((CLIPArguments, TrainingArguments)) 86 | args, training_args = parser.parse_json_file(json_file=train_args_file) 87 | # 创建输出目录 88 | if not os.path.exists(training_args.output_dir): 89 | os.makedirs(training_args.output_dir) 90 | # 记录训练参数 91 | with open(train_args_file, 'r', encoding='utf8') as f: 92 | train_args = json.load(f) 93 | with open(join(training_args.output_dir, 'train_args.json'), 'w', encoding='utf8') as f: 94 | json.dump(train_args, f, indent=2) 95 | # 设置随机种子 96 | set_seed(training_args.seed) 97 | # 已经有了一版BertCLIP的预训练权重,直接加载 98 | if args.load_from_bert_clip: 99 | bert_clip_model, clip_processor = load_model_and_processor_from_bert_clip(args.clip_pretrain_path) 100 | # vision encoder和text encoder分别加载自不同的预训练权重 101 | else: 102 | bert_clip_model, clip_processor = load_model_and_processor(args.clip_pretrain_path, args.bert_pretrain_path) 103 | # 加载数据集 104 | train_dataset = CLIPDataset(args.train_file, clip_processor, args.image_path) 105 | # 初始化collator 106 | data_collator = CLIPCollator(clip_processor=clip_processor, max_seq_length=args.max_seq_length) 107 | 108 | # 初始化训练器 109 | # 此处将tokenizer设为clip_processor,主要是为了保存模型的时候能够顺便保存processor的配置,没有其他作用。 110 | trainer = Trainer( 111 | model=bert_clip_model, 112 | args=training_args, 113 | train_dataset=train_dataset, 114 | data_collator=data_collator, 115 | tokenizer=clip_processor 116 | ) 117 | 118 | # 开始训练 119 | train_result = trainer.train() 120 | metrics = train_result.metrics 121 | trainer.log_metrics("train", metrics) 122 | trainer.save_metrics("train", metrics) 123 | trainer.save_state() 124 | trainer.save_model(join(training_args.output_dir, 'checkpoint-final')) 125 | 126 | # 评测验证集的指标 127 | if args.test_file is not None: 128 | logger.info("*** start test ***") 129 | test_dataset = CLIPDataset(args.test_file, clip_processor, args.image_path) 130 | metrics = trainer.evaluate(test_dataset) 131 | trainer.log_metrics("test", metrics) 132 | trainer.save_metrics("test", metrics) 133 | 134 | 135 | if __name__ == '__main__': 136 | main() 137 | -------------------------------------------------------------------------------- /predict_similarity.py: -------------------------------------------------------------------------------- 1 | """ 2 | 计算图文相似度,以及文本相似度的脚本 3 | """ 4 | import torch 5 | 6 | from component.model import BertCLIPModel 7 | from transformers import CLIPProcessor 8 | from PIL import Image 9 | 10 | 11 | def load_model_and_processor(model_name_or_path): 12 | # 加载模型 13 | model = BertCLIPModel.from_pretrained(model_name_or_path) 14 | # note: 代码库默认使用CLIPTokenizer, 这里需要设置自己需要的tokenizer的名称 15 | CLIPProcessor.tokenizer_class = 'BertTokenizerFast' 16 | processor = CLIPProcessor.from_pretrained(model_name_or_path) 17 | return model, processor 18 | 19 | 20 | def process_data(texts, image_files, clip_processor): 21 | # 如果存在需要对图片进行预处理,则读取文件 22 | if image_files is not None: 23 | images = [Image.open(x).convert('RGB') for x in image_files] 24 | else: 25 | images = None 26 | # 预处理 27 | inputs = clip_processor(images=images, text=texts, return_tensors='pt', padding=True) 28 | if 'token_type_ids' in inputs: 29 | inputs.pop('token_type_ids') 30 | return inputs 31 | 32 | 33 | def cal_image_text_sim(model, clip_processor): 34 | """ 35 | 计算图片和所有候选文本之间的相似度 36 | """ 37 | print('-------------- 计算图文相似度 --------------') 38 | texts = [ 39 | '秋天跑车唯美图片桌面壁纸', '可爱的小鸡', '一群可爱的小黄鸡在篮子里', '一只小狗', '一只可爱的小猫', '清澈的湖水,蓝蓝的天空,茂密的树木', 40 | '冬日里,一只老虎在雪地玩耍', '一只老虎在河边喝水', '一辆公交车停在路边', '一只公鸡在打鸣' 41 | ] 42 | image_files = [ 43 | './images/test/autumn_car.jpeg', './images/test/bus.jpeg', './images/test/cat.jpeg', './images/test/cock.jpeg', 44 | './images/test/cute_chick.jpeg', './images/test/dog.jpeg', './images/test/lake_tree.jpeg', 45 | './images/test/tiger.jpeg', './images/test/tiger_river.jpeg' 46 | ] 47 | # 特征处理 48 | inputs = process_data(texts, image_files, clip_processor) 49 | 50 | with torch.no_grad(): 51 | out = model(**inputs) 52 | 53 | # 对于每张图片,其与所有文本的相似度 54 | logits_per_image = out.logits_per_image 55 | # 对分数做softmax 56 | logits_per_image = torch.softmax(logits_per_image, dim=-1) 57 | # 对于每张图片,将其与所有文本的相似度,进行降序排序 58 | logits_per_image = logits_per_image.numpy().tolist() 59 | for scores, file in zip(logits_per_image, image_files): 60 | result = sorted([(text, score) for text, score in zip(texts, scores)], key=lambda x: x[1], reverse=True) 61 | print('图片和所有候选文本的相似度:{}'.format(file)) 62 | print(result) 63 | print() 64 | 65 | 66 | def cal_text_text_sim(model, clip_processor): 67 | """ 68 | 计算文本和文本之间的相似度 69 | """ 70 | print('-------------- 计算文本相似度 --------------') 71 | texts = [ 72 | '桑巴军团', '巴西', '日耳曼战车', 73 | '德国', '一个圆圆的月亮高高挂在天空', '夜幕中的白玉盘升起,星光灿烂', '小猪', '佩奇', '足球场', '绿茵', 74 | '雪花漫天飞舞,狂风怒号,天地之间白茫茫一片', '北国风光,千里冰封,万里雪飘,银装素裹', 75 | '大漠沙如雪,燕山月似钩', '月光洒在沙滩上,就像铺上了一层白皑皑的雪。燕山上,月亮像钩子一般', 76 | '天街小雨润如酥,草色摇看近却无', '长安街上细密的春雨润滑如酥,远望草色连成一片,近看却又显得稀疏', 77 | '一只老虎在草原上追捕一只小鹿', '大猫在飞速狂奔,捕杀猎物', '英雄联盟', 'lol' 78 | ] 79 | inputs = process_data(texts, None, clip_processor) 80 | with torch.no_grad(): 81 | text_embeds = model.get_text_features(**inputs) 82 | # normalized features 83 | text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True) 84 | # 计算两两文本之间的相似度 85 | logit_scale = model.logit_scale.exp() 86 | logits_per_text = torch.matmul(text_embeds, text_embeds.t()) * logit_scale 87 | 88 | # 对于每个文本,将自己的分数置为-10000 89 | batch_size = logits_per_text.size(0) 90 | eye = torch.eye(batch_size) * -10000 91 | logits_per_text = logits_per_text + eye 92 | # 对分数做softmax 93 | logits_per_text = torch.softmax(logits_per_text, dim=-1) 94 | 95 | # 对于每个文本,将其与所有文本的相似度,进行降序排序 96 | logits_per_text = logits_per_text.numpy().tolist() 97 | for scores, text in zip(logits_per_text, texts): 98 | result = sorted([(text, score) for text, score in zip(texts, scores)], key=lambda x: x[1], reverse=True) 99 | print('文本和所有候选文本的相似度:{}'.format(text)) 100 | print(result) 101 | print() 102 | 103 | 104 | def cal_image_image_sim(model, clip_processor): 105 | """ 106 | 计算图片与图片之间的相似度 107 | """ 108 | print('-------------- 计算图图相似度 --------------') 109 | image_files = [ 110 | './images/test/bus.jpeg', './images/test/bus2.jpeg', 111 | './images/test/cat.jpeg', './images/test/cat2.jpeg', './images/test/cock.jpeg', 112 | './images/test/cute_chick.jpeg', './images/test/dog.jpeg', './images/test/dog2.jpeg', 113 | './images/test/tiger.jpeg', './images/test/tiger_river.jpeg', './images/test/autumn_car.jpeg' 114 | ] 115 | # 特征处理 116 | inputs = process_data(None, image_files, clip_processor) 117 | 118 | with torch.no_grad(): 119 | image_embeds = model.get_image_features(**inputs) 120 | # normalized features 121 | image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True) 122 | # 计算两两文本之间的相似度 123 | logit_scale = model.logit_scale.exp() 124 | logits_per_image = torch.matmul(image_embeds, image_embeds.t()) * logit_scale 125 | 126 | # 对于每个文本,将自己的分数置为-10000 127 | batch_size = logits_per_image.size(0) 128 | eye = torch.eye(batch_size) * -10000 129 | logits_per_image = logits_per_image + eye 130 | # 对分数做softmax 131 | logits_per_image = torch.softmax(logits_per_image, dim=-1) 132 | 133 | # 对于每个文本,将其与所有文本的相似度,进行降序排序 134 | logits_per_image = logits_per_image.numpy().tolist() 135 | for scores, image in zip(logits_per_image, image_files): 136 | result = sorted([(image, score) for image, score in zip(image_files, scores)], key=lambda x: x[1], reverse=True) 137 | print('图片和所有候选图片的相似度:{}'.format(image)) 138 | print(result) 139 | print() 140 | 141 | 142 | def main(): 143 | model_name_or_path = 'YeungNLP/clip-vit-bert-chinese-1M' 144 | # 加载模型 145 | model, clip_processor = load_model_and_processor(model_name_or_path) 146 | # 预测相似度 147 | cal_image_text_sim(model, clip_processor) 148 | cal_text_text_sim(model, clip_processor) 149 | cal_image_image_sim(model, clip_processor) 150 | 151 | 152 | if __name__ == '__main__': 153 | main() 154 | 155 | -------------------------------------------------------------------------------- /component/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from typing import Optional, Tuple, Union 4 | from transformers import ( 5 | BertModel, 6 | CLIPPreTrainedModel, 7 | CLIPVisionConfig, 8 | BertConfig 9 | ) 10 | from transformers.models.clip.modeling_clip import CLIPOutput, clip_loss 11 | from transformers.modeling_outputs import BaseModelOutputWithPooling 12 | from transformers.models.clip.modeling_clip import CLIPVisionTransformer 13 | from .configuration import BertCLIPConfig 14 | 15 | 16 | class BertCLIPModel(CLIPPreTrainedModel): 17 | config_class = BertCLIPConfig 18 | 19 | def __init__(self, config: BertCLIPConfig): 20 | super(BertCLIPModel, self).__init__(config) 21 | 22 | if not isinstance(config.text_config, BertConfig): 23 | raise ValueError( 24 | f"config.text_config is expected to be of type BertConfig but is of type {type(config.text_config)}." 25 | ) 26 | 27 | if not isinstance(config.vision_config, CLIPVisionConfig): 28 | raise ValueError( 29 | f"config.vision_config is expected to be of type CLIPVisionConfig but is of type {type(config.vision_config)}." 30 | ) 31 | 32 | text_config = config.text_config 33 | vision_config = config.vision_config 34 | 35 | self.projection_dim = config.projection_dim 36 | self.text_embed_dim = text_config.hidden_size 37 | self.vision_embed_dim = vision_config.hidden_size 38 | 39 | # 将文本encoder修改为bert 40 | self.text_model = BertModel(text_config) 41 | self.vision_model = CLIPVisionTransformer(vision_config) 42 | 43 | self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False) 44 | self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False) 45 | self.logit_scale = nn.Parameter(torch.ones([]) * self.config.logit_scale_init_value) 46 | 47 | # Initialize weights and apply final processing 48 | self.post_init() 49 | 50 | def get_text_features( 51 | self, 52 | input_ids: Optional[torch.Tensor] = None, 53 | attention_mask: Optional[torch.Tensor] = None, 54 | position_ids: Optional[torch.Tensor] = None, 55 | output_attentions: Optional[bool] = None, 56 | output_hidden_states: Optional[bool] = None, 57 | return_dict: Optional[bool] = None, 58 | ) -> torch.FloatTensor: 59 | r""" 60 | Returns: 61 | text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by 62 | applying the projection layer to the pooled output of [`CLIPTextModel`]. 63 | """ 64 | # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. 65 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 66 | output_hidden_states = ( 67 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 68 | ) 69 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 70 | 71 | text_outputs = self.text_model( 72 | input_ids=input_ids, 73 | attention_mask=attention_mask, 74 | position_ids=position_ids, 75 | output_attentions=output_attentions, 76 | output_hidden_states=output_hidden_states, 77 | return_dict=return_dict, 78 | ) 79 | 80 | pooled_output = text_outputs[1] 81 | text_features = self.text_projection(pooled_output) 82 | 83 | return text_features 84 | 85 | def get_image_features( 86 | self, 87 | pixel_values: Optional[torch.FloatTensor] = None, 88 | output_attentions: Optional[bool] = None, 89 | output_hidden_states: Optional[bool] = None, 90 | return_dict: Optional[bool] = None, 91 | ) -> torch.FloatTensor: 92 | r""" 93 | Returns: 94 | image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by 95 | applying the projection layer to the pooled output of [`CLIPVisionModel`]. 96 | """ 97 | # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. 98 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 99 | output_hidden_states = ( 100 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 101 | ) 102 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 103 | 104 | vision_outputs = self.vision_model( 105 | pixel_values=pixel_values, 106 | output_attentions=output_attentions, 107 | output_hidden_states=output_hidden_states, 108 | return_dict=return_dict, 109 | ) 110 | 111 | pooled_output = vision_outputs[1] # pooled_output 112 | image_features = self.visual_projection(pooled_output) 113 | 114 | return image_features 115 | 116 | def forward( 117 | self, 118 | input_ids: Optional[torch.LongTensor] = None, 119 | pixel_values: Optional[torch.FloatTensor] = None, 120 | attention_mask: Optional[torch.Tensor] = None, 121 | position_ids: Optional[torch.LongTensor] = None, 122 | return_loss: Optional[bool] = None, 123 | output_attentions: Optional[bool] = None, 124 | output_hidden_states: Optional[bool] = None, 125 | return_dict: Optional[bool] = None, 126 | ) -> Union[Tuple, CLIPOutput]: 127 | # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. 128 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 129 | output_hidden_states = ( 130 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 131 | ) 132 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 133 | 134 | vision_outputs = self.vision_model( 135 | pixel_values=pixel_values, 136 | output_attentions=output_attentions, 137 | output_hidden_states=output_hidden_states, 138 | return_dict=return_dict, 139 | ) 140 | 141 | text_outputs = self.text_model( 142 | input_ids=input_ids, 143 | attention_mask=attention_mask, 144 | position_ids=position_ids, 145 | output_attentions=output_attentions, 146 | output_hidden_states=output_hidden_states, 147 | return_dict=return_dict, 148 | ) 149 | 150 | image_embeds = vision_outputs[1] 151 | image_embeds = self.visual_projection(image_embeds) 152 | 153 | text_embeds = text_outputs[1] 154 | text_embeds = self.text_projection(text_embeds) 155 | 156 | # normalized features 157 | image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True) 158 | text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True) 159 | 160 | # cosine similarity as logits 161 | logit_scale = self.logit_scale.exp() 162 | logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale 163 | logits_per_image = logits_per_text.T 164 | 165 | loss = None 166 | if return_loss: 167 | loss = clip_loss(logits_per_text) 168 | 169 | if not return_dict: 170 | output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) 171 | return ((loss,) + output) if loss is not None else output 172 | 173 | return CLIPOutput( 174 | loss=loss, 175 | logits_per_image=logits_per_image, 176 | logits_per_text=logits_per_text, 177 | text_embeds=text_embeds, 178 | image_embeds=image_embeds, 179 | text_model_output=text_outputs, 180 | vision_model_output=vision_outputs, 181 | ) 182 | 183 | 184 | class BertCLIPTextModel(CLIPPreTrainedModel): 185 | config_class = BertCLIPConfig 186 | 187 | def __init__(self, config: BertCLIPConfig): 188 | super().__init__(config) 189 | bert_config = config.text_config 190 | self.text_model = BertModel(bert_config) 191 | # Initialize weights and apply final processing 192 | self.post_init() 193 | 194 | def get_input_embeddings(self) -> nn.Module: 195 | return self.text_model.embeddings.token_embedding 196 | 197 | def set_input_embeddings(self, value): 198 | self.text_model.embeddings.token_embedding = value 199 | 200 | def forward( 201 | self, 202 | input_ids: Optional[torch.Tensor] = None, 203 | attention_mask: Optional[torch.Tensor] = None, 204 | position_ids: Optional[torch.Tensor] = None, 205 | output_attentions: Optional[bool] = None, 206 | output_hidden_states: Optional[bool] = None, 207 | return_dict: Optional[bool] = None, 208 | ) -> Union[Tuple, BaseModelOutputWithPooling]: 209 | return self.text_model( 210 | input_ids=input_ids, 211 | attention_mask=attention_mask, 212 | position_ids=position_ids, 213 | output_attentions=output_attentions, 214 | output_hidden_states=output_hidden_states, 215 | return_dict=return_dict, 216 | ) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CLIP-Chinese:中文多模态对比学习CLIP预训练模型 2 | 3 | ## 项目描述 4 | 微信公众号【YeungNLP】文章:[CLIP-Chinese:中文多模态对比学习预训练模型](https://mp.weixin.qq.com/s/6gQX91M-Lt7eiMimhYRJEw) ,文章内可获取140w中文图文对预训练数据,以及中文CLIP预训练权重。 5 | 6 | CLIP是由OpenAI提出的一种多模态对比学习方法,原模型使用了4亿个图文对进行对比学习训练,在下游的各种任务上均取得了不错的效果,并且在Zero-Shot任务上效果也令人惊艳。 7 | 模型论文可参考[CLIP论文:Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020) 8 | 9 | 由于原生的CLIP模型是基于英文语料训练的,无法在中文任务中使用,本项目便是为了解决该问题。 本项目的主要工作如下: 10 | - 编写Vit+Bert结构的CLIP模型,下面将其称为BertCLIP模型,以及预训练的pipeline。 11 | - 基于LiT-tuning(Locked-image Text tuning)的方法,使用140万中文文本数据,对BertCLIP模型进行预训练。 12 | - 在图文相似度、文本相似度、图图相似度等任务上,验证预训练模型的有效性。 13 | - 分享140w中文图文对数据,分享预训练模型权重。 14 | 15 | ## 预模型权重分享 16 | 预训练权重使用方式详见下文 17 | 18 | | 预训练模型 | 预训练模型名称 | 模型地址 | 19 | |----------|---------|-------------------------------------------------------------| 20 | | BertCLIP的整体权重 | YeungNLP/clip-vit-bert-chinese-1M | https://huggingface.co/YeungNLP/clip-vit-bert-chinese-1M | 21 | | Bert的权重 | YeungNLP/bert-from-clip-chinese-1M | https://huggingface.co/YeungNLP/bert-from-clip-chinese-1M | 22 | 23 | ## 运行环境 24 | python==3.8、transformers==4.18.0、torch==1.12.0 25 | 26 | 27 | ## 项目结构 28 | - data:存放训练数据 29 | - images:存放训练图片 30 | - images:存放一些测试的图片 31 | - module:一些模块 32 | - argument.py:定制一些训练配置参数 33 | - configuration.py:模型配置config 34 | - datacollator.py 35 | - dataset.py 36 | - model.py:模型结构 37 | - train_args:训练参数的配置文件 38 | - download_image.py:下载图片的脚本 39 | - filter_data.py:过滤训练数据的脚本 40 | - train_clip.py:模型训练脚本 41 | - predict_similarity.py:计算图文相似度、文本相似度、图图相似度的脚本 42 | 43 | ## 模型介绍与训练细节 44 | 笔者编写了一个基于Vit+Bert结构的BertCLIP模型,模型结构与原生CLIP大同小异,如下图所示。 45 | 46 | ![model](images/model.png) 47 | 48 | 预训练时,Vit与Bert分别加载不同的预训练权重,进行初始化。其中Vit的权重使用openai的clip模型进行初始化, 49 | 而Bert的权重使用mengzi中文预训练权重进行初始化。 50 | 51 | 在训练的时候,使用LiT-tuning(Locked-image Text tuning)的策略,也就是将Vit的权重进行冻结,对模型的其他参数进行训练。使用140w的中文图文对,过滤掉一些坏图, 52 | batch size=768,warmup step为1000步,学习率为5e-5,使用cosine衰减策略,训练50个epoch,大概73100个step,最终训练loss降到0.23左右。 53 | ![model](images/train_loss.png) 54 | 55 | ## 使用方法 56 | 57 | ### Quick Start 58 | 使用如下脚本,就可成功加载笔者分享的预训练权重,对图片和文本进行预处理,并且得到模型的输出 59 | 60 | ```python 61 | from transformers import CLIPProcessor 62 | from component.model import BertCLIPModel 63 | from PIL import Image 64 | import requests 65 | 66 | model_name_or_path = 'YeungNLP/clip-vit-bert-chinese-1M' 67 | # 加载预训练模型权重 68 | model = BertCLIPModel.from_pretrained(model_name_or_path) 69 | CLIPProcessor.tokenizer_class = 'BertTokenizerFast' 70 | # 初始化processor 71 | processor = CLIPProcessor.from_pretrained(model_name_or_path) 72 | # 预处理输入 73 | url = "http://images.cocodataset.org/val2017/000000039769.jpg" 74 | image = Image.open(requests.get(url, stream=True).raw) 75 | inputs = processor(text=["一只小狗在摇尾巴", "一只小猪在吃饭"], images=image, return_tensors="pt", padding=True) 76 | inputs.pop('token_type_ids') # 输入中不包含token_type_ids 77 | 78 | outputs = model(**inputs) 79 | 80 | # 对于每张图片,计算其与所有文本的相似度 81 | logits_per_image = outputs.logits_per_image # image-text的相似度得分 82 | probs = logits_per_image.softmax(dim=1) # 对分数进行归一化 83 | 84 | # 对于每个文本,计算其与所有图片的相似度 85 | logits_per_text = outputs.logits_per_text # text-image的相似度得分 86 | probs = logits_per_text.softmax(dim=1) # 对分数进行归一化 87 | 88 | # 获得文本编码 89 | text_embeds = outputs.text_embeds 90 | # 获得图像编码 91 | image_embeds = outputs.image_embeds 92 | ``` 93 | 94 | 单独加载图像编码器,进行下游任务 95 | ```python 96 | from PIL import Image 97 | import requests 98 | from transformers import CLIPProcessor, CLIPVisionModel 99 | 100 | model_name_or_path = 'YeungNLP/clip-vit-bert-chinese-1M' 101 | model = CLIPVisionModel.from_pretrained(model_name_or_path) 102 | CLIPProcessor.tokenizer_class = 'BertTokenizerFast' 103 | processor = CLIPProcessor.from_pretrained(model_name_or_path) 104 | 105 | url = "http://images.cocodataset.org/val2017/000000039769.jpg" 106 | image = Image.open(requests.get(url, stream=True).raw) 107 | 108 | inputs = processor(images=image, return_tensors="pt") 109 | 110 | outputs = model(**inputs) 111 | last_hidden_state = outputs.last_hidden_state 112 | pooled_output = outputs.pooler_output 113 | ``` 114 | 115 | 单独加载文本编码器,进行下游任务 116 | 117 | ```python 118 | from component.model import BertCLIPTextModel 119 | from transformers import BertTokenizerFast 120 | 121 | model_name_or_path = 'YeungNLP/clip-vit-bert-chinese-1M' 122 | model = BertCLIPTextModel.from_pretrained(model_name_or_path) 123 | tokenizer = BertTokenizerFast.from_pretrained(model_name_or_path) 124 | 125 | inputs = tokenizer(["一只小狗在摇尾巴", "一只小猪在吃饭"], padding=True, return_tensors="pt") 126 | inputs.pop('token_type_ids') # 输入中不包含token_type_ids 127 | 128 | outputs = model(**inputs) 129 | last_hidden_state = outputs.last_hidden_state 130 | pooled_output = outputs.pooler_output 131 | ``` 132 | 133 | 作者把训练好的Bert模型权重也单独拎出来,可以直接使用BertModel直接加载,进行下游任务 134 | ```python 135 | from transformers import BertTokenizer, BertModel 136 | 137 | model_name_or_path = 'YeungNLP/bert-from-clip-chinese-1M' 138 | tokenizer = BertTokenizer.from_pretrained(model_name_or_path) 139 | model = BertModel.from_pretrained(model_name_or_path) 140 | ``` 141 | 142 | 143 | ### 获取训练数据 144 | 可以直接使用作者分享的140w的中文训练数据,数据可从公众号文章中获取。也可以使用自己的训练数据。训练数据为csv文件,格式如下,其中filename表示图片下载后的文件名。 145 | ``` 146 | text,url,filename 147 | 欧美夏季ebay连衣裙 气质圆领通勤绑带收腰连衣裙 zc3730,"https://gimg2.baidu.com/image_search/src=http%3A%2F%2Fcbu01.alicdn.com%2Fimg%2Fibank%2F2020%2F527%2F038%2F17187830725_1528924397.220x220.jpg&refer=http%3A%2F%2Fcbu01.alicdn.com&app=2002&size=f9999,10000&q=a80&n=0&g=0n&fmt=jpeg?sec=1632524815&t=d66159b43fb0335c11898f9764847ea7",test-0.jpg 148 | "曾是名不见经传的王平,为何能够取代魏延,成为蜀汉",https://pic.rmb.bdstatic.com/19539b3b1a7e1daee93b0f3d99b8e795.png,test-1.jpg 149 | 女童黄色连衣裙,"https://gimg2.baidu.com/image_search/src=http%3A%2F%2Fa.vpimg2.com%2Fupload%2Fmerchandise%2F227958%2FLYQ-S314186413-3.jpg&refer=http%3A%2F%2Fa.vpimg2.com&app=2002&size=f9999,10000&q=a80&n=0&g=0n&fmt=jpeg?sec=1632501843&t=b0a3b843f9ecebd71fe6f27643c17486",test-2.jpg 150 | ``` 151 | 152 | ### 下载图片 153 | 执行download_image.py脚本,可以直接多线程下载图片,只需要指定线程数、训练文件,以及图片保存路径即可。 154 | 155 | ### 配置训练参数 156 | 在train_args/train_clip.json中按需配置训练参数,参数说明如下: 157 | - output_dir:训练输出路径 158 | - clip_pretrain_path:clip模型的预训练权重 159 | - bert_pretrain_path:bert模型的预训练权重 160 | - load_from_bert_clip:是否使用BertCLIP的模型权重对模型进行初始化。若为False,则需要同时指定clip_pretrain_path与bert_pretrain_path,模型的Vit与Bert的权重分别加载自不同的预训练权重。若为True,则整个BertCLIP模型直接使用clip_pretrain_path的预训练权重进行初始化。 161 | - image_path:存放图片的目录 162 | - train_file:训练集 163 | - test_file:测试集,如果为None,则不进行预测 164 | - num_train_epochs:训练轮次 165 | - max_steps:训练的最大步数,会覆盖num_train_epochs的效果 166 | - per_device_train_batch_size:训练的batch size 167 | - per_device_eval_batch_size:推理的batch size 168 | - learning_rate:学习率 169 | - max_seq_length:文本的最大长度 170 | - logging_steps:多少步打印一次训练日志 171 | - save_steps:多少步保存一次checkpoint 172 | - save_total_limit:最多保存多少个checkpoint 173 | - lr_scheduler_type:学习率的变化策略 174 | - warmup_steps:warmup的步数,会覆盖warmup_ratio的效果 175 | - warmup_ratio:warmup的比例 176 | - gradient_accumulation_steps:梯度累计的步数 177 | - optim:优化器 178 | - seed:随机种子 179 | - fp16:是否使用混合精度进行训练,最好设为True,可以使用更大的batch size,并且加快训练速度 180 | - no_cuda:是否不使用GPU 181 | - dataloader_num_workers:使用多少个线程加载训练数据,根据自己的机器情况,尽量设大一些,否则训练瓶颈会卡在读图片上 182 | 183 | 184 | ### 开始训练 185 | 注:如果需要在YeungNLP/clip-vit-bert-chinese-1M权重的基础上做继续预训练,需要将令load_from_bert_clip=True,clip_pretrain_path="YeungNLP/clip-vit-bert-chinese-1M"。 186 | ``` 187 | CUDA_VISIBLE_DEVICES=0 python train_clip.py --train_args_file train_args/train_clip.json 188 | 189 | 后台运行: 190 | CUDA_VISIBLE_DEVICES=0 nohup python train_clip.py --train_args_file train_args/train_clip.json & 191 | ``` 192 | 193 | ### 相似度计算 194 | 作者实现了图文相似度、文本相似度、图图相似度的计算脚本,在predict_similarity.py文件中 195 | 196 | 197 | ## 效果展示 198 | ### 图文相似度计算 199 | 在计算图文相似的时候,首先计算两两图文向量之间的点乘相似度。对于每张图,将其与所有文本的相似度进行softmax归一化,得到最终的分数。 200 | 201 | | 图片 | 候选文本的相似度 | 202 | |---------------|-----------| 203 | | ![model](images/test/bus.jpeg) | [('一辆公交车停在路边', 1.0), ('清澈的湖水,蓝蓝的天空,茂密的树木', 5.12310229794366e-08), ('秋天跑车唯美图片桌面壁纸', 8.085075942076969e-10), ('冬日里,一只老虎在雪地玩耍', 4.903254538501933e-11), ('一只小狗', 7.86001212033094e-12), ('一只可爱的小猫', 1.0248470908719165e-12), ('可爱的小鸡', 6.081324679940714e-13), ('一群可爱的小黄鸡在篮子里', 4.469586525434992e-14), ('一只老虎在河边喝水', 3.782940198479535e-15), ('一只公鸡在打鸣', 9.850900002943315e-16)] | 204 | | ![model](images/test/cat.jpeg) | [('一只可爱的小猫', 0.9998341798782349), ('一只小狗', 0.00011115620145574212), ('冬日里,一只老虎在雪地玩耍', 3.2785530493129045e-05), ('可爱的小鸡', 1.3479968401952647e-05), ('一只公鸡在打鸣', 5.406232048699167e-06), ('一只老虎在河边喝水', 1.8825736560756923e-06), ('秋天跑车唯美图片桌面壁纸', 7.272767561516957e-07), ('一群可爱的小黄鸡在篮子里', 3.3080158345910604e-07), ('清澈的湖水,蓝蓝的天空,茂密的树木', 2.4945970622525238e-08), ('一辆公交车停在路边', 3.1998936920324406e-13)] | 205 | | ![model](images/test/lake_tree.jpeg) | [('清澈的湖水,蓝蓝的天空,茂密的树木', 0.9990612864494324), ('秋天跑车唯美图片桌面壁纸', 0.0009054617257788777), ('一只公鸡在打鸣', 3.1990679417504e-05), ('一只老虎在河边喝水', 7.763640610392031e-07), ('一只可爱的小猫', 2.097889790775298e-07), ('冬日里,一只老虎在雪地玩耍', 1.320097595680636e-07), ('一只小狗', 3.0081434232442916e-08), ('一群可爱的小黄鸡在篮子里', 2.7587644169102532e-08), ('一辆公交车停在路边', 1.4087997435296984e-08), ('可爱的小鸡', 2.3810455343498127e-11)] | 206 | | ![model](images/test/tiger.jpeg) | [('冬日里,一只老虎在雪地玩耍', 0.9999402761459351), ('一只老虎在河边喝水', 5.974959640298039e-05), ('一只可爱的小猫', 1.1624400997334305e-08), ('一只小狗', 1.0728960254946518e-11), ('秋天跑车唯美图片桌面壁纸', 2.6702420656554704e-12), ('一只公鸡在打鸣', 1.529327337511377e-13), ('清澈的湖水,蓝蓝的天空,茂密的树木', 4.067204540281373e-14), ('可爱的小鸡', 5.289698732746477e-15), ('一辆公交车停在路边', 6.407785717133061e-17), ('一群可爱的小黄鸡在篮子里', 5.284812596720461e-17)] | 207 | | ![model](images/test/tiger_river.jpeg) | [('一只老虎在河边喝水', 0.9969038367271423), ('冬日里,一只老虎在雪地玩耍', 0.0030961050651967525), ('一只可爱的小猫', 6.944087971305635e-09), ('一只小狗', 3.5471511838913727e-10), ('清澈的湖水,蓝蓝的天空,茂密的树木', 1.8006697521943948e-10), ('一只公鸡在打鸣', 3.4972351403705915e-11), ('可爱的小鸡', 3.3940988040936926e-12), ('一群可爱的小黄鸡在篮子里', 2.376999638786792e-12), ('一辆公交车停在路边', 2.276026318370067e-13), ('秋天跑车唯美图片桌面壁纸', 2.0756604714091548e-13)] | 208 | | ![model](images/test/autumn_car.jpeg) | [('秋天跑车唯美图片桌面壁纸', 1.0), ('冬日里,一只老虎在雪地玩耍', 9.960791913510292e-11), ('一只公鸡在打鸣', 1.591680606760626e-11), ('一只可爱的小猫', 4.712048893434906e-12), ('一只老虎在河边喝水', 5.603533045558939e-13), ('可爱的小鸡', 9.460436448983922e-14), ('一辆公交车停在路边', 9.048587345985432e-14), ('一只小狗', 5.001745647162641e-15), ('一群可爱的小黄鸡在篮子里', 1.828375462031742e-15), ('清澈的湖水,蓝蓝的天空,茂密的树木', 7.682980915206854e-18)] | 209 | | ![model](images/test/cock.jpeg) | [('一只公鸡在打鸣', 0.9975091218948364), ('可爱的小鸡', 0.0022025061771273613), ('一群可爱的小黄鸡在篮子里', 0.00028838840080425143), ('秋天跑车唯美图片桌面壁纸', 6.824043552455805e-09), ('一只老虎在河边喝水', 4.110817908298259e-09), ('一只小狗', 2.337234850102732e-09), ('一只可爱的小猫', 1.6396863866674494e-09), ('清澈的湖水,蓝蓝的天空,茂密的树木', 2.0205015438534701e-10), ('冬日里,一只老虎在雪地玩耍', 4.627530997280971e-11), ('一辆公交车停在路边', 2.0185879335242463e-13)] | 210 | | ![model](images/test/cute_chick.jpeg) | [('一群可爱的小黄鸡在篮子里', 0.8838089108467102), ('可爱的小鸡', 0.07804790884256363), ('一只公鸡在打鸣', 0.03811056911945343), ('一只小狗', 3.069013109779917e-05), ('一只可爱的小猫', 1.8627710005603149e-06), ('清澈的湖水,蓝蓝的天空,茂密的树木', 3.4984658725534246e-08), ('秋天跑车唯美图片桌面壁纸', 1.3271076459986375e-09), ('一只老虎在河边喝水', 1.7967190235612662e-11), ('冬日里,一只老虎在雪地玩耍', 6.9594542802253745e-12), ('一辆公交车停在路边', 1.7240564512588548e-14)] | 211 | | ![model](images/test/dog.jpeg)| [('一只小狗', 0.9999330043792725), ('一只可爱的小猫', 6.655451579717919e-05), ('可爱的小鸡', 3.337503642342199e-07), ('秋天跑车唯美图片桌面壁纸', 1.249009784487498e-07), ('冬日里,一只老虎在雪地玩耍', 1.2343871702569231e-08), ('清澈的湖水,蓝蓝的天空,茂密的树木', 3.481111399139536e-09), ('一只公鸡在打鸣', 2.925292993949391e-11), ('一辆公交车停在路边', 1.3085215203045841e-11), ('一只老虎在河边喝水', 2.5823388566381666e-12), ('一群可爱的小黄鸡在篮子里', 1.0345113437768005e-12)] | 212 | 213 | ### 文本相似度计算 214 | 在计算文本相似度的时候,首先计算两两文本之间的点乘相似度。对于每个文本,将其与自身的相似度置为-10000(否则对于每个文本,其与自身的相似度永远为最大), 215 | 然后将其与所有文本的相似度进行softmax归一化,得到最终的分数。 216 | 217 | | 文本 | 候选文本的相似度 | 218 | |--------------------------|-------| 219 | | 桑巴军团 |[('巴西', 0.6179894804954529), ('佩奇', 0.37836360931396484), ('足球场', 0.0035378895699977875), ('日耳曼战车', 0.00010510809806874022), ('绿茵', 2.4702653718122747e-06), ('德国', 1.4552163065673085e-06), ('一个圆圆的月亮高高挂在天空', 1.4657725699862567e-08), ('北国风光,千里冰封,万里雪飘,银装素裹', 8.691507069613635e-09), ('大猫在飞速狂奔,捕杀猎物', 1.1898879659355543e-09), ('雪花漫天飞舞,狂风怒号,天地之间白茫茫一片', 3.0943728135390813e-10), ('小猪', 2.8596228163202397e-10), ('天街小雨润如酥,草色摇看近却无', 2.0656101828997464e-10), ('一只老虎在草原上追捕一只小鹿', 4.377333912009007e-11), ('夜幕中的白玉盘升起,星光灿烂', 1.956253908863559e-11), ('月光洒在沙滩上,就像铺上了一层白皑皑的雪。燕山上,月亮像钩子一般', 5.170562328987716e-12), ('大漠沙如雪,燕山月似钩', 4.7753528562011205e-12), ('长安街上细密的春雨润滑如酥,远望草色连成一片,近看却又显得稀疏', 3.951113825596548e-16), ('桑巴军团', 0.0)] | 220 | | 日耳曼战车 | [('德国', 0.9601516723632812), ('足球场', 0.0380559079349041), ('桑巴军团', 0.0007087080157361925), ('佩奇', 0.0005535364616662264), ('绿茵', 0.00036988715874031186), ('大猫在飞速狂奔,捕杀猎物', 0.00015183206414803863), ('一只老虎在草原上追捕一只小鹿', 8.405648259213194e-06), ('巴西', 1.0465098654321991e-07), ('北国风光,千里冰封,万里雪飘,银装素裹', 4.625822391801648e-09), ('夜幕中的白玉盘升起,星光灿烂', 1.3382804864292552e-09), ('小猪', 5.867449304197692e-10), ('一个圆圆的月亮高高挂在天空', 6.638854049834109e-11), ('雪花漫天飞舞,狂风怒号,天地之间白茫茫一片', 2.6341435976906524e-11), ('天街小雨润如酥,草色摇看近却无', 2.48930927954083e-11), ('月光洒在沙滩上,就像铺上了一层白皑皑的雪。燕山上,月亮像钩子一般', 1.732185575531453e-11), ('大漠沙如雪,燕山月似钩', 3.200957565414192e-13), ('长安街上细密的春雨润滑如酥,远望草色连成一片,近看却又显得稀疏', 3.148795785744285e-13), ('日耳曼战车', 0.0)] | 221 | | 一个圆圆的月亮高高挂在天空 | [('夜幕中的白玉盘升起,星光灿烂', 0.7875770330429077), ('大漠沙如雪,燕山月似钩', 0.19447773694992065), ('月光洒在沙滩上,就像铺上了一层白皑皑的雪。燕山上,月亮像钩子一般', 0.017945053055882454), ('天街小雨润如酥,草色摇看近却无', 1.723899032413101e-07), ('北国风光,千里冰封,万里雪飘,银装素裹', 2.9736675344338437e-08), ('绿茵', 4.084741433985073e-09), ('雪花漫天飞舞,狂风怒号,天地之间白茫茫一片', 1.6437875505204147e-09), ('佩奇', 5.808487579805899e-10), ('德国', 2.9951585656107227e-10), ('桑巴军团', 1.1388523457611655e-10), ('小猪', 4.8488959375481144e-11), ('大猫在飞速狂奔,捕杀猎物', 2.1691641538534867e-11), ('足球场', 3.278335285530898e-12), ('一只老虎在草原上追捕一只小鹿', 1.7711488742647163e-12), ('长安街上细密的春雨润滑如酥,远望草色连成一片,近看却又显得稀疏', 1.8018553771433077e-13), ('日耳曼战车', 7.650024141719544e-14), ('巴西', 1.2753736123292427e-14), ('一个圆圆的月亮高高挂在天空', 0.0)] | 222 | | 小猪 |[('佩奇', 0.9999858140945435), ('天街小雨润如酥,草色摇看近却无', 1.3750308426097035e-05), ('绿茵', 2.648558847795357e-07), ('北国风光,千里冰封,万里雪飘,银装素裹', 1.0770643399382607e-07), ('足球场', 6.4896809703896e-08), ('一只老虎在草原上追捕一只小鹿', 1.265229077063168e-08), ('雪花漫天飞舞,狂风怒号,天地之间白茫茫一片', 4.490777971710713e-09), ('夜幕中的白玉盘升起,星光灿烂', 3.1008255962916564e-09), ('巴西', 2.9627589270830867e-09), ('一个圆圆的月亮高高挂在天空', 1.0077703116451175e-09), ('长安街上细密的春雨润滑如酥,远望草色连成一片,近看却又显得稀疏', 3.178841634365881e-10), ('大猫在飞速狂奔,捕杀猎物', 2.389905495725486e-10), ('德国', 1.0391849880608817e-10), ('桑巴军团', 4.6177273810288355e-11), ('日耳曼战车', 1.4051987004548572e-11), ('月光洒在沙滩上,就像铺上了一层白皑皑的雪。燕山上,月亮像钩子一般', 6.461968953637084e-14), ('大漠沙如雪,燕山月似钩', 1.0214997115320576e-15), ('小猪', 0.0)] | 223 | | 足球场 |[('绿茵', 0.999913215637207), ('日耳曼战车', 4.463562436285429e-05), ('桑巴军团', 2.7979182050330564e-05), ('佩奇', 1.4129162082099356e-05), ('北国风光,千里冰封,万里雪飘,银装素裹', 1.4699661043948709e-08), ('天街小雨润如酥,草色摇看近却无', 3.888342092750463e-09), ('小猪', 3.1782971809946048e-09), ('雪花漫天飞舞,狂风怒号,天地之间白茫茫一片', 3.047165020308995e-10), ('巴西', 2.8028558640702528e-11), ('夜幕中的白玉盘升起,星光灿烂', 2.1283181814157892e-11), ('德国', 1.7147439718145918e-11), ('大猫在飞速狂奔,捕杀猎物', 3.996368340419831e-12), ('一个圆圆的月亮高高挂在天空', 3.3369006342820473e-12), ('一只老虎在草原上追捕一只小鹿', 1.9575203816842718e-13), ('长安街上细密的春雨润滑如酥,远望草色连成一片,近看却又显得稀疏', 1.5908632308085473e-14), ('大漠沙如雪,燕山月似钩', 1.2009468177814333e-15), ('月光洒在沙滩上,就像铺上了一层白皑皑的雪。燕山上,月亮像钩子一般', 1.0192180694699505e-16), ('足球场', 0.0)] | 224 | | 雪花漫天飞舞,狂风怒号,天地之间白茫茫一片 | [('北国风光,千里冰封,万里雪飘,银装素裹', 1.0), ('天街小雨润如酥,草色摇看近却无', 3.117766895349705e-12), ('一个圆圆的月亮高高挂在天空', 3.364900724592626e-14), ('大猫在飞速狂奔,捕杀猎物', 1.6608293380978786e-14), ('足球场', 6.128196976024243e-15), ('小猪', 4.423127276497862e-15), ('夜幕中的白玉盘升起,星光灿烂', 1.571655234070251e-15), ('长安街上细密的春雨润滑如酥,远望草色连成一片,近看却又显得稀疏', 2.2122742988916772e-16), ('大漠沙如雪,燕山月似钩', 1.76497022842987e-16), ('佩奇', 1.389188296545236e-16), ('桑巴军团', 4.921529099193707e-17), ('巴西', 1.3619025032692556e-17), ('绿茵', 7.592168064477108e-18), ('日耳曼战车', 6.21349124828395e-19), ('德国', 5.029247784564751e-19), ('月光洒在沙滩上,就像铺上了一层白皑皑的雪。燕山上,月亮像钩子一般', 2.3105380141441713e-19), ('一只老虎在草原上追捕一只小鹿', 2.065503593757295e-19), ('雪花漫天飞舞,狂风怒号,天地之间白茫茫一片', 0.0)] | 225 | | 大漠沙如雪,燕山月似钩 | [('月光洒在沙滩上,就像铺上了一层白皑皑的雪。燕山上,月亮像钩子一般', 0.9999990463256836), ('夜幕中的白玉盘升起,星光灿烂', 8.503350841237989e-07), ('北国风光,千里冰封,万里雪飘,银装素裹', 6.838811117404475e-08), ('雪花漫天飞舞,狂风怒号,天地之间白茫茫一片', 1.222678347456707e-10), ('天街小雨润如酥,草色摇看近却无', 1.3596265606430347e-11), ('佩奇', 8.439372247912025e-13), ('桑巴军团', 5.26145750232021e-13), ('lol', 2.938053988467415e-13), ('英雄联盟', 4.792821074370811e-14), ('德国', 3.41292274931744e-14), ('足球场', 1.6731451971509562e-14), ('巴西', 1.5801006607741447e-14), ('绿茵', 1.5472469632555816e-14), ('日耳曼战车', 5.230573715445759e-15), ('一只老虎在草原上追捕一只小鹿', 1.527517829496487e-15), ('小猪', 6.969782548515633e-16), ('长安街上细密的春雨润滑如酥,远望草色连成一片,近看却又显得稀疏', 6.336404662481165e-17), ('大猫在飞速狂奔,捕杀猎物', 3.555133268119002e-17), ('大漠沙如雪,燕山月似钩', 0.0)] | 226 | | 天街小雨润如酥,草色摇看近却无 | [('长安街上细密的春雨润滑如酥,远望草色连成一片,近看却又显得稀疏', 0.9994089603424072), ('夜幕中的白玉盘升起,星光灿烂', 0.0003134367580059916), ('北国风光,千里冰封,万里雪飘,银装素裹', 0.0002538462576922029), ('小猪', 1.5032141163828783e-05), ('一个圆圆的月亮高高挂在天空', 3.916868081432767e-06), ('雪花漫天飞舞,狂风怒号,天地之间白茫茫一片', 3.4605425298650516e-06), ('绿茵', 9.778947287486517e-07), ('大猫在飞速狂奔,捕杀猎物', 1.992641358583569e-07), ('足球场', 8.67964047301939e-08), ('一只老虎在草原上追捕一只小鹿', 9.750069196456934e-09), ('佩奇', 5.847227146915657e-09), ('德国', 9.09259750825342e-11), ('桑巴军团', 3.646501156584492e-11), ('巴西', 2.7616209666292413e-11), ('大漠沙如雪,燕山月似钩', 2.1784638329358508e-11), ('日耳曼战车', 6.517418982970868e-13), ('月光洒在沙滩上,就像铺上了一层白皑皑的雪。燕山上,月亮像钩子一般', 3.130157401722358e-14), ('天街小雨润如酥,草色摇看近却无', 0.0)] | 227 | | 一只老虎在草原上追捕一只小鹿 | [('大猫在飞速狂奔,捕杀猎物', 0.9999905824661255), ('日耳曼战车', 8.540602721041068e-06), ('小猪', 5.367821813706541e-07), ('天街小雨润如酥,草色摇看近却无', 3.7838006505808153e-07), ('佩奇', 9.640578113589982e-09), ('一个圆圆的月亮高高挂在天空', 1.5617185322724936e-09), ('绿茵', 1.2935598148189342e-09), ('桑巴军团', 2.998873926962631e-10), ('足球场', 1.6957589499266845e-10), ('北国风光,千里冰封,万里雪飘,银装素裹', 1.6597206248247787e-11), ('德国', 1.4758482630439218e-11), ('夜幕中的白玉盘升起,星光灿烂', 8.974962266428133e-12), ('雪花漫天飞舞,狂风怒号,天地之间白茫茫一片', 8.89707058027156e-12), ('长安街上细密的春雨润滑如酥,远望草色连成一片,近看却又显得稀疏', 1.3045043560991343e-12), ('大漠沙如雪,燕山月似钩', 9.498044711842013e-14), ('巴西', 7.571923638224551e-14), ('月光洒在沙滩上,就像铺上了一层白皑皑的雪。燕山上,月亮像钩子一般', 5.642470591102474e-15), ('一只老虎在草原上追捕一只小鹿', 0.0)] | 228 | | 英雄联盟 | [('lol', 1.0), ('绿茵', 8.241636929531447e-16), ('足球场', 5.327819702000563e-17), ('桑巴军团', 1.4594147724981305e-17), ('小猪', 5.374621539699783e-19), ('德国', 1.2981428736465124e-19), ('北国风光,千里冰封,万里雪飘,银装素裹', 1.0041148040720617e-19), ('佩奇', 1.3570826911534407e-20), ('一只老虎在草原上追捕一只小鹿', 8.53713238494233e-21), ('夜幕中的白玉盘升起,星光灿烂', 5.394499158339701e-21), ('巴西', 1.3131330665565925e-21), ('天街小雨润如酥,草色摇看近却无', 1.1178349363340949e-21), ('日耳曼战车', 7.541350209812422e-22), ('雪花漫天飞舞,狂风怒号,天地之间白茫茫一片', 5.246146658079659e-22), ('长安街上细密的春雨润滑如酥,远望草色连成一片,近看却又显得稀疏', 1.330888747811307e-22), ('大漠沙如雪,燕山月似钩', 1.0396770843343703e-22), ('月光洒在沙滩上,就像铺上了一层白皑皑的雪。燕山上,月亮像钩子一般', 4.255491910806315e-25), ('大猫在飞速狂奔,捕杀猎物', 2.597710579245674e-25), ('英雄联盟', 0.0)] | 229 | 230 | 231 | ### 图片相似度计算 232 | 与文本相似度的计算方式一致。为便于展示,仅选出top1的图片及其相似度分数。 233 | 234 | 注:由于在训练BertCLIP时,将图像编码器的权重冻结,所以该部分的能力,主要归功于OpenAI的clip预训练权重。 235 | 236 | | 图片 | top1图片 | top1图片分数 | 237 | |---------------|--------------------------------|----------| 238 | |![model](images/test/bus.jpeg) | ![model](images/test/bus2.jpeg) | 1.0 | 239 | |![model](images/test/cat.jpeg) | ![model](images/test/cat2.jpeg) | 0.9999992847442627 | 240 | |![model](images/test/cock.jpeg) | ![model](images/test/cute_chick.jpeg) | 0.9951345324516296 | 241 | |![model](images/test/dog.jpeg) | ![model](images/test/dog2.jpeg) | 0.9999798536300659 | 242 | |![model](images/test/tiger.jpeg) | ![model](images/test/tiger_river.jpeg) | 1.0 | 243 | 244 | 245 | 246 | 247 | 248 | --------------------------------------------------------------------------------