├── .gitignore ├── README.md ├── SRM_Kernels.npy ├── YeNet.py ├── main.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | checkpoints/* 3 | .ipynb_checkpoints/ 4 | log_reader.ipynb 5 | logs 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # YeNet-Pytorch 2 | Pytorch implementation of "Deep Learning Hierarchical Representations for Image Steganalysis" by Jian Ye, Jiangqun Ni and Yang Yi 3 | 4 | ## Dataset 5 | training and validation must contains there own cover and stego / beta maps images directory, stego images must have the same name than there corresponding cover 6 | 7 | ## Publication 8 | [The publication can be found here](http://ieeexplore.ieee.org/document/7937836/) 9 | -------------------------------------------------------------------------------- /SRM_Kernels.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Caenorst/YeNet-Pytorch/02703fc43360a09487a1232622c98bc6a545c9db/SRM_Kernels.npy -------------------------------------------------------------------------------- /YeNet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.nn.parameter import Parameter 8 | import torch.nn.functional as F 9 | 10 | SRM_npy = np.load('SRM_Kernels.npy') 11 | 12 | class SRM_conv2d(nn.Module): 13 | def __init__(self, stride=1, padding=0): 14 | super(SRM_conv2d, self).__init__() 15 | self.in_channels = 1 16 | self.out_channels = 30 17 | self.kernel_size = (5, 5) 18 | if isinstance(stride, int): 19 | self.stride = (stride, stride) 20 | else: 21 | self.stride = stride 22 | if isinstance(padding, int): 23 | self.padding = (padding, padding) 24 | else: 25 | self.padding = padding 26 | self.dilation = (1,1) 27 | self.transpose = False 28 | self.output_padding = (0,) 29 | self.groups = 1 30 | self.weight = Parameter(torch.Tensor(30, 1, 5, 5), \ 31 | requires_grad=True) 32 | self.bias = Parameter(torch.Tensor(30), \ 33 | requires_grad=True) 34 | self.reset_parameters() 35 | 36 | def reset_parameters(self): 37 | self.weight.data.numpy()[:] = SRM_npy 38 | self.bias.data.zero_() 39 | 40 | def forward(self, input): 41 | return F.conv2d(input, self.weight, self.bias, \ 42 | self.stride, self.padding, self.dilation, \ 43 | self.groups) 44 | 45 | class ConvBlock(nn.Module): 46 | def __init__(self, in_channels, out_channels, kernel_size=3, \ 47 | stride=1, with_bn=False): 48 | super(ConvBlock, self).__init__() 49 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, \ 50 | stride) 51 | self.relu = nn.ReLU() 52 | self.with_bn = with_bn 53 | if with_bn: 54 | self.norm = nn.BatchNorm2d(out_channels) 55 | else: 56 | self.norm = lambda x: x 57 | self.reset_parameters() 58 | 59 | def forward(self, x): 60 | return self.norm(self.relu(self.conv(x))) 61 | 62 | def reset_parameters(self): 63 | nn.init.xavier_uniform(self.conv.weight) 64 | self.conv.bias.data.fill_(0.2) 65 | if self.with_bn: 66 | self.norm.reset_parameters() 67 | 68 | class YeNet(nn.Module): 69 | def __init__(self, with_bn=False, threshold=3): 70 | super(YeNet, self).__init__() 71 | self.with_bn = with_bn 72 | self.preprocessing = SRM_conv2d(1, 0) 73 | self.TLU = nn.Hardtanh(-threshold, threshold, True) 74 | if with_bn: 75 | self.norm1 = nn.BatchNorm2d(30) 76 | else: 77 | self.norm1 = lambda x: x 78 | self.block2 = ConvBlock(30, 30, 3, with_bn=self.with_bn) 79 | self.block3 = ConvBlock(30, 30, 3, with_bn=self.with_bn) 80 | self.block4 = ConvBlock(30, 30, 3, with_bn=self.with_bn) 81 | self.pool1 = nn.AvgPool2d(2, 2) 82 | self.block5 = ConvBlock(30, 32, 5, with_bn=self.with_bn) 83 | self.pool2 = nn.AvgPool2d(3, 2) 84 | self.block6 = ConvBlock(32, 32, 5, with_bn=self.with_bn) 85 | self.pool3 = nn.AvgPool2d(3, 2) 86 | self.block7 = ConvBlock(32, 32, 5, with_bn=self.with_bn) 87 | self.pool4 = nn.AvgPool2d(3, 2) 88 | self.block8 = ConvBlock(32, 16, 3, with_bn=self.with_bn) 89 | self.block9 = ConvBlock(16, 16, 3, 3, with_bn=self.with_bn) 90 | self.ip1 = nn.Linear(3 * 3 * 16, 2) 91 | self.reset_parameters() 92 | 93 | def forward(self, x): 94 | x = x.float() 95 | x = self.preprocessing(x) 96 | x = self.TLU(x) 97 | x = self.norm1(x) 98 | x = self.block2(x) 99 | x = self.block3(x) 100 | x = self.block4(x) 101 | x = self.pool1(x) 102 | x = self.block5(x) 103 | x = self.pool2(x) 104 | x = self.block6(x) 105 | x = self.pool3(x) 106 | x = self.block7(x) 107 | x = self.pool4(x) 108 | x = self.block8(x) 109 | x = self.block9(x) 110 | x = x.view(x.size(0), -1) 111 | x = self.ip1(x) 112 | return x 113 | 114 | def reset_parameters(self): 115 | for mod in self.modules(): 116 | if isinstance(mod, SRM_conv2d) or \ 117 | isinstance(mod, nn.BatchNorm2d) or \ 118 | isinstance(mod, ConvBlock): 119 | mod.reset_parameters() 120 | elif isinstance(mod, nn.Linear): 121 | nn.init.normal(mod.weight, 0. ,0.01) 122 | mod.bias.data.zero_() 123 | 124 | def accuracy(outputs, labels): 125 | _, argmax = torch.max(outputs, 1) 126 | return (labels == argmax.squeeze()).float().mean() 127 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import os 4 | import shutil 5 | import time 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from torch.autograd import Variable 10 | from torchvision import transforms 11 | import utils 12 | import YeNet 13 | 14 | 15 | parser = argparse.ArgumentParser(description='PyTorch implementation of YeNet') 16 | parser.add_argument('train_cover_dir', type=str, metavar='PATH', 17 | help='path of directory containing all ' + 18 | 'training cover images') 19 | parser.add_argument('train_stego_dir', type=str, metavar='PATH', 20 | help='path of directory containing all ' + 21 | 'training stego images or beta maps') 22 | parser.add_argument('valid_cover_dir', type=str, metavar='PATH', 23 | help='path of directory containing all ' + 24 | 'validation cover images') 25 | parser.add_argument('valid_stego_dir', type=str, metavar='PATH', 26 | help='path of directory containing all ' + 27 | 'validation stego images or beta maps') 28 | parser.add_argument('--batch-size', type=int, default=32, metavar='N', 29 | help='input batch size for training (default: 32)') 30 | parser.add_argument('--test-batch-size', type=int, default=32, metavar='N', 31 | help='input batch size for testing (default: 32)') 32 | parser.add_argument('--epochs', type=int, default=1000, metavar='N', 33 | help='number of epochs to train (default: 1000)') 34 | parser.add_argument('--lr', type=float, default=4e-1, metavar='LR', 35 | help='learning rate (default: 4e-1)') 36 | parser.add_argument('--use-batch-norm', action='store_true', default=False, 37 | help='use batch normalization after each activation,' + 38 | ' also disable pair constraint (default: False)') 39 | parser.add_argument('--embed-otf', action='store_true', default=False, 40 | help='use beta maps and embed on the fly instead' + 41 | ' of use stego images (default: False)') 42 | parser.add_argument('--no-cuda', action='store_true', default=False, 43 | help='disables CUDA training') 44 | parser.add_argument('--gpu', type=int, default=0, 45 | help='index of gpu used (default: 0)') 46 | parser.add_argument('--seed', type=int, default=1, metavar='S', 47 | help='random seed (default: 1)') 48 | parser.add_argument('--log-interval', type=int, default=200, metavar='N', 49 | help='how many batches to wait before logging training status') 50 | # TODO: use a format to store logs (tensorboard ?) 51 | # parser.add_argument('--log-path', type=str, default='logs/training.log', 52 | # metavar='PATH', help='path to generated log file') 53 | args = parser.parse_args() 54 | arch = 'YeNet_with_bn' if args.use_batch_norm else 'YeNet' 55 | args.cuda = not args.no_cuda and torch.cuda.is_available() 56 | torch.manual_seed(args.seed) 57 | if args.cuda: 58 | torch.cuda.manual_seed(args.seed) 59 | torch.cuda.set_device(args.gpu) 60 | else: 61 | args.gpu = None 62 | kwargs = {'num_workers': 8, 'pin_memory': True} if args.cuda else {} 63 | 64 | train_transform = transforms.Compose([ 65 | utils.RandomRot(), 66 | utils.RandomFlip(), 67 | utils.ToTensor() 68 | ]) 69 | 70 | valid_transform = transforms.Compose([ 71 | utils.ToTensor() 72 | ]) 73 | 74 | print("Generate loaders...") 75 | train_loader = utils.DataLoaderStego(args.train_cover_dir, args.train_stego_dir, 76 | embedding_otf=args.embed_otf, shuffle=True, 77 | pair_constraint=not(args.use_batch_norm), 78 | batch_size=args.batch_size, 79 | transform=train_transform, 80 | num_workers=kwargs['num_workers'], 81 | pin_memory=kwargs['pin_memory']) 82 | 83 | valid_loader = utils.DataLoaderStego(args.valid_cover_dir, args.valid_stego_dir, 84 | embedding_otf=False, shuffle=False, 85 | pair_constraint=True, 86 | batch_size=args.test_batch_size, 87 | transform=valid_transform, 88 | num_workers=kwargs['num_workers'], 89 | pin_memory=kwargs['pin_memory']) 90 | print('train_loader have {} iterations, valid_loader have {} iterations'.format( 91 | len(train_loader), len(valid_loader))) 92 | # valid_loader = train_loader 93 | print("Generate model") 94 | net = YeNet.YeNet(with_bn=args.use_batch_norm) 95 | 96 | print(net) 97 | print("Generate loss and optimizer") 98 | if args.cuda: 99 | net.cuda() 100 | criterion = nn.CrossEntropyLoss().cuda() 101 | else: 102 | criterion = nn.CrossEntropyLoss().cuda() 103 | optimizer = optim.Adadelta(net.parameters(), lr=args.lr, rho=0.95, eps=1e-8, 104 | weight_decay=5e-4) 105 | _time = time.time() 106 | 107 | def train(epoch): 108 | net.train() 109 | running_loss = 0. 110 | running_accuracy = 0. 111 | for batch_idx, data in enumerate(train_loader): 112 | images, labels = Variable(data['images']), Variable(data['labels']) 113 | if args.cuda: 114 | images, labels = images.cuda(), labels.cuda() 115 | optimizer.zero_grad() 116 | outputs = net(images) 117 | accuracy = YeNet.accuracy(outputs, labels).data[0] 118 | running_accuracy += accuracy 119 | loss = criterion(outputs, labels) 120 | running_loss += loss.data[0] 121 | loss.backward() 122 | optimizer.step() 123 | if (batch_idx + 1) % args.log_interval == 0: 124 | running_accuracy /= args.log_interval 125 | running_loss /= args.log_interval 126 | print(('\nTrain epoch: {} [{}/{}]\tAccuracy: ' + 127 | '{:.2f}%\tLoss: {:.6f}').format( 128 | epoch, batch_idx + 1, len(train_loader), 129 | 100 * running_accuracy, running_loss)) 130 | running_loss = 0. 131 | running_accuracy = 0. 132 | net.train() 133 | 134 | def valid(): 135 | net.eval() 136 | valid_loss = 0. 137 | valid_accuracy = 0. 138 | correct = 0 139 | for data in valid_loader: 140 | # break 141 | images, labels = Variable(data['images']), Variable(data['labels']) 142 | if args.cuda: 143 | images, labels = images.cuda(), labels.cuda() 144 | outputs = net(images) 145 | valid_loss += criterion(outputs, labels).data[0] 146 | valid_accuracy += YeNet.accuracy(outputs, labels).data[0] 147 | valid_loss /= len(valid_loader) 148 | valid_accuracy /= len(valid_loader) 149 | print('\nTest set: Loss: {:.4f}, Accuracy: {:.2f}%)\n'.format( 150 | valid_loss, 100 * valid_accuracy)) 151 | return valid_loss, valid_accuracy 152 | 153 | def save_checkpoint(state, is_best, filename='checkpoints/checkpoint.pth.tar'): 154 | torch.save(state, filename) 155 | if is_best: 156 | shutil.copyfile(filename, 'checkpoints/model_best.pth.tar') 157 | 158 | best_accuracy = 0. 159 | for epoch in range(1, args.epochs + 1): 160 | print("Epoch:", epoch) 161 | print("Train") 162 | train(epoch) 163 | print("Time:", time.time() - _time) 164 | print("Test") 165 | _, accuracy = valid() 166 | if accuracy > best_accuracy: 167 | best_accuracy = accuracy 168 | is_best = True 169 | else: 170 | is_best = False 171 | print("Time:", time.time() - _time) 172 | save_checkpoint({ 173 | 'epoch': epoch, 174 | 'arch': arch, 175 | 'state_dict': net.state_dict(), 176 | 'best_prec1': accuracy, 177 | 'optimizer': optimizer.state_dict(), 178 | }, is_best) 179 | 180 | 181 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import random 5 | from glob import glob 6 | import itertools 7 | import torch.multiprocessing as multiprocessing 8 | from torch.utils.data.dataset import Dataset 9 | from torch.utils.data.dataloader import DataLoader, DataLoaderIter 10 | from torch.utils.data.sampler import Sampler, SequentialSampler, \ 11 | RandomSampler 12 | from torchvision import transforms 13 | from PIL import Image 14 | from scipy import io, misc 15 | 16 | class DatasetNoPair(Dataset): 17 | def __init__(self, cover_dir, stego_dir, embedding_otf=False, 18 | transform=None): 19 | self.cover_dir = cover_dir 20 | self.stego_dir = stego_dir 21 | self.cover_list = [x.split('/')[-1] for x in glob(cover_dir + '/*')] 22 | self.transform = transform 23 | self.embedding_otf = embedding_otf 24 | assert len(self.cover_list) != 0, "cover_dir is empty" 25 | # stego_list = ['.'.join(x.split('/')[-1].split('.')[:-1]) 26 | # for x in glob(stego_dir + '/*')] 27 | 28 | def __len__(self): 29 | return len(self.cover_list) * 2 30 | 31 | def __getitem__(self, idx): 32 | idx = int(idx) 33 | cover_idx = (idx - (idx % 2)) / 2 34 | if idx % 2 == 0: 35 | labels = np.zeros((1,1), dtype='int32') 36 | cover_path = os.path.join(self.cover_dir, 37 | self.cover_list[cover_idx]) 38 | images = misc.imread(cover_path) 39 | elif self.embedding_otf: 40 | labels = np.ones((1,1), dtype='int32') 41 | cover_path = os.path.join(self.cover_dir, 42 | self.cover_list[cover_idx]) 43 | cover = misc.imread(cover_path) 44 | beta_path = os.path.join(self.stego_dir, \ 45 | '.'.join(self.cover_list[cover_idx]. \ 46 | split('.')[:-1]) + '.mat') 47 | beta_map = io.loadmat(beta_path)['pChange'] 48 | rand_arr = np.random.rand(cover.shape[0], cover.shape[1]) 49 | images = np.copy(cover) 50 | inf_map = rand_arr < (beta_map / 2.) 51 | images[np.logical_and(cover != 255, inf_map)] += 1 52 | inf_map[:,:] = rand_arr > 1 - (beta_map / 2.) 53 | images[np.logical_and(cover != 0, inf_map)] -= 1 54 | else: 55 | labels = np.ones((1,1), dtype='int32') 56 | stego_path = os.path.join(self.stego_dir, 57 | self.cover_list[cover_idx]) 58 | images = misc.imread(stego_path) 59 | samples = {'images': images[None,:,:,None], 'labels': labels} 60 | if self.transform: 61 | samples = self.transform(samples) 62 | return samples 63 | 64 | class DatasetPair(Dataset): 65 | def __init__(self, cover_dir, stego_dir, embedding_otf=False, 66 | transform=None): 67 | self.cover_dir = cover_dir 68 | self.stego_dir = stego_dir 69 | self.cover_list = [x.split('/')[-1] for x in glob(cover_dir + '/*')] 70 | self.transform = transform 71 | self.embedding_otf = embedding_otf 72 | assert len(self.cover_list) != 0, "cover_dir is empty" 73 | # stego_list = ['.'.join(x.split('/')[-1].split('.')[:-1]) 74 | # for x in glob(stego_dir + '/*')] 75 | 76 | def __len__(self): 77 | return len(self.cover_list) 78 | 79 | def __getitem__(self, idx): 80 | idx = int(idx) 81 | labels = np.array([0,1], dtype='int32') 82 | cover_path = os.path.join(self.cover_dir, 83 | self.cover_list[idx]) 84 | cover = Image.open(cover_path) 85 | images = np.empty((2, cover.size[0], cover.size[1], 1), 86 | dtype='uint8') 87 | images[0,:,:,0] = np.array(cover) 88 | if self.embedding_otf: 89 | images[1,:,:,0] = np.copy(images[0,:,:,0]) 90 | beta_path = os.path.join(self.stego_dir, \ 91 | '.'.join(self.cover_list[idx]. \ 92 | split('.')[:-1]) + '.mat') 93 | beta_map = io.loadmat(beta_path)['pChange'] 94 | rand_arr = np.random.rand(cover.size[0], cover.size[1]) 95 | inf_map = rand_arr < (beta_map / 2.) 96 | images[1,np.logical_and(images[0,:,:,0] != 255, inf_map),0] += 1 97 | inf_map[:,:] = rand_arr > 1 - (beta_map / 2.) 98 | images[1,np.logical_and(images[0,:,:,0] != 0, inf_map),0] -= 1 99 | else: 100 | stego_path = os.path.join(self.stego_dir, 101 | self.cover_list[idx]) 102 | images[1,:,:,0] = misc.imread(stego_path) 103 | samples = {'images': images, 'labels': labels} 104 | if self.transform: 105 | samples = self.transform(samples) 106 | return samples 107 | 108 | class RandomBalancedSampler(Sampler): 109 | def __init__(self, data_source): 110 | self.data_source = data_source 111 | 112 | def __iter__(self): 113 | cover_perm = [x * 2 for x in torch.randperm( \ 114 | len(self.data_source) / 2).long()] 115 | stego_perm = [x * 2 + 1 for x in torch.randperm( \ 116 | len(self.data_source) / 2).long()] 117 | # idx_list = torch.randperm(len(self.data_source) / 2).long() 118 | # cover_perm = [x * 2 for x in idx_list] 119 | # stego_perm = [x * 2 + 1 for x in idx_list] 120 | return iter(it.next() for it in \ 121 | itertools.cycle([iter(cover_perm), iter(stego_perm)])) 122 | 123 | def __len__(self): 124 | return len(self.data_source) 125 | 126 | class DataLoaderIterWithReshape(DataLoaderIter): 127 | def next(self): 128 | if self.num_workers == 0: # same-process loading 129 | indices = next(self.sample_iter) # may raise StopIteration 130 | batch = self._reshape(self.collate_fn( 131 | [self.dataset[i] for i in indices])) 132 | if self.pin_memory: 133 | batch = pin_memory_batch(batch) 134 | return batch 135 | 136 | # check if the next sample has already been generated 137 | if self.rcvd_idx in self.reorder_dict: 138 | batch = self.reorder_dict.pop(self.rcvd_idx) 139 | return self._reshape(self._process_next_batch(batch)) 140 | 141 | if self.batches_outstanding == 0: 142 | self._shutdown_workers() 143 | raise StopIteration 144 | 145 | while True: 146 | assert (not self.shutdown and self.batches_outstanding > 0) 147 | idx, batch = self.data_queue.get() 148 | self.batches_outstanding -= 1 149 | if idx != self.rcvd_idx: 150 | # store out-of-order samples 151 | self.reorder_dict[idx] = batch 152 | continue 153 | return self._reshape(self._process_next_batch(batch)) 154 | 155 | def _reshape(self, batch): 156 | images, labels = batch['images'], batch['labels'] 157 | shape = list(images.size()) 158 | return {'images': images.view(shape[0] * shape[1], *shape[2:]), 159 | 'labels': labels.view(-1)} 160 | 161 | 162 | class DataLoaderStego(DataLoader): 163 | def __init__(self, cover_dir, stego_dir, embedding_otf=False, 164 | shuffle=False, pair_constraint=False, batch_size=1, 165 | transform=None, num_workers=0, pin_memory=False): 166 | self.pair_constraint = pair_constraint 167 | self.embedding_otf = embedding_otf 168 | if pair_constraint and batch_size % 2 == 0: 169 | dataset = DatasetPair(cover_dir, stego_dir, embedding_otf, 170 | transform) 171 | _batch_size = batch_size / 2 172 | else: 173 | dataset = DatasetNoPair(cover_dir, stego_dir, embedding_otf, 174 | transform) 175 | _batch_size = batch_size 176 | if pair_constraint: 177 | if shuffle: 178 | sampler = RandomSampler(dataset) 179 | else: 180 | sampler = SequentialSampler(dataset) 181 | else: 182 | sampler = RandomBalancedSampler(dataset) 183 | super(DataLoaderStego, self). \ 184 | __init__(dataset, _batch_size, None, sampler, \ 185 | None, num_workers, pin_memory=pin_memory, drop_last=True) 186 | self.shuffle = shuffle 187 | 188 | def __iter__(self): 189 | return DataLoaderIterWithReshape(self) 190 | # if self.pair_constraint: 191 | # return DataLoaderIterWithReshape(self) 192 | # else: 193 | # return DataLoaderIter(self) 194 | 195 | class ToTensor(object): 196 | def __call__(self, samples): 197 | images, labels = samples['images'], samples['labels'] 198 | images = images.transpose((0,3,1,2)) 199 | # images = (images.transpose((0,3,1,2)).astype('float32') / 127.5) - 1. 200 | return {'images': torch.from_numpy(images), 201 | 'labels': torch.from_numpy(labels).long()} 202 | 203 | class RandomRot(object): 204 | def __call__(self, samples): 205 | images = samples['images'] 206 | rot = random.randint(0,3) 207 | return {'images': np.rot90(images, rot, axes=[1,2]).copy(), 208 | 'labels': samples['labels']} 209 | 210 | class RandomFlip(object): 211 | def __call__(self, samples): 212 | if random.random() < 0.5: 213 | images = samples['images'] 214 | return {'images': np.flip(images, axis=2).copy(), 215 | 'labels': samples['labels']} 216 | else: 217 | return samples --------------------------------------------------------------------------------