├── .gitignore
├── LICENSE
├── README.md
├── detection
├── coco2yolo.py
├── coco_eval.py
├── vis_yolo_gt_dt.py
└── yolo2coco.py
└── text-image
├── convert_diffusers_to_original_stable_diffusion.py
├── data_filter
├── data_filter_demo.ipynb
├── wukong_filter.py
└── wukong_reader.py
├── fid_clip_score
├── .gitignore
├── coco_sample_generator.py
├── compute_fid.ipynb
├── fid_clip_coco.ipynb
├── fid_clip_coco_cn.ipynb
├── run_generator.sh
└── run_generator_cn.sh
├── imagenet_CN_zeroshot_data.py
├── iterable_tar_unzip.sh
├── save_hg_ckpt.ipynb
└── zeroshot_retrieval_evaluation.ipynb
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Weifeng Chen
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | Some Scripts For DEEP LEARNING
3 |
4 | # 1. detection
5 | ## yolo2coco.py
6 | 将yolo格式数据集修改成coco格式。`$ROOT_PATH`是根目录,需要按下面的形式组织数据:
7 |
8 | ```bash
9 | └── $ROOT_PATH
10 |
11 | ├── classes.txt
12 |
13 | ├── images
14 |
15 | └──labels
16 | ```
17 |
18 | - `classes.txt` 是类的声明,一行一类。
19 |
20 | - `images` 目录包含所有图片 (目前支持`png`和`jpg`格式数据)
21 |
22 | - `labels` 目录包含所有标签(与图片**同名**的`txt`格式数据)
23 |
24 | 配置好文件夹后,执行:`python yolo2coco.py --root_dir $ROOT_PATH ` ,然后就能看见生成的 `annotations` 文件夹。
25 |
26 | **参数说明**
27 | - `--root_path` 输入根目录`$ROOT_PATH`的位置。
28 | - `--save_path` 如果不进行数据集划分,可利用此参数指定输出文件的名字,默认保存为`train.json`
29 | - `--random_split` 随机划分参数,若指定`--random_split`参数,则输出在`annotations`文件夹下包含 `train.json` `val.json` `test.json` (默认随机划分成8:1:1)
30 | - `--split_by_file` 自定义数据集划分,若指定`--split_by_file`参数,则输出在`annotations`文件夹 `train.json` `val.json` `test.json`。需要在`$ROOT_PATH`文件下有 `./train.txt ./val.txt ./test.txt` ,可以这3个文件来定义训练集、验证集、测试集。**注意**, 这里里面填写的应是图片文件名字,而不是图片的绝对地址。(在line 43也自行可以修改一下读取方式,为了方便起见,不推荐把图片放在不同位置)
31 |
32 |
33 | ## coco2yolo.py
34 |
35 | 读入coco数据集json格式的标注,输出可供yolo训练的标签。
36 |
37 | **需要注意的是,COCO2017官方的数据集中categories id 是不连续的**,这在yolo读取的时候会出问题,所以需要重新映射一下,这个代码会按id从小到大映射到0~79之间。(如果是自己的数据集,也会重新映射)
38 |
39 | 执行:`python coco2yolo.py --json_path $JSON_FILE_PATH --save_path $LABEL_SAVE_PATH`
40 |
41 | - `$JSON_FILE_PATH`是json文件的地址。
42 | - `$JSON_FILE_PATH`是输出目录(默认为工作目录下的`./labels`目录。
43 |
44 |
45 | ## zeroshot_retrieval_evaluation.ipynb
46 | - 检索topN的计算,支持一对多检索。(一张图对应有多个captions)
47 |
48 | ## vis_yolo_gt_dt.py
49 | 同时把GT和预测结果可视化在同一张图中。`$DT_DIR`是预测结果标签地址,必须是和GT同名的标签。`$ROOT_PATH`文件目录:
50 |
51 | ```bash
52 | └── $ROOT_PATH
53 |
54 | ├── classes.txt
55 |
56 | ├── images
57 |
58 | └── labels
59 | ```
60 |
61 | 执行:`python vis_yolo_gt_dt.py --root $ROOT_PATH --dt $DT_DIR`后生成在`outputs`文件夹中。
62 |
63 | - `classes.txt`和`images`必须有。
64 | - `labels`可以没有,那样就只展示`$DT_DIR`预测结果。
65 | - `$DT_DIR` 若没有输入,则只展示标签结果。
66 |
67 | ## coco_eval.py
68 |
69 | 评估生成的结果,针对**yolov5**生成的检测结果(test中的`--save-json`参数,会生成`best_predictions.json`),但是这个不适应cocoapi,需要用脚本来修改适应。执行:
70 |
71 | `python coco_eval.py --gt $GT_PATH --dt $DT_PATH --yolov5`
72 |
73 | - `--gt` json格式,用于指定测试集的结果,如果没有,可以利用前面的`yolo2coco.py`进行转换。
74 | - `--dt` 同样检测网络生成的预测,使用cocoapi中`loadRes`来加载,所以需要有相应格式的检测结果。
75 | - `--yolov5` 将官方代码中生成的结果转换成适配cocoapi的结果。
76 |
77 | # 2. text-image
78 | ## zeroshot_retrieval_evalution.ipynb
79 | 检索模型的评估指标。(topK召回率),支持多对多的情况。(比如一个文本匹配多张图片)
80 | ## fid_clip_score
81 | 用于画text2image的 FID-CLIP Score曲线图。
--------------------------------------------------------------------------------
/detection/coco2yolo.py:
--------------------------------------------------------------------------------
1 | """
2 | 2021/1/24
3 | COCO 格式的数据集转化为 YOLO 格式的数据集,源代码采取遍历方式,太慢,
4 | 这里改进了一下时间复杂度,从O(nm)改为O(n+m),但是牺牲了一些内存占用
5 | --json_path 输入的json文件路径
6 | --save_path 保存的文件夹名字,默认为当前目录下的labels。
7 | """
8 |
9 | import os
10 | import json
11 | from tqdm import tqdm
12 | import argparse
13 |
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('--json_path', default='./instances_val2017.json',type=str, help="input: coco format(json)")
16 | parser.add_argument('--save_path', default='./labels', type=str, help="specify where to save the output dir of labels")
17 | arg = parser.parse_args()
18 |
19 | def convert(size, box):
20 | dw = 1. / (size[0])
21 | dh = 1. / (size[1])
22 | x = box[0] + box[2] / 2.0
23 | y = box[1] + box[3] / 2.0
24 | w = box[2]
25 | h = box[3]
26 |
27 | x = x * dw
28 | w = w * dw
29 | y = y * dh
30 | h = h * dh
31 | return (x, y, w, h)
32 |
33 | if __name__ == '__main__':
34 | json_file = arg.json_path # COCO Object Instance 类型的标注
35 | ana_txt_save_path = arg.save_path # 保存的路径
36 |
37 | data = json.load(open(json_file, 'r'))
38 | if not os.path.exists(ana_txt_save_path):
39 | os.makedirs(ana_txt_save_path)
40 |
41 | id_map = {} # coco数据集的id不连续!重新映射一下再输出!
42 | for i, category in enumerate(data['categories']):
43 | id_map[category['id']] = i
44 |
45 | # 通过事先建表来降低时间复杂度
46 | max_id = 0
47 | for img in data['images']:
48 | max_id = max(max_id, img['id'])
49 | # 注意这里不能写作 [[]]*(max_id+1),否则列表内的空列表共享地址
50 | img_ann_dict = [[] for i in range(max_id+1)]
51 | for i, ann in enumerate(data['annotations']):
52 | img_ann_dict[ann['image_id']].append(i)
53 |
54 | for img in tqdm(data['images']):
55 | filename = img["file_name"]
56 | img_width = img["width"]
57 | img_height = img["height"]
58 | img_id = img["id"]
59 | head, tail = os.path.splitext(filename)
60 | ana_txt_name = head + ".txt" # 对应的txt名字,与jpg一致
61 | f_txt = open(os.path.join(ana_txt_save_path, ana_txt_name), 'w')
62 | '''for ann in data['annotations']:
63 | if ann['image_id'] == img_id:
64 | box = convert((img_width, img_height), ann["bbox"])
65 | f_txt.write("%s %s %s %s %s\n" % (id_map[ann["category_id"]], box[0], box[1], box[2], box[3]))'''
66 | # 这里可以直接查表而无需重复遍历
67 | for ann_id in img_ann_dict[img_id]:
68 | ann = data['annotations'][ann_id]
69 | box = convert((img_width, img_height), ann["bbox"])
70 | f_txt.write("%s %s %s %s %s\n" % (id_map[ann["category_id"]], box[0], box[1], box[2], box[3]))
71 | f_txt.close()
72 |
73 | # 旧版,很慢hhh
74 | # """
75 | # COCO 格式的数据集转化为 YOLO 格式的数据集
76 | # --json_path 输入的json文件路径
77 | # --save_path 保存的文件夹名字,默认为当前目录下的labels。
78 | # """
79 |
80 | # import os
81 | # import json
82 | # from tqdm import tqdm
83 | # import argparse
84 |
85 | # parser = argparse.ArgumentParser()
86 | # parser.add_argument('--json_path', default='./instances_val2017.json',type=str, help="input: coco format(json)")
87 | # parser.add_argument('--save_path', default='./labels', type=str, help="specify where to save the output dir of labels")
88 | # arg = parser.parse_args()
89 |
90 | # def convert(size, box):
91 | # dw = 1. / (size[0])
92 | # dh = 1. / (size[1])
93 | # x = box[0] + box[2] / 2.0
94 | # y = box[1] + box[3] / 2.0
95 | # w = box[2]
96 | # h = box[3]
97 |
98 | # x = x * dw
99 | # w = w * dw
100 | # y = y * dh
101 | # h = h * dh
102 | # return (x, y, w, h)
103 |
104 | # if __name__ == '__main__':
105 | # json_file = arg.json_path # COCO Object Instance 类型的标注
106 | # ana_txt_save_path = arg.save_path # 保存的路径
107 |
108 | # data = json.load(open(json_file, 'r'))
109 | # if not os.path.exists(ana_txt_save_path):
110 | # os.makedirs(ana_txt_save_path)
111 |
112 | # id_map = {} # coco数据集的id不连续!重新映射一下再输出!
113 | # with open(os.path.join(ana_txt_save_path, 'classes.txt'), 'w') as f:
114 | # # 写入classes.txt
115 | # for i, category in enumerate(data['categories']):
116 | # f.write(f"{category['name']}\n")
117 | # id_map[category['id']] = i
118 | # # print(id_map)
119 |
120 | # for img in tqdm(data['images']):
121 | # filename = img["file_name"]
122 | # img_width = img["width"]
123 | # img_height = img["height"]
124 | # img_id = img["id"]
125 | # head, tail = os.path.splitext(filename)
126 | # ana_txt_name = head + ".txt" # 对应的txt名字,与jpg一致
127 | # f_txt = open(os.path.join(ana_txt_save_path, ana_txt_name), 'w')
128 | # for ann in data['annotations']:
129 | # if ann['image_id'] == img_id:
130 | # box = convert((img_width, img_height), ann["bbox"])
131 | # f_txt.write("%s %s %s %s %s\n" % (id_map[ann["category_id"]], box[0], box[1], box[2], box[3]))
132 | # f_txt.close()
133 |
--------------------------------------------------------------------------------
/detection/coco_eval.py:
--------------------------------------------------------------------------------
1 | import json
2 | import argparse
3 | from pycocotools.coco import COCO
4 | from pycocotools.cocoeval import COCOeval
5 | import os
6 | import time
7 |
8 | def transform_yolov5_result(result, filename2id):
9 | f = open(result ,'r',encoding='utf-8')
10 | dts = json.load(f)
11 | output_dts = []
12 | for dt in dts:
13 | dt['image_id'] = filename2id[dt['image_id']+'.jpg']
14 | dt['category_id'] # id对应好,coco格式和yolo格式的category_id可能不同。
15 | output_dts.append(dt)
16 | with open('temp.json', 'w') as f:
17 | json.dump(output_dts, f)
18 |
19 | def coco_evaluate(gt_path, dt_path, yolov5_flag):
20 | cocoGt = COCO(gt_path)
21 | imgIds = cocoGt.getImgIds()
22 | gts = cocoGt.loadImgs(imgIds)
23 | filename2id = {}
24 |
25 | for gt in gts:
26 | filename2id[gt['file_name']] = gt['id']
27 | print("NUM OF TEST IMAGES: ",len(filename2id))
28 |
29 | if yolov5_flag:
30 | transform_yolov5_result(dt_path, filename2id)
31 | cocoDt = cocoGt.loadRes('temp.json')
32 | else:
33 | cocoDt = cocoGt.loadRes(dt_path)
34 | cocoEval = COCOeval(cocoGt, cocoDt, "bbox")
35 | cocoEval.evaluate()
36 | cocoEval.accumulate()
37 | cocoEval.summarize()
38 | if yolov5_flag:
39 | os.remove('temp.json')
40 |
41 | if __name__ == '__main__':
42 | parser = argparse.ArgumentParser()
43 | parser.add_argument("--gt", type=str, help="Assign the groud true path.", default=None)
44 | parser.add_argument("--dt", type=str, help="Assign the detection result path.", default=None)
45 | parser.add_argument("--yolov5",action='store_true',help="fix yolov5 output bug", default=None)
46 |
47 | args = parser.parse_args()
48 | gt_path = args.gt
49 | dt_path = args.dt
50 | if args.yolov5:
51 | coco_evaluate(gt_path, dt_path, True)
52 | else:
53 | coco_evaluate(gt_path, dt_path, False)
54 |
--------------------------------------------------------------------------------
/detection/vis_yolo_gt_dt.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import os
3 | from glob import glob
4 | import random
5 | import matplotlib.pyplot as plt
6 | import argparse
7 | from tqdm import tqdm
8 | import numpy as np
9 |
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument('--root',type=str ,default='', help="which should include ./images and ./labels and classes.txt")
12 | parser.add_argument('--dt',type=str ,default='' , help="yolo format results of detection, include ./labels")
13 | parser.add_argument('--conf' , type=float ,default=0.5, help="visulization conf thres")
14 | arg = parser.parse_args()
15 |
16 | colorlist = []
17 | # 5^3种颜色。
18 | for i in range(25,256,50):
19 | for j in range(25,256,50):
20 | for k in range(25,256,50):
21 | colorlist.append((i,j,k))
22 | random.shuffle(colorlist)
23 |
24 | def plot_bbox(img_path, img_dir, out_dir, gt=None ,dt=None, cls2label=None, line_thickness=None):
25 | img = cv2.imread(os.path.join(img_dir, img_path))
26 | height, width,_ = img.shape
27 | tl = line_thickness or round(0.002 * (width + height) / 2) + 1 # line/font thickness
28 | font = cv2.FONT_HERSHEY_SIMPLEX
29 | if gt:
30 | tf = max(tl - 1, 1) # font thickness
31 | with open(gt,'r') as f:
32 | annotations = f.readlines()
33 | # print(annotations)
34 | for ann in annotations:
35 | ann = list(map(float,ann.split()))
36 | ann[0] = int(ann[0])
37 | # print(ann)
38 | cls,x,y,w,h = ann
39 | color = colorlist[cls]
40 | c1, c2 = (int((x-w/2)*width),int((y-h/2)*height)), (int((x+w/2)*width), int((y+h/2)*height))
41 | cv2.rectangle(img, c1, c2, color, thickness=tl*2, lineType=cv2.LINE_AA)
42 | # 类别名称显示
43 | cv2.putText(img, str(cls2label[cls]), (c1[0], c1[1] - 2), 0, tl / 4, color, thickness=tf, lineType=cv2.LINE_AA)
44 | if dt:
45 | with open(dt,'r') as f:
46 | annotations = f.readlines()
47 | # print(annotations)
48 | for ann in annotations:
49 | ann = list(map(float,ann.split()))
50 | ann[0] = int(ann[0])
51 | # print(ann)
52 | if len(ann) == 6:
53 | cls,x,y,w,h,conf = ann
54 | if conf < arg.conf:
55 | # thres = 0.5
56 | continue
57 | elif len(ann) == 5:
58 | cls,x,y,w,h = ann
59 | color = colorlist[len(colorlist) - cls - 1]
60 |
61 | c1, c2 = (int((x-w/2)*width), int((y-h/2)*height)), (int((x+w/2)*width), int((y+h/2)*height))
62 | cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
63 |
64 | # # cls label
65 | tf = max(tl - 1, 1) # font thickness
66 | t_size = cv2.getTextSize(cls2label[cls], 0, fontScale=tl / 3, thickness=tf)[0]
67 | c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
68 | # cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled
69 | if len(ann) == 6:
70 | cv2.putText(img, str(round(conf,2)), (c1[0], c1[1] - 2), 0, tl / 4, color, thickness=tf, lineType=cv2.LINE_AA)
71 | cv2.imwrite(os.path.join(out_dir,img_path),img)
72 |
73 | if __name__ == "__main__":
74 | root_path = arg.root
75 | pred_path = arg.dt
76 | img_dir = os.path.join(root_path,'images')
77 | GT_dir = os.path.join(root_path,'labels')
78 | DT_dir = os.path.join(pred_path)
79 | out_dir = os.path.join(root_path,'outputs')
80 | cls_dir = os.path.join(root_path,'classes.txt')
81 | cls_dict = {}
82 |
83 | if not os.path.exists(img_dir):
84 | raise Exception("image dir {} do not exist!".format(img_dir))
85 | if not os.path.exists(cls_dir):
86 | raise Exception("class dir {} do not exist!".format(cls_dir))
87 | else:
88 | with open(cls_dir,'r') as f:
89 | classes = f.readlines()
90 | for i in range(len(classes)):
91 | cls_dict[i] = classes[i].strip()
92 | print("class map:", cls_dict)
93 | if not os.path.exists(out_dir):
94 | os.mkdir(out_dir)
95 | if not os.path.exists(GT_dir):
96 | print(f"WARNNING: {GT_dir} ,GT NOT Available!")
97 | if not os.path.exists(DT_dir):
98 | print(f"WARNNING: {DT_dir} ,DT NOT Available!")
99 | for each_img in tqdm(os.listdir(img_dir)):
100 | gt = None
101 | dt = None
102 | if os.path.exists(os.path.join(GT_dir,each_img.replace('jpg','txt'))):
103 | gt = os.path.join(GT_dir,each_img.replace('jpg','txt'))
104 | if os.path.exists(os.path.join(DT_dir,each_img.replace('jpg','txt'))):
105 | dt = os.path.join(DT_dir,each_img.replace('jpg','txt'))
106 |
107 | plot_bbox(each_img, img_dir, out_dir, gt, dt, cls2label=cls_dict)
108 |
--------------------------------------------------------------------------------
/detection/yolo2coco.py:
--------------------------------------------------------------------------------
1 | """
2 | YOLO 格式的数据集转化为 COCO 格式的数据集
3 | --root_dir 输入根路径
4 | --save_path 保存文件的名字(没有random_split时使用)
5 | --random_split 有则会随机划分数据集,然后再分别保存为3个文件。
6 | --split_by_file 按照 ./train.txt ./val.txt ./test.txt 来对数据集进行划分。
7 | """
8 |
9 | import os
10 | import cv2
11 | import json
12 | from tqdm import tqdm
13 | from sklearn.model_selection import train_test_split
14 | import argparse
15 |
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument('--root_dir', default='./data',type=str, help="root path of images and labels, include ./images and ./labels and classes.txt")
18 | parser.add_argument('--save_path', type=str,default='./train.json', help="if not split the dataset, give a path to a json file")
19 | parser.add_argument('--random_split', action='store_true', help="random split the dataset, default ratio is 8:1:1")
20 | parser.add_argument('--split_by_file', action='store_true', help="define how to split the dataset, include ./train.txt ./val.txt ./test.txt ")
21 |
22 | arg = parser.parse_args()
23 |
24 | def train_test_val_split_random(img_paths,ratio_train=0.8,ratio_test=0.1,ratio_val=0.1):
25 | # 这里可以修改数据集划分的比例。
26 | assert int(ratio_train+ratio_test+ratio_val) == 1
27 | train_img, middle_img = train_test_split(img_paths,test_size=1-ratio_train, random_state=233)
28 | ratio=ratio_val/(1-ratio_train)
29 | val_img, test_img =train_test_split(middle_img,test_size=ratio, random_state=233)
30 | print("NUMS of train:val:test = {}:{}:{}".format(len(train_img), len(val_img), len(test_img)))
31 | return train_img, val_img, test_img
32 |
33 | def train_test_val_split_by_files(img_paths, root_dir):
34 | # 根据文件 train.txt, val.txt, test.txt(里面写的都是对应集合的图片名字) 来定义训练集、验证集和测试集
35 | phases = ['train', 'val', 'test']
36 | img_split = []
37 | for p in phases:
38 | define_path = os.path.join(root_dir, f'{p}.txt')
39 | print(f'Read {p} dataset definition from {define_path}')
40 | assert os.path.exists(define_path)
41 | with open(define_path, 'r') as f:
42 | img_paths = f.readlines()
43 | # img_paths = [os.path.split(img_path.strip())[1] for img_path in img_paths] # NOTE 取消这句备注可以读取绝对地址。
44 | img_split.append(img_paths)
45 | return img_split[0], img_split[1], img_split[2]
46 |
47 |
48 | def yolo2coco(arg):
49 | root_path = arg.root_dir
50 | print("Loading data from ",root_path)
51 |
52 | assert os.path.exists(root_path)
53 | originLabelsDir = os.path.join(root_path, 'labels')
54 | originImagesDir = os.path.join(root_path, 'images')
55 | with open(os.path.join(root_path, 'classes.txt')) as f:
56 | classes = f.read().strip().split()
57 | # images dir name
58 | indexes = os.listdir(originImagesDir)
59 |
60 | if arg.random_split or arg.split_by_file:
61 | # 用于保存所有数据的图片信息和标注信息
62 | train_dataset = {'categories': [], 'annotations': [], 'images': []}
63 | val_dataset = {'categories': [], 'annotations': [], 'images': []}
64 | test_dataset = {'categories': [], 'annotations': [], 'images': []}
65 |
66 | # 建立类别标签和数字id的对应关系, 类别id从0开始。
67 | for i, cls in enumerate(classes, 0):
68 | train_dataset['categories'].append({'id': i, 'name': cls, 'supercategory': 'mark'})
69 | val_dataset['categories'].append({'id': i, 'name': cls, 'supercategory': 'mark'})
70 | test_dataset['categories'].append({'id': i, 'name': cls, 'supercategory': 'mark'})
71 |
72 | if arg.random_split:
73 | print("spliting mode: random split")
74 | train_img, val_img, test_img = train_test_val_split_random(indexes,0.8,0.1,0.1)
75 | elif arg.split_by_file:
76 | print("spliting mode: split by files")
77 | train_img, val_img, test_img = train_test_val_split_by_files(indexes, root_path)
78 | else:
79 | dataset = {'categories': [], 'annotations': [], 'images': []}
80 | for i, cls in enumerate(classes, 0):
81 | dataset['categories'].append({'id': i, 'name': cls, 'supercategory': 'mark'})
82 |
83 | # 标注的id
84 | ann_id_cnt = 0
85 | for k, index in enumerate(tqdm(indexes)):
86 | # 支持 png jpg 格式的图片。
87 | txtFile = index.replace('images','txt').replace('.jpg','.txt').replace('.png','.txt')
88 | # 读取图像的宽和高
89 | im = cv2.imread(os.path.join(root_path, 'images/') + index)
90 | height, width, _ = im.shape
91 | if arg.random_split or arg.split_by_file:
92 | # 切换dataset的引用对象,从而划分数据集
93 | if index in train_img:
94 | dataset = train_dataset
95 | elif index in val_img:
96 | dataset = val_dataset
97 | elif index in test_img:
98 | dataset = test_dataset
99 | # 添加图像的信息
100 | dataset['images'].append({'file_name': index,
101 | 'id': k,
102 | 'width': width,
103 | 'height': height})
104 | if not os.path.exists(os.path.join(originLabelsDir, txtFile)):
105 | # 如没标签,跳过,只保留图片信息。
106 | continue
107 | with open(os.path.join(originLabelsDir, txtFile), 'r') as fr:
108 | labelList = fr.readlines()
109 | for label in labelList:
110 | label = label.strip().split()
111 | x = float(label[1])
112 | y = float(label[2])
113 | w = float(label[3])
114 | h = float(label[4])
115 |
116 | # convert x,y,w,h to x1,y1,x2,y2
117 | H, W, _ = im.shape
118 | x1 = (x - w / 2) * W
119 | y1 = (y - h / 2) * H
120 | x2 = (x + w / 2) * W
121 | y2 = (y + h / 2) * H
122 | # 标签序号从0开始计算, coco2017数据集标号混乱,不管它了。
123 | cls_id = int(label[0])
124 | width = max(0, x2 - x1)
125 | height = max(0, y2 - y1)
126 | dataset['annotations'].append({
127 | 'area': width * height,
128 | 'bbox': [x1, y1, width, height],
129 | 'category_id': cls_id,
130 | 'id': ann_id_cnt,
131 | 'image_id': k,
132 | 'iscrowd': 0,
133 | # mask, 矩形是从左上角点按顺时针的四个顶点
134 | 'segmentation': [[x1, y1, x2, y1, x2, y2, x1, y2]]
135 | })
136 | ann_id_cnt += 1
137 |
138 | # 保存结果
139 | folder = os.path.join(root_path, 'annotations')
140 | if not os.path.exists(folder):
141 | os.makedirs(folder)
142 | if arg.random_split or arg.split_by_file:
143 | for phase in ['train','val','test']:
144 | json_name = os.path.join(root_path, 'annotations/{}.json'.format(phase))
145 | with open(json_name, 'w') as f:
146 | if phase == 'train':
147 | json.dump(train_dataset, f)
148 | elif phase == 'val':
149 | json.dump(val_dataset, f)
150 | elif phase == 'test':
151 | json.dump(test_dataset, f)
152 | print('Save annotation to {}'.format(json_name))
153 | else:
154 | json_name = os.path.join(root_path, 'annotations/{}'.format(arg.save_path))
155 | with open(json_name, 'w') as f:
156 | json.dump(dataset, f)
157 | print('Save annotation to {}'.format(json_name))
158 |
159 | if __name__ == "__main__":
160 |
161 | yolo2coco(arg)
--------------------------------------------------------------------------------
/text-image/convert_diffusers_to_original_stable_diffusion.py:
--------------------------------------------------------------------------------
1 | # Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
2 | # *Only* converts the UNet, VAE, and Text Encoder.
3 | # Does not convert optimizer state or any other thing.
4 |
5 | import argparse
6 | import os.path as osp
7 |
8 | import torch
9 |
10 |
11 | # =================#
12 | # UNet Conversion #
13 | # =================#
14 |
15 | unet_conversion_map = [
16 | # (stable-diffusion, HF Diffusers)
17 | ("time_embed.0.weight", "time_embedding.linear_1.weight"),
18 | ("time_embed.0.bias", "time_embedding.linear_1.bias"),
19 | ("time_embed.2.weight", "time_embedding.linear_2.weight"),
20 | ("time_embed.2.bias", "time_embedding.linear_2.bias"),
21 | ("input_blocks.0.0.weight", "conv_in.weight"),
22 | ("input_blocks.0.0.bias", "conv_in.bias"),
23 | ("out.0.weight", "conv_norm_out.weight"),
24 | ("out.0.bias", "conv_norm_out.bias"),
25 | ("out.2.weight", "conv_out.weight"),
26 | ("out.2.bias", "conv_out.bias"),
27 | ]
28 |
29 | unet_conversion_map_resnet = [
30 | # (stable-diffusion, HF Diffusers)
31 | ("in_layers.0", "norm1"),
32 | ("in_layers.2", "conv1"),
33 | ("out_layers.0", "norm2"),
34 | ("out_layers.3", "conv2"),
35 | ("emb_layers.1", "time_emb_proj"),
36 | ("skip_connection", "conv_shortcut"),
37 | ]
38 |
39 | unet_conversion_map_layer = []
40 | # hardcoded number of downblocks and resnets/attentions...
41 | # would need smarter logic for other networks.
42 | for i in range(4):
43 | # loop over downblocks/upblocks
44 |
45 | for j in range(2):
46 | # loop over resnets/attentions for downblocks
47 | hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
48 | sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
49 | unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
50 |
51 | if i < 3:
52 | # no attention layers in down_blocks.3
53 | hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
54 | sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
55 | unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
56 |
57 | for j in range(3):
58 | # loop over resnets/attentions for upblocks
59 | hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
60 | sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
61 | unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
62 |
63 | if i > 0:
64 | # no attention layers in up_blocks.0
65 | hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
66 | sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
67 | unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
68 |
69 | if i < 3:
70 | # no downsample in down_blocks.3
71 | hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
72 | sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
73 | unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
74 |
75 | # no upsample in up_blocks.3
76 | hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
77 | sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
78 | unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
79 |
80 | hf_mid_atn_prefix = "mid_block.attentions.0."
81 | sd_mid_atn_prefix = "middle_block.1."
82 | unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
83 |
84 | for j in range(2):
85 | hf_mid_res_prefix = f"mid_block.resnets.{j}."
86 | sd_mid_res_prefix = f"middle_block.{2*j}."
87 | unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
88 |
89 |
90 | def convert_unet_state_dict(unet_state_dict):
91 | # buyer beware: this is a *brittle* function,
92 | # and correct output requires that all of these pieces interact in
93 | # the exact order in which I have arranged them.
94 | mapping = {k: k for k in unet_state_dict.keys()}
95 | for sd_name, hf_name in unet_conversion_map:
96 | mapping[hf_name] = sd_name
97 | for k, v in mapping.items():
98 | if "resnets" in k:
99 | for sd_part, hf_part in unet_conversion_map_resnet:
100 | v = v.replace(hf_part, sd_part)
101 | mapping[k] = v
102 | for k, v in mapping.items():
103 | for sd_part, hf_part in unet_conversion_map_layer:
104 | v = v.replace(hf_part, sd_part)
105 | mapping[k] = v
106 | new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
107 | return new_state_dict
108 |
109 |
110 | # ================#
111 | # VAE Conversion #
112 | # ================#
113 |
114 | vae_conversion_map = [
115 | # (stable-diffusion, HF Diffusers)
116 | ("nin_shortcut", "conv_shortcut"),
117 | ("norm_out", "conv_norm_out"),
118 | ("mid.attn_1.", "mid_block.attentions.0."),
119 | ]
120 |
121 | for i in range(4):
122 | # down_blocks have two resnets
123 | for j in range(2):
124 | hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
125 | sd_down_prefix = f"encoder.down.{i}.block.{j}."
126 | vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
127 |
128 | if i < 3:
129 | hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
130 | sd_downsample_prefix = f"down.{i}.downsample."
131 | vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
132 |
133 | hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
134 | sd_upsample_prefix = f"up.{3-i}.upsample."
135 | vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
136 |
137 | # up_blocks have three resnets
138 | # also, up blocks in hf are numbered in reverse from sd
139 | for j in range(3):
140 | hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
141 | sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
142 | vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
143 |
144 | # this part accounts for mid blocks in both the encoder and the decoder
145 | for i in range(2):
146 | hf_mid_res_prefix = f"mid_block.resnets.{i}."
147 | sd_mid_res_prefix = f"mid.block_{i+1}."
148 | vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
149 |
150 |
151 | vae_conversion_map_attn = [
152 | # (stable-diffusion, HF Diffusers)
153 | ("norm.", "group_norm."),
154 | ("q.", "query."),
155 | ("k.", "key."),
156 | ("v.", "value."),
157 | ("proj_out.", "proj_attn."),
158 | ]
159 |
160 |
161 | def reshape_weight_for_sd(w):
162 | # convert HF linear weights to SD conv2d weights
163 | return w.reshape(*w.shape, 1, 1)
164 |
165 |
166 | def convert_vae_state_dict(vae_state_dict):
167 | mapping = {k: k for k in vae_state_dict.keys()}
168 | for k, v in mapping.items():
169 | for sd_part, hf_part in vae_conversion_map:
170 | v = v.replace(hf_part, sd_part)
171 | mapping[k] = v
172 | for k, v in mapping.items():
173 | if "attentions" in k:
174 | for sd_part, hf_part in vae_conversion_map_attn:
175 | v = v.replace(hf_part, sd_part)
176 | mapping[k] = v
177 | new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
178 | weights_to_convert = ["q", "k", "v", "proj_out"]
179 | for k, v in new_state_dict.items():
180 | for weight_name in weights_to_convert:
181 | if f"mid.attn_1.{weight_name}.weight" in k:
182 | print(f"Reshaping {k} for SD format")
183 | new_state_dict[k] = reshape_weight_for_sd(v)
184 | return new_state_dict
185 |
186 |
187 | # =========================#
188 | # Text Encoder Conversion #
189 | # =========================#
190 | # pretty much a no-op
191 |
192 |
193 | def convert_text_enc_state_dict(text_enc_dict):
194 | return text_enc_dict
195 |
196 |
197 | if __name__ == "__main__":
198 | parser = argparse.ArgumentParser()
199 |
200 | parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
201 | parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
202 | parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
203 |
204 | args = parser.parse_args()
205 |
206 | assert args.model_path is not None, "Must provide a model path!"
207 |
208 | assert args.checkpoint_path is not None, "Must provide a checkpoint path!"
209 |
210 | unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin")
211 | vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin")
212 | text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin")
213 |
214 | # Convert the UNet model
215 | unet_state_dict = torch.load(unet_path, map_location="cpu")
216 | unet_state_dict = convert_unet_state_dict(unet_state_dict)
217 | unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
218 |
219 | # Convert the VAE model
220 | vae_state_dict = torch.load(vae_path, map_location="cpu")
221 | vae_state_dict = convert_vae_state_dict(vae_state_dict)
222 | vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
223 |
224 | # Convert the text encoder model
225 | text_enc_dict = torch.load(text_enc_path, map_location="cpu")
226 | text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
227 | text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
228 |
229 | # Put together new checkpoint
230 | state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
231 | if args.half:
232 | state_dict = {k: v.half() for k, v in state_dict.items()}
233 | state_dict = {"state_dict": state_dict}
234 | torch.save(state_dict, args.checkpoint_path)
--------------------------------------------------------------------------------
/text-image/data_filter/data_filter_demo.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# 集成了水印、美学、CLIP模型,用于给图文质量打分"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": null,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "import pytorch_lightning as pl\n",
17 | "import torch.nn as nn\n",
18 | "import torch.nn.functional as F\n",
19 | "import torch\n",
20 | "import timm\n",
21 | "from torchvision import transforms as T\n",
22 | "import open_clip\n",
23 | "import torch\n",
24 | "from transformers import BertModel, BertTokenizer\n",
25 | "from PIL import Image\n",
26 | "\n",
27 | "class AestheticsMLP(pl.LightningModule):\n",
28 | " # 美学判别器是基于CLIP的基础上接了一个MLP\n",
29 | " def __init__(self, input_size, xcol='emb', ycol='avg_rating'):\n",
30 | " super().__init__()\n",
31 | " self.input_size = input_size\n",
32 | " self.xcol = xcol\n",
33 | " self.ycol = ycol\n",
34 | " self.layers = nn.Sequential(\n",
35 | " nn.Linear(self.input_size, 1024),\n",
36 | " #nn.ReLU(),\n",
37 | " nn.Dropout(0.2),\n",
38 | " nn.Linear(1024, 128),\n",
39 | " #nn.ReLU(),\n",
40 | " nn.Dropout(0.2),\n",
41 | " nn.Linear(128, 64),\n",
42 | " #nn.ReLU(),\n",
43 | " nn.Dropout(0.1),\n",
44 | "\n",
45 | " nn.Linear(64, 16),\n",
46 | " #nn.ReLU(),\n",
47 | "\n",
48 | " nn.Linear(16, 1)\n",
49 | " )\n",
50 | "\n",
51 | " def forward(self, x):\n",
52 | " return self.layers(x)\n",
53 | "\n",
54 | " def training_step(self, batch, batch_idx):\n",
55 | " x = batch[self.xcol]\n",
56 | " y = batch[self.ycol].reshape(-1, 1)\n",
57 | " x_hat = self.layers(x)\n",
58 | " loss = F.mse_loss(x_hat, y)\n",
59 | " return loss\n",
60 | " \n",
61 | " def validation_step(self, batch, batch_idx):\n",
62 | " x = batch[self.xcol]\n",
63 | " y = batch[self.ycol].reshape(-1, 1)\n",
64 | " x_hat = self.layers(x)\n",
65 | " loss = F.mse_loss(x_hat, y)\n",
66 | " return loss\n",
67 | "\n",
68 | " def configure_optimizers(self):\n",
69 | " optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)\n",
70 | " return optimizer\n",
71 | "\n",
72 | "\n",
73 | "class WaterMarkModel(nn.Module):\n",
74 | " def __init__(self, model_path='./watermark_model_v1.pt'):\n",
75 | " super(WaterMarkModel, self).__init__()\n",
76 | " # model definition\n",
77 | " self.model = timm.create_model(\n",
78 | " 'efficientnet_b3a', pretrained=True, num_classes=2)\n",
79 | "\n",
80 | " self.model.classifier = nn.Sequential(\n",
81 | " # 1536 is the orginal in_features\n",
82 | " nn.Linear(in_features=1536, out_features=625),\n",
83 | " nn.ReLU(), # ReLu to be the activation function\n",
84 | " nn.Dropout(p=0.3),\n",
85 | " nn.Linear(in_features=625, out_features=256),\n",
86 | " nn.ReLU(),\n",
87 | " nn.Linear(in_features=256, out_features=2),\n",
88 | " )\n",
89 | " self.model.load_state_dict(torch.load(model_path))\n",
90 | " def forward(self, x):\n",
91 | " return self.model(x)"
92 | ]
93 | },
94 | {
95 | "cell_type": "code",
96 | "execution_count": null,
97 | "metadata": {},
98 | "outputs": [],
99 | "source": [
100 | "class FilterSystem:\n",
101 | " def __init__(\n",
102 | " self, \n",
103 | " clip_model_path=\"IDEA-CCNL/Taiyi-CLIP-RoBERTa-102M-ViT-L-Chinese\",\n",
104 | " aesthetics_model_path=\"./ava+logos-l14-linearMSE.pth\",\n",
105 | " watermark_model_path=\"./watermark_model_v1.pt\"\n",
106 | " ):\n",
107 | " self.clip_model_path = clip_model_path\n",
108 | " self.aesthetics_model_path = aesthetics_model_path\n",
109 | " self.watermark_model_path = watermark_model_path\n",
110 | "\n",
111 | " def init_clip_model(self, ):\n",
112 | " # 此处初始化clip模型,返回模型、tokenizer、processor\n",
113 | " text_encoder = BertModel.from_pretrained(self.clip_model_path).eval().cuda()\n",
114 | " text_tokenizer = BertTokenizer.from_pretrained(self.clip_model_path)\n",
115 | " clip_model, _, processor = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai')\n",
116 | " clip_model = clip_model.eval().cuda()\n",
117 | " self.text_encoder, self.text_tokenizer, self.clip_model, self.processor = text_encoder, text_tokenizer, clip_model, processor\n",
118 | " print(\"clip model loaded\")\n",
119 | " return None\n",
120 | "\n",
121 | " def init_aesthetics_model(self, ):\n",
122 | " # 此处初始化美学模型\n",
123 | " self.aesthetics_model = AestheticsMLP(768)\n",
124 | " self.aesthetics_model.load_state_dict(torch.load(self.aesthetics_model_path))\n",
125 | " self.aesthetics_model.eval().cuda()\n",
126 | " print(\"aesthetics model loaded\")\n",
127 | " return None\n",
128 | "\n",
129 | " def init_watermark_model(self, ):\n",
130 | " self.watermark_model = WaterMarkModel(self.watermark_model_path)\n",
131 | " self.watermark_model.eval().cuda()\n",
132 | " self.watermark_processor = T.Compose([\n",
133 | " T.Resize((256, 256)),\n",
134 | " T.ToTensor(),\n",
135 | " T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
136 | " ])\n",
137 | " print(\"watermark model loaded\")\n",
138 | " return None\n",
139 | "\n",
140 | " def get_image_feature(self, images):\n",
141 | " # 此处返回图像的特征向量\n",
142 | " if isinstance(images, list):\n",
143 | " images = torch.stack([self.processor(image) for image in images]).cuda()\n",
144 | " elif isinstance(images, torch.Tensor):\n",
145 | " images = images.cuda()\n",
146 | "\n",
147 | " with torch.no_grad():\n",
148 | " image_features = self.clip_model.encode_image(images)\n",
149 | " image_features /= image_features.norm(dim=1, keepdim=True)\n",
150 | " return image_features\n",
151 | " \n",
152 | " def get_text_feature(self, text):\n",
153 | " # 此处返回文本的特征向量\n",
154 | " if isinstance(text, list) or isinstance(text, str):\n",
155 | " text = self.text_tokenizer(text, return_tensors='pt', padding=True)['input_ids'].cuda()\n",
156 | " elif isinstance(text, torch.Tensor):\n",
157 | " text = text.cuda()\n",
158 | "\n",
159 | " with torch.no_grad():\n",
160 | " text_features = self.text_encoder(text)[1]\n",
161 | " text_features /= text_features.norm(dim=1, keepdim=True)\n",
162 | " return text_features\n",
163 | "\n",
164 | " def calculate_clip_score(self, features1, features2):\n",
165 | " # 此处2个特征向量的相似度,输入可以是 图片+文本、文本+文本、图片+图片。\n",
166 | " # 返回的是相似度矩阵,维度为 f1.shape[0] * f2.shape[0]\n",
167 | " score_matrix = features1 @ features2.t()\n",
168 | " return score_matrix\n",
169 | "\n",
170 | " def get_aesthetics_score(self, features):\n",
171 | " # 此处返回美学分数,传入的是CLIP的feature, 先计算get_image_feature在传入此函数~(模型是ViT-L-14)\n",
172 | " with torch.no_grad():\n",
173 | " scores = self.aesthetics_model(features)\n",
174 | " scores = scores[:, 0].detach().cpu().numpy()\n",
175 | " return scores\n",
176 | " \n",
177 | " def get_watermark_score(self, images):\n",
178 | " if isinstance(images, list):\n",
179 | " images = torch.stack([self.watermark_processor(image) for image in images]).cuda()\n",
180 | " elif isinstance(images, torch.Tensor):\n",
181 | " images = images.cuda()\n",
182 | " with torch.no_grad():\n",
183 | " pred = self.watermark_model(images)\n",
184 | " watermark_scores = F.softmax(pred, dim=1)[:,0].detach().cpu().numpy()\n",
185 | "\n",
186 | " return watermark_scores"
187 | ]
188 | },
189 | {
190 | "cell_type": "markdown",
191 | "metadata": {},
192 | "source": [
193 | "## 小规模数据测试"
194 | ]
195 | },
196 | {
197 | "cell_type": "code",
198 | "execution_count": null,
199 | "metadata": {},
200 | "outputs": [],
201 | "source": [
202 | "demo = FilterSystem()\n",
203 | "demo.init_clip_model()\n",
204 | "demo.init_aesthetics_model()\n",
205 | "demo.init_watermark_model()"
206 | ]
207 | },
208 | {
209 | "cell_type": "code",
210 | "execution_count": null,
211 | "metadata": {},
212 | "outputs": [],
213 | "source": [
214 | "image_path = './demo_images/watermark_example.png'\n",
215 | "image_path2 = './demo_images/mengna.jpg'\n",
216 | "image_path3 = './demo_images/shuiyin.jpg'\n",
217 | "image_path4 = './demo_images/1.jpg'\n",
218 | "image_demo = [Image.open(image_path).convert('RGB'), Image.open(image_path2).convert('RGB'), Image.open(image_path3).convert('RGB'), Image.open(image_path4).convert('RGB')]\n",
219 | "image_feature = demo.get_image_feature(image_demo,) # 计算图片特征,传入图片列表,一般而言,可以在数据库保存这个东西,用于响应文本query\n",
220 | "aes_score = demo.get_aesthetics_score(image_feature) # 计算美学分数,传入图片特征,一般而言,可以在数据库保存这个东西,用于响应文本query\n",
221 | "print(aes_score)"
222 | ]
223 | },
224 | {
225 | "cell_type": "code",
226 | "execution_count": null,
227 | "metadata": {},
228 | "outputs": [],
229 | "source": [
230 | "text_demo = ['一副很美的画','港口小船', '蒙娜丽莎'] # 这里也可以只有一个文本,也就是query\n",
231 | "text_feature = demo.get_text_feature(text_demo) # 计算文本特征,传入文本列表\n",
232 | "similarity = demo.calculate_clip_score(image_feature, text_feature) # 计算相似度\n",
233 | "print(similarity)"
234 | ]
235 | },
236 | {
237 | "cell_type": "code",
238 | "execution_count": null,
239 | "metadata": {},
240 | "outputs": [],
241 | "source": [
242 | "watermark_score = demo.get_watermark_score(image_demo)\n",
243 | "print(watermark_score)"
244 | ]
245 | },
246 | {
247 | "cell_type": "markdown",
248 | "metadata": {},
249 | "source": [
250 | "## 读取处理保存(单个进程)"
251 | ]
252 | },
253 | {
254 | "cell_type": "code",
255 | "execution_count": null,
256 | "metadata": {},
257 | "outputs": [],
258 | "source": [
259 | "# data setting\n",
260 | "root_path = \"./project/dataset/laion_chinese_cwf/image_part00\"\n",
261 | "all_folders = sorted(os.listdir(root_path))"
262 | ]
263 | },
264 | {
265 | "cell_type": "code",
266 | "execution_count": null,
267 | "metadata": {},
268 | "outputs": [],
269 | "source": [
270 | "# model setting\n",
271 | "filter_model = FilterSystem()\n",
272 | "filter_model.init_clip_model()\n",
273 | "filter_model.init_aesthetics_model()\n",
274 | "filter_model.init_watermark_model()"
275 | ]
276 | },
277 | {
278 | "cell_type": "code",
279 | "execution_count": null,
280 | "metadata": {},
281 | "outputs": [],
282 | "source": [
283 | "from model import FilterSystem\n",
284 | "from dataset import TxtDataset\n",
285 | "import os\n",
286 | "from torch.utils.data import DataLoader\n",
287 | "from tqdm import tqdm\n",
288 | "from PIL import Image\n",
289 | "import numpy as np\n",
290 | "import pandas as pd\n",
291 | "\n",
292 | "def sub_process(filter_model, each_folder_path):\n",
293 | " each_dataset = TxtDataset(each_folder_path)\n",
294 | " each_dataloader = DataLoader(each_dataset, batch_size=8, shuffle=False, num_workers=8)\n",
295 | "\n",
296 | " image_paths = []\n",
297 | " aes_scores = []\n",
298 | " clip_scores = []\n",
299 | " watermark_scores = []\n",
300 | " for iii, (batch_image_paths, texts,) in enumerate(tqdm(each_dataloader)):\n",
301 | " images = [Image.open(each_image_path).convert(\"RGB\") for each_image_path in batch_image_paths]\n",
302 | " image_paths.extend(batch_image_paths)\n",
303 | "\n",
304 | " image_features = filter_model.get_image_feature(images,) # 计算图片特征,传入图片列表,一般而言,可以在数据库保存这个东西,用于响应文本query\n",
305 | " aes_score = filter_model.get_aesthetics_score(image_features) # 计算美学分数,传入图片特征,一般而言,可以在数据库保存这个东西,用于响应文本query\n",
306 | " aes_scores.extend(aes_score)\n",
307 | "\n",
308 | " text_features = filter_model.get_text_feature(list(texts)) # 计算文本特征,传入文本列表\n",
309 | " clip_score = filter_model.calculate_clip_score(image_features, text_features) # 计算相似度\n",
310 | " clip_scores.extend(torch.diagonal(clip_score).detach().cpu().numpy()) # 需要取对角线,只需要自己和对应文本的相似度\n",
311 | "\n",
312 | " watermark_score = filter_model.get_watermark_score(images) # 计算水印分数,传入图片路径列表\n",
313 | " watermark_scores.extend(watermark_score)\n",
314 | " \n",
315 | " # print('aes_score:', aes_score, '\\n',\n",
316 | " # 'clip_score:', clip_score, '\\n',\n",
317 | " # 'watermark_score:', watermark_score, '\\n',\n",
318 | " # 'image_paths:', image_paths, '\\n',\n",
319 | " # 'texts:', texts)\n",
320 | " \n",
321 | " score_pd = pd.DataFrame({'image_path': image_paths, 'aes_score': aes_scores, 'clip_score': clip_scores, 'watermark_score': watermark_scores})\n",
322 | " score_pd.to_csv(os.path.join(each_folder_path, 'score.csv'), index=False)\n",
323 | " print('save score.csv in {}'.format(each_folder_path), '\\n', '-'*20)\n",
324 | "\n",
325 | "for each_folder in all_folders[:10]:\n",
326 | " each_folder_path = os.path.join(root_path, each_folder)\n",
327 | " sub_process(filter_model, each_folder_path)"
328 | ]
329 | },
330 | {
331 | "cell_type": "code",
332 | "execution_count": null,
333 | "metadata": {},
334 | "outputs": [],
335 | "source": [
336 | "from model import FilterSystem\n",
337 | "from dataset import TxtDataset\n",
338 | "import os\n",
339 | "from torch.utils.data import DataLoader\n",
340 | "from tqdm import tqdm\n",
341 | "from PIL import Image\n",
342 | "import numpy as np\n",
343 | "import pandas as pd\n",
344 | "from concurrent.futures import ProcessPoolExecutor, wait, ALL_COMPLETED\n",
345 | "\n",
346 | "p = ProcessPoolExecutor(max_workers=4)\n",
347 | "\n",
348 | "def sub_process(filter_model, each_folder_path):\n",
349 | " each_dataset = TxtDataset(each_folder_path)\n",
350 | " each_dataloader = DataLoader(each_dataset, batch_size=8, shuffle=False, num_workers=8)\n",
351 | "\n",
352 | " image_paths = []\n",
353 | " aes_scores = []\n",
354 | " clip_scores = []\n",
355 | " watermark_scores = []\n",
356 | " for iii, (batch_image_paths, texts,) in enumerate(each_dataloader):\n",
357 | " images = [Image.open(each_image_path) for each_image_path in batch_image_paths]\n",
358 | " image_paths.extend(batch_image_paths)\n",
359 | "\n",
360 | " image_features = filter_model.get_image_feature(images,) # 计算图片特征,传入图片列表,一般而言,可以在数据库保存这个东西,用于响应文本query\n",
361 | " aes_score = filter_model.get_aesthetics_score(image_features) # 计算美学分数,传入图片特征,一般而言,可以在数据库保存这个东西,用于响应文本query\n",
362 | " aes_scores.extend(aes_score)\n",
363 | "\n",
364 | " text_features = filter_model.get_text_feature(list(texts)) # 计算文本特征,传入文本列表\n",
365 | " clip_score = filter_model.calculate_clip_score(image_features, text_features) # 计算相似度\n",
366 | " clip_scores.extend(torch.diagonal(clip_score).detach().cpu().numpy()) # 需要取对角线,只需要自己和对应文本的相似度\n",
367 | "\n",
368 | " watermark_score = filter_model.get_watermark_score(images) # 计算水印分数,传入图片路径列表\n",
369 | " watermark_scores.extend(watermark_score)\n",
370 | " \n",
371 | " # print('aes_score:', aes_score, '\\n',\n",
372 | " # 'clip_score:', clip_score, '\\n',\n",
373 | " # 'watermark_score:', watermark_score, '\\n',\n",
374 | " # 'image_paths:', image_paths, '\\n',\n",
375 | " # 'texts:', texts)\n",
376 | " \n",
377 | " score_pd = pd.DataFrame({'image_path': image_paths, 'aes_score': aes_scores, 'clip_score': clip_scores, 'watermark_score': watermark_scores})\n",
378 | " score_pd.to_csv(os.path.join(each_folder_path, 'score.csv'), index=False)\n",
379 | " print('save score.csv in {}'.format(each_folder_path), '\\n', '-'*20)\n",
380 | "\n",
381 | "for each_folder in all_folders[:10]:\n",
382 | " each_folder_path = os.path.join(root_path, each_folder)\n",
383 | " f1 = p.submit(sub_process, model_pool[0], each_folder_path)\n",
384 | " f2 = p.submit(sub_process, model_pool[1], each_folder_path)\n",
385 | " f3 = p.submit(sub_process, model_pool[2], each_folder_path)\n",
386 | " f4 = p.submit(sub_process, model_pool[3], each_folder_path)\n",
387 | " res = wait([f1, f2, f3, f4], return_when=ALL_COMPLETED)\n",
388 | "p.shutdown()\n"
389 | ]
390 | },
391 | {
392 | "cell_type": "code",
393 | "execution_count": null,
394 | "metadata": {},
395 | "outputs": [],
396 | "source": [
397 | "# 用model pool来开启4进程跑\n",
398 | "model_pool = [FilterSystem() for i in range(4)]\n",
399 | "for model in model_pool:\n",
400 | " model.init_clip_model()\n",
401 | " model.init_aesthetics_model()\n",
402 | " model.init_watermark_model()"
403 | ]
404 | },
405 | {
406 | "cell_type": "code",
407 | "execution_count": null,
408 | "metadata": {},
409 | "outputs": [],
410 | "source": [
411 | "print(aes_scores, clip_scores, watermark_scores)"
412 | ]
413 | },
414 | {
415 | "cell_type": "code",
416 | "execution_count": null,
417 | "metadata": {},
418 | "outputs": [],
419 | "source": [
420 | "print('image_paths:', image_paths, '\\n', 'texts:', texts)"
421 | ]
422 | },
423 | {
424 | "cell_type": "markdown",
425 | "metadata": {},
426 | "source": [
427 | "# pytorch lightning + multi process."
428 | ]
429 | },
430 | {
431 | "cell_type": "code",
432 | "execution_count": 1,
433 | "metadata": {},
434 | "outputs": [],
435 | "source": [
436 | "import pytorch_lightning as pl\n",
437 | "\n",
438 | "class ScoreSystem(pl.LightningModule):\n",
439 | " def __init__(Self):\n",
440 | " super().__init__()\n",
441 | " self.text_encoder, self.text_tokenizer, self.clip_model, self.processor = self.init_clip_model()\n",
442 | " self.aesthetics_model = self.init_aesthetics_model()\n",
443 | " self.watermark_model, self.watermark_processor = self.init_watermark_model()\n",
444 | "\n",
445 | " def init_clip_model(self):\n",
446 | " text_encoder = BertModel.from_pretrained(self.clip_model_path).eval().cuda()\n",
447 | " text_tokenizer = BertTokenizer.from_pretrained(self.clip_model_path)\n",
448 | " clip_model, _, processor = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai')\n",
449 | " clip_model = clip_model.eval().cuda()\n",
450 | " print(\"clip model loaded\")\n",
451 | " return text_encoder, text_tokenizer, clip_model, processor\n",
452 | "\n",
453 | " def init_aesthetics_model(self, ):\n",
454 | " # 此处初始化美学模型\n",
455 | " aesthetics_model = AestheticsMLP(768)\n",
456 | " aesthetics_model.load_state_dict(torch.load(self.aesthetics_model_path)).eval().cuda()\n",
457 | " print(\"aesthetics model loaded\")\n",
458 | " return aesthetics_model\n",
459 | "\n",
460 | " def init_watermark_model(self, ):\n",
461 | " watermark_model = WaterMarkModel(self.watermark_model_path)\n",
462 | " watermark_model.eval().cuda()\n",
463 | " watermark_processor = T.Compose([\n",
464 | " T.Resize((256, 256)),\n",
465 | " T.ToTensor(),\n",
466 | " T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
467 | " ])\n",
468 | " print(\"watermark model loaded\")\n",
469 | " return watermark_model, watermark_processor\n",
470 | "\n",
471 | " def get_image_feature(self, images):\n",
472 | " # 此处返回图像的特征向量\n",
473 | " if isinstance(images, list):\n",
474 | " images = torch.stack([self.processor(image) for image in images]).cuda()\n",
475 | " elif isinstance(images, torch.Tensor):\n",
476 | " images = images.cuda()\n",
477 | "\n",
478 | " with torch.no_grad():\n",
479 | " image_features = self.clip_model.encode_image(images)\n",
480 | " image_features /= image_features.norm(dim=1, keepdim=True)\n",
481 | " return image_features\n",
482 | " \n",
483 | " def get_text_feature(self, text):\n",
484 | " # 此处返回文本的特征向量\n",
485 | " if isinstance(text, list) or isinstance(text, str):\n",
486 | " text = self.text_tokenizer(text, return_tensors='pt', padding=True)['input_ids'].cuda()\n",
487 | " elif isinstance(text, torch.Tensor):\n",
488 | " text = text.cuda()\n",
489 | "\n",
490 | " with torch.no_grad():\n",
491 | " text_features = self.text_encoder(text)[1]\n",
492 | " text_features /= text_features.norm(dim=1, keepdim=True)\n",
493 | " return text_features\n",
494 | "\n",
495 | " def calculate_clip_score(self, features1, features2):\n",
496 | " # 此处2个特征向量的相似度,输入可以是 图片+文本、文本+文本、图片+图片。\n",
497 | " # 返回的是相似度矩阵,维度为 f1.shape[0] * f2.shape[0]\n",
498 | " score_matrix = features1 @ features2.t()\n",
499 | " return score_matrix\n",
500 | "\n",
501 | " def get_aesthetics_score(self, features):\n",
502 | " # 此处返回美学分数,传入的是CLIP的feature, 先计算get_image_feature在传入此函数~(模型是ViT-L-14)\n",
503 | " with torch.no_grad():\n",
504 | " scores = self.aesthetics_model(features)\n",
505 | " scores = scores[:, 0].detach().cpu().numpy()\n",
506 | " return scores\n",
507 | " \n",
508 | " def get_watermark_score(self, images):\n",
509 | " if isinstance(images, list):\n",
510 | " images = torch.stack([self.watermark_processor(image) for image in images]).cuda()\n",
511 | " elif isinstance(images, torch.Tensor):\n",
512 | " images = images.cuda()\n",
513 | " with torch.no_grad():\n",
514 | " pred = self.watermark_model(images)\n",
515 | " watermark_scores = F.softmax(pred, dim=1)[:,0].detach().cpu().numpy()\n",
516 | "\n",
517 | " return watermark_scores\n",
518 | "\n",
519 | " def predict_step(self, batch, batch_idx):\n",
520 | " images, texts = batch \n",
521 | " # TODO 这里要么传入处理后的2种图片,要么传入纯图片,然后在下面的函数处理。(目前是传入纯图片)\n",
522 | " image_features = self.get_image_feature(images)\n",
523 | " text_features = self.get_text_feature(texts)\n",
524 | " clip_scores = self.calculate_clip_score(image_features, text_features)\n",
525 | " aes_scores = self.get_aesthetics_score(image_features)\n",
526 | " watermark_scores = self.get_watermark_score(images)\n",
527 | " return clip_scores, aes_scores, watermark_scores\n",
528 | "\n",
529 | " def on_predict_epoch_end(self, outputs):\n",
530 | " # 此处返回所有预测结果\n",
531 | " clip_scores = torch.cat([output[0] for output in outputs], dim=0)\n",
532 | " aes_scores = torch.cat([output[1] for output in outputs], dim=0)\n",
533 | " watermark_scores = torch.cat([output[2] for output in outputs], dim=0)\n",
534 | " return clip_scores, aes_scores, watermark_scores"
535 | ]
536 | }
537 | ],
538 | "metadata": {
539 | "kernelspec": {
540 | "display_name": "Python 3.9.13 ('base')",
541 | "language": "python",
542 | "name": "python3"
543 | },
544 | "language_info": {
545 | "codemirror_mode": {
546 | "name": "ipython",
547 | "version": 3
548 | },
549 | "file_extension": ".py",
550 | "mimetype": "text/x-python",
551 | "name": "python",
552 | "nbconvert_exporter": "python",
553 | "pygments_lexer": "ipython3",
554 | "version": "3.9.13"
555 | },
556 | "orig_nbformat": 4,
557 | "vscode": {
558 | "interpreter": {
559 | "hash": "4cc247672a8bfe61dc951074f9ca89ab002dc0f7e14586a8bb0828228bebeefa"
560 | }
561 | }
562 | },
563 | "nbformat": 4,
564 | "nbformat_minor": 2
565 | }
566 |
--------------------------------------------------------------------------------
/text-image/data_filter/wukong_filter.py:
--------------------------------------------------------------------------------
1 | # %%
2 | from torch.utils.data import Dataset, ConcatDataset
3 | from torchvision import transforms
4 | import os
5 | from PIL import Image
6 | from concurrent.futures import ProcessPoolExecutor
7 | import json
8 | import torch
9 | from transformers import BertModel
10 | import open_clip
11 | import numpy as np
12 | from transformers import BertTokenizer
13 | import pandas as pd
14 | from tqdm import tqdm
15 | import argparse
16 |
17 |
18 | parser = argparse.ArgumentParser(description="Simple example of a training script.")
19 | parser.add_argument(
20 | "--part",
21 | type=int,
22 | default=0,
23 | required=True,
24 | )
25 | args = parser.parse_args()
26 |
27 |
28 | class CsvDataset(Dataset):
29 | def __init__(self, input_filename, transforms, input_root, tokenizer, img_key, caption_key, sep="\t"):
30 | # logging.debug(f'Loading csv data from {input_filename}.')
31 | print(f'Loading csv data from {input_filename}.')
32 | self.images = []
33 | self.captions = []
34 | if input_filename.endswith('.csv'):
35 | df = pd.read_csv(input_filename, index_col=0)
36 | df = df[df['used'] == 1]
37 | self.images.extend(df[img_key].tolist())
38 | self.captions.extend(df[caption_key].tolist())
39 | # NOTE 中文的tokenizer
40 | self.tokenizer = tokenizer
41 | self.context_length = 77
42 | self.root = input_root
43 | self.transforms = transforms
44 |
45 | def __len__(self):
46 | return len(self.images)
47 |
48 | def __getitem__(self, idx):
49 | img_path = str(self.images[idx])
50 | image = self.transforms(Image.open( os.path.join(self.root, img_path )))
51 | text = self.tokenizer(str(self.captions[idx]), max_length=self.context_length, padding='max_length', truncation=True, return_tensors='pt')['input_ids'][0]
52 | return image, text, img_path
53 |
54 |
55 | text_encoder = BertModel.from_pretrained("IDEA-CCNL/Taiyi-CLIP-RoBERTa-102M-ViT-L-Chinese").eval().cuda()
56 | clip_model, _, processor = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai')
57 | clip_model = clip_model.eval().cuda()
58 | text_tokenizer = BertTokenizer.from_pretrained("IDEA-CCNL/Taiyi-CLIP-RoBERTa-102M-ViT-L-Chinese")
59 |
60 |
61 | input_filename = './project/dataset/wukong/release'
62 | preprocess_fn = processor
63 | input_root = './project/dataset/wukong/images'
64 | tokenizer = text_tokenizer
65 | all_csvs = sorted(os.listdir(input_filename))
66 |
67 | for i in range(len(all_csvs)*args.part//5, len(all_csvs)*(args.part+1)//5):
68 | # 分成5part
69 | each_csv_path = os.path.join(input_filename, all_csvs[i])
70 | dataset = CsvDataset(each_csv_path, preprocess_fn, input_root, tokenizer, img_key="name", caption_key="caption")
71 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=False, num_workers=8, pin_memory=True)
72 |
73 | df = pd.read_csv(each_csv_path, index_col=0)
74 | df = df[df['used'] == 1]
75 | scores = []
76 | for iii, (image, text, image_path) in enumerate(tqdm(dataloader)):
77 | # print(image.shape, text.shape)
78 | with torch.no_grad():
79 | image = image.cuda()
80 | text = text.cuda()
81 | # print(image.shape, text.shape)
82 | image_features = clip_model.encode_image(image)
83 | text_features = text_encoder(text)[1]
84 |
85 | # print(image_features.shape, text_features.shape)
86 | # 归一化
87 | image_features = image_features / image_features.norm(dim=1, keepdim=True)
88 | text_features = text_features / text_features.norm(dim=1, keepdim=True)
89 | score_each_pair = image_features @ text_features.t()
90 |
91 | scores.extend(torch.diagonal(score_each_pair).detach().cpu().numpy())
92 | # break
93 | df['score'] = scores
94 | df.to_csv( each_csv_path.replace(all_csvs[i], 'score'+all_csvs[i]) , index=False)
95 | print('saving score to', each_csv_path.replace(all_csvs[i], 'score'+all_csvs[i]) )
96 |
--------------------------------------------------------------------------------
/text-image/data_filter/wukong_reader.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset, ConcatDataset
2 | from torchvision import transforms
3 | import os
4 | from PIL import Image
5 | from concurrent.futures import ProcessPoolExecutor
6 | import json
7 | import torch
8 | from transformers import BertModel
9 | import open_clip
10 | import numpy as np
11 | from transformers import BertTokenizer
12 | import pandas as pd
13 | from tqdm import tqdm
14 | import argparse
15 | import torch
16 | # NOTE 加速读取数据,直接用原版的,在外部使用并行读取策略。30min->3min
17 | class CsvDataset(Dataset):
18 | def __init__(self, input_filename, input_root, img_key, caption_key, transforms=None, thres=0.2, sep="\t"):
19 | # logging.debug(f'Loading csv data from {input_filename}.')
20 | print(f'Loading csv data from {input_filename}.')
21 | self.images = []
22 | self.captions = []
23 |
24 | if input_filename.endswith('.csv'):
25 | # print(f"Load Data from{input_filename}")
26 | df = pd.read_csv(input_filename, index_col=0)
27 | df = df[df['used'] == 1]
28 | df = df[df['score']>thres]
29 | self.images.extend(df[img_key].tolist())
30 | self.captions.extend(df[caption_key].tolist())
31 |
32 | # NOTE 中文的tokenizer
33 | self.tokenizer = BertTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")
34 |
35 | self.context_length = 77
36 | self.root = input_root
37 | self.transforms = transforms
38 |
39 | def __len__(self):
40 | return len(self.images)
41 |
42 | def __getitem__(self, idx):
43 | img_path = str(self.images[idx])
44 | image = self.transforms(Image.open( os.path.join(self.root, img_path )))
45 | text = self.tokenizer(str(self.captions[idx]), max_length=self.context_length, padding='max_length', truncation=True, return_tensors='pt')['input_ids'][0]
46 | return image, text
47 |
48 |
49 | def process_pool_read_csv_dataset(input_root, input_filename, thres=0.20):
50 | # here input_filename is a directory containing a CSV file
51 | all_csvs = os.listdir(input_filename)
52 |
53 | csv_with_score = [each for each in all_csvs if 'score' in each ]
54 | all_datasets = []
55 | res = []
56 | p = ProcessPoolExecutor(max_workers=24)
57 | for i in range(len(csv_with_score)):
58 | each_csv_path = os.path.join(input_filename, csv_with_score[i])
59 | print(i, each_csv_path)
60 | res.append(p.submit(CsvDataset, each_csv_path, input_root, img_key="name", caption_key="caption", thres=thres))
61 | p.shutdown()
62 | for future in res:
63 | all_datasets.append(future.result())
64 | dataset = ConcatDataset(all_datasets)
65 | return dataset
66 |
67 |
68 | tokenizer = BertTokenizer.from_pretrained("IDEA-CCNL/Taiyi-CLIP-RoBERTa-102M-ViT-L-Chinese", model_max_length=512)
69 | input_filename = './project/dataset/wukong/release' # 这里存的是csv标注地址
70 | input_root = './project/dataset/wukong/images'
71 | dataset = process_pool_read_csv_dataset(input_root, input_filename, thres=0.22)
72 |
73 | print(len(dataset))
--------------------------------------------------------------------------------
/text-image/fid_clip_score/.gitignore:
--------------------------------------------------------------------------------
1 | /output*
2 |
--------------------------------------------------------------------------------
/text-image/fid_clip_score/coco_sample_generator.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset, DataLoader
2 | import pandas as pd
3 | import os
4 | from diffusers import StableDiffusionPipeline
5 | from argparse import ArgumentParser
6 | from tqdm import tqdm
7 | from multiprocessing import Process
8 |
9 | parser = ArgumentParser()
10 | parser.add_argument('--coco_path', type=str, default='../dataset/coco')
11 | parser.add_argument('--coco_cache_file', type=str, default='../dataset/coco/subset.parquet')
12 | parser.add_argument('--output_path', type=str, default='./output')
13 | parser.add_argument('--model_path', type=str, default='../pretrained_models/stable-diffusion-v1-4')
14 | parser.add_argument('--sample_step', type=int, default=20)
15 | parser.add_argument('--guidance_scale', type=float, default=1.5)
16 | parser.add_argument('--batch_size', type=int, default=2)
17 | args = parser.parse_args()
18 |
19 |
20 | class COCOCaptionSubset(Dataset):
21 | def __init__(self, path, transform=None):
22 | self.df = pd.read_parquet(path)
23 |
24 | def __len__(self):
25 | return len(self.df)
26 |
27 | def __getitem__(self, idx):
28 | row = self.df.iloc[idx]
29 | return row['file_name'], row['caption']
30 |
31 | def save_images(images, image_paths, output_path):
32 | for i, image_path in enumerate(image_paths):
33 | image_path = image_path.replace('/', '_')
34 | image_path = os.path.join(output_path, image_path)
35 | images[i].save(image_path)
36 |
37 | if __name__ == '__main__':
38 | # testing
39 | coco_path = args.coco_path
40 | # coco_cache_file = f'{coco_path}/subset.parquet' # sampled subsets
41 | cocosubset = COCOCaptionSubset(args.coco_cache_file)
42 | cocosubsetloader = DataLoader(cocosubset, batch_size=args.batch_size, shuffle=False, num_workers=8)
43 |
44 | # load the t2i model
45 | stable_diffusion = StableDiffusionPipeline.from_pretrained(args.model_path, requires_safety_checker=False).to('cuda')
46 |
47 | sample_step = args.sample_step
48 | guidance_scale = args.guidance_scale
49 |
50 |
51 | output_path = os.path.join(
52 | args.output_path,
53 | f'./gs{guidance_scale}_ss{sample_step}'
54 | )
55 | os.makedirs(output_path, exist_ok=True)
56 |
57 | for i, (image_paths, captions) in enumerate(tqdm(cocosubsetloader)):
58 | outputs = stable_diffusion(list(captions), num_inference_steps=sample_step, guidance_scale=guidance_scale).images
59 | p = Process(target=save_images, args=(outputs, image_paths, output_path))
60 | p.start()
--------------------------------------------------------------------------------
/text-image/fid_clip_score/compute_fid.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# FID指标计算"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 1,
13 | "metadata": {},
14 | "outputs": [
15 | {
16 | "name": "stdout",
17 | "output_type": "stream",
18 | "text": [
19 | "FID (same): -0.001\n",
20 | "FID (different): 486.117\n"
21 | ]
22 | }
23 | ],
24 | "source": [
25 | "import numpy as np\n",
26 | "from scipy.linalg import sqrtm\n",
27 | "def calculate_fid(act1, act2):\n",
28 | " # calculate mean and covariance statistics\n",
29 | " mu1, sigma1 = act1.mean(axis=0), np.cov(act1, rowvar=False)\n",
30 | " mu2, sigma2 = act2.mean(axis=0), np.cov(act2, rowvar=False)\n",
31 | " # print(mu1.shape, mu2.shape, sigma1.shape, sigma2.shape)\n",
32 | " # calculate sum squared difference between means\n",
33 | " ssdiff = np.sum((mu1 - mu2)**2.0)\n",
34 | " # print(ssdiff)\n",
35 | " # calculate sqrt of product between cov\n",
36 | " covmean = sqrtm(sigma1.dot(sigma2)) # 负数平方根也能算\n",
37 | " # print(covmean)\n",
38 | " # check and correct imaginary numbers from sqrt\n",
39 | " if np.iscomplexobj(covmean):\n",
40 | " covmean = covmean.real\n",
41 | " # calculate score\n",
42 | " fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)\n",
43 | " return fid\n",
44 | "\n",
45 | "# define two collections of activations\n",
46 | "act1 = np.random.rand(2, 2048)\n",
47 | "act2 = np.random.rand(3, 2048)\n",
48 | "# fid between act1 and act1\n",
49 | "fid = calculate_fid(act1, act1)\n",
50 | "print('FID (same): %.3f' % fid) # should be 0.0\n",
51 | "# fid between act1 and act2\n",
52 | "fid = calculate_fid(act1, act2)\n",
53 | "print('FID (different): %.3f' % fid) # should be > 0.0"
54 | ]
55 | }
56 | ],
57 | "metadata": {
58 | "kernelspec": {
59 | "display_name": "Python 3.9.13 ('base')",
60 | "language": "python",
61 | "name": "python3"
62 | },
63 | "language_info": {
64 | "codemirror_mode": {
65 | "name": "ipython",
66 | "version": 3
67 | },
68 | "file_extension": ".py",
69 | "mimetype": "text/x-python",
70 | "name": "python",
71 | "nbconvert_exporter": "python",
72 | "pygments_lexer": "ipython3",
73 | "version": "3.9.13"
74 | },
75 | "orig_nbformat": 4,
76 | "vscode": {
77 | "interpreter": {
78 | "hash": "4cc247672a8bfe61dc951074f9ca89ab002dc0f7e14586a8bb0828228bebeefa"
79 | }
80 | }
81 | },
82 | "nbformat": 4,
83 | "nbformat_minor": 2
84 | }
85 |
--------------------------------------------------------------------------------
/text-image/fid_clip_score/fid_clip_coco.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "reference: https://wandb.ai/dalle-mini/dalle-mini/reports/CLIP-score-vs-FID-pareto-curves--VmlldzoyMDYyNTAy"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "# Sampling data"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": null,
20 | "metadata": {},
21 | "outputs": [],
22 | "source": [
23 | "# load data\n",
24 | "import json\n",
25 | "coco_path = '/home/tiger/project/dataset/coco'\n",
26 | "data_file = f'{coco_path}/annotations/captions_val2014.json'\n",
27 | "data = json.load(open(data_file))\n",
28 | "\n",
29 | "\n",
30 | "# merge images and annotations\n",
31 | "import pandas as pd\n",
32 | "images = data['images']\n",
33 | "annotations = data['annotations']\n",
34 | "df = pd.DataFrame(images)\n",
35 | "df_annotations = pd.DataFrame(annotations)\n",
36 | "df = df.merge(pd.DataFrame(annotations), how='left', left_on='id', right_on='image_id')\n",
37 | "\n",
38 | "\n",
39 | "# keep only the relevant columns\n",
40 | "df = df[['file_name', 'caption']]\n",
41 | "\n",
42 | "\n",
43 | "# shuffle the dataset\n",
44 | "df = df.sample(frac=1)\n",
45 | "\n",
46 | "\n",
47 | "# remove duplicate images\n",
48 | "df = df.drop_duplicates(subset='file_name')\n",
49 | "\n",
50 | "\n",
51 | "# create a random subset\n",
52 | "n_samples = 10000\n",
53 | "df_sample = df.sample(n_samples)\n",
54 | "\n",
55 | "\n",
56 | "# save the sample to a parquet file\n",
57 | "df_sample.to_parquet(f'{coco_path}/subset.parquet')\n",
58 | "\n",
59 | "\n",
60 | "# copy the images to reference folder\n",
61 | "from pathlib import Path\n",
62 | "import shutil\n",
63 | "subset_path = Path(f'{coco_path}/subset')\n",
64 | "subset_path.mkdir(exist_ok=True)\n",
65 | "for i, row in df_sample.iterrows():\n",
66 | " path = f'{coco_path}/val2014/' + row['file_name']\n",
67 | " shutil.copy(path, f'{coco_path}/subset/')\n"
68 | ]
69 | },
70 | {
71 | "cell_type": "code",
72 | "execution_count": null,
73 | "metadata": {},
74 | "outputs": [],
75 | "source": [
76 | "# center crop the images\n",
77 | "def center_crop_images(folder, output_folder, size):\n",
78 | " # coco images are not square, so we need to center crop them\n",
79 | " from PIL import Image\n",
80 | " import os\n",
81 | " os.makedirs(output_folder, exist_ok=True)\n",
82 | " for file in os.listdir(folder):\n",
83 | " image_path = os.path.join(folder, file)\n",
84 | " image = Image.open(image_path)\n",
85 | " width, height = image.size\n",
86 | " left = (width - size) / 2 if width > size else 0\n",
87 | " top = (height - size) / 2 if height > size else 0\n",
88 | " right = (width + size) / 2 if width > size else width\n",
89 | " bottom = (height + size) / 2 if height > size else height\n",
90 | " image = image.crop((left, top, right, bottom))\n",
91 | " image = image.resize((size, size)) # resize non-square images\n",
92 | " image.save(os.path.join(output_folder, file))\n",
93 | "\n",
94 | "folder_name = '/home/tiger/project/dataset/coco/subset'\n",
95 | "output_folder = '/home/tiger/project/dataset/coco/subset_cropped'\n",
96 | "center_crop_images(folder_name, output_folder, 320)"
97 | ]
98 | },
99 | {
100 | "cell_type": "markdown",
101 | "metadata": {},
102 | "source": [
103 | "# Load subset as dataloader"
104 | ]
105 | },
106 | {
107 | "cell_type": "code",
108 | "execution_count": 5,
109 | "metadata": {},
110 | "outputs": [],
111 | "source": [
112 | "# load the subset\n",
113 | "from torch.utils.data import Dataset, DataLoader\n",
114 | "import pandas as pd \n",
115 | "\n",
116 | "class COCOCaptionSubset(Dataset):\n",
117 | " def __init__(self, path, transform=None):\n",
118 | " self.df = pd.read_parquet(path)\n",
119 | "\n",
120 | " def __len__(self):\n",
121 | " return len(self.df)\n",
122 | "\n",
123 | " def __getitem__(self, idx):\n",
124 | " row = self.df.iloc[idx]\n",
125 | " return row['file_name'], row['caption']\n",
126 | "\n",
127 | "# testing \n",
128 | "coco_path = '/home/tiger/project/dataset/coco'\n",
129 | "coco_cache_file = f'{coco_path}/subset.parquet' # sampled subsets\n",
130 | "cocosubset = COCOCaptionSubset(coco_cache_file)\n",
131 | "cocosubsetloader = DataLoader(cocosubset, batch_size=64, shuffle=False, num_workers=8)\n"
132 | ]
133 | },
134 | {
135 | "cell_type": "code",
136 | "execution_count": null,
137 | "metadata": {},
138 | "outputs": [],
139 | "source": []
140 | },
141 | {
142 | "cell_type": "markdown",
143 | "metadata": {},
144 | "source": [
145 | "# Generating Images Via T2I Model"
146 | ]
147 | },
148 | {
149 | "cell_type": "code",
150 | "execution_count": null,
151 | "metadata": {},
152 | "outputs": [],
153 | "source": [
154 | "# demo inference, use coco_sample_generator.py to generate more\n",
155 | "# load the t2i model\n",
156 | "# from diffusers import StableDiffusionPipeline\n",
157 | "# stable_diffusion = StableDiffusionPipeline.from_pretrained(\"/home/tiger/project/pretrained_models/stable-diffusion-v1-4\").to('cuda') \n",
158 | "\n",
159 | "# sample_step = 20\n",
160 | "# guidance_scale = 1.5\n",
161 | "\n",
162 | "# import os\n",
163 | "\n",
164 | "# output_path = f'./output_gs{guidance_scale}_ss{sample_step}'\n",
165 | "# os.makedirs(output_path, exist_ok=True)\n",
166 | "\n",
167 | "# for i, (image_paths, captions) in enumerate(cocosubsetloader):\n",
168 | "# outputs = stable_diffusion(list(captions), num_inference_steps=sample_step, guidance_scale=guidance_scale).images\n",
169 | "# for j, image_path in enumerate(image_paths):\n",
170 | "# image_path = image_path.replace('/', '_')\n",
171 | "# image_path = os.path.join(output_path, image_path)\n",
172 | "# outputs[j].save(image_path)\n",
173 | "# break"
174 | ]
175 | },
176 | {
177 | "cell_type": "markdown",
178 | "metadata": {},
179 | "source": []
180 | },
181 | {
182 | "cell_type": "code",
183 | "execution_count": 1,
184 | "metadata": {},
185 | "outputs": [],
186 | "source": [
187 | "import torch\n",
188 | "device = torch.device('cuda')\n",
189 | "\n",
190 | "coco_subset_crop_path = '/home/tiger/project/dataset/coco/subset_cropped'\n",
191 | "output_root = '/home/tiger/project/position-guided-t2i/output'\n",
192 | "output_paths = [os.path.join(output_root, out) for out in sorted(os.listdir(output_root))]\n"
193 | ]
194 | },
195 | {
196 | "cell_type": "code",
197 | "execution_count": 2,
198 | "metadata": {},
199 | "outputs": [
200 | {
201 | "name": "stderr",
202 | "output_type": "stream",
203 | "text": [
204 | "/home/tiger/anaconda3/lib/python3.9/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and will be removed in 0.15, please use 'weights' instead.\n",
205 | " warnings.warn(\n",
206 | "/home/tiger/anaconda3/lib/python3.9/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=None`.\n",
207 | " warnings.warn(msg)\n",
208 | "100%|██████████| 50/50 [00:13<00:00, 3.58it/s]\n",
209 | "100%|██████████| 50/50 [00:16<00:00, 3.06it/s]\n"
210 | ]
211 | },
212 | {
213 | "name": "stdout",
214 | "output_type": "stream",
215 | "text": [
216 | "/home/tiger/project/position-guided-t2i/output/gs1.5_ss20 22.765903388613765\n"
217 | ]
218 | },
219 | {
220 | "name": "stderr",
221 | "output_type": "stream",
222 | "text": [
223 | "100%|██████████| 50/50 [00:09<00:00, 5.50it/s]\n",
224 | "100%|██████████| 50/50 [00:16<00:00, 3.03it/s]\n"
225 | ]
226 | },
227 | {
228 | "name": "stdout",
229 | "output_type": "stream",
230 | "text": [
231 | "/home/tiger/project/position-guided-t2i/output/gs2.0_ss20 18.159921113816665\n"
232 | ]
233 | },
234 | {
235 | "name": "stderr",
236 | "output_type": "stream",
237 | "text": [
238 | "100%|██████████| 50/50 [00:10<00:00, 4.95it/s]\n",
239 | "100%|██████████| 50/50 [00:15<00:00, 3.14it/s]\n"
240 | ]
241 | },
242 | {
243 | "name": "stdout",
244 | "output_type": "stream",
245 | "text": [
246 | "/home/tiger/project/position-guided-t2i/output/gs3.0_ss20 15.94397287378655\n"
247 | ]
248 | },
249 | {
250 | "name": "stderr",
251 | "output_type": "stream",
252 | "text": [
253 | "100%|██████████| 50/50 [00:09<00:00, 5.18it/s]\n",
254 | "100%|██████████| 50/50 [00:15<00:00, 3.14it/s]\n"
255 | ]
256 | },
257 | {
258 | "name": "stdout",
259 | "output_type": "stream",
260 | "text": [
261 | "/home/tiger/project/position-guided-t2i/output/gs4.0_ss20 16.315106185605657\n"
262 | ]
263 | },
264 | {
265 | "name": "stderr",
266 | "output_type": "stream",
267 | "text": [
268 | "100%|██████████| 50/50 [00:09<00:00, 5.17it/s]\n",
269 | "100%|██████████| 50/50 [00:16<00:00, 3.00it/s]\n"
270 | ]
271 | },
272 | {
273 | "name": "stdout",
274 | "output_type": "stream",
275 | "text": [
276 | "/home/tiger/project/position-guided-t2i/output/gs5.0_ss20 17.35088805364785\n"
277 | ]
278 | },
279 | {
280 | "name": "stderr",
281 | "output_type": "stream",
282 | "text": [
283 | "100%|██████████| 50/50 [00:09<00:00, 5.15it/s]\n",
284 | "100%|██████████| 50/50 [00:16<00:00, 3.01it/s]\n"
285 | ]
286 | },
287 | {
288 | "name": "stdout",
289 | "output_type": "stream",
290 | "text": [
291 | "/home/tiger/project/position-guided-t2i/output/gs6.0_ss20 17.933771904354728\n"
292 | ]
293 | },
294 | {
295 | "name": "stderr",
296 | "output_type": "stream",
297 | "text": [
298 | "100%|██████████| 50/50 [00:09<00:00, 5.21it/s]\n",
299 | "100%|██████████| 50/50 [00:16<00:00, 3.06it/s]\n"
300 | ]
301 | },
302 | {
303 | "name": "stdout",
304 | "output_type": "stream",
305 | "text": [
306 | "/home/tiger/project/position-guided-t2i/output/gs7.0_ss20 19.059673548019532\n"
307 | ]
308 | },
309 | {
310 | "name": "stderr",
311 | "output_type": "stream",
312 | "text": [
313 | "100%|██████████| 50/50 [00:09<00:00, 5.18it/s]\n",
314 | "100%|██████████| 50/50 [00:17<00:00, 2.90it/s]\n"
315 | ]
316 | },
317 | {
318 | "name": "stdout",
319 | "output_type": "stream",
320 | "text": [
321 | "/home/tiger/project/position-guided-t2i/output/gs8.0_ss20 20.12984543749127\n"
322 | ]
323 | }
324 | ],
325 | "source": [
326 | "# fid score\n",
327 | "\n",
328 | "# !pip install pytorch_fid\n",
329 | "# !python -m pytorch_fid /home/tiger/project/dataset/coco/subset_cropped /home/tiger/project/position-guided-t2i/output/gs2.0_ss20\n",
330 | "\n",
331 | "from pytorch_fid.fid_score import calculate_fid_given_paths\n",
332 | "\n",
333 | "fids = []\n",
334 | "for output_path in output_paths:\n",
335 | " fid_value = calculate_fid_given_paths([coco_subset_crop_path, output_path], batch_size=200, device=device, dims=2048, num_workers=8)\n",
336 | " fids.append(fid_value)\n",
337 | " print(output_path, fid_value)"
338 | ]
339 | },
340 | {
341 | "cell_type": "code",
342 | "execution_count": 3,
343 | "metadata": {},
344 | "outputs": [],
345 | "source": [
346 | "from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer\n",
347 | "from PIL import Image\n",
348 | "import numpy as np\n",
349 | "from tqdm import tqdm\n",
350 | "\n",
351 | "\n",
352 | "def load_clip_model(model_path='openai/clip-vit-large-patch14'):\n",
353 | " # text_encoder = BertModel.from_pretrained(model_path).eval().cuda()\n",
354 | " # text_tokenizer = BertTokenizer.from_pretrained(model_path)\n",
355 | " clip_model = CLIPModel.from_pretrained(model_path)\n",
356 | " processor = CLIPProcessor.from_pretrained(model_path)\n",
357 | " tokenizer = CLIPTokenizer.from_pretrained(model_path)\n",
358 | "\n",
359 | " clip_model = clip_model.eval().cuda()\n",
360 | " return clip_model, processor, tokenizer\n",
361 | "\n",
362 | "\n",
363 | "def clip_score(clip_model, processor, tokenizer, dataloader, output_image_path):\n",
364 | " all_image_features = []\n",
365 | " all_text_features = []\n",
366 | " for (i, (image_paths, captions)) in enumerate(tqdm(dataloader)):\n",
367 | " # print(image_paths, captions)\n",
368 | " text_inputs = tokenizer(list(captions), padding=True, return_tensors=\"pt\").to('cuda')\n",
369 | " text_features = clip_model.get_text_features(**text_inputs)\n",
370 | " text_features = text_features / text_features.norm(dim=-1, keepdim=True)\n",
371 | " text_features = text_features.detach().cpu().numpy()\n",
372 | " all_text_features.append(text_features)\n",
373 | "\n",
374 | " # vit 速度比较龟\n",
375 | " images = [Image.open(os.path.join( output_image_path , image_path)) for image_path in image_paths]\n",
376 | " image_inputs = processor(images = images, return_tensors=\"pt\").to('cuda')\n",
377 | " image_features = clip_model.get_image_features(**image_inputs)\n",
378 | " image_features = image_features / image_features.norm(dim=-1, keepdim=True)\n",
379 | " image_features = image_features.detach().cpu().numpy()\n",
380 | " all_image_features.append(image_features)\n",
381 | "\n",
382 | " # NOTE testing 等太久了,抽样吧... 需要全部的话,把这个 if 去掉\n",
383 | " if i == 10:\n",
384 | " break\n",
385 | "\n",
386 | " all_text_features = np.concatenate(all_text_features, axis=0)\n",
387 | " all_image_features = np.concatenate(all_image_features, axis=0)\n",
388 | " mean_similarity = (all_image_features @ all_text_features.T).diagonal().mean()\n",
389 | " return mean_similarity\n"
390 | ]
391 | },
392 | {
393 | "cell_type": "code",
394 | "execution_count": 6,
395 | "metadata": {},
396 | "outputs": [
397 | {
398 | "name": "stderr",
399 | "output_type": "stream",
400 | "text": [
401 | " 6%|▋ | 10/157 [00:15<03:45, 1.54s/it]\n"
402 | ]
403 | },
404 | {
405 | "name": "stdout",
406 | "output_type": "stream",
407 | "text": [
408 | "/home/tiger/project/position-guided-t2i/output/gs1.5_ss20 0.23490335\n"
409 | ]
410 | },
411 | {
412 | "name": "stderr",
413 | "output_type": "stream",
414 | "text": [
415 | " 6%|▋ | 10/157 [00:14<03:39, 1.50s/it]\n"
416 | ]
417 | },
418 | {
419 | "name": "stdout",
420 | "output_type": "stream",
421 | "text": [
422 | "/home/tiger/project/position-guided-t2i/output/gs2.0_ss20 0.24406949\n"
423 | ]
424 | },
425 | {
426 | "name": "stderr",
427 | "output_type": "stream",
428 | "text": [
429 | " 6%|▋ | 10/157 [00:15<03:41, 1.51s/it]\n"
430 | ]
431 | },
432 | {
433 | "name": "stdout",
434 | "output_type": "stream",
435 | "text": [
436 | "/home/tiger/project/position-guided-t2i/output/gs3.0_ss20 0.25112092\n"
437 | ]
438 | },
439 | {
440 | "name": "stderr",
441 | "output_type": "stream",
442 | "text": [
443 | " 6%|▋ | 10/157 [00:14<03:39, 1.49s/it]\n"
444 | ]
445 | },
446 | {
447 | "name": "stdout",
448 | "output_type": "stream",
449 | "text": [
450 | "/home/tiger/project/position-guided-t2i/output/gs4.0_ss20 0.25709876\n"
451 | ]
452 | },
453 | {
454 | "name": "stderr",
455 | "output_type": "stream",
456 | "text": [
457 | " 6%|▋ | 10/157 [00:14<03:37, 1.48s/it]\n"
458 | ]
459 | },
460 | {
461 | "name": "stdout",
462 | "output_type": "stream",
463 | "text": [
464 | "/home/tiger/project/position-guided-t2i/output/gs5.0_ss20 0.25781947\n"
465 | ]
466 | },
467 | {
468 | "name": "stderr",
469 | "output_type": "stream",
470 | "text": [
471 | " 6%|▋ | 10/157 [00:14<03:40, 1.50s/it]\n"
472 | ]
473 | },
474 | {
475 | "name": "stdout",
476 | "output_type": "stream",
477 | "text": [
478 | "/home/tiger/project/position-guided-t2i/output/gs6.0_ss20 0.2593051\n"
479 | ]
480 | },
481 | {
482 | "name": "stderr",
483 | "output_type": "stream",
484 | "text": [
485 | " 6%|▋ | 10/157 [00:14<03:37, 1.48s/it]\n"
486 | ]
487 | },
488 | {
489 | "name": "stdout",
490 | "output_type": "stream",
491 | "text": [
492 | "/home/tiger/project/position-guided-t2i/output/gs7.0_ss20 0.26007786\n"
493 | ]
494 | },
495 | {
496 | "name": "stderr",
497 | "output_type": "stream",
498 | "text": [
499 | " 6%|▋ | 10/157 [00:14<03:37, 1.48s/it]"
500 | ]
501 | },
502 | {
503 | "name": "stdout",
504 | "output_type": "stream",
505 | "text": [
506 | "/home/tiger/project/position-guided-t2i/output/gs8.0_ss20 0.2596085\n"
507 | ]
508 | },
509 | {
510 | "name": "stderr",
511 | "output_type": "stream",
512 | "text": [
513 | "\n"
514 | ]
515 | }
516 | ],
517 | "source": [
518 | "clip_model_path=\"/home/tiger/project/pretrained_models/clip-vit-large-patch14\"\n",
519 | "clip_model, processor, tokenizer = load_clip_model(clip_model_path)\n",
520 | "clip_scores = []\n",
521 | "for output_path in output_paths:\n",
522 | " clip_score_each = clip_score(clip_model, processor, tokenizer, cocosubsetloader, output_path) # 3min ....\n",
523 | " print(output_path, clip_score_each)\n",
524 | " clip_scores.append(clip_score_each)"
525 | ]
526 | },
527 | {
528 | "cell_type": "code",
529 | "execution_count": 7,
530 | "metadata": {},
531 | "outputs": [
532 | {
533 | "data": {
534 | "image/png": "",
535 | "text/plain": [
536 | ""
537 | ]
538 | },
539 | "metadata": {},
540 | "output_type": "display_data"
541 | }
542 | ],
543 | "source": [
544 | "# plot clip score as x-axis, fid score as y-axis, line chart\n",
545 | "import matplotlib.pyplot as plt\n",
546 | "plt.plot(clip_scores, fids, 'o-')\n",
547 | "plt.xlabel('clip score')\n",
548 | "plt.ylabel('fid score')\n",
549 | "plt.show()"
550 | ]
551 | }
552 | ],
553 | "metadata": {
554 | "kernelspec": {
555 | "display_name": "Python 3.9.13 ('base')",
556 | "language": "python",
557 | "name": "python3"
558 | },
559 | "language_info": {
560 | "codemirror_mode": {
561 | "name": "ipython",
562 | "version": 3
563 | },
564 | "file_extension": ".py",
565 | "mimetype": "text/x-python",
566 | "name": "python",
567 | "nbconvert_exporter": "python",
568 | "pygments_lexer": "ipython3",
569 | "version": "3.9.13"
570 | },
571 | "orig_nbformat": 4,
572 | "vscode": {
573 | "interpreter": {
574 | "hash": "4cc247672a8bfe61dc951074f9ca89ab002dc0f7e14586a8bb0828228bebeefa"
575 | }
576 | }
577 | },
578 | "nbformat": 4,
579 | "nbformat_minor": 2
580 | }
581 |
--------------------------------------------------------------------------------
/text-image/fid_clip_score/run_generator.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | ps aux | grep -E 'run_watch.sh|watch.py' |awk '{print $2}' | xargs kill -9 # kill previous watchdog
3 | guidance_scales=(1.5 2.0 3.0 4.0 5.0 6.0 7.0 8.0)
4 | for i in {0..7}
5 | do
6 | echo ${i}
7 | CUDA_VISIBLE_DEVICES=${i} nohup python coco_sample_generator.py --guidance_scale ${guidance_scales[${i}]} --batch_size 16 --sample_step 20 > stable_generator.log 2>&1 &
8 | done
9 | wait
10 | bash ~/release_watchdog.sh # start watchdog
--------------------------------------------------------------------------------
/text-image/fid_clip_score/run_generator_cn.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | ps aux | grep -E 'run_watch.sh|watch.py' |awk '{print $2}' | xargs kill -9 # kill previous watchdog
3 | guidance_scales=(1.5 2.0 3.0 4.0 5.0 6.0 7.0 8.0)
4 | for i in {0..7}
5 | do
6 | echo ${i}
7 | CUDA_VISIBLE_DEVICES=${i} nohup python coco_sample_generator.py --model_path ../pretrained_models/stable_cn --coco_cache_file ../dataset/coco/subset_cn.parquet --output_path ./output_cn --guidance_scale ${guidance_scales[${i}]} --batch_size 16 --sample_step 20 > stable_generator.log 2>&1 &
8 | done
9 | wait
10 | bash ~/release_watchdog.sh # start watchdog
--------------------------------------------------------------------------------
/text-image/imagenet_CN_zeroshot_data.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | imagenet_classnames = [
4 | "丁鲷",
5 | "金鱼",
6 | "大白鲨",
7 | "虎鲨",
8 | "锤头鲨",
9 | "电鳐",
10 | "黄貂鱼",
11 | "公鸡",
12 | "母鸡",
13 | "鸵鸟",
14 | "燕雀",
15 | "金翅雀",
16 | "家朱雀",
17 | "灯芯草雀",
18 | "靛蓝雀",
19 | "蓝鹀",
20 | "夜莺",
21 | "松鸦",
22 | "喜鹊",
23 | "山雀",
24 | "河鸟",
25 | "鸢(猛禽)",
26 | "秃头鹰",
27 | "秃鹫",
28 | "大灰猫头鹰",
29 | "欧洲火蝾螈",
30 | "普通蝾螈",
31 | "水蜥",
32 | "斑点蝾螈",
33 | "蝾螈",
34 | "牛蛙",
35 | "树蛙",
36 | "尾蛙",
37 | "红海龟",
38 | "皮革龟",
39 | "泥龟",
40 | "淡水龟",
41 | "箱龟",
42 | "带状壁虎",
43 | "普通鬣蜥",
44 | "美国变色龙",
45 | "鞭尾蜥蜴",
46 | "飞龙科蜥蜴",
47 | "褶边蜥蜴",
48 | "鳄鱼蜥蜴",
49 | "毒蜥",
50 | "绿蜥蜴",
51 | "非洲变色龙",
52 | "科莫多蜥蜴",
53 | "非洲鳄",
54 | "美国鳄鱼",
55 | "三角龙",
56 | "雷蛇",
57 | "环蛇",
58 | "希腊蛇",
59 | "绿蛇",
60 | "国王蛇",
61 | "袜带蛇",
62 | "水蛇",
63 | "藤蛇",
64 | "夜蛇",
65 | "大蟒蛇",
66 | "岩石蟒蛇",
67 | "印度眼镜蛇",
68 | "绿曼巴",
69 | "海蛇",
70 | "角腹蛇",
71 | "菱纹响尾蛇",
72 | "角响尾蛇",
73 | "三叶虫",
74 | "盲蜘蛛",
75 | "蝎子",
76 | "黑金花园蜘蛛",
77 | "谷仓蜘蛛",
78 | "花园蜘蛛",
79 | "黑寡妇蜘蛛",
80 | "狼蛛",
81 | "狼蜘蛛",
82 | "壁虱",
83 | "蜈蚣",
84 | "黑松鸡",
85 | "松鸡",
86 | "披肩鸡",
87 | "草原鸡",
88 | "孔雀",
89 | "鹌鹑",
90 | "鹧鸪",
91 | "非洲灰鹦鹉",
92 | "金刚鹦鹉",
93 | "硫冠鹦鹉",
94 | "短尾鹦鹉",
95 | "褐翅鸦鹃",
96 | "食蜂鸟;蜂虎",
97 | "犀鸟",
98 | "蜂鸟",
99 | "鹟䴕",
100 | "巨嘴鸟;大嘴鸟",
101 | "野鸭",
102 | "红胸秋沙鸭",
103 | "鹅",
104 | "黑天鹅",
105 | "大象",
106 | "针鼹鼠",
107 | "鸭嘴兽",
108 | "沙袋鼠",
109 | "考拉",
110 | "袋熊",
111 | "水母",
112 | "海葵",
113 | "脑珊瑚",
114 | "扁形虫扁虫",
115 | "线虫",
116 | "海螺",
117 | "蜗牛",
118 | "鼻涕虫",
119 | "海蛞蝓;海参",
120 | "石鳖",
121 | "鹦鹉螺",
122 | "珍宝蟹",
123 | "石蟹",
124 | "招潮蟹",
125 | "帝王蟹",
126 | "美国龙虾",
127 | "大螯虾",
128 | "小龙虾",
129 | "寄居蟹",
130 | "等足目动物(明虾和螃蟹近亲)",
131 | "白鹳",
132 | "黑鹳",
133 | "鹭",
134 | "火烈鸟",
135 | "小蓝鹭",
136 | "美国鹭",
137 | "麻鸦",
138 | "鹤",
139 | "秧鹤",
140 | "欧洲水鸡",
141 | "沼泽泥母鸡",
142 | "鸨",
143 | "红翻石鹬",
144 | "红背鹬",
145 | "红脚鹬",
146 | "半蹼鹬",
147 | "蛎鹬",
148 | "鹈鹕",
149 | "国王企鹅",
150 | "信天翁",
151 | "灰鲸",
152 | "杀人鲸",
153 | "海牛",
154 | "海狮",
155 | "吉娃娃",
156 | "日本狆犬",
157 | "马尔济斯犬",
158 | "狮子狗",
159 | "西施犬",
160 | "布莱尼姆猎犬",
161 | "巴比狗",
162 | "玩具犬",
163 | "罗得西亚长背猎狗",
164 | "阿富汗猎犬",
165 | "巴吉度猎犬",
166 | "比格犬",
167 | "侦探犬",
168 | "蓝色快狗",
169 | "黑褐猎浣熊犬",
170 | "沃克猎犬",
171 | "英国猎狐犬",
172 | "美洲赤狗",
173 | "俄罗斯猎狼犬",
174 | "爱尔兰猎狼犬",
175 | "意大利灰狗",
176 | "惠比特犬",
177 | "依比沙猎犬",
178 | "挪威猎犬",
179 | "奥达猎犬",
180 | "沙克犬",
181 | "苏格兰猎鹿犬",
182 | "威玛猎犬",
183 | "斯塔福德郡斗牛犬",
184 | "美国斯塔福德郡梗",
185 | "贝德灵顿梗",
186 | "边境梗",
187 | "凯丽蓝梗",
188 | "爱尔兰梗",
189 | "诺福克梗",
190 | "诺维奇梗",
191 | "约克犬;约克夏梗犬",
192 | "刚毛猎狐梗",
193 | "莱克兰梗",
194 | "锡利哈姆梗",
195 | "艾尔谷犬",
196 | "凯恩梗",
197 | "澳大利亚梗",
198 | "丹迪丁蒙梗",
199 | "波士顿梗",
200 | "迷你雪纳瑞犬",
201 | "巨型雪纳瑞犬",
202 | "标准雪纳瑞犬",
203 | "苏格兰梗犬",
204 | "西藏梗",
205 | "丝毛梗",
206 | "爱尔兰软毛梗犬",
207 | "西高地白梗",
208 | "拉萨阿普索犬",
209 | "平毛寻回犬",
210 | "卷毛寻回犬",
211 | "金毛猎犬",
212 | "拉布拉多猎犬",
213 | "乞沙比克猎犬",
214 | "德国短毛指示犬",
215 | "维兹拉犬",
216 | "英国塞特犬",
217 | "爱尔兰雪达犬",
218 | "戈登雪达犬",
219 | "布列塔尼犬猎犬",
220 | "黄毛",
221 | "英国史宾格犬",
222 | "威尔士史宾格犬",
223 | "可卡犬",
224 | "萨塞克斯猎犬",
225 | "爱尔兰水猎犬",
226 | "哥威斯犬",
227 | "舒柏奇犬",
228 | "比利时牧羊犬",
229 | "马里努阿犬",
230 | "伯瑞犬",
231 | "凯尔皮犬",
232 | "匈牙利牧羊犬",
233 | "老英国牧羊犬",
234 | "喜乐蒂牧羊犬",
235 | "牧羊犬",
236 | "边境牧羊犬",
237 | "法兰德斯牧牛狗",
238 | "罗特韦尔犬",
239 | "德国牧羊犬",
240 | "多伯曼犬",
241 | "鹿犬;迷你杜宾犬",
242 | "大瑞士山地犬",
243 | "伯恩山犬",
244 | "阿策尔山犬",
245 | "恩特尔布赫山犬",
246 | "拳师狗",
247 | "斗牛獒",
248 | "藏獒",
249 | "法国斗牛犬",
250 | "大丹犬",
251 | "圣伯纳德狗",
252 | "爱斯基摩犬",
253 | "阿拉斯加雪橇犬",
254 | "哈士奇",
255 | "达尔马提亚",
256 | "狮毛狗",
257 | "巴辛吉狗",
258 | "八哥犬",
259 | "莱昂贝格狗",
260 | "纽芬兰犬",
261 | "大白熊犬",
262 | "萨摩耶犬",
263 | "博美犬",
264 | "松狮",
265 | "凯斯犬",
266 | "布鲁塞尔格林芬犬",
267 | "彭布洛克威尔士科基犬",
268 | "威尔士柯基犬",
269 | "玩具贵宾犬",
270 | "迷你贵宾犬",
271 | "标准贵宾犬",
272 | "墨西哥无毛犬",
273 | "灰狼",
274 | "白狼",
275 | "红太狼",
276 | "狼",
277 | "澳洲野狗",
278 | "豺",
279 | "非洲猎犬",
280 | "鬣狗",
281 | "红狐狸",
282 | "沙狐",
283 | "北极狐狸",
284 | "灰狐狸",
285 | "虎斑猫",
286 | "山猫",
287 | "波斯猫",
288 | "暹罗猫",
289 | "埃及猫",
290 | "美洲狮",
291 | "猞猁",
292 | "豹子",
293 | "雪豹",
294 | "美洲虎",
295 | "狮子",
296 | "老虎",
297 | "猎豹",
298 | "棕熊",
299 | "美洲黑熊",
300 | "冰熊",
301 | "懒熊",
302 | "獴",
303 | "猫鼬",
304 | "虎甲虫",
305 | "瓢虫",
306 | "土鳖虫",
307 | "天牛",
308 | "龟甲虫",
309 | "粪甲虫",
310 | "犀牛甲虫",
311 | "象甲",
312 | "苍蝇",
313 | "蜜蜂",
314 | "蚂蚁",
315 | "蚱蜢",
316 | "蟋蟀",
317 | "竹节虫",
318 | "蟑螂",
319 | "螳螂",
320 | "蝉",
321 | "叶蝉",
322 | "草蜻蛉",
323 | "蜻蜓",
324 | "豆娘",
325 | "优红蛱蝶",
326 | "小环蝴蝶",
327 | "君主蝴蝶",
328 | "菜粉蝶",
329 | "白蝴蝶",
330 | "灰蝶",
331 | "海星",
332 | "海胆",
333 | "海黄瓜;海参",
334 | "野兔",
335 | "兔",
336 | "安哥拉兔",
337 | "仓鼠",
338 | "刺猬",
339 | "黑松鼠",
340 | "土拨鼠",
341 | "海狸",
342 | "豚鼠",
343 | "栗色马",
344 | "斑马",
345 | "猪",
346 | "野猪",
347 | "疣猪",
348 | "河马",
349 | "牛",
350 | "水牛",
351 | "野牛",
352 | "公羊",
353 | "大角羊",
354 | "山羊",
355 | "狷羚",
356 | "黑斑羚",
357 | "瞪羚",
358 | "阿拉伯单峰骆驼",
359 | "骆驼",
360 | "黄鼠狼",
361 | "水貂",
362 | "臭猫",
363 | "黑足鼬",
364 | "水獭",
365 | "臭鼬",
366 | "獾",
367 | "犰狳",
368 | "树懒",
369 | "猩猩",
370 | "大猩猩",
371 | "黑猩猩",
372 | "长臂猿",
373 | "合趾猿长臂猿",
374 | "长尾猴",
375 | "赤猴",
376 | "狒狒",
377 | "恒河猴",
378 | "白头叶猴",
379 | "疣猴",
380 | "长鼻猴",
381 | "狨(美洲产小型长尾猴)",
382 | "卷尾猴",
383 | "吼猴",
384 | "伶猴",
385 | "蜘蛛猴",
386 | "松鼠猴",
387 | "马达加斯加环尾狐猴",
388 | "大狐猴",
389 | "印度大象",
390 | "非洲象",
391 | "小熊猫",
392 | "大熊猫",
393 | "杖鱼",
394 | "鳗鱼",
395 | "银鲑",
396 | "三色刺蝶鱼",
397 | "海葵鱼",
398 | "鲟鱼",
399 | "雀鳝",
400 | "狮子鱼",
401 | "河豚",
402 | "算盘",
403 | "长袍",
404 | "学位袍",
405 | "手风琴",
406 | "原声吉他",
407 | "航空母舰",
408 | "客机",
409 | "飞艇",
410 | "祭坛",
411 | "救护车",
412 | "水陆两用车",
413 | "模拟时钟",
414 | "蜂房",
415 | "围裙",
416 | "垃圾桶",
417 | "攻击步枪",
418 | "背包",
419 | "面包店",
420 | "平衡木",
421 | "热气球",
422 | "圆珠笔",
423 | "创可贴",
424 | "班卓琴",
425 | "栏杆",
426 | "杠铃",
427 | "理发师的椅子",
428 | "理发店",
429 | "牲口棚",
430 | "晴雨表",
431 | "圆筒",
432 | "园地小车",
433 | "棒球",
434 | "篮球",
435 | "婴儿床",
436 | "巴松管",
437 | "游泳帽",
438 | "沐浴毛巾",
439 | "浴缸",
440 | "沙滩车",
441 | "灯塔",
442 | "烧杯",
443 | "熊皮高帽",
444 | "啤酒瓶",
445 | "啤酒杯",
446 | "钟塔",
447 | "(小儿用的)围嘴",
448 | "串联自行车",
449 | "比基尼",
450 | "装订册",
451 | "双筒望远镜",
452 | "鸟舍",
453 | "船库",
454 | "双人雪橇",
455 | "饰扣式领带",
456 | "阔边女帽",
457 | "书橱",
458 | "书店",
459 | "瓶盖",
460 | "弓箭",
461 | "蝴蝶结领结",
462 | "铜制牌位",
463 | "奶罩",
464 | "防波堤",
465 | "铠甲",
466 | "扫帚",
467 | "桶",
468 | "扣环",
469 | "防弹背心",
470 | "动车",
471 | "肉铺",
472 | "出租车",
473 | "大锅",
474 | "蜡烛",
475 | "大炮",
476 | "独木舟",
477 | "开瓶器",
478 | "开衫",
479 | "车镜",
480 | "旋转木马",
481 | "木匠的工具包",
482 | "纸箱",
483 | "车轮",
484 | "取款机",
485 | "盒式录音带",
486 | "卡带播放器",
487 | "城堡",
488 | "双体船",
489 | "CD播放器",
490 | "大提琴",
491 | "移动电话",
492 | "铁链",
493 | "围栏",
494 | "链甲",
495 | "电锯",
496 | "箱子",
497 | "梳妆台",
498 | "编钟",
499 | "中国橱柜",
500 | "圣诞袜",
501 | "教堂",
502 | "电影院",
503 | "切肉刀",
504 | "悬崖屋",
505 | "斗篷",
506 | "木屐",
507 | "鸡尾酒调酒器",
508 | "咖啡杯",
509 | "咖啡壶",
510 | "螺旋结构(楼梯)",
511 | "组合锁",
512 | "电脑键盘",
513 | "糖果",
514 | "集装箱船",
515 | "敞篷车",
516 | "瓶塞钻",
517 | "短号",
518 | "牛仔靴",
519 | "牛仔帽",
520 | "摇篮",
521 | "起重机",
522 | "头盔",
523 | "板条箱",
524 | "小儿床",
525 | "砂锅",
526 | "槌球",
527 | "拐杖",
528 | "胸甲",
529 | "大坝",
530 | "书桌",
531 | "台式电脑",
532 | "有线电话",
533 | "尿布湿",
534 | "数字时钟",
535 | "数字手表",
536 | "餐桌板",
537 | "抹布",
538 | "洗碗机",
539 | "盘式制动器",
540 | "码头",
541 | "狗拉雪橇",
542 | "圆顶",
543 | "门垫",
544 | "钻井平台",
545 | "鼓",
546 | "鼓槌",
547 | "哑铃",
548 | "荷兰烤箱",
549 | "电风扇",
550 | "电吉他",
551 | "电力机车",
552 | "组合电视柜",
553 | "信封",
554 | "浓缩咖啡机",
555 | "扑面粉",
556 | "女用长围巾",
557 | "文件",
558 | "消防船",
559 | "消防车",
560 | "火炉栏",
561 | "旗杆",
562 | "长笛",
563 | "折叠椅",
564 | "橄榄球头盔",
565 | "叉车",
566 | "喷泉",
567 | "钢笔",
568 | "有四根帷柱的床",
569 | "运货车厢",
570 | "圆号",
571 | "煎锅",
572 | "裘皮大衣",
573 | "垃圾车",
574 | "防毒面具",
575 | "汽油泵",
576 | "高脚杯",
577 | "卡丁车",
578 | "高尔夫球",
579 | "高尔夫球车",
580 | "狭长小船",
581 | "锣",
582 | "礼服",
583 | "钢琴",
584 | "温室",
585 | "散热器格栅",
586 | "杂货店",
587 | "断头台",
588 | "小发夹",
589 | "头发喷雾",
590 | "半履带装甲车",
591 | "锤子",
592 | "大篮子",
593 | "手摇鼓风机",
594 | "手提电脑",
595 | "手帕",
596 | "硬盘",
597 | "口琴",
598 | "竖琴",
599 | "收割机",
600 | "斧头",
601 | "手枪皮套",
602 | "家庭影院",
603 | "蜂窝",
604 | "钩爪",
605 | "衬裙",
606 | "单杠",
607 | "马车",
608 | "沙漏",
609 | "iPod",
610 | "熨斗",
611 | "南瓜灯笼",
612 | "牛仔裤",
613 | "吉普车",
614 | "T恤衫",
615 | "拼图",
616 | "人力车",
617 | "操纵杆",
618 | "和服",
619 | "护膝",
620 | "蝴蝶结",
621 | "大褂",
622 | "长柄勺",
623 | "灯罩",
624 | "笔记本电脑",
625 | "割草机",
626 | "镜头盖",
627 | "开信刀",
628 | "图书馆",
629 | "救生艇",
630 | "点火器",
631 | "豪华轿车",
632 | "远洋班轮",
633 | "唇膏",
634 | "平底便鞋",
635 | "洗剂",
636 | "扬声器",
637 | "放大镜",
638 | "锯木厂",
639 | "磁罗盘",
640 | "邮袋",
641 | "信箱",
642 | "女游泳衣",
643 | "有肩带浴衣",
644 | "窨井盖",
645 | "沙球(一种打击乐器)",
646 | "马林巴木琴",
647 | "面膜",
648 | "火柴",
649 | "花柱",
650 | "迷宫",
651 | "量杯",
652 | "药箱",
653 | "巨石",
654 | "麦克风",
655 | "微波炉",
656 | "军装",
657 | "奶桶",
658 | "迷你巴士",
659 | "迷你裙",
660 | "面包车",
661 | "导弹",
662 | "连指手套",
663 | "搅拌钵",
664 | "活动房屋(由汽车拖拉的)",
665 | "T型发动机小汽车",
666 | "调制解调器",
667 | "修道院",
668 | "显示器",
669 | "电瓶车",
670 | "砂浆",
671 | "学士",
672 | "清真寺",
673 | "蚊帐",
674 | "摩托车",
675 | "山地自行车",
676 | "登山帐",
677 | "鼠标",
678 | "捕鼠器",
679 | "搬家货车",
680 | "动物的口套",
681 | "金属钉子",
682 | "颈托",
683 | "项链",
684 | "乳头(瓶)",
685 | "笔记本",
686 | "方尖碑",
687 | "双簧管",
688 | "陶笛",
689 | "里程表",
690 | "滤油器",
691 | "风琴",
692 | "示波器",
693 | "罩裙",
694 | "牛车",
695 | "氧气面罩",
696 | "包装",
697 | "船桨",
698 | "明轮",
699 | "挂锁",
700 | "画笔",
701 | "睡衣",
702 | "宫殿",
703 | "排箫",
704 | "纸巾",
705 | "降落伞",
706 | "双杠",
707 | "公园长椅",
708 | "停车收费表",
709 | "客车",
710 | "露台",
711 | "付费电话",
712 | "基座",
713 | "铅笔盒",
714 | "卷笔刀",
715 | "香水(瓶)",
716 | "培养皿",
717 | "复印机",
718 | "拨弦片",
719 | "尖顶头盔",
720 | "用尖板条连成的尖桩篱栅",
721 | "皮卡",
722 | "桥墩",
723 | "存钱罐",
724 | "药瓶",
725 | "枕头",
726 | "乒乓球",
727 | "风车",
728 | "海盗船",
729 | "水罐",
730 | "木工刨",
731 | "天文馆",
732 | "塑料袋",
733 | "板架",
734 | "犁型铲雪机",
735 | "手压皮碗泵",
736 | "宝丽来相机",
737 | "电线杆",
738 | "警车",
739 | "雨披",
740 | "台球桌",
741 | "充气饮料瓶",
742 | "花盆",
743 | "陶工旋盘",
744 | "电钻",
745 | "祈祷垫",
746 | "打印机",
747 | "监狱",
748 | "炮弹",
749 | "投影仪",
750 | "冰球",
751 | "沙包",
752 | "小钱袋;手袋",
753 | "羽管笔",
754 | "被子",
755 | "赛车",
756 | "球拍",
757 | "散热器",
758 | "收音机",
759 | "射电望远镜",
760 | "雨桶",
761 | "休闲车",
762 | "卷轴",
763 | "反射式照相机",
764 | "冰箱",
765 | "遥控器",
766 | "餐厅",
767 | "左轮手枪",
768 | "步枪",
769 | "摇椅",
770 | "电转烤肉架",
771 | "橡皮",
772 | "橄榄球",
773 | "直尺",
774 | "跑步鞋",
775 | "保险柜",
776 | "安全别针",
777 | "盐瓶(调味用)",
778 | "凉鞋",
779 | "纱笼",
780 | "萨克斯管",
781 | "剑鞘",
782 | "秤",
783 | "校车",
784 | "帆船",
785 | "记分牌",
786 | "屏幕",
787 | "螺丝",
788 | "螺丝刀",
789 | "安全带",
790 | "缝纫机",
791 | "盾牌",
792 | "皮鞋店",
793 | "障子",
794 | "购物篮",
795 | "购物车",
796 | "铁锹",
797 | "浴帽",
798 | "浴帘",
799 | "滑雪板",
800 | "滑雪面罩",
801 | "睡袋",
802 | "滑尺",
803 | "滑动门",
804 | "角子老虎机",
805 | "潜水通气管",
806 | "摩托雪橇;雪地机动车",
807 | "扫雪机",
808 | "皂液器",
809 | "足球",
810 | "袜子",
811 | "碟式太阳能",
812 | "宽边帽",
813 | "汤碗",
814 | "空格键",
815 | "空间加热器",
816 | "航天飞机",
817 | "锅铲;做饭的铲子",
818 | "快艇",
819 | "蜘蛛网",
820 | "纺锤;手纺用的绕线杆",
821 | "跑车",
822 | "聚光灯",
823 | "舞台",
824 | "蒸汽机车",
825 | "钢拱桥",
826 | "钢滚筒",
827 | "听诊器",
828 | "女用披肩",
829 | "石头墙",
830 | "秒表",
831 | "火炉",
832 | "过滤器",
833 | "有轨电车",
834 | "担架",
835 | "沙发床",
836 | "佛塔",
837 | "潜艇",
838 | "套装",
839 | "日晷",
840 | "太阳镜",
841 | "太阳镜",
842 | "防晒霜",
843 | "悬索桥",
844 | "拖把",
845 | "运动衫",
846 | "游泳裤",
847 | "秋千",
848 | "开关",
849 | "注射器;吸管",
850 | "台灯",
851 | "坦克",
852 | "录音机",
853 | "茶壶",
854 | "泰迪",
855 | "电视",
856 | "网球;打网球的球",
857 | "茅草",
858 | "幕布",
859 | "顶针",
860 | "打谷机;脱粒机",
861 | "宝座",
862 | "瓦屋顶",
863 | "烤面包机",
864 | "烟草店",
865 | "马桶",
866 | "火炬",
867 | "图腾柱",
868 | "拖车;牵引车",
869 | "玩具店",
870 | "拖拉机",
871 | "半挂汽车",
872 | "托盘",
873 | "风衣",
874 | "三轮车",
875 | "三体船",
876 | "三脚架",
877 | "凯旋门",
878 | "无轨电车",
879 | "长号",
880 | "浴盆",
881 | "旋转式栅门",
882 | "打字机键盘",
883 | "伞",
884 | "独轮车",
885 | "直立式钢琴",
886 | "吸尘器",
887 | "花瓶;装饰瓶",
888 | "拱顶",
889 | "天鹅绒",
890 | "自动售货机",
891 | "法衣;祭衣;祭服",
892 | "高架桥",
893 | "小提琴",
894 | "排球",
895 | "松饼机",
896 | "挂钟",
897 | "钱包;钱夹",
898 | "衣柜衣橱",
899 | "军用飞机",
900 | "洗脸盆",
901 | "洗衣机",
902 | "水瓶",
903 | "水壶",
904 | "水塔",
905 | "威士忌壶",
906 | "哨子",
907 | "假发",
908 | "纱窗",
909 | "百叶窗",
910 | "温莎领带",
911 | "葡萄酒瓶",
912 | "飞机翅膀",
913 | "炒菜锅",
914 | "木勺子;木头勺子",
915 | "毛织品",
916 | "原木栅栏",
917 | "沉船",
918 | "双桅船",
919 | "蒙古包",
920 | "网站;网页",
921 | "漫画",
922 | "纵横字谜",
923 | "路标",
924 | "交通信号灯",
925 | "防尘罩",
926 | "菜单",
927 | "盘子",
928 | "墨西哥鳄梨酱;墨西哥牛油果酱",
929 | "清炖肉汤",
930 | "火锅",
931 | "乳脂蛋糕;英国甜点",
932 | "冰淇淋",
933 | "冰棍;雪糕",
934 | "法式面包",
935 | "百吉饼",
936 | "椒盐脆饼",
937 | "芝士汉堡",
938 | "热狗",
939 | "土豆泥",
940 | "结球甘蓝",
941 | "西兰花;绿菜花",
942 | "菜花;花椰菜",
943 | "西葫芦",
944 | "金丝瓜;意面南瓜;面条瓜",
945 | "绿色小南瓜;青南瓜",
946 | "南瓜",
947 | "黄瓜",
948 | "洋蓟;球蓟",
949 | "甜椒",
950 | "刺棘蓟",
951 | "蘑菇",
952 | "绿苹果",
953 | "草莓",
954 | "橘子",
955 | "柠檬",
956 | "无花果",
957 | "菠萝",
958 | "香蕉",
959 | "菠萝蜜",
960 | "番荔枝",
961 | "石榴",
962 | "干草",
963 | "培根蛋酱意大利面",
964 | "巧克力酱",
965 | "生面;面团",
966 | "瑞士肉包",
967 | "披萨",
968 | "馅饼",
969 | "卷饼",
970 | "红葡萄酒",
971 | "意式浓缩咖啡",
972 | "杯子",
973 | "蛋酒",
974 | "高山",
975 | "泡泡",
976 | "悬崖",
977 | "珊瑚礁",
978 | "间歇泉;间断喷发的温泉",
979 | "湖边",
980 | "岬角;深入海中的狭长高地",
981 | "沙洲",
982 | "沙滩",
983 | "峡谷",
984 | "火山",
985 | "棒球运动员",
986 | "新郎",
987 | "潜水员",
988 | "油菜",
989 | "雏菊",
990 | "杓兰",
991 | "玉米",
992 | "橡子",
993 | "玫瑰果",
994 | "七叶树果实",
995 | "珊瑚菌",
996 | "木耳",
997 | "鹿花菌",
998 | "臭角菇",
999 | "地星",
1000 | "多叶奇果菌",
1001 | "牛肝菌",
1002 | "玉米棒子",
1003 | "卫生纸"
1004 | ]
1005 |
1006 |
1007 |
1008 |
1009 | openai_imagenet_template = [
1010 | lambda c: f'质量差的{c}的照片。',
1011 | lambda c: f'许多{c}的照片。',
1012 | lambda c: f'{c}的雕塑。',
1013 | lambda c: f'难以看到{c}的照片。',
1014 | lambda c: f'{c}的低分辨率照片。',
1015 | lambda c: f'{c}的渲染。',
1016 | lambda c: f'涂鸦{c}。',
1017 | lambda c: f'{c}的糟糕照片。',
1018 | lambda c: f'{c}的裁剪照片。',
1019 | lambda c: f'{c}的纹身。',
1020 | lambda c: f'{c}的刺绣照片。',
1021 | lambda c: f'很难看到{c}的照片。',
1022 | lambda c: f'{c}的明亮照片。',
1023 | lambda c: f'一张干净的{c}的照片。',
1024 | lambda c: f'一张包含{c}的照片。',
1025 | lambda c: f'{c}的深色照片。',
1026 | lambda c: f'{c}的手绘画。',
1027 | lambda c: f'我的{c}的照片。',
1028 | lambda c: f'不自然的{c}的照片。',
1029 | lambda c: f'一张酷的{c}的照片。',
1030 | lambda c: f'{c}的特写照片。',
1031 | lambda c: f'{c}的黑白照片。',
1032 | lambda c: f'一幅{c}的画。',
1033 | lambda c: f'一幅{c}的绘画。',
1034 | lambda c: f'一张{c}的像素照片。',
1035 | lambda c: f'{c}的雕像。',
1036 | lambda c: f'一张{c}的明亮照片。',
1037 | lambda c: f'{c}的裁剪照片。',
1038 | lambda c: f'人造的{c}的照片。',
1039 | lambda c: f'一张关于{c}的照片。',
1040 | lambda c: f'损坏的{c}的jpeg照片。',
1041 | lambda c: f'{c}的模糊照片。',
1042 | lambda c: f'{c}的相片。',
1043 | lambda c: f'一张{c}的好照片。',
1044 | lambda c: f'{c}的渲染照。',
1045 | lambda c: f'视频游戏中的{c}。',
1046 | lambda c: f'一张{c}的照片。',
1047 | lambda c: f'{c}的涂鸦。',
1048 | lambda c: f'{c}的近距离照片。',
1049 | lambda c: f'{c}的折纸。',
1050 | lambda c: f'{c}在视频游戏中。',
1051 | lambda c: f'{c}的草图。',
1052 | lambda c: f'{c}的涂鸦照。',
1053 | lambda c: f'{c}的折纸形状。',
1054 | lambda c: f'低分辨率的{c}的照片。',
1055 | lambda c: f'玩具{c}。',
1056 | lambda c: f'{c}的副本。',
1057 | lambda c: f'{c}的干净的照片。',
1058 | lambda c: f'一张大{c}的照片。',
1059 | lambda c: f'{c}的重现。',
1060 | lambda c: f'一张漂亮的{c}的照片。',
1061 | lambda c: f'一张奇怪的{c}的照片。',
1062 | lambda c: f'模糊的{c}的照片。',
1063 | lambda c: f'卡通{c}。',
1064 | lambda c: f'{c}的艺术作品。',
1065 | lambda c: f'{c}的素描。',
1066 | lambda c: f'刺绣{c}。',
1067 | lambda c: f'{c}的像素照。',
1068 | lambda c: f'{c}的拍照。',
1069 | lambda c: f'{c}的损坏的照片。',
1070 | lambda c: f'高质量的{c}的照片。',
1071 | lambda c: f'毛绒玩具{c}。',
1072 | lambda c: f'漂亮的{c}的照片。',
1073 | lambda c: f'小{c}的照片。',
1074 | lambda c: f'照片是奇怪的{c}。',
1075 | lambda c: f'漫画{c}。',
1076 | lambda c: f'{c}的艺术照。',
1077 | lambda c: f'{c}的图形。',
1078 | lambda c: f'大{c}的照片。',
1079 | lambda c: f'黑白的{c}的照片。',
1080 | lambda c: f'{c}毛绒玩具。',
1081 | lambda c: f'一张{c}的深色照片。',
1082 | lambda c: f'{c}的摄影图。',
1083 | lambda c: f'{c}的涂鸦照。',
1084 | lambda c: f'玩具形状的{c}。',
1085 | lambda c: f'拍了{c}的照片。',
1086 | lambda c: f'酷酷的{c}的照片。',
1087 | lambda c: f'照片里的小{c}。',
1088 | lambda c: f'{c}的刺青。',
1089 | ]
1090 |
--------------------------------------------------------------------------------
/text-image/iterable_tar_unzip.sh:
--------------------------------------------------------------------------------
1 | # for name in `ls -d */`;
2 | # do;
3 | name="image_part12/"
4 | for i in `ls $name*.tar`;
5 | do
6 | mkdir ./project/dataset/laion_chinese_cwf/${i%.tar}
7 | tar xvf $i -C ./project/dataset/laion_chinese_cwf/${i%.tar};
8 | done;
9 | # done
10 |
11 |
12 |
--------------------------------------------------------------------------------
/text-image/save_hg_ckpt.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Roberta-base 转换为hugging face版"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": null,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "import torch\n",
17 | "from transformers import BertForSequenceClassification, BertConfig, BertTokenizer\n",
18 | "\n",
19 | "# NOTE 使用open_clip 中文pretrain model训练的结果\n",
20 | "taiyi_path = './project/open_clip_new/src/logs/2022_11_18-21_15_16-model_ViT-L-14-lr_0.0001-b_640-j_16-p_amp/checkpoints/epoch_4.pt'\n",
21 | "bertconfig = BertConfig.from_pretrained(\"hfl/chinese-roberta-wwm-ext\", num_labels=512)\n",
22 | "my_transformer = BertForSequenceClassification.from_pretrained(\"hfl/chinese-roberta-wwm-ext\", config=bertconfig)\n",
23 | "mytokenizer = BertTokenizer.from_pretrained(\"hfl/chinese-roberta-wwm-ext\")\n",
24 | "\n",
25 | "# NOTE 需要改名加载\n",
26 | "state_dict_of_bert = torch.load(taiyi_path)['state_dict']\n",
27 | "bert_weights = {k.replace('module.transformer.',''):v for k,v in state_dict_of_bert.items() if 'module.transformer' in k}\n",
28 | "my_transformer.load_state_dict(bert_weights)"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": null,
34 | "metadata": {},
35 | "outputs": [],
36 | "source": [
37 | "# 同时保存模型和词表格。然后把这个上传到huggingface上面去\n",
38 | "my_transformer.save_pretrained('./CLIP-roberta')\n",
39 | "mytokenizer.save_pretrained('./CLIP-roberta')"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": null,
45 | "metadata": {},
46 | "outputs": [],
47 | "source": [
48 | "total = sum([param.nelement() for param in my_transformer.parameters()])\n",
49 | "print(\"Number of parameter: %.2fM\" % (total/1e6))"
50 | ]
51 | },
52 | {
53 | "cell_type": "markdown",
54 | "metadata": {},
55 | "source": [
56 | "# Roberta-large 版"
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": null,
62 | "metadata": {},
63 | "outputs": [],
64 | "source": [
65 | "import torch\n",
66 | "from transformers import BertForSequenceClassification, BertConfig, BertTokenizer\n",
67 | "\n",
68 | "# NOTE 使用中文pretrain model训练的结果\n",
69 | "taiyi_path = './open_clip/src/logs/2022_07_18-18_39_51-model_ViT-L-14-lr_1e-05-b_224-j_8-p_amp/checkpoints/epoch_7.pt'\n",
70 | "bertconfig = BertConfig.from_pretrained(\"hfl/chinese-roberta-wwm-ext-large\", num_labels=768)\n",
71 | "my_transformer = BertForSequenceClassification.from_pretrained(\"hfl/chinese-roberta-wwm-ext-large\", config=bertconfig)\n",
72 | "mytokenizer = BertTokenizer.from_pretrained(\"hfl/chinese-roberta-wwm-ext-large\")\n",
73 | "\n",
74 | "\n",
75 | "state_dict_of_bert = torch.load(taiyi_path)['state_dict']\n",
76 | "bert_weights = {k.replace('module.transformer.',''):v for k,v in state_dict_of_bert.items() if 'module.transformer' in k}\n",
77 | "my_transformer.load_state_dict(bert_weights)"
78 | ]
79 | },
80 | {
81 | "cell_type": "code",
82 | "execution_count": null,
83 | "metadata": {},
84 | "outputs": [],
85 | "source": [
86 | "# 同时保存模型和词表格。然后把这个上传到huggingface上面去\n",
87 | "my_transformer.save_pretrained('./Taiyi-CLIP-Roberta-large-326M-Chinese')\n",
88 | "mytokenizer.save_pretrained('./Taiyi-CLIP-Roberta-large-326M-Chinese')"
89 | ]
90 | },
91 | {
92 | "cell_type": "code",
93 | "execution_count": null,
94 | "metadata": {},
95 | "outputs": [],
96 | "source": [
97 | "total = sum([param.nelement() for param in my_transformer.parameters()])\n",
98 | "print(\"Number of parameter: %.2fM\" % (total/1e6))\n"
99 | ]
100 | },
101 | {
102 | "cell_type": "code",
103 | "execution_count": null,
104 | "metadata": {},
105 | "outputs": [],
106 | "source": [
107 | "from PIL import Image\n",
108 | "import requests\n",
109 | "import clip\n",
110 | "import torch\n",
111 | "from transformers import BertForSequenceClassification, BertConfig, BertTokenizer\n",
112 | "from transformers import CLIPProcessor, CLIPModel\n",
113 | "import numpy as np\n",
114 | "\n",
115 | "query_texts = [\"一只猫\", \"一只狗\",'两只猫', '两只老虎','一只老虎'] # 这里是输入文本的,可以随意替换。\n",
116 | "# 加载Taiyi 中文 text encoder\n",
117 | "text_tokenizer = BertTokenizer.from_pretrained(\"IDEA-CCNL/Taiyi-CLIP-Roberta-large-326M-Chinese\")\n",
118 | "text_encoder = BertForSequenceClassification.from_pretrained(\"IDEA-CCNL/Taiyi-CLIP-Roberta-large-326M-Chinese\").eval()\n",
119 | "text = text_tokenizer(query_texts, return_tensors='pt', padding=True)['input_ids']\n",
120 | "\n",
121 | "url = \"http://images.cocodataset.org/val2017/000000039769.jpg\" # 这里可以换成任意图片的url\n",
122 | "# 加载CLIP的image encoder\n",
123 | "clip_model = CLIPModel.from_pretrained(\"openai/clip-vit-large-patch14\") \n",
124 | "processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-large-patch14\")\n",
125 | "image = processor(images=Image.open(requests.get(url, stream=True).raw), return_tensors=\"pt\")\n",
126 | "\n",
127 | "with torch.no_grad():\n",
128 | " image_features = clip_model.get_image_features(**image)\n",
129 | " text_features = text_encoder(text).logits\n",
130 | " # 归一化\n",
131 | " image_features = image_features / image_features.norm(dim=1, keepdim=True)\n",
132 | " text_features = text_features / text_features.norm(dim=1, keepdim=True)\n",
133 | " # 计算余弦相似度 logit_scale是尺度系数\n",
134 | " logit_scale = clip_model.logit_scale.exp()\n",
135 | " logits_per_image = logit_scale * image_features @ text_features.t()\n",
136 | " logits_per_text = logits_per_image.t()\n",
137 | " probs = logits_per_image.softmax(dim=-1).cpu().numpy()\n",
138 | " print(np.around(probs, 3))\n"
139 | ]
140 | },
141 | {
142 | "cell_type": "markdown",
143 | "metadata": {},
144 | "source": [
145 | "# ViT-H, 维度对应的Roberta-Large"
146 | ]
147 | },
148 | {
149 | "cell_type": "code",
150 | "execution_count": null,
151 | "metadata": {},
152 | "outputs": [],
153 | "source": [
154 | "import torch\n",
155 | "from transformers import BertModel, BertTokenizer\n",
156 | "\n",
157 | "# NOTE load from local path\n",
158 | "local_path = './scripts_t2i/open_clip_new/src/logs/2022_09_16-23_03_14-model_ViT-H-14-lr_5e-05-b_256-j_32-p_amp/checkpoints/epoch_21.pt'\n",
159 | "text_encoder = BertModel.from_pretrained(\"hfl/chinese-roberta-wwm-ext-large\").cuda().eval()\n",
160 | "state_dict_of_bert = torch.load(local_path)['state_dict']\n",
161 | "bert_weights = {k.replace('module.transformer.',''):v for k,v in state_dict_of_bert.items() if 'module.transformer' in k}\n",
162 | "text_encoder.load_state_dict(bert_weights)\n",
163 | "tokenizer = BertTokenizer.from_pretrained(\"hfl/chinese-roberta-wwm-ext\")\n"
164 | ]
165 | },
166 | {
167 | "cell_type": "code",
168 | "execution_count": null,
169 | "metadata": {},
170 | "outputs": [],
171 | "source": [
172 | "# 同时保存模型和词表格。然后把这个上传到huggingface上面去\n",
173 | "text_encoder.save_pretrained('./fengshen/Taiyi-CLIP-Roberta-326M-ViT-H-Chinese')\n",
174 | "tokenizer.save_pretrained('./fengshen/Taiyi-CLIP-Roberta-326M-ViT-H-Chinese')"
175 | ]
176 | },
177 | {
178 | "cell_type": "markdown",
179 | "metadata": {},
180 | "source": [
181 | "# ViT-L --- Roberta-base"
182 | ]
183 | },
184 | {
185 | "cell_type": "code",
186 | "execution_count": null,
187 | "metadata": {},
188 | "outputs": [],
189 | "source": [
190 | "import torch\n",
191 | "from transformers import BertModel, BertTokenizer\n",
192 | "\n",
193 | "# NOTE load from local path\n",
194 | "local_path = './project/open_clip_new/src/logs/2022_11_18-21_15_16-model_ViT-L-14-lr_0.0001-b_640-j_16-p_amp/checkpoints/epoch_4.pt'\n",
195 | "text_encoder = BertModel.from_pretrained(\"hfl/chinese-roberta-wwm-ext\").cuda().eval()\n",
196 | "state_dict_of_bert = torch.load(local_path)['state_dict']\n",
197 | "bert_weights = {k.replace('module.transformer.',''):v for k,v in state_dict_of_bert.items() if 'module.transformer' in k}\n",
198 | "text_encoder.load_state_dict(bert_weights)\n",
199 | "tokenizer = BertTokenizer.from_pretrained(\"hfl/chinese-roberta-wwm-ext\")\n"
200 | ]
201 | },
202 | {
203 | "cell_type": "code",
204 | "execution_count": null,
205 | "metadata": {},
206 | "outputs": [],
207 | "source": [
208 | "# 同时保存模型和词表格。然后把这个上传到huggingface上面去\n",
209 | "text_encoder.save_pretrained('./project/temp_weights/vit-l-roberta-base')\n",
210 | "tokenizer.save_pretrained('./project/temp_weights/vit-l-roberta-base')"
211 | ]
212 | },
213 | {
214 | "cell_type": "code",
215 | "execution_count": null,
216 | "metadata": {},
217 | "outputs": [],
218 | "source": [
219 | "total = sum([param.nelement() for param in text_encoder.parameters()])\n",
220 | "print(\"Number of parameter: %.2fM\" % (total/1e6))"
221 | ]
222 | }
223 | ],
224 | "metadata": {
225 | "kernelspec": {
226 | "display_name": "Python 3.9.13 ('base')",
227 | "language": "python",
228 | "name": "python3"
229 | },
230 | "language_info": {
231 | "codemirror_mode": {
232 | "name": "ipython",
233 | "version": 3
234 | },
235 | "file_extension": ".py",
236 | "mimetype": "text/x-python",
237 | "name": "python",
238 | "nbconvert_exporter": "python",
239 | "pygments_lexer": "ipython3",
240 | "version": "3.9.13"
241 | },
242 | "orig_nbformat": 4,
243 | "vscode": {
244 | "interpreter": {
245 | "hash": "4cc247672a8bfe61dc951074f9ca89ab002dc0f7e14586a8bb0828228bebeefa"
246 | }
247 | }
248 | },
249 | "nbformat": 4,
250 | "nbformat_minor": 2
251 | }
252 |
--------------------------------------------------------------------------------
/text-image/zeroshot_retrieval_evaluation.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "def get_metrics(image_features, text_features, labels, logit_scale):\n",
10 | " # 计算相似度,支持多个样本的情况(比如一个图片有多个caption)\n",
11 | " # img2txt计算的时候要用到,因为一张图片可能对应多个文本。\n",
12 | " # txt2img计算的时候不需要(一般一个text只有一个对应图片)\n",
13 | " metrics = {}\n",
14 | " logits_per_image = (logit_scale * image_features @ text_features.t()).detach().cpu()\n",
15 | " logits_per_text = logits_per_image.t().detach().cpu()\n",
16 | "\n",
17 | " logits = {\"image_to_text\": logits_per_image, \"text_to_image\": logits_per_text}\n",
18 | "\n",
19 | " label2idx = {} # 计算label到idx的映射。\n",
20 | " repeat_id = []\n",
21 | " for i, label in enumerate(labels):\n",
22 | " if label not in label2idx:\n",
23 | " label2idx[label] = [i]\n",
24 | " else:\n",
25 | " # 表示该index的标签出现过,记录这个index,后续算txt2img分数的时候,这些index的权值要降低。\n",
26 | " label2idx[label].append(i)\n",
27 | " repeat_id.append(i)\n",
28 | " # print(label2idx) # 标注了每个label的idx\n",
29 | "\n",
30 | " # print('repeat_id:', repeat_id)\n",
31 | " ground_truth = [label2idx[label] for label in labels]\n",
32 | " # print(ground_truth)\n",
33 | "\n",
34 | " for name, logit in logits.items():\n",
35 | " # print(name, logit.shape)\n",
36 | " if name == 'text_to_image':\n",
37 | " logit[:, repeat_id] -= 1e8 # 这部分的分数要降低。(重复出现的图片,直接忽略)\n",
38 | " r1_stat, r5_stat, r10_stat = [], [], []\n",
39 | " ranking = torch.argsort(logit, descending=True) # index of the largest element to the smallest\n",
40 | " # print(name, ranking[:, :10])\n",
41 | " for i, each_query in enumerate(ranking[:, :10]):\n",
42 | " for j, q in enumerate(each_query):\n",
43 | " if q in ground_truth[i]:\n",
44 | " if j == 0:\n",
45 | " r1_stat.append(1)\n",
46 | " r5_stat.append(1)\n",
47 | " r10_stat.append(1)\n",
48 | " break\n",
49 | " if j < 5:\n",
50 | " r5_stat.append(1)\n",
51 | " r10_stat.append(1)\n",
52 | " break\n",
53 | " if j < 10:\n",
54 | " r10_stat.append(1)\n",
55 | " break\n",
56 | " print(f'{name} r1:{sum(r1_stat)/len(logit)}, r5:{sum(r5_stat)/len(logit)}, r10:{sum(r10_stat)/len(logit)}')\n"
57 | ]
58 | },
59 | {
60 | "cell_type": "markdown",
61 | "metadata": {},
62 | "source": [
63 | "# COCO-CN"
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": null,
69 | "metadata": {},
70 | "outputs": [],
71 | "source": [
72 | "from transformers import BertTokenizer\n",
73 | "from torch.utils.data import Dataset\n",
74 | "from torch.utils.data import DataLoader\n",
75 | "from PIL import Image\n",
76 | "class COCO_CN(Dataset):\n",
77 | " def __init__(self, img_root_path='./dataset/coco', \\\n",
78 | " test_img_path='./dataset/coco/coco-cn-version1805v1.1/coco-cn_test.txt', \\\n",
79 | " annot_path = './dataset/coco/coco-cn-version1805v1.1/imageid.human-written-caption.txt', \\\n",
80 | " transform=None):\n",
81 | " self.images = []\n",
82 | " self.captions = []\n",
83 | " self.labels = []\n",
84 | " self.root = img_root_path\n",
85 | " \n",
86 | " test_path = dict()\n",
87 | " with open(test_img_path, 'r') as f:\n",
88 | " for line in f:\n",
89 | " line = line.strip()\n",
90 | " if line not in test_path:\n",
91 | " test_path[line] = 1\n",
92 | " # print(test_path)\n",
93 | "\n",
94 | " with open(annot_path, 'r') as f:\n",
95 | " for line in f:\n",
96 | " line = line.strip().split('\\t')\n",
97 | " key, caption = line[0].split('#')[0], line[1]\n",
98 | " # NOTE 只保留test set的\n",
99 | " if key not in test_path:\n",
100 | " continue\n",
101 | " # if line[0].split('#')[-1] != '0':\n",
102 | " # # print(key, line[0].split('#')[-1])\n",
103 | " # continue # 只保留一句\n",
104 | " img_path = key + '.jpg'\n",
105 | "\n",
106 | " if 'train' in img_path:\n",
107 | " self.images.append(os.path.join('train2014' ,img_path) )\n",
108 | " else:\n",
109 | " self.images.append(os.path.join('val2014' ,img_path) )\n",
110 | " self.captions.append(caption)\n",
111 | " self.labels.append(key)\n",
112 | " self.transforms = transform\n",
113 | " self.tokenizer = BertTokenizer.from_pretrained(\"hfl/chinese-roberta-wwm-ext\")\n",
114 | "\n",
115 | " # NOTE large 模型\n",
116 | " self.context_length = 77\n",
117 | "\n",
118 | " def __len__(self):\n",
119 | " return len(self.images)\n",
120 | "\n",
121 | " def __getitem__(self, idx):\n",
122 | " img_path = str(self.images[idx])\n",
123 | " image = self.transforms(Image.open( os.path.join(self.root, img_path ))) \n",
124 | " text = self.tokenizer(str(self.captions[idx]), max_length=self.context_length, padding='max_length', truncation=True, return_tensors='pt')['input_ids'][0]\n",
125 | " label = self.labels[idx]\n",
126 | " return image, text, label\n",
127 | "\n",
128 | "from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \\\n",
129 | " CenterCrop\n",
130 | "def _convert_to_rgb(image):\n",
131 | " return image.convert('RGB')\n",
132 | "\n",
133 | "def image_transform(\n",
134 | " image_size: int,\n",
135 | " is_train: bool,\n",
136 | " mean=(0.48145466, 0.4578275, 0.40821073),\n",
137 | " std=(0.26862954, 0.26130258, 0.27577711)\n",
138 | "):\n",
139 | " normalize = Normalize(mean=mean, std=std)\n",
140 | " if is_train:\n",
141 | " return Compose([\n",
142 | " RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC),\n",
143 | " _convert_to_rgb,\n",
144 | " ToTensor(),\n",
145 | " normalize,\n",
146 | " ])\n",
147 | " else:\n",
148 | " return Compose([\n",
149 | " Resize(image_size, interpolation=InterpolationMode.BICUBIC),\n",
150 | " CenterCrop(image_size),\n",
151 | " _convert_to_rgb,\n",
152 | " ToTensor(),\n",
153 | " normalize,\n",
154 | " ])\n",
155 | "\n",
156 | "val_transform = image_transform(224, False)\n",
157 | "dataset = COCO_CN(transform = val_transform)\n",
158 | "dataloader = DataLoader(dataset, batch_size=128, shuffle=False, num_workers=4)"
159 | ]
160 | },
161 | {
162 | "cell_type": "code",
163 | "execution_count": null,
164 | "metadata": {},
165 | "outputs": [],
166 | "source": [
167 | "len(dataset)"
168 | ]
169 | },
170 | {
171 | "cell_type": "code",
172 | "execution_count": null,
173 | "metadata": {},
174 | "outputs": [],
175 | "source": [
176 | "from transformers import BertTokenizer\n",
177 | "from transformers import BertForSequenceClassification\n",
178 | "from transformers import CLIPModel\n",
179 | "import torch\n",
180 | "# NOTE load model\n",
181 | "\n",
182 | "text_encoder = BertForSequenceClassification.from_pretrained(\"IDEA-CCNL/Taiyi-CLIP-Roberta-102M-Chinese\").cuda().eval()\n",
183 | "clip_model = CLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\").cuda().eval() \n",
184 | "\n",
185 | "# text_encoder = BertForSequenceClassification.from_pretrained(\"IDEA-CCNL/Taiyi-CLIP-Roberta-large-326M-Chinese\").cuda().eval()\n",
186 | "# clip_model = CLIPModel.from_pretrained(\"openai/clip-vit-large-patch14\").cuda().eval() \n",
187 | "\n",
188 | "\n",
189 | "all_img_features, all_text_features, all_labels = [], [], []\n",
190 | "with torch.no_grad():\n",
191 | " for i, data in enumerate(dataloader):\n",
192 | " images, captions, labels = data\n",
193 | " images = images.cuda()\n",
194 | " captions = captions.cuda()\n",
195 | " all_labels.extend(labels)\n",
196 | " # print(images.shape, captions.shape, labels)\n",
197 | "\n",
198 | " image_features = clip_model.get_image_features(images)\n",
199 | " text_features = text_encoder(captions).logits\n",
200 | " # 归一化\n",
201 | " image_features = image_features / image_features.norm(dim=1, keepdim=True)\n",
202 | " text_features = text_features / text_features.norm(dim=1, keepdim=True)\n",
203 | " all_img_features.append(image_features)\n",
204 | " all_text_features.append(text_features)\n",
205 | " # if i == 10:\n",
206 | " # break\n",
207 | " img_features = torch.cat(all_img_features)\n",
208 | " text_features = torch.cat(all_text_features)\n",
209 | " print(img_features.shape, text_features.shape, len(all_labels))\n",
210 | "get_metrics(img_features, text_features, all_labels, 100) "
211 | ]
212 | },
213 | {
214 | "cell_type": "markdown",
215 | "metadata": {},
216 | "source": [
217 | "# flickr30k-CNA"
218 | ]
219 | },
220 | {
221 | "cell_type": "code",
222 | "execution_count": null,
223 | "metadata": {},
224 | "outputs": [],
225 | "source": [
226 | "from transformers import BertTokenizer\n",
227 | "from torch.utils.data import Dataset\n",
228 | "from torch.utils.data import DataLoader\n",
229 | "from PIL import Image\n",
230 | "class flickr30k_CNA(Dataset):\n",
231 | " def __init__(self, img_root_path='./dataset/mm_data/Flickr30k-CNA/flickr30k/images', \\\n",
232 | " text_annot_path='./dataset/mm_data/Flickr30k-CNA/test/flickr30k_cn_test.txt', \\\n",
233 | " transform=None):\n",
234 | " self.images = []\n",
235 | " self.captions = []\n",
236 | " self.labels = []\n",
237 | " self.root = img_root_path\n",
238 | " with open(text_annot_path, 'r') as f:\n",
239 | " for line in f:\n",
240 | " line = line.strip().split('\\t')\n",
241 | " key, caption = line[0].split('#')[0], line[1]\n",
242 | " img_path = key + '.jpg'\n",
243 | " self.images.append(img_path)\n",
244 | " self.captions.append(caption)\n",
245 | " self.labels.append(key)\n",
246 | " self.transforms = transform\n",
247 | " self.tokenizer = BertTokenizer.from_pretrained(\"hfl/chinese-roberta-wwm-ext\")\n",
248 | "\n",
249 | " # NOTE large 模型\n",
250 | " self.context_length = 77\n",
251 | "\n",
252 | " def __len__(self):\n",
253 | " return len(self.images)\n",
254 | "\n",
255 | " def __getitem__(self, idx):\n",
256 | " img_path = str(self.images[idx])\n",
257 | " image = self.transforms(Image.open( os.path.join(self.root, img_path ))) \n",
258 | " text = self.tokenizer(str(self.captions[idx]), max_length=self.context_length, padding='max_length', truncation=True, return_tensors='pt')['input_ids'][0]\n",
259 | " label = self.labels[idx]\n",
260 | " return image, text, label\n",
261 | "\n",
262 | "from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \\\n",
263 | " CenterCrop\n",
264 | "def _convert_to_rgb(image):\n",
265 | " return image.convert('RGB')\n",
266 | "\n",
267 | "def image_transform(\n",
268 | " image_size: int,\n",
269 | " is_train: bool,\n",
270 | " mean=(0.48145466, 0.4578275, 0.40821073),\n",
271 | " std=(0.26862954, 0.26130258, 0.27577711)\n",
272 | "):\n",
273 | " normalize = Normalize(mean=mean, std=std)\n",
274 | " if is_train:\n",
275 | " return Compose([\n",
276 | " RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC),\n",
277 | " _convert_to_rgb,\n",
278 | " ToTensor(),\n",
279 | " normalize,\n",
280 | " ])\n",
281 | " else:\n",
282 | " return Compose([\n",
283 | " Resize(image_size, interpolation=InterpolationMode.BICUBIC),\n",
284 | " CenterCrop(image_size),\n",
285 | " _convert_to_rgb,\n",
286 | " ToTensor(),\n",
287 | " normalize,\n",
288 | " ])\n",
289 | "\n",
290 | "val_transform = image_transform(224, False)\n",
291 | "img_root = '/dataset/mm_data/Flickr30k-CNA/flickr30k/images'\n",
292 | "text_annot_path = './dataset/mm_data/Flickr30k-CNA/test/flickr30k_cn_test.txt'\n",
293 | "dataset = flickr30k_CNA(img_root, text_annot_path, val_transform)\n",
294 | "dataloader = DataLoader(dataset, batch_size=128, shuffle=False, num_workers=4)"
295 | ]
296 | },
297 | {
298 | "cell_type": "code",
299 | "execution_count": null,
300 | "metadata": {},
301 | "outputs": [],
302 | "source": [
303 | "from transformers import BertTokenizer\n",
304 | "from transformers import BertForSequenceClassification\n",
305 | "from transformers import CLIPModel\n",
306 | "import torch\n",
307 | "# text_encoder = BertForSequenceClassification.from_pretrained(\"IDEA-CCNL/Taiyi-CLIP-Roberta-102M-Chinese\").cuda().eval()\n",
308 | "# clip_model = CLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\").cuda().eval() \n",
309 | "\n",
310 | "# NOTE large\n",
311 | "text_encoder = BertForSequenceClassification.from_pretrained(\"IDEA-CCNL/Taiyi-CLIP-Roberta-large-326M-Chinese\").cuda().eval()\n",
312 | "clip_model = CLIPModel.from_pretrained(\"openai/clip-vit-large-patch14\").cuda().eval() "
313 | ]
314 | },
315 | {
316 | "cell_type": "code",
317 | "execution_count": null,
318 | "metadata": {},
319 | "outputs": [],
320 | "source": [
321 | "all_img_features, all_text_features, all_labels = [], [], []\n",
322 | "with torch.no_grad():\n",
323 | " for i, data in enumerate(dataloader):\n",
324 | " images, captions, labels = data\n",
325 | " images = images.cuda()\n",
326 | " captions = captions.cuda()\n",
327 | " all_labels.extend(labels)\n",
328 | " # print(images.shape, captions.shape, labels)\n",
329 | "\n",
330 | " image_features = clip_model.get_image_features(images)\n",
331 | " text_features = text_encoder(captions).logits\n",
332 | " # 归一化\n",
333 | " image_features = image_features / image_features.norm(dim=1, keepdim=True)\n",
334 | " text_features = text_features / text_features.norm(dim=1, keepdim=True)\n",
335 | " all_img_features.append(image_features)\n",
336 | " all_text_features.append(text_features)\n",
337 | " # if i == 10:\n",
338 | " # break\n",
339 | " img_features = torch.cat(all_img_features)\n",
340 | " text_features = torch.cat(all_text_features)\n",
341 | " print(img_features.shape, text_features.shape, len(all_labels))"
342 | ]
343 | },
344 | {
345 | "cell_type": "code",
346 | "execution_count": null,
347 | "metadata": {},
348 | "outputs": [],
349 | "source": [
350 | "get_metrics(img_features, text_features, all_labels, 100) # 图片取前1000张,因为后面的是重复的(每张图片对应5个caption)Flickr"
351 | ]
352 | },
353 | {
354 | "cell_type": "markdown",
355 | "metadata": {},
356 | "source": [
357 | "# non-classification"
358 | ]
359 | },
360 | {
361 | "cell_type": "code",
362 | "execution_count": null,
363 | "metadata": {},
364 | "outputs": [],
365 | "source": [
366 | "# NOTE load from local path\n",
367 | "from transformers import BertModel\n",
368 | "local_path = './project/open_clip_new/src/logs/2022_11_18-21_15_16-model_ViT-L-14-lr_0.0001-b_640-j_16-p_amp/checkpoints/epoch_5.pt'\n",
369 | "text_encoder = BertModel.from_pretrained(\"hfl/chinese-roberta-wwm-ext\").cuda().eval()\n",
370 | "state_dict_of_bert = torch.load(local_path)['state_dict']\n",
371 | "bert_weights = {k.replace('module.transformer.',''):v for k,v in state_dict_of_bert.items() if 'module.transformer' in k}\n",
372 | "text_encoder.load_state_dict(bert_weights)"
373 | ]
374 | },
375 | {
376 | "cell_type": "code",
377 | "execution_count": null,
378 | "metadata": {},
379 | "outputs": [],
380 | "source": [
381 | "all_img_features, all_text_features, all_labels = [], [], []\n",
382 | "with torch.no_grad():\n",
383 | " for i, data in enumerate(dataloader):\n",
384 | " images, captions, labels = data\n",
385 | " images = images.cuda()\n",
386 | " captions = captions.cuda()\n",
387 | " all_labels.extend(labels)\n",
388 | " # print(images.shape, captions.shape, labels)\n",
389 | "\n",
390 | " image_features = clip_model.get_image_features(images)\n",
391 | " text_features = text_encoder(captions)[1]\n",
392 | " # 归一化\n",
393 | " image_features = image_features / image_features.norm(dim=1, keepdim=True)\n",
394 | " text_features = text_features / text_features.norm(dim=1, keepdim=True)\n",
395 | " all_img_features.append(image_features)\n",
396 | " all_text_features.append(text_features)\n",
397 | " # if i == 10:\n",
398 | " # break\n",
399 | " img_features = torch.cat(all_img_features)\n",
400 | " text_features = torch.cat(all_text_features)\n",
401 | " print(img_features.shape, text_features.shape, len(all_labels))\n",
402 | "get_metrics(img_features, text_features, all_labels, 100) # 图片取前1000张,因为后面的是重复的(每张图片对应5个caption)Flickr"
403 | ]
404 | }
405 | ],
406 | "metadata": {
407 | "kernelspec": {
408 | "display_name": "Python 3.9.13 ('base')",
409 | "language": "python",
410 | "name": "python3"
411 | },
412 | "language_info": {
413 | "codemirror_mode": {
414 | "name": "ipython",
415 | "version": 3
416 | },
417 | "file_extension": ".py",
418 | "mimetype": "text/x-python",
419 | "name": "python",
420 | "nbconvert_exporter": "python",
421 | "pygments_lexer": "ipython3",
422 | "version": "3.9.13"
423 | },
424 | "orig_nbformat": 4,
425 | "vscode": {
426 | "interpreter": {
427 | "hash": "4cc247672a8bfe61dc951074f9ca89ab002dc0f7e14586a8bb0828228bebeefa"
428 | }
429 | }
430 | },
431 | "nbformat": 4,
432 | "nbformat_minor": 2
433 | }
434 |
--------------------------------------------------------------------------------