├── README.md └── VisDemo.py /README.md: -------------------------------------------------------------------------------- 1 | # Detr-heat-map-visualization 2 | 用于detr的各个头部的位置热图/输出热图可视化 3 | 效果图: 4 | ![1](https://user-images.githubusercontent.com/49540854/163091461-abfcfb51-f85f-46b2-9bfc-e7d431e60a61.jpg) 5 | -------------------------------------------------------------------------------- /VisDemo.py: -------------------------------------------------------------------------------- 1 | # #------------------------------------------------------------# 2 | # 可视化Detr方法: 3 | # spatial attention weight : (cq + oq)*pk 4 | # combined attention weight: (cq + oq)*(memory + pk) 5 | # 其中: 6 | # pk:原始特征图的位置编码; 7 | # oq:训练好的object queries 8 | # cq:decoder最后一层self-attn中的输出query 9 | # memory:encoder的输出 10 | # #------------------------------------------------------------# 11 | # 在此基础上只要稍微修改便可可视化ConditionalDetr的Fig1特征图 12 | # #------------------------------------------------------------# 13 | # 代码参考自:https://github.com/facebookresearch/detr/tree/colab 14 | # #------------------------------------------------------------# 15 | 16 | import math 17 | import numpy as np 18 | 19 | from PIL import Image 20 | import requests 21 | import matplotlib.pyplot as plt 22 | 23 | import ipywidgets as widgets 24 | from IPython.display import display, clear_output 25 | 26 | import torch 27 | from torch import nn 28 | from torchvision.models import resnet50 29 | import torchvision.transforms as T 30 | from torch.nn.functional import dropout,linear,softmax 31 | torch.set_grad_enabled(False) 32 | 33 | 34 | def box_cxcywh_to_xyxy(x): 35 | x_c, y_c, w, h = x.unbind(1) 36 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 37 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 38 | return torch.stack(b, dim=1) 39 | 40 | def rescale_bboxes(out_bbox, size): 41 | img_w, img_h = size 42 | b = box_cxcywh_to_xyxy(out_bbox) 43 | b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32) 44 | return b 45 | 46 | # COCO classes 47 | CLASSES = [ 48 | 'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 49 | 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 50 | 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 51 | 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 52 | 'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 53 | 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 54 | 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass', 55 | 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 56 | 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 57 | 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A', 58 | 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 59 | 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 60 | 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 61 | 'toothbrush' 62 | ] 63 | # colors for visualization 64 | COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125], 65 | [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]] 66 | 67 | # standard PyTorch mean-std input image normalization 68 | transform = T.Compose([ 69 | T.Resize(800), 70 | T.ToTensor(), 71 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 72 | ]) 73 | 74 | 75 | # 加载线上的模型 76 | model = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True) 77 | model.eval() 78 | # 获取训练好的参数 79 | for name, parameters in model.named_parameters(): 80 | # 获取训练好的object queries,即pq:[100,256] 81 | if name == 'query_embed.weight': 82 | pq = parameters 83 | # 获取解码器的最后一层的交叉注意力模块中q和k的线性权重和偏置:[256*3,256],[768] 84 | if name == 'transformer.decoder.layers.5.multihead_attn.in_proj_weight': 85 | in_proj_weight = parameters 86 | if name == 'transformer.decoder.layers.5.multihead_attn.in_proj_bias': 87 | in_proj_bias = parameters 88 | # 线上下载图像 89 | url = 'http://images.cocodataset.org/val2017/000000039769.jpg' 90 | im = Image.open(requests.get(url, stream=True).raw) 91 | # img_path = '/home/wujian/000000039769.jpg' 92 | # im = Image.open(img_path) 93 | 94 | # mean-std normalize the input image (batch-size: 1) 95 | img = transform(im).unsqueeze(0) 96 | 97 | # propagate through the model 98 | outputs = model(img) 99 | 100 | # keep only predictions with 0.7+ confidence 101 | probas = outputs['pred_logits'].softmax(-1)[0, :, :-1] 102 | keep = probas.max(-1).values > 0.9 103 | 104 | # convert boxes from [0; 1] to image scales 105 | bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size) 106 | 107 | # use lists to store the outputs via up-values 108 | conv_features, enc_attn_weights, dec_attn_weights = [], [], [] 109 | cq = [] # 存储detr中的 cq 110 | pk = [] # 存储detr中的 encoder pos 111 | memory = [] # 存储encoder的输出特征图memory 112 | 113 | # 注册hook 114 | hooks = [ 115 | # 获取resnet最后一层特征图 116 | model.backbone[-2].register_forward_hook( 117 | lambda self, input, output: conv_features.append(output) 118 | ), 119 | # 获取encoder的图像特征图memory 120 | model.transformer.encoder.register_forward_hook( 121 | lambda self, input, output: memory.append(output) 122 | ), 123 | # 获取encoder的最后一层layer的self-attn weights 124 | model.transformer.encoder.layers[-1].self_attn.register_forward_hook( 125 | lambda self, input, output: enc_attn_weights.append(output[1]) 126 | ), 127 | # 获取decoder的最后一层layer中交叉注意力的 weights 128 | model.transformer.decoder.layers[-1].multihead_attn.register_forward_hook( 129 | lambda self, input, output: dec_attn_weights.append(output[1]) 130 | ), 131 | # 获取decoder最后一层self-attn的输出cq 132 | model.transformer.decoder.layers[-1].norm1.register_forward_hook( 133 | lambda self, input, output: cq.append(output) 134 | ), 135 | # 获取图像特征图的位置编码pk 136 | model.backbone[-1].register_forward_hook( 137 | lambda self, input, output: pk.append(output) 138 | ), 139 | ] 140 | 141 | # propagate through the model 142 | outputs = model(img) 143 | 144 | # 用完的hook后删除 145 | for hook in hooks: 146 | hook.remove() 147 | 148 | # don't need the list anymore 149 | conv_features = conv_features[0] # [1,2048,25,34] 150 | enc_attn_weights = enc_attn_weights[0] # [1,850,850] : [N,L,S] 151 | dec_attn_weights = dec_attn_weights[0] # [1,100,850] : [N,L,S] --> [batch, tgt_len, src_len] 152 | memory = memory[0] # [850,1,256] 153 | 154 | cq = cq[0] # decoder的self_attn:最后一层输出[100,1,256] 155 | pk = pk[0] # [1,256,25,34] 156 | 157 | # 绘制postion embedding 158 | pk = pk.flatten(-2).permute(2,0,1) # [1,256,850] --> [850,1,256] 159 | pq = pq.unsqueeze(1).repeat(1,1,1) # [100,1,256] 160 | q = pq + cq 161 | #------------------------------------------------------# 162 | # 1) k = pk,则可视化: (cq + oq)*pk 163 | # 2_ k = pk + memory,则可视化 (cq + oq)*(memory + pk) 164 | # 读者可自行尝试 165 | #------------------------------------------------------# 166 | k = pk 167 | # k = pk + memory 168 | #------------------------------------------------------# 169 | 170 | # 将q和k完成线性层的映射,代码参考自nn.MultiHeadAttn() 171 | _b = in_proj_bias 172 | _start = 0 173 | _end = 256 174 | _w = in_proj_weight[_start:_end, :] 175 | if _b is not None: 176 | _b = _b[_start:_end] 177 | q = linear(q, _w, _b) 178 | 179 | _b = in_proj_bias 180 | _start = 256 181 | _end = 256 * 2 182 | _w = in_proj_weight[_start:_end, :] 183 | if _b is not None: 184 | _b = _b[_start:_end] 185 | k = linear(k, _w, _b) 186 | 187 | scaling = float(256) ** -0.5 188 | q = q * scaling 189 | q = q.contiguous().view(100, 8, 32).transpose(0, 1) 190 | k = k.contiguous().view(-1, 8, 32).transpose(0, 1) 191 | attn_output_weights = torch.bmm(q, k.transpose(1, 2)) 192 | 193 | attn_output_weights = attn_output_weights.view(1, 8, 100, 850) 194 | attn_output_weights = attn_output_weights.view(1 * 8, 100, 850) 195 | attn_output_weights = softmax(attn_output_weights, dim=-1) 196 | attn_output_weights = attn_output_weights.view(1, 8, 100, 850) 197 | 198 | # 后续可视化各个头 199 | attn_every_heads = attn_output_weights # [1,8,100,850] 200 | attn_output_weights = attn_output_weights.sum(dim=1) / 8 # [1,100,850] 201 | 202 | #-----------# 203 | # 可视化 204 | #-----------# 205 | # get the feature map shape 206 | h, w = conv_features['0'].tensors.shape[-2:] 207 | 208 | fig, axs = plt.subplots(ncols=len(bboxes_scaled), nrows=10, figsize=(22, 28)) # [11,2] 209 | colors = COLORS * 100 210 | 211 | # 可视化 212 | for idx, ax_i, (xmin, ymin, xmax, ymax) in zip(keep.nonzero(), axs.T, bboxes_scaled): 213 | # 可视化decoder的注意力权重 214 | ax = ax_i[0] 215 | ax.imshow(dec_attn_weights[0, idx].view(h, w)) 216 | ax.axis('off') 217 | ax.set_title(f'query id: {idx.item()}',fontsize = 30) 218 | # 可视化框和类别 219 | ax = ax_i[1] 220 | ax.imshow(im) 221 | ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, 222 | fill=False, color='blue', linewidth=3)) 223 | ax.axis('off') 224 | ax.set_title(CLASSES[probas[idx].argmax()],fontsize = 30) 225 | # 分别可视化8个头部的位置特征图 226 | for head in range(2, 2 + 8): 227 | ax = ax_i[head] 228 | ax.imshow(attn_every_heads[0, head-2, idx].view(h,w)) 229 | ax.axis('off') 230 | ax.set_title(f'head:{head-2}',fontsize = 30) 231 | fig.tight_layout() # 自动调整子图来使其填充整个画布 232 | plt.show() --------------------------------------------------------------------------------