├── .gitignore ├── LICENSE ├── README.md ├── detection ├── coco2yolo.py ├── coco_eval.py ├── vis_yolo_gt_dt.py └── yolo2coco.py └── text-image ├── convert_diffusers_to_original_stable_diffusion.py ├── data_filter ├── data_filter_demo.ipynb ├── wukong_filter.py └── wukong_reader.py ├── fid_clip_score ├── .gitignore ├── coco_sample_generator.py ├── compute_fid.ipynb ├── fid_clip_coco.ipynb ├── fid_clip_coco_cn.ipynb ├── run_generator.sh └── run_generator_cn.sh ├── imagenet_CN_zeroshot_data.py ├── iterable_tar_unzip.sh ├── save_hg_ckpt.ipynb └── zeroshot_retrieval_evaluation.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Weifeng Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | Some Scripts For DEEP LEARNING 3 | 4 | # 1. detection 5 | ## yolo2coco.py 6 | 将yolo格式数据集修改成coco格式。`$ROOT_PATH`是根目录,需要按下面的形式组织数据: 7 | 8 | ```bash 9 | └── $ROOT_PATH 10 | 11 | ├── classes.txt 12 | 13 | ├── images 14 | 15 | └──labels 16 | ``` 17 | 18 | - `classes.txt` 是类的声明,一行一类。 19 | 20 | - `images` 目录包含所有图片 (目前支持`png`和`jpg`格式数据) 21 | 22 | - `labels` 目录包含所有标签(与图片**同名**的`txt`格式数据) 23 | 24 | 配置好文件夹后,执行:`python yolo2coco.py --root_dir $ROOT_PATH ` ,然后就能看见生成的 `annotations` 文件夹。 25 | 26 | **参数说明** 27 | - `--root_path` 输入根目录`$ROOT_PATH`的位置。 28 | - `--save_path` 如果不进行数据集划分,可利用此参数指定输出文件的名字,默认保存为`train.json` 29 | - `--random_split` 随机划分参数,若指定`--random_split`参数,则输出在`annotations`文件夹下包含 `train.json` `val.json` `test.json` (默认随机划分成8:1:1) 30 | - `--split_by_file` 自定义数据集划分,若指定`--split_by_file`参数,则输出在`annotations`文件夹 `train.json` `val.json` `test.json`。需要在`$ROOT_PATH`文件下有 `./train.txt ./val.txt ./test.txt` ,可以这3个文件来定义训练集、验证集、测试集。**注意**, 这里里面填写的应是图片文件名字,而不是图片的绝对地址。(在line 43也自行可以修改一下读取方式,为了方便起见,不推荐把图片放在不同位置) 31 | 32 | 33 | ## coco2yolo.py 34 | 35 | 读入coco数据集json格式的标注,输出可供yolo训练的标签。 36 | 37 | **需要注意的是,COCO2017官方的数据集中categories id 是不连续的**,这在yolo读取的时候会出问题,所以需要重新映射一下,这个代码会按id从小到大映射到0~79之间。(如果是自己的数据集,也会重新映射) 38 | 39 | 执行:`python coco2yolo.py --json_path $JSON_FILE_PATH --save_path $LABEL_SAVE_PATH` 40 | 41 | - `$JSON_FILE_PATH`是json文件的地址。 42 | - `$JSON_FILE_PATH`是输出目录(默认为工作目录下的`./labels`目录。 43 | 44 | 45 | ## zeroshot_retrieval_evaluation.ipynb 46 | - 检索topN的计算,支持一对多检索。(一张图对应有多个captions) 47 | 48 | ## vis_yolo_gt_dt.py 49 | 同时把GT和预测结果可视化在同一张图中。`$DT_DIR`是预测结果标签地址,必须是和GT同名的标签。`$ROOT_PATH`文件目录: 50 | 51 | ```bash 52 | └── $ROOT_PATH 53 | 54 | ├── classes.txt 55 | 56 | ├── images 57 | 58 | └── labels 59 | ``` 60 | 61 | 执行:`python vis_yolo_gt_dt.py --root $ROOT_PATH --dt $DT_DIR`后生成在`outputs`文件夹中。 62 | 63 | - `classes.txt`和`images`必须有。 64 | - `labels`可以没有,那样就只展示`$DT_DIR`预测结果。 65 | - `$DT_DIR` 若没有输入,则只展示标签结果。 66 | 67 | ## coco_eval.py 68 | 69 | 评估生成的结果,针对**yolov5**生成的检测结果(test中的`--save-json`参数,会生成`best_predictions.json`),但是这个不适应cocoapi,需要用脚本来修改适应。执行: 70 | 71 | `python coco_eval.py --gt $GT_PATH --dt $DT_PATH --yolov5` 72 | 73 | - `--gt` json格式,用于指定测试集的结果,如果没有,可以利用前面的`yolo2coco.py`进行转换。 74 | - `--dt` 同样检测网络生成的预测,使用cocoapi中`loadRes`来加载,所以需要有相应格式的检测结果。 75 | - `--yolov5` 将官方代码中生成的结果转换成适配cocoapi的结果。 76 | 77 | # 2. text-image 78 | ## zeroshot_retrieval_evalution.ipynb 79 | 检索模型的评估指标。(topK召回率),支持多对多的情况。(比如一个文本匹配多张图片) 80 | ## fid_clip_score 81 | 用于画text2image的 FID-CLIP Score曲线图。 -------------------------------------------------------------------------------- /detection/coco2yolo.py: -------------------------------------------------------------------------------- 1 | """ 2 | 2021/1/24 3 | COCO 格式的数据集转化为 YOLO 格式的数据集,源代码采取遍历方式,太慢, 4 | 这里改进了一下时间复杂度,从O(nm)改为O(n+m),但是牺牲了一些内存占用 5 | --json_path 输入的json文件路径 6 | --save_path 保存的文件夹名字,默认为当前目录下的labels。 7 | """ 8 | 9 | import os 10 | import json 11 | from tqdm import tqdm 12 | import argparse 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--json_path', default='./instances_val2017.json',type=str, help="input: coco format(json)") 16 | parser.add_argument('--save_path', default='./labels', type=str, help="specify where to save the output dir of labels") 17 | arg = parser.parse_args() 18 | 19 | def convert(size, box): 20 | dw = 1. / (size[0]) 21 | dh = 1. / (size[1]) 22 | x = box[0] + box[2] / 2.0 23 | y = box[1] + box[3] / 2.0 24 | w = box[2] 25 | h = box[3] 26 | 27 | x = x * dw 28 | w = w * dw 29 | y = y * dh 30 | h = h * dh 31 | return (x, y, w, h) 32 | 33 | if __name__ == '__main__': 34 | json_file = arg.json_path # COCO Object Instance 类型的标注 35 | ana_txt_save_path = arg.save_path # 保存的路径 36 | 37 | data = json.load(open(json_file, 'r')) 38 | if not os.path.exists(ana_txt_save_path): 39 | os.makedirs(ana_txt_save_path) 40 | 41 | id_map = {} # coco数据集的id不连续!重新映射一下再输出! 42 | for i, category in enumerate(data['categories']): 43 | id_map[category['id']] = i 44 | 45 | # 通过事先建表来降低时间复杂度 46 | max_id = 0 47 | for img in data['images']: 48 | max_id = max(max_id, img['id']) 49 | # 注意这里不能写作 [[]]*(max_id+1),否则列表内的空列表共享地址 50 | img_ann_dict = [[] for i in range(max_id+1)] 51 | for i, ann in enumerate(data['annotations']): 52 | img_ann_dict[ann['image_id']].append(i) 53 | 54 | for img in tqdm(data['images']): 55 | filename = img["file_name"] 56 | img_width = img["width"] 57 | img_height = img["height"] 58 | img_id = img["id"] 59 | head, tail = os.path.splitext(filename) 60 | ana_txt_name = head + ".txt" # 对应的txt名字,与jpg一致 61 | f_txt = open(os.path.join(ana_txt_save_path, ana_txt_name), 'w') 62 | '''for ann in data['annotations']: 63 | if ann['image_id'] == img_id: 64 | box = convert((img_width, img_height), ann["bbox"]) 65 | f_txt.write("%s %s %s %s %s\n" % (id_map[ann["category_id"]], box[0], box[1], box[2], box[3]))''' 66 | # 这里可以直接查表而无需重复遍历 67 | for ann_id in img_ann_dict[img_id]: 68 | ann = data['annotations'][ann_id] 69 | box = convert((img_width, img_height), ann["bbox"]) 70 | f_txt.write("%s %s %s %s %s\n" % (id_map[ann["category_id"]], box[0], box[1], box[2], box[3])) 71 | f_txt.close() 72 | 73 | # 旧版,很慢hhh 74 | # """ 75 | # COCO 格式的数据集转化为 YOLO 格式的数据集 76 | # --json_path 输入的json文件路径 77 | # --save_path 保存的文件夹名字,默认为当前目录下的labels。 78 | # """ 79 | 80 | # import os 81 | # import json 82 | # from tqdm import tqdm 83 | # import argparse 84 | 85 | # parser = argparse.ArgumentParser() 86 | # parser.add_argument('--json_path', default='./instances_val2017.json',type=str, help="input: coco format(json)") 87 | # parser.add_argument('--save_path', default='./labels', type=str, help="specify where to save the output dir of labels") 88 | # arg = parser.parse_args() 89 | 90 | # def convert(size, box): 91 | # dw = 1. / (size[0]) 92 | # dh = 1. / (size[1]) 93 | # x = box[0] + box[2] / 2.0 94 | # y = box[1] + box[3] / 2.0 95 | # w = box[2] 96 | # h = box[3] 97 | 98 | # x = x * dw 99 | # w = w * dw 100 | # y = y * dh 101 | # h = h * dh 102 | # return (x, y, w, h) 103 | 104 | # if __name__ == '__main__': 105 | # json_file = arg.json_path # COCO Object Instance 类型的标注 106 | # ana_txt_save_path = arg.save_path # 保存的路径 107 | 108 | # data = json.load(open(json_file, 'r')) 109 | # if not os.path.exists(ana_txt_save_path): 110 | # os.makedirs(ana_txt_save_path) 111 | 112 | # id_map = {} # coco数据集的id不连续!重新映射一下再输出! 113 | # with open(os.path.join(ana_txt_save_path, 'classes.txt'), 'w') as f: 114 | # # 写入classes.txt 115 | # for i, category in enumerate(data['categories']): 116 | # f.write(f"{category['name']}\n") 117 | # id_map[category['id']] = i 118 | # # print(id_map) 119 | 120 | # for img in tqdm(data['images']): 121 | # filename = img["file_name"] 122 | # img_width = img["width"] 123 | # img_height = img["height"] 124 | # img_id = img["id"] 125 | # head, tail = os.path.splitext(filename) 126 | # ana_txt_name = head + ".txt" # 对应的txt名字,与jpg一致 127 | # f_txt = open(os.path.join(ana_txt_save_path, ana_txt_name), 'w') 128 | # for ann in data['annotations']: 129 | # if ann['image_id'] == img_id: 130 | # box = convert((img_width, img_height), ann["bbox"]) 131 | # f_txt.write("%s %s %s %s %s\n" % (id_map[ann["category_id"]], box[0], box[1], box[2], box[3])) 132 | # f_txt.close() 133 | -------------------------------------------------------------------------------- /detection/coco_eval.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | from pycocotools.coco import COCO 4 | from pycocotools.cocoeval import COCOeval 5 | import os 6 | import time 7 | 8 | def transform_yolov5_result(result, filename2id): 9 | f = open(result ,'r',encoding='utf-8') 10 | dts = json.load(f) 11 | output_dts = [] 12 | for dt in dts: 13 | dt['image_id'] = filename2id[dt['image_id']+'.jpg'] 14 | dt['category_id'] # id对应好,coco格式和yolo格式的category_id可能不同。 15 | output_dts.append(dt) 16 | with open('temp.json', 'w') as f: 17 | json.dump(output_dts, f) 18 | 19 | def coco_evaluate(gt_path, dt_path, yolov5_flag): 20 | cocoGt = COCO(gt_path) 21 | imgIds = cocoGt.getImgIds() 22 | gts = cocoGt.loadImgs(imgIds) 23 | filename2id = {} 24 | 25 | for gt in gts: 26 | filename2id[gt['file_name']] = gt['id'] 27 | print("NUM OF TEST IMAGES: ",len(filename2id)) 28 | 29 | if yolov5_flag: 30 | transform_yolov5_result(dt_path, filename2id) 31 | cocoDt = cocoGt.loadRes('temp.json') 32 | else: 33 | cocoDt = cocoGt.loadRes(dt_path) 34 | cocoEval = COCOeval(cocoGt, cocoDt, "bbox") 35 | cocoEval.evaluate() 36 | cocoEval.accumulate() 37 | cocoEval.summarize() 38 | if yolov5_flag: 39 | os.remove('temp.json') 40 | 41 | if __name__ == '__main__': 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument("--gt", type=str, help="Assign the groud true path.", default=None) 44 | parser.add_argument("--dt", type=str, help="Assign the detection result path.", default=None) 45 | parser.add_argument("--yolov5",action='store_true',help="fix yolov5 output bug", default=None) 46 | 47 | args = parser.parse_args() 48 | gt_path = args.gt 49 | dt_path = args.dt 50 | if args.yolov5: 51 | coco_evaluate(gt_path, dt_path, True) 52 | else: 53 | coco_evaluate(gt_path, dt_path, False) 54 | -------------------------------------------------------------------------------- /detection/vis_yolo_gt_dt.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | from glob import glob 4 | import random 5 | import matplotlib.pyplot as plt 6 | import argparse 7 | from tqdm import tqdm 8 | import numpy as np 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--root',type=str ,default='', help="which should include ./images and ./labels and classes.txt") 12 | parser.add_argument('--dt',type=str ,default='' , help="yolo format results of detection, include ./labels") 13 | parser.add_argument('--conf' , type=float ,default=0.5, help="visulization conf thres") 14 | arg = parser.parse_args() 15 | 16 | colorlist = [] 17 | # 5^3种颜色。 18 | for i in range(25,256,50): 19 | for j in range(25,256,50): 20 | for k in range(25,256,50): 21 | colorlist.append((i,j,k)) 22 | random.shuffle(colorlist) 23 | 24 | def plot_bbox(img_path, img_dir, out_dir, gt=None ,dt=None, cls2label=None, line_thickness=None): 25 | img = cv2.imread(os.path.join(img_dir, img_path)) 26 | height, width,_ = img.shape 27 | tl = line_thickness or round(0.002 * (width + height) / 2) + 1 # line/font thickness 28 | font = cv2.FONT_HERSHEY_SIMPLEX 29 | if gt: 30 | tf = max(tl - 1, 1) # font thickness 31 | with open(gt,'r') as f: 32 | annotations = f.readlines() 33 | # print(annotations) 34 | for ann in annotations: 35 | ann = list(map(float,ann.split())) 36 | ann[0] = int(ann[0]) 37 | # print(ann) 38 | cls,x,y,w,h = ann 39 | color = colorlist[cls] 40 | c1, c2 = (int((x-w/2)*width),int((y-h/2)*height)), (int((x+w/2)*width), int((y+h/2)*height)) 41 | cv2.rectangle(img, c1, c2, color, thickness=tl*2, lineType=cv2.LINE_AA) 42 | # 类别名称显示 43 | cv2.putText(img, str(cls2label[cls]), (c1[0], c1[1] - 2), 0, tl / 4, color, thickness=tf, lineType=cv2.LINE_AA) 44 | if dt: 45 | with open(dt,'r') as f: 46 | annotations = f.readlines() 47 | # print(annotations) 48 | for ann in annotations: 49 | ann = list(map(float,ann.split())) 50 | ann[0] = int(ann[0]) 51 | # print(ann) 52 | if len(ann) == 6: 53 | cls,x,y,w,h,conf = ann 54 | if conf < arg.conf: 55 | # thres = 0.5 56 | continue 57 | elif len(ann) == 5: 58 | cls,x,y,w,h = ann 59 | color = colorlist[len(colorlist) - cls - 1] 60 | 61 | c1, c2 = (int((x-w/2)*width), int((y-h/2)*height)), (int((x+w/2)*width), int((y+h/2)*height)) 62 | cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA) 63 | 64 | # # cls label 65 | tf = max(tl - 1, 1) # font thickness 66 | t_size = cv2.getTextSize(cls2label[cls], 0, fontScale=tl / 3, thickness=tf)[0] 67 | c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3 68 | # cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled 69 | if len(ann) == 6: 70 | cv2.putText(img, str(round(conf,2)), (c1[0], c1[1] - 2), 0, tl / 4, color, thickness=tf, lineType=cv2.LINE_AA) 71 | cv2.imwrite(os.path.join(out_dir,img_path),img) 72 | 73 | if __name__ == "__main__": 74 | root_path = arg.root 75 | pred_path = arg.dt 76 | img_dir = os.path.join(root_path,'images') 77 | GT_dir = os.path.join(root_path,'labels') 78 | DT_dir = os.path.join(pred_path) 79 | out_dir = os.path.join(root_path,'outputs') 80 | cls_dir = os.path.join(root_path,'classes.txt') 81 | cls_dict = {} 82 | 83 | if not os.path.exists(img_dir): 84 | raise Exception("image dir {} do not exist!".format(img_dir)) 85 | if not os.path.exists(cls_dir): 86 | raise Exception("class dir {} do not exist!".format(cls_dir)) 87 | else: 88 | with open(cls_dir,'r') as f: 89 | classes = f.readlines() 90 | for i in range(len(classes)): 91 | cls_dict[i] = classes[i].strip() 92 | print("class map:", cls_dict) 93 | if not os.path.exists(out_dir): 94 | os.mkdir(out_dir) 95 | if not os.path.exists(GT_dir): 96 | print(f"WARNNING: {GT_dir} ,GT NOT Available!") 97 | if not os.path.exists(DT_dir): 98 | print(f"WARNNING: {DT_dir} ,DT NOT Available!") 99 | for each_img in tqdm(os.listdir(img_dir)): 100 | gt = None 101 | dt = None 102 | if os.path.exists(os.path.join(GT_dir,each_img.replace('jpg','txt'))): 103 | gt = os.path.join(GT_dir,each_img.replace('jpg','txt')) 104 | if os.path.exists(os.path.join(DT_dir,each_img.replace('jpg','txt'))): 105 | dt = os.path.join(DT_dir,each_img.replace('jpg','txt')) 106 | 107 | plot_bbox(each_img, img_dir, out_dir, gt, dt, cls2label=cls_dict) 108 | -------------------------------------------------------------------------------- /detection/yolo2coco.py: -------------------------------------------------------------------------------- 1 | """ 2 | YOLO 格式的数据集转化为 COCO 格式的数据集 3 | --root_dir 输入根路径 4 | --save_path 保存文件的名字(没有random_split时使用) 5 | --random_split 有则会随机划分数据集,然后再分别保存为3个文件。 6 | --split_by_file 按照 ./train.txt ./val.txt ./test.txt 来对数据集进行划分。 7 | """ 8 | 9 | import os 10 | import cv2 11 | import json 12 | from tqdm import tqdm 13 | from sklearn.model_selection import train_test_split 14 | import argparse 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--root_dir', default='./data',type=str, help="root path of images and labels, include ./images and ./labels and classes.txt") 18 | parser.add_argument('--save_path', type=str,default='./train.json', help="if not split the dataset, give a path to a json file") 19 | parser.add_argument('--random_split', action='store_true', help="random split the dataset, default ratio is 8:1:1") 20 | parser.add_argument('--split_by_file', action='store_true', help="define how to split the dataset, include ./train.txt ./val.txt ./test.txt ") 21 | 22 | arg = parser.parse_args() 23 | 24 | def train_test_val_split_random(img_paths,ratio_train=0.8,ratio_test=0.1,ratio_val=0.1): 25 | # 这里可以修改数据集划分的比例。 26 | assert int(ratio_train+ratio_test+ratio_val) == 1 27 | train_img, middle_img = train_test_split(img_paths,test_size=1-ratio_train, random_state=233) 28 | ratio=ratio_val/(1-ratio_train) 29 | val_img, test_img =train_test_split(middle_img,test_size=ratio, random_state=233) 30 | print("NUMS of train:val:test = {}:{}:{}".format(len(train_img), len(val_img), len(test_img))) 31 | return train_img, val_img, test_img 32 | 33 | def train_test_val_split_by_files(img_paths, root_dir): 34 | # 根据文件 train.txt, val.txt, test.txt(里面写的都是对应集合的图片名字) 来定义训练集、验证集和测试集 35 | phases = ['train', 'val', 'test'] 36 | img_split = [] 37 | for p in phases: 38 | define_path = os.path.join(root_dir, f'{p}.txt') 39 | print(f'Read {p} dataset definition from {define_path}') 40 | assert os.path.exists(define_path) 41 | with open(define_path, 'r') as f: 42 | img_paths = f.readlines() 43 | # img_paths = [os.path.split(img_path.strip())[1] for img_path in img_paths] # NOTE 取消这句备注可以读取绝对地址。 44 | img_split.append(img_paths) 45 | return img_split[0], img_split[1], img_split[2] 46 | 47 | 48 | def yolo2coco(arg): 49 | root_path = arg.root_dir 50 | print("Loading data from ",root_path) 51 | 52 | assert os.path.exists(root_path) 53 | originLabelsDir = os.path.join(root_path, 'labels') 54 | originImagesDir = os.path.join(root_path, 'images') 55 | with open(os.path.join(root_path, 'classes.txt')) as f: 56 | classes = f.read().strip().split() 57 | # images dir name 58 | indexes = os.listdir(originImagesDir) 59 | 60 | if arg.random_split or arg.split_by_file: 61 | # 用于保存所有数据的图片信息和标注信息 62 | train_dataset = {'categories': [], 'annotations': [], 'images': []} 63 | val_dataset = {'categories': [], 'annotations': [], 'images': []} 64 | test_dataset = {'categories': [], 'annotations': [], 'images': []} 65 | 66 | # 建立类别标签和数字id的对应关系, 类别id从0开始。 67 | for i, cls in enumerate(classes, 0): 68 | train_dataset['categories'].append({'id': i, 'name': cls, 'supercategory': 'mark'}) 69 | val_dataset['categories'].append({'id': i, 'name': cls, 'supercategory': 'mark'}) 70 | test_dataset['categories'].append({'id': i, 'name': cls, 'supercategory': 'mark'}) 71 | 72 | if arg.random_split: 73 | print("spliting mode: random split") 74 | train_img, val_img, test_img = train_test_val_split_random(indexes,0.8,0.1,0.1) 75 | elif arg.split_by_file: 76 | print("spliting mode: split by files") 77 | train_img, val_img, test_img = train_test_val_split_by_files(indexes, root_path) 78 | else: 79 | dataset = {'categories': [], 'annotations': [], 'images': []} 80 | for i, cls in enumerate(classes, 0): 81 | dataset['categories'].append({'id': i, 'name': cls, 'supercategory': 'mark'}) 82 | 83 | # 标注的id 84 | ann_id_cnt = 0 85 | for k, index in enumerate(tqdm(indexes)): 86 | # 支持 png jpg 格式的图片。 87 | txtFile = index.replace('images','txt').replace('.jpg','.txt').replace('.png','.txt') 88 | # 读取图像的宽和高 89 | im = cv2.imread(os.path.join(root_path, 'images/') + index) 90 | height, width, _ = im.shape 91 | if arg.random_split or arg.split_by_file: 92 | # 切换dataset的引用对象,从而划分数据集 93 | if index in train_img: 94 | dataset = train_dataset 95 | elif index in val_img: 96 | dataset = val_dataset 97 | elif index in test_img: 98 | dataset = test_dataset 99 | # 添加图像的信息 100 | dataset['images'].append({'file_name': index, 101 | 'id': k, 102 | 'width': width, 103 | 'height': height}) 104 | if not os.path.exists(os.path.join(originLabelsDir, txtFile)): 105 | # 如没标签,跳过,只保留图片信息。 106 | continue 107 | with open(os.path.join(originLabelsDir, txtFile), 'r') as fr: 108 | labelList = fr.readlines() 109 | for label in labelList: 110 | label = label.strip().split() 111 | x = float(label[1]) 112 | y = float(label[2]) 113 | w = float(label[3]) 114 | h = float(label[4]) 115 | 116 | # convert x,y,w,h to x1,y1,x2,y2 117 | H, W, _ = im.shape 118 | x1 = (x - w / 2) * W 119 | y1 = (y - h / 2) * H 120 | x2 = (x + w / 2) * W 121 | y2 = (y + h / 2) * H 122 | # 标签序号从0开始计算, coco2017数据集标号混乱,不管它了。 123 | cls_id = int(label[0]) 124 | width = max(0, x2 - x1) 125 | height = max(0, y2 - y1) 126 | dataset['annotations'].append({ 127 | 'area': width * height, 128 | 'bbox': [x1, y1, width, height], 129 | 'category_id': cls_id, 130 | 'id': ann_id_cnt, 131 | 'image_id': k, 132 | 'iscrowd': 0, 133 | # mask, 矩形是从左上角点按顺时针的四个顶点 134 | 'segmentation': [[x1, y1, x2, y1, x2, y2, x1, y2]] 135 | }) 136 | ann_id_cnt += 1 137 | 138 | # 保存结果 139 | folder = os.path.join(root_path, 'annotations') 140 | if not os.path.exists(folder): 141 | os.makedirs(folder) 142 | if arg.random_split or arg.split_by_file: 143 | for phase in ['train','val','test']: 144 | json_name = os.path.join(root_path, 'annotations/{}.json'.format(phase)) 145 | with open(json_name, 'w') as f: 146 | if phase == 'train': 147 | json.dump(train_dataset, f) 148 | elif phase == 'val': 149 | json.dump(val_dataset, f) 150 | elif phase == 'test': 151 | json.dump(test_dataset, f) 152 | print('Save annotation to {}'.format(json_name)) 153 | else: 154 | json_name = os.path.join(root_path, 'annotations/{}'.format(arg.save_path)) 155 | with open(json_name, 'w') as f: 156 | json.dump(dataset, f) 157 | print('Save annotation to {}'.format(json_name)) 158 | 159 | if __name__ == "__main__": 160 | 161 | yolo2coco(arg) -------------------------------------------------------------------------------- /text-image/convert_diffusers_to_original_stable_diffusion.py: -------------------------------------------------------------------------------- 1 | # Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint. 2 | # *Only* converts the UNet, VAE, and Text Encoder. 3 | # Does not convert optimizer state or any other thing. 4 | 5 | import argparse 6 | import os.path as osp 7 | 8 | import torch 9 | 10 | 11 | # =================# 12 | # UNet Conversion # 13 | # =================# 14 | 15 | unet_conversion_map = [ 16 | # (stable-diffusion, HF Diffusers) 17 | ("time_embed.0.weight", "time_embedding.linear_1.weight"), 18 | ("time_embed.0.bias", "time_embedding.linear_1.bias"), 19 | ("time_embed.2.weight", "time_embedding.linear_2.weight"), 20 | ("time_embed.2.bias", "time_embedding.linear_2.bias"), 21 | ("input_blocks.0.0.weight", "conv_in.weight"), 22 | ("input_blocks.0.0.bias", "conv_in.bias"), 23 | ("out.0.weight", "conv_norm_out.weight"), 24 | ("out.0.bias", "conv_norm_out.bias"), 25 | ("out.2.weight", "conv_out.weight"), 26 | ("out.2.bias", "conv_out.bias"), 27 | ] 28 | 29 | unet_conversion_map_resnet = [ 30 | # (stable-diffusion, HF Diffusers) 31 | ("in_layers.0", "norm1"), 32 | ("in_layers.2", "conv1"), 33 | ("out_layers.0", "norm2"), 34 | ("out_layers.3", "conv2"), 35 | ("emb_layers.1", "time_emb_proj"), 36 | ("skip_connection", "conv_shortcut"), 37 | ] 38 | 39 | unet_conversion_map_layer = [] 40 | # hardcoded number of downblocks and resnets/attentions... 41 | # would need smarter logic for other networks. 42 | for i in range(4): 43 | # loop over downblocks/upblocks 44 | 45 | for j in range(2): 46 | # loop over resnets/attentions for downblocks 47 | hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." 48 | sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." 49 | unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) 50 | 51 | if i < 3: 52 | # no attention layers in down_blocks.3 53 | hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." 54 | sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." 55 | unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) 56 | 57 | for j in range(3): 58 | # loop over resnets/attentions for upblocks 59 | hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." 60 | sd_up_res_prefix = f"output_blocks.{3*i + j}.0." 61 | unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) 62 | 63 | if i > 0: 64 | # no attention layers in up_blocks.0 65 | hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." 66 | sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." 67 | unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) 68 | 69 | if i < 3: 70 | # no downsample in down_blocks.3 71 | hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." 72 | sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." 73 | unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) 74 | 75 | # no upsample in up_blocks.3 76 | hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." 77 | sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." 78 | unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) 79 | 80 | hf_mid_atn_prefix = "mid_block.attentions.0." 81 | sd_mid_atn_prefix = "middle_block.1." 82 | unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) 83 | 84 | for j in range(2): 85 | hf_mid_res_prefix = f"mid_block.resnets.{j}." 86 | sd_mid_res_prefix = f"middle_block.{2*j}." 87 | unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) 88 | 89 | 90 | def convert_unet_state_dict(unet_state_dict): 91 | # buyer beware: this is a *brittle* function, 92 | # and correct output requires that all of these pieces interact in 93 | # the exact order in which I have arranged them. 94 | mapping = {k: k for k in unet_state_dict.keys()} 95 | for sd_name, hf_name in unet_conversion_map: 96 | mapping[hf_name] = sd_name 97 | for k, v in mapping.items(): 98 | if "resnets" in k: 99 | for sd_part, hf_part in unet_conversion_map_resnet: 100 | v = v.replace(hf_part, sd_part) 101 | mapping[k] = v 102 | for k, v in mapping.items(): 103 | for sd_part, hf_part in unet_conversion_map_layer: 104 | v = v.replace(hf_part, sd_part) 105 | mapping[k] = v 106 | new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} 107 | return new_state_dict 108 | 109 | 110 | # ================# 111 | # VAE Conversion # 112 | # ================# 113 | 114 | vae_conversion_map = [ 115 | # (stable-diffusion, HF Diffusers) 116 | ("nin_shortcut", "conv_shortcut"), 117 | ("norm_out", "conv_norm_out"), 118 | ("mid.attn_1.", "mid_block.attentions.0."), 119 | ] 120 | 121 | for i in range(4): 122 | # down_blocks have two resnets 123 | for j in range(2): 124 | hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}." 125 | sd_down_prefix = f"encoder.down.{i}.block.{j}." 126 | vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) 127 | 128 | if i < 3: 129 | hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0." 130 | sd_downsample_prefix = f"down.{i}.downsample." 131 | vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) 132 | 133 | hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." 134 | sd_upsample_prefix = f"up.{3-i}.upsample." 135 | vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) 136 | 137 | # up_blocks have three resnets 138 | # also, up blocks in hf are numbered in reverse from sd 139 | for j in range(3): 140 | hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." 141 | sd_up_prefix = f"decoder.up.{3-i}.block.{j}." 142 | vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) 143 | 144 | # this part accounts for mid blocks in both the encoder and the decoder 145 | for i in range(2): 146 | hf_mid_res_prefix = f"mid_block.resnets.{i}." 147 | sd_mid_res_prefix = f"mid.block_{i+1}." 148 | vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) 149 | 150 | 151 | vae_conversion_map_attn = [ 152 | # (stable-diffusion, HF Diffusers) 153 | ("norm.", "group_norm."), 154 | ("q.", "query."), 155 | ("k.", "key."), 156 | ("v.", "value."), 157 | ("proj_out.", "proj_attn."), 158 | ] 159 | 160 | 161 | def reshape_weight_for_sd(w): 162 | # convert HF linear weights to SD conv2d weights 163 | return w.reshape(*w.shape, 1, 1) 164 | 165 | 166 | def convert_vae_state_dict(vae_state_dict): 167 | mapping = {k: k for k in vae_state_dict.keys()} 168 | for k, v in mapping.items(): 169 | for sd_part, hf_part in vae_conversion_map: 170 | v = v.replace(hf_part, sd_part) 171 | mapping[k] = v 172 | for k, v in mapping.items(): 173 | if "attentions" in k: 174 | for sd_part, hf_part in vae_conversion_map_attn: 175 | v = v.replace(hf_part, sd_part) 176 | mapping[k] = v 177 | new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} 178 | weights_to_convert = ["q", "k", "v", "proj_out"] 179 | for k, v in new_state_dict.items(): 180 | for weight_name in weights_to_convert: 181 | if f"mid.attn_1.{weight_name}.weight" in k: 182 | print(f"Reshaping {k} for SD format") 183 | new_state_dict[k] = reshape_weight_for_sd(v) 184 | return new_state_dict 185 | 186 | 187 | # =========================# 188 | # Text Encoder Conversion # 189 | # =========================# 190 | # pretty much a no-op 191 | 192 | 193 | def convert_text_enc_state_dict(text_enc_dict): 194 | return text_enc_dict 195 | 196 | 197 | if __name__ == "__main__": 198 | parser = argparse.ArgumentParser() 199 | 200 | parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.") 201 | parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.") 202 | parser.add_argument("--half", action="store_true", help="Save weights in half precision.") 203 | 204 | args = parser.parse_args() 205 | 206 | assert args.model_path is not None, "Must provide a model path!" 207 | 208 | assert args.checkpoint_path is not None, "Must provide a checkpoint path!" 209 | 210 | unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin") 211 | vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin") 212 | text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin") 213 | 214 | # Convert the UNet model 215 | unet_state_dict = torch.load(unet_path, map_location="cpu") 216 | unet_state_dict = convert_unet_state_dict(unet_state_dict) 217 | unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()} 218 | 219 | # Convert the VAE model 220 | vae_state_dict = torch.load(vae_path, map_location="cpu") 221 | vae_state_dict = convert_vae_state_dict(vae_state_dict) 222 | vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()} 223 | 224 | # Convert the text encoder model 225 | text_enc_dict = torch.load(text_enc_path, map_location="cpu") 226 | text_enc_dict = convert_text_enc_state_dict(text_enc_dict) 227 | text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()} 228 | 229 | # Put together new checkpoint 230 | state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict} 231 | if args.half: 232 | state_dict = {k: v.half() for k, v in state_dict.items()} 233 | state_dict = {"state_dict": state_dict} 234 | torch.save(state_dict, args.checkpoint_path) -------------------------------------------------------------------------------- /text-image/data_filter/data_filter_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 集成了水印、美学、CLIP模型,用于给图文质量打分" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import pytorch_lightning as pl\n", 17 | "import torch.nn as nn\n", 18 | "import torch.nn.functional as F\n", 19 | "import torch\n", 20 | "import timm\n", 21 | "from torchvision import transforms as T\n", 22 | "import open_clip\n", 23 | "import torch\n", 24 | "from transformers import BertModel, BertTokenizer\n", 25 | "from PIL import Image\n", 26 | "\n", 27 | "class AestheticsMLP(pl.LightningModule):\n", 28 | " # 美学判别器是基于CLIP的基础上接了一个MLP\n", 29 | " def __init__(self, input_size, xcol='emb', ycol='avg_rating'):\n", 30 | " super().__init__()\n", 31 | " self.input_size = input_size\n", 32 | " self.xcol = xcol\n", 33 | " self.ycol = ycol\n", 34 | " self.layers = nn.Sequential(\n", 35 | " nn.Linear(self.input_size, 1024),\n", 36 | " #nn.ReLU(),\n", 37 | " nn.Dropout(0.2),\n", 38 | " nn.Linear(1024, 128),\n", 39 | " #nn.ReLU(),\n", 40 | " nn.Dropout(0.2),\n", 41 | " nn.Linear(128, 64),\n", 42 | " #nn.ReLU(),\n", 43 | " nn.Dropout(0.1),\n", 44 | "\n", 45 | " nn.Linear(64, 16),\n", 46 | " #nn.ReLU(),\n", 47 | "\n", 48 | " nn.Linear(16, 1)\n", 49 | " )\n", 50 | "\n", 51 | " def forward(self, x):\n", 52 | " return self.layers(x)\n", 53 | "\n", 54 | " def training_step(self, batch, batch_idx):\n", 55 | " x = batch[self.xcol]\n", 56 | " y = batch[self.ycol].reshape(-1, 1)\n", 57 | " x_hat = self.layers(x)\n", 58 | " loss = F.mse_loss(x_hat, y)\n", 59 | " return loss\n", 60 | " \n", 61 | " def validation_step(self, batch, batch_idx):\n", 62 | " x = batch[self.xcol]\n", 63 | " y = batch[self.ycol].reshape(-1, 1)\n", 64 | " x_hat = self.layers(x)\n", 65 | " loss = F.mse_loss(x_hat, y)\n", 66 | " return loss\n", 67 | "\n", 68 | " def configure_optimizers(self):\n", 69 | " optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)\n", 70 | " return optimizer\n", 71 | "\n", 72 | "\n", 73 | "class WaterMarkModel(nn.Module):\n", 74 | " def __init__(self, model_path='./watermark_model_v1.pt'):\n", 75 | " super(WaterMarkModel, self).__init__()\n", 76 | " # model definition\n", 77 | " self.model = timm.create_model(\n", 78 | " 'efficientnet_b3a', pretrained=True, num_classes=2)\n", 79 | "\n", 80 | " self.model.classifier = nn.Sequential(\n", 81 | " # 1536 is the orginal in_features\n", 82 | " nn.Linear(in_features=1536, out_features=625),\n", 83 | " nn.ReLU(), # ReLu to be the activation function\n", 84 | " nn.Dropout(p=0.3),\n", 85 | " nn.Linear(in_features=625, out_features=256),\n", 86 | " nn.ReLU(),\n", 87 | " nn.Linear(in_features=256, out_features=2),\n", 88 | " )\n", 89 | " self.model.load_state_dict(torch.load(model_path))\n", 90 | " def forward(self, x):\n", 91 | " return self.model(x)" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "class FilterSystem:\n", 101 | " def __init__(\n", 102 | " self, \n", 103 | " clip_model_path=\"IDEA-CCNL/Taiyi-CLIP-RoBERTa-102M-ViT-L-Chinese\",\n", 104 | " aesthetics_model_path=\"./ava+logos-l14-linearMSE.pth\",\n", 105 | " watermark_model_path=\"./watermark_model_v1.pt\"\n", 106 | " ):\n", 107 | " self.clip_model_path = clip_model_path\n", 108 | " self.aesthetics_model_path = aesthetics_model_path\n", 109 | " self.watermark_model_path = watermark_model_path\n", 110 | "\n", 111 | " def init_clip_model(self, ):\n", 112 | " # 此处初始化clip模型,返回模型、tokenizer、processor\n", 113 | " text_encoder = BertModel.from_pretrained(self.clip_model_path).eval().cuda()\n", 114 | " text_tokenizer = BertTokenizer.from_pretrained(self.clip_model_path)\n", 115 | " clip_model, _, processor = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai')\n", 116 | " clip_model = clip_model.eval().cuda()\n", 117 | " self.text_encoder, self.text_tokenizer, self.clip_model, self.processor = text_encoder, text_tokenizer, clip_model, processor\n", 118 | " print(\"clip model loaded\")\n", 119 | " return None\n", 120 | "\n", 121 | " def init_aesthetics_model(self, ):\n", 122 | " # 此处初始化美学模型\n", 123 | " self.aesthetics_model = AestheticsMLP(768)\n", 124 | " self.aesthetics_model.load_state_dict(torch.load(self.aesthetics_model_path))\n", 125 | " self.aesthetics_model.eval().cuda()\n", 126 | " print(\"aesthetics model loaded\")\n", 127 | " return None\n", 128 | "\n", 129 | " def init_watermark_model(self, ):\n", 130 | " self.watermark_model = WaterMarkModel(self.watermark_model_path)\n", 131 | " self.watermark_model.eval().cuda()\n", 132 | " self.watermark_processor = T.Compose([\n", 133 | " T.Resize((256, 256)),\n", 134 | " T.ToTensor(),\n", 135 | " T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n", 136 | " ])\n", 137 | " print(\"watermark model loaded\")\n", 138 | " return None\n", 139 | "\n", 140 | " def get_image_feature(self, images):\n", 141 | " # 此处返回图像的特征向量\n", 142 | " if isinstance(images, list):\n", 143 | " images = torch.stack([self.processor(image) for image in images]).cuda()\n", 144 | " elif isinstance(images, torch.Tensor):\n", 145 | " images = images.cuda()\n", 146 | "\n", 147 | " with torch.no_grad():\n", 148 | " image_features = self.clip_model.encode_image(images)\n", 149 | " image_features /= image_features.norm(dim=1, keepdim=True)\n", 150 | " return image_features\n", 151 | " \n", 152 | " def get_text_feature(self, text):\n", 153 | " # 此处返回文本的特征向量\n", 154 | " if isinstance(text, list) or isinstance(text, str):\n", 155 | " text = self.text_tokenizer(text, return_tensors='pt', padding=True)['input_ids'].cuda()\n", 156 | " elif isinstance(text, torch.Tensor):\n", 157 | " text = text.cuda()\n", 158 | "\n", 159 | " with torch.no_grad():\n", 160 | " text_features = self.text_encoder(text)[1]\n", 161 | " text_features /= text_features.norm(dim=1, keepdim=True)\n", 162 | " return text_features\n", 163 | "\n", 164 | " def calculate_clip_score(self, features1, features2):\n", 165 | " # 此处2个特征向量的相似度,输入可以是 图片+文本、文本+文本、图片+图片。\n", 166 | " # 返回的是相似度矩阵,维度为 f1.shape[0] * f2.shape[0]\n", 167 | " score_matrix = features1 @ features2.t()\n", 168 | " return score_matrix\n", 169 | "\n", 170 | " def get_aesthetics_score(self, features):\n", 171 | " # 此处返回美学分数,传入的是CLIP的feature, 先计算get_image_feature在传入此函数~(模型是ViT-L-14)\n", 172 | " with torch.no_grad():\n", 173 | " scores = self.aesthetics_model(features)\n", 174 | " scores = scores[:, 0].detach().cpu().numpy()\n", 175 | " return scores\n", 176 | " \n", 177 | " def get_watermark_score(self, images):\n", 178 | " if isinstance(images, list):\n", 179 | " images = torch.stack([self.watermark_processor(image) for image in images]).cuda()\n", 180 | " elif isinstance(images, torch.Tensor):\n", 181 | " images = images.cuda()\n", 182 | " with torch.no_grad():\n", 183 | " pred = self.watermark_model(images)\n", 184 | " watermark_scores = F.softmax(pred, dim=1)[:,0].detach().cpu().numpy()\n", 185 | "\n", 186 | " return watermark_scores" 187 | ] 188 | }, 189 | { 190 | "cell_type": "markdown", 191 | "metadata": {}, 192 | "source": [ 193 | "## 小规模数据测试" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": null, 199 | "metadata": {}, 200 | "outputs": [], 201 | "source": [ 202 | "demo = FilterSystem()\n", 203 | "demo.init_clip_model()\n", 204 | "demo.init_aesthetics_model()\n", 205 | "demo.init_watermark_model()" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": null, 211 | "metadata": {}, 212 | "outputs": [], 213 | "source": [ 214 | "image_path = './demo_images/watermark_example.png'\n", 215 | "image_path2 = './demo_images/mengna.jpg'\n", 216 | "image_path3 = './demo_images/shuiyin.jpg'\n", 217 | "image_path4 = './demo_images/1.jpg'\n", 218 | "image_demo = [Image.open(image_path).convert('RGB'), Image.open(image_path2).convert('RGB'), Image.open(image_path3).convert('RGB'), Image.open(image_path4).convert('RGB')]\n", 219 | "image_feature = demo.get_image_feature(image_demo,) # 计算图片特征,传入图片列表,一般而言,可以在数据库保存这个东西,用于响应文本query\n", 220 | "aes_score = demo.get_aesthetics_score(image_feature) # 计算美学分数,传入图片特征,一般而言,可以在数据库保存这个东西,用于响应文本query\n", 221 | "print(aes_score)" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": null, 227 | "metadata": {}, 228 | "outputs": [], 229 | "source": [ 230 | "text_demo = ['一副很美的画','港口小船', '蒙娜丽莎'] # 这里也可以只有一个文本,也就是query\n", 231 | "text_feature = demo.get_text_feature(text_demo) # 计算文本特征,传入文本列表\n", 232 | "similarity = demo.calculate_clip_score(image_feature, text_feature) # 计算相似度\n", 233 | "print(similarity)" 234 | ] 235 | }, 236 | { 237 | "cell_type": "code", 238 | "execution_count": null, 239 | "metadata": {}, 240 | "outputs": [], 241 | "source": [ 242 | "watermark_score = demo.get_watermark_score(image_demo)\n", 243 | "print(watermark_score)" 244 | ] 245 | }, 246 | { 247 | "cell_type": "markdown", 248 | "metadata": {}, 249 | "source": [ 250 | "## 读取处理保存(单个进程)" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": null, 256 | "metadata": {}, 257 | "outputs": [], 258 | "source": [ 259 | "# data setting\n", 260 | "root_path = \"./project/dataset/laion_chinese_cwf/image_part00\"\n", 261 | "all_folders = sorted(os.listdir(root_path))" 262 | ] 263 | }, 264 | { 265 | "cell_type": "code", 266 | "execution_count": null, 267 | "metadata": {}, 268 | "outputs": [], 269 | "source": [ 270 | "# model setting\n", 271 | "filter_model = FilterSystem()\n", 272 | "filter_model.init_clip_model()\n", 273 | "filter_model.init_aesthetics_model()\n", 274 | "filter_model.init_watermark_model()" 275 | ] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "execution_count": null, 280 | "metadata": {}, 281 | "outputs": [], 282 | "source": [ 283 | "from model import FilterSystem\n", 284 | "from dataset import TxtDataset\n", 285 | "import os\n", 286 | "from torch.utils.data import DataLoader\n", 287 | "from tqdm import tqdm\n", 288 | "from PIL import Image\n", 289 | "import numpy as np\n", 290 | "import pandas as pd\n", 291 | "\n", 292 | "def sub_process(filter_model, each_folder_path):\n", 293 | " each_dataset = TxtDataset(each_folder_path)\n", 294 | " each_dataloader = DataLoader(each_dataset, batch_size=8, shuffle=False, num_workers=8)\n", 295 | "\n", 296 | " image_paths = []\n", 297 | " aes_scores = []\n", 298 | " clip_scores = []\n", 299 | " watermark_scores = []\n", 300 | " for iii, (batch_image_paths, texts,) in enumerate(tqdm(each_dataloader)):\n", 301 | " images = [Image.open(each_image_path).convert(\"RGB\") for each_image_path in batch_image_paths]\n", 302 | " image_paths.extend(batch_image_paths)\n", 303 | "\n", 304 | " image_features = filter_model.get_image_feature(images,) # 计算图片特征,传入图片列表,一般而言,可以在数据库保存这个东西,用于响应文本query\n", 305 | " aes_score = filter_model.get_aesthetics_score(image_features) # 计算美学分数,传入图片特征,一般而言,可以在数据库保存这个东西,用于响应文本query\n", 306 | " aes_scores.extend(aes_score)\n", 307 | "\n", 308 | " text_features = filter_model.get_text_feature(list(texts)) # 计算文本特征,传入文本列表\n", 309 | " clip_score = filter_model.calculate_clip_score(image_features, text_features) # 计算相似度\n", 310 | " clip_scores.extend(torch.diagonal(clip_score).detach().cpu().numpy()) # 需要取对角线,只需要自己和对应文本的相似度\n", 311 | "\n", 312 | " watermark_score = filter_model.get_watermark_score(images) # 计算水印分数,传入图片路径列表\n", 313 | " watermark_scores.extend(watermark_score)\n", 314 | " \n", 315 | " # print('aes_score:', aes_score, '\\n',\n", 316 | " # 'clip_score:', clip_score, '\\n',\n", 317 | " # 'watermark_score:', watermark_score, '\\n',\n", 318 | " # 'image_paths:', image_paths, '\\n',\n", 319 | " # 'texts:', texts)\n", 320 | " \n", 321 | " score_pd = pd.DataFrame({'image_path': image_paths, 'aes_score': aes_scores, 'clip_score': clip_scores, 'watermark_score': watermark_scores})\n", 322 | " score_pd.to_csv(os.path.join(each_folder_path, 'score.csv'), index=False)\n", 323 | " print('save score.csv in {}'.format(each_folder_path), '\\n', '-'*20)\n", 324 | "\n", 325 | "for each_folder in all_folders[:10]:\n", 326 | " each_folder_path = os.path.join(root_path, each_folder)\n", 327 | " sub_process(filter_model, each_folder_path)" 328 | ] 329 | }, 330 | { 331 | "cell_type": "code", 332 | "execution_count": null, 333 | "metadata": {}, 334 | "outputs": [], 335 | "source": [ 336 | "from model import FilterSystem\n", 337 | "from dataset import TxtDataset\n", 338 | "import os\n", 339 | "from torch.utils.data import DataLoader\n", 340 | "from tqdm import tqdm\n", 341 | "from PIL import Image\n", 342 | "import numpy as np\n", 343 | "import pandas as pd\n", 344 | "from concurrent.futures import ProcessPoolExecutor, wait, ALL_COMPLETED\n", 345 | "\n", 346 | "p = ProcessPoolExecutor(max_workers=4)\n", 347 | "\n", 348 | "def sub_process(filter_model, each_folder_path):\n", 349 | " each_dataset = TxtDataset(each_folder_path)\n", 350 | " each_dataloader = DataLoader(each_dataset, batch_size=8, shuffle=False, num_workers=8)\n", 351 | "\n", 352 | " image_paths = []\n", 353 | " aes_scores = []\n", 354 | " clip_scores = []\n", 355 | " watermark_scores = []\n", 356 | " for iii, (batch_image_paths, texts,) in enumerate(each_dataloader):\n", 357 | " images = [Image.open(each_image_path) for each_image_path in batch_image_paths]\n", 358 | " image_paths.extend(batch_image_paths)\n", 359 | "\n", 360 | " image_features = filter_model.get_image_feature(images,) # 计算图片特征,传入图片列表,一般而言,可以在数据库保存这个东西,用于响应文本query\n", 361 | " aes_score = filter_model.get_aesthetics_score(image_features) # 计算美学分数,传入图片特征,一般而言,可以在数据库保存这个东西,用于响应文本query\n", 362 | " aes_scores.extend(aes_score)\n", 363 | "\n", 364 | " text_features = filter_model.get_text_feature(list(texts)) # 计算文本特征,传入文本列表\n", 365 | " clip_score = filter_model.calculate_clip_score(image_features, text_features) # 计算相似度\n", 366 | " clip_scores.extend(torch.diagonal(clip_score).detach().cpu().numpy()) # 需要取对角线,只需要自己和对应文本的相似度\n", 367 | "\n", 368 | " watermark_score = filter_model.get_watermark_score(images) # 计算水印分数,传入图片路径列表\n", 369 | " watermark_scores.extend(watermark_score)\n", 370 | " \n", 371 | " # print('aes_score:', aes_score, '\\n',\n", 372 | " # 'clip_score:', clip_score, '\\n',\n", 373 | " # 'watermark_score:', watermark_score, '\\n',\n", 374 | " # 'image_paths:', image_paths, '\\n',\n", 375 | " # 'texts:', texts)\n", 376 | " \n", 377 | " score_pd = pd.DataFrame({'image_path': image_paths, 'aes_score': aes_scores, 'clip_score': clip_scores, 'watermark_score': watermark_scores})\n", 378 | " score_pd.to_csv(os.path.join(each_folder_path, 'score.csv'), index=False)\n", 379 | " print('save score.csv in {}'.format(each_folder_path), '\\n', '-'*20)\n", 380 | "\n", 381 | "for each_folder in all_folders[:10]:\n", 382 | " each_folder_path = os.path.join(root_path, each_folder)\n", 383 | " f1 = p.submit(sub_process, model_pool[0], each_folder_path)\n", 384 | " f2 = p.submit(sub_process, model_pool[1], each_folder_path)\n", 385 | " f3 = p.submit(sub_process, model_pool[2], each_folder_path)\n", 386 | " f4 = p.submit(sub_process, model_pool[3], each_folder_path)\n", 387 | " res = wait([f1, f2, f3, f4], return_when=ALL_COMPLETED)\n", 388 | "p.shutdown()\n" 389 | ] 390 | }, 391 | { 392 | "cell_type": "code", 393 | "execution_count": null, 394 | "metadata": {}, 395 | "outputs": [], 396 | "source": [ 397 | "# 用model pool来开启4进程跑\n", 398 | "model_pool = [FilterSystem() for i in range(4)]\n", 399 | "for model in model_pool:\n", 400 | " model.init_clip_model()\n", 401 | " model.init_aesthetics_model()\n", 402 | " model.init_watermark_model()" 403 | ] 404 | }, 405 | { 406 | "cell_type": "code", 407 | "execution_count": null, 408 | "metadata": {}, 409 | "outputs": [], 410 | "source": [ 411 | "print(aes_scores, clip_scores, watermark_scores)" 412 | ] 413 | }, 414 | { 415 | "cell_type": "code", 416 | "execution_count": null, 417 | "metadata": {}, 418 | "outputs": [], 419 | "source": [ 420 | "print('image_paths:', image_paths, '\\n', 'texts:', texts)" 421 | ] 422 | }, 423 | { 424 | "cell_type": "markdown", 425 | "metadata": {}, 426 | "source": [ 427 | "# pytorch lightning + multi process." 428 | ] 429 | }, 430 | { 431 | "cell_type": "code", 432 | "execution_count": 1, 433 | "metadata": {}, 434 | "outputs": [], 435 | "source": [ 436 | "import pytorch_lightning as pl\n", 437 | "\n", 438 | "class ScoreSystem(pl.LightningModule):\n", 439 | " def __init__(Self):\n", 440 | " super().__init__()\n", 441 | " self.text_encoder, self.text_tokenizer, self.clip_model, self.processor = self.init_clip_model()\n", 442 | " self.aesthetics_model = self.init_aesthetics_model()\n", 443 | " self.watermark_model, self.watermark_processor = self.init_watermark_model()\n", 444 | "\n", 445 | " def init_clip_model(self):\n", 446 | " text_encoder = BertModel.from_pretrained(self.clip_model_path).eval().cuda()\n", 447 | " text_tokenizer = BertTokenizer.from_pretrained(self.clip_model_path)\n", 448 | " clip_model, _, processor = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai')\n", 449 | " clip_model = clip_model.eval().cuda()\n", 450 | " print(\"clip model loaded\")\n", 451 | " return text_encoder, text_tokenizer, clip_model, processor\n", 452 | "\n", 453 | " def init_aesthetics_model(self, ):\n", 454 | " # 此处初始化美学模型\n", 455 | " aesthetics_model = AestheticsMLP(768)\n", 456 | " aesthetics_model.load_state_dict(torch.load(self.aesthetics_model_path)).eval().cuda()\n", 457 | " print(\"aesthetics model loaded\")\n", 458 | " return aesthetics_model\n", 459 | "\n", 460 | " def init_watermark_model(self, ):\n", 461 | " watermark_model = WaterMarkModel(self.watermark_model_path)\n", 462 | " watermark_model.eval().cuda()\n", 463 | " watermark_processor = T.Compose([\n", 464 | " T.Resize((256, 256)),\n", 465 | " T.ToTensor(),\n", 466 | " T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n", 467 | " ])\n", 468 | " print(\"watermark model loaded\")\n", 469 | " return watermark_model, watermark_processor\n", 470 | "\n", 471 | " def get_image_feature(self, images):\n", 472 | " # 此处返回图像的特征向量\n", 473 | " if isinstance(images, list):\n", 474 | " images = torch.stack([self.processor(image) for image in images]).cuda()\n", 475 | " elif isinstance(images, torch.Tensor):\n", 476 | " images = images.cuda()\n", 477 | "\n", 478 | " with torch.no_grad():\n", 479 | " image_features = self.clip_model.encode_image(images)\n", 480 | " image_features /= image_features.norm(dim=1, keepdim=True)\n", 481 | " return image_features\n", 482 | " \n", 483 | " def get_text_feature(self, text):\n", 484 | " # 此处返回文本的特征向量\n", 485 | " if isinstance(text, list) or isinstance(text, str):\n", 486 | " text = self.text_tokenizer(text, return_tensors='pt', padding=True)['input_ids'].cuda()\n", 487 | " elif isinstance(text, torch.Tensor):\n", 488 | " text = text.cuda()\n", 489 | "\n", 490 | " with torch.no_grad():\n", 491 | " text_features = self.text_encoder(text)[1]\n", 492 | " text_features /= text_features.norm(dim=1, keepdim=True)\n", 493 | " return text_features\n", 494 | "\n", 495 | " def calculate_clip_score(self, features1, features2):\n", 496 | " # 此处2个特征向量的相似度,输入可以是 图片+文本、文本+文本、图片+图片。\n", 497 | " # 返回的是相似度矩阵,维度为 f1.shape[0] * f2.shape[0]\n", 498 | " score_matrix = features1 @ features2.t()\n", 499 | " return score_matrix\n", 500 | "\n", 501 | " def get_aesthetics_score(self, features):\n", 502 | " # 此处返回美学分数,传入的是CLIP的feature, 先计算get_image_feature在传入此函数~(模型是ViT-L-14)\n", 503 | " with torch.no_grad():\n", 504 | " scores = self.aesthetics_model(features)\n", 505 | " scores = scores[:, 0].detach().cpu().numpy()\n", 506 | " return scores\n", 507 | " \n", 508 | " def get_watermark_score(self, images):\n", 509 | " if isinstance(images, list):\n", 510 | " images = torch.stack([self.watermark_processor(image) for image in images]).cuda()\n", 511 | " elif isinstance(images, torch.Tensor):\n", 512 | " images = images.cuda()\n", 513 | " with torch.no_grad():\n", 514 | " pred = self.watermark_model(images)\n", 515 | " watermark_scores = F.softmax(pred, dim=1)[:,0].detach().cpu().numpy()\n", 516 | "\n", 517 | " return watermark_scores\n", 518 | "\n", 519 | " def predict_step(self, batch, batch_idx):\n", 520 | " images, texts = batch \n", 521 | " # TODO 这里要么传入处理后的2种图片,要么传入纯图片,然后在下面的函数处理。(目前是传入纯图片)\n", 522 | " image_features = self.get_image_feature(images)\n", 523 | " text_features = self.get_text_feature(texts)\n", 524 | " clip_scores = self.calculate_clip_score(image_features, text_features)\n", 525 | " aes_scores = self.get_aesthetics_score(image_features)\n", 526 | " watermark_scores = self.get_watermark_score(images)\n", 527 | " return clip_scores, aes_scores, watermark_scores\n", 528 | "\n", 529 | " def on_predict_epoch_end(self, outputs):\n", 530 | " # 此处返回所有预测结果\n", 531 | " clip_scores = torch.cat([output[0] for output in outputs], dim=0)\n", 532 | " aes_scores = torch.cat([output[1] for output in outputs], dim=0)\n", 533 | " watermark_scores = torch.cat([output[2] for output in outputs], dim=0)\n", 534 | " return clip_scores, aes_scores, watermark_scores" 535 | ] 536 | } 537 | ], 538 | "metadata": { 539 | "kernelspec": { 540 | "display_name": "Python 3.9.13 ('base')", 541 | "language": "python", 542 | "name": "python3" 543 | }, 544 | "language_info": { 545 | "codemirror_mode": { 546 | "name": "ipython", 547 | "version": 3 548 | }, 549 | "file_extension": ".py", 550 | "mimetype": "text/x-python", 551 | "name": "python", 552 | "nbconvert_exporter": "python", 553 | "pygments_lexer": "ipython3", 554 | "version": "3.9.13" 555 | }, 556 | "orig_nbformat": 4, 557 | "vscode": { 558 | "interpreter": { 559 | "hash": "4cc247672a8bfe61dc951074f9ca89ab002dc0f7e14586a8bb0828228bebeefa" 560 | } 561 | } 562 | }, 563 | "nbformat": 4, 564 | "nbformat_minor": 2 565 | } 566 | -------------------------------------------------------------------------------- /text-image/data_filter/wukong_filter.py: -------------------------------------------------------------------------------- 1 | # %% 2 | from torch.utils.data import Dataset, ConcatDataset 3 | from torchvision import transforms 4 | import os 5 | from PIL import Image 6 | from concurrent.futures import ProcessPoolExecutor 7 | import json 8 | import torch 9 | from transformers import BertModel 10 | import open_clip 11 | import numpy as np 12 | from transformers import BertTokenizer 13 | import pandas as pd 14 | from tqdm import tqdm 15 | import argparse 16 | 17 | 18 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 19 | parser.add_argument( 20 | "--part", 21 | type=int, 22 | default=0, 23 | required=True, 24 | ) 25 | args = parser.parse_args() 26 | 27 | 28 | class CsvDataset(Dataset): 29 | def __init__(self, input_filename, transforms, input_root, tokenizer, img_key, caption_key, sep="\t"): 30 | # logging.debug(f'Loading csv data from {input_filename}.') 31 | print(f'Loading csv data from {input_filename}.') 32 | self.images = [] 33 | self.captions = [] 34 | if input_filename.endswith('.csv'): 35 | df = pd.read_csv(input_filename, index_col=0) 36 | df = df[df['used'] == 1] 37 | self.images.extend(df[img_key].tolist()) 38 | self.captions.extend(df[caption_key].tolist()) 39 | # NOTE 中文的tokenizer 40 | self.tokenizer = tokenizer 41 | self.context_length = 77 42 | self.root = input_root 43 | self.transforms = transforms 44 | 45 | def __len__(self): 46 | return len(self.images) 47 | 48 | def __getitem__(self, idx): 49 | img_path = str(self.images[idx]) 50 | image = self.transforms(Image.open( os.path.join(self.root, img_path ))) 51 | text = self.tokenizer(str(self.captions[idx]), max_length=self.context_length, padding='max_length', truncation=True, return_tensors='pt')['input_ids'][0] 52 | return image, text, img_path 53 | 54 | 55 | text_encoder = BertModel.from_pretrained("IDEA-CCNL/Taiyi-CLIP-RoBERTa-102M-ViT-L-Chinese").eval().cuda() 56 | clip_model, _, processor = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai') 57 | clip_model = clip_model.eval().cuda() 58 | text_tokenizer = BertTokenizer.from_pretrained("IDEA-CCNL/Taiyi-CLIP-RoBERTa-102M-ViT-L-Chinese") 59 | 60 | 61 | input_filename = './project/dataset/wukong/release' 62 | preprocess_fn = processor 63 | input_root = './project/dataset/wukong/images' 64 | tokenizer = text_tokenizer 65 | all_csvs = sorted(os.listdir(input_filename)) 66 | 67 | for i in range(len(all_csvs)*args.part//5, len(all_csvs)*(args.part+1)//5): 68 | # 分成5part 69 | each_csv_path = os.path.join(input_filename, all_csvs[i]) 70 | dataset = CsvDataset(each_csv_path, preprocess_fn, input_root, tokenizer, img_key="name", caption_key="caption") 71 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=False, num_workers=8, pin_memory=True) 72 | 73 | df = pd.read_csv(each_csv_path, index_col=0) 74 | df = df[df['used'] == 1] 75 | scores = [] 76 | for iii, (image, text, image_path) in enumerate(tqdm(dataloader)): 77 | # print(image.shape, text.shape) 78 | with torch.no_grad(): 79 | image = image.cuda() 80 | text = text.cuda() 81 | # print(image.shape, text.shape) 82 | image_features = clip_model.encode_image(image) 83 | text_features = text_encoder(text)[1] 84 | 85 | # print(image_features.shape, text_features.shape) 86 | # 归一化 87 | image_features = image_features / image_features.norm(dim=1, keepdim=True) 88 | text_features = text_features / text_features.norm(dim=1, keepdim=True) 89 | score_each_pair = image_features @ text_features.t() 90 | 91 | scores.extend(torch.diagonal(score_each_pair).detach().cpu().numpy()) 92 | # break 93 | df['score'] = scores 94 | df.to_csv( each_csv_path.replace(all_csvs[i], 'score'+all_csvs[i]) , index=False) 95 | print('saving score to', each_csv_path.replace(all_csvs[i], 'score'+all_csvs[i]) ) 96 | -------------------------------------------------------------------------------- /text-image/data_filter/wukong_reader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, ConcatDataset 2 | from torchvision import transforms 3 | import os 4 | from PIL import Image 5 | from concurrent.futures import ProcessPoolExecutor 6 | import json 7 | import torch 8 | from transformers import BertModel 9 | import open_clip 10 | import numpy as np 11 | from transformers import BertTokenizer 12 | import pandas as pd 13 | from tqdm import tqdm 14 | import argparse 15 | import torch 16 | # NOTE 加速读取数据,直接用原版的,在外部使用并行读取策略。30min->3min 17 | class CsvDataset(Dataset): 18 | def __init__(self, input_filename, input_root, img_key, caption_key, transforms=None, thres=0.2, sep="\t"): 19 | # logging.debug(f'Loading csv data from {input_filename}.') 20 | print(f'Loading csv data from {input_filename}.') 21 | self.images = [] 22 | self.captions = [] 23 | 24 | if input_filename.endswith('.csv'): 25 | # print(f"Load Data from{input_filename}") 26 | df = pd.read_csv(input_filename, index_col=0) 27 | df = df[df['used'] == 1] 28 | df = df[df['score']>thres] 29 | self.images.extend(df[img_key].tolist()) 30 | self.captions.extend(df[caption_key].tolist()) 31 | 32 | # NOTE 中文的tokenizer 33 | self.tokenizer = BertTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext") 34 | 35 | self.context_length = 77 36 | self.root = input_root 37 | self.transforms = transforms 38 | 39 | def __len__(self): 40 | return len(self.images) 41 | 42 | def __getitem__(self, idx): 43 | img_path = str(self.images[idx]) 44 | image = self.transforms(Image.open( os.path.join(self.root, img_path ))) 45 | text = self.tokenizer(str(self.captions[idx]), max_length=self.context_length, padding='max_length', truncation=True, return_tensors='pt')['input_ids'][0] 46 | return image, text 47 | 48 | 49 | def process_pool_read_csv_dataset(input_root, input_filename, thres=0.20): 50 | # here input_filename is a directory containing a CSV file 51 | all_csvs = os.listdir(input_filename) 52 | 53 | csv_with_score = [each for each in all_csvs if 'score' in each ] 54 | all_datasets = [] 55 | res = [] 56 | p = ProcessPoolExecutor(max_workers=24) 57 | for i in range(len(csv_with_score)): 58 | each_csv_path = os.path.join(input_filename, csv_with_score[i]) 59 | print(i, each_csv_path) 60 | res.append(p.submit(CsvDataset, each_csv_path, input_root, img_key="name", caption_key="caption", thres=thres)) 61 | p.shutdown() 62 | for future in res: 63 | all_datasets.append(future.result()) 64 | dataset = ConcatDataset(all_datasets) 65 | return dataset 66 | 67 | 68 | tokenizer = BertTokenizer.from_pretrained("IDEA-CCNL/Taiyi-CLIP-RoBERTa-102M-ViT-L-Chinese", model_max_length=512) 69 | input_filename = './project/dataset/wukong/release' # 这里存的是csv标注地址 70 | input_root = './project/dataset/wukong/images' 71 | dataset = process_pool_read_csv_dataset(input_root, input_filename, thres=0.22) 72 | 73 | print(len(dataset)) -------------------------------------------------------------------------------- /text-image/fid_clip_score/.gitignore: -------------------------------------------------------------------------------- 1 | /output* 2 | -------------------------------------------------------------------------------- /text-image/fid_clip_score/coco_sample_generator.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | import pandas as pd 3 | import os 4 | from diffusers import StableDiffusionPipeline 5 | from argparse import ArgumentParser 6 | from tqdm import tqdm 7 | from multiprocessing import Process 8 | 9 | parser = ArgumentParser() 10 | parser.add_argument('--coco_path', type=str, default='../dataset/coco') 11 | parser.add_argument('--coco_cache_file', type=str, default='../dataset/coco/subset.parquet') 12 | parser.add_argument('--output_path', type=str, default='./output') 13 | parser.add_argument('--model_path', type=str, default='../pretrained_models/stable-diffusion-v1-4') 14 | parser.add_argument('--sample_step', type=int, default=20) 15 | parser.add_argument('--guidance_scale', type=float, default=1.5) 16 | parser.add_argument('--batch_size', type=int, default=2) 17 | args = parser.parse_args() 18 | 19 | 20 | class COCOCaptionSubset(Dataset): 21 | def __init__(self, path, transform=None): 22 | self.df = pd.read_parquet(path) 23 | 24 | def __len__(self): 25 | return len(self.df) 26 | 27 | def __getitem__(self, idx): 28 | row = self.df.iloc[idx] 29 | return row['file_name'], row['caption'] 30 | 31 | def save_images(images, image_paths, output_path): 32 | for i, image_path in enumerate(image_paths): 33 | image_path = image_path.replace('/', '_') 34 | image_path = os.path.join(output_path, image_path) 35 | images[i].save(image_path) 36 | 37 | if __name__ == '__main__': 38 | # testing 39 | coco_path = args.coco_path 40 | # coco_cache_file = f'{coco_path}/subset.parquet' # sampled subsets 41 | cocosubset = COCOCaptionSubset(args.coco_cache_file) 42 | cocosubsetloader = DataLoader(cocosubset, batch_size=args.batch_size, shuffle=False, num_workers=8) 43 | 44 | # load the t2i model 45 | stable_diffusion = StableDiffusionPipeline.from_pretrained(args.model_path, requires_safety_checker=False).to('cuda') 46 | 47 | sample_step = args.sample_step 48 | guidance_scale = args.guidance_scale 49 | 50 | 51 | output_path = os.path.join( 52 | args.output_path, 53 | f'./gs{guidance_scale}_ss{sample_step}' 54 | ) 55 | os.makedirs(output_path, exist_ok=True) 56 | 57 | for i, (image_paths, captions) in enumerate(tqdm(cocosubsetloader)): 58 | outputs = stable_diffusion(list(captions), num_inference_steps=sample_step, guidance_scale=guidance_scale).images 59 | p = Process(target=save_images, args=(outputs, image_paths, output_path)) 60 | p.start() -------------------------------------------------------------------------------- /text-image/fid_clip_score/compute_fid.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# FID指标计算" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [ 15 | { 16 | "name": "stdout", 17 | "output_type": "stream", 18 | "text": [ 19 | "FID (same): -0.001\n", 20 | "FID (different): 486.117\n" 21 | ] 22 | } 23 | ], 24 | "source": [ 25 | "import numpy as np\n", 26 | "from scipy.linalg import sqrtm\n", 27 | "def calculate_fid(act1, act2):\n", 28 | " # calculate mean and covariance statistics\n", 29 | " mu1, sigma1 = act1.mean(axis=0), np.cov(act1, rowvar=False)\n", 30 | " mu2, sigma2 = act2.mean(axis=0), np.cov(act2, rowvar=False)\n", 31 | " # print(mu1.shape, mu2.shape, sigma1.shape, sigma2.shape)\n", 32 | " # calculate sum squared difference between means\n", 33 | " ssdiff = np.sum((mu1 - mu2)**2.0)\n", 34 | " # print(ssdiff)\n", 35 | " # calculate sqrt of product between cov\n", 36 | " covmean = sqrtm(sigma1.dot(sigma2)) # 负数平方根也能算\n", 37 | " # print(covmean)\n", 38 | " # check and correct imaginary numbers from sqrt\n", 39 | " if np.iscomplexobj(covmean):\n", 40 | " covmean = covmean.real\n", 41 | " # calculate score\n", 42 | " fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)\n", 43 | " return fid\n", 44 | "\n", 45 | "# define two collections of activations\n", 46 | "act1 = np.random.rand(2, 2048)\n", 47 | "act2 = np.random.rand(3, 2048)\n", 48 | "# fid between act1 and act1\n", 49 | "fid = calculate_fid(act1, act1)\n", 50 | "print('FID (same): %.3f' % fid) # should be 0.0\n", 51 | "# fid between act1 and act2\n", 52 | "fid = calculate_fid(act1, act2)\n", 53 | "print('FID (different): %.3f' % fid) # should be > 0.0" 54 | ] 55 | } 56 | ], 57 | "metadata": { 58 | "kernelspec": { 59 | "display_name": "Python 3.9.13 ('base')", 60 | "language": "python", 61 | "name": "python3" 62 | }, 63 | "language_info": { 64 | "codemirror_mode": { 65 | "name": "ipython", 66 | "version": 3 67 | }, 68 | "file_extension": ".py", 69 | "mimetype": "text/x-python", 70 | "name": "python", 71 | "nbconvert_exporter": "python", 72 | "pygments_lexer": "ipython3", 73 | "version": "3.9.13" 74 | }, 75 | "orig_nbformat": 4, 76 | "vscode": { 77 | "interpreter": { 78 | "hash": "4cc247672a8bfe61dc951074f9ca89ab002dc0f7e14586a8bb0828228bebeefa" 79 | } 80 | } 81 | }, 82 | "nbformat": 4, 83 | "nbformat_minor": 2 84 | } 85 | -------------------------------------------------------------------------------- /text-image/fid_clip_score/fid_clip_coco.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "reference: https://wandb.ai/dalle-mini/dalle-mini/reports/CLIP-score-vs-FID-pareto-curves--VmlldzoyMDYyNTAy" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "# Sampling data" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "# load data\n", 24 | "import json\n", 25 | "coco_path = '/home/tiger/project/dataset/coco'\n", 26 | "data_file = f'{coco_path}/annotations/captions_val2014.json'\n", 27 | "data = json.load(open(data_file))\n", 28 | "\n", 29 | "\n", 30 | "# merge images and annotations\n", 31 | "import pandas as pd\n", 32 | "images = data['images']\n", 33 | "annotations = data['annotations']\n", 34 | "df = pd.DataFrame(images)\n", 35 | "df_annotations = pd.DataFrame(annotations)\n", 36 | "df = df.merge(pd.DataFrame(annotations), how='left', left_on='id', right_on='image_id')\n", 37 | "\n", 38 | "\n", 39 | "# keep only the relevant columns\n", 40 | "df = df[['file_name', 'caption']]\n", 41 | "\n", 42 | "\n", 43 | "# shuffle the dataset\n", 44 | "df = df.sample(frac=1)\n", 45 | "\n", 46 | "\n", 47 | "# remove duplicate images\n", 48 | "df = df.drop_duplicates(subset='file_name')\n", 49 | "\n", 50 | "\n", 51 | "# create a random subset\n", 52 | "n_samples = 10000\n", 53 | "df_sample = df.sample(n_samples)\n", 54 | "\n", 55 | "\n", 56 | "# save the sample to a parquet file\n", 57 | "df_sample.to_parquet(f'{coco_path}/subset.parquet')\n", 58 | "\n", 59 | "\n", 60 | "# copy the images to reference folder\n", 61 | "from pathlib import Path\n", 62 | "import shutil\n", 63 | "subset_path = Path(f'{coco_path}/subset')\n", 64 | "subset_path.mkdir(exist_ok=True)\n", 65 | "for i, row in df_sample.iterrows():\n", 66 | " path = f'{coco_path}/val2014/' + row['file_name']\n", 67 | " shutil.copy(path, f'{coco_path}/subset/')\n" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": null, 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "# center crop the images\n", 77 | "def center_crop_images(folder, output_folder, size):\n", 78 | " # coco images are not square, so we need to center crop them\n", 79 | " from PIL import Image\n", 80 | " import os\n", 81 | " os.makedirs(output_folder, exist_ok=True)\n", 82 | " for file in os.listdir(folder):\n", 83 | " image_path = os.path.join(folder, file)\n", 84 | " image = Image.open(image_path)\n", 85 | " width, height = image.size\n", 86 | " left = (width - size) / 2 if width > size else 0\n", 87 | " top = (height - size) / 2 if height > size else 0\n", 88 | " right = (width + size) / 2 if width > size else width\n", 89 | " bottom = (height + size) / 2 if height > size else height\n", 90 | " image = image.crop((left, top, right, bottom))\n", 91 | " image = image.resize((size, size)) # resize non-square images\n", 92 | " image.save(os.path.join(output_folder, file))\n", 93 | "\n", 94 | "folder_name = '/home/tiger/project/dataset/coco/subset'\n", 95 | "output_folder = '/home/tiger/project/dataset/coco/subset_cropped'\n", 96 | "center_crop_images(folder_name, output_folder, 320)" 97 | ] 98 | }, 99 | { 100 | "cell_type": "markdown", 101 | "metadata": {}, 102 | "source": [ 103 | "# Load subset as dataloader" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 5, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "# load the subset\n", 113 | "from torch.utils.data import Dataset, DataLoader\n", 114 | "import pandas as pd \n", 115 | "\n", 116 | "class COCOCaptionSubset(Dataset):\n", 117 | " def __init__(self, path, transform=None):\n", 118 | " self.df = pd.read_parquet(path)\n", 119 | "\n", 120 | " def __len__(self):\n", 121 | " return len(self.df)\n", 122 | "\n", 123 | " def __getitem__(self, idx):\n", 124 | " row = self.df.iloc[idx]\n", 125 | " return row['file_name'], row['caption']\n", 126 | "\n", 127 | "# testing \n", 128 | "coco_path = '/home/tiger/project/dataset/coco'\n", 129 | "coco_cache_file = f'{coco_path}/subset.parquet' # sampled subsets\n", 130 | "cocosubset = COCOCaptionSubset(coco_cache_file)\n", 131 | "cocosubsetloader = DataLoader(cocosubset, batch_size=64, shuffle=False, num_workers=8)\n" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": null, 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [] 140 | }, 141 | { 142 | "cell_type": "markdown", 143 | "metadata": {}, 144 | "source": [ 145 | "# Generating Images Via T2I Model" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": null, 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [ 154 | "# demo inference, use coco_sample_generator.py to generate more\n", 155 | "# load the t2i model\n", 156 | "# from diffusers import StableDiffusionPipeline\n", 157 | "# stable_diffusion = StableDiffusionPipeline.from_pretrained(\"/home/tiger/project/pretrained_models/stable-diffusion-v1-4\").to('cuda') \n", 158 | "\n", 159 | "# sample_step = 20\n", 160 | "# guidance_scale = 1.5\n", 161 | "\n", 162 | "# import os\n", 163 | "\n", 164 | "# output_path = f'./output_gs{guidance_scale}_ss{sample_step}'\n", 165 | "# os.makedirs(output_path, exist_ok=True)\n", 166 | "\n", 167 | "# for i, (image_paths, captions) in enumerate(cocosubsetloader):\n", 168 | "# outputs = stable_diffusion(list(captions), num_inference_steps=sample_step, guidance_scale=guidance_scale).images\n", 169 | "# for j, image_path in enumerate(image_paths):\n", 170 | "# image_path = image_path.replace('/', '_')\n", 171 | "# image_path = os.path.join(output_path, image_path)\n", 172 | "# outputs[j].save(image_path)\n", 173 | "# break" 174 | ] 175 | }, 176 | { 177 | "cell_type": "markdown", 178 | "metadata": {}, 179 | "source": [] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": 1, 184 | "metadata": {}, 185 | "outputs": [], 186 | "source": [ 187 | "import torch\n", 188 | "device = torch.device('cuda')\n", 189 | "\n", 190 | "coco_subset_crop_path = '/home/tiger/project/dataset/coco/subset_cropped'\n", 191 | "output_root = '/home/tiger/project/position-guided-t2i/output'\n", 192 | "output_paths = [os.path.join(output_root, out) for out in sorted(os.listdir(output_root))]\n" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": 2, 198 | "metadata": {}, 199 | "outputs": [ 200 | { 201 | "name": "stderr", 202 | "output_type": "stream", 203 | "text": [ 204 | "/home/tiger/anaconda3/lib/python3.9/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and will be removed in 0.15, please use 'weights' instead.\n", 205 | " warnings.warn(\n", 206 | "/home/tiger/anaconda3/lib/python3.9/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=None`.\n", 207 | " warnings.warn(msg)\n", 208 | "100%|██████████| 50/50 [00:13<00:00, 3.58it/s]\n", 209 | "100%|██████████| 50/50 [00:16<00:00, 3.06it/s]\n" 210 | ] 211 | }, 212 | { 213 | "name": "stdout", 214 | "output_type": "stream", 215 | "text": [ 216 | "/home/tiger/project/position-guided-t2i/output/gs1.5_ss20 22.765903388613765\n" 217 | ] 218 | }, 219 | { 220 | "name": "stderr", 221 | "output_type": "stream", 222 | "text": [ 223 | "100%|██████████| 50/50 [00:09<00:00, 5.50it/s]\n", 224 | "100%|██████████| 50/50 [00:16<00:00, 3.03it/s]\n" 225 | ] 226 | }, 227 | { 228 | "name": "stdout", 229 | "output_type": "stream", 230 | "text": [ 231 | "/home/tiger/project/position-guided-t2i/output/gs2.0_ss20 18.159921113816665\n" 232 | ] 233 | }, 234 | { 235 | "name": "stderr", 236 | "output_type": "stream", 237 | "text": [ 238 | "100%|██████████| 50/50 [00:10<00:00, 4.95it/s]\n", 239 | "100%|██████████| 50/50 [00:15<00:00, 3.14it/s]\n" 240 | ] 241 | }, 242 | { 243 | "name": "stdout", 244 | "output_type": "stream", 245 | "text": [ 246 | "/home/tiger/project/position-guided-t2i/output/gs3.0_ss20 15.94397287378655\n" 247 | ] 248 | }, 249 | { 250 | "name": "stderr", 251 | "output_type": "stream", 252 | "text": [ 253 | "100%|██████████| 50/50 [00:09<00:00, 5.18it/s]\n", 254 | "100%|██████████| 50/50 [00:15<00:00, 3.14it/s]\n" 255 | ] 256 | }, 257 | { 258 | "name": "stdout", 259 | "output_type": "stream", 260 | "text": [ 261 | "/home/tiger/project/position-guided-t2i/output/gs4.0_ss20 16.315106185605657\n" 262 | ] 263 | }, 264 | { 265 | "name": "stderr", 266 | "output_type": "stream", 267 | "text": [ 268 | "100%|██████████| 50/50 [00:09<00:00, 5.17it/s]\n", 269 | "100%|██████████| 50/50 [00:16<00:00, 3.00it/s]\n" 270 | ] 271 | }, 272 | { 273 | "name": "stdout", 274 | "output_type": "stream", 275 | "text": [ 276 | "/home/tiger/project/position-guided-t2i/output/gs5.0_ss20 17.35088805364785\n" 277 | ] 278 | }, 279 | { 280 | "name": "stderr", 281 | "output_type": "stream", 282 | "text": [ 283 | "100%|██████████| 50/50 [00:09<00:00, 5.15it/s]\n", 284 | "100%|██████████| 50/50 [00:16<00:00, 3.01it/s]\n" 285 | ] 286 | }, 287 | { 288 | "name": "stdout", 289 | "output_type": "stream", 290 | "text": [ 291 | "/home/tiger/project/position-guided-t2i/output/gs6.0_ss20 17.933771904354728\n" 292 | ] 293 | }, 294 | { 295 | "name": "stderr", 296 | "output_type": "stream", 297 | "text": [ 298 | "100%|██████████| 50/50 [00:09<00:00, 5.21it/s]\n", 299 | "100%|██████████| 50/50 [00:16<00:00, 3.06it/s]\n" 300 | ] 301 | }, 302 | { 303 | "name": "stdout", 304 | "output_type": "stream", 305 | "text": [ 306 | "/home/tiger/project/position-guided-t2i/output/gs7.0_ss20 19.059673548019532\n" 307 | ] 308 | }, 309 | { 310 | "name": "stderr", 311 | "output_type": "stream", 312 | "text": [ 313 | "100%|██████████| 50/50 [00:09<00:00, 5.18it/s]\n", 314 | "100%|██████████| 50/50 [00:17<00:00, 2.90it/s]\n" 315 | ] 316 | }, 317 | { 318 | "name": "stdout", 319 | "output_type": "stream", 320 | "text": [ 321 | "/home/tiger/project/position-guided-t2i/output/gs8.0_ss20 20.12984543749127\n" 322 | ] 323 | } 324 | ], 325 | "source": [ 326 | "# fid score\n", 327 | "\n", 328 | "# !pip install pytorch_fid\n", 329 | "# !python -m pytorch_fid /home/tiger/project/dataset/coco/subset_cropped /home/tiger/project/position-guided-t2i/output/gs2.0_ss20\n", 330 | "\n", 331 | "from pytorch_fid.fid_score import calculate_fid_given_paths\n", 332 | "\n", 333 | "fids = []\n", 334 | "for output_path in output_paths:\n", 335 | " fid_value = calculate_fid_given_paths([coco_subset_crop_path, output_path], batch_size=200, device=device, dims=2048, num_workers=8)\n", 336 | " fids.append(fid_value)\n", 337 | " print(output_path, fid_value)" 338 | ] 339 | }, 340 | { 341 | "cell_type": "code", 342 | "execution_count": 3, 343 | "metadata": {}, 344 | "outputs": [], 345 | "source": [ 346 | "from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer\n", 347 | "from PIL import Image\n", 348 | "import numpy as np\n", 349 | "from tqdm import tqdm\n", 350 | "\n", 351 | "\n", 352 | "def load_clip_model(model_path='openai/clip-vit-large-patch14'):\n", 353 | " # text_encoder = BertModel.from_pretrained(model_path).eval().cuda()\n", 354 | " # text_tokenizer = BertTokenizer.from_pretrained(model_path)\n", 355 | " clip_model = CLIPModel.from_pretrained(model_path)\n", 356 | " processor = CLIPProcessor.from_pretrained(model_path)\n", 357 | " tokenizer = CLIPTokenizer.from_pretrained(model_path)\n", 358 | "\n", 359 | " clip_model = clip_model.eval().cuda()\n", 360 | " return clip_model, processor, tokenizer\n", 361 | "\n", 362 | "\n", 363 | "def clip_score(clip_model, processor, tokenizer, dataloader, output_image_path):\n", 364 | " all_image_features = []\n", 365 | " all_text_features = []\n", 366 | " for (i, (image_paths, captions)) in enumerate(tqdm(dataloader)):\n", 367 | " # print(image_paths, captions)\n", 368 | " text_inputs = tokenizer(list(captions), padding=True, return_tensors=\"pt\").to('cuda')\n", 369 | " text_features = clip_model.get_text_features(**text_inputs)\n", 370 | " text_features = text_features / text_features.norm(dim=-1, keepdim=True)\n", 371 | " text_features = text_features.detach().cpu().numpy()\n", 372 | " all_text_features.append(text_features)\n", 373 | "\n", 374 | " # vit 速度比较龟\n", 375 | " images = [Image.open(os.path.join( output_image_path , image_path)) for image_path in image_paths]\n", 376 | " image_inputs = processor(images = images, return_tensors=\"pt\").to('cuda')\n", 377 | " image_features = clip_model.get_image_features(**image_inputs)\n", 378 | " image_features = image_features / image_features.norm(dim=-1, keepdim=True)\n", 379 | " image_features = image_features.detach().cpu().numpy()\n", 380 | " all_image_features.append(image_features)\n", 381 | "\n", 382 | " # NOTE testing 等太久了,抽样吧... 需要全部的话,把这个 if 去掉\n", 383 | " if i == 10:\n", 384 | " break\n", 385 | "\n", 386 | " all_text_features = np.concatenate(all_text_features, axis=0)\n", 387 | " all_image_features = np.concatenate(all_image_features, axis=0)\n", 388 | " mean_similarity = (all_image_features @ all_text_features.T).diagonal().mean()\n", 389 | " return mean_similarity\n" 390 | ] 391 | }, 392 | { 393 | "cell_type": "code", 394 | "execution_count": 6, 395 | "metadata": {}, 396 | "outputs": [ 397 | { 398 | "name": "stderr", 399 | "output_type": "stream", 400 | "text": [ 401 | " 6%|▋ | 10/157 [00:15<03:45, 1.54s/it]\n" 402 | ] 403 | }, 404 | { 405 | "name": "stdout", 406 | "output_type": "stream", 407 | "text": [ 408 | "/home/tiger/project/position-guided-t2i/output/gs1.5_ss20 0.23490335\n" 409 | ] 410 | }, 411 | { 412 | "name": "stderr", 413 | "output_type": "stream", 414 | "text": [ 415 | " 6%|▋ | 10/157 [00:14<03:39, 1.50s/it]\n" 416 | ] 417 | }, 418 | { 419 | "name": "stdout", 420 | "output_type": "stream", 421 | "text": [ 422 | "/home/tiger/project/position-guided-t2i/output/gs2.0_ss20 0.24406949\n" 423 | ] 424 | }, 425 | { 426 | "name": "stderr", 427 | "output_type": "stream", 428 | "text": [ 429 | " 6%|▋ | 10/157 [00:15<03:41, 1.51s/it]\n" 430 | ] 431 | }, 432 | { 433 | "name": "stdout", 434 | "output_type": "stream", 435 | "text": [ 436 | "/home/tiger/project/position-guided-t2i/output/gs3.0_ss20 0.25112092\n" 437 | ] 438 | }, 439 | { 440 | "name": "stderr", 441 | "output_type": "stream", 442 | "text": [ 443 | " 6%|▋ | 10/157 [00:14<03:39, 1.49s/it]\n" 444 | ] 445 | }, 446 | { 447 | "name": "stdout", 448 | "output_type": "stream", 449 | "text": [ 450 | "/home/tiger/project/position-guided-t2i/output/gs4.0_ss20 0.25709876\n" 451 | ] 452 | }, 453 | { 454 | "name": "stderr", 455 | "output_type": "stream", 456 | "text": [ 457 | " 6%|▋ | 10/157 [00:14<03:37, 1.48s/it]\n" 458 | ] 459 | }, 460 | { 461 | "name": "stdout", 462 | "output_type": "stream", 463 | "text": [ 464 | "/home/tiger/project/position-guided-t2i/output/gs5.0_ss20 0.25781947\n" 465 | ] 466 | }, 467 | { 468 | "name": "stderr", 469 | "output_type": "stream", 470 | "text": [ 471 | " 6%|▋ | 10/157 [00:14<03:40, 1.50s/it]\n" 472 | ] 473 | }, 474 | { 475 | "name": "stdout", 476 | "output_type": "stream", 477 | "text": [ 478 | "/home/tiger/project/position-guided-t2i/output/gs6.0_ss20 0.2593051\n" 479 | ] 480 | }, 481 | { 482 | "name": "stderr", 483 | "output_type": "stream", 484 | "text": [ 485 | " 6%|▋ | 10/157 [00:14<03:37, 1.48s/it]\n" 486 | ] 487 | }, 488 | { 489 | "name": "stdout", 490 | "output_type": "stream", 491 | "text": [ 492 | "/home/tiger/project/position-guided-t2i/output/gs7.0_ss20 0.26007786\n" 493 | ] 494 | }, 495 | { 496 | "name": "stderr", 497 | "output_type": "stream", 498 | "text": [ 499 | " 6%|▋ | 10/157 [00:14<03:37, 1.48s/it]" 500 | ] 501 | }, 502 | { 503 | "name": "stdout", 504 | "output_type": "stream", 505 | "text": [ 506 | "/home/tiger/project/position-guided-t2i/output/gs8.0_ss20 0.2596085\n" 507 | ] 508 | }, 509 | { 510 | "name": "stderr", 511 | "output_type": "stream", 512 | "text": [ 513 | "\n" 514 | ] 515 | } 516 | ], 517 | "source": [ 518 | "clip_model_path=\"/home/tiger/project/pretrained_models/clip-vit-large-patch14\"\n", 519 | "clip_model, processor, tokenizer = load_clip_model(clip_model_path)\n", 520 | "clip_scores = []\n", 521 | "for output_path in output_paths:\n", 522 | " clip_score_each = clip_score(clip_model, processor, tokenizer, cocosubsetloader, output_path) # 3min ....\n", 523 | " print(output_path, clip_score_each)\n", 524 | " clip_scores.append(clip_score_each)" 525 | ] 526 | }, 527 | { 528 | "cell_type": "code", 529 | "execution_count": 7, 530 | "metadata": {}, 531 | "outputs": [ 532 | { 533 | "data": { 534 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjMAAAGxCAYAAACXwjeMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAA9hAAAPYQGoP6dpAABU5ElEQVR4nO3deVwUdeMH8M/sArscyyIol4CCiIoY4IFnHmWmj2mm5VGWdmcemT1lPuVVmVq/NK208vGozDPx6PJJU/HWVFARREBUVBAB2eW+dn5/GFskKMcuM7t83q/Xvnrt7LDzYV4b+3Hm+50RRFEUQURERGShFFIHICIiIqoPlhkiIiKyaCwzREREZNFYZoiIiMiiscwQERGRRWOZISIiIovGMkNEREQWjWWGiIiILJqN1AHMzWAw4Pr169BoNBAEQeo4REREVAOiKCI3Nxfe3t5QKO5+7MXqy8z169fh6+srdQwiIiKqg9TUVPj4+Nx1HUnLzPz58xEZGYnz58/D3t4ePXr0wMKFC9GmTRvjOnPmzMGGDRuQmpoKOzs7dOrUCfPmzUPXrl1rtA2NRgPg9s5wdnY2y+9BREREpqXX6+Hr62v8Hr8bSctMVFQUJk6ciC5duqCsrAzvvPMOBgwYgLi4ODg6OgIAgoKC8PnnnyMgIACFhYVYvHgxBgwYgKSkJDRr1uye26g4teTs7MwyQ0REZGFqMkREkNONJm/evAl3d3dERUWhd+/eVa6j1+uh1Wqxe/duPPjgg/d8z4r1dTodywwREZGFqM33t6zGzOh0OgCAq6trla+XlJTg66+/hlarRWhoaJXrFBcXo7i42Phcr9ebPigRERHJhmymZouiiGnTpqFXr14ICQmp9NpPP/0EJycnqNVqLF68GLt27ULTpk2rfJ/58+dDq9UaHxz8S0REZN1kc5pp4sSJ+Pnnn3Hw4ME7Ri3n5+cjLS0NmZmZWLFiBfbs2YNjx47B3d39jvep6siMr68vTzMRERFZkNqcZpLFkZnJkydjx44d2Lt3b5XTrxwdHREYGIhu3bph5cqVsLGxwcqVK6t8L5VKZRzsy0G/RERE1k/SMTOiKGLy5MnYunUr9u3bB39//xr/3N+PvhAREVHjJWmZmThxItatW4ft27dDo9EgPT0dAKDVamFvb4/8/HzMmzcPQ4cOhZeXF7KysrBs2TJcvXoVTzzxhJTRiYiISCYkLTPLly8HAPTt27fS8tWrV2P8+PFQKpU4f/48vvnmG2RmZsLNzQ1dunTBgQMH0L59ewkSExERkdxIfprpbtRqNSIjIxsoDREREVkiWV1nxpKUG0QcT8lGRm4R3DVqRPi7QqngjSyJiIgaGstMHeyMTcPcH+OQpisyLvPSqjF7SDAGhnhJmIyIiKjxkcXUbEuyMzYNE9aeqlRkACBdV4QJa09hZ2yaRMmIiIgaJ5aZWig3iJj7YxyqGulTsWzuj3EoN8jiOoRERESNAstMLRxPyb7jiMzfiQDSdEU4npLdcKGIiIgaOZaZWsjIrb7I1GU9IiIiqj+WmVpw16hNuh4RERHVH8tMLUT4u8JLq0Z1E7AF3J7VFOHv2pCxiIiIGjWWmVpQKgTMHhIMAFUWGhHA7CHBvN4MERFRA2KZqaWBIV5YPrYjPLV3nkpSKgCfJg4SpCIiImq8BPFe9xSwcHq9HlqtFjqdDs7OziZ738pXAFZhzaFL+F/cDfg3dcRPk3vBUcXrERIREdVVbb6/+Y1bR0qFgO6t3IzP23k548ySA0jJzMecHefw8ROhEqYjIiJqPHiayURcHOyweFQYFAKw+eRV/Hj6utSRiIiIGgWWGRPqFuCGSf0CAQD/iTyL1OwCiRMRERFZP5YZE5vyYGt0atEEucVleG1DNMrKDVJHIiIismosMyZmo1Tg01Fh0KhtcOpKDpb8nih1JCIiIqvGMmMGvq4O+PCxDgCAz/cm4ejFLIkTERERWS+WGTMZEuqNkZ19IIrA6xtjcCu/ROpIREREVollxozmDG2PgKaOSNMVYfqWM7DyS/oQERFJgmXGjBzsbLB0TDhslQJ+i7uB749dkToSERGR1WGZMbOQ5lpMH9gWAPD+T3G4cCNX4kRERETWhWWmATzX0x99gpqhuMyAyeuiUVRaLnUkIiIiq8Ey0wAUCgH/90QomjqpkHAjFx/+Ei91JCIiIqvBMtNAmmlU+GTk7fs1fXvkMnbF3ZA4ERERkXVgmWlAfYKa4cX7/QEAb/1wGum6IokTERERWT6WmQb25sNtEdLcGbcKSvH6xhiUGzhdm4iIqD5YZhqYnY0CS0eHw8FOiSMXs/BlVLLUkYiIiCway4wEApo5Ye7Q9gCARbsuIPrKLYkTERERWS6WGYk83skHQ0K9UW4QMWVDNPRFpVJHIiIiskgsMxIRBAHzHguBTxN7pGYXYua2WN7ugIiIqA5YZiTkrLbFktHhUCoEbI+5jshT16SOREREZHFYZiTWqUUTvN6/NQBg1vZYpGTmS5yIiIjIsrDMyMCEvoHoFuCK/JJyTFkfjZIyg9SRiIiILAbLjAwoFQIWjwqDi4Mtzl7T4ZPfEqSOREREZDFYZmTCS2uPhSPuAwB8tf8iDiTelDgRERGRZWCZkZGH23tibDc/AMC0TaeRmVcscSIiIiL5Y5mRmXcHByPIwwk3c4vx5ubTnK5NRER0DywzMqO2VWLpmHDY2SiwN+EmVh+6JHUkIiIiWZO0zMyfPx9dunSBRqOBu7s7hg0bhoSEvwa/lpaWYvr06ejQoQMcHR3h7e2NZ555BtevX5cwtfm19XTGu4PbAQAW/Hoe567rJE5EREQkX5KWmaioKEycOBFHjx7Frl27UFZWhgEDBiA///a1VgoKCnDq1CnMnDkTp06dQmRkJC5cuIChQ4dKGbtBPN2tBfq380BJuQFT1kejoKRM6khERESyJIgyGpRx8+ZNuLu7IyoqCr17965ynT/++AMRERG4fPky/Pz87vmeer0eWq0WOp0Ozs7Opo5sVtn5JRi0ZD9u6IsxuosvFvw524mIiMja1eb7W1ZjZnS626dTXF1d77qOIAhwcXFpoFTScXW0w+JRYRAEYMMfqfjlbJrUkYiIiGRHNmVGFEVMmzYNvXr1QkhISJXrFBUV4e2338aTTz5ZbUsrLi6GXq+v9LBkPVo1xat9WwEA3t5yBtdyCiVOREREJC+yKTOTJk3CmTNnsH79+ipfLy0txejRo2EwGLBs2bJq32f+/PnQarXGh6+vr7kiN5ip/YMQ5usCfVEZpm6IRlk5b3dARERUQRZlZvLkydixYwf27t0LHx+fO14vLS3FyJEjkZKSgl27dt313NmMGTOg0+mMj9TUVHNGbxC2SgWWjg6Hk8oGf1y6hc/2JEkdiYiISDYkLTOiKGLSpEmIjIzEnj174O/vf8c6FUUmMTERu3fvhpub213fU6VSwdnZudLDGvi5OWDeY7dPv322JxHHU7IlTkRERCQPkpaZiRMnYu3atVi3bh00Gg3S09ORnp6OwsLb40LKysrw+OOP48SJE/j+++9RXl5uXKekpETK6JJ4NKw5RnT0gUEEpm6Ihq6gVOpIREREkpN0arYgCFUuX716NcaPH49Lly5VebQGAPbu3Yu+ffvecxuWPDW7KnnFZXhk6QFcyirAoBBPLHuqY7X7kYiIyFLV5vvbpoEyVelePaply5a8N9E/OKlssHRMOEYsP4xfY9Ox4Y9UjIm49/V2iIiIrJUsBgBT7dzn44J/D2gDAJj74zkkZeRKnIiIiEg6LDMW6sX7A3B/66YoKjVg8voYFJWWSx2JiIhIEiwzFkqhEPDJyFC4OdohPk2PBb+elzoSERGRJFhmLJi7Ro3/eyIUALDm8CXsOX9D4kREREQNj2XGwvVr647net6e8fXvzWeQoS+SOBEREVHDYpmxAtMHtUGwlzOy80swbdNpGAycAUZERI0Hy4wVUNkosXRMOOxtlTiYlImvD1yUOhIREVGDYZmxEoHuTpg9JBgA8H//S8Dp1BxpAxERETUQlhkrMqqLLwZ38EKZQcSUDdHIKy6TOhIREZHZscxYEUEQ8OHwDmjuYo/LWQWYtS1W6khERERmxzJjZbT2tlgyOgwKAYiMvoZt0dekjkRERGRWLDNWqHNLV0x5sDUA4N1tsbiSVSBxIiIiIvNhmbFSk/oFIqKlK/KKyzB5QzRKyw1SRyIiIjILlhkrZaNUYPHoMDirbXA6NQeLdl2QOhIREZFZsMxYseYu9lg44j4AwJdRyTiclClxIiIiItNjmbFygzp4YUyEL0QRmLoxBtn5JVJHIiIiMimWmUZg1iPtEejuhIzcYrz1w2mIIm93QERE1oNlphGwt1Ni6ehw2CkV2B2fgW+PXJY6EhERkcmwzDQSwd7OmPGvtgCAeb/EIz5NL3EiIiIi02CZaUTG92iJB9q6o6TMgCnro1FYUi51JCIionpjmWlEBEHAx4/fB3eNCokZeXj/5zipIxEREdUby0wj4+akwqKRYRAEYN2xK9gZmyZ1JCIionphmWmEerVuipd6BwAApm85i+s5hRInIiIiqjuWmUbqjYfaINRHC11hKaZujEG5gdO1iYjIMrHMNFJ2NgosHRMORzsljqdk44u9SVJHIiIiqhOWmUashZsj3h8WAgBY8nsiTl7OljgRERFR7bHMNHLDO/pgWJg3yg0ipqyPga6wVOpIREREtcIyQ3h/WAj8XB1wLacQ72w9y9sdEBGRRWGZIWjUtlg6Jhw2CgE/nUnD5hNXpY5ERERUYywzBAAI83XBtAFBAIDZO84h+WaexImIiIhqhmWGjF7p3Qo9WrmhsLQcU9ZHo7iMtzsgIiL5Y5khI4VCwOJRYWjiYItz1/X4aGeC1JGIiIjuiWWGKvFwVuPjx0MBACsPpmBfQobEiYiIiO6OZYbu0D/YA+O6twAA/HvzadzMLZY4ERERUfVYZqhKM/7VDm09NcjMK8Ebm0/DwNsdEBGRTLHMUJXUtkp8NiYcalsF9l+4iZUHU6SOREREVCWWGapWaw8NZj4SDAD46H/ncfaqTuJEREREd2KZobt6MsIPA9t7orRcxJQN0cgvLpM6EhERUSUsM3RXgiBgwYgO8NKqkZKZjzk7zkkdiYiIqBJJy8z8+fPRpUsXaDQauLu7Y9iwYUhIqHxtk8jISDz88MNo2rQpBEFATEyMNGEbMRcHO3w6KgwKAdh88ip2nL4udSQiIiIjSctMVFQUJk6ciKNHj2LXrl0oKyvDgAEDkJ+fb1wnPz8fPXv2xIIFCyRMSl0D3DCpXyAA4J3Is0jNLpA4ERER0W2CKKNbJN+8eRPu7u6IiopC7969K7126dIl+Pv7Izo6GmFhYTV+T71eD61WC51OB2dnZxMnblzKyg0Y9fVRnLx8C+F+Ltj0cnfYKnmmkoiITK8239+y+ibS6W7PlnF1da3zexQXF0Ov11d6kGnYKBX4dFQYNGobRF/JwZLdiVJHIiIikk+ZEUUR06ZNQ69evRASElLn95k/fz60Wq3x4evra8KU5OvqgPnDOwAAvtiXhCPJWRInIiKixk42ZWbSpEk4c+YM1q9fX6/3mTFjBnQ6nfGRmppqooRU4ZH7vDGysw9EEXh9Ywxu5ZdIHYmIiBoxWZSZyZMnY8eOHdi7dy98fHzq9V4qlQrOzs6VHmR6c4a2R0BTR6TrizB9yxnIaOgVERE1MpKWGVEUMWnSJERGRmLPnj3w9/eXMg7VgoOdDZaOCYedUoHf4m5g7bErUkciIqJGStIyM3HiRKxduxbr1q2DRqNBeno60tPTUVhYaFwnOzsbMTExiIuLAwAkJCQgJiYG6enpUsWmP4U01+KtgW0AAB/8FIeE9FyJExERUWMkaZlZvnw5dDod+vbtCy8vL+Nj48aNxnV27NiB8PBwDB48GAAwevRohIeH48svv5QqNv3Ncz390SeoGYrLDJiyPhpFpeVSRyIiokZGVteZMQdeZ8b8buYWY9CSA8jMK8Yz3VvgvUfrPhuNiIgIsODrzJBlaqZRYdHIUADAt0cu47dzPAVIREQNh2WGTKJ3UDO8eP/tAdxvbTmDdF2RxImIiKixYJkhk3nz4bYIae6MnIJSvL4xBuUGqz6DSUREMsEyQyZjZ6PA0tHhcLBT4sjFLHwZlSx1JCIiagRYZsikApo5Ye7Q9gCARbsu4NSVWxInIiIia8cyQyb3eCcfDAn1RrlBxGsboqEvKpU6EhERWTGWGTI5QRAw77EQ+DSxR2p2Id7dGsvbHRARkdmwzJBZOKttsWR0OJQKATtOX8eWU9ekjkRERFaKZYbMplOLJni9f2sAwKztsUjJzJc4ERERWSOWGTKrCX0D0S3AFQUl5ZiyPholZQapIxERkZVhmSGzUioELB4VBhcHW5y9psMnvyVIHYmIiKwMywyZnZfWHgtH3AcA+Gr/Rey/cFPiREREZE1YZqhBPNzeE2O7+QEApm06jcy8YokTERGRtWCZoQbz7uBgBHk4ITOvGG9uPs3p2kREZBIsM9Rg1LZKfDamI1Q2CuxNuInVhy5JHYmIiKwAyww1qDaeGrw7uB0AYMGv5xF7TSdxIiIisnQsM9TgxnZrgYeCPVBSbsCUDdEoKCmTOhIREVkwlhlqcIIgYOGI++DhrMLFm/l478c4qSMREZEFY5khSbg62mHxqDAIArDhj1T8fCZN6khERGShWGZIMj1aNcWrfVsBAN6OPIOrtwokTkRERJaIZYYkNbV/EMJ8XZBbVIapG2JQVs7bHRARUe2wzJCkbJUKLB0dDieVDU5cvoXP9iRJHYmIiCwMywxJzs/NAfMeCwEAfLYnEcdTsiVOREREloRlhmTh0bDmGNHRBwYRmLohGrqCUqkjERGRhWCZIdmY+2h7tHRzwHVdEd6OPMPbHRARUY2wzJBsOKls8NmYjrBVCvg1Nh0b/kiVOhIREVkAlhmSlQ4+Wrz5cBsAwNwfzyHxRq7EiYiISO5YZkh2XugVgPtbN0VRqQGT10ejqLRc6khERCRjLDMkOwqFgE9GhsLN0Q7n03Ox4NfzUkciIiIZY5khWXLXqPF/T4QCANYcvoTf429InIiIiOSKZYZkq19bdzzX0x8A8OYPZ5ChL5I4ERERyRHLDMna9EFtEOzljOz8EkzbdBoGA6drExE1hHKDiCPJWdgecw1HkrNQLuO/vzZSByC6G5WNEkvHhGPIZwdxMCkTXx+4iFf6tJI6FhGRVdsZm4a5P8YhTffXEXEvrRqzhwRjYIiXhMmqxiMzJHuB7k6YMzQYAPB//0tATGqOtIGIiKzYztg0TFh7qlKRAYB0XREmrD2FnbFpEiWrHssMWYSRnX0xuIMXygwiXtsQjbziMqkjERFZnXKDiLk/xqGqE0oVy+b+GCe7U04sM2QRBEHAh8M7oLmLPS5nFWDWtlipIxERWZ3jKdl3HJH5OxFAmq5IdjcEZpkhi6G1t8WS0WFQCEBk9DVsjb4qdSQiIquSkVuzWaM1Xa+hsMyQRenc0hWvPRgEAHh3aywuZ+VLnIiIyHq4a9QmXa+hSFpm5s+fjy5dukCj0cDd3R3Dhg1DQkJCpXVEUcScOXPg7e0Ne3t79O3bF+fOnZMoMcnBpAcCEdHSFfkl5ZiyIQal5QapIxERWYUIf1d4adUQqnldwO1ZTRH+rg0Z654kLTNRUVGYOHEijh49il27dqGsrAwDBgxAfv5f/9r+6KOPsGjRInz++ef4448/4OnpiYceegi5ubwBYWOlVAhYPDoMzmobnE7NwaJdF6SORERkFZQKAbOH3J49+s9CU/F89pBgKBXV1R1pCKIoymZI8s2bN+Hu7o6oqCj07t0boijC29sbU6dOxfTp0wEAxcXF8PDwwMKFC/Hyyy/f8z31ej20Wi10Oh2cnZ3N/StQA/r1bBomfH8KggCsfb4regY2lToSEZFVkMN1Zmrz/S2rMTM6nQ4A4Op6+/BVSkoK0tPTMWDAAOM6KpUKffr0weHDh6t8j+LiYuj1+koPsk6DOnhhTIQfRBF4fWMMsvKKpY5ERGQVBoZ44eD0B/DfZzobl+2c2luWF8wDZFRmRFHEtGnT0KtXL4SEhAAA0tPTAQAeHh6V1vXw8DC+9k/z58+HVqs1Pnx9fc0bnCQ165FgBLo7ISO3GNO3nIGMDjQSEVk0pUJA/2APNHVSAQAuZcp3woVsysykSZNw5swZrF+//o7XBKHyuTlRFO9YVmHGjBnQ6XTGR2pqqlnykjzY2ymxdHQ47JQK7I7PwLdHLksdiYjIqgS6OwIAkjLyJE5SPVmUmcmTJ2PHjh3Yu3cvfHx8jMs9PT0B4I6jMBkZGXccramgUqng7Oxc6UHWLdjbGf/5V1sAwLxf4hGfxlOLRESmEujuBABIuskyUyVRFDFp0iRERkZiz5498Pf3r/S6v78/PD09sWvXLuOykpISREVFoUePHg0dl2RsXI+WeLCtO0rKDJi8PhqFJeVSRyIisgqtmt0uM8nWdmQmOTkZ7777LsaMGYOMjAwAwM6dO2t9/ZeJEydi7dq1WLduHTQaDdLT05Geno7CwkIAt08vTZ06FR9++CG2bt2K2NhYjB8/Hg4ODnjyySfrEp2slCAI+Ojx++CuUSEpIw/v/xwndSQiIqtglUdmoqKi0KFDBxw7dgyRkZHIy7v9y505cwazZ8+u1XstX74cOp0Offv2hZeXl/GxceNG4zpvvfUWpk6dildffRWdO3fGtWvX8Ntvv0Gj0dQ2Olk5NycVFo0MgyAA645dkeWdXYmILE1FmbmcVYCSMnlepLTWZebtt9/GBx98gF27dsHOzs64vF+/fjhy5Eit3ksUxSof48ePN64jCALmzJmDtLQ0FBUVISoqyjjbieiferVuipd7twIATN9yFtdzCiVORERk2Tyd1XBS2aDcIMr2FjK1LjNnz57FY489dsfyZs2aISsryyShiOrjjQFBCPXRQldYiqkbY2R3q3oiIksiCAJaNZP3jKZalxkXFxekpd15+D46OhrNmzc3SSii+rBVKrB0TDgc7ZQ4npKNL/YmSR2JiMiitaoYN2MtZebJJ5/E9OnTkZ6eDkEQYDAYcOjQIfz73//GM888Y46MRLXWws0R7w+7fTry090XcOJStsSJiIgsl3FGk0wHAde6zMybNw9+fn5o3rw58vLyEBwcjN69e6NHjx549913zZGRqE6Gd/TBY+HNYRCB1zbEQFdYKnUkIiKLJPcZTbUqM6Io4vr161ixYgUSExOxadMmrF27FufPn8d3330HpVJprpxEdfLeo+3h5+qAazmF+M/Ws7zdARFRHVSUmeSMfBhkOA7RpjYri6KI1q1b49y5c2jdujUCAgLMlYvIJDRqWywdE47Hlx/Gz2fS0Kd1M4zswvt1ERHVRgtXB9gqBRSWluO6rhA+TRykjlRJrY7MKBQKtG7dmrOWyKKE+brgjQFtAACzd5yT7QA2IiK5slEq0NJNvjOaaj1m5qOPPsKbb76J2NhYc+QhMouXewegZ6AbCkvLMWV9NIrLeLsDIqLaCJTxjKZal5mxY8fi+PHjCA0Nhb29PVxdXSs9iORIoRCwaGQYmjjYIi5Nj492JkgdiYjIovw1o0l+F86r1ZgZAPj000/NEIPI/Dyc1fj48VC88O0JrDyYgl6tm6JfG3epYxERWYS/BgHL78hMrcvMuHHjzJGDqEH0D/bA+B4tsebwJfx702n8OvV+uGvUUsciIpI9OU/PrnWZAYDy8nJs27YN8fHxEAQBwcHBGDp0KKdmk0V4e1BbHL2YhfPpuXhj02l882wEFApB6lhERLIW8OctDbLzS5CdXwJXR7t7/ETDqfWYmaSkJLRr1w7PPPMMIiMj8cMPP2Ds2LFo3749kpOTzZGRyKTUtkp8NiYcalsFDiRmYuXBFKkjERHJnoOdDZq72AOQ3yDgWpeZKVOmoFWrVkhNTcWpU6cQHR2NK1euwN/fH1OmTDFHRiKTa+2hwcxHggEAH/3vPM5e1UmciIhI/uQ6o6nWZSYqKgofffRRpZlLbm5uWLBgAaKiokwajsicnozww8D2nigtFzFlQzTyi8ukjkREJGtyvUdTrcuMSqVCbm7uHcvz8vJgZyef82dE9yIIAhaM6AAvrRopmfmYveOc1JGIiGTNao7MPPLII3jppZdw7NgxiKIIURRx9OhRvPLKKxg6dKg5MhKZjYuDHT4dFQaFAPxw8iq2x1yTOhIRkWxZTZlZunQpWrVqhe7du0OtVkOtVqNnz54IDAzEkiVLzJGRyKy6BrhhUr9AAMC7W2ORml0gcSIiInmqKDPXcgpRUCKfU/O1nprt4uKC7du3IykpCfHx8RBFEcHBwQgMDDRHPqIGMeXB1jiUnIWTl29hyoZobHq5O2yVte76RERWzdXRDq6OdsjOL8HFm/kIaa6VOhKAOhyZqRAYGIghQ4Zg6NChLDJk8WyUCiwZHQaN2gbRV3KwZHei1JGIiGQpUIaDgGtdZh5//HEsWLDgjuUff/wxnnjiCZOEIpKCTxMHzB/eAQDwxb4kHEnm3eGJiP6plbv87p5dp6nZgwcPvmP5wIEDsX//fpOEIpLKI/d5Y1RnX4gi8PrGGNzKL5E6EhGRrFRMz7boMlPdFGxbW1vo9XqThCKS0uyhwQho5oh0fRHe2nIGoihKHYmISDbkOKOp1mUmJCQEGzduvGP5hg0bEBwcbJJQRFJysLPB0tHhsFMqsCvuBtYeuyJ1JCIi2agoM5ey8lFWbpA4zW21ns00c+ZMjBgxAsnJyXjggQcAAL///jvWr1+PzZs3mzwgkRRCmmvx1sA2+ODneHzwUxwiWrqijadG6lhERJLz1trD3laJwtJyXM4uMJ52klKtj8wMHToU27ZtQ1JSEl599VW88cYbuHr1Knbv3o1hw4aZISKRNJ7r6Y++bZqhuMyAyetPoai0XOpIRESSUygE4yDgZJmcaqrT1OzBgwfj0KFDyM/PR2ZmJvbs2YM+ffqYOhuRpBQKAf/3RCiaOqlw4UYe5v0cL3UkIiJZqJienSST6dm1LjOpqam4evWq8fnx48cxdepUfP311yYNRiQHTZ1UWDQyFADw3dHL+O1cusSJiIikJ7cZTbUuM08++ST27t0LAEhPT0f//v1x/Phx/Oc//8F7771n8oBEUusd1Awv3u8PAHhryxmk6QolTkREJK2KQcAWe5opNjYWERERAIBNmzahQ4cOOHz4MNatW4c1a9aYOh+RLLz5cFt0aK5FTkEpXt8Yg3IDp2sTUeNlLDM382Vx+Ypal5nS0lKoVCoAwO7du413ym7bti3S0tJMm45IJuxsFFg6JhwOdkocvZiNL6OSpY5ERCSZFm6OUCoE5BWXIV1fJHWc2peZ9u3b48svv8SBAwewa9cuDBw4EABw/fp1uLm5mTwgkVz4N3XE3KHtAQCLdl3AqSu3JE5ERCQNOxsFWrg5AACSM/IlTlOHMrNw4UJ89dVX6Nu3L8aMGYPQ0NuDI3fs2GE8/URkrR7v5IMhod4oN4iYsj4a+qJSqSMREUnCOKMpI1fiJHW4aF7fvn2RmZkJvV6PJk2aGJe/9NJLcHBwMGk4IrkRBAHzHgtB9JVbuHqrEO9ujcWS0WEQBEHqaEREDaqVuxMQd0MW07PrdJ0ZpVJZqcgAQMuWLeHu7m6SUERy5qy2xdIx4VAqBOw4fR1bTl2TOhIRUYMLaHr7wnlHL2bhSHKWpBMj6lRmiBq7jn5N8Hr/1gCAWdtjcVEG/zIhImooO2PTsODX8wCApIx8jFlxFL0W7sHOWGkmArHMENXRhL6B6BbgioKScry2IQYlZfK44RoRkTntjE3DhLWnkJVfUml5uq4IE9aekqTQsMwQ1ZFSIeDTUeFwcbDF2Ws6/N9vCVJHIiIyq3KDiLk/xqGqE0oVy+b+GNfgp5wkLTP79+/HkCFD4O3tDUEQsG3btkqv37hxA+PHj4e3tzccHBwwcOBAJCYmShOWqAqeWjU+GnEfAODr/Rex/8JNiRMREZnP8ZRspOmqv66MCCBNV4TjKdkNFwo1nM20dOnSGr/hlClTarxufn4+QkND8eyzz2LEiBGVXhNFEcOGDYOtrS22b98OZ2dnLFq0CP3790dcXBwcHR1rvB0icxrQ3hNju/lh7dErmLbpNHZOvR9NnVRSxyIiMrmM3JpdIK+m65lKjcrM4sWLKz2/efMmCgoK4OLiAgDIycmBg4MD3N3da1VmBg0ahEGDBlX5WmJiIo4ePYrY2Fi0b3/7QmXLli2Du7s71q9fjxdeeKHG2yEyt3cHB+N4SjYu3MjDvzefxqpxXaBQcLo2EVkXd43apOuZSo1OM6WkpBgf8+bNQ1hYGOLj45GdnY3s7GzEx8ejY8eOeP/9900WrLi4GACgVv+1Q5RKJezs7HDw4MG7/pxer6/0IDI3ta0Sn43pCJWNAvsSbmL14UtSRyIiMrkIf1d4adWo7p9qAgAvrRoR/q4NGav2Y2ZmzpyJzz77DG3atDEua9OmDRYvXox3333XZMHatm2LFi1aYMaMGbh16xZKSkqwYMECpKen3/UeUPPnz4dWqzU+fH19TZaJ6G7aeGrw7uB2AICFv55H7DWdxImIiExLqRAwe0hwla9VFJzZQ4KhbOAj07UuM2lpaSgtvfMS7uXl5bhx44ZJQgGAra0ttmzZggsXLsDV1RUODg7Yt28fBg0aBKVSWe3PzZgxAzqdzvhITU01WSaiexnbrQUeCvZASbkBUzZEo6CkTOpIREQmNTDEC8vHdoS9XeXvYk+tGsvHdsTAEK8Gz1TrMvPggw/ixRdfxIkTJ4y3/T5x4gRefvll9O/f36ThOnXqhJiYGOTk5CAtLQ07d+5EVlYW/P39q/0ZlUoFZ2fnSg+ihiIIAj4acR88ndW4eDMfc3fESR2JiMjkBoZ4wVNze6LDS/f7Y/2L3XBw+gOSFBmgDmVm1apVaN68OSIiIqBWq6FSqdC1a1d4eXnhv//9rzkyQqvVolmzZkhMTMSJEyfw6KOPmmU7RKbQxNEOi0aFQhCAjSdS8fMZaa6ISURkLjf0RUjJKoBCACY+0BrdW7k1+Kmlv6v1jSabNWuGX375BRcuXMD58+chiiLatWuHoKCgWm88Ly8PSUlJxucpKSmIiYmBq6sr/Pz8sHnzZjRr1gx+fn44e/YsXnvtNQwbNgwDBgyo9baIGlKPVk3xat9W+GJvMt6OPINQXy18mvBGrERkHY5ezAIAtPfWQmtvK3GaOpSZCkFBQXUqMH934sQJ9OvXz/h82rRpAIBx48ZhzZo1SEtLw7Rp03Djxg14eXnhmWeewcyZM+u1TaKGMrV/EA4lZSEmNQevbYjBxpe6wUbJi24TkeU7evH2RfG6BTTsrKXqCGLFwJe7mDZtGt5//304OjoaC0d1Fi1aZLJwpqDX66HVaqHT6Th+hhpcanYB/rXkAHKLyzDlwdaY9lD9/gFARCQHD/zfPlzMzMd/n+mM/sEeZtlGbb6/a3RkJjo62jiD6dSpUxCEqs+LVbecqLHydXXAB4+F4LUNMfh8TyJ6tnJD1wA3qWMREdXZDX0RLmbmQxCALg18PZnq1KjMLFmyxNiK9u3bZ848RFbn0bDm2H8hE1tOXcXUjTH49bX74eJgJ3UsIqI6+Wu8jLMsxssANZzNFB4ejszMTABAQEAAsrKyzBqKyNrMfbQ9/Js6Ik1XhLe3nEUNzu4SEcmScbyMv3yOMteozLi4uCAlJQUAcOnSJRgMBrOGIrI2TiobLB0dDlulgJ3n0rH+OC/mSESW6difR2a6yeiUeY1OM40YMQJ9+vSBl5cXBEFA586dq70K78WLF00akMhadPDR4s2H2+DDX87jvZ/OoUvLJmjtoZE6FhFRjWXIcLwMUMMy8/XXX2P48OFISkrClClT8OKLL0Kj4R9hotp6oVcADiRm4kBiJiavj8a2iT2htq3+9hxERHJyNOX2KSY5jZcBanGdmYEDBwIATp48iddee41lhqgOFAoBn4wMxaBPD+B8ei4W/Hoec4a2lzoWEVGNVAz+ldN4GaAOtzNYvXo1iwxRPbhr1Pi/kaEAgDWHL+H3eNPdoJWIyJyOynC8DFCHMkNE9devjTue63n7hqlv/nAGN/RFEiciIrq7DH0RLt6U33gZgGWGSDLTB7VBsJczsvNLMG1TDAwGTtcmIvmqGC8T7CWv8TIAywyRZFQ2SiwdEw57WyUOJWXhq/2cCUhE8iXXU0wAywyRpALdnTBnaDAA4JPfEhCTmiNtICKiarDMEFG1Rnb2xeAOXigziJiyPhq5RaVSRyIiquTv42UiWsprvAzAMkMkOUEQ8OHwDmjuYo8r2QWYtf2c1JGIiCqpNF7GQV7jZQCWGSJZ0NrbYsnoMCgEYGv0NWyNvip1JCIiIznewuDvWGaIZKJzS1e89mAQAODdrbG4nJUvcSIiotvkPF4GYJkhkpVJDwQioqUr8kvKMWV9NErKeFNXIpJWRm4RkmU8XgZgmSGSFaVCwOLRYdDa2+L0VR0W7bogdSQiauSOXZT3eBmAZYZIdpq72GPhiA4AgK/2J+NQUqbEiYioMas4xdRVZvdj+juWGSIZGhjihTERfhBF4PWNMcjKK5Y6EhE1Un+Nl5HnKSaAZYZItmY9EoxAdydk5BbjzR/OQBR5uwMialiVxsvI7H5Mf8cyQyRT9nZKLB0dDjsbBfacz8A3hy9JHYmIGpmK8TLtPJ3h4mAncZrqscwQyViwtzP+M6gtAODDX88jPk0vcSIiakzkPiW7AssMkcyN69ESD7Z1R0mZAZPXR6OwpFzqSETUSFjCeBmAZYZI9gRBwEeP3wd3jQpJGXl476c4qSMRUSNgKeNlAJYZIovg5qTC4lFhEARg/fEr2BmbJnUkIrJyx1MsY7wMwDJDZDF6BjbFy71bAQDe+uEMruUUSpyIiKyZpYyXAVhmiCzKGwOCEOqjhb6oDK9viEG5gdO1icg8jv45k6mrzMfLACwzRBbFVqnA0jHhcFLZ4PilbHy+J0nqSERkhW7mFiMpIw+CAHSV+XgZgGWGyOK0cHPE+8PaAwCW/H4BJy5lS5yIiKzNsZTbp5jaWsB4GYBlhsgiPRbug8fCm8MgAq9tiEF2fgmOJGdhe8w1HEnO4uknIqoXS5mSXcFG6gBEVDfvPdoeJy/fwpXsAvRY8DuKSg3G17y0asweEoyBIV4SJiQiS1UxXsYSBv8CPDJDZLE0aluMifADgEpFBgDSdUWYsPYUp3ATUa1Z2ngZgGWGyGKVG0R8e+RSla9VnGSa+2McTzkRUa1Y2ngZgGWGyGIdT8lGmq6o2tdFAGm6IuOFr4iIasLSxssALDNEFisjt/oiU5f1iIiAv+6UbSnjZQCWGSKL5a5Rm3Q9IqLMvGIkZuQBACJa8sgMEZlZhL8rvLRqCHdZx0llgy4tmzRYJiKybBVHZdp6atDE0TLGywASl5n9+/djyJAh8Pb2hiAI2LZtW6XX8/LyMGnSJPj4+MDe3h7t2rXD8uXLpQlLJDNKhYDZQ4IBoNpCk1dchre2nEFJmaGaNYiI/mJJ92P6O0nLTH5+PkJDQ/H5559X+frrr7+OnTt3Yu3atYiPj8frr7+OyZMnY/v27Q2clEieBoZ4YfnYjvDUVj6V5KVV48kIPygVAiJPXcO4VcehKyiVKCURWQpLLTOSXjRv0KBBGDRoULWvHzlyBOPGjUPfvn0BAC+99BK++uornDhxAo8++mgDpSSSt4EhXngo2BPHU7KRkVsEd40aEf6uUCoEDGjvgYnfn8KRi1kYvvwQ1jwbAV9XB6kjE5EM/X28jKVcX6aCrMfM9OrVCzt27MC1a9cgiiL27t2LCxcu4OGHH5Y6GpGsKBUCurdyw6NhzdG9lRuUitsnnvq2ccfmV3rA01mN5Jv5eGzZIURfuSVxWiKSm3KDiO+OXAYA+Daxh7O9rcSJakfWZWbp0qUIDg6Gj48P7OzsMHDgQCxbtgy9evWq9meKi4uh1+srPYgas2BvZ2yb2BPBXs7IzCvB6K+P8srARGS0MzYNvRbuwZLfEwEAqbcK0WvhHov6OyH7MnP06FHs2LEDJ0+exCeffIJXX30Vu3fvrvZn5s+fD61Wa3z4+vo2YGIiefLUqrHple7o16YZissMmPD9Kfz3wEWIIq8OTNSY7YxNw4S1p+64AKel3RJFEGXy10wQBGzduhXDhg0DABQWFkKr1WLr1q0YPHiwcb0XXngBV69exc6dO6t8n+LiYhQXFxuf6/V6+Pr6QqfTwdnZ2ay/A5HclZUbMOfHc1h79AoA4OluLTB7SDBslLL+dw0RmUG5QUSvhXuqvZK4gNv/EDo4/QHjqeuGpNfrodVqa/T9Ldu/YKWlpSgtLYVCUTmiUqmEwVD9NFOVSgVnZ+dKDyK6zUapwPuPhuCdf7WDIADfHb2Ml747ifziMqmjEVEDs6Zbokg6mykvLw9JSUnG5ykpKYiJiYGrqyv8/PzQp08fvPnmm7C3t0eLFi0QFRWFb7/9FosWLZIwNZFlEwQBL/YOgE8Te0zdGIM95zPwxJdHsGp8lzumeBOR9bKmW6JIemTmxIkTCA8PR3h4OABg2rRpCA8Px6xZswAAGzZsQJcuXfDUU08hODgYCxYswLx58/DKK69IGZvIKgzq4IUNL3WDm6Md4tL0eGzZIcSnccA8UWNhTbdEkc2YGXOpzTk3osboSlYBnl1zHMk38+GkssEXT3VEn6BmUsciIjM7npKNkV8dqfZ1jpkhIovh5+aAyAk90S3AFXnFZXhuzR9Yd+yK1LGIyIzOp+vx4rcnjM//WVUqns8eEixJkaktlhkigtbBFt8+1xXDw5uj3CDiP1vPYv6v8TAYrPrALVGjdCWrAE+vPA5dYSk6+rlgyeiwO8bLeWrVWD62IwaGeEmUsnYkHQBMRPJhZ6PAJyND4efmgE93J+KrqIu4ml2IT0aGQm2rlDoeEZlAhr4IY1cew83cYrT11GD1+AhoHWzxyH3eVd4SxVKwzBCRkSAImNo/CL5NHPB25Bn8fDYNabpCrHimM9ycVFLHI6J60BWU4plVx3EluwB+rg749rnbRQb465YoloqnmYjoDiM6+eDb57rCWW2DU1dyMHz5YSTfzJM6FhHVUUFJGZ5dcxzn03PRTKPC2ue7wt1Z/rOUaoplhoiq1L2VGyJf7QFfV3tczirA8GWHcexiltSxiKiWSsoMeGXtKZy6kgNntQ2+ez4Cfm4OUscyKZYZIqpWoLsGW1/tiTBfF+gKS/H0yuPYHnNN6lhEVEPlBhHTNsVg/4WbsLdVYvWzEWjraX2XKWGZIaK7auqkwvoXu2Fge0+UlBvw2oYYfPZ7Im9SSSRzoihi1vZY/HQmDbZKAV8+3QmdWjSROpZZsMwQ0T3Z2ymx7KmOePF+fwDAJ7su4K0fzqC0vPr7pBGRtP7vtwR8f+wKBAFYNDLMqi+GyTJDRDWiUAh4Z3Aw3h8WAoUAbD55FeNX375WBRHJy38PXMQXe5MBAB8MC8GQUG+JE5kXywwR1crT3Vpg5bgucLBT4lBSFh5ffhhXbxVIHYuI/rTpRCo++DkeAPDmw23wVNcWEicyP5YZIqq1fm3dsenl7vBwViExIw/DvjiMM1dzpI5F1OjtjE3H21vOAABevN8fr/ZtJXGihsEyQ0R1EtJci20Te6KtpwaZecUY9dVR/HYuXepYRI3W4aRMTFkfDYMIjOzsg//8qx0EwXKu4lsfLDNEVGdeWntsfqU7+gQ1Q2FpOV5eexKrDqZIHYuo0TmdmoMXvz2BknIDHm7vgQ8f69BoigzAMkNE9aRR22LluM54sqsfRBF476c4zNlxDuW8SSVRg0jKyMX41ceRX1KOHq3csGR0OGyUjevrvXH9tkRkFjZKBeYNC8GMQW0BAGsOX8LL351AQUmZxMmIrNvVWwUY+9/juFVQilAfLb5+pnOjvDEsywwRmYQgCHi5Tyt88WRH2NkosDs+A6O+OooMfZHU0YisUmZeMZ5eeRzp+iIEujth9bMRcFI1zvtHs8wQkUkNvs8L61/sBldHO5y9psNjyw4jIT1X6lhEVkVfVIpnVh5HSmY+mrvY47vnI+DqaCd1LMmwzBCRyXVq0QRbX+2BgKaOuJZTiMeXH8aBxJtSxyKyCkWl5XhhzQnEpenh5miH756PgJfWXupYkmKZISKzaOHmiMhXeyCipStyi8vw7Oo/sPGPK1LHIrJopeUGTPz+FI5fyoZGZYNvnotAQDMnqWNJjmWGiMzGxcEO370QgUfDvFFmEDF9y1l8/L/zMHCmE1GtGQwi3vrhDH4/nwGVjQL/HdcZIc21UseShcY5UoiIGozKRolPR4WhhasDlu5Jwhd7k3EluxAfP35fo5x1QVRT5QYRx1OykZFbBHeNCr/GpmNr9DUoFQKWPdURXQPcpI4oGywzRGR2giBg2oA28HV1wIzIs/jx9HWk5RTi62c6N+pBi0TV2Rmbhrk/xiFNd+dswE+eCMWD7TwkSCVfPM1ERA3mic6++Oa5CGjUNjhx+RaGLzuElMx8qWMRycrO2DRMWHuqyiIDAGpbfnX/E/cIETWonoFNETmhB5q72ONSVgGGLzuEE5eypY5FJAvlBhFzf4xDdaPKBABzf4zjFbb/gWWGiBpcaw8Ntk7sgVAfLW4VlOLJ/x7Dj6evSx2LSHLHU7KrPSIDACKANF0RjqfwHwB/xzJDRJJw16ix4aXuGBDsgZIyAyavj8YXe5MgivwXJzVeGbk1u2J2TddrLFhmiEgy9nZKLB/bCc/19AcAfPy/BLy95SxKyw0SJyOShrtGbdL1GguWGSKSlFIhYNaQYMwd2h4KAdh4IhXPrfkD+qJSqaMRNbgIf1d4adUQqnldAOClVSPC37UhY8keywwRycK4Hi3x9dOdYW+rxIHETDyx/Aiu5RRKHYuoQSkVAmYPCa7ytYqCM3tIMJSK6upO48QyQ0Sy0T/YA5tf6Q53jQoJN3Ix7ItDOHtVJ3UsogY1MMQLkx4IvGO5p1aN5WM7YmCIlwSp5I0XzSMiWQlprsXWiT3x3Oo/kHAjFyO/OoLPxoSjfzAvEkaNR2n57YHwvYOaYkRHH7hrbp9a4hGZqvHIDBHJTnMXe2ye0B33t26KwtJyvPTdCXxz+JLUsYgazJGLWQCAR0Ob49Gw5ujeyo1F5i5YZohIlpzVtlg1vgtGd/GFQQRm7ziH93ixMGoEcotKEXvt9unV7q14/6WaYJkhItmyVSowf3gHvDWwDQBg1aEUvLL2JApKyiRORmQ+Jy7dQrlBRAs3B3i72EsdxyKwzBCRrAmCgFf7BuKzMeGws1FgV9wNjP76KC8aRlar4hRTN38elakplhkisghDQr2x7oWuaOJgizNXdXjsi8NIvJErdSwikzuSfLvM8BRTzbHMEJHF6NzSFZGv9oR/U0dcyynE8OWHcTgpU+pYRCajKyzFuescL1NbLDNEZFH8mzoickIPdGnZBLlFZXhm1XFsPpEqdSwikziekg2DCAQ0dYSHM29ZUFOSlpn9+/djyJAh8Pb2hiAI2LZtW6XXBUGo8vHxxx9LE5iIZKGJox2+e74rhoR6o8wg4s0fzuCT3xJ4k0qyeEcrxsvwqEytSFpm8vPzERoais8//7zK19PS0io9Vq1aBUEQMGLEiAZOSkRyo7ZVYsmoMEzs1woA8NmeJLy+MQbFZeUSJyOqu4rxMt0CWGZqQ9IrAA8aNAiDBg2q9nVPT89Kz7dv345+/fohICDA3NGIyAIoFALefLgt/Fwd8M7WWGyLuY7ruiJ8/XQnuDjYSR2PqFZyCkoQn64HAHQL4I0ka8NixszcuHEDP//8M55//vm7rldcXAy9Xl/pQUTWbVQXP6x5NgIalQ2Op2Rj+LLDuJyVL3Usolo5ejEboggEujvBXcPxMrVhMWXmm2++gUajwfDhw++63vz586HVao0PX1/fBkpIRFLq1bopfpjQA95aNS5m5uOxZYdx8vItqWMR1VjFeJnuPMVUaxZTZlatWoWnnnoKavXd2+qMGTOg0+mMj9RUznIgaizaeGqwbWJPhDR3RnZ+CcasOIqfz6RJHYuoRoyDf1lmas0iysyBAweQkJCAF1544Z7rqlQqODs7V3oQUePh7qzGppe7o387d5SUGTBx3Sl8GZXMmU4ka1l5xTiffvsikBwvU3sWUWZWrlyJTp06ITQ0VOooRGQBHOxs8NXTnTG+R0sAwIJfz+OdbbEoKzdIG4yoGsdSsgEAbTw0cHNSSZzG8khaZvLy8hATE4OYmBgAQEpKCmJiYnDlyhXjOnq9Hps3b67RURkiogpKhYA5Q9tj1iPBEARg3bEreO6bE8gtKpU6GtEdeAuD+pG0zJw4cQLh4eEIDw8HAEybNg3h4eGYNWuWcZ0NGzZAFEWMGTNGqphEZMGe6+WPr8Z2gtpWgf0XbuKJL48gTVcodSyiSo5wvEy9CKKVn0jW6/XQarXQ6XQcP0PUiJ25moPn1pxAZl4xPJxVWDW+C9p7a6WORYSbucXoMm83BAE49e5DaOLIayQBtfv+togxM0RE9XWfjwu2TeyBIA8n3NAX44kvj2Dv+QypYxEZZzG19XRmkakjlhkiajR8mjhg8ys90DPQDQUl5Xj+mz/w3ZFLUseiRu4Iry9TbywzRNSoaO1tsXp8BJ7o5AODCMzcfg7zfo6DwWDVZ9xJxo5y8G+9scwQUaNjZ6PAR4/fh38PCAIArDiQgle/P4XCEt6kkhrWDX0RLmbmQxCAiJa8vkxdscwQUaMkCAImPdAaS0aHwU6pwM5z6Riz4ihu5hZLHY0akYrxMu29naF1sJU4jeVimSGiRu3RsOZY+0JXuDjYIiY1B48tO4SkjFypY1EjYby+DMfL1AvLDBE1ehH+roic0AMt3Bxw9VYhhi87bPySITIn4+BfjpepF5YZIiIAAc2cEDmhBzq1aAJ9URmeWXUMkaeuSh2LrNj1nEJcziqAQgC6cLxMvbDMEBH9yc1Jhe9f6IrB93mhtFzEtE2nsXjXBd6kksyiYrxMh+ZaaNQcL1MfLDNERH+jtlXis9HheKVPKwDAkt8T8cam0ygp400qybQqTmV24ymmemOZISL6B4VCwNuD2mL+8A5QKgRERl/DM6uOQVfAm1SS6fBieabDMkNEVI0xEX5YNb4LnFQ2OHoxG8OXH0JqdoHUscgKpGYX4OqtQigVAsfLmADLDBHRXfQJaobNr3SHl1aN5Jv5GPbFIURfuSV1LLJwFUdl7vPRwlFlI3Eay8cyQ0R0D+28nLFtYk+093ZGVn4JRn99FL+eTZM6FlmwozzFZFIsM0RENeDhrMaml7vjgbbuKC4z4NV1p7Bi/0XOdKJaE0WR92MyMZYZIqIaclTZ4OunO+GZ7i0gisC8X+Ixc3ssyso504lq7kp2Aa7rimCrFNC5BcfLmALLDBFRLdgoFZg7tD3eHdwOggCsPXoFL357AnnFZVJHIwtRMSU7zNcF9nZKidNYB5YZIqJaEgQBL9wfgOVPdYLaVoG9CTcx8ssjSNcVSR2NLEDF4N9uHC9jMiwzRER1NDDEExte6o6mTnaIS9Nj2BeHEHddL3UskjFRFDn41wxYZoiI6iHM1wVbX+2JQHcnpOuL8MSXh7EvIUPqWCRTKZn5uKEvhp1SgY4tmkgdx2qwzBAR1ZOvqwO2vNID3QPckF9Sjue/OYF1x65IHYtkqOIUU7ifC9S2HC9jKiwzREQmoHWwxTfPRWB4x+YoN4j4z9azmP9rPAwGTt2mvxjvx8RTTCbFMkNEZCJ2Ngp88kQoXu8fBAD4KuoiJq+PRlFpucTJSGrlBhFHkjMRlXATANDVn1OyTYllhojIhARBwGv9W2PxqFDYKgX8fDYNT644iqy8YqmjkUR2xqah18I9GLPiGHL/nMI/bVMMdsbyKtKmwjJDRGQGj4X74Lvnu0Jrb4tTV3Lw2LLDSL6ZJ3UsamA7Y9MwYe0ppP1j2v4NfTEmrD3FQmMiLDNERGbSLcANWyb0gK+rPa5kF2D4ssM49ucAULJ+5QYRc3+MQ1WjpiqWzf0xDuUcV1VvLDNERGYU6O6Era/2RJivC3SFpXh65XFsi74mdSwyA1EUcS2nEHsTMvBVVDLGrz5+xxGZSusDSNMV4XhKdsOFtFK87zgRkZk1dVJhw0vd8PrGGPwam46pG2OQml2ASQ8EQhAEqeNRLYmiiMy8Ely4kYuE9FwkZvz53xt5xjExtZGRyytH1xfLDBFRA1DbKvHFkx2xcOd5fLX/Ij7ZdQGXswvw4WMdYGfDg+RylVNQggs38nDhRq6xvFy4kYtbBaVVrm+jEBDQzBFBHho42Cqx6eTVe27DXaM2dexGh2WGiKiBKBQCZvyrHXxdHTBreyx+OHkV13MKsXxsJ2jtbaWO16jlFZch8cbtoysJfxaXCzdycUNf9Sw0QQBaujkiyMMJbTw0aO2hQRtPDVq6ORrLablBxIGkTKTriqocNyMA8NSqEcFp2vXGMkNE1MDGdmuB5k3sMen7UzicnIXHlx/GqvFd4OvqIHU0q1dUWo7km3l/HmXJQ+KNXCTcyMXVW4XV/kxzF3u08dSg9Z/FJchDg0B3p3tewVepEDB7SDAmrD0FAahUaCpOLs4eEgylgqca60sQRdGqh1Hr9XpotVrodDo4OztLHYeIyOjcdR2eX3MC6foiNHVSYeW4zgj1dZE6llUoLTfgUmb+7aMs6bnGU0WXsvJR3eQhd40KQX+WlTaeTmjtoUFrdydo1PU7arYzNg1zf4yrNBjYS6vG7CHBGBjiVa/3tma1+f5mmSEiklCarhDPrTmB+DQ91LYKLB0djgHtPaWOZTHKDSJSswuQcCP3z6MsebiQnouLmXkoLa/6683FwfZ2YfHQIMjDyVhgmjjamTXn8ZRsZOQWwV1z+9QSj8jcHcvM37DMEJHc5RWXYeL3pxB14SYEAXh3cDCe69mSM53+RhRFXNcV3R7Lkp5rHNeSlJGHolJDlT/jaKdEkKcGQe4aBHn+WV48ndDMScV9awFYZv6GZYaILEFZuQGzd5zD93/ebXtc9xaYNaR9o/vXuyiKuJlXfHsgbvpfA3HvNu1ZZaNAaw+nSqWltYcTmrvYs7RYsNp8f3MAMBGRDNgoFfhgWAhauDngw1/O45sjl3H1ViGWjgmHo8o6/1RXTHs2niKqxbTnv88g8nN1aHSljyrjkRkiIpn55WwaXt8Yg+IyA0KaO2PVuC5wd7bca5FUNe05IT0XGbn3nvb814DcytOeyfrxyAwRkQX7VwcveGrVePGbE4i9psewLw5h1bNd0NZT3v8g++e054pTRPea9hzk4fTXmJYaTnsm+jtJj8zs378fH3/8MU6ePIm0tDRs3boVw4YNq7ROfHw8pk+fjqioKBgMBrRv3x6bNm2Cn59fjbbBIzNEZKmuZBVg/JrjuHgzH04qGyx7qiN6BzWTOladpj0306iMZcWU057JelnMkZn8/HyEhobi2WefxYgRI+54PTk5Gb169cLzzz+PuXPnQqvVIj4+Hmq15R5uJSKqKT83B0RO6IGXvzuJYynZeHbNH5g3LASjI2r2j7n6spRpz0SyGTMjCMIdR2ZGjx4NW1tbfPfdd3V+Xx6ZISJLV1xWjre3nMXWP++2/WrfVvj3gDYQAZNcu8Q47fnPAbic9kxyYDFHZu7GYDDg559/xltvvYWHH34Y0dHR8Pf3x4wZM+44FfV3xcXFKC7+a1CZXq9vgLREROajslFi0chQ+Lo6YOnviVi2LxnHUrJw7VYh0v9276B7XVW2umnPF27kIe8u054D3f+8jD+nPZNMyfbITHp6Ory8vODg4IAPPvgA/fr1w86dO/Gf//wHe/fuRZ8+fap8nzlz5mDu3Ll3LOeRGSKyBj+cvIq3fjhd5diUimqxfGxHdAtwM057vvC34sJpz2QpLPKief8sM9evX0fz5s0xZswYrFu3zrje0KFD4ejoiPXr11f5PlUdmfH19WWZISKrUG4Q0WXebmTnl1S7jkJAtQNxOe2ZLIVVnGZq2rQpbGxsEBwcXGl5u3btcPDgwWp/TqVSQaVSmTseEZEkjqdk37XIAH8VGU57psZCtmXGzs4OXbp0QUJCQqXlFy5cQIsWLSRKRUQkrYzconuvBGDhiA4Y1aVhZj0RSU3SMpOXl4ekpCTj85SUFMTExMDV1RV+fn548803MWrUKPTu3ds4ZubHH3/Evn37pAtNRCQhd03NLk3h5+po5iRE8iHpmJl9+/ahX79+dywfN24c1qxZAwBYtWoV5s+fj6tXr6JNmzaYO3cuHn300Rpvg1OzicialBtE9Fq4B+m6IlT1x1sA4KlV4+D0BzhwlyyaRQ4ANheWGSKyNjtj0zBh7SkAqFRo/j6bqbrp2USWojbf3xy6TkRkYQaGeGH52I7w1FY+5eSpVbPIUKMk2wHARERUvYEhXngo2NMkVwAmsnQsM0REFkqpENC9lZvUMYgkx9NMREREZNFYZoiIiMiiscwQERGRRWOZISIiIovGMkNEREQWjWWGiIiILBrLDBEREVk0lhkiIiKyaCwzREREZNGs/grAFffR1Ov1EichIiKimqr43q7J/bCtvszk5uYCAHx9fSVOQkRERLWVm5sLrVZ713UEsSaVx4IZDAZcv34dGo0GgmD6G7Dp9Xr4+voiNTX1nrcop7rjfjY/7uOGwf3cMLifzc/c+1gUReTm5sLb2xsKxd1HxVj9kRmFQgEfHx+zb8fZ2Zn/wzQA7mfz4z5uGNzPDYP72fzMuY/vdUSmAgcAExERkUVjmSEiIiKLxjJTTyqVCrNnz4ZKpZI6ilXjfjY/7uOGwf3cMLifzU9O+9jqBwATERGRdeORGSIiIrJoLDNERERk0VhmiIiIyKKxzBAREZFFa/RlZtmyZfD394darUanTp1w4MCBateNjIzEQw89hGbNmsHZ2Rndu3fH//73vzvW6dy5M1xcXODo6IiwsDB89913ldaZM2cOBEGo9PD09DTL7ycXpt7Pf7dhwwYIgoBhw4bVa7uWTop9zM9y/ffzmjVr7tiHgiCgqKioztu1dFLsY36WTfM3IycnBxMnToSXlxfUajXatWuHX375pc7brTGxEduwYYNoa2srrlixQoyLixNfe+010dHRUbx8+XKV67/22mviwoULxePHj4sXLlwQZ8yYIdra2oqnTp0yrrN3714xMjJSjIuLE5OSksRPP/1UVCqV4s6dO43rzJ49W2zfvr2YlpZmfGRkZJj995WKOfZzhUuXLonNmzcX77//fvHRRx+t13YtmVT7mJ/l+u/n1atXi87OzpX2YVpaWr22a8mk2sf8LNd/PxcXF4udO3cW//Wvf4kHDx4UL126JB44cECMiYmp83ZrqlGXmYiICPGVV16ptKxt27bi22+/XeP3CA4OFufOnXvXdcLDw8V3333X+Hz27NliaGhorbJaMnPt57KyMrFnz57if//7X3HcuHF3fNGaYruWQqp9zM9y/ffz6tWrRa1Wa/btWgqp9jE/y/Xfz8uXLxcDAgLEkpISs263Ko32NFNJSQlOnjyJAQMGVFo+YMAAHD58uEbvYTAYkJubC1dX1ypfF0URv//+OxISEtC7d+9KryUmJsLb2xv+/v4YPXo0Ll68WLdfRObMuZ/fe+89NGvWDM8//7xZtmsppNrHFfhZrv9+zsvLQ4sWLeDj44NHHnkE0dHRJt2upZBqH1fgZ7l++3nHjh3o3r07Jk6cCA8PD4SEhODDDz9EeXm5ybZbHau/0WR1MjMzUV5eDg8Pj0rLPTw8kJ6eXqP3+OSTT5Cfn4+RI0dWWq7T6dC8eXMUFxdDqVRi2bJleOihh4yvd+3aFd9++y2CgoJw48YNfPDBB+jRowfOnTsHNze3+v9yMmKu/Xzo0CGsXLkSMTExZtuupZBqHwP8LAP1389t27bFmjVr0KFDB+j1eixZsgQ9e/bE6dOn0bp1a36WYf59DPCzDNR/P1+8eBF79uzBU089hV9++QWJiYmYOHEiysrKMGvWLLN+lhttmakgCEKl56Io3rGsKuvXr8ecOXOwfft2uLu7V3pNo9EgJiYGeXl5+P333zFt2jQEBASgb9++AIBBgwYZ1+3QoQO6d++OVq1a4ZtvvsG0adPq/0vJkCn3c25uLsaOHYsVK1agadOmZtmuJZJiH/OzXP+/Gd26dUO3bt2Mz3v27ImOHTvis88+w9KlS+u9XUskxT7mZ7n++9lgMMDd3R1ff/01lEolOnXqhOvXr+Pjjz/GrFmz6r3du2m0ZaZp06ZQKpV3tMGMjIw7WuM/bdy4Ec8//zw2b96M/v373/G6QqFAYGAgACAsLAzx8fGYP3++scz8k6OjIzp06IDExMS6/TIyZo79nJycjEuXLmHIkCHGZQaDAQBgY2ODhIQE+Pr61nm7lkaqfdyqVas73o+f5ard62/G3ykUCnTp0sW4D+uzXUsj1T6uCj/LVbvbfvby8oKtrS2USqVxWbt27ZCeno6SkhKzfpYb7ZgZOzs7dOrUCbt27aq0fNeuXejRo0e1P7d+/XqMHz8e69atw+DBg2u0LVEUUVxcXO3rxcXFiI+Ph5eXV83CWxBz7Oe2bdvi7NmziImJMT6GDh2Kfv36ISYmBr6+vnXeriWSah9XhZ/lO9X2b4YoioiJiTHuQ36Wzb+Pq8LP8p3utZ979uyJpKQk4z98AODChQvw8vKCnZ2deT/L9Ro+bOEqpoitXLlSjIuLE6dOnSo6OjqKly5dEkVRFN9++23x6aefNq6/bt060cbGRvziiy8qTd/LyckxrvPhhx+Kv/32m5icnCzGx8eLn3zyiWhjYyOuWLHCuM4bb7wh7tu3T7x48aJ49OhR8ZFHHhE1Go1xu9bGHPv5n6qaaXOv7VoTqfYxP8v1389z5swRd+7cKSYnJ4vR0dHis88+K9rY2IjHjh2r8XatiVT7mJ/l+u/nK1euiE5OTuKkSZPEhIQE8aeffhLd3d3FDz74oMbbratGXWZEURS/+OILsUWLFqKdnZ3YsWNHMSoqyvjauHHjxD59+hif9+nTRwRwx2PcuHHGdd555x0xMDBQVKvVYpMmTcTu3buLGzZsqLTNUaNGiV5eXqKtra3o7e0tDh8+XDx37py5f1VJmXo//1NVX7T32q61kWIf87Nc//08depU0c/PT7SzsxObNWsmDhgwQDx8+HCttmttpNjH/Cyb5m/G4cOHxa5du4oqlUoMCAgQ582bJ5aVldV4u3UliKIo1u/YDhEREZF0Gu2YGSIiIrIOLDNERERk0VhmiIiIyKKxzBAREZFFY5khIiIii8YyQ0RERBaNZYaIiIgsGssMEZnNpUuXIAiC8c7b+/btgyAIyMnJkTQXEVkXlhkiajA9evRAWloatFqt1FGIyIqwzBBRg7Gzs4OnpycEQZA6So2VlJRIHYGI7oFlhojqxWAwYOHChQgMDIRKpYKfnx/mzZtX5br/PM20Zs0auLi4YNu2bQgKCoJarcZDDz2E1NTUardXUlKCSZMmwcvLC2q1Gi1btsT8+fONr+fk5OCll16Ch4cH1Go1QkJC8NNPPxlf37JlC9q3bw+VSoWWLVvik08+qfT+LVu2xAcffIDx48dDq9XixRdfBAAcPnwYvXv3hr29PXx9fTFlyhTk5+fXdbcRkQmxzBBRvcyYMQMLFy7EzJkzERcXh3Xr1sHDw6PGP19QUIB58+bhm2++waFDh6DX6zF69Ohq11+6dCl27NiBTZs2ISEhAWvXrkXLli0B3C5WgwYNwuHDh7F27VrExcVhwYIFUCqVAICTJ09i5MiRGD16NM6ePYs5c+Zg5syZWLNmTaVtfPzxxwgJCcHJkycxc+ZMnD17Fg8//DCGDx+OM2fOYOPGjTh48CAmTZpU6/1FRGZQ71tVElGjpdfrRZVKJa5YsaLK11NSUkQAYnR0tCiKorh3714RgHjr1i1RFEVx9erVIgDx6NGjxp+Jj48XAYjHjh2r8j0nT54sPvDAA6LBYLjjtf/973+iQqEQExISqvzZJ598UnzooYcqLXvzzTfF4OBg4/MWLVqIw4YNq7TO008/Lb700kuVlh04cEBUKBRiYWFhldsioobDIzNEVGfx8fEoLi7Ggw8+WOf3sLGxQefOnY3P27ZtCxcXF8THx1e5/vjx4xETE4M2bdpgypQp+O2334yvxcTEwMfHB0FBQdXm7dmzZ6VlPXv2RGJiIsrLy43L/p4HuH1EZ82aNXBycjI+Hn74YRgMBqSkpNT6dyYi07KROgARWS57e3uTvE9VA4KrGyTcsWNHpKSk4Ndff8Xu3bsxcuRI9O/fHz/88MM984iieMf7iqJ4x3qOjo6VnhsMBrz88suYMmXKHev6+fnddZtEZH48MkNEdda6dWvY29vj999/r/N7lJWV4cSJE8bnCQkJyMnJQdu2bav9GWdnZ4waNQorVqzAxo0bsWXLFmRnZ+O+++7D1atXceHChSp/Ljg4GAcPHqy07PDhwwgKCjKOq6lKx44dce7cOQQGBt7xsLOzq+VvTESmxiMzRFRnarUa06dPx1tvvQU7Ozv07NkTN2/exLlz5/D888/X6D1sbW0xefJkLF26FLa2tpg0aRK6deuGiIiIKtdfvHgxvLy8EBYWBoVCgc2bN8PT0xMuLi7o06cPevfujREjRmDRokUIDAzE+fPnIQgCBg4ciDfeeANdunTB+++/j1GjRuHIkSP4/PPPsWzZsrtmnD59Orp164aJEyfixRdfhKOjI+Lj47Fr1y589tlntd5vRGRaLDNEVC8zZ86EjY0NZs2ahevXr8PLywuvvPJKjX/ewcEB06dPx5NPPomrV6+iV69eWLVqVbXrOzk5YeHChUhMTIRSqUSXLl3wyy+/QKG4faB5y5Yt+Pe//40xY8YgPz8fgYGBWLBgAYDbR1g2bdqEWbNm4f3334eXlxfee+89jB8//q4Z77vvPkRFReGdd97B/fffD1EU0apVK4waNarGvycRmY8gVnXCmIioAaxZswZTp07l7Q2IqF44ZoaIiIgsGssMERERWTSeZiIiIiKLxiMzREREZNFYZoiIiMiiscwQERGRRWOZISIiIovGMkNEREQWjWWGiIiILBrLDBEREVk0lhkiIiKyaCwzREREZNH+H8K2bvrmBMC0AAAAAElFTkSuQmCC", 535 | "text/plain": [ 536 | "
" 537 | ] 538 | }, 539 | "metadata": {}, 540 | "output_type": "display_data" 541 | } 542 | ], 543 | "source": [ 544 | "# plot clip score as x-axis, fid score as y-axis, line chart\n", 545 | "import matplotlib.pyplot as plt\n", 546 | "plt.plot(clip_scores, fids, 'o-')\n", 547 | "plt.xlabel('clip score')\n", 548 | "plt.ylabel('fid score')\n", 549 | "plt.show()" 550 | ] 551 | } 552 | ], 553 | "metadata": { 554 | "kernelspec": { 555 | "display_name": "Python 3.9.13 ('base')", 556 | "language": "python", 557 | "name": "python3" 558 | }, 559 | "language_info": { 560 | "codemirror_mode": { 561 | "name": "ipython", 562 | "version": 3 563 | }, 564 | "file_extension": ".py", 565 | "mimetype": "text/x-python", 566 | "name": "python", 567 | "nbconvert_exporter": "python", 568 | "pygments_lexer": "ipython3", 569 | "version": "3.9.13" 570 | }, 571 | "orig_nbformat": 4, 572 | "vscode": { 573 | "interpreter": { 574 | "hash": "4cc247672a8bfe61dc951074f9ca89ab002dc0f7e14586a8bb0828228bebeefa" 575 | } 576 | } 577 | }, 578 | "nbformat": 4, 579 | "nbformat_minor": 2 580 | } 581 | -------------------------------------------------------------------------------- /text-image/fid_clip_score/run_generator.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ps aux | grep -E 'run_watch.sh|watch.py' |awk '{print $2}' | xargs kill -9 # kill previous watchdog 3 | guidance_scales=(1.5 2.0 3.0 4.0 5.0 6.0 7.0 8.0) 4 | for i in {0..7} 5 | do 6 | echo ${i} 7 | CUDA_VISIBLE_DEVICES=${i} nohup python coco_sample_generator.py --guidance_scale ${guidance_scales[${i}]} --batch_size 16 --sample_step 20 > stable_generator.log 2>&1 & 8 | done 9 | wait 10 | bash ~/release_watchdog.sh # start watchdog -------------------------------------------------------------------------------- /text-image/fid_clip_score/run_generator_cn.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ps aux | grep -E 'run_watch.sh|watch.py' |awk '{print $2}' | xargs kill -9 # kill previous watchdog 3 | guidance_scales=(1.5 2.0 3.0 4.0 5.0 6.0 7.0 8.0) 4 | for i in {0..7} 5 | do 6 | echo ${i} 7 | CUDA_VISIBLE_DEVICES=${i} nohup python coco_sample_generator.py --model_path ../pretrained_models/stable_cn --coco_cache_file ../dataset/coco/subset_cn.parquet --output_path ./output_cn --guidance_scale ${guidance_scales[${i}]} --batch_size 16 --sample_step 20 > stable_generator.log 2>&1 & 8 | done 9 | wait 10 | bash ~/release_watchdog.sh # start watchdog -------------------------------------------------------------------------------- /text-image/imagenet_CN_zeroshot_data.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | imagenet_classnames = [ 4 | "丁鲷", 5 | "金鱼", 6 | "大白鲨", 7 | "虎鲨", 8 | "锤头鲨", 9 | "电鳐", 10 | "黄貂鱼", 11 | "公鸡", 12 | "母鸡", 13 | "鸵鸟", 14 | "燕雀", 15 | "金翅雀", 16 | "家朱雀", 17 | "灯芯草雀", 18 | "靛蓝雀", 19 | "蓝鹀", 20 | "夜莺", 21 | "松鸦", 22 | "喜鹊", 23 | "山雀", 24 | "河鸟", 25 | "鸢(猛禽)", 26 | "秃头鹰", 27 | "秃鹫", 28 | "大灰猫头鹰", 29 | "欧洲火蝾螈", 30 | "普通蝾螈", 31 | "水蜥", 32 | "斑点蝾螈", 33 | "蝾螈", 34 | "牛蛙", 35 | "树蛙", 36 | "尾蛙", 37 | "红海龟", 38 | "皮革龟", 39 | "泥龟", 40 | "淡水龟", 41 | "箱龟", 42 | "带状壁虎", 43 | "普通鬣蜥", 44 | "美国变色龙", 45 | "鞭尾蜥蜴", 46 | "飞龙科蜥蜴", 47 | "褶边蜥蜴", 48 | "鳄鱼蜥蜴", 49 | "毒蜥", 50 | "绿蜥蜴", 51 | "非洲变色龙", 52 | "科莫多蜥蜴", 53 | "非洲鳄", 54 | "美国鳄鱼", 55 | "三角龙", 56 | "雷蛇", 57 | "环蛇", 58 | "希腊蛇", 59 | "绿蛇", 60 | "国王蛇", 61 | "袜带蛇", 62 | "水蛇", 63 | "藤蛇", 64 | "夜蛇", 65 | "大蟒蛇", 66 | "岩石蟒蛇", 67 | "印度眼镜蛇", 68 | "绿曼巴", 69 | "海蛇", 70 | "角腹蛇", 71 | "菱纹响尾蛇", 72 | "角响尾蛇", 73 | "三叶虫", 74 | "盲蜘蛛", 75 | "蝎子", 76 | "黑金花园蜘蛛", 77 | "谷仓蜘蛛", 78 | "花园蜘蛛", 79 | "黑寡妇蜘蛛", 80 | "狼蛛", 81 | "狼蜘蛛", 82 | "壁虱", 83 | "蜈蚣", 84 | "黑松鸡", 85 | "松鸡", 86 | "披肩鸡", 87 | "草原鸡", 88 | "孔雀", 89 | "鹌鹑", 90 | "鹧鸪", 91 | "非洲灰鹦鹉", 92 | "金刚鹦鹉", 93 | "硫冠鹦鹉", 94 | "短尾鹦鹉", 95 | "褐翅鸦鹃", 96 | "食蜂鸟;蜂虎", 97 | "犀鸟", 98 | "蜂鸟", 99 | "鹟䴕", 100 | "巨嘴鸟;大嘴鸟", 101 | "野鸭", 102 | "红胸秋沙鸭", 103 | "鹅", 104 | "黑天鹅", 105 | "大象", 106 | "针鼹鼠", 107 | "鸭嘴兽", 108 | "沙袋鼠", 109 | "考拉", 110 | "袋熊", 111 | "水母", 112 | "海葵", 113 | "脑珊瑚", 114 | "扁形虫扁虫", 115 | "线虫", 116 | "海螺", 117 | "蜗牛", 118 | "鼻涕虫", 119 | "海蛞蝓;海参", 120 | "石鳖", 121 | "鹦鹉螺", 122 | "珍宝蟹", 123 | "石蟹", 124 | "招潮蟹", 125 | "帝王蟹", 126 | "美国龙虾", 127 | "大螯虾", 128 | "小龙虾", 129 | "寄居蟹", 130 | "等足目动物(明虾和螃蟹近亲)", 131 | "白鹳", 132 | "黑鹳", 133 | "鹭", 134 | "火烈鸟", 135 | "小蓝鹭", 136 | "美国鹭", 137 | "麻鸦", 138 | "鹤", 139 | "秧鹤", 140 | "欧洲水鸡", 141 | "沼泽泥母鸡", 142 | "鸨", 143 | "红翻石鹬", 144 | "红背鹬", 145 | "红脚鹬", 146 | "半蹼鹬", 147 | "蛎鹬", 148 | "鹈鹕", 149 | "国王企鹅", 150 | "信天翁", 151 | "灰鲸", 152 | "杀人鲸", 153 | "海牛", 154 | "海狮", 155 | "吉娃娃", 156 | "日本狆犬", 157 | "马尔济斯犬", 158 | "狮子狗", 159 | "西施犬", 160 | "布莱尼姆猎犬", 161 | "巴比狗", 162 | "玩具犬", 163 | "罗得西亚长背猎狗", 164 | "阿富汗猎犬", 165 | "巴吉度猎犬", 166 | "比格犬", 167 | "侦探犬", 168 | "蓝色快狗", 169 | "黑褐猎浣熊犬", 170 | "沃克猎犬", 171 | "英国猎狐犬", 172 | "美洲赤狗", 173 | "俄罗斯猎狼犬", 174 | "爱尔兰猎狼犬", 175 | "意大利灰狗", 176 | "惠比特犬", 177 | "依比沙猎犬", 178 | "挪威猎犬", 179 | "奥达猎犬", 180 | "沙克犬", 181 | "苏格兰猎鹿犬", 182 | "威玛猎犬", 183 | "斯塔福德郡斗牛犬", 184 | "美国斯塔福德郡梗", 185 | "贝德灵顿梗", 186 | "边境梗", 187 | "凯丽蓝梗", 188 | "爱尔兰梗", 189 | "诺福克梗", 190 | "诺维奇梗", 191 | "约克犬;约克夏梗犬", 192 | "刚毛猎狐梗", 193 | "莱克兰梗", 194 | "锡利哈姆梗", 195 | "艾尔谷犬", 196 | "凯恩梗", 197 | "澳大利亚梗", 198 | "丹迪丁蒙梗", 199 | "波士顿梗", 200 | "迷你雪纳瑞犬", 201 | "巨型雪纳瑞犬", 202 | "标准雪纳瑞犬", 203 | "苏格兰梗犬", 204 | "西藏梗", 205 | "丝毛梗", 206 | "爱尔兰软毛梗犬", 207 | "西高地白梗", 208 | "拉萨阿普索犬", 209 | "平毛寻回犬", 210 | "卷毛寻回犬", 211 | "金毛猎犬", 212 | "拉布拉多猎犬", 213 | "乞沙比克猎犬", 214 | "德国短毛指示犬", 215 | "维兹拉犬", 216 | "英国塞特犬", 217 | "爱尔兰雪达犬", 218 | "戈登雪达犬", 219 | "布列塔尼犬猎犬", 220 | "黄毛", 221 | "英国史宾格犬", 222 | "威尔士史宾格犬", 223 | "可卡犬", 224 | "萨塞克斯猎犬", 225 | "爱尔兰水猎犬", 226 | "哥威斯犬", 227 | "舒柏奇犬", 228 | "比利时牧羊犬", 229 | "马里努阿犬", 230 | "伯瑞犬", 231 | "凯尔皮犬", 232 | "匈牙利牧羊犬", 233 | "老英国牧羊犬", 234 | "喜乐蒂牧羊犬", 235 | "牧羊犬", 236 | "边境牧羊犬", 237 | "法兰德斯牧牛狗", 238 | "罗特韦尔犬", 239 | "德国牧羊犬", 240 | "多伯曼犬", 241 | "鹿犬;迷你杜宾犬", 242 | "大瑞士山地犬", 243 | "伯恩山犬", 244 | "阿策尔山犬", 245 | "恩特尔布赫山犬", 246 | "拳师狗", 247 | "斗牛獒", 248 | "藏獒", 249 | "法国斗牛犬", 250 | "大丹犬", 251 | "圣伯纳德狗", 252 | "爱斯基摩犬", 253 | "阿拉斯加雪橇犬", 254 | "哈士奇", 255 | "达尔马提亚", 256 | "狮毛狗", 257 | "巴辛吉狗", 258 | "八哥犬", 259 | "莱昂贝格狗", 260 | "纽芬兰犬", 261 | "大白熊犬", 262 | "萨摩耶犬", 263 | "博美犬", 264 | "松狮", 265 | "凯斯犬", 266 | "布鲁塞尔格林芬犬", 267 | "彭布洛克威尔士科基犬", 268 | "威尔士柯基犬", 269 | "玩具贵宾犬", 270 | "迷你贵宾犬", 271 | "标准贵宾犬", 272 | "墨西哥无毛犬", 273 | "灰狼", 274 | "白狼", 275 | "红太狼", 276 | "狼", 277 | "澳洲野狗", 278 | "豺", 279 | "非洲猎犬", 280 | "鬣狗", 281 | "红狐狸", 282 | "沙狐", 283 | "北极狐狸", 284 | "灰狐狸", 285 | "虎斑猫", 286 | "山猫", 287 | "波斯猫", 288 | "暹罗猫", 289 | "埃及猫", 290 | "美洲狮", 291 | "猞猁", 292 | "豹子", 293 | "雪豹", 294 | "美洲虎", 295 | "狮子", 296 | "老虎", 297 | "猎豹", 298 | "棕熊", 299 | "美洲黑熊", 300 | "冰熊", 301 | "懒熊", 302 | "獴", 303 | "猫鼬", 304 | "虎甲虫", 305 | "瓢虫", 306 | "土鳖虫", 307 | "天牛", 308 | "龟甲虫", 309 | "粪甲虫", 310 | "犀牛甲虫", 311 | "象甲", 312 | "苍蝇", 313 | "蜜蜂", 314 | "蚂蚁", 315 | "蚱蜢", 316 | "蟋蟀", 317 | "竹节虫", 318 | "蟑螂", 319 | "螳螂", 320 | "蝉", 321 | "叶蝉", 322 | "草蜻蛉", 323 | "蜻蜓", 324 | "豆娘", 325 | "优红蛱蝶", 326 | "小环蝴蝶", 327 | "君主蝴蝶", 328 | "菜粉蝶", 329 | "白蝴蝶", 330 | "灰蝶", 331 | "海星", 332 | "海胆", 333 | "海黄瓜;海参", 334 | "野兔", 335 | "兔", 336 | "安哥拉兔", 337 | "仓鼠", 338 | "刺猬", 339 | "黑松鼠", 340 | "土拨鼠", 341 | "海狸", 342 | "豚鼠", 343 | "栗色马", 344 | "斑马", 345 | "猪", 346 | "野猪", 347 | "疣猪", 348 | "河马", 349 | "牛", 350 | "水牛", 351 | "野牛", 352 | "公羊", 353 | "大角羊", 354 | "山羊", 355 | "狷羚", 356 | "黑斑羚", 357 | "瞪羚", 358 | "阿拉伯单峰骆驼", 359 | "骆驼", 360 | "黄鼠狼", 361 | "水貂", 362 | "臭猫", 363 | "黑足鼬", 364 | "水獭", 365 | "臭鼬", 366 | "獾", 367 | "犰狳", 368 | "树懒", 369 | "猩猩", 370 | "大猩猩", 371 | "黑猩猩", 372 | "长臂猿", 373 | "合趾猿长臂猿", 374 | "长尾猴", 375 | "赤猴", 376 | "狒狒", 377 | "恒河猴", 378 | "白头叶猴", 379 | "疣猴", 380 | "长鼻猴", 381 | "狨(美洲产小型长尾猴)", 382 | "卷尾猴", 383 | "吼猴", 384 | "伶猴", 385 | "蜘蛛猴", 386 | "松鼠猴", 387 | "马达加斯加环尾狐猴", 388 | "大狐猴", 389 | "印度大象", 390 | "非洲象", 391 | "小熊猫", 392 | "大熊猫", 393 | "杖鱼", 394 | "鳗鱼", 395 | "银鲑", 396 | "三色刺蝶鱼", 397 | "海葵鱼", 398 | "鲟鱼", 399 | "雀鳝", 400 | "狮子鱼", 401 | "河豚", 402 | "算盘", 403 | "长袍", 404 | "学位袍", 405 | "手风琴", 406 | "原声吉他", 407 | "航空母舰", 408 | "客机", 409 | "飞艇", 410 | "祭坛", 411 | "救护车", 412 | "水陆两用车", 413 | "模拟时钟", 414 | "蜂房", 415 | "围裙", 416 | "垃圾桶", 417 | "攻击步枪", 418 | "背包", 419 | "面包店", 420 | "平衡木", 421 | "热气球", 422 | "圆珠笔", 423 | "创可贴", 424 | "班卓琴", 425 | "栏杆", 426 | "杠铃", 427 | "理发师的椅子", 428 | "理发店", 429 | "牲口棚", 430 | "晴雨表", 431 | "圆筒", 432 | "园地小车", 433 | "棒球", 434 | "篮球", 435 | "婴儿床", 436 | "巴松管", 437 | "游泳帽", 438 | "沐浴毛巾", 439 | "浴缸", 440 | "沙滩车", 441 | "灯塔", 442 | "烧杯", 443 | "熊皮高帽", 444 | "啤酒瓶", 445 | "啤酒杯", 446 | "钟塔", 447 | "(小儿用的)围嘴", 448 | "串联自行车", 449 | "比基尼", 450 | "装订册", 451 | "双筒望远镜", 452 | "鸟舍", 453 | "船库", 454 | "双人雪橇", 455 | "饰扣式领带", 456 | "阔边女帽", 457 | "书橱", 458 | "书店", 459 | "瓶盖", 460 | "弓箭", 461 | "蝴蝶结领结", 462 | "铜制牌位", 463 | "奶罩", 464 | "防波堤", 465 | "铠甲", 466 | "扫帚", 467 | "桶", 468 | "扣环", 469 | "防弹背心", 470 | "动车", 471 | "肉铺", 472 | "出租车", 473 | "大锅", 474 | "蜡烛", 475 | "大炮", 476 | "独木舟", 477 | "开瓶器", 478 | "开衫", 479 | "车镜", 480 | "旋转木马", 481 | "木匠的工具包", 482 | "纸箱", 483 | "车轮", 484 | "取款机", 485 | "盒式录音带", 486 | "卡带播放器", 487 | "城堡", 488 | "双体船", 489 | "CD播放器", 490 | "大提琴", 491 | "移动电话", 492 | "铁链", 493 | "围栏", 494 | "链甲", 495 | "电锯", 496 | "箱子", 497 | "梳妆台", 498 | "编钟", 499 | "中国橱柜", 500 | "圣诞袜", 501 | "教堂", 502 | "电影院", 503 | "切肉刀", 504 | "悬崖屋", 505 | "斗篷", 506 | "木屐", 507 | "鸡尾酒调酒器", 508 | "咖啡杯", 509 | "咖啡壶", 510 | "螺旋结构(楼梯)", 511 | "组合锁", 512 | "电脑键盘", 513 | "糖果", 514 | "集装箱船", 515 | "敞篷车", 516 | "瓶塞钻", 517 | "短号", 518 | "牛仔靴", 519 | "牛仔帽", 520 | "摇篮", 521 | "起重机", 522 | "头盔", 523 | "板条箱", 524 | "小儿床", 525 | "砂锅", 526 | "槌球", 527 | "拐杖", 528 | "胸甲", 529 | "大坝", 530 | "书桌", 531 | "台式电脑", 532 | "有线电话", 533 | "尿布湿", 534 | "数字时钟", 535 | "数字手表", 536 | "餐桌板", 537 | "抹布", 538 | "洗碗机", 539 | "盘式制动器", 540 | "码头", 541 | "狗拉雪橇", 542 | "圆顶", 543 | "门垫", 544 | "钻井平台", 545 | "鼓", 546 | "鼓槌", 547 | "哑铃", 548 | "荷兰烤箱", 549 | "电风扇", 550 | "电吉他", 551 | "电力机车", 552 | "组合电视柜", 553 | "信封", 554 | "浓缩咖啡机", 555 | "扑面粉", 556 | "女用长围巾", 557 | "文件", 558 | "消防船", 559 | "消防车", 560 | "火炉栏", 561 | "旗杆", 562 | "长笛", 563 | "折叠椅", 564 | "橄榄球头盔", 565 | "叉车", 566 | "喷泉", 567 | "钢笔", 568 | "有四根帷柱的床", 569 | "运货车厢", 570 | "圆号", 571 | "煎锅", 572 | "裘皮大衣", 573 | "垃圾车", 574 | "防毒面具", 575 | "汽油泵", 576 | "高脚杯", 577 | "卡丁车", 578 | "高尔夫球", 579 | "高尔夫球车", 580 | "狭长小船", 581 | "锣", 582 | "礼服", 583 | "钢琴", 584 | "温室", 585 | "散热器格栅", 586 | "杂货店", 587 | "断头台", 588 | "小发夹", 589 | "头发喷雾", 590 | "半履带装甲车", 591 | "锤子", 592 | "大篮子", 593 | "手摇鼓风机", 594 | "手提电脑", 595 | "手帕", 596 | "硬盘", 597 | "口琴", 598 | "竖琴", 599 | "收割机", 600 | "斧头", 601 | "手枪皮套", 602 | "家庭影院", 603 | "蜂窝", 604 | "钩爪", 605 | "衬裙", 606 | "单杠", 607 | "马车", 608 | "沙漏", 609 | "iPod", 610 | "熨斗", 611 | "南瓜灯笼", 612 | "牛仔裤", 613 | "吉普车", 614 | "T恤衫", 615 | "拼图", 616 | "人力车", 617 | "操纵杆", 618 | "和服", 619 | "护膝", 620 | "蝴蝶结", 621 | "大褂", 622 | "长柄勺", 623 | "灯罩", 624 | "笔记本电脑", 625 | "割草机", 626 | "镜头盖", 627 | "开信刀", 628 | "图书馆", 629 | "救生艇", 630 | "点火器", 631 | "豪华轿车", 632 | "远洋班轮", 633 | "唇膏", 634 | "平底便鞋", 635 | "洗剂", 636 | "扬声器", 637 | "放大镜", 638 | "锯木厂", 639 | "磁罗盘", 640 | "邮袋", 641 | "信箱", 642 | "女游泳衣", 643 | "有肩带浴衣", 644 | "窨井盖", 645 | "沙球(一种打击乐器)", 646 | "马林巴木琴", 647 | "面膜", 648 | "火柴", 649 | "花柱", 650 | "迷宫", 651 | "量杯", 652 | "药箱", 653 | "巨石", 654 | "麦克风", 655 | "微波炉", 656 | "军装", 657 | "奶桶", 658 | "迷你巴士", 659 | "迷你裙", 660 | "面包车", 661 | "导弹", 662 | "连指手套", 663 | "搅拌钵", 664 | "活动房屋(由汽车拖拉的)", 665 | "T型发动机小汽车", 666 | "调制解调器", 667 | "修道院", 668 | "显示器", 669 | "电瓶车", 670 | "砂浆", 671 | "学士", 672 | "清真寺", 673 | "蚊帐", 674 | "摩托车", 675 | "山地自行车", 676 | "登山帐", 677 | "鼠标", 678 | "捕鼠器", 679 | "搬家货车", 680 | "动物的口套", 681 | "金属钉子", 682 | "颈托", 683 | "项链", 684 | "乳头(瓶)", 685 | "笔记本", 686 | "方尖碑", 687 | "双簧管", 688 | "陶笛", 689 | "里程表", 690 | "滤油器", 691 | "风琴", 692 | "示波器", 693 | "罩裙", 694 | "牛车", 695 | "氧气面罩", 696 | "包装", 697 | "船桨", 698 | "明轮", 699 | "挂锁", 700 | "画笔", 701 | "睡衣", 702 | "宫殿", 703 | "排箫", 704 | "纸巾", 705 | "降落伞", 706 | "双杠", 707 | "公园长椅", 708 | "停车收费表", 709 | "客车", 710 | "露台", 711 | "付费电话", 712 | "基座", 713 | "铅笔盒", 714 | "卷笔刀", 715 | "香水(瓶)", 716 | "培养皿", 717 | "复印机", 718 | "拨弦片", 719 | "尖顶头盔", 720 | "用尖板条连成的尖桩篱栅", 721 | "皮卡", 722 | "桥墩", 723 | "存钱罐", 724 | "药瓶", 725 | "枕头", 726 | "乒乓球", 727 | "风车", 728 | "海盗船", 729 | "水罐", 730 | "木工刨", 731 | "天文馆", 732 | "塑料袋", 733 | "板架", 734 | "犁型铲雪机", 735 | "手压皮碗泵", 736 | "宝丽来相机", 737 | "电线杆", 738 | "警车", 739 | "雨披", 740 | "台球桌", 741 | "充气饮料瓶", 742 | "花盆", 743 | "陶工旋盘", 744 | "电钻", 745 | "祈祷垫", 746 | "打印机", 747 | "监狱", 748 | "炮弹", 749 | "投影仪", 750 | "冰球", 751 | "沙包", 752 | "小钱袋;手袋", 753 | "羽管笔", 754 | "被子", 755 | "赛车", 756 | "球拍", 757 | "散热器", 758 | "收音机", 759 | "射电望远镜", 760 | "雨桶", 761 | "休闲车", 762 | "卷轴", 763 | "反射式照相机", 764 | "冰箱", 765 | "遥控器", 766 | "餐厅", 767 | "左轮手枪", 768 | "步枪", 769 | "摇椅", 770 | "电转烤肉架", 771 | "橡皮", 772 | "橄榄球", 773 | "直尺", 774 | "跑步鞋", 775 | "保险柜", 776 | "安全别针", 777 | "盐瓶(调味用)", 778 | "凉鞋", 779 | "纱笼", 780 | "萨克斯管", 781 | "剑鞘", 782 | "秤", 783 | "校车", 784 | "帆船", 785 | "记分牌", 786 | "屏幕", 787 | "螺丝", 788 | "螺丝刀", 789 | "安全带", 790 | "缝纫机", 791 | "盾牌", 792 | "皮鞋店", 793 | "障子", 794 | "购物篮", 795 | "购物车", 796 | "铁锹", 797 | "浴帽", 798 | "浴帘", 799 | "滑雪板", 800 | "滑雪面罩", 801 | "睡袋", 802 | "滑尺", 803 | "滑动门", 804 | "角子老虎机", 805 | "潜水通气管", 806 | "摩托雪橇;雪地机动车", 807 | "扫雪机", 808 | "皂液器", 809 | "足球", 810 | "袜子", 811 | "碟式太阳能", 812 | "宽边帽", 813 | "汤碗", 814 | "空格键", 815 | "空间加热器", 816 | "航天飞机", 817 | "锅铲;做饭的铲子", 818 | "快艇", 819 | "蜘蛛网", 820 | "纺锤;手纺用的绕线杆", 821 | "跑车", 822 | "聚光灯", 823 | "舞台", 824 | "蒸汽机车", 825 | "钢拱桥", 826 | "钢滚筒", 827 | "听诊器", 828 | "女用披肩", 829 | "石头墙", 830 | "秒表", 831 | "火炉", 832 | "过滤器", 833 | "有轨电车", 834 | "担架", 835 | "沙发床", 836 | "佛塔", 837 | "潜艇", 838 | "套装", 839 | "日晷", 840 | "太阳镜", 841 | "太阳镜", 842 | "防晒霜", 843 | "悬索桥", 844 | "拖把", 845 | "运动衫", 846 | "游泳裤", 847 | "秋千", 848 | "开关", 849 | "注射器;吸管", 850 | "台灯", 851 | "坦克", 852 | "录音机", 853 | "茶壶", 854 | "泰迪", 855 | "电视", 856 | "网球;打网球的球", 857 | "茅草", 858 | "幕布", 859 | "顶针", 860 | "打谷机;脱粒机", 861 | "宝座", 862 | "瓦屋顶", 863 | "烤面包机", 864 | "烟草店", 865 | "马桶", 866 | "火炬", 867 | "图腾柱", 868 | "拖车;牵引车", 869 | "玩具店", 870 | "拖拉机", 871 | "半挂汽车", 872 | "托盘", 873 | "风衣", 874 | "三轮车", 875 | "三体船", 876 | "三脚架", 877 | "凯旋门", 878 | "无轨电车", 879 | "长号", 880 | "浴盆", 881 | "旋转式栅门", 882 | "打字机键盘", 883 | "伞", 884 | "独轮车", 885 | "直立式钢琴", 886 | "吸尘器", 887 | "花瓶;装饰瓶", 888 | "拱顶", 889 | "天鹅绒", 890 | "自动售货机", 891 | "法衣;祭衣;祭服", 892 | "高架桥", 893 | "小提琴", 894 | "排球", 895 | "松饼机", 896 | "挂钟", 897 | "钱包;钱夹", 898 | "衣柜衣橱", 899 | "军用飞机", 900 | "洗脸盆", 901 | "洗衣机", 902 | "水瓶", 903 | "水壶", 904 | "水塔", 905 | "威士忌壶", 906 | "哨子", 907 | "假发", 908 | "纱窗", 909 | "百叶窗", 910 | "温莎领带", 911 | "葡萄酒瓶", 912 | "飞机翅膀", 913 | "炒菜锅", 914 | "木勺子;木头勺子", 915 | "毛织品", 916 | "原木栅栏", 917 | "沉船", 918 | "双桅船", 919 | "蒙古包", 920 | "网站;网页", 921 | "漫画", 922 | "纵横字谜", 923 | "路标", 924 | "交通信号灯", 925 | "防尘罩", 926 | "菜单", 927 | "盘子", 928 | "墨西哥鳄梨酱;墨西哥牛油果酱", 929 | "清炖肉汤", 930 | "火锅", 931 | "乳脂蛋糕;英国甜点", 932 | "冰淇淋", 933 | "冰棍;雪糕", 934 | "法式面包", 935 | "百吉饼", 936 | "椒盐脆饼", 937 | "芝士汉堡", 938 | "热狗", 939 | "土豆泥", 940 | "结球甘蓝", 941 | "西兰花;绿菜花", 942 | "菜花;花椰菜", 943 | "西葫芦", 944 | "金丝瓜;意面南瓜;面条瓜", 945 | "绿色小南瓜;青南瓜", 946 | "南瓜", 947 | "黄瓜", 948 | "洋蓟;球蓟", 949 | "甜椒", 950 | "刺棘蓟", 951 | "蘑菇", 952 | "绿苹果", 953 | "草莓", 954 | "橘子", 955 | "柠檬", 956 | "无花果", 957 | "菠萝", 958 | "香蕉", 959 | "菠萝蜜", 960 | "番荔枝", 961 | "石榴", 962 | "干草", 963 | "培根蛋酱意大利面", 964 | "巧克力酱", 965 | "生面;面团", 966 | "瑞士肉包", 967 | "披萨", 968 | "馅饼", 969 | "卷饼", 970 | "红葡萄酒", 971 | "意式浓缩咖啡", 972 | "杯子", 973 | "蛋酒", 974 | "高山", 975 | "泡泡", 976 | "悬崖", 977 | "珊瑚礁", 978 | "间歇泉;间断喷发的温泉", 979 | "湖边", 980 | "岬角;深入海中的狭长高地", 981 | "沙洲", 982 | "沙滩", 983 | "峡谷", 984 | "火山", 985 | "棒球运动员", 986 | "新郎", 987 | "潜水员", 988 | "油菜", 989 | "雏菊", 990 | "杓兰", 991 | "玉米", 992 | "橡子", 993 | "玫瑰果", 994 | "七叶树果实", 995 | "珊瑚菌", 996 | "木耳", 997 | "鹿花菌", 998 | "臭角菇", 999 | "地星", 1000 | "多叶奇果菌", 1001 | "牛肝菌", 1002 | "玉米棒子", 1003 | "卫生纸" 1004 | ] 1005 | 1006 | 1007 | 1008 | 1009 | openai_imagenet_template = [ 1010 | lambda c: f'质量差的{c}的照片。', 1011 | lambda c: f'许多{c}的照片。', 1012 | lambda c: f'{c}的雕塑。', 1013 | lambda c: f'难以看到{c}的照片。', 1014 | lambda c: f'{c}的低分辨率照片。', 1015 | lambda c: f'{c}的渲染。', 1016 | lambda c: f'涂鸦{c}。', 1017 | lambda c: f'{c}的糟糕照片。', 1018 | lambda c: f'{c}的裁剪照片。', 1019 | lambda c: f'{c}的纹身。', 1020 | lambda c: f'{c}的刺绣照片。', 1021 | lambda c: f'很难看到{c}的照片。', 1022 | lambda c: f'{c}的明亮照片。', 1023 | lambda c: f'一张干净的{c}的照片。', 1024 | lambda c: f'一张包含{c}的照片。', 1025 | lambda c: f'{c}的深色照片。', 1026 | lambda c: f'{c}的手绘画。', 1027 | lambda c: f'我的{c}的照片。', 1028 | lambda c: f'不自然的{c}的照片。', 1029 | lambda c: f'一张酷的{c}的照片。', 1030 | lambda c: f'{c}的特写照片。', 1031 | lambda c: f'{c}的黑白照片。', 1032 | lambda c: f'一幅{c}的画。', 1033 | lambda c: f'一幅{c}的绘画。', 1034 | lambda c: f'一张{c}的像素照片。', 1035 | lambda c: f'{c}的雕像。', 1036 | lambda c: f'一张{c}的明亮照片。', 1037 | lambda c: f'{c}的裁剪照片。', 1038 | lambda c: f'人造的{c}的照片。', 1039 | lambda c: f'一张关于{c}的照片。', 1040 | lambda c: f'损坏的{c}的jpeg照片。', 1041 | lambda c: f'{c}的模糊照片。', 1042 | lambda c: f'{c}的相片。', 1043 | lambda c: f'一张{c}的好照片。', 1044 | lambda c: f'{c}的渲染照。', 1045 | lambda c: f'视频游戏中的{c}。', 1046 | lambda c: f'一张{c}的照片。', 1047 | lambda c: f'{c}的涂鸦。', 1048 | lambda c: f'{c}的近距离照片。', 1049 | lambda c: f'{c}的折纸。', 1050 | lambda c: f'{c}在视频游戏中。', 1051 | lambda c: f'{c}的草图。', 1052 | lambda c: f'{c}的涂鸦照。', 1053 | lambda c: f'{c}的折纸形状。', 1054 | lambda c: f'低分辨率的{c}的照片。', 1055 | lambda c: f'玩具{c}。', 1056 | lambda c: f'{c}的副本。', 1057 | lambda c: f'{c}的干净的照片。', 1058 | lambda c: f'一张大{c}的照片。', 1059 | lambda c: f'{c}的重现。', 1060 | lambda c: f'一张漂亮的{c}的照片。', 1061 | lambda c: f'一张奇怪的{c}的照片。', 1062 | lambda c: f'模糊的{c}的照片。', 1063 | lambda c: f'卡通{c}。', 1064 | lambda c: f'{c}的艺术作品。', 1065 | lambda c: f'{c}的素描。', 1066 | lambda c: f'刺绣{c}。', 1067 | lambda c: f'{c}的像素照。', 1068 | lambda c: f'{c}的拍照。', 1069 | lambda c: f'{c}的损坏的照片。', 1070 | lambda c: f'高质量的{c}的照片。', 1071 | lambda c: f'毛绒玩具{c}。', 1072 | lambda c: f'漂亮的{c}的照片。', 1073 | lambda c: f'小{c}的照片。', 1074 | lambda c: f'照片是奇怪的{c}。', 1075 | lambda c: f'漫画{c}。', 1076 | lambda c: f'{c}的艺术照。', 1077 | lambda c: f'{c}的图形。', 1078 | lambda c: f'大{c}的照片。', 1079 | lambda c: f'黑白的{c}的照片。', 1080 | lambda c: f'{c}毛绒玩具。', 1081 | lambda c: f'一张{c}的深色照片。', 1082 | lambda c: f'{c}的摄影图。', 1083 | lambda c: f'{c}的涂鸦照。', 1084 | lambda c: f'玩具形状的{c}。', 1085 | lambda c: f'拍了{c}的照片。', 1086 | lambda c: f'酷酷的{c}的照片。', 1087 | lambda c: f'照片里的小{c}。', 1088 | lambda c: f'{c}的刺青。', 1089 | ] 1090 | -------------------------------------------------------------------------------- /text-image/iterable_tar_unzip.sh: -------------------------------------------------------------------------------- 1 | # for name in `ls -d */`; 2 | # do; 3 | name="image_part12/" 4 | for i in `ls $name*.tar`; 5 | do 6 | mkdir ./project/dataset/laion_chinese_cwf/${i%.tar} 7 | tar xvf $i -C ./project/dataset/laion_chinese_cwf/${i%.tar}; 8 | done; 9 | # done 10 | 11 | 12 | -------------------------------------------------------------------------------- /text-image/save_hg_ckpt.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Roberta-base 转换为hugging face版" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import torch\n", 17 | "from transformers import BertForSequenceClassification, BertConfig, BertTokenizer\n", 18 | "\n", 19 | "# NOTE 使用open_clip 中文pretrain model训练的结果\n", 20 | "taiyi_path = './project/open_clip_new/src/logs/2022_11_18-21_15_16-model_ViT-L-14-lr_0.0001-b_640-j_16-p_amp/checkpoints/epoch_4.pt'\n", 21 | "bertconfig = BertConfig.from_pretrained(\"hfl/chinese-roberta-wwm-ext\", num_labels=512)\n", 22 | "my_transformer = BertForSequenceClassification.from_pretrained(\"hfl/chinese-roberta-wwm-ext\", config=bertconfig)\n", 23 | "mytokenizer = BertTokenizer.from_pretrained(\"hfl/chinese-roberta-wwm-ext\")\n", 24 | "\n", 25 | "# NOTE 需要改名加载\n", 26 | "state_dict_of_bert = torch.load(taiyi_path)['state_dict']\n", 27 | "bert_weights = {k.replace('module.transformer.',''):v for k,v in state_dict_of_bert.items() if 'module.transformer' in k}\n", 28 | "my_transformer.load_state_dict(bert_weights)" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "# 同时保存模型和词表格。然后把这个上传到huggingface上面去\n", 38 | "my_transformer.save_pretrained('./CLIP-roberta')\n", 39 | "mytokenizer.save_pretrained('./CLIP-roberta')" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "total = sum([param.nelement() for param in my_transformer.parameters()])\n", 49 | "print(\"Number of parameter: %.2fM\" % (total/1e6))" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": {}, 55 | "source": [ 56 | "# Roberta-large 版" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "import torch\n", 66 | "from transformers import BertForSequenceClassification, BertConfig, BertTokenizer\n", 67 | "\n", 68 | "# NOTE 使用中文pretrain model训练的结果\n", 69 | "taiyi_path = './open_clip/src/logs/2022_07_18-18_39_51-model_ViT-L-14-lr_1e-05-b_224-j_8-p_amp/checkpoints/epoch_7.pt'\n", 70 | "bertconfig = BertConfig.from_pretrained(\"hfl/chinese-roberta-wwm-ext-large\", num_labels=768)\n", 71 | "my_transformer = BertForSequenceClassification.from_pretrained(\"hfl/chinese-roberta-wwm-ext-large\", config=bertconfig)\n", 72 | "mytokenizer = BertTokenizer.from_pretrained(\"hfl/chinese-roberta-wwm-ext-large\")\n", 73 | "\n", 74 | "\n", 75 | "state_dict_of_bert = torch.load(taiyi_path)['state_dict']\n", 76 | "bert_weights = {k.replace('module.transformer.',''):v for k,v in state_dict_of_bert.items() if 'module.transformer' in k}\n", 77 | "my_transformer.load_state_dict(bert_weights)" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": null, 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "# 同时保存模型和词表格。然后把这个上传到huggingface上面去\n", 87 | "my_transformer.save_pretrained('./Taiyi-CLIP-Roberta-large-326M-Chinese')\n", 88 | "mytokenizer.save_pretrained('./Taiyi-CLIP-Roberta-large-326M-Chinese')" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "total = sum([param.nelement() for param in my_transformer.parameters()])\n", 98 | "print(\"Number of parameter: %.2fM\" % (total/1e6))\n" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": null, 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [ 107 | "from PIL import Image\n", 108 | "import requests\n", 109 | "import clip\n", 110 | "import torch\n", 111 | "from transformers import BertForSequenceClassification, BertConfig, BertTokenizer\n", 112 | "from transformers import CLIPProcessor, CLIPModel\n", 113 | "import numpy as np\n", 114 | "\n", 115 | "query_texts = [\"一只猫\", \"一只狗\",'两只猫', '两只老虎','一只老虎'] # 这里是输入文本的,可以随意替换。\n", 116 | "# 加载Taiyi 中文 text encoder\n", 117 | "text_tokenizer = BertTokenizer.from_pretrained(\"IDEA-CCNL/Taiyi-CLIP-Roberta-large-326M-Chinese\")\n", 118 | "text_encoder = BertForSequenceClassification.from_pretrained(\"IDEA-CCNL/Taiyi-CLIP-Roberta-large-326M-Chinese\").eval()\n", 119 | "text = text_tokenizer(query_texts, return_tensors='pt', padding=True)['input_ids']\n", 120 | "\n", 121 | "url = \"http://images.cocodataset.org/val2017/000000039769.jpg\" # 这里可以换成任意图片的url\n", 122 | "# 加载CLIP的image encoder\n", 123 | "clip_model = CLIPModel.from_pretrained(\"openai/clip-vit-large-patch14\") \n", 124 | "processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-large-patch14\")\n", 125 | "image = processor(images=Image.open(requests.get(url, stream=True).raw), return_tensors=\"pt\")\n", 126 | "\n", 127 | "with torch.no_grad():\n", 128 | " image_features = clip_model.get_image_features(**image)\n", 129 | " text_features = text_encoder(text).logits\n", 130 | " # 归一化\n", 131 | " image_features = image_features / image_features.norm(dim=1, keepdim=True)\n", 132 | " text_features = text_features / text_features.norm(dim=1, keepdim=True)\n", 133 | " # 计算余弦相似度 logit_scale是尺度系数\n", 134 | " logit_scale = clip_model.logit_scale.exp()\n", 135 | " logits_per_image = logit_scale * image_features @ text_features.t()\n", 136 | " logits_per_text = logits_per_image.t()\n", 137 | " probs = logits_per_image.softmax(dim=-1).cpu().numpy()\n", 138 | " print(np.around(probs, 3))\n" 139 | ] 140 | }, 141 | { 142 | "cell_type": "markdown", 143 | "metadata": {}, 144 | "source": [ 145 | "# ViT-H, 维度对应的Roberta-Large" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": null, 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [ 154 | "import torch\n", 155 | "from transformers import BertModel, BertTokenizer\n", 156 | "\n", 157 | "# NOTE load from local path\n", 158 | "local_path = './scripts_t2i/open_clip_new/src/logs/2022_09_16-23_03_14-model_ViT-H-14-lr_5e-05-b_256-j_32-p_amp/checkpoints/epoch_21.pt'\n", 159 | "text_encoder = BertModel.from_pretrained(\"hfl/chinese-roberta-wwm-ext-large\").cuda().eval()\n", 160 | "state_dict_of_bert = torch.load(local_path)['state_dict']\n", 161 | "bert_weights = {k.replace('module.transformer.',''):v for k,v in state_dict_of_bert.items() if 'module.transformer' in k}\n", 162 | "text_encoder.load_state_dict(bert_weights)\n", 163 | "tokenizer = BertTokenizer.from_pretrained(\"hfl/chinese-roberta-wwm-ext\")\n" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": null, 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "# 同时保存模型和词表格。然后把这个上传到huggingface上面去\n", 173 | "text_encoder.save_pretrained('./fengshen/Taiyi-CLIP-Roberta-326M-ViT-H-Chinese')\n", 174 | "tokenizer.save_pretrained('./fengshen/Taiyi-CLIP-Roberta-326M-ViT-H-Chinese')" 175 | ] 176 | }, 177 | { 178 | "cell_type": "markdown", 179 | "metadata": {}, 180 | "source": [ 181 | "# ViT-L --- Roberta-base" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": null, 187 | "metadata": {}, 188 | "outputs": [], 189 | "source": [ 190 | "import torch\n", 191 | "from transformers import BertModel, BertTokenizer\n", 192 | "\n", 193 | "# NOTE load from local path\n", 194 | "local_path = './project/open_clip_new/src/logs/2022_11_18-21_15_16-model_ViT-L-14-lr_0.0001-b_640-j_16-p_amp/checkpoints/epoch_4.pt'\n", 195 | "text_encoder = BertModel.from_pretrained(\"hfl/chinese-roberta-wwm-ext\").cuda().eval()\n", 196 | "state_dict_of_bert = torch.load(local_path)['state_dict']\n", 197 | "bert_weights = {k.replace('module.transformer.',''):v for k,v in state_dict_of_bert.items() if 'module.transformer' in k}\n", 198 | "text_encoder.load_state_dict(bert_weights)\n", 199 | "tokenizer = BertTokenizer.from_pretrained(\"hfl/chinese-roberta-wwm-ext\")\n" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": null, 205 | "metadata": {}, 206 | "outputs": [], 207 | "source": [ 208 | "# 同时保存模型和词表格。然后把这个上传到huggingface上面去\n", 209 | "text_encoder.save_pretrained('./project/temp_weights/vit-l-roberta-base')\n", 210 | "tokenizer.save_pretrained('./project/temp_weights/vit-l-roberta-base')" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": null, 216 | "metadata": {}, 217 | "outputs": [], 218 | "source": [ 219 | "total = sum([param.nelement() for param in text_encoder.parameters()])\n", 220 | "print(\"Number of parameter: %.2fM\" % (total/1e6))" 221 | ] 222 | } 223 | ], 224 | "metadata": { 225 | "kernelspec": { 226 | "display_name": "Python 3.9.13 ('base')", 227 | "language": "python", 228 | "name": "python3" 229 | }, 230 | "language_info": { 231 | "codemirror_mode": { 232 | "name": "ipython", 233 | "version": 3 234 | }, 235 | "file_extension": ".py", 236 | "mimetype": "text/x-python", 237 | "name": "python", 238 | "nbconvert_exporter": "python", 239 | "pygments_lexer": "ipython3", 240 | "version": "3.9.13" 241 | }, 242 | "orig_nbformat": 4, 243 | "vscode": { 244 | "interpreter": { 245 | "hash": "4cc247672a8bfe61dc951074f9ca89ab002dc0f7e14586a8bb0828228bebeefa" 246 | } 247 | } 248 | }, 249 | "nbformat": 4, 250 | "nbformat_minor": 2 251 | } 252 | -------------------------------------------------------------------------------- /text-image/zeroshot_retrieval_evaluation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "def get_metrics(image_features, text_features, labels, logit_scale):\n", 10 | " # 计算相似度,支持多个样本的情况(比如一个图片有多个caption)\n", 11 | " # img2txt计算的时候要用到,因为一张图片可能对应多个文本。\n", 12 | " # txt2img计算的时候不需要(一般一个text只有一个对应图片)\n", 13 | " metrics = {}\n", 14 | " logits_per_image = (logit_scale * image_features @ text_features.t()).detach().cpu()\n", 15 | " logits_per_text = logits_per_image.t().detach().cpu()\n", 16 | "\n", 17 | " logits = {\"image_to_text\": logits_per_image, \"text_to_image\": logits_per_text}\n", 18 | "\n", 19 | " label2idx = {} # 计算label到idx的映射。\n", 20 | " repeat_id = []\n", 21 | " for i, label in enumerate(labels):\n", 22 | " if label not in label2idx:\n", 23 | " label2idx[label] = [i]\n", 24 | " else:\n", 25 | " # 表示该index的标签出现过,记录这个index,后续算txt2img分数的时候,这些index的权值要降低。\n", 26 | " label2idx[label].append(i)\n", 27 | " repeat_id.append(i)\n", 28 | " # print(label2idx) # 标注了每个label的idx\n", 29 | "\n", 30 | " # print('repeat_id:', repeat_id)\n", 31 | " ground_truth = [label2idx[label] for label in labels]\n", 32 | " # print(ground_truth)\n", 33 | "\n", 34 | " for name, logit in logits.items():\n", 35 | " # print(name, logit.shape)\n", 36 | " if name == 'text_to_image':\n", 37 | " logit[:, repeat_id] -= 1e8 # 这部分的分数要降低。(重复出现的图片,直接忽略)\n", 38 | " r1_stat, r5_stat, r10_stat = [], [], []\n", 39 | " ranking = torch.argsort(logit, descending=True) # index of the largest element to the smallest\n", 40 | " # print(name, ranking[:, :10])\n", 41 | " for i, each_query in enumerate(ranking[:, :10]):\n", 42 | " for j, q in enumerate(each_query):\n", 43 | " if q in ground_truth[i]:\n", 44 | " if j == 0:\n", 45 | " r1_stat.append(1)\n", 46 | " r5_stat.append(1)\n", 47 | " r10_stat.append(1)\n", 48 | " break\n", 49 | " if j < 5:\n", 50 | " r5_stat.append(1)\n", 51 | " r10_stat.append(1)\n", 52 | " break\n", 53 | " if j < 10:\n", 54 | " r10_stat.append(1)\n", 55 | " break\n", 56 | " print(f'{name} r1:{sum(r1_stat)/len(logit)}, r5:{sum(r5_stat)/len(logit)}, r10:{sum(r10_stat)/len(logit)}')\n" 57 | ] 58 | }, 59 | { 60 | "cell_type": "markdown", 61 | "metadata": {}, 62 | "source": [ 63 | "# COCO-CN" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "from transformers import BertTokenizer\n", 73 | "from torch.utils.data import Dataset\n", 74 | "from torch.utils.data import DataLoader\n", 75 | "from PIL import Image\n", 76 | "class COCO_CN(Dataset):\n", 77 | " def __init__(self, img_root_path='./dataset/coco', \\\n", 78 | " test_img_path='./dataset/coco/coco-cn-version1805v1.1/coco-cn_test.txt', \\\n", 79 | " annot_path = './dataset/coco/coco-cn-version1805v1.1/imageid.human-written-caption.txt', \\\n", 80 | " transform=None):\n", 81 | " self.images = []\n", 82 | " self.captions = []\n", 83 | " self.labels = []\n", 84 | " self.root = img_root_path\n", 85 | " \n", 86 | " test_path = dict()\n", 87 | " with open(test_img_path, 'r') as f:\n", 88 | " for line in f:\n", 89 | " line = line.strip()\n", 90 | " if line not in test_path:\n", 91 | " test_path[line] = 1\n", 92 | " # print(test_path)\n", 93 | "\n", 94 | " with open(annot_path, 'r') as f:\n", 95 | " for line in f:\n", 96 | " line = line.strip().split('\\t')\n", 97 | " key, caption = line[0].split('#')[0], line[1]\n", 98 | " # NOTE 只保留test set的\n", 99 | " if key not in test_path:\n", 100 | " continue\n", 101 | " # if line[0].split('#')[-1] != '0':\n", 102 | " # # print(key, line[0].split('#')[-1])\n", 103 | " # continue # 只保留一句\n", 104 | " img_path = key + '.jpg'\n", 105 | "\n", 106 | " if 'train' in img_path:\n", 107 | " self.images.append(os.path.join('train2014' ,img_path) )\n", 108 | " else:\n", 109 | " self.images.append(os.path.join('val2014' ,img_path) )\n", 110 | " self.captions.append(caption)\n", 111 | " self.labels.append(key)\n", 112 | " self.transforms = transform\n", 113 | " self.tokenizer = BertTokenizer.from_pretrained(\"hfl/chinese-roberta-wwm-ext\")\n", 114 | "\n", 115 | " # NOTE large 模型\n", 116 | " self.context_length = 77\n", 117 | "\n", 118 | " def __len__(self):\n", 119 | " return len(self.images)\n", 120 | "\n", 121 | " def __getitem__(self, idx):\n", 122 | " img_path = str(self.images[idx])\n", 123 | " image = self.transforms(Image.open( os.path.join(self.root, img_path ))) \n", 124 | " text = self.tokenizer(str(self.captions[idx]), max_length=self.context_length, padding='max_length', truncation=True, return_tensors='pt')['input_ids'][0]\n", 125 | " label = self.labels[idx]\n", 126 | " return image, text, label\n", 127 | "\n", 128 | "from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \\\n", 129 | " CenterCrop\n", 130 | "def _convert_to_rgb(image):\n", 131 | " return image.convert('RGB')\n", 132 | "\n", 133 | "def image_transform(\n", 134 | " image_size: int,\n", 135 | " is_train: bool,\n", 136 | " mean=(0.48145466, 0.4578275, 0.40821073),\n", 137 | " std=(0.26862954, 0.26130258, 0.27577711)\n", 138 | "):\n", 139 | " normalize = Normalize(mean=mean, std=std)\n", 140 | " if is_train:\n", 141 | " return Compose([\n", 142 | " RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC),\n", 143 | " _convert_to_rgb,\n", 144 | " ToTensor(),\n", 145 | " normalize,\n", 146 | " ])\n", 147 | " else:\n", 148 | " return Compose([\n", 149 | " Resize(image_size, interpolation=InterpolationMode.BICUBIC),\n", 150 | " CenterCrop(image_size),\n", 151 | " _convert_to_rgb,\n", 152 | " ToTensor(),\n", 153 | " normalize,\n", 154 | " ])\n", 155 | "\n", 156 | "val_transform = image_transform(224, False)\n", 157 | "dataset = COCO_CN(transform = val_transform)\n", 158 | "dataloader = DataLoader(dataset, batch_size=128, shuffle=False, num_workers=4)" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": null, 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "len(dataset)" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [ 176 | "from transformers import BertTokenizer\n", 177 | "from transformers import BertForSequenceClassification\n", 178 | "from transformers import CLIPModel\n", 179 | "import torch\n", 180 | "# NOTE load model\n", 181 | "\n", 182 | "text_encoder = BertForSequenceClassification.from_pretrained(\"IDEA-CCNL/Taiyi-CLIP-Roberta-102M-Chinese\").cuda().eval()\n", 183 | "clip_model = CLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\").cuda().eval() \n", 184 | "\n", 185 | "# text_encoder = BertForSequenceClassification.from_pretrained(\"IDEA-CCNL/Taiyi-CLIP-Roberta-large-326M-Chinese\").cuda().eval()\n", 186 | "# clip_model = CLIPModel.from_pretrained(\"openai/clip-vit-large-patch14\").cuda().eval() \n", 187 | "\n", 188 | "\n", 189 | "all_img_features, all_text_features, all_labels = [], [], []\n", 190 | "with torch.no_grad():\n", 191 | " for i, data in enumerate(dataloader):\n", 192 | " images, captions, labels = data\n", 193 | " images = images.cuda()\n", 194 | " captions = captions.cuda()\n", 195 | " all_labels.extend(labels)\n", 196 | " # print(images.shape, captions.shape, labels)\n", 197 | "\n", 198 | " image_features = clip_model.get_image_features(images)\n", 199 | " text_features = text_encoder(captions).logits\n", 200 | " # 归一化\n", 201 | " image_features = image_features / image_features.norm(dim=1, keepdim=True)\n", 202 | " text_features = text_features / text_features.norm(dim=1, keepdim=True)\n", 203 | " all_img_features.append(image_features)\n", 204 | " all_text_features.append(text_features)\n", 205 | " # if i == 10:\n", 206 | " # break\n", 207 | " img_features = torch.cat(all_img_features)\n", 208 | " text_features = torch.cat(all_text_features)\n", 209 | " print(img_features.shape, text_features.shape, len(all_labels))\n", 210 | "get_metrics(img_features, text_features, all_labels, 100) " 211 | ] 212 | }, 213 | { 214 | "cell_type": "markdown", 215 | "metadata": {}, 216 | "source": [ 217 | "# flickr30k-CNA" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": null, 223 | "metadata": {}, 224 | "outputs": [], 225 | "source": [ 226 | "from transformers import BertTokenizer\n", 227 | "from torch.utils.data import Dataset\n", 228 | "from torch.utils.data import DataLoader\n", 229 | "from PIL import Image\n", 230 | "class flickr30k_CNA(Dataset):\n", 231 | " def __init__(self, img_root_path='./dataset/mm_data/Flickr30k-CNA/flickr30k/images', \\\n", 232 | " text_annot_path='./dataset/mm_data/Flickr30k-CNA/test/flickr30k_cn_test.txt', \\\n", 233 | " transform=None):\n", 234 | " self.images = []\n", 235 | " self.captions = []\n", 236 | " self.labels = []\n", 237 | " self.root = img_root_path\n", 238 | " with open(text_annot_path, 'r') as f:\n", 239 | " for line in f:\n", 240 | " line = line.strip().split('\\t')\n", 241 | " key, caption = line[0].split('#')[0], line[1]\n", 242 | " img_path = key + '.jpg'\n", 243 | " self.images.append(img_path)\n", 244 | " self.captions.append(caption)\n", 245 | " self.labels.append(key)\n", 246 | " self.transforms = transform\n", 247 | " self.tokenizer = BertTokenizer.from_pretrained(\"hfl/chinese-roberta-wwm-ext\")\n", 248 | "\n", 249 | " # NOTE large 模型\n", 250 | " self.context_length = 77\n", 251 | "\n", 252 | " def __len__(self):\n", 253 | " return len(self.images)\n", 254 | "\n", 255 | " def __getitem__(self, idx):\n", 256 | " img_path = str(self.images[idx])\n", 257 | " image = self.transforms(Image.open( os.path.join(self.root, img_path ))) \n", 258 | " text = self.tokenizer(str(self.captions[idx]), max_length=self.context_length, padding='max_length', truncation=True, return_tensors='pt')['input_ids'][0]\n", 259 | " label = self.labels[idx]\n", 260 | " return image, text, label\n", 261 | "\n", 262 | "from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \\\n", 263 | " CenterCrop\n", 264 | "def _convert_to_rgb(image):\n", 265 | " return image.convert('RGB')\n", 266 | "\n", 267 | "def image_transform(\n", 268 | " image_size: int,\n", 269 | " is_train: bool,\n", 270 | " mean=(0.48145466, 0.4578275, 0.40821073),\n", 271 | " std=(0.26862954, 0.26130258, 0.27577711)\n", 272 | "):\n", 273 | " normalize = Normalize(mean=mean, std=std)\n", 274 | " if is_train:\n", 275 | " return Compose([\n", 276 | " RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC),\n", 277 | " _convert_to_rgb,\n", 278 | " ToTensor(),\n", 279 | " normalize,\n", 280 | " ])\n", 281 | " else:\n", 282 | " return Compose([\n", 283 | " Resize(image_size, interpolation=InterpolationMode.BICUBIC),\n", 284 | " CenterCrop(image_size),\n", 285 | " _convert_to_rgb,\n", 286 | " ToTensor(),\n", 287 | " normalize,\n", 288 | " ])\n", 289 | "\n", 290 | "val_transform = image_transform(224, False)\n", 291 | "img_root = '/dataset/mm_data/Flickr30k-CNA/flickr30k/images'\n", 292 | "text_annot_path = './dataset/mm_data/Flickr30k-CNA/test/flickr30k_cn_test.txt'\n", 293 | "dataset = flickr30k_CNA(img_root, text_annot_path, val_transform)\n", 294 | "dataloader = DataLoader(dataset, batch_size=128, shuffle=False, num_workers=4)" 295 | ] 296 | }, 297 | { 298 | "cell_type": "code", 299 | "execution_count": null, 300 | "metadata": {}, 301 | "outputs": [], 302 | "source": [ 303 | "from transformers import BertTokenizer\n", 304 | "from transformers import BertForSequenceClassification\n", 305 | "from transformers import CLIPModel\n", 306 | "import torch\n", 307 | "# text_encoder = BertForSequenceClassification.from_pretrained(\"IDEA-CCNL/Taiyi-CLIP-Roberta-102M-Chinese\").cuda().eval()\n", 308 | "# clip_model = CLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\").cuda().eval() \n", 309 | "\n", 310 | "# NOTE large\n", 311 | "text_encoder = BertForSequenceClassification.from_pretrained(\"IDEA-CCNL/Taiyi-CLIP-Roberta-large-326M-Chinese\").cuda().eval()\n", 312 | "clip_model = CLIPModel.from_pretrained(\"openai/clip-vit-large-patch14\").cuda().eval() " 313 | ] 314 | }, 315 | { 316 | "cell_type": "code", 317 | "execution_count": null, 318 | "metadata": {}, 319 | "outputs": [], 320 | "source": [ 321 | "all_img_features, all_text_features, all_labels = [], [], []\n", 322 | "with torch.no_grad():\n", 323 | " for i, data in enumerate(dataloader):\n", 324 | " images, captions, labels = data\n", 325 | " images = images.cuda()\n", 326 | " captions = captions.cuda()\n", 327 | " all_labels.extend(labels)\n", 328 | " # print(images.shape, captions.shape, labels)\n", 329 | "\n", 330 | " image_features = clip_model.get_image_features(images)\n", 331 | " text_features = text_encoder(captions).logits\n", 332 | " # 归一化\n", 333 | " image_features = image_features / image_features.norm(dim=1, keepdim=True)\n", 334 | " text_features = text_features / text_features.norm(dim=1, keepdim=True)\n", 335 | " all_img_features.append(image_features)\n", 336 | " all_text_features.append(text_features)\n", 337 | " # if i == 10:\n", 338 | " # break\n", 339 | " img_features = torch.cat(all_img_features)\n", 340 | " text_features = torch.cat(all_text_features)\n", 341 | " print(img_features.shape, text_features.shape, len(all_labels))" 342 | ] 343 | }, 344 | { 345 | "cell_type": "code", 346 | "execution_count": null, 347 | "metadata": {}, 348 | "outputs": [], 349 | "source": [ 350 | "get_metrics(img_features, text_features, all_labels, 100) # 图片取前1000张,因为后面的是重复的(每张图片对应5个caption)Flickr" 351 | ] 352 | }, 353 | { 354 | "cell_type": "markdown", 355 | "metadata": {}, 356 | "source": [ 357 | "# non-classification" 358 | ] 359 | }, 360 | { 361 | "cell_type": "code", 362 | "execution_count": null, 363 | "metadata": {}, 364 | "outputs": [], 365 | "source": [ 366 | "# NOTE load from local path\n", 367 | "from transformers import BertModel\n", 368 | "local_path = './project/open_clip_new/src/logs/2022_11_18-21_15_16-model_ViT-L-14-lr_0.0001-b_640-j_16-p_amp/checkpoints/epoch_5.pt'\n", 369 | "text_encoder = BertModel.from_pretrained(\"hfl/chinese-roberta-wwm-ext\").cuda().eval()\n", 370 | "state_dict_of_bert = torch.load(local_path)['state_dict']\n", 371 | "bert_weights = {k.replace('module.transformer.',''):v for k,v in state_dict_of_bert.items() if 'module.transformer' in k}\n", 372 | "text_encoder.load_state_dict(bert_weights)" 373 | ] 374 | }, 375 | { 376 | "cell_type": "code", 377 | "execution_count": null, 378 | "metadata": {}, 379 | "outputs": [], 380 | "source": [ 381 | "all_img_features, all_text_features, all_labels = [], [], []\n", 382 | "with torch.no_grad():\n", 383 | " for i, data in enumerate(dataloader):\n", 384 | " images, captions, labels = data\n", 385 | " images = images.cuda()\n", 386 | " captions = captions.cuda()\n", 387 | " all_labels.extend(labels)\n", 388 | " # print(images.shape, captions.shape, labels)\n", 389 | "\n", 390 | " image_features = clip_model.get_image_features(images)\n", 391 | " text_features = text_encoder(captions)[1]\n", 392 | " # 归一化\n", 393 | " image_features = image_features / image_features.norm(dim=1, keepdim=True)\n", 394 | " text_features = text_features / text_features.norm(dim=1, keepdim=True)\n", 395 | " all_img_features.append(image_features)\n", 396 | " all_text_features.append(text_features)\n", 397 | " # if i == 10:\n", 398 | " # break\n", 399 | " img_features = torch.cat(all_img_features)\n", 400 | " text_features = torch.cat(all_text_features)\n", 401 | " print(img_features.shape, text_features.shape, len(all_labels))\n", 402 | "get_metrics(img_features, text_features, all_labels, 100) # 图片取前1000张,因为后面的是重复的(每张图片对应5个caption)Flickr" 403 | ] 404 | } 405 | ], 406 | "metadata": { 407 | "kernelspec": { 408 | "display_name": "Python 3.9.13 ('base')", 409 | "language": "python", 410 | "name": "python3" 411 | }, 412 | "language_info": { 413 | "codemirror_mode": { 414 | "name": "ipython", 415 | "version": 3 416 | }, 417 | "file_extension": ".py", 418 | "mimetype": "text/x-python", 419 | "name": "python", 420 | "nbconvert_exporter": "python", 421 | "pygments_lexer": "ipython3", 422 | "version": "3.9.13" 423 | }, 424 | "orig_nbformat": 4, 425 | "vscode": { 426 | "interpreter": { 427 | "hash": "4cc247672a8bfe61dc951074f9ca89ab002dc0f7e14586a8bb0828228bebeefa" 428 | } 429 | } 430 | }, 431 | "nbformat": 4, 432 | "nbformat_minor": 2 433 | } 434 | --------------------------------------------------------------------------------