├── README.md ├── environment.yml ├── models └── metafusion_net.py ├── test.py ├── utils └── dataloader.py └── weight └── model_weight.pth /README.md: -------------------------------------------------------------------------------- 1 | # MetaFusion in PyTorch 2 | Implementation of "MetaFusion: Infrared and Visible Image Fusion via Meta-Feature Embedding from Object Detection" in PyTorch. 3 | 4 | # Requirements 5 | python 3.7 6 | 7 | pytorch 1.8.1 8 | 9 | cv2 4.5.5 10 | # Test 11 | You can use the following command to test: 12 | 13 | >python test.py --test_ir_root IR_IMAGE_PATH --test_vis_root VIS_IMAGE_PATH --save_path RESULT_IMAGE_PATH 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: fusedet 2 | channels: 3 | - pytorch 4 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/simpleitk 5 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch 6 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/menpo 7 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/bioconda 8 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/msys2 9 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge 10 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/msys2 11 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/r 12 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main 13 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ 14 | - defaults 15 | dependencies: 16 | - _libgcc_mutex=0.1=main 17 | - _openmp_mutex=4.5=1_gnu 18 | - aiohttp=3.8.1=py37h7f8727e_1 19 | - aiosignal=1.2.0=pyhd3eb1b0_0 20 | - async-timeout=4.0.1=pyhd3eb1b0_0 21 | - asynctest=0.13.0=py_0 22 | - attrs=21.4.0=pyhd3eb1b0_0 23 | - blas=1.0=mkl 24 | - blinker=1.4=py37h06a4308_0 25 | - brotlipy=0.7.0=py37h27cfd23_1003 26 | - bzip2=1.0.8=h7b6447c_0 27 | - c-ares=1.18.1=h7f8727e_0 28 | - ca-certificates=2022.4.26=h06a4308_0 29 | - certifi=2021.10.8=py37h06a4308_2 30 | - cffi=1.15.0=py37hd667e15_1 31 | - click=8.0.4=py37h06a4308_0 32 | - cryptography=3.4.8=py37hd23ed53_0 33 | - cudatoolkit=10.2.89=hfd86e86_1 34 | - dataclasses=0.8=pyh6d0b6a4_7 35 | - ffmpeg=4.2.2=h20bf706_0 36 | - freetype=2.11.0=h70c0345_0 37 | - frozenlist=1.2.0=py37h7f8727e_0 38 | - giflib=5.2.1=h7b6447c_0 39 | - gmp=6.2.1=h2531618_2 40 | - gnutls=3.6.15=he1e5248_0 41 | - idna=3.3=pyhd3eb1b0_0 42 | - importlib-metadata=4.11.3=py37h06a4308_0 43 | - intel-openmp=2021.4.0=h06a4308_3561 44 | - joblib=1.1.0=pyhd3eb1b0_0 45 | - jpeg=9b=h024ee3a_2 46 | - lame=3.100=h7b6447c_0 47 | - lcms2=2.12=h3be6417_0 48 | - ld_impl_linux-64=2.35.1=h7274673_9 49 | - libffi=3.3=he6710b0_2 50 | - libgcc-ng=9.3.0=h5101ec6_17 51 | - libgfortran-ng=7.5.0=ha8ba4b0_17 52 | - libgfortran4=7.5.0=ha8ba4b0_17 53 | - libgomp=9.3.0=h5101ec6_17 54 | - libidn2=2.3.2=h7f8727e_0 55 | - libopus=1.3.1=h7b6447c_0 56 | - libpng=1.6.37=hbc83047_0 57 | - libprotobuf=3.19.1=h4ff587b_0 58 | - libstdcxx-ng=9.3.0=hd4cf53a_17 59 | - libtasn1=4.16.0=h27cfd23_0 60 | - libtiff=4.1.0=h2733197_1 61 | - libunistring=0.9.10=h27cfd23_0 62 | - libuv=1.40.0=h7b6447c_0 63 | - libvpx=1.7.0=h439df22_0 64 | - libwebp=1.2.0=h89dd481_0 65 | - lz4-c=1.9.3=h295c915_1 66 | - mkl=2021.4.0=h06a4308_640 67 | - mkl-service=2.4.0=py37h7f8727e_0 68 | - mkl_fft=1.3.1=py37hd3c417c_0 69 | - mkl_random=1.2.2=py37h51133e4_0 70 | - multidict=5.2.0=py37h7f8727e_2 71 | - ncurses=6.3=h7f8727e_2 72 | - nettle=3.7.3=hbbd107a_1 73 | - ninja=1.10.2=h06a4308_5 74 | - ninja-base=1.10.2=hd09550d_5 75 | - numpy=1.21.5=py37he7a7128_2 76 | - numpy-base=1.21.5=py37hf524024_2 77 | - oauthlib=3.2.0=pyhd3eb1b0_0 78 | - openh264=2.1.1=h4ff587b_0 79 | - openssl=1.1.1n=h7f8727e_0 80 | - pillow=9.0.1=py37h22f2fdc_0 81 | - pip=21.2.2=py37h06a4308_0 82 | - pyasn1=0.4.8=pyhd3eb1b0_0 83 | - pycparser=2.21=pyhd3eb1b0_0 84 | - pyjwt=2.1.0=py37h06a4308_0 85 | - pyopenssl=21.0.0=pyhd3eb1b0_1 86 | - pysocks=1.7.1=py37_1 87 | - python=3.7.13=h12debd9_0 88 | - pytorch=1.8.1=py3.7_cuda10.2_cudnn7.6.5_0 89 | - readline=8.1.2=h7f8727e_1 90 | - requests=2.27.1=pyhd3eb1b0_0 91 | - scikit-learn=1.0.2=py37h51133e4_1 92 | - scipy=1.7.3=py37hc147768_0 93 | - setuptools=61.2.0=py37h06a4308_0 94 | - six=1.16.0=pyhd3eb1b0_1 95 | - sqlite=3.38.3=hc218d9a_0 96 | - tensorboardx=2.2=pyhd3eb1b0_0 97 | - threadpoolctl=2.2.0=pyh0d69192_0 98 | - tk=8.6.11=h1ccaba5_0 99 | - torchaudio=0.8.1=py37 100 | - torchvision=0.9.1=py37_cu102 101 | - typing-extensions=4.1.1=hd3eb1b0_0 102 | - typing_extensions=4.1.1=pyh06a4308_0 103 | - urllib3=1.26.9=py37h06a4308_0 104 | - wheel=0.37.1=pyhd3eb1b0_0 105 | - x264=1!157.20191217=h7b6447c_0 106 | - xz=5.2.5=h7f8727e_1 107 | - yarl=1.6.3=py37h27cfd23_0 108 | - zlib=1.2.12=h7f8727e_2 109 | - zstd=1.4.9=haebb681_0 110 | - pip: 111 | - absl-py==1.0.0 112 | - backcall==0.2.0 113 | - cachetools==5.0.0 114 | - charset-normalizer==2.0.12 115 | - cycler==0.11.0 116 | - decorator==5.1.1 117 | - fonttools==4.33.3 118 | - google-auth==2.6.6 119 | - google-auth-oauthlib==0.4.6 120 | - grpcio==1.44.0 121 | - imageio==2.19.5 122 | - ipython==7.34.0 123 | - jedi==0.18.1 124 | - kiwisolver==1.4.2 125 | - kornia==0.6.4 126 | - markdown==3.3.6 127 | - matplotlib==3.5.1 128 | - matplotlib-inline==0.1.3 129 | - networkx==2.6.3 130 | - opencv-python==4.5.5.64 131 | - packaging==21.3 132 | - pandas==1.3.5 133 | - parso==0.8.3 134 | - pexpect==4.8.0 135 | - pickleshare==0.7.5 136 | - prompt-toolkit==3.0.30 137 | - protobuf==3.20.1 138 | - psutil==5.9.1 139 | - ptyprocess==0.7.0 140 | - pyasn1-modules==0.2.8 141 | - pycocotools==2.0.6 142 | - pygments==2.12.0 143 | - pyparsing==3.0.8 144 | - python-dateutil==2.8.2 145 | - pytz==2022.1 146 | - pywavelets==1.3.0 147 | - pyyaml==6.0 148 | - requests-oauthlib==1.3.1 149 | - rsa==4.8 150 | - scikit-image==0.19.3 151 | - seaborn==0.11.2 152 | - tensorboard==2.9.0 153 | - tensorboard-data-server==0.6.1 154 | - tensorboard-plugin-wit==1.8.1 155 | - thop==0.0.31-2005241907 156 | - tifffile==2021.11.2 157 | - tqdm==4.64.0 158 | - traitlets==5.3.0 159 | - wcwidth==0.2.5 160 | - werkzeug==2.1.2 161 | - zipp==3.8.0 162 | prefix: /home/xie/anaconda3/envs/fusedet 163 | -------------------------------------------------------------------------------- /models/metafusion_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import torch.nn.init as init 6 | 7 | 8 | def to_var(x, requires_grad=True): 9 | if torch.cuda.is_available(): 10 | x = x.cuda() 11 | return Variable(x, requires_grad=requires_grad) 12 | 13 | 14 | class MetaModule(nn.Module): 15 | # adopted from: Adrien Ecoffet https://github.com/AdrienLE 16 | def params(self): 17 | for name, param in self.named_params(self): 18 | yield param 19 | 20 | def named_leaves(self): 21 | return [] 22 | 23 | def named_submodules(self): 24 | return [] 25 | 26 | def named_params(self, curr_module=None, memo=None, prefix=''): 27 | if memo is None: 28 | memo = set() 29 | 30 | if hasattr(curr_module, 'named_leaves'): 31 | for name, p in curr_module.named_leaves(): 32 | if p is not None and p not in memo: 33 | memo.add(p) 34 | yield prefix + ('.' if prefix else '') + name, p 35 | else: 36 | for name, p in curr_module._parameters.items(): 37 | if p is not None and p not in memo: 38 | memo.add(p) 39 | yield prefix + ('.' if prefix else '') + name, p 40 | 41 | for mname, module in curr_module.named_children(): 42 | submodule_prefix = prefix + ('.' if prefix else '') + mname 43 | for name, p in self.named_params(module, memo, submodule_prefix): 44 | yield name, p 45 | 46 | def update_params(self, lr_inner, first_order=False, source_params=None, detach=False): 47 | if source_params is not None: 48 | for tgt, src in zip(self.named_params(self), source_params): 49 | name_t, param_t = tgt 50 | grad = src 51 | if first_order: 52 | grad = to_var(grad.detach().data) 53 | if grad is not None: 54 | # print(grad) 55 | # TODO 56 | tmp = param_t - lr_inner * grad 57 | self.set_param(self, name_t, tmp) 58 | else: 59 | 60 | for name, param in self.named_params(self): 61 | if not detach: 62 | grad = param.grad 63 | if first_order: 64 | grad = to_var(grad.detach().data) 65 | tmp = param - lr_inner * grad 66 | self.set_param(self, name, tmp) 67 | else: 68 | param = param.detach_() 69 | self.set_param(self, name, param) 70 | 71 | def set_param(self, curr_mod, name, param): 72 | if '.' in name: 73 | n = name.split('.') 74 | module_name = n[0] 75 | rest = '.'.join(n[1:]) 76 | for name, mod in curr_mod.named_children(): 77 | if module_name == name: 78 | self.set_param(mod, rest, param) 79 | break 80 | else: 81 | setattr(curr_mod, name, param) 82 | 83 | def detach_params(self): 84 | for name, param in self.named_params(self): 85 | self.set_param(self, name, param.detach()) 86 | 87 | def copy(self, other, same_var=False): 88 | for name, param in other.named_params(): 89 | if not same_var: 90 | param = to_var(param.data.clone(), requires_grad=True) 91 | self.set_param(name, param) 92 | 93 | 94 | class MetaConv2d(MetaModule): 95 | def __init__(self, *args, **kwargs): 96 | super().__init__() 97 | ignore = nn.Conv2d(*args, **kwargs) 98 | 99 | self.in_channels = ignore.in_channels 100 | self.out_channels = ignore.out_channels 101 | self.stride = ignore.stride 102 | self.padding = ignore.padding 103 | self.dilation = ignore.dilation 104 | self.groups = ignore.groups 105 | self.kernel_size = ignore.kernel_size 106 | 107 | self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True)) 108 | 109 | if ignore.bias is not None: 110 | self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True)) 111 | else: 112 | self.register_buffer('bias', None) 113 | 114 | def forward(self, x): 115 | return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 116 | 117 | def named_leaves(self): 118 | return [('weight', self.weight), ('bias', self.bias)] 119 | 120 | 121 | def _weights_init(m): 122 | classname = m.__class__.__name__ 123 | if isinstance(m, MetaConv2d): 124 | init.kaiming_normal(m.weight) 125 | 126 | 127 | class FusionNet(MetaModule): 128 | def __init__(self, block_num, feature_out): 129 | super(FusionNet, self).__init__() 130 | block1 = [] 131 | self.feature_out = feature_out 132 | for i in range(block_num): 133 | if i == 0: 134 | block1.append(FusionBlock(in_block=4, out_block=64, k_size=3)) 135 | elif i == 1: 136 | block1.append(FusionBlock(in_block=128, out_block=128, k_size=3)) 137 | else: 138 | block1.append(FusionBlock(in_block=256, out_block=128, k_size=3)) 139 | self.block1 = nn.Sequential(*block1) 140 | 141 | if block_num == 1: 142 | self.block2_in = 128 143 | else: 144 | self.block2_in = 256 145 | self.block2 = nn.Sequential( 146 | nn.Conv2d(self.block2_in, 128, 3, 1, 1), 147 | nn.ReLU(inplace=True), 148 | nn.Conv2d(128, 128, 3, 1, 1), 149 | nn.ReLU(inplace=True), 150 | 151 | nn.Conv2d(128, 128, 3, 1, 1), 152 | nn.ReLU(inplace=True), 153 | nn.Conv2d(128, 128, 3, 1, 1), 154 | nn.ReLU(inplace=True), 155 | nn.Conv2d(128, 64, 3, 1, 1), 156 | nn.ReLU(inplace=True), 157 | 158 | nn.Conv2d(64, 2, 1, 1, 0), 159 | nn.Sigmoid(), 160 | ) 161 | 162 | def forward(self, x): 163 | 164 | if self.feature_out: 165 | block_out = [] 166 | for i in range(len(self.block1)): 167 | x = self.block1[i](x) 168 | block_out.append(x.clone()) 169 | 170 | x = self.block2(x) 171 | return block_out, x # [from shallow to deep] 172 | else: 173 | x = self.block1(x) 174 | x = self.block2(x) 175 | 176 | return None, x 177 | 178 | 179 | class FusionBlock(MetaModule): 180 | def __init__(self, in_block, out_block, k_size=3): 181 | super(FusionBlock, self).__init__() 182 | self.conv1_1 = MetaConv2d( 183 | in_channels=in_block, 184 | out_channels=out_block, 185 | kernel_size=k_size, 186 | stride=1, 187 | padding=(k_size - 1) // 2, 188 | bias=True 189 | ) 190 | self.conv1_2 = MetaConv2d( 191 | in_channels=out_block, 192 | out_channels=out_block, 193 | kernel_size=k_size, 194 | stride=1, 195 | padding=(k_size - 1) // 2, 196 | bias=True 197 | ) 198 | self.conv1_0_00 = MetaConv2d( 199 | in_channels=out_block, 200 | out_channels=out_block, 201 | kernel_size=k_size, 202 | stride=1, 203 | padding=(k_size - 1) // 2, 204 | bias=True 205 | ) 206 | self.conv1_0_01 = MetaConv2d( 207 | in_channels=out_block, 208 | out_channels=out_block, 209 | kernel_size=k_size, 210 | stride=1, 211 | padding=(k_size - 1) // 2, 212 | bias=True 213 | ) 214 | self.relu = nn.ReLU() 215 | 216 | def forward(self, x): 217 | x = self.conv1_1(x) 218 | x = self.relu(x) 219 | x = self.conv1_2(x) 220 | x = self.relu(x) 221 | 222 | x0 = self.conv1_0_00(x) 223 | x0 = self.relu(x0) 224 | x1 = self.conv1_0_01(x) 225 | x1 = self.relu(x1) 226 | 227 | return torch.cat([x0, x1], dim=1) 228 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import cv2 4 | from PIL import Image 5 | from tqdm import tqdm 6 | import numpy as np 7 | import torch 8 | 9 | from models.metafusion_net import FusionNet as FusionNetwork 10 | from utils.dataloader import get_test_loader 11 | 12 | device = 'cuda:0' 13 | 14 | 15 | def test(test_loader, model, checkpoint, save_path): 16 | val_save_path = save_path 17 | os.makedirs(val_save_path, exist_ok=True) 18 | model.eval() 19 | model.load_state_dict(torch.load(checkpoint), strict=True) 20 | tqdm.write('load from{}'.format(checkpoint)) 21 | with torch.no_grad(): 22 | 23 | for i, (irimage, visimage_rgb, visimage_bri, visimage_clr, image_name) in enumerate(tqdm(test_loader), start=1): 24 | ir_image = irimage.to(device) 25 | visimage_rgb = visimage_rgb.to(device) 26 | visimage_bri = visimage_bri.to(device) 27 | 28 | _, res_weight = model(torch.cat([ir_image, visimage_rgb], dim=1)) 29 | fus_img = res_weight[:, 0, :, :] * ir_image + res_weight[:, 1, :, :] * visimage_bri 30 | 31 | # HSV2RGB 32 | bri = fus_img.detach().cpu().numpy() * 255 33 | bri = bri.reshape([fus_img.size()[2], fus_img.size()[3]]) 34 | bri = np.where(bri < 0, 0, bri) 35 | bri = np.where(bri > 255, 255, bri) 36 | im1 = Image.fromarray(bri.astype(np.uint8)) 37 | 38 | clr = visimage_clr.numpy().squeeze().transpose(1, 2, 0) * 255 39 | clr = np.concatenate((clr, bri.reshape(fus_img.size()[2], fus_img.size()[3], 1)), axis=2) 40 | 41 | clr[:, :, 2] = im1 42 | clr = cv2.cvtColor(clr.astype(np.uint8), cv2.COLOR_HSV2RGB) 43 | 44 | if 'TNO' in image_name[0]: 45 | cv2.imwrite( 46 | os.path.join(val_save_path, os.path.split(image_name[0])[1]), 47 | clr) 48 | else: 49 | cv2.imwrite( 50 | os.path.join(val_save_path, os.path.split(image_name[0])[1][:-4] + '.jpg'), 51 | clr) 52 | 53 | 54 | if __name__ == '__main__': 55 | import argparse 56 | 57 | parser = argparse.ArgumentParser() 58 | parser.add_argument('--checkpoint', type=str, default='./weight/model_weight.pth', help='fusion network weight') 59 | parser.add_argument('--blocks', type=int, default=3, help='blocks number') 60 | parser.add_argument('--test_ir_root', type=str, default='', required=True, help='the test ir images root') 61 | parser.add_argument('--test_vis_root', type=str, default='', required=True, help='the test vis images root') 62 | parser.add_argument('--save_path', type=str, default='./res/', help='the fusion results will be saved here') 63 | 64 | opt = parser.parse_args() 65 | 66 | # build the model 67 | fusion_net = FusionNetwork(block_num=opt.blocks, feature_out=False).to(device) 68 | print(fusion_net) 69 | 70 | # load data 71 | tqdm.write('load data...') 72 | 73 | test_loader = get_test_loader( 74 | ir_root=opt.test_ir_root, 75 | vis_root=opt.test_vis_root, 76 | batchsize=1, 77 | shuffle=False 78 | ) 79 | 80 | test( 81 | test_loader=test_loader, 82 | model=fusion_net, 83 | checkpoint=opt.checkpoint, 84 | save_path=opt.save_path 85 | ) 86 | -------------------------------------------------------------------------------- /utils/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | from PIL import Image 5 | import torch.utils.data as data 6 | import torchvision.transforms as transforms 7 | 8 | 9 | def get_test_loader(ir_root, vis_root, batchsize=1, testsize=320, 10 | shuffle=False, num_workers=8, pin_memory=True): 11 | # dataset = TestDataset(ir_root=ir_root, vis_root=vis_root, testsize=testsize) 12 | dataset = TestFusionDataset(ir_root=ir_root, vis_root=vis_root, testsize=testsize) 13 | data_loader = data.DataLoader(dataset=dataset, 14 | batch_size=batchsize, 15 | shuffle=shuffle, 16 | num_workers=num_workers, 17 | pin_memory=pin_memory) 18 | return data_loader 19 | 20 | 21 | def rgb_loader(path): 22 | with open(path, 'rb') as f: 23 | img = Image.open(f) 24 | img = img.resize((1024 // 2, 768 // 2), Image.BILINEAR) 25 | return img.convert('RGB') 26 | 27 | 28 | class TestFusionDataset(data.Dataset): 29 | def __init__(self, ir_root, vis_root, testsize): 30 | self.testsize = testsize 31 | # get filenames 32 | self.irimages = [os.path.join(ir_root, f) for f in os.listdir(ir_root) 33 | if f.endswith('.jpg') or f.endswith('.png') or f.endswith('.bmp')] 34 | self.visimages = [os.path.join(vis_root, f) for f in os.listdir(vis_root) 35 | if f.endswith('.jpg') or f.endswith('.png') or f.endswith('.bmp')] 36 | 37 | # sorted files 38 | self.irimages = sorted(self.irimages) 39 | self.visimages = sorted(self.visimages) 40 | 41 | # transforms 42 | 43 | self.img_transform = transforms.Compose([transforms.ToTensor()]) 44 | self.toPIL = transforms.ToPILImage() 45 | self.size = len(self.visimages) 46 | if len(self.visimages) != len(self.irimages): 47 | raise ValueError('ir and vis img num is different.') 48 | 49 | def __getitem__(self, index): 50 | # read imgs 51 | irimage = self.gray_loader(self.irimages[index]) 52 | visimage_rgb = rgb_loader(self.visimages[index]) 53 | visimage_bri, visimage_clr = self.bri_clr_loader(self.visimages[index]) 54 | 55 | visimage_bri = self.toPIL(visimage_bri) 56 | visimage_clr = self.toPIL(visimage_clr) 57 | 58 | irimage = self.img_transform(irimage) 59 | visimage_rgb = self.img_transform(visimage_rgb) 60 | visimage_bri = self.img_transform(visimage_bri) 61 | visimage_clr = self.img_transform(visimage_clr) 62 | 63 | return irimage, visimage_rgb, visimage_bri, visimage_clr, self.irimages[index] 64 | 65 | def bri_clr_loader(self, path): 66 | img1 = cv2.imread(path) 67 | img1 = cv2.resize(img1, (1024 // 2, 768 // 2), 68 | interpolation=cv2.INTER_LINEAR) 69 | img1 = cv2.cvtColor(img1, cv2.COLOR_RGB2HSV) 70 | color = img1[:, :, 0:2] 71 | brightness = img1[:, :, 2] 72 | return brightness, color 73 | 74 | def gray_loader(self, path): 75 | with open(path, 'rb') as f: 76 | img = Image.open(f) 77 | img = img.resize((1024 // 2, 768 // 2), Image.BILINEAR) 78 | return img.convert('L') 79 | 80 | def binary_loader(self, path): 81 | with open(path, 'rb') as f: 82 | img = Image.open(f) 83 | img = img.resize((1024 // 2, 768 // 2), Image.BILINEAR) 84 | return img.convert('L') 85 | 86 | def __len__(self): 87 | return self.size 88 | -------------------------------------------------------------------------------- /weight/model_weight.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdzhao123/MetaFusion/a8c5eafc84df9399e35602a00e3651a7d3d5c7d1/weight/model_weight.pth --------------------------------------------------------------------------------