├── .gitignore ├── README.md ├── docs ├── bank.jpeg └── gushi.jpg ├── offline ├── bce_rerank_server.py ├── es_index_init.py └── llava_model_server.py ├── src ├── __init__.py ├── serve │ ├── __init__.py │ ├── image_search_server.py │ └── image_upload_gradio_server.py └── utils │ └── __init__.py └── tests ├── __init__.py ├── gallary_doc_test.py ├── image_size_change.py └── paddle_ocr_test.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | __pycache__ 3 | src/data -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 本项目使用LLaVA 1.6多模态模型实现以文搜图和以图搜图功能。 2 | 3 | ### OCR模型 4 | 5 | PaddleOCR 6 | 7 | 模型介绍及部署方法: https://www.paddlepaddle.org.cn/hubdetail?name=ch_pp-ocrv3&en_category=TextRecognition 8 | 9 | 10 | ### 多模态模型 11 | 12 | LLaVA 1.6 13 | 14 | Github网址:https://github.com/haotian-liu/LLaVA/tree/main 15 | 16 | DEMO网址:https://llava.hliu.cc/ 17 | 18 | ### 实现原理 19 | 20 | 待补充 21 | 22 | ### 图片上传 23 | 24 | src/serve/image_upload_gradio_server.py 25 | 26 | ![image-search-图片上传.png](https://s2.loli.net/2024/02/06/bJCqkv4LVgminpy.png) 27 | 28 | ### 使用文字搜图 29 | 30 | src/serve/image_search_server.py 31 | 32 | - 单个短语 33 | 34 | ![image-search-单个短语1.png](https://s2.loli.net/2024/02/07/9xKPRYX1ZbQB5Sz.png) 35 | ![image-search-单个短语2.png](https://s2.loli.net/2024/02/07/ajvFCI4NZtBTH5s.png) 36 | ![image-search-单个短语3.png](https://s2.loli.net/2024/02/07/CeGMUjNEBZ8ThHQ.png) 37 | 38 | - 多个短语 39 | 40 | ![image-search-多个短语1.png](https://s2.loli.net/2024/02/07/YwvpK2BakXuziER.png) 41 | ![image-search-多个短语2.png](https://s2.loli.net/2024/02/07/CPZywoEUgXHsRpQ.png) 42 | ![image-search-多个短语3.png](https://s2.loli.net/2024/02/07/wpgVYUAE6HdP4fJ.png) 43 | 44 | ### 以图搜图 45 | 46 | ![image-search-以图搜图1.png](https://s2.loli.net/2024/02/07/2ZdHhRr7cgoDFyW.png) 47 | ![image-search-以图搜图2.png](https://s2.loli.net/2024/02/07/PXRnKO3tl8zvZm6.png) 48 | ![image-search-以图搜图3.png](https://s2.loli.net/2024/02/07/iafwum1IEKvezhn.png) 49 | -------------------------------------------------------------------------------- /docs/bank.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/percent4/multi-modal-image-search/9ccd46677b4a957c0b5dd5cfebd2bfcf8fb8e2c1/docs/bank.jpeg -------------------------------------------------------------------------------- /docs/gushi.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/percent4/multi-modal-image-search/9ccd46677b4a957c0b5dd5cfebd2bfcf8fb8e2c1/docs/gushi.jpg -------------------------------------------------------------------------------- /offline/bce_rerank_server.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @place: Pudong, Shanghai 3 | # @file: bce_rerank_server.py 4 | # @time: 2024/2/6 17:06 5 | import uvicorn 6 | from fastapi import FastAPI 7 | from pydantic import BaseModel 8 | from operator import itemgetter 9 | from sentence_transformers import CrossEncoder 10 | 11 | 12 | app = FastAPI() 13 | # init rerank model 14 | model = CrossEncoder('/data-ai/usr/lmj/models/bce-reranker-base_v1', max_length=512) 15 | 16 | 17 | class SentencePair(BaseModel): 18 | text1: str 19 | text2: str 20 | 21 | 22 | class Sentences(BaseModel): 23 | texts: list[SentencePair] 24 | 25 | 26 | @app.get('/') 27 | def home(): 28 | return 'hello world' 29 | 30 | 31 | @app.post('/rerank') 32 | def get_embedding(sentence_pairs: Sentences): 33 | scores = model.predict([[pair.text1, pair.text2] for pair in sentence_pairs.texts]).tolist() 34 | result = [[scores[i], sentence_pairs.texts[i].text1, sentence_pairs.texts[i].text2] for i in range(len(scores))] 35 | sorted_result = sorted(result, key=itemgetter(0), reverse=True) 36 | return {"result": sorted_result} 37 | 38 | 39 | if __name__ == '__main__': 40 | uvicorn.run(app, host='0.0.0.0', port=50074) 41 | -------------------------------------------------------------------------------- /offline/es_index_init.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @place: Pudong, Shanghai 3 | # @file: es_index_init.py 4 | # @time: 2024/2/5 15:59 5 | from elasticsearch import Elasticsearch 6 | 7 | # 连接Elasticsearch 8 | es_client = Elasticsearch("http://localhost:9200") 9 | 10 | # 创建新的ES index 11 | mapping = { 12 | 'properties': { 13 | 'url': { 14 | 'type': 'text' 15 | }, 16 | 'tag': { 17 | 'type': 'keyword' 18 | }, 19 | 'description': { 20 | 'type': 'text', 21 | 'analyzer': 'ik_smart', 22 | 'search_analyzer': 'ik_smart' 23 | }, 24 | 'title': { 25 | 'type': 'text' 26 | }, 27 | "insert_time": { 28 | "type": "date", 29 | "format": "yyyy-MM-dd HH:mm:ss" 30 | }, 31 | 'ocr_result': { 32 | 'type': 'text' 33 | } 34 | } 35 | } 36 | 37 | es_client.indices.create(index='image-search-ocr', ignore=400) 38 | result = es_client.indices.put_mapping(index='image-search-ocr', body=mapping) 39 | print(result) 40 | -------------------------------------------------------------------------------- /offline/llava_model_server.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @place: Pudong, Shanghai 3 | # @file: llava_model_server.py 4 | # @time: 2024/2/18 15:45 5 | import torch 6 | 7 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 8 | from llava.conversation import conv_templates, SeparatorStyle 9 | from llava.model.builder import load_pretrained_model 10 | from llava.utils import disable_torch_init 11 | from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path 12 | 13 | from PIL import Image 14 | 15 | import requests 16 | from PIL import Image 17 | from io import BytesIO 18 | from transformers import TextStreamer 19 | from pydantic import BaseModel 20 | import uvicorn 21 | from fastapi import FastAPI 22 | 23 | 24 | # Model 25 | disable_torch_init() 26 | 27 | model_path = "/data-ai/usr/lmj/models/llava-v1.6-34b" 28 | model_name = get_model_name_from_path(model_path) 29 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, False, False, device="cuda") 30 | 31 | 32 | def load_image(image_file): 33 | if image_file.startswith('http://') or image_file.startswith('https://'): 34 | response = requests.get(image_file) 35 | image = Image.open(BytesIO(response.content)).convert('RGB') 36 | else: 37 | image = Image.open(image_file).convert('RGB') 38 | return image 39 | 40 | 41 | def model_infer(image_file, inp): 42 | if "llama-2" in model_name.lower(): 43 | conv_mode = "llava_llama_2" 44 | elif "mistral" in model_name.lower(): 45 | conv_mode = "mistral_instruct" 46 | elif "v1.6-34b" in model_name.lower(): 47 | conv_mode = "chatml_direct" 48 | elif "v1" in model_name.lower(): 49 | conv_mode = "llava_v1" 50 | elif "mpt" in model_name.lower(): 51 | conv_mode = "mpt" 52 | else: 53 | conv_mode = "llava_v0" 54 | 55 | conv = conv_templates[conv_mode].copy() 56 | if "mpt" in model_name.lower(): 57 | roles = ('user', 'assistant') 58 | else: 59 | roles = conv.roles 60 | 61 | image = load_image(image_file) 62 | image_size = image.size 63 | # Similar operation in model_worker.py 64 | image_tensor = process_images([image], image_processor, model.config) 65 | if type(image_tensor) is list: 66 | image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor] 67 | else: 68 | image_tensor = image_tensor.to(model.device, dtype=torch.float16) 69 | 70 | if image is not None: 71 | # first message 72 | if model.config.mm_use_im_start_end: 73 | inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp 74 | else: 75 | inp = DEFAULT_IMAGE_TOKEN + '\n' + inp 76 | conv.append_message(conv.roles[0], inp) 77 | image = None 78 | else: 79 | # later messages 80 | conv.append_message(conv.roles[0], inp) 81 | conv.append_message(conv.roles[1], None) 82 | prompt = conv.get_prompt() 83 | 84 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) 85 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 86 | keywords = [stop_str] 87 | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) 88 | 89 | with torch.inference_mode(): 90 | output_ids = model.generate( 91 | input_ids, 92 | images=image_tensor, 93 | image_sizes=[image_size], 94 | do_sample=True, 95 | temperature=0.1, 96 | max_new_tokens=1024, 97 | streamer=streamer, 98 | use_cache=True) 99 | 100 | outputs = tokenizer.decode(output_ids[0, 1:-1]).strip() 101 | conv.messages[-1][-1] = outputs 102 | print("\n", {"prompt": prompt, "outputs": outputs}, "\n") 103 | return outputs 104 | 105 | 106 | app = FastAPI() 107 | 108 | 109 | class ImageInput(BaseModel): 110 | url: str 111 | ocr_result: str 112 | 113 | 114 | @app.get('/') 115 | def home(): 116 | return 'hello world' 117 | 118 | 119 | @app.post('/img_desc') 120 | def image_desc(image_input: ImageInput): 121 | title_string = "请为这张图片生成一个中文标题。" if not image_input.ocr_result else f'这张图片中的文字为"{image_input.ocr_result}"。请为这张图片生成一个中文标题。' 122 | title_output = model_infer(image_input.url, title_string) 123 | desc_string = "请详细描述这张图片中的内容。" if not image_input.ocr_result else f'这张图片中的文字为"{image_input.ocr_result}"。请详细描述这张图片中的内容。' 124 | desc_output = model_infer(image_input.url, desc_string) 125 | return {"url": image_input.url, "title": title_output, "desc": desc_output} 126 | 127 | 128 | if __name__ == '__main__': 129 | uvicorn.run(app, host="0.0.0.0", port=50075) 130 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @place: Pudong, Shanghai 3 | # @file: __init__.py.py 4 | # @time: 2024/2/4 23:40 5 | -------------------------------------------------------------------------------- /src/serve/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @place: Pudong, Shanghai 3 | # @file: __init__.py.py 4 | # @time: 2024/2/5 14:58 5 | -------------------------------------------------------------------------------- /src/serve/image_search_server.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @place: Pudong, Shanghai 3 | # @file: image_search_server.py 4 | # @time: 2024/2/6 10:36 5 | import json 6 | import requests 7 | import gradio as gr 8 | from PIL import Image 9 | from io import BytesIO 10 | from datetime import datetime as dt 11 | from elasticsearch import Elasticsearch 12 | from uuid import uuid4 13 | 14 | # 连接Elasticsearch 15 | es_client = Elasticsearch("http://localhost:9200") 16 | 17 | 18 | def load_image(image_file): 19 | if image_file.startswith('http://') or image_file.startswith('https://'): 20 | response = requests.get(image_file) 21 | image = Image.open(BytesIO(response.content)).convert('RGB') 22 | else: 23 | image = Image.open(image_file).convert('RGB') 24 | 25 | return image 26 | 27 | 28 | # check if url is already in es 29 | def is_image_in_es(url): 30 | dsl = { 31 | 'query': { 32 | 'match': { 33 | 'url': url 34 | } 35 | }, 36 | "size": 5 37 | } 38 | search_result = es_client.search(index='image-search-ocr', body=dsl) 39 | if search_result['hits']['hits']: 40 | url_result = {_['_source']['url']: _['_id'] for _ in search_result['hits']['hits']} 41 | if url in url_result: 42 | print("the url is already in ElasticSearch!") 43 | return True, url_result[url] 44 | print("the url is not in ElasticSearch!") 45 | return False, None 46 | 47 | 48 | def get_image_desc(image_url): 49 | image_exists, _id = is_image_in_es(image_url) 50 | if image_exists: 51 | result = es_client.get(index="image-search-ocr", id=_id) 52 | print("get image info by ElasticSearch!") 53 | return result["_source"]["title"], result["_source"]["description"] 54 | else: 55 | # get image title and description 56 | url = "http://localhost:50075/img_desc" 57 | payload = json.dumps({"url": image_url}) 58 | headers = {'Content-Type': 'application/json'} 59 | response = requests.request("POST", url, headers=headers, data=payload) 60 | result = response.json() 61 | print("get image info by LLaVA model!") 62 | return result["title"], result["desc"] 63 | 64 | 65 | def insert_es_data(url, title, desc): 66 | image_exists, _id = is_image_in_es(url) 67 | if not image_exists: 68 | doc = { 69 | "url": url, 70 | "title": title, 71 | "description": desc, 72 | "tag": "search", 73 | "insert_time": dt.now().strftime("%Y-%m-%d %H:%M:%S") 74 | } 75 | es_client.index(index="image-search-ocr", document=doc) 76 | print(f"insert {url} into es successfully!") 77 | 78 | 79 | def get_rerank_result(text_list): 80 | url = "http://localhost:50074/rerank" 81 | payload = json.dumps({ 82 | "texts": [ 83 | { 84 | "text1": text[0], 85 | "text2": text[1] 86 | } 87 | for text in text_list 88 | ] 89 | }) 90 | headers = { 91 | 'Content-Type': 'application/json' 92 | } 93 | response = requests.request("POST", url, headers=headers, data=payload) 94 | print("rerank result: ") 95 | for _ in response.json()['result']: 96 | print(_) 97 | return response.json()['result'] 98 | 99 | 100 | def image_search_by_http(query_str): 101 | result = [] 102 | # 对image的title进行全文检索 103 | image_title, image_desc = get_image_desc(query_str) 104 | print("image info: ", repr(image_title), repr(image_desc)) 105 | insert_es_data(query_str, image_title, image_desc) 106 | dsl = { 107 | 'query': { 108 | 'match': { 109 | 'description': image_desc 110 | } 111 | }, 112 | "size": 10 113 | } 114 | search_result = es_client.search(index='image-search-ocr', body=dsl) 115 | if search_result['hits']['hits']: 116 | es_search_result = {_['_source']['description'][:200]: _['_source']['url'] for _ in 117 | search_result['hits']['hits']} 118 | desc_title_dict = {_['_source']['description'][:200]: _['_source']['title'] for _ in 119 | search_result['hits']['hits']} 120 | # get title rerank result 121 | text_list = [[image_desc[:200], key] for key in es_search_result.keys()] 122 | rerank_result = get_rerank_result(text_list=text_list) 123 | # get at most 5 similar images 124 | i = 0 125 | for record in rerank_result: 126 | score, image_desc, other_desc = record 127 | if image_desc != other_desc and score > 0.4: 128 | i += 1 129 | result.append([es_search_result[other_desc], desc_title_dict[other_desc]]) 130 | if i > 4: 131 | break 132 | return result 133 | 134 | 135 | def image_search_by_text(query_str): 136 | result = [] 137 | # 对query进行全文搜索 138 | queries = query_str.split() 139 | dsl = { 140 | "query": { 141 | "bool": { 142 | "must": [ 143 | {"match": {"description": _}} for _ in queries 144 | ] 145 | } 146 | }, 147 | "size": 5 148 | } 149 | search_result = es_client.search(index='image-search-ocr', body=dsl) 150 | if search_result['hits']['hits']: 151 | result = [[_['_source']['url'], _['_source']['title']] for _ in search_result['hits']['hits']] 152 | print('search result: ', result) 153 | return result 154 | 155 | 156 | def image_search(query): 157 | user_image_image = None 158 | if query.startswith("http"): 159 | user_image_image = query 160 | result = image_search_by_http(query) 161 | else: 162 | result = image_search_by_text(query) 163 | 164 | images = [load_image(record[0]) for record in result] 165 | if len(images) >= 3: 166 | images = images[:3] 167 | else: 168 | for _ in range(3 - len(images)): 169 | images.append(None) 170 | return user_image_image, images[0], images[1], images[2] 171 | 172 | 173 | if __name__ == '__main__': 174 | with gr.Blocks() as demo: 175 | with gr.Row(): 176 | with gr.Column(scale=0.3): 177 | user_input = gr.TextArea(lines=1, placeholder="Enter search word", label="Search") 178 | user_input_image = gr.Image() 179 | with gr.Column(scale=0.2): 180 | search_image1 = gr.Image(type='pil', height=200) 181 | search_image2 = gr.Image(type='pil', height=200) 182 | search_image3 = gr.Image(type='pil', height=200) 183 | submit = gr.Button("Search") 184 | submit.click(fn=image_search, 185 | inputs=user_input, 186 | outputs=[user_input_image, search_image1, search_image2, search_image3]) 187 | demo.launch(server_name="0.0.0.0", server_port=7680) 188 | -------------------------------------------------------------------------------- /src/serve/image_upload_gradio_server.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @place: Pudong, Shanghai 3 | # @file: image_upload_gradio_server.py 4 | # @time: 2024/2/4 23:41 5 | import json 6 | import uuid 7 | from datetime import datetime as dt 8 | import cv2 9 | import base64 10 | import requests 11 | import gradio as gr 12 | from PIL import Image 13 | from io import BytesIO 14 | from elasticsearch import Elasticsearch 15 | from urllib.request import urlretrieve 16 | 17 | # 连接Elasticsearch 18 | es_client = Elasticsearch("http://localhost:9200") 19 | 20 | 21 | def load_image(image_file): 22 | if image_file.startswith('http://') or image_file.startswith('https://'): 23 | response = requests.get(image_file) 24 | image = Image.open(BytesIO(response.content)).convert('RGB') 25 | else: 26 | image = Image.open(image_file).convert('RGB') 27 | return image 28 | 29 | 30 | def cv2_to_base64(image): 31 | data = cv2.imencode('.jpg', image)[1] 32 | return base64.b64encode(data.tobytes()).decode('utf8') 33 | 34 | 35 | def image_ocr(image_url): 36 | # download image by url 37 | image_path = f'../data/{str(uuid.uuid4())}.jpg' 38 | urlretrieve(image_url, image_path) 39 | # get ocr result 40 | data = {'images': [cv2_to_base64(cv2.imread(image_path))]} 41 | headers = {"Content-type": "application/json"} 42 | url = "http://localhost:50076/predict/ch_pp-ocrv3" 43 | r = requests.post(url=url, headers=headers, data=json.dumps(data)) 44 | if r.json()["results"]: 45 | return "\n".join([ocr_record["text"].strip() for ocr_record in r.json()["results"][0]["data"]]) 46 | else: 47 | return "" 48 | 49 | 50 | def insert_es_data(url, tag, title, desc, ocr_result): 51 | doc = { 52 | "url": url, 53 | "title": title, 54 | "description": desc, 55 | "tag": tag, 56 | "insert_time": dt.now().strftime("%Y-%m-%d %H:%M:%S"), 57 | "ocr_result": ocr_result 58 | } 59 | es_client.index(index="image-search-ocr", document=doc) 60 | print("insert into es successfully!") 61 | 62 | 63 | def get_image_desc(ocr_choice, image_url): 64 | if not ocr_choice: 65 | ocr_result = "" 66 | else: 67 | ocr_result = image_ocr(image_url=image_url) 68 | print("ocr result: ", repr(ocr_result)) 69 | # load image 70 | image = load_image(image_url) 71 | # get image title and description 72 | url = "http://localhost:50075/img_desc" 73 | payload = json.dumps({"url": image_url, "ocr_result": ocr_result}) 74 | headers = {'Content-Type': 'application/json'} 75 | response = requests.request("POST", url, headers=headers, data=payload) 76 | result = response.json() 77 | return image, ocr_result, result["title"], result["desc"] 78 | 79 | 80 | with gr.Blocks() as demo: 81 | with gr.Row(): 82 | with gr.Column(): 83 | checkout_group = gr.CheckboxGroup(choices=["LLaVA 1.6"], value="LLaVA 1.6", label='models') 84 | ocr_group = gr.CheckboxGroup(choices=["PaddleOCR"], label='OCR') 85 | user_input = gr.TextArea(lines=5, placeholder="Enter the url of an image", label="image url") 86 | tags = gr.TextArea(lines=1, placeholder="Enter the tags of an image", label="image tag") 87 | with gr.Column(): 88 | image_box = gr.Image() 89 | ocr_output = gr.TextArea(lines=2, label='OCR result') 90 | title_output = gr.TextArea(lines=1, label='image title') 91 | desc_output = gr.TextArea(lines=5, label='image description') 92 | submit = gr.Button("Submit") 93 | insert_data = gr.Button("Insert") 94 | submit.click(fn=get_image_desc, 95 | inputs=[ocr_group, user_input], 96 | outputs=[image_box, ocr_output, title_output, desc_output]) 97 | insert_data.click(fn=insert_es_data, 98 | inputs=[user_input, tags, title_output, desc_output, ocr_output]) 99 | 100 | demo.launch() 101 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @place: Pudong, Shanghai 3 | # @file: __init__.py.py 4 | # @time: 2024/2/5 14:58 5 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @place: Pudong, Shanghai 3 | # @file: __init__.py.py 4 | # @time: 2024/2/4 23:40 5 | -------------------------------------------------------------------------------- /tests/gallary_doc_test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @place: Pudong, Shanghai 3 | # @file: gallary_doc_test.py 4 | # @time: 2024/2/7 16:22 5 | # This demo needs to be run from the repo folder. 6 | # python demo/fake_gan/run.py 7 | import random 8 | 9 | import gradio as gr 10 | 11 | 12 | def fake_gan(): 13 | images = [ 14 | (random.choice( 15 | [ 16 | "http://www.marketingtool.online/en/face-generator/img/faces/avatar-1151ce9f4b2043de0d2e3b7826127998.jpg", 17 | "http://www.marketingtool.online/en/face-generator/img/faces/avatar-116b5e92936b766b7fdfc242649337f7.jpg", 18 | "http://www.marketingtool.online/en/face-generator/img/faces/avatar-1163530ca19b5cebe1b002b8ec67b6fc.jpg", 19 | "http://www.marketingtool.online/en/face-generator/img/faces/avatar-1116395d6e6a6581eef8b8038f4c8e55.jpg", 20 | "http://www.marketingtool.online/en/face-generator/img/faces/avatar-11319be65db395d0e8e6855d18ddcef0.jpg", 21 | ] 22 | ), f"label {i}") 23 | for i in range(3) 24 | ] 25 | print(images) 26 | return images 27 | 28 | 29 | with gr.Blocks() as demo: 30 | gallery = gr.Gallery( 31 | label="Generated images", show_label=False, elem_id="gallery" 32 | , columns=[3], rows=[1], object_fit="contain", height="auto") 33 | btn = gr.Button("Generate images", scale=0) 34 | 35 | btn.click(fake_gan, None, gallery) 36 | 37 | if __name__ == "__main__": 38 | demo.launch() 39 | -------------------------------------------------------------------------------- /tests/image_size_change.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @place: Pudong, Shanghai 3 | # @file: image_size_change.py 4 | # @time: 2024/2/6 21:54 5 | import gradio as gr 6 | from PIL import Image 7 | import numpy as np 8 | 9 | 10 | # 定义一个函数来调整图片尺寸并返回 11 | def resize_image(image, width, height): 12 | # 将上传的图片转换为Pillow图像 13 | img = Image.fromarray(image.astype('uint8'), 'RGB') 14 | print(img.size) 15 | width = img.size[0] # 获取宽度 16 | height = img.size[1] # 获取高度 17 | resized_image = img.resize((int(width * 0.3), int(height * 0.3))) 18 | # 调整图片尺寸 19 | # resized_image = img.resize((int(width), int(height))) 20 | # 将Pillow图像转换回numpy数组以便Gradio可以显示 21 | return np.array(resized_image) 22 | 23 | 24 | # 创建Gradio接口 25 | iface = gr.Interface(fn=resize_image, 26 | inputs=[gr.inputs.Image(shape=(200, 200)), gr.inputs.Slider(minimum=100, maximum=1000, default=200, label="Width"), gr.inputs.Slider(minimum=100, maximum=1000, default=200, label="Height")], 27 | outputs=gr.Image(shape=(80, 80), type="numpy", label="Resized Image"), 28 | description="Upload an image to resize it to your desired dimensions.") 29 | 30 | # 启动应用 31 | iface.launch() 32 | -------------------------------------------------------------------------------- /tests/paddle_ocr_test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @place: Pudong, Shanghai 3 | # @file: paddle_ocr_test.py 4 | # @time: 2024/2/8 13:54 5 | import requests 6 | import json 7 | import cv2 8 | import base64 9 | 10 | 11 | def cv2_to_base64(image): 12 | data = cv2.imencode('.jpg', image)[1] 13 | return base64.b64encode(data.tobytes()).decode('utf8') 14 | 15 | 16 | def image_ocr(image_path): 17 | data = {'images': [cv2_to_base64(cv2.imread(image_path))]} 18 | headers = {"Content-type": "application/json"} 19 | url = "http://localhost:50076/predict/ch_pp-ocrv3" 20 | r = requests.post(url=url, headers=headers, data=json.dumps(data)) 21 | return r.json()["results"] 22 | 23 | 24 | if __name__ == '__main__': 25 | image_path_test = "/Users/admin/PycharmProjects/multi-modal-image-search/docs/bank.jpeg" 26 | res = image_ocr(image_path=image_path_test) 27 | import pprint 28 | pprint.pprint(res) 29 | --------------------------------------------------------------------------------