├── 0.png ├── 1.png ├── 2.png ├── 3.png ├── 4.png ├── 5.png ├── README.md └── optvis.py /0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elichen/Feature-visualization/HEAD/0.png -------------------------------------------------------------------------------- /1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elichen/Feature-visualization/HEAD/1.png -------------------------------------------------------------------------------- /2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elichen/Feature-visualization/HEAD/2.png -------------------------------------------------------------------------------- /3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elichen/Feature-visualization/HEAD/3.png -------------------------------------------------------------------------------- /4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elichen/Feature-visualization/HEAD/4.png -------------------------------------------------------------------------------- /5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elichen/Feature-visualization/HEAD/5.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Feature-visualization 2 | Deep learning CNN feature visualization 3 | - A Pytorch / Fast.ai port of https://github.com/tensorflow/lucid 4 | 5 | ![](0.png) 6 | ![](1.png) 7 | ![](2.png) 8 | ![](3.png) 9 | ![](4.png) 10 | ![](5.png) 11 | -------------------------------------------------------------------------------- /optvis.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import tensor 4 | import matplotlib.pyplot as plt 5 | from IPython.display import clear_output 6 | from torchvision import transforms 7 | import fastai.vision as vision 8 | 9 | def init_fft_buf(h, w, rand_sd=0.01, **kwargs): 10 | img_buf = np.random.normal(size=(1, 3, h, w//2 + 1, 2), scale=rand_sd).astype(np.float32) 11 | spectrum_t = tensor(img_buf).float().cuda() 12 | return spectrum_t 13 | 14 | def get_fft_scale(h, w, decay_power=.75, **kwargs): 15 | d=.5**.5 # set center frequency scale to 1 16 | fy = np.fft.fftfreq(h,d=d)[:,None] 17 | if w % 2 == 1: 18 | fx = np.fft.fftfreq(w,d=d)[: w // 2 + 2] 19 | else: 20 | fx = np.fft.fftfreq(w,d=d)[: w // 2 + 1] 21 | freqs = (fx*fx + fy*fy) ** decay_power 22 | scale = 1.0 / np.maximum(freqs, 1.0 / (max(w, h)*d)) 23 | scale = tensor(scale).float()[None,None,...,None].cuda() 24 | return scale 25 | 26 | def fft_to_rgb(h, w, t, **kwargs): 27 | scale = get_fft_scale(h, w, **kwargs) 28 | t = scale * t 29 | t = torch.irfft(t, 2, normalized=True, signal_sizes=(h,w)) 30 | return t 31 | 32 | def rgb_to_fft(h, w, t, **kwargs): 33 | t = torch.rfft(t, normalized=True, signal_ndim=2) 34 | scale = get_fft_scale(h, w, **kwargs) 35 | t = t / scale 36 | return t 37 | 38 | def color_correlation_normalized(): 39 | color_correlation_svd_sqrt = np.asarray([[0.26, 0.09, 0.02], 40 | [0.27, 0.00, -0.05], 41 | [0.27, -0.09, 0.03]]).astype(np.float32) 42 | max_norm_svd_sqrt = np.max(np.linalg.norm(color_correlation_svd_sqrt, axis=0)) 43 | color_correlation_normalized = tensor(color_correlation_svd_sqrt / max_norm_svd_sqrt).cuda() 44 | return color_correlation_normalized 45 | 46 | def lucid_colorspace_to_rgb(t): 47 | t_flat = t.permute(0,2,3,1) 48 | t_flat = torch.matmul(t_flat, color_correlation_normalized().T) 49 | t = t_flat.permute(0,3,1,2) 50 | return t 51 | 52 | def rgb_to_lucid_colorspace(t): 53 | t_flat = t.permute(0,2,3,1) 54 | inverse = torch.inverse(color_correlation_normalized().T) 55 | t_flat = torch.matmul(t_flat, inverse) 56 | t = t_flat.permute(0,3,1,2) 57 | return t 58 | 59 | def imagenet_mean_std(): 60 | return (tensor([0.485, 0.456, 0.406]).cuda(), 61 | tensor([0.229, 0.224, 0.225]).cuda()) 62 | 63 | def denormalize(x): 64 | mean, std = imagenet_mean_std() 65 | return x.float()*std[...,None,None] + mean[...,None,None] 66 | 67 | def normalize(x): 68 | mean, std = imagenet_mean_std() 69 | return (x-mean[...,None,None]) / std[...,None,None] 70 | 71 | def image_buf_to_rgb(h, w, img_buf, **kwargs): 72 | img = img_buf.detach() 73 | img = fft_to_rgb(h, w, img, **kwargs) 74 | img = lucid_colorspace_to_rgb(img) 75 | img = torch.sigmoid(img) 76 | img = img[0] 77 | return img 78 | 79 | def show_rgb(img, label=None, ax=None, dpi=25, **kwargs): 80 | plt_show = True if ax == None else False 81 | if ax == None: _, ax = plt.subplots(figsize=(img.shape[2]/dpi,img.shape[1]/dpi)) 82 | x = img.cpu().permute(1,2,0).numpy() 83 | ax.imshow(x) 84 | ax.axis('off') 85 | ax.set_title(label) 86 | if plt_show: plt.show() 87 | 88 | def gpu_affine_grid(size): 89 | size = ((1,)+size) 90 | N, C, H, W = size 91 | grid = torch.FloatTensor(N, H, W, 2).cuda() 92 | linear_points = torch.linspace(-1, 1, W) if W > 1 else tensor([-1.]) 93 | grid[:, :, :, 0] = torch.ger(torch.ones(H), linear_points).expand_as(grid[:, :, :, 0]) 94 | linear_points = torch.linspace(-1, 1, H) if H > 1 else tensor([-1.]) 95 | grid[:, :, :, 1] = torch.ger(linear_points, torch.ones(W)).expand_as(grid[:, :, :, 1]) 96 | return vision.FlowField(size[2:], grid) 97 | 98 | def lucid_transforms(img, jitter=None, scale=.5, degrees=45, **kwargs): 99 | h,w = img.shape[-2], img.shape[-1] 100 | if jitter is None: 101 | jitter = min(h,w)//2 102 | fastai_image = vision.Image(img.squeeze()) 103 | 104 | # pad 105 | fastai_image._flow = gpu_affine_grid(fastai_image.shape) 106 | vision.transform.pad()(fastai_image, jitter) 107 | 108 | # jitter 109 | first_jitter = int((jitter*(2/3))) 110 | vision.transform.crop_pad()(fastai_image, 111 | (h+first_jitter,w+first_jitter), 112 | row_pct=np.random.rand(), col_pct=np.random.rand()) 113 | 114 | # scale 115 | percent = scale * 100 # scale up to integer to avoid float repr errors 116 | scale_factors = [(100 - percent + percent/5. * i)/100 for i in range(11)] 117 | rand_scale = scale_factors[int(np.random.rand()*len(scale_factors))] 118 | fastai_image._flow = gpu_affine_grid(fastai_image.shape) 119 | vision.transform.zoom()(fastai_image, rand_scale) 120 | 121 | # rotate 122 | rotate_factors = list(range(-degrees, degrees+1)) + degrees//2 * [0] 123 | rand_rotate = rotate_factors[int(np.random.rand()*len(rotate_factors))] 124 | fastai_image._flow = gpu_affine_grid(fastai_image.shape) 125 | vision.transform.rotate()(fastai_image, rand_rotate) 126 | 127 | # jitter 128 | vision.transform.crop_pad()(fastai_image, (h,w), row_pct=np.random.rand(), col_pct=np.random.rand()) 129 | 130 | return fastai_image.data[None,:] 131 | 132 | def tensor_stats(t, label=""): 133 | if len(label) > 0: label += " " 134 | return("%smean:%.2f std:%.2f max:%.2f min:%.2f" % (label, t.mean().item(),t.std().item(),t.max().item(),t.min().item())) 135 | 136 | def cossim(act0, act1, cosim_weight=0, **kwargs): 137 | dot = (act0 * act1).sum() 138 | mag0 = act0.pow(2).sum().sqrt() 139 | mag1 = act1.pow(2).sum().sqrt() 140 | cossim = cosim_weight*dot/(mag0*mag1) 141 | return cossim 142 | 143 | def visualize_feature(model, layer, feature, start_image=None, last_hook_out=None, 144 | size=200, steps=500, lr=0.004, weight_decay=0.1, grad_clip=1, 145 | debug=False, frames=10, show=True, **kwargs): 146 | h,w = size if type(size) is tuple else (size,size) 147 | if start_image is not None: 148 | fastai_image = vision.Image(start_image.squeeze()) 149 | fastai_image._flow = gpu_affine_grid((3,h,w)) # resize 150 | img_buf = fastai_image.data[None,:] 151 | img_buf = normalize(img_buf) 152 | img_buf = rgb_to_lucid_colorspace(img_buf) 153 | img_buf = rgb_to_fft(h, w, img_buf, **kwargs) 154 | else: 155 | img_buf = init_fft_buf(h, w, **kwargs) 156 | img_buf.requires_grad_() 157 | opt = torch.optim.AdamW([img_buf], lr=lr, weight_decay=weight_decay) 158 | 159 | hook_out = None 160 | def callback(m, i, o): 161 | nonlocal hook_out 162 | hook_out = o 163 | hook = layer.register_forward_hook(callback) 164 | 165 | for i in range(1,steps+1): 166 | opt.zero_grad() 167 | 168 | img = fft_to_rgb(h, w, img_buf, **kwargs) 169 | img = lucid_colorspace_to_rgb(img) 170 | stats = tensor_stats(img) 171 | img = torch.sigmoid(img) 172 | img = normalize(img) 173 | img = lucid_transforms(img, **kwargs) 174 | model(img.cuda()) 175 | if feature is None: 176 | loss = -1 * hook_out[0].pow(2).mean() 177 | else: 178 | loss = -1 * hook_out[0][feature].mean() 179 | if last_hook_out is not None: 180 | simularity = cossim(hook_out[0], last_hook_out, **kwargs) 181 | loss = loss + loss * simularity 182 | 183 | loss.backward() 184 | torch.nn.utils.clip_grad_norm_(img_buf,grad_clip) 185 | opt.step() 186 | 187 | if debug and (i)%(int(steps/frames))==0: 188 | clear_output(wait=True) 189 | label = f"step: {i} loss: {loss:.2f} stats:{stats}" 190 | show_rgb(image_buf_to_rgb(h, w, img_buf, **kwargs), 191 | label=label, **kwargs) 192 | 193 | hook.remove() 194 | 195 | retval = image_buf_to_rgb(h, w, img_buf, **kwargs) 196 | if show: 197 | if not debug: show_rgb(retval, **kwargs) 198 | return retval, hook_out[0].clone().detach() 199 | --------------------------------------------------------------------------------