├── code ├── common │ ├── __init__.py │ ├── __pycache__ │ │ ├── tools.cpython-37.pyc │ │ ├── tools.cpython-38.pyc │ │ ├── utils.cpython-37.pyc │ │ ├── utils.cpython-38.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── transforms.cpython-37.pyc │ │ └── transforms.cpython-38.pyc │ ├── transforms.py │ ├── tools.py │ └── utils.py ├── models │ ├── __pycache__ │ │ ├── mvssnet.cpython-37.pyc │ │ ├── mvssnet.cpython-38.pyc │ │ ├── unet_model.cpython-37.pyc │ │ ├── unet_model.cpython-38.pyc │ │ ├── unet_parts.cpython-37.pyc │ │ └── unet_parts.cpython-38.pyc │ ├── unet_model.py │ ├── unet_parts.py │ └── mvssnet.py ├── loss.py ├── inference.py ├── inference_model_ensemble.py ├── train.py ├── th-search.py └── dataset.py ├── run.sh ├── requirements.txt ├── Dockerfile └── README.md /code/common/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | python3 code/inference_model_ensemble.py.py -------------------------------------------------------------------------------- /code/common/__pycache__/tools.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyx3911/tianchi-image-tamper/HEAD/code/common/__pycache__/tools.cpython-37.pyc -------------------------------------------------------------------------------- /code/common/__pycache__/tools.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyx3911/tianchi-image-tamper/HEAD/code/common/__pycache__/tools.cpython-38.pyc -------------------------------------------------------------------------------- /code/common/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyx3911/tianchi-image-tamper/HEAD/code/common/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /code/common/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyx3911/tianchi-image-tamper/HEAD/code/common/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /code/common/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyx3911/tianchi-image-tamper/HEAD/code/common/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /code/common/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyx3911/tianchi-image-tamper/HEAD/code/common/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /code/models/__pycache__/mvssnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyx3911/tianchi-image-tamper/HEAD/code/models/__pycache__/mvssnet.cpython-37.pyc -------------------------------------------------------------------------------- /code/models/__pycache__/mvssnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyx3911/tianchi-image-tamper/HEAD/code/models/__pycache__/mvssnet.cpython-38.pyc -------------------------------------------------------------------------------- /code/common/__pycache__/transforms.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyx3911/tianchi-image-tamper/HEAD/code/common/__pycache__/transforms.cpython-37.pyc -------------------------------------------------------------------------------- /code/common/__pycache__/transforms.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyx3911/tianchi-image-tamper/HEAD/code/common/__pycache__/transforms.cpython-38.pyc -------------------------------------------------------------------------------- /code/models/__pycache__/unet_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyx3911/tianchi-image-tamper/HEAD/code/models/__pycache__/unet_model.cpython-37.pyc -------------------------------------------------------------------------------- /code/models/__pycache__/unet_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyx3911/tianchi-image-tamper/HEAD/code/models/__pycache__/unet_model.cpython-38.pyc -------------------------------------------------------------------------------- /code/models/__pycache__/unet_parts.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyx3911/tianchi-image-tamper/HEAD/code/models/__pycache__/unet_parts.cpython-37.pyc -------------------------------------------------------------------------------- /code/models/__pycache__/unet_parts.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyx3911/tianchi-image-tamper/HEAD/code/models/__pycache__/unet_parts.cpython-38.pyc -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | albumentations==1.1.0 2 | Cython==0.29.24 3 | future==0.18.2 4 | loguru==0.5.3 5 | numpy==1.21.4 6 | opencv-contrib-python==4.5.4.60 7 | opencv-python==4.5.4.58 8 | opencv-python-headless==4.5.5.62 9 | Pillow==8.4.0 10 | segmentation-models-pytorch==0.2.1 11 | torch==1.8.1 12 | torchvision==0.9.1 13 | tqdm==4.62.3 14 | matplotlib==3.4.3 -------------------------------------------------------------------------------- /code/common/transforms.py: -------------------------------------------------------------------------------- 1 | # from albumentations.pytorch.functional import img_to_tensor 2 | import pdb 3 | 4 | 5 | 6 | 7 | def direct_val(imgs): 8 | normalize = {"mean": [0.485, 0.456, 0.406], 9 | "std": [0.229, 0.224, 0.225]} 10 | if len(imgs) != 1: 11 | pdb.set_trace() 12 | imgs = img_to_tensor(imgs[0], normalize).unsqueeze(0) 13 | return imgs -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:20.04 2 | MAINTAINER liyingxuan 913797866@qq.com 3 | 4 | RUN apt-get update 5 | RUN apt-get upgrade -y 6 | 7 | # Install python3 8 | RUN apt-get install -y python3.7 9 | RUN apt-get install -y libgl1-mesa-dev 10 | RUN DEBIAN_FRONTEND="noninteractive" apt-get install -y libglib2.0-dev 11 | 12 | ADD . / 13 | WORKDIR / 14 | 15 | # Install pip 16 | RUN apt-get install -y python3-pip 17 | RUN pip install --upgrade pip 18 | RUN chmod +x requirements.txt 19 | RUN pip3 install -r requirements.txt 20 | 21 | CMD ["sh", "run.sh"] -------------------------------------------------------------------------------- /code/common/tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch.utils.data 4 | import torchvision.transforms as transforms 5 | from common.transforms import direct_val 6 | import pdb 7 | debug = 0 8 | 9 | 10 | transform_pil = transforms.Compose([ 11 | transforms.ToPILImage(), 12 | ]) 13 | 14 | 15 | def run_model(model, inputs): 16 | output = model(inputs) 17 | return output 18 | 19 | 20 | def inference_single(img, model, th=0): 21 | model.eval() 22 | with torch.no_grad(): 23 | img = img.reshape((-1, img.shape[-3], img.shape[-2], img.shape[-1])) 24 | img = direct_val(img) 25 | img = img.cuda() 26 | 27 | print(img) 28 | _, seg = run_model(model, img) 29 | print(seg) 30 | seg = torch.sigmoid(seg).detach().cpu() 31 | print(seg) 32 | if torch.isnan(seg).any() or torch.isinf(seg).any(): 33 | max_score = 0.0 34 | else: 35 | max_score = torch.max(seg).numpy() 36 | seg = [np.array(transform_pil(seg[i])) for i in range(len(seg))] 37 | print(seg) 38 | 39 | if len(seg) != 1: 40 | pdb.set_trace() 41 | else: 42 | fake_seg = seg[0] 43 | if th == 0: 44 | return fake_seg, max_score 45 | fake_seg = 255.0 * (fake_seg > 255 * th) 46 | fake_seg = fake_seg.astype(np.uint8) 47 | 48 | return fake_seg, max_score 49 | 50 | -------------------------------------------------------------------------------- /code/models/unet_model.py: -------------------------------------------------------------------------------- 1 | from models.unet_parts import * 2 | import torch.nn as nn 3 | 4 | 5 | class Unet(nn.Module): 6 | def __init__(self, n_channels, n_classes): 7 | super(Unet, self).__init__() 8 | self.inc = inconv(n_channels, 64) 9 | self.down1 = U_down(64, 128) 10 | self.down2 = U_down(128, 256) 11 | self.down3 = U_down(256, 512) 12 | self.down4 = U_down(512, 512) 13 | self.up1 = U_up(1024, 256) 14 | self.up2 = U_up(512, 128) 15 | self.up3 = U_up(256, 64) 16 | self.up4 = U_up(128, 64) 17 | self.out = outconv(64, n_classes) 18 | 19 | def forward(self, x): 20 | x1 = self.inc(x) 21 | x2 = self.down1(x1) 22 | x3 = self.down2(x2) 23 | x4 = self.down3(x3) 24 | x5 = self.down4(x4) 25 | x = self.up1(x5, x4) 26 | x = self.up2(x, x3) 27 | x = self.up3(x, x2) 28 | x = self.up4(x, x1) 29 | x = self.out(x) 30 | return x 31 | 32 | 33 | class Res_Unet(nn.Module): 34 | def __init__(self, n_channels, n_classes): 35 | super(Res_Unet, self).__init__() 36 | self.down = RU_first_down(n_channels, 32) 37 | self.down1 = RU_down(32, 64) 38 | self.down2 = RU_down(64, 128) 39 | self.down3 = RU_down(128, 256) 40 | self.down4 = RU_down(256, 256) 41 | self.up1 = RU_up(512, 128) 42 | self.up2 = RU_up(256, 64) 43 | self.up3 = RU_up(128, 32) 44 | self.up4 = RU_up(64, 32) 45 | self.out = outconv(32, n_classes) 46 | 47 | def forward(self, x): 48 | x1 = self.down(x) 49 | x2 = self.down1(x1) 50 | x3 = self.down2(x2) 51 | x4 = self.down3(x3) 52 | x5 = self.down4(x4) 53 | 54 | x = self.up1(x5, x4) 55 | x = self.up2(x, x3) 56 | x = self.up3(x, x2) 57 | x = self.up4(x, x1) 58 | x = self.out(x) 59 | return x 60 | 61 | 62 | class Ringed_Res_Unet(nn.Module): 63 | def __init__(self, n_channels=3, n_classes=1): 64 | super(Ringed_Res_Unet, self).__init__() 65 | self.down = RRU_first_down(n_channels, 32) 66 | self.down1 = RRU_down(32, 64) 67 | self.down2 = RRU_down(64, 128) 68 | self.down3 = RRU_down(128, 256) 69 | self.down4 = RRU_down(256, 256) 70 | self.up1 = RRU_up(512, 128) 71 | self.up2 = RRU_up(256, 64) 72 | self.up3 = RRU_up(128, 32) 73 | self.up4 = RRU_up(64, 32) 74 | self.out = outconv(32, n_classes) 75 | 76 | def forward(self, x): 77 | x1 = self.down(x) 78 | x2 = self.down1(x1) 79 | x3 = self.down2(x2) 80 | x4 = self.down3(x3) 81 | x5 = self.down4(x4) 82 | x = self.up1(x5, x4) 83 | x = self.up2(x, x3) 84 | x = self.up3(x, x2) 85 | x = self.up4(x, x1) 86 | x = self.out(x) 87 | return x 88 | 89 | class Ringed_Res_Unet_Slim(nn.Module): 90 | def __init__(self, n_channels=3, n_classes=1): 91 | super(Ringed_Res_Unet_Slim, self).__init__() 92 | self.down = RRU_first_down(n_channels, 32) 93 | self.down1 = RRU_down(32, 32) 94 | self.down2 = RRU_down(32, 32) 95 | self.up3 = RRU_up(64, 32) 96 | self.up4 = RRU_up(64, 32) 97 | self.out = outconv(32, n_classes) 98 | 99 | def forward(self, x): 100 | x1 = self.down(x) 101 | x2 = self.down1(x1) 102 | x3 = self.down2(x2) 103 | x = self.up3(x3, x2) 104 | x = self.up4(x, x1) 105 | x = self.out(x) 106 | return x -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 天池—真实场景篡改图像检测挑战赛 2 | 3 | 队伍名称:欢乐摸鱼~ 4 | 5 | 团队成员:今天也要摸鱼呀~、liyingxuan、瑶瑶子可可爱爱 6 | 7 | ## 环境 8 | 9 | 系统:Ubuntu 20.04 10 | 11 | 显卡:2080Ti 12 | 13 | ## Test 14 | 15 | 代码结构和环境配置方式按照天池复赛的代码规范和docker要求进行整理。 16 | 17 | - 可以直接拉取我上传到阿里云的镜像(公有仓库) 18 | 19 | registry.cn-hangzhou.aliyuncs.com/liyingxuan/tianchi_lyx:final 20 | 21 | - 如果重新build镜像的话 22 | 23 | 24 | ```bash 25 | docker build -t tianchi . 26 | ``` 27 | 28 | 运行(要挂载一下存储): 29 | 30 | ```bash 31 | nvidia docker run tianchi -v 32 | ``` 33 | 34 | 单模型推理: 35 | 36 | ```bash 37 | python code/inference.py --model unet --pretrained-ckpt 38 | ``` 39 | 40 | 多模型推理(需要自己去py文件中修改一下路径) 41 | 42 | ```bash 43 | python code/inference_model_ensemble.py 44 | ``` 45 | 46 | ## Train 47 | 48 | 具体的方案见“方案说明” 49 | 50 | ```bash 51 | cd code 52 | python train.py --model unet --work-dir ./work_dir/unet/ 53 | ``` 54 | 55 | ## 方案说明 56 | 57 | 按照五折交叉验证的方式对训练集进行划分 58 | 59 | ### baseline 60 | 61 | - 尝试了mvssnet,rrunet,最终选择了以efficientnet-b5为encoder的unet++网络 62 | - 损失函数为diceloss和bceloss 63 | - 图片resize成512*512大小 64 | - AdamW优化器 65 | - 对训练集的4000张数据按照五折交叉验证的方式进行划分,训练过程中保存验证集上得分最高的模型 66 | 67 | 初赛时,Unet单模最高分数大约为2350分,mvssnet大约1900分,rrunet大约2100分。将encoder换成efficientnet-b7会更好,但是我2080Ti太小了,训练效率很低。 68 | 69 | ### 滑窗裁剪 70 | 71 | - 考虑到对图片进行缩放会丢失太多的细节信息,不利于篡改区域检测,所以使用滑动窗口的方法将图像裁剪成512*512的小块,重叠区域为128 72 | - 推理的时候也将图像进行裁剪,重叠区域的预测结果取平均 73 | - 但有时候需要结合整张图片的全局信息才比较好进行判断,将图像resize成512分辨率输入网络得到的结果记为$M_{resize}$,将滑窗裁剪得到的结果记为$M_{slice}$,最终预测结果为$\alpha M_{resize} + (1-\alpha)M_{slice}$ 74 | - 调整阈值$th$(在验证集上设置不同的阈值,进行搜索得到最优解,$\alpha$也是在验证集上搜索得到),搜索的代码为`th-search.py` 75 | 76 | 仅仅用2350分的模型(直接resize图片进行训练),但在推理时采用$0.7M_{resize} + 0.3M_{slice}$的策略,阈值设为0.3,单模分数就达到了2560。 77 | 78 | 后来对训练集也进行裁剪,推理时$M = 0.8M_{resize} + 0.2M_{slice}$,阈值设为0.4,单模分数差不多是2700。 79 | 80 | ### 数据增强 81 | 82 | - 加入了sea3的训练数据(效果其实很小,在baseline的基础上单模大概提升了40-50分) 83 | - 对图像进行二次篡改,包括从同一张图像随机复制一块,从另一张图像随机复制一块,以及随机擦除,离线生成了3000张二次篡改图像(这里感谢Dave大佬开源的方案) 84 | - 或者在线进行数据增强(`dataset.py`中的`ManiDatasetAug`类) 85 | 86 | 数据增强用离线和在线效果差不多,大约提升了30-40分。 87 | 88 | ### 模型集成 89 | 90 | - 2-3个模型的预测结果取平均再进行阈值化 91 | 92 | 一般能在单模的基础上提升50-150分不等,有时候又会起到反效果,看脸。 93 | 94 | 来不及训练5个模型,如果可以的话可能提升更多。 95 | 96 | ### 半监督 97 | 98 | 用提交分数最高的模型给测试集打伪标签,借鉴Tri-training的方式筛选部分测试集加入训练集中,具体为: 99 | 100 | - 用fold0,fold1和fold2训练的模型model0,model1和model2分别推理测试集的样本,得到的结果pre0,pre1和pre2,并经过阈值化得到res0,res1和res2 101 | - 对于model0,计算res1和res2中交并比最高的$x$张图像,并将对应的pre1和pre2的结果进行融合,并二值化,加入model0的新训练集中 102 | - model1和model2执行类似model0的操作,筛选$x$张测试集及其伪标签加入训练集中 103 | - 迭代,并在迭代的过程中逐渐增大$x$,我是500、1000、1500、2000、2500 104 | 105 | 分数提升很明显,初赛提升200多分(应该没到极限),复赛提升100多。但是$x$增大到3000的时候分数就开始下降了,说明伪标签的质量并不高。 106 | 107 | ### 最终方案 108 | 109 | - efficientnet-b5,Unet 110 | - 加入sea3的训练数据 111 | - 生成二次篡改图像3000张 or 在线数据增强 112 | - 半监督,打伪标签,迭代训练 113 | - 集成了两个模型:使用768分辨率训练的模型(resize),使用512滑窗切片训练的模型($0.8M_{resize} + 0.2M_{slice}$) 114 | 115 | ### 分数: 116 | 117 | 初赛:2961/3000张 118 | 119 | 复赛:1440/2000张 120 | 121 | ### 不足 122 | 123 | - 前期对语义分割任务了解不多,忽视了backbone的重要性,从前排大佬的讨论来看用swin Unet或者ConvNeXt-XL作为baseline单模的分数就会很高。从复赛top1和top5大佬的方案来看,他们都是使用了比较大的模型作为baseline,所以初赛阶段不加任何技巧单模的分数就能到2900分。 124 | - 可以尝试更为先进的半监督方案,top1大佬的方式是,对于预测结果中大于0.7的,认为是篡改区域,小于0.3的认为是非篡改区域,0.3-0.7的认为是置信度较低的区域,生成一个mask,计算loss的时候忽略这部分区域,使用所有4000张的测试集数据进行半监督训练。 125 | 126 | ## TOP方案复盘 127 | 128 | 什么是真正的暴力美学。 129 | 130 | 因为天池不强制要求选手公开自己的方案,我目前只看到了TOP1和TOP5的方案。 131 | 132 | TOP1和TOP5都使用了swin-v2或ConvNeXt-L这种大模型,然后将分辨率调整到1280,加上半监督。 133 | 134 | ## 参考资料 135 | 136 | - [关于数据增强-天池技术圈-天池技术讨论区 (aliyun.com)](https://tianchi.aliyun.com/forum/postDetail?spm=5176.12586969.1002.3.5de127f0Rf2Aep&postId=348449) 137 | - [qubvel/segmentation_models.pytorch: Segmentation models with pretrained backbones. PyTorch. (github.com)](https://github.com/qubvel/segmentation_models.pytorch) 138 | 139 | ## 致谢 140 | 141 | 最后,感谢主办方举办此次比赛,为我们提供了锻炼的平台,还在群里耐心解答我们的疑惑。感谢大佬开源的方案,让没接触过这类任务的我们也能很快上手。感谢天池提供的交流和学习的机会,让我结识了非常靠谱的队友,可以一起学习,共同进步。 142 | -------------------------------------------------------------------------------- /code/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class DiceLoss(nn.Module): 5 | def __init__(self): 6 | super(DiceLoss, self).__init__() 7 | 8 | def forward(self, input, target): 9 | N = target.shape[0] 10 | 11 | smooth = 1 12 | 13 | input_flat = input.view(N, -1) 14 | target_flat = target.view(N, -1) 15 | 16 | intersection = input_flat * target_flat 17 | N_dice_eff = (2*intersection.sum(1) + smooth) / (input_flat.sum(1) + target_flat.sum(1) + smooth) 18 | # N_dice_eff = (2*intersection.sum(1) + smooth) / (torch.pow(input_flat,2).sum(1) + torch.pow(target_flat,2).sum(1) + smooth) 19 | 20 | loss = 1 - N_dice_eff.sum() / N 21 | # print(input, target, loss) 22 | return loss 23 | 24 | class ClsLoss(nn.Module): 25 | def __init__(self): 26 | super(ClsLoss, self).__init__() 27 | self.bceloss = nn.BCELoss() 28 | 29 | def forward(self, input, target): 30 | N = target.shape[0] 31 | input_flat = input.view(N, -1) 32 | target_flat = target.view(N, -1) 33 | input_cls = torch.max(input_flat, 1)[0] 34 | target_cls = torch.max(target_flat, 1)[0] 35 | loss = self.bceloss(input_cls, target_cls) 36 | return loss 37 | 38 | class PixelClsLoss(nn.Module): 39 | def __init__(self): 40 | super(PixelClsLoss, self).__init__() 41 | self.bceloss = nn.BCELoss() 42 | def forward(self, input, target): 43 | bs = target.shape[0] 44 | input_flat = input.flatten() 45 | target_flat = target.flatten() 46 | return self.bceloss(input_flat, target_flat) 47 | 48 | class EdgeLoss(nn.Module): 49 | def __init__(self): 50 | super(EdgeLoss, self).__init__() 51 | self.dice_loss = DiceLoss() 52 | 53 | def forward(self, input, target): 54 | # print(input.shape, target.shape) 55 | target = nn.functional.interpolate(target, scale_factor=0.25, mode='bilinear') 56 | return self.dice_loss(input, target) 57 | 58 | class MSELoss(nn.Module): 59 | def __init__(self): 60 | super(MSELoss, self).__init__() 61 | self.mseloss = torch.nn.MSELoss() 62 | def forward(self, input, target): 63 | return self.mseloss(input, target) 64 | 65 | def lovasz_grad(gt_sorted): 66 | """ 67 | Computes gradient of the Lovasz extension w.r.t sorted errors 68 | See Alg. 1 in paper 69 | """ 70 | p = len(gt_sorted) 71 | gts = gt_sorted.sum() 72 | intersection = gts - gt_sorted.float().cumsum(0) 73 | union = gts + (1 - gt_sorted).float().cumsum(0) 74 | jaccard = 1. - intersection / union 75 | if p > 1: # cover 1-pixel case 76 | jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] 77 | return jaccard 78 | # --------------------------- BINARY LOSSES --------------------------- 79 | def lovasz_hinge(logits, labels, per_image=True, ignore=None): 80 | """ 81 | Binary Lovasz hinge loss 82 | logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) 83 | labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) 84 | per_image: compute the loss per image instead of per batch 85 | ignore: void class id 86 | """ 87 | if per_image: 88 | loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore)) 89 | for log, lab in zip(logits, labels)) 90 | else: 91 | loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore)) 92 | return loss 93 | 94 | def lovasz_hinge_flat(logits, labels): 95 | """ 96 | Binary Lovasz hinge loss 97 | logits: [P] Variable, logits at each prediction (between -\infty and +\infty) 98 | labels: [P] Tensor, binary ground truth labels (0 or 1) 99 | ignore: label to ignore 100 | """ 101 | if len(labels) == 0: 102 | # only void pixels, the gradients should be 0 103 | return logits.sum() * 0. 104 | signs = 2. * labels.float() - 1. 105 | errors = (1. - logits * Variable(signs)) 106 | errors_sorted, perm = torch.sort(errors, dim=0, descending=True) 107 | perm = perm.data 108 | gt_sorted = labels[perm] 109 | grad = lovasz_grad(gt_sorted) 110 | loss = torch.dot(F.relu(errors_sorted), Variable(grad)) 111 | return loss 112 | 113 | def flatten_binary_scores(scores, labels, ignore=None): 114 | """ 115 | Flattens predictions in the batch (binary case) 116 | Remove labels equal to 'ignore' 117 | """ 118 | scores = scores.view(-1) 119 | labels = labels.view(-1) 120 | if ignore is None: 121 | return scores, labels 122 | valid = (labels != ignore) 123 | vscores = scores[valid] 124 | vlabels = labels[valid] 125 | return vscores, vlabels 126 | 127 | # --------------------------- MULTICLASS LOSSES --------------------------- 128 | def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=None): 129 | """ 130 | Multi-class Lovasz-Softmax loss 131 | probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1). 132 | Interpreted as binary (sigmoid) output with outputs of size [B, H, W]. 133 | labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) 134 | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. 135 | per_image: compute the loss per image instead of per batch 136 | ignore: void class labels 137 | """ 138 | print(probas.shape) 139 | if per_image: 140 | loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes) 141 | for prob, lab in zip(probas, labels)) 142 | else: 143 | loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes) 144 | return loss 145 | 146 | 147 | def lovasz_softmax_flat(probas, labels, classes='present'): 148 | """ 149 | Multi-class Lovasz-Softmax loss 150 | probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1) 151 | labels: [P] Tensor, ground truth labels (between 0 and C - 1) 152 | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. 153 | """ 154 | if probas.numel() == 0: 155 | # only void pixels, the gradients should be 0 156 | return probas * 0. 157 | C = probas.size(1) 158 | losses = [] 159 | class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes 160 | for c in class_to_sum: 161 | fg = (labels == c).float() # foreground for class c 162 | if (classes is 'present' and fg.sum() == 0): 163 | continue 164 | if C == 1: 165 | if len(classes) > 1: 166 | raise ValueError('Sigmoid output possible only with 1 class') 167 | class_pred = probas[:, 0] 168 | else: 169 | class_pred = probas[:, c] 170 | errors = (Variable(fg) - class_pred).abs() 171 | errors_sorted, perm = torch.sort(errors, 0, descending=True) 172 | perm = perm.data 173 | fg_sorted = fg[perm] 174 | losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted)))) 175 | return mean(losses) 176 | 177 | def flatten_probas(probas, labels, ignore=None): 178 | """ 179 | Flattens predictions in the batch 180 | """ 181 | if probas.dim() == 3: 182 | # assumes output of a sigmoid layer 183 | B, H, W = probas.size() 184 | probas = probas.view(B, 1, H, W) 185 | B, C, H, W = probas.size() 186 | probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C 187 | labels = labels.view(-1) 188 | if ignore is None: 189 | return probas, labels 190 | valid = (labels != ignore) 191 | vprobas = probas[valid.nonzero().squeeze()] 192 | vlabels = labels[valid] 193 | return vprobas, vlabels 194 | 195 | -------------------------------------------------------------------------------- /code/common/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0,'..') 3 | import cv2 4 | import time 5 | import numpy as np 6 | import collections 7 | import sys 8 | 9 | cv2.ocl.setUseOpenCL(False) 10 | cv2.setNumThreads(0) 11 | 12 | 13 | def read_annotations(data_path): 14 | lines = map(str.strip, open(data_path).readlines()) 15 | data = [] 16 | for line in lines: 17 | temp = line.split() 18 | if len(temp) == 1: 19 | sample_path = temp[0] 20 | mask_path = 'None' 21 | label = -1 22 | else: 23 | sample_path, mask_path, label = temp 24 | label = int(int(label) > 0) 25 | data.append((sample_path, mask_path, label)) 26 | return data 27 | 28 | 29 | def str2bool(in_str): 30 | if in_str in [1, "1", "t", "True", "true"]: 31 | return True 32 | elif in_str in [0, "0", "f", "False", "false", "none"]: 33 | return False 34 | 35 | 36 | def calculate_img_score(pd, gt): 37 | seg_inv, gt_inv = np.logical_not(pd), np.logical_not(gt) 38 | true_pos = float(np.logical_and(pd, gt).sum()) 39 | false_pos = np.logical_and(pd, gt_inv).sum() 40 | false_neg = np.logical_and(seg_inv, gt).sum() 41 | true_neg = float(np.logical_and(seg_inv, gt_inv).sum()) 42 | acc = (true_pos + true_neg) / (true_pos + true_neg + false_neg + false_pos + 1e-6) 43 | sen = true_pos / (true_pos + false_neg + 1e-6) 44 | spe = true_neg / (true_neg + false_pos + 1e-6) 45 | f1 = 2 * sen * spe / (sen + spe) 46 | return acc, sen, spe, f1, true_pos, true_neg, false_pos, false_neg 47 | 48 | 49 | def calculate_pixel_f1(pd, gt): 50 | if np.max(pd) == np.max(gt) and np.max(pd) == 0: 51 | f1, iou = 1.0, 1.0 52 | return f1, 0.0, 0.0 53 | seg_inv, gt_inv = np.logical_not(pd), np.logical_not(gt) 54 | true_pos = float(np.logical_and(pd, gt).sum()) 55 | false_pos = np.logical_and(pd, gt_inv).sum() 56 | false_neg = np.logical_and(seg_inv, gt).sum() 57 | f1 = 2 * true_pos / (2 * true_pos + false_pos + false_neg + 1e-6) 58 | precision = true_pos / (true_pos + false_pos + 1e-6) 59 | recall = true_pos / (true_pos + false_neg + 1e-6) 60 | return f1, precision, recall 61 | 62 | 63 | class Progbar(object): 64 | """Displays a progress bar. 65 | # Arguments 66 | target: Total number of steps expected, None if unknown. 67 | width: Progress bar width on screen. 68 | verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose) 69 | stateful_metrics: Iterable of string names of metrics that 70 | should *not* be averaged over time. Metrics in this list 71 | will be displayed as-is. All others will be averaged 72 | by the progbar before display. 73 | interval: Minimum visual progress update interval (in seconds). 74 | """ 75 | 76 | def __init__(self, target, width=30, verbose=1, interval=0.05, 77 | stateful_metrics=None): 78 | self.target = target 79 | self.width = width 80 | self.verbose = verbose 81 | self.interval = interval 82 | if stateful_metrics: 83 | self.stateful_metrics = set(stateful_metrics) 84 | else: 85 | self.stateful_metrics = set() 86 | 87 | self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and 88 | sys.stdout.isatty()) or 89 | 'ipykernel' in sys.modules) 90 | self._total_width = 0 91 | self._seen_so_far = 0 92 | self._values = collections.OrderedDict() 93 | self._start = time.time() 94 | self._last_update = 0 95 | 96 | def update(self, current, values=None): 97 | """Updates the progress bar. 98 | # Arguments 99 | current: Index of current step. 100 | values: List of tuples: 101 | `(name, value_for_last_step)`. 102 | If `name` is in `stateful_metrics`, 103 | `value_for_last_step` will be displayed as-is. 104 | Else, an average of the metric over time will be displayed. 105 | """ 106 | values = values or [] 107 | for k, v in values: 108 | if k not in self.stateful_metrics: 109 | if k not in self._values: 110 | self._values[k] = [v * (current - self._seen_so_far), 111 | current - self._seen_so_far] 112 | else: 113 | self._values[k][0] += v * (current - self._seen_so_far) 114 | self._values[k][1] += (current - self._seen_so_far) 115 | else: 116 | self._values[k] = v 117 | self._seen_so_far = current 118 | 119 | now = time.time() 120 | info = ' - %.0fs' % (now - self._start) 121 | if self.verbose == 1: 122 | if (now - self._last_update < self.interval and 123 | self.target is not None and current < self.target): 124 | return 125 | 126 | prev_total_width = self._total_width 127 | if self._dynamic_display: 128 | sys.stdout.write('\b' * prev_total_width) 129 | sys.stdout.write('\r') 130 | else: 131 | sys.stdout.write('\n') 132 | 133 | if self.target is not None: 134 | numdigits = int(np.floor(np.log10(self.target))) + 1 135 | barstr = '%%%dd/%d [' % (numdigits, self.target) 136 | bar = barstr % current 137 | prog = float(current) / self.target 138 | prog_width = int(self.width * prog) 139 | if prog_width > 0: 140 | bar += ('=' * (prog_width - 1)) 141 | if current < self.target: 142 | bar += '>' 143 | else: 144 | bar += '=' 145 | bar += ('.' * (self.width - prog_width)) 146 | bar += ']' 147 | else: 148 | bar = '%7d/Unknown' % current 149 | 150 | self._total_width = len(bar) 151 | sys.stdout.write(bar) 152 | 153 | if current: 154 | time_per_unit = (now - self._start) / current 155 | else: 156 | time_per_unit = 0 157 | if self.target is not None and current < self.target: 158 | eta = time_per_unit * (self.target - current) 159 | if eta > 3600: 160 | eta_format = '%d:%02d:%02d' % (eta // 3600, (eta % 3600) // 60, eta % 60) 161 | elif eta > 60: 162 | eta_format = '%d:%02d' % (eta // 60, eta % 60) 163 | else: 164 | eta_format = '%ds' % eta 165 | 166 | info = ' - ETA: %s' % eta_format 167 | else: 168 | if time_per_unit >= 1: 169 | info += ' %.0fs/step' % time_per_unit 170 | elif time_per_unit >= 1e-3: 171 | info += ' %.0fms/step' % (time_per_unit * 1e3) 172 | else: 173 | info += ' %.0fus/step' % (time_per_unit * 1e6) 174 | 175 | for k in self._values: 176 | info += ' - %s:' % k 177 | if isinstance(self._values[k], list): 178 | avg = np.mean( 179 | self._values[k][0] / max(1, self._values[k][1])) 180 | if abs(avg) > 1e-3: 181 | info += ' %.4f' % avg 182 | else: 183 | info += ' %.4e' % avg 184 | else: 185 | info += ' %s' % self._values[k] 186 | 187 | self._total_width += len(info) 188 | if prev_total_width > self._total_width: 189 | info += (' ' * (prev_total_width - self._total_width)) 190 | 191 | if self.target is not None and current >= self.target: 192 | info += '\n' 193 | 194 | sys.stdout.write(info) 195 | sys.stdout.flush() 196 | 197 | elif self.verbose == 2: 198 | if self.target is None or current >= self.target: 199 | for k in self._values: 200 | info += ' - %s:' % k 201 | avg = np.mean( 202 | self._values[k][0] / max(1, self._values[k][1])) 203 | if avg > 1e-3: 204 | info += ' %.4f' % avg 205 | else: 206 | info += ' %.4e' % avg 207 | info += '\n' 208 | 209 | sys.stdout.write(info) 210 | sys.stdout.flush() 211 | 212 | self._last_update = now 213 | 214 | def add(self, n, values=None): 215 | self.update(self._seen_so_far + n, values) 216 | 217 | 218 | class AverageMeter(object): 219 | """Computes and stores the average and current value""" 220 | 221 | def __init__(self): 222 | self.reset() 223 | 224 | def reset(self): 225 | self.val = 0 226 | self.avg = 0 227 | self.sum = 0 228 | self.count = 0 229 | 230 | def update(self, val, n=1): 231 | self.val = val 232 | self.sum += val * n 233 | self.count += n 234 | self.avg = self.sum / self.count 235 | 236 | def __str__(self): 237 | """String representation for logging 238 | """ 239 | # for values that should be recorded exactly e.g. iteration number 240 | if self.count == 0: 241 | return str(self.val) 242 | # for stats 243 | return '%.4f (%.4f)' % (self.val, self.avg) 244 | 245 | 246 | if __name__ == "__main__": 247 | pass -------------------------------------------------------------------------------- /code/models/unet_parts.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | # ~~~~~~~~~~ U-Net ~~~~~~~~~~ 7 | 8 | class U_double_conv(nn.Module): 9 | def __init__(self, in_ch, out_ch): 10 | super(U_double_conv, self).__init__() 11 | self.conv = nn.Sequential( 12 | nn.Conv2d(in_ch, out_ch, 3, padding=1), 13 | nn.BatchNorm2d(out_ch), 14 | nn.ReLU(inplace=True), 15 | nn.Conv2d(out_ch, out_ch, 3, padding=1), 16 | nn.BatchNorm2d(out_ch), 17 | nn.ReLU(inplace=True) 18 | ) 19 | 20 | def forward(self, x): 21 | x = self.conv(x) 22 | return x 23 | 24 | 25 | class inconv(nn.Module): 26 | def __init__(self, in_ch, out_ch): 27 | super(inconv, self).__init__() 28 | self.conv = U_double_conv(in_ch, out_ch) 29 | 30 | def forward(self, x): 31 | x = self.conv(x) 32 | return x 33 | 34 | 35 | class U_down(nn.Module): 36 | def __init__(self, in_ch, out_ch): 37 | super(U_down, self).__init__() 38 | self.mpconv = nn.Sequential( 39 | nn.MaxPool2d(kernel_size=2, stride=2), 40 | U_double_conv(in_ch, out_ch) 41 | ) 42 | 43 | def forward(self, x): 44 | x = self.mpconv(x) 45 | return x 46 | 47 | 48 | class U_up(nn.Module): 49 | def __init__(self, in_ch, out_ch, bilinear=True): 50 | super(U_up, self).__init__() 51 | 52 | # would be a nice idea if the upsampling could be learned too, 53 | # but my machine do not have enough memory to handle all those weights 54 | if bilinear: 55 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 56 | else: 57 | self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2) 58 | 59 | self.conv = U_double_conv(in_ch, out_ch) 60 | 61 | def forward(self, x1, x2): 62 | x1 = self.up(x1) 63 | diffX = x2.size()[2] - x1.size()[2] 64 | diffY = x2.size()[3] - x1.size()[3] 65 | 66 | x1 = F.pad(x1, (diffY, 0, 67 | diffX, 0)) 68 | x = torch.cat([x2, x1], dim=1) 69 | 70 | x = self.conv(x) 71 | return x 72 | 73 | 74 | # ~~~~~~~~~~ RU-Net ~~~~~~~~~~ 75 | 76 | class RU_double_conv(nn.Module): 77 | def __init__(self, in_ch, out_ch): 78 | super(RU_double_conv, self).__init__() 79 | self.conv = nn.Sequential( 80 | nn.Conv2d(in_ch, out_ch, 3, padding=1), 81 | nn.BatchNorm2d(out_ch), 82 | nn.ReLU(inplace=True), 83 | nn.Conv2d(out_ch, out_ch, 3, padding=1), 84 | nn.BatchNorm2d(out_ch)) 85 | 86 | def forward(self, x): 87 | x = self.conv(x) 88 | return x 89 | 90 | 91 | class RU_first_down(nn.Module): 92 | def __init__(self, in_ch, out_ch): 93 | super(RU_first_down, self).__init__() 94 | self.conv = RU_double_conv(in_ch, out_ch) 95 | self.relu = nn.ReLU(inplace=True) 96 | self.res_conv = nn.Sequential( 97 | nn.Conv2d(in_ch, out_ch, kernel_size=1, bias=False), 98 | nn.BatchNorm2d(out_ch)) 99 | 100 | def forward(self, x): 101 | # the first ring conv 102 | ft1 = self.conv(x) 103 | print(ft1.shape, self.res_conv(x).shape) 104 | r1 = self.relu(ft1 + self.res_conv(x)) 105 | 106 | return r1 107 | 108 | 109 | class RU_down(nn.Module): 110 | def __init__(self, in_ch, out_ch): 111 | super(RU_down, self).__init__() 112 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 113 | self.conv = RU_double_conv(in_ch, out_ch) 114 | self.relu = nn.ReLU(inplace=True) 115 | self.res_conv = nn.Sequential( 116 | nn.Conv2d(in_ch, out_ch, 1, bias=False), 117 | nn.BatchNorm2d(out_ch)) 118 | 119 | def forward(self, x): 120 | x = self.maxpool(x) 121 | # the first ring conv 122 | ft1 = self.conv(x) 123 | r1 = self.relu(ft1 + self.res_conv(x)) 124 | 125 | return r1 126 | 127 | 128 | class RU_up(nn.Module): 129 | def __init__(self, in_ch, out_ch, bilinear=False): 130 | super(RU_up, self).__init__() 131 | # would be a nice idea if the upsampling could be learned too, 132 | # but my machine do not have enough memory to handle all those weights 133 | # nn.Upsample hasn't weights to learn, but nn.ConvTransposed2d has weights to learn. 134 | if bilinear: 135 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 136 | else: 137 | self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2) 138 | 139 | self.conv = RU_double_conv(in_ch, out_ch) 140 | self.relu = nn.ReLU(inplace=True) 141 | self.res_conv = nn.Sequential( 142 | nn.Conv2d(in_ch, out_ch, kernel_size=1, bias=False), 143 | nn.GroupNorm(32, out_ch)) 144 | 145 | def forward(self, x1, x2): 146 | x1 = self.up(x1) 147 | diffX = x2.size()[2] - x1.size()[2] 148 | diffY = x2.size()[3] - x1.size()[3] 149 | 150 | x1 = F.pad(x1, (diffY, 0, 151 | diffX, 0)) 152 | x = torch.cat([x2, x1], dim=1) 153 | 154 | # the first ring conv 155 | ft1 = self.conv(x) 156 | r1 = self.relu(self.res_conv(x) + ft1) 157 | 158 | return r1 159 | 160 | 161 | # ~~~~~~~~~~ RRU-Net ~~~~~~~~~~ 162 | 163 | class RRU_double_conv(nn.Module): 164 | def __init__(self, in_ch, out_ch): 165 | super(RRU_double_conv, self).__init__() 166 | self.conv = nn.Sequential( 167 | nn.Conv2d(in_ch, out_ch, 3, padding=2, dilation=2), 168 | nn.GroupNorm(32, out_ch), 169 | nn.ReLU(inplace=True), 170 | nn.Conv2d(out_ch, out_ch, 3, padding=2, dilation=2), 171 | nn.GroupNorm(32, out_ch) 172 | ) 173 | 174 | def forward(self, x): 175 | x = self.conv(x) 176 | return x 177 | 178 | 179 | class RRU_first_down(nn.Module): 180 | def __init__(self, in_ch, out_ch): 181 | super(RRU_first_down, self).__init__() 182 | self.conv = RRU_double_conv(in_ch, out_ch) 183 | self.relu = nn.ReLU(inplace=True) 184 | 185 | self.res_conv = nn.Sequential( 186 | nn.Conv2d(in_ch, out_ch, kernel_size=1, bias=False), 187 | nn.GroupNorm(32, out_ch) 188 | ) 189 | self.res_conv_back = nn.Sequential( 190 | nn.Conv2d(out_ch, in_ch, kernel_size=1, bias=False) 191 | ) 192 | 193 | def forward(self, x): 194 | # the first ring conv 195 | ft1 = self.conv(x) 196 | r1 = self.relu(ft1 + self.res_conv(x)) 197 | # the second ring conv 198 | ft2 = self.res_conv_back(r1) 199 | x = torch.mul(1 + F.sigmoid(ft2), x) 200 | # the third ring conv 201 | ft3 = self.conv(x) 202 | r3 = self.relu(ft3 + self.res_conv(x)) 203 | 204 | return r3 205 | 206 | 207 | class RRU_down(nn.Module): 208 | def __init__(self, in_ch, out_ch): 209 | super(RRU_down, self).__init__() 210 | self.conv = RRU_double_conv(in_ch, out_ch) 211 | self.relu = nn.ReLU(inplace=True) 212 | self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 213 | 214 | self.res_conv = nn.Sequential( 215 | nn.Conv2d(in_ch, out_ch, 1, bias=False), 216 | nn.GroupNorm(32, out_ch)) 217 | self.res_conv_back = nn.Sequential( 218 | nn.Conv2d(out_ch, in_ch, kernel_size=1, bias=False)) 219 | 220 | def forward(self, x): 221 | x = self.pool(x) 222 | # the first ring conv 223 | ft1 = self.conv(x) 224 | r1 = self.relu(ft1 + self.res_conv(x)) 225 | # the second ring conv 226 | ft2 = self.res_conv_back(r1) 227 | x = torch.mul(1 + F.sigmoid(ft2), x) 228 | # the third ring conv 229 | ft3 = self.conv(x) 230 | r3 = self.relu(ft3 + self.res_conv(x)) 231 | 232 | return r3 233 | 234 | 235 | class RRU_up(nn.Module): 236 | def __init__(self, in_ch, out_ch, bilinear=False): 237 | super(RRU_up, self).__init__() 238 | if bilinear: 239 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 240 | else: 241 | self.up = nn.Sequential( 242 | nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2), 243 | nn.GroupNorm(32, in_ch // 2)) 244 | 245 | self.conv = RRU_double_conv(in_ch, out_ch) 246 | self.relu = nn.ReLU(inplace=True) 247 | 248 | self.res_conv = nn.Sequential( 249 | nn.Conv2d(in_ch, out_ch, kernel_size=1, bias=False), 250 | nn.GroupNorm(32, out_ch)) 251 | self.res_conv_back = nn.Sequential( 252 | nn.Conv2d(out_ch, in_ch, kernel_size=1, bias=False)) 253 | 254 | def forward(self, x1, x2): 255 | x1 = self.up(x1) 256 | diffX = x2.size()[2] - x1.size()[2] 257 | diffY = x2.size()[3] - x1.size()[3] 258 | 259 | x1 = F.pad(x1, (diffY, 0, 260 | diffX, 0)) 261 | 262 | x = self.relu(torch.cat([x2, x1], dim=1)) 263 | 264 | # the first ring conv 265 | ft1 = self.conv(x) 266 | r1 = self.relu(self.res_conv(x) + ft1) 267 | # the second ring conv 268 | ft2 = self.res_conv_back(r1) 269 | x = torch.mul(1 + F.sigmoid(ft2), x) 270 | # the third ring conv 271 | ft3 = self.conv(x) 272 | r3 = self.relu(ft3 + self.res_conv(x)) 273 | 274 | return r3 275 | 276 | 277 | # !!!!!!!!!!!! Universal functions !!!!!!!!!!!! 278 | 279 | class outconv(nn.Module): 280 | def __init__(self, in_ch, out_ch): 281 | super(outconv, self).__init__() 282 | self.conv = nn.Conv2d(in_ch, out_ch, 1) 283 | 284 | def forward(self, x): 285 | x = self.conv(x) 286 | return x 287 | -------------------------------------------------------------------------------- /code/inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | import torch 4 | import cv2 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | from PIL import Image 8 | import torchvision.transforms as transforms 9 | 10 | from models.mvssnet import get_mvss 11 | from models.unet_model import Ringed_Res_Unet 12 | import segmentation_models_pytorch as smp 13 | from common.tools import inference_single 14 | 15 | import argparse 16 | 17 | from dataset import ManiDataset 18 | from tqdm import tqdm 19 | 20 | def parse_args(): 21 | parser = argparse.ArgumentParser(description='inference') 22 | parser.add_argument('--root', type=str, default='/data1/datasets/Image-Manipulation-Detection/test/') 23 | parser.add_argument('--model', type=str, default='None') 24 | parser.add_argument('--th', type=float, default=0.5) 25 | parser.add_argument('--weights', type=str, default='./ckpt/mvssnet_casia.pt') 26 | parser.add_argument('--save-dir', type=str, default='./images/') 27 | args = parser.parse_args() 28 | return args 29 | 30 | def init_model(model_type, pretrained_ckpt=None): 31 | if model_type == 'mvssnet': 32 | model = get_mvss(backbone='resnet50', 33 | pretrained_base=True, 34 | nclass=1, 35 | sobel=True, 36 | constrain=True, 37 | n_input=3, 38 | ) 39 | # TODO: initialize with pretrained_ckpt 40 | checkpoint = torch.load(pretrained_ckpt, map_location='cpu') 41 | print("load model :{}".format(pretrained_ckpt)) 42 | elif model_type == 'rrunet': 43 | model = Ringed_Res_Unet(n_channels=3, n_classes=1) 44 | checkpoint = torch.load(pretrained_ckpt, map_location='cpu') 45 | elif model_type == 'unet': 46 | model = smp.Unet('efficientnet-b5', classes=1, activation='sigmoid') 47 | checkpoint = torch.load(pretrained_ckpt, map_location='cpu') 48 | 49 | model.load_state_dict(checkpoint, strict=True) 50 | model = model.cuda() 51 | model.eval() 52 | return model 53 | 54 | def cal_sliding_params(img_h, img_w): 55 | # 计算需要裁剪成几块 56 | col, row = 1, 1 57 | while (512*col - (col-1)*128) < img_h: 58 | col += 1 59 | while (512*row - (row-1)*128) < img_w: 60 | row += 1 61 | return col, row 62 | 63 | def img_slide_window(img, col, row): 64 | imgs = [] 65 | # 计算 overlape 66 | delta_x, delta_y = 0, 0 67 | if row > 1: 68 | delta_x = int((row*512-img.shape[-1])/(row-1)) 69 | if col > 1: 70 | delta_y = int((col*512-img.shape[-2])/(col-1)) 71 | 72 | for i in range(col): 73 | for j in range(row): 74 | begin_h = 512*i - max(0, i)*delta_y 75 | begin_w = 512*j - max(0, j)*delta_x 76 | 77 | if begin_h + 512 > img.shape[-2]: 78 | begin_h = img.shape[-2] - 512 79 | if begin_w + 512 > img.shape[-1]: 80 | begin_w = img.shape[-1] - 512 81 | slide = img[:, :, begin_h:begin_h+512, begin_w:begin_w+512].squeeze(0) 82 | imgs.append(slide) 83 | # print(begin_h, begin_w, begin_h+512, begin_w+512, img.shape) 84 | return torch.stack(imgs, dim=0) 85 | 86 | def merge_slides_result(segs, col, row, img_shape): 87 | count = torch.zeros([1, img_shape[2], img_shape[3]]).cuda() 88 | seg = torch.zeros([1, img_shape[2], img_shape[3]]).cuda() 89 | 90 | # 计算 overlape 91 | delta_x, delta_y = 0, 0 92 | if row > 1: 93 | delta_x = int((row*512-img_shape[-1])/(row-1)) 94 | if col > 1: 95 | delta_y = int((col*512-img_shape[-2])/(col-1)) 96 | 97 | # print(col, row) 98 | for i in range(col): 99 | for j in range(row): 100 | begin_h = 512*i - max(0, i)*delta_y 101 | begin_w = 512*j - max(0, j)*delta_x 102 | 103 | if begin_h + 512 > img_shape[-2]: 104 | begin_h = img_shape[-2] - 512 105 | if begin_w + 512 > img_shape[-1]: 106 | begin_w = img_shape[-1] - 512 107 | seg[:, begin_h:begin_h+512, begin_w:begin_w+512] += segs[i*row+j] 108 | count[:, begin_h:begin_h+512, begin_w:begin_w+512] += 1.0 109 | seg = seg / count 110 | return seg.unsqueeze(0) 111 | 112 | def slide_window_val_unet(model, valloader, th=0.5): 113 | losses = AverageMeter() 114 | scores = AverageMeter() 115 | ious = AverageMeter() 116 | model.eval() 117 | 118 | dice_loss = DiceLoss() 119 | # dice_loss = PixelClsLoss() 120 | 121 | for img, label in tqdm(valloader): 122 | img, label = img.cuda(), label.cuda() #[3,h,w] 123 | # print(img.shape, label.shape) 124 | with torch.no_grad(): 125 | img_h, img_w = img.shape[-2], img.shape[-1] 126 | col, row = cal_sliding_params(img_h, img_w) 127 | imgs = img_slide_window(img, col, row) 128 | # print(imgs.shape) 129 | seg = model(imgs) 130 | # seg = torch.sigmoid(seg) 131 | seg = merge_slides_result(seg, col, row, img.shape) 132 | loss = dice_loss(seg, label).item() 133 | 134 | losses.update(loss, 1) 135 | f1, iou = calculate_batch_score(seg, label, th) 136 | scores.update(f1, 1) 137 | ious.update(iou, 1) 138 | 139 | return losses.avg, scores.avg, ious.avg 140 | 141 | def TTA_val(model, valloader, th=0.5): 142 | losses = AverageMeter() 143 | scores = AverageMeter() 144 | ious = AverageMeter() 145 | model.eval() 146 | 147 | dice_loss = DiceLoss() 148 | # dice_loss = PixelClsLoss() 149 | 150 | for img, label in tqdm(valloader): 151 | img, label = img.cuda(), label.cuda() #[3,h,w] 152 | # print(img.shape, label.shape) 153 | 154 | # 滑窗计算 155 | with torch.no_grad(): 156 | img_h, img_w = img.shape[-2], img.shape[-1] 157 | col, row = cal_sliding_params(img_h, img_w) 158 | imgs = img_slide_window(img, col, row) 159 | # print(imgs.shape) 160 | seg = model(imgs) 161 | # seg = torch.sigmoid(seg) 162 | seg = merge_slides_result(seg, col, row, img.shape) 163 | loss = dice_loss(seg, label).item() 164 | 165 | # 直接resize 计算 166 | transform = transforms.Compose([ 167 | transforms.Resize([512,512]) 168 | ]) 169 | invtransform = transforms.Compose([ 170 | transforms.Resize([img.shape[-2], img.shape[-1]]) 171 | ]) 172 | with torch.no_grad(): 173 | seg_resize = model(transform(img)) 174 | seg_resize = invtransform(seg_resize) 175 | seg = seg*0.4 + seg_resize*0.6 176 | 177 | losses.update(loss, 1) 178 | f1, iou = calculate_batch_score(seg, label, th) 179 | scores.update(f1, 1) 180 | ious.update(iou, 1) 181 | 182 | return losses.avg, scores.avg, ious.avg 183 | 184 | def TTA_inference_single(model, img, th=0.3, alpha=0.3): 185 | transform_pil = transforms.Compose([ 186 | transforms.ToPILImage(), 187 | ]) 188 | img = img.cuda().view(-1, img.shape[0], img.shape[1], img.shape[2]) 189 | # print(img) 190 | 191 | # 滑窗检测 192 | with torch.no_grad(): 193 | img_h, img_w = img.shape[-2], img.shape[-1] 194 | col, row = cal_sliding_params(img_h, img_w) 195 | imgs = img_slide_window(img, col, row) 196 | # print(imgs.shape) 197 | seg = model(imgs) 198 | # seg = torch.sigmoid(seg) 199 | seg = merge_slides_result(seg, col, row, img.shape) 200 | 201 | transform = transforms.Compose([ 202 | transforms.Resize([512,512]) 203 | ]) 204 | invtransform = transforms.Compose([ 205 | transforms.Resize([img.shape[-2], img.shape[-1]]) 206 | ]) 207 | with torch.no_grad(): 208 | seg_resize = model(transform(img)) 209 | seg_resize = invtransform(seg_resize) 210 | seg = seg*alpha + seg_resize*(1-alpha) 211 | seg = seg.detach().cpu() 212 | 213 | if torch.isnan(seg).any() or torch.isinf(seg).any(): 214 | max_score = 0.0 215 | else: 216 | max_score = torch.max(seg).numpy() 217 | seg = [np.array(transform_pil(seg[i])) for i in range(len(seg))] 218 | # print(seg) 219 | 220 | if len(seg) != 1: 221 | pdb.set_trace() 222 | else: 223 | fake_seg = seg[0] 224 | if th == 0: 225 | return fake_seg, max_score 226 | 227 | # fake_seg = 255.0 * (fake_seg > 255 * th) 228 | # fake_seg = 255.0 * fake_seg 229 | fake_seg = fake_seg.astype(np.uint8) 230 | 231 | # print(fake_seg.shape) 232 | return fake_seg, max_score 233 | 234 | def TTA_inference_two(model, model1, img, th=0.3, alpha=0.3): 235 | # model slidewindow, model1 resize 236 | transform_pil = transforms.Compose([ 237 | transforms.ToPILImage(), 238 | ]) 239 | img = img.cuda().view(-1, img.shape[0], img.shape[1], img.shape[2]) 240 | # print(img) 241 | 242 | # 滑窗检测 243 | with torch.no_grad(): 244 | img_h, img_w = img.shape[-2], img.shape[-1] 245 | col, row = cal_sliding_params(img_h, img_w) 246 | imgs = img_slide_window(img, col, row) 247 | # print(imgs.shape) 248 | seg = model(imgs) 249 | # seg = torch.sigmoid(seg) 250 | seg = merge_slides_result(seg, col, row, img.shape) 251 | 252 | transform = transforms.Compose([ 253 | transforms.Resize([512,512]) 254 | ]) 255 | invtransform = transforms.Compose([ 256 | transforms.Resize([img.shape[-2], img.shape[-1]]) 257 | ]) 258 | with torch.no_grad(): 259 | seg_resize = model1(transform(img)) 260 | seg_resize = invtransform(seg_resize) 261 | seg = seg*alpha + seg_resize*(1-alpha) 262 | seg = seg.detach().cpu() 263 | 264 | if torch.isnan(seg).any() or torch.isinf(seg).any(): 265 | max_score = 0.0 266 | else: 267 | max_score = torch.max(seg).numpy() 268 | seg = [np.array(transform_pil(seg[i])) for i in range(len(seg))] 269 | # print(seg) 270 | 271 | if len(seg) != 1: 272 | pdb.set_trace() 273 | else: 274 | fake_seg = seg[0] 275 | if th == 0: 276 | return fake_seg, max_score 277 | 278 | fake_seg = 255.0 * (fake_seg > 255 * th) 279 | # fake_seg = 255.0 * fake_seg 280 | fake_seg = fake_seg.astype(np.uint8) 281 | 282 | # print(fake_seg.shape) 283 | return fake_seg, max_score 284 | 285 | if __name__ == '__main__': 286 | args = parse_args() 287 | model = init_model(args.model, args.weights) 288 | 289 | # init dataset 290 | testset = ManiDataset([args.root], ['test.txt'], mode='test', resize=False) 291 | 292 | for img, img_name, fake_size in tqdm(testset): 293 | # print(img.shape, img_name, fake_size) 294 | seg, max_score = TTA_inference_single(model, img, th=args.th, alpha=0.3) 295 | # seg, max_score = TTA_inference_two(model, model1, img, th=args.th, alpha=0.35) 296 | # np.save("{}/{}.npy".format(args.save_dir, img_name), seg) 297 | seg = cv2.resize(seg, (fake_size[1], fake_size[0])) 298 | cv2.imwrite("{}/{}.png".format(args.save_dir, img_name), seg) -------------------------------------------------------------------------------- /code/inference_model_ensemble.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | import torch 4 | import cv2 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | from PIL import Image 8 | import torchvision.transforms as transforms 9 | 10 | from models.mvssnet import get_mvss 11 | from models.unet_model import Ringed_Res_Unet 12 | from common.tools import inference_single 13 | import segmentation_models_pytorch as smp 14 | 15 | import argparse 16 | 17 | from dataset import ManiDataset 18 | from tqdm import tqdm 19 | 20 | def parse_args(): 21 | parser = argparse.ArgumentParser(description='inference') 22 | parser.add_argument('--root', type=str, default='/data1/datasets/Image-Manipulation-Detection/test/') 23 | parser.add_argument('--model', type=str, default='unet') 24 | parser.add_argument('--th', type=float, default=0.5) 25 | parser.add_argument('--weights', type=str, default='./ckpt/mvssnet_casia.pt') 26 | parser.add_argument('--save-dir', type=str, default='./images/') 27 | args = parser.parse_args() 28 | return args 29 | 30 | def init_model(model_type, pretrained_ckpt=None): 31 | if model_type == 'mvssnet': 32 | model = get_mvss(backbone='resnet50', 33 | pretrained_base=True, 34 | nclass=1, 35 | sobel=True, 36 | constrain=True, 37 | n_input=3, 38 | ) 39 | # TODO: initialize with pretrained_ckpt 40 | checkpoint = torch.load(pretrained_ckpt, map_location='cpu') 41 | print("load model :{}".format(pretrained_ckpt)) 42 | elif model_type == 'rrunet': 43 | model = Ringed_Res_Unet(n_channels=3, n_classes=1) 44 | checkpoint = torch.load(pretrained_ckpt, map_location='cpu') 45 | print("load model :{}".format(pretrained_ckpt)) 46 | elif model_type == 'unet': 47 | model = smp.Unet('efficientnet-b5', classes=1, activation='sigmoid') 48 | checkpoint = torch.load(pretrained_ckpt, map_location='cpu') 49 | print("load model :{}".format(pretrained_ckpt)) 50 | elif model_type == 'linknet': 51 | model = smp.Linknet('efficientnet-b5', classes=1, activation='sigmoid') 52 | checkpoint = torch.load(pretrained_ckpt, map_location='cpu') 53 | print("load model :{}".format(pretrained_ckpt)) 54 | 55 | model.load_state_dict(checkpoint, strict=True) 56 | model = model.cuda() 57 | model.eval() 58 | return model 59 | 60 | def inference_single_rru(model, img, th=0): 61 | transform_pil = transforms.Compose([ 62 | transforms.ToPILImage(), 63 | ]) 64 | img = img.cuda().view(-1, img.shape[0], img.shape[1], img.shape[2]) 65 | # print(img) 66 | with torch.no_grad(): 67 | seg = model(img) 68 | # print(seg) 69 | seg = torch.sigmoid(seg).detach().cpu() 70 | # print(seg) 71 | 72 | if torch.isnan(seg).any() or torch.isinf(seg).any(): 73 | max_score = 0.0 74 | else: 75 | max_score = torch.max(seg).numpy() 76 | seg = [np.array(transform_pil(seg[i])) for i in range(len(seg))] 77 | # print(seg) 78 | 79 | if len(seg) != 1: 80 | pdb.set_trace() 81 | else: 82 | fake_seg = seg[0] 83 | if th == 0: 84 | return fake_seg, max_score 85 | # fake_seg = 255.0 * (fake_seg > 255 * th) 86 | # fake_seg = fake_seg.astype(np.uint8) 87 | 88 | # print(fake_seg.shape) 89 | return fake_seg, max_score 90 | 91 | def inference_single_unet(model, img, th=0): 92 | transform_pil = transforms.Compose([ 93 | transforms.ToPILImage(), 94 | ]) 95 | torch_resize = transforms.Resize([512, 512]) 96 | img = torch_resize(img) 97 | 98 | img = img.cuda().view(-1, img.shape[0], img.shape[1], img.shape[2]) 99 | # print(img) 100 | with torch.no_grad(): 101 | seg = model(img) 102 | # print(seg) 103 | seg = seg.detach().cpu() 104 | # print(seg) 105 | 106 | if torch.isnan(seg).any() or torch.isinf(seg).any(): 107 | max_score = 0.0 108 | else: 109 | max_score = torch.max(seg).numpy() 110 | seg = [np.array(transform_pil(seg[i])) for i in range(len(seg))] 111 | # print(seg) 112 | 113 | if len(seg) != 1: 114 | pdb.set_trace() 115 | else: 116 | fake_seg = seg[0] 117 | if th == 0: 118 | return fake_seg, max_score 119 | # fake_seg = 255.0 * (fake_seg > 255 * th) 120 | # fake_seg = fake_seg.astype(np.uint8) 121 | 122 | # print(fake_seg.shape) 123 | return fake_seg, max_score 124 | 125 | def cal_sliding_params(img_h, img_w): 126 | # 计算需要裁剪成几块 127 | col, row = 1, 1 128 | while (512*col - (col-1)*128) < img_h: 129 | col += 1 130 | while (512*row - (row-1)*128) < img_w: 131 | row += 1 132 | return col, row 133 | 134 | def img_slide_window(img, col, row): 135 | imgs = [] 136 | # 计算 overlape 137 | delta_x, delta_y = 0, 0 138 | if row > 1: 139 | delta_x = int((row*512-img.shape[-1])/(row-1)) 140 | if col > 1: 141 | delta_y = int((col*512-img.shape[-2])/(col-1)) 142 | 143 | for i in range(col): 144 | for j in range(row): 145 | begin_h = 512*i - max(0, i)*delta_y 146 | begin_w = 512*j - max(0, j)*delta_x 147 | 148 | if begin_h + 512 > img.shape[-2]: 149 | begin_h = img.shape[-2] - 512 150 | if begin_w + 512 > img.shape[-1]: 151 | begin_w = img.shape[-1] - 512 152 | slide = img[:, :, begin_h:begin_h+512, begin_w:begin_w+512].squeeze(0) 153 | imgs.append(slide) 154 | # print(begin_h, begin_w, begin_h+512, begin_w+512, img.shape) 155 | return torch.stack(imgs, dim=0) 156 | 157 | def merge_slides_result(segs, col, row, img_shape): 158 | count = torch.zeros([1, img_shape[2], img_shape[3]]).cuda() 159 | seg = torch.zeros([1, img_shape[2], img_shape[3]]).cuda() 160 | 161 | # 计算 overlape 162 | delta_x, delta_y = 0, 0 163 | if row > 1: 164 | delta_x = int((row*512-img_shape[-1])/(row-1)) 165 | if col > 1: 166 | delta_y = int((col*512-img_shape[-2])/(col-1)) 167 | 168 | # print(col, row) 169 | for i in range(col): 170 | for j in range(row): 171 | begin_h = 512*i - max(0, i)*delta_y 172 | begin_w = 512*j - max(0, j)*delta_x 173 | 174 | if begin_h + 512 > img_shape[-2]: 175 | begin_h = img_shape[-2] - 512 176 | if begin_w + 512 > img_shape[-1]: 177 | begin_w = img_shape[-1] - 512 178 | seg[:, begin_h:begin_h+512, begin_w:begin_w+512] += segs[i*row+j] 179 | count[:, begin_h:begin_h+512, begin_w:begin_w+512] += 1.0 180 | seg = seg / count 181 | return seg.unsqueeze(0) 182 | 183 | def slide_window_val_unet(model, valloader, th=0.5): 184 | losses = AverageMeter() 185 | scores = AverageMeter() 186 | ious = AverageMeter() 187 | model.eval() 188 | 189 | dice_loss = DiceLoss() 190 | # dice_loss = PixelClsLoss() 191 | 192 | for img, label in tqdm(valloader): 193 | img, label = img.cuda(), label.cuda() #[3,h,w] 194 | # print(img.shape, label.shape) 195 | with torch.no_grad(): 196 | img_h, img_w = img.shape[-2], img.shape[-1] 197 | col, row = cal_sliding_params(img_h, img_w) 198 | imgs = img_slide_window(img, col, row) 199 | # print(imgs.shape) 200 | seg = model(imgs) 201 | # seg = torch.sigmoid(seg) 202 | seg = merge_slides_result(seg, col, row, img.shape) 203 | loss = dice_loss(seg, label).item() 204 | 205 | losses.update(loss, 1) 206 | f1, iou = calculate_batch_score(seg, label, th) 207 | scores.update(f1, 1) 208 | ious.update(iou, 1) 209 | 210 | return losses.avg, scores.avg, ious.avg 211 | 212 | def TTA_val(model, valloader, th=0.5, alpha=0.3): 213 | losses = AverageMeter() 214 | scores = AverageMeter() 215 | ious = AverageMeter() 216 | model.eval() 217 | 218 | dice_loss = DiceLoss() 219 | # dice_loss = PixelClsLoss() 220 | 221 | for img, label in tqdm(valloader): 222 | img, label = img.cuda(), label.cuda() #[3,h,w] 223 | # print(img.shape, label.shape) 224 | 225 | # 滑窗计算 226 | with torch.no_grad(): 227 | img_h, img_w = img.shape[-2], img.shape[-1] 228 | col, row = cal_sliding_params(img_h, img_w) 229 | imgs = img_slide_window(img, col, row) 230 | # print(imgs.shape) 231 | seg = model(imgs) 232 | # seg = torch.sigmoid(seg) 233 | seg = merge_slides_result(seg, col, row, img.shape) 234 | loss = dice_loss(seg, label).item() 235 | 236 | # 直接resize 计算 237 | transform = transforms.Compose([ 238 | transforms.Resize([512,512]) 239 | ]) 240 | invtransform = transforms.Compose([ 241 | transforms.Resize([img.shape[-2], img.shape[-1]]) 242 | ]) 243 | with torch.no_grad(): 244 | seg_resize = model(transform(img)) 245 | seg_resize = invtransform(seg_resize) 246 | seg = seg*alpha + seg_resize*(1-alpha) 247 | 248 | losses.update(loss, 1) 249 | f1, iou = calculate_batch_score(seg, label, th) 250 | scores.update(f1, 1) 251 | ious.update(iou, 1) 252 | 253 | return losses.avg, scores.avg, ious.avg 254 | 255 | 256 | def TTA_inference_single_unet(model, img, th=0, alpha=0.8): 257 | transform_pil = transforms.Compose([ 258 | transforms.ToPILImage(), 259 | ]) 260 | img = img.cuda().view(-1, img.shape[0], img.shape[1], img.shape[2]) 261 | # print(img) 262 | 263 | # 滑窗检测 264 | with torch.no_grad(): 265 | img_h, img_w = img.shape[-2], img.shape[-1] 266 | col, row = cal_sliding_params(img_h, img_w) 267 | imgs = img_slide_window(img, col, row) 268 | # print(imgs.shape) 269 | seg = model(imgs) 270 | # seg = torch.sigmoid(seg) 271 | seg = merge_slides_result(seg, col, row, img.shape) 272 | 273 | transform = transforms.Compose([ 274 | transforms.Resize([512,512]) 275 | ]) 276 | invtransform = transforms.Compose([ 277 | transforms.Resize([img.shape[-2], img.shape[-1]]) 278 | ]) 279 | with torch.no_grad(): 280 | seg_resize = model(transform(img)) 281 | seg_resize = invtransform(seg_resize) 282 | seg = seg*alpha + seg_resize*(1-alpha) 283 | seg = seg.detach().cpu() 284 | 285 | seg = seg - seg.min() 286 | seg = seg / seg.max() 287 | 288 | if torch.isnan(seg).any() or torch.isinf(seg).any(): 289 | max_score = 0.0 290 | else: 291 | max_score = torch.max(seg).numpy() 292 | seg = [np.array(transform_pil(seg[i])) for i in range(len(seg))] 293 | 294 | 295 | if len(seg) != 1: 296 | pdb.set_trace() 297 | else: 298 | fake_seg = seg[0] 299 | if th == 0: 300 | return fake_seg, max_score 301 | 302 | # fake_seg = 255.0 * (fake_seg > 255 * th) 303 | # fake_seg = fake_seg.astype(np.uint8) 304 | 305 | # print(fake_seg.shape) 306 | return fake_seg, max_score 307 | 308 | 309 | if __name__ == '__main__': 310 | args = parse_args() 311 | weight1 = "./work_dir/2859-unet-5fold0-aug/weights/best.pth.tar" # fold 4 312 | weight2 = "./work_dir/2859-unet-5fold1-aug/weights/best.pth.tar" # fold 1 313 | weight3 = "./work_dir/2859-unet-5fold2-aug/weights/best.pth.tar" # fold2 314 | weight4 = "./work_dir/2859-unet-5fold3-aug/weights/best.pth.tar" 315 | model1 = init_model('unet', weight1) 316 | model2 = init_model('unet', weight2) 317 | model3 = init_model('unet', weight3) 318 | model4 = init_model('unet', weight4) 319 | 320 | # init dataset 321 | testset = ManiDataset([args.root], ['test.txt'], mode='test', resize=True) 322 | 323 | for img, img_name, fake_size in tqdm(testset): 324 | # print(img.shape, img_name, fake_size) 325 | seg1, max_score1 = TTA_inference_single_unet(model1, img, th=args.th, alpha=0.3) 326 | seg2, max_score2 = TTA_inference_single_unet(model2, img, th=args.th, alpha=0.3) 327 | seg3, max_score3 = TTA_inference_single_unet(model3, img, th=args.th, alpha=0.3) 328 | seg4, max_score4 = TTA_inference_single_unet(model4, img, th=args.th, alpha=0.3) 329 | 330 | seg1 = cv2.resize(seg1, (fake_size[1], fake_size[0])) 331 | seg2 = cv2.resize(seg2, (fake_size[1], fake_size[0])) 332 | seg3 = cv2.resize(seg3, (fake_size[1], fake_size[0])) 333 | seg4 = cv2.resize(seg4, (fake_size[1], fake_size[0])) 334 | # print(seg1.shape) 335 | seg = seg1*0.25 + seg2*0.25 + seg3*0.25 + seg4*0.25 336 | 337 | seg = 255.0 * (seg > 255 * args.th) 338 | seg = seg.astype(np.uint8) 339 | _, seg = cv2.threshold(seg, int(255*args.th), 255, cv2.THRESH_BINARY) 340 | cv2.imwrite("{}/{}.png".format(args.save_dir, img_name), seg) -------------------------------------------------------------------------------- /code/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | import torch 4 | import cv2 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | from PIL import Image 8 | 9 | from models.mvssnet import get_mvss 10 | from models.unet_model import Ringed_Res_Unet, Ringed_Res_Unet_Slim 11 | import segmentation_models_pytorch as smp 12 | # from common.tools import inference_single 13 | from common.utils import calculate_pixel_f1, calculate_img_score, AverageMeter 14 | from loss import ClsLoss, DiceLoss, PixelClsLoss, EdgeLoss 15 | from loss import * 16 | 17 | import torch.utils.data as data 18 | import argparse 19 | from tqdm import tqdm 20 | from dataset import ManiDataset, ManiDatasetAug 21 | 22 | from loguru import logger 23 | 24 | def parse_args(): 25 | parser = argparse.ArgumentParser(description='train') 26 | parser.add_argument('--root', type=str, default='/data1/datasets/Image-Manipulation-Detection/train/') 27 | parser.add_argument('--model', type=str, default='unet') 28 | parser.add_argument('--pretrained-ckpt', type=str, default=None) 29 | parser.add_argument('--th', type=float, default=0.5) 30 | parser.add_argument('--epoch', type=int, default=200) 31 | parser.add_argument('--batchsize', type=int, default=1) 32 | # parser.add_argument("--model_name", type=str, help="Path to the pretrained model", default="ckpt/mvssnet.pth") 33 | parser.add_argument('--work-dir', type=str, default='./work_dir/') 34 | args = parser.parse_args() 35 | return args 36 | 37 | def init_model(model_type, pretrained_ckpt=None): 38 | if model_type == 'mvssnet': 39 | model = get_mvss(backbone='resnet50', 40 | pretrained_base=True, 41 | nclass=1, 42 | sobel=True, 43 | constrain=True, 44 | n_input=3, 45 | ) 46 | # TODO: initialize with pretrained_ckpt 47 | # checkpoint = torch.load("./ckpt/mvssnet_tianchi.pt", map_location='cpu') 48 | elif model_type == 'rrunet': 49 | model = Ringed_Res_Unet(n_channels=3, n_classes=1) 50 | # checkpoint = torch.load("./work_dir/rru-diceloss-fold4/weights/last.pth.tar", map_location='cpu') 51 | # model.load_state_dict(checkpoint, strict=True) 52 | elif model_type == 'rrunet-slim': 53 | model = Ringed_Res_Unet_Slim(n_channels=3, n_classes=1) 54 | elif model_type == 'unet': 55 | model = smp.Unet('tu-eca_swinnext26ts_256', classes=1, activation='sigmoid') 56 | elif model_type == 'linknet': 57 | model = smp.Linknet('efficientnet-b5', classes=1, activation='sigmoid') 58 | 59 | if pretrained_ckpt != None: 60 | checkpoint = torch.load(pretrained_ckpt, map_location='cpu') 61 | model.load_state_dict(checkpoint, strict=True) 62 | model = model.cuda() 63 | return model 64 | 65 | def train_mvss(model, trainloader, optimizer): 66 | losses = AverageMeter() 67 | f1_score = AverageMeter() 68 | ious = AverageMeter() 69 | 70 | model.train() 71 | 72 | dice_loss = DiceLoss() 73 | cls_loss = ClsLoss() 74 | alpha = 0.16 75 | beta = 0.04 76 | edge_loss = EdgeLoss() 77 | pixel_bceloss = PixelClsLoss() 78 | 79 | for img, label in tqdm(trainloader): 80 | img, label = img.cuda(), label.cuda() 81 | # print(img.shape, label.shape) 82 | edge, seg = model(img) 83 | # print(edge.shape, seg.shape) 84 | seg = torch.sigmoid(seg) 85 | loss = alpha*dice_loss(seg, label) + beta*cls_loss(seg, label) + (1-alpha-beta)*edge_loss(edge, label) 86 | # loss = pixel_bceloss(seg, label) 87 | f1, iou = calculate_batch_score(seg, label) 88 | optimizer.zero_grad() 89 | loss.backward() 90 | optimizer.step() 91 | 92 | losses.update(loss.item(), img.shape[0]) 93 | f1_score.update(f1, img.shape[0]) 94 | ious.update(iou, img.shape[0]) 95 | 96 | return losses.avg, f1_score.avg, ious.avg 97 | 98 | def calculate_iou(segs, labels, eps=1e-8): 99 | intersaction = segs * labels 100 | iou = intersaction.sum()/(segs.sum()+labels.sum()-intersaction.sum() + eps) 101 | return iou 102 | 103 | def calculate_batch_score(segs, labels): 104 | batch_size = segs.shape[0] 105 | batch_f1, batch_iou = 0.0, 0.0 106 | for i in range(batch_size): 107 | pd = segs[i] 108 | gt = labels[i] 109 | fake_seg = pd.detach().cpu().numpy() 110 | fake_gt = gt.detach().cpu().numpy() 111 | fake_seg = np.where(fake_seg<0.5, 0.0, 1.0) 112 | # print(fake_seg.shape, fake_gt.shape) 113 | f1, p, r = calculate_pixel_f1(fake_seg.flatten(),fake_gt.flatten()) 114 | batch_f1 += f1 115 | iou = calculate_iou(fake_seg.flatten(),fake_gt.flatten()) 116 | batch_iou += iou 117 | return batch_f1 / batch_size, batch_iou / batch_size 118 | 119 | 120 | def val_mvss(model, valloader): 121 | losses = AverageMeter() 122 | scores = AverageMeter() 123 | ious = AverageMeter() 124 | model.eval() 125 | 126 | dice_loss = DiceLoss() 127 | alpha = 0.16 128 | cls_loss = ClsLoss() 129 | beta = 0.04 130 | edge_loss = EdgeLoss() 131 | 132 | pixel_bceloss = PixelClsLoss() 133 | 134 | for img, label in tqdm(valloader): 135 | img, label = img.cuda(), label.cuda() 136 | # print(img.shape, label.shape) 137 | with torch.no_grad(): 138 | edge, seg = model(img) 139 | seg = torch.sigmoid(seg) 140 | loss = alpha*dice_loss(seg, label).item() + beta*cls_loss(seg,label).item() + (1-alpha-beta)*edge_loss(edge, label).item() 141 | # loss = pixel_bceloss(seg, label).item() 142 | losses.update(loss, 1) 143 | f1, iou = calculate_batch_score(seg, label) 144 | scores.update(f1, 1) 145 | ious.update(iou, 1) 146 | 147 | return losses.avg, scores.avg, ious.avg 148 | 149 | def train_rru(model, trainloader, optimizer): 150 | losses = AverageMeter() 151 | f1_score = AverageMeter() 152 | ious = AverageMeter() 153 | 154 | model.train() 155 | 156 | dice_loss = DiceLoss() 157 | bce_loss = PixelClsLoss() 158 | 159 | for img, label in tqdm(trainloader): 160 | img, label = img.cuda(), label.cuda() 161 | # print(img.shape, label.shape) 162 | seg = model(img) 163 | # print(edge.shape, seg.shape) 164 | seg = torch.sigmoid(seg) 165 | loss = dice_loss(seg, label)*0.3 + bce_loss(seg, label)*0.7 166 | 167 | f1, iou = calculate_batch_score(seg, label) 168 | optimizer.zero_grad() 169 | loss.backward() 170 | optimizer.step() 171 | 172 | losses.update(loss.item(), img.shape[0]) 173 | f1_score.update(f1, img.shape[0]) 174 | ious.update(iou, img.shape[0]) 175 | 176 | return losses.avg, f1_score.avg, ious.avg 177 | 178 | def val_rru(model, valloader): 179 | losses = AverageMeter() 180 | scores = AverageMeter() 181 | ious = AverageMeter() 182 | model.eval() 183 | 184 | dice_loss = DiceLoss() 185 | # dice_loss = PixelClsLoss() 186 | 187 | for img, label in tqdm(valloader): 188 | img, label = img.cuda(), label.cuda() 189 | # print(img.shape, label.shape) 190 | with torch.no_grad(): 191 | seg = model(img) 192 | seg = torch.sigmoid(seg) 193 | loss = dice_loss(seg, label).item() 194 | 195 | losses.update(loss, 1) 196 | f1, iou = calculate_batch_score(seg, label) 197 | scores.update(f1, 1) 198 | ious.update(iou, 1) 199 | 200 | return losses.avg, scores.avg, ious.avg 201 | 202 | def train_unet(model, trainloader, optimizer): 203 | losses = AverageMeter() 204 | f1_score = AverageMeter() 205 | ious = AverageMeter() 206 | 207 | global best_score 208 | 209 | model.train() 210 | 211 | dice_loss = DiceLoss() 212 | bce_loss = PixelClsLoss() 213 | 214 | i = 0 215 | 216 | for img, label in tqdm(trainloader): 217 | img, label = img.cuda(), label.cuda() 218 | # print(img.shape, label.shape) 219 | seg = model(img) 220 | # # print(edge.shape, seg.shape) 221 | # seg = torch.sigmoid(seg) 222 | loss = dice_loss(seg, label)*0.3 + bce_loss(seg, label)*0.7 223 | # loss = lovasz_softmax(seg, label) 224 | 225 | f1, iou = calculate_batch_score(seg, label) 226 | optimizer.zero_grad() 227 | loss.backward() 228 | optimizer.step() 229 | 230 | losses.update(loss.item(), img.shape[0]) 231 | f1_score.update(f1, img.shape[0]) 232 | ious.update(iou, img.shape[0]) 233 | 234 | return losses.avg, f1_score.avg, ious.avg 235 | 236 | def val_unet(model, valloader): 237 | losses = AverageMeter() 238 | scores = AverageMeter() 239 | ious = AverageMeter() 240 | model.eval() 241 | 242 | dice_loss = DiceLoss() 243 | # dice_loss = PixelClsLoss() 244 | 245 | for img, label in tqdm(valloader): 246 | img, label = img.cuda(), label.cuda() 247 | # print(img.shape, label.shape) 248 | with torch.no_grad(): 249 | seg = model(img) 250 | # seg = torch.sigmoid(seg) 251 | loss = dice_loss(seg, label).item() 252 | 253 | losses.update(loss, 1) 254 | f1, iou = calculate_batch_score(seg, label) 255 | scores.update(f1, 1) 256 | ious.update(iou, 1) 257 | 258 | return losses.avg, scores.avg, ious.avg 259 | 260 | def cal_sliding_params(img_h, img_w): 261 | # 计算需要裁剪成几块 262 | col, row = 1, 1 263 | while (512*col - (col-1)*128) < img_h: 264 | col += 1 265 | while (512*row - (row-1)*128) < img_w: 266 | row += 1 267 | return col, row 268 | 269 | def img_slide_window(img, col, row): 270 | imgs = [] 271 | # 计算 overlape 272 | delta_x, delta_y = 0, 0 273 | if row > 1: 274 | delta_x = int((row*512-img.shape[-1])/(row-1)) 275 | if col > 1: 276 | delta_y = int((col*512-img.shape[-2])/(col-1)) 277 | 278 | for i in range(col): 279 | for j in range(row): 280 | begin_h = 512*i - max(0, i)*delta_y 281 | begin_w = 512*j - max(0, j)*delta_x 282 | 283 | if begin_h + 512 > img.shape[-2]: 284 | begin_h = img.shape[-2] - 512 285 | if begin_w + 512 > img.shape[-1]: 286 | begin_w = img.shape[-1] - 512 287 | slide = img[:, :, begin_h:begin_h+512, begin_w:begin_w+512].squeeze(0) 288 | imgs.append(slide) 289 | # print(begin_h, begin_w, begin_h+512, begin_w+512, img.shape) 290 | return torch.stack(imgs, dim=0) 291 | 292 | def merge_slides_result(segs, col, row, img_shape): 293 | count = torch.zeros([1, img_shape[2], img_shape[3]]).cuda() 294 | seg = torch.zeros([1, img_shape[2], img_shape[3]]).cuda() 295 | 296 | # 计算 overlape 297 | delta_x, delta_y = 0, 0 298 | if row > 1: 299 | delta_x = int((row*512-img_shape[-1])/(row-1)) 300 | if col > 1: 301 | delta_y = int((col*512-img_shape[-2])/(col-1)) 302 | 303 | # print(col, row) 304 | for i in range(col): 305 | for j in range(row): 306 | begin_h = 512*i - max(0, i)*delta_y 307 | begin_w = 512*j - max(0, j)*delta_x 308 | 309 | if begin_h + 512 > img_shape[-2]: 310 | begin_h = img_shape[-2] - 512 311 | if begin_w + 512 > img_shape[-1]: 312 | begin_w = img_shape[-1] - 512 313 | seg[:, begin_h:begin_h+512, begin_w:begin_w+512] += segs[i*row+j] 314 | count[:, begin_h:begin_h+512, begin_w:begin_w+512] += 1.0 315 | seg = seg / count 316 | return seg.unsqueeze(0) 317 | 318 | def slide_window_val_unet(model, valloader): 319 | losses = AverageMeter() 320 | scores = AverageMeter() 321 | ious = AverageMeter() 322 | model.eval() 323 | 324 | dice_loss = DiceLoss() 325 | # dice_loss = PixelClsLoss() 326 | 327 | for img, label in tqdm(valloader): 328 | img, label = img.cuda(), label.cuda() #[3,h,w] 329 | # print(img.shape, label.shape) 330 | with torch.no_grad(): 331 | img_h, img_w = img.shape[-2], img.shape[-1] 332 | col, row = cal_sliding_params(img_h, img_w) 333 | imgs = img_slide_window(img, col, row) 334 | # print(imgs.shape) 335 | seg = model(imgs) 336 | # seg = torch.sigmoid(seg) 337 | seg = merge_slides_result(seg, col, row, img.shape) 338 | # print(seg.shape, label.shape) 339 | loss = dice_loss(seg, label).item() 340 | # print(loss) 341 | # exit() 342 | 343 | losses.update(loss, 1) 344 | f1, iou = calculate_batch_score(seg, label) 345 | scores.update(f1, 1) 346 | ious.update(iou, 1) 347 | # print(losses.avg) 348 | return losses.avg, scores.avg, ious.avg 349 | 350 | def save_checkpoint(state, work_dir, name): 351 | filepath = os.path.join(work_dir, "weights/", name + '.pth.tar') 352 | torch.save(state, filepath) 353 | 354 | def set_work_dir(work_dir): 355 | if not os.path.exists(work_dir): 356 | os.mkdir(work_dir) 357 | os.mkdir(os.path.join(work_dir, "weights/")) 358 | logger.add(os.path.join(work_dir, "info.log")) 359 | logger.info("Training begin") 360 | 361 | if __name__ == '__main__': 362 | args = parse_args() 363 | 364 | set_work_dir(args.work_dir) 365 | 366 | # init model 367 | model = init_model(args.model, args.pretrained_ckpt) 368 | 369 | root = ["/data1/datasets/Image-Manipulation-Detection/train/", ] 370 | split = ["train-split0.txt"] 371 | 372 | #init dataset 373 | trainset = ManiDatasetAug(root, split=split, w=512, h=512, mode='train') 374 | trainloader = data.DataLoader(trainset, batch_size=args.batchsize, shuffle=True, num_workers=4) 375 | valset = ManiDataset([args.root], split=["val-split0.txt"], w=512, h=512, mode='val') 376 | valloader = data.DataLoader(valset, batch_size=args.batchsize*2, shuffle=False, num_workers=4) 377 | # print(len(trainset),len(valset)) 378 | 379 | optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9,0.999), eps=1e-08, weight_decay=0.0) 380 | # optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001, betas=(0.9,0.999), eps=1e-08, weight_decay=0.005) 381 | 382 | global best_score 383 | best_score = 0.0 384 | for epoch in range(args.epoch): 385 | # epoch_loss, train_f1, train_iou = train_rru(model, trainloader, optimizer) 386 | epoch_loss, train_f1, train_iou = train_unet(model, trainloader, optimizer) 387 | 388 | eval_loss, eval_f1, eval_iou = val_unet(model, valloader) 389 | # eval_loss, eval_f1, eval_iou = val_rru(model, valloader) 390 | # eval_loss, eval_f1, eval_iou = slide_window_val_unet(model, valloader) 391 | 392 | logger.info("epoch: {}, train loss: {}, train f1: {}, train iou: {}, train score: {}, val loss: {}, val f1: {}, val iou: {}, val score: {}".format(epoch, epoch_loss, train_f1, train_iou, train_iou+train_f1, eval_loss, eval_f1, eval_iou, eval_iou+eval_f1)) 393 | save_checkpoint(model.state_dict(), args.work_dir, "last") 394 | 395 | if(eval_f1 + eval_iou > best_score): 396 | best_score = eval_f1 + eval_iou 397 | save_checkpoint(model.state_dict(), args.work_dir, "best") 398 | logger.info("save_checkpoint, best_score: {}".format(best_score)) -------------------------------------------------------------------------------- /code/th-search.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | import torch 4 | import cv2 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | from PIL import Image 8 | 9 | from models.mvssnet import get_mvss 10 | from models.unet_model import Ringed_Res_Unet, Ringed_Res_Unet_Slim 11 | import segmentation_models_pytorch as smp 12 | # from common.tools import inference_single 13 | from common.utils import calculate_pixel_f1, calculate_img_score, AverageMeter 14 | from loss import ClsLoss, DiceLoss, PixelClsLoss, EdgeLoss 15 | from loss import * 16 | 17 | import torchvision.transforms as transforms 18 | import torch.utils.data as data 19 | import argparse 20 | from tqdm import tqdm 21 | from dataset import ManiDataset 22 | 23 | from loguru import logger 24 | 25 | def parse_args(): 26 | parser = argparse.ArgumentParser(description='train') 27 | parser.add_argument('--root', type=str, default='/data1/datasets/Image-Manipulation-Detection/train/') 28 | parser.add_argument('--model', type=str, default='None') 29 | parser.add_argument('--th', type=float, default=0.5) 30 | parser.add_argument('--epoch', type=int, default=200) 31 | parser.add_argument('--batchsize', type=int, default=1) 32 | # parser.add_argument("--model_name", type=str, help="Path to the pretrained model", default="ckpt/mvssnet.pth") 33 | parser.add_argument('--work-dir', type=str, default='./work_dir/') 34 | parser.add_argument('--weights', type=str, default='None') 35 | args = parser.parse_args() 36 | return args 37 | 38 | def init_model(model_type, pretrained_ckpt=None): 39 | if model_type == 'mvssnet': 40 | model = get_mvss(backbone='resnet50', 41 | pretrained_base=True, 42 | nclass=1, 43 | sobel=True, 44 | constrain=True, 45 | n_input=3, 46 | ) 47 | # TODO: initialize with pretrained_ckpt 48 | # checkpoint = torch.load("./ckpt/mvssnet_tianchi.pt", map_location='cpu') 49 | elif model_type == 'rrunet': 50 | model = Ringed_Res_Unet(n_channels=3, n_classes=1) 51 | # checkpoint = torch.load("./work_dir/rru-diceloss-fold4/weights/last.pth.tar", map_location='cpu') 52 | # model.load_state_dict(checkpoint, strict=True) 53 | elif model_type == 'rrunet-slim': 54 | model = Ringed_Res_Unet_Slim(n_channels=3, n_classes=1) 55 | elif model_type == 'unet': 56 | model = smp.Unet('efficientnet-b5', classes=1, activation='sigmoid') 57 | elif model_type == 'linknet': 58 | model = smp.Linknet('efficientnet-b5', classes=1, activation='sigmoid') 59 | 60 | if pretrained_ckpt != None: 61 | checkpoint = torch.load(pretrained_ckpt, map_location='cpu') 62 | model.load_state_dict(checkpoint, strict=True) 63 | model = model.cuda() 64 | return model 65 | 66 | def train_mvss(model, trainloader, optimizer): 67 | losses = AverageMeter() 68 | f1_score = AverageMeter() 69 | ious = AverageMeter() 70 | 71 | model.train() 72 | 73 | dice_loss = DiceLoss() 74 | cls_loss = ClsLoss() 75 | alpha = 0.16 76 | beta = 0.04 77 | edge_loss = EdgeLoss() 78 | pixel_bceloss = PixelClsLoss() 79 | 80 | for img, label in tqdm(trainloader): 81 | img, label = img.cuda(), label.cuda() 82 | # print(img.shape, label.shape) 83 | edge, seg = model(img) 84 | # print(edge.shape, seg.shape) 85 | seg = torch.sigmoid(seg) 86 | loss = alpha*dice_loss(seg, label) + beta*cls_loss(seg, label) + (1-alpha-beta)*edge_loss(edge, label) 87 | # loss = pixel_bceloss(seg, label) 88 | f1, iou = calculate_batch_score(seg, label) 89 | optimizer.zero_grad() 90 | loss.backward() 91 | optimizer.step() 92 | 93 | losses.update(loss.item(), img.shape[0]) 94 | f1_score.update(f1, img.shape[0]) 95 | ious.update(iou, img.shape[0]) 96 | 97 | return losses.avg, f1_score.avg, ious.avg 98 | 99 | def calculate_iou(segs, labels, eps=1e-8): 100 | intersaction = segs * labels 101 | iou = intersaction.sum()/(segs.sum()+labels.sum()-intersaction.sum() + eps) 102 | return iou 103 | 104 | def calculate_batch_score(segs, labels, th=0.5): 105 | batch_size = segs.shape[0] 106 | batch_f1, batch_iou = 0.0, 0.0 107 | for i in range(batch_size): 108 | pd = segs[i] 109 | gt = labels[i] 110 | fake_seg = pd.detach().cpu().numpy() 111 | fake_gt = gt.detach().cpu().numpy() 112 | fake_seg = np.where(fake_seg 1: 212 | delta_x = int((row*512-img.shape[-1])/(row-1)) 213 | if col > 1: 214 | delta_y = int((col*512-img.shape[-2])/(col-1)) 215 | 216 | for i in range(col): 217 | for j in range(row): 218 | begin_h = 512*i - max(0, i)*delta_y 219 | begin_w = 512*j - max(0, j)*delta_x 220 | 221 | if begin_h + 512 > img.shape[-2]: 222 | begin_h = img.shape[-2] - 512 223 | if begin_w + 512 > img.shape[-1]: 224 | begin_w = img.shape[-1] - 512 225 | slide = img[:, :, begin_h:begin_h+512, begin_w:begin_w+512].squeeze(0) 226 | imgs.append(slide) 227 | # print(begin_h, begin_w, begin_h+512, begin_w+512, img.shape) 228 | return torch.stack(imgs, dim=0) 229 | 230 | def merge_slides_result(segs, col, row, img_shape): 231 | count = torch.zeros([1, img_shape[2], img_shape[3]]).cuda() 232 | seg = torch.zeros([1, img_shape[2], img_shape[3]]).cuda() 233 | 234 | # 计算 overlape 235 | delta_x, delta_y = 0, 0 236 | if row > 1: 237 | delta_x = int((row*512-img_shape[-1])/(row-1)) 238 | if col > 1: 239 | delta_y = int((col*512-img_shape[-2])/(col-1)) 240 | 241 | # print(col, row) 242 | for i in range(col): 243 | for j in range(row): 244 | begin_h = 512*i - max(0, i)*delta_y 245 | begin_w = 512*j - max(0, j)*delta_x 246 | 247 | if begin_h + 512 > img_shape[-2]: 248 | begin_h = img_shape[-2] - 512 249 | if begin_w + 512 > img_shape[-1]: 250 | begin_w = img_shape[-1] - 512 251 | seg[:, begin_h:begin_h+512, begin_w:begin_w+512] += segs[i*row+j] 252 | count[:, begin_h:begin_h+512, begin_w:begin_w+512] += 1.0 253 | seg = seg / count 254 | return seg.unsqueeze(0) 255 | 256 | def slide_window_val_unet(model, valloader, th=0.5): 257 | losses = AverageMeter() 258 | scores = AverageMeter() 259 | ious = AverageMeter() 260 | model.eval() 261 | 262 | dice_loss = DiceLoss() 263 | # dice_loss = PixelClsLoss() 264 | 265 | for img, label in tqdm(valloader): 266 | img, label = img.cuda(), label.cuda() #[3,h,w] 267 | # print(img.shape, label.shape) 268 | with torch.no_grad(): 269 | img_h, img_w = img.shape[-2], img.shape[-1] 270 | col, row = cal_sliding_params(img_h, img_w) 271 | imgs = img_slide_window(img, col, row) 272 | # print(imgs.shape) 273 | seg = model(imgs) 274 | # seg = torch.sigmoid(seg) 275 | seg = merge_slides_result(seg, col, row, img.shape) 276 | loss = dice_loss(seg, label).item() 277 | 278 | losses.update(loss, 1) 279 | f1, iou = calculate_batch_score(seg, label, th) 280 | scores.update(f1, 1) 281 | ious.update(iou, 1) 282 | 283 | return losses.avg, scores.avg, ious.avg 284 | 285 | def TTA_val(model, valloader, th=0.5, alpha=0.3): 286 | losses = AverageMeter() 287 | scores = AverageMeter() 288 | ious = AverageMeter() 289 | model.eval() 290 | 291 | dice_loss = DiceLoss() 292 | # dice_loss = PixelClsLoss() 293 | 294 | for img, label in tqdm(valloader): 295 | img, label = img.cuda(), label.cuda() #[3,h,w] 296 | # print(img.shape, label.shape) 297 | 298 | # 滑窗计算 299 | with torch.no_grad(): 300 | img_h, img_w = img.shape[-2], img.shape[-1] 301 | col, row = cal_sliding_params(img_h, img_w) 302 | imgs = img_slide_window(img, col, row) 303 | # print(imgs.shape) 304 | seg = model(imgs) 305 | # seg = torch.sigmoid(seg) 306 | seg = merge_slides_result(seg, col, row, img.shape) 307 | loss = dice_loss(seg, label).item() 308 | 309 | # 直接resize 计算 310 | transform = transforms.Compose([ 311 | transforms.Resize([512,512]) 312 | ]) 313 | invtransform = transforms.Compose([ 314 | transforms.Resize([img.shape[-2], img.shape[-1]]) 315 | ]) 316 | with torch.no_grad(): 317 | seg_resize = model(transform(img)) 318 | seg_resize = invtransform(seg_resize) 319 | seg = seg*alpha + seg_resize*(1-alpha) 320 | 321 | # 标准化处理 322 | seg = seg-seg.min() 323 | seg = seg / seg.max() 324 | 325 | losses.update(loss, 1) 326 | f1, iou = calculate_batch_score(seg, label, th) 327 | scores.update(f1, 1) 328 | ious.update(iou, 1) 329 | 330 | return losses.avg, scores.avg, ious.avg 331 | 332 | def TTA_val2(model, model2, valloader, th=0.5, alpha=0.3): 333 | # model:slidewindow 334 | # model:Resize 335 | 336 | losses = AverageMeter() 337 | scores = AverageMeter() 338 | ious = AverageMeter() 339 | model.eval() 340 | model1.eval() 341 | 342 | dice_loss = DiceLoss() 343 | # dice_loss = PixelClsLoss() 344 | 345 | for img, label in tqdm(valloader): 346 | img, label = img.cuda(), label.cuda() #[3,h,w] 347 | # print(img.shape, label.shape) 348 | 349 | # 滑窗计算 350 | with torch.no_grad(): 351 | img_h, img_w = img.shape[-2], img.shape[-1] 352 | col, row = cal_sliding_params(img_h, img_w) 353 | imgs = img_slide_window(img, col, row) 354 | # print(imgs.shape) 355 | seg = model(imgs) 356 | # seg = torch.sigmoid(seg) 357 | seg = merge_slides_result(seg, col, row, img.shape) 358 | loss = dice_loss(seg, label).item() 359 | 360 | # 直接resize 计算 361 | transform = transforms.Compose([ 362 | transforms.Resize([512,512]) 363 | ]) 364 | invtransform = transforms.Compose([ 365 | transforms.Resize([img.shape[-2], img.shape[-1]]) 366 | ]) 367 | with torch.no_grad(): 368 | seg_resize = model1(transform(img)) 369 | seg_resize = invtransform(seg_resize) 370 | seg = seg*alpha + seg_resize*(1-alpha) 371 | 372 | 373 | losses.update(loss, 1) 374 | f1, iou = calculate_batch_score(seg, label, th) 375 | scores.update(f1, 1) 376 | ious.update(iou, 1) 377 | 378 | return losses.avg, scores.avg, ious.avg 379 | 380 | 381 | if __name__ == '__main__': 382 | args = parse_args() 383 | 384 | # init model 385 | model = init_model(args.model, args.weights) 386 | # model = init_model('unet', "./work_dir/2859-unet-5fold2-sw512-aug/weights/best.pth.tar") 387 | # model1 = init_model('unet', "./work_dir/2859-unet-5fold0-aug/weights/best.pth.tar") 388 | 389 | # root = ["/data1/datasets/Image-Manipulation-Detection/train-slide-window-512/"] 390 | # split = ["train-split4-slide-window-512.txt",] 391 | 392 | root = ["/data1/datasets/Image-Manipulation-Detection/train/", ] 393 | split = ["train-split2.txt"] 394 | 395 | #init dataset 396 | trainset = ManiDataset(root, split=split, w=512, h=512, mode='train') 397 | trainloader = data.DataLoader(trainset, batch_size=args.batchsize, shuffle=True, num_workers=4) 398 | valset = ManiDataset([args.root], split=["val-split2.txt"], w=512, h=512, mode='val-rawsize') 399 | valloader = data.DataLoader(valset, batch_size=args.batchsize, shuffle=False, num_workers=4) 400 | # print(len(trainset),len(valset)) 401 | 402 | best_score = 0.0 403 | ths = [0.1, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] 404 | # ths = [0.5, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1.0] 405 | # ths = [0.3] 406 | for th in ths: 407 | # eval_loss, eval_f1, eval_iou = val_unet(model, valloader, th=th) 408 | # eval_loss, eval_f1, eval_iou = val_rru(model, valloader) 409 | # eval_loss, eval_f1, eval_iou = slide_window_val_unet(model, valloader, th=th) 410 | eval_loss, eval_f1, eval_iou = TTA_val(model, valloader, th=th, alpha=0.3) 411 | # eval_loss, eval_f1, eval_iou = TTA_val2(model, model1, valloader, th=0.3, alpha=th) 412 | print("alpha: {}, loss:{}, f1: {}, iou: {}, score: {}".format(th, eval_loss, eval_f1, eval_iou, eval_f1+eval_iou)) 413 | -------------------------------------------------------------------------------- /code/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | import torch.utils.data as data 7 | import torchvision.transforms as transforms 8 | import torchvision.datasets as datasets 9 | from PIL import Image 10 | import cv2 11 | import os 12 | import random 13 | import albumentations as albu 14 | import random 15 | import math 16 | 17 | class ManiDataset(torch.utils.data.Dataset): 18 | def __init__(self, root, split, w=512, h=512, mode="train", resize=True, crop=False): 19 | # print(root, split) 20 | assert len(root) == len(split) 21 | 22 | self.splits = [] 23 | for i, s in enumerate(split): 24 | self.splits.append(os.path.join(root[i], s)) 25 | 26 | self.roots = root 27 | self.imgs = [] 28 | self.labels = [] 29 | self.w = w 30 | self.h = h 31 | 32 | self.mode = mode 33 | self.resize = resize 34 | self.crop = crop 35 | 36 | self.setup(self.roots, self.splits) 37 | 38 | self.transform = transforms.Compose([ 39 | transforms.ToTensor(), 40 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 41 | ]) 42 | self.mask_transform = transforms.Compose([ 43 | transforms.ToTensor() 44 | ]) 45 | 46 | if self.mode == 'train' or self.mode == 'semi-train': 47 | self.albu = albu.Compose([ 48 | albu.RandomBrightnessContrast(p=0.5), 49 | albu.OneOf([ 50 | albu.ImageCompression(quality_lower=20, quality_upper=50, p=0.5), 51 | albu.ISONoise(color_shift=(0.01, 0.05), intensity=(0.1, 0.5), always_apply=False, p=0.5), 52 | albu.GaussNoise(var_limit=(100, 150), p=0.5) 53 | ], p=0.5), 54 | albu.OneOf([ 55 | albu.Rotate(limit=[90,90], p=0.5), 56 | albu.Rotate(limit=[270,270], p=0.5), 57 | ], p=0.5), 58 | albu.RandomCrop(self.h, self.w, p=1), 59 | albu.Resize(self.h, self.w, p=1), 60 | ]) 61 | elif self.mode == 'val': 62 | self.albu = albu.Compose([ 63 | albu.Resize(self.h, self.w, p=1), 64 | ]) 65 | 66 | def setup(self, roots, splits): 67 | for i, split in enumerate(splits): 68 | root = roots[i] 69 | with open(split, 'r') as f: 70 | while True: 71 | line = f.readline().strip("\n") 72 | if line: 73 | self.imgs.append(os.path.join(root, "img/", line+".jpg")) 74 | if self.mode == 'train' or self.mode == 'val' or self.mode == 'val-rawsize': 75 | self.labels.append(os.path.join(root, "anno/", line+".png")) 76 | else: 77 | break 78 | # print(self.imgs) 79 | 80 | def __getitem__(self, index): 81 | fake = cv2.imread(self.imgs[index]) 82 | fake = cv2.cvtColor(fake, cv2.COLOR_BGR2RGB) 83 | fake_size = fake.shape 84 | 85 | if self.mode == 'train' or self.mode == 'val': 86 | label = cv2.imread(self.labels[index], cv2.IMREAD_GRAYSCALE) 87 | if(fake_size[0] < self.h or fake_size[1] < self.w): 88 | fake = cv2.resize(fake, (self.h,self.w)) 89 | label = cv2.resize(label, (self.h,self.w)) 90 | else: 91 | augmented = self.albu(image=fake, mask=label) 92 | fake, label = augmented['image'], augmented['mask'] 93 | _, label = cv2.threshold(label, 127, 255, cv2.THRESH_BINARY) 94 | 95 | img = self.transform(fake) 96 | label = self.mask_transform(label) 97 | return img.float(), label.float() 98 | 99 | elif self.mode == "val-rawsize": 100 | label = cv2.imread(self.labels[index], cv2.IMREAD_GRAYSCALE) 101 | _, label = cv2.threshold(label, 127, 255, cv2.THRESH_BINARY) 102 | 103 | img = self.transform(fake) 104 | label = self.mask_transform(label) 105 | return img.float(), label.float() 106 | 107 | elif self.mode == 'test': 108 | if self.resize: 109 | fake = cv2.resize(fake, (self.h, self.w)) 110 | img = self.transform(fake) 111 | img_name = self.imgs[index].split("/")[-1].split('.')[0] 112 | return img.float(), img_name, fake_size 113 | elif self.mode == 'semi-train': 114 | augmented = self.albu(image=fake) 115 | fake = augmented['image'] 116 | img = self.transform(fake) 117 | return img.float() 118 | 119 | def __len__(self): 120 | return len(self.imgs) 121 | 122 | def cal_new_mask(new_img, img, mask): 123 | """ 124 | new img: 二次篡改的图片 125 | img:原来训练集中的图片 126 | mask:二次篡改前的标签,0-255 127 | """ 128 | diff_img = cv2.absdiff(new_img, img) 129 | diff = np.linalg.norm(diff_img, ord=np.inf, axis=2) 130 | # print(diff.shape, mask.shape) 131 | _, diff = cv2.threshold(diff, 1, 255, cv2.THRESH_BINARY) 132 | 133 | new_mask = diff + mask 134 | new_mask = np.clip(new_mask, 0, 255) 135 | 136 | return new_mask 137 | 138 | def rand_bbox(size): 139 | # opencv格式的size 140 | W = size[1] 141 | H = size[0] 142 | 143 | cut_rat_w = random.random()*0.1 + 0.05 144 | cut_rat_h = random.random()*0.1 + 0.05 145 | 146 | cut_w = int(W * cut_rat_w) 147 | cut_h = int(H * cut_rat_h) 148 | 149 | cx = np.random.randint(W) 150 | cy = np.random.randint(H) 151 | 152 | bbx1 = np.clip(cx - cut_w // 2, 0, W) # 左上 153 | bby1 = np.clip(cy - cut_h // 2, 0, H) # 左上 154 | bbx2 = np.clip(cx + cut_w // 2, 0, W) # 右下 155 | bby2 = np.clip(cy + cut_h // 2, 0, H) # 右下 156 | 157 | return bbx1, bby1, bbx2, bby2 158 | 159 | def copy_move(img1, img2, msk, is_plot=False): 160 | img = img1.copy() 161 | size = img.shape # h,w,c 162 | W = size[1] 163 | H = size[0] 164 | 165 | if img2 is None: # 从自身复制粘贴 166 | bbx1, bby1, bbx2, bby2 = rand_bbox(img.shape) 167 | 168 | x_move = random.randrange(-bbx1, (W - bbx2)) 169 | y_move = random.randrange(-bby1, (H - bby2)) 170 | 171 | img[bby1+y_move:bby2+y_move, bbx1+x_move:bbx2+x_move, :] = img[bby1:bby2, bbx1:bbx2, :] 172 | 173 | else: # 从其他图像复制粘贴 174 | bbx1, bby1, bbx2, bby2 = rand_bbox(img2.shape) 175 | 176 | x_move = random.randrange(-bbx1, (W - bbx2)) 177 | y_move = random.randrange(-bby1, (H - bby2)) 178 | 179 | img[bby1+y_move:bby2+y_move, bbx1+x_move:bbx2+x_move, :] = img2[bby1:bby2, bbx1:bbx2, :] 180 | 181 | """ 182 | 这里改了一下dave的代码中直接根据修改区域计算mask,因为我发现有时候裁剪了一样的区域粘贴过来, 183 | 计算方法是二次篡改的图片减去原图,有差异的地方叠加到原来的mask上 184 | """ 185 | msk = cal_new_mask(img, img1, msk) 186 | if is_plot: # 标出二次窜改的区域,主要是为了debug,生成图像的时候记得改成false 187 | img = cv2.rectangle(img, pt1=[bbx1+x_move, bby1+y_move], pt2=[bbx2+x_move, bby2+y_move], color=(255,0,0), thickness=3) 188 | 189 | return np.uint8(img), np.uint8(msk) 190 | 191 | def erase(img1, msk, is_plot=False): 192 | img = img1.copy() 193 | size = img.shape # h,w,c 194 | W = size[1] 195 | H = size[0] 196 | 197 | def midpoint(x1, y1, x2, y2): 198 | x_mid = int((x1 + x2)/2) 199 | y_mid = int((y1 + y2)/2) 200 | return (x_mid, y_mid) 201 | 202 | bbx1, bby1, bbx2, bby2 = rand_bbox(img.shape) 203 | # print(bbx1, bby1, bbx2, bby2) 204 | 205 | x_mid0, y_mid0 = midpoint(bbx1, bby1, bbx1, bby2) 206 | x_mid1, y_mid1 = midpoint(bbx2, bby1, bbx2, bby2) 207 | thickness = int(math.sqrt((bby2-bby1)**2)) 208 | 209 | mask_ = np.zeros(img.shape[:2], dtype="uint8") 210 | cv2.line(mask_, (x_mid0, y_mid0), (x_mid1, y_mid1), 255, thickness) 211 | 212 | # cv2.imwrite("mask_.jpg", mask_) 213 | img = cv2.inpaint(img, mask_, 7, cv2.INPAINT_NS) 214 | 215 | msk = cal_new_mask(img1, img, msk) 216 | 217 | if is_plot: 218 | img = cv2.rectangle(img, pt1=[bbx1, bby1], pt2=[bbx2, bby2], color=(255,0,0), thickness=3) 219 | 220 | return np.uint8(img), np.uint8(msk) 221 | 222 | def mosaic_effect(img): 223 | img = img.numpy().transpose(1,2,0) 224 | h, w, n = img.shape 225 | # size = random.randint(5, 20) #马赛克大小 226 | size = 9 227 | for i in range(0, h - size, size): 228 | for j in range(0, w - size, size): 229 | rect = [j, i, size, size] 230 | color = img[i, j].tolist() 231 | left_up = (rect[0], rect[1]) 232 | right_down = (rect[0]+size, rect[1]+size) 233 | cv2.rectangle(img, left_up, right_down, color, -1) 234 | return torch.from_numpy(img).permute(2, 0, 1) 235 | 236 | def mosaic(img, msk): 237 | resize = albu.Resize(512,512)(image=img, mask=msk) 238 | img = torch.from_numpy(resize['image']).permute(2, 0, 1) 239 | msk = torch.from_numpy(resize['mask']) 240 | size = img.size() 241 | if len(size) == 4: 242 | W = size[2] 243 | H = size[3] 244 | elif len(size) == 3: 245 | W = size[1] 246 | H = size[2] 247 | else: 248 | raise Exception 249 | 250 | bbx1, bby1, bbx2, bby2 = rand_bbox(img.size()) 251 | 252 | x_move = random.randrange(-bbx1, (W - bbx2)) 253 | y_move = random.randrange(-bby1, (H - bby2)) 254 | 255 | img[:, bbx1+x_move:bbx2+x_move, bby1+y_move:bby2+y_move] = mosaic_effect(img[:, bbx1:bbx2, bby1:bby2]) 256 | msk[bbx1+x_move:bbx2+x_move, bby1+y_move:bby2+y_move] = torch.ones_like(msk[bbx1:bbx2, bby1:bby2])*255 257 | 258 | # img = img.numpy().transpose(1,2,0) 259 | img = cv2.rectangle(img.numpy().transpose(1,2,0),pt1=[bby1+y_move, bbx1+x_move], pt2=[bby2+y_move, bbx2+x_move], color=(255,0,0), thickness=3) 260 | msk = msk.numpy() 261 | 262 | return img, msk 263 | 264 | class ManiDatasetAug(torch.utils.data.Dataset): 265 | def __init__(self, root, split, w=512, h=512, mode="train", resize=True, crop=False): 266 | # print(root, split) 267 | assert len(root) == len(split) 268 | 269 | self.splits = [] 270 | for i, s in enumerate(split): 271 | self.splits.append(os.path.join(root[i], s)) 272 | 273 | self.roots = root 274 | self.imgs = [] 275 | self.labels = [] 276 | self.w = w 277 | self.h = h 278 | 279 | self.mode = mode 280 | self.resize = resize 281 | self.crop = crop 282 | 283 | self.setup(self.roots, self.splits) 284 | 285 | self.transform = transforms.Compose([ 286 | transforms.ToTensor(), 287 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 288 | ]) 289 | self.mask_transform = transforms.Compose([ 290 | transforms.ToTensor() 291 | ]) 292 | 293 | if self.mode == 'train' or self.mode == 'semi-train': 294 | self.albu = albu.Compose([ 295 | albu.RandomBrightnessContrast(p=0.5), 296 | albu.OneOf([ 297 | albu.ImageCompression(quality_lower=20, quality_upper=50, p=0.5), 298 | albu.ISONoise(color_shift=(0.01, 0.05), intensity=(0.1, 0.5), always_apply=False, p=0.5), 299 | albu.GaussNoise(var_limit=(100, 150), p=0.5) 300 | ], p=0.5), 301 | albu.OneOf([ 302 | albu.Rotate(limit=[90,90], p=0.5), 303 | albu.Rotate(limit=[270,270], p=0.5), 304 | ], p=0.5), 305 | # albu.RandomCrop(self.h, self.w, p=1), 306 | albu.RandomResizedCrop(512, 512, scale=(0.8, 1.0), ratio=(0.75, 1.3333333333333333), p=1), 307 | albu.Resize(self.h, self.w, p=1), 308 | ]) 309 | elif self.mode == 'val': 310 | self.albu = albu.Compose([ 311 | albu.Resize(self.h, self.w, p=1), 312 | ]) 313 | 314 | def setup(self, roots, splits): 315 | for i, split in enumerate(splits): 316 | root = roots[i] 317 | with open(split, 'r') as f: 318 | while True: 319 | line = f.readline().strip("\n") 320 | if line: 321 | self.imgs.append(os.path.join(root, "img/", line+".jpg")) 322 | if self.mode == 'train' or self.mode == 'val' or self.mode == 'val-rawsize': 323 | self.labels.append(os.path.join(root, "anno/", line+".png")) 324 | else: 325 | break 326 | # print(self.imgs) 327 | 328 | def __getitem__(self, index): 329 | fake = cv2.imread(self.imgs[index]) 330 | fake = cv2.cvtColor(fake, cv2.COLOR_BGR2RGB) 331 | fake_size = fake.shape 332 | 333 | if self.mode == 'train': 334 | label = cv2.imread(self.labels[index], cv2.IMREAD_GRAYSCALE) 335 | 336 | p = random.randint(0,3) 337 | # print(p) 338 | if p == 2: 339 | #自身随机裁切 340 | fake, label = copy_move(fake, None, label) 341 | elif p == 3: 342 | # 从其他图片随机裁切 343 | img2 = cv2.imread(self.imgs[random.randint(0, len(self.imgs)-1)]) 344 | img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB) 345 | fake, label = copy_move(fake, img2, label) 346 | elif p == 1: 347 | # 随机擦除 348 | fake, label = erase(fake, label) 349 | 350 | # # 随机马赛克: 351 | # # fake, label = mosaic(fake, label) 352 | 353 | if(fake_size[0] < self.h or fake_size[1] < self.w): 354 | fake = cv2.resize(fake, (self.h,self.w)) 355 | label = cv2.resize(label, (self.h,self.w)) 356 | else: 357 | augmented = self.albu(image=fake, mask=label) 358 | fake, label = augmented['image'], augmented['mask'] 359 | _, label = cv2.threshold(label, 127, 255, cv2.THRESH_BINARY) 360 | 361 | # cv2.imwrite("aug_img.jpg", fake) 362 | # cv2.imwrite("aug_mask.jpg", label) 363 | 364 | img = self.transform(fake) 365 | label = self.mask_transform(label) 366 | 367 | return img.float(), label.float() 368 | 369 | elif self.mode == 'val': 370 | label = cv2.imread(self.labels[index], cv2.IMREAD_GRAYSCALE) 371 | if(fake_size[0] < self.h or fake_size[1] < self.w): 372 | fake = cv2.resize(fake, (self.h,self.w)) 373 | label = cv2.resize(label, (self.h,self.w)) 374 | else: 375 | augmented = self.albu(image=fake, mask=label) 376 | fake, label = augmented['image'], augmented['mask'] 377 | _, label = cv2.threshold(label, 127, 255, cv2.THRESH_BINARY) 378 | 379 | img = self.transform(fake) 380 | label = self.mask_transform(label) 381 | return img.float(), label.float() 382 | 383 | elif self.mode == "val-rawsize": 384 | label = cv2.imread(self.labels[index], cv2.IMREAD_GRAYSCALE) 385 | _, label = cv2.threshold(label, 127, 255, cv2.THRESH_BINARY) 386 | 387 | img = self.transform(fake) 388 | label = self.mask_transform(label) 389 | return img.float(), label.float() 390 | 391 | elif self.mode == 'test': 392 | if self.resize: 393 | fake = cv2.resize(fake, (self.h, self.w)) 394 | img = self.transform(fake) 395 | img_name = self.imgs[index].split("/")[-1].split('.')[0] 396 | return img.float(), img_name, fake_size 397 | 398 | def __len__(self): 399 | return len(self.imgs) 400 | 401 | if __name__ == "__main__": 402 | """ 测试dataset类 """ 403 | root = "/data1/datasets/Image-Manipulation-Detection/train" 404 | trainset = ManiDatasetAug([root], split=["train.txt"], w=512, h=512, ) 405 | trainloader = data.DataLoader(trainset, batch_size=1, shuffle=True) 406 | 407 | for img, mask in trainloader: 408 | print(img.shape, mask.shape) 409 | print(img) 410 | print(mask.max()) 411 | break -------------------------------------------------------------------------------- /code/models/mvssnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.utils.model_zoo as model_zoo 5 | import torchvision 6 | import numpy as np 7 | 8 | 9 | def get_sobel(in_chan, out_chan): 10 | filter_x = np.array([ 11 | [1, 0, -1], 12 | [2, 0, -2], 13 | [1, 0, -1], 14 | ]).astype(np.float32) 15 | filter_y = np.array([ 16 | [1, 2, 1], 17 | [0, 0, 0], 18 | [-1, -2, -1], 19 | ]).astype(np.float32) 20 | 21 | filter_x = filter_x.reshape((1, 1, 3, 3)) 22 | filter_x = np.repeat(filter_x, in_chan, axis=1) 23 | filter_x = np.repeat(filter_x, out_chan, axis=0) 24 | 25 | filter_y = filter_y.reshape((1, 1, 3, 3)) 26 | filter_y = np.repeat(filter_y, in_chan, axis=1) 27 | filter_y = np.repeat(filter_y, out_chan, axis=0) 28 | 29 | filter_x = torch.from_numpy(filter_x) 30 | filter_y = torch.from_numpy(filter_y) 31 | filter_x = nn.Parameter(filter_x, requires_grad=False) 32 | filter_y = nn.Parameter(filter_y, requires_grad=False) 33 | conv_x = nn.Conv2d(in_chan, out_chan, kernel_size=3, stride=1, padding=1, bias=False) 34 | conv_x.weight = filter_x 35 | conv_y = nn.Conv2d(in_chan, out_chan, kernel_size=3, stride=1, padding=1, bias=False) 36 | conv_y.weight = filter_y 37 | sobel_x = nn.Sequential(conv_x, nn.BatchNorm2d(out_chan)) 38 | sobel_y = nn.Sequential(conv_y, nn.BatchNorm2d(out_chan)) 39 | return sobel_x, sobel_y 40 | 41 | 42 | def run_sobel(conv_x, conv_y, input): 43 | g_x = conv_x(input) 44 | g_y = conv_y(input) 45 | g = torch.sqrt(torch.pow(g_x, 2) + torch.pow(g_y, 2)) 46 | return torch.sigmoid(g) * input 47 | 48 | 49 | def rgb2gray(rgb): 50 | b, g, r = rgb[:, 0, :, :], rgb[:, 1, :, :], rgb[:, 2, :, :] 51 | gray = 0.2989*r + 0.5870*g + 0.1140*b 52 | gray = torch.unsqueeze(gray, 1) 53 | return gray 54 | 55 | 56 | class BayarConv2d(nn.Module): 57 | def __init__(self, in_channels, out_channels, kernel_size=5, stride=1, padding=0): 58 | self.in_channels = in_channels 59 | self.out_channels = out_channels 60 | self.kernel_size = kernel_size 61 | self.stride = stride 62 | self.padding = padding 63 | self.minus1 = (torch.ones(self.in_channels, self.out_channels, 1) * -1.000) 64 | 65 | super(BayarConv2d, self).__init__() 66 | # only (kernel_size ** 2 - 1) trainable params as the center element is always -1 67 | self.kernel = nn.Parameter(torch.rand(self.in_channels, self.out_channels, kernel_size ** 2 - 1), 68 | requires_grad=True) 69 | 70 | 71 | def bayarConstraint(self): 72 | self.kernel.data = self.kernel.permute(2, 0, 1) 73 | self.kernel.data = torch.div(self.kernel.data, self.kernel.data.sum(0)) 74 | self.kernel.data = self.kernel.permute(1, 2, 0) 75 | ctr = self.kernel_size ** 2 // 2 76 | real_kernel = torch.cat((self.kernel[:, :, :ctr], self.minus1.to(self.kernel.device), self.kernel[:, :, ctr:]), dim=2) 77 | real_kernel = real_kernel.reshape((self.out_channels, self.in_channels, self.kernel_size, self.kernel_size)) 78 | return real_kernel 79 | 80 | def forward(self, x): 81 | x = F.conv2d(x, self.bayarConstraint(), stride=self.stride, padding=self.padding) 82 | return x 83 | 84 | 85 | model_urls = { 86 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 87 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 88 | } 89 | 90 | 91 | def conv3x3(in_planes, out_planes, stride=1): 92 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 93 | padding=1, bias=False) 94 | 95 | 96 | class Bottleneck(nn.Module): 97 | expansion = 4 98 | 99 | def __init__(self, inplanes, planes, stride=1, downsample=None, rate=1): 100 | super(Bottleneck, self).__init__() 101 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 102 | self.bn1 = nn.BatchNorm2d(planes) 103 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 104 | padding=rate, dilation=rate, bias=False) 105 | self.bn2 = nn.BatchNorm2d(planes) 106 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 107 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 108 | self.relu = nn.ReLU(inplace=True) 109 | self.downsample = downsample 110 | self.stride = stride 111 | 112 | def forward(self, x): 113 | residual = x 114 | 115 | out = self.conv1(x) 116 | out = self.bn1(out) 117 | out = self.relu(out) 118 | 119 | out = self.conv2(out) 120 | out = self.bn2(out) 121 | out = self.relu(out) 122 | 123 | out = self.conv3(out) 124 | out = self.bn3(out) 125 | 126 | if self.downsample is not None: 127 | residual = self.downsample(x) 128 | 129 | out += residual 130 | out = self.relu(out) 131 | 132 | return out 133 | 134 | 135 | class ResNet(nn.Module): 136 | def __init__(self, block, layers, num_classes=1000, n_input=3): 137 | self.inplanes = 64 138 | super(ResNet, self).__init__() 139 | self.conv1 = nn.Conv2d(n_input, 64, kernel_size=7, stride=2, padding=3, 140 | bias=False) 141 | self.bn1 = nn.BatchNorm2d(64) 142 | self.relu = nn.ReLU(inplace=True) 143 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 144 | self.layer1 = self._make_layer(block, 64, layers[0]) 145 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 146 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 147 | rates = [1, 2, 4] 148 | self.layer4 = self._make_deeplabv3_layer(block, 512, layers[3], rates=rates, stride=1) # stride 2 => stride 1 149 | self.avgpool = nn.AvgPool2d(7, stride=1) 150 | self.fc = nn.Linear(512 * block.expansion, num_classes) 151 | for m in self.modules(): 152 | if isinstance(m, nn.Conv2d): 153 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 154 | elif isinstance(m, nn.BatchNorm2d): 155 | nn.init.constant_(m.weight, 1) 156 | nn.init.constant_(m.bias, 0) 157 | 158 | def _make_layer(self, block, planes, blocks, stride=1): 159 | downsample = None 160 | if stride != 1 or self.inplanes != planes * block.expansion: 161 | downsample = nn.Sequential( 162 | nn.Conv2d(self.inplanes, planes * block.expansion, 163 | kernel_size=1, stride=stride, bias=False), 164 | nn.BatchNorm2d(planes * block.expansion), 165 | ) 166 | 167 | layers = [] 168 | layers.append(block(self.inplanes, planes, stride, downsample)) 169 | self.inplanes = planes * block.expansion 170 | for i in range(1, blocks): 171 | layers.append(block(self.inplanes, planes)) 172 | 173 | return nn.Sequential(*layers) 174 | 175 | def _make_deeplabv3_layer(self, block, planes, blocks, rates, stride=1): 176 | downsample = None 177 | if stride != 1 or self.inplanes != planes * block.expansion: 178 | downsample = nn.Sequential( 179 | nn.Conv2d(self.inplanes, planes * block.expansion, 180 | kernel_size=1, stride=stride, bias=False), 181 | nn.BatchNorm2d(planes * block.expansion), 182 | ) 183 | 184 | layers = [] 185 | layers.append(block(self.inplanes, planes, stride, downsample)) 186 | self.inplanes = planes * block.expansion 187 | for i in range(1, blocks): 188 | layers.append(block(self.inplanes, planes, rate=rates[i])) 189 | 190 | return nn.Sequential(*layers) 191 | 192 | def forward(self, x): 193 | x = self.conv1(x) 194 | x = self.bn1(x) 195 | x = self.relu(x) 196 | x = self.maxpool(x) 197 | 198 | x = self.layer1(x) 199 | x = self.layer2(x) 200 | x = self.layer2(x) 201 | x = self.layer3(x) 202 | x = self.layer4(x) 203 | 204 | x = self.avgpool(x) 205 | x = x.view(x.size(0), -1) 206 | x = self.fc(x) 207 | 208 | return x 209 | 210 | 211 | def resnet(pretrained=False, layers=[3,4,6,3], backbone='resnet50', n_input=3, **kwargs): 212 | """Constructs a ResNet-50 model. 213 | Args: 214 | pretrained (bool): If True, returns a model pre-trained on ImageNet 215 | """ 216 | model = ResNet(Bottleneck, layers, n_input=n_input, **kwargs) 217 | 218 | pretrain_dict = model_zoo.load_url(model_urls[backbone]) 219 | try: 220 | model.load_state_dict(pretrain_dict,strict=False) 221 | except: 222 | print("loss conv1") 223 | model_dict = {} 224 | for k, v in pretrain_dict.items(): 225 | if k in pretrain_dict and 'conv1' not in k: 226 | model_dict[k] = v 227 | model.load_state_dict(model_dict, strict=False) 228 | print("load pretrain success") 229 | return model 230 | 231 | 232 | class ResNet50(nn.Module): 233 | def __init__(self, pretrained=True,n_input=3): 234 | """Declare all needed layers.""" 235 | super(ResNet50, self).__init__() 236 | self.model = resnet(n_input=n_input, pretrained=pretrained, layers=[3, 4, 6, 3], backbone='resnet50') 237 | self.relu = self.model.relu # Place a hook 238 | 239 | layers_cfg = [4, 5, 6, 7] 240 | self.blocks = [] 241 | for i, num_this_layer in enumerate(layers_cfg): 242 | self.blocks.append(list(self.model.children())[num_this_layer]) 243 | 244 | def base_forward(self, x): 245 | feature_map = [] 246 | x = self.model.conv1(x) 247 | x = self.model.bn1(x) 248 | x = self.model.relu(x) 249 | x = self.model.maxpool(x) 250 | 251 | for i, block in enumerate(self.blocks): 252 | x = block(x) 253 | feature_map.append(x) 254 | 255 | out = nn.AvgPool2d(x.shape[2:])(x).view(x.shape[0], -1) 256 | 257 | return feature_map, out 258 | 259 | 260 | class ERB(nn.Module): 261 | def __init__(self, in_channels, out_channels): 262 | super(ERB, self).__init__() 263 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 264 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 265 | self.relu = nn.ReLU() 266 | self.bn = nn.BatchNorm2d(out_channels) 267 | self.conv3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 268 | 269 | def forward(self, x, relu=True): 270 | x = self.conv1(x) 271 | res = self.conv2(x) 272 | res = self.bn(res) 273 | res = self.relu(res) 274 | res = self.conv3(res) 275 | if relu: 276 | return self.relu(x + res) 277 | else: 278 | return x+res 279 | 280 | 281 | class MVSSNet(ResNet50): 282 | def __init__(self, nclass, aux=False, sobel=False, constrain=False, n_input=3, **kwargs): 283 | super(MVSSNet, self).__init__(pretrained=True, n_input=n_input) 284 | self.num_class = nclass 285 | self.aux = aux 286 | 287 | self.__setattr__('exclusive', ['head']) 288 | 289 | self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) 290 | self.upsample_4 = nn.Upsample(scale_factor=4, mode="bilinear", align_corners=True) 291 | self.sobel = sobel 292 | self.constrain = constrain 293 | 294 | self.erb_db_1 = ERB(256, self.num_class) 295 | self.erb_db_2 = ERB(512, self.num_class) 296 | self.erb_db_3 = ERB(1024, self.num_class) 297 | self.erb_db_4 = ERB(2048, self.num_class) 298 | 299 | self.erb_trans_1 = ERB(self.num_class, self.num_class) 300 | self.erb_trans_2 = ERB(self.num_class, self.num_class) 301 | self.erb_trans_3 = ERB(self.num_class, self.num_class) 302 | 303 | if self.sobel: 304 | print("----------use sobel-------------") 305 | self.sobel_x1, self.sobel_y1 = get_sobel(256, 1) 306 | self.sobel_x2, self.sobel_y2 = get_sobel(512, 1) 307 | self.sobel_x3, self.sobel_y3 = get_sobel(1024, 1) 308 | self.sobel_x4, self.sobel_y4 = get_sobel(2048, 1) 309 | 310 | if self.constrain: 311 | print("----------use constrain-------------") 312 | self.noise_extractor = ResNet50(n_input=3, pretrained=True) 313 | self.constrain_conv = BayarConv2d(in_channels=1, out_channels=3, padding=2) 314 | self.head = _DAHead(2048+2048, self.num_class, aux, **kwargs) 315 | else: 316 | self.head = _DAHead(2048, self.num_class, aux, **kwargs) 317 | 318 | def forward(self, x): 319 | size = x.size()[2:] 320 | input_ = x.clone() 321 | feature_map, _ = self.base_forward(input_) 322 | c1, c2, c3, c4 = feature_map 323 | 324 | if self.sobel: 325 | res1 = self.erb_db_1(run_sobel(self.sobel_x1, self.sobel_y1, c1)) 326 | res1 = self.erb_trans_1(res1 + self.upsample(self.erb_db_2(run_sobel(self.sobel_x2, self.sobel_y2, c2)))) 327 | res1 = self.erb_trans_2(res1 + self.upsample_4(self.erb_db_3(run_sobel(self.sobel_x3, self.sobel_y3, c3)))) 328 | res1 = self.erb_trans_3(res1 + self.upsample_4(self.erb_db_4(run_sobel(self.sobel_x4, self.sobel_y4, c4))), relu=False) 329 | 330 | else: 331 | res1 = self.erb_db_1(c1) 332 | res1 = self.erb_trans_1(res1 + self.upsample(self.erb_db_2(c2))) 333 | res1 = self.erb_trans_2(res1 + self.upsample_4(self.erb_db_3(c3))) 334 | res1 = self.erb_trans_3(res1 + self.upsample_4(self.erb_db_4(c4)), relu=False) 335 | 336 | if self.constrain: 337 | x = rgb2gray(x) 338 | x = self.constrain_conv(x) 339 | constrain_features, _ = self.noise_extractor.base_forward(x) 340 | constrain_feature = constrain_features[-1] 341 | c4 = torch.cat([c4, constrain_feature], dim=1) 342 | 343 | outputs = [] 344 | 345 | x = self.head(c4) 346 | x0 = F.interpolate(x[0], size, mode='bilinear', align_corners=True) 347 | outputs.append(x0) 348 | 349 | if self.aux: 350 | x1 = F.interpolate(x[1], size, mode='bilinear', align_corners=True) 351 | x2 = F.interpolate(x[2], size, mode='bilinear', align_corners=True) 352 | outputs.append(x1) 353 | outputs.append(x2) 354 | 355 | return res1, x0 356 | 357 | 358 | class _PositionAttentionModule(nn.Module): 359 | """ Position attention module""" 360 | 361 | def __init__(self, in_channels, **kwargs): 362 | super(_PositionAttentionModule, self).__init__() 363 | self.conv_b = nn.Conv2d(in_channels, in_channels // 8, 1) 364 | self.conv_c = nn.Conv2d(in_channels, in_channels // 8, 1) 365 | self.conv_d = nn.Conv2d(in_channels, in_channels, 1) 366 | self.alpha = nn.Parameter(torch.zeros(1)) 367 | self.softmax = nn.Softmax(dim=-1) 368 | 369 | def forward(self, x): 370 | batch_size, _, height, width = x.size() 371 | feat_b = self.conv_b(x).view(batch_size, -1, height * width).permute(0, 2, 1) 372 | feat_c = self.conv_c(x).view(batch_size, -1, height * width) 373 | attention_s = self.softmax(torch.bmm(feat_b, feat_c)) 374 | feat_d = self.conv_d(x).view(batch_size, -1, height * width) 375 | feat_e = torch.bmm(feat_d, attention_s.permute(0, 2, 1)).view(batch_size, -1, height, width) 376 | out = self.alpha * feat_e + x 377 | 378 | return out 379 | 380 | 381 | class _ChannelAttentionModule(nn.Module): 382 | """Channel attention module""" 383 | 384 | def __init__(self, **kwargs): 385 | super(_ChannelAttentionModule, self).__init__() 386 | self.beta = nn.Parameter(torch.zeros(1)) 387 | self.softmax = nn.Softmax(dim=-1) 388 | 389 | def forward(self, x): 390 | batch_size, _, height, width = x.size() 391 | feat_a = x.view(batch_size, -1, height * width) 392 | feat_a_transpose = x.view(batch_size, -1, height * width).permute(0, 2, 1) 393 | attention = torch.bmm(feat_a, feat_a_transpose) 394 | attention_new = torch.max(attention, dim=-1, keepdim=True)[0].expand_as(attention) - attention 395 | attention = self.softmax(attention_new) 396 | 397 | feat_e = torch.bmm(attention, feat_a).view(batch_size, -1, height, width) 398 | out = self.beta * feat_e + x 399 | 400 | return out 401 | 402 | 403 | class _DAHead(nn.Module): 404 | def __init__(self, in_channels, nclass, aux=True, norm_layer=nn.BatchNorm2d, norm_kwargs=None, **kwargs): 405 | super(_DAHead, self).__init__() 406 | self.aux = aux 407 | inter_channels = in_channels // 4 408 | self.conv_p1 = nn.Sequential( 409 | nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), 410 | norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)), 411 | nn.ReLU(True) 412 | ) 413 | self.conv_c1 = nn.Sequential( 414 | nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), 415 | norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)), 416 | nn.ReLU(True) 417 | ) 418 | self.pam = _PositionAttentionModule(inter_channels, **kwargs) 419 | self.cam = _ChannelAttentionModule(**kwargs) 420 | self.conv_p2 = nn.Sequential( 421 | nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False), 422 | norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)), 423 | nn.ReLU(True) 424 | ) 425 | self.conv_c2 = nn.Sequential( 426 | nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False), 427 | norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)), 428 | nn.ReLU(True) 429 | ) 430 | self.out = nn.Sequential( 431 | nn.Dropout(0.1), 432 | nn.Conv2d(inter_channels, nclass, 1) 433 | ) 434 | if aux: 435 | self.conv_p3 = nn.Sequential( 436 | nn.Dropout(0.1), 437 | nn.Conv2d(inter_channels, nclass, 1) 438 | ) 439 | self.conv_c3 = nn.Sequential( 440 | nn.Dropout(0.1), 441 | nn.Conv2d(inter_channels, nclass, 1) 442 | ) 443 | 444 | def forward(self, x): 445 | feat_p = self.conv_p1(x) 446 | feat_p = self.pam(feat_p) 447 | feat_p = self.conv_p2(feat_p) 448 | 449 | feat_c = self.conv_c1(x) 450 | feat_c = self.cam(feat_c) 451 | feat_c = self.conv_c2(feat_c) 452 | 453 | feat_fusion = feat_p + feat_c 454 | 455 | outputs = [] 456 | fusion_out = self.out(feat_fusion) 457 | outputs.append(fusion_out) 458 | if self.aux: 459 | p_out = self.conv_p3(feat_p) 460 | c_out = self.conv_c3(feat_c) 461 | outputs.append(p_out) 462 | outputs.append(c_out) 463 | 464 | return tuple(outputs) 465 | 466 | 467 | def get_mvss(backbone='resnet50', pretrained_base=True, nclass=1, sobel=True, n_input=3, constrain=True, **kwargs): 468 | model = MVSSNet(nclass, backbone=backbone, 469 | pretrained_base=pretrained_base, 470 | sobel=sobel, 471 | n_input=n_input, 472 | constrain=constrain, 473 | **kwargs) 474 | return model 475 | 476 | 477 | if __name__ == '__main__': 478 | img = torch.randn(2, 3, 512, 512) 479 | model = get_mvss(sobel=True, n_input=3, constrain=True) 480 | edge, outputs = model(img) 481 | print(outputs.shape) 482 | print(edge.shape) 483 | --------------------------------------------------------------------------------