├── README.md └── visualization ├── features_ops.py └── image_ops.py /README.md: -------------------------------------------------------------------------------- 1 | # 目前实现的操作 2 | ## Feature Operations 3 | 1. 获取指定层的特征图 4 | 2. 指定层特征图可视化(带频域可视化) 5 | 3. 特征图之间相关性 6 | 7 | ## Image Operations 8 | 1. 图像放缩 9 | 2. 图像直方图 10 | 3. 残差图 11 | 4. 伪彩色图 12 | 5. 不同模型参数量对比图 13 | 6. 三维散点图 14 | 7. ensemble测试 -------------------------------------------------------------------------------- /visualization/features_ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # ===========================获取指定层的特征图======================================= # 4 | feas = [] 5 | def extract_features(model:torch.nn.Module): 6 | def get_fesa(module, input, output): 7 | feas.append(output) 8 | 9 | for name,sub in model.named_modules(): 10 | if name == 'layers.5.residual_group.blocks': 11 | sub.register_forward_hook(get_fesa) 12 | 13 | # ===========================指定层特征图可视化(带频域可视化)======================================= # 14 | def visual_features(features: torch.Tensor, fre=False, path='./'): 15 | from einops import rearrange 16 | import matplotlib.pyplot as plt 17 | import numpy as np 18 | output = features.data.float().cpu().squeeze() 19 | C,H,W = output.shape 20 | if fre: 21 | # heatmap_fre_tot = np.zeros((H,W,3)) 22 | magnitude_tot = np.zeros((H,W)) 23 | for i in range(output.shape[0]): 24 | f = np.fft.fft2(output[i,:,:]) 25 | f = np.fft.fftshift(f) 26 | magnitude = 20*np.log(np.abs(f)+1) 27 | magnitude_tot += magnitude 28 | magnitude_tot = ((magnitude_tot - magnitude_tot.min())/(magnitude_tot.max()-magnitude_tot.min())) 29 | magnitude_tot = (magnitude_tot * 255.0).round().astype(np.uint8) 30 | output_tot = np.zeros((H,W)) 31 | for i in range(output.shape[0]): 32 | tmp = output[i,:,:] 33 | output_tot += tmp.numpy() 34 | output_tot = ((output_tot - output_tot.min())/(output_tot.max()-output_tot.min())) 35 | output_tot = (output_tot * 255.0).round().astype(np.uint8) 36 | plt.imshow(output_tot, cmap=plt.cm.jet) 37 | plt.savefig(path+'.png') 38 | if fre: 39 | # cv2.imwrite(path+'_fre.png', heatmap_fre_tot) 40 | plt.imshow(magnitude_tot, cmap=plt.cm.jet) 41 | # plt.colorbar() 42 | plt.savefig(path+'_fre.png') 43 | plt.clf() 44 | 45 | # ===========================特征图之间相关性======================================= # 46 | def get_the_mae(features): 47 | b, c, = features.shape[0], features.shape[1] 48 | diff = torch.zeros((c,c)) 49 | for i in range(c): 50 | for j in range(c): 51 | diff[i][j] = (torch.mean(torch.square(features[0,i,:,:] - features[0,j,:,:]))) 52 | 53 | import seaborn as sns 54 | import matplotlib.pyplot as plt 55 | 56 | plt.figure(figsize=(10,8)) 57 | xlabels = ['F1','F2','F3','F4','F5','F6','F7','F8','F9','F10','F11','F12','F13','F14','F15','F16'] 58 | ylabels = ['F1','F2','F3','F4','F5','F6','F7','F8','F9','F10','F11','F12','F13','F14','F15','F16'] 59 | sns.heatmap(diff, xticklabels=xlabels, yticklabels=ylabels, fmt='.2f', annot=True) 60 | plt.title('MSE between feature maps') 61 | plt.margins(0,0) 62 | plt.savefig('sns_heatmap_cmap.jpg', dpi=300) 63 | 64 | return diff 65 | -------------------------------------------------------------------------------- /visualization/image_ops.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import torch 3 | import numpy as np 4 | 5 | from PIL import Image 6 | from torchvision.transforms import transforms 7 | 8 | # ===========================图像放缩======================================= # 9 | def scale_image(path, scale): 10 | img = Image.open(path) 11 | w, h = img.size 12 | img_x2 = img.resize((w//scale, h//scale), resample=Image.BICUBIC) 13 | img_x2.save(path.split('.')[0]+'_scale.png') 14 | 15 | # ===========================直方图======================================= # 16 | def draw_histogram(path): 17 | plt.figure() 18 | labels = ['1', '2', '3', '4', '5'] 19 | y1 = [0.1348,0.9999,0.9938,0.9975,0.0711,1.0000,1.0000,1.0000,0.0361,0.8649,0.1108,1.0000,0.0104,0.2340,1.0000,0.1005].sort() 20 | y2 = [0.9936,0.6699,0.8878,0.0479,0.2529,0.8340,0.4752,0.2570,0.9115,0.2424,0.3269,0.4807,0.8444,0.9860,0.7761,0.1936].sort() 21 | y3 = [0.1160,1.0000,0.2762,1.0000,0.1739,1.0000,0.3655,0.3386,0.2902,0.2745,0.2755,1.0000,1.0000,0.1414,0.7498,0.3521].sort() 22 | y4 = [0.4027,0.5063,0.3615,0.1793,0.3555,0.6465,0.5135,1.0000,0.4005,0.3201,0.6962,0.2700,0.5110,0.9001,0.5695,0.6446].sort() 23 | y5 = [0.2748,0.9998,0.3185,0.8779,0.9924,0.4448,0.9998,0.1454,0.3088,0.1400,0.9908,0.1009,0.1388,0.1151,0.1300,0.2037].sort() 24 | x1 = [i+1 for i in range(len(y1))] 25 | x2 = [i+x1[-1]+2 for i in range(len(y1))] 26 | x3 = [i+x2[-1]+2 for i in range(len(y1))] 27 | x4 = [i+x3[-1]+2 for i in range(len(y1))] 28 | x5 = [i+x4[-1]+2 for i in range(len(y1))] 29 | plt.bar(x1, y1, color='red') 30 | plt.bar(x2, y2, color='green') 31 | plt.bar(x3, y3, color='blue') 32 | plt.bar(x4, y4, color='orange') 33 | plt.bar(x5, y5, color='yellow') 34 | 35 | plt.xticks([i*len(y1) + 9+i for i in range(len(labels))], labels) 36 | plt.ylim(0, 1.2) 37 | plt.ylabel('Sparsity') 38 | plt.savefig(path, dpi = 300) 39 | 40 | # ===========================残差图======================================= # 41 | def draw_residual_imgs(path1, path2, path_out): 42 | img_hr,_,_ = Image.open(path1).convert('YCbCr').split() 43 | img_sr,_,_ = Image.open(path2).convert('YCbCr').split() 44 | 45 | residual = transforms.ToPILImage()(torch.abs(transforms.ToTensor()(img_sr) - transforms.ToTensor()(img_hr))) 46 | residual.save(path_out) 47 | 48 | # ===========================伪彩色图======================================= # 49 | def heat_imgs(path): 50 | plt.figure() 51 | img = plt.imread(path) 52 | img = img/255. 53 | plt.imshow(img, cmap = plt.cm.jet) 54 | plt.colorbar() 55 | plt.savefig(path.split('.')[0]+'_heat.png') 56 | 57 | # ===========================不同模型参数量对比图======================================= # 58 | def model_cmp(): 59 | x = [1.57, 52.70, 6.00, 5.50, 2.07, 2.26, 4.55, 14.00, 29.90, 91.20] 60 | y = [37.27, 36.66, 37.00, 37.06, 37.21, 37.36, 36.83, 37.38, 37.52, 37.53] 61 | params = [9.90, 24.00, 12.46, 25.00, 9.03, 14.63, 21.18, 60.00, 813.00, 412.00] 62 | colors=list(np.arange(1,len(params)+1)/len(params)) 63 | params = np.array(params) 64 | area = np.pi * 16 * 20 * params/(np.pi * 4 * 20 ) 65 | plt.figure() 66 | plt.xlabel('Number of MACs (G)') 67 | plt.ylabel('PSNR (db)') 68 | plt.scatter(x, y, alpha=0.8, s=area, c=colors) 69 | plt.grid() 70 | plt.ylim(36.5, 37.7) 71 | plt.xlim(0, 100) 72 | # plt.legend() 73 | plt.annotate('SGSR-M5', (1.57,37.27), (1.34+2.5,37.27-0.08), weight="bold", color="r", arrowprops=dict(arrowstyle="->", connectionstyle="arc3", color="r")) 74 | plt.annotate('SRCNN', (52.70,36.66), (52.70-4.0,36.66+0.1), weight="bold", color="b", arrowprops=dict(arrowstyle="->", connectionstyle="arc3", color="b")) 75 | plt.annotate('FSRCNN', (6.00,37.00), (6.00+1.0,37.00-0.1), weight="bold", color="b", arrowprops=dict(arrowstyle="->", connectionstyle="arc3", color="b")) 76 | plt.annotate('MOREMNAS-C', (5.50,37.06), (5.50+4.0,37.06-0.02), weight="bold", color="b", arrowprops=dict(arrowstyle="->", connectionstyle="arc3", color="b")) 77 | plt.annotate('SESR-M3', (2.07,37.21), (2.07-2.0,37.21-0.1), weight="bold", color="b", arrowprops=dict(arrowstyle="->", connectionstyle="arc3", color="b")) 78 | # plt.annotate('ICCV21', (3.10,37.32), (3.10+2.0,37.32-0.08), weight="bold", color="b", arrowprops=dict(arrowstyle="->", connectionstyle="arc3", color="b")) 79 | plt.annotate('SGSR-M8', (2.22,37.36), (2.22-1.0,37.36+0.1), weight="bold", color="r", arrowprops=dict(arrowstyle="->", connectionstyle="arc3", color="r")) 80 | plt.annotate('ESPCN', (4.55,36.83), (4.55-1.0,36.83-0.1), weight="bold", color="b", arrowprops=dict(arrowstyle="->", connectionstyle="arc3", color="b")) 81 | plt.annotate('TPSR-NoGAN', (14.00,37.38), (14.00-1.0,37.38-0.1), weight="bold", color="b", arrowprops=dict(arrowstyle="->", connectionstyle="arc3", color="b")) 82 | # plt.annotate('VDSR', (612.60,37.53), (612.60-1.0,37.53-0.1), weight="bold", color="b", arrowprops=dict(arrowstyle="->", connectionstyle="arc3", color="b")) 83 | plt.annotate('LapSRN', (29.90,37.52), (29.90-1.0,37.52-0.1), weight="bold", color="b", arrowprops=dict(arrowstyle="->", connectionstyle="arc3", color="b")) 84 | plt.annotate('CARN-M', (91.20,37.53), (91.20-4.0,37.53-0.1), weight="bold", color="b", arrowprops=dict(arrowstyle="->", connectionstyle="arc3", color="b")) 85 | plt.savefig('./samples/modelsPK.png', dpi = 300) 86 | 87 | # ============================三维散点图======================================== # 88 | def draw_3D_figs(): 89 | fig = plt.figure() 90 | # ax = fig.add_subplot(111, projection='3d') 91 | ax = fig.gca(projection='3d') 92 | ax.view_init(elev=14, azim=-34) 93 | ax.invert_xaxis() 94 | 95 | # For each set of style and range settings, plot n random points in the box 96 | # defined by x in [23, 32], y in [0, 100], z in [zlow, zhigh]. 97 | x = [ 52.70, 6.00, 5.50, 2.07, 4.55, 14.00, 3.10, 2.34] 98 | y = [ 36.66, 37.00, 37.06, 37.21, 36.83, 37.38, 37.32, 37.33] 99 | params = [ 24.00, 12.46, 25.00, 9.03, 21.18, 60.00, 14.00, 10.20] 100 | 101 | x_our = [2.26] 102 | y_our = [37.36] 103 | param_our = [14.63] 104 | # for c, m, zlow, zhigh in [('r', 'o', -50, -25), ('b', '^', -30, -5)]: 105 | # xs = randrange(n, 23, 32) 106 | # ys = randrange(n, 0, 100) 107 | # zs = randrange(n, zlow, zhigh) 108 | ax.scatter(x, params, y, c='b', marker='o') 109 | ax.scatter(x_our, param_our, y_our, c='r', marker='*', s=80, depthshade=False) 110 | ax.text( 2.26,14.93, 37.39, "SRGFS(Ours)") 111 | ax.text( 52.90,24.30, 36.66, "SRCNN(2014)") 112 | ax.text( 6.00,12.46, 37.02, "FSRCNN(2016)") 113 | ax.text( 5.50,25.20, 37.00, "MOREMNAS(2020)") 114 | ax.text( 2.37,9.23, 37.11, "SESR(2021)") 115 | ax.text( 5.10,21.58, 36.83, "ESPCN(2016)") 116 | ax.text( 14.00,60.00, 37.38, "TPSR-NoGAN(2019)") 117 | ax.text( 3.45,14.50, 37.27, "ICCV(2021)") 118 | ax.text( 2.84,10.50, 37.23, "ACMM(2021)") 119 | 120 | label = ["SRCNN", "FSRCNN", "MOREMNAS", "SESR", "ESPCN", "TPSR-NoGAN", "ICCV", "ACMMM", "SRGFS(Ours)"] 121 | 122 | ax.set_xlabel('Params (K)') 123 | ax.set_ylabel('Number of MACs (G)') 124 | ax.set_zlabel('PSNR (dB)') 125 | 126 | plt.savefig('./samples/modelsPK_3d.png', dpi = 300) 127 | 128 | # ===========================ensemble测试======================================= # 129 | def ensemble(path): 130 | img_hr,_,_ = Image.open(path).convert('YCbCr').split() 131 | 132 | # 8种旋转方式 133 | img_hr.save('./samples/{}.png'.format('img_0')) 134 | img_90 = img_hr.rotate(90) 135 | img_90.rotate(-90).save('./samples/{}.png'.format('img_-90')) 136 | img_90.save('./samples/{}.png'.format('img_90')) 137 | img_180 = img_hr.rotate(180) 138 | img_180.save('./samples/{}.png'.format('img_180')) 139 | img_180.rotate(-180).save('./samples/{}.png'.format('img_-180')) 140 | img_270 = img_hr.rotate(270) 141 | img_270.save('./samples/{}.png'.format('img_270')) 142 | img_270.rotate(-270).save('./samples/{}.png'.format('img_-270')) 143 | img_hr_flip = img_hr.transpose(Image.FLIP_TOP_BOTTOM) 144 | img_hr_flip.save('./samples/{}.png'.format('img_hr_flip_0')) 145 | img_hr_flip_90 = img_hr_flip.rotate(90) 146 | img_hr_flip_90.save('./samples/{}.png'.format('img_hr_flip_90')) 147 | img_hr_flip_180 = img_hr_flip.rotate(180) 148 | img_hr_flip_180.save('./samples/{}.png'.format('img_hr_flip_180')) 149 | img_hr_flip_270 = img_hr_flip.rotate(270) 150 | img_hr_flip_270.save('./samples/{}.png'.format('img_hr_flip_270')) 151 | 152 | img_hr_flip.transpose(Image.FLIP_TOP_BOTTOM).save('./samples/{}.png'.format('img_hr_flip_-0')) 153 | img_hr_flip_90.rotate(-90).transpose(Image.FLIP_TOP_BOTTOM).save('./samples/{}.png'.format('img_hr_flip_-90')) 154 | img_hr_flip_180.rotate(-180).transpose(Image.FLIP_TOP_BOTTOM).save('./samples/{}.png'.format('img_hr_flip_-180')) 155 | img_hr_flip_270.rotate(-270).transpose(Image.FLIP_TOP_BOTTOM).save('./samples/{}.png'.format('img_hr_flip_-270')) 156 | --------------------------------------------------------------------------------