├── CLIP.ipynb ├── LLaVA.ipynb ├── README.md ├── Tip-Adapter.ipynb ├── picture_spider.ipynb └── 中英文CLIP.ipynb /CLIP.ipynb: -------------------------------------------------------------------------------- 1 | {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"machine_shape":"hm","gpuType":"V100","mount_file_id":"1sAnZ_E2Lzfcx8ApdmaOXlz7vvrIR6dl6","authorship_tag":"ABX9TyNArW73rfMsjpk9es4kj3mQ"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"accelerator":"GPU"},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"RCA-OAtko3eA"},"outputs":[],"source":["!pip install ftfy regex tqdm\n","!pip install git+https://github.com/openai/CLIP.git"]},{"cell_type":"markdown","source":["五个类别+others"],"metadata":{"id":"lgbjlbfwH1yC"}},{"cell_type":"code","source":["import torch\n","import clip\n","from PIL import Image\n","import os\n","\n","device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n","\n","# 加载模型\n","model, preprocess = clip.load(\"ViT-B/32\", device=device)\n","\n","# 定义类别和对应的图片命名规则\n","categories = {\n"," \"pet cat\": \"宠物猫\",\n"," \"tomato\": \"番茄\",\n"," \"paper-cut\": \"剪纸\",\n"," \"computer\": \"电脑\",\n"," \"dumpling\": \"饺子\"\n","}\n","\n","# 初始化统计数据\n","stats = {category: {\"TP\": 0, \"FP\": 0, \"FN\": 0} for category in categories.keys()}\n","\n","# 图片文件夹路径\n","image_folder_path = \"/content/drive/MyDrive/picture_data\"\n","\n","# 遍历图片\n","for image_name in os.listdir(image_folder_path):\n"," if image_name.endswith(\".png\"):\n"," image_path = os.path.join(image_folder_path, image_name)\n"," image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)\n"," text = clip.tokenize(list(categories.keys()) + [\"others\"]).to(device)\n"," with torch.no_grad():\n"," logits_per_image, _ = model(image, text)\n"," probs = logits_per_image.softmax(dim=-1).cpu().numpy()\n"," # with torch.no_grad():\n"," # image_features = model.encode_image(image)\n"," # text_features = model.encode_text(text)\n","\n"," # image_features /= image_features.norm(dim=-1, keepdim=True)\n"," # text_features /= text_features.norm(dim=-1, keepdim=True)\n","\n"," # similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)\n"," # probs = similarity.cpu().numpy()\n","\n"," predicted_category = (list(categories.keys()) + [\"others\"])[probs.argmax()]\n","\n"," category_matched = False\n"," for category, prefix in categories.items():\n"," if prefix in image_name:\n"," category_matched = True\n"," if category == predicted_category:\n"," stats[category][\"TP\"] += 1\n"," else:\n"," stats[category][\"FN\"] += 1\n"," if predicted_category != \"others\":\n"," stats[predicted_category][\"FP\"] += 1\n"," break\n","\n"," # 处理负样本\n"," if not category_matched:\n"," if predicted_category != \"others\":\n"," stats[predicted_category][\"FP\"] += 1 # 错误地将负样本判断为正样本类别\n","\n","# 计算平均精确度、平均召回率以及平均F1分数\n","total_precision = 0\n","total_recall = 0\n","total_f1 = 0\n","num_categories = len(categories)\n","\n","for category, data in stats.items():\n"," if category != \"others\":\n"," precision = data[\"TP\"] / (data[\"TP\"] + data[\"FP\"]) if (data[\"TP\"] + data[\"FP\"]) > 0 else 0\n"," recall = data[\"TP\"] / (data[\"TP\"] + data[\"FN\"]) if (data[\"TP\"] + data[\"FN\"]) > 0 else 0\n"," f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0\n"," print(f\"Category: {category}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}\")\n"," total_precision += precision\n"," total_recall += recall\n"," total_f1 += f1\n","\n","average_precision = total_precision / num_categories\n","average_recall = total_recall / num_categories\n","average_f1 = total_f1 / num_categories\n","\n","print(f\"Average Precision: {average_precision:.4f}, Average Recall: {average_recall:.4f}, Average F1: {average_f1:.4f}\")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"Kr4VYvFfqE2J","executionInfo":{"status":"ok","timestamp":1709783155652,"user_tz":-480,"elapsed":68530,"user":{"displayName":"Thomasine Kaczka","userId":"06732724743998191518"}},"outputId":"dd923720-00cf-4d26-f0b1-8fc02d376819"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stderr","text":["/usr/local/lib/python3.10/dist-packages/PIL/Image.py:996: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images\n"," warnings.warn(\n"]},{"output_type":"stream","name":"stdout","text":["Category: pet cat, Precision: 0.9256, Recall: 0.9950, F1: 0.9590\n","Category: tomato, Precision: 0.4785, Recall: 1.0000, F1: 0.6472\n","Category: paper-cut, Precision: 0.5535, Recall: 0.9050, F1: 0.6869\n","Category: computer, Precision: 0.5319, Recall: 1.0000, F1: 0.6944\n","Category: dumpling, Precision: 0.3333, Recall: 0.9900, F1: 0.4987\n","Average Precision: 0.5646, Average Recall: 0.9780, Average F1: 0.6973\n"]}]},{"cell_type":"markdown","source":["一个类别+others"],"metadata":{"id":"f_xtcABJqqhe"}},{"cell_type":"code","source":["total_precision = 0\n","total_recall = 0\n","total_f1 = 0\n","\n","# 对每个类别分别计算precision、recall及F1分数\n","for category, keyword in categories.items():\n"," # 初始化统计数据\n"," TP = 0\n"," FP = 0\n"," FN = 0\n","\n"," # 遍历图片\n"," for image_name in os.listdir(image_folder_path):\n"," if image_name.endswith(\".png\"):\n"," image_path = os.path.join(image_folder_path, image_name)\n","\n"," # 预处理图片并进行预测\n"," image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)\n"," text_labels = [category] + [\"others\"]\n"," text = clip.tokenize(text_labels).to(device)\n","\n"," with torch.no_grad():\n"," logits_per_image, _ = model(image, text)\n"," probs = logits_per_image.softmax(dim=-1).cpu().numpy()\n","\n"," # 获取最高概率的类别\n"," predicted_category = text_labels[probs.argmax()]\n","\n"," # 判断真实类别并更新统计数据\n"," actual_category = \"others\" if keyword not in image_name else category\n"," if predicted_category == actual_category and actual_category == category:\n"," TP += 1\n"," elif predicted_category == category and actual_category == \"others\":\n"," FP += 1\n"," elif predicted_category == \"others\" and actual_category == category:\n"," FN += 1\n","\n"," precision = TP / (TP + FP) if (TP + FP) > 0 else 0\n"," recall = TP / (TP + FN) if (TP + FN) > 0 else 0\n"," f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0\n","\n"," print(f\"Category: {category}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1_score:.4f}\")\n"," total_precision += precision\n"," total_recall += recall\n"," total_f1 += f1_score\n","\n","# 计算并打印平均Precision、Recall及F1分数\n","average_precision = total_precision / len(categories)\n","average_recall = total_recall / len(categories)\n","average_f1 = total_f1 / len(categories)\n","\n","print(f\"Average Precision: {average_precision:.4f}, Average Recall: {average_recall:.4f}, Average F1: {average_f1:.4f}\")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"zQyVwgO6bUVE","executionInfo":{"status":"ok","timestamp":1709779759965,"user_tz":-480,"elapsed":281381,"user":{"displayName":"Thomasine Kaczka","userId":"06732724743998191518"}},"outputId":"3450e79b-0710-444a-c535-466585cf0300"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Category: pet cat, Precision: 0.5952, Recall: 1.0000, F1 Score: 0.7463\n","Category: tomato, Precision: 0.2217, Recall: 1.0000, F1 Score: 0.3630\n","Category: paper-cut, Precision: 0.3306, Recall: 1.0000, F1 Score: 0.4969\n","Category: computer, Precision: 0.4598, Recall: 1.0000, F1 Score: 0.6299\n","Category: dumpling, Precision: 0.2051, Recall: 1.0000, F1 Score: 0.3404\n","Average Precision: 0.3625, Average Recall: 1.0000, Average F1: 0.5153\n"]}]},{"cell_type":"markdown","source":["卡阈值(单类别+others)"],"metadata":{"id":"UNmaINASvKcB"}},{"cell_type":"code","source":["# 设定阈值\n","threshold = 0.97\n","\n","total_precision = 0\n","total_recall = 0\n","total_f1 = 0\n","\n","# 对每个类别分别计算precision和recall\n","for category, keyword in categories.items():\n"," # 初始化统计数据\n"," TP = 0\n"," FP = 0\n"," FN = 0\n","\n"," # 遍历图片\n"," for image_name in os.listdir(image_folder_path):\n"," if image_name.endswith(\".png\"):\n"," image_path = os.path.join(image_folder_path, image_name)\n","\n"," # 预处理图片\n"," image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)\n","\n"," # 准备文本\n"," text_labels = [category, \"others\"]\n"," text = clip.tokenize(text_labels).to(device)\n","\n"," # 进行预测\n"," with torch.no_grad():\n"," logits_per_image, _ = model(image, text)\n"," probs = logits_per_image.softmax(dim=-1).cpu().numpy()\n","\n"," # 使用阈值判断类别\n"," is_positive_prediction = probs[0, 0] > threshold\n"," predicted_category = category if is_positive_prediction else \"others\"\n","\n"," # 判断真实类别\n"," actual_category = category if keyword in image_name else \"others\"\n","\n"," # 更新统计数据\n"," if predicted_category == actual_category and actual_category == category:\n"," TP += 1\n"," elif predicted_category == category and actual_category == \"others\":\n"," FP += 1\n"," elif actual_category == category and predicted_category == \"others\":\n"," FN += 1\n","\n"," # 计算Precision和Recall\n"," precision = TP / (TP + FP) if (TP + FP) > 0 else 0\n"," recall = TP / (TP + FN) if (TP + FN) > 0 else 0\n"," f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0\n","\n","\n"," print(f\"Category: {category}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1_score:.4f}\")\n"," total_precision += precision\n"," total_recall += recall\n"," total_f1 += f1_score\n","\n","# 计算并打印平均Precision和Recall\n","average_precision = total_precision / len(categories)\n","average_recall = total_recall / len(categories)\n","average_f1 = total_f1 / len(categories)\n","\n","print(f\"Average Precision: {average_precision:.4f}, Average Recall: {average_recall:.4f}, Average F1: {average_f1:.4f}\")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"FTT3wwx9cYWO","executionInfo":{"status":"ok","timestamp":1709786207212,"user_tz":-480,"elapsed":275249,"user":{"displayName":"Thomasine Kaczka","userId":"06732724743998191518"}},"outputId":"64a7af62-9714-4101-d3b4-dfdf68999f4a"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Category: pet cat, Precision: 0.9789, Recall: 0.9300, F1 Score: 0.9538\n","Category: tomato, Precision: 0.6923, Recall: 0.9900, F1 Score: 0.8148\n","Category: paper-cut, Precision: 0.8058, Recall: 0.5600, F1 Score: 0.6608\n","Category: computer, Precision: 0.8428, Recall: 0.9650, F1 Score: 0.8998\n","Category: dumpling, Precision: 0.4910, Recall: 0.9600, F1 Score: 0.6497\n","Average Precision: 0.7622, Average Recall: 0.8810, Average F1: 0.7958\n"]}]},{"cell_type":"markdown","source":["卡阈值(单类别相似度)"],"metadata":{"id":"MUnhUQ1vqvCu"}},{"cell_type":"code","source":["import torch\n","import clip\n","from PIL import Image\n","import os\n","\n","device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n","\n","# 加载模型\n","model, preprocess = clip.load(\"ViT-B/32\", device=device)\n","\n","# 定义类别和对应的图片命名规则\n","categories = {\n"," \"pet cat\": \"宠物猫\",\n"," \"tomato\": \"番茄\",\n"," \"paper-cut\": \"剪纸\",\n"," \"computer\": \"电脑\",\n"," \"dumpling\": \"饺子\"\n","}\n","\n","# 设定阈值\n","threshold = 0.23\n","\n","total_precision = 0\n","total_recall = 0\n","total_f1 = 0\n","\n","# 对每个类别分别计算precision和recall\n","for category, keyword in categories.items():\n"," # 初始化统计数据\n"," TP = 0\n"," FP = 0\n"," FN = 0\n","\n"," # 遍历图片\n"," for image_name in os.listdir(image_folder_path):\n"," if image_name.endswith(\".png\"):\n"," image_path = os.path.join(image_folder_path, image_name)\n","\n"," # 预处理图片\n"," image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)\n","\n"," # 准备文本\n"," text_labels = [category]\n"," text = clip.tokenize(text_labels).to(device)\n","\n"," # 进行预测\n"," with torch.no_grad():\n"," image_features = model.encode_image(image)\n"," text_features = model.encode_text(text)\n","\n"," image_features /= image_features.norm(dim=-1, keepdim=True)\n"," text_features /= text_features.norm(dim=-1, keepdim=True)\n","\n"," similarity = image_features @ text_features.T\n","\n"," # 使用阈值判断类别\n"," is_positive_prediction = similarity[0] > threshold\n"," predicted_category = category if is_positive_prediction else \"others\"\n","\n"," # 判断真实类别\n"," actual_category = category if keyword in image_name else \"others\"\n","\n"," # 更新统计数据\n"," if predicted_category == actual_category and actual_category == category:\n"," TP += 1\n"," elif predicted_category == category and actual_category == \"others\":\n"," FP += 1\n"," elif actual_category == category and predicted_category == \"others\":\n"," FN += 1\n","\n"," # 计算Precision和Recall\n"," precision = TP / (TP + FP) if (TP + FP) > 0 else 0\n"," recall = TP / (TP + FN) if (TP + FN) > 0 else 0\n"," f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0\n","\n","\n"," print(f\"Category: {category}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1_score:.4f}\")\n"," total_precision += precision\n"," total_recall += recall\n"," total_f1 += f1_score\n","\n","# 计算并打印平均Precision和Recall\n","average_precision = total_precision / len(categories)\n","average_recall = total_recall / len(categories)\n","average_f1 = total_f1 / len(categories)\n","\n","print(f\"Average Precision: {average_precision:.4f}, Average Recall: {average_recall:.4f}, Average F1: {average_f1:.4f}\")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"G_3LMzzRTG65","executionInfo":{"status":"ok","timestamp":1710062303479,"user_tz":-480,"elapsed":313717,"user":{"displayName":"Thomasine Kaczka","userId":"06732724743998191518"}},"outputId":"03eb2b6b-5165-41a1-b706-48d0947f45c9"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Category: pet cat, Precision: 0.9132, Recall: 1.0000, F1 Score: 0.9547\n","Category: tomato, Precision: 0.5038, Recall: 0.9850, F1 Score: 0.6667\n","Category: paper-cut, Precision: 0.7036, Recall: 0.8900, F1 Score: 0.7859\n","Category: computer, Precision: 0.9031, Recall: 0.8850, F1 Score: 0.8939\n","Category: dumpling, Precision: 0.3186, Recall: 0.9700, F1 Score: 0.4796\n","Average Precision: 0.6685, Average Recall: 0.9460, Average F1: 0.7561\n"]}]}]} -------------------------------------------------------------------------------- /LLaVA.ipynb: -------------------------------------------------------------------------------- 1 | {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"machine_shape":"hm","gpuType":"A100","mount_file_id":"1XZVzWbmE82x-zqh7l8jA_OxXpK2Wp3Xo","authorship_tag":"ABX9TyPzjmRc9Wrcat3cJdyvsw2a"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"accelerator":"GPU","widgets":{"application/vnd.jupyter.widget-state+json":{"0ca0009f2a93493bac02da2a64d73e8e":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_ca0f5b4219f5403b9fb07fc481b106ce","IPY_MODEL_0cbaabb7826a46648832afb671e1f49c","IPY_MODEL_bf2a53bec0ed4e498850cfdbb8d40c73"],"layout":"IPY_MODEL_4b35cd35ba574525873d92528ccab59e"}},"ca0f5b4219f5403b9fb07fc481b106ce":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_b001ba37765f4282bd3e70f433e0e692","placeholder":"​","style":"IPY_MODEL_37492222fac349a4828812196cc44cbf","value":"Loading checkpoint shards: 100%"}},"0cbaabb7826a46648832afb671e1f49c":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"success","description":"","description_tooltip":null,"layout":"IPY_MODEL_cc786140912243048057de4af9fec28d","max":2,"min":0,"orientation":"horizontal","style":"IPY_MODEL_ef84b19574b5477f9f6fe2d7741a1eba","value":2}},"bf2a53bec0ed4e498850cfdbb8d40c73":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_daa6b1473a5145b9b684b143ccc4c48a","placeholder":"​","style":"IPY_MODEL_666daa6fcd174b71a05eae0005f37fea","value":" 2/2 [00:02<00:00,  1.11s/it]"}},"4b35cd35ba574525873d92528ccab59e":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"b001ba37765f4282bd3e70f433e0e692":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"37492222fac349a4828812196cc44cbf":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"cc786140912243048057de4af9fec28d":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"ef84b19574b5477f9f6fe2d7741a1eba":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"daa6b1473a5145b9b684b143ccc4c48a":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"666daa6fcd174b71a05eae0005f37fea":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}}}}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"pDsvkAR8ShoO"},"outputs":[],"source":["!git clone https://github.com/haotian-liu/LLaVA.git\n","%cd LLaVA\n","!pip install -e ."]},{"cell_type":"markdown","source":["加载模型"],"metadata":{"id":"5f4Hyc4NKvC4"}},{"cell_type":"code","source":["from llava.model.builder import load_pretrained_model\n","from llava.mm_utils import get_model_name_from_path\n","from llava.eval.run_llava import eval_model\n","\n","model_path = \"liuhaotian/llava-v1.5-7b\"\n","\n","tokenizer, model, image_processor, context_len = load_pretrained_model(\n"," model_path=model_path,\n"," model_base=None,\n"," model_name=get_model_name_from_path(model_path),\n"," load_in_4bit=True\n",")"],"metadata":{"id":"yC3VhWJlV-VD"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Github demo"],"metadata":{"id":"QLLrWbT5KwHR"}},{"cell_type":"code","source":["model_path = \"liuhaotian/llava-v1.5-7b\"\n","prompt = \"Please determine if this is a photo about computer? Answer yes or no\"\n","image_file = \"/content/drive/MyDrive/target_data/computer/电脑_1.png\"\n","\n","args = type('Args', (), {\n"," \"model_path\": model_path,\n"," \"model_base\": None,\n"," \"model_name\": get_model_name_from_path(model_path),\n"," \"query\": prompt,\n"," \"conv_mode\": None,\n"," \"image_file\": image_file,\n"," \"sep\": \",\",\n"," \"temperature\": 0,\n"," \"top_p\": None,\n"," \"num_beams\": 1,\n"," \"max_new_tokens\": 512\n","})()\n","\n","eval_model(args)\n"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":176,"referenced_widgets":["0ca0009f2a93493bac02da2a64d73e8e","ca0f5b4219f5403b9fb07fc481b106ce","0cbaabb7826a46648832afb671e1f49c","bf2a53bec0ed4e498850cfdbb8d40c73","4b35cd35ba574525873d92528ccab59e","b001ba37765f4282bd3e70f433e0e692","37492222fac349a4828812196cc44cbf","cc786140912243048057de4af9fec28d","ef84b19574b5477f9f6fe2d7741a1eba","daa6b1473a5145b9b684b143ccc4c48a","666daa6fcd174b71a05eae0005f37fea"]},"id":"SESxcO9yWrKJ","executionInfo":{"status":"ok","timestamp":1709881595869,"user_tz":-480,"elapsed":10053,"user":{"displayName":"Thomasine Kaczka","userId":"06732724743998191518"}},"outputId":"fcbfff10-d417-437c-a8e7-8f3e04410421"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stderr","text":["You are using a model of type llava to instantiate a model of type llava_llama. This is not supported for all configurations of models and can yield errors.\n"]},{"output_type":"display_data","data":{"text/plain":["Loading checkpoint shards: 0%| | 0/2 [00:00 0 else 0\n"," recall = TP / (TP + FN) if (TP + FN) > 0 else 0\n"," f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0\n"," return precision, recall, f1\n","\n","\n","def verify_and_calculate_metrics(target_root_folder, categories):\n"," total_precision, total_recall, total_f1 = 0, 0, 0\n","\n"," for category, keyword in categories.items():\n"," TP, FP, FN = 0, 0, 0\n","\n"," category_folder = os.path.join(target_root_folder, category)\n"," prompt = f\"Please determine if this is a photo about {category}? Answer yes or no\"\n"," args.query = prompt\n","\n"," for image_name in os.listdir(category_folder):\n"," image_path = os.path.join(category_folder, image_name)\n"," args.image_file = image_path\n","\n"," output = get_eval_model_output(args)\n"," is_correct = \"yes\" in output.lower()\n","\n"," # 根据图片名称判断真实类别\n"," actual_category = keyword if keyword in image_name else \"其他\"\n","\n"," # 更新统计数据\n"," if is_correct and actual_category == keyword:\n"," TP += 1\n"," elif is_correct and actual_category != keyword:\n"," FP += 1\n","\n"," # 计算并打印当前类别的Precision、Recall和F1 score\n"," FN = 200 - TP\n"," precision, recall, f1 = calculate_metrics(TP, FP, FN)\n"," total_precision += precision\n"," total_recall += recall\n"," total_f1 += f1\n"," print(f\"Category: {category}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}\")\n","\n"," # 计算并打印总体平均Precision、Recall和F1 score\n"," average_precision = total_precision / len(categories)\n"," average_recall = total_recall / len(categories)\n"," average_f1 = total_f1 / len(categories)\n"," print(f\"Average Precision: {average_precision:.4f}, Average Recall: {average_recall:.4f}, Average F1: {average_f1:.4f}\")\n","\n","\n","# 示例调用函数\n","target_root_folder = \"/content/drive/MyDrive/target_data\"\n","verify_and_calculate_metrics(target_root_folder, categories)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"K9Sb5fJlLJp7","executionInfo":{"status":"ok","timestamp":1709893627937,"user_tz":-480,"elapsed":1242020,"user":{"displayName":"Thomasine Kaczka","userId":"06732724743998191518"}},"outputId":"2b474ee8-fbca-43c4-8919-38bc72e72d43"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Category: pet cat, Precision: 0.9132, Recall: 1.0000, F1 Score: 0.9547\n","Category: tomato, Precision: 0.9949, Recall: 0.9750, F1 Score: 0.9848\n","Category: paper-cut, Precision: 0.8609, Recall: 0.9900, F1 Score: 0.9209\n","Category: computer, Precision: 0.9469, Recall: 0.9800, F1 Score: 0.9631\n","Category: dumpling, Precision: 0.7976, Recall: 0.9850, F1 Score: 0.8814\n","Average Precision: 0.9027, Average Recall: 0.9860, Average F1: 0.9410\n"]}]},{"cell_type":"code","source":["categories = {\n"," \"pet cat\": \"宠物猫\",\n"," \"tomato\": \"番茄\",\n"," \"paper-cut\": \"剪纸\",\n"," \"computer\": \"电脑\",\n"," \"dumpling\": \"饺子\"\n","}\n","\n","def calculate_metrics(TP, FP, FN):\n"," precision = TP / (TP + FP) if (TP + FP) > 0 else 0\n"," recall = TP / (TP + FN) if (TP + FN) > 0 else 0\n"," f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0\n"," return precision, recall, f1\n","\n","\n","def verify_and_calculate_metrics(target_root_folder, categories):\n"," total_precision, total_recall, total_f1 = 0, 0, 0\n","\n"," for category, keyword in categories.items():\n"," TP, FP, FN = 0, 0, 0\n","\n"," category_folder = os.path.join(target_root_folder, category)\n"," prompt = f\"Please determine if this is a photo about {category}? Answer yes or no\"\n"," args.query = prompt\n","\n"," for image_name in os.listdir(category_folder):\n"," image_path = os.path.join(category_folder, image_name)\n"," args.image_file = image_path\n","\n"," output = get_eval_model_output(args)\n"," is_correct = \"yes\" in output.lower()\n","\n"," # 根据图片名称判断真实类别\n"," actual_category = keyword if keyword in image_name else \"其他\"\n","\n"," # 更新统计数据\n"," if is_correct and actual_category == keyword:\n"," TP += 1\n"," elif is_correct and actual_category != keyword:\n"," FP += 1\n","\n"," # 计算并打印当前类别的Precision、Recall和F1 score\n"," FN = 200 - TP\n"," precision, recall, f1 = calculate_metrics(TP, FP, FN)\n"," total_precision += precision\n"," total_recall += recall\n"," total_f1 += f1\n"," print(f\"Category: {category}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}\")\n","\n"," # 计算并打印总体平均Precision、Recall和F1 score\n"," average_precision = total_precision / len(categories)\n"," average_recall = total_recall / len(categories)\n"," average_f1 = total_f1 / len(categories)\n"," print(f\"Average Precision: {average_precision:.4f}, Average Recall: {average_recall:.4f}, Average F1: {average_f1:.4f}\")\n","\n","\n","# 示例调用函数\n","target_root_folder = \"/content/drive/MyDrive/target_data2\"\n","verify_and_calculate_metrics(target_root_folder, categories)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"lyVaTEvlSsGH","executionInfo":{"status":"ok","timestamp":1710242624659,"user_tz":-480,"elapsed":1156728,"user":{"displayName":"Thomasine Kaczka","userId":"06732724743998191518"}},"outputId":"2cf4f580-815f-4bee-c6aa-a8f44810b21f"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stderr","text":["/usr/local/lib/python3.10/dist-packages/bitsandbytes/nn/modules.py:391: UserWarning: Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.\n"," warnings.warn('Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.')\n","/usr/local/lib/python3.10/dist-packages/transformers/generation/configuration_utils.py:392: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n"," warnings.warn(\n","/usr/local/lib/python3.10/dist-packages/transformers/generation/configuration_utils.py:397: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `None` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.\n"," warnings.warn(\n","/usr/local/lib/python3.10/dist-packages/transformers/generation/configuration_utils.py:392: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n"," warnings.warn(\n","/usr/local/lib/python3.10/dist-packages/transformers/generation/configuration_utils.py:397: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `None` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.\n"," warnings.warn(\n"]},{"output_type":"stream","name":"stdout","text":["Category: pet cat, Precision: 0.9615, Recall: 1.0000, F1 Score: 0.9804\n","Category: tomato, Precision: 0.9949, Recall: 0.9750, F1 Score: 0.9848\n","Category: paper-cut, Precision: 0.8603, Recall: 0.9850, F1 Score: 0.9184\n"]},{"output_type":"stream","name":"stderr","text":["/usr/local/lib/python3.10/dist-packages/PIL/Image.py:996: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images\n"," warnings.warn(\n"]},{"output_type":"stream","name":"stdout","text":["Category: computer, Precision: 0.9897, Recall: 0.9600, F1 Score: 0.9746\n","Category: dumpling, Precision: 0.8140, Recall: 0.9850, F1 Score: 0.8914\n","Average Precision: 0.9241, Average Recall: 0.9810, Average F1: 0.9499\n"]}]},{"cell_type":"code","source":["categories = {\n"," \"pet cat\": \"宠物猫\",\n"," \"tomato\": \"番茄\",\n"," \"paper-cut\": \"剪纸\",\n"," \"computer\": \"电脑\",\n"," \"dumpling\": \"饺子\"\n","}\n","\n","def calculate_metrics(TP, FP, FN):\n"," precision = TP / (TP + FP) if (TP + FP) > 0 else 0\n"," recall = TP / (TP + FN) if (TP + FN) > 0 else 0\n"," f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0\n"," return precision, recall, f1\n","\n","\n","def verify_and_calculate_metrics(target_root_folder, categories):\n"," total_precision, total_recall, total_f1 = 0, 0, 0\n","\n"," for category, keyword in categories.items():\n"," TP, FP, FN = 0, 0, 0\n","\n"," category_folder = os.path.join(target_root_folder, category)\n"," prompt = f\"Please determine if this is a photo about {category}? Answer yes or no\"\n"," args.query = prompt\n","\n"," for image_name in os.listdir(category_folder):\n"," image_path = os.path.join(category_folder, image_name)\n"," args.image_file = image_path\n","\n"," output = get_eval_model_output(args)\n"," is_correct = \"yes\" in output.lower()\n","\n"," # 根据图片名称判断真实类别\n"," actual_category = keyword if keyword in image_name else \"其他\"\n","\n"," # 更新统计数据\n"," if is_correct and actual_category == keyword:\n"," TP += 1\n"," elif is_correct and actual_category != keyword:\n"," FP += 1\n","\n"," # 计算并打印当前类别的Precision、Recall和F1 score\n"," FN = 200 - TP\n"," precision, recall, f1 = calculate_metrics(TP, FP, FN)\n"," total_precision += precision\n"," total_recall += recall\n"," total_f1 += f1\n"," print(f\"Category: {category}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}\")\n","\n"," # 计算并打印总体平均Precision、Recall和F1 score\n"," average_precision = total_precision / len(categories)\n"," average_recall = total_recall / len(categories)\n"," average_f1 = total_f1 / len(categories)\n"," print(f\"Average Precision: {average_precision:.4f}, Average Recall: {average_recall:.4f}, Average F1: {average_f1:.4f}\")\n","\n","\n","# 示例调用函数\n","target_root_folder = \"/content/drive/MyDrive/target_data3\"\n","verify_and_calculate_metrics(target_root_folder, categories)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"Zx3aSVuZzq5e","executionInfo":{"status":"ok","timestamp":1710385138663,"user_tz":-480,"elapsed":75563,"user":{"displayName":"Thomasine Kaczka","userId":"06732724743998191518"}},"outputId":"a27b5ac5-a13f-4f09-d8e3-49c8ced5d261"},"execution_count":3,"outputs":[{"metadata":{"tags":null},"name":"stderr","output_type":"stream","text":["/usr/local/lib/python3.10/dist-packages/bitsandbytes/nn/modules.py:391: UserWarning: Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.\n"," warnings.warn('Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.')\n","/usr/local/lib/python3.10/dist-packages/transformers/generation/configuration_utils.py:392: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n"," warnings.warn(\n","/usr/local/lib/python3.10/dist-packages/transformers/generation/configuration_utils.py:397: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `None` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.\n"," warnings.warn(\n","/usr/local/lib/python3.10/dist-packages/transformers/generation/configuration_utils.py:392: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n"," warnings.warn(\n","/usr/local/lib/python3.10/dist-packages/transformers/generation/configuration_utils.py:397: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `None` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.\n"," warnings.warn(\n"]},{"metadata":{"tags":null},"name":"stdout","output_type":"stream","text":["Category: pet cat, Precision: 0.9662, Recall: 1.0000, F1 Score: 0.9828\n","Category: tomato, Precision: 0.9949, Recall: 0.9750, F1 Score: 0.9848\n","Category: paper-cut, Precision: 0.9336, Recall: 0.9850, F1 Score: 0.9586\n"]},{"metadata":{"tags":null},"name":"stderr","output_type":"stream","text":["/usr/local/lib/python3.10/dist-packages/PIL/Image.py:996: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images\n"," warnings.warn(\n"]},{"output_type":"stream","name":"stdout","text":["Category: computer, Precision: 1.0000, Recall: 0.9150, F1 Score: 0.9556\n","Category: dumpling, Precision: 0.8383, Recall: 0.9850, F1 Score: 0.9057\n","Average Precision: 0.9466, Average Recall: 0.9720, Average F1: 0.9575\n"]}]},{"cell_type":"markdown","source":["释放内存"],"metadata":{"id":"QgZpS_hLNCU5"}},{"cell_type":"code","source":["del model"],"metadata":{"id":"SQ-LH_vmdbFz"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["import torch\n","import gc\n","gc.collect()\n","torch.cuda.empty_cache()"],"metadata":{"id":"vmhUHmS6h9oV"},"execution_count":null,"outputs":[]}]} -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 基于多模态检索的互联网图文匹配 2 | ## 运行环境 3 | Google Colaboratory 4 | LLaVa部分需要A100-40G,其余16G即可 5 | ## 项目背景 6 | 图文检索是当前最重要的多模态能力之一,针对互联网或者AI公司积累的大量历史数据,以及手机图库中的图片数据,利用多模态大模型对这些数据进行检索,是数据治理的关键。 7 | ## 项目内容 8 | ### 1. 数据集: 9 | - 数据来源:爬取多类别的互联网图像数据2000张 10 | - 样本选择:选择5个正样本,每个正样本对应易混淆的负样本,5个正样本中有2个和中国特色相关的词语描述。每个正负样本200张图,10个正负样本共2000张图片。 11 | ### 2. 项目流程: 12 | 中英文CLIP图文匹配+多模态大模型LLaVA结果矫正 13 | 1. 先用英文CLIP模型对图片以及文本标签进行匹配。 14 | 2. 设置对比实验: 15 | - 同时输入5个类别以计算相似度的方式来计算precision、recall和F1 score 16 | - 只用其中一个类别+其他类计算precision和recall和F1 Score 17 | - 通过卡阈值的方式,只输入一个类别计算precision和recall和F1 Score 18 | 在上面三种方式中选择精度最高的。 19 | 3. 使用中文CLIP模型,按照精度最高的方式进行图文匹配,并计算精度。 20 | 4. 将英文和中文CLIP的匹配结果取并集,去重后作为正样本,并计算精度。 21 | 5. 中英文CLIP的结果经过LLaVA做矫正,保留LLaVA输出正确的正样本,并计算精度。 22 | 6. 额外收集5个正样本的1000张图使用CLIP-Adapter进行微调训练,在之前的2000张图上做测试,然后从1000张图中抽取10,100,500张图片用于训练的测试对比。 23 | ### 3. 注: 24 | - 数据集中加入了中国特色的词语,中文模型在中文领域有更好的匹配效果。 25 | - 中英文模型取并集的结果可以尽可能召回更多的正样本,减少遗漏,提升recall。 26 | - 合并后召回的图片数量增加,precision不够高,LLaVA可以进一步筛选,提高precision。 27 | ## 项目结论 28 | 1. 对于原始英文CLIP,精度最高的方式为卡阈值,随着阈值的增加,precision会逐渐提高,recall则逐渐降低,当阈值为0.24时,整体精度最高:Average Precision: 0.7740, Average Recall: 0.8800, Average F1: 0.7984。 29 | 2. 合并中英文CLIP的匹配结果,Average Precision: 0.7794, Average Recall: 0.9830, Average F1: 0.8519,不损失precison的情况下,将recall从88%提高到了98.3%,提高了11.7%。 30 | 3. 经过LLaVA矫正,Average Precision: 0.9466, Average Recall: 0.9720, Average F1: 0.9575,损失少量recall,提高了21.5%precison,相比于原始CLIP,最终精度(F1 Score)提高了19.9%。 31 | 4. 经过CLIP-Adapter的微调,100张图片的训练几乎没有提高精度,500张图片的训练将精度提高了2.37%,1000张图片的训练将精度提高了2.93%。 32 | ## 不足与展望 33 | 1. 数据集使用了中文互联网的图片数据,中文CLIP已经有很好的效果,可以尝试不同的数据来源进行测试。 34 | 2. 使用LLaVA对结果进行矫正有不错的效果,不过precision还有提升的空间,可以在特定的VQA场景上进行微调。 35 | 3. CLIP-Adapter进行微调时使用的数据量较小,效果不太明显,可以尝试构建更大量级的微调数据集。 36 | ## 参考 37 | 原始CLIP:https://github.com/openai/CLIP 38 | 中文CLIP:https://huggingface.co/IDEA-CCNL/Taiyi-CLIP-Roberta-large-326M-Chinese 39 | LLaVA:https://github.com/haotian-liu/LLaVA 40 | CLIP-Adapter(Tip-Adapter):https://github.com/gaopengcuhk/Tip-Adapter 41 | -------------------------------------------------------------------------------- /Tip-Adapter.ipynb: -------------------------------------------------------------------------------- 1 | {"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"HIgPKJfKuW6B"},"outputs":[],"source":["!git clone https://github.com/gaopengcuhk/Tip-Adapter.git\n","%cd Tip-Adapter\n","!pip install -r requirements.txt"]},{"cell_type":"markdown","source":["修改Tip-Adapter中的main.py"],"metadata":{"id":"kgTAL-uYU9jw"}},{"cell_type":"code","execution_count":2,"metadata":{"id":"YTP_zhnmwfLE","executionInfo":{"status":"ok","timestamp":1710409110129,"user_tz":-480,"elapsed":1377,"user":{"displayName":"Thomasine Kaczka","userId":"06732724743998191518"}}},"outputs":[],"source":["!cp -r /content/drive/MyDrive/main.py /content/Tip-Adapter"]},{"cell_type":"markdown","source":["将处理好的数据集复制到指定路径"],"metadata":{"id":"PEZf-F9xeLdO"}},{"cell_type":"code","execution_count":10,"metadata":{"id":"PU7kr_eXuEJP","executionInfo":{"status":"ok","timestamp":1710411250996,"user_tz":-480,"elapsed":96615,"user":{"displayName":"Thomasine Kaczka","userId":"06732724743998191518"}}},"outputs":[],"source":["!cp -r /content/drive/MyDrive/data_process/data_process_1000/ /content/drive/MyDrive/ucf101/UCF-101-midframes"]},{"cell_type":"markdown","source":["100张图训练"],"metadata":{"id":"nxBUfiHQgMb0"}},{"cell_type":"markdown","source":["注意:训练前需要在/content/Tip-Adapter/configs/ucf101.yaml中补充root_path"],"metadata":{"id":"Ye2rD1MOWNqf"}},{"cell_type":"code","source":["!CUDA_VISIBLE_DEVICES=0 python main.py --config configs/ucf101.yaml"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"sCzKeHz1c3se","executionInfo":{"status":"ok","timestamp":1710411964346,"user_tz":-480,"elapsed":264519,"user":{"displayName":"Thomasine Kaczka","userId":"06732724743998191518"}},"outputId":"c50a480e-766f-419c-a5dc-7971e2bc8054"},"execution_count":12,"outputs":[{"output_type":"stream","name":"stdout","text":["\n","Running configs.\n","{'root_path': '/content/drive/MyDrive', 'load_cache': False, 'load_pre_feat': False, 'search_hp': True, 'search_scale': [7, 3], 'search_step': [200, 20], 'init_beta': 1, 'init_alpha': 3, 'dataset': 'ucf101', 'shots': 16, 'backbone': 'RN50', 'lr': 0.001, 'augment_epoch': 10, 'train_epoch': 20, 'cache_dir': './caches/ucf101'} \n","\n","Preparing dataset.\n","Reading split from /content/drive/MyDrive/ucf101/split_zhou_UCF101.json\n","Creating a 16-shot dataset\n","\n","Getting textual features as CLIP's classifier.\n","\n","Constructing cache model by few-shot visual features and labels.\n","Augment Epoch: 0 / 10\n","100% 1/1 [00:47<00:00, 47.37s/it]\n","Augment Epoch: 1 / 10\n","100% 1/1 [00:00<00:00, 1.23it/s]\n","Augment Epoch: 2 / 10\n","100% 1/1 [00:00<00:00, 1.23it/s]\n","Augment Epoch: 3 / 10\n","100% 1/1 [00:00<00:00, 1.26it/s]\n","Augment Epoch: 4 / 10\n","100% 1/1 [00:00<00:00, 1.26it/s]\n","Augment Epoch: 5 / 10\n","100% 1/1 [00:00<00:00, 1.24it/s]\n","Augment Epoch: 6 / 10\n","100% 1/1 [00:00<00:00, 1.24it/s]\n","Augment Epoch: 7 / 10\n","100% 1/1 [00:00<00:00, 1.20it/s]\n","Augment Epoch: 8 / 10\n","100% 1/1 [00:00<00:00, 1.18it/s]\n","Augment Epoch: 9 / 10\n","100% 1/1 [00:00<00:00, 1.16it/s]\n","\n","Loading visual features and labels from val set.\n","100% 1/1 [00:20<00:00, 20.72s/it]\n","\n","Loading visual features and labels from test set.\n"," 0% 0/31 [00:00\")\n"," with open(filename, \"wb\") as f:\n"," f.write(res.content)\n"," print(\"存储路径:\" + filename)\n","\n"," # 入口函数\n"," def run(self):\n"," searchName = input(\"查询内容:\")\n"," self.search_name = searchName\n"," searchName_parse = parse.quote(searchName) # 编码\n","\n"," self.create_directory(searchName)\n","\n"," pic_number = 0 # 图像数量\n"," for index in range(self.json_count):\n"," pn = (index+1)*10\n"," request_url = self.url.format(searchName_parse, searchName_parse, str(pn))\n"," list_image_link = self.get_image_link(request_url)\n"," for link in list_image_link:\n"," pic_number += 1\n"," filename = self.directory.format(f\"{self.search_name}_{str(pic_number)}.png\")\n"," self.save_image(link, filename)\n"," time.sleep(0.2) # 休眠0.2秒,防止封ip\n"," print(searchName+\"----图像下载完成--------->\")\n","\n","if __name__ == '__main__':\n"," spider = BaiduImageSpider()\n"," spider.json_count = 20\n"," spider.run()"]}]} -------------------------------------------------------------------------------- /中英文CLIP.ipynb: -------------------------------------------------------------------------------- 1 | {"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"4H70Mp-DIaPP"},"outputs":[],"source":["!pip install ftfy regex tqdm\n","!pip install git+https://github.com/openai/CLIP.git"]},{"cell_type":"markdown","source":["中文CLIP huggingface demo"],"metadata":{"id":"c1MBvfhyIRXL"}},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":3116,"status":"ok","timestamp":1710218434386,"user":{"displayName":"Thomasine Kaczka","userId":"06732724743998191518"},"user_tz":-480},"id":"RI86z8fdCoRX","outputId":"ee4e8ffd-81a2-46d3-98bd-1816ca1919d1"},"outputs":[{"output_type":"stream","name":"stdout","text":["tensor([[ 0.1047, 0.0269, 0.0761, -0.0248, 0.0007]], device='cuda:0')\n","[[0.946 0. 0.054 0. 0. ]]\n"]}],"source":["from PIL import Image\n","import requests\n","import clip\n","import torch\n","from transformers import BertForSequenceClassification, BertConfig, BertTokenizer\n","from transformers import CLIPProcessor, CLIPModel\n","import numpy as np\n","\n","device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n","\n","query_texts = [\"一只猫\", \"一只狗\",'两只猫', '两只老虎','一只老虎'] # 这里是输入文本的,可以随意替换。\n","# 加载Taiyi 中文 text encoder\n","text_tokenizer = BertTokenizer.from_pretrained(\"IDEA-CCNL/Taiyi-CLIP-Roberta-large-326M-Chinese\")\n","text_encoder = BertForSequenceClassification.from_pretrained(\"IDEA-CCNL/Taiyi-CLIP-Roberta-large-326M-Chinese\").eval().to(device)\n","text = text_tokenizer(query_texts, return_tensors='pt', padding=True)['input_ids'].to(device)\n","\n","url = \"https://hbimg.huaban.com/e637198ad1a5a0b4347d1a21abdd4a6118bd5accb4a23-etvyBB_fw658\" # 这里可以换成任意图片的url\n","image_path = \"/content/drive/MyDrive/picture_data/宠物猫_2.png\"\n","# 加载CLIP的image encoder\n","clip_model = CLIPModel.from_pretrained(\"openai/clip-vit-large-patch14\").to(device)\n","processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-large-patch14\", device=device)\n","# image = processor(images=Image.open(requests.get(url, stream=True).raw), return_tensors=\"pt\")\n","image = processor(images=(Image.open(image_path)), return_tensors=\"pt\").to(device)\n","\n","\n","with torch.no_grad():\n"," image_features = clip_model.get_image_features(**image)\n"," text_features = text_encoder(text).logits\n"," # 归一化\n"," image_features = image_features / image_features.norm(dim=1, keepdim=True)\n"," text_features = text_features / text_features.norm(dim=1, keepdim=True)\n"," # 计算余弦相似度 logit_scale是尺度系数\n"," logit_scale = clip_model.logit_scale.exp()\n"," logits_per_image = logit_scale * image_features @ text_features.t()\n"," similarity = image_features @ text_features.T\n"," print(similarity)\n"," logits_per_text = logits_per_image.t()\n"," probs = logits_per_image.softmax(dim=-1).cpu().numpy()\n"," print(np.around(probs, 3))"]},{"cell_type":"markdown","source":["计算类别概率/相似度"],"metadata":{"id":"hgM9QGSVIt7s"}},{"cell_type":"code","execution_count":null,"metadata":{"id":"K-gV92AK4FU-"},"outputs":[],"source":["def predict(image, text):\n"," with torch.no_grad():\n"," image_features = clip_model.get_image_features(**image)\n"," text_features = text_encoder(text).logits\n"," # 归一化\n"," image_features = image_features / image_features.norm(dim=1, keepdim=True)\n"," text_features = text_features / text_features.norm(dim=1, keepdim=True)\n"," # 计算余弦相似度 logit_scale是尺度系数\n"," logit_scale = clip_model.logit_scale.exp()\n"," logits_per_image = logit_scale * image_features @ text_features.t()\n"," logits_per_text = logits_per_image.t()\n"," probs = logits_per_image.softmax(dim=-1).cpu().numpy()\n"," return probs\n","\n","def cal_similarity(image, text):\n"," with torch.no_grad():\n"," image_features = clip_model.get_image_features(**image)\n"," text_features = text_encoder(text).logits\n"," # 归一化\n"," image_features = image_features / image_features.norm(dim=1, keepdim=True)\n"," text_features = text_features / text_features.norm(dim=1, keepdim=True)\n"," # 计算余弦相似度 logit_scale是尺度系数\n"," logit_scale = clip_model.logit_scale.exp()\n"," similarity = image_features @ text_features.t()\n"," # logits_per_text = logits_per_image.t()\n"," # probs = logits_per_image.softmax(dim=-1).cpu().numpy()\n"," return similarity.cpu().numpy()"]},{"cell_type":"markdown","source":["卡阈值(单类别+others)"],"metadata":{"id":"UwN94Fi6I7_5"}},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":651049,"status":"ok","timestamp":1710234098914,"user":{"displayName":"Thomasine Kaczka","userId":"06732724743998191518"},"user_tz":-480},"id":"uwsKUUx-2-Pj","outputId":"be7055d9-82c0-4b38-fb5b-698e95aec7ba"},"outputs":[{"output_type":"stream","name":"stderr","text":["We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.\n","You may ignore this warning if your `pad_token_id` (0) is identical to the `bos_token_id` (0), `eos_token_id` (2), or the `sep_token_id` (None), and your input is not padded.\n"]},{"output_type":"stream","name":"stdout","text":["Category: 宠物猫, Precision: 0.7866, Recall: 0.9950, F1 Score: 0.8786\n","Category: 番茄, Precision: 0.5587, Recall: 1.0000, F1 Score: 0.7168\n","Category: 剪纸, Precision: 0.8032, Recall: 1.0000, F1 Score: 0.8909\n","Category: 电脑, Precision: 0.6576, Recall: 0.9700, F1 Score: 0.7838\n","Category: 饺子, Precision: 0.4158, Recall: 1.0000, F1 Score: 0.5874\n","Average Precision: 0.6444, Average Recall: 0.9930, Average F1: 0.7715\n"]}],"source":["from PIL import Image\n","import requests\n","import clip\n","import torch\n","from transformers import BertForSequenceClassification, BertConfig, BertTokenizer\n","from transformers import CLIPProcessor, CLIPModel\n","import numpy as np\n","import os\n","\n","device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n","\n","text_tokenizer = BertTokenizer.from_pretrained(\"IDEA-CCNL/Taiyi-CLIP-Roberta-large-326M-Chinese\")\n","text_encoder = BertForSequenceClassification.from_pretrained(\"IDEA-CCNL/Taiyi-CLIP-Roberta-large-326M-Chinese\").eval().to(device)\n","\n","clip_model = CLIPModel.from_pretrained(\"openai/clip-vit-large-patch14\").to(device)\n","processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-large-patch14\", device=device)\n","\n","# 定义类别和对应的图片命名规则\n","categories = [\"宠物猫\", \"番茄\", \"剪纸\", \"电脑\", \"饺子\"]\n","\n","# 初始化统计数据\n","stats = {category: {\"TP\": 0, \"FP\": 0, \"FN\": 0} for category in categories}\n","\n","# 图片文件夹路径\n","image_folder_path = \"/content/drive/MyDrive/picture_data\"\n","\n","# 设定阈值\n","threshold = 0.99\n","\n","# 初始化用于计算平均的变量\n","total_precision = 0\n","total_recall = 0\n","total_f1 = 0\n","\n","# 对每个类别分别计算precision和recall\n","for category in categories:\n"," # 初始化统计数据\n"," TP = 0\n"," FP = 0\n"," FN = 0\n","\n"," # 遍历图片\n"," for image_name in os.listdir(image_folder_path):\n"," if image_name.endswith(\".png\"):\n"," image_path = os.path.join(image_folder_path, image_name)\n","\n"," # 预处理图片\n"," image = processor(images=(Image.open(image_path)), return_tensors=\"pt\").to(device)\n","\n"," # 准备文本\n"," text_labels = [category, \"其他\"]\n"," # text = clip.tokenize(list(categories.keys()) + [\"others\"]).to(device)\n"," text = text_tokenizer(text_labels, return_tensors='pt', padding=True)['input_ids'].to(device)\n","\n"," # 进行预测\n"," probs = predict(image, text)\n","\n"," # 使用阈值判断类别\n"," is_positive_prediction = probs[0, 0] > threshold\n"," predicted_category = category if is_positive_prediction else \"others\"\n","\n"," # 判断真实类别\n"," actual_category = category if category in image_name else \"others\"\n","\n"," # 更新统计数据\n"," if predicted_category == actual_category and actual_category == category:\n"," TP += 1\n"," elif predicted_category == category and actual_category == \"others\":\n"," FP += 1\n"," elif actual_category == category and predicted_category == \"others\":\n"," FN += 1\n","\n"," # 计算Precision和Recall\n"," precision = TP / (TP + FP) if (TP + FP) > 0 else 0\n"," recall = TP / (TP + FN) if (TP + FN) > 0 else 0\n"," f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0\n","\n","\n"," print(f\"Category: {category}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1_score:.4f}\")\n"," total_precision += precision\n"," total_recall += recall\n"," total_f1 += f1_score\n","\n","# 计算并打印平均Precision和Recall\n","average_precision = total_precision / len(categories)\n","average_recall = total_recall / len(categories)\n","average_f1 = total_f1 / len(categories)\n","\n","print(f\"Average Precision: {average_precision:.4f}, Average Recall: {average_recall:.4f}, Average F1: {average_f1:.4f}\")"]},{"cell_type":"markdown","source":["卡阈值(单类别相似度)"],"metadata":{"id":"U_DTx1FxJPgF"}},{"cell_type":"code","source":["from PIL import Image\n","import requests\n","import clip\n","import torch\n","from transformers import BertForSequenceClassification, BertConfig, BertTokenizer\n","from transformers import CLIPProcessor, CLIPModel\n","import numpy as np\n","import os\n","\n","device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n","\n","text_tokenizer = BertTokenizer.from_pretrained(\"IDEA-CCNL/Taiyi-CLIP-Roberta-large-326M-Chinese\")\n","text_encoder = BertForSequenceClassification.from_pretrained(\"IDEA-CCNL/Taiyi-CLIP-Roberta-large-326M-Chinese\").eval().to(device)\n","\n","clip_model = CLIPModel.from_pretrained(\"openai/clip-vit-large-patch14\").to(device)\n","processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-large-patch14\", device=device)\n","\n","# 定义类别和对应的图片命名规则\n","categories = [\"宠物猫\", \"番茄\", \"剪纸\", \"电脑\", \"饺子\"]\n","\n","# 初始化统计数据\n","stats = {category: {\"TP\": 0, \"FP\": 0, \"FN\": 0} for category in categories}\n","\n","# 图片文件夹路径\n","image_folder_path = \"/content/drive/MyDrive/picture_data\"\n","\n","# 设定阈值\n","threshold = 0.11\n","\n","# 初始化用于计算平均的变量\n","total_precision = 0\n","total_recall = 0\n","total_f1 = 0\n","\n","# 对每个类别分别计算precision和recall\n","for category in categories:\n"," # 初始化统计数据\n"," TP = 0\n"," FP = 0\n"," FN = 0\n","\n"," # 遍历图片\n"," for image_name in os.listdir(image_folder_path):\n"," if image_name.endswith(\".png\"):\n"," image_path = os.path.join(image_folder_path, image_name)\n","\n"," # 预处理图片\n"," image = processor(images=(Image.open(image_path)), return_tensors=\"pt\").to(device)\n","\n"," # 准备文本\n"," text_labels = [category]\n"," # text = clip.tokenize(list(categories.keys()) + [\"others\"]).to(device)\n"," text = text_tokenizer(text_labels, return_tensors='pt', padding=True)['input_ids'].to(device)\n","\n"," # 进行预测\n"," similarity = cal_similarity(image, text)\n","\n"," # 使用阈值判断类别\n"," is_positive_prediction = similarity[0] > threshold\n"," predicted_category = category if is_positive_prediction else \"others\"\n","\n"," # 判断真实类别\n"," actual_category = category if category in image_name else \"others\"\n","\n"," # 更新统计数据\n"," if predicted_category == actual_category and actual_category == category:\n"," TP += 1\n"," elif predicted_category == category and actual_category == \"others\":\n"," FP += 1\n"," elif actual_category == category and predicted_category == \"others\":\n"," FN += 1\n","\n"," # 计算Precision和Recall\n"," precision = TP / (TP + FP) if (TP + FP) > 0 else 0\n"," recall = TP / (TP + FN) if (TP + FN) > 0 else 0\n"," f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0\n","\n","\n"," print(f\"Category: {category}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1_score:.4f}\")\n"," total_precision += precision\n"," total_recall += recall\n"," total_f1 += f1_score\n","\n","# 计算并打印平均Precision和Recall\n","average_precision = total_precision / len(categories)\n","average_recall = total_recall / len(categories)\n","average_f1 = total_f1 / len(categories)\n","\n","print(f\"Average Precision: {average_precision:.4f}, Average Recall: {average_recall:.4f}, Average F1: {average_f1:.4f}\")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"jOugK2a23Eub","executionInfo":{"status":"ok","timestamp":1710236639138,"user_tz":-480,"elapsed":639822,"user":{"displayName":"Thomasine Kaczka","userId":"06732724743998191518"}},"outputId":"c353b7d2-379a-4e76-80f5-58162b3e5694"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stderr","text":["/usr/local/lib/python3.10/dist-packages/PIL/Image.py:996: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images\n"," warnings.warn(\n"]},{"output_type":"stream","name":"stdout","text":["Category: 宠物猫, Precision: 0.9706, Recall: 0.9900, F1 Score: 0.9802\n","Category: 番茄, Precision: 1.0000, Recall: 0.9600, F1 Score: 0.9796\n","Category: 剪纸, Precision: 1.0000, Recall: 0.9600, F1 Score: 0.9796\n","Category: 电脑, Precision: 1.0000, Recall: 0.6250, F1 Score: 0.7692\n","Category: 饺子, Precision: 0.9213, Recall: 0.9950, F1 Score: 0.9567\n","Average Precision: 0.9784, Average Recall: 0.9060, Average F1: 0.9331\n"]}]},{"cell_type":"markdown","source":["中英文CLIP类别判断函数"],"metadata":{"id":"66R_xJQYJell"}},{"cell_type":"code","execution_count":null,"metadata":{"id":"QGVVfIEXOTi7"},"outputs":[],"source":["import os\n","import torch\n","import clip\n","import shutil\n","from PIL import Image\n","from transformers import BertTokenizer, BertForSequenceClassification, CLIPProcessor, CLIPModel\n","from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize\n","\n","# 设置设备\n","device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n","model_1, processor_1 = clip.load(\"ViT-B/32\", device=device)\n","\n","text_tokenizer = BertTokenizer.from_pretrained(\"IDEA-CCNL/Taiyi-CLIP-Roberta-large-326M-Chinese\")\n","text_encoder = BertForSequenceClassification.from_pretrained(\"IDEA-CCNL/Taiyi-CLIP-Roberta-large-326M-Chinese\").eval().to(device)\n","model_2 = CLIPModel.from_pretrained(\"openai/clip-vit-large-patch14\").to(device)\n","processor_2 = CLIPProcessor.from_pretrained(\"openai/clip-vit-large-patch14\", device=device)\n","\n","def run_english_clip_model(image_name, category):\n"," image_path = os.path.join(image_folder_path, image_name)\n"," image = processor_1(Image.open(image_path)).unsqueeze(0).to(device)\n"," # text = clip.tokenize([category, \"others\"]).to(device)\n"," text = clip.tokenize([category]).to(device)\n","\n"," with torch.no_grad():\n"," image_features = model_1.encode_image(image)\n"," text_features = model_1.encode_text(text)\n","\n"," image_features /= image_features.norm(dim=-1, keepdim=True)\n"," text_features /= text_features.norm(dim=-1, keepdim=True)\n","\n"," similarity = image_features @ text_features.T\n","\n"," threshold = 0.24\n"," is_positive_prediction = similarity[0] > threshold\n","\n"," return is_positive_prediction\n","\n","\n","def run_chinese_clip_model(image_name, category):\n"," image_path = os.path.join(image_folder_path, image_name)\n"," image = processor_2(images=(Image.open(image_path)), return_tensors=\"pt\").to(device)\n"," # text = text_tokenizer([category, \"其他\"], return_tensors='pt', padding=True)['input_ids'].to(device)\n"," text = text_tokenizer([category], return_tensors='pt', padding=True)['input_ids'].to(device)\n","\n"," with torch.no_grad():\n"," image_features = model_2.get_image_features(**image)\n"," text_features = text_encoder(text).logits\n"," # 归一化\n"," image_features = image_features / image_features.norm(dim=1, keepdim=True)\n"," text_features = text_features / text_features.norm(dim=1, keepdim=True)\n"," # 计算余弦相似度\n"," logit_scale = model_2.logit_scale.exp()\n"," logits_per_image = logit_scale * image_features @ text_features.t()\n"," # probs = logits_per_image.softmax(dim=-1).cpu().numpy()\n"," similarity = image_features @ text_features.t()\n","\n"," threshold = 0.1\n"," is_positive_prediction = similarity[0] > threshold\n","\n"," return is_positive_prediction"]},{"cell_type":"markdown","source":["中英文CLIP结果取并集,并复制分类结果"],"metadata":{"id":"HF7ljhT3JvHP"}},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"PS5A0NynNdNJ","executionInfo":{"status":"ok","timestamp":1710382652072,"user_tz":-480,"elapsed":872738,"user":{"displayName":"Thomasine Kaczka","userId":"06732724743998191518"}},"outputId":"16905f1f-0d30-4c8a-9718-f34d2727c00f"},"outputs":[{"output_type":"stream","name":"stdout","text":["Category: pet cat, Precision: 0.9390, Recall: 1.0000, F1 Score: 0.9685\n","Category: tomato, Precision: 0.7082, Recall: 0.9950, F1 Score: 0.8274\n","Category: paper-cut, Precision: 0.8615, Recall: 0.9950, F1 Score: 0.9234\n","Category: computer, Precision: 0.9635, Recall: 0.9250, F1 Score: 0.9439\n","Category: dumpling, Precision: 0.4246, Recall: 1.0000, F1 Score: 0.5961\n","Average Precision: 0.7794, Average Recall: 0.9830, Average F1: 0.8519\n"]}],"source":["categories = {\n"," \"pet cat\": \"宠物猫\",\n"," \"tomato\": \"番茄\",\n"," \"paper-cut\": \"剪纸\",\n"," \"computer\": \"电脑\",\n"," \"dumpling\": \"饺子\"\n","}\n","\n","image_folder_path = \"/content/drive/MyDrive/picture_data\"\n","target_root_folder = \"/content/drive/MyDrive/target_data3/\"\n","\n","# 初始化统计数据\n","total_precision = 0\n","total_recall = 0\n","total_f1 = 0\n","\n","for category in categories.keys():\n"," # 确保每个类别的目标文件夹存在\n"," target_folder = os.path.join(target_root_folder, category)\n"," os.makedirs(target_folder, exist_ok=True)\n","\n","for category, keyword in categories.items():\n"," TP = 0\n"," FP = 0\n"," FN = 0\n","\n"," for image_name in os.listdir(image_folder_path):\n"," if image_name.endswith(\".png\"):\n"," # 运行英文CLIP模型\n"," is_positive_english = run_english_clip_model(image_name, category)\n"," # 运行中文CLIP模型\n"," is_positive_chinese = run_chinese_clip_model(image_name, keyword)\n","\n"," # 判断是否为正样本(英文和中文结果的并集)\n"," is_positive = is_positive_english or is_positive_chinese\n","\n"," image_path = os.path.join(image_folder_path, image_name)\n"," if is_positive:\n"," target_folder = os.path.join(target_root_folder, category)\n"," shutil.copy(image_path, target_folder)\n","\n"," # 获取分类后的图片路径\n"," target_folder = os.path.join(target_root_folder, category)\n"," classified_images = os.listdir(target_folder)\n","\n"," for image_name in classified_images:\n"," if image_name.endswith(\".png\"):\n"," # 判断真实类别\n"," actual_category = category if keyword in image_name else \"其他\"\n","\n"," # 更新统计数据\n"," if actual_category == category:\n"," TP += 1\n"," else:\n"," FP += 1\n","\n"," # 计算并打印当前类别的Precision、Recall\n"," precision = TP / (TP + FP) if (TP + FP) > 0 else 0\n"," recall = TP / 200 # 固定正样本总数为200\n"," f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0\n"," print(f\"Category: {category}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1_score:.4f}\")\n","\n"," # 更新总体统计数据\n"," total_precision += precision\n"," total_recall += recall\n"," total_f1 += f1_score\n","\n","# 计算并打印平均Precision、Recall和F1分数\n","average_precision = total_precision / len(categories)\n","average_recall = total_recall / len(categories)\n","average_f1 = total_f1 / len(categories)\n","print(f\"Average Precision: {average_precision:.4f}, Average Recall: {average_recall:.4f}, Average F1: {average_f1:.4f}\")\n"]}],"metadata":{"accelerator":"GPU","colab":{"gpuType":"A100","machine_shape":"hm","provenance":[],"mount_file_id":"1ys6d9h7WcyQPrBWuROLdewS3y2EWBJ7P","authorship_tag":"ABX9TyOmOWZlaBs7gEYdq6PuYnw/"},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"}},"nbformat":4,"nbformat_minor":0} --------------------------------------------------------------------------------