├── .gitignore ├── README.md ├── config ├── __init__.py └── conf.py ├── main.py ├── static └── Constrained-CNN.png └── utils ├── __init__.py ├── data.py ├── helper.py └── model.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea* 2 | .DS_Store 3 | *pyc 4 | __pycache__ 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Constrained-CNN 2 | 3 | This repository is an Pytorch implement for paper "Constrained Convolutional Neural Networks: A New Approach Towards General Purpose Image Manipulation Detection". 4 | 5 | ![Constrained-CNN](https://github.com/grasses/Constrained-CNN/blob/master/static/Constrained-CNN.png?raw=true) 6 | 7 | Note: Note: this is not the official implement for Constrained CNN, you can follow the paper here: [https://ieeexplore.ieee.org/document/8335799](https://ieeexplore.ieee.org/document/8335799) 8 | 9 | ## Usage 10 | 11 | Since this repository only includes implement of Constrained CNN model, the reader should complete the dataloader (see main.py: 58). 12 | 1. Download dataset and complete dataloader code. 13 | 2. Run `python main.py` 14 | 15 | 16 | ## License 17 | 18 | This library is under the GPL V3 license. For the full copyright and license information, please view the LICENSE file that was distributed with this source code. -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grasses/Constrained-CNN/1d255beebab3850144d98fe204edfc8c372fca06/config/__init__.py -------------------------------------------------------------------------------- /config/conf.py: -------------------------------------------------------------------------------- 1 | import os, torch, random 2 | random.seed(100) 3 | torch.manual_seed(100) 4 | torch.cuda.manual_seed(100) 5 | 6 | class Conf(object): 7 | ROOT = os.path.dirname(os.path.abspath(os.path.dirname(__file__))) 8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | 10 | total_class = 2 11 | total_epoch = 10 12 | batch_size = 1 13 | learning_rate = 0.001 14 | 15 | data_path = os.path.join(ROOT, "dataset") 16 | model_path = os.path.join(ROOT, "model") -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | from utils.data import Data 6 | from utils.helper import Helper 7 | from utils.model import MISLnet as Model 8 | np.random.seed(100) 9 | torch.manual_seed(100) 10 | torch.cuda.manual_seed(100) 11 | 12 | def get_args(): 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--config", default="conf") 15 | return parser.parse_args() 16 | 17 | def training(conf, model, loader): 18 | print("\n-> start pretrain model!") 19 | max_step = 0 20 | optimizer = torch.optim.Adam(model.parameters(), lr=conf.learning_rate) 21 | for epoch in range(conf.total_epoch): 22 | for step, (x, y) in enumerate(loader): 23 | max_step = max(max_step, step) 24 | global_step = epoch * max(max_step, step) + step 25 | x, y = x.to(conf.device), y.to(conf.device) 26 | logist, output = model(x) 27 | loss = F.cross_entropy(output, y) 28 | model.zero_grad() 29 | loss.backward() 30 | optimizer.step() 31 | pred = output.data.max(1)[1] 32 | correct = pred.eq(y.data.view_as(pred)).cpu().sum().item() 33 | acc = 100.0 * (correct / conf.batch_size) 34 | print("-> training epoch={:d} loss={:.3f} acc={:.3f}% {:d}".format(epoch, loss, acc, conf.batch_size)) 35 | 36 | def testing(conf, model, test_loader): 37 | correct = 0 38 | test_loss = 0 39 | with torch.no_grad(): 40 | for x, y in test_loader: 41 | x, y = x.to(conf.device), y.to(conf.device) 42 | logist, output = model(x) 43 | test_loss += F.cross_entropy(output, y, reduction="sum").item() 44 | pred = output.argmax(dim=1, keepdim=True) 45 | correct += pred.eq(y.view_as(pred)).sum().item() 46 | test_loss /= len(test_loader.dataset) 47 | acc = 100. * correct / len(test_loader.dataset) 48 | print("-> testing loss={} acc={}".format(test_loss, acc)) 49 | return test_loss, acc 50 | 51 | def main(): 52 | args = get_args() 53 | conf = __import__("config." + args.config, globals(), locals(), ["Conf"]).Conf 54 | helper = Helper(conf=conf) 55 | 56 | data = Data(conf) 57 | data.load_data() 58 | # you need to setup: data.train_loader/data.test_loader 59 | 60 | model = Model(conf).to(conf.device) 61 | print(model) 62 | training(conf, model, data.train_loader) 63 | 64 | if __name__ == "__main__": 65 | main() -------------------------------------------------------------------------------- /static/Constrained-CNN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grasses/Constrained-CNN/1d255beebab3850144d98fe204edfc8c372fca06/static/Constrained-CNN.png -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grasses/Constrained-CNN/1d255beebab3850144d98fe204edfc8c372fca06/utils/__init__.py -------------------------------------------------------------------------------- /utils/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import Dataset, DataLoader 3 | 4 | class ImgDataset(Dataset): 5 | def __init__(self, x, y): 6 | self.x = np.array(x, dtype=np.float32) 7 | self.y = np.array(y, dtype=np.int64) 8 | def __getitem__(self, index): 9 | return self.x[index], self.y[index] 10 | def __len__(self): 11 | return len(self.x) 12 | 13 | class Data(): 14 | def __init__(self, conf): 15 | self.conf = conf 16 | self.data_path = conf.data_path 17 | self.train_loader = None 18 | self.test_loader = None 19 | 20 | def extract_data(self, kind="train"): 21 | # you should load image here 22 | images = np.random.randn(100, 1, 256, 256) 23 | labels = np.ones(shape=(100)) 24 | return images, labels 25 | 26 | def load_data(self, batch_size=0): 27 | print("-> load data from: {}".format(self.data_path)) 28 | if not batch_size: 29 | batch_size = self.conf.batch_size 30 | x, y = self.extract_data(kind="train") 31 | self.train_loader = DataLoader(ImgDataset(x, y), batch_size=batch_size, shuffle=True) 32 | x, y = self.extract_data(kind="test") 33 | self.test_loader = DataLoader(ImgDataset(x, y), batch_size=batch_size, shuffle=True) 34 | return self -------------------------------------------------------------------------------- /utils/helper.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | class Helper(): 4 | def __init__(self, conf): 5 | self.conf = conf 6 | self.init() 7 | 8 | def init(self): 9 | check_path = [self.conf.model_path, self.conf.data_path] 10 | for path in check_path: 11 | if not os.path.exists(path): 12 | os.makedirs(path) -------------------------------------------------------------------------------- /utils/model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | 4 | __author__ = 'homeway' 5 | __copyright__ = 'Copyright © 2019/11/22, homeway' 6 | 7 | import torch, torch.nn as nn, torch.nn.functional as F, random 8 | from config.conf import Conf 9 | 10 | class ModelBase(nn.Module): 11 | def __init__(self, name, created_time): 12 | super(ModelBase, self).__init__() 13 | self.name = name 14 | self.created_time = created_time 15 | 16 | def copy_params(self, state_dict): 17 | own_state = self.state_dict() 18 | for (name, param) in state_dict.items(): 19 | if name in own_state: 20 | own_state[name].copy_(param.clone()) 21 | 22 | def boost_params(self, scale=1.0): 23 | if scale == 1.0: 24 | return self.state_dict() 25 | for (name, param) in self.state_dict().items(): 26 | self.state_dict()[name].copy_((scale * param).clone()) 27 | return self.state_dict() 28 | 29 | # self - x 30 | def sub_params(self, x): 31 | own_state = self.state_dict() 32 | for (name, param) in x.items(): 33 | if name in own_state: 34 | own_state[name].copy_(own_state[name] - param) 35 | 36 | # self + x 37 | def add_params(self, x): 38 | a = self.state_dict() 39 | for (name, param) in x.items(): 40 | if name in a: 41 | a[name].copy_(a[name] + param) 42 | 43 | class MISLnet(ModelBase): 44 | def __init__(self, conf=Conf, name=None, created_time=None): 45 | super(MISLnet, self).__init__(f'{name}_ConstrainedCNN', created_time) 46 | 47 | self.register_parameter("const_weight", None) 48 | self.const_weight = nn.Parameter(torch.randn(size=[3, 1, 5, 5]), requires_grad=True) 49 | self.conv1 = nn.Conv2d(3, 96, 7, stride=2, padding=4) 50 | self.conv2 = nn.Conv2d(96, 64, 5, stride=1, padding=2) 51 | self.conv3 = nn.Conv2d(64, 64, 5, stride=1, padding=2) 52 | self.conv4 = nn.Conv2d(64, 128, 1, stride=1) 53 | self.fc1 = nn.Linear(6272, 200) 54 | self.fc2 = nn.Linear(200, 200) 55 | self.fc3 = nn.Linear(200, conf.total_class) 56 | self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2) 57 | self.avg_pool = nn.AvgPool2d(kernel_size=3, stride=2) 58 | 59 | def normalized_F(self): 60 | central_pixel = (self.const_weight.data[:, 0, 2, 2]) 61 | for i in range(3): 62 | sumed = self.const_weight.data[i].sum() - central_pixel[i] 63 | self.const_weight.data[i] /= sumed 64 | self.const_weight.data[i, 0, 2, 2] = -1.0 65 | 66 | def forward(self, x): 67 | # Constrained-CNN 68 | self.normalized_F() 69 | x = F.conv2d(x, self.const_weight) 70 | # CNN 71 | x = self.conv1(x) 72 | x = self.max_pool(torch.tanh(x)) 73 | x = self.conv2(x) 74 | x = self.max_pool(torch.tanh(x)) 75 | x = self.conv3(x) 76 | x = self.max_pool(torch.tanh(x)) 77 | x = self.conv4(x) 78 | x = self.avg_pool(torch.tanh(x)) 79 | # Fully Connected 80 | x = torch.flatten(x, 1) 81 | x = self.fc1(x) 82 | x = torch.tanh(x) 83 | x = self.fc2(x) 84 | x = torch.tanh(x) 85 | logist = self.fc3(x) 86 | output = F.softmax(logist, dim=1) 87 | return logist, output 88 | 89 | 90 | if __name__ == "__main__": 91 | model = MISLnet(conf=Conf, name="testing") 92 | model.summary() --------------------------------------------------------------------------------