├── README.md ├── cpp ├── CMakeLists.txt ├── include │ ├── Tokenizer.hpp │ └── util.hpp ├── myvocab.txt ├── pokemon.jpeg └── src │ ├── main.cpp │ └── util.cpp ├── python ├── clip_tokenizer.py ├── main.py ├── myvocab.txt └── pokemon.jpeg └── testimgs ├── 000000005082.jpg ├── 000000014309.jpg ├── 000000024283.jpg ├── 000000025111.jpg ├── 000000043121.jpg ├── 000000047693.jpg ├── 000000057110.jpg ├── 000000060398.jpg ├── 000000060621.jpg ├── 000000064549.jpg ├── 000000064840.jpg ├── 000000075717.jpg ├── 000000076516.jpg └── 000000076706.jpg /README.md: -------------------------------------------------------------------------------- 1 | OpenAI发布的Clip是一个基于图像和文本并行的多模态模型,CLIP模型的效果实现了图像和文本向同一个特征空间映射的能力。 2 | 当进行图像识别时,我们将待识别的图像映射成一个特征向量。同时我们将所有的类别文本转换成一个句子,然后将这个句子映射成另外一组特征向量。 3 | 4 | 文本特征向量和图像特征向量最相近的那一个便是我们要识别的目标图像的类。 5 | Clip的玩法有很多种: 6 | (1). 7 | 将图像与类别文本相比较,以得到最接近的分类结果。 8 | (2). 9 | 对于给定的输入文本和图片相册,在文本与图片相册中找到最匹配的结果。 10 | (3). 11 | 对于图像生成模型,结合生成结果与输入文本间的 CLIP 距离生成更好的图片。例如,扩散模型里面就用到了clip。 12 | (4). 13 | 利用 CLIP encoding 提取特征,以该特征的映射结果结合 GPT2 生成 caption。 14 | (5).搜索相似图片。 15 | 16 | 这里面我最感兴趣的是第二项:给出一句话来描述想要的图片,就能从图库中搜出来符合要求的。现在市场上已经有这个功能的软件了,例如这里 17 | https://github.com/mazzzystar/Queryable 18 | 19 | 20 | 在github上clip部署的程序,文本都是英文的,于是我就想着编写一套中文clip的图文检索程序。 21 | 图像模块和文本模块的onnx文件在百度云盘,链接:https://pan.baidu.com/s/18eBA19kMqdJpP5muV9V18w 22 | 提取码:d30y 23 | 24 | 这套程序有C++和Python两个版本的,其中在编写C++程序时,遇到了一个坑。起初在win10系统里,编写完代码之后编译运行, 25 | 发现在TokenizerClipChinese里,切割中文字符串里的每个汉字时,出现了乱码,后来把程序放在ubuntu系统运行,就不会出现中文乱码了。 26 | 切割中文字符串里的每个汉字的函数split_chinese,代码来自这里 https://github.com/Shellbye/Shellbye.github.io/issues/27 27 | 从上面才到的这个坑,可以看出函数split_chinese,只适用于linux系统,在win10系统会出现乱码的。 28 | 并且还需要注意,在使用std::filesystem时,需要指定c++17来编译的。本仓库里的C++程序,需要放在ubuntu系统编译运行,才能得到正确的结果。 29 | 30 | 31 | 程序很简陋,感兴趣的开发者可以添加一个图形界面,显示输入文字和图库中搜出来符合要求的图片,这样看起来更直观。 32 | 此外,保存特征特征向量和计算特征向量之间的相似性,可以使用faiss库,它的速度非常快,是毫秒级的。因为在真实世界里的相册, 33 | 很有可能相册里的图片数量是百万级甚至是亿级的。 34 | 35 | -------------------------------------------------------------------------------- /cpp/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | project(test) 2 | set(CMAKE_CXX_STANDARD 17) 3 | 4 | add_executable(${PROJECT_NAME} ${CMAKE_CURRENT_SOURCE_DIR}/src/main.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/util.cpp) 5 | 6 | target_include_directories(${PROJECT_NAME} 7 | PUBLIC "/usr/local/include/opencv4" 8 | PUBLIC "/opt/onnxruntime-linux-x64-1.11.1/include" 9 | PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include 10 | ) 11 | target_link_libraries(${PROJECT_NAME} 12 | "/usr/local/lib/libopencv_imgcodecs.so.4.8.0" 13 | "/usr/local/lib/libopencv_highgui.so.4.8.0" 14 | "/usr/local/lib/libopencv_imgproc.so.4.8.0" 15 | "/usr/local/lib/libopencv_core.so.4.8.0" 16 | "/usr/local/lib/libopencv_dnn.so.4.8.0" 17 | "/opt/onnxruntime-linux-x64-1.11.1/lib/libonnxruntime.so.1.11.1" 18 | ) -------------------------------------------------------------------------------- /cpp/include/Tokenizer.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "map" 3 | #include "vector" 4 | #include "string" 5 | #include "fstream" 6 | #include "iostream" 7 | 8 | std::vector split_chinese(std::string s); 9 | std::vector split_chinese(std::string s) { 10 | std::vector t; 11 | for (size_t i = 0; i < s.length();) { 12 | int cplen = 1; 13 | // 以下的几个if,要参考这里 https://en.wikipedia.org/wiki/UTF-8#Description 14 | if ((s[i] & 0xf8) == 0xf0) // 11111000, 11110000 15 | cplen = 4; 16 | else if ((s[i] & 0xf0) == 0xe0) // 11100000 17 | cplen = 3; 18 | else if ((s[i] & 0xe0) == 0xc0) // 11000000 19 | cplen = 2; 20 | if ((i + cplen) > s.length()) 21 | cplen = 1; 22 | t.push_back(s.substr(i, cplen)); 23 | i += cplen; 24 | } 25 | return t; 26 | } 27 | 28 | class TokenizerBase 29 | { 30 | protected: 31 | std::map tokenizer_token2idx; 32 | 33 | public: 34 | virtual bool load_tokenize(std::string vocab_path) = 0; 35 | virtual void encode_text(std::string text, std::vector &idx) = 0; 36 | }; 37 | 38 | class TokenizerClip : public TokenizerBase 39 | { 40 | protected: 41 | std::vector stringSplit(const std::string &str, char delim) 42 | { 43 | std::vector elems; 44 | auto lastPos = str.find_first_not_of(delim, 0); 45 | auto pos = str.find_first_of(delim, lastPos); 46 | while (pos != std::string::npos || lastPos != std::string::npos) 47 | { 48 | elems.push_back(str.substr(lastPos, pos - lastPos)); 49 | lastPos = str.find_first_not_of(delim, pos); 50 | pos = str.find_first_of(delim, lastPos); 51 | } 52 | return elems; 53 | } 54 | 55 | void tokenize(std::string token, std::vector &idx) 56 | { 57 | idx.push_back(49406); 58 | { 59 | std::vector tokens = stringSplit(token, ' '); 60 | for (auto t : tokens) 61 | { 62 | idx.push_back(tokenizer_token2idx[t + ""]); 63 | } 64 | } 65 | idx.push_back(49407); 66 | 67 | // memset(feat, 0, sizeof(CLIP_TEXT_FEATURE_T)); 68 | // memcpy(feat->feature, idx.data(), idx.size() * sizeof(int)); 69 | } 70 | 71 | public: 72 | bool load_tokenize(std::string vocab_path) override 73 | { 74 | std::ifstream infile; 75 | infile.open(vocab_path.data()); 76 | if (!infile.good()) 77 | { 78 | return false; 79 | } 80 | 81 | std::string s; 82 | int idx = 0; 83 | while (getline(infile, s)) 84 | { 85 | tokenizer_token2idx.insert(std::pair(s, idx)); 86 | idx++; 87 | } 88 | infile.close(); 89 | return true; 90 | } 91 | 92 | void encode_text(std::string text, std::vector &idx) override 93 | { 94 | idx.clear(); 95 | return tokenize(text, idx); 96 | } 97 | }; 98 | 99 | class TokenizerClipChinese : public TokenizerClip 100 | { 101 | public: 102 | 103 | bool load_tokenize(std::string vocab_path) override 104 | { 105 | std::ifstream infile; 106 | infile.open(vocab_path.data()); 107 | if (!infile.good()) 108 | { 109 | return false; 110 | } 111 | 112 | std::string s; 113 | int idx = 0; 114 | while (getline(infile, s)) ////在win10系统, c++解析含有中文的txt文档,有乱码现象 115 | { 116 | tokenizer_token2idx.insert(std::pair(s, idx)); 117 | 118 | idx++; 119 | } 120 | infile.close(); 121 | return true; 122 | } 123 | 124 | void encode_text(std::string text, std::vector &idx) override 125 | { 126 | #define CLS 101 127 | #define SEP 102 128 | ////ALOGD("%s\n", text.c_str()); 129 | idx.clear(); 130 | idx.push_back(CLS); 131 | 132 | for (size_t i = 0; i < text.length();) 133 | { 134 | int cplen = 1; 135 | if ((text[i] & 0xf8) == 0xf0) 136 | cplen = 4; // 占用4个字节,前5位为11110 137 | else if ((text[i] & 0xf0) == 0xe0) 138 | cplen = 3; // 占用3个字节,前4位为1110 139 | else if ((text[i] & 0xe0) == 0xc0) 140 | cplen = 2; // 占用2个字节,前3位为110 141 | // 个人感觉这行代码好像没什么用,如果三种情况都不符合,那么cplen就为初始化的0,是符合utf-8编码定义的 142 | if ((i + cplen) > text.length()) 143 | cplen = 1; 144 | auto tmp = text.substr(i, cplen); 145 | i += cplen; 146 | idx.push_back(tokenizer_token2idx[tmp]); 147 | // std::cout << idx[idx.size() - 1] << std::endl; 148 | } 149 | 150 | /*{ 151 | ///std::vector tokens = stringSplit(text, ' '); 152 | std::vector tokens = split_chinese(text); 153 | 154 | for (auto t : tokens) 155 | { 156 | std::cout << t << ","; 157 | if (tokenizer_token2idx.count(t) > 0) 158 | { 159 | idx.push_back(tokenizer_token2idx[t]); 160 | } 161 | else 162 | { 163 | for (size_t i = 0; i < t.length();) 164 | { 165 | int cplen = 1; 166 | if ((t[i] & 0xf8) == 0xf0) 167 | cplen = 4; // 占用4个字节,前5位为11110 168 | else if ((t[i] & 0xf0) == 0xe0) 169 | cplen = 3; // 占用3个字节,前4位为1110 170 | else if ((t[i] & 0xe0) == 0xc0) 171 | cplen = 2; // 占用2个字节,前3位为110 172 | // 个人感觉这行代码好像没什么用,如果三种情况都不符合,那么cplen就为初始化的0,是符合utf-8编码定义的 173 | if ((i + cplen) > t.length()) 174 | cplen = 1; 175 | auto tmp = t.substr(i, cplen); 176 | i += cplen; 177 | idx.push_back(tokenizer_token2idx[tmp]); 178 | // std::cout << idx[idx.size() - 1] << std::endl; 179 | } 180 | } 181 | } 182 | }*/ 183 | 184 | idx.push_back(SEP); 185 | return; 186 | } 187 | }; 188 | -------------------------------------------------------------------------------- /cpp/include/util.hpp: -------------------------------------------------------------------------------- 1 | #ifndef UTIL_HPP 2 | #define UTIL_HPP 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | std::vector listdir(const std::string image_dir); 10 | void softmax(std::vector &input); ///单张图片的,不考虑batchsize多个图片的 11 | 12 | int write_image_feature_name2bin(int len_feature, const float* output, const std::vector imglist, const char* bin_name); 13 | float* read_image_feature_name2bin(int* imgnum, int* len_feature, std::vector& imglist, const char* bin_name); 14 | 15 | void copyfile_dstpath(std::vector> imglist, std::string savepath); 16 | 17 | // 实现argsort功能 ,模板定义通常写在头文件里 18 | template std::vector argsort_ascend(const std::vector& array) 19 | { 20 | const int array_len(array.size()); 21 | std::vector array_index(array_len, 0); 22 | for (int i = 0; i < array_len; ++i) ////std::iota(array_index.begin(), array_index.end(), 0); ////不用for循环的方法 23 | array_index[i] = i; 24 | 25 | std::sort(array_index.begin(), array_index.end(), 26 | [&array](int pos1, int pos2) {return (array[pos1] < array[pos2]); }); 27 | 28 | return array_index; 29 | } 30 | 31 | #endif 32 | -------------------------------------------------------------------------------- /cpp/pokemon.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/Chinese-CLIP-opencv-onnxrun/7d21b5954d93432fe58fbea72285cd1358492090/cpp/pokemon.jpeg -------------------------------------------------------------------------------- /cpp/src/main.cpp: -------------------------------------------------------------------------------- 1 | #define _CRT_SECURE_NO_WARNINGS 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | //#include 9 | #include 10 | #include "Tokenizer.hpp" 11 | #include "util.hpp" 12 | 13 | using namespace cv; 14 | using namespace std; 15 | using namespace dnn; 16 | using namespace Ort; 17 | 18 | typedef struct 19 | { 20 | string name; 21 | float prob; 22 | } class_info; 23 | 24 | class Clip 25 | { 26 | public: 27 | Clip(string image_modelpath, string text_modelpath, string vocab_path); 28 | void generate_image_feature(Mat cv_image); 29 | void generate_text_feature(std::vector texts); 30 | class_info zero_shot_image_classify(Mat cv_image, std::vector texts); 31 | void generate_imagedir_features(const std::string image_dir, const char* bin_name); 32 | std::vector> input_text_search_image(std::string text, const float* image_features, const std::vector imglist); 33 | 34 | private: 35 | Net net; ////image_model 36 | Mat normalize_(Mat img); 37 | const int inpWidth = 224; 38 | const int inpHeight = 224; 39 | float mean[3] = { 0.48145466, 0.4578275, 0.40821073 }; 40 | float std[3] = { 0.26862954, 0.26130258, 0.27577711 }; 41 | 42 | std::shared_ptr tokenizer; 43 | std::vector image_features_input; 44 | std::vector> text_features_input; 45 | std::vector text_tokens_input; 46 | bool load_tokenizer(std::string vocab_path); 47 | 48 | Env env = Env(ORT_LOGGING_LEVEL_ERROR, "CLIP_text_model"); 49 | Ort::Session *ort_session = nullptr; ////text_model 50 | SessionOptions sessionOptions = SessionOptions(); 51 | vector input_names; 52 | vector output_names; 53 | vector> input_node_dims; // >=1 outputs 54 | vector> output_node_dims; // >=1 outputs 55 | const int context_length = 52; 56 | const int len_text_feature = 512; 57 | }; 58 | 59 | Clip::Clip(string image_modelpath, string text_modelpath, string vocab_path) 60 | { 61 | this->net = readNet(image_modelpath); ///opencv4.5��ȡ������.opencv4.7�Ϳ��Լ��سɹ� 62 | 63 | //OrtStatus* status = OrtSessionOptionsAppendExecutionProvider_CUDA(sessionOptions, 0); 64 | sessionOptions.SetGraphOptimizationLevel(ORT_ENABLE_BASIC); 65 | ort_session = new Session(env, text_modelpath.c_str(), sessionOptions); 66 | size_t numInputNodes = ort_session->GetInputCount(); 67 | size_t numOutputNodes = ort_session->GetOutputCount(); 68 | AllocatorWithDefaultOptions allocator; 69 | for (int i = 0; i < numInputNodes; i++) 70 | { 71 | input_names.push_back(ort_session->GetInputName(i, allocator)); 72 | Ort::TypeInfo input_type_info = ort_session->GetInputTypeInfo(i); 73 | auto input_tensor_info = input_type_info.GetTensorTypeAndShapeInfo(); 74 | auto input_dims = input_tensor_info.GetShape(); 75 | input_node_dims.push_back(input_dims); 76 | } 77 | for (int i = 0; i < numOutputNodes; i++) 78 | { 79 | output_names.push_back(ort_session->GetOutputName(i, allocator)); 80 | Ort::TypeInfo output_type_info = ort_session->GetOutputTypeInfo(i); 81 | auto output_tensor_info = output_type_info.GetTensorTypeAndShapeInfo(); 82 | auto output_dims = output_tensor_info.GetShape(); 83 | output_node_dims.push_back(output_dims); 84 | } 85 | ////context_length = input_node_dims[0][1]; 86 | ///len_text_feature = output_node_dims[0][1]; 87 | this->load_tokenizer(vocab_path); 88 | } 89 | 90 | bool Clip::load_tokenizer(std::string vocab_path) 91 | { 92 | tokenizer.reset(new TokenizerClipChinese); 93 | this->text_tokens_input = std::vector(1024 * this->context_length); 94 | return tokenizer->load_tokenize(vocab_path); 95 | } 96 | 97 | Mat Clip::normalize_(Mat img) 98 | { 99 | Mat rgbimg; 100 | cvtColor(img, rgbimg, COLOR_BGR2RGB); 101 | vector rgbChannels(3); 102 | split(rgbimg, rgbChannels); 103 | for (int c = 0; c < 3; c++) 104 | { 105 | rgbChannels[c].convertTo(rgbChannels[c], CV_32FC1, 1.0 / (255.0* std[c]), (0.0 - mean[c]) / std[c]); 106 | } 107 | Mat m_normalized_mat; 108 | merge(rgbChannels, m_normalized_mat); 109 | return m_normalized_mat; 110 | } 111 | 112 | void Clip::generate_image_feature(Mat srcimg) 113 | { 114 | Mat temp_image; 115 | resize(srcimg, temp_image, cv::Size(this->inpWidth, this->inpHeight), 0, 0, INTER_CUBIC); 116 | Mat normalized_mat = this->normalize_(temp_image); 117 | Mat blob = blobFromImage(normalized_mat); 118 | this->net.setInput(blob); 119 | vector outs; 120 | ////net.enableWinograd(false); ////如果是opencv4.7,那就需要加上这一行 121 | this->net.forward(outs, this->net.getUnconnectedOutLayersNames()); 122 | float* ptr_feat = (float*)outs[0].data; 123 | const int len_image_feature = outs[0].size[1]; ///忽律第0维batchsize=1, len_image_feature是定值512,跟len_text_feature相等的, 也可以写死在类成员变量里 124 | this->image_features_input.resize(len_image_feature); 125 | float norm = 0.0; 126 | for (int i = 0; i < len_image_feature; i++) 127 | { 128 | norm += ptr_feat[i] * ptr_feat[i]; 129 | } 130 | norm = sqrt(norm); 131 | for (int i = 0; i < len_image_feature; i++) 132 | { 133 | this->image_features_input[i] = ptr_feat[i] / norm; 134 | } 135 | } 136 | 137 | void Clip::generate_text_feature(std::vector texts) 138 | { 139 | std::vector> text_token; 140 | text_token.resize(texts.size()); 141 | for (size_t i = 0; i < texts.size(); i++) 142 | { 143 | this->tokenizer->encode_text(texts[i], text_token[i]); 144 | } 145 | 146 | if (text_token.size() * this->context_length > text_tokens_input.size()) 147 | { 148 | text_tokens_input.resize(text_token.size() * this->context_length); 149 | } 150 | 151 | memset(text_tokens_input.data(), 0, text_token.size() * this->context_length * sizeof(int)); 152 | auto text_tokens_input_ptr = text_tokens_input.data(); 153 | for (size_t i = 0; i < text_token.size(); i++) 154 | { 155 | if (text_token[i].size() > this->context_length) 156 | { 157 | printf("text_features index %ld ,bigger than %d\n", i, this->context_length); 158 | continue; 159 | } 160 | memcpy(text_tokens_input_ptr + i * this->context_length, text_token[i].data(), text_token[i].size() * sizeof(int)); 161 | } 162 | 163 | std::vector text_token_shape = { 1, this->context_length }; 164 | this->text_features_input.resize(text_token.size()); 165 | 166 | std::vector text_tokens_input_64(texts.size() * this->context_length); 167 | for (size_t i = 0; i < text_tokens_input_64.size(); i++) 168 | { 169 | text_tokens_input_64[i] = text_tokens_input[i]; 170 | } 171 | 172 | auto allocator_info = MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); 173 | for (size_t i = 0; i < text_token.size(); i++) 174 | { 175 | auto inputTensor = Ort::Value::CreateTensor(allocator_info, text_tokens_input_64.data() + i * this->context_length, this->context_length, text_token_shape.data(), text_token_shape.size()); 176 | 177 | Ort::RunOptions runOptions; 178 | vector ort_outputs = ort_session->Run(runOptions, &input_names[0], &inputTensor, 1, output_names.data(), output_names.size()); 179 | const float *text_feature_ptr = ort_outputs[0].GetTensorMutableData(); 180 | 181 | this->text_features_input[i].resize(this->len_text_feature); 182 | float norm = 0.0; 183 | for (int j = 0; j < this->len_text_feature; j++) 184 | { 185 | norm += text_feature_ptr[j] * text_feature_ptr[j]; 186 | } 187 | norm = sqrt(norm); 188 | for (int j = 0; j < this->len_text_feature; j++) 189 | { 190 | this->text_features_input[i][j] = text_feature_ptr[j] / norm; 191 | } 192 | 193 | } 194 | } 195 | 196 | class_info Clip::zero_shot_image_classify(Mat cv_image, std::vector texts) 197 | { 198 | this->generate_image_feature(cv_image); 199 | this->generate_text_feature(texts); 200 | vector logits_per_image(texts.size()); ///单张图片的,不考虑batchsize多个图片的 201 | for (int i = 0; i < this->text_features_input.size(); i++) 202 | { 203 | float sum = 0; 204 | for (int j = 0; j < len_text_feature; j++) 205 | { 206 | sum += this->image_features_input[j] * this->text_features_input[i][j]; ////图片特征向量跟文本特征向量做内积 207 | } 208 | logits_per_image[i] = 100 * sum; 209 | } 210 | softmax(logits_per_image); 211 | int maxPosition = std::max_element(logits_per_image.begin(), logits_per_image.end()) - logits_per_image.begin(); ///最大值的下标 212 | class_info result = { texts[maxPosition], logits_per_image[maxPosition] }; 213 | return result; 214 | } 215 | 216 | void Clip::generate_imagedir_features(const std::string image_dir, const char* bin_name) 217 | { 218 | std::vector imglist = listdir(image_dir); 219 | const int imgnum = imglist.size(); 220 | cout<<"遍历到"<len_text_feature]; 222 | for(int i=0;igenerate_image_feature(srcimg); 227 | memcpy(imagedir_features + i * len_text_feature, this->image_features_input.data(), len_text_feature * sizeof(float)); 228 | } 229 | 230 | write_image_feature_name2bin(this->len_text_feature, imagedir_features, imglist, bin_name); 231 | 232 | delete [] imagedir_features; 233 | imagedir_features = nullptr; 234 | } 235 | 236 | std::vector> Clip::input_text_search_image(std::string text, const float* image_features, const std::vector imglist) 237 | { 238 | const int imgnum = imglist.size(); 239 | std::vector texts = {text}; 240 | this->generate_text_feature(texts); 241 | vector logits_per_image(imgnum); 242 | for (int i = 0; i < imgnum; i++) 243 | { 244 | float sum = 0; 245 | for (int j = 0; j < len_text_feature; j++) 246 | { 247 | sum += image_features[i*len_text_feature+j] * this->text_features_input[0][j]; ////图片特征向量跟文本特征向量做内积 248 | } 249 | logits_per_image[i] = 100 * sum; 250 | } 251 | softmax(logits_per_image); 252 | std::vector index = argsort_ascend(logits_per_image); ///注意此处是从小到大排列的 253 | std::vector> top5imglist(5); 254 | for(int i=0;i result = std::make_tuple(imglist[ind], logits_per_image[ind]); 258 | top5imglist[i] = result; 259 | } 260 | return top5imglist; 261 | } 262 | 263 | 264 | int main() 265 | { 266 | Clip mynet("/project/chinese-clip-cpp/image_model.onnx", "/project/chinese-clip-cpp/text_model.onnx", "/project/chinese-clip-cpp/myvocab.txt"); 267 | 268 | const std::string image_dir = "/project/chinese-clip-cpp/testimgs"; 269 | const char* bin_name = "image_features.bin"; 270 | 271 | ///第一步,输入文件夹,生成图片的特征向量,保存到数据库文件 272 | ///mynet.generate_imagedir_features(image_dir, bin_name); 273 | 274 | ///第二步,输入一句话, 计算最相似的图片 275 | string input_text = "踢足球的人"; 276 | std::string savepath = "/project/chinese-clip-cpp/resultimgs"; 277 | 278 | int imgnum = 0, len_feature = 0; 279 | vector imglist; 280 | float* imagedir_features = read_image_feature_name2bin(&imgnum, &len_feature, imglist, bin_name); 281 | printf("读取 %s 成功\n", bin_name); 282 | cout<<"有"< texts = { "杰尼龟", "妙蛙种子", "小火龙", "皮卡丘" }; 295 | Mat srcimg = imread(imgpath); 296 | class_info result = mynet.zero_shot_image_classify(srcimg, texts); 297 | cout << "最大概率:" << result.prob << ", 对应类别:" << result.name << endl;*/ 298 | 299 | return 0; 300 | } -------------------------------------------------------------------------------- /cpp/src/util.cpp: -------------------------------------------------------------------------------- 1 | #include "util.hpp" 2 | #include 3 | #include 4 | 5 | std::vector listdir(const std::string image_dir) 6 | { 7 | const std::vector exts{".jpg", ".jpeg", ".JPG", ".JPEG", ".png", ".PNG",".heic", ".heif", ".bmp", ".webp", ".dib", ".pbm", ".pgm", ".ppm", ".emf", ".wmf", ".tiff", ".tif"}; 8 | std::vector imglist; 9 | for (const auto &entry : std::filesystem::directory_iterator(image_dir)) ///c++17标准里才有std::filesystem函数的,你也可以用C语言函数opendir来实现遍历文件夹里的全部图片,又或者使用cv::glob 10 | { 11 | if (entry.is_regular_file()) 12 | { 13 | std::string ext = entry.path().extension(); 14 | 15 | std::transform(ext.begin(), ext.end(), ext.begin(), [](unsigned char c) 16 | { return std::tolower(c); }); 17 | 18 | if (std::find(exts.begin(), exts.end(), ext) != exts.end()) 19 | { 20 | imglist.push_back(entry.path().string()); 21 | } 22 | } 23 | } 24 | 25 | if (imglist.empty()) 26 | { 27 | std::cout<<"input imgdir is empty!"< &input) ///单张图片的,不考虑batchsize多个图片的 33 | { 34 | const int length = input.size(); 35 | std::vector exp_x(length); 36 | float maxVal = *std::max_element(input.begin(), input.end()); 37 | float sum = 0; 38 | for (int i = 0; i < length; i++) 39 | { 40 | const float expval = std::exp(input[i] - maxVal); 41 | exp_x[i] = expval; 42 | sum += expval; 43 | } 44 | for (int i = 0; i < length; i++) 45 | { 46 | input[i] = exp_x[i] / sum; 47 | } 48 | } 49 | 50 | int write_image_feature_name2bin(int len_feature, const float* output, const std::vector imglist, const char* bin_name) 51 | { 52 | const int imgnum = imglist.size(); 53 | FILE* fp = fopen(bin_name, "wb"); 54 | fwrite(&imgnum, sizeof(int), 1, fp); 55 | fwrite(&len_feature, sizeof(int), 1, fp); 56 | fwrite(output, sizeof(float), imgnum * len_feature, fp); 57 | for (int i = 0; i < imglist.size(); i++) //// num_face == names.size(); 58 | { 59 | int len_s = imglist[i].length(); 60 | fwrite(&len_s, sizeof(int), 1, fp); ///字符串的长度 61 | fwrite(imglist[i].c_str(), sizeof(char), len_s + 1, fp); ///字符串末尾'\0'也算一个字符的 62 | } 63 | fclose(fp); 64 | return 0; 65 | } 66 | 67 | float* read_image_feature_name2bin(int* imgnum, int* len_feature, std::vector& imglist, const char* bin_name) 68 | { 69 | FILE* fp = fopen(bin_name, "rb"); 70 | fread(imgnum, sizeof(int), 1, fp); 71 | fread(len_feature, sizeof(int), 1, fp); 72 | int len = (*imgnum) * (*len_feature); 73 | float* output = new float[len]; 74 | fread(output, sizeof(float), len, fp);//导入数据 75 | for (int i = 0; i < *imgnum; i++) 76 | { 77 | int len_s = 0; 78 | fread(&len_s, sizeof(int), 1, fp); 79 | char* name = new char[len_s + 1]; ///字符串末尾'\0'也算一个字符的 80 | fread(name, sizeof(char), len_s + 1, fp);//导入数据 81 | //cout << name << endl; 82 | imglist.push_back(name); 83 | delete[] name; 84 | name = nullptr; 85 | } 86 | fclose(fp);//关闭文件。 87 | return output; 88 | } 89 | 90 | void copyfile_dstpath(std::vector> imglist, std::string savepath) 91 | { 92 | if (std::filesystem::exists(savepath)) 93 | { 94 | std::filesystem::remove_all(savepath); 95 | } 96 | std::filesystem::create_directories(savepath); 97 | 98 | for(int i=0;i(imglist[i]); 101 | ////float score = get<1>(imglist[i]); 102 | std::filesystem::copy(imgpath, savepath); 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /python/clip_tokenizer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | """Tokenization classes.""" 18 | import collections 19 | import re 20 | import unicodedata 21 | import six 22 | from functools import lru_cache 23 | import os 24 | import numpy as np 25 | @lru_cache() 26 | def default_vocab(): 27 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "myvocab.txt") 28 | 29 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): 30 | """Checks whether the casing config is consistent with the checkpoint name.""" 31 | 32 | # The casing has to be passed in by the user and there is no explicit check 33 | # as to whether it matches the checkpoint. The casing information probably 34 | # should have been stored in the bert_config.json file, but it's not, so 35 | # we have to heuristically detect it to validate. 36 | 37 | if not init_checkpoint: 38 | return 39 | 40 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint) 41 | if m is None: 42 | return 43 | 44 | model_name = m.group(1) 45 | 46 | lower_models = [ 47 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12", 48 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12" 49 | ] 50 | 51 | cased_models = [ 52 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", 53 | "multi_cased_L-12_H-768_A-12" 54 | ] 55 | 56 | is_bad_config = False 57 | if model_name in lower_models and not do_lower_case: 58 | is_bad_config = True 59 | actual_flag = "False" 60 | case_name = "lowercased" 61 | opposite_flag = "True" 62 | 63 | if model_name in cased_models and do_lower_case: 64 | is_bad_config = True 65 | actual_flag = "True" 66 | case_name = "cased" 67 | opposite_flag = "False" 68 | 69 | if is_bad_config: 70 | raise ValueError( 71 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " 72 | "However, `%s` seems to be a %s model, so you " 73 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches " 74 | "how the model was pre-training. If this error is wrong, please " 75 | "just comment out this check." % (actual_flag, init_checkpoint, 76 | model_name, case_name, opposite_flag)) 77 | 78 | 79 | def convert_to_unicode(text): 80 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 81 | if six.PY3: 82 | if isinstance(text, str): 83 | return text 84 | elif isinstance(text, bytes): 85 | return text.decode("utf-8", "ignore") 86 | else: 87 | raise ValueError("Unsupported string type: %s" % (type(text))) 88 | elif six.PY2: 89 | if isinstance(text, str): 90 | return text.decode("utf-8", "ignore") 91 | elif isinstance(text, unicode): 92 | return text 93 | else: 94 | raise ValueError("Unsupported string type: %s" % (type(text))) 95 | else: 96 | raise ValueError("Not running on Python2 or Python 3?") 97 | 98 | 99 | def printable_text(text): 100 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 101 | 102 | # These functions want `str` for both Python2 and Python3, but in one case 103 | # it's a Unicode string and in the other it's a byte string. 104 | if six.PY3: 105 | if isinstance(text, str): 106 | return text 107 | elif isinstance(text, bytes): 108 | return text.decode("utf-8", "ignore") 109 | else: 110 | raise ValueError("Unsupported string type: %s" % (type(text))) 111 | elif six.PY2: 112 | if isinstance(text, str): 113 | return text 114 | elif isinstance(text, unicode): 115 | return text.encode("utf-8") 116 | else: 117 | raise ValueError("Unsupported string type: %s" % (type(text))) 118 | else: 119 | raise ValueError("Not running on Python2 or Python 3?") 120 | 121 | 122 | def load_vocab(vocab_file): 123 | """Loads a vocabulary file into a dictionary.""" 124 | vocab = collections.OrderedDict() 125 | index = 0 126 | with open(vocab_file, "r", encoding="utf-8") as reader: 127 | while True: 128 | token = convert_to_unicode(reader.readline()) 129 | if not token: 130 | break 131 | token = token.strip() 132 | vocab[token] = index 133 | index += 1 134 | return vocab 135 | 136 | 137 | def convert_by_vocab(vocab, items): 138 | """Converts a sequence of [tokens|ids] using the vocab.""" 139 | output = [] 140 | for item in items: 141 | output.append(vocab[item]) 142 | return output 143 | 144 | 145 | def convert_tokens_to_ids(vocab, tokens): 146 | return convert_by_vocab(vocab, tokens) 147 | 148 | 149 | def convert_ids_to_tokens(inv_vocab, ids): 150 | return convert_by_vocab(inv_vocab, ids) 151 | 152 | 153 | def whitespace_tokenize(text): 154 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 155 | text = text.strip() 156 | if not text: 157 | return [] 158 | tokens = text.split() 159 | return tokens 160 | 161 | 162 | class FullTokenizer(object): 163 | """Runs end-to-end tokenziation.""" 164 | 165 | def __init__(self, vocab_file=default_vocab(), do_lower_case=True): 166 | self.vocab = load_vocab(vocab_file) 167 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 168 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 169 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 170 | 171 | def tokenize(self, text): 172 | split_tokens = [] 173 | for token in self.basic_tokenizer.tokenize(text): 174 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 175 | split_tokens.append(sub_token) 176 | 177 | return split_tokens 178 | 179 | def convert_tokens_to_ids(self, tokens): 180 | return convert_by_vocab(self.vocab, tokens) 181 | 182 | def convert_ids_to_tokens(self, ids): 183 | return convert_by_vocab(self.inv_vocab, ids) 184 | 185 | @staticmethod 186 | def convert_tokens_to_string(tokens, clean_up_tokenization_spaces=True): 187 | """ Converts a sequence of tokens (string) in a single string. """ 188 | 189 | def clean_up_tokenization(out_string): 190 | """ Clean up a list of simple English tokenization artifacts 191 | like spaces before punctuations and abreviated forms. 192 | """ 193 | out_string = ( 194 | out_string.replace(" .", ".") 195 | .replace(" ?", "?") 196 | .replace(" !", "!") 197 | .replace(" ,", ",") 198 | .replace(" ' ", "'") 199 | .replace(" n't", "n't") 200 | .replace(" 'm", "'m") 201 | .replace(" 's", "'s") 202 | .replace(" 've", "'ve") 203 | .replace(" 're", "'re") 204 | ) 205 | return out_string 206 | 207 | text = ' '.join(tokens).replace(' ##', '').strip() 208 | if clean_up_tokenization_spaces: 209 | clean_text = clean_up_tokenization(text) 210 | return clean_text 211 | else: 212 | return text 213 | 214 | def vocab_size(self): 215 | return len(self.vocab) 216 | 217 | 218 | class BasicTokenizer(object): 219 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 220 | 221 | def __init__(self, do_lower_case=True): 222 | """Constructs a BasicTokenizer. 223 | 224 | Args: 225 | do_lower_case: Whether to lower case the input. 226 | """ 227 | self.do_lower_case = do_lower_case 228 | 229 | def tokenize(self, text): 230 | """Tokenizes a piece of text.""" 231 | text = convert_to_unicode(text) 232 | text = self._clean_text(text) 233 | 234 | # This was added on November 1st, 2018 for the multilingual and Chinese 235 | # models. This is also applied to the English models now, but it doesn't 236 | # matter since the English models were not trained on any Chinese data 237 | # and generally don't have any Chinese data in them (there are Chinese 238 | # characters in the vocabulary because Wikipedia does have some Chinese 239 | # words in the English Wikipedia.). 240 | text = self._tokenize_chinese_chars(text) 241 | 242 | orig_tokens = whitespace_tokenize(text) 243 | split_tokens = [] 244 | for token in orig_tokens: 245 | if self.do_lower_case: 246 | token = token.lower() 247 | token = self._run_strip_accents(token) 248 | split_tokens.extend(self._run_split_on_punc(token)) 249 | 250 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 251 | return output_tokens 252 | 253 | def _run_strip_accents(self, text): 254 | """Strips accents from a piece of text.""" 255 | text = unicodedata.normalize("NFD", text) 256 | output = [] 257 | for char in text: 258 | cat = unicodedata.category(char) 259 | if cat == "Mn": 260 | continue 261 | output.append(char) 262 | return "".join(output) 263 | 264 | def _run_split_on_punc(self, text): 265 | """Splits punctuation on a piece of text.""" 266 | chars = list(text) 267 | i = 0 268 | start_new_word = True 269 | output = [] 270 | while i < len(chars): 271 | char = chars[i] 272 | if _is_punctuation(char): 273 | output.append([char]) 274 | start_new_word = True 275 | else: 276 | if start_new_word: 277 | output.append([]) 278 | start_new_word = False 279 | output[-1].append(char) 280 | i += 1 281 | 282 | return ["".join(x) for x in output] 283 | 284 | def _tokenize_chinese_chars(self, text): 285 | """Adds whitespace around any CJK character.""" 286 | output = [] 287 | for char in text: 288 | cp = ord(char) 289 | if self._is_chinese_char(cp): 290 | output.append(" ") 291 | output.append(char) 292 | output.append(" ") 293 | else: 294 | output.append(char) 295 | return "".join(output) 296 | 297 | def _is_chinese_char(self, cp): 298 | """Checks whether CP is the codepoint of a CJK character.""" 299 | # This defines a "chinese character" as anything in the CJK Unicode block: 300 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 301 | # 302 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 303 | # despite its name. The modern Korean Hangul alphabet is a different block, 304 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 305 | # space-separated words, so they are not treated specially and handled 306 | # like the all of the other languages. 307 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 308 | (cp >= 0x3400 and cp <= 0x4DBF) or # 309 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 310 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 311 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 312 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 313 | (cp >= 0xF900 and cp <= 0xFAFF) or # 314 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 315 | return True 316 | 317 | return False 318 | 319 | def _clean_text(self, text): 320 | """Performs invalid character removal and whitespace cleanup on text.""" 321 | output = [] 322 | for char in text: 323 | cp = ord(char) 324 | if cp == 0 or cp == 0xfffd or _is_control(char): 325 | continue 326 | if _is_whitespace(char): 327 | output.append(" ") 328 | else: 329 | output.append(char) 330 | return "".join(output) 331 | 332 | 333 | class WordpieceTokenizer(object): 334 | """Runs WordPiece tokenziation.""" 335 | 336 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): 337 | self.vocab = vocab 338 | self.unk_token = unk_token 339 | self.max_input_chars_per_word = max_input_chars_per_word 340 | 341 | def tokenize(self, text): 342 | """Tokenizes a piece of text into its word pieces. 343 | 344 | This uses a greedy longest-match-first algorithm to perform tokenization 345 | using the given vocabulary. 346 | 347 | For example: 348 | input = "unaffable" 349 | output = ["un", "##aff", "##able"] 350 | 351 | Args: 352 | text: A single token or whitespace separated tokens. This should have 353 | already been passed through `BasicTokenizer. 354 | 355 | Returns: 356 | A list of wordpiece tokens. 357 | """ 358 | 359 | text = convert_to_unicode(text) 360 | 361 | output_tokens = [] 362 | for token in whitespace_tokenize(text): 363 | chars = list(token) 364 | if len(chars) > self.max_input_chars_per_word: 365 | output_tokens.append(self.unk_token) 366 | continue 367 | 368 | is_bad = False 369 | start = 0 370 | sub_tokens = [] 371 | while start < len(chars): 372 | end = len(chars) 373 | cur_substr = None 374 | while start < end: 375 | substr = "".join(chars[start:end]) 376 | if start > 0: 377 | substr = "##" + substr 378 | if substr in self.vocab: 379 | cur_substr = substr 380 | break 381 | end -= 1 382 | if cur_substr is None: 383 | is_bad = True 384 | break 385 | sub_tokens.append(cur_substr) 386 | start = end 387 | 388 | if is_bad: 389 | output_tokens.append(self.unk_token) 390 | else: 391 | output_tokens.extend(sub_tokens) 392 | return output_tokens 393 | 394 | 395 | def _is_whitespace(char): 396 | """Checks whether `chars` is a whitespace character.""" 397 | # \t, \n, and \r are technically contorl characters but we treat them 398 | # as whitespace since they are generally considered as such. 399 | if char == " " or char == "\t" or char == "\n" or char == "\r": 400 | return True 401 | cat = unicodedata.category(char) 402 | if cat == "Zs": 403 | return True 404 | return False 405 | 406 | 407 | def _is_control(char): 408 | """Checks whether `chars` is a control character.""" 409 | # These are technically control characters but we count them as whitespace 410 | # characters. 411 | if char == "\t" or char == "\n" or char == "\r": 412 | return False 413 | cat = unicodedata.category(char) 414 | if cat in ("Cc", "Cf"): 415 | return True 416 | return False 417 | 418 | 419 | def _is_punctuation(char): 420 | """Checks whether `chars` is a punctuation character.""" 421 | cp = ord(char) 422 | # We treat all non-letter/number ASCII as punctuation. 423 | # Characters such as "^", "$", and "`" are not in the Unicode 424 | # Punctuation class but we treat them as punctuation anyways, for 425 | # consistency. 426 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 427 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 428 | return True 429 | cat = unicodedata.category(char) 430 | if cat.startswith("P"): 431 | return True 432 | return False 433 | 434 | _tokenizer = FullTokenizer() 435 | def tokenize(texts, context_length: int = 52): 436 | if isinstance(texts, str): 437 | texts = [texts] 438 | 439 | all_tokens = [] 440 | for text in texts: 441 | all_tokens.append([_tokenizer.vocab['[CLS]']] + _tokenizer.convert_tokens_to_ids(_tokenizer.tokenize(text))[ 442 | :context_length - 2] + [_tokenizer.vocab['[SEP]']]) 443 | 444 | result = np.zeros((len(all_tokens), context_length), dtype=np.int64) 445 | for i, tokens in enumerate(all_tokens): 446 | assert len(tokens) <= context_length 447 | result[i, :len(tokens)] = tokens 448 | return result 449 | -------------------------------------------------------------------------------- /python/main.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import onnxruntime as ort 3 | import numpy as np 4 | import os 5 | import shutil 6 | import pickle 7 | from clip_tokenizer import tokenize 8 | 9 | class Clip(): 10 | def __init__(self, image_modelpath, text_modelpath): 11 | self.img_model = cv2.dnn.readNet(image_modelpath) 12 | self.input_height, self.input_width = 224, 224 13 | 14 | self.mean = np.array([0.48145466, 0.4578275, 0.40821073], 15 | dtype=np.float32).reshape((1, 1, 3)) 16 | self.std = np.array([0.26862954, 0.26130258, 0.27577711], 17 | dtype=np.float32).reshape((1, 1, 3)) 18 | 19 | so = ort.SessionOptions() 20 | so.log_severity_level = 3 21 | self.txt_model = ort.InferenceSession(text_modelpath, so) 22 | self.context_length = 52 23 | 24 | def preprocess(self, srcimg): 25 | img = cv2.cvtColor(srcimg, cv2.COLOR_BGR2RGB) 26 | img = cv2.resize(img, (self.input_width, self.input_height), 27 | interpolation=cv2.INTER_CUBIC) 28 | img = (img.astype(np.float32)/255.0 - self.mean) / self.std 29 | return img 30 | 31 | def generate_image_feature(self, srcimg): 32 | img = self.preprocess(srcimg) 33 | blob = cv2.dnn.blobFromImage(img) 34 | self.img_model.setInput(blob) 35 | image_features = self.img_model.forward(self.img_model.getUnconnectedOutLayersNames())[0] 36 | 37 | img_norm = np.linalg.norm(image_features, axis=-1, keepdims=True) 38 | image_features /= img_norm 39 | return image_features 40 | 41 | def generate_text_feature(self, input_text): 42 | text = tokenize(input_text, context_length=self.context_length) 43 | text_features = [] 44 | for i in range(len(text)): 45 | one_text = np.expand_dims(text[i],axis=0) 46 | text_feature = self.txt_model.run(None, {self.txt_model.get_inputs()[0].name: one_text})[0].squeeze() 47 | text_features.append(text_feature) 48 | text_features = np.stack(text_features, axis=0) 49 | txt_norm = np.linalg.norm(text_features, axis=1, keepdims=True) 50 | text_features /= txt_norm 51 | return text_features 52 | 53 | def run_image_classify(self, image, input_strs): 54 | image_features = self.generate_image_feature(image) 55 | text_features = self.generate_text_feature(input_strs) 56 | logits_per_image = 100 * np.dot(image_features, text_features.T) 57 | exp_logits = np.exp(logits_per_image - np.max(logits_per_image, axis=-1, keepdims=True)) ###softmax 58 | softmax_logit = exp_logits / np.sum(exp_logits, axis=-1, keepdims=True) ###softmax 59 | max_str = input_strs[softmax_logit.argmax()] 60 | max_str_logit = softmax_logit.max() 61 | return max_str, max_str_logit 62 | 63 | def generate_imagedir_features(self, image_dir): 64 | imglist, image_features = [], [] 65 | for imgname in os.listdir(image_dir): 66 | srcimg = cv2.imread(os.path.join(image_dir, imgname)) 67 | if srcimg is None: ###有可能当前文件不是图片 68 | continue 69 | img_feat = self.generate_image_feature(srcimg) 70 | image_features.append(img_feat.squeeze()) 71 | imglist.append(imgname) 72 | 73 | image_features = np.stack(image_features, axis=0) 74 | return image_features, imglist 75 | 76 | def input_text_search_image(self, input_text, image_features, imglist): 77 | text_features = self.generate_text_feature(input_text) 78 | logits_per_image = 100 * np.dot(text_features, image_features.T) 79 | exp_logits = np.exp(logits_per_image - np.max(logits_per_image, axis=-1, keepdims=True)) ###softmax 80 | softmax_logit = exp_logits / np.sum(exp_logits, axis=-1, keepdims=True) ###softmax 81 | softmax_logit = softmax_logit.reshape(-1) ### 拉平数组 82 | similar_id = np.argsort(-softmax_logit) ### 降序排列 83 | top5_imglist = [(imglist[similar_id[i]], softmax_logit[similar_id[i]]) for i in range(5)] 84 | return top5_imglist 85 | 86 | 87 | if __name__=='__main__': 88 | mynet = Clip("image_model.onnx", "text_model.onnx") 89 | 90 | ###第一步,输入文件夹,生成图片的特征向量,保存到数据库文件 91 | # image_dir = os.path.join(os.getcwd(), 'testimgs') 92 | # image_features, imglist = mynet.generate_imagedir_features(image_dir) 93 | # with open('features.pkl', 'wb') as f: 94 | # pickle.dump((image_features, imglist), f) ###文件夹确定之后, 图片特征向量计算出来之后就是一个定值了,后面第二步的时候, 就加载它, 因为每次输入的文字可能不一样,这时不需要再重复计算图片特征向量了 95 | # print('生成特征向量数据库成功!!!') 96 | 97 | ###第二步,输入一句话, 计算最相似的图片 98 | input_text = "踢足球的人" ####第一步生成了特征向量数据库,这时候每当输入新的文本时,就不需要再重新计算图片特征向量 99 | with open('features.pkl', 'rb') as f: 100 | image_features, imglist = pickle.load(f) 101 | top5_imglist = mynet.input_text_search_image(input_text, image_features, imglist) 102 | print(top5_imglist) 103 | 104 | image_dir = os.path.join(os.getcwd(), 'testimgs') 105 | result_imgs = os.path.join(os.getcwd(), 'result_imgs') 106 | if os.path.exists(result_imgs): 107 | shutil.rmtree(result_imgs) 108 | os.makedirs(result_imgs) 109 | for imgname,conf in top5_imglist: 110 | shutil.copy(os.path.join(image_dir, imgname), result_imgs) 111 | 112 | 113 | #####输入提示词, 做图片分类 114 | # imgpath = 'pokemon.jpeg' 115 | # text = ["杰尼龟", "妙蛙种子", "小火龙", "皮卡丘"] 116 | 117 | # mynet = Clip("image_model.onnx", "text_model.onnx") 118 | 119 | # srcimg = cv2.imread(imgpath) 120 | # max_str, max_str_logit = mynet.run_image_classify(srcimg, text) 121 | # print(f"最大概率:{max_str_logit}, 对应类别:{max_str}") -------------------------------------------------------------------------------- /python/pokemon.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/Chinese-CLIP-opencv-onnxrun/7d21b5954d93432fe58fbea72285cd1358492090/python/pokemon.jpeg -------------------------------------------------------------------------------- /testimgs/000000005082.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/Chinese-CLIP-opencv-onnxrun/7d21b5954d93432fe58fbea72285cd1358492090/testimgs/000000005082.jpg -------------------------------------------------------------------------------- /testimgs/000000014309.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/Chinese-CLIP-opencv-onnxrun/7d21b5954d93432fe58fbea72285cd1358492090/testimgs/000000014309.jpg -------------------------------------------------------------------------------- /testimgs/000000024283.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/Chinese-CLIP-opencv-onnxrun/7d21b5954d93432fe58fbea72285cd1358492090/testimgs/000000024283.jpg -------------------------------------------------------------------------------- /testimgs/000000025111.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/Chinese-CLIP-opencv-onnxrun/7d21b5954d93432fe58fbea72285cd1358492090/testimgs/000000025111.jpg -------------------------------------------------------------------------------- /testimgs/000000043121.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/Chinese-CLIP-opencv-onnxrun/7d21b5954d93432fe58fbea72285cd1358492090/testimgs/000000043121.jpg -------------------------------------------------------------------------------- /testimgs/000000047693.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/Chinese-CLIP-opencv-onnxrun/7d21b5954d93432fe58fbea72285cd1358492090/testimgs/000000047693.jpg -------------------------------------------------------------------------------- /testimgs/000000057110.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/Chinese-CLIP-opencv-onnxrun/7d21b5954d93432fe58fbea72285cd1358492090/testimgs/000000057110.jpg -------------------------------------------------------------------------------- /testimgs/000000060398.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/Chinese-CLIP-opencv-onnxrun/7d21b5954d93432fe58fbea72285cd1358492090/testimgs/000000060398.jpg -------------------------------------------------------------------------------- /testimgs/000000060621.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/Chinese-CLIP-opencv-onnxrun/7d21b5954d93432fe58fbea72285cd1358492090/testimgs/000000060621.jpg -------------------------------------------------------------------------------- /testimgs/000000064549.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/Chinese-CLIP-opencv-onnxrun/7d21b5954d93432fe58fbea72285cd1358492090/testimgs/000000064549.jpg -------------------------------------------------------------------------------- /testimgs/000000064840.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/Chinese-CLIP-opencv-onnxrun/7d21b5954d93432fe58fbea72285cd1358492090/testimgs/000000064840.jpg -------------------------------------------------------------------------------- /testimgs/000000075717.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/Chinese-CLIP-opencv-onnxrun/7d21b5954d93432fe58fbea72285cd1358492090/testimgs/000000075717.jpg -------------------------------------------------------------------------------- /testimgs/000000076516.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/Chinese-CLIP-opencv-onnxrun/7d21b5954d93432fe58fbea72285cd1358492090/testimgs/000000076516.jpg -------------------------------------------------------------------------------- /testimgs/000000076706.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/Chinese-CLIP-opencv-onnxrun/7d21b5954d93432fe58fbea72285cd1358492090/testimgs/000000076706.jpg --------------------------------------------------------------------------------