├── result.png ├── methodoverview.jpg ├── model ├── MobileNet_v2.pth ├── backbone_vgg16.py ├── MobileNet_v2.py ├── model_vgg16.py └── model_MobileNet_2.py ├── demo.py ├── dataset └── dataset.py ├── tool ├── bbox_to_point.py ├── gwhdcoco_count.py ├── split_dataset.py └── mergecoco.py ├── README.md ├── Visualizations ├── grad_cam.py └── utils.py ├── test.py └── train.py /result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GZU-SAMLab/CSNet/HEAD/result.png -------------------------------------------------------------------------------- /methodoverview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GZU-SAMLab/CSNet/HEAD/methodoverview.jpg -------------------------------------------------------------------------------- /model/MobileNet_v2.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GZU-SAMLab/CSNet/HEAD/model/MobileNet_v2.pth -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | from matplotlib import pyplot as plt 2 | from torchvision import transforms 3 | 4 | from PIL import Image 5 | from torchsummary import summary 6 | 7 | from model.model_vgg16_ import * 8 | 9 | 10 | #定义训练设备 11 | device = torch.device("cuda") 12 | 13 | #创建网络模型 14 | model = Multi_Granularity(decive=device) 15 | model.load_state_dict(torch.load("\\best_weight\CSNet_best_weight.pth")) #权重路径 16 | model.to(device) 17 | model.eval() 18 | img_path = "D:\python_project\practice\count_w\MLP\\test\\0addc041-a6b6-4643-8e10-8b9c51e932f1.png" #图片路径 19 | image = Image.open(img_path) 20 | image = image.convert("RGB") 21 | transform = transforms.ToTensor() 22 | image = transform(image) * 255 23 | Resize = transforms.Resize([512, 512]) 24 | image = Resize(image) 25 | image = torch.reshape(image, (1, 3, 512, 512)) 26 | 27 | output = model(image) 28 | print(output) 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | -------------------------------------------------------------------------------- /dataset/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import ImageFilter 3 | import cv2 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | from PIL import Image 8 | from torch.utils.data import Dataset 9 | from torchvision.io import read_image 10 | from torchvision.transforms import transforms 11 | 12 | 13 | class Countgwhd(Dataset): 14 | def __init__(self, img_path, ann_path, resize_shape): 15 | self.img_path = img_path 16 | self.ann_path = pd.read_csv(ann_path) 17 | self.shape = resize_shape 18 | self.transform = transforms.Resize([self.shape, self.shape]) 19 | 20 | def __getitem__(self, idex): 21 | # 拼接图片 22 | img_path = os.path.join(self.img_path + self.ann_path.iloc[idex, 0]) 23 | # tensor类型 24 | image = Image.open(img_path) 25 | image = image.convert("RGB") 26 | TOtensor = transforms.ToTensor() 27 | image = TOtensor(image) * 255 28 | label = self.ann_path.iloc[idex, 1] 29 | image = self.transform(image) 30 | return image, label 31 | 32 | def __len__(self): 33 | return len(self.ann_path) 34 | 35 | 36 | 37 | -------------------------------------------------------------------------------- /tool/bbox_to_point.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import scipy.io as sio 4 | import os 5 | 6 | path = "\\gwhd_2021\\val" #数据集的图片路径 7 | files = os.listdir(path) #返回文件夹下包含的文件的名字列表 8 | fp = open("gwhd_2021\\annotations\\" + "val.json", "r") #读取文件 9 | json_data = json.load(fp) 10 | image_id = 0 11 | 12 | bbox = [] 13 | for name in files: 14 | points = [] 15 | for image in json_data["images"]: 16 | # print(image) 17 | if image["file_name"] == name: 18 | image_id = image["id"] 19 | # print(image_id) 20 | break 21 | for annotation in json_data["annotations"]: 22 | if annotation["image_id"] == image_id: 23 | bbox = annotation["bbox"] 24 | point = [] 25 | point.append(bbox[0] + bbox[2]//2) 26 | point.append(bbox[1] + bbox[3]//2) 27 | points.append(point) 28 | 29 | data_inner = {"location":points, "number":len(points)} 30 | print(len(points)) 31 | image_info = np.zeros((1,), dtype=object) 32 | image_info[0] = data_inner 33 | 34 | mat_name = name.split(".")[0] + '.mat' 35 | sio.savemat(os.path.join("gwhd_2021\\annotations", "val_density", mat_name), {'image_info': image_info}) 36 | print("完成") 37 | 38 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CSNet 2 | 3 | **Count-Supervised Network (CSNet)** can complete the counting of wheat ears with only quantitative supervision. 4 | 5 | paper: CSNet: A Count-supervised Network via Multiscale MLP-Mixer for Wheat Ear Counting 6 | 7 | ## The Overview of CSNet 8 | ![](methodoverview.jpg) 9 | 10 | 11 | ## About Data 12 | We use the global wheat Head Detection ([dataset](http://www.global-wheat.com/gwhd.html)) for training, where the quantity labels are obtained by summing the target boxes in the dataset. 13 | 14 | ## Code Structure 15 | `train.py` To train the model. 16 | 17 | `test.py` To test the model. 18 | 19 | `demo.py` To predict an image. 20 | 21 | `model/model_vgg16.py` The structure of the network and the backbone is vgg16. 22 | 23 | `model/model_MobileNetV2.py` The structure of the network and the backbone is MobileNetV2. 24 | 25 | `model/backbone_vgg16.py` The structure of the first ten layers of Vgg16. 26 | 27 | `tool/gwhdcoco_count.py` To convert the coco label in the GWHD dataset to count label. 28 | 29 | `tool/mergecoco.py` To combine multiple Json files of coco label into one. 30 | 31 | ## Training 32 | ```shell 33 | python train.py --batch_size=16 --epoch=1000 --lr=1e-4 --device="cuda" 34 | ``` 35 | # Testing 36 | ```shell 37 | python test.py --data_root "/home/liyaoxi/data/gwhd/" --device="cuda" 38 | ``` 39 | 40 | ## Mode Weight 41 | 42 | 43 | ## Result 44 | ![](result.png) 45 | 46 | -------------------------------------------------------------------------------- /model/backbone_vgg16.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torchsummary import summary 4 | 5 | 6 | class Backbone_VGG16(nn.Module): 7 | def __init__(self): 8 | super(Backbone_VGG16, self).__init__() 9 | self.features = nn.Sequential( 10 | nn.Conv2d(3, 64, 3, padding=1), 11 | nn.ReLU(), 12 | nn.Conv2d(64, 64, 3, padding=1), 13 | nn.ReLU(), 14 | nn.MaxPool2d(2, 2), 15 | nn.Conv2d(64, 128, 3, padding=1), 16 | nn.ReLU(), 17 | nn.Conv2d(128, 128, 3, padding=1), 18 | nn.ReLU(), 19 | nn.MaxPool2d(2, 2), 20 | nn.Conv2d(128, 256, 3, padding=1), 21 | nn.ReLU(), 22 | nn.Conv2d(256, 256, 3, padding=1), 23 | nn.ReLU(), 24 | nn.Conv2d(256, 256, 3, padding=1), 25 | nn.ReLU(), 26 | nn.MaxPool2d(2, 2), 27 | nn.Conv2d(256, 512, 3, padding=1), 28 | nn.ReLU(), 29 | nn.Conv2d(512, 512, 3, padding=1), 30 | nn.ReLU(), 31 | nn.Conv2d(512, 512, 3, padding=1), 32 | nn.ReLU(), 33 | ) 34 | 35 | 36 | def forward(self, x): 37 | x = self.features(x) 38 | return x 39 | 40 | if __name__ == '__main__': 41 | vgg16 = Backbone_VGG16() 42 | input = torch.ones((2, 3, 512, 512)) 43 | summary(vgg16, (3, 512, 512)) 44 | output = vgg16(input) 45 | print(output.shape) 46 | -------------------------------------------------------------------------------- /tool/gwhdcoco_count.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import csv 4 | 5 | #装载json文件 6 | def load_json(json_path): 7 | with open(json_path) as f: 8 | file_json = json.load(f) 9 | return file_json 10 | 11 | #把框标记数据集转换为计数数据集 12 | def transform_count(file_path): 13 | files = next(os.walk(file_path))[2] 14 | #循环每一个json文件 15 | for i in range(len(files)): 16 | json_path = file_path + str(files[i]) 17 | json_file = load_json(json_path) 18 | images_len = len(json_file['images']) 19 | ann_len = len(json_file['annotations']) 20 | 21 | #循环每一个图片 22 | for j in range(images_len): 23 | image_name = json_file['images'][j]['file_name'] 24 | image_id = json_file['images'][j]['id'] 25 | # 创建csv文件 26 | f = open('\\annotation\\train\\{}.csv'.format(image_name.split('.')[0]), 'w', encoding='UTF-8', newline='') 27 | writer = csv.writer(f) 28 | dot = ("x", "y") 29 | writer.writerow(dot) 30 | for k in range(ann_len): 31 | if json_file['annotations'][k]['image_id'] == image_id: 32 | spot = json_file['annotations'][k]['bbox'] 33 | dot = (spot[0]+0.5*spot[2], spot[1]+0.5*spot[3]) 34 | writer.writerow(dot) 35 | # print(dot) 36 | # spots.append(dot) 37 | # count = count + 1 38 | f.close() 39 | print('成功输入第{}张图片'.format(j+1)) 40 | print('成功完成{}文件'.format(files[i])) 41 | print('全部完成!') 42 | 43 | 44 | file_path = "annotation\\" 45 | if __name__ == "__main__": 46 | transform_count(file_path) 47 | 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /Visualizations/grad_cam.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from PIL import Image 5 | import matplotlib.pyplot as plt 6 | from torchvision import transforms 7 | from utils import GradCAM, show_cam_on_image 8 | 9 | from ..model.model_vgg16 import * 10 | 11 | 12 | 13 | 14 | def main(): 15 | # 定义训练设备 16 | device = torch.device("cuda") 17 | # 创建网络模型 18 | model = Multi_Granularity(device=device) 19 | model.load_state_dict(torch.load("\\best_weight\\CSNet_best_weight.pth", )) 20 | model.to(device) 21 | target_layers = [model.backbone.features[-2]] 22 | 23 | data_transform = transforms.Compose([transforms.ToTensor()]) 24 | Resize = transforms.Resize([512, 512]) 25 | # load image 26 | img_path = "/count_w/MLP/test/figure1_7.png" 27 | img = Image.open(img_path).convert('RGB') 28 | img = Resize(img) 29 | 30 | # [N, C, H, W] 31 | img_tensor = data_transform(img)*255 # expand batch dimension 32 | input_tensor = torch.reshape(img_tensor, (1, 3, 512, 512)) 33 | 34 | cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False) 35 | target = torch.reshape(torch.tensor([4.]),(1, 1)) # 填入真实数量 36 | # target = 254 # pug, pug-dog 37 | 38 | grayscale_cam = cam(input_tensor=input_tensor, target=target) 39 | img = np.array(img, dtype=np.uint8) 40 | grayscale_cam = grayscale_cam[0, :] 41 | visualization = show_cam_on_image(img.astype(dtype=np.float32)/255. , 42 | grayscale_cam, 43 | use_rgb=True) 44 | plt.imshow(visualization) 45 | plt.axis("off") 46 | # plt.savefig("C:\\Users\liyaoxi\Desktop\gwhd\\figure1_7_cam.png", bbox_inches="tight", pad_inches=0) 47 | plt.show() 48 | 49 | 50 | 51 | if __name__ == '__main__': 52 | main() 53 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import torch 4 | from torch.utils.data import DataLoader 5 | from torch.utils.tensorboard import SummaryWriter 6 | import math 7 | import dataset.dataset 8 | from model.model_vgg16 import * 9 | 10 | def main(): 11 | print("start testing") 12 | parser = argparse.ArgumentParser('Set parameters for test ', add_help=False) 13 | parser.add_argument('--device', default="cuda", type=str) 14 | parser.add_argument("--img_size", default=512, type=int) 15 | parser.add_argument("--data_root", default='/home/liyaoxi/mmdetection/data/gwhd/', type=str) 16 | args = parser.parse_args() 17 | 18 | # 定义训练参数 19 | device = args.device 20 | img_size = args.img_size 21 | 22 | data_root = args.data_root 23 | test_root = data_root + 'test/' 24 | test_ann = data_root + 'annotations/test.csv' 25 | 26 | test_dataset = dataset.Countgwhd(img_path=test_root, ann_path=test_ann, resize_shape=img_size) 27 | 28 | #创建网络模型 29 | model = Multi_Granularity(device=device) 30 | model.to(device) 31 | test(test_dataset,model, device) 32 | 33 | def test(test_dataset, model, device): 34 | #加载数据集 35 | val_dataloader = DataLoader(test_dataset, batch_size=1) 36 | model.eval() 37 | 38 | mae = 0.0 39 | mse = 0.0 40 | i = 0 41 | for data in val_dataloader: 42 | i = i + 1 43 | imgs, targets = data 44 | imgs = imgs.to(device) 45 | targets = targets.to(device) 46 | with torch.no_grad(): 47 | output = model(imgs) 48 | count = torch.sum(output).item() 49 | 50 | gt_count = torch.sum(targets).item() 51 | mae += abs(gt_count - count) 52 | mse += abs(gt_count - count) * abs(gt_count - count) 53 | 54 | print("真实数量:{} \t 预测数量:{}".format(gt_count, count)) 55 | 56 | mae = mae * 1.0 / i 57 | mse = math.sqrt(mse / i) 58 | print("此次测试结果为:MAE:{} \t MSE:{}".format(mae, mse)) 59 | 60 | 61 | if __name__ == '__main__': 62 | main() 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /tool/split_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import copy 4 | import csv 5 | import shutil 6 | 7 | 8 | def subset(alist, idxs): 9 | ''' 10 | 用法:根据下标idxs取出列表alist的子集 11 | alist: list 12 | idxs: list 13 | ''' 14 | sub_list = [] 15 | for idx in idxs: 16 | sub_list.append(alist[idx]) 17 | 18 | return sub_list 19 | 20 | 21 | def split_list(alist, rate=0.2 ,shuffle=True): 22 | ''' 23 | 用法:将alist切分成两个列表 24 | shuffle: 表示是否要随机切分列表,默认为True 25 | rate:表示划分比率 26 | ''' 27 | 28 | index = list(range(len(alist))) # 保留下标 29 | 30 | # 是否打乱列表 31 | if shuffle: 32 | random.shuffle(index) 33 | 34 | elem_num = int(len(alist) * 0.2) # 划分列表所含有的元素数量 35 | sub_lists = [] 36 | 37 | 38 | for i in range(elem_num): 39 | sub_lists.append(copy.deepcopy(alist[index[i]])) 40 | 41 | return sub_lists 42 | 43 | 44 | def SplitDataset(path): 45 | file_names = [] 46 | #读取文件夹下所有文件的名称 47 | for file_name in os.listdir(path): 48 | file_names.append(file_name) 49 | #随机划分文件名称 50 | sub_list = split_list(file_names) 51 | #更改数据集标注 52 | all_reader = csv.reader(file_all) 53 | train_writer = csv.writer(file_train) 54 | test_writer = csv.writer(file_test) 55 | 56 | i = 0 57 | for line in all_reader: 58 | i = i + 1 59 | print(line) 60 | if line[0] == "image_name": 61 | train_writer.writerow(line) 62 | test_writer.writerow(line) 63 | else: 64 | if line[0] in sub_list: 65 | test_writer.writerow(line) 66 | shutil.copyfile(os.path.join(path, line[0]), os.path.join(path_test, line[0])) 67 | else: 68 | train_writer.writerow(line) 69 | shutil.copyfile(os.path.join(path, line[0]), os.path.join(path_train, line[0])) 70 | print("成功{}".format(i)) 71 | # print(all_reader) 72 | file_all.close() 73 | file_train.close() 74 | file_test.close() 75 | print("全部成功") 76 | 77 | 78 | 79 | 80 | path = '..\gwhd' 81 | path_train = '..\gwhd\\train' 82 | path_test = '..\gwhd\\test' 83 | file_path = "..\gwhd\\" 84 | file_all = open( file_path + "all.csv", "r", encoding="utf-8-sig") 85 | file_train = open( file_path + "train.csv", "w", newline="") 86 | file_test = open( file_path + "test.csv", "w", newline="") 87 | if __name__ == "__main__": 88 | SplitDataset(path) 89 | -------------------------------------------------------------------------------- /tool/mergecoco.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | 5 | def load_json(filenamejson): 6 | with open(filenamejson) as f: 7 | raw_data = json.load(f) 8 | return raw_data 9 | 10 | 11 | def merge_coco_json(file_path): 12 | file_count = 0 13 | files = next(os.walk(file_path))[2] 14 | for x in range(len(files)): 15 | # file_suffix = str(files[x]).split(".")[1] 16 | # file_name = str(files[x]).split(".")[0] 17 | 18 | # 计数 19 | file_count = file_count + 1 20 | # 组合文件路径 21 | filenamejson = file_path + str(files[x]) 22 | # 读取文件 23 | if x == 0: 24 | # 第一个文件作为root 25 | root_data = load_json(filenamejson) 26 | else: 27 | raw_data = load_json(filenamejson) 28 | # 追加images的数据 29 | ##root_data['images'].append(raw_data['images'][0]) 30 | 31 | ###追加images 32 | root_images_len = len(root_data['images']) 33 | raw_images_len = len(raw_data['images']) 34 | for i in range(raw_images_len): 35 | raw_data['images'][i]['id'] = int(raw_data['images'][i]['id']) + int(root_images_len) 36 | root_data['images'].extend(raw_data['images']) 37 | 38 | ###追加annotations 39 | root_annotations_len = len(root_data['annotations']) 40 | raw_annotations_len = len(raw_data['annotations']) 41 | for j in range(raw_annotations_len): 42 | raw_data['annotations'][j]['id'] = int(raw_data['annotations'][j]['id']) + int(root_annotations_len) 43 | raw_data['annotations'][j]['image_id'] = int(raw_data['annotations'][j]['image_id']) + int( 44 | root_images_len) 45 | root_data['annotations'].extend(raw_data['annotations']) 46 | 47 | temp = [] 48 | for m in root_data["categories"]: 49 | if m not in temp: 50 | temp.append(m) 51 | root_data["categories"] = temp 52 | print("共处理 {0} 个json文件".format(file_count)) 53 | print("共找到 {0} 个类别".format(str(root_data["categories"]).count('name', 0, len(str(root_data["categories"]))))) 54 | 55 | json_str = json.dumps(root_data) 56 | with open('merge.json', 'w') as json_file: 57 | json_file.write(json_str) 58 | # 写出合并文件 59 | 60 | print("Done!") 61 | 62 | 63 | file_path = "C:\\Users\liyaoxi\\Desktop\\data\\" #待合并的路径 64 | if __name__ == "__main__": 65 | merge_coco_json(file_path) -------------------------------------------------------------------------------- /model/MobileNet_v2.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch 4 | from torchsummary import summary 5 | def conv_bn(inp, oup, stride): 6 | return nn.Sequential( 7 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 8 | nn.BatchNorm2d(oup), 9 | nn.ReLU6(inplace=True) 10 | ) 11 | 12 | 13 | def conv_1x1_bn(inp, oup): 14 | return nn.Sequential( 15 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 16 | nn.BatchNorm2d(oup), 17 | nn.ReLU6(inplace=True) 18 | ) 19 | 20 | 21 | def make_divisible(x, divisible_by=8): 22 | import numpy as np 23 | return int(np.ceil(x * 1. / divisible_by) * divisible_by) 24 | 25 | 26 | class InvertedResidual(nn.Module): 27 | def __init__(self, inp, oup, stride, expand_ratio): 28 | super(InvertedResidual, self).__init__() 29 | self.stride = stride 30 | assert stride in [1, 2] 31 | 32 | hidden_dim = int(inp * expand_ratio) 33 | self.use_res_connect = self.stride == 1 and inp == oup 34 | 35 | if expand_ratio == 1: 36 | self.conv = nn.Sequential( 37 | # dw 38 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 39 | nn.BatchNorm2d(hidden_dim), 40 | nn.ReLU6(inplace=True), 41 | # pw-linear 42 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 43 | nn.BatchNorm2d(oup), 44 | ) 45 | else: 46 | self.conv = nn.Sequential( 47 | # pw 48 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 49 | nn.BatchNorm2d(hidden_dim), 50 | nn.ReLU6(inplace=True), 51 | # dw 52 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 53 | nn.BatchNorm2d(hidden_dim), 54 | nn.ReLU6(inplace=True), 55 | # pw-linear 56 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 57 | nn.BatchNorm2d(oup), 58 | ) 59 | 60 | def forward(self, x): 61 | if self.use_res_connect: 62 | return x + self.conv(x) 63 | else: 64 | return self.conv(x) 65 | 66 | 67 | class MobileNetV2(nn.Module): 68 | def __init__(self, n_class=1000, input_size=224, width_mult=1.): 69 | super(MobileNetV2, self).__init__() 70 | block = InvertedResidual 71 | input_channel = 32 72 | # last_channel = 1280 73 | last_channel = 512 74 | interverted_residual_setting = [ 75 | # t, c, n, s 76 | [1, 16, 1, 1], 77 | [6, 24, 2, 2], 78 | [6, 32, 3, 2], 79 | [6, 64, 4, 2], 80 | [6, 96, 3, 1], 81 | # [6, 160, 3, 2], 82 | # [6, 320, 1, 1], 83 | ] 84 | 85 | # building first layer 86 | assert input_size % 32 == 0 87 | # input_channel = make_divisible(input_channel * width_mult) # first channel is always 32! 88 | self.last_channel = make_divisible(last_channel * width_mult) if width_mult > 1.0 else last_channel 89 | self.features = [conv_bn(3, input_channel, 2)] 90 | # building inverted residual blocks 91 | for t, c, n, s in interverted_residual_setting: 92 | output_channel = make_divisible(c * width_mult) if t > 1 else c 93 | for i in range(n): 94 | if i == 0: 95 | self.features.append(block(input_channel, output_channel, s, expand_ratio=t)) 96 | else: 97 | self.features.append(block(input_channel, output_channel, 1, expand_ratio=t)) 98 | input_channel = output_channel 99 | # building last several layers 100 | self.features.append(conv_1x1_bn(input_channel, self.last_channel)) 101 | # make it nn.Sequential 102 | self.features = nn.Sequential(*self.features) 103 | 104 | # building classifier 105 | # self.classifier = nn.Linear(self.last_channel, n_class) 106 | 107 | self._initialize_weights() 108 | 109 | def forward(self, x): 110 | x = self.features(x) 111 | # print(x.shape) 112 | # x = x.mean(3).mean(2) 113 | # x = self.classifier(x) 114 | return x 115 | 116 | def _initialize_weights(self): 117 | for m in self.modules(): 118 | if isinstance(m, nn.Conv2d): 119 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 120 | m.weight.data.normal_(0, math.sqrt(2. / n)) 121 | if m.bias is not None: 122 | m.bias.data.zero_() 123 | elif isinstance(m, nn.BatchNorm2d): 124 | m.weight.data.fill_(1) 125 | m.bias.data.zero_() 126 | elif isinstance(m, nn.Linear): 127 | n = m.weight.size(1) 128 | m.weight.data.normal_(0, 0.01) 129 | m.bias.data.zero_() 130 | 131 | 132 | 133 | 134 | 135 | if __name__ == '__main__': 136 | net = MobileNetV2() 137 | # summary(neck, (3, 512, 512)) 138 | # net.load_state_dict(torch.load("MOBileNet_v2.pth")) 139 | # print(net.state_dict()) 140 | summary(net,(3, 512, 512)) 141 | 142 | 143 | 144 | 145 | -------------------------------------------------------------------------------- /model/model_vgg16.py: -------------------------------------------------------------------------------- 1 | 2 | from backbone_vgg16 import * 3 | from torch import nn 4 | from fightingcv_attention.attention.CBAM import CBAMBlock 5 | 6 | 7 | class mlp_block(nn.Module): 8 | def __init__(self, in_channels, mlp_dim, drop_ratio=0.): 9 | super().__init__() 10 | self.block = nn.Sequential( 11 | nn.Linear(in_channels, mlp_dim), 12 | nn.GELU(), 13 | nn.Dropout(drop_ratio), 14 | nn.Linear(mlp_dim, in_channels), 15 | nn.Dropout(drop_ratio) 16 | ) 17 | 18 | def forward(self, x): 19 | x = self.block(x) 20 | return x 21 | 22 | class mlp_layer(nn.Module): 23 | def __init__(self, seq_length_s, hidden_size_c, dc, ds, drop=0.): 24 | super().__init__() 25 | self.ln = nn.LayerNorm(hidden_size_c) 26 | # 注意两个block分别作用于输入的行和列, 即SXC,所以in_channels不一样 27 | self.token_mixing = mlp_block(in_channels=seq_length_s, mlp_dim=int(dc * seq_length_s), drop_ratio=drop) 28 | self.channel_mixing = mlp_block(in_channels=hidden_size_c, mlp_dim=int(ds * hidden_size_c), drop_ratio=drop) 29 | 30 | def forward(self, x): 31 | x1 = self.ln(x) 32 | x2 = x1.transpose(1, 2) # 转置矩阵 33 | x3 = self.token_mixing(x2) 34 | x4 = x3.transpose(1, 2) 35 | 36 | y1 = x + x4 # skip-connection 37 | y2 = self.ln(y1) 38 | y3 = self.channel_mixing(y2) 39 | y = y1 + y3 40 | 41 | return y 42 | 43 | # 按照paper中的 Table 1 来配置参数 44 | class mlp_mixer(nn.Module): 45 | def __init__(self, 46 | in_channels=3, 47 | layer_num=4, 48 | patch_size=32, 49 | hidden_size_c=768, 50 | seq_length_s=49, 51 | dc=0.5, 52 | ds=4, 53 | drop=0. 54 | ): 55 | super().__init__() 56 | self.in_channels = in_channels 57 | self.patch_size = patch_size 58 | self.layer_num = layer_num 59 | self.hidden_size_c = hidden_size_c 60 | self.seq_length_s = seq_length_s 61 | self.dc = dc 62 | self.ds = ds 63 | 64 | self.ln = nn.LayerNorm(self.hidden_size_c) 65 | 66 | # 图片切割并做映射embedding,通过一个卷积实现 67 | self.proj = nn.Conv2d(self.in_channels, self.hidden_size_c, kernel_size=self.patch_size, 68 | stride=self.patch_size) 69 | 70 | # 添加多个mixer-layer 71 | self.mixer_layer = nn.ModuleList([]) 72 | for _ in range(self.layer_num): 73 | self.mixer_layer.append(mlp_layer(seq_length_s, hidden_size_c, ds, dc, drop)) 74 | 75 | 76 | # 定义正向传播过程 77 | def forward(self, x): 78 | 79 | # flatten: [B, C, H, W] -> [B, C, HW] # 第二个维度上展平 刚好是高度维度 80 | # transpose: [B, C, HW] -> [B, HW, C] 81 | x = self.proj(x).flatten(2).transpose(1, 2) 82 | print(x.shape) 83 | for mixer_layer in self.mixer_layer: 84 | x = mixer_layer(x) 85 | x = self.ln(x) 86 | return x 87 | 88 | 89 | 90 | class Multi_Granularity(nn.Module): 91 | def __init__(self, 92 | layer_num = 4, 93 | hidden_size_c = 256, 94 | img_size = 32, 95 | device = "cuda" 96 | ): 97 | super().__init__() 98 | self.layer_num = layer_num 99 | self.hidden_size_c = hidden_size_c 100 | self.img_size = img_size 101 | self.device = device 102 | self.mlp_mixer = nn.ModuleList([]) 103 | self.mlp_layer = nn.ModuleList([]) 104 | self.patch = [16, 8, 4, 32, 64] 105 | self.seq = [16, 64, 256, 256] 106 | self.seq_all = 336 107 | # self.conv = nn.Conv2d(512, 128, 1) 108 | 109 | 110 | for i in range(3): 111 | self.mlp_mixer.append(mlp_mixer(in_channels=512, layer_num=self.layer_num, patch_size=self.patch[i], 112 | hidden_size_c=self.hidden_size_c, seq_length_s=self.seq[i])) 113 | 114 | for _ in range(self.layer_num): 115 | self.mlp_layer.append(mlp_layer(hidden_size_c=self.hidden_size_c, seq_length_s= self.seq_all, dc=0.5, ds=4)) 116 | 117 | self.CounterHead = nn.Sequential( 118 | nn.Flatten(), 119 | nn.ReLU(), 120 | nn.Linear(self.hidden_size_c * self.seq_all, 512), 121 | # nn.LayerNorm(512), 122 | nn.ReLU(), 123 | nn.Dropout(0.5), 124 | nn.Linear(512, 10), 125 | # nn.Linear(10, 1), 126 | nn.AvgPool1d((10)), 127 | nn.ReLU() 128 | 129 | ) 130 | # self.upsample = nn.Upsample(scale_factor=2, mode="nearest") 131 | self.backbone = Backbone_VGG16() 132 | self.backbone.to(device=self.device) 133 | self.backbone.load_state_dict(torch.load("VGG16_10.pth", map_location=self.device),strict=False) 134 | # for param in self.backbone.parameters(): 135 | # param.requires_grad = False 136 | 137 | self.cbam = CBAMBlock(channel=512, reduction=16, kernel_size=7) 138 | self.cbam.to(device=self.device) 139 | 140 | 141 | 142 | 143 | def forward(self, x): 144 | #经过vgg16 145 | x1 = self.backbone(x) 146 | x1 = self.cbam(x1) 147 | # x1 = self.conv(x1) 148 | 149 | 150 | 151 | x_coarse = self.mlp_mixer[0](x1) 152 | x_middle = self.mlp_mixer[1](x1) 153 | x_fine = self.mlp_mixer[2](x1) 154 | print(x_coarse.shape) 155 | x_all = torch.cat([ x_coarse, x_middle, x_fine], 1) 156 | for mlp_layer in self.mlp_layer: 157 | x_all = mlp_layer(x_all) 158 | x_cout = self.CounterHead(x_all) 159 | return x_cout 160 | 161 | 162 | if __name__ == '__main__': 163 | neck = Multi_Granularity(device="cpu") 164 | # neck.load_state_dict(torch.load("mobilenetv2_1.0-f2a8633.pth", map_location="cpu")) 165 | # summary(neck, (3, 512, 512)) 166 | input = torch.ones((16, 3, 512, 512)) 167 | imgs = torch.ones((16, 3, 512, 512)) 168 | output = neck(input) 169 | # print(output) 170 | 171 | -------------------------------------------------------------------------------- /model/model_MobileNet_2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | 4 | from MobileNet_v2 import * 5 | from torch import nn 6 | from fightingcv_attention.attention.CBAM import CBAMBlock 7 | from fightingcv_attention.attention.BAM import BAMBlock 8 | from einops.layers.torch import Rearrange 9 | 10 | 11 | class mlp_block(nn.Module): 12 | def __init__(self, in_channels, mlp_dim, drop_ratio=0.): 13 | super().__init__() 14 | self.block = nn.Sequential( 15 | nn.Linear(in_channels, mlp_dim), 16 | nn.GELU(), 17 | nn.Dropout(drop_ratio), 18 | nn.Linear(mlp_dim, in_channels), 19 | nn.Dropout(drop_ratio) 20 | ) 21 | 22 | def forward(self, x): 23 | x = self.block(x) 24 | return x 25 | 26 | 27 | class mlp_layer(nn.Module): 28 | def __init__(self, seq_length_s, hidden_size_c, dc, ds, drop=0.): 29 | super().__init__() 30 | self.ln = nn.LayerNorm(hidden_size_c) 31 | # 注意两个block分别作用于输入的行和列, 即SXC,所以in_channels不一样 32 | self.token_mixing = mlp_block(in_channels=seq_length_s, mlp_dim=int(dc * seq_length_s), drop_ratio=drop) 33 | self.channel_mixing = mlp_block(in_channels=hidden_size_c, mlp_dim=int(ds * hidden_size_c), drop_ratio=drop) 34 | 35 | def forward(self, x): 36 | x1 = self.ln(x) 37 | x2 = x1.transpose(1, 2) # 转置矩阵 38 | x3 = self.token_mixing(x2) 39 | x4 = x3.transpose(1, 2) 40 | 41 | y1 = x + x4 # skip-connection 42 | y2 = self.ln(y1) 43 | y3 = self.channel_mixing(y2) 44 | y = y1 + y3 45 | 46 | return y 47 | 48 | 49 | # 按照paper中的 Table 1 来配置参数 50 | class mlp_mixer(nn.Module): 51 | def __init__(self, 52 | in_channels=3, 53 | layer_num=4, 54 | patch_size=32, 55 | hidden_size_c=768, 56 | seq_length_s=49, 57 | dc=0.5, 58 | ds=4, 59 | drop=0. 60 | ): 61 | super().__init__() 62 | self.in_channels = in_channels 63 | self.patch_size = patch_size 64 | self.layer_num = layer_num 65 | self.hidden_size_c = hidden_size_c 66 | self.seq_length_s = seq_length_s 67 | self.dc = dc 68 | self.ds = ds 69 | 70 | self.ln = nn.LayerNorm(self.hidden_size_c) 71 | 72 | # 图片切割并做映射embedding,通过一个卷积实现 73 | self.proj = nn.Conv2d(self.in_channels, self.hidden_size_c, kernel_size=self.patch_size, 74 | stride=self.patch_size) 75 | 76 | # 添加多个mixer-layer 77 | self.mixer_layer = nn.ModuleList([]) 78 | for _ in range(self.layer_num): 79 | self.mixer_layer.append(mlp_layer(seq_length_s, hidden_size_c, ds, dc, drop)) 80 | 81 | # 定义正向传播过程 82 | def forward(self, x): 83 | 84 | # flatten: [B, C, H, W] -> [B, C, HW] # 第二个维度上展平 刚好是高度维度 85 | # transpose: [B, C, HW] -> [B, HW, C] 86 | x = self.proj(x).flatten(2).transpose(1, 2) 87 | for mixer_layer in self.mixer_layer: 88 | x = mixer_layer(x) 89 | x = self.ln(x) 90 | return x 91 | 92 | 93 | class Multi_Granularity(nn.Module): 94 | def __init__(self, 95 | layer_num=4, 96 | hidden_size_c=256, 97 | img_size=32, 98 | device="cuda" 99 | ): 100 | super().__init__() 101 | self.layer_num = layer_num 102 | self.hidden_size_c = hidden_size_c 103 | self.img_size = img_size 104 | self.device = device 105 | self.mlp_mixer = nn.ModuleList([]) 106 | self.mlp_layer = nn.ModuleList([]) 107 | self.patch = [16, 8, 4, 32, 64] 108 | self.seq = [16, 64, 256, 256] 109 | self.seq_all = 336 110 | # self.conv = nn.Conv2d(512, 128, 1) 111 | 112 | for i in range(3): 113 | self.mlp_mixer.append(mlp_mixer(in_channels=512, layer_num=self.layer_num, patch_size=self.patch[i], 114 | hidden_size_c=self.hidden_size_c, seq_length_s=self.seq[i])) 115 | 116 | for _ in range(self.layer_num): 117 | self.mlp_layer.append(mlp_layer(hidden_size_c=self.hidden_size_c, seq_length_s=self.seq_all, dc=0.5, ds=4)) 118 | 119 | self.CounterHead = nn.Sequential( 120 | nn.Flatten(), 121 | nn.ReLU(), 122 | nn.Linear(self.hidden_size_c * self.seq_all, 512), 123 | # nn.LayerNorm(512), 124 | nn.ReLU(), 125 | nn.Dropout(0.5), 126 | nn.Linear(512, 10), 127 | # nn.Linear(10, 1), 128 | nn.AvgPool1d((10)), 129 | nn.ReLU() 130 | 131 | ) 132 | self.upsample = nn.Upsample(scale_factor=2, mode="nearest") 133 | self.backbone = MobileNetV2() 134 | self.backbone.to(device=self.device) 135 | self.backbone.load_state_dict(torch.load("MobileNet_v2.pth"), strict=False) 136 | # for param in self.backbone.parameters(): 137 | # param.requires_grad = False 138 | 139 | self.cbam = CBAMBlock(channel=512, reduction=16, kernel_size=7) 140 | self.cbam.to(device=self.device) 141 | 142 | def forward(self, x): 143 | # 经过vgg16 144 | x1 = self.backbone(x) 145 | x1 = self.cbam(x1) 146 | x1 = self.upsample(x1) 147 | # x1 = self.conv(x1) 148 | 149 | x_coarse = self.mlp_mixer[0](x1) 150 | x_middle = self.mlp_mixer[1](x1) 151 | x_fine = self.mlp_mixer[2](x1) 152 | 153 | x_all = torch.cat([x_coarse, x_middle, x_fine], 1) 154 | for mlp_layer in self.mlp_layer: 155 | x_all = mlp_layer(x_all) 156 | x_cout = self.CounterHead(x_all) 157 | return x_cout 158 | 159 | 160 | if __name__ == '__main__': 161 | neck = Multi_Granularity(device="cpu") 162 | # neck.load_state_dict(torch.load("mobilenetv2_1.0-f2a8633.pth", map_location="cpu")) 163 | summary(neck, (3, 512, 512)) 164 | input = torch.ones((16, 3, 512, 512)) 165 | imgs = torch.ones((16, 3, 512, 512)) 166 | output = neck(input) 167 | print(output) 168 | 169 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | 4 | from torch.utils.data import DataLoader 5 | from torch.utils.tensorboard import SummaryWriter 6 | import math 7 | import datase.dataset 8 | from model.model_vgg16 import * 9 | 10 | 11 | workers = 32 12 | save_file = "weight_best.pth" 13 | save_sumary = "gwhd_2020" 14 | def main(): 15 | print("start training") 16 | parser = argparse.ArgumentParser('Set parameters for training ', add_help=False) 17 | parser.add_argument('--device', default="cuda", type=str) 18 | parser.add_argument('--batch_size', default=32, type=int) 19 | parser.add_argument("--epoch", default=1000, type=int) 20 | parser.add_argument("--img_size", default=512, type=int) 21 | parser.add_argument("--lr", default=0.01, type=float) 22 | args = parser.parse_args() 23 | 24 | # 定义训练参数 25 | device = args.device 26 | learning_rate = args.lr 27 | # 网络超参数 28 | batch_size = args.batch_size 29 | epoch = args.epoch 30 | img_size = args.img_size 31 | 32 | #数据集路径 33 | data_root = '/home/liyaoxi/data/gwhd/gwhd_2021/' 34 | train_root = data_root + 'train/' 35 | train_ann = data_root + 'annotations/train.csv' 36 | val_root = data_root + 'val/' 37 | val_ann = data_root + 'annotations/val.csv' 38 | test_root = data_root + 'test/' 39 | test_ann = data_root + 'annotations/test.csv' 40 | 41 | 42 | # 准备数据集 43 | train_dataset = dataset.Countgwhd(img_path=train_root, ann_path=train_ann, resize_shape=img_size) 44 | val_dataset = dataset.Countgwhd(img_path=val_root, ann_path=val_ann, resize_shape=img_size) 45 | test_dataset = dataset.Countgwhd(img_path=test_root, ann_path=test_ann, resize_shape=img_size) 46 | 47 | #创建网络模型 48 | model = Multi_Granularity(device=device) 49 | model.to(device) 50 | 51 | #创建损失函数 52 | loss_fn = nn.L1Loss(reduction="sum") 53 | loss_fn = loss_fn.to(device) 54 | 55 | #优化器 56 | optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) 57 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[800], gamma=0.1) 58 | 59 | torch.set_num_threads(workers) 60 | writer = SummaryWriter("log/"+ save_sumary) 61 | best = 10 62 | val_epoch = 0 63 | #训练 64 | for i in range(epoch): 65 | start1 = time.time() 66 | print("----------epoch: {}, lr: {}----------".format(i + 1, optimizer.param_groups[0]['lr'])) 67 | loss_one_epoch = train(train_dataset, model, loss_fn, optimizer, lr_scheduler, args.batch_size, device, mask) 68 | writer.add_scalar("train_loss", loss_one_epoch, i) 69 | end1 = time.time() 70 | print("这轮所用时间为:{}min \n\n".format((end1-start1)/60)) 71 | 72 | if (i+1) % 5 == 0 and i>=9 : 73 | val_epoch = val_epoch + 1 74 | start2 = time.time() 75 | print("----------开始验证----------") 76 | prec = val(val_dataset, model, device) 77 | if prec < best: 78 | best = prec 79 | torch.save(model.state_dict(), save_file) 80 | end2 = time.time() 81 | print("测试所用时间为:{}min".format((end2-start2)/60)) 82 | print("当前最好的mae为:{}".format(best)) 83 | writer.add_scalar("val_mae", prec, val_epoch) 84 | 85 | print("\n----------开始测试----------") 86 | test(test_dataset, model, device, best) 87 | writer.close() 88 | 89 | def train(train_dataset, model, loss_fn, optimizer, lr_scheduler, batch_size, device, mask): 90 | # 加载数据集 91 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, 92 | num_workers=workers) 93 | model.train() 94 | loss_ave = 0 95 | print_freq = 20 96 | data_num = 0 97 | for data in train_dataloader: 98 | 99 | data_num = data_num + 1 100 | imgs, targets = data 101 | 102 | 103 | imgs = imgs.to(device) 104 | 105 | targets = targets.float().to(device) 106 | 107 | outputs = model(imgs) 108 | outputs = torch.reshape(outputs, [-1]) 109 | 110 | loss = loss_fn(outputs, targets) 111 | 112 | optimizer.zero_grad() 113 | loss.backward() 114 | optimizer.step() 115 | loss_ave = loss_ave + loss.item() 116 | if data_num % print_freq == 0: 117 | print("---loss: {}---".format(loss.item())) 118 | print("----------本轮的平均loss为:{} ----------\n".format(loss_ave / data_num)) 119 | lr_scheduler.step() 120 | return loss_ave / data_num 121 | 122 | def val(val_dataset, model, device): 123 | #加载数据集 124 | val_dataloader = DataLoader(val_dataset, batch_size=1) 125 | model.eval() 126 | 127 | mae = 0.0 128 | mse = 0.0 129 | i = 0 130 | 131 | for data in val_dataloader: 132 | i = i + 1 133 | imgs, targets = data 134 | imgs = imgs.to(device) 135 | targets = targets.to(device) 136 | with torch.no_grad(): 137 | output = model(imgs) 138 | count = torch.sum(output).item() 139 | 140 | gt_count = torch.sum(targets).item() 141 | mae += abs(gt_count - count) 142 | mse += abs(gt_count - count) * abs(gt_count - count) 143 | 144 | if i % 15 == 0: 145 | print("真实数量:{} \t 预测数量:{}".format(gt_count, count)) 146 | 147 | mae = mae * 1.0 / i 148 | mse = math.sqrt(mse / i) 149 | print("此次测试结果为:MAE:{} \t MSE:{}".format(mae, mse)) 150 | 151 | return mae 152 | 153 | def test(test_dataset, model, device, best): 154 | #加载数据集 155 | if best < 10: 156 | model.load_state_dict(torch.load(save_file)) 157 | else: 158 | pass 159 | val_dataloader = DataLoader(test_dataset, batch_size=1) 160 | model.eval() 161 | 162 | mae = 0.0 163 | mse = 0.0 164 | i = 0 165 | for data in val_dataloader: 166 | i = i + 1 167 | imgs, targets = data 168 | imgs = imgs.to(device) 169 | targets = targets.to(device) 170 | with torch.no_grad(): 171 | output = model(imgs) 172 | count = torch.sum(output).item() 173 | 174 | gt_count = torch.sum(targets).item() 175 | mae += abs(gt_count - count) 176 | mse += abs(gt_count - count) * abs(gt_count - count) 177 | 178 | print("真实数量:{} \t 预测数量:{}".format(gt_count, count)) 179 | 180 | mae = mae * 1.0 / i 181 | mse = math.sqrt(mse / i) 182 | print("此次测试结果为:MAE:{} \t MSE:{}".format(mae, mse)) 183 | 184 | 185 | if __name__ == '__main__': 186 | main() 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | -------------------------------------------------------------------------------- /Visualizations/utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from torch import nn 4 | 5 | class ActivationsAndGradients: 6 | """ Class for extracting activations and 7 | registering gradients from targeted intermediate layers """ 8 | 9 | def __init__(self, model, target_layers, reshape_transform): 10 | self.model = model 11 | self.gradients = [] 12 | self.activations = [] 13 | self.reshape_transform = reshape_transform 14 | self.handles = [] 15 | for target_layer in target_layers: 16 | self.handles.append( 17 | target_layer.register_forward_hook( 18 | self.save_activation)) 19 | # Backward compatibility with older pytorch versions: 20 | if hasattr(target_layer, 'register_' 21 | 'full_backward_hook'): 22 | self.handles.append( 23 | target_layer.register_full_backward_hook( 24 | self.save_gradient)) 25 | else: 26 | self.handles.append( 27 | target_layer.register_backward_hook( 28 | self.save_gradient)) 29 | 30 | def save_activation(self, module, input, output): 31 | activation = output 32 | if self.reshape_transform is not None: 33 | activation = self.reshape_transform(activation) 34 | self.activations.append(activation.cpu().detach()) 35 | 36 | def save_gradient(self, module, grad_input, grad_output): 37 | # Gradients are computed in reverse order 38 | grad = grad_output[0] 39 | if self.reshape_transform is not None: 40 | grad = self.reshape_transform(grad) 41 | self.gradients = [grad.cpu().detach()] + self.gradients 42 | 43 | def __call__(self, x): 44 | self.gradients = [] 45 | self.activations = [] 46 | return self.model(x) 47 | 48 | def release(self): 49 | for handle in self.handles: 50 | handle.remove() 51 | 52 | 53 | class GradCAM: 54 | def __init__(self, 55 | model, 56 | target_layers, 57 | reshape_transform=None, 58 | use_cuda=False): 59 | self.model = model.eval() 60 | self.target_layers = target_layers 61 | self.reshape_transform = reshape_transform 62 | self.cuda = use_cuda 63 | if self.cuda: 64 | self.model = model.cuda() 65 | self.activations_and_grads = ActivationsAndGradients( 66 | self.model, target_layers, reshape_transform) 67 | 68 | """ Get a vector of weights for every channel in the target layer. 69 | Methods that return weights channels, 70 | will typically need to only implement this function. """ 71 | 72 | @staticmethod 73 | def get_cam_weights(grads): 74 | return np.mean(grads, axis=(2, 3), keepdims=True) 75 | 76 | @staticmethod 77 | def get_loss(output, target): 78 | Loss = nn.L1Loss(reduction="sum") 79 | loss = Loss(output,target) 80 | print(output) 81 | print(loss) 82 | return loss 83 | 84 | def get_cam_image(self, activations, grads): 85 | weights = self.get_cam_weights(grads) 86 | weighted_activations = weights * activations 87 | cam = weighted_activations.sum(axis=1) 88 | 89 | return cam 90 | 91 | @staticmethod 92 | def get_target_width_height(input_tensor): 93 | width, height = input_tensor.size(-1), input_tensor.size(-2) 94 | return width, height 95 | 96 | def compute_cam_per_layer(self, input_tensor): 97 | activations_list = [a.cpu().data.numpy() 98 | for a in self.activations_and_grads.activations] 99 | grads_list = [g.cpu().data.numpy() 100 | for g in self.activations_and_grads.gradients] 101 | target_size = self.get_target_width_height(input_tensor) 102 | 103 | cam_per_target_layer = [] 104 | # Loop over the saliency image from every layer 105 | 106 | for layer_activations, layer_grads in zip(activations_list, grads_list): 107 | cam = self.get_cam_image(layer_activations, layer_grads) 108 | cam[cam < 0] = 0 # works like mute the min-max scale in the function of scale_cam_image 109 | scaled = self.scale_cam_image(cam, target_size) 110 | cam_per_target_layer.append(scaled[:, None, :]) 111 | 112 | return cam_per_target_layer 113 | 114 | def aggregate_multi_layers(self, cam_per_target_layer): 115 | cam_per_target_layer = np.concatenate(cam_per_target_layer, axis=1) 116 | cam_per_target_layer = np.maximum(cam_per_target_layer, 0) 117 | result = np.mean(cam_per_target_layer, axis=1) 118 | return self.scale_cam_image(result) 119 | 120 | @staticmethod 121 | def scale_cam_image(cam, target_size=None): 122 | result = [] 123 | for img in cam: 124 | img = img - np.min(img) 125 | img = img / (1e-7 + np.max(img)) 126 | if target_size is not None: 127 | img = cv2.resize(img, target_size) 128 | result.append(img) 129 | result = np.float32(result) 130 | 131 | return result 132 | 133 | def __call__(self, input_tensor, target=None): 134 | 135 | if self.cuda: 136 | input_tensor = input_tensor.cuda() 137 | 138 | # 正向传播得到网络输出logits(未经过softmax) 139 | output = self.activations_and_grads(input_tensor) 140 | self.model.zero_grad() 141 | loss = self.get_loss(output, target) 142 | loss.backward(retain_graph=True) 143 | 144 | # In most of the saliency attribution papers, the saliency is 145 | # computed with a single target layer. 146 | # Commonly it is the last convolutional layer. 147 | # Here we support passing a list with multiple target layers. 148 | # It will compute the saliency image for every image, 149 | # and then aggregate them (with a default mean aggregation). 150 | # This gives you more flexibility in case you just want to 151 | # use all conv layers for example, all Batchnorm layers, 152 | # or something else. 153 | cam_per_layer = self.compute_cam_per_layer(input_tensor) 154 | return self.aggregate_multi_layers(cam_per_layer) 155 | 156 | def __del__(self): 157 | self.activations_and_grads.release() 158 | 159 | def __enter__(self): 160 | return self 161 | 162 | def __exit__(self, exc_type, exc_value, exc_tb): 163 | self.activations_and_grads.release() 164 | if isinstance(exc_value, IndexError): 165 | # Handle IndexError here... 166 | print( 167 | f"An exception occurred in CAM with block: {exc_type}. Message: {exc_value}") 168 | return True 169 | 170 | 171 | def show_cam_on_image(img: np.ndarray, 172 | mask: np.ndarray, 173 | use_rgb: bool = False, 174 | colormap: int = cv2.COLORMAP_JET) -> np.ndarray: 175 | """ This function overlays the cam mask on the image as an heatmap. 176 | By default the heatmap is in BGR format. 177 | 178 | :param img: The base image in RGB or BGR format. 179 | :param mask: The cam mask. 180 | :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format. 181 | :param colormap: The OpenCV colormap to be used. 182 | :returns: The default image with the cam overlay. 183 | """ 184 | 185 | heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap) 186 | if use_rgb: 187 | heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) 188 | heatmap = np.float32(heatmap) / 255 189 | 190 | if np.max(img) > 1: 191 | raise Exception( 192 | "The input image should np.float32 in the range [0, 1]") 193 | 194 | cam = heatmap + img 195 | cam = cam / np.max(cam) 196 | return np.uint8(255 * cam) 197 | 198 | 199 | def center_crop_img(img: np.ndarray, size: int): 200 | h, w, c = img.shape 201 | 202 | if w == h == size: 203 | return img 204 | 205 | if w < h: 206 | ratio = size / w 207 | new_w = size 208 | new_h = int(h * ratio) 209 | else: 210 | ratio = size / h 211 | new_h = size 212 | new_w = int(w * ratio) 213 | 214 | img = cv2.resize(img, dsize=(new_w, new_h)) 215 | 216 | if new_w == size: 217 | h = (new_h - size) // 2 218 | img = img[h: h+size] 219 | else: 220 | w = (new_w - size) // 2 221 | img = img[:, w: w+size] 222 | 223 | return img --------------------------------------------------------------------------------