├── charset ├── FS16.txt ├── demo_characters.txt ├── check_overlap.py ├── get_remain.py ├── TRAIN800.txt ├── TEST5646.txt └── GB2312_CN6763.txt ├── .gitignore ├── functions ├── __init__.py └── modulated_deform_conv_func.py ├── modules ├── __init__.py └── modulated_deform_conv.py ├── requirements.txt ├── scripts ├── 03b_cluster_get_cf_basis.sh ├── 03c_copy_basis_subset.sh ├── 03e_init_cf_env.sh ├── 02a_run_ddp.sh ├── 01a_gen_date.sh ├── 04c_cal_mean_scores.sh ├── 04b_get_scores.sh ├── 01b_copy_subset.sh ├── basis │ ├── copy_basis_imgs.py │ ├── select_base_font.py │ └── get_basis_simple.py ├── 03a_get_content_embeddings.sh ├── visualization │ ├── format_data.py │ ├── viz.py │ ├── plot_nn.py │ ├── plot_nn_pair.py │ └── make_sprite.py ├── 03f_run_ddp_cf.sh ├── data_preparation │ └── gen_subset.py ├── 03d_cal_cf_weights.sh ├── option_run_inf_dgfont.sh └── 04a_run_inf_cf.sh ├── src ├── vision.cpp ├── cpu │ ├── deform_conv_cpu.h │ ├── deform_psroi_pooling_cpu.h │ ├── modulated_deform_conv_cpu.h │ ├── deform_cpu.cpp │ ├── deform_psroi_pooling_cpu.cpp │ └── modulated_deform_cpu.cpp ├── cuda │ ├── deform_conv_cuda.h │ ├── deform_psroi_pooling_cuda.h │ └── modulated_deform_conv_cuda.h ├── deform_conv.h ├── modulated_deform_conv.h └── deform_psroi_pooling.h ├── eval ├── lpips_2imgs.py ├── lpips_2dirs.py ├── cal_mean.py ├── eval_utils.py ├── get_scores.py ├── get_scores_test.py └── eval_2dirs.py ├── tools ├── phl.py ├── pkl.py ├── wdl.py ├── label_smooth.py ├── utils.py ├── hsic.py ├── ops.py └── abl_allinone.py ├── oss_client.py ├── models ├── guidingNet.py ├── discriminator.py └── blocks.py ├── font2img.py ├── train ├── train_style_vec.py └── train.py ├── validation ├── validation.py └── validation_cf.py ├── README.md ├── datasets └── datasetgetter.py ├── cal_cf_weight.py └── collect_content_embeddings.py /charset/FS16.txt: -------------------------------------------------------------------------------- 1 | 歌袁战余造限汽沪频距防封界蚊乃难 -------------------------------------------------------------------------------- /charset/demo_characters.txt: -------------------------------------------------------------------------------- 1 | 苟利国家生死以岂因祸福避趋之 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | output/ 3 | ./basis/ 4 | 5 | __pychache__ 6 | *pyc -------------------------------------------------------------------------------- /functions/__init__.py: -------------------------------------------------------------------------------- 1 | from .modulated_deform_conv_func import ModulatedDeformConvFunction 2 | -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .modulated_deform_conv import ModulatedDeformConv, _ModulatedDeformConv, ModulatedDeformConvPack -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch (>=1.0) 2 | tqdm 3 | numpy 4 | opencv-python 5 | scipy 6 | sklearn 7 | matplotlib 8 | pillow 9 | tensorboardX 10 | scikit-image 11 | scikit-learn 12 | pytorch-fid 13 | lpips 14 | pandas 15 | kornia -------------------------------------------------------------------------------- /scripts/03b_cluster_get_cf_basis.sh: -------------------------------------------------------------------------------- 1 | n_basis=10 2 | model_name=B0_K240BS32I1000E200_LR1e-4-wdl0.01_20230426-233306 3 | item=180 4 | content_fm=output/embeddings/embedding_${model_name}_${item}/c_src.pth 5 | 6 | CUDA_VISIBLE_DEVICES=0 python scripts/basis/get_basis_simple.py \ 7 | -c ${content_fm} -lbs 10 -nb ${n_basis} -m ${model_name}_${item} -------------------------------------------------------------------------------- /scripts/03c_copy_basis_subset.sh: -------------------------------------------------------------------------------- 1 | base_n=10 2 | 3 | for FLAG in TRAIN800 TEST5646 FS16 4 | do 5 | basis=basis/B0_K240BS32I1000E200_LR1e-4-wdl0.01_20230426-233306_180_basis_240_id_10.txt 6 | in_folder=data/imgs/Seen240_S80F50_${FLAG} 7 | out_folder=data/imgs/BASIS_S80F50_${FLAG} 8 | 9 | python scripts/basis/copy_basis_imgs.py -b ${basis} -i ${in_folder} -o ${out_folder} 10 | done -------------------------------------------------------------------------------- /scripts/03e_init_cf_env.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | item=180 3 | model_base_folder=output/models/logs 4 | model=B0_K240BS32I1000E200_LR1e-4-wdl0.01_20230426-233306 5 | model_cf=CF_from_${model}_${item} 6 | ckpt=model_${item}.ckpt 7 | 8 | mkdir ${model_base_folder}/${model_cf} 9 | cp ${model_base_folder}/${model}/${ckpt} ${model_base_folder}/${model_cf}/ 10 | echo ${ckpt} > ${model_base_folder}/${model_cf}/checkpoint.txt -------------------------------------------------------------------------------- /charset/check_overlap.py: -------------------------------------------------------------------------------- 1 | from numpy import chararray 2 | import click 3 | import os 4 | 5 | @click.command() 6 | @click.argument('f1', type=click.File('r')) 7 | @click.argument('f2', type=click.File('r')) 8 | 9 | def check(f1, f2): 10 | t1 = set(f1.readlines()[0]) 11 | t2 = set(f2.readlines()[0]) 12 | print(len(t1), len(t2)) 13 | print('Overlap:', len(t1 & t2)) 14 | 15 | if __name__ == "__main__": 16 | check() -------------------------------------------------------------------------------- /scripts/02a_run_ddp.sh: -------------------------------------------------------------------------------- 1 | mkdir output 2 | mkdir output/models 3 | 4 | CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch \ 5 | --nproc_per_node=2 --use_env main.py \ 6 | --img_size 80 \ 7 | --data_path data/imgs/Seen240_S80F50_TRAIN800 \ 8 | --lr 1e-4 \ 9 | --output_k 240 \ 10 | --batch_size 16 \ 11 | --iters 1000 \ 12 | --epoch 200 \ 13 | --val_num 10 \ 14 | --baseline_idx 0 \ 15 | --save_path output/models \ 16 | --model_name B0_K240BS32I1000E200_LR1e-4-wdl0.01 \ 17 | --ddp \ 18 | --wdl --w_wdl 0.01 \ 19 | --no_val 20 | # --load_model CF-Font/output/models/logs/B0_K240BS32I1000E200_LR1e-4-wdl0.01_20230426-233306 -------------------------------------------------------------------------------- /scripts/01a_gen_date.sh: -------------------------------------------------------------------------------- 1 | img_size=80 # 128 2 | chara_size=50 # 80 3 | chara=charset/GB2312_CN6763.txt #charset/test.txt 4 | 5 | font_basefolder=data/fonts 6 | out_basefolder=data/imgs 7 | mkdir $out_folder 8 | 9 | for font_set in Seen240 Unseen60 10 | do 11 | font_folder=${font_basefolder}/Font_$font_set 12 | out_folder=${out_basefolder}/${font_set}_S${img_size}F${chara_size}_FULL 13 | mkdir $out_folder 14 | 15 | python font2img.py --ttf_path $font_folder \ 16 | --img_size $img_size \ 17 | --chara_size $chara_size \ 18 | --chara $chara \ 19 | --save_path $out_folder 20 | done -------------------------------------------------------------------------------- /scripts/04c_cal_mean_scores.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | gid=${1:-"0"} 3 | 4 | for option in seen unseen 5 | do 6 | 7 | if [ $option == 'seen' ];then 8 | rst_name=CF_from_B0_K240BS32I1000E200_LR1e-4-wdl0.01_20230426-233306_180_200_top-1_ft10_wdl0.01_lr0.01 9 | font_len=240 10 | elif [ $option == 'unseen' ]; then 11 | rst_name=unseen_CF_from_B0_K240BS32I1000E200_LR1e-4-wdl0.01_20230426-233306_180_200_top-1_ft10_wdl0.01_lr0.01 12 | font_len=60 13 | fi 14 | pred_path=output/test_rsts/${rst_name} 15 | CUDA_VISIBLE_DEVICES=${gid} python eval/cal_mean.py \ 16 | -f ${pred_path}/a_scores/ \ 17 | -k ${font_len} 18 | # -j 1 3 4 ## !!! jump some fonts like basis font 19 | done -------------------------------------------------------------------------------- /scripts/04b_get_scores.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | gid=${1:-"0"} 3 | 4 | for option in seen unseen 5 | do 6 | 7 | if [ $option == 'seen' ];then 8 | rst_name=CF_from_B0_K240BS32I1000E200_LR1e-4-wdl0.01_20230426-233306_180_200_top-1_ft10_wdl0.01_lr0.01 9 | gt_path=data/imgs/Seen240_S80F50_TEST5646 10 | elif [ $option == 'unseen' ]; then 11 | rst_name=unseen_CF_from_B0_K240BS32I1000E200_LR1e-4-wdl0.01_20230426-233306_180_200_top-1_ft10_wdl0.01_lr0.01 12 | gt_path=data/imgs/Unseen60_S80F50_TEST5646 13 | fi 14 | pred_path=output/test_rsts/${rst_name} 15 | CUDA_VISIBLE_DEVICES=${gid} python eval/get_scores_test.py \ 16 | -gt ${gt_path} \ 17 | -pred ${pred_path} \ 18 | --gpu 19 | #-m l1 20 | done -------------------------------------------------------------------------------- /charset/get_remain.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('-f', '--full', type=str, required=True, help='Full chars') 7 | parser.add_argument('-t', '--train', type=str, required=True, help='Train chars') 8 | parser.add_argument('-o', '--out', type=str, required=True, help='save chars') 9 | args = parser.parse_args() 10 | 11 | with open(args.full, 'r') as f: 12 | chars_f = f.readline() 13 | with open(args.train, 'r') as f: 14 | chars_t = f.readline() 15 | 16 | chars_f = set(chars_f) 17 | chars_t = set(chars_t) 18 | chars_o = sorted(list(chars_f - chars_t)) 19 | print(chars_o, len(chars_o)) 20 | 21 | with open(args.out, 'w') as f: 22 | f.write(''.join(chars_o)) -------------------------------------------------------------------------------- /src/vision.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "deform_psroi_pooling.h" 3 | #include "deform_conv.h" 4 | #include "modulated_deform_conv.h" 5 | 6 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 7 | m.def("deform_conv_forward", &deform_conv_forward, "deform_conv_forward"); 8 | m.def("deform_conv_backward", &deform_conv_backward, "deform_conv_backward"); 9 | m.def("modulated_deform_conv_forward", &modulated_deform_conv_forward, "modulated_deform_conv_forward"); 10 | m.def("modulated_deform_conv_backward", &modulated_deform_conv_backward, "modulated_deform_conv_backward"); 11 | m.def("deform_psroi_pooling_forward", &deform_psroi_pooling_forward, "deform_psroi_pooling_forward"); 12 | m.def("deform_psroi_pooling_backward", &deform_psroi_pooling_backward, "deform_psroi_pooling_backward"); 13 | } 14 | -------------------------------------------------------------------------------- /scripts/01b_copy_subset.sh: -------------------------------------------------------------------------------- 1 | for font_set in Seen240_S80F50 Unseen60_S80F50 2 | do 3 | full_fp=data/imgs/${font_set}_FULL 4 | train_fp=data/imgs/${font_set}_TRAIN800 5 | test_fp=data/imgs/${font_set}_TEST5646 6 | fewshot_fp=data/imgs/${font_set}_FS16 7 | 8 | full_ch=charset/GB2312_CN6763.txt 9 | train_ch=charset/TRAIN800.txt 10 | test_ch=charset/TEST5646.txt 11 | fewshot_ch=charset/FS16.txt # in train_ch 12 | 13 | python scripts/data_preparation/gen_subset.py -i ${full_fp} -o ${train_fp} -ic ${full_ch} -oc ${train_ch} 14 | python scripts/data_preparation/gen_subset.py -i ${full_fp} -o ${test_fp} -ic ${full_ch} -oc ${test_ch} 15 | python scripts/data_preparation/gen_subset.py -i ${full_fp} -o ${fewshot_fp} -ic ${full_ch} -oc ${fewshot_ch} 16 | done -------------------------------------------------------------------------------- /scripts/basis/copy_basis_imgs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | import shutil 4 | import os 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('-i', '--input', type=str, required=True) 8 | parser.add_argument('-o', '--output', type=str, required=True) 9 | parser.add_argument('-b', '--basis', type=str, required=True) 10 | 11 | args = parser.parse_args() 12 | 13 | i = args.input 14 | o = args.output 15 | # basis = np.load(args.basis) 16 | basis = np.loadtxt(args.basis, dtype='int') 17 | 18 | if not os.path.exists(o): os.mkdir(o) 19 | 20 | # check 21 | print(basis) 22 | 23 | for ido, idi in enumerate(basis): 24 | src = os.path.join(i, f"id_{idi}") 25 | dst = os.path.join(o, f"id_{ido}") 26 | print(src, '-->', dst) 27 | shutil.copytree(src, dst) 28 | -------------------------------------------------------------------------------- /charset/TRAIN800.txt: -------------------------------------------------------------------------------- 1 | 歌袁战余造限汽沪频距防封界蚊乃难积迈台饥信料移细认势抗术钱圆良八银响丹亲店朽款杆佳超唱峦群士护端沾若按曾峰倍花凯异范黑货富贬皇芥型岔鲁觉受购粪选环沉模沿处志径列压杜石另停旷解乍刚排右级普长杯盾至空冲斗叮尼并兵城金叔妈例丐肥武捆虾屿朔协视伪再踪架巴岭裹奏消器往枕权直双查卷谈缺除掌劳科操露味革贞爱升丑言蚂钢枝齿率注伟财营轮思沧房投保狱丧斯研食秆住咏还衣旺阳布码存药村针随本身院白杂伙粉画回齐板被先虽材行略菜谁刺友秋伞船卖伯芦省构易丝况库批次建春蚀宣则记楚特养沫令族杀维乎击捡丢剧征格千胡酸师贤脱孙怎副蚁胶审肉溶越宁游述球罪缓丽危改虏确六规品色乱远仗继忠团任愿含专卧接索周功茫套奔枉迅沼价务青创侵握伤什知暗染棉提持层半显源王万静荷何感乙致肾争亚际跑失景统乡治基永根类讳快胞奇清鸡祖案粒简刘织绝轻语企供尚慢位音课朵宽太岩沽期钓波紧渐围昌重阿艺触酒单逐拿切流费庐江老镇守官咐迫枚病队训煤穿俊每瓶铁党置厚贝娘整曲杠早配顺才促似陆伏即沙乏福足光决犯七察映丰希判侍仲鱼乓纸笑入茶活死仿艇姜钳粮丘氢兴更弱乐降考读张庆包府丛今风夜射撤柔支晶仓印市岳管轰百领香绿常厂续火苗欲首眼将量善够皮担走呀庄鸣己急岸世矿口吗放极讨买九仅植给苦杰典伦尝议占华海否亿帝落效举助驼写演钞岛图强荣互汉据灯稳尔然转待湖冈笔龙垂求母门树害跳差资柏承地斤组状减英居标商阻荧红参识桥软五贩短计检哪仁章班职育丙诗县片帮岖从推枪振克告济沦息贵亮东区报久试叶芝罗服名涉号欧析液座跟盐容划业氧混纪书订坚影取河引共式热输客敌李剂称办较带攻导掉卫诵井刮著沮验初神丸脊补帜培父角题某搞木讥旬抓聚指央调干田盟凭独找刑胜观农鞭乔车裂依获密假换免绍广乒滑需厨便土吸宜通州序延近集照洋烈修雷束连南杼始伊柿岗甲垄叫宗古莉马执庙贡望安别终黄见坑真矛俘温局充席耐许元站玉果增虎山策庇速坐坏固女川答奋轴乌筛床尽岁众饺草休路季妻破装证洲态故她严京伍念声运牧拉做施控末字附交夫牛呼室 -------------------------------------------------------------------------------- /scripts/03a_get_content_embeddings.sh: -------------------------------------------------------------------------------- 1 | k=240 2 | item=180 3 | img_size=80 4 | model_base=B0_K240BS32I1000E200_LR1e-4-wdl0.01_20230426-233306 5 | model=output/models/logs/${model_base}/model_${item}.ckpt 6 | 7 | # data=data/imgs/Seen240_S80F50_TRAIN800 8 | # !!! The more characters used to cluster, the better. 9 | # However, since the memory limitation, you can also use a subset of data_K240_S80F50_TRAIN800 to cluster. 10 | # like random choose 50 characters or simply use few-shot 16 characters. 11 | data=data/imgs/Seen240_S80F50_FS16 12 | 13 | CUDA_VISIBLE_DEVICES=0 python collect_content_embeddings.py --img_size ${img_size} \ 14 | --data_path ${data} \ 15 | --output_k ${k} \ 16 | --batch_size 32 \ 17 | --load_model ${model} \ 18 | --save_path output/embeddings/embedding_${model_base}_${item} \ 19 | --baseline_idx 0 \ 20 | --n_atts ${k} \ 21 | --no_skip -------------------------------------------------------------------------------- /scripts/basis/select_base_font.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | import shutil 4 | import os 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('-i', '--input', type=str, required=True) 8 | parser.add_argument('-o', '--output', type=str, required=True) 9 | parser.add_argument('-b', '--basis', type=str, required=True) 10 | 11 | args = parser.parse_args() 12 | 13 | i = args.input 14 | o = args.output 15 | basis = np.load(args.basis) 16 | 17 | if not os.path.exists(o): os.mkdir(o) 18 | fns = sorted(os.listdir(i)) 19 | valid_ends = ['ttf', 'otf', 'ttc'] 20 | checker = lambda i: any([i.lower().endswith(end) for end in valid_ends]) 21 | fns = np.array([fn for fn in fns if checker(fn)]) 22 | 23 | print(len(fns)) 24 | print(basis) 25 | print(fns[basis]) 26 | 27 | for fn in fns[basis]: 28 | shutil.copyfile(os.path.join(i, fn), os.path.join(o, fn)) -------------------------------------------------------------------------------- /scripts/visualization/format_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | 5 | import glob 6 | from tqdm import tqdm 7 | 8 | # n = 400 9 | # path = 'data_400_hard_50' 10 | # out_p = 'data_400_hard_50_2x2_format' 11 | 12 | n = 221 13 | path = '../../dataset/Data_K221/data_K221_S128F80_Train800_Base50/' 14 | out_p = 'data_221_S128F80_Base50_2x2_format' 15 | 16 | os.makedirs(out_p, exist_ok=True) 17 | 18 | for i in tqdm(range(n)): 19 | folder = os.path.join(path, f'id_{i}') 20 | img_out_p = os.path.join(out_p, '{:04}.png'.format(i)) 21 | img_fns = glob.glob(os.path.join(folder, '*.png')) 22 | imgs = [cv2.imread(img_fn) for img_fn in sorted(img_fns)] # 50 23 | # hw = int(len(imgs) ** 0.5) 24 | hw = 2 25 | ih, iw, _ = imgs[0].shape 26 | imgs = imgs[:hw*hw] 27 | imgs = np.array(imgs).reshape((hw,hw,ih,iw,3)).transpose((0,2,1,3,4)).reshape((hw*ih,hw*iw,3)) 28 | cv2.imwrite(img_out_p, imgs) 29 | -------------------------------------------------------------------------------- /scripts/03f_run_ddp_cf.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | size=80 4 | item=180 5 | k=240 6 | basis_n=10 7 | data=data/imgs/Seen240_S80F50_TRAIN800 8 | model_base=B0_K240BS32I1000E200_LR1e-4-wdl0.01_20230426-233306 9 | model_name=CF_from_${model_base}_${item} 10 | base_idxs="basis/${model_base}_${item}_basis_${k}_id_${basis_n}.txt" 11 | base_ws="basis/${model_base}_${item}_basis_${k}_id_${basis_n}_ws_${k}x${basis_n}_t0.01.pth" 12 | 13 | CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch \ 14 | --nproc_per_node=2 --use_env main.py \ 15 | --content_fusion \ 16 | --img_size ${size} \ 17 | --data_path ${data} \ 18 | --lr 1e-4 \ 19 | --output_k ${k} \ 20 | --batch_size 16 \ 21 | --iters 1000 \ 22 | --epoch 200 \ 23 | --val_num 10 \ 24 | --baseline_idx 0 \ 25 | --save_path output/models \ 26 | --load_model ${model_name} \ 27 | --base_idxs ${base_idxs} --base_ws ${base_ws} \ 28 | --ddp \ 29 | --no_val \ 30 | --wdl --w_wdl 0.01 -------------------------------------------------------------------------------- /eval/lpips_2imgs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import lpips 3 | 4 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 5 | parser.add_argument('-p0','--path0', type=str, default='./imgs/ex_ref.png') 6 | parser.add_argument('-p1','--path1', type=str, default='./imgs/ex_p0.png') 7 | parser.add_argument('-v','--version', type=str, default='0.1') 8 | parser.add_argument('--use_gpu', action='store_true', help='turn on flag to use GPU') 9 | 10 | opt = parser.parse_args() 11 | 12 | ## Initializing the model 13 | loss_fn = lpips.LPIPS(net='alex',version=opt.version) 14 | 15 | if(opt.use_gpu): 16 | loss_fn.cuda() 17 | 18 | # Load images 19 | img0 = lpips.im2tensor(lpips.load_image(opt.path0)) # RGB image from [-1,1] 20 | img1 = lpips.im2tensor(lpips.load_image(opt.path1)) 21 | 22 | if(opt.use_gpu): 23 | img0 = img0.cuda() 24 | img1 = img1.cuda() 25 | 26 | # Compute distance 27 | dist01 = loss_fn.forward(img0, img1) 28 | print('Distance: %.3f'%dist01) 29 | -------------------------------------------------------------------------------- /tools/phl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def naive_dt(mask_fg, max_n = -1): # fg 0 bg 1-n 5 | b,c,h,w = mask_fg.shape 6 | assert c == 1 7 | mask_distance = mask_fg * 1.0 8 | mask_n = h * w 9 | bg_n = 1 10 | i = 1 11 | kernel = torch.tensor([[0.,1.,0.], [1.,1.,1.], [0.,1.,0.]], device=mask_fg.device)[None, None] 12 | while bg_n > 0 and i != max_n: 13 | mask_fg = torch.nn.functional.conv2d(mask_fg.float(), kernel, padding=1) > 0 14 | mask_distance += 1.0 * mask_fg 15 | bg_n = torch.logical_not(mask_fg).sum() 16 | i += 1 17 | return i - mask_distance 18 | 19 | def PHL(f1, f2, thres=0.5): # N,1,H,W fg1 bg0 20 | # f1, f2 in [-1, 1] fg -1 bg 1 21 | # Pseudo_Hamming_Loss 22 | mask1_fg = f1 < thres # black 23 | mask2_fg = f2 < thres # TODO Gradient??? 24 | prob1_fg = (1-f1) / 2 25 | prob2_fg = (1-f2) / 2 26 | dis1 = naive_dt(mask1_fg) * prob1_fg 27 | dis2 = naive_dt(mask2_fg) * prob2_fg 28 | mask1minus2 = 1.0 * mask1_fg - 1.0 * mask2_fg 29 | mask1not2 = mask1minus2 > 0 30 | mask2not1 = mask1minus2 < 0 31 | dis = torch.zeros_like(dis1) 32 | dis[mask2not1] = dis1[mask2not1] 33 | dis[mask1not2] = dis2[mask1not2] 34 | return dis -------------------------------------------------------------------------------- /tools/pkl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import kornia as K 3 | 4 | def kl_1d(a, b): # [N, C] 5 | eps = 1e-8 6 | return torch.nn.functional.kl_div(torch.log(a + eps), b, reduction='mean') 7 | 8 | def PKL(f1, f2): # N,1,H,W 9 | # f1, f2 in [-1, 1] fg -1 bg 1 10 | # W distance Loss 11 | # HW 12 | f1 = (1-f1) / 2 # fg 1 bg 0 13 | f2 = (1-f2) / 2 # fg 1 bg 0 14 | B = f1.shape[0] 15 | f1_0 = f1.sum((1,2)) # N,W 16 | f2_0 = f2.sum((1,2)) # N,W 17 | loss_0 = kl_1d(f1_0, f2_0) # N 18 | f1_1 = f1.sum((1,3)) # N,H 19 | f2_1 = f2.sum((1,3)) # N,H 20 | loss_1 = kl_1d(f1_1, f2_1) # N 21 | losses = [loss_0, loss_1] 22 | for angle in [15., 30., 45., 60., 75.]: 23 | f1r = K.geometry.rotate(f1, angle * torch.ones(B, device=f1.device)) 24 | f2r = K.geometry.rotate(f2, angle * torch.ones(B, device=f1.device)) 25 | f1r_0 = f1r.sum((1,2)) # N,W 26 | f2r_0 = f2r.sum((1,2)) # N,W 27 | lossr_0 = kl_1d(f1r_0, f2r_0) # N 28 | losses.append(lossr_0) 29 | f1r_1 = f1r.sum((1,3)) # N,H 30 | f2r_1 = f2r.sum((1,3)) # N,H 31 | lossr_1 = kl_1d(f1r_1, f2r_1) # N 32 | losses.append(lossr_1) 33 | loss = torch.stack(losses).mean() 34 | 35 | return loss 36 | -------------------------------------------------------------------------------- /eval/lpips_2dirs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import lpips 4 | 5 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 6 | parser.add_argument('-d0','--dir0', type=str, default='./imgs/ex_dir0') 7 | parser.add_argument('-d1','--dir1', type=str, default='./imgs/ex_dir1') 8 | parser.add_argument('-o','--out', type=str, default='./imgs/example_dists.txt') 9 | parser.add_argument('-v','--version', type=str, default='0.1') 10 | parser.add_argument('--use_gpu', action='store_true', help='turn on flag to use GPU') 11 | 12 | opt = parser.parse_args() 13 | 14 | ## Initializing the model 15 | loss_fn = lpips.LPIPS(net='alex',version=opt.version) 16 | if(opt.use_gpu): 17 | loss_fn.cuda() 18 | 19 | # crawl directories 20 | f = open(opt.out,'w') 21 | files = os.listdir(opt.dir0) 22 | 23 | for file in files: 24 | if(os.path.exists(os.path.join(opt.dir1,file))): 25 | # Load images 26 | img0 = lpips.im2tensor(lpips.load_image(os.path.join(opt.dir0,file))) # RGB image from [-1,1] 27 | img1 = lpips.im2tensor(lpips.load_image(os.path.join(opt.dir1,file))) 28 | 29 | if(opt.use_gpu): 30 | img0 = img0.cuda() 31 | img1 = img1.cuda() 32 | 33 | # Compute distance 34 | dist01 = loss_fn.forward(img0,img1) 35 | print('%s: %.3f'%(file,dist01)) 36 | f.writelines('%s: %.6f\n'%(file,dist01)) 37 | 38 | f.close() 39 | -------------------------------------------------------------------------------- /tools/wdl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import kornia as K 3 | 4 | def w_dis_1d(a, b): # [N, C] 5 | a_cdf = torch.cumsum(a, dim=1) # [N, C] 6 | a_cdf = a_cdf / a_cdf[:,-1:] 7 | b_cdf = torch.cumsum(b, dim=1) 8 | b_cdf = b_cdf / b_cdf[:,-1:] 9 | return (a_cdf - b_cdf).abs().sum(1) # [N] 10 | 11 | def WDL(f1, f2): # N,1,H,W 12 | # f1, f2 in [-1, 1] fg -1 bg 1 13 | # W distance Loss 14 | # HW 15 | f1 = (1-f1) / 2 # fg 1 bg 0 16 | f2 = (1-f2) / 2 # fg 1 bg 0 17 | B = f1.shape[0] 18 | f1_0 = f1.sum((1,2)) # N,W 19 | f2_0 = f2.sum((1,2)) # N,W 20 | loss_0 = w_dis_1d(f1_0, f2_0) # N 21 | f1_1 = f1.sum((1,3)) # N,H 22 | f2_1 = f2.sum((1,3)) # N,H 23 | loss_1 = w_dis_1d(f1_1, f2_1) # N 24 | losses = [loss_0, loss_1] 25 | for angle in [15., 30., 45., 60., 75.]: 26 | f1r = K.geometry.rotate(f1, angle * torch.ones(B, device=f1.device)) 27 | f2r = K.geometry.rotate(f2, angle * torch.ones(B, device=f1.device)) 28 | f1r_0 = f1r.sum((1,2)) # N,W 29 | f2r_0 = f2r.sum((1,2)) # N,W 30 | lossr_0 = w_dis_1d(f1r_0, f2r_0) # N 31 | losses.append(lossr_0) 32 | f1r_1 = f1r.sum((1,3)) # N,H 33 | f2r_1 = f2r.sum((1,3)) # N,H 34 | lossr_1 = w_dis_1d(f1r_1, f2r_1) # N 35 | losses.append(lossr_1) 36 | loss = torch.stack(losses).mean() 37 | 38 | return loss 39 | -------------------------------------------------------------------------------- /scripts/data_preparation/gen_subset.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import argparse 3 | import shutil 4 | import os 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('-i', '--input', type=str, required=True) 8 | parser.add_argument('-o', '--output', type=str, required=True) 9 | parser.add_argument('-ic', '--input_chara', type=str, required=True) 10 | parser.add_argument('-oc', '--output_chara', type=str, required=True) 11 | 12 | args = parser.parse_args() 13 | 14 | def get_charas(path): 15 | with open(path,encoding='utf-8') as f: 16 | characters = f.read() 17 | return list(characters) 18 | 19 | def safe_mkdir(path): 20 | if not os.path.exists(path): 21 | os.mkdir(path) 22 | 23 | ip = args.input 24 | op = args.output 25 | ic = get_charas(args.input_chara) 26 | oc = get_charas(args.output_chara) 27 | 28 | assert len(set(oc) - set(ic)) == 0 29 | ic_mapper = {c:i for i,c in enumerate(ic)} 30 | 31 | safe_mkdir(op) 32 | 33 | for sub_folder in tqdm(sorted(os.listdir(ip))): 34 | if not os.path.isdir(os.path.join(ip, sub_folder)): continue 35 | safe_mkdir(os.path.join(op, sub_folder)) 36 | for out_idx, out_char in enumerate(oc): 37 | in_idx = ic_mapper[out_char] 38 | src = os.path.join(ip, sub_folder, '%04d.png' % in_idx) 39 | dst = os.path.join(op, sub_folder, '%04d.png' % out_idx) 40 | shutil.copyfile(src, dst) -------------------------------------------------------------------------------- /scripts/03d_cal_cf_weights.sh: -------------------------------------------------------------------------------- 1 | for option in seen unseen 2 | do 3 | 4 | item=180 5 | k=240 6 | 7 | basis_len=10 8 | img_size=80 9 | basis_n=10 10 | temperature=0.01 11 | 12 | #model_base="B0_K240BS8x4I1000E200_LR1e-4_Pytorch181_20220423-163142" 13 | model_base=B0_K240BS32I1000E200_LR1e-4-wdl0.01_20230426-233306 14 | model=output/models/logs/${model_base}/model_${item}.ckpt 15 | 16 | basis_fn=B0_K240BS32I1000E200_LR1e-4-wdl0.01_20230426-233306_180_basis_240_id_10 17 | 18 | 19 | if [ $option == 'seen' ];then 20 | font_len=240 21 | # data=data/imgs/Seen240_S80F50_TRAIN800 22 | # basis_data=data/imgs/BASIS_S80F50_TRAIN800 23 | data=data/imgs/Seen240_S80F50_FS16 24 | save_fn=basis/${basis_fn}_ws_${font_len}x10_t${temperature}.pth 25 | elif [ $option == 'unseen' ]; then 26 | font_len=60 27 | data=data/imgs/Unseen60_S80F50_FS16 28 | save_fn=basis/${basis_fn}_unseen_ws_${font_len}x10_t${temperature}.pth 29 | fi 30 | basis_data=data/imgs/BASIS_S80F50_FS16 31 | 32 | CUDA_VISIBLE_DEVICES=1 python cal_cf_weight.py \ 33 | --img_size ${img_size} \ 34 | --data_path ${data} \ 35 | --basis_path ${basis_data} \ 36 | --output_k ${k} \ 37 | --load_model ${model} \ 38 | --font_len ${font_len} \ 39 | --basis_len ${basis_len} \ 40 | --baseline_idx 0 \ 41 | -t ${temperature} \ 42 | --save_fn ${save_fn} 43 | done -------------------------------------------------------------------------------- /src/cpu/deform_conv_cpu.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor 5 | deform_conv_cpu_forward(const at::Tensor &input, 6 | const at::Tensor &weight, 7 | const at::Tensor &bias, 8 | const at::Tensor &offset, 9 | const int kernel_h, 10 | const int kernel_w, 11 | const int stride_h, 12 | const int stride_w, 13 | const int pad_h, 14 | const int pad_w, 15 | const int dilation_h, 16 | const int dilation_w, 17 | const int group, 18 | const int deformable_group, 19 | const int im2col_step); 20 | 21 | std::vector 22 | deform_conv_cpu_backward(const at::Tensor &input, 23 | const at::Tensor &weight, 24 | const at::Tensor &bias, 25 | const at::Tensor &offset, 26 | const at::Tensor &grad_output, 27 | const int kernel_h, 28 | const int kernel_w, 29 | const int stride_h, 30 | const int stride_w, 31 | const int pad_h, 32 | const int pad_w, 33 | const int dilation_h, 34 | const int dilation_w, 35 | const int group, 36 | const int deformable_group, 37 | const int im2col_step); 38 | 39 | 40 | -------------------------------------------------------------------------------- /src/cuda/deform_conv_cuda.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor 5 | deform_conv_cuda_forward(const at::Tensor &input, 6 | const at::Tensor &weight, 7 | const at::Tensor &bias, 8 | const at::Tensor &offset, 9 | const int kernel_h, 10 | const int kernel_w, 11 | const int stride_h, 12 | const int stride_w, 13 | const int pad_h, 14 | const int pad_w, 15 | const int dilation_h, 16 | const int dilation_w, 17 | const int group, 18 | const int deformable_group, 19 | const int im2col_step); 20 | 21 | std::vector 22 | deform_conv_cuda_backward(const at::Tensor &input, 23 | const at::Tensor &weight, 24 | const at::Tensor &bias, 25 | const at::Tensor &offset, 26 | const at::Tensor &grad_output, 27 | const int kernel_h, 28 | const int kernel_w, 29 | const int stride_h, 30 | const int stride_w, 31 | const int pad_h, 32 | const int pad_w, 33 | const int dilation_h, 34 | const int dilation_w, 35 | const int group, 36 | const int deformable_group, 37 | const int im2col_step); 38 | 39 | -------------------------------------------------------------------------------- /src/cpu/deform_psroi_pooling_cpu.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | 5 | std::tuple 6 | deform_psroi_pooling_cpu_forward(const at::Tensor &input, 7 | const at::Tensor &bbox, 8 | const at::Tensor &trans, 9 | const int no_trans, 10 | const float spatial_scale, 11 | const int output_dim, 12 | const int group_size, 13 | const int pooled_size, 14 | const int part_size, 15 | const int sample_per_part, 16 | const float trans_std); 17 | 18 | std::tuple 19 | deform_psroi_pooling_cpu_backward(const at::Tensor &out_grad, 20 | const at::Tensor &input, 21 | const at::Tensor &bbox, 22 | const at::Tensor &trans, 23 | const at::Tensor &top_count, 24 | const int no_trans, 25 | const float spatial_scale, 26 | const int output_dim, 27 | const int group_size, 28 | const int pooled_size, 29 | const int part_size, 30 | const int sample_per_part, 31 | const float trans_std); -------------------------------------------------------------------------------- /src/cuda/deform_psroi_pooling_cuda.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | std::tuple 5 | deform_psroi_pooling_cuda_forward(const at::Tensor &input, 6 | const at::Tensor &bbox, 7 | const at::Tensor &trans, 8 | const int no_trans, 9 | const float spatial_scale, 10 | const int output_dim, 11 | const int group_size, 12 | const int pooled_size, 13 | const int part_size, 14 | const int sample_per_part, 15 | const float trans_std); 16 | 17 | std::tuple 18 | deform_psroi_pooling_cuda_backward(const at::Tensor &out_grad, 19 | const at::Tensor &input, 20 | const at::Tensor &bbox, 21 | const at::Tensor &trans, 22 | const at::Tensor &top_count, 23 | const int no_trans, 24 | const float spatial_scale, 25 | const int output_dim, 26 | const int group_size, 27 | const int pooled_size, 28 | const int part_size, 29 | const int sample_per_part, 30 | const float trans_std); -------------------------------------------------------------------------------- /src/cpu/modulated_deform_conv_cpu.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor 5 | modulated_deform_conv_cpu_forward(const at::Tensor &input, 6 | const at::Tensor &weight, 7 | const at::Tensor &bias, 8 | const at::Tensor &offset, 9 | const at::Tensor &mask, 10 | const int kernel_h, 11 | const int kernel_w, 12 | const int stride_h, 13 | const int stride_w, 14 | const int pad_h, 15 | const int pad_w, 16 | const int dilation_h, 17 | const int dilation_w, 18 | const int group, 19 | const int deformable_group, 20 | const int im2col_step); 21 | 22 | std::vector 23 | modulated_deform_conv_cpu_backward(const at::Tensor &input, 24 | const at::Tensor &weight, 25 | const at::Tensor &bias, 26 | const at::Tensor &offset, 27 | const at::Tensor &mask, 28 | const at::Tensor &grad_output, 29 | const int kernel_h, 30 | const int kernel_w, 31 | const int stride_h, 32 | const int stride_w, 33 | const int pad_h, 34 | const int pad_w, 35 | const int dilation_h, 36 | const int dilation_w, 37 | const int group, 38 | const int deformable_group, 39 | const int im2col_step); 40 | 41 | 42 | -------------------------------------------------------------------------------- /src/cuda/modulated_deform_conv_cuda.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor 5 | modulated_deform_conv_cuda_forward(const at::Tensor &input, 6 | const at::Tensor &weight, 7 | const at::Tensor &bias, 8 | const at::Tensor &offset, 9 | const at::Tensor &mask, 10 | const int kernel_h, 11 | const int kernel_w, 12 | const int stride_h, 13 | const int stride_w, 14 | const int pad_h, 15 | const int pad_w, 16 | const int dilation_h, 17 | const int dilation_w, 18 | const int group, 19 | const int deformable_group, 20 | const int im2col_step); 21 | 22 | std::vector 23 | modulated_deform_conv_cuda_backward(const at::Tensor &input, 24 | const at::Tensor &weight, 25 | const at::Tensor &bias, 26 | const at::Tensor &offset, 27 | const at::Tensor &mask, 28 | const at::Tensor &grad_output, 29 | const int kernel_h, 30 | const int kernel_w, 31 | const int stride_h, 32 | const int stride_w, 33 | const int pad_h, 34 | const int pad_w, 35 | const int dilation_h, 36 | const int dilation_w, 37 | const int group, 38 | const int deformable_group, 39 | const int im2col_step); 40 | 41 | -------------------------------------------------------------------------------- /src/cpu/deform_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | 7 | at::Tensor 8 | deform_conv_cpu_forward(const at::Tensor &input, 9 | const at::Tensor &weight, 10 | const at::Tensor &bias, 11 | const at::Tensor &offset, 12 | const int kernel_h, 13 | const int kernel_w, 14 | const int stride_h, 15 | const int stride_w, 16 | const int pad_h, 17 | const int pad_w, 18 | const int dilation_h, 19 | const int dilation_w, 20 | const int group, 21 | const int deformable_group, 22 | const int im2col_step) 23 | { 24 | AT_ERROR("Not implement on cpu"); 25 | } 26 | 27 | std::vector 28 | deform_conv_cpu_backward(const at::Tensor &input, 29 | const at::Tensor &weight, 30 | const at::Tensor &bias, 31 | const at::Tensor &offset, 32 | const at::Tensor &grad_output, 33 | const int kernel_h, 34 | const int kernel_w, 35 | const int stride_h, 36 | const int stride_w, 37 | const int pad_h, 38 | const int pad_w, 39 | const int dilation_h, 40 | const int dilation_w, 41 | const int group, 42 | const int deformable_group, 43 | const int im2col_step) 44 | { 45 | AT_ERROR("Not implement on cpu"); 46 | } 47 | 48 | -------------------------------------------------------------------------------- /oss_client.py: -------------------------------------------------------------------------------- 1 | import oss2 2 | import cv2 3 | import numpy as np 4 | import urllib.request 5 | import os 6 | 7 | from itertools import islice 8 | # import requests 9 | 10 | class OSSCTD(object): 11 | def __init__(self): 12 | host_name = 'xxx' 13 | bucket_name = 'xxx' 14 | auth = oss2.Auth('xxx', 'xxx') 15 | self.bucket = oss2.Bucket(auth, host_name, bucket_name) 16 | self.url_prefix = 'xxx' 17 | 18 | def read_file(self, file_path, auth_check=False): 19 | if auth_check: 20 | return self.bucket.get_object(file_path).read() 21 | else: 22 | url = os.path.join(self.url_prefix, file_path) 23 | return urllib.request.urlopen(url).read() 24 | 25 | def read_image(self, file_path, mode=cv2.IMREAD_UNCHANGED, auth_check=False): 26 | img_data = self.read_file(file_path, auth_check=auth_check) 27 | img_data = np.asarray(bytearray(img_data), dtype='uint8') 28 | img = cv2.imdecode(img_data, mode) 29 | return img 30 | 31 | def write_file(self, local_file, remote_file): 32 | with open(local_file, 'rb') as fin: 33 | data = fin.read() 34 | self.bucket.put_object(remote_file, data) 35 | 36 | def fetch_file(self, remote_file, local_file): 37 | self.bucket.get_object_to_file(remote_file, local_file) 38 | 39 | # def requests_write_file(self, src, dst): 40 | # with open(src, 'wb') as f: 41 | # url = os.path.join(self.url_prefix, dst) 42 | # response = requests.get(url) 43 | # f.write(response.content) 44 | 45 | def showFiles(self, bucket): 46 | print("Show All Files:") 47 | for b in islice(oss2.ObjectIterator(bucket, prefix='xxx'), None): 48 | print(b.key) -------------------------------------------------------------------------------- /src/cpu/deform_psroi_pooling_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | 7 | std::tuple 8 | deform_psroi_pooling_cpu_forward(const at::Tensor &input, 9 | const at::Tensor &bbox, 10 | const at::Tensor &trans, 11 | const int no_trans, 12 | const float spatial_scale, 13 | const int output_dim, 14 | const int group_size, 15 | const int pooled_size, 16 | const int part_size, 17 | const int sample_per_part, 18 | const float trans_std) 19 | { 20 | AT_ERROR("Not implement on cpu"); 21 | } 22 | 23 | std::tuple 24 | deform_psroi_pooling_cpu_backward(const at::Tensor &out_grad, 25 | const at::Tensor &input, 26 | const at::Tensor &bbox, 27 | const at::Tensor &trans, 28 | const at::Tensor &top_count, 29 | const int no_trans, 30 | const float spatial_scale, 31 | const int output_dim, 32 | const int group_size, 33 | const int pooled_size, 34 | const int part_size, 35 | const int sample_per_part, 36 | const float trans_std) 37 | { 38 | AT_ERROR("Not implement on cpu"); 39 | } -------------------------------------------------------------------------------- /src/cpu/modulated_deform_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | 7 | at::Tensor 8 | modulated_deform_conv_cpu_forward(const at::Tensor &input, 9 | const at::Tensor &weight, 10 | const at::Tensor &bias, 11 | const at::Tensor &offset, 12 | const at::Tensor &mask, 13 | const int kernel_h, 14 | const int kernel_w, 15 | const int stride_h, 16 | const int stride_w, 17 | const int pad_h, 18 | const int pad_w, 19 | const int dilation_h, 20 | const int dilation_w, 21 | const int group, 22 | const int deformable_group, 23 | const int im2col_step) 24 | { 25 | AT_ERROR("Not implement on cpu"); 26 | } 27 | 28 | std::vector 29 | modulated_deform_conv_cpu_backward(const at::Tensor &input, 30 | const at::Tensor &weight, 31 | const at::Tensor &bias, 32 | const at::Tensor &offset, 33 | const at::Tensor &mask, 34 | const at::Tensor &grad_output, 35 | const int kernel_h, 36 | const int kernel_w, 37 | const int stride_h, 38 | const int stride_w, 39 | const int pad_h, 40 | const int pad_w, 41 | const int dilation_h, 42 | const int dilation_w, 43 | const int group, 44 | const int deformable_group, 45 | const int im2col_step) 46 | { 47 | AT_ERROR("Not implement on cpu"); 48 | } 49 | 50 | -------------------------------------------------------------------------------- /scripts/option_run_inf_dgfont.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | item=180 # 200 4 | style_font_len=240 5 | 6 | # Make data for DG-Font (Copy the content font to $data/id_${font_len}) 7 | # The data file tree like follows: (few-shot 16, target content 5646) 8 | # 9 | # . 10 | # ├── data 11 | # │   ├── fonts 12 | # │   └── imgs 13 | # │   ├── Seen240_S80F50_FS16_DGFONT 14 | # │   | ├── id_0 15 | # │   │   │   ├── 0000.png 16 | # │   │   │   ├── 0001.png 17 | # │   │   │   ├── ... 18 | # │   │   │   └── 0015.png 19 | # │   | ├── id_1 20 | # │   | ├── ... 21 | # │   | ├── id_239 22 | # │   │   │   ├── 0000.png 23 | # │   │   │   ├── 0001.png 24 | # │   │   │   ├── ... 25 | # │   │   │   └── 0015.png 26 | # │   | └── id_240 27 | # │   │      ├── 0000.png 28 | # │   │      ├── 0001.png 29 | # │   │      ├── ... 30 | # │   │      └── 5645.png 31 | # │   └── ... 32 | # ├── charset 33 | # └── ... 34 | # here is a example: (in folder `data/imgs`) 35 | # > cp -rf Seen240_S80F50_FS16 Seen240_S80F50_FS16_DGFONT 36 | # > cp -rf Seen240_S80F50_TEST5646/id_0 Seen240_S80F50_FS16_DGFONT/id_240 37 | 38 | output_k=240 39 | py_file=inf_with_style_ft.py 40 | img_size=80 41 | model_base=B0_K240BS32I1000E200_LR1e-4-wdl0.01_20230426-233306 42 | 43 | data=data/imgs/Seen240_S80F50_FS16_DGFONT 44 | model=output/models/logs/${model_base}/model_${item}.ckpt 45 | save_path="output/test_rsts/dgfont_${model_base}_${item}" 46 | 47 | CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node 2 \ 48 | --use_env --master_port 34544 ${py_file} \ 49 | --img_size ${img_size} \ 50 | --data_path ${data} \ 51 | --output_k ${output_k} \ 52 | --load_model ${model} \ 53 | --save_path ${save_path} \ 54 | --font_len ${style_font_len} \ 55 | --baseline_idx 0 \ 56 | --sty_batch_size 40 57 | -------------------------------------------------------------------------------- /scripts/visualization/viz.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from tensorboard.plugins import projector 4 | import click, os 5 | import torch 6 | import shutil 7 | 8 | @click.command() 9 | @click.option('--name', default="FontStyle_228", help='Name of visualisation') 10 | @click.option('--sprite_size', default=128, help='Size of sprite') 11 | @click.option('--feature_dir', default='feature.npy', help='codebook npy filepath') 12 | @click.option('--sprite', help='Name of sprites file') 13 | @click.option('--prefix', default='viz', help='log for tensorboard') 14 | @click.option('--suffix', default='', help='log for tensorboard') 15 | def main(name, sprite_size, feature_dir, sprite, prefix, suffix): 16 | config = projector.ProjectorConfig() 17 | 18 | for feature, feature_suf in [['style.pth', '_s'], ['c_src.pth', '_c']]: 19 | feature = os.path.join(feature_dir, feature) 20 | # assert sprite in ['rgb', 'alpha'] 21 | if feature.endswith('npy'): 22 | codebook = feature 23 | codebook = np.load(codebook) 24 | else: 25 | codebook = torch.load(feature).numpy() 26 | codebook = tf.Variable(tf.convert_to_tensor(codebook, dtype=tf.float32)) 27 | ckpt = tf.train.Checkpoint(embedding=codebook) 28 | logdir = os.path.join('viz', prefix + '_' + name) 29 | ckpt.save(os.path.join(logdir, f'embedding{suffix}{feature_suf}.ckpt')) 30 | 31 | 32 | # You can add multiple embeddings. Here we add only one. 33 | embedding = config.embeddings.add() 34 | embedding.tensor_name = "embedding{suffix}{feature_suf}/.ATTRIBUTES/VARIABLE_VALUE" 35 | embedding.sprite.image_path = os.path.join(os.getcwd(), sprite) 36 | embedding.sprite.single_image_dim.extend([sprite_size, sprite_size]) 37 | 38 | 39 | projector.visualize_embeddings(logdir, config) 40 | 41 | if __name__ == '__main__': 42 | main() 43 | -------------------------------------------------------------------------------- /tools/label_smooth.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | ## 10 | # version 1: use torch.autograd 11 | class LabelSmoothSoftmaxCEV1(nn.Module): 12 | ''' 13 | This is the autograd version, you can also try the LabelSmoothSoftmaxCEV2 that uses derived gradients 14 | ''' 15 | 16 | def __init__(self, lb_smooth=0.1, reduction='mean', ignore_index=-100): 17 | super(LabelSmoothSoftmaxCEV1, self).__init__() 18 | self.lb_smooth = lb_smooth 19 | self.reduction = reduction 20 | self.lb_ignore = ignore_index 21 | self.log_softmax = nn.LogSoftmax(dim=1) 22 | 23 | def forward(self, logits, label): 24 | ''' 25 | Same usage method as nn.CrossEntropyLoss: 26 | >>> criteria = LabelSmoothSoftmaxCEV1() 27 | >>> logits = torch.randn(8, 19, 384, 384) # nchw, float/half 28 | >>> lbs = torch.randint(0, 19, (8, 384, 384)) # nhw, int64_t 29 | >>> loss = criteria(logits, lbs) 30 | ''' 31 | # overcome ignored label 32 | logits = logits.float() # use fp32 to avoid nan 33 | with torch.no_grad(): 34 | num_classes = logits.size(1) 35 | label = label.clone().detach() 36 | ignore = label.eq(self.lb_ignore) 37 | n_valid = ignore.eq(0).sum() 38 | label[ignore] = 0 39 | lb_pos, lb_neg = 1. - self.lb_smooth, self.lb_smooth / num_classes 40 | lb_one_hot = torch.empty_like(logits).fill_( 41 | lb_neg).scatter_(1, label.unsqueeze(1), lb_pos).detach() 42 | 43 | logs = self.log_softmax(logits) 44 | loss = -torch.sum(logs * lb_one_hot, dim=1) 45 | loss[ignore] = 0 46 | if self.reduction == 'mean': 47 | loss = loss.sum() / n_valid 48 | if self.reduction == 'sum': 49 | loss = loss.sum() 50 | 51 | return loss 52 | -------------------------------------------------------------------------------- /scripts/visualization/plot_nn.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | # import matplotlib.pyplot as plt 3 | import glob 4 | import torch 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | imgs_400 = [cv2.imread(fn) for fn in sorted(glob.glob('./data/data_400_hard_50_2x2_format/*png'))] 9 | 10 | sv = torch.load('embedding_net400_d400/style.pth') 11 | 12 | sv_dis = (sv[None,...] - sv[:,None,:]).abs().mean(-1) # 400x400 13 | sv_dis += torch.eye(400)*1000 # kill eye 14 | 15 | score, sv_idx = torch.min(sv_dis, 1) 16 | 17 | out_imgs = [] 18 | for i in range(400): 19 | img_src = imgs_400[i] 20 | img_nn = imgs_400[sv_idx[i]] 21 | img_cat = np.hstack([img_src, img_nn]) 22 | img_cat[-2:] = img_cat[:2] = img_cat[:,-2:] = img_cat[:,:2] = 0 23 | out_imgs.append(img_cat) 24 | 25 | # 8x50 26 | ni,nj = 20, 20 27 | 28 | out_imgs = np.array(out_imgs) 29 | out_img = np.vstack([np.hstack(out_imgs[i*ni: (i+1)*ni]) for i in range(nj)]) 30 | cv2.imwrite('out.png', out_img) 31 | 32 | ### 33 | cv = torch.load('embedding_net400_d400/c_src.pth') 34 | 35 | cv = cv.reshape([400,-1]) 36 | 37 | # cv_dis = (cv[None,...] - cv[:,None,:]).abs().mean(-1) # 400x400 38 | # cv_dis += torch.eye(400)*1000 # kill eye 39 | 40 | # avoid OOM 41 | cv_dis_s = [] 42 | per = 5 43 | for i in tqdm(range(400//per)): 44 | cv_dis = (cv[:,None,:] - cv[i*per:(i+1)*per][None,...]).abs().mean(-1) # [400, 1, k] - [1, 20,k] -> [400, 20, k] -> [400,20] 45 | cv_dis_s.append(cv_dis) 46 | 47 | cv_dis = torch.cat(cv_dis_s, 1) 48 | assert cv_dis.shape[0] == 400 and cv_dis.shape[1] == 400 49 | cv_dis += torch.eye(400)*1000 50 | 51 | score, cv_idx = torch.min(cv_dis, 1) 52 | 53 | out_imgs = [] 54 | for i in range(400): 55 | img_src = imgs_400[i] 56 | img_nn = imgs_400[cv_idx[i]] 57 | img_cat = np.hstack([img_src, img_nn]) 58 | img_cat[-2:] = img_cat[:2] = img_cat[:,-2:] = img_cat[:,:2] = 0 59 | out_imgs.append(img_cat) 60 | 61 | # 8x50 62 | out_imgs = np.array(out_imgs) 63 | out_img = np.vstack([np.hstack(out_imgs[i*ni: (i+1)*ni]) for i in range(nj)]) 64 | cv2.imwrite('out_c.png', out_img) -------------------------------------------------------------------------------- /scripts/visualization/plot_nn_pair.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | # import matplotlib.pyplot as plt 3 | import glob 4 | import torch 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | imgs_400 = [cv2.imread(fn) for fn in sorted(glob.glob('./data/data_400_hard_50_2x2_format/*png'))] 9 | 10 | # sv = torch.load('embedding_net400_d400/style.pth') 11 | 12 | # sv_dis = (sv[None,...] - sv[:,None,:]).abs().mean(-1) # 400x400 13 | # sv_dis += torch.eye(400)*1000 # kill eye 14 | 15 | # score, sv_idx = torch.min(sv_dis, 1) 16 | 17 | # out_imgs = [] 18 | # for i in range(400): 19 | # img_src = imgs_400[i] 20 | # img_nn = imgs_400[sv_idx[i]] 21 | # img_cat = np.hstack([img_src, img_nn]) 22 | # img_cat[-2:] = img_cat[:2] = img_cat[:,-2:] = img_cat[:,:2] = 0 23 | # out_imgs.append(img_cat) 24 | 25 | # # 8x50 26 | # out_imgs = np.array(out_imgs) 27 | # out_img = np.vstack([np.hstack(out_imgs[i*8: (i+1)*8]) for i in range(50)]) 28 | # cv2.imwrite('out.png', out_img) 29 | 30 | ### 31 | cv = torch.load('embedding_net400_d400/c_src.pth') 32 | 33 | cv = cv.reshape([400,-1]) 34 | 35 | # cv_dis = (cv[None,...] - cv[:,None,:]).abs().mean(-1) # 400x400 36 | # cv_dis += torch.eye(400)*1000 # kill eye 37 | 38 | # avoid OOM 39 | cv_dis_s = [] 40 | per = 5 41 | for i in tqdm(range(400//per)): 42 | cv_dis = (cv[:,None,:] - cv[i*per:(i+1)*per][None,...]).abs().mean(-1) # [400, 1, k] - [1, 20,k] -> [400, 20, k] -> [400,20] 43 | cv_dis_s.append(cv_dis) 44 | 45 | cv_dis = torch.cat(cv_dis_s, 1) 46 | assert cv_dis.shape[0] == 400 and cv_dis.shape[1] == 400 47 | cv_dis += torch.eye(400)*1000 48 | 49 | score, cv_idx = torch.min(cv_dis, 1) 50 | 51 | out_imgs = [] 52 | for i in range(400): 53 | img_src = imgs_400[i] 54 | img_nn = imgs_400[cv_idx[i]] 55 | img_cat = np.hstack([img_src, img_nn]) 56 | img_cat[-2:] = img_cat[:2] = img_cat[:,-2:] = img_cat[:,:2] = 0 57 | out_imgs.append(img_cat) 58 | 59 | # 8x50 60 | out_imgs = np.array(out_imgs) 61 | out_img = np.vstack([np.hstack(out_imgs[i*8: (i+1)*8]) for i in range(50)]) 62 | cv2.imwrite('out_c.png', out_img) -------------------------------------------------------------------------------- /eval/cal_mean.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import numpy as np 5 | import time 6 | import pdb 7 | 8 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 9 | parser.add_argument('-f','--folder', type=str, default='../test/tmp_base') 10 | parser.add_argument('-j','--jump', type=int, default=[], nargs='+') 11 | parser.add_argument('-k','--k', type=int, default=240) 12 | parser.add_argument('-c','--choice', type=str, default='fs') 13 | args = parser.parse_args() 14 | 15 | def get_kv(line, sp=None): 16 | if line[-1] == '\n': line = line[:-1] 17 | try: 18 | k, v = line.split(sp) 19 | v = float(v.lstrip()) 20 | return k, v 21 | except: 22 | return None, None 23 | 24 | scores = {} 25 | 26 | def put_in_dict(k, v, d): 27 | if k is None: return 28 | if k in d.keys(): 29 | d[k].append(v) 30 | else: 31 | d[k] = [v] 32 | 33 | for i in range(args.k): 34 | if i in args.jump: continue 35 | f_fn = os.path.join(args.folder, f'id_{i}_fid.txt') 36 | s_fn = os.path.join(args.folder, f'id_{i}_scores.txt') 37 | # fid 38 | if 'f' in args.choice: 39 | with open(f_fn, 'r') as f: 40 | lines = f.readlines() 41 | if len(lines) == 0: print(i) 42 | for line in lines: 43 | k, v = get_kv(line, sp=':') 44 | put_in_dict(k,v,scores) 45 | # scores 46 | if 's' in args.choice: 47 | with open(s_fn, 'r') as f: 48 | lines = f.readlines() 49 | for line in lines: 50 | if 'Detail' in line: break 51 | k, v = get_kv(line) 52 | put_in_dict(k,v,scores) 53 | 54 | for k, v in scores.items(): 55 | print(f'[{len(v)}]', k, np.mean(v)) 56 | 57 | for k, d in [['l1', 5], ['rmse', 4], ['ssim', 4], ['lpips', 5], ['FID', 4]]: 58 | if k in scores.keys(): 59 | #print('{}'.format(np.mean(scores[k]).round(d)), end=' ') 60 | #print(f'%.{d}f'%(np.mean(scores[k])), end=' ') 61 | print(f'%.{d}f'%(np.mean(scores[k]))) 62 | print('') 63 | -------------------------------------------------------------------------------- /eval/eval_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import lpips 3 | 4 | try: 5 | from skimage.metrics import structural_similarity as ssim 6 | from skimage.metrics import mean_squared_error as mse 7 | from skimage.metrics import peak_signal_noise_ratio as psnr 8 | except: 9 | from skimage.measure import compare_ssim as ssim 10 | from skimage.measure import compare_mse as mse 11 | from skimage.measure import compare_psnr as psnr 12 | 13 | class LPIPS(): 14 | def __init__(self, using_gpu=False): 15 | self.model_lpips = lpips.LPIPS(net='alex') 16 | self.using_gpu = using_gpu 17 | if using_gpu: 18 | self.model_lpips.cuda() 19 | 20 | def cal_lpips(self, i0, i1): 21 | img0 = lpips.im2tensor(i0) # [-1, 1] 22 | img1 = lpips.im2tensor(i1) 23 | if self.using_gpu: 24 | img0 = img0.cuda() 25 | img1 = img1.cuda() 26 | 27 | # Compute distance 28 | dist01 = self.model_lpips.forward(img0,img1).flatten() # RGB image from [-1,1] 29 | assert len(dist01) == 1 30 | return dist01[0] 31 | 32 | L1 = lambda in0, in1, data_range=255.: np.mean(np.abs((in0 / data_range - in1 / data_range))) # HWC, [0,1] # Smaller, better 33 | RMSE = lambda in0, in1, data_range=255.: mse(in0 / data_range, in1 / data_range) ** 0.5 # Smaller, better 34 | 35 | # SSIM = lambda in0, in1, data_range=255.: ssim(in0, in1, data_range=data_range, multichannel=True) # Bigger, better 36 | #SSIM = lambda in0, in1, data_range=255.: ssim(in0, in1, data_range=data_range, channel_axis=True) # Bigger, better 37 | 38 | def SSIM(imgs_fake, imgs_real): 39 | mssim0 = ssim(imgs_fake[:,:,0], imgs_real[:,:,0], data_range=255, gaussian_weights=True) 40 | mssim1 = ssim(imgs_fake[:,:,1], imgs_real[:,:,1], data_range=255, gaussian_weights=True) 41 | mssim2 = ssim(imgs_fake[:,:,2], imgs_real[:,:,2], data_range=255, gaussian_weights=True) 42 | mssim = (mssim0 + mssim1 + mssim2)/3 43 | return mssim 44 | 45 | PSNR = lambda gt, pred, data_range=255.: psnr(gt, pred, data_range=data_range) # Bigger, better 46 | 47 | # FID: python -m pytorch_fid path/to/dataset1 path/to/dataset2 48 | -------------------------------------------------------------------------------- /scripts/04a_run_inf_cf.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | for option in unseen 4 | do 5 | item=200 6 | k=240 7 | img_size=80 8 | base_n=10 9 | 10 | py_file="inf_with_style_ft_cf.py" # inf 11 | topk=-1 # -1:use all 1:one-hot 2:top2 12 | ftep=10 # 0 13 | wdl_ft=0.01 14 | wdl=0.01 15 | lr=0.01 16 | t=0.01 17 | 18 | model_base_src=B0_K240BS32I1000E200_LR1e-4-wdl0.01_20230426-233306 19 | model_cf=CF_from_${model_base_src}_180 20 | model=output/models/logs/${model_cf}/model_${item}.ckpt 21 | 22 | if [ $option == 'seen' ];then 23 | font_len=240 24 | target_style=data/imgs/Seen240_S80F50_FS16 25 | basis_ws=basis/B0_K240BS32I1000E200_LR1e-4-wdl0.01_20230426-233306_180_basis_240_id_10_ws_240x10_t0.01.pth 26 | save_path="output/test_rsts/${model_cf}_${item}_top${topk}_ft${ftep}_wdl${wdl}_lr${lr}" 27 | elif [ $option == 'unseen' ]; then 28 | font_len=60 29 | target_style=data/imgs/Unseen60_S80F50_FS16 30 | basis_ws=basis/B0_K240BS32I1000E200_LR1e-4-wdl0.01_20230426-233306_180_basis_240_id_10_unseen_ws_60x10_t0.01.pth 31 | save_path="output/test_rsts/unseen_${model_cf}_${item}_top${topk}_ft${ftep}_wdl${wdl}_lr${lr}" 32 | fi 33 | 34 | basis_content_folder=data/imgs/BASIS_S80F50_TEST5646 35 | basis_style_ft_folder=data/imgs/BASIS_S80F50_FS16 36 | #load_sv="output/models/fonts/test_rsts/${model_base}_${item}/style_vec.pth" 37 | 38 | CUDA_VISIBLE_DEVICES=1 python -m torch.distributed.launch \ 39 | --nproc_per_node 1 --use_env --master_port 34545 ${py_file} \ 40 | --img_size ${img_size} \ 41 | --data_path ${target_style} \ 42 | --output_k ${k} \ 43 | --load_model ${model} \ 44 | --save_path ${save_path} \ 45 | --font_len ${font_len} \ 46 | --baseline_idx 0 \ 47 | --sty_batch_size 40 \ 48 | --basis_ws ${basis_ws} \ 49 | --top_k ${topk} \ 50 | --basis_folder ${basis_content_folder} \ 51 | --basis_ft_folder ${basis_style_ft_folder} \ 52 | --ft_epoch ${ftep} \ 53 | --lr ${lr} \ 54 | --wdl --w_wdl ${wdl_ft} 55 | # --load_style ${load_sv} 56 | #--pkl --w_pkl ${wdl_ft} 57 | #--wdl --w_wdl ${wdl_ft} 58 | done -------------------------------------------------------------------------------- /functions/modulated_deform_conv_func.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from __future__ import absolute_import 3 | from __future__ import print_function 4 | from __future__ import division 5 | 6 | import math 7 | import torch 8 | from torch import nn 9 | from torch.autograd import Function 10 | from torch.nn.modules.utils import _pair 11 | from torch.autograd.function import once_differentiable 12 | 13 | import DCN 14 | 15 | class ModulatedDeformConvFunction(Function): 16 | @staticmethod 17 | def forward(ctx, input, offset, mask, weight, bias, 18 | stride, padding, dilation, groups, deformable_groups, im2col_step): 19 | ctx.stride = _pair(stride) 20 | ctx.padding = _pair(padding) 21 | ctx.dilation = _pair(dilation) 22 | ctx.kernel_size = _pair(weight.shape[2:4]) 23 | ctx.groups = groups 24 | ctx.deformable_groups = deformable_groups 25 | ctx.im2col_step = im2col_step 26 | output = DCN.modulated_deform_conv_forward(input, weight, bias, 27 | offset, mask, 28 | ctx.kernel_size[0], ctx.kernel_size[1], 29 | ctx.stride[0], ctx.stride[1], 30 | ctx.padding[0], ctx.padding[1], 31 | ctx.dilation[0], ctx.dilation[1], 32 | ctx.groups, 33 | ctx.deformable_groups, 34 | ctx.im2col_step) 35 | ctx.save_for_backward(input, offset, mask, weight, bias) 36 | return output 37 | 38 | @staticmethod 39 | @once_differentiable 40 | def backward(ctx, grad_output): 41 | input, offset, mask, weight, bias = ctx.saved_tensors 42 | grad_input, grad_offset, grad_mask, grad_weight, grad_bias = \ 43 | DCN.modulated_deform_conv_backward(input, weight, 44 | bias, 45 | offset, mask, 46 | grad_output, 47 | ctx.kernel_size[0], ctx.kernel_size[1], 48 | ctx.stride[0], ctx.stride[1], 49 | ctx.padding[0], ctx.padding[1], 50 | ctx.dilation[0], ctx.dilation[1], 51 | ctx.groups, 52 | ctx.deformable_groups, 53 | ctx.im2col_step) 54 | 55 | return grad_input, grad_offset, grad_mask, grad_weight, grad_bias,\ 56 | None, None, None, None, None, None 57 | -------------------------------------------------------------------------------- /tools/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | 5 | class Logger(object): 6 | def __init__(self, log_dir): 7 | self.last = None 8 | 9 | def scalar_summary(self, tag, value, step): 10 | if self.last and self.last['step'] != step: 11 | print(self.last) 12 | self.last = None 13 | if self.last is None: 14 | self.last = {'step':step,'iter':step,'epoch':1} 15 | self.last[tag] = value 16 | 17 | def images_summary(self, tag, images, step, nrow=8): 18 | """Log a list of images.""" 19 | self.viz.images( 20 | images, 21 | opts=dict(title='%s/%d' % (tag, step), caption='%s/%d' % (tag, step)), 22 | nrow=nrow 23 | ) 24 | 25 | 26 | def makedirs(path): 27 | # if not os.path.exists(path): 28 | os.makedirs(path, exist_ok=True) 29 | 30 | 31 | def save_checkpoint(state, check_list, args, oss_client, epoch=0): 32 | check_file = os.path.join(args.log_dir, 'model_{}.ckpt'.format(epoch)) 33 | torch.save(state, check_file) 34 | check_list.write('model_{}.ckpt\n'.format(epoch)) 35 | if args.on_oss: 36 | print("Saving checkpoint ... ", flush=True) 37 | print("Saved %d on oss ... " % epoch, end='') 38 | oss_client.write_file(check_file, args.log_dir_oss + '/ckpt/generator_{}.pth'.format(epoch)) 39 | print("Done") 40 | print("Delete local file ... ", end='') 41 | if os.path.exists(check_file): 42 | os.remove(check_file) 43 | print("Done") 44 | 45 | 46 | class AverageMeter(object): 47 | """Computes and stores the average and current value""" 48 | def __init__(self): 49 | self.reset() 50 | 51 | def reset(self): 52 | self.val = 0 53 | self.avg = 0 54 | self.sum = 0 55 | self.count = 0 56 | 57 | def update(self, val, n=1): 58 | self.val = val 59 | self.sum += val * n 60 | self.count += n 61 | self.avg = self.sum / self.count 62 | 63 | 64 | def accuracy(output, target, topk=(1,)): 65 | """Computes the accuracy over the k top predictions for the specified values of k""" 66 | with torch.no_grad(): 67 | maxk = max(topk) 68 | batch_size = target.size(0) 69 | 70 | _, pred = output.topk(maxk, 1, True, True) 71 | pred = pred.t() 72 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 73 | 74 | res = [] 75 | for k in topk: 76 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 77 | res.append(correct_k.mul_(100.0 / batch_size)) 78 | return res 79 | 80 | 81 | def add_logs(args, logger, tag, value, step): 82 | logger.add_scalar(tag, value, step) 83 | -------------------------------------------------------------------------------- /scripts/visualization/make_sprite.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import glob 3 | # import tensorflow as tf 4 | import numpy as np 5 | import click 6 | import json 7 | 8 | 9 | def images_to_sprite(data): 10 | """ 11 | Creates the sprite image along with any necessary padding 12 | Source : https://github.com/tensorflow/tensorflow/issues/6322 13 | Args: 14 | data: NxHxW[x3] tensor containing the images. 15 | Returns: 16 | data: Properly shaped HxWx3 image with any necessary padding. 17 | """ 18 | if len(data.shape) == 3: 19 | data = np.tile(data[..., np.newaxis], (1, 1, 1, 3)) 20 | data = data.astype(np.float32) 21 | min = np.min(data.reshape((data.shape[0], -1)), axis=1) 22 | data = (data.transpose(1, 2, 3, 0) - min).transpose(3, 0, 1, 2) 23 | max = np.max(data.reshape((data.shape[0], -1)), axis=1) 24 | data = (data.transpose(1, 2, 3, 0) / max).transpose(3, 0, 1, 2) 25 | 26 | n = int(np.ceil(np.sqrt(data.shape[0]))) 27 | padding = ((0, n ** 2 - data.shape[0]), (0, 0), 28 | (0, 0)) + ((0, 0),) * (data.ndim - 3) 29 | data = np.pad(data, padding, mode='constant', 30 | constant_values=0) 31 | # Tile the individual thumbnails into an image. 32 | data = data.reshape((n, n) + data.shape[1:]).transpose((0, 2, 1, 3) 33 | + tuple(range(4, data.ndim + 1))) 34 | data = data.reshape((n * data.shape[1], n * data.shape[3]) + data.shape[4:]) 35 | data = (data * 255).astype(np.uint8) 36 | return data 37 | 38 | 39 | def populate_img_arr(images_paths): 40 | """ 41 | Get an array of images for a list of image paths 42 | Args: 43 | size: the size of image , in pixels 44 | should_preprocess: if the images should be processed (according to InceptionV3 requirements) 45 | Returns: 46 | arr: An array of the loaded images 47 | """ 48 | arr = [] 49 | for i, img_path in enumerate(images_paths): 50 | img = Image.open(img_path) 51 | x = np.array(img) 52 | arr.append(x) 53 | arr = np.array(arr) 54 | return arr 55 | 56 | 57 | @click.command() 58 | @click.option('--data', help='Data folder,has to end with /') 59 | @click.option('--sprite_size', default=128, help='Size of sprite') 60 | @click.option('--sprite_name', default="sprites.png", help='Name of sprites file') 61 | def main(data, sprite_size, sprite_name): 62 | if not data.endswith('/'): 63 | raise ValueError('Makesure --name ends with a "/"') 64 | 65 | images_paths = glob.glob(data + "*.jpg") 66 | images_paths.extend(glob.glob(data + "*.JPG")) 67 | images_paths.extend(glob.glob(data + "*.png")) 68 | 69 | raw_imgs = populate_img_arr(sorted(images_paths)) 70 | sprite = Image.fromarray(images_to_sprite(raw_imgs).astype(np.uint8)) 71 | sprite.save(sprite_name) 72 | 73 | if __name__ == '__main__': 74 | main() 75 | -------------------------------------------------------------------------------- /src/deform_conv.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cpu/deform_conv_cpu.h" 4 | 5 | #ifdef WITH_CUDA 6 | #include "cuda/deform_conv_cuda.h" 7 | #endif 8 | 9 | 10 | at::Tensor 11 | deform_conv_forward(const at::Tensor &input, 12 | const at::Tensor &weight, 13 | const at::Tensor &bias, 14 | const at::Tensor &offset, 15 | const int kernel_h, 16 | const int kernel_w, 17 | const int stride_h, 18 | const int stride_w, 19 | const int pad_h, 20 | const int pad_w, 21 | const int dilation_h, 22 | const int dilation_w, 23 | const int group, 24 | const int deformable_group, 25 | const int im2col_step) 26 | { 27 | if (input.type().is_cuda()) 28 | { 29 | #ifdef WITH_CUDA 30 | return deform_conv_cuda_forward(input, weight, bias, offset, 31 | kernel_h, kernel_w, 32 | stride_h, stride_w, 33 | pad_h, pad_w, 34 | dilation_h, dilation_w, 35 | group, 36 | deformable_group, 37 | im2col_step); 38 | #else 39 | AT_ERROR("Not compiled with GPU support"); 40 | #endif 41 | } 42 | AT_ERROR("Not implemented on the CPU"); 43 | } 44 | 45 | std::vector 46 | deform_conv_backward(const at::Tensor &input, 47 | const at::Tensor &weight, 48 | const at::Tensor &bias, 49 | const at::Tensor &offset, 50 | const at::Tensor &grad_output, 51 | const int kernel_h, 52 | const int kernel_w, 53 | const int stride_h, 54 | const int stride_w, 55 | const int pad_h, 56 | const int pad_w, 57 | const int dilation_h, 58 | const int dilation_w, 59 | const int group, 60 | const int deformable_group, 61 | const int im2col_step) 62 | { 63 | if (input.type().is_cuda()) 64 | { 65 | #ifdef WITH_CUDA 66 | return deform_conv_cuda_backward(input, 67 | weight, 68 | bias, 69 | offset, 70 | grad_output, 71 | kernel_h, kernel_w, 72 | stride_h, stride_w, 73 | pad_h, pad_w, 74 | dilation_h, dilation_w, 75 | group, 76 | deformable_group, 77 | im2col_step); 78 | #else 79 | AT_ERROR("Not compiled with GPU support"); 80 | #endif 81 | } 82 | AT_ERROR("Not implemented on the CPU"); 83 | } 84 | 85 | -------------------------------------------------------------------------------- /src/modulated_deform_conv.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cpu/modulated_deform_conv_cpu.h" 4 | 5 | #ifdef WITH_CUDA 6 | #include "cuda/modulated_deform_conv_cuda.h" 7 | #endif 8 | 9 | 10 | at::Tensor 11 | modulated_deform_conv_forward(const at::Tensor &input, 12 | const at::Tensor &weight, 13 | const at::Tensor &bias, 14 | const at::Tensor &offset, 15 | const at::Tensor &mask, 16 | const int kernel_h, 17 | const int kernel_w, 18 | const int stride_h, 19 | const int stride_w, 20 | const int pad_h, 21 | const int pad_w, 22 | const int dilation_h, 23 | const int dilation_w, 24 | const int group, 25 | const int deformable_group, 26 | const int im2col_step) 27 | { 28 | if (input.type().is_cuda()) 29 | { 30 | #ifdef WITH_CUDA 31 | return modulated_deform_conv_cuda_forward(input, weight, bias, offset, mask, 32 | kernel_h, kernel_w, 33 | stride_h, stride_w, 34 | pad_h, pad_w, 35 | dilation_h, dilation_w, 36 | group, 37 | deformable_group, 38 | im2col_step); 39 | #else 40 | AT_ERROR("Not compiled with GPU support"); 41 | #endif 42 | } 43 | AT_ERROR("Not implemented on the CPU"); 44 | } 45 | 46 | std::vector 47 | modulated_deform_conv_backward(const at::Tensor &input, 48 | const at::Tensor &weight, 49 | const at::Tensor &bias, 50 | const at::Tensor &offset, 51 | const at::Tensor &mask, 52 | const at::Tensor &grad_output, 53 | const int kernel_h, 54 | const int kernel_w, 55 | const int stride_h, 56 | const int stride_w, 57 | const int pad_h, 58 | const int pad_w, 59 | const int dilation_h, 60 | const int dilation_w, 61 | const int group, 62 | const int deformable_group, 63 | const int im2col_step) 64 | { 65 | if (input.type().is_cuda()) 66 | { 67 | #ifdef WITH_CUDA 68 | return modulated_deform_conv_cuda_backward(input, 69 | weight, 70 | bias, 71 | offset, 72 | mask, 73 | grad_output, 74 | kernel_h, kernel_w, 75 | stride_h, stride_w, 76 | pad_h, pad_w, 77 | dilation_h, dilation_w, 78 | group, 79 | deformable_group, 80 | im2col_step); 81 | #else 82 | AT_ERROR("Not compiled with GPU support"); 83 | #endif 84 | } 85 | AT_ERROR("Not implemented on the CPU"); 86 | } 87 | 88 | -------------------------------------------------------------------------------- /eval/get_scores.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import pandas as pd 5 | import cv2 6 | import tqdm 7 | import lpips 8 | import time 9 | 10 | from .eval_utils import LPIPS, L1, RMSE, SSIM 11 | 12 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 13 | parser.add_argument('-gt','--dir_gt', type=str, default='../test/data_writing_unseen_10_test_200') 14 | parser.add_argument('-pred','--dir_pred', type=str, default='../test/tmp_base') 15 | parser.add_argument('-o','--out', type=str, default='./imgs/example_dists') 16 | parser.add_argument('-m','--methods', type=str, default='', nargs='+') 17 | parser.add_argument('--subfolder', action='store_true') 18 | parser.add_argument('--use_gpu', action='store_true', help='turn on flag to use GPU') 19 | 20 | opt = parser.parse_args() 21 | 22 | methods_choices = ['l1', 'rmse', 'ssim', 'lpips', 'fid'] 23 | if opt.methods is '': 24 | use_methods = methods_choices 25 | else: 26 | use_methods = opt.methods 27 | assert all([mi in methods_choices for mi in use_methods]), 'invalid mathods exist' 28 | 29 | def load_image(path): 30 | assert path[-3:]=='bmp' or path[-3:]=='jpg' or path[-3:]=='png' or path[-4:]=='jpeg' 31 | return cv2.imread(path)[:,:,::-1] 32 | 33 | ## Initializing the model 34 | if 'lpips' in use_methods: 35 | lpips_model = LPIPS(using_gpu=opt.use_gpu) 36 | 37 | 38 | # crawl directories 39 | # f = open(opt.out,'w') 40 | 41 | files = os.listdir(opt.dir_gt) 42 | 43 | rsts = {} 44 | for i in use_methods: 45 | if i != 'fid': 46 | rsts[i] = [] 47 | num = len(files) 48 | 49 | st = time.time() 50 | fns = [] 51 | for file in tqdm.tqdm(files): 52 | if not file.endswith('.png'): continue 53 | assert os.path.exists(os.path.join(opt.dir_gt,file)) 54 | # Load images 55 | fns.append(file) 56 | img_gt = lpips.load_image(os.path.join(opt.dir_gt,file)) # HWC, RGB, [0, 255] 57 | img_pred = lpips.load_image(os.path.join(opt.dir_pred,file)) # HWC, RGB, [0, 255] 58 | if 'lpips' in use_methods: 59 | rst_lpips = lpips_model.cal_lpips(img_gt, img_pred) 60 | rsts['lpips'].append(rst_lpips) 61 | if 'l1' in use_methods: 62 | rst_l1 = L1(img_gt, img_pred, 255.) 63 | rsts['l1'].append(rst_l1) 64 | if 'rmse' in use_methods: 65 | rst_rmse = RMSE(img_gt, img_pred, 255.) 66 | rsts['rmse'].append(rst_rmse) 67 | if 'ssim' in use_methods: 68 | rst_ssim = SSIM(img_gt, img_pred, 255.) 69 | rsts['ssim'].append(rst_ssim) 70 | 71 | tab = pd.DataFrame(rsts, fns) 72 | 73 | print('Mean') 74 | print('=============') 75 | print(tab.mean(0)) 76 | 77 | print('writing to txt:', opt.out + '_scores.txt') 78 | with open(opt.out + '_scores.txt', 'w') as f: 79 | f.write('Mean\n') 80 | f.write('=============\n') 81 | f.write(tab.mean().to_string()) 82 | f.write('\n\nDetail\n') 83 | f.write('=============\n') 84 | f.write(tab.to_string()) 85 | print(f'done! using {time.time() - st}s') 86 | 87 | if 'fid' in use_methods: 88 | st = time.time() 89 | print('Calculate Fid...') 90 | os.system(f'python -m pytorch_fid --batch-size 25 --device cuda:1 {opt.dir_gt} {opt.dir_pred} 1>{opt.out}_fid.txt 2>&1') 91 | print(f'done! using {time.time() - st}s') 92 | -------------------------------------------------------------------------------- /models/guidingNet.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | 4 | try: 5 | from models.blocks import Conv2dBlock, FRN 6 | except: 7 | from blocks import Conv2dBlock, FRN 8 | 9 | 10 | cfg = { 11 | 'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 12 | 'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 13 | 'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 14 | 'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 15 | 'vgg19cut': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'N'], 16 | } 17 | 18 | 19 | class GuidingNet(nn.Module): 20 | def __init__(self, img_size=64, output_k={'cont': 128, 'disc': 10}): 21 | super(GuidingNet, self).__init__() 22 | # network layers setting 23 | self.features = make_layers(cfg['vgg11'], True) 24 | 25 | self.disc = nn.Linear(512, output_k['disc']) 26 | self.cont = nn.Linear(512, output_k['cont']) 27 | 28 | self._initialize_weights() 29 | 30 | def forward(self, x, sty=False): 31 | x = self.features(x) 32 | x = F.adaptive_avg_pool2d(x, (1, 1)) 33 | flat = x.view(x.size(0), -1) 34 | cont = self.cont(flat) 35 | if sty: 36 | return cont 37 | disc = self.disc(flat) 38 | return {'cont': cont, 'disc': disc} 39 | 40 | def _initialize_weights(self): 41 | for m in self.modules(): 42 | if isinstance(m, nn.Conv2d): 43 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 44 | if m.bias is not None: 45 | nn.init.constant_(m.bias, 0) 46 | elif isinstance(m, nn.BatchNorm2d): 47 | nn.init.constant_(m.weight, 1) 48 | nn.init.constant_(m.bias, 0) 49 | elif isinstance(m, nn.Linear): 50 | nn.init.normal_(m.weight, 0, 0.01) 51 | nn.init.constant_(m.bias, 0) 52 | 53 | def moco(self, x): 54 | x = self.features(x) 55 | x = F.adaptive_avg_pool2d(x, (1, 1)) 56 | flat = x.view(x.size(0), -1) 57 | cont = self.cont(flat) 58 | return cont 59 | 60 | def iic(self, x): 61 | x = self.features(x) 62 | x = F.adaptive_avg_pool2d(x, (1, 1)) 63 | flat = x.view(x.size(0), -1) 64 | disc = self.disc(flat) 65 | return disc 66 | 67 | 68 | def make_layers(cfg, batch_norm=False): 69 | layers = [] 70 | in_channels = 3 71 | for v in cfg: 72 | if v == 'M': 73 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 74 | else: 75 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 76 | if batch_norm: 77 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=False)] 78 | else: 79 | layers += [conv2d, nn.ReLU(inplace=False)] 80 | in_channels = v 81 | return nn.Sequential(*layers) 82 | 83 | 84 | if __name__ == '__main__': 85 | import torch 86 | C = GuidingNet(64) 87 | x_in = torch.randn(4, 3, 64, 64) 88 | sty = C.moco(x_in) 89 | cls = C.iic(x_in) 90 | print(sty.shape, cls.shape) 91 | -------------------------------------------------------------------------------- /scripts/basis/get_basis_simple.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import time 4 | import glob 5 | import tqdm 6 | import torch 7 | import argparse 8 | import sklearn 9 | import numpy as np 10 | # from matplotlib import pyplot as plt 11 | from sklearn.cluster import KMeans #, AffinityPropagation, MiniBatchKMeans 12 | from sklearn.decomposition import PCA 13 | 14 | parser = argparse.ArgumentParser(description='Get ContentFusion basis') 15 | parser.add_argument('-c', '--content', type=str, default='../../../embedding_baseline/c_src.pth', help='path to content embedding') 16 | parser.add_argument('-m', '--model_name', type=str) 17 | parser.add_argument('-if', '--ignore_font', default=[], type=int, nargs='+', help='the font to drop in basis') 18 | parser.add_argument('-ic', '--ignore_char', default=[], type=int, nargs='+', help='the char to drop in basis') 19 | parser.add_argument('-nb', '--basis_number', default=[10], type=int, nargs='+', help='the number of basis') 20 | parser.add_argument('-lbs', '--load_bs', default=1, type=int, help='the batchsize for cal distance') 21 | args = parser.parse_args() 22 | 23 | cvs = torch.load(args.content)#.cpu().numpy() 24 | k, n_samples, _, _, _ = cvs.shape 25 | print(cvs.shape) # (221, 50, 256, 32, 32) 26 | 27 | # filter out? 28 | if len(args.ignore_font) == 0 and len(args.ignore_char) == 0: 29 | n_samples_remain = n_samples 30 | else: 31 | ignore_font = args.ignore_font 32 | ignore_char = args.ignore_char 33 | ignore_char = torch.tensor(ignore_char) 34 | mask = torch.ones(n_samples, dtype=bool) 35 | mask.scatter_(0, ignore_char, False) 36 | n_samples_remain = mask.sum() 37 | print(f'remain: {n_samples_remain}/{n_samples}') 38 | cvs = cvs[:, mask] 39 | 40 | # get embedding 41 | cvs = cvs.reshape(*cvs.shape[:2], -1) # [221, n_samples_remain, xxx] 42 | # L1 43 | cv_dis_s = [] 44 | per = args.load_bs 45 | assert k%per == 0 46 | for i in tqdm.tqdm(range(k//per)): 47 | cv_dis = (cvs[:,None,:] - cvs[i*per:(i+1)*per][None,...]).abs().mean(-1) # [221, 1, k] - [1, 20,k] -> [400, 20, k] -> [400,20] 48 | cv_dis_s.append(cv_dis) 49 | 50 | cv_dis = torch.cat(cv_dis_s, 1) 51 | assert cv_dis.shape[0] == k and cv_dis.shape[1] == k 52 | # cv_dis += torch.eye(400)*1000 53 | torch.save(cv_dis, os.path.join(os.path.dirname(args.content), 54 | f'{args.model_name}_cv_dis_{k}x{k}x{n_samples_remain}.pth')) 55 | 56 | cv_dis = cv_dis.mean(-1) 57 | 58 | # kmeans 59 | for nb in tqdm.tqdm(args.basis_number): 60 | kmeans = KMeans(n_clusters=nb, random_state=0).fit(cv_dis) 61 | centers = kmeans.cluster_centers_ # [10, 400] 62 | dis_mat_l1 = np.abs(centers[:,None,:] - cv_dis.numpy()[None, :, :]).mean(-1) # [10,400] 63 | print(np.min(dis_mat_l1, axis=-1), np.argmin(dis_mat_l1, axis=-1)) 64 | if not os.path.exists('basis'): os.mkdir('basis') 65 | # np.save(f'basis/{args.model_name}_basis_{k}_id_{nb}.npy', np.array(sorted(np.argmin(dis_mat_l1, axis=-1)))) 66 | np.savetxt(f'basis/{args.model_name}_basis_{k}_id_{nb}.txt', np.array(sorted(np.argmin(dis_mat_l1, axis=-1))), newline=' ',fmt='%d') 67 | 68 | # vis 69 | # imgs = [] 70 | # for i in np.array(sorted(np.argmin(dis_mat_l1, axis=-1))): 71 | # img = cv2.imread('data/data_221_S128F80_Base50_2x2_format/{:04}.png'.format(i)) 72 | # imgs.append(img) 73 | 74 | # plt.figure(dpi=200) 75 | # plt.imshow(np.hstack(imgs)) 76 | -------------------------------------------------------------------------------- /models/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import torch.nn.init as init 6 | 7 | import math 8 | 9 | try: 10 | from models.blocks import FRN, ActFirstResBlk 11 | except: 12 | from blocks import FRN, ActFirstResBlk 13 | 14 | 15 | class Discriminator(nn.Module): 16 | """Discriminator: (image x, domain y) -> (logit out).""" 17 | def __init__(self, image_size=256, num_domains=2, max_conv_dim=1024): 18 | super(Discriminator, self).__init__() 19 | dim_in = 64 if image_size < 256 else 32 20 | blocks = [] 21 | blocks += [nn.Conv2d(3, dim_in, 3, 1, 1)] 22 | 23 | repeat_num = int(np.log2(image_size)) - 2 24 | for _ in range(repeat_num): 25 | dim_out = min(dim_in*2, max_conv_dim) 26 | blocks += [ActFirstResBlk(dim_in, dim_in, downsample=False)] 27 | blocks += [ActFirstResBlk(dim_in, dim_out, downsample=True)] 28 | dim_in = dim_out 29 | 30 | blocks += [nn.LeakyReLU(0.2)] 31 | blocks += [nn.Conv2d(dim_out, dim_out, 4, 1, 0)] 32 | blocks += [nn.LeakyReLU(0.2)] 33 | blocks += [nn.Conv2d(dim_out, num_domains, 1, 1, 0)] 34 | self.main = nn.Sequential(*blocks) 35 | 36 | self.apply(weights_init('kaiming')) 37 | 38 | def forward(self, x, y): 39 | """ 40 | Inputs: 41 | - x: images of shape (batch, 3, image_size, image_size). 42 | - y: domain indices of shape (batch). 43 | Output: 44 | - out: logits of shape (batch). 45 | """ 46 | out = self.main(x) 47 | feat = out 48 | out = out.view(out.size(0), -1) # (batch, num_domains) 49 | idx = torch.LongTensor(range(y.size(0))).to(y.device) 50 | out = out[idx, y] # (batch) 51 | return out, feat 52 | 53 | def _initialize_weights(self, mode='fan_in'): 54 | for m in self.modules(): 55 | if isinstance(m, nn.Conv2d): 56 | nn.init.kaiming_normal_(m.weight, mode=mode, nonlinearity='relu') 57 | if m.bias is not None: 58 | m.bias.data.zero_() 59 | 60 | 61 | def weights_init(init_type='gaussian'): 62 | def init_fun(m): 63 | classname = m.__class__.__name__ 64 | if (classname.find('Conv') == 0 or classname.find( 65 | 'Linear') == 0) and hasattr(m, 'weight'): 66 | if init_type == 'gaussian': 67 | init.normal_(m.weight.data, 0.0, 0.02) 68 | elif init_type == 'xavier': 69 | init.xavier_normal_(m.weight.data, gain=math.sqrt(2)) 70 | elif init_type == 'kaiming': 71 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 72 | elif init_type == 'orthogonal': 73 | init.orthogonal_(m.weight.data, gain=math.sqrt(2)) 74 | elif init_type == 'default': 75 | pass 76 | else: 77 | assert 0, "Unsupported initialization: {}".format(init_type) 78 | if hasattr(m, 'bias') and m.bias is not None: 79 | init.constant_(m.bias.data, 0.0) 80 | return init_fun 81 | 82 | 83 | if __name__ == '__main__': 84 | D = Discriminator(64, 10) 85 | x_in = torch.randn(4, 3, 64, 64) 86 | y_in = torch.randint(0, 10, size=(4, )) 87 | out, feat = D(x_in, y_in) 88 | print(out.shape, feat.shape) 89 | -------------------------------------------------------------------------------- /src/deform_psroi_pooling.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cpu/deform_psroi_pooling_cpu.h" 4 | 5 | #ifdef WITH_CUDA 6 | #include "cuda/deform_psroi_pooling_cuda.h" 7 | #endif 8 | 9 | 10 | std::tuple 11 | deform_psroi_pooling_forward(const at::Tensor &input, 12 | const at::Tensor &bbox, 13 | const at::Tensor &trans, 14 | const int no_trans, 15 | const float spatial_scale, 16 | const int output_dim, 17 | const int group_size, 18 | const int pooled_size, 19 | const int part_size, 20 | const int sample_per_part, 21 | const float trans_std) 22 | { 23 | if (input.type().is_cuda()) 24 | { 25 | #ifdef WITH_CUDA 26 | return deform_psroi_pooling_cuda_forward(input, 27 | bbox, 28 | trans, 29 | no_trans, 30 | spatial_scale, 31 | output_dim, 32 | group_size, 33 | pooled_size, 34 | part_size, 35 | sample_per_part, 36 | trans_std); 37 | #else 38 | AT_ERROR("Not compiled with GPU support"); 39 | #endif 40 | } 41 | AT_ERROR("Not implemented on the CPU"); 42 | } 43 | 44 | std::tuple 45 | deform_psroi_pooling_backward(const at::Tensor &out_grad, 46 | const at::Tensor &input, 47 | const at::Tensor &bbox, 48 | const at::Tensor &trans, 49 | const at::Tensor &top_count, 50 | const int no_trans, 51 | const float spatial_scale, 52 | const int output_dim, 53 | const int group_size, 54 | const int pooled_size, 55 | const int part_size, 56 | const int sample_per_part, 57 | const float trans_std) 58 | { 59 | if (input.type().is_cuda()) 60 | { 61 | #ifdef WITH_CUDA 62 | return deform_psroi_pooling_cuda_backward(out_grad, 63 | input, 64 | bbox, 65 | trans, 66 | top_count, 67 | no_trans, 68 | spatial_scale, 69 | output_dim, 70 | group_size, 71 | pooled_size, 72 | part_size, 73 | sample_per_part, 74 | trans_std); 75 | #else 76 | AT_ERROR("Not compiled with GPU support"); 77 | #endif 78 | } 79 | AT_ERROR("Not implemented on the CPU"); 80 | } -------------------------------------------------------------------------------- /font2img.py: -------------------------------------------------------------------------------- 1 | from PIL import Image,ImageDraw,ImageFont 2 | import os 3 | import numpy as np 4 | import pathlib 5 | import argparse 6 | 7 | 8 | parser = argparse.ArgumentParser(description='Obtaining characters from .ttf') 9 | parser.add_argument('--ttf_path', type=str, default='../ttf_folder',help='ttf directory') 10 | parser.add_argument('--chara', type=str, default='../chara.txt',help='characters') 11 | parser.add_argument('--save_path', type=str, default='../save_folder',help='images directory') 12 | parser.add_argument('--img_size', type=int, help='The size of generated images') 13 | parser.add_argument('--chara_size', type=int, help='The size of generated characters') 14 | parser.add_argument('--start_id', type=int, default=0, help='The start index for save') 15 | parser.add_argument('--only_id', type=int, default=-1, help='The only index for save') 16 | args = parser.parse_args() 17 | 18 | file_object = open(args.chara,encoding='utf-8') 19 | try: 20 | characters = file_object.read() 21 | finally: 22 | file_object.close() 23 | 24 | 25 | def draw_single_char(ch, font, canvas_size, x_offset, y_offset): 26 | img = Image.new("RGB", (canvas_size, canvas_size), (255, 255, 255)) 27 | draw = ImageDraw.Draw(img) 28 | draw.text((x_offset, y_offset), ch, (0, 0, 0), font=font) 29 | return img 30 | 31 | def draw_example(ch, src_font, canvas_size, x_offset, y_offset): 32 | src_img = draw_single_char(ch, src_font, canvas_size, x_offset, y_offset) 33 | example_img = Image.new("RGB", (canvas_size, canvas_size), (255, 255, 255)) 34 | example_img.paste(src_img, (0, 0)) 35 | return example_img 36 | 37 | data_dir = args.ttf_path 38 | data_root = pathlib.Path(data_dir) 39 | # print(data_root) 40 | 41 | all_image_paths = list(data_root.glob('*.tt*')) + list(data_root.glob('*.TT*')) + list(data_root.glob('*.ot*')) + list(data_root.glob('*.OT*')) 42 | all_image_paths = [str(path) for path in all_image_paths] 43 | all_image_paths = sorted(all_image_paths) 44 | all_image_paths = all_image_paths[args.start_id:] 45 | 46 | seq = list() 47 | # Auto Run 48 | print(len(all_image_paths)) 49 | for (label,item) in zip(range(args.start_id, args.start_id+len(all_image_paths)),all_image_paths): 50 | print(label, item) 51 | if args.only_id == -1 or args.only_id == label: 52 | lrs = [] 53 | tds = [] 54 | # for sample_i in range(100): 55 | for sample_i in range(len(characters)): 56 | src_font = ImageFont.truetype(item, size = args.chara_size) 57 | try: 58 | # check pos 59 | # mean 60 | chara_base = characters[sample_i] 61 | img_base = 255 - np.array(draw_example(chara_base, src_font, args.img_size, (args.img_size-args.chara_size)/2, (args.img_size-args.chara_size)/2)) 62 | img_base = img_base.sum(2) 63 | img_base_dim0 = img_base.sum(1) 64 | img_base_dim1 = img_base.sum(0) 65 | pos_dim0 = np.where(img_base_dim0 > 0)[0] 66 | pos_dim1 = np.where(img_base_dim1 > 0)[0] 67 | top, down = pos_dim0.min(), args.img_size - pos_dim0.max() 68 | left, right = pos_dim1.min(), args.img_size - pos_dim1.max() 69 | lr = 0 if abs(left - right) < 2 else right-left 70 | td = 0 if abs(top - down) < 2 else down-top 71 | lrs.append(lr) 72 | tds.append(td) 73 | except: 74 | print(f'Skip check sample {sample_i}') 75 | lr = np.mean(lrs) 76 | td = np.mean(tds) 77 | 78 | try: 79 | if left + right > args.img_size * 2 / 3 or top + down > args.img_size * 2 / 3: 80 | print('!!!', label, item) 81 | except: 82 | lr = td = 0 83 | # exit() 84 | # characters_now = characters[:405] if label == 0 else characters 85 | characters_now = characters 86 | 87 | for (chara,cnt) in zip(characters_now, range(len(characters_now))): 88 | img = draw_example(chara, src_font, args.img_size, (args.img_size-args.chara_size + lr)/2, (args.img_size-args.chara_size + td)/2) 89 | path_full = os.path.join(args.save_path, 'id_%d'%label) 90 | if not os.path.exists(path_full): 91 | os.mkdir(path_full) 92 | img.save(os.path.join(path_full, "%04d.png" % (cnt))) -------------------------------------------------------------------------------- /tools/hsic.py: -------------------------------------------------------------------------------- 1 | """ 2 | Original code: https://github.com/clovaai/rebias/blob/master/criterions/hsic.py 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | def to_numpy(x): 10 | """convert Pytorch tensor to numpy array 11 | """ 12 | return x.clone().detach().cpu().numpy() 13 | 14 | 15 | class HSIC(nn.Module): 16 | """Base class for the finite sample estimator of Hilbert-Schmidt Independence Criterion (HSIC) 17 | ..math:: HSIC (X, Y) := || C_{x, y} ||^2_{HS}, where HSIC (X, Y) = 0 iif X and Y are independent. 18 | 19 | Empirically, we use the finite sample estimator of HSIC (with m observations) by, 20 | (1) biased estimator (HSIC_0) 21 | Gretton, Arthur, et al. "Measuring statistical dependence with Hilbert-Schmidt norms." 2005. 22 | :math: (m - 1)^2 tr KHLH. 23 | where K_{ij} = kernel_x (x_i, x_j), L_{ij} = kernel_y (y_i, y_j), H = 1 - m^{-1} 1 1 (Hence, K, L, H are m by m matrices). 24 | (2) unbiased estimator (HSIC_1) 25 | Song, Le, et al. "Feature selection via dependence maximization." 2012. 26 | :math: \frac{1}{m (m - 3)} \bigg[ tr (\tilde K \tilde L) + \frac{1^\top \tilde K 1 1^\top \tilde L 1}{(m-1)(m-2)} - \frac{2}{m-2} 1^\top \tilde K \tilde L 1 \bigg]. 27 | where \tilde K and \tilde L are related to K and L by the diagonal entries of \tilde K_{ij} and \tilde L_{ij} are set to zero. 28 | 29 | Parameters 30 | ---------- 31 | sigma_x : float 32 | the kernel size of the kernel function for X. 33 | sigma_y : float 34 | the kernel size of the kernel function for Y. 35 | algorithm: str ('unbiased' / 'biased') 36 | the algorithm for the finite sample estimator. 'unbiased' is used for our paper. 37 | reduction: not used (for compatibility with other losses). 38 | """ 39 | def __init__(self, sigma_x, sigma_y=None, algorithm='unbiased', 40 | reduction=None): 41 | super(HSIC, self).__init__() 42 | 43 | if sigma_y is None: 44 | sigma_y = sigma_x 45 | 46 | self.sigma_x = sigma_x 47 | self.sigma_y = sigma_y 48 | 49 | if algorithm == 'biased': 50 | self.estimator = self.biased_estimator 51 | elif algorithm == 'unbiased': 52 | self.estimator = self.unbiased_estimator 53 | else: 54 | raise ValueError('invalid estimator: {}'.format(algorithm)) 55 | 56 | def _kernel_x(self, X): 57 | raise NotImplementedError 58 | 59 | def _kernel_y(self, Y): 60 | raise NotImplementedError 61 | 62 | def biased_estimator(self, input1, input2): 63 | """Biased estimator of Hilbert-Schmidt Independence Criterion 64 | Gretton, Arthur, et al. "Measuring statistical dependence with Hilbert-Schmidt norms." 2005. 65 | """ 66 | K = self._kernel_x(input1) 67 | L = self._kernel_y(input2) 68 | 69 | KH = K - K.mean(0, keepdim=True) 70 | LH = L - L.mean(0, keepdim=True) 71 | 72 | N = len(input1) 73 | 74 | return torch.trace(KH @ LH / (N - 1) ** 2) 75 | 76 | def unbiased_estimator(self, input1, input2): 77 | """Unbiased estimator of Hilbert-Schmidt Independence Criterion 78 | Song, Le, et al. "Feature selection via dependence maximization." 2012. 79 | """ 80 | kernel_XX = self._kernel_x(input1) 81 | kernel_YY = self._kernel_y(input2) 82 | 83 | tK = kernel_XX - torch.diag(kernel_XX) 84 | tL = kernel_YY - torch.diag(kernel_YY) 85 | 86 | N = len(input1) 87 | 88 | hsic = ( 89 | torch.trace(tK @ tL) 90 | + (torch.sum(tK) * torch.sum(tL) / (N - 1) / (N - 2)) 91 | - (2 * torch.sum(tK, 0).dot(torch.sum(tL, 0)) / (N - 2)) 92 | ) 93 | 94 | return hsic / (N * (N - 3)) 95 | 96 | def forward(self, input1, input2, **kwargs): 97 | return self.estimator(input1, input2) 98 | 99 | 100 | class RbfHSIC(HSIC): 101 | """Radial Basis Function (RBF) kernel HSIC implementation. 102 | """ 103 | def _kernel(self, X, sigma): 104 | X = X.view(len(X), -1) 105 | XX = X @ X.t() 106 | X_sqnorms = torch.diag(XX) 107 | X_L2 = -2 * XX + X_sqnorms.unsqueeze(1) + X_sqnorms.unsqueeze(0) 108 | gamma = 1 / (2 * sigma ** 2) 109 | 110 | kernel_XX = torch.exp(-gamma * X_L2) 111 | return kernel_XX 112 | 113 | def _kernel_x(self, X): 114 | return self._kernel(X, self.sigma_x) 115 | 116 | def _kernel_y(self, Y): 117 | return self._kernel(Y, self.sigma_y) 118 | 119 | 120 | class MinusRbfHSIC(RbfHSIC): 121 | """``Minus'' RbfHSIC for the ``max'' optimization. 122 | """ 123 | def forward(self, input1, input2, **kwargs): 124 | return -self.estimator(input1, input2) 125 | -------------------------------------------------------------------------------- /eval/get_scores_test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import pandas as pd 5 | import numpy as np 6 | import cv2 7 | import glob 8 | import tqdm 9 | import lpips 10 | import time 11 | 12 | import pdb 13 | 14 | from eval_utils import LPIPS, L1, RMSE, SSIM 15 | 16 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 17 | parser.add_argument('-gt','--dir_gt', type=str, default='../test/data_writing_unseen_10_test_200') 18 | parser.add_argument('-pred','--dir_pred', type=str, default='../test/tmp_base') 19 | parser.add_argument('-o','--out', type=str, default='a_scores') 20 | parser.add_argument('-m','--methods', type=str, default='', nargs='+') 21 | parser.add_argument('--subfolder', action='store_true') 22 | parser.add_argument('--gpu', action='store_true', help='turn on flag to use GPU') 23 | parser.add_argument('-j','--jump', type=int, default=[], nargs='+') 24 | parser.add_argument('--only', type=int, default=[], nargs='+') 25 | 26 | opt = parser.parse_args() 27 | 28 | methods_choices = ['l1', 'rmse', 'ssim', 'lpips', 'fid'] 29 | if opt.methods is '': 30 | use_methods = methods_choices 31 | else: 32 | use_methods = opt.methods 33 | assert all([mi in methods_choices for mi in use_methods]), 'invalid mathods exist' 34 | 35 | def load_image(path): 36 | assert path[-3:]=='bmp' or path[-3:]=='jpg' or path[-3:]=='png' or path[-4:]=='jpeg' 37 | return cv2.imread(path)[:,:,::-1] 38 | 39 | ## Initializing the model 40 | if 'lpips' in use_methods: 41 | lpips_model = LPIPS(using_gpu=opt.gpu) 42 | 43 | 44 | # crawl directories 45 | # f = open(opt.out,'w') 46 | 47 | folders = sorted(os.listdir(opt.dir_gt), key=lambda k: int(k.lstrip('id_'))) 48 | opt.out = os.path.join(opt.dir_pred, opt.out) 49 | print(opt.out) 50 | # exit() 51 | if not os.path.exists(opt.out): 52 | os.mkdir(opt.out) 53 | 54 | rsts_all = {i:[] for i in use_methods if i != 'fid'} 55 | 56 | fs_bar = tqdm.tqdm(folders) 57 | for folder in fs_bar: 58 | if int(folder.lstrip('id_')) in opt.jump: 59 | continue 60 | if int(folder.lstrip('id_')) in opt.only or len(opt.only) == 0: 61 | dir_gt = os.path.join(opt.dir_gt, folder) 62 | if not os.path.isdir(dir_gt): continue 63 | dir_pred = os.path.join(opt.dir_pred, folder) 64 | 65 | # rename 66 | fns = glob.glob(os.path.join(dir_pred, '*png')) 67 | for fn in fns: 68 | new_base_fn = '{:04}.png'.format(int(os.path.basename(fn).split('.')[0])) 69 | # print(fn, new_base_fn) 70 | os.rename(fn, os.path.join(dir_pred, new_base_fn)) 71 | # exit() 72 | 73 | out = os.path.join(opt.out, folder) 74 | 75 | files = sorted(os.listdir(dir_gt)) 76 | rsts = {i:[] for i in use_methods if i != 'fid'} 77 | num = len(files) 78 | 79 | st = time.time() 80 | fns = [] 81 | for file_g in tqdm.tqdm(files, leave=False): 82 | if not file_g.endswith('.png'): continue 83 | file_p = file_g 84 | assert os.path.exists(os.path.join(dir_pred,file_p)) 85 | # Load images 86 | fns.append(file_g) 87 | img_gt = lpips.load_image(os.path.join(dir_gt,file_g)) # HWC, RGB, [0, 255] 88 | img_pred = lpips.load_image(os.path.join(dir_pred,file_p)) # HWC, RGB, [0, 255] 89 | 90 | if 'lpips' in use_methods: 91 | rst_lpips = lpips_model.cal_lpips(img_gt, img_pred) 92 | rsts['lpips'].append(rst_lpips.cpu().detach().numpy()) 93 | rsts_all['lpips'].append(rst_lpips.cpu().detach().numpy()) 94 | if 'l1' in use_methods: 95 | rst_l1 = L1(img_gt, img_pred, 255.) 96 | rsts['l1'].append(rst_l1) 97 | rsts_all['l1'].append(rst_l1) 98 | if 'rmse' in use_methods: 99 | rst_rmse = RMSE(img_gt, img_pred, 255.) 100 | rsts['rmse'].append(rst_rmse) 101 | rsts_all['rmse'].append(rst_rmse) 102 | if 'ssim' in use_methods: 103 | rst_ssim = SSIM(img_gt, img_pred) 104 | rsts['ssim'].append(rst_ssim) 105 | rsts_all['ssim'].append(rst_ssim) 106 | 107 | tab = pd.DataFrame(rsts, fns) 108 | 109 | # print('Mean') 110 | # print('=============') 111 | # pdb.set_trace() 112 | # tm = tab.mean(0) 113 | # info = list(zip(tm.index, tm.to_numpy())) 114 | fs_bar.set_postfix({k:np.mean(v) for k,v in rsts_all.items()}) 115 | # print(tab.mean(0)) 116 | 117 | # print('writing to txt:', out + '_scores.txt') 118 | with open(out + '_scores.txt', 'w') as f: 119 | f.write('Mean\n') 120 | f.write('=============\n') 121 | f.write(tab.mean().to_string()) 122 | f.write('\n\nDetail\n') 123 | f.write('=============\n') 124 | f.write(tab.to_string()) 125 | # print(f'done! using {time.time() - st}s') 126 | 127 | if 'fid' in use_methods: 128 | st = time.time() 129 | # print('Calculate Fid...') 130 | os.system(f'python -m pytorch_fid {dir_gt} {dir_pred} 1>{out}_fid.txt 2>/dev/null') # 2>&1 131 | # print(f'done! using {time.time() - st}s') 132 | 133 | if opt.only == 0: 134 | with open(os.path.join(opt.out, 'all_scores.txt'), 'w') as f: 135 | f.write('Mean\n') 136 | f.write('=============\n') 137 | for k, v in rsts_all.items(): 138 | print(k, np.mean(v)) 139 | f.write(f'{k}: {np.mean(v)}') 140 | -------------------------------------------------------------------------------- /charset/TEST5646.txt: -------------------------------------------------------------------------------- 1 | 一丁丈三上下丌不与且丕丞两丨个丫丬中串丶为主丿乇么义之乜乞也习乩乳乾了予事二亍于亏云亓亘些亟亠亡亢亦产亨亩享亭亳亵人亻仂仃仄仆仇仉介仍仑仔仕他付仙仝仞仟仡代以仨仪仫们仰仳仵件份伉伎伐优会伛传伢伥伧伫估伲伴伶伸伺伽佃但低佐佑体佗佘佚佛作佝佞佟你佣佤佥佧佩佬佯佰佴佶佻佼佾使侃侄侈侉侏侑侔侗侠侣侦侧侨侩侪侬侮俄俅俎俏俐俑俗俚俜俞俟俣俦俨俩俪俭俯俱俳俸俺俾倌倏倒倔倘倚倜倡倥倦倨倩倪倬倭倮债值倾偃偈偌偎偕健偬偶偷偻偾偿傀傅傈傍傣傥傧储傩催傲傺像僖僚僦僧僬僭僮僳僵僻儆儇儋儒儡儿兀允兄兆兑兔兕兖兢全公兮兰关其具兹兼兽冂内冉册冒冕冖冗军冠冢冤冥冫冬冯冰冱冶冷冻冼冽净凄准凇凉凋凌凑凛几凡凤凫凰凳凵凶凸凹出凼函刀刁刂刃分刈刊刍刎刖删刨利刭到刳刷券刹刽刿剀剁剃削剌前剐剑剔剖剜剞剡剥剪割剽剿劁劂劈劐劓力劝加劢劣动努劫劬劭励劲勃勇勉勋勐勒勖勘募勤勰勹勺勾勿匀匆匈匍匏匐匕化匙匚匝匠匡匣匦匪匮匹医匿十卅午卉卑卒卓博卜卞卟卡卢卣卤卦卩卮卯却卵卸卺卿厄厅历厉厌厍厕厘原厢厣厥厦厩厮厶去叁又叉及反发变叙叛叟叠句叨叩只召叭可叱史叵司叹叻叼叽吁吃各吆合吉吊同后吏吐向吒吓吕吖君吝吞吟吠吡吣吧吨吩听吭吮启吱吲吴吵吹吻吼吾呃呆呈呋呐呒呓呔呕呖呗员呙呛呜呢呤呦呱呲呵呶呷呸呻命咀咂咄咆咋和咎咒咔咕咖咙咚咛咝咣咤咦咧咨咩咪咫咬咭咯咱咴咸咻咽咿哀哂哄哆哇哈哉哎哏哐哑哒哔哕哗哙哚哜哝哞哟哥哦哧哨哩哭哮哲哳哺哼哽哿唁唆唇唉唏唐唑唔唛唠唣唤唧唪唬售唯唰唳唷唼唾唿啁啃啄啉啊啐啕啖啜啡啤啥啦啧啪啭啮啵啶啷啸啻啼啾喀喁喂喃喇喈喊喋喏喑喔喘喙喜喝喟喧喱喳喵喷喹喻喽喾嗄嗅嗉嗍嗑嗒嗓嗔嗖嗜嗝嗟嗡嗣嗤嗥嗦嗨嗪嗫嗬嗯嗲嗳嗵嗷嗽嗾嘀嘁嘈嘉嘌嘎嘏嘘嘛嘞嘟嘣嘤嘧嘬嘭嘱嘲嘴嘶嘹嘻嘿噌噍噎噔噗噘噙噜噢噤噪噫噬噱噶噻噼嚅嚆嚎嚏嚓嚯囊囔囗囚四囝囟因囡囤囫园困囱囵囹国囿圃圄圈圊圜圣在圩圪圬圭圮圯圳圹圻圾址坂均坊坌坍坎块坛坜坝坞坟坠坡坤坦坨坩坪坫坭坯坳坶坷坻坼垃垅垆垌垒垛垠垡垢垣垤垦垧垩垫垭垮垲垴垸埂埃埋埏埒埔埕埘埙埚埝域埠埤埯埴埸埽堀堂堆堇堋堍堑堕堙堞堡堤堪堰堵塄塌塍塑塔塘塞塥填塬塾墁境墅墉墒墓墚墟墨墩墼壁壕壬壮壳壶壹夂备夏夔夕外夙多夤夥大天夭夯头夷夸夹夺夼奁奂奄奈奉奎契奕奖奘奚奠奢奥奴奶奸好妁如妃妄妆妇妊妍妒妓妖妗妙妞妣妤妥妨妩妪妫妮妯妲妹妾姆姊姐姑姒姓委姗姘姚姝姣姥姨姹姻姿威娃娄娅娇娈娉娌娑娓娜娟娠娣娥娩娱娲娴娶娼婀婆婉婊婚婢婧婪婴婵婶婷婺婿媒媚媛媪媲媳媸媾嫁嫂嫉嫌嫒嫔嫖嫘嫜嫠嫡嫣嫦嫩嫫嬉嬖嬗嬲嬴嬷孀子孑孓孔孕孚孛孜孝孟孢孤孥学孪孬孰孱孳孵孺宀它宄宅宇宋完宏宓宕宙定宛宝实宠宥宦宪宫宰宴宵家宸宾宿寂寄寅寇寐寒寓寝寞寡寤寥寨寮寰寸对寺寻寿尉尊小少尕尖尘尜尢尤尥尬就尸尹尺尻尾尿屁屈屉届屋屎屏屐屑屙属屠屡屣屦屮屯屹屺岂岈岌岍岐岑岘岙岚岜岢岣岫岬岱岵岷岽岿峁峄峋峒峙峡峤峥峨峪峭峻崂崃崆崇崎崔崖崛崞崤崦崧崩崭崮崴崽崾嵇嵋嵌嵘嵛嵝嵩嵫嵬嵯嵴嶂嶙嶝巅巍巛巡巢工左巧巨巩巫巯已巳巷巽巾币帅帆帏帐帑帔帕帖帘帙帚帛帧帱帷帻帼帽幂幄幅幌幔幕幛幞幡幢平年幺幻幼幽庀庋庑应底庖庞废庠庥度庭庳庵庶庸庹庾廉廊廑廒廓廖廨廪廴廷廾廿开弁弃弄弈弋弑弓弗弘弛弟弥弦弧弩弪弭弯弹弼彐归当录彖彗彘彡形彤彦彩彪彬彭彰彳彷役彻彼徂徇很徉徊律後徐徒徕得徘徙徜御徨循徭微徵徼徽心忄必忆忉忌忍忏忐忑忒忖忘忙忡忤忧忪忭忮忱忸忻忽忾忿怀怂怃怄怅怆怊怍怏怒怔怕怖怙怛怜怠怡怦性怨怩怪怫怯怵怼怿恁恂恃恋恍恐恒恕恙恚恝恢恣恤恧恨恩恪恫恬恰恳恶恸恹恺恻恼恽恿悃悄悉悌悍悒悔悖悚悛悝悟悠患悦您悫悬悭悯悱悲悴悸悼情惆惊惋惑惕惚惝惟惦惧惨惩惫惬惭惮惯惰想惴惶惹惺愀愁愆愈愉愍意愕愚愠愣愤愦愧愫慈慊慌慎慑慝慧慨慰慵憎憔憝憧憨憩憬憷憾懂懈懊懋懑懒懔懦懿戆戈戊戋戌戍戎戏成我戒戕或戗戚戛戟戡戢戤戥截戬戮戳戴户戽戾所扃扇扈扉手扌扎扑扒打扔托扛扣扦扩扪扫扭扮扯扰扳扶扼技抄抉把抑抒抖折抚抟抠抡抢抨披抬抱抵抹抻押抽抿拂拄拆拇拈拊拌拍拎拐拒拓拔拖拗拘拙拚招拜拟拢拥拦拧拨择括拭拮拯拱拳拴拶拷拼拽拾挂挈挎挑挖挚挛挝挞挟挡挢挣挤挥挨挪挫挲挹挺挽捂捃捅捉捋捌捍捎捏捐捕捞损捣捧捩捭捱捶捺捻掀掂掇授掊掎掏掐掖掘掠掩掬掮掰掳掴掷掸掺掼掾揄揆揉揍揎描插揖揞揠揣揩揪揭揲援揶揸揿搀搁搂搅搋搏搐搓搔搛搜搠搡搦搪搬搭搴携搽搿摁摄摅摆摇摈摊摒摔摘摞摧摩摭摸摹摺撂撄撅撑撕撖撙撞撩撬播撮撰撵撷撸撺撼擀擂擅擎擐擒擗擘擞擢擤擦攀攉攒攫攮攴攵收攸政敉敏救敕敖教敛敞敢敦敫敬数敲敷文斋斌斐斑斓斛斜斟斡斥斧斩斫新方於旁旃旄旅旆旋旌旎旒旖旗无既日旦旧旨旭旮旯旰旱时昀昂昃昆昊明昏昕昙昝星昧昨昭是昱昴昵昶昼晁晃晋晌晏晒晔晕晖晗晚晟晡晤晦晨晰晴晷智晾暂暄暇暌暑暖暝暧暨暮暴暹暾曙曜曝曰曳曷曹曼替最月有朊朋朐朕朗朝朦未札朱朴机杈杉杌杏杓杖杞条来杩杪杭杲杳杵杷松枇枋林枘枞枢枣枥枧枨枫枭枯枰枳枵枷枸柁柃柄柑柒柘柙柚柜柝柞柠柢柩柬柯柰柱柳柴柽栀栅栈栉栊栋栌栎栏栓栖栗栝校栩株栲栳样栽栾桀桁桂桃桄桅框桉桊桌桎桐桑桓桔桕桠桢档桤桦桧桩桫桴桶桷梁梃梅梆梏梓梗梢梦梧梨梭梯械梳梵棂棋棍棒棕棘棚棠森棰棱棵棹棺棼椁椅椋椎椐椒椟椠椤椭椰椹椽椿楂楔楗楝楞楠楣楦楫楮楱楷楸楹楼榀概榆榇榈榉榍榔榕榘榛榜榧榨榫榭榱榴榷榻槁槊槌槎槐槔槟槠槭槲槽槿樊樗樘樟横樱樵樽樾橄橇橐橘橙橛橡橥橱橹橼檀檄檎檐檑檗檠檩檫檬欠欢欣欤欷欹欺歃歆歇歉歙止正此步歧歪歹歼殁殂殃殄殆殉殊殍殒殓殖殚殛殡殪殳殴殷殿毅毋毒毓比毕毖毗毙毛毡毪毫毯毳毵毹毽氅氆氇氍氏氐民氓气氕氖氘氙氚氛氟氡氤氨氩氪氮氯氰氲水氵氽汀汁汆汇汊汐汔汕汗汛汜汝汞池污汨汩汪汰汲汴汶汹汾沁沂沃沅沆沈沌沏沐沓沔沛沟没沣沤沥沩沭沱沲沸油泄泅泉泊泌泐泓泔法泖泗泛泞泠泡泣泥泪泫泮泯泰泱泳泵泶泷泸泺泻泼泽泾洁洄洇洌洎洒洗洙洚洛洞津洧洪洫洮洱洳洵洹洼洽浃浈浊测浍浏浑浒浓浔浙浚浜浞浠浣浦浩浪浮浯浴浸浼涂涅涌涎涑涓涔涕涛涝涞涟涠涡涣涤润涧涨涪涫涮涯涵涸涿淀淄淅淆淇淋淌淑淖淘淙淝淞淠淡淤淦淫淬淮淳淹淼渊渌渍渎渑渔渖渗渚渝渠渡渣渤渥渫渭港渲渴渺湃湄湍湎湓湔湘湛湟湫湮湾湿溃溅溆溉溏溘溜溟溥溧溪溯溱溲溴溷溺溻溽滁滂滇滋滏滓滔滗滚滞滟滠满滢滤滦滨滩滴滹漂漆漉漏漓漕漠漤漩漪漫漭漯漱漳漶漾潆潇潋潍潘潜潞潢潦潭潮潲潴潸潺潼澄澈澉澌澍澎澜澡澧澳澶澹激濂濉濑濒濞濠濡濮濯瀑瀚瀛瀵瀹灌灏灬灭灰灵灶灸灼灾灿炅炉炊炎炒炔炕炖炙炜炝炫炬炭炮炯炱炳炷炸点炻炽烀烁烂烃烊烘烙烛烟烤烦烨烩烬烯烷烹烽焉焊焐焓焕焖焘焙焚焦焯焰焱煊煌煎煜煞煦煨煮煲煳煺煽熄熊熔熘熟熠熨熬熳熵熹燃燎燔燠燥燧燮燹爆爨爪爬爰爷爸爹爻爿版牌牍牒牙牝牟牡牢牦物牮牯牲牵牺牾牿犁犄犊犋犍犒犟犬犭犰犴犷犸犹狁狂狃狄狈狍狎狐狒狗狙狞狠狡狨狩狭狮狯狰狲狳狴狷狸狺狻狼猁猃猊猓猕猖猗猛猜猝猞猡猢猥猩猪猫猬献猱猷猸猹猾猿獍獐獒獗獠獬獭獾玄玎玑玖玛玟玢玩玫玮现玲玳玷玺玻珀珂珈珉珊珍珏珐珑珙珞珠珥珧珩珲琅理琉琊琏琚琢琥琦琨琪琬琮琰琳琴琵琶琼瑁瑕瑗瑙瑚瑛瑜瑞瑟瑭瑰瑶瑷瑾璀璁璃璇璋璎璐璜璞璧璩璺瓒瓜瓞瓠瓢瓣瓦瓮瓯瓴瓷瓿甄甏甑甓甘甙甚甜生甥用甩甫甬甭甯由申电男甸町甾畀畈畋畎畏畔留畚畛畜畦番畲畴畸畹畿疃疆疋疏疒疔疖疗疙疚疝疟疠疣疤疥疫疬疮疯疰疱疲疳疴疵疸疹疼疽疾痂痃痄症痈痉痊痍痒痔痕痖痘痛痞痢痣痤痦痧痨痪痫痰痱痴痹痼痿瘀瘁瘃瘅瘌瘐瘕瘗瘘瘙瘛瘟瘠瘢瘤瘥瘦瘩瘪瘫瘭瘰瘳瘴瘵瘼瘾瘿癀癃癌癍癔癖癜癞癣癫癯癸登皂的皆皈皋皎皑皓皖皙皤皱皲皴皿盂盅盆盈盍盎盏盒盔盖盗盘盛盥目盯盱盲相盹盼眄眇眈眉看眍眙眚眠眢眦眨眩眭眯眵眶眷眸眺着睁睃睇睐睑睚睛睡睢督睥睦睨睬睹睽睿瞀瞄瞅瞌瞍瞎瞑瞒瞟瞠瞧瞩瞪瞬瞰瞳瞵瞻瞽瞿矍矗矜矢矣矧矩矫矬矮矶矸矽矾砂砉砌砍砑砒砖砗砘砚砜砝砟砣砥砦砧砩砬砭砰砷砸砹砺砻砼砾础硅硇硌硎硐硒硕硖硝硪硫硬硭硷硼碇碉碌碍碎碑碓碗碘碚碛碜碟碡碣碧碰碱碲碳碴碹磁磅磉磊磋磐磔磕磙磨磬磲磴磷磺礁礅礓礞礤礴示礻礼社祀祁祆祈祉祓祗祚祛祜祝祟祠祢祥祧票祭祯祷祸祺禀禁禄禅禊禚禧禹禺离禽禾秀私秃种秒秕秘租秣秤秦秧秩秫秭秸秽稀稂稃稆程稍税稔稗稚稞稠稣稷稹稻稼稽稿穆穴究穷穸穹窀突窃窄窆窈窍窑窒窕窖窗窘窜窝窟窠窥窦窨窬窭窿立竖竞竟竣童竦竭竹竺竽竿笃笄笆笈笊笋笏笕笙笛笞笠笤笥符笨笪笫第笮笱笳笸笺笼笾筅筇等筋筌筏筐筑筒筘筚筝筠筢筮筱筲筵筷筹筻签箅箍箐箔箕箜箝箢箦箧箨箩箪箫箬箭箱箴箸篁篆篑篓篙篚篝篥篦篪篱篷篾簇簋簌簏簟簦簧簪簿籁米籴籼籽粑粕粗粘粜粝粞粟粢粤粥粱粳粹粼粽精糁糅糈糊糌糍糕糖糗糙糜糟糨糯糸系紊素紫累絮絷綦綮縻繁繇纛纟纠纡纣纤纥约纨纩纫纬纭纯纰纱纲纳纵纶纷纹纺纽纾线绀绁绂绅绉绊绋绌绎经绐绑绒结绔绗绘绚绛络绞绠绡绢绣绥绦绨绩绪绫绮绯绰绱绲绳绵绶绷绸绺绻综绽绾缀缁缂缃缄缅缇缈缉缋缌缍缏缒缔缕缗缘缙缚缛缜缝缟缠缡缣缤缥缦缧缨缩缪缫缬缭缮缯缰缱缲缳缴缵缶缸罂罄罅罐网罘罚罟罡罢罨罩罱署罴罹罾羁羊羌美羔羚羝羞羟羡羧羯羰羸羹羼羽羿翁翅翊翌翎翔翕翟翠翡翥翦翮翰翱翳翻翼耀耄者耆耋而耍耒耔耕耖耗耘耙耜耠耢耥耦耧耨耩耪耱耳耵耶耷耸耻耽耿聂聃聆聊聋聍聒联聘聩聱聿肀肃肄肆肇肋肌肓肖肘肚肛肜肝肟股肢肤肩肪肫肭肮肯肱肴肷肺肼肽肿胀胁胂胃胄胆胍胎胖胗胙胚胛胝胥胧胨胩胪胫胬胭胯胰胱胳胴胸胺胼能脂脆脉脍脎脏脐脑脒脓脔脖脘脚脞脬脯脲脶脸脾腆腈腋腌腐腑腓腔腕腙腚腠腥腧腩腭腮腰腱腴腺腻腼腽腾腿膀膂膈膊膏膑膘膛膜膝膣膦膨膪膳膺膻臀臁臂臆臊臌臣臧自臬臭臻臼臾舀舁舂舄舅舌舍舐舒舛舜舟舡舢舣舨航舫般舭舯舰舱舳舴舵舶舷舸舻舾艄艉艋艏艘艚艟艨艮艰艳艴艹艽艾艿节芄芈芊芋芍芎芏芑芒芗芘芙芜芟芡芤芨芩芪芫芬芭芮芯芰芳芴芷芸芹芽芾苁苄苇苈苊苋苌苍苎苏苑苒苓苔苕苘苛苜苞苟苠苡苣苤苫苯苴苷苹苻茁茂茄茅茆茇茈茉茌茎茏茑茔茕茗茚茛茜茧茨茬茭茯茱茳茴茵茸茹茺茼荀荃荆荇荏荐荑荒荔荚荜荞荟荠荤荥荦荨荩荪荫荬荭荮荸荻荼荽莅莆莎莒莓莘莛莜莞莠莨莩莪莫莰莱莲莳莴莶莸莹莺莼莽菀菁菅菇菊菌菏菔菖菘菝菟菠菡菥菩菪菰菱菲菸菹菽萁萃萄萆萋萌萍萎萏萑萘萜萝萤萦萧萨萱萸萼葆葑葙葚葛葜葡董葩葫葬葭葱葳葵葶葸葺蒂蒇蒈蒉蒋蒌蒗蒙蒜蒡蒯蒲蒴蒸蒹蒺蒽蒿蓁蓄蓉蓊蓍蓐蓑蓓蓖蓠蓣蓥蓦蓬蓰蓼蓿蔌蔑蔓蔗蔚蔟蔡蔫蔬蔹蔺蔻蔼蕃蕈蕉蕊蕖蕞蕤蕨蕲蕴蕺蕾薄薅薇薏薜薪薮薯薷薹藁藏藐藓藕藩藻藿蘑蘧蘩蘸蘼虍虐虑虔虚虞虢虫虬虮虱虹虺虻虼虿蚋蚌蚍蚓蚕蚜蚝蚣蚤蚧蚨蚩蚪蚬蚯蚰蚱蚴蚵蚶蚺蛀蛄蛆蛇蛉蛊蛋蛎蛏蛐蛑蛔蛘蛙蛛蛞蛟蛤蛩蛭蛮蛰蛱蛳蛴蛸蛹蛾蜀蜂蜃蜇蜈蜉蜊蜍蜒蜓蜕蜗蜘蜚蜜蜞蜢蜣蜥蜩蜮蜱蜴蜷蜻蜾蜿蝇蝈蝉蝌蝎蝓蝗蝠蝣蝤蝥蝰蝴蝶蝻蝼蝽蝾螂螃螅螈螋融螓螗螟螨螫螬螭螯螳螵螺蟀蟆蟋蟑蟒蟓蟛蟠蟥蟮蟹蟾蠃蠊蠓蠕蠖蠛蠼血衄衅衍衔街衙衢衤表衩衫衬衮衰衲衷衽衾衿袂袄袈袋袍袒袖袜袢袤袭袱袷袼裁裆裉裎裒裔裕裘裙裟裢裣裤裥裨裰裱裳裴裸裼裾褂褐褒褓褚褛褡褥褪褫褰褶襁襞襟襦襻西要覃觅觇觊觋觌觎觏觐觑觖觚觜觥觫觯訇訾詈詹誉誊誓謇謦警譬讠讣讦讧让讪讫讯讲讴讵讶讷讹论讼讽设访诀诂诃评诅诈诉诊诋诌词诎诏译诒诓诔诖诘诙诚诛诜话诞诟诠诡询诣诤详诧诨诩诫诬诮误诰诱诲诳说诶请诸诹诺诼诽诿谀谂谄谅谆谇谊谋谌谍谎谏谐谑谒谓谔谕谖谗谘谙谚谛谜谟谠谡谢谣谤谦谧谨谩谪谫谬谭谮谯谰谱谲谳谴谵谶谷豁豆豇豉豌豕豚象豢豪豫豳豸豹豺貂貅貉貊貌貔貘负责败账贪贫贮贯贰贱贲贳贴贶贷贸贺贻贼贽贾贿赀赁赂赃赆赇赈赉赊赋赌赍赎赏赐赔赕赖赘赙赚赛赝赞赠赡赢赣赤赦赧赫赭赳赴赵赶起趁趄趋趑趔趟趣趱趴趵趸趺趼趾趿跃跄跆跋跌跎跏跖跗跚跛跞跣跤跨跪跫跬跸跹跺跻跽踅踉踊踌踏踔踝踞踟踢踣踩踮踯踱踵踹踺踽蹀蹂蹄蹇蹈蹉蹊蹋蹑蹒蹙蹦蹬蹭蹯蹰蹲蹴蹶蹼蹿躁躅躇躏躜躞躬躯躲躺軎轧轨轩轫轭轱轲轳轵轶轷轸轹轺轼载轾轿辁辂辄辅辆辇辈辉辊辍辎辏辐辑辔辕辖辘辙辚辛辜辞辟辣辨辩辫辰辱辶边辽达迁迂迄过迎迓返迕这进违迟迢迤迥迦迨迩迪迭迮迳迷迸迹追退送适逃逄逅逆逊逋逍透逑递途逖逗逛逝逞逡逢逦逭逯逵逶逸逻逼逾遁遂遄遇遏遐遑遒道遗遘遛遢遣遥遨遭遮遴遵遽避邀邂邃邈邑邓邕邗邙邛邝邡邢那邦邪邬邮邯邰邱邳邴邵邸邹邺邻邾郁郄郅郇郊郎郏郐郑郓郗郛郜郝郡郢郦郧部郫郭郯郴郸都郾鄂鄄鄙鄞鄢鄣鄯鄱鄹酉酊酋酌酎酏酐酗酚酝酞酡酢酣酤酥酩酪酬酮酯酰酱酲酴酵酶酷酹酽酾酿醅醇醉醌醍醐醑醒醚醛醢醣醪醭醮醯醴醵采釉释里野釜銎銮鋈錾鍪鎏鏊鏖鐾鑫钅钆钇钉钊钋钌钍钎钏钐钒钔钕钗钙钚钛钜钝钟钠钡钣钤钥钦钧钨钩钪钫钬钭钮钯钰钲钴钵钶钷钸钹钺钻钼钽钾钿铀铂铃铄铅铆铉铊铋铌铍铎铐铑铒铕铖铗铘铛铜铝铞铟铠铡铢铣铤铥铧铨铩铪铫铬铭铮铯铰铱铲铳铵铷铸铹铺铼铽链铿销锂锃锄锅锆锇锈锉锊锋锌锍锎锏锐锑锒锓锔锕锖锗锘锚锛锝锞锟锡锢锣锤锥锦锨锩锪锫锬锭键锯锰锱锲锴锵锶锷锸锹锺锼锾锿镀镁镂镄镅镆镉镊镌镍镎镏镐镑镓镔镖镗镘镙镛镜镝镞镟镡镢镣镤镥镦镧镨镩镪镫镬镭镯镰镱镲镳闩闪闫闭问闯闰闱闲闳间闵闶闷闸闹闺闻闼闽闾阀阁阃阄阅阆阈阉阊阋阌阍阎阏阐阑阒阔阕阖阗阙阚阜阝阡阢阪阮阱阴阵阶阼阽陀陂陇陈陉陋陌陕陛陟陡陧陨险陪陬陲陴陵陶陷隅隆隈隋隍隐隔隗隙障隧隰隹隼隽雀雁雄雅雇雉雌雎雏雒雕雠雨雩雪雯雳零雹雾霁霄霆震霈霉霍霎霏霓霖霜霞霪霭霹霾靓靖靛非靠靡面靥靳靴靶靼鞅鞋鞍鞑鞒鞔鞘鞠鞣鞫鞯鞲韦韧韩韪韫韬韭韵韶页顶顷顸项须顼顽顾顿颀颁颂颃预颅颇颈颉颊颌颍颓颔颖颗颚颛颜额颞颟颠颡颢颤颥颦颧飑飒飓飕飘飙飚飞飧飨餍餮饕饣饨饩饪饫饬饭饮饰饱饲饴饵饷饼饽饿馀馁馄馅馆馇馈馊馋馍馏馐馑馒馔馕馗馘馨驭驮驯驰驱驳驴驵驶驷驸驹驺驻驽驾驿骀骂骄骅骆骈骊骋骏骐骑骒骓骖骘骚骛骜骝骞骟骠骡骢骣骤骨骰骱骶骷骺骼髀髁髂髅髋髌髑髓高髟髡髦髫髭髯髹髻鬃鬈鬏鬓鬟鬯鬲鬻鬼魁魂魃魄魅魇魈魉魏魑魔鱿鲂鲅鲆鲇鲈鲋鲍鲎鲐鲑鲒鲔鲕鲚鲛鲜鲞鲟鲠鲡鲢鲣鲤鲥鲦鲧鲨鲩鲫鲭鲮鲰鲱鲲鲳鲴鲵鲶鲷鲸鲺鲻鲼鲽鳃鳄鳅鳇鳋鳌鳍鳎鳏鳐鳓鳔鳕鳗鳘鳙鳜鳝鳞鳟鳢鸟鸠鸢鸥鸦鸨鸩鸪鸫鸬鸭鸯鸱鸲鸳鸵鸶鸷鸸鸹鸺鸽鸾鸿鹁鹂鹃鹄鹅鹆鹇鹈鹉鹋鹌鹎鹏鹑鹕鹗鹘鹚鹛鹜鹞鹣鹤鹦鹧鹨鹩鹪鹫鹬鹭鹰鹱鹳鹾鹿麂麇麈麋麒麓麝麟麦麴麸麻麽麾黉黍黏黔默黛黜黝黟黠黢黥黩黪黯黹黻黼黾鼋鼍鼎鼐鼓鼗鼙鼠鼢鼬鼯鼷鼹鼻鼽鼾齄齑龀龃龄龅龆龇龈龉龊龋龌龚龛龟龠 -------------------------------------------------------------------------------- /modules/modulated_deform_conv.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from __future__ import absolute_import 3 | from __future__ import print_function 4 | from __future__ import division 5 | 6 | import torch 7 | import math 8 | from torch import nn 9 | from torch.nn import init 10 | from torch.nn.modules.utils import _pair 11 | 12 | from functions.modulated_deform_conv_func import ModulatedDeformConvFunction 13 | 14 | class ModulatedDeformConv(nn.Module): 15 | 16 | def __init__(self, in_channels, out_channels, 17 | kernel_size, stride, padding, dilation=1, groups=1, deformable_groups=1, im2col_step=64, bias=True): 18 | super(ModulatedDeformConv, self).__init__() 19 | 20 | if in_channels % groups != 0: 21 | raise ValueError('in_channels {} must be divisible by groups {}'.format(in_channels, groups)) 22 | if out_channels % groups != 0: 23 | raise ValueError('out_channels {} must be divisible by groups {}'.format(out_channels, groups)) 24 | 25 | self.in_channels = in_channels 26 | self.out_channels = out_channels 27 | self.kernel_size = _pair(kernel_size) 28 | self.stride = _pair(stride) 29 | self.padding = _pair(padding) 30 | self.dilation = _pair(dilation) 31 | self.groups = groups 32 | self.deformable_groups = deformable_groups 33 | self.im2col_step = im2col_step 34 | self.use_bias = bias 35 | 36 | self.weight = nn.Parameter(torch.Tensor( 37 | out_channels, in_channels//groups, *self.kernel_size)) 38 | self.bias = nn.Parameter(torch.Tensor(out_channels)) 39 | self.reset_parameters() 40 | if not self.use_bias: 41 | self.bias.requires_grad = False 42 | 43 | def reset_parameters(self): 44 | n = self.in_channels 45 | init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 46 | if self.bias is not None: 47 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) 48 | bound = 1 / math.sqrt(fan_in) 49 | init.uniform_(self.bias, -bound, bound) 50 | 51 | def forward(self, input, offset, mask): 52 | assert 2 * self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] == \ 53 | offset.shape[1] 54 | assert self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] == \ 55 | mask.shape[1] 56 | return ModulatedDeformConvFunction.apply(input, offset, mask, 57 | self.weight, 58 | self.bias, 59 | self.stride, 60 | self.padding, 61 | self.dilation, 62 | self.groups, 63 | self.deformable_groups, 64 | self.im2col_step) 65 | 66 | _ModulatedDeformConv = ModulatedDeformConvFunction.apply 67 | 68 | class ModulatedDeformConvPack(ModulatedDeformConv): 69 | 70 | def __init__(self, in_channels, out_channels, 71 | kernel_size, stride, padding, 72 | dilation=1, groups=1, deformable_groups=1, double=False, im2col_step=64, bias=True, lr_mult=0.1): 73 | super(ModulatedDeformConvPack, self).__init__(in_channels, out_channels, 74 | kernel_size, stride, padding, dilation, groups, deformable_groups, im2col_step, bias) 75 | 76 | out_channels = self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1] 77 | if double == False: 78 | self.conv_offset_mask = nn.Conv2d(self.in_channels, 79 | out_channels, 80 | kernel_size=self.kernel_size, 81 | stride=self.stride, 82 | padding=self.padding, 83 | bias=True) 84 | else: 85 | self.conv_offset_mask = nn.Conv2d(self.in_channels*2, 86 | out_channels, 87 | kernel_size=self.kernel_size, 88 | stride=self.stride, 89 | padding=self.padding, 90 | bias=True) 91 | self.conv_offset_mask.lr_mult = lr_mult 92 | self.init_offset() 93 | 94 | def init_offset(self): 95 | self.conv_offset_mask.weight.data.zero_() 96 | self.conv_offset_mask.bias.data.zero_() 97 | 98 | def forward(self, input_offset, input_real): 99 | out = self.conv_offset_mask(input_offset) 100 | o1, o2, mask = torch.chunk(out, 3, dim=1) 101 | offset = torch.cat((o1, o2), dim=1) 102 | mask = torch.sigmoid(mask) 103 | return ModulatedDeformConvFunction.apply(input_real, offset, mask, 104 | self.weight, 105 | self.bias, 106 | self.stride, 107 | self.padding, 108 | self.dilation, 109 | self.groups, 110 | self.deformable_groups, 111 | self.im2col_step), offset 112 | 113 | -------------------------------------------------------------------------------- /train/train_style_vec.py: -------------------------------------------------------------------------------- 1 | from tqdm import trange 2 | import torch.nn 3 | import torch.nn.functional as F 4 | import torch.nn.parallel 5 | import torch.optim 6 | import torch.utils.data 7 | import torch.utils.data.distributed 8 | from tools.utils import * 9 | from tools.ops import compute_grad_gp, update_average, copy_norm_params, queue_data, dequeue_data, \ 10 | average_gradients, calc_adv_loss, calc_contrastive_loss, calc_recon_loss, calc_abl 11 | from tools.hsic import RbfHSIC 12 | 13 | 14 | def add_indp_fact_loss(self, *exp_pairs): 15 | pairs = [] 16 | for _exp1, _exp2 in exp_pairs: 17 | _pairs = [(F.adaptive_avg_pool2d(_exp1[:, i], 1).squeeze(), 18 | F.adaptive_avg_pool2d(_exp2[:, i], 1).squeeze()) 19 | for i in range(_exp1.shape[1])] 20 | pairs += _pairs 21 | 22 | crit = RbfHSIC(1) 23 | losses = [crit(*pair) for pair in pairs] 24 | return losses 25 | 26 | def trainGAN(data_loader, networks, opts, epoch, args, additional, \ 27 | detach=False, quantize=False, style_con=False, reconstruction_losses=False, hsic_loss=False, abl=False): 28 | # avg meter 29 | d_losses = AverageMeter() 30 | d_advs = AverageMeter() 31 | d_gps = AverageMeter() 32 | 33 | g_losses = AverageMeter() 34 | g_advs = AverageMeter() 35 | g_imgrecs = AverageMeter() 36 | g_rec = AverageMeter() 37 | 38 | moco_losses = AverageMeter() 39 | 40 | # set nets 41 | D = networks['D'] if not args.distributed else networks['D'].module 42 | G = networks['G'] if not args.distributed else networks['G'].module 43 | C = networks['C'] if not args.distributed else networks['C'].module 44 | G_EMA = networks['G_EMA'] if not args.distributed else networks['G_EMA'].module 45 | C_EMA = networks['C_EMA'] if not args.distributed else networks['C_EMA'].module 46 | # set opts 47 | d_opt = opts['D'] 48 | g_opt = opts['G'] 49 | c_opt = opts['C'] 50 | # switch to train mode 51 | D.train() 52 | G.train() 53 | C.train() 54 | C_EMA.train() 55 | G_EMA.train() 56 | 57 | logger = additional['logger'] 58 | 59 | 60 | # summary writer 61 | train_it = iter(data_loader) 62 | 63 | t_train = trange(0, args.iters, initial=0, total=args.iters) 64 | 65 | for i in t_train: 66 | try: 67 | imgs, y_org = next(train_it) 68 | except: 69 | train_it = iter(data_loader) 70 | imgs, y_org = next(train_it) 71 | 72 | x_org = imgs 73 | 74 | 75 | x_ref_idx = torch.randperm(x_org.size(0)) 76 | 77 | x_org = x_org.to(torch.cuda.current_device()) 78 | y_org = y_org.to(torch.cuda.current_device()) 79 | x_ref_idx = x_ref_idx.to(torch.cuda.current_device()) 80 | 81 | x_ref = x_org.clone() 82 | x_ref = x_ref[x_ref_idx] 83 | 84 | training_mode = 'GAN' 85 | 86 | # Train G 87 | s_src = C.moco(x_org) 88 | 89 | c_src, skip1, skip2 = G.cnt_encoder(x_org) 90 | x_rec, _ = G.decode(c_src, s_src, skip1, skip2) 91 | 92 | g_imgrec = calc_recon_loss(x_rec, x_org) 93 | 94 | g_loss = args.w_rec * g_imgrec 95 | 96 | # abl 97 | if abl: 98 | g_img_abl = calc_abl(x_rec, x_org) 99 | # print(f"abl:{g_img_abl} g_adv:{g_adv} g_imgrec:{g_imgrec} g_conrec:{g_conrec} offset_loss:{offset_loss}") 100 | if g_img_abl is not None: 101 | g_loss += args.w_rec * g_img_abl 102 | 103 | g_opt.zero_grad() 104 | c_opt.zero_grad() 105 | g_loss.backward() 106 | if args.distributed: 107 | average_gradients(G) 108 | average_gradients(C) 109 | c_opt.step() 110 | g_opt.step() 111 | 112 | ################## 113 | # END Train GANs # 114 | ################## 115 | 116 | 117 | if epoch >= args.ema_start: 118 | training_mode = training_mode + "_EMA" 119 | update_average(G_EMA, G) 120 | update_average(C_EMA, C) 121 | 122 | torch.cuda.synchronize() 123 | 124 | # with torch.no_grad(): 125 | # if epoch >= args.separated: 126 | # d_losses.update(d_loss.item(), x_org.size(0)) 127 | # d_advs.update(d_adv.item(), x_org.size(0)) 128 | # d_gps.update(d_gp.item(), x_org.size(0)) 129 | 130 | # g_losses.update(g_loss.item(), x_org.size(0)) 131 | # g_advs.update(g_adv.item(), x_org.size(0)) 132 | # g_imgrecs.update(g_imgrec.item(), x_org.size(0)) 133 | # g_rec.update(g_conrec.item(), x_org.size(0)) 134 | 135 | # moco_losses.update(offset_loss.item(), x_org.size(0)) 136 | 137 | # if (i + 1) % args.log_step == 0 and (args.gpu == 0 or args.gpu == '0') and logger is not None and args.local_rank == 0: 138 | # summary_step = epoch * args.iters + i 139 | # add_logs(args, logger, 'D/LOSS', d_losses.avg, summary_step) 140 | # add_logs(args, logger, 'D/ADV', d_advs.avg, summary_step) 141 | # add_logs(args, logger, 'D/GP', d_gps.avg, summary_step) 142 | 143 | # add_logs(args, logger, 'G/LOSS', g_losses.avg, summary_step) 144 | # add_logs(args, logger, 'G/ADV', g_advs.avg, summary_step) 145 | # add_logs(args, logger, 'G/IMGREC', g_imgrecs.avg, summary_step) 146 | # add_logs(args, logger, 'G/conrec', g_rec.avg, summary_step) 147 | 148 | # add_logs(args, logger, 'C/OFFSET', moco_losses.avg, summary_step) 149 | 150 | # print('Epoch: [{}/{}] [{}/{}] MODE[{}] Avg Loss: D[{d_losses.avg:.2f}] G[{g_losses.avg:.2f}] '.format(epoch + 1, args.epochs, i+1, args.iters, 151 | # training_mode, d_losses=d_losses, g_losses=g_losses)) 152 | 153 | # copy_norm_params(G_EMA, G) 154 | # copy_norm_params(C_EMA, C) 155 | 156 | -------------------------------------------------------------------------------- /eval/eval_2dirs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import pandas as pd 5 | import cv2 6 | import tqdm 7 | import lpips 8 | import time 9 | 10 | from tqdm.cli import main 11 | 12 | from eval_utils import LPIPS, L1, RMSE, SSIM 13 | 14 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 15 | parser.add_argument('-gt','--dir_gt', type=str, default='./imgs/ex_dir0') 16 | parser.add_argument('-pred','--dir_pred', type=str, default='./imgs/ex_dir1') 17 | parser.add_argument('-o','--out', type=str, default='./imgs/example_dists') 18 | parser.add_argument('-m','--methods', type=str, default='', nargs='+') 19 | parser.add_argument('--subfolder', action='store_true') 20 | parser.add_argument('--use_gpu', action='store_true', help='turn on flag to use GPU') 21 | parser.add_argument('--standalone', action='store_true', help='standalone metric keep alone in output file') 22 | parser.add_argument('--check', action='store_true', help='check empty file') 23 | parser.add_argument('-t','--tmp_folder', type=str, default='/tmp', help='tmp folder for check empty file') 24 | 25 | opt = parser.parse_args() 26 | 27 | methods_choices = ['l1', 'rmse', 'ssim', 'lpips', 'fid'] 28 | methods_standalone = ['fid'] 29 | if opt.methods is '': 30 | use_methods = methods_choices 31 | else: 32 | use_methods = opt.methods 33 | assert all([mi in methods_choices for mi in use_methods]), 'invalid mathods exist' 34 | 35 | def load_image(path): 36 | assert path[-3:]=='bmp' or path[-3:]=='jpg' or path[-3:]=='png' or path[-4:]=='jpeg' 37 | return cv2.imread(path)[:,:,::-1] 38 | 39 | def easy_log(msg, flag=None, fn=None, append=False, on_screen=True, pd_df=None): 40 | if flag == '=': 41 | msg += '\n=======================' 42 | elif flag == '-': 43 | msg += '\n-------------' 44 | elif flag == '!': 45 | msg += '!!!\n!!!\n' 46 | 47 | if on_screen: 48 | print(msg) 49 | if pd_df is not None: print(pd_df) 50 | if fn: 51 | with open(fn, 'a' if append else 'w') as f: 52 | msg = '{:}\n{:}\n'.format(msg, pd_df.to_string()) if pd_df is not None else msg+'\n' 53 | f.write(msg) 54 | 55 | 56 | if __name__ == '__main__': 57 | methods_allinone = set(use_methods) - set(methods_standalone) 58 | methods_standalone = set(use_methods) - methods_allinone 59 | 60 | files = sorted(os.listdir(opt.dir_gt)) 61 | for file in files: 62 | assert os.path.exists(os.path.join(opt.dir_pred,file)) 63 | 64 | empty_files = [] 65 | remain_files = [] 66 | 67 | exist_empty_file = False 68 | if opt.check: 69 | for file in tqdm.tqdm(files): 70 | img = cv2.imread(os.path.join(opt.dir_gt,file)) 71 | if img.mean() == 255: # all white 72 | empty_files.append(file) 73 | else: 74 | remain_files.append(file) 75 | exist_empty_file = len(empty_files) > 0 76 | if exist_empty_file: 77 | easy_log('Finding {} empty imgs: {}'.format(len(empty_files), empty_files), flag = '!', fn = opt.out + '_scores.txt') 78 | files = remain_files 79 | gt_true, pred_true = opt.dir_gt, opt.dir_pred 80 | gt_fake, pred_fake = os.path.basename(opt.dir_gt), os.path.basename(opt.dir_pred) 81 | if gt_fake == pred_fake: gt_fake = 'gt_' + gt_fake 82 | gt_fake, pred_fake = os.path.join(opt.tmp_folder, gt_fake), os.path.join(opt.tmp_folder, pred_fake) 83 | opt.dir_gt, opt.dir_pred = gt_fake, pred_fake 84 | if os.path.exists(opt.dir_gt) or os.path.exists(opt.dir_pred): 85 | ans = input(f'Delete <2{gt_fake}> and <{pred_fake}>? [y/n] ') 86 | assert ans == 'y' 87 | os.system(f'rm -rf {gt_fake} & rm -rf {pred_fake}') 88 | assert not os.path.exists(opt.dir_gt) and not os.path.exists(opt.dir_pred) 89 | os.mkdir(opt.dir_gt) 90 | os.mkdir(opt.dir_pred) 91 | for file in files: 92 | shutil.copy(os.path.join(gt_true, file), os.path.join(opt.dir_gt, file)) 93 | shutil.copy(os.path.join(pred_true, file), os.path.join(opt.dir_pred, file)) 94 | 95 | 96 | 97 | if len(methods_standalone) > 0: 98 | easy_log('Standalone Metrics', flag = '=', fn = opt.out + '_scores.txt', append=len(empty_files) > -1) 99 | # Fid 100 | if 'fid' in use_methods: 101 | suffix = '_fid.txt' if opt.standalone else '_scores.txt' 102 | st = time.time() 103 | print('Calculate Fid...') 104 | print('writing to txt:', opt.out + suffix) 105 | os.system(f'python -m pytorch_fid {opt.dir_gt} {opt.dir_pred} --batch-size 100 | tee -a {opt.out}{suffix}') # 2>&1 | tee 106 | print('done! using {:.2f}s'.format(time.time() - st)) 107 | 108 | # ALL in one 109 | if len(methods_allinone) > 0: 110 | if 'lpips' in use_methods: 111 | lpips_model = LPIPS(using_gpu=opt.use_gpu) 112 | 113 | easy_log('\nAllInOne Metrics', flag = '=', fn = opt.out + '_scores.txt', append=True) 114 | 115 | rsts = {} 116 | for i in use_methods: 117 | if i != 'fid': 118 | rsts[i] = [] 119 | num = len(files) 120 | 121 | st = time.time() 122 | fns = [] 123 | for file in tqdm.tqdm(files): 124 | if(os.path.exists(os.path.join(opt.dir_pred,file))): 125 | # Load images 126 | fns.append(file) 127 | img_gt = lpips.load_image(os.path.join(opt.dir_gt,file)) # HWC, RGB, [0, 255] 128 | img_pred = lpips.load_image(os.path.join(opt.dir_pred,file)) # HWC, RGB, [0, 255] 129 | if 'lpips' in use_methods: 130 | rst_lpips = lpips_model.cal_lpips(img_gt, img_pred) 131 | rsts['lpips'].append(rst_lpips.detach().cpu().numpy()) 132 | if 'l1' in use_methods: 133 | rst_l1 = L1(img_gt, img_pred, 255.) 134 | rsts['l1'].append(rst_l1) 135 | if 'rmse' in use_methods: 136 | rst_rmse = RMSE(img_gt, img_pred, 255.) 137 | rsts['rmse'].append(rst_rmse) 138 | if 'ssim' in use_methods: 139 | rst_ssim = SSIM(img_gt, img_pred, 255.) 140 | rsts['ssim'].append(rst_ssim) 141 | 142 | tab = pd.DataFrame(rsts, fns) 143 | 144 | easy_log('Mean', flag = '-', fn = opt.out + '_scores.txt', append=True, pd_df=tab.mean(0)) 145 | print('-------------\n') 146 | easy_log('Detail', flag = '-', fn = opt.out + '_scores.txt', append=True, pd_df=tab, on_screen=False) 147 | 148 | print('writing to txt:', opt.out + '_scores.txt') 149 | print('done! using {:.2f}s'.format(time.time() - st)) 150 | 151 | if exist_empty_file: 152 | os.system(f'rm -rf {gt_fake} & rm -rf {pred_fake}') -------------------------------------------------------------------------------- /tools/ops.py: -------------------------------------------------------------------------------- 1 | from torch import autograd 2 | import torch 3 | import torch.distributed as dist 4 | from torch.nn import functional as F 5 | 6 | from .phl import PHL 7 | from .wdl import WDL 8 | from .pkl import PKL 9 | 10 | from .abl_allinone import ABL 11 | cal_abl = ABL(max_clip_dist=4., max_N_ratio=1/50) 12 | 13 | 14 | def calc_abl(predict, target, with_iou = False): # May None 15 | # RGB -> gray 16 | predict = predict.mean(1, keepdim=True) 17 | target = target.mean(1) 18 | #target = torch.where(target > 0.5, 1, 0) 19 | target = (target > 0.5) * 1.0 20 | logits = torch.cat([1-predict, predict], dim=1) 21 | abl = cal_abl(logits, target) 22 | return abl 23 | 24 | def compute_grad_gp(d_out, x_in, is_patch=False): 25 | batch_size = x_in.size(0) 26 | grad_dout = autograd.grad( 27 | outputs=d_out.sum() if not is_patch else d_out.mean(), inputs=x_in, 28 | create_graph=True, retain_graph=True, only_inputs=True)[0] 29 | grad_dout2 = grad_dout.pow(2) 30 | assert (grad_dout2.size() == x_in.size()) 31 | reg = grad_dout2.sum() / batch_size 32 | return reg 33 | 34 | 35 | def compute_grad_gp_wgan(D, x_real, x_fake, gpu): 36 | alpha = torch.rand(x_real.size(0), 1, 1, 1).cuda(gpu) 37 | 38 | x_interpolate = ((1 - alpha) * x_real + alpha * x_fake).detach() 39 | x_interpolate.requires_grad = True 40 | d_inter_logit = D(x_interpolate) 41 | grad = torch.autograd.grad(d_inter_logit, x_interpolate, 42 | grad_outputs=torch.ones_like(d_inter_logit), create_graph=True)[0] 43 | 44 | norm = grad.view(grad.size(0), -1).norm(p=2, dim=1) 45 | 46 | d_gp = ((norm - 1) ** 2).mean() 47 | return d_gp 48 | 49 | 50 | def update_average(model_tgt, model_src, beta=0.999): 51 | with torch.no_grad(): 52 | param_dict_src = dict(model_src.named_parameters()) 53 | for p_name, p_tgt in model_tgt.named_parameters(): 54 | p_src = param_dict_src[p_name] 55 | assert (p_src is not p_tgt) 56 | p_tgt.copy_(beta * p_tgt + (1. - beta) * p_src) 57 | 58 | 59 | def copy_norm_params(model_tgt, model_src): 60 | with torch.no_grad(): 61 | src_state_dict = model_src.state_dict() 62 | tgt_state_dict = model_tgt.state_dict() 63 | names = [name for name, _ in model_tgt.named_parameters()] 64 | for n in names: 65 | del src_state_dict[n] 66 | tgt_state_dict.update(src_state_dict) 67 | model_tgt.load_state_dict(tgt_state_dict, strict=False) 68 | 69 | 70 | def calc_iic_loss(x_out, x_tf_out, lamb=1.0, EPS=1e-10): 71 | # has had softmax applied 72 | _, k = x_out.size() 73 | p_i_j = compute_joint(x_out, x_tf_out) 74 | assert (p_i_j.size() == (k, k)) 75 | 76 | p_i = p_i_j.sum(dim=1).view(k, 1).expand(k, k) 77 | p_j = p_i_j.sum(dim=0).view(1, k).expand(k, k) # but should be same, symmetric 78 | 79 | # avoid NaN losses. Effect will get cancelled out by p_i_j tiny anyway 80 | p_i_j[(p_i_j < EPS).data] = EPS 81 | p_j[(p_j < EPS).data] = EPS 82 | p_i[(p_i < EPS).data] = EPS 83 | 84 | loss = - p_i_j * (torch.log(p_i_j) \ 85 | - lamb * torch.log(p_j) \ 86 | - lamb * torch.log(p_i)) 87 | 88 | loss = loss.sum() 89 | 90 | return loss 91 | 92 | 93 | def compute_joint(x_out, x_tf_out): 94 | # produces variable that requires grad (since args require grad) 95 | 96 | bn, k = x_out.size() 97 | assert (x_tf_out.size(0) == bn and x_tf_out.size(1) == k) 98 | 99 | p_i_j = x_out.unsqueeze(2) * x_tf_out.unsqueeze(1) # bn, k, k 100 | p_i_j = p_i_j.sum(dim=0) # k, k 101 | p_i_j = (p_i_j + p_i_j.t()) / 2. # symmetrise 102 | p_i_j = p_i_j / p_i_j.sum() # normalise 103 | 104 | return p_i_j 105 | 106 | 107 | def calc_recon_loss(predict, target): # L1 108 | return torch.mean(torch.abs(predict - target)) 109 | 110 | def calc_wdl(predict, target): 111 | return torch.mean(WDL(predict.mean(dim=1, keepdim=True), target.mean(dim=1, keepdim=True))) 112 | 113 | def calc_pkl(predict, target): 114 | return torch.mean(PKL(predict.mean(dim=1, keepdim=True), target.mean(dim=1, keepdim=True))) 115 | 116 | def calc_pseudo_hamming_loss(predict, target, thres=0): 117 | return torch.mean(PHL(predict.mean(dim=1, keepdim=True), target.mean(dim=1, keepdim=True), thres)) 118 | 119 | def calc_contrastive_loss(args, query, key, queue, temp=0.07): 120 | N = query.shape[0] 121 | K = queue.shape[0] 122 | 123 | zeros = torch.zeros(N, dtype=torch.long).cuda(args.gpu) 124 | key = key.detach() 125 | logit_pos = torch.bmm(query.view(N, 1, -1), key.view(N, -1, 1)) 126 | logit_neg = torch.mm(query.view(N, -1), queue.t().view(-1, K)) 127 | 128 | logit = torch.cat([logit_pos.view(N, 1), logit_neg], dim=1) 129 | 130 | loss = F.cross_entropy(logit / temp, zeros) 131 | 132 | return loss 133 | 134 | 135 | def calc_adv_loss(logit, mode): 136 | assert mode in ['d_real', 'd_fake', 'g'] 137 | if mode == 'd_real': 138 | loss = F.relu(1.0 - logit).mean() 139 | elif mode == 'd_fake': 140 | loss = F.relu(1.0 + logit).mean() 141 | else: 142 | loss = -logit.mean() 143 | 144 | return loss 145 | 146 | 147 | def queue_data(data, k): 148 | return torch.cat([data, k], dim=0) 149 | 150 | 151 | def dequeue_data(data, K=1024): 152 | if len(data) > K: 153 | return data[-K:] 154 | else: 155 | return data 156 | 157 | 158 | def initialize_queue(model_k, device, train_loader, feat_size=128): 159 | queue = torch.zeros((0, feat_size), dtype=torch.float) 160 | queue = queue.to(device) 161 | 162 | for _, (data, _) in enumerate(train_loader): 163 | x_k = data[1] 164 | x_k = x_k.cuda(device) 165 | outs = model_k(x_k) 166 | k = outs['cont'] 167 | k = k.detach() 168 | queue = queue_data(queue, k) 169 | queue = dequeue_data(queue, K=1024) 170 | break 171 | return queue 172 | 173 | 174 | def average_gradients(model): 175 | size = float(dist.get_world_size()) 176 | for param in model.parameters(): 177 | # Handle unused parameters 178 | if param.grad is None: 179 | continue 180 | dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) 181 | param.grad.data /= size 182 | -------------------------------------------------------------------------------- /charset/GB2312_CN6763.txt: -------------------------------------------------------------------------------- 1 | 一丁七万丈三上下丌不与丐丑专且丕世丘丙业丛东丝丞丢两严丧丨个丫丬中丰串临丶丸丹为主丽举丿乃久乇么义之乌乍乎乏乐乒乓乔乖乘乙乜九乞也习乡书乩买乱乳乾了予争事二亍于亏云互亓五井亘亚些亟亠亡亢交亥亦产亨亩享京亭亮亲亳亵人亻亿什仁仂仃仄仅仆仇仉今介仍从仑仓仔仕他仗付仙仝仞仟仡代令以仨仪仫们仰仲仳仵件价任份仿企伉伊伍伎伏伐休众优伙会伛伞伟传伢伤伥伦伧伪伫伯估伲伴伶伸伺似伽佃但位低住佐佑体何佗佘余佚佛作佝佞佟你佣佤佥佧佩佬佯佰佳佴佶佻佼佾使侃侄侈侉例侍侏侑侔侗供依侠侣侥侦侧侨侩侪侬侮侯侵便促俄俅俊俎俏俐俑俗俘俚俜保俞俟信俣俦俨俩俪俭修俯俱俳俸俺俾倌倍倏倒倔倘候倚倜借倡倥倦倨倩倪倬倭倮债值倾偃假偈偌偎偏偕做停健偬偶偷偻偾偿傀傅傈傍傣傥傧储傩催傲傺傻像僖僚僦僧僬僭僮僳僵僻儆儇儋儒儡儿兀允元兄充兆先光克免兑兔兕兖党兜兢入全八公六兮兰共关兴兵其具典兹养兼兽冀冁冂内冈冉册再冒冕冖冗写军农冠冢冤冥冫冬冯冰冱冲决况冶冷冻冼冽净凄准凇凉凋凌减凑凛凝几凡凤凫凭凯凰凳凵凶凸凹出击凼函凿刀刁刂刃分切刈刊刍刎刑划刖列刘则刚创初删判刨利别刭刮到刳制刷券刹刺刻刽刿剀剁剂剃削剌前剐剑剔剖剜剞剡剥剧剩剪副割剽剿劁劂劈劐劓力劝办功加务劢劣动助努劫劬劭励劲劳劾势勃勇勉勋勐勒勖勘募勤勰勹勺勾勿匀包匆匈匍匏匐匕化北匙匚匝匠匡匣匦匪匮匹区医匾匿十千卅升午卉半华协卑卒卓单卖南博卜卞卟占卡卢卣卤卦卧卩卫卮卯印危即却卵卷卸卺卿厂厄厅历厉压厌厍厕厘厚厝原厢厣厥厦厨厩厮厶去县叁参又叉及友双反发叔取受变叙叛叟叠口古句另叨叩只叫召叭叮可台叱史右叵叶号司叹叻叼叽吁吃各吆合吉吊同名后吏吐向吒吓吕吖吗君吝吞吟吠吡吣否吧吨吩含听吭吮启吱吲吴吵吸吹吻吼吾呀呃呆呈告呋呐呒呓呔呕呖呗员呙呛呜呢呤呦周呱呲味呵呶呷呸呻呼命咀咂咄咆咋和咎咏咐咒咔咕咖咙咚咛咝咣咤咦咧咨咩咪咫咬咭咯咱咳咴咸咻咽咿哀品哂哄哆哇哈哉哌响哎哏哐哑哒哓哔哕哗哙哚哜哝哞哟哥哦哧哨哩哪哭哮哲哳哺哼哽哿唁唆唇唉唏唐唑唔唛唠唢唣唤唧唪唬售唯唰唱唳唷唼唾唿啁啃啄商啉啊啐啕啖啜啡啤啥啦啧啪啬啭啮啵啶啷啸啻啼啾喀喁喂喃善喇喈喉喊喋喏喑喔喘喙喜喝喟喧喱喳喵喷喹喻喽喾嗄嗅嗉嗌嗍嗑嗒嗓嗔嗖嗜嗝嗟嗡嗣嗤嗥嗦嗨嗪嗫嗬嗯嗲嗳嗵嗷嗽嗾嘀嘁嘈嘉嘌嘎嘏嘘嘛嘞嘟嘣嘤嘧嘬嘭嘱嘲嘴嘶嘹嘻嘿噌噍噎噔噗噘噙噜噢噤器噩噪噫噬噱噶噻噼嚅嚆嚎嚏嚓嚣嚯嚷嚼囊囔囗囚四囝回囟因囡团囤囫园困囱围囵囹固国图囿圃圄圆圈圉圊圜土圣在圩圪圬圭圮圯地圳圹场圻圾址坂均坊坌坍坎坏坐坑块坚坛坜坝坞坟坠坡坤坦坨坩坪坫坭坯坳坶坷坻坼垂垃垄垅垆型垌垒垓垛垠垡垢垣垤垦垧垩垫垭垮垲垴垸埂埃埋城埏埒埔埕埘埙埚埝域埠埤埭埯埴埸培基埽堀堂堆堇堋堍堑堕堙堞堠堡堤堪堰堵塄塌塍塑塔塘塞塥填塬塾墀墁境墅墉墒墓墙墚增墟墨墩墼壁壅壑壕壤士壬壮声壳壶壹夂处备复夏夔夕外夙多夜够夤夥大天太夫夭央夯失头夷夸夹夺夼奁奂奄奇奈奉奋奎奏契奔奕奖套奘奚奠奢奥女奴奶奸她好妁如妃妄妆妇妈妊妍妒妓妖妗妙妞妣妤妥妨妩妪妫妮妯妲妹妻妾姆姊始姐姑姒姓委姗姘姚姜姝姣姥姨姬姹姻姿威娃娄娅娆娇娈娉娌娑娓娘娜娟娠娣娥娩娱娲娴娶娼婀婆婉婊婕婚婢婧婪婴婵婶婷婺婿媒媚媛媪媲媳媵媸媾嫁嫂嫉嫌嫒嫔嫖嫘嫜嫠嫡嫣嫦嫩嫫嫱嬉嬖嬗嬲嬴嬷孀子孑孓孔孕字存孙孚孛孜孝孟孢季孤孥学孩孪孬孰孱孳孵孺孽宀宁它宄宅宇守安宋完宏宓宕宗官宙定宛宜宝实宠审客宣室宥宦宪宫宰害宴宵家宸容宽宾宿寂寄寅密寇富寐寒寓寝寞察寡寤寥寨寮寰寸对寺寻导寿封射将尉尊小少尔尕尖尘尚尜尝尢尤尥尧尬就尴尸尹尺尻尼尽尾尿局屁层居屈屉届屋屎屏屐屑展屙属屠屡屣履屦屮屯山屹屺屿岁岂岈岌岍岐岑岔岖岗岘岙岚岛岜岢岣岩岫岬岭岱岳岵岷岸岽岿峁峄峋峒峙峡峤峥峦峨峪峭峰峻崂崃崆崇崎崔崖崛崞崤崦崧崩崭崮崴崽崾嵇嵊嵋嵌嵘嵛嵝嵩嵫嵬嵯嵴嶂嶙嶝嶷巅巍巛川州巡巢工左巧巨巩巫差巯己已巳巴巷巽巾币市布帅帆师希帏帐帑帔帕帖帘帙帚帛帜帝带帧席帮帱帷常帻帼帽幂幄幅幌幔幕幛幞幡幢干平年并幸幺幻幼幽广庀庄庆庇床庋序庐庑库应底庖店庙庚府庞废庠庥度座庭庳庵庶康庸庹庾廉廊廑廒廓廖廛廨廪廴延廷建廾廿开弁异弃弄弈弊弋式弑弓引弗弘弛弟张弥弦弧弩弪弭弯弱弹强弼彀彐归当录彖彗彘彝彡形彤彦彩彪彬彭彰影彳彷役彻彼往征徂径待徇很徉徊律後徐徒徕得徘徙徜御徨循徭微徵德徼徽心忄必忆忉忌忍忏忐忑忒忖志忘忙忝忠忡忤忧忪快忭忮忱念忸忻忽忾忿怀态怂怃怄怅怆怊怍怎怏怒怔怕怖怙怛怜思怠怡急怦性怨怩怪怫怯怵总怼怿恁恂恃恋恍恐恒恕恙恚恝恢恣恤恧恨恩恪恫恬恭息恰恳恶恸恹恺恻恼恽恿悃悄悉悌悍悒悔悖悚悛悝悟悠患悦您悫悬悭悯悱悲悴悸悻悼情惆惊惋惑惕惘惚惜惝惟惠惦惧惨惩惫惬惭惮惯惰想惴惶惹惺愀愁愆愈愉愍愎意愕愚感愠愣愤愦愧愫愿慈慊慌慎慑慕慝慢慧慨慰慵慷憋憎憔憝憧憨憩憬憷憾懂懈懊懋懑懒懔懦懵懿戆戈戊戋戌戍戎戏成我戒戕或戗战戚戛戟戡戢戤戥截戬戮戳戴户戽戾房所扁扃扇扈扉手扌才扎扑扒打扔托扛扣扦执扩扪扫扬扭扮扯扰扳扶批扼找承技抄抉把抑抒抓投抖抗折抚抛抟抠抡抢护报抨披抬抱抵抹抻押抽抿拂拄担拆拇拈拉拊拌拍拎拐拒拓拔拖拗拘拙拚招拜拟拢拣拥拦拧拨择括拭拮拯拱拳拴拶拷拼拽拾拿持挂指挈按挎挑挖挚挛挝挞挟挠挡挢挣挤挥挨挪挫振挲挹挺挽捂捃捅捆捉捋捌捍捎捏捐捕捞损捡换捣捧捩捭据捱捶捷捺捻掀掂掇授掉掊掌掎掏掐排掖掘掠探掣接控推掩措掬掭掮掰掳掴掷掸掺掼掾揄揆揉揍揎描提插揖揞揠握揣揩揪揭揲援揶揸揽揿搀搁搂搅搋搌搏搐搓搔搛搜搞搠搡搦搪搬搭搴携搽搿摁摄摅摆摇摈摊摒摔摘摞摧摩摭摸摹摺撂撄撅撇撑撒撕撖撙撞撤撩撬播撮撰撵撷撸撺撼擀擂擅操擎擐擒擗擘擞擢擤擦攀攉攒攘攥攫攮支攴攵收攸改攻放政故效敉敌敏救敕敖教敛敝敞敢散敦敫敬数敲整敷文斋斌斐斑斓斗料斛斜斟斡斤斥斧斩斫断斯新方於施旁旃旄旅旆旋旌旎族旒旖旗无既日旦旧旨早旬旭旮旯旰旱时旷旺昀昂昃昆昊昌明昏易昔昕昙昝星映春昧昨昭是昱昴昵昶昼显晁晃晋晌晏晒晓晔晕晖晗晚晟晡晤晦晨普景晰晴晶晷智晾暂暄暇暌暑暖暗暝暧暨暮暴暹暾曙曛曜曝曦曩曰曲曳更曷曹曼曾替最月有朊朋服朐朔朕朗望朝期朦木未末本札术朱朴朵机朽杀杂权杆杈杉杌李杏材村杓杖杜杞束杠条来杨杩杪杭杯杰杲杳杵杷杼松板极构枇枉枋析枕林枘枚果枝枞枢枣枥枧枨枪枫枭枯枰枳枵架枷枸柁柃柄柏某柑柒染柔柘柙柚柜柝柞柠柢查柩柬柯柰柱柳柴柽柿栀栅标栈栉栊栋栌栎栏树栓栖栗栝校栩株栲栳样核根格栽栾桀桁桂桃桄桅框案桉桊桌桎桐桑桓桔桕桠桡桢档桤桥桦桧桨桩桫桴桶桷梁梃梅梆梏梓梗梢梦梧梨梭梯械梳梵检棂棉棋棍棒棕棘棚棠棣森棰棱棵棹棺棼椁椅椋植椎椐椒椟椠椤椭椰椴椹椽椿楂楔楗楚楝楞楠楣楦楫楮楱楷楸楹楼榀概榄榆榇榈榉榍榔榕榘榛榜榧榨榫榭榱榴榷榻槁槊槌槎槐槔槛槟槠槭槲槽槿樊樗樘樟模樨横樯樱樵樽樾橄橇橐橘橙橛橡橥橱橹橼檀檄檎檐檑檗檠檩檫檬欠次欢欣欤欧欲欷欹欺款歃歆歇歉歌歙止正此步武歧歪歹死歼殁殂殃殄殆殇殉殊残殍殒殓殖殚殛殡殪殳殴段殷殿毁毂毅毋母每毒毓比毕毖毗毙毛毡毪毫毯毳毵毹毽氅氆氇氍氏氐民氓气氕氖氘氙氚氛氟氡氢氤氦氧氨氩氪氮氯氰氲水氵永氽汀汁求汆汇汉汊汐汔汕汗汛汜汝汞江池污汤汨汩汪汰汲汴汶汹汽汾沁沂沃沅沆沈沉沌沏沐沓沔沙沛沟没沣沤沥沦沧沩沪沫沭沮沱沲河沸油治沼沽沾沿泄泅泉泊泌泐泓泔法泖泗泛泞泠泡波泣泥注泪泫泮泯泰泱泳泵泶泷泸泺泻泼泽泾洁洄洇洋洌洎洒洗洙洚洛洞津洧洪洫洮洱洲洳洵洹活洼洽派流浃浅浆浇浈浊测浍济浏浑浒浓浔浙浚浜浞浠浣浦浩浪浮浯浴海浸浼涂涅消涉涌涎涑涓涔涕涛涝涞涟涠涡涣涤润涧涨涩涪涫涮涯液涵涸涿淀淄淅淆淇淋淌淑淖淘淙淝淞淠淡淤淦淫淬淮深淳混淹添淼清渊渌渍渎渐渑渔渖渗渚渝渠渡渣渤渥温渫渭港渲渴游渺湃湄湍湎湓湔湖湘湛湟湫湮湾湿溃溅溆溉溏源溘溜溟溢溥溧溪溯溱溲溴溶溷溺溻溽滁滂滇滋滏滑滓滔滕滗滚滞滟滠满滢滤滥滦滨滩滴滹漂漆漉漏漓演漕漠漤漩漪漫漭漯漱漳漶漾潆潇潋潍潘潜潞潢潦潭潮潲潴潸潺潼澄澈澉澌澍澎澜澡澧澳澶澹激濂濉濑濒濞濠濡濮濯瀑瀚瀛瀣瀵瀹灌灏灞火灬灭灯灰灵灶灸灼灾灿炀炅炉炊炎炒炔炕炖炙炜炝炫炬炭炮炯炱炳炷炸点炻炼炽烀烁烂烃烈烊烘烙烛烟烤烦烧烨烩烫烬热烯烷烹烽焉焊焐焓焕焖焘焙焚焦焯焰焱然煅煊煌煎煜煞煤煦照煨煮煲煳煸煺煽熄熊熏熔熘熙熟熠熨熬熳熵熹燃燎燔燕燠燥燧燮燹爆爝爨爪爬爰爱爵父爷爸爹爻爽爿片版牌牍牒牖牙牛牝牟牡牢牦牧物牮牯牲牵特牺牾牿犀犁犄犊犋犍犏犒犟犬犭犯犰犴状犷犸犹狁狂狃狄狈狍狎狐狒狗狙狞狠狡狨狩独狭狮狯狰狱狲狳狴狷狸狺狻狼猁猃猊猎猓猕猖猗猛猜猝猞猡猢猥猩猪猫猬献猱猴猷猸猹猾猿獍獐獒獗獠獬獭獯獾玄率玉王玎玑玖玛玟玢玩玫玮环现玲玳玷玺玻珀珂珈珉珊珍珏珐珑珙珞珠珥珧珩班珲球琅理琉琊琏琐琚琛琢琥琦琨琪琬琮琰琳琴琵琶琼瑁瑕瑗瑙瑚瑛瑜瑞瑟瑭瑰瑶瑷瑾璀璁璃璇璋璎璐璜璞璧璨璩璺瓒瓜瓞瓠瓢瓣瓤瓦瓮瓯瓴瓶瓷瓿甄甍甏甑甓甘甙甚甜生甥用甩甫甬甭甯田由甲申电男甸町画甾畀畅畈畋界畎畏畔留畚畛畜略畦番畲畴畸畹畿疃疆疋疏疑疒疔疖疗疙疚疝疟疠疡疣疤疥疫疬疮疯疰疱疲疳疴疵疸疹疼疽疾痂痃痄病症痈痉痊痍痒痔痕痖痘痛痞痢痣痤痦痧痨痪痫痰痱痴痹痼痿瘀瘁瘃瘅瘊瘌瘐瘕瘗瘘瘙瘛瘟瘠瘢瘤瘥瘦瘩瘪瘫瘭瘰瘳瘴瘵瘸瘼瘾瘿癀癃癌癍癔癖癜癞癣癫癯癸登白百皂的皆皇皈皋皎皑皓皖皙皤皮皱皲皴皿盂盅盆盈益盍盎盏盐监盒盔盖盗盘盛盟盥目盯盱盲直相盹盼盾省眄眇眈眉看眍眙眚真眠眢眦眨眩眭眯眵眶眷眸眺眼着睁睃睇睐睑睚睛睡睢督睥睦睨睫睬睹睽睾睿瞀瞄瞅瞌瞍瞎瞑瞒瞟瞠瞢瞥瞧瞩瞪瞬瞰瞳瞵瞻瞽瞿矍矗矛矜矢矣知矧矩矫矬短矮石矶矸矽矾矿砀码砂砉砌砍砑砒研砖砗砘砚砜砝砟砣砥砦砧砩砬砭砰破砷砸砹砺砻砼砾础硅硇硌硎硐硒硕硖硗硝硪硫硬硭确硷硼碇碉碌碍碎碑碓碗碘碚碛碜碟碡碣碥碧碰碱碲碳碴碹碾磁磅磉磊磋磐磔磕磙磨磬磲磴磷磺礁礅礓礞礤礴示礻礼社祀祁祆祈祉祓祖祗祚祛祜祝神祟祠祢祥祧票祭祯祷祸祺禀禁禄禅禊福禚禧禳禹禺离禽禾秀私秃秆秉秋种科秒秕秘租秣秤秦秧秩秫秭积称秸移秽稀稂稃稆程稍税稔稗稚稞稠稣稳稷稹稻稼稽稿穆穑穗穰穴究穷穸穹空穿窀突窃窄窆窈窍窑窒窕窖窗窘窜窝窟窠窥窦窨窬窭窳窿立竖站竞竟章竣童竦竭端竹竺竽竿笃笄笆笈笊笋笏笑笔笕笙笛笞笠笤笥符笨笪笫第笮笱笳笸笺笼笾筅筇等筋筌筏筐筑筒答策筘筚筛筝筠筢筮筱筲筵筷筹筻签简箅箍箐箔箕算箜箝管箢箦箧箨箩箪箫箬箭箱箴箸篁篆篇篌篑篓篙篚篝篡篥篦篪篮篱篷篼篾簇簋簌簏簖簟簦簧簪簸簿籀籁籍米籴类籼籽粉粑粒粕粗粘粜粝粞粟粢粤粥粪粮粱粲粳粹粼粽精糁糅糇糈糊糌糍糕糖糗糙糜糟糠糨糯糸系紊素索紧紫累絮絷綦綮縻繁繇纂纛纟纠纡红纣纤纥约级纨纩纪纫纬纭纯纰纱纲纳纵纶纷纸纹纺纽纾线绀绁绂练组绅细织终绉绊绋绌绍绎经绐绑绒结绔绕绗绘给绚绛络绝绞统绠绡绢绣绥绦继绨绩绪绫续绮绯绰绱绲绳维绵绶绷绸绺绻综绽绾绿缀缁缂缃缄缅缆缇缈缉缋缌缍缎缏缑缒缓缔缕编缗缘缙缚缛缜缝缟缠缡缢缣缤缥缦缧缨缩缪缫缬缭缮缯缰缱缲缳缴缵缶缸缺罂罄罅罐网罔罕罗罘罚罟罡罢罨罩罪置罱署罴罹罾羁羊羌美羔羚羝羞羟羡群羧羯羰羲羸羹羼羽羿翁翅翊翌翎翔翕翘翟翠翡翥翦翩翮翰翱翳翻翼耀老考耄者耆耋而耍耐耒耔耕耖耗耘耙耜耠耢耥耦耧耨耩耪耱耳耵耶耷耸耻耽耿聂聃聆聊聋职聍聒联聘聚聩聪聱聿肀肃肄肆肇肉肋肌肓肖肘肚肛肜肝肟肠股肢肤肥肩肪肫肭肮肯肱育肴肷肺肼肽肾肿胀胁胂胃胄胆背胍胎胖胗胙胚胛胜胝胞胡胤胥胧胨胩胪胫胬胭胯胰胱胲胳胴胶胸胺胼能脂脆脉脊脍脎脏脐脑脒脓脔脖脘脚脞脬脯脱脲脶脸脾腆腈腊腋腌腐腑腓腔腕腙腚腠腥腧腩腭腮腰腱腴腹腺腻腼腽腾腿膀膂膈膊膏膑膘膛膜膝膣膦膨膪膳膺膻臀臁臂臃臆臊臌臣臧自臬臭至致臻臼臾舀舁舂舄舅舆舌舍舐舒舔舛舜舞舟舡舢舣舨航舫般舭舯舰舱舳舴舵舶舷舸船舻舾艄艇艉艋艏艘艚艟艨艮良艰色艳艴艹艺艽艾艿节芄芈芊芋芍芎芏芑芒芗芘芙芜芝芟芡芤芥芦芨芩芪芫芬芭芮芯芰花芳芴芷芸芹芽芾苁苄苇苈苊苋苌苍苎苏苑苒苓苔苕苗苘苛苜苞苟苠苡苣苤若苦苫苯英苴苷苹苻茁茂范茄茅茆茇茈茉茌茎茏茑茔茕茗茚茛茜茧茨茫茬茭茯茱茳茴茵茶茸茹茺茼荀荃荆荇草荏荐荑荒荔荚荛荜荞荟荠荡荣荤荥荦荧荨荩荪荫荬荭荮药荷荸荻荼荽莅莆莉莎莒莓莘莛莜莞莠莨莩莪莫莰莱莲莳莴莶获莸莹莺莼莽菀菁菅菇菊菌菏菔菖菘菜菝菟菠菡菥菩菪菰菱菲菸菹菽萁萃萄萆萋萌萍萎萏萑萘萜萝萤营萦萧萨萱萸萼落葆葑著葙葚葛葜葡董葩葫葬葭葱葳葵葶葸葺蒂蒇蒈蒉蒋蒌蒎蒗蒙蒜蒡蒯蒲蒴蒸蒹蒺蒽蒿蓁蓄蓉蓊蓍蓐蓑蓓蓖蓝蓟蓠蓣蓥蓦蓬蓰蓼蓿蔌蔑蔓蔗蔚蔟蔡蔫蔬蔷蔸蔹蔺蔻蔼蔽蕃蕈蕉蕊蕖蕙蕞蕤蕨蕲蕴蕹蕺蕻蕾薄薅薇薏薛薜薤薨薪薮薯薰薷薹藁藉藏藐藓藕藜藤藩藻藿蘅蘑蘖蘧蘩蘸蘼虍虎虏虐虑虔虚虞虢虫虬虮虱虹虺虻虼虽虾虿蚀蚁蚂蚊蚋蚌蚍蚓蚕蚜蚝蚣蚤蚧蚨蚩蚪蚬蚯蚰蚱蚴蚵蚶蚺蛀蛄蛆蛇蛉蛊蛋蛎蛏蛐蛑蛔蛘蛙蛛蛞蛟蛤蛩蛭蛮蛰蛱蛲蛳蛴蛸蛹蛾蜀蜂蜃蜇蜈蜉蜊蜍蜒蜓蜕蜗蜘蜚蜜蜞蜡蜢蜣蜥蜩蜮蜱蜴蜷蜻蜾蜿蝇蝈蝉蝌蝎蝓蝗蝙蝠蝣蝤蝥蝮蝰蝴蝶蝻蝼蝽蝾螂螃螅螈螋融螓螗螟螨螫螬螭螯螳螵螺螽蟀蟆蟊蟋蟑蟒蟓蟛蟠蟥蟪蟮蟹蟾蠃蠊蠓蠕蠖蠛蠡蠢蠲蠹蠼血衄衅行衍衔街衙衡衢衣衤补表衩衫衬衮衰衲衷衽衾衿袁袂袄袅袈袋袍袒袖袜袢袤被袭袱袷袼裁裂装裆裉裎裒裔裕裘裙裟裢裣裤裥裨裰裱裳裴裸裹裼裾褂褊褐褒褓褙褚褛褡褥褪褫褰褴褶襁襄襞襟襦襻西要覃覆见观规觅视觇览觉觊觋觌觎觏觐觑角觖觚觜觞解觥触觫觯觳言訇訾詈詹誉誊誓謇謦警譬讠计订讣认讥讦讧讨让讪讫训议讯记讲讳讴讵讶讷许讹论讼讽设访诀证诂诃评诅识诈诉诊诋诌词诎诏译诒诓诔试诖诗诘诙诚诛诜话诞诟诠诡询诣诤该详诧诨诩诫诬语诮误诰诱诲诳说诵诶请诸诹诺读诼诽课诿谀谁谂调谄谅谆谇谈谊谋谌谍谎谏谐谑谒谓谔谕谖谗谘谙谚谛谜谝谟谠谡谢谣谤谥谦谧谨谩谪谫谬谭谮谯谰谱谲谳谴谵谶谷豁豆豇豉豌豕豚象豢豪豫豳豸豹豺貂貅貉貊貌貔貘贝贞负贡财责贤败账货质贩贪贫贬购贮贯贰贱贲贳贴贵贶贷贸费贺贻贼贽贾贿赀赁赂赃资赅赆赇赈赉赊赋赌赍赎赏赐赓赔赕赖赘赙赚赛赜赝赞赠赡赢赣赤赦赧赫赭走赳赴赵赶起趁趄超越趋趑趔趟趣趱足趴趵趸趺趼趾趿跃跄跆跋跌跎跏跑跖跗跚跛距跞跟跣跤跨跪跫跬路跳践跷跸跹跺跻跽踅踉踊踌踏踔踝踞踟踢踣踩踪踬踮踯踱踵踹踺踽蹀蹁蹂蹄蹇蹈蹉蹊蹋蹑蹒蹙蹦蹩蹬蹭蹯蹰蹲蹴蹶蹼蹿躁躅躇躏躐躔躜躞身躬躯躲躺軎车轧轨轩轫转轭轮软轰轱轲轳轴轵轶轷轸轹轺轻轼载轾轿辁辂较辄辅辆辇辈辉辊辋辍辎辏辐辑输辔辕辖辗辘辙辚辛辜辞辟辣辨辩辫辰辱辶边辽达迁迂迄迅过迈迎运近迓返迕还这进远违连迟迢迤迥迦迨迩迪迫迭迮述迳迷迸迹追退送适逃逄逅逆选逊逋逍透逐逑递途逖逗通逛逝逞速造逡逢逦逭逮逯逵逶逸逻逼逾遁遂遄遇遍遏遐遑遒道遗遘遛遢遣遥遨遭遮遴遵遽避邀邂邃邈邋邑邓邕邗邙邛邝邡邢那邦邪邬邮邯邰邱邳邴邵邶邸邹邺邻邾郁郄郅郇郊郎郏郐郑郓郗郛郜郝郡郢郦郧部郫郭郯郴郸都郾鄂鄄鄙鄞鄢鄣鄯鄱鄹酃酆酉酊酋酌配酎酏酐酒酗酚酝酞酡酢酣酤酥酩酪酬酮酯酰酱酲酴酵酶酷酸酹酽酾酿醅醇醉醋醌醍醐醑醒醚醛醢醣醪醭醮醯醴醵醺采釉释里重野量金釜鉴銎銮鋈錾鍪鎏鏊鏖鐾鑫钅钆钇针钉钊钋钌钍钎钏钐钒钓钔钕钗钙钚钛钜钝钞钟钠钡钢钣钤钥钦钧钨钩钪钫钬钭钮钯钰钱钲钳钴钵钶钷钸钹钺钻钼钽钾钿铀铁铂铃铄铅铆铈铉铊铋铌铍铎铐铑铒铕铖铗铘铙铛铜铝铞铟铠铡铢铣铤铥铧铨铩铪铫铬铭铮铯铰铱铲铳铴铵银铷铸铹铺铼铽链铿销锁锂锃锄锅锆锇锈锉锊锋锌锍锎锏锐锑锒锓锔锕锖锗锘错锚锛锝锞锟锡锢锣锤锥锦锨锩锪锫锬锭键锯锰锱锲锴锵锶锷锸锹锺锻锼锾锿镀镁镂镄镅镆镇镉镊镌镍镎镏镐镑镒镓镔镖镗镘镙镛镜镝镞镟镡镢镣镤镥镦镧镨镩镪镫镬镭镯镰镱镲镳镶长门闩闪闫闭问闯闰闱闲闳间闵闶闷闸闹闺闻闼闽闾阀阁阂阃阄阅阆阈阉阊阋阌阍阎阏阐阑阒阔阕阖阗阙阚阜阝队阡阢阪阮阱防阳阴阵阶阻阼阽阿陀陂附际陆陇陈陉陋陌降限陔陕陛陟陡院除陧陨险陪陬陲陴陵陶陷隅隆隈隋隍随隐隔隗隘隙障隧隰隳隶隹隼隽难雀雁雄雅集雇雉雌雍雎雏雒雕雠雨雩雪雯雳零雷雹雾需霁霄霆震霈霉霍霎霏霓霖霜霞霪霭霰露霸霹霾青靓靖静靛非靠靡面靥革靳靴靶靼鞅鞋鞍鞑鞒鞔鞘鞠鞣鞫鞭鞯鞲鞴韦韧韩韪韫韬韭音韵韶页顶顷顸项顺须顼顽顾顿颀颁颂颃预颅领颇颈颉颊颌颍颏颐频颓颔颖颗题颚颛颜额颞颟颠颡颢颤颥颦颧风飑飒飓飕飘飙飚飞食飧飨餍餐餮饔饕饣饥饧饨饩饪饫饬饭饮饯饰饱饲饴饵饶饷饺饼饽饿馀馁馄馅馆馇馈馊馋馍馏馐馑馒馓馔馕首馗馘香馥馨马驭驮驯驰驱驳驴驵驶驷驸驹驺驻驼驽驾驿骀骁骂骄骅骆骇骈骊骋验骏骐骑骒骓骖骗骘骚骛骜骝骞骟骠骡骢骣骤骥骧骨骰骱骶骷骸骺骼髀髁髂髅髋髌髑髓高髟髡髦髫髭髯髹髻鬃鬈鬏鬓鬟鬣鬯鬲鬻鬼魁魂魃魄魅魇魈魉魍魏魑魔鱼鱿鲁鲂鲅鲆鲇鲈鲋鲍鲎鲐鲑鲒鲔鲕鲚鲛鲜鲞鲟鲠鲡鲢鲣鲤鲥鲦鲧鲨鲩鲫鲭鲮鲰鲱鲲鲳鲴鲵鲶鲷鲸鲺鲻鲼鲽鳃鳄鳅鳆鳇鳊鳋鳌鳍鳎鳏鳐鳓鳔鳕鳖鳗鳘鳙鳜鳝鳞鳟鳢鸟鸠鸡鸢鸣鸥鸦鸨鸩鸪鸫鸬鸭鸯鸱鸲鸳鸵鸶鸷鸸鸹鸺鸽鸾鸿鹁鹂鹃鹄鹅鹆鹇鹈鹉鹊鹋鹌鹎鹏鹑鹕鹗鹘鹚鹛鹜鹞鹣鹤鹦鹧鹨鹩鹪鹫鹬鹭鹰鹱鹳鹾鹿麂麇麈麋麒麓麝麟麦麴麸麻麽麾黄黉黍黎黏黑黔默黛黜黝黟黠黢黥黧黩黪黯黹黻黼黾鼋鼍鼎鼐鼓鼗鼙鼠鼢鼬鼯鼷鼹鼻鼽鼾齄齐齑齿龀龃龄龅龆龇龈龉龊龋龌龙龚龛龟龠 -------------------------------------------------------------------------------- /validation/validation.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import torch.nn 3 | import torch.nn.parallel 4 | import torch.optim 5 | import torch.utils.data 6 | import torch.utils.data.distributed 7 | import torchvision.utils as vutils 8 | import torch.nn.functional as F 9 | 10 | import numpy as np 11 | 12 | try: 13 | from tqdm import tqdm 14 | except ImportError: 15 | # If not tqdm is not available, provide a mock version of it 16 | def tqdm(x): 17 | return x 18 | 19 | from scipy import linalg 20 | 21 | from tools.utils import * 22 | 23 | 24 | def validateUN(data_loader, networks, epoch, args, additional=None, oss_client=None, with_name=None): 25 | # set nets 26 | D = networks['D'] if not args.distributed else networks['D'].module 27 | G = networks['G'] if not args.distributed else networks['G'].module 28 | C = networks['C'] if not args.distributed else networks['C'].module 29 | C_EMA = networks['C_EMA'] if not args.distributed else networks['C_EMA'].module 30 | G_EMA = networks['G_EMA'] if not args.distributed else networks['G_EMA'].module 31 | # switch to train mode 32 | D.eval() 33 | G.eval() 34 | C.eval() 35 | C_EMA.eval() 36 | G_EMA.eval() 37 | # data loader 38 | val_dataset = data_loader['TRAINSET'] 39 | val_loader = data_loader['VAL'] 40 | 41 | x_each_cls = [] 42 | with torch.no_grad(): 43 | val_tot_tars = torch.tensor(val_dataset.targets) 44 | for cls_idx in range(len(args.att_to_use)): 45 | tmp_cls_set = (val_tot_tars == args.att_to_use[cls_idx]).nonzero()[-args.val_num:] 46 | tmp_ds = torch.utils.data.Subset(val_dataset, tmp_cls_set) 47 | tmp_dl = torch.utils.data.DataLoader(tmp_ds, batch_size=49, shuffle=False, 48 | num_workers=4, pin_memory=True, drop_last=False) 49 | tmp_sample = torch.cat([x_.clone() for x_ ,_ in tmp_dl], dim=0) 50 | x_each_cls.append(tmp_sample) 51 | # print([len(x) for x in x_each_cls]) 52 | 53 | 54 | if epoch >= args.fid_start: 55 | # Reference guided 56 | with torch.no_grad(): 57 | # Just a buffer image ( to make a grid ) 58 | ones = torch.ones(1, x_each_cls[0].size(1), x_each_cls[0].size(2), x_each_cls[0].size(3)).to(torch.cuda.current_device(), non_blocking=True) 59 | for src_idx in [0]: #tqdm(range(len(args.att_to_use) // args.val_src_reduce)): 60 | x_src = x_each_cls[src_idx][:args.val_batch, :, :, :].to(torch.cuda.current_device(), non_blocking=True) 61 | rnd_idx = torch.randperm(x_each_cls[src_idx].size(0))[:args.val_batch] 62 | x_src_rnd = x_each_cls[src_idx][rnd_idx].to(torch.cuda.current_device(), non_blocking=True) 63 | for ref_idx in range(len(args.att_to_use) // args.val_ref_reduce): 64 | x_res_ema = torch.cat((ones, x_src), 0) 65 | x_rnd_ema = torch.cat((ones, x_src_rnd), 0) 66 | x_ref = x_each_cls[ref_idx][:args.val_batch, :, :, :].to(torch.cuda.current_device(), non_blocking=True) 67 | rnd_idx = torch.randperm(x_each_cls[ref_idx].size(0))[:args.val_batch] 68 | x_ref_rnd = x_each_cls[ref_idx][rnd_idx].to(torch.cuda.current_device(), non_blocking=True) 69 | for sample_idx in range(args.val_batch): 70 | x_ref_tmp = x_ref[sample_idx: sample_idx + 1].repeat((args.val_batch, 1, 1, 1)) 71 | 72 | c_src, skip1, skip2 = G_EMA.cnt_encoder(x_src) 73 | s_ref = C_EMA(x_ref_tmp, sty=True) 74 | x_res_ema_tmp,_ = G_EMA.decode(c_src, s_ref, skip1, skip2) 75 | 76 | x_ref_tmp = x_ref_rnd[sample_idx: sample_idx + 1].repeat((args.val_batch, 1, 1, 1)) 77 | 78 | c_src, skip1, skip2 = G_EMA.cnt_encoder(x_src_rnd) 79 | s_ref = C_EMA(x_ref_tmp, sty=True) 80 | x_rnd_ema_tmp,_ = G_EMA.decode(c_src, s_ref, skip1, skip2) 81 | 82 | x_res_ema_tmp = torch.cat((x_ref[sample_idx: sample_idx + 1], x_res_ema_tmp), 0) 83 | x_res_ema = torch.cat((x_res_ema, x_res_ema_tmp), 0) 84 | 85 | x_rnd_ema_tmp = torch.cat((x_ref_rnd[sample_idx: sample_idx + 1], x_rnd_ema_tmp), 0) 86 | x_rnd_ema = torch.cat((x_rnd_ema, x_rnd_ema_tmp), 0) 87 | 88 | # vutils.save_image(x_res_ema, os.path.join(args.res_dir, '{}_EMA_{}_{}{}.jpg'.format(args.local_rank, epoch+1, src_idx, ref_idx)), normalize=True, 89 | # nrow=(x_res_ema.size(0) // (x_src.size(0) + 2) + 1)) 90 | # vutils.save_image(x_rnd_ema, os.path.join(args.res_dir, '{}_RNDEMA_{}_{}{}.jpg'.format(args.local_rank, epoch+1, src_idx, ref_idx)), normalize=True, 91 | # nrow=(x_res_ema.size(0) // (x_src.size(0) + 2) + 1)) 92 | 93 | if args.local_rank == 0: 94 | img_res = vutils.make_grid(x_res_ema, normalize=True, nrow=(x_res_ema.size(0) // (x_src.size(0) + 2) + 1)) 95 | img_ema = vutils.make_grid(x_rnd_ema, normalize=True, nrow=(x_res_ema.size(0) // (x_src.size(0) + 2) + 1)) 96 | vutils.save_image(img_res, os.path.join(args.res_dir, '{:03d}_EMA_{:03d}_{:03d}{:03d}.jpg'.format(args.local_rank, epoch+1, src_idx, ref_idx))) 97 | vutils.save_image(img_ema, os.path.join(args.res_dir, '{:03d}_RNDEMA_{:03d}_{:03d}{:03d}.jpg'.format(args.local_rank, epoch+1, src_idx, ref_idx))) 98 | if oss_client is not None: 99 | oss_client.write_file(os.path.join(args.res_dir, '{:03d}_EMA_{:03d}_{:03d}{:03d}.jpg'.format(args.local_rank, epoch+1, src_idx, ref_idx)), \ 100 | os.path.join(args.res_dir_oss, '{:03d}_EMA_{:03d}_{:03d}{:03d}.jpg'.format(args.local_rank, epoch+1, src_idx, ref_idx))) 101 | oss_client.write_file(os.path.join(args.res_dir, '{:03d}_RNDEMA_{:03d}_{:03d}{:03d}.jpg'.format(args.local_rank, epoch+1, src_idx, ref_idx)), \ 102 | os.path.join(args.res_dir_oss, '{:03d}_RNDEMA_{:03d}_{:03d}{:03d}.jpg'.format(args.local_rank, epoch+1, src_idx, ref_idx))) 103 | 104 | os.remove(os.path.join(args.res_dir, '{:03d}_EMA_{:03d}_{:03d}{:03d}.jpg'.format(args.local_rank, epoch+1, src_idx, ref_idx))) 105 | os.remove(os.path.join(args.res_dir, '{:03d}_RNDEMA_{:03d}_{:03d}{:03d}.jpg'.format(args.local_rank, epoch+1, src_idx, ref_idx))) -------------------------------------------------------------------------------- /validation/validation_cf.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import torch.nn 3 | import torch.nn.parallel 4 | import torch.optim 5 | import torch.utils.data 6 | import torch.utils.data.distributed 7 | import torchvision.utils as vutils 8 | import torch.nn.functional as F 9 | 10 | import numpy as np 11 | 12 | try: 13 | from tqdm import tqdm 14 | except ImportError: 15 | # If not tqdm is not available, provide a mock version of it 16 | def tqdm(x): 17 | return x 18 | 19 | from scipy import linalg 20 | 21 | from tools.utils import * 22 | 23 | 24 | def validateUN(data_loader, networks, epoch, args, additional=None, oss_client=None): 25 | # set nets 26 | D = networks['D'] if not args.distributed else networks['D'].module 27 | G = networks['G'] if not args.distributed else networks['G'].module 28 | C = networks['C'] if not args.distributed else networks['C'].module 29 | C_EMA = networks['C_EMA'] if not args.distributed else networks['C_EMA'].module 30 | G_EMA = networks['G_EMA'] if not args.distributed else networks['G_EMA'].module 31 | # switch to train mode 32 | D.eval() 33 | G.eval() 34 | C.eval() 35 | C_EMA.eval() 36 | G_EMA.eval() 37 | # data loader 38 | val_dataset = data_loader['TRAINSET'] 39 | val_loader = data_loader['VAL'] 40 | 41 | x_each_cls = [] 42 | with torch.no_grad(): 43 | val_tot_tars = torch.tensor(val_dataset.targets) 44 | for cls_idx in range(len(args.att_to_use)): 45 | tmp_cls_set = (val_tot_tars == args.att_to_use[cls_idx]).nonzero()[-args.val_num:] 46 | tmp_ds = torch.utils.data.Subset(val_dataset, tmp_cls_set) 47 | tmp_dl = torch.utils.data.DataLoader(tmp_ds, batch_size=args.val_num, shuffle=False, 48 | num_workers=0, pin_memory=True, drop_last=False) 49 | tmp_iter = iter(tmp_dl) 50 | tmp_sample = None 51 | for sample_idx in range(len(tmp_iter)): 52 | imgs, _ = next(tmp_iter) 53 | x_ = imgs 54 | if tmp_sample is None: 55 | tmp_sample = x_.clone() 56 | else: 57 | tmp_sample = torch.cat((tmp_sample, x_), 0) 58 | x_each_cls.append(tmp_sample) 59 | 60 | 61 | if epoch >= args.fid_start: 62 | # Reference guided 63 | with torch.no_grad(): 64 | # Just a buffer image ( to make a grid ) 65 | ones = torch.ones(1, x_each_cls[0].size(1), x_each_cls[0].size(2), x_each_cls[0].size(3)).to(torch.cuda.current_device(), non_blocking=True) 66 | for src_idx in range(len(args.att_to_use) // args.val_src_reduce): 67 | x_src = x_each_cls[src_idx][:args.val_batch, :, :, :].to(torch.cuda.current_device(), non_blocking=True) 68 | rnd_idx = torch.randperm(x_each_cls[src_idx].size(0))[:args.val_batch] 69 | x_src_rnd = x_each_cls[src_idx][rnd_idx].to(torch.cuda.current_device(), non_blocking=True) 70 | for ref_idx in range(len(args.att_to_use) // args.val_ref_reduce): 71 | x_res_ema = torch.cat((ones, x_src), 0) 72 | x_rnd_ema = torch.cat((ones, x_src_rnd), 0) 73 | x_ref = x_each_cls[ref_idx][:args.val_batch, :, :, :].to(torch.cuda.current_device(), non_blocking=True) 74 | rnd_idx = torch.randperm(x_each_cls[ref_idx].size(0))[:args.val_batch] 75 | x_ref_rnd = x_each_cls[ref_idx][rnd_idx].to(torch.cuda.current_device(), non_blocking=True) 76 | for sample_idx in range(args.val_batch): 77 | x_ref_tmp = x_ref[sample_idx: sample_idx + 1].repeat((args.val_batch, 1, 1, 1)) 78 | 79 | c_src, skip1, skip2 = G_EMA.cnt_encoder(x_src) 80 | s_ref = C_EMA(x_ref_tmp, sty=True) 81 | x_res_ema_tmp,_ = G_EMA.decode(c_src, s_ref, skip1, skip2) 82 | 83 | x_ref_tmp = x_ref_rnd[sample_idx: sample_idx + 1].repeat((args.val_batch, 1, 1, 1)) 84 | 85 | c_src, skip1, skip2 = G_EMA.cnt_encoder(x_src_rnd) 86 | s_ref = C_EMA(x_ref_tmp, sty=True) 87 | x_rnd_ema_tmp,_ = G_EMA.decode(c_src, s_ref, skip1, skip2) 88 | 89 | x_res_ema_tmp = torch.cat((x_ref[sample_idx: sample_idx + 1], x_res_ema_tmp), 0) 90 | x_res_ema = torch.cat((x_res_ema, x_res_ema_tmp), 0) 91 | 92 | x_rnd_ema_tmp = torch.cat((x_ref_rnd[sample_idx: sample_idx + 1], x_rnd_ema_tmp), 0) 93 | x_rnd_ema = torch.cat((x_rnd_ema, x_rnd_ema_tmp), 0) 94 | 95 | # vutils.save_image(x_res_ema, os.path.join(args.res_dir, '{}_EMA_{}_{}{}.jpg'.format(args.local_rank, epoch+1, src_idx, ref_idx)), normalize=True, 96 | # nrow=(x_res_ema.size(0) // (x_src.size(0) + 2) + 1)) 97 | # vutils.save_image(x_rnd_ema, os.path.join(args.res_dir, '{}_RNDEMA_{}_{}{}.jpg'.format(args.local_rank, epoch+1, src_idx, ref_idx)), normalize=True, 98 | # nrow=(x_res_ema.size(0) // (x_src.size(0) + 2) + 1)) 99 | 100 | if args.local_rank == 0: 101 | img_res = vutils.make_grid(x_res_ema, normalize=True, nrow=(x_res_ema.size(0) // (x_src.size(0) + 2) + 1)) 102 | img_ema = vutils.make_grid(x_rnd_ema, normalize=True, nrow=(x_res_ema.size(0) // (x_src.size(0) + 2) + 1)) 103 | vutils.save_image(img_res, os.path.join(args.res_dir, '{:03d}_EMA_{:03d}_{:03d}{:03d}.jpg'.format(args.local_rank, epoch+1, src_idx, ref_idx))) 104 | vutils.save_image(img_ema, os.path.join(args.res_dir, '{:03d}_RNDEMA_{:03d}_{:03d}{:03d}.jpg'.format(args.local_rank, epoch+1, src_idx, ref_idx))) 105 | if oss_client is not None: 106 | oss_client.write_file(os.path.join(args.res_dir, '{:03d}_EMA_{:03d}_{:03d}{:03d}.jpg'.format(args.local_rank, epoch+1, src_idx, ref_idx)), \ 107 | os.path.join(args.res_dir_oss, '{:03d}_EMA_{:03d}_{:03d}{:03d}.jpg'.format(args.local_rank, epoch+1, src_idx, ref_idx))) 108 | oss_client.write_file(os.path.join(args.res_dir, '{:03d}_RNDEMA_{:03d}_{:03d}{:03d}.jpg'.format(args.local_rank, epoch+1, src_idx, ref_idx)), \ 109 | os.path.join(args.res_dir_oss, '{:03d}_RNDEMA_{:03d}_{:03d}{:03d}.jpg'.format(args.local_rank, epoch+1, src_idx, ref_idx))) 110 | 111 | os.remove(os.path.join(args.res_dir, '{:03d}_EMA_{:03d}_{:03d}{:03d}.jpg'.format(args.local_rank, epoch+1, src_idx, ref_idx))) 112 | os.remove(os.path.join(args.res_dir, '{:03d}_RNDEMA_{:03d}_{:03d}{:03d}.jpg'.format(args.local_rank, epoch+1, src_idx, ref_idx))) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CF-Font 2 | 3 | ------------- 4 | teaser 5 | 6 | By Chi Wang, Min Zhou, Tiezheng Ge, Yuning Jiang, Hujun Bao, Weiwei Xu* 7 | 8 | This repo is the official implementation of `CF-Font: Content Fusion for Few-shot Font Generation` (CVF:[Paper link](https://openaccess.thecvf.com/content/CVPR2023/html/Wang_CF-Font_Content_Fusion_for_Few-Shot_Font_Generation_CVPR_2023_paper.html); arXiv:[2303.14017](https://arxiv.org/abs/2303.14017)) accepted by [CVPR 2023](https://cvpr2023.thecvf.com/). 9 | 10 | ## Video demos for Style Interpolation 11 | - A poem demo 12 | 13 | 14 | https://user-images.githubusercontent.com/96471617/227514150-b9ea651f-3859-489b-a24b-5623d806aca8.mp4 15 | 16 | 17 | 18 | - Comparison with DG-Font 19 | 20 | 21 | 22 | https://user-images.githubusercontent.com/96471617/227696956-10b663ad-a0bd-4759-847d-a61d08eb97bf.mp4 23 | 24 | 25 | 26 | ## Dependencies 27 | 28 | Libarary 29 | ------------- 30 | ``` 31 | pytorch (>=1.0) 32 | tqdm 33 | numpy 34 | opencv-python 35 | scipy 36 | sklearn 37 | matplotlib 38 | pillow 39 | tensorboardX 40 | scikit-image 41 | scikit-learn 42 | pytorch-fid 43 | lpips 44 | pandas 45 | kornia 46 | ``` 47 | 48 | DCN 49 | -------------- 50 | 51 | please refer to https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 to install the dependencies of deformable convolution. 52 | 53 | Dataset 54 | -------------- 55 | [方正字库](https://www.foundertype.com/index.php/FindFont/index) provides free font download for non-commercial users. 56 | 57 | # How to run 58 | 59 | 1. prepare dataset 60 | - Put your font files to a folder and character file to charset 61 | ```bash 62 | . 63 | ├── data 64 | │   └── fonts 65 | │      ├── Font_Seen240 66 | │      │ ├── 000_xxxx.ttf 67 | │      │ ├── 001_xxxx.ttf 68 | │      │ ├── ... 69 | │      │ └── 239_xxxx.ttf 70 | │      └── Font_Unseen60 71 | ├── charset 72 | │   ├── check_overlap.py 73 | │   ├── GB2312_CN6763.txt # all characters used in TRAIN and TEST 74 | │   ├── FS16.txt # Few-Shot characters, should not be included in TESTxxx.txt for fairness (can be included in TRAINxxx.txt) 75 | │   ├── TEST5646.txt # all characters used in TEST 76 | │   └── TRAIN800.txt # all characters used in TRAIN 77 | └── ... 78 | ``` 79 | - Generate dataset with the full standard Chinese character set (6763 in total) of GB/T 2312 : 80 | ```bash 81 | sh scripts/01a_gen_date.sh 82 | ``` 83 | After that, your file tree should be: 84 | ```bash 85 | . 86 | ├── data 87 | │   ├── fonts 88 | │   └── imgs 89 | │   ├── Seen240_S80F50_FULL 90 | │   │ ├── id_0 91 | │   │   │   ├── 0000.png 92 | │   │   │   ├── 0001.png 93 | │   │   │   ├── ... 94 | │   │   │   └── 6762.png 95 | │   │ ├── id_1 96 | │   │ ├── ... 97 | │   │ └── id_239 98 | │   └── Unseen60_S80F50_FULL 99 | ├── charset 100 | └── ... 101 | ``` 102 | - Get subsets with train, test and fewshot character txts. 103 | ```bash 104 | sh scripts/01b_copy_subset.sh 105 | ``` 106 | After that, your file tree should be: 107 | ```bash 108 | . 109 | ├── data 110 | │   ├── fonts 111 | │   └── imgs 112 | │   ├── Seen240_S80F50_FS16 113 | │   │ ├── id_0 114 | │   │   │   ├── 0000.png 115 | │   │   │   ├── 0001.png 116 | │   │   │   ├── ... 117 | │   │   │   └── 0015.png 118 | │   │ ├── id_1 119 | │   │ ├── ... 120 | │   │ └── id_239 121 | │   ├── Seen240_S80F50_FULL 122 | │   ├── Seen240_S80F50_TEST5646 123 | │   │ ├── id_0 124 | │   │   │   ├── 0000.png 125 | │   │   │   ├── 0001.png 126 | │   │   │   ├── ... 127 | │   │   │   └── 5645.png 128 | │   │ ├── id_1 129 | │   │ ├── ... 130 | │   │ └── id_239 131 | │   ├── Seen240_S80F50_TRAIN800 132 | │   │ ├── id_0 133 | │   │   │   ├── 0000.png 134 | │   │   │   ├── 0001.png 135 | │   │   │   ├── ... 136 | │   │   │   └── 0799.png 137 | │   │ ├── id_1 138 | │   │ ├── ... 139 | │   │ └── id_239 140 | │   ├── Unseen60_S80F50_FS16 141 | │   ├── Unseen60_S80F50_FULL 142 | │   ├── Unseen60_S80F50_TEST5646 143 | │   └── Unseen60_S80F50_TRAIN800 144 | ├── charset 145 | └── ... 146 | ``` 147 | 2. Train base network 148 | ```bash 149 | # enable PC-WDL with the flag `--wdl` and PC-PKL with the flag `-pkl` 150 | sh scripts/02a_run_ddp.sh 151 | ``` 152 | Option: In order to evaluate the training of the network, we can use the script `scripts/option_run_inf_dgfont.sh` to inference. 153 | 3. Train CF-Font 154 | - Select basis. Basis fonts better contain a standard font, like `song`. 155 | - Manually. If you want select manually, please put basis ids (one line, seperated with a space) to a txt file, like: 156 | ``` 157 | 0 1 2 3 4 5 6 7 8 9 158 | ``` 159 | - By clustering. 160 | ```bash 161 | # Content embeddings collection 162 | sh scripts/03a_get_content_embeddings.sh 163 | 164 | # obtain basis ids through clustering 165 | sh scripts/03b_cluster_get_cf_basis.sh 166 | ``` 167 | - Get subsets with basis font ids. 168 | ```bash 169 | sh scripts/03c_copy_basis_subset.sh 170 | ``` 171 | - Get basis weight for content fusion: 172 | ```bash 173 | sh scripts/03d_cal_cf_weights.sh 174 | ``` 175 | - Train CF-Font 176 | ```bash 177 | # make a folder for CF-Font training, and copy the pretrain model here. 178 | sh scripts/03e_init_cf_env.sh 179 | 180 | # train CF-Font 181 | sh scripts/03f_run_ddp_cf.sh 182 | ``` 183 | 4. Inference and evaluation: 184 | ```bash 185 | # Inference (SII with the flag `--ft`) 186 | sh scripts/04a_run_inf_cf.sh 187 | 188 | # Evaluation 189 | ## get scores for each font 190 | sh scripts/04b_get_scores.sh 191 | 192 | ## get mean scores (use `-j` to skip the unwanted fonts, like basis fonts) 193 | sh scripts/04c_cal_mean_scores.sh 194 | ``` 195 | 196 | ## Citation 197 | 198 | ``` bibtex 199 | @InProceedings{Wang_2023_CVPR, 200 | author = {Wang, Chi and Zhou, Min and Ge, Tiezheng and Jiang, Yuning and Bao, Hujun and Xu, Weiwei}, 201 | title = {CF-Font: Content Fusion for Few-Shot Font Generation}, 202 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 203 | month = {June}, 204 | year = {2023}, 205 | pages = {1858-1867} 206 | } 207 | ``` 208 | 209 | ## Acknowledgements 210 | We would like to thank [Alimama](https://www.alimama.com) (Alibaba Group) and [State Key Lab of CAD&CG](http://www.cad.zju.edu.cn) (Zhejiang University) for their support and advices in our project. Our code is based on [DG-Font](https://github.com/ecnuycxie/DG-Font). 211 | 212 | ## Contact 213 | 214 | If you have any questions, please feel free to contact `wangchi1995@zju.edu.cn`. 215 | -------------------------------------------------------------------------------- /models/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | 6 | class ResBlocks(nn.Module): 7 | def __init__(self, num_blocks, dim, norm, act, pad_type, use_sn=False): 8 | super(ResBlocks, self).__init__() 9 | self.model = nn.ModuleList() 10 | for i in range(num_blocks): 11 | self.model.append(ResBlock(dim, norm=norm, act=act, pad_type=pad_type, use_sn=use_sn)) 12 | self.model = nn.Sequential(*self.model) 13 | 14 | def forward(self, x): 15 | return self.model(x) 16 | 17 | 18 | class ResBlock(nn.Module): 19 | def __init__(self, dim, norm='in', act='relu', pad_type='zero', use_sn=False): 20 | super(ResBlock, self).__init__() 21 | self.model = nn.Sequential(Conv2dBlock(dim, dim, 3, 1, 1, 22 | norm=norm, 23 | act=act, 24 | pad_type=pad_type, use_sn=use_sn), 25 | Conv2dBlock(dim, dim, 3, 1, 1, 26 | norm=norm, 27 | act='none', 28 | pad_type=pad_type, use_sn=use_sn)) 29 | 30 | def forward(self, x): 31 | x_org = x 32 | residual = self.model(x) 33 | out = x_org + 0.1 * residual 34 | return out 35 | 36 | 37 | class ActFirstResBlk(nn.Module): 38 | def __init__(self, dim_in, dim_out, downsample=True): 39 | super(ActFirstResBlk, self).__init__() 40 | self.norm1 = FRN(dim_in) 41 | self.norm2 = FRN(dim_in) 42 | self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1) 43 | self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) 44 | self.downsample = downsample 45 | self.learned_sc = (dim_in != dim_out) 46 | if self.learned_sc: 47 | self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) 48 | 49 | def _shortcut(self, x): 50 | if self.learned_sc: 51 | x = self.conv1x1(x) 52 | if self.downsample: 53 | x = F.avg_pool2d(x, 2) 54 | return x 55 | 56 | def _residual(self, x): 57 | x = self.norm1(x) 58 | x = self.conv1(x) 59 | if self.downsample: 60 | x = F.avg_pool2d(x, 2) 61 | x = self.norm2(x) 62 | x = self.conv2(x) 63 | return x 64 | 65 | def forward(self, x): 66 | return torch.rsqrt(torch.tensor(2.0)) * self._shortcut(x) + torch.rsqrt(torch.tensor(2.0)) * self._residual(x) 67 | 68 | 69 | class LinearBlock(nn.Module): 70 | def __init__(self, in_dim, out_dim, norm='none', act='relu', use_sn=False): 71 | super(LinearBlock, self).__init__() 72 | use_bias = True 73 | self.fc = nn.Linear(in_dim, out_dim, bias=use_bias) 74 | if use_sn: 75 | self.fc = nn.utils.spectral_norm(self.fc) 76 | 77 | # initialize normalization 78 | norm_dim = out_dim 79 | if norm == 'bn': 80 | self.norm = nn.BatchNorm1d(norm_dim) 81 | elif norm == 'in': 82 | self.norm = nn.InstanceNorm1d(norm_dim) 83 | elif norm == 'none': 84 | self.norm = None 85 | else: 86 | assert 0, "Unsupported normalization: {}".format(norm) 87 | 88 | # initialize activation 89 | if act == 'relu': 90 | self.activation = nn.ReLU(inplace=True) 91 | elif act == 'lrelu': 92 | self.activation = nn.LeakyReLU(0.2, inplace=True) 93 | elif act == 'tanh': 94 | self.activation = nn.Tanh() 95 | elif act == 'none': 96 | self.activation = None 97 | else: 98 | assert 0, "Unsupported activation: {}".format(act) 99 | 100 | def forward(self, x): 101 | out = self.fc(x) 102 | if self.norm: 103 | out = self.norm(out) 104 | if self.activation: 105 | out = self.activation(out) 106 | return out 107 | 108 | 109 | class Conv2dBlock(nn.Module): 110 | def __init__(self, in_dim, out_dim, ks, st, padding=0, 111 | norm='none', act='relu', pad_type='zero', 112 | use_bias=True, use_sn=False): 113 | super(Conv2dBlock, self).__init__() 114 | self.use_bias = use_bias 115 | 116 | # initialize padding 117 | if pad_type == 'reflect': 118 | self.pad = nn.ReflectionPad2d(padding) 119 | elif pad_type == 'replicate': 120 | self.pad = nn.ReplicationPad2d(padding) 121 | elif pad_type == 'zero': 122 | self.pad = nn.ZeroPad2d(padding) 123 | else: 124 | assert 0, "Unsupported padding type: {}".format(pad_type) 125 | 126 | # initialize normalization 127 | norm_dim = out_dim 128 | if norm == 'bn': 129 | self.norm = nn.BatchNorm2d(norm_dim) 130 | elif norm == 'in': 131 | self.norm = nn.InstanceNorm2d(norm_dim) 132 | elif norm == 'adain': 133 | self.norm = AdaIN2d(norm_dim) 134 | elif norm == 'none': 135 | self.norm = None 136 | else: 137 | assert 0, "Unsupported normalization: {}".format(norm) 138 | 139 | # initialize activation 140 | if act == 'relu': 141 | self.activation = nn.ReLU(inplace=True) 142 | elif act == 'lrelu': 143 | self.activation = nn.LeakyReLU(0.2, inplace=True) 144 | elif act == 'tanh': 145 | self.activation = nn.Tanh() 146 | elif act == 'none': 147 | self.activation = None 148 | else: 149 | assert 0, "Unsupported activation: {}".format(act) 150 | 151 | self.conv = nn.Conv2d(in_dim, out_dim, ks, st, bias=self.use_bias) 152 | if use_sn: 153 | self.conv = nn.utils.spectral_norm(self.conv) 154 | 155 | def forward(self, x): 156 | x = self.conv(self.pad(x)) 157 | if self.norm: 158 | x = self.norm(x) 159 | if self.activation: 160 | x = self.activation(x) 161 | return x 162 | 163 | 164 | class FRN(nn.Module): 165 | def __init__(self, num_features, eps=1e-6): 166 | super(FRN, self).__init__() 167 | self.tau = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 168 | self.gamma = nn.Parameter(torch.ones(1, num_features, 1, 1)) 169 | self.beta = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 170 | self.eps = eps 171 | 172 | def forward(self, x): 173 | x = x * torch.rsqrt(torch.mean(x**2, dim=[2, 3], keepdim=True) + self.eps) 174 | return torch.max(self.gamma * x + self.beta, self.tau) 175 | 176 | 177 | class AdaIN2d(nn.Module): 178 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=False, track_running_stats=True): 179 | super(AdaIN2d, self).__init__() 180 | self.num_features = num_features 181 | self.eps = eps 182 | self.momentum = momentum 183 | self.affine = affine 184 | self.track_running_stats = track_running_stats 185 | 186 | if self.affine: 187 | self.weight = nn.Parameter(torch.Tensor(num_features)) 188 | self.bias = nn.Parameter(torch.Tensor(num_features)) 189 | else: 190 | self.weight = None 191 | self.bias = None 192 | 193 | if self.track_running_stats: 194 | self.register_buffer('running_mean', torch.zeros(num_features)) 195 | self.register_buffer('running_var', torch.ones(num_features)) 196 | else: 197 | self.register_buffer('running_mean', None) 198 | self.register_buffer('running_var', None) 199 | 200 | def forward(self, x): 201 | assert self.weight is not None and self.bias is not None, "AdaIN params are None" 202 | N, C, H, W = x.size() 203 | running_mean = self.running_mean.repeat(N) 204 | running_var = self.running_var.repeat(N) 205 | x_ = x.contiguous().view(1, N * C, H * W) 206 | normed = F.batch_norm(x_, running_mean, running_var, 207 | self.weight, self.bias, 208 | True, self.momentum, self.eps) 209 | return normed.view(N, C, H, W) 210 | 211 | def __repr__(self): 212 | return self.__class__.__name__ + '(num_features=' + str(self.num_features) + ')' 213 | 214 | 215 | if __name__ == '__main__': 216 | print("CALL blocks.py") 217 | -------------------------------------------------------------------------------- /datasets/datasetgetter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.datasets import ImageFolder 3 | import os 4 | import random 5 | import torchvision.transforms as transforms 6 | from datasets.custom_dataset import ImageFolerRemap, CrossdomainFolder, \ 7 | ImageFolerRemapPairCF, ImageFolerRemapUnpairCF, ImageFolerRemapPairbasis, \ 8 | ImageFolerRemapPair, TwoDataset 9 | 10 | class Compose(object): 11 | def __init__(self, tf): 12 | self.tf = tf 13 | 14 | def __call__(self, img): 15 | for t in self.tf: 16 | img = t(img) 17 | return img 18 | 19 | 20 | def get_dataset(args, data_dir=None, class_to_use=None, with_path=False): 21 | 22 | mean = [0.5, 0.5, 0.5] 23 | std = [0.5, 0.5, 0.5] 24 | 25 | normalize = transforms.Normalize(mean=mean, std=std) 26 | transform = Compose([transforms.Resize((args.img_size, args.img_size)), 27 | transforms.ToTensor(), 28 | normalize]) 29 | transform_val = Compose([transforms.Resize((args.img_size, args.img_size)), 30 | transforms.ToTensor(), 31 | normalize]) 32 | 33 | class_to_use = class_to_use or args.att_to_use 34 | remap_table = {k: i for i, k in enumerate(class_to_use)} 35 | if args.local_rank == 0: 36 | print(f'USE CLASSES: {class_to_use}\nLABEL MAP: {remap_table}') 37 | 38 | img_dir = data_dir or args.data_dir 39 | 40 | dataset = ImageFolerRemap(img_dir, transform=transform, remap_table=remap_table, with_path=with_path) 41 | valdataset = ImageFolerRemap(img_dir, transform=transform_val, remap_table=remap_table, with_path=with_path) 42 | # parse classes to use 43 | tot_targets = torch.tensor(dataset.targets) 44 | 45 | if True: # my implement 46 | train_dataset = {'TRAIN': dataset, 'FULL': dataset} 47 | subset_indices = random.sample(range(len(valdataset)), args.val_num) 48 | val_dataset = torch.utils.data.Subset(valdataset, subset_indices) 49 | else: # DG-Font style implement 50 | min_data, max_data = 99999999, 0 51 | train_idx, val_idx = None, None 52 | for k in class_to_use: 53 | tmp_idx = (tot_targets == k).nonzero(as_tuple=False) 54 | train_tmp_idx = tmp_idx[:-args.val_num] 55 | val_tmp_idx = tmp_idx[-args.val_num:] 56 | 57 | if k == class_to_use[0]: 58 | train_idx = train_tmp_idx.clone() 59 | val_idx = val_tmp_idx.clone() 60 | else: 61 | train_idx = torch.cat((train_idx, train_tmp_idx)) 62 | val_idx = torch.cat((val_idx, val_tmp_idx)) 63 | 64 | min_data = min(min_data, len(train_tmp_idx)) 65 | max_data = max(max_data, len(train_tmp_idx)) 66 | 67 | train_dataset = torch.utils.data.Subset(dataset, train_idx) 68 | val_dataset = torch.utils.data.Subset(valdataset, val_idx) 69 | train_dataset = {'TRAIN': train_dataset, 'FULL': dataset} 70 | 71 | args.min_data, args.max_data = min_data, max_data 72 | if args.local_rank == 0: 73 | print(f"MIN/MAX DATA: {args.min_data}/{args.max_data}") 74 | 75 | return train_dataset, val_dataset 76 | 77 | def get_full_dataset_ft(args, data_dir=None, class_to_use=None, with_path=False, 78 | ft_ignore_target=-1): 79 | 80 | mean = [0.5, 0.5, 0.5] 81 | std = [0.5, 0.5, 0.5] 82 | 83 | normalize = transforms.Normalize(mean=mean, std=std) 84 | transform = Compose([transforms.Resize((args.img_size, args.img_size)), 85 | transforms.ToTensor(), 86 | normalize]) 87 | class_to_use = class_to_use or args.att_to_use 88 | remap_table = {k: i for i, k in enumerate(class_to_use)} 89 | if args.local_rank == 0: 90 | print(f'USE CLASSES: {class_to_use}\nLABEL MAP: {remap_table}') 91 | 92 | img_dir = data_dir or args.data_dir 93 | dataset = ImageFolerRemap(img_dir, transform=transform, remap_table=remap_table, with_path=with_path) 94 | dataser_ft = ImageFolerRemapPair(img_dir, transform=transform, remap_table=remap_table, ignore_target=ft_ignore_target) 95 | # parse classes to use 96 | # tot_targets = torch.tensor(dataset.targets) 97 | 98 | # min_data, max_data = 99999999, 0 99 | # train_idx, val_idx = None, None 100 | # for k in class_to_use: 101 | # tmp_idx = (tot_targets == k).nonzero(as_tuple=False) 102 | # min_data = min(min_data, len(tmp_idx)) 103 | # max_data = max(max_data, len(tmp_idx)) 104 | full_dataset = {'FULL': dataset, 'FT': dataser_ft} 105 | # args.min_data, args.max_data = min_data, max_data 106 | # if args.local_rank == 0: 107 | # print(f"MIN/MAX DATA: {args.min_data}/{args.max_data}") 108 | return full_dataset 109 | 110 | def get_full_dataset_cfft(args, data_dir=None, base_dir=None, base_ft_dir=None, class_to_use=None, with_path=False, 111 | ft_ignore_target=-1, class_to_use_base=None): 112 | 113 | mean = [0.5, 0.5, 0.5] 114 | std = [0.5, 0.5, 0.5] 115 | 116 | normalize = transforms.Normalize(mean=mean, std=std) 117 | transform = Compose([transforms.Resize((args.img_size, args.img_size)), 118 | transforms.ToTensor(), 119 | normalize]) 120 | class_to_use = class_to_use or args.att_to_use 121 | class_to_use_base = class_to_use_base or args.att_to_use_base 122 | remap_table = {k: i for i, k in enumerate(class_to_use)} 123 | remap_table_base = {k: i for i, k in enumerate(class_to_use_base)} 124 | if args.local_rank == 0: 125 | print(f'USE CLASSES: {class_to_use}\nLABEL MAP: {remap_table}\nBASE LABEL MAP: {remap_table_base}') 126 | 127 | img_dir = data_dir or args.data_dir 128 | img_base_dir = base_dir or args.base_dir 129 | img_base_ft_dir = base_ft_dir or args.base_ft_dir 130 | dataset = ImageFolerRemap(img_dir, transform=transform, remap_table=remap_table, with_path=with_path) 131 | dataset_base = ImageFolerRemapPair(img_base_dir, transform=transform, remap_table=remap_table_base, with_path=with_path) 132 | dataser_full_ft = ImageFolerRemapPair(img_dir, transform=transform, remap_table=remap_table, ignore_target=ft_ignore_target) 133 | dataser_base_ft = ImageFolerRemapPair(img_base_ft_dir, transform=transform, remap_table=remap_table_base, ignore_target=ft_ignore_target) 134 | dataser_ft = TwoDataset(dataser_full_ft, dataser_base_ft) 135 | # parse classes to use 136 | tot_targets = torch.tensor(dataset.targets) 137 | 138 | # min_data, max_data = 99999999, 0 139 | # train_idx, val_idx = None, None 140 | # for k in class_to_use: 141 | # tmp_idx = (tot_targets == k).nonzero(as_tuple=False) 142 | # min_data = min(min_data, len(tmp_idx)) 143 | # max_data = max(max_data, len(tmp_idx)) 144 | full_dataset = {'FULL': dataset, 'BASE': dataset_base, 'FT': dataser_ft} 145 | # args.min_data, args.max_data = min_data, max_data 146 | # if args.local_rank == 0: 147 | # print(f"MIN/MAX DATA: {args.min_data}/{args.max_data}") 148 | return full_dataset 149 | 150 | def get_full_dataset(args): 151 | mean = [0.5, 0.5, 0.5] 152 | std = [0.5, 0.5, 0.5] 153 | normalize = transforms.Normalize(mean=mean, std=std) 154 | transform = Compose([transforms.Resize((args.img_size, args.img_size)), 155 | transforms.ToTensor(), 156 | normalize]) 157 | dataset = ImageFolerRemap(args.data_dir, transform=transform, remap_table=remap_table) 158 | return dataset 159 | 160 | 161 | def get_cf_dataset(args): 162 | 163 | mean = [0.5, 0.5, 0.5] 164 | std = [0.5, 0.5, 0.5] 165 | 166 | normalize = transforms.Normalize(mean=mean, std=std) 167 | 168 | transform = Compose([transforms.Resize((args.img_size, args.img_size)), 169 | transforms.ToTensor(), 170 | normalize]) 171 | 172 | class_to_use = args.att_to_use 173 | 174 | if args.local_rank == 0: 175 | print('USE CLASSES', class_to_use) 176 | 177 | # remap labels 178 | remap_table = {} 179 | i = 0 180 | for k in class_to_use: 181 | remap_table[k] = i 182 | i += 1 183 | 184 | if args.local_rank == 0: 185 | print("LABEL MAP:", remap_table) 186 | 187 | img_dir = args.data_dir 188 | 189 | # cf_dataset = ImageFolerRemapPairCF(img_dir, base_idxs=args.base_idxs, base_ws=args.base_ws,transform=transform, remap_table=remap_table, \ 190 | # sample_skip_base=True, sample_N=args.sample_N) 191 | cf_dataset = ImageFolerRemapUnpairCF(img_dir, base_ws=args.base_ws, transform=transform, remap_table=remap_table, top_n=args.base_top_n) 192 | cf_basis_dataset = ImageFolerRemapPairbasis(img_dir, base_idxs=args.base_idxs, base_ws=args.base_ws,transform=transform, remap_table=remap_table) # keep bs 1 193 | 194 | return cf_dataset, cf_basis_dataset 195 | -------------------------------------------------------------------------------- /tools/abl_allinone.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | from scipy.ndimage import distance_transform_edt as distance 7 | # can find here: https://github.com/CoinCheung/pytorch-loss/blob/af876e43218694dc8599cc4711d9a5c5e043b1b2/label_smooth.py 8 | from .label_smooth import LabelSmoothSoftmaxCEV1 as LSSCE 9 | from torchvision import transforms 10 | from functools import partial 11 | from operator import itemgetter 12 | 13 | # Tools 14 | def kl_div(a,b): # q,p 15 | return F.softmax(b, dim=1) * (F.log_softmax(b, dim=1) - F.log_softmax(a, dim=1)) 16 | 17 | def one_hot2dist(seg): 18 | res = np.zeros_like(seg) 19 | for i in range(len(seg)): 20 | posmask = seg[i].astype(np.bool) 21 | if posmask.any(): 22 | negmask = ~posmask 23 | res[i] = distance(negmask) * negmask - (distance(posmask) - 1) * posmask 24 | return res 25 | 26 | def class2one_hot(seg, C): 27 | seg = seg.unsqueeze(dim=0) if len(seg.shape) == 2 else seg 28 | res = torch.stack([seg == c for c in range(C)], dim=1).type(torch.int32) 29 | return res 30 | 31 | # Active Boundary Loss 32 | class ABL(nn.Module): 33 | def __init__(self, use_detach=True, max_N_ratio = 1/100, ignore_label = 255, label_smoothing=0.2, max_clip_dist = 20.): 34 | super(ABL, self).__init__() 35 | self.ignore_label = ignore_label 36 | self.label_smoothing = label_smoothing 37 | self.use_detach=use_detach 38 | self.max_N_ratio = max_N_ratio 39 | 40 | self.weight_func = lambda w, max_distance=max_clip_dist: torch.clamp(w, max=max_distance) / max_distance 41 | 42 | # The code for distance map generation is partially referenced from https://github.com/LIVIAETS/boundary-loss 43 | self.dist_map_transform = transforms.Compose([ 44 | lambda img: img.unsqueeze(0), 45 | lambda nd: nd.type(torch.int64), 46 | partial(class2one_hot, C=1), 47 | itemgetter(0), 48 | lambda t: t.cpu().numpy(), 49 | one_hot2dist, 50 | lambda nd: torch.tensor(nd, dtype=torch.float32) 51 | ]) 52 | 53 | if label_smoothing == 0: 54 | self.criterion = nn.CrossEntropyLoss( 55 | weight=None, 56 | ignore_index=ignore_label, 57 | reduction='none' 58 | ) 59 | else: 60 | self.criterion = LSSCE( 61 | reduction='none', 62 | ignore_index=ignore_label, 63 | lb_smooth = label_smoothing 64 | ) 65 | 66 | def logits2boundary(self, logit): 67 | eps = 1e-5 68 | _, _, h, w = logit.shape 69 | max_N = (h*w) * self.max_N_ratio 70 | kl_lr = kl_div(logit[:, :, 1:, :], logit[:, :, :-1, :]).sum(1, keepdim=True) 71 | kl_ud = kl_div(logit[:, :, :, 1:], logit[:, :, :, :-1]).sum(1, keepdim=True) 72 | kl_lr = torch.nn.functional.pad( 73 | kl_lr, [0, 0, 0, 1, 0, 0, 0, 0], mode='constant', value=0) 74 | kl_ud = torch.nn.functional.pad( 75 | kl_ud, [0, 1, 0, 0, 0, 0, 0, 0], mode='constant', value=0) 76 | kl_combine = kl_lr+kl_ud 77 | while True: # avoid the case that full image is the same color 78 | kl_combine_bin = (kl_combine > eps).to(torch.float) 79 | if kl_combine_bin.sum() > max_N: 80 | eps *=1.2 81 | else: 82 | break 83 | #dilate 84 | dilate_weight = torch.ones((1,1,3,3)).cuda() 85 | edge2 = torch.nn.functional.conv2d(kl_combine_bin, dilate_weight, stride=1, padding=1) 86 | edge2 = edge2.squeeze(1) # NCHW->NHW 87 | kl_combine_bin = (edge2 > 0) 88 | return kl_combine_bin 89 | 90 | def gt2boundary(self, gt, ignore_label=-1): # gt NHW 91 | gt_lr = gt[:,1:,:]-gt[:,:-1,:] # NHW 92 | gt_ud = gt[:,:,1:]-gt[:,:,:-1] 93 | gt_lr = torch.nn.functional.pad(gt_lr, [0,0,0,1,0,0], mode='constant', value=0) != 0 94 | gt_ud = torch.nn.functional.pad(gt_ud, [0,1,0,0,0,0], mode='constant', value=0) != 0 95 | gt_combine = gt_lr+gt_ud 96 | del gt_lr 97 | del gt_ud 98 | 99 | # set 'ignore area' to all boundary 100 | gt_combine += (gt==ignore_label) 101 | 102 | return gt_combine > 0 103 | 104 | def get_direction_gt_predkl(self, pred_dist_map, pred_bound, logits): 105 | # NHW,NHW,NCHW 106 | eps = 1e-5 107 | # bound = torch.where(pred_bound) # 3k 108 | bound = torch.nonzero(pred_bound*1) 109 | n,x,y = bound.T 110 | max_dis = 1e5 111 | 112 | logits = logits.permute(0,2,3,1) # NHWC 113 | 114 | pred_dist_map_d = torch.nn.functional.pad(pred_dist_map,(1,1,1,1,0,0),mode='constant', value=max_dis) # NH+2W+2 115 | 116 | logits_d = torch.nn.functional.pad(logits,(0,0,1,1,1,1,0,0),mode='constant') # N(H+2)(W+2)C 117 | logits_d[:,0,:,:] = logits_d[:,1,:,:] # N(H+2)(W+2)C 118 | logits_d[:,-1,:,:] = logits_d[:,-2,:,:] # N(H+2)(W+2)C 119 | logits_d[:,:,0,:] = logits_d[:,:,1,:] # N(H+2)(W+2)C 120 | logits_d[:,:,-1,:] = logits_d[:,:,-2,:] # N(H+2)(W+2)C 121 | 122 | """ 123 | | 4| 0| 5| 124 | | 2| 8| 3| 125 | | 6| 1| 7| 126 | """ 127 | x_range = [1, -1, 0, 0, -1, 1, -1, 1, 0] 128 | y_range = [0, 0, -1, 1, 1, 1, -1, -1, 0] 129 | dist_maps = torch.zeros((0,len(x))).cuda() # 8k 130 | kl_maps = torch.zeros((0,len(x))).cuda() # 8k 131 | 132 | kl_center = logits[(n,x,y)] # KC 133 | 134 | for dx, dy in zip(x_range, y_range): 135 | dist_now = pred_dist_map_d[(n,x+dx+1,y+dy+1)] 136 | dist_maps = torch.cat((dist_maps,dist_now.unsqueeze(0)),0) 137 | 138 | if dx != 0 or dy != 0: 139 | logits_now = logits_d[(n,x+dx+1,y+dy+1)] 140 | # kl_map_now = torch.kl_div((kl_center+eps).log(), logits_now+eps).sum(2) # 8KC->8K 141 | if self.use_detach: 142 | logits_now = logits_now.detach() 143 | kl_map_now = kl_div(kl_center, logits_now) 144 | 145 | kl_map_now = kl_map_now.sum(1) # KC->K 146 | kl_maps = torch.cat((kl_maps,kl_map_now.unsqueeze(0)),0) 147 | torch.clamp(kl_maps, min=0.0, max=20.0) 148 | 149 | # direction_gt shound be Nk (8k->K) 150 | direction_gt = torch.argmin(dist_maps, dim=0) 151 | # weight_ce = pred_dist_map[bound] 152 | weight_ce = pred_dist_map[(n,x,y)] 153 | # print(weight_ce) 154 | 155 | # delete if min is 8 (local position) 156 | direction_gt_idx = [direction_gt!=8] 157 | direction_gt = direction_gt[direction_gt_idx] 158 | 159 | 160 | kl_maps = torch.transpose(kl_maps,0,1) 161 | direction_pred = kl_maps[direction_gt_idx] 162 | weight_ce = weight_ce[direction_gt_idx] 163 | 164 | return direction_gt, direction_pred, weight_ce 165 | 166 | def get_dist_maps(self, target): 167 | target_detach = target.clone().detach() 168 | dist_maps = torch.cat([self.dist_map_transform(target_detach[i]) for i in range(target_detach.shape[0])]) 169 | out = -dist_maps 170 | out = torch.where(out>0, out, torch.zeros_like(out)) 171 | 172 | return out 173 | 174 | def forward(self, logits, target): 175 | eps = 1e-10 176 | ph, pw = logits.size(2), logits.size(3) 177 | h, w = target.size(1), target.size(2) 178 | 179 | if ph != h or pw != w: 180 | logits = F.interpolate(input=logits, size=( 181 | h, w), mode='bilinear', align_corners=True) 182 | 183 | gt_boundary = self.gt2boundary(target, ignore_label=self.ignore_label) 184 | 185 | dist_maps = self.get_dist_maps(gt_boundary).cuda() # <-- !!! it will slow down the training, you can move the code of distance map generation to dataloader. 186 | 187 | pred_boundary = self.logits2boundary(logits) 188 | if pred_boundary.sum() < 1: # avoid nan 189 | return None # you should check in the outside. if None, skip this loss. 190 | 191 | direction_gt, direction_pred, weight_ce = self.get_direction_gt_predkl(dist_maps, pred_boundary, logits) # NHW, NHW, NCHW 192 | 193 | # direction_pred [K,8], direction_gt [K] 194 | loss = self.criterion(direction_pred, direction_gt) 195 | 196 | weight_ce = self.weight_func(weight_ce) 197 | loss = (loss * weight_ce).mean() # add distance weight 198 | 199 | return loss 200 | 201 | 202 | if __name__ == '__main__': 203 | from torch.backends import cudnn 204 | import os 205 | import random 206 | cudnn.benchmark = False 207 | cudnn.deterministic = True 208 | 209 | seed = 0 210 | torch.manual_seed(seed) 211 | torch.cuda.manual_seed(seed) 212 | torch.cuda.manual_seed_all(seed) 213 | 214 | random.seed(seed) 215 | np.random.seed(seed) 216 | os.environ['PYTHONHASHSEED'] = str(seed) 217 | 218 | n,c,h,w = 1,2,100,100 219 | gt = torch.zeros((n,h,w)).cuda() 220 | gt[0,5] = 1 221 | gt[0,50] = 1 222 | logits = torch.randn((n,c,h,w)).cuda() 223 | 224 | abl = ABL() 225 | print(abl(logits, gt)) 226 | -------------------------------------------------------------------------------- /train/train.py: -------------------------------------------------------------------------------- 1 | from tqdm import trange 2 | import torch.nn 3 | import torch.nn.functional as F 4 | import torch.nn.parallel 5 | import torch.optim 6 | import torch.utils.data 7 | import torch.utils.data.distributed 8 | from tools.utils import * 9 | from tools.ops import compute_grad_gp, update_average, copy_norm_params, queue_data, dequeue_data, \ 10 | average_gradients, calc_adv_loss, calc_contrastive_loss, calc_recon_loss, \ 11 | calc_abl, calc_pseudo_hamming_loss, calc_wdl, calc_pkl 12 | from tools.hsic import RbfHSIC 13 | 14 | 15 | def add_indp_fact_loss(self, *exp_pairs): 16 | pairs = [] 17 | for _exp1, _exp2 in exp_pairs: 18 | _pairs = [(F.adaptive_avg_pool2d(_exp1[:, i], 1).squeeze(), 19 | F.adaptive_avg_pool2d(_exp2[:, i], 1).squeeze()) 20 | for i in range(_exp1.shape[1])] 21 | pairs += _pairs 22 | 23 | crit = RbfHSIC(1) 24 | losses = [crit(*pair) for pair in pairs] 25 | return losses 26 | 27 | def trainGAN(data_loader, networks, opts, epoch, args, additional, \ 28 | detach=False, quantize=False, style_con=False, \ 29 | reconstruction_losses=False, hsic_loss=False, abl=False,\ 30 | phl=False, wdl=False): 31 | # avg meter 32 | d_losses = AverageMeter() 33 | d_advs = AverageMeter() 34 | d_gps = AverageMeter() 35 | 36 | g_losses = AverageMeter() 37 | g_advs = AverageMeter() 38 | g_imgrecs = AverageMeter() 39 | g_rec = AverageMeter() 40 | 41 | moco_losses = AverageMeter() 42 | 43 | # set nets 44 | D = networks['D'] if not args.distributed else networks['D'].module 45 | G = networks['G'] if not args.distributed else networks['G'].module 46 | C = networks['C'] if not args.distributed else networks['C'].module 47 | G_EMA = networks['G_EMA'] if not args.distributed else networks['G_EMA'].module 48 | C_EMA = networks['C_EMA'] if not args.distributed else networks['C_EMA'].module 49 | # set opts 50 | d_opt = opts['D'] 51 | g_opt = opts['G'] 52 | c_opt = opts['C'] 53 | # switch to train mode 54 | D.train() 55 | G.train() 56 | C.train() 57 | C_EMA.train() 58 | G_EMA.train() 59 | 60 | logger = additional['logger'] 61 | 62 | 63 | # summary writer 64 | train_it = iter(data_loader) 65 | 66 | t_train = trange(0, args.iters, initial=0, total=args.iters) 67 | 68 | for i in t_train: 69 | try: 70 | imgs, y_org = next(train_it) 71 | except: 72 | train_it = iter(data_loader) # ...死循环 73 | print('TrainsetLen', len(data_loader)) 74 | imgs, y_org = next(train_it) # images, class_idxs 75 | 76 | x_org = imgs 77 | x_ref_idx = torch.randperm(x_org.size(0)) # shuffle 78 | 79 | # x_org = x_org.cuda(args.gpu) 80 | 81 | # y_org = y_org.cuda(args.gpu) 82 | # x_ref_idx = x_ref_idx.cuda(args.gpu) 83 | 84 | x_org = x_org.to(torch.cuda.current_device()) 85 | y_org = y_org.to(torch.cuda.current_device()) 86 | x_ref_idx = x_ref_idx.to(torch.cuda.current_device()) 87 | 88 | x_ref = x_org.clone() 89 | x_ref = x_ref[x_ref_idx] 90 | 91 | training_mode = 'GAN' 92 | 93 | #################### 94 | # BEGIN Train GANs # 95 | #################### 96 | with torch.no_grad(): 97 | y_ref = y_org.clone() 98 | y_ref = y_ref[x_ref_idx] 99 | s_ref = C.moco(x_ref) 100 | c_src, skip1, skip2 = G.cnt_encoder(x_org) 101 | x_fake, _ = G.decode(c_src, s_ref, skip1, skip2) 102 | 103 | x_ref.requires_grad_() 104 | 105 | d_real_logit, _ = D(x_ref, y_ref) 106 | d_fake_logit, _ = D(x_fake.detach(), y_ref) 107 | 108 | d_adv_real = calc_adv_loss(d_real_logit, 'd_real') 109 | d_adv_fake = calc_adv_loss(d_fake_logit, 'd_fake') 110 | 111 | d_adv = d_adv_real + d_adv_fake 112 | 113 | d_gp = args.w_gp * compute_grad_gp(d_real_logit, x_ref, is_patch=False) 114 | 115 | d_loss = d_adv + d_gp 116 | 117 | d_opt.zero_grad() 118 | d_adv_real.backward(retain_graph=True) 119 | d_gp.backward() 120 | d_adv_fake.backward() 121 | if args.distributed: 122 | average_gradients(D) 123 | d_opt.step() 124 | 125 | # Train G 126 | s_src = C.moco(x_org) 127 | s_ref = C.moco(x_ref) 128 | 129 | c_src, skip1, skip2 = G.cnt_encoder(x_org) 130 | x_fake, offset_loss = G.decode(c_src, s_ref, skip1, skip2) 131 | x_rec, _ = G.decode(c_src, s_src, skip1, skip2) 132 | 133 | g_fake_logit, _ = D(x_fake, y_ref) 134 | g_rec_logit, _ = D(x_rec, y_org) 135 | 136 | g_adv_fake = calc_adv_loss(g_fake_logit, 'g') 137 | g_adv_rec = calc_adv_loss(g_rec_logit, 'g') 138 | 139 | g_adv = g_adv_fake + g_adv_rec 140 | 141 | g_imgrec = calc_recon_loss(x_rec, x_org) if not args.no_l1 else torch.zeros_like(g_adv) 142 | 143 | if phl: 144 | g_imgrec += calc_pseudo_hamming_loss(x_rec, x_org, thres=0) # -1 ~ 1 145 | 146 | if wdl: 147 | g_imgrec += calc_wdl(x_rec, x_org) * args.w_wdl # -1 ~ 1 148 | 149 | if args.pkl: 150 | g_imgrec += calc_pkl(x_rec, x_org) * args.w_pkl # -1 ~ 1 151 | 152 | # TODO Maybe add detach and binary clip? 153 | if quantize: 154 | x_fake = (x_fake*255).round()/255 155 | 156 | if detach: 157 | x_fake = x_fake.detach() 158 | 159 | # c_x_fake, _, _ = G.cnt_encoder(x_fake) 160 | c_x_fake, skip1_x_fake, skip2_x_fake = G.cnt_encoder(x_fake) 161 | g_conrec = calc_recon_loss(c_x_fake, c_src) 162 | 163 | g_loss = args.w_adv * g_adv + args.w_rec * g_imgrec +args.w_rec * g_conrec + args.w_off * offset_loss 164 | 165 | # abl 166 | if abl: 167 | g_img_abl = calc_abl(x_rec, x_org) 168 | # print(f"abl:{g_img_abl} g_adv:{g_adv} g_imgrec:{g_imgrec} g_conrec:{g_conrec} offset_loss:{offset_loss}") 169 | if g_img_abl is not None: 170 | g_loss += args.w_rec * g_img_abl 171 | 172 | if style_con: 173 | s_x_fake = C.moco(x_fake) 174 | g_stylerec = calc_recon_loss(s_ref, s_x_fake) 175 | g_loss = g_loss + args.w_rec * g_stylerec 176 | 177 | if reconstruction_losses: 178 | if not style_con: 179 | s_x_fake = C.moco(x_fake) 180 | c_ref, skip1_ref, skip2_ref = G.cnt_encoder(x_ref) 181 | 182 | # need recheck 183 | x_org_fake, _ = G.decode(c_x_fake, s_src.detach(), skip1_x_fake, skip2_x_fake) 184 | x_ref_fake, _ = G.decode(c_ref.detach(), s_x_fake, skip1_ref.detach(), skip2_ref.detach()) 185 | 186 | # calc_recon_loss 187 | g_imgrec_org = calc_recon_loss(x_org_fake, x_org) 188 | g_imgrec_ref = calc_recon_loss(x_ref_fake, x_ref) 189 | g_loss = g_loss + args.w_rec * (g_imgrec_org + g_imgrec_ref) 190 | 191 | if hsic_loss: 192 | hsic_losses = add_indp_fact_loss( 193 | [s_src, c_src], 194 | [s_ref, c_ref], 195 | ) 196 | g_loss = g_loss + args.w_hsic * (hsic_losses.mean()) 197 | 198 | g_opt.zero_grad() 199 | c_opt.zero_grad() 200 | g_loss.backward() 201 | if args.distributed: 202 | average_gradients(G) 203 | average_gradients(C) 204 | c_opt.step() 205 | g_opt.step() 206 | 207 | ################## 208 | # END Train GANs # 209 | ################## 210 | 211 | 212 | if epoch >= args.ema_start: 213 | training_mode = training_mode + "_EMA" 214 | update_average(G_EMA, G) 215 | update_average(C_EMA, C) 216 | 217 | torch.cuda.synchronize() 218 | 219 | with torch.no_grad(): 220 | if epoch >= args.separated: 221 | d_losses.update(d_loss.item(), x_org.size(0)) 222 | d_advs.update(d_adv.item(), x_org.size(0)) 223 | d_gps.update(d_gp.item(), x_org.size(0)) 224 | 225 | g_losses.update(g_loss.item(), x_org.size(0)) 226 | g_advs.update(g_adv.item(), x_org.size(0)) 227 | g_imgrecs.update(g_imgrec.item(), x_org.size(0)) 228 | g_rec.update(g_conrec.item(), x_org.size(0)) 229 | 230 | moco_losses.update(offset_loss.item(), x_org.size(0)) 231 | 232 | if (i + 1) % args.log_step == 0 and (args.gpu == 0 or args.gpu == '0') and logger is not None and args.local_rank == 0: 233 | summary_step = epoch * args.iters + i 234 | add_logs(args, logger, 'D/LOSS', d_losses.avg, summary_step) 235 | add_logs(args, logger, 'D/ADV', d_advs.avg, summary_step) 236 | add_logs(args, logger, 'D/GP', d_gps.avg, summary_step) 237 | 238 | add_logs(args, logger, 'G/LOSS', g_losses.avg, summary_step) 239 | add_logs(args, logger, 'G/ADV', g_advs.avg, summary_step) 240 | add_logs(args, logger, 'G/IMGREC', g_imgrecs.avg, summary_step) 241 | add_logs(args, logger, 'G/conrec', g_rec.avg, summary_step) 242 | 243 | add_logs(args, logger, 'C/OFFSET', moco_losses.avg, summary_step) 244 | 245 | print('Epoch: [{}/{}] [{}/{}] MODE[{}] Avg Loss: D[{d_losses.avg:.2f}] G[{g_losses.avg:.2f}] '.format(epoch + 1, args.epochs, i+1, args.iters, 246 | training_mode, d_losses=d_losses, g_losses=g_losses)) 247 | 248 | copy_norm_params(G_EMA, G) 249 | copy_norm_params(C_EMA, C) 250 | 251 | -------------------------------------------------------------------------------- /cal_cf_weight.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from collections import OrderedDict 3 | 4 | import numpy as np 5 | 6 | import time 7 | import torch.nn 8 | import torch.nn.parallel 9 | import torch.backends.cudnn as cudnn 10 | import torch.distributed as dist 11 | import torch.optim 12 | # import torch.multiprocessing as mp 13 | import torch.utils.data 14 | import torch.utils.data.distributed 15 | 16 | from models.generator import Generator as Generator 17 | from models.discriminator import Discriminator as Discriminator 18 | from models.guidingNet import GuidingNet 19 | from tools.utils import * 20 | from datasets.datasetgetter import get_dataset 21 | from tqdm import tqdm, trange 22 | 23 | # from oss_client import OSSCTD 24 | 25 | # Configuration 26 | parser = argparse.ArgumentParser(description='PyTorch GAN Training') 27 | parser.add_argument('--data_path', type=str, default='data/imgs/Seen240_S80F50_TRAIN800', 28 | help='Dataset directory. Please refer Dataset in README.md') 29 | parser.add_argument('--basis_path', type=str, default='data/imgs/BASIS_S80F50_TRAIN800', 30 | help='Basis directory. Please refer Dataset in README.md') 31 | parser.add_argument('--workers', default=4, type=int, help='the number of workers of data loader') 32 | parser.add_argument('--font_len', default=5, type=int, help='the font id for style reference [style]') 33 | parser.add_argument('--basis_len', default=5, type=int, help='the font id for basis') 34 | 35 | parser.add_argument('--sty_dim', default=128, type=int, help='The size of style vector') 36 | parser.add_argument('--output_k', default=400, type=int, help='Total number of classes to use') 37 | parser.add_argument('--img_size', default=80, type=int, help='Input image size') 38 | 39 | parser.add_argument('--load_model', default=None, type=str, metavar='PATH', help='path to checkpoint') 40 | parser.add_argument('--baseline_idx', default=2, type=int, help='the index of baseline. \ 41 | 0: the old baseline. 1: baseline that move the place of DCN. 2: Add addtional ADAIN_Conv based 1. 3: Delete last ADAIN based 1') 42 | parser.add_argument('--load_style', type=str, default='', help='load style') 43 | 44 | parser.add_argument('-t', '--temperature', type=float, default=0.01, help='softmax temperature') 45 | 46 | parser.add_argument('--zero_eye', action='store_true', help='set eye to zero') 47 | parser.add_argument('--save_fn', type=str, default='', help='save weight name') 48 | 49 | args = parser.parse_args() 50 | args.bs_per_font = 40 51 | args.mini_batch = args.val_num = 6 52 | args.local_rank = 0 53 | 54 | def main(): 55 | st_main = time.time() 56 | args.data_dir = args.data_path 57 | 58 | args.att_to_use = list(range(args.font_len)) 59 | args.att_to_use_basis = list(range(args.basis_len)) 60 | 61 | args.epoch_acc = [] 62 | args.epoch_avg_subhead_acc = [] 63 | args.epoch_stats = [] 64 | 65 | networks, opts = build_model(args) 66 | load_model(args, networks, opts) 67 | dataset, _ = get_dataset(args) 68 | print(args.basis_path) 69 | basis_dataset, _ = get_dataset(args, data_dir=args.basis_path, class_to_use=args.att_to_use_basis) 70 | inf(networks, dataset['FULL'], basis_dataset['FULL'], args) 71 | print('Using ', time.time() - st_main) 72 | 73 | def load_to_list(ds, att): 74 | # import pdb; pdb.set_trace() 75 | # load all data 76 | each_cls = [] 77 | with torch.no_grad(): 78 | val_tot_tars = torch.tensor(ds.targets) 79 | with trange(len(att)) as t: 80 | for cls_idx in t: 81 | t.set_description('Loading Data') 82 | tmp_cls_set = (val_tot_tars == att[cls_idx]).nonzero() 83 | tmp_ds = torch.utils.data.Subset(ds, tmp_cls_set) 84 | tmp_dl = torch.utils.data.DataLoader(tmp_ds, batch_size=args.bs_per_font, shuffle=False, 85 | num_workers=4, pin_memory=True, drop_last=False) 86 | cls_now = torch.cat([x.clone() for x, _ in tmp_dl], 0) 87 | each_cls.append(cls_now) 88 | del tmp_dl 89 | return each_cls 90 | 91 | 92 | def inf(networks, dataset, basis_dataset, args): 93 | # set nets 94 | basis_each_cls = load_to_list(basis_dataset, args.att_to_use_basis) 95 | x_each_cls = load_to_list(dataset, args.att_to_use) 96 | chars_num = len(dataset) // args.font_len 97 | 98 | G_EMA = networks['G_EMA'] 99 | G_EMA.eval() 100 | 101 | refs_bar = trange(args.font_len) 102 | st = time.time() 103 | tot_idx = 0 104 | 105 | ws = [] # 400, basis 106 | 107 | for i in range(len(x_each_cls)): 108 | assert len(x_each_cls[i]) == chars_num 109 | for i in range(len(basis_each_cls)): 110 | assert len(basis_each_cls[i]) == chars_num 111 | 112 | basis_each_cls = torch.stack(basis_each_cls) # [10, 404, 3, h, w] 113 | 114 | with torch.no_grad(): 115 | for s_id_now in refs_bar: # [1, basis] 116 | refs_bar.set_description(f"Ref") 117 | 118 | ws_i = [] 119 | 120 | # if chars_num % args.mini_batch != 0: 121 | # print('cannot be divided without a remainder, set mini_batch to 1') 122 | # args.mini_batch = 1 123 | for idx, (cnt_idx) in enumerate(trange((int)(np.ceil(chars_num/args.mini_batch)), leave=False)): 124 | idx_min_now = cnt_idx * args.mini_batch 125 | idx_max_now = min(chars_num, (cnt_idx+1) * args.mini_batch) 126 | 127 | x_src = x_each_cls[s_id_now][idx_min_now:idx_max_now, :, :, :].cuda(non_blocking=True) 128 | c_src, skip1, skip2 = G_EMA.cnt_encoder(x_src) 129 | x_basis_now = basis_each_cls[:, idx_min_now:idx_max_now].cuda(non_blocking=True) # [10, val_batch_now, 3, h, w] 130 | basis_shape = x_basis_now.shape 131 | #c_src_basis, _, _ = G_EMA.cnt_encoder(x_basis_now.reshape(-1, *basis_shape[2:])) 132 | #c_src_basis = c_src_basis.reshape(*basis_shape[:2], *c_src_basis.shape[1:]) 133 | c_src_basis_list = [G_EMA.cnt_encoder(x_basis_now[bi])[0] for bi in range(basis_shape[0])] 134 | c_src_basis = torch.stack(c_src_basis_list) 135 | weight_now = get_logits_dis(c_src[None,...], c_src_basis) # [10] 136 | ws_i.append(weight_now) 137 | 138 | info = dict() 139 | info['fps'] = tot_idx/(time.time()-st) 140 | refs_bar.set_postfix(info) 141 | logits_i = torch.mean(torch.stack(ws_i),0)/args.temperature 142 | if args.zero_eye: 143 | logits_i[s_id_now] = -100 # mask eye 144 | ws_prob = torch.nn.functional.softmax(logits_i) 145 | ws.append(ws_prob) 146 | ws = torch.stack(ws).cpu() 147 | torch.save(ws, args.save_fn) 148 | ################# 149 | # Sub functions # 150 | ################# 151 | 152 | def get_logits_dis(tgt_feature, base_feature,opt='l1'): 153 | assert opt in ['inner', 'l1', 'l2'] 154 | base_n = base_feature.shape[0] 155 | if opt == 'inner': 156 | w_inner = tgt_feature * base_feature # [10,40,262144] 157 | w_mean = torch.mean(w_inner.reshape(base_n,-1), 1) 158 | elif opt == 'l1': 159 | w_l1 = tgt_feature - base_feature # [10,40,262144] 160 | w_mean = -torch.mean(w_l1.abs().reshape(base_n,-1), 1) 161 | elif opt == 'l2': 162 | w_l2 = (tgt_feature - base_feature)**2 # [10,40,262144] 163 | w_mean = -torch.mean(w_l2.reshape(base_n,-1), 1) 164 | return w_mean 165 | 166 | def print_args(args): 167 | for arg in vars(args): 168 | print('{:35}{:20}\n'.format(arg, str(getattr(args, arg)))) 169 | 170 | 171 | def build_model(args): 172 | args.to_train = 'CDG' 173 | 174 | networks = {} 175 | opts = {} 176 | if 'C' in args.to_train: 177 | networks['C'] = GuidingNet(args.img_size, {'cont': args.sty_dim, 'disc': args.output_k}) 178 | networks['C_EMA'] = GuidingNet(args.img_size, {'cont': args.sty_dim, 'disc': args.output_k}) 179 | if 'D' in args.to_train: 180 | networks['D'] = Discriminator(args.img_size, num_domains=args.output_k) 181 | if 'G' in args.to_train: 182 | networks['G'] = Generator(args.img_size, args.sty_dim, use_sn=False, mute=True, baseline_idx=args.baseline_idx) 183 | networks['G_EMA'] = Generator(args.img_size, args.sty_dim, use_sn=False, mute=True, baseline_idx=args.baseline_idx) 184 | 185 | for name, net in networks.items(): 186 | net_tmp = net.cuda() 187 | networks[name] = net_tmp #torch.nn.parallel.DistributedDataParallel(net_tmp ,device_ids=[local_rank], 188 | # output_device=local_rank) 189 | 190 | if 'C' in args.to_train: 191 | opts['C'] = torch.optim.Adam(networks['C'].parameters(), 1e-4, weight_decay=0.001) 192 | networks['C_EMA'].load_state_dict(networks['C'].state_dict()) 193 | if 'D' in args.to_train: 194 | opts['D'] = torch.optim.RMSprop(networks['D'].parameters(), 1e-4, weight_decay=0.0001) 195 | if 'G' in args.to_train: 196 | opts['G'] = torch.optim.RMSprop(networks['G'].parameters(), 1e-4, weight_decay=0.0001) 197 | 198 | return networks, opts 199 | 200 | 201 | def load_model(args, networks, opts): 202 | if args.load_model is not None: 203 | load_file = args.load_model 204 | if os.path.isfile(load_file): 205 | print("=> loading checkpoint '{}'".format(load_file)) 206 | checkpoint = torch.load(load_file, map_location='cpu') 207 | args.start_epoch = checkpoint['epoch'] 208 | 209 | for name, net in networks.items(): 210 | tmp_keys = next(iter(checkpoint[name + '_state_dict'].keys())) 211 | if 'module' in tmp_keys: 212 | tmp_new_dict = OrderedDict() 213 | for key, val in checkpoint[name + '_state_dict'].items(): 214 | tmp_new_dict[key[7:]] = val 215 | # tmp_new_dict[key] = val 216 | net.load_state_dict(tmp_new_dict, strict=False) 217 | networks[name] = net 218 | else: 219 | net.load_state_dict(checkpoint[name + '_state_dict']) 220 | networks[name] = net 221 | 222 | for name, opt in opts.items(): 223 | opt.load_state_dict(checkpoint[name.lower() + '_optimizer']) 224 | opts[name] = opt 225 | print("=> loaded checkpoint '{}' (epoch {})".format(load_file, checkpoint['epoch'])) 226 | else: 227 | print("=> no checkpoint found at '{}'".format(args.load_model)) 228 | 229 | if __name__ == '__main__': 230 | main() 231 | -------------------------------------------------------------------------------- /collect_content_embeddings.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from genericpath import exists 3 | import warnings 4 | from datetime import datetime 5 | from glob import glob 6 | from shutil import copyfile 7 | from collections import OrderedDict 8 | 9 | import numpy as np 10 | 11 | import torch.nn 12 | import torch.nn.parallel 13 | import torch.backends.cudnn as cudnn 14 | import torch.distributed as dist 15 | import torch.optim 16 | # import torch.multiprocessing as mp 17 | import torch.utils.data 18 | import torch.utils.data.distributed 19 | 20 | from models.generator import Generator as Generator 21 | from models.discriminator import Discriminator as Discriminator 22 | from models.guidingNet import GuidingNet 23 | from models.inception import InceptionV3 24 | 25 | from train.train import trainGAN 26 | 27 | from validation.validation import validateUN 28 | 29 | from tools.utils import * 30 | from datasets.datasetgetter import get_dataset 31 | from tools.ops import initialize_queue 32 | 33 | import torchvision.utils as vutils 34 | from tqdm import tqdm 35 | 36 | import pdb 37 | 38 | from tools.ops import compute_grad_gp, update_average, copy_norm_params, queue_data, dequeue_data, \ 39 | average_gradients, calc_adv_loss, calc_contrastive_loss, calc_recon_loss, calc_abl 40 | 41 | # from oss_client import OSSCTD 42 | 43 | # Configuration 44 | parser = argparse.ArgumentParser(description='PyTorch GAN Training') 45 | parser.add_argument("--save_path", default='../vis', help="where to store images") 46 | 47 | parser.add_argument('--data_path', type=str, default='../data', 48 | help='Dataset directory. Please refer Dataset in README.md') 49 | parser.add_argument('--workers', default=4, type=int, help='the number of workers of data loader') 50 | 51 | parser.add_argument('--model_name', type=str, default='GAN', 52 | help='Prefix of logs and results folders. ' 53 | 'ex) --model_name=ABC generates ABC_20191230-131145 in logs and results') 54 | 55 | parser.add_argument('--batch_size', default=32, type=int, help='Batch size for training') 56 | parser.add_argument('--val_batch', default=1, type=int, help='Batch size for validation. ' 57 | 'The result images are stored in the form of (val_batch, val_batch) grid.') 58 | parser.add_argument('--ref_num', default=10, type=int, help='Number of images as reference') 59 | parser.add_argument('--s_id', default=5, type=int, help='the font id for style reference [style]') 60 | parser.add_argument('--c_id', default=0, type=int, help='the font id for content [content]') 61 | parser.add_argument('--ft_id', default=0, type=int, help='the font id for finetune [finetune]') 62 | 63 | parser.add_argument('--sty_dim', default=128, type=int, help='The size of style vector') 64 | parser.add_argument('--output_k', default=400, type=int, help='Total number of classes to use') 65 | parser.add_argument('--img_size', default=80, type=int, help='Input image size') 66 | parser.add_argument('--dims', default=2048, type=int, help='Inception dims for FID') 67 | 68 | parser.add_argument('--ft_epoch', default=0, type=int, help='the number of epochs for style vector finetune') 69 | 70 | parser.add_argument('--w_rec', default=0.1, type=float, help='Coefficient of Rec. loss of G') 71 | parser.add_argument('--lr', default=1e-4, type=float, help='learning rate') 72 | 73 | parser.add_argument('--load_style', type=str, default='', help='load style') 74 | 75 | parser.add_argument('--abl', action='store_true', help='using ABL') 76 | parser.add_argument('--no_skip', action='store_true', help='not save skip') 77 | # parser.add_argument('--vis', action='store_true', help='vis result') 78 | parser.add_argument('--vis', type=str, default='', help='vis result path') 79 | 80 | parser.add_argument('--load_model', default=None, type=str, metavar='PATH', 81 | help='path to latest checkpoint (default: None)' 82 | 'ex) --load_model GAN_20190101_101010' 83 | 'It loads the latest .ckpt file specified in checkpoint.txt in GAN_20190101_101010') 84 | parser.add_argument('--n_atts', default=400, type=int, help='The size of atention maps') 85 | 86 | parser.add_argument('--baseline_idx', default=2, type=int, help='the index of baseline. \ 87 | 0: the old baseline. 1: baseline that move the place of DCN. 2: Add addtional ADAIN_Conv based 1. 3: Delete last ADAIN based 1') 88 | 89 | parser.add_argument('--nocontent', action='store_true', help='no content') 90 | 91 | 92 | args = parser.parse_args() 93 | args.val_num = 30 94 | args.local_rank = 0 95 | 96 | n_atts = args.n_atts 97 | def main(): 98 | args.num_cls = args.output_k 99 | args.data_dir = args.data_path 100 | 101 | args.att_to_use = list(range(n_atts)) 102 | 103 | # IIC statistics 104 | args.epoch_acc = [] 105 | args.epoch_avg_subhead_acc = [] 106 | args.epoch_stats = [] 107 | 108 | # build model - return dict 109 | networks, opts = build_model(args) 110 | 111 | # load model if args.load_model is specified 112 | load_model(args, networks, opts) 113 | 114 | # All the test is done in the training - do not need to call 115 | dataset, _ = get_dataset(args) 116 | inf(dataset, networks, opts, 999, args) 117 | 118 | 119 | def inf(dataset, networks, opts, epoch, args): 120 | # set nets 121 | if not os.path.exists(args.save_path): 122 | os.mkdir(args.save_path) 123 | 124 | # data loader 125 | val_dataset = dataset['FULL'] 126 | 127 | # load all data 128 | C_EMA = networks['C_EMA'] 129 | G_EMA = networks['G_EMA'] 130 | C_EMA.eval() 131 | G_EMA.eval() 132 | 133 | with torch.no_grad(): 134 | val_tot_tars = torch.tensor(val_dataset.targets) 135 | s_refs = [] 136 | c_srcs = [] 137 | skip1s = [] 138 | skip2s = [] 139 | for cls_idx in tqdm(range(n_atts)): 140 | tmp_cls_set = (val_tot_tars == cls_idx).nonzero() 141 | val_num = len(tmp_cls_set) 142 | tmp_ds = torch.utils.data.Subset(val_dataset, tmp_cls_set) 143 | tmp_dl = torch.utils.data.DataLoader(tmp_ds, batch_size=val_num, 144 | shuffle=False, num_workers=2, pin_memory=True, drop_last=False) 145 | tmp_iter = iter(tmp_dl) 146 | tmp_sample = None 147 | for sample_idx in range(len(tmp_iter)): 148 | imgs, _ = next(tmp_iter) 149 | x_ = imgs 150 | tmp_sample = x_.clone() if tmp_sample is None else torch.cat((tmp_sample, x_), 0) 151 | 152 | x_ref = tmp_sample.cuda() # all ref 153 | s_ref = C_EMA(x_ref, sty=True) 154 | s_refs.append(torch.mean(s_ref.detach().cpu(), dim=0, keepdim=True)) # average 155 | 156 | if not args.nocontent: 157 | c_src, skip1, skip2 = G_EMA.cnt_encoder(x_ref) 158 | # pdb.set_trace() 159 | c_srcs.append(c_src.detach().cpu().unsqueeze(0)) 160 | if not args.no_skip: 161 | skip1s.append(skip1.detach().cpu().unsqueeze(0)) 162 | skip2s.append(skip2.detach().cpu().unsqueeze(0)) 163 | 164 | 165 | 166 | s_ref = torch.cat(s_refs, dim=0) 167 | ref_fn = os.path.join(args.save_path, 'style.pth') 168 | torch.save(s_ref, ref_fn) 169 | if not args.nocontent: 170 | c_src = torch.cat(c_srcs, dim=0) 171 | c_src_fn = os.path.join(args.save_path, 'c_src.pth') 172 | torch.save(c_src, c_src_fn) 173 | if not args.no_skip: 174 | skip1 = torch.cat(skip1s, dim=0) 175 | skip2 = torch.cat(skip2s, dim=0) 176 | skip1_fn = os.path.join(args.save_path, 'skip1.pth') 177 | skip2_fn = os.path.join(args.save_path, 'skip2.pth') 178 | torch.save(skip1, skip1_fn) 179 | torch.save(skip2, skip2_fn) 180 | 181 | 182 | 183 | 184 | ################# 185 | # Sub functions # 186 | ################# 187 | def print_args(args): 188 | for arg in vars(args): 189 | print('{:35}{:20}\n'.format(arg, str(getattr(args, arg)))) 190 | 191 | 192 | def build_model(args): 193 | args.to_train = 'CG' 194 | 195 | networks = {} 196 | opts = {} 197 | if 'C' in args.to_train: 198 | networks['C'] = GuidingNet(args.img_size, {'cont': args.sty_dim, 'disc': args.output_k}) 199 | networks['C_EMA'] = GuidingNet(args.img_size, {'cont': args.sty_dim, 'disc': args.output_k}) 200 | if 'D' in args.to_train: 201 | networks['D'] = Discriminator(args.img_size, num_domains=args.output_k) 202 | if 'G' in args.to_train: 203 | networks['G'] = Generator(args.img_size, args.sty_dim, use_sn=False, mute=True, baseline_idx=args.baseline_idx) 204 | networks['G_EMA'] = Generator(args.img_size, args.sty_dim, use_sn=False, mute=True, baseline_idx=args.baseline_idx) 205 | 206 | for name, net in networks.items(): 207 | net_tmp = net.cuda() 208 | networks[name] = net_tmp #torch.nn.parallel.DistributedDataParallel(net_tmp ,device_ids=[local_rank], 209 | # output_device=local_rank) 210 | 211 | if 'C' in args.to_train: 212 | opts['C'] = torch.optim.Adam(networks['C'].parameters(), 1e-4, weight_decay=0.001) 213 | networks['C_EMA'].load_state_dict(networks['C'].state_dict()) 214 | if 'D' in args.to_train: 215 | opts['D'] = torch.optim.RMSprop(networks['D'].parameters(), 1e-4, weight_decay=0.0001) 216 | if 'G' in args.to_train: 217 | opts['G'] = torch.optim.RMSprop(networks['G'].parameters(), 1e-4, weight_decay=0.0001) 218 | 219 | return networks, opts 220 | 221 | 222 | def load_model(args, networks, opts): 223 | if args.load_model is not None: 224 | load_file = args.load_model 225 | if os.path.isfile(load_file): 226 | print("=> loading checkpoint '{}'".format(load_file)) 227 | checkpoint = torch.load(load_file, map_location='cpu') 228 | args.start_epoch = checkpoint['epoch'] 229 | 230 | for name, net in networks.items(): 231 | tmp_keys = next(iter(checkpoint[name + '_state_dict'].keys())) 232 | if 'module' in tmp_keys: 233 | tmp_new_dict = OrderedDict() 234 | for key, val in checkpoint[name + '_state_dict'].items(): 235 | tmp_new_dict[key[7:]] = val 236 | # tmp_new_dict[key] = val 237 | net.load_state_dict(tmp_new_dict, strict=False) 238 | networks[name] = net 239 | else: 240 | net.load_state_dict(checkpoint[name + '_state_dict']) 241 | networks[name] = net 242 | 243 | for name, opt in opts.items(): 244 | opt.load_state_dict(checkpoint[name.lower() + '_optimizer']) 245 | opts[name] = opt 246 | print("=> loaded checkpoint '{}' (epoch {})".format(load_file, checkpoint['epoch'])) 247 | else: 248 | print("=> no checkpoint found at '{}'".format(args.load_model)) 249 | 250 | if __name__ == '__main__': 251 | main() 252 | --------------------------------------------------------------------------------