├── .gitignore ├── requirements.txt ├── data ├── driving │ ├── 0.jpg │ ├── 1161.jpg │ ├── 2481.jpg │ ├── 473.jpg │ ├── 953.jpg │ └── 58162.jpg ├── celeba_crop128 │ ├── train │ │ ├── 000001.jpg │ │ ├── 000002.jpg │ │ ├── 000003.jpg │ │ ├── 000004.jpg │ │ ├── 000005.jpg │ │ ├── 000006.jpg │ │ ├── 000007.jpg │ │ ├── 000008.jpg │ │ └── 000009.jpg │ └── celeba_process.py └── README.md ├── frameworks ├── AI2 │ └── README.md ├── ExactLine │ └── README.md ├── README.md └── GenProver │ ├── load_model.py │ ├── README.md │ ├── genmodels.py │ └── components.py ├── experiments ├── rectangle.py ├── synthetic_data.py ├── dataset.py ├── README.md ├── face_recognition.py ├── mutation.py ├── augment_geometrical.py └── model.py ├── implementation ├── LRF.py ├── continuity.py ├── README.md └── independence.py └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | opencv-python 4 | imgaug -------------------------------------------------------------------------------- /data/driving/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yuanyuan-Yuan/GCert/HEAD/data/driving/0.jpg -------------------------------------------------------------------------------- /data/driving/1161.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yuanyuan-Yuan/GCert/HEAD/data/driving/1161.jpg -------------------------------------------------------------------------------- /data/driving/2481.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yuanyuan-Yuan/GCert/HEAD/data/driving/2481.jpg -------------------------------------------------------------------------------- /data/driving/473.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yuanyuan-Yuan/GCert/HEAD/data/driving/473.jpg -------------------------------------------------------------------------------- /data/driving/953.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yuanyuan-Yuan/GCert/HEAD/data/driving/953.jpg -------------------------------------------------------------------------------- /data/driving/58162.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yuanyuan-Yuan/GCert/HEAD/data/driving/58162.jpg -------------------------------------------------------------------------------- /data/celeba_crop128/train/000001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yuanyuan-Yuan/GCert/HEAD/data/celeba_crop128/train/000001.jpg -------------------------------------------------------------------------------- /data/celeba_crop128/train/000002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yuanyuan-Yuan/GCert/HEAD/data/celeba_crop128/train/000002.jpg -------------------------------------------------------------------------------- /data/celeba_crop128/train/000003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yuanyuan-Yuan/GCert/HEAD/data/celeba_crop128/train/000003.jpg -------------------------------------------------------------------------------- /data/celeba_crop128/train/000004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yuanyuan-Yuan/GCert/HEAD/data/celeba_crop128/train/000004.jpg -------------------------------------------------------------------------------- /data/celeba_crop128/train/000005.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yuanyuan-Yuan/GCert/HEAD/data/celeba_crop128/train/000005.jpg -------------------------------------------------------------------------------- /data/celeba_crop128/train/000006.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yuanyuan-Yuan/GCert/HEAD/data/celeba_crop128/train/000006.jpg -------------------------------------------------------------------------------- /data/celeba_crop128/train/000007.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yuanyuan-Yuan/GCert/HEAD/data/celeba_crop128/train/000007.jpg -------------------------------------------------------------------------------- /data/celeba_crop128/train/000008.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yuanyuan-Yuan/GCert/HEAD/data/celeba_crop128/train/000008.jpg -------------------------------------------------------------------------------- /data/celeba_crop128/train/000009.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yuanyuan-Yuan/GCert/HEAD/data/celeba_crop128/train/000009.jpg -------------------------------------------------------------------------------- /frameworks/AI2/README.md: -------------------------------------------------------------------------------- 1 | # AI2/ERAN 2 | 3 | The official implementation is provided [here](https://github.com/eth-sri/eran). In our experiments, we use the adaptor provided by [VeriGauge](https://github.com/AI-secure/VeriGauge) to set up AI2/ERAN. 4 | 5 | [VeriGauge](https://github.com/AI-secure/VeriGauge) and [AI2/ERAN](https://github.com/eth-sri/eran) are well implemented and documented; you can smoothly set up everything following their instructions. -------------------------------------------------------------------------------- /frameworks/ExactLine/README.md: -------------------------------------------------------------------------------- 1 | # ExactLine 2 | 3 | We use the ExactLine implemented by authors of GenProver. The source code can be downloaded [here](https://openreview.net/forum?id=HJxRMlrtPH). 4 | 5 | The implmentation of ExactLine and GenProver are almost the same, except that GenProver merges segments in intermediate outputs as box/polyhedra. Thus, to use ExactLine, you only need to set 6 | 7 | ```python 8 | use_clustr = None 9 | ``` 10 | 11 | in the implementation of GenProver. -------------------------------------------------------------------------------- /frameworks/README.md: -------------------------------------------------------------------------------- 1 | # Frameworks 2 | 3 | GCert is incorporated into the following three frameworks for certification. 4 | 5 | - ERAN (a.k.a. AI2): AI2: Safety and Robustness Certification of Neural Networks with Abstract Interpretation (IEEE S&P 2018). See [here](https://github.com/Yuanyuan-Yuan/GCert/tree/main/frameworks/AI2) for setups. 6 | 7 | - ExactLine: Computing Linear Restrictions of Neural Networks (NeurIPS 2019). See [here](https://github.com/Yuanyuan-Yuan/GCert/tree/main/frameworks/ExactLine) for setups. 8 | 9 | - GenProver: Robustness Certification with Generative Models (PLDI 2021). See [here](https://github.com/Yuanyuan-Yuan/GCert/tree/main/frameworks/GenProver) for setups. -------------------------------------------------------------------------------- /experiments/rectangle.py: -------------------------------------------------------------------------------- 1 | ############################################################ 2 | # This script shows how the minimal enclosing rectangle # 3 | # is obtained in our first evaluation. # 4 | ############################################################ 5 | 6 | import cv2 7 | from PIL import Image 8 | import torchvision.transforms as transforms 9 | 10 | def to_image(tensor): 11 | ndarr = tensor.mul(255).clamp(0, 255).byte().permute(1, 2, 0).squeeze(2).cpu().numpy() 12 | return Image.fromarray(ndarr).convert('RGB') 13 | 14 | def to_ndarr(tensor): 15 | # ndarr = tensor.mul(255).clamp(0, 255).byte().permute(1, 2, 0).squeeze(2).cpu().numpy() 16 | ndarr = tensor.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy() 17 | return ndarr 18 | 19 | generator = None 20 | z = None 21 | # here, you can use any generator and z 22 | 23 | G = generator(z) 24 | ndarr = to_ndarr(G).astype(np.uint8) 25 | 26 | ndarr = cv2.cvtColor(ndarr, cv2.COLOR_GRAY2BGR) 27 | ret, thresh = cv2.threshold(ndarr, 127, 255, cv2.THRESH_BINARY) 28 | contours, hierarchy = cv2.findContours(thresh, 1, 2)[-2:] 29 | 30 | rect = cv2.minAreaRect(contours[0]) 31 | box = cv2.boxPoints(rect) 32 | box = np.int0(box) 33 | cv2.drawContours(ndarr, [box], 0, (0, 255, 0), 2) 34 | 35 | cv2.imwrite('image_with_rectangle.jpg', ndarr) -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Data 2 | 3 | The following four datasets are considered in our evaluation. 4 | 5 | - MNIST - We use the dataset provided by Pytorch (see [here](https://pytorch.org/vision/stable/generated/torchvision.datasets.MNIST.html)). Note that the original image size is $1 \times 27 \times 27$. We resize the image size to $1 \times 32 \times 32$. 6 | 7 | - CIFAR10 - We use the dataset provided by Pytorch (see [here](https://pytorch.org/vision/stable/generated/torchvision.datasets.CIFAR10.html)). 8 | 9 | - Driving - The images can be downloaded [here](https://github.com/SullyChen/driving-datasets). We provide several examples in `driving` folder. The dataset class can be implemented using the [ImageFolder](https://pytorch.org/vision/stable/generated/torchvision.datasets.ImageFolder.html) class in Pytorch. 10 | 11 | - CelebA - The official dataset can be downloaded [here](https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html). Once downloading the dataset, you can use `celeba_crop128/celeba_process.py` to process the dataset, which splits the dataset into different subfolders and crop and resize the faces into $128 \times 128$. Several processed examples are given in the `celeba_crop128/train` folder. We also provide the mapping between file names and human IDs in `CelebA_ID_to_name.json` and `CelebA_name_to_ID.json`. 12 | The dataset class is implemented in `experiments/dataset.py`. -------------------------------------------------------------------------------- /frameworks/GenProver/load_model.py: -------------------------------------------------------------------------------- 1 | ############################################################ 2 | # This scripts shows how to load your customized model # 3 | # (see examples in `experiments/model.py`) such that it # 4 | # can fit into GenProver/ExactLine. # 5 | ############################################################ 6 | 7 | import genmodels 8 | 9 | def convert_key(num_seq, state_dict): 10 | for key in list(state_dict.keys()): 11 | for i in range(num_seq): 12 | if 'seq_%d' % i in key: 13 | new_s = '%d.net' % i 14 | state_dict[key.replace('seq_%d' % i, new_s)] = state_dict.pop(key) 15 | break 16 | return state_dict 17 | 18 | gen_state_dict = gen_ckpt['generator'] 19 | gen_num_seq = 4 # For `ConvGeneratorSeq32` in `experiments/model.py` 20 | gen_state_dict = convert_key(gen_num_seq, gen_state_dict) 21 | 22 | 23 | cls_state_dict = cls_ckpt['classifier'] 24 | cls_num_seq = 7 # For `F3` and `F4` in `experiments/model.py` 25 | cls_state_dict = convert_key(cls_num_seq, cls_state_dict) 26 | 27 | ########################################################### 28 | # Note that the initialization is different with Pytorch. # 29 | ########################################################### 30 | generator = genmodels.ConvGenerator().infer([50]).to(h.device) 31 | generator.eval() 32 | 33 | classifier = genmodels.F3().infer([1, 32, 32]).to(h.device) 34 | classifier.eval() 35 | 36 | decoder.load_state_dict(gen_state_dict) 37 | classifier.load_state_dict(cls_state_dict) -------------------------------------------------------------------------------- /data/celeba_crop128/celeba_process.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import PIL 4 | from PIL import Image 5 | 6 | def make_path(path): 7 | if not os.path.exists(path): 8 | os.mkdir(path) 9 | 10 | input_dir = './img_align_celeba/' 11 | output_dir = './celeba_crop128/' 12 | 13 | txt_path = './celeba_partition.txt' 14 | 15 | 16 | def center_crop(img, new_width=128, new_height=128): 17 | width, height = img.size 18 | left = (width - new_width) / 2 19 | top = (height - new_height) / 2 20 | right = (width + new_width) / 2 21 | bottom = (height + new_height) / 2 22 | return img.crop((left, top, right, bottom)) 23 | 24 | split_dic = {} 25 | with open(txt_path, 'r') as f: 26 | lines = f.readlines() 27 | for l in lines: 28 | file_name, split = l.strip().split(' ') 29 | if file_name in split_dic.keys(): 30 | print('Error.') 31 | split_dic[file_name] = int(split) 32 | 33 | sub_list = ['train/', 'val/', 'test/'] 34 | 35 | make_path(output_dir) 36 | for sub in sub_list: 37 | make_path(output_dir + sub) 38 | 39 | file_list = sorted(os.listdir(input_dir)) 40 | for i, file_name in enumerate(tqdm(file_list)): 41 | input_path = input_dir + file_name 42 | split_idx = split_dic[file_name] 43 | output_path = output_dir + sub_list[split_idx] + file_name 44 | img = Image.open(input_path) 45 | out_img = center_crop( 46 | img=img, 47 | new_width=128, 48 | new_height=128 49 | ) 50 | out_img.save(output_path) 51 | -------------------------------------------------------------------------------- /frameworks/GenProver/README.md: -------------------------------------------------------------------------------- 1 | # GenProver 2 | 3 | The official implementation of GenProver is provided [here](https://openreview.net/forum?id=HJxRMlrtPH). 4 | 5 | After downloading the code, you need to modify the following scripts in the projects: 6 | 7 | - `components.py` - GenProver is implemented based on [DiffAI](https://github.com/eth-sri/diffai) and `components.py` re-implements different Pytorch `nn` modules with `InferModule` of DiffAI. We modified the implementations of several modules (mostly the `BatchNorm` module) to better fit the implementations in Pytorch. You can replace the original `components.py` with our provided one. 8 | 9 | - `genmodels.py` - We added implementations (with DiffAI modules) of our models in this script. You can replace the original `genmodels.py` with our provided one. 10 | 11 | Note that in order to load models trained with Pytorch, you need to do the following: 12 | 13 | 1. Implement the model following the examples given in `experiments/model.py`. We suggest implementing the model with `nn.Sequential()` and hard-coding the name for each `nn.Sequential()`. 14 | 15 | 2. Implement every operation as a class inherited from Pytorch `nn` module. For example, the `torch.cat()` operation should be implement as `class CatTwo(nn.Module)`; see examples in `experiments/model.py`. 16 | 17 | 3. Implement the corresponding class following DiffAI in `components.py`. For example, for the `class CatTwo(nn.Module)` in `experiments/model.py`, you should implement a `class CatTwo(InferModule)` in `components.py`; more examples are given in `components.py`. 18 | 19 | 4. When loading the trained weights, you need to convert the key in `state_dict`. We provide the implementation and examples in `load_model.py`. -------------------------------------------------------------------------------- /implementation/LRF.py: -------------------------------------------------------------------------------- 1 | ##################################################################### 2 | # This script provides the implementation of low-rank factorization # 3 | # (a.k.a. robust PCA). The implementation is based on # 4 | # https://github.com/zhujiapeng/LowRankGAN/blob/master/RobustPCA.py # 5 | ##################################################################### 6 | 7 | import numpy as np 8 | 9 | class RobustPCA(object): 10 | 11 | def __init__(self, M, lamb=1/60): 12 | self.M = M 13 | self.S = np.zeros(self.M.shape) # sparse matrix 14 | self.L = np.zeros(self.M.shape) # low-rank matrix 15 | self.Lamb = np.zeros(self.M.shape) # Lambda matrix 16 | # mu is the coefficient used in augmented Lagrangian. 17 | self.mu = np.prod(self.M.shape) / (4 * np.linalg.norm(self.M, ord=1)) 18 | self.mu_inv = 1 / self.mu 19 | self.iter = 0 20 | self.error = 1e-7 * self.frobenius_norm(self.M) 21 | 22 | if lamb: 23 | self.lamb = lamb 24 | else: 25 | self.lamb = 1 / np.sqrt(np.max(self.M.shape)) 26 | 27 | def reset_iter(self): 28 | """Resets the iteration.""" 29 | self.iter = 0 30 | 31 | @staticmethod 32 | def frobenius_norm(M): 33 | """Computes the Frobenius norm of a given matrix.""" 34 | return np.linalg.norm(M, ord='fro') 35 | 36 | @staticmethod 37 | def shrink(M, tau): 38 | return np.sign(M) * np.maximum((np.abs(M) - tau), np.zeros(M.shape)) 39 | 40 | def svd_threshold(self, M, tau): 41 | U, S, VH = np.linalg.svd(M, full_matrices=False) 42 | return np.dot(U, np.dot(np.diag(self.shrink(S, tau)), VH)) 43 | 44 | def fit(self, max_iter=10000, iter_print=100): 45 | self.reset_iter() 46 | err_i = np.Inf 47 | S_k = self.S 48 | L_k = self.L 49 | Lamb_k = self.Lamb 50 | 51 | while (err_i > self.error) and self.iter < max_iter: 52 | L_k = self.svd_threshold( 53 | self.M - S_k - self.mu_inv * Lamb_k, self.mu_inv) 54 | S_k = self.shrink( 55 | self.M - L_k - self.mu_inv * Lamb_k, self.mu_inv * self.lamb) 56 | Lamb_k = Lamb_k + self.mu * (L_k + S_k - self.M) 57 | err_i = self.frobenius_norm(L_k + S_k - self.M) 58 | self.iter += 1 59 | # if (self.iter % iter_print) == 0: 60 | # print(f'iteration: {self.iter}, error: {err_i}') 61 | 62 | return L_k, S_k -------------------------------------------------------------------------------- /implementation/continuity.py: -------------------------------------------------------------------------------- 1 | ############################################################ 2 | # In this script we show how our continuity regulation # 3 | # is applied on the training stage of generative models. # 4 | ############################################################ 5 | 6 | import random 7 | import numpy as np 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | # This script assumes you are using Pytorch 13 | 14 | cuda = False 15 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 16 | 17 | batch_size = 32 18 | latent_dimension = 50 19 | 20 | def continuity(generator): 21 | # if the latent space follows uniform distribution 22 | z1 = Tensor(np.random.uniform(-1, 1, (batch_size, latent_dimension))) 23 | z2 = Tensor(np.random.uniform(-1, 1, (batch_size, latent_dimension))) 24 | # # if the latent space follows normal distribution 25 | # z1 = Tensor(np.random.normal(0, 1, (batch_size, latent_dimension))) 26 | # z2 = Tensor(np.random.normal(0, 1, (batch_size, latent_dimension))) 27 | G1 = generator(z1) 28 | G2 = generator(z2) 29 | gamma = random.uniform(0, 1) 30 | z = torch.lerp(z1, z2, gamma) 31 | # an `intermediate point` between z1 and z2 32 | G = generator(z) 33 | penality = (gamma * G2 - G - (1 - gamma) * G1).square().mean() 34 | return penality 35 | 36 | 37 | ########################################################## 38 | # Below we show how the function `continuity` is used in # 39 | # standard training process of GAN. # 40 | ########################################################## 41 | 42 | # you can use any generator 43 | generator = None 44 | discriminator = None 45 | dataloader = None 46 | # replace the above with your custom ones 47 | 48 | 49 | learning_rate = 0.01 50 | optimizer_G = torch.optim.Adam(generator.parameters(), lr=learning_rate) 51 | optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=learning_rate) 52 | 53 | bce = torch.nn.BCELoss() 54 | real = Tensor(np.ones(batch_size, 1)) 55 | fake = Tensor(np.zeros(batch_size, 1)) 56 | 57 | n_epochs = 100 58 | for epoch in range(n_epochs): 59 | generator.train() 60 | discriminator.train() 61 | for i, (images, *_) in enumerate(tqdm(dataloader)): 62 | 63 | images = Tensor(images) 64 | 65 | optimizer_G.zero_grad() 66 | z = Tensor(np.random.uniform(-1, 1, (batch_size, latent_dimension))) 67 | # or: Tensor(np.random.normal(0, 1, (batch_size, latent_dimension))) 68 | G = generator(z) 69 | g_loss = bce(discriminator(G), real) 70 | 71 | g_loss += continuity(generator) 72 | # just add this one line :D 73 | 74 | g_loss.backward() 75 | optimizer_G.step() 76 | 77 | optimizer_D.zero_grad() 78 | real_loss = bce(discriminator(images), real) 79 | fake_loss = bce(discriminator(G), fake) 80 | d_loss = (real_loss + fake_loss) / 2 81 | d_loss.backward() 82 | optimizer_D.step() -------------------------------------------------------------------------------- /experiments/synthetic_data.py: -------------------------------------------------------------------------------- 1 | ############################################################ 2 | # This script implements the synthetic dataset class for # 3 | # our ablation study. You can use the SyntheticDataset as # 4 | # a torchvision dataset class. # 5 | ############################################################ 6 | 7 | import os 8 | import json 9 | import cv2 10 | import random 11 | import numpy as np 12 | from PIL import Image 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | 18 | from torch.utils.data import Dataset 19 | from torch.utils.data import DataLoader 20 | import torchvision.transforms as transforms 21 | 22 | 23 | def synthesize( 24 | use_translate=False, 25 | use_scale=False, 26 | use_rotate=False, 27 | use_color=False, 28 | save_path=None): 29 | start = (8, 8) # (0, 0) --> (16, 16) 30 | shape = (16, 16) # (4, 4) --> (16, 16) 31 | if use_color: 32 | img = np.zeros((32, 32, 3), np.uint8) 33 | B = int(random.randint(0, 1) * 255) 34 | G = int(random.randint(0, 1) * 255) 35 | R = int(random.randint(0, 1) * 255) 36 | color = (B, G, R) 37 | else: 38 | img = np.zeros((32, 32, 1), np.uint8) 39 | color = (255) 40 | 41 | (dx, dy) = (0, 0) 42 | (sx, sy) = (1, 1) 43 | if use_translate: 44 | dx = random.randint(0, 8) * (1 if random.randint(0, 1) > 0 else -1) 45 | dy = random.randint(0, 8) * (1 if random.randint(0, 1) > 0 else -1) 46 | if use_scale: 47 | sx = random.randint(5, 15) / 10 48 | sy = random.randint(5, 15) / 10 49 | 50 | # print((start[0]+dx, start[1]+dy)) 51 | # print((start[0]+dx+shape[0]*sx, start[1]+dy+shape[1]*sy)) 52 | # print((B, G, R)) 53 | 54 | (x1, y1) = (int(start[0] + dx), int(start[1] + dy)) 55 | cv2.rectangle( 56 | img, 57 | pt1=(x1, y1), 58 | pt2=(int(x1 + shape[0] * sx), int(y1 + shape[1] * sy)), 59 | color=color, 60 | thickness=-1 61 | ) 62 | 63 | if use_rotate: 64 | (cols, rows) = (32, 32) 65 | R = random.randint(-9, 9) * 10 66 | M = cv2.getRotationMatrix2D((cols/2, rows/2), R, 1) 67 | img = cv2.warpAffine(img, M, (cols, rows)) 68 | 69 | if save_path is not None: 70 | cv2.imwrite(save_path, img) 71 | return img if use_color else img[:, :, 0] 72 | 73 | 74 | def to_PIL(img): 75 | if len(img.shape) == 3: 76 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 77 | im_pil = Image.fromarray(img) 78 | return im_pil 79 | 80 | 81 | class SyntheticDataset(Dataset): 82 | def __init__(self, n_data=50000): 83 | super(SyntheticDataset).__init__() 84 | self.n_data = n_data 85 | self.transform = transforms.Compose([ 86 | transforms.ToTensor(), 87 | ]) 88 | 89 | def __len__(self): 90 | return self.n_data 91 | 92 | def __getitem__(self, idx): 93 | image_pil = to_PIL(synthesize( 94 | use_translate=True, 95 | use_scale=True, 96 | use_rotate=False, 97 | use_color=False 98 | )) 99 | image_tensor = self.transform(image_pil) 100 | # print(image_tensor.size()) 101 | return image_tensor -------------------------------------------------------------------------------- /experiments/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | from tqdm import tqdm 5 | import numpy as np 6 | from PIL import Image 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from torch.utils.data import Dataset 13 | from torch.utils.data import DataLoader 14 | import torchvision.transforms as transforms 15 | 16 | 17 | class CelebARecog(Dataset): 18 | def __init__(self, 19 | img_dir='/YOUR_DATA_DIR/celeba_crop128/', 20 | name2id_path='/YOUR_DATA_DIR/CelebA_name_to_ID.json', 21 | id2name_path='/YOUR_DATA_DIR/CelebA_ID_to_name.json', 22 | split='train', 23 | num_tuple=800, 24 | img_size=64): 25 | super(CelebARecog).__init__() 26 | assert split in ['train', 'val', 'test'] 27 | self.transform = transforms.Compose([ 28 | transforms.Resize(img_size), 29 | transforms.CenterCrop(img_size), 30 | transforms.ToTensor() 31 | ]) 32 | with open(name2id_path, 'r') as f: 33 | self.name2id = json.load(f) 34 | with open(id2name_path, 'r') as f: 35 | self.id2name = json.load(f) 36 | self.name_dir = img_dir + split + '/' 37 | self.name_list = sorted(os.listdir(self.name_dir)) 38 | self.pos_pair_list = [] 39 | self.neg_pair_list = [] 40 | for i in tqdm(range(min(num_tuple, len(self.name_list)))): 41 | name = self.name_list[i] 42 | ID = self.name2id[name] 43 | if (len(self.id2name[ID]) == 1) or (len([x for x in self.id2name[ID] if x in self.name_list]) == 0): 44 | continue 45 | same = random.choice(self.id2name[ID]) 46 | while (same not in self.name_list) or (same == name): 47 | same = random.choice(self.id2name[ID]) 48 | diff = random.choice(self.name_list) 49 | while diff in self.id2name[ID]: 50 | diff = random.choice(self.name_list) 51 | self.pos_pair_list.append((name, same)) 52 | self.neg_pair_list.append((name, diff)) 53 | 54 | print('Total %d' % len(self.pos_pair_list)) 55 | 56 | def __len__(self): 57 | return len(self.pos_pair_list) 58 | 59 | def __getitem__(self, idx): 60 | (name, same) = self.pos_pair_list[idx] 61 | (_, diff) = self.neg_pair_list[idx] 62 | img = Image.open(self.name_dir + name) 63 | img = self.transform(img) 64 | img_same = Image.open(self.name_dir + same) 65 | img_same = self.transform(img_same) 66 | img_diff = Image.open(self.name_dir + diff) 67 | img_diff = self.transform(img_diff) 68 | return img, img_same, img_diff 69 | 70 | class CelebAGen(Dataset): 71 | def __init__(self, 72 | img_dir='/YOUR_DATA_DIR/celeba_crop128/', 73 | split='train', 74 | num_img=200000, 75 | img_size=64, 76 | gray=False): 77 | super(CelebAGen).__init__() 78 | assert split in ['train', 'val', 'test'] 79 | self.transform = transforms.Compose([ 80 | transforms.Grayscale(num_output_channels=1), 81 | transforms.Resize(img_size), 82 | transforms.CenterCrop(img_size), 83 | transforms.ToTensor() 84 | ]) if gray else transforms.Compose([ 85 | transforms.Resize(img_size), 86 | transforms.CenterCrop(img_size), 87 | transforms.ToTensor() 88 | ]) 89 | self.name_dir = img_dir + split + '/' 90 | self.name_list = sorted(os.listdir(self.name_dir)) 91 | self.name_list = self.name_list[:num_img] 92 | 93 | def __len__(self): 94 | return len(self.name_list) 95 | 96 | def __getitem__(self, idx): 97 | name = self.name_list[idx] 98 | img = Image.open(self.name_dir + name) 99 | img = self.transform(img) 100 | return img -------------------------------------------------------------------------------- /experiments/README.md: -------------------------------------------------------------------------------- 1 | # Experiments 2 | 3 | This folder provides scripts for our evaluations. 4 | 5 | ## Models and Datasets 6 | 7 | - `dataset.py` - We implement two Pytorch Dataset classes for CelebA. `CelebARecog` is used for training face recognition models. `CelebAGen` is employed for training face image generator. 8 | 9 | - `model.py` - We implement our models in this script. In accordance to requirements of GenProver/ExactLine, the implementations are carefully crafted. See details in [GenProver](https://github.com/Yuanyuan-Yuan/GCert/tree/main/frameworks/GenProver). 10 | 11 | - `face_recognition.py` - Our face recognition model takes a tuple of two images as one input and predicts whether the two faces are from the same person. This script implements how we train the face recognition model. 12 | 13 | See [data](https://github.com/Yuanyuan-Yuan/GCert/tree/main/data) for how to download and process the datasets. 14 | 15 | ## Mutations 16 | 17 | ### Geometrical 18 | 19 | - `augment_geometrical.py` - This script shows how we augment the training data with different geometrical (affine) mutations. In brief, this is achieved by applying the mutation in runtime. 20 | 21 | Pytorch `transforms` module supports randomly applying affine mutations on each input, see implementation below. 22 | 23 | ```python 24 | transforms.RandomAffine( 25 | degrees=30, 26 | # translate=(0.3, 0.3), 27 | # scale=(0.75, 1.2), 28 | # shear=(0.2) 29 | ), 30 | ``` 31 | 32 | We also provide implementations of different mutations in `mutation.py`. Below is the example of rotation. 33 | 34 | ```python 35 | class Rotation(Transformation): 36 | def init_id(self): 37 | self.category = 'geometrical' 38 | self.name = 'rotation' 39 | 40 | def mutate(self, seed): 41 | x = seed['x'] 42 | img = self.torch2cv(x) 43 | ext = self.extent() 44 | rows, cols, ch = img.shape 45 | M = cv2.getRotationMatrix2D((cols/2, rows/2), ext, 1) 46 | x_ = cv2.warpAffine(img, M, (cols, rows)) 47 | return self.cv2torch(x_), seed['z'] 48 | 49 | def extent(self): 50 | ext = np.random.choice(list(range(-180, 180))) 51 | # Set the maximal extent of mutations here 52 | return ext 53 | ``` 54 | 55 | You can also augment the training data with rotation (such that rotation can be decomposed from the latent space of the generative model) in the follow way. 56 | 57 | ```python 58 | from mutation import Rotation 59 | 60 | for epoch in range(num_epoch): 61 | for (image, *_) in dataloader: 62 | image_ = Rotation.mutate(image) 63 | # Then use `image_` to train the generative model 64 | ``` 65 | 66 | ### Perceptual-Level 67 | 68 | For perceptual-level mutations, since they are extracted from the perception variations from natural images, you don not need to anything; just train a standard generative model. See `implementation/independence.py` for how to obtain perceptual-level mutations. 69 | 70 | ### Stylized 71 | 72 | For stylized mutations, you need to train the generative model following the cycle-consistency (which is proposed in CycleGAN). The official Pytorch implementation of CycleGAN is provided [here](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix). You can smoothly set up everything following the official documents. 73 | 74 | For different artistical styles, we use the style files provided [here](https://github.com/rgeirhos/Stylized-ImageNet). 75 | 76 | For weather-filters, we use the simulated filters provided by [imgaug](https://github.com/aleju/imgaug). The implementations are given in `mutation.py`. Below is an example of the foggy mutation. 77 | 78 | ```python 79 | import imgaug.augmenters as iaa 80 | 81 | class Weather(Transformation): 82 | def mutate(self, seed): 83 | x = seed['x'] 84 | img = self.torch2cv(x) 85 | x_ = self.trans(images=[img])[0] 86 | return self.cv2torch(x_), seed['z'] 87 | 88 | class Fog(Weather): 89 | def init_id(self): 90 | self.category = 'style' 91 | self.name = 'fog' 92 | self.trans = iaa.Fog() 93 | ``` 94 | 95 | ## Evalution Tools 96 | 97 | - `rectangle.py` - This script implements how to calculate the minimal enclosing rectangle for assessing the geometrical properties. 98 | 99 | - `synthetic_data.py` - This script implements the synthetic dataset of our ablation study. You can directly use the `SyntheticDataset` class as one Pytorch dataset class. -------------------------------------------------------------------------------- /implementation/README.md: -------------------------------------------------------------------------------- 1 | # Implementation 2 | 3 | This folder provides implementations and examples of regulating generative models with independence and continuity. 4 | 5 | ## Continuity 6 | 7 | To enforce the continuity, you need to add an extra training objective. See more details in `implementation/continuity.py`. Below, we show how to train a conventional GAN with regulation of continuity. 8 | 9 | ```python 10 | def continuity(generator): 11 | # if the latent space follows uniform distribution 12 | z1 = Tensor(np.random.uniform(-1, 1, (batch_size, latent_dimension))) 13 | z2 = Tensor(np.random.uniform(-1, 1, (batch_size, latent_dimension))) 14 | # # if the latent space follows normal distribution 15 | # z1 = Tensor(np.random.normal(0, 1, (batch_size, latent_dimension))) 16 | # z2 = Tensor(np.random.normal(0, 1, (batch_size, latent_dimension))) 17 | G1 = generator(z1) 18 | G2 = generator(z2) 19 | gamma = random.uniform(0, 1) 20 | z = torch.lerp(z1, z2, gamma) 21 | # an `intermediate point` between z1 and z2 22 | G = generator(z) 23 | penality = (gamma * G2 - G - (1 - gamma) * G1).square().mean() 24 | return penality 25 | 26 | n_epochs = 100 27 | for epoch in range(n_epochs): 28 | generator.train() 29 | discriminator.train() 30 | for i, (images, *_) in enumerate(tqdm(dataloader)): 31 | 32 | images = Tensor(images) 33 | 34 | optimizer_G.zero_grad() 35 | z = Tensor(np.random.uniform(-1, 1, (batch_size, latent_dimension))) 36 | # or: Tensor(np.random.normal(0, 1, (batch_size, latent_dimension))) 37 | G = generator(z) 38 | g_loss = bce(discriminator(G), real) 39 | 40 | g_loss += continuity(generator) 41 | # just add this one line :D 42 | 43 | g_loss.backward() 44 | optimizer_G.step() 45 | 46 | optimizer_D.zero_grad() 47 | real_loss = bce(discriminator(images), real) 48 | fake_loss = bce(discriminator(G), fake) 49 | d_loss = (real_loss + fake_loss) / 2 50 | d_loss.backward() 51 | optimizer_D.step() 52 | ``` 53 | 54 | ## Independence 55 | 56 | The independence is ensured from the following two aspects. 57 | 58 | ### Global Mutations 59 | 60 | For global mutations, different mutations are represented as *orthogonal* directions in the latent space. This is achieved using SVD; see details in `independence.py`. 61 | 62 | Below is an example of getting global mutating directions 63 | 64 | ```python 65 | J = Jacobian(G, z) 66 | # `G` is the generative model and `z` is the latent point 67 | directions = get_direction(J, None) 68 | ``` 69 | 70 | ### Local Mutations 71 | 72 | For local mutations, besides representing different mutations as orthogonal directions, we also ensure that only the selected local region is mutated. This is achieved by projecting mutating directions of the local region into non-mutating directions of the background. 73 | 74 | Before performing local mutation, you need to manualy set the foreground and backgroud indexes. Below is an example of mutating eyes for ffhq images. 75 | 76 | ```python 77 | COORDINATE_ffhq = { 78 | 'left_eye': [120, 95, 20, 38], 79 | 'right_eye': [120, 159, 20, 38], 80 | 'eyes': [120, 128, 20, 115], 81 | 'nose': [142, 131, 40, 46], 82 | 'mouth': [184, 127, 30, 70], 83 | 'chin': [217, 130, 42, 110], 84 | 'eyebrow': [126, 105, 15, 118], 85 | } 86 | 87 | def get_mask_by_coordinates(image_size, coordinate): 88 | """Get mask using the provided coordinates.""" 89 | mask = np.zeros([image_size, image_size], dtype=np.float32) 90 | center_x, center_y = coordinate[0], coordinate[1] 91 | crop_x, crop_y = coordinate[2], coordinate[3] 92 | xx = center_x - crop_x // 2 93 | yy = center_y - crop_y // 2 94 | mask[xx:xx + crop_x, yy:yy + crop_y] = 1. 95 | return mask 96 | 97 | coords = COORDINATE_ffhq['eyes'] 98 | mask = get_mask_by_coordinates(256, coordinate=coords) 99 | foreground_ind = np.where(mask == 1) 100 | background_ind = np.where((1 - mask) == 1) 101 | directions = get_direction(J, None, foreground_ind, background_ind) 102 | ``` 103 | 104 | ## Performing Mutations 105 | 106 | Once you get the mutating directions, you can perform mutations in the following way. 107 | 108 | ```python 109 | delta = 1.0 110 | for i in range(len(directions)): 111 | v = directions[i] 112 | x_ = G(z + delta * v) 113 | ``` 114 | 115 | `delta` controls the extent of the mutation. `x_` is the mutated input using the `i`-th mutating direction. -------------------------------------------------------------------------------- /experiments/face_recognition.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | from tqdm import tqdm 5 | import numpy as np 6 | import math 7 | import sys 8 | 9 | import torchvision.transforms as transforms 10 | from torchvision.utils import save_image 11 | 12 | from torch.utils.data import DataLoader 13 | from torchvision import datasets 14 | from torch.autograd import Variable 15 | 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | import torch.autograd as autograd 19 | import torch 20 | 21 | from model import * 22 | from dataset import CelebARecog 23 | 24 | os.makedirs('images', exist_ok=True) 25 | os.makedirs('ckpt', exist_ok=True) 26 | 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('--exp_name', type=str, default='face_recog_32', help='experiment name') 29 | parser.add_argument('--n_epochs', type=int, default=200, help='number of epochs of training') 30 | parser.add_argument('--batch_size', type=int, default=64, help='size of the batches') 31 | parser.add_argument('--num_tuple', type=int, default=10000) 32 | parser.add_argument('--lr', type=float, default=0.0002, help='adam: learning rate') 33 | parser.add_argument('--b1', type=float, default=0.5, help='adam: decay of first order momentum of gradient') 34 | parser.add_argument('--b2', type=float, default=0.999, help='adam: decay of first order momentum of gradient') 35 | parser.add_argument('--n_cpu', type=int, default=8, help='number of cpu threads to use during batch generation') 36 | parser.add_argument('--latent_dim', type=int, default=100, help='dimensionality of the latent space') 37 | parser.add_argument('--img_size', type=int, default=32, help='size of each image dimension') 38 | parser.add_argument('--channels', type=int, default=3, help='number of image channels') 39 | parser.add_argument('--n_critic', type=int, default=3, help='number of training steps for discriminator per iter') 40 | parser.add_argument('--lambda_gp', type=float, default=10, help='loss weight for gradient penalty') 41 | parser.add_argument('--clip_value', type=float, default=0.01, help='lower and upper clip value for disc. weights') 42 | # parser.add_argument('--sample_interval', type=int, default=500, help='interval betwen image samples') 43 | parser.add_argument('--save_every', type=int, default=50, help='interval betwen image samples') 44 | args = parser.parse_args() 45 | print(args) 46 | 47 | os.makedirs('images/%s' % args.exp_name, exist_ok=True) 48 | os.makedirs('ckpt/%s' % args.exp_name, exist_ok=True) 49 | 50 | img_shape = (args.channels, args.img_size, args.img_size) 51 | 52 | cuda = True if torch.cuda.is_available() else False 53 | 54 | # classifier = RecogSeq64() 55 | classifier = RecogSeq32() 56 | 57 | if cuda: 58 | classifier.cuda() 59 | 60 | # Configure data loader 61 | 62 | train_set = CelebARecog(split='train', num_tuple=args.num_tuple, img_size=args.img_size) 63 | test_set = CelebARecog(split='test', num_tuple=args.num_tuple, img_size=args.img_size) 64 | 65 | train_loader = torch.utils.data.DataLoader( 66 | train_set, 67 | batch_size=args.batch_size, 68 | shuffle=True, 69 | ) 70 | 71 | test_loader = torch.utils.data.DataLoader( 72 | test_set, 73 | batch_size=args.batch_size, 74 | shuffle=True, 75 | ) 76 | 77 | # Optimizers 78 | optimizer = torch.optim.Adam(classifier.parameters(), lr=args.lr, betas=(args.b1, args.b2)) 79 | 80 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 81 | 82 | 83 | # ---------- 84 | # Training 85 | # ---------- 86 | 87 | def accuracy(pred, target): 88 | is_same = ((pred > 0.5) == target) 89 | return (is_same.sum() / len(is_same)).item() 90 | 91 | mse = nn.MSELoss().cuda() 92 | bce_log = nn.BCEWithLogitsLoss().cuda() 93 | bce = nn.BCELoss().cuda() 94 | 95 | def train(): 96 | loss_list = [] 97 | pos_acc_list = [] 98 | neg_acc_list = [] 99 | classifier.train() 100 | for i, (img, same, diff) in enumerate(tqdm(train_loader)): 101 | 102 | # Configure input 103 | img = Variable(img.type(Tensor)) 104 | same = Variable(same.type(Tensor)) 105 | diff = Variable(diff.type(Tensor)) 106 | 107 | ones = Variable(torch.ones([img.size(0), 1]).type(Tensor)) 108 | zeros = Variable(torch.zeros([img.size(0), 1]).type(Tensor)) 109 | 110 | optimizer.zero_grad() 111 | 112 | pred_same = classifier((img, same)) 113 | pred_diff = classifier((img, diff)) 114 | 115 | loss_same = bce(pred_same, ones) 116 | loss_diff = bce(pred_diff, zeros) 117 | loss = loss_same + loss_diff 118 | 119 | loss.backward() 120 | optimizer.step() 121 | 122 | pos_acc = accuracy(pred_same, ones) 123 | neg_acc = accuracy(pred_diff, zeros) 124 | loss_list.append(loss.item()) 125 | pos_acc_list.append(pos_acc) 126 | neg_acc_list.append(neg_acc) 127 | return loss_list, pos_acc_list, neg_acc_list 128 | 129 | def test(): 130 | with torch.no_grad(): 131 | loss_list = [] 132 | pos_acc_list = [] 133 | neg_acc_list = [] 134 | classifier.eval() 135 | for i, (img, same, diff) in enumerate(tqdm(test_loader)): 136 | 137 | # Configure input 138 | img = Variable(img.type(Tensor)) 139 | same = Variable(same.type(Tensor)) 140 | diff = Variable(diff.type(Tensor)) 141 | 142 | ones = Variable(torch.ones([img.size(0), 1]).type(Tensor)) 143 | zeros = Variable(torch.zeros([img.size(0), 1]).type(Tensor)) 144 | 145 | pred_same = classifier((img, same)) 146 | pred_diff = classifier((img, diff)) 147 | 148 | loss_same = bce(pred_same, ones) 149 | loss_diff = bce(pred_diff, zeros) 150 | loss = loss_same + loss_diff 151 | 152 | pos_acc = accuracy(pred_same, ones) 153 | neg_acc = accuracy(pred_diff, zeros) 154 | loss_list.append(loss.item()) 155 | pos_acc_list.append(pos_acc) 156 | neg_acc_list.append(neg_acc) 157 | return loss_list, pos_acc_list, neg_acc_list 158 | 159 | 160 | for epoch in range(args.n_epochs): 161 | loss_list, pos_acc_list, neg_acc_list = train() 162 | print( 163 | '[Epoch %d/%d] [loss: %f] [pos acc: %f] [neg acc: %f]' 164 | % (epoch, args.n_epochs, np.mean(loss_list), np.mean(pos_acc_list), np.mean(neg_acc_list)) 165 | ) 166 | if epoch % args.save_every == 0: 167 | loss_list, pos_acc_list, neg_acc_list = test() 168 | print( 169 | '[Test] [loss: %f] [pos acc: %f] [neg acc: %f]' 170 | % (np.mean(loss_list), np.mean(pos_acc_list), np.mean(neg_acc_list)) 171 | ) 172 | state = { 173 | 'classifier': classifier.state_dict(), 174 | 'optimizer': optimizer.state_dict(), 175 | } 176 | torch.save(state, 'ckpt/%s/%d.ckpt' % (args.exp_name, epoch)) 177 | 178 | loss_list, pos_acc_list, neg_acc_list = test() 179 | print( 180 | '[Test] [loss: %f] [pos acc: %f] [neg acc: %f]' 181 | % (np.mean(loss_list), np.mean(pos_acc_list), np.mean(neg_acc_list)) 182 | ) 183 | state = { 184 | 'classifier': classifier.state_dict(), 185 | 'optimizer': optimizer.state_dict(), 186 | } 187 | torch.save(state, 'ckpt/%s/final.ckpt' % (args.exp_name)) -------------------------------------------------------------------------------- /experiments/mutation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import cv2 4 | import numpy as np 5 | 6 | import imgaug.augmenters as iaa 7 | 8 | import torch 9 | import torch.nn as n 10 | import torch.nn.functional as F 11 | 12 | import tool 13 | 14 | __all__ = [ 15 | 'Noise', 'Brightness', 'Contrast', 'Blur', 16 | 'Translation', 'Scale', 'Rotation', 'Shear', 'Reflection', 17 | 'Cloud', 'Fog', 'Snow', 'Rain', 18 | ] 19 | 20 | class Transformation: 21 | def __init__(self): 22 | self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 23 | self.init_id() 24 | 25 | def torch2cv(self, tsr): 26 | assert tsr.size(0) == 1 27 | # tsr.max() == 1 and tsr.min() == -1 28 | tsr = tool.general_normalize_inv(tsr) 29 | arr = tsr[0].detach().transpose(0, 2).cpu().numpy() * 255 30 | return arr 31 | 32 | def cv2torch(self, arr): 33 | arr = arr.astype(np.float32) 34 | tsr = torch.from_numpy(arr / 255).transpose(0, 2).unsqueeze(0) 35 | tsr = tool.general_normalize(tsr)#.to(self.device) 36 | return tsr 37 | 38 | def cv_pad(self, src, dst): 39 | H_src, W_src, C = src.shape 40 | H_dst, W_dst = dst.shape[:2] 41 | 42 | top = np.max((H_src - H_dst) // 2, 0) 43 | left = np.max((W_src - W_dst) // 2, 0) 44 | 45 | if len(dst.shape) == 2: 46 | COLOR = 0 47 | padded = np.full((H_src, W_src, C), COLOR, dtype=np.uint8) 48 | padded[top:top+H_dst, left:left+W_dst, 0] = dst 49 | else: 50 | COLOR = (0, 0, 0) 51 | padded = np.full((H_src, W_src, C), COLOR, dtype=np.uint8) 52 | padded[top:top+H_dst, left:left+W_dst] = dst 53 | return padded 54 | 55 | def init_id(self): 56 | raise NotImplementedError 57 | 58 | def mutate(self, x): 59 | raise NotImplementedError 60 | 61 | def extent(self): 62 | raise NotImplementedError 63 | 64 | 65 | class Noise(Transformation): 66 | def init_id(self): 67 | self.category = 'pixel' 68 | self.name = 'noise' 69 | 70 | def mutate(self, seed): 71 | x = seed['x'] 72 | x_ = x + self.extent() * torch.randn(x.size())#.to(self.device) 73 | return x_, seed['z'] 74 | 75 | def extent(self): 76 | return np.random.uniform(0, 1) 77 | 78 | class Brightness(Transformation): 79 | def init_id(self): 80 | self.category = 'pixel' 81 | self.name = 'brightness' 82 | 83 | def mutate(self, seed): 84 | x = seed['x'] 85 | arr = self.torch2cv(x) 86 | x_ = cv2.convertScaleAbs(arr, beta=self.extent(), alpha=1) 87 | return self.cv2torch(x_), seed['z'] 88 | 89 | def extent(self): 90 | return 10 + 10 * np.random.randint(7) 91 | 92 | class Contrast(Transformation): 93 | def init_id(self): 94 | self.category = 'pixel' 95 | self.name = 'contrast' 96 | 97 | def mutate(self, seed): 98 | x = seed['x'] 99 | arr = self.torch2cv(x) 100 | x_ = cv2.convertScaleAbs(arr, beta=0, alpha=self.extent()) 101 | return self.cv2torch(x_), seed['z'] 102 | 103 | def extent(self): 104 | return 0.8 + 0.2 * np.random.randint(7) 105 | 106 | class Blur(Transformation): 107 | def init_id(self): 108 | self.category = 'pixel' 109 | self.name = 'blur' 110 | 111 | def mutate(self, seed): 112 | x = seed['x'] 113 | arr = self.torch2cv(x) 114 | x_ = self.extent()(arr) 115 | return self.cv2torch(x_), seed['z'] 116 | 117 | def extent(self): 118 | blr = np.random.choice([ 119 | lambda img: cv2.blur(img, (3, 3)), 120 | lambda img: cv2.blur(img, (4, 4)), 121 | lambda img: cv2.blur(img, (5, 5)), 122 | lambda img: cv2.blur(img, (6, 6)), 123 | lambda img: cv2.GaussianBlur(img, (3, 3), 0), 124 | lambda img: cv2.GaussianBlur(img, (5, 5), 0), 125 | lambda img: cv2.GaussianBlur(img, (7, 7), 0), 126 | lambda img: cv2.medianBlur(img, 3), 127 | lambda img: cv2.medianBlur(img, 5), 128 | lambda img: cv2.bilateralFilter(img, 9, 75, 75), 129 | ]) 130 | return blr 131 | 132 | class Translation(Transformation): 133 | def init_id(self): 134 | self.category = 'geometrical' 135 | self.name = 'translation' 136 | 137 | def mutate(self, seed): 138 | x = seed['x'] 139 | img = self.torch2cv(x) 140 | params = self.extent() 141 | rows, cols, ch = img.shape 142 | M = np.float32([[1, 0, params[0]], [0, 1, params[1]]]) 143 | x_ = cv2.warpAffine(img, M, (cols, rows)) 144 | return self.cv2torch(x_), seed['z'] 145 | 146 | def extent(self): 147 | params = [np.random.randint(-3, 4) * 5, np.random.randint(-3, 4) * 5] 148 | return params 149 | 150 | class Scale(Transformation): 151 | def init_id(self): 152 | self.category = 'geometrical' 153 | self.name = 'scale' 154 | 155 | def mutate(self, seed): 156 | x = seed['x'] 157 | img = self.torch2cv(x) 158 | ext = self.extent() 159 | res = cv2.resize(img, None, fx=ext, fy=ext, interpolation=cv2.INTER_CUBIC) 160 | if ext <= 1: 161 | x_ = self.cv_pad(img, res) 162 | else: 163 | H, W = img.shape[:2] 164 | top = (res.shape[0] - H) // 2 165 | left = (res.shape[1] - W) // 2 166 | x_ = res[top:top+H, left:left+W] 167 | return self.cv2torch(x_), seed['z'] 168 | 169 | def extent(self): 170 | ext = np.random.choice(list(np.arange(0.5, 1.2, 0.05))) 171 | return ext 172 | 173 | class Rotation(Transformation): 174 | def init_id(self): 175 | self.category = 'geometrical' 176 | self.name = 'rotation' 177 | 178 | def mutate(self, seed): 179 | x = seed['x'] 180 | img = self.torch2cv(x) 181 | ext = self.extent() 182 | rows, cols, ch = img.shape 183 | M = cv2.getRotationMatrix2D((cols/2, rows/2), ext, 1) 184 | x_ = cv2.warpAffine(img, M, (cols, rows)) 185 | return self.cv2torch(x_), seed['z'] 186 | 187 | def extent(self): 188 | ext = np.random.choice(list(range(-180, 180))) 189 | return ext 190 | 191 | class Shear(Transformation): 192 | def init_id(self): 193 | self.category = 'geometrical' 194 | self.name = 'shear' 195 | 196 | def mutate(self, seed): 197 | x = seed['x'] 198 | img = self.torch2cv(x) 199 | rows, cols, ch = img.shape 200 | ext = self.extent() 201 | factor = ext * (-1.0) 202 | M = np.float32([[1, factor, 0], [0, 1, 0]]) 203 | dst = cv2.warpAffine(img, M, (cols, rows)) 204 | x_ = self.cv_pad(img, dst) 205 | return self.cv2torch(x_), seed['z'] 206 | 207 | def extent(self): 208 | ext = np.random.choice(list(range(-2, 2))) 209 | return ext 210 | 211 | class Reflection(Transformation): 212 | def init_id(self): 213 | self.category = 'geometrical' 214 | self.name = 'reflection' 215 | 216 | def mutate(self, seed): 217 | x = seed['x'] 218 | img = self.torch2cv(x) 219 | x_ = cv2.flip(img, self.extent()) 220 | return self.cv2torch(x_), seed['z'] 221 | 222 | def extent(self): 223 | ext = np.random.randint(-1, 2) 224 | return ext 225 | 226 | class Weather(Transformation): 227 | def mutate(self, seed): 228 | x = seed['x'] 229 | img = self.torch2cv(x) 230 | x_ = self.trans(images=[img])[0] 231 | return self.cv2torch(x_), seed['z'] 232 | 233 | class Cloud(Weather): 234 | def init_id(self): 235 | self.category = 'style' 236 | self.name = 'cloud' 237 | self.trans = iaa.Clouds() 238 | 239 | class Fog(Weather): 240 | def init_id(self): 241 | self.category = 'style' 242 | self.name = 'fog' 243 | self.trans = iaa.Fog() 244 | 245 | class Snow(Weather): 246 | def init_id(self): 247 | self.category = 'style' 248 | self.name = 'snow' 249 | self.trans = iaa.Snowflakes(flake_size=(0.1, 0.4), speed=(0.01, 0.05)) 250 | 251 | class Rain(Weather): 252 | def init_id(self): 253 | self.category = 'style' 254 | self.name = 'rain' 255 | self.trans = iaa.Rain(speed=(0.1, 0.3)) -------------------------------------------------------------------------------- /implementation/independence.py: -------------------------------------------------------------------------------- 1 | ############################################################ 2 | # In this script we show how to get independent mutations. # 3 | # The implementation is based on # 4 | # https://github.com/zhujiapeng/LowRankGAN # 5 | # and https://github.com/zhujiapeng/resefa # 6 | ############################################################ 7 | 8 | import os 9 | from tqdm import tqdm 10 | import numpy as np 11 | 12 | import torch 13 | from torch.autograd.functional import jacobian 14 | 15 | from LRF import RobustPCA 16 | 17 | def batched_jacobian(f, x): 18 | """Computes the Jacobian of f w.r.t x. 19 | 20 | This is according to the reverse mode autodiff rule, 21 | 22 | sum_i v^b_i dy^b_i / dx^b_j = sum_i x^b_j R_ji v^b_i, 23 | 24 | where: 25 | - b is the batch index from 0 to B - 1 26 | - i, j are the vector indices from 0 to N-1 27 | - v^b_i is a "test vector", which is set to 1 column-wise to obtain the correct 28 | column vectors out ot the above expression. 29 | 30 | :param f: function R^N -> R^N 31 | :param x: torch.tensor of shape [B, N] 32 | :return: Jacobian matrix (torch.tensor) of shape [B, N, N] 33 | """ 34 | x.requires_grad = True 35 | B, N = x.shape 36 | y = f(x) 37 | jacobian = list() 38 | for i in range(N): 39 | v = torch.zeros_like(y) 40 | v[:, i] = 1. 41 | dy_i_dx = torch.autograd.grad( 42 | y, 43 | x, 44 | grad_outputs=v, 45 | retain_graph=True, 46 | create_graph=True, 47 | allow_unused=True 48 | )[0] # shape [B, N] 49 | jacobian.append(dy_i_dx) 50 | 51 | jacobian = torch.stack(jacobian, dim=2).requires_grad_() 52 | 53 | return jacobian 54 | 55 | def Jacobian(G, latent_zs): 56 | jacobians = [] 57 | for idx in tqdm(range(latent_zs.shape[0])): 58 | latent_z = latent_zs[idx:idx+1] 59 | jac_i = jacobian( 60 | func=G, 61 | inputs=latent_z, 62 | create_graph=False, 63 | strict=False 64 | ) 65 | # print('jac_i: ', jac_i.size()) 66 | jacobians.append(jac_i) 67 | jacobians = torch.cat(jacobians, dim=0) 68 | print('jacobians size: ', jacobians.size()) 69 | np_jacobians = jacobians.detach().cpu().numpy() 70 | return np_jacobians 71 | 72 | def Jacobian_Y(G, latent_zs, ys): 73 | jacobians = [] 74 | for idx in tqdm(range(latent_zs.shape[0])): 75 | latent_z = latent_zs[idx:idx+1] 76 | y = ys[idx:idx+1] 77 | jac_i = jacobian( 78 | func=G, 79 | inputs=(latent_z, y), 80 | create_graph=False, 81 | strict=False 82 | ) 83 | # print('jac_i: ', jac_i.size()) 84 | jacobians.append(jac_i[0]) 85 | jacobians = torch.cat(jacobians, dim=0) 86 | print('jacobians size: ', jacobians.size()) 87 | np_jacobians = jacobians.detach().cpu().numpy() 88 | return np_jacobians 89 | 90 | def get_direction(jacobians, save_dir, 91 | foreground_ind=None, 92 | background_ind=None, 93 | lamb=60, 94 | num_relax=0, 95 | max_iter=10000): 96 | # lamb: the coefficient to control the sparsity 97 | # num_relax: factor of relaxation for the non-zeros singular values 98 | image_size = jacobians.shape[2] 99 | z_dim = jacobians.shape[-1] 100 | for ind in tqdm(range(jacobians.shape[0])): 101 | jacobian = jacobians[ind] 102 | if foreground_ind is not None and background_ind is not None: 103 | if len(jacobian.shape) == 4: # [H, W, 1, latent_dim] 104 | jaco_fore = jacobian[foreground_ind[0], foreground_ind[1], 0] 105 | jaco_back = jacobian[background_ind[0], background_ind[1], 0] 106 | elif len(jacobian.shape) == 5: # [channel, H, W, 1, latent_dim] 107 | jaco_fore = jacobian[:, foreground_ind[0], foreground_ind[1], 0] 108 | jaco_back = jacobian[:, background_ind[0], background_ind[1], 0] 109 | else: 110 | raise ValueError(f'Shape of Jacobian is not correct!') 111 | jaco_fore = np.reshape(jaco_fore, [-1, z_dim]) 112 | jaco_back = np.reshape(jaco_back, [-1, z_dim]) 113 | coef_f = 1 / jaco_fore.shape[0] 114 | coef_b = 1 / jaco_back.shape[0] 115 | M_fore = coef_f * jaco_fore.T.dot(jaco_fore) 116 | B_back = coef_b * jaco_back.T.dot(jaco_back) 117 | # R-PCA on foreground 118 | RPCA = RobustPCA(M_fore, lamb=1/lamb) 119 | L_f, _ = RPCA.fit(max_iter=max_iter) 120 | rank_f = np.linalg.matrix_rank(L_f) 121 | # R-PCA on background 122 | RPCA = RobustPCA(B_back, lamb=1/lamb) 123 | L_b, _ = RPCA.fit(max_iter=max_iter) 124 | rank_b = np.linalg.matrix_rank(L_b) 125 | # SVD on the low-rank matrix 126 | _, _, VHf = np.linalg.svd(L_f) 127 | _, _, VHb = np.linalg.svd(L_b) 128 | F_principal = VHf[:rank_f] 129 | relax_subspace = min(max(1, rank_b - num_relax), z_dim-1) 130 | B_null = VHb[rank_b:].T 131 | 132 | F_principal_proj = B_null.dot(B_null.T).dot(F_principal.T) # Projection 133 | F_principal_proj = F_principal_proj.T 134 | F_principal_proj /= np.linalg.norm( 135 | F_principal_proj, axis=1, keepdims=True) 136 | print('direction size: ', F_principal_proj.shape) 137 | if save_dir is not None: 138 | save_name = '%d_direction.npy' % ind 139 | np.save(save_dir + save_name, F_principal_proj) 140 | return F_principal_proj 141 | else: 142 | jaco = np.reshape(jacobian, [-1, z_dim]) 143 | coef = 1 / jaco.shape[0] 144 | M = coef * jaco.T.dot(jaco) 145 | 146 | RPCA = RobustPCA(M, lamb=1/lamb) 147 | L, _ = RPCA.fit(max_iter=max_iter) 148 | rank = np.linalg.matrix_rank(L) 149 | _, _, VH = np.linalg.svd(L) 150 | principal = VH[:max(rank, 5)] 151 | print('direction size: ', principal.shape) 152 | if save_dir is not None: 153 | save_name = '%d_direction.npy' % ind 154 | np.save(save_dir + save_name, principal) 155 | return principal 156 | 157 | if __name__ == '__main__': 158 | G = None # The generative model 159 | z = None # The latent point 160 | y = None # The class label 161 | 162 | ######################################################## 163 | # 1. You need to first compute the Jacobian matrix. # 164 | ######################################################## 165 | 166 | # 1.1 For conventional generative model G 167 | J = Jacobian(G, z) 168 | 169 | # 1.2 For class-conditional G (e.g., BigGAN), which generates images of class `y` 170 | J = Jacobian(G, z, y) 171 | 172 | ############################################################# 173 | # 2.1 Then, you can get the mutating directions as follows: # 174 | ############################################################# 175 | directions = get_direction(J, save_dir=None) 176 | 177 | ############################################################################## 178 | # 2.2 For local mutations, you need to manually set the indexes of foreground # 179 | # and background. LowRankGAN authors provide examples of the indexes, see # 180 | # https://github.com/zhujiapeng/resefa/blob/main/coordinate.py and # 181 | # https://github.com/zhujiapeng/LowRankGAN/blob/master/coordinate.py # 182 | ############################################################################### 183 | 184 | COORDINATE_ffhq = { 185 | 'left_eye': [120, 95, 20, 38], 186 | 'right_eye': [120, 159, 20, 38], 187 | 'eyes': [120, 128, 20, 115], 188 | 'nose': [142, 131, 40, 46], 189 | 'mouth': [184, 127, 30, 70], 190 | 'chin': [217, 130, 42, 110], 191 | 'eyebrow': [126, 105, 15, 118], 192 | } 193 | 194 | def get_mask_by_coordinates(image_size, coordinate): 195 | """Get mask using the provided coordinates.""" 196 | mask = np.zeros([image_size, image_size], dtype=np.float32) 197 | center_x, center_y = coordinate[0], coordinate[1] 198 | crop_x, crop_y = coordinate[2], coordinate[3] 199 | xx = center_x - crop_x // 2 200 | yy = center_y - crop_y // 2 201 | mask[xx:xx + crop_x, yy:yy + crop_y] = 1. 202 | return mask 203 | 204 | coords = COORDINATE_ffhq['eyes'] 205 | mask = get_mask_by_coordinates(256, coordinate=coords) 206 | foreground_ind = np.where(mask == 1) 207 | background_ind = np.where((1 - mask) == 1) 208 | directions = get_direction(J, None, foreground_ind, background_ind) 209 | 210 | ################################################################################## 211 | # 3. Once you get the direction, you can perform mutations in the following way. # 212 | ################################################################################## 213 | 214 | delta = 1.0 215 | for i in range(len(directions)): 216 | v = directions[i] 217 | x_ = G(z + delta * v) 218 | # `x_` is the mutated input using the i-th mutating direction 219 | 220 | -------------------------------------------------------------------------------- /experiments/augment_geometrical.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | from tqdm import tqdm 5 | import numpy as np 6 | import math 7 | import sys 8 | 9 | import torchvision.transforms as transforms 10 | from torchvision.utils import save_image 11 | 12 | from torch.utils.data import DataLoader 13 | from torchvision import datasets 14 | from torch.autograd import Variable 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | import torch.autograd as autograd 20 | 21 | 22 | from model import * 23 | 24 | os.makedirs('images', exist_ok=True) 25 | os.makedirs('ckpt', exist_ok=True) 26 | 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('--exp_name', type=str, default='test', help='experiment name') 29 | parser.add_argument('--cls', type=int, default=0, help='selected class') 30 | parser.add_argument('--n_epochs', type=int, default=500, help='number of epochs of training') 31 | parser.add_argument('--batch_size', type=int, default=512, help='size of the batches') 32 | parser.add_argument('--lr', type=float, default=0.0002, help='adam: learning rate') 33 | parser.add_argument('--b1', type=float, default=0.5, help='adam: decay of first order momentum of gradient') 34 | parser.add_argument('--b2', type=float, default=0.999, help='adam: decay of first order momentum of gradient') 35 | parser.add_argument('--n_cpu', type=int, default=8, help='number of cpu threads to use during batch generation') 36 | parser.add_argument('--latent_dim', type=int, default=100, help='dimensionality of the latent space') 37 | parser.add_argument('--img_size', type=int, default=32, help='size of each image dimension') 38 | parser.add_argument('--channels', type=int, default=1, help='number of image channels') 39 | parser.add_argument('--n_critic', type=int, default=3, help='number of training steps for discriminator per iter') 40 | parser.add_argument('--lambda_gp', type=float, default=10, help='loss weight for gradient penalty') 41 | parser.add_argument('--clip_value', type=float, default=0.01, help='lower and upper clip value for disc. weights') 42 | # parser.add_argument('--sample_interval', type=int, default=500, help='interval betwen image samples') 43 | parser.add_argument('--save_every', type=int, default=50, help='interval betwen image samples') 44 | args = parser.parse_args() 45 | print(args) 46 | 47 | os.makedirs('images/%s' % args.exp_name, exist_ok=True) 48 | os.makedirs('ckpt/%s' % args.exp_name, exist_ok=True) 49 | 50 | img_shape = (args.channels, args.img_size, args.img_size) 51 | 52 | cuda = True if torch.cuda.is_available() else False 53 | 54 | # Initialize generator and discriminator 55 | generator = ConvGeneratorSeq() 56 | discriminator = Discriminator() 57 | 58 | if cuda: 59 | generator.cuda() 60 | discriminator.cuda() 61 | 62 | # Configure data loader 63 | os.makedirs('./data/mnist', exist_ok=True) 64 | 65 | dataset_full = datasets.MNIST( 66 | './data/mnist', 67 | train=True, 68 | download=True, 69 | transform=transforms.Compose([ 70 | transforms.RandomAffine( 71 | degrees=30, 72 | # translate=(0.3, 0.3), 73 | # scale=(0.75, 1.2), 74 | # shear=(0.2) 75 | ), 76 | 77 | transforms.Resize(args.img_size), 78 | transforms.ToTensor(), 79 | ] 80 | ), 81 | ) 82 | # Selecting classes 7, 2, 5 and 6 83 | if args.cls != -1: 84 | idx = (dataset_full.targets == args.cls) 85 | dataset_full.targets = dataset_full.targets[idx] 86 | dataset_full.data = dataset_full.data[idx] 87 | 88 | dataloader = torch.utils.data.DataLoader( 89 | dataset_full, 90 | batch_size=args.batch_size, 91 | shuffle=True, 92 | ) 93 | 94 | # Optimizers 95 | optimizer_G = torch.optim.Adam(generator.parameters(), lr=args.lr, betas=(args.b1, args.b2)) 96 | optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=args.lr, betas=(args.b1, args.b2)) 97 | 98 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 99 | 100 | 101 | ######################################## 102 | # This is original training objective. # 103 | ######################################## 104 | def compute_gradient_penalty(D, real_samples, fake_samples): 105 | '''Calculates the gradient penalty loss for WGAN GP''' 106 | # Random weight term for interpolation between real and fake samples 107 | alpha = Tensor(np.random.random((real_samples.size(0), 1, 1, 1))) 108 | # Get random interpolation between real and fake samples 109 | interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True) 110 | d_interpolates = D(interpolates) 111 | fake = Variable(Tensor(real_samples.shape[0], 1).fill_(1.0), requires_grad=False) 112 | # Get gradient w.r.t. interpolates 113 | gradients = autograd.grad( 114 | outputs=d_interpolates, 115 | inputs=interpolates, 116 | grad_outputs=fake, 117 | create_graph=True, 118 | retain_graph=True, 119 | only_inputs=True, 120 | )[0] 121 | gradients = gradients.view(gradients.size(0), -1) 122 | gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() 123 | return gradient_penalty 124 | 125 | ######################################## 126 | # Regulation with continuity. # 127 | ######################################## 128 | def continuity(generator): 129 | # if the latent space follows uniform distribution 130 | z1 = Tensor(np.random.uniform(-1, 1, (batch_size, latent_dimension))) 131 | z2 = Tensor(np.random.uniform(-1, 1, (batch_size, latent_dimension))) 132 | # # if the latent space follows normal distribution 133 | # z1 = Tensor(np.random.normal(0, 1, (batch_size, latent_dimension))) 134 | # z2 = Tensor(np.random.normal(0, 1, (batch_size, latent_dimension))) 135 | G1 = generator(z1) 136 | G2 = generator(z2) 137 | gamma = random.uniform(0, 1) 138 | z = torch.lerp(z1, z2, gamma) 139 | # an `intermediate point` between z1 and z2 140 | G = generator(z) 141 | penality = (gamma * G2 - G - (1 - gamma) * G1).square().mean() 142 | return penality 143 | 144 | ######################################## 145 | # Training # 146 | ######################################## 147 | for epoch in range(args.n_epochs): 148 | d_loss_list = [] 149 | g_loss_list = [] 150 | 151 | generator.train() 152 | discriminator.train() 153 | for i, (imgs, *_) in enumerate(tqdm(dataloader)): 154 | 155 | # Configure input 156 | real_imgs = Variable(imgs.type(Tensor)) 157 | 158 | # --------------------- 159 | # Train Discriminator 160 | # --------------------- 161 | optimizer_D.zero_grad() 162 | # Sample noise as generator input 163 | z = Variable(Tensor(np.random.uniform(-1, 1, (imgs.shape[0], args.latent_dim)))) 164 | # Generate a batch of images 165 | fake_imgs = generator(z) 166 | 167 | # Real images 168 | real_validity = discriminator(real_imgs) 169 | # Fake images 170 | fake_validity = discriminator(fake_imgs) 171 | # Gradient penalty 172 | gradient_penalty = compute_gradient_penalty(discriminator, real_imgs.data, fake_imgs.data) 173 | # Adversarial loss 174 | d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + args.lambda_gp * gradient_penalty 175 | 176 | d_loss.backward() 177 | optimizer_D.step() 178 | 179 | optimizer_G.zero_grad() 180 | 181 | # Train the generator every n_critic steps 182 | if i % args.n_critic == 0: 183 | 184 | # ----------------- 185 | # Train Generator 186 | # ----------------- 187 | 188 | # Generate a batch of images 189 | fake_imgs = generator(z) 190 | # Loss measures generator's ability to fool the discriminator 191 | # Train on fake images 192 | fake_validity = discriminator(fake_imgs) 193 | g_loss = -torch.mean(fake_validity) 194 | 195 | continuity_penalty = continuity(generator) 196 | 197 | (g_loss + continuity_penalty).backward(retain_graph=True) 198 | optimizer_G.step() 199 | 200 | d_loss_list.append(d_loss.item()) 201 | g_loss_list.append(g_loss.item()) 202 | 203 | print( 204 | '[Epoch %d/%d] [D loss: %f] [G loss: %f]' 205 | % (epoch, args.n_epochs, np.mean(d_loss_list), np.mean(g_loss_list)) 206 | ) 207 | if epoch % args.save_every == 0: 208 | save_image(fake_imgs.data[:100], 'images/%s/%d_fake.png' % (args.exp_name, epoch), nrow=10, normalize=True) 209 | save_image(real_imgs.data[:100], 'images/%s/%d_real.png' % (args.exp_name, epoch), nrow=10, normalize=True) 210 | generator.eval() 211 | discriminator.eval() 212 | state = { 213 | 'generator': generator.state_dict(), 214 | 'discriminator': discriminator.state_dict(), 215 | 'optimizer_G': optimizer_G.state_dict(), 216 | 'optimizer_D': optimizer_D.state_dict() 217 | } 218 | torch.save(state, 'ckpt/%s/%d.ckpt' % (args.exp_name, epoch)) 219 | 220 | save_image(fake_imgs.data[:100], 'images/%s/final_fake.png' % (args.exp_name), nrow=10, normalize=True) 221 | save_image(real_imgs.data[:100], 'images/%s/final_real.png' % (args.exp_name), nrow=10, normalize=True) 222 | generator.eval() 223 | discriminator.eval() 224 | state = { 225 | 'generator': generator.state_dict(), 226 | 'discriminator': discriminator.state_dict(), 227 | 'optimizer_G': optimizer_G.state_dict(), 228 | 'optimizer_D': optimizer_D.state_dict() 229 | } 230 | torch.save(state, 'ckpt/%s/final.ckpt' % (args.exp_name)) -------------------------------------------------------------------------------- /frameworks/GenProver/genmodels.py: -------------------------------------------------------------------------------- 1 | try: 2 | from . import components as n 3 | from . import ai 4 | from . import scheduling as S 5 | from . import helpers as h 6 | except: 7 | import components as n 8 | import scheduling as S 9 | import ai 10 | import helpers as h 11 | 12 | def ConvTinyInv(c): 13 | w = int(c[-1] / 2) 14 | return n.Seq(n.InvLeNet([200, w * w * 8], w, [ (8,3,2,1,1) , (h.product(c[:-2]),3,1,1,0) ], ibp_init=True), n.View(c)) 15 | 16 | 17 | def ConvSmallInv(c): 18 | w = int(c[-1] / 2) 19 | return n.Seq(n.InvLeNet([400, w * w * 8], w, [ (16,3,2,1,1) , (h.product(c[:-2]),3,1,1,0) ], ibp_init=True), n.View(c)) 20 | 21 | def ConvMedInv(c): 22 | w = int(c[-1] / 4) 23 | return n.Seq(n.InvLeNet([1000,2000, w * w * 16], w, [ (64,3,2,1,1) , (32,3,1,1,0), (32,3,2,1,1) , (h.product(c[:-2]),3,1,1,0) ], ibp_init = True), n.View(c)) 24 | 25 | 26 | def ConvLargeInv(c): 27 | w = int(c[-1] / 4) 28 | return n.Seq(n.InvLeNet([500, 1000,2000, w * w * 16], w, [ (128,3,2,1,1) , (64,3,1,1,0), (32,3,2,1,1) , (h.product(c[:-2]),3,1,1,0) ], ibp_init = True), n.View(c)) 29 | 30 | 31 | def ConvDCInv(c): 32 | w = int(c[-1] / 4) 33 | return n.Seq(n.InvLeNet([w * w * 16], w, [ (128,3,2,1,1) , (64,3,1,1,0), (32,3,2,1,1) , (h.product(c[:-2]),3,1,1,0) ], ibp_init = True), n.View(c)) 34 | 35 | def ConvDeepCInv(c): 36 | w = int(c[-1] / 8) 37 | return n.Seq(n.InvLeNet([w * w * 16], w, [ (256,3,2,1,1) , (128,3,1,1,0), (64,3,2,1,1) , (32,3,1,1,0), (16,3,2,1,1) , (h.product(c[:-2]),3,1,1,0) ], ibp_init = True), n.View(c)) 38 | 39 | 40 | def ConvInvTest(c=[1, 32, 32]): 41 | w = int(c[-1] / 4) 42 | return n.Seq( 43 | n.InvLeNet( 44 | [1000, 2000, w * w * 16], 45 | w, 46 | [ 47 | (64,3,2,1,1), (32,3,1,1,0), (32,3,2,1,1), 48 | (h.product(c[:-2]),3,1,1,0) 49 | ], 50 | ibp_init=True 51 | ), 52 | n.View(c) 53 | ) 54 | 55 | def FFNN(layers, last_lin = False, last_zono = False, **kargs): 56 | starts = layers 57 | ends = [] 58 | if last_lin: 59 | ends = ( 60 | [CorrelateAll(only_train=False)] if last_zono else [] 61 | ) + [ 62 | PrintActivation(activation = "Affine"), 63 | Linear(layers[-1], **kargs) 64 | ] 65 | 66 | starts = layers[:-1] 67 | 68 | return Seq( 69 | *( 70 | [ 71 | Seq( 72 | PrintActivation(**kargs), 73 | Linear(s, **kargs), 74 | activation(**kargs) 75 | ) for s in starts 76 | ] + ends 77 | ) 78 | ) 79 | 80 | def InvLeNet(ly, w, conv_layers, bias=True, normal=False, **kargs): 81 | def transfer(tp, lin = False): 82 | return (ConvTranspose2D if lin else ConvTranspose)( 83 | out_channels=tp[0], 84 | kernel_size=tp[1], 85 | stride=tp[2], 86 | padding=tp[3], 87 | out_padding=tp[4], 88 | bias=False, 89 | normal=normal, 90 | **kargs 91 | ) 92 | 93 | return Seq( 94 | FFNN(ly, bias=bias, **kargs), 95 | Unflatten2d(w), 96 | *[ 97 | transfer(s) for s in conv_layers[:-1] 98 | ], 99 | transfer(conv_layers[-1], lin=True) 100 | ) 101 | 102 | def dcgan_upconv(nin, nout, **kargs): 103 | return n.Seq( 104 | n.ConvTranspose2D(nin, nout, 4, 2, 1), 105 | n.BatchNorm(nout), 106 | n.ReLU(), 107 | ) 108 | 109 | 110 | def ConvGenerator(**kargs): 111 | nf = 4 112 | dim = 50 113 | nc = 1 114 | return n.Seq( 115 | n.View([dim, 1, 1]), 116 | # n.ConvTranspose(out_channels=nf * 4, kernel_size=4, stride=1, padding=0, activation='ReLU', batch_norm=True, **kargs), 117 | n.Seq( 118 | n.ConvTranspose2D(out_channels=nf * 4, kernel_size=4, stride=1, padding=0, **kargs), 119 | n.BatchNorm(training=False, **kargs), 120 | n.Activation(activation='ReLU', **kargs), 121 | ), 122 | # n.ConvTranspose(out_channels=nf * 2, kernel_size=4, stride=2, padding=1, activation='ReLU', batch_norm=True, **kargs), 123 | n.Seq( 124 | n.ConvTranspose2D(out_channels=nf * 2, kernel_size=4, stride=2, padding=1, **kargs), 125 | n.BatchNorm(training=False, **kargs), 126 | n.Activation(activation='ReLU', **kargs), 127 | ), 128 | # n.ConvTranspose(out_channels=nf, kernel_size=4, stride=2, padding=1, activation='ReLU', batch_norm=True, **kargs), 129 | n.Seq( 130 | n.ConvTranspose2D(out_channels=nf, kernel_size=4, stride=2, padding=1, **kargs), 131 | n.BatchNorm(training=False, **kargs), 132 | n.Activation(activation='ReLU', **kargs), 133 | ), 134 | n.Seq( 135 | n.ConvTranspose2D(out_channels=nc, kernel_size=4, stride=2, padding=1, **kargs), 136 | n.Activation(activation='ReLU', **kargs), 137 | n.Negate(**kargs), 138 | n.AddOne(**kargs), 139 | n.Activation(activation='ReLU', **kargs) 140 | ) 141 | ) 142 | 143 | def ConvGenerator32(**kargs): 144 | nf = 4 145 | dim = 50 146 | nc = 1 147 | return n.Seq( 148 | n.View([dim, 1, 1]), 149 | # n.ConvTranspose(out_channels=nf * 4, kernel_size=4, stride=1, padding=0, activation='ReLU', batch_norm=True, **kargs), 150 | n.Seq( 151 | n.ConvTranspose2D(out_channels=nf * 4, kernel_size=4, stride=1, padding=0, **kargs), 152 | n.BatchNorm(training=False, **kargs), 153 | n.Activation(activation='ReLU', **kargs), 154 | ), 155 | # n.ConvTranspose(out_channels=nf * 2, kernel_size=4, stride=2, padding=1, activation='ReLU', batch_norm=True, **kargs), 156 | n.Seq( 157 | n.ConvTranspose2D(out_channels=nf * 2, kernel_size=4, stride=2, padding=1, **kargs), 158 | n.BatchNorm(training=False, **kargs), 159 | n.Activation(activation='ReLU', **kargs), 160 | ), 161 | # n.ConvTranspose(out_channels=nf, kernel_size=4, stride=2, padding=1, activation='ReLU', batch_norm=True, **kargs), 162 | n.Seq( 163 | n.ConvTranspose2D(out_channels=nf, kernel_size=4, stride=2, padding=1, **kargs), 164 | n.BatchNorm(training=False, **kargs), 165 | n.Activation(activation='ReLU', **kargs), 166 | ), 167 | n.Seq( 168 | n.ConvTranspose2D(out_channels=nc, kernel_size=4, stride=2, padding=1, **kargs), 169 | n.Activation(activation='ReLU', **kargs), 170 | n.Negate(**kargs), 171 | n.AddOne(**kargs), 172 | n.Activation(activation='ReLU', **kargs) 173 | ) 174 | ) 175 | 176 | def Recog32(**kargs): 177 | nf = 16 178 | dim = 100 179 | return n.Seq( 180 | n.CatTwo(), 181 | n.Seq( 182 | n.Conv2D(out_channels=nf, kernel_size=4, stride=2, padding=1, **kargs), 183 | n.BatchNorm(training=False, **kargs), 184 | n.Activation(activation='ReLU', **kargs), 185 | ), 186 | n.Seq( 187 | n.Conv2D(out_channels=nf * 2, kernel_size=4, stride=2, padding=1, **kargs), 188 | n.BatchNorm(training=False, **kargs), 189 | n.Activation(activation='ReLU', **kargs), 190 | ), 191 | n.Seq( 192 | n.Conv2D(out_channels=nf * 4, kernel_size=4, stride=2, padding=1, **kargs), 193 | n.BatchNorm(training=False, **kargs), 194 | n.Activation(activation='ReLU', **kargs), 195 | ), 196 | n.Seq( 197 | n.Conv2D(out_channels=dim, kernel_size=4, stride=1, padding=0, **kargs), 198 | n.BatchNorm(training=False, **kargs), 199 | n.Activation(activation='ReLU', **kargs), 200 | ), 201 | n.View([dim]), 202 | n.Seq( 203 | n.Linear(1, **kargs), 204 | n.Activation(activation='Sigmoid', **kargs), 205 | ), 206 | ) 207 | 208 | def F1(**kargs): 209 | n_class = 1 210 | dim = 10 211 | return n.Seq( 212 | n.View([32 * 32]), 213 | n.Seq( 214 | n.Linear(dim, **kargs), 215 | # n.BatchNorm(training=False, **kargs), 216 | n.Activation(activation='ReLU', **kargs), 217 | ), 218 | n.Seq( 219 | n.Linear(dim, **kargs), 220 | # n.BatchNorm(training=False, **kargs), 221 | n.Activation(activation='ReLU', **kargs), 222 | ), 223 | n.Seq( 224 | n.Linear(n_class, **kargs), 225 | ), 226 | ) 227 | 228 | def F2(**kargs): 229 | n_class = 1 230 | dim = 10 231 | nf = 16 232 | return n.Seq( 233 | n.Seq( 234 | n.Conv2D(out_channels=nf, kernel_size=4, stride=2, padding=1, **kargs), 235 | n.BatchNorm(training=False, **kargs), 236 | n.Activation(activation='ReLU', **kargs), 237 | ), 238 | n.Seq( 239 | n.Conv2D(out_channels=nf * 2, kernel_size=4, stride=4, padding=0, **kargs), 240 | n.BatchNorm(training=False, **kargs), 241 | n.Activation(activation='ReLU', **kargs), 242 | ), 243 | n.View([nf * 2 * 4 * 4]), 244 | n.Seq( 245 | n.Linear(n_class, **kargs), 246 | ), 247 | ) 248 | 249 | def F3(**kargs): 250 | n_class = 1 251 | dim = 10 252 | nf = 16 253 | return n.Seq( 254 | n.Seq( 255 | n.Conv2D(out_channels=nf, kernel_size=4, stride=2, padding=1, **kargs), 256 | n.BatchNorm(training=False, **kargs), 257 | n.Activation(activation='ReLU', **kargs), 258 | ), 259 | n.Seq( 260 | n.Conv2D(out_channels=nf * 2, kernel_size=4, stride=2, padding=1, **kargs), 261 | n.BatchNorm(training=False, **kargs), 262 | n.Activation(activation='ReLU', **kargs), 263 | ), 264 | n.Seq( 265 | n.Conv2D(out_channels=nf * 4, kernel_size=4, stride=2, padding=1, **kargs), 266 | n.BatchNorm(training=False, **kargs), 267 | n.Activation(activation='ReLU', **kargs), 268 | ), 269 | n.Seq( 270 | n.Conv2D(out_channels=dim, kernel_size=4, stride=1, padding=0, **kargs), 271 | n.BatchNorm(training=False, **kargs), 272 | n.Activation(activation='ReLU', **kargs), 273 | ), 274 | n.View([dim]), 275 | n.Seq( 276 | n.Linear(dim, **kargs), 277 | # n.BatchNorm(training=False, **kargs), 278 | n.Activation(activation='ReLU', **kargs), 279 | ), 280 | n.Seq( 281 | n.Linear(n_class, **kargs), 282 | ), 283 | ) 284 | -------------------------------------------------------------------------------- /experiments/model.py: -------------------------------------------------------------------------------- 1 | ##################################################################### 2 | # This script provides certified models in our evaluation. # 3 | # The implementation is crafted to support being incorporated # 4 | # into GenProver/ExactLine. # 5 | ##################################################################### 6 | 7 | import numpy as np 8 | from collections import OrderedDict 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | 15 | class View(nn.Module): 16 | def __init__(self, shape): 17 | super(View, self).__init__() 18 | self.shape = shape 19 | 20 | def forward(self, x): 21 | return x.view([-1] + list(self.shape)) 22 | 23 | class Negate(nn.Module): 24 | def __init__(self): 25 | super(Negate, self).__init__() 26 | 27 | def forward(self, x): 28 | return -x 29 | 30 | class AddOne(nn.Module): 31 | def __init__(self): 32 | super(AddOne, self).__init__() 33 | 34 | def forward(self, x): 35 | return x + 1 36 | 37 | class CatTwo(nn.Module): 38 | def __init__(self): 39 | super(CatTwo, self).__init__() 40 | 41 | def forward(self, tp): 42 | (x, y) = tp 43 | return torch.cat([x, y], dim=1) 44 | 45 | class ParSum(nn.Module): 46 | def __init__(self, net1, net2): 47 | super(ParSum, self).__init__() 48 | self.net1 = net1 49 | self.net2 = net2 50 | 51 | def forward(self, x, just_left = False): 52 | r1 = self.net1(x) 53 | if just_left: 54 | return r1 55 | r2 = self.net2(x) 56 | return r1 + r2 57 | 58 | class ConvGeneratorSeq(nn.Module): 59 | def __init__(self, nc=1, dim=100): 60 | super(ConvGeneratorSeq, self).__init__() 61 | nf = 4 62 | self.dim = dim 63 | self.net = nn.Sequential(OrderedDict([ 64 | ('seq_0', View([dim, 1, 1])), 65 | ('seq_1', 66 | nn.Sequential( 67 | nn.ConvTranspose2d(dim, nf * 4, 4, 1, 0), 68 | nn.BatchNorm2d(nf * 4), 69 | nn.ReLU(), 70 | )), # upc1 71 | ('seq_2', 72 | nn.Sequential( 73 | nn.ConvTranspose2d(nf * 4, nf * 2, 4, 2, 1), 74 | nn.BatchNorm2d(nf * 2), 75 | nn.ReLU() 76 | )), # upc2 77 | ('seq_3', 78 | nn.Sequential( 79 | nn.ConvTranspose2d(nf * 2, nf, 4, 2, 1), 80 | nn.BatchNorm2d(nf), 81 | nn.ReLU() 82 | )), # upc3 83 | ('seq_4', 84 | nn.Sequential( 85 | nn.ConvTranspose2d(nf, nc, 4, 2, 1), 86 | nn.ReLU(), 87 | Negate(), 88 | AddOne(), 89 | nn.ReLU() 90 | )) 91 | ])) 92 | 93 | def forward(self, x): 94 | out = self.net(x) 95 | return out 96 | 97 | class ConvGeneratorSeq32(nn.Module): 98 | def __init__(self, nc=1, dim=100): 99 | super(ConvGeneratorSeq32, self).__init__() 100 | nf = 16 101 | self.dim = dim 102 | self.net = nn.Sequential(OrderedDict([ 103 | ('seq_0', View([dim, 1, 1])), 104 | ('seq_1', 105 | nn.Sequential( 106 | nn.ConvTranspose2d(dim, nf * 4, 4, 1, 0), 107 | nn.BatchNorm2d(nf * 4), 108 | nn.ReLU(), 109 | )), # upc1 110 | ('seq_2', 111 | nn.Sequential( 112 | nn.ConvTranspose2d(nf * 4, nf * 2, 4, 2, 1), 113 | nn.BatchNorm2d(nf * 2), 114 | nn.ReLU() 115 | )), # upc2 116 | ('seq_3', 117 | nn.Sequential( 118 | nn.ConvTranspose2d(nf * 2, nf, 4, 2, 1), 119 | nn.BatchNorm2d(nf), 120 | nn.ReLU() 121 | )), # upc3 122 | ('seq_4', 123 | nn.Sequential( 124 | nn.ConvTranspose2d(nf, nc, 4, 2, 1), 125 | nn.ReLU(), 126 | Negate(), 127 | AddOne(), 128 | nn.ReLU() 129 | )) 130 | ])) 131 | 132 | def forward(self, x): 133 | out = self.net(x) 134 | return out 135 | 136 | class ConvGeneratorSeq64(nn.Module): 137 | def __init__(self, nc=1, dim=100): 138 | super(ConvGeneratorSeq64, self).__init__() 139 | nf = 16 140 | self.dim = dim 141 | self.net = nn.Sequential(OrderedDict([ 142 | ('seq_0', View([dim, 1, 1])), 143 | ('seq_1', 144 | nn.Sequential( 145 | nn.ConvTranspose2d(dim, nf * 4, 4, 1, 0), 146 | nn.BatchNorm2d(nf * 4), 147 | nn.ReLU(), 148 | )), # upc1 149 | ('seq_2', 150 | nn.Sequential( 151 | nn.ConvTranspose2d(nf * 4, nf * 4, 4, 2, 1), 152 | nn.BatchNorm2d(nf * 4), 153 | nn.ReLU() 154 | )), # upc2 155 | ('seq_3', 156 | nn.Sequential( 157 | nn.ConvTranspose2d(nf * 4, nf * 2, 4, 2, 1), 158 | nn.BatchNorm2d(nf * 2), 159 | nn.ReLU() 160 | )), # upc3 161 | ('seq_4', 162 | nn.Sequential( 163 | nn.ConvTranspose2d(nf * 2, nf, 4, 2, 1), 164 | nn.BatchNorm2d(nf), 165 | nn.ReLU() 166 | )), # upc4 167 | ('seq_5', 168 | nn.Sequential( 169 | nn.ConvTranspose2d(nf, nc, 4, 2, 1), 170 | nn.ReLU(), 171 | Negate(), 172 | AddOne(), 173 | nn.ReLU() 174 | )) 175 | ])) 176 | 177 | def forward(self, x): 178 | out = self.net(x) 179 | return out 180 | 181 | class RecogSeq64(nn.Module): 182 | def __init__(self, nc=3, dim=100): 183 | super(RecogSeq64, self).__init__() 184 | nf = 16 185 | self.dim = dim 186 | self.net = nn.Sequential(OrderedDict([ 187 | ('seq_0', CatTwo()), 188 | # 64 x 64 189 | ('seq_1', 190 | nn.Sequential( 191 | nn.Conv2d(nc * 2, nf, 4, 2, 1), 192 | nn.BatchNorm2d(nf), 193 | nn.ReLU(), 194 | )), # upc1 195 | # 32 x 32 196 | ('seq_2', 197 | nn.Sequential( 198 | nn.Conv2d(nf, nf * 2, 4, 2, 1), 199 | nn.BatchNorm2d(nf * 2), 200 | nn.ReLU() 201 | )), # upc2 202 | # 16 x 16 203 | ('seq_3', 204 | nn.Sequential( 205 | nn.Conv2d(nf * 2, nf * 4, 4, 2, 1), 206 | nn.BatchNorm2d(nf * 4), 207 | nn.ReLU() 208 | )), # upc3 209 | # 8 x 8 210 | ('seq_4', 211 | nn.Sequential( 212 | nn.Conv2d(nf * 4, nf * 4, 4, 2, 1), 213 | nn.BatchNorm2d(nf * 4), 214 | nn.ReLU() 215 | )), # upc4 216 | # 4 x 4 217 | ('seq_5', 218 | nn.Sequential( 219 | nn.Conv2d(nf * 4, dim, 4, 1, 0), 220 | nn.BatchNorm2d(dim), 221 | nn.ReLU(), 222 | )), 223 | # 1 x 1 224 | ('seq_6', View([dim])), 225 | ('seq_7', 226 | nn.Sequential( 227 | nn.Linear(dim, 1), 228 | nn.Sigmoid() 229 | )), 230 | ])) 231 | 232 | def forward(self, tp): 233 | out = self.net(tp) 234 | return out 235 | 236 | class RecogSeq32(nn.Module): 237 | def __init__(self, nc=3, dim=100): 238 | super(RecogSeq32, self).__init__() 239 | nf = 16 240 | self.dim = dim 241 | self.net = nn.Sequential(OrderedDict([ 242 | ('seq_0', CatTwo()), 243 | # 32 x 32 244 | ('seq_1', 245 | nn.Sequential( 246 | nn.Conv2d(nc * 2, nf, 4, 2, 1), 247 | nn.BatchNorm2d(nf), 248 | nn.ReLU(), 249 | )), # upc1 250 | # 16 x 16 251 | ('seq_2', 252 | nn.Sequential( 253 | nn.Conv2d(nf, nf * 2, 4, 2, 1), 254 | nn.BatchNorm2d(nf * 2), 255 | nn.ReLU() 256 | )), # upc2 257 | # 8 x 8 258 | ('seq_3', 259 | nn.Sequential( 260 | nn.Conv2d(nf * 2, nf * 4, 4, 2, 1), 261 | nn.BatchNorm2d(nf * 4), 262 | nn.ReLU() 263 | )), # upc3 264 | # 4 x 4 265 | ('seq_4', 266 | nn.Sequential( 267 | nn.Conv2d(nf * 4, dim, 4, 1, 0), 268 | nn.BatchNorm2d(dim), 269 | nn.ReLU(), 270 | )), 271 | # 1 x 1 272 | ('seq_5', View([dim])), 273 | ('seq_6', 274 | nn.Sequential( 275 | nn.Linear(dim, 1), 276 | nn.Sigmoid() 277 | )), 278 | ])) 279 | 280 | def forward(self, tp): 281 | out = self.net(tp) 282 | return out 283 | 284 | class F1(nn.Module): 285 | def __init__(self, n_class=10, dim=10): 286 | super(F1, self).__init__() 287 | self.dim = dim 288 | self.net = nn.Sequential(OrderedDict([ 289 | ('seq_0', View([32 * 32])), 290 | ('seq_1', 291 | nn.Sequential( 292 | nn.Linear(32 * 32, dim), 293 | # nn.BatchNorm1d(dim), 294 | nn.ReLU(), 295 | )), 296 | ('seq_2', 297 | nn.Sequential( 298 | nn.Linear(dim, dim), 299 | # nn.BatchNorm1d(dim), 300 | nn.ReLU(), 301 | )), 302 | ('seq_3', 303 | nn.Sequential( 304 | nn.Linear(dim, n_class), 305 | )), 306 | ])) 307 | 308 | def forward(self, x): 309 | out = self.net(x) 310 | return out 311 | 312 | class F2(nn.Module): 313 | def __init__(self, nc=1, n_class=10, dim=10): 314 | super(F2, self).__init__() 315 | self.dim = dim 316 | nf = 16 317 | self.net = nn.Sequential(OrderedDict([ 318 | # 32 x 32 319 | ('seq_0', 320 | nn.Sequential( 321 | nn.Conv2d(nc, nf, 4, 2, 1), 322 | nn.BatchNorm2d(nf), 323 | nn.ReLU(), 324 | )), 325 | # 16 x 16 326 | ('seq_1', 327 | nn.Sequential( 328 | nn.Conv2d(nf, nf * 2, 4, 4, 0), 329 | nn.BatchNorm2d(nf * 2), 330 | nn.ReLU(), 331 | )), 332 | # 4 x 4 333 | ('seq_2', View([nf * 2 * 4 * 4])), 334 | ('seq_3', 335 | nn.Sequential( 336 | nn.Linear(nf * 2 * 4 * 4, n_class), 337 | )), 338 | ])) 339 | 340 | def forward(self, x): 341 | out = self.net(x) 342 | return out 343 | 344 | class F3(nn.Module): 345 | def __init__(self, nc=1, n_class=10, dim=10): 346 | super(F3, self).__init__() 347 | self.dim = dim 348 | nf = 16 349 | self.net = nn.Sequential(OrderedDict([ 350 | # 32 x 32 351 | ('seq_0', 352 | nn.Sequential( 353 | nn.Conv2d(nc, nf, 4, 2, 1), 354 | nn.BatchNorm2d(nf), 355 | nn.ReLU(), 356 | )), 357 | # 16 x 16 358 | ('seq_1', 359 | nn.Sequential( 360 | nn.Conv2d(nf, nf * 2, 4, 2, 1), 361 | nn.BatchNorm2d(nf * 2), 362 | nn.ReLU(), 363 | )), 364 | # 8 x 8 365 | ('seq_2', 366 | nn.Sequential( 367 | nn.Conv2d(nf * 2, nf * 4, 4, 2, 1), 368 | nn.BatchNorm2d(nf * 4), 369 | nn.ReLU(), 370 | )), 371 | # 4 x 4 372 | ('seq_3', 373 | nn.Sequential( 374 | nn.Conv2d(nf * 4, dim, 4, 1, 0), 375 | nn.BatchNorm2d(dim), 376 | nn.ReLU(), 377 | )), 378 | # 1 x 1 379 | ('seq_4', View([dim])), 380 | ('seq_5', 381 | nn.Sequential( 382 | nn.Linear(dim, dim), 383 | # nn.BatchNorm1d(dim), 384 | nn.ReLU(), 385 | )), 386 | ('seq_6', 387 | nn.Sequential( 388 | nn.Linear(dim, n_class), 389 | )), 390 | ])) 391 | 392 | def forward(self, x): 393 | out = self.net(x) 394 | return out -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GCert 2 | Research Artifact of USENIX Security 2023 Paper: *Precise and Generalized Robustness Certification for Neural Networks* 3 | 4 | Preprint: https://arxiv.org/pdf/2306.06747.pdf 5 | 6 | 7 | ## Installation 8 | 9 | - Build from source code 10 | 11 | ```setup 12 | git clone https://github.com/Yuanyuan-Yuan/GCert 13 | cd GCert 14 | pip install -r requirements.txt 15 | ``` 16 | 17 | ## Structure 18 | 19 | This repo is organized as follows: 20 | 21 | - `implementation` - This folder provides implementations and examples of regulating 22 | generative models with continuity and independence. See detailed documents [here](https://github.com/Yuanyuan-Yuan/GCert/tree/main/implementation) 23 | 24 | - `experiments` - This folder provides scripts of our evaluations. See detailed documents [here](https://github.com/Yuanyuan-Yuan/GCert/tree/main/experiments) 25 | 26 | - `frameworks` - GCert is incorporated into three conventional certification frameworks (i.e., 27 | AI2/Eran, GenProver, and ExactLine). This folder provides the scripts for configurations; see 28 | detailed documents [here](https://github.com/Yuanyuan-Yuan/GCert/tree/main/frameworks) 29 | 30 | - `data` - This folder provides scripts for data processing and shows examples of some data samples. See detailed documents [here](https://github.com/Yuanyuan-Yuan/GCert/tree/main/data). 31 | 32 | ## Data 33 | 34 | The following four datasets are considered in our evaluation. 35 | 36 | - MNIST - We use the dataset provided by Pytorch (see [here](https://pytorch.org/vision/stable/generated/torchvision.datasets.MNIST.html)). Note that the original image size is $1 \times 27 \times 27$. We resize the image size to $1 \times 32 \times 32$. 37 | 38 | - CIFAR10 - We use the dataset provided by Pytorch (see [here](https://pytorch.org/vision/stable/generated/torchvision.datasets.CIFAR10.html)). 39 | 40 | - Driving - The images can be downloaded [here](https://github.com/SullyChen/driving-datasets). We provide several examples in `data/driving` folder. The dataset class can be implemented using the [ImageFolder](https://pytorch.org/vision/stable/generated/torchvision.datasets.ImageFolder.html) class in Pytorch. 41 | 42 | - CelebA - The official dataset can be downloaded [here](https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html). Once downloading the dataset, you can use `data/celeba_crop128/celeba_process.py` to process the dataset, which splits the dataset into different subfolders and crop and resize the faces into $128 \times 128$. Several processed examples are given in the `data/celeba_crop128/train` folder. We also provide the mapping between file names and human IDs in `data/CelebA_ID_to_name.json` and `data/CelebA_name_to_ID.json`. 43 | The dataset class is implemented in `experiments/dataset.py`. 44 | 45 | ## Frameworks 46 | 47 | GCert is incorporated into the following three frameworks for certification. 48 | 49 | ### GenProver 50 | 51 | The official implementation of GenProver is provided [here](https://openreview.net/forum?id=HJxRMlrtPH). 52 | 53 | After downloading the code, you need to modify the following scripts in the projects: 54 | 55 | - `frameworks/GenProver/components.py` - GenProver is implemented based on [DiffAI](https://github.com/eth-sri/diffai) and `components.py` re-implements different Pytorch `nn` modules with `InferModule` of DiffAI. We modified the implementations of several modules (mostly the `BatchNorm` module) to better fit the implementations in Pytorch. You can replace the original `components.py` with our provided one. 56 | 57 | - `frameworks/GenProver/genmodels.py` - We added implementations (with DiffAI modules) of our models in this script. You can replace the original `genmodels.py` with our provided one. 58 | 59 | Note that in order to load models trained with Pytorch, you need to do the following: 60 | 61 | 1. Implement the model following the examples given in `experiments/model.py`. We suggest implementing the model with `nn.Sequential()` and hard-coding the name for each `nn.Sequential()`. 62 | 63 | 2. Implement every operation as a class inherited from Pytorch `nn` module. For example, the `torch.cat()` operation should be implement as `class CatTwo(nn.Module)` in `experiments/model.py`; see examples in `experiments/model.py`. 64 | 65 | 3. Implement the corresponding class following DiffAI in `frameworks/GenProver/components.py`. For example, for the `class CatTwo(nn.Module)` in `experiments/model.py`, you should implement a `class CatTwo(InferModule)` in `components.py`; more examples are given in `components.py`. 66 | 67 | 4. When loading the trained weights, you need to convert the key in `state_dict`. We provide the implementation and examples in `frameworks/GenProver/load_model.py`. 68 | 69 | ### ExactLine 70 | 71 | We use the ExactLine implemented by authors of GenProver. The source code can be downloaded [here](https://openreview.net/forum?id=HJxRMlrtPH). 72 | 73 | The implmentation of ExactLine and GenProver are almost the same, except that GenProver merges segments in intermediate outputs as box/polyhedra. Thus, to use ExactLine, you only need to set 74 | 75 | ```python 76 | use_clustr = None 77 | ``` 78 | 79 | in the implementation of GenProver. 80 | 81 | ### AI2/ERAN 82 | 83 | The official implementation is provided [here](https://github.com/eth-sri/eran). In our experiments, we use the adaptor provided by [VeriGauge](https://github.com/AI-secure/VeriGauge) to set up AI2/ERAN. 84 | 85 | [VeriGauge](https://github.com/AI-secure/VeriGauge) and [AI2/ERAN](https://github.com/eth-sri/eran) are well implemented and documented; you can smoothly set up everything following their instructions. 86 | 87 | ## Implementation 88 | 89 | This folder provides implementations and examples of regulating generative models with independence and continuity. 90 | 91 | ### Continuity 92 | 93 | To enforce the continuity, you need to add an extra training objective. See more details in `implementation/continuity.py`. Below, we show how to train a conventional GAN with regulation of continuity. 94 | 95 | ```python 96 | def continuity(generator): 97 | # if the latent space follows uniform distribution 98 | z1 = Tensor(np.random.uniform(-1, 1, (batch_size, latent_dimension))) 99 | z2 = Tensor(np.random.uniform(-1, 1, (batch_size, latent_dimension))) 100 | # # if the latent space follows normal distribution 101 | # z1 = Tensor(np.random.normal(0, 1, (batch_size, latent_dimension))) 102 | # z2 = Tensor(np.random.normal(0, 1, (batch_size, latent_dimension))) 103 | G1 = generator(z1) 104 | G2 = generator(z2) 105 | gamma = random.uniform(0, 1) 106 | z = torch.lerp(z1, z2, gamma) 107 | # an `intermediate point` between z1 and z2 108 | G = generator(z) 109 | penality = (gamma * G2 - G - (1 - gamma) * G1).square().mean() 110 | return penality 111 | 112 | n_epochs = 100 113 | for epoch in range(n_epochs): 114 | generator.train() 115 | discriminator.train() 116 | for i, (images, *_) in enumerate(tqdm(dataloader)): 117 | 118 | images = Tensor(images) 119 | 120 | optimizer_G.zero_grad() 121 | z = Tensor(np.random.uniform(-1, 1, (batch_size, latent_dimension))) 122 | # or: Tensor(np.random.normal(0, 1, (batch_size, latent_dimension))) 123 | G = generator(z) 124 | g_loss = bce(discriminator(G), real) 125 | 126 | g_loss += continuity(generator) 127 | # just add this one line :D 128 | 129 | g_loss.backward() 130 | optimizer_G.step() 131 | 132 | optimizer_D.zero_grad() 133 | real_loss = bce(discriminator(images), real) 134 | fake_loss = bce(discriminator(G), fake) 135 | d_loss = (real_loss + fake_loss) / 2 136 | d_loss.backward() 137 | optimizer_D.step() 138 | ``` 139 | 140 | ### Independence 141 | 142 | The independence is ensured from the following two aspects. 143 | 144 | #### Global Mutations 145 | 146 | For global mutations, different mutations are represented as *orthogonal* directions in the latent space. This is achieved using SVD; see details in `implementation/independence.py`. 147 | 148 | Below is an example of getting global mutating directions 149 | 150 | ```python 151 | J = Jacobian(G, z) 152 | # `G` is the generative model and `z` is the latent point 153 | directions = get_direction(J, None) 154 | ``` 155 | 156 | #### Local Mutations 157 | 158 | For local mutations, besides representing different mutations as orthogonal directions, we also ensure that only the selected local region is mutated. This is achieved by projecting mutating directions of the local region into non-mutating directions of the background. 159 | 160 | Before performing local mutation, you need to manualy set the foreground and backgroud indexes. Below is an example of mutating eyes for ffhq images. 161 | 162 | ```python 163 | COORDINATE_ffhq = { 164 | 'left_eye': [120, 95, 20, 38], 165 | 'right_eye': [120, 159, 20, 38], 166 | 'eyes': [120, 128, 20, 115], 167 | 'nose': [142, 131, 40, 46], 168 | 'mouth': [184, 127, 30, 70], 169 | 'chin': [217, 130, 42, 110], 170 | 'eyebrow': [126, 105, 15, 118], 171 | } 172 | 173 | def get_mask_by_coordinates(image_size, coordinate): 174 | """Get mask using the provided coordinates.""" 175 | mask = np.zeros([image_size, image_size], dtype=np.float32) 176 | center_x, center_y = coordinate[0], coordinate[1] 177 | crop_x, crop_y = coordinate[2], coordinate[3] 178 | xx = center_x - crop_x // 2 179 | yy = center_y - crop_y // 2 180 | mask[xx:xx + crop_x, yy:yy + crop_y] = 1. 181 | return mask 182 | 183 | coords = COORDINATE_ffhq['eyes'] 184 | mask = get_mask_by_coordinates(256, coordinate=coords) 185 | foreground_ind = np.where(mask == 1) 186 | background_ind = np.where((1 - mask) == 1) 187 | directions = get_direction(J, None, foreground_ind, background_ind) 188 | ``` 189 | 190 | The coordinates are provided by authors of [LowRankGAN](https://github.com/zhujiapeng/resefa/blob/main/coordinate.py). 191 | 192 | ### Performing Mutations 193 | 194 | Once you get the mutating directions, you can perform mutations in the following way. 195 | 196 | ```python 197 | delta = 1.0 198 | for i in range(len(directions)): 199 | v = directions[i] 200 | x_ = G(z + delta * v) 201 | ``` 202 | 203 | `delta` controls the extent of the mutation. `x_` is the mutated input using the `i`-th mutating direction. 204 | 205 | ## Experiments 206 | 207 | This folder provides scripts for our evaluations. 208 | 209 | ### Models and Datasets 210 | 211 | - `experiments/dataset.py` - We implement two Pytorch Dataset classes for CelebA. `CelebARecog` is used for training face recognition models. `CelebAGen` is employed for training face image generator. 212 | 213 | - `experiments/model.py` - We implement our models in this script. In accordance to requirements of GenProver/ExactLine, the implementations are carefully crafted. See details in [GenProver](https://github.com/Yuanyuan-Yuan/GCert/tree/main/frameworks/GenProver). 214 | 215 | - `experiments/face_recognition.py` - Our face recognition model takes a tuple of two images as one input and predicts whether the two faces are from the same person. This script implements how we train the face recognition model. 216 | 217 | See [data](https://github.com/Yuanyuan-Yuan/GCert/tree/main/data) for how to download and process the datasets. 218 | 219 | ### Mutations 220 | 221 | #### Geometrical 222 | 223 | - `experiments/augment_geometrical.py` - This script shows how we augment the training data with different geometrical (affine) mutations. In brief, this is achieved by applying the mutation in runtime. 224 | 225 | Pytorch `transforms` module supports randomly applying affine mutations on each input, see implementation below. 226 | 227 | ```python 228 | transforms.RandomAffine( 229 | degrees=30, 230 | # translate=(0.3, 0.3), 231 | # scale=(0.75, 1.2), 232 | # shear=(0.2) 233 | ), 234 | ``` 235 | 236 | We also provide implementations of different mutations in `experiments/mutation.py`. Below is the example of rotation. 237 | 238 | ```python 239 | class Rotation(Transformation): 240 | def init_id(self): 241 | self.category = 'geometrical' 242 | self.name = 'rotation' 243 | 244 | def mutate(self, seed): 245 | x = seed['x'] 246 | img = self.torch2cv(x) 247 | ext = self.extent() 248 | rows, cols, ch = img.shape 249 | M = cv2.getRotationMatrix2D((cols/2, rows/2), ext, 1) 250 | x_ = cv2.warpAffine(img, M, (cols, rows)) 251 | return self.cv2torch(x_), seed['z'] 252 | 253 | def extent(self): 254 | ext = np.random.choice(list(range(-180, 180))) 255 | # Set the maximal extent of mutations here 256 | return ext 257 | ``` 258 | 259 | You can also augment the training data with rotation (such that rotation can be decomposed from the latent space of the generative model) in the follow way. 260 | 261 | ```python 262 | from mutation import Rotation 263 | 264 | for epoch in range(num_epoch): 265 | for (image, *_) in dataloader: 266 | image_ = Rotation.mutate(image) 267 | # Then use `image_` to train the generative model 268 | ``` 269 | 270 | #### Perceptual-Level 271 | 272 | For perceptual-level mutations, since they are extracted from the perception variations from natural images, you don not need to do anything; just train a standard generative model. See `implementation/independence.py` for how to obtain perceptual-level mutations. 273 | 274 | #### Stylized 275 | 276 | For stylized mutations, you need to train the generative model following the cycle-consistency (which is proposed in CycleGAN). The official Pytorch implementation of CycleGAN is provided [here](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix). You can smoothly set up everything following the official documents. 277 | 278 | For different artistical styles, we use the style files provided [here](https://github.com/rgeirhos/Stylized-ImageNet). 279 | 280 | For weather-filters, we use the simulated filters provided by [imgaug](https://github.com/aleju/imgaug). The implementations are given in `experiments/mutation.py`. Below is an example of the foggy mutation. 281 | 282 | ```python 283 | import imgaug.augmenters as iaa 284 | 285 | class Weather(Transformation): 286 | def mutate(self, seed): 287 | x = seed['x'] 288 | img = self.torch2cv(x) 289 | x_ = self.trans(images=[img])[0] 290 | return self.cv2torch(x_), seed['z'] 291 | 292 | class Fog(Weather): 293 | def init_id(self): 294 | self.category = 'style' 295 | self.name = 'fog' 296 | self.trans = iaa.Fog() 297 | ``` 298 | 299 | ### Evalution Tools 300 | 301 | - `experiments/rectangle.py` - This script implements how to calculate the minimal enclosing rectangle for assessing the geometrical properties. 302 | 303 | - `experiments/synthetic_data.py` - This script implements the synthetic dataset of our ablation study. You can directly use the `SyntheticDataset` class as one Pytorch dataset class. 304 | 305 | ## Acknowledgement 306 | 307 | We sincerely thank authors of the following projects for open-sourcing their code, which greatly help us develop GCert. 308 | 309 | - GenProver/DiffAI: https://github.com/eth-sri/diffai 310 | 311 | - AI2/ERAN: https://github.com/eth-sri/eran 312 | 313 | - VeriGauge: https://github.com/AI-secure/VeriGauge 314 | 315 | - ExactLine: https://github.com/95616ARG/SyReNN 316 | 317 | - LowRankGAN: https://github.com/zhujiapeng/LowRankGAN 318 | 319 | ## Citation 320 | 321 | If GCert is helpful for your research, please consider cite our work as follows: 322 | 323 | ```bib 324 | @inproceedings{yuan2023precise, 325 | title={Precise and Generalized Robustness Certification for Neural Networks}, 326 | author={Yuan, Yuanyuan and Wang, Shuai and Su, Zhendong}, 327 | booktitle={32nd USENIX Security Symposium (USENIX Security 23)}, 328 | year={2023} 329 | } 330 | ``` 331 | 332 | If you have any questions, feel free to contact Yuanyuan (yyuanaq@cse.ust.hk). -------------------------------------------------------------------------------- /frameworks/GenProver/components.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from torch.distributions import multinomial, categorical 5 | import torch.optim as optim 6 | 7 | import math 8 | 9 | try: 10 | from . import helpers as h 11 | from . import ai 12 | from . import scheduling as S 13 | except: 14 | import helpers as h 15 | import ai 16 | import scheduling as S 17 | 18 | import math 19 | import abc 20 | 21 | from torch.nn.modules.conv import _ConvNd 22 | from enum import Enum 23 | 24 | 25 | class InferModule(nn.Module): 26 | def __init__(self, *args, normal = False, ibp_init = False, **kwargs): 27 | self.args = args 28 | self.kwargs = kwargs 29 | self.infered = False 30 | self.normal = normal 31 | self.ibp_init = ibp_init 32 | 33 | def infer(self, in_shape, global_args = None): 34 | """ this is really actually stateful. """ 35 | 36 | if self.infered: 37 | return self 38 | self.infered = True 39 | 40 | super(InferModule, self).__init__() 41 | self.inShape = list(in_shape) 42 | self.outShape = list(self.init(list(in_shape), *self.args, global_args = global_args, **self.kwargs)) 43 | if self.outShape is None: 44 | raise "init should set the out_shape" 45 | 46 | self.reset_parameters() 47 | return self 48 | 49 | def reset_parameters(self): 50 | if not hasattr(self,'weight') or self.weight is None: 51 | return 52 | n = h.product(self.weight.size()) / self.outShape[0] 53 | stdv = 1 / math.sqrt(n) 54 | 55 | if self.ibp_init: 56 | torch.nn.init.orthogonal_(self.weight.data) 57 | elif self.normal: 58 | self.weight.data.normal_(0, stdv) 59 | self.weight.data.clamp_(-1, 1) 60 | else: 61 | self.weight.data.uniform_(-stdv, stdv) 62 | 63 | if self.bias is not None: 64 | if self.ibp_init: 65 | self.bias.data.zero_() 66 | elif self.normal: 67 | self.bias.data.normal_(0, stdv) 68 | self.bias.data.clamp_(-1, 1) 69 | else: 70 | self.bias.data.uniform_(-stdv, stdv) 71 | 72 | def clip_norm(self): 73 | if not hasattr(self, "weight"): 74 | return 75 | if not hasattr(self,"weight_g"): 76 | if torch.__version__[0] == "0": 77 | nn.utils.weight_norm(self, dim=None) 78 | else: 79 | nn.utils.weight_norm(self) 80 | 81 | self.weight_g.data.clamp_(-h.max_c_for_norm, h.max_c_for_norm) 82 | 83 | if torch.__version__[0] != "0": 84 | self.weight_v.data.clamp_(-h.max_c_for_norm * 10000,h.max_c_for_norm * 10000) 85 | if hasattr(self, "bias"): 86 | self.bias.data.clamp_(-h.max_c_for_norm * 10000, h.max_c_for_norm * 10000) 87 | 88 | def regularize(self, p): 89 | reg = 0 90 | if torch.__version__[0] == "0": 91 | for param in self.parameters(): 92 | reg += param.norm(p) 93 | else: 94 | if hasattr(self, "weight_g"): 95 | reg += self.weight_g.norm().sum() 96 | reg += self.weight_v.norm().sum() 97 | elif hasattr(self, "weight"): 98 | reg += self.weight.norm().sum() 99 | 100 | if hasattr(self, "bias"): 101 | reg += self.bias.view(-1).norm(p=p).sum() 102 | 103 | return reg 104 | 105 | def remove_norm(self): 106 | if hasattr(self,"weight_g"): 107 | torch.nn.utils.remove_weight_norm(self) 108 | 109 | def showNet(self, t = ""): 110 | print(t + self.__class__.__name__) 111 | 112 | def printNet(self, f): 113 | print(self.__class__.__name__, file=f) 114 | 115 | @abc.abstractmethod 116 | def forward(self, *args, **kargs): 117 | pass 118 | 119 | def __call__(self, *args, onyx=False, **kargs): 120 | if onyx: 121 | return self.forward(*args, onyx=onyx, **kargs) 122 | else: 123 | return super(InferModule, self).__call__(*args, **kargs) 124 | 125 | @abc.abstractmethod 126 | def neuronCount(self): 127 | pass 128 | 129 | def depth(self): 130 | return 0 131 | 132 | 133 | 134 | 135 | class Linear(InferModule): 136 | def init(self, in_shape, out_shape, **kargs): 137 | self.in_neurons = h.product(in_shape) 138 | if isinstance(out_shape, int): 139 | out_shape = [out_shape] 140 | self.out_neurons = h.product(out_shape) 141 | 142 | self.weight = torch.nn.Parameter(torch.Tensor(self.out_neurons, self.in_neurons)) 143 | self.bias = torch.nn.Parameter(torch.Tensor(self.out_neurons)) 144 | 145 | return out_shape 146 | 147 | def forward(self, x, **kargs): 148 | s = x.size() 149 | x = x.view(s[0], h.product(s[1:])) 150 | return (x.matmul(self.weight.T) + self.bias).view(s[0], *self.outShape) 151 | 152 | def neuronCount(self): 153 | return 0 154 | 155 | def showNet(self, t = ""): 156 | print(t + "Linear out=" + str(self.out_neurons)) 157 | 158 | def printNet(self, f): 159 | print("Linear(" + str(self.out_neurons) + ")" ) 160 | 161 | print(h.printListsNumpy(list(self.weight.transpose(1,0).data)), file= f) 162 | print(h.printNumpy(self.bias), file= f) 163 | 164 | class Activation(InferModule): 165 | def init(self, in_shape, global_args = None, activation = "ReLU", **kargs): 166 | self.activation = [ "ReLU","Sigmoid", "Tanh", "Softplus", "ELU", "SELU"].index(activation) 167 | self.activation_name = activation 168 | return in_shape 169 | 170 | def regularize(self, p): 171 | return 0 172 | 173 | def forward(self, x, **kargs): 174 | return [lambda x:x.relu(), lambda x:x.sigmoid(), lambda x:x.tanh(), lambda x:x.softplus(), lambda x:x.elu(), lambda x:x.selu()][self.activation](x) 175 | 176 | def neuronCount(self): 177 | return h.product(self.outShape) 178 | 179 | def depth(self): 180 | return 1 181 | 182 | def showNet(self, t = ""): 183 | print(t + self.activation_name) 184 | 185 | def printNet(self, f): 186 | pass 187 | 188 | class ReLU(Activation): 189 | pass 190 | 191 | def activation(*args, batch_norm = False, **kargs): 192 | a = Activation(*args, **kargs) 193 | return Seq(BatchNorm(), a) if batch_norm else a 194 | 195 | class Identity(InferModule): # for feigning model equivelence when removing an op 196 | def init(self, in_shape, global_args = None, **kargs): 197 | return in_shape 198 | 199 | def forward(self, x, **kargs): 200 | return x 201 | 202 | def neuronCount(self): 203 | return 0 204 | 205 | def printNet(self, f): 206 | pass 207 | 208 | def regularize(self, p): 209 | return 0 210 | 211 | def showNet(self, *args, **kargs): 212 | pass 213 | 214 | class Negate(Identity): # for feigning model equivelence when removing an op 215 | def init(self, in_shape, global_args = None, **kargs): 216 | return in_shape 217 | 218 | def forward(self, x, **kargs): 219 | return x * (-1) 220 | 221 | def showNet(self, *args, **kargs): 222 | print("Negate") 223 | 224 | class AddOne(Identity): # for feigning model equivelence when removing an op 225 | def init(self, in_shape, global_args = None, **kargs): 226 | return in_shape 227 | 228 | def forward(self, x, **kargs): 229 | return x + 1 230 | 231 | def showNet(self, *args, **kargs): 232 | print("AddOne") 233 | 234 | class Abs(Identity): # for feigning model equivelence when removing an op 235 | def init(self, in_shape, global_args = None, **kargs): 236 | return in_shape 237 | 238 | def forward(self, x, ignore_abs = False, **kargs): 239 | if ignore_abs: 240 | return x 241 | return x.abs() 242 | 243 | def showNet(self, *args, **kargs): 244 | print("Abs") 245 | 246 | class Sq(Identity): # for feigning model equivelence when removing an op 247 | def init(self, in_shape, global_args = None, **kargs): 248 | return in_shape 249 | 250 | def forward(self, x, **kargs): 251 | return x.pow(2) 252 | 253 | def showNet(self, *args, **kargs): 254 | print("Sq") 255 | 256 | 257 | class Dropout(InferModule): 258 | def init(self, in_shape, p=0.5, use_2d = False, alpha_dropout = False, **kargs): 259 | self.p = S.Const.initConst(p) 260 | self.use_2d = use_2d 261 | self.alpha_dropout = alpha_dropout 262 | return in_shape 263 | 264 | def forward(self, x, time = 0, **kargs): 265 | if self.training: 266 | with torch.no_grad(): 267 | p = self.p.getVal(time = time) 268 | mask = (F.dropout2d if self.use_2d else F.dropout)(h.ones(x.size()),p=p, training=True) 269 | if self.alpha_dropout: 270 | with torch.no_grad(): 271 | keep_prob = 1 - p 272 | alpha = -1.7580993408473766 273 | a = math.pow(keep_prob + alpha * alpha * keep_prob * (1 - keep_prob), -0.5) 274 | b = -a * alpha * (1 - keep_prob) 275 | mask = mask * a 276 | return x * mask + b 277 | else: 278 | return x * mask 279 | else: 280 | return x 281 | 282 | def neuronCount(self): 283 | return 0 284 | 285 | def showNet(self, t = ""): 286 | print(t + "Dropout p=" + str(self.p)) 287 | 288 | def printNet(self, f): 289 | print("Dropout(" + str(self.p) + ")" ) 290 | 291 | class PrintActivation(Identity): 292 | def init(self, in_shape, global_args = None, activation = "ReLU", **kargs): 293 | self.activation = activation 294 | return in_shape 295 | 296 | def printNet(self, f): 297 | print(self.activation, file = f) 298 | 299 | class PrintReLU(PrintActivation): 300 | pass 301 | 302 | class Conv2D(InferModule): 303 | 304 | def init(self, in_shape, out_channels, kernel_size, stride = 1, global_args = None, bias=True, padding = 0, activation = "ReLU", **kargs): 305 | self.prev = in_shape 306 | self.in_channels = in_shape[0] 307 | self.out_channels = out_channels 308 | self.kernel_size = kernel_size 309 | self.stride = stride 310 | self.padding = padding 311 | self.activation = activation 312 | self.use_softplus = h.default(global_args, 'use_softplus', False) 313 | 314 | weights_shape = (self.out_channels, self.in_channels, kernel_size, kernel_size) 315 | self.weight = torch.nn.Parameter(torch.Tensor(*weights_shape)) 316 | if bias: 317 | self.bias = torch.nn.Parameter(torch.Tensor(weights_shape[0])) 318 | else: 319 | self.bias = None # h.zeros(weights_shape[0]) 320 | 321 | outshape = h.getShapeConv(in_shape, (out_channels, kernel_size, kernel_size), stride, padding) 322 | return outshape 323 | 324 | def forward(self, input, **kargs): 325 | return input.conv2d(self.weight, bias=self.bias, stride=self.stride, padding = self.padding ) 326 | 327 | def printNet(self, f): # only complete if we've forwardt stride=1 328 | print("Conv2D", file = f) 329 | sz = list(self.prev) 330 | print(self.activation + ", filters={}, kernel_size={}, input_shape={}, stride={}, padding={}".format(self.out_channels, [self.kernel_size, self.kernel_size], list(reversed(sz)), [self.stride, self.stride], self.padding ), file = f) 331 | print(h.printListsNumpy([[list(p) for p in l ] for l in self.weight.permute(2,3,1,0).data]) , file= f) 332 | print(h.printNumpy(self.bias if self.bias is not None else h.dten(self.out_channels)), file= f) 333 | 334 | def showNet(self, t = ""): 335 | sz = list(self.prev) 336 | print(t + "Conv2D, filters={}, kernel_size={}, input_shape={}, stride={}, padding={}".format(self.out_channels, [self.kernel_size, self.kernel_size], list(reversed(sz)), [self.stride, self.stride], self.padding )) 337 | 338 | def neuronCount(self): 339 | return 0 340 | 341 | 342 | class ConvTranspose2D(InferModule): 343 | 344 | def init(self, in_shape, out_channels, kernel_size, stride = 1, global_args = None, bias=True, padding = 0, out_padding=0, activation = "ReLU", **kargs): 345 | self.prev = in_shape 346 | self.in_channels = in_shape[0] 347 | self.out_channels = out_channels 348 | self.kernel_size = kernel_size 349 | self.stride = stride 350 | self.padding = padding 351 | self.out_padding = out_padding 352 | self.activation = activation 353 | self.use_softplus = h.default(global_args, 'use_softplus', False) 354 | 355 | weights_shape = (self.in_channels, self.out_channels, kernel_size, kernel_size) 356 | self.weight = torch.nn.Parameter(torch.Tensor(*weights_shape)) 357 | if bias: 358 | self.bias = torch.nn.Parameter(torch.Tensor(weights_shape[1])) 359 | else: 360 | self.bias = None # h.zeros(weights_shape[0]) 361 | 362 | outshape = h.getShapeConvTranspose(in_shape, (out_channels, kernel_size, kernel_size), stride, padding, out_padding) 363 | return outshape 364 | 365 | def forward(self, input, **kargs): 366 | return input.conv_transpose2d(self.weight, bias=self.bias, stride=self.stride, padding = self.padding, output_padding=self.out_padding) 367 | 368 | def printNet(self, f): # only complete if we've forwardt stride=1 369 | print("ConvTranspose2D", file = f) 370 | print(self.activation + ", filters={}, kernel_size={}, input_shape={}".format(self.out_channels, list(self.kernel_size), list(self.prev) ), file = f) 371 | print(h.printListsNumpy([[list(p) for p in l ] for l in self.weight.permute(2,3,1,0).data]) , file= f) 372 | print(h.printNumpy(self.bias), file= f) 373 | 374 | def neuronCount(self): 375 | return 0 376 | 377 | 378 | 379 | class MaxPool2D(InferModule): 380 | def init(self, in_shape, kernel_size, stride = None, **kargs): 381 | self.prev = in_shape 382 | self.kernel_size = kernel_size 383 | self.stride = kernel_size if stride is None else stride 384 | return h.getShapeConv(in_shape, (in_shape[0], kernel_size, kernel_size), stride) 385 | 386 | def forward(self, x, **kargs): 387 | return x.max_pool2d(self.kernel_size, self.stride) 388 | 389 | def printNet(self, f): 390 | print("MaxPool2D stride={}, kernel_size={}, input_shape={}".format(list(self.stride), list(self.shape[2:]), list(self.prev[1:]+self.prev[:1]) ), file = f) 391 | 392 | def neuronCount(self): 393 | return h.product(self.outShape) 394 | 395 | class AvgPool2D(InferModule): 396 | def init(self, in_shape, kernel_size, stride = None, **kargs): 397 | self.prev = in_shape 398 | self.kernel_size = kernel_size 399 | self.stride = kernel_size if stride is None else stride 400 | out_size = h.getShapeConv(in_shape, (in_shape[0], kernel_size, kernel_size), self.stride, padding = 1) 401 | return out_size 402 | 403 | def forward(self, x, **kargs): 404 | if h.product(x.size()[2:]) == 1: 405 | return x 406 | return x.avg_pool2d(kernel_size = self.kernel_size, stride = self.stride, padding = 1) 407 | 408 | def printNet(self, f): 409 | print("AvgPool2D stride={}, kernel_size={}, input_shape={}".format(list(self.stride), list(self.shape[2:]), list(self.prev[1:]+self.prev[:1]) ), file = f) 410 | 411 | def neuronCount(self): 412 | return h.product(self.outShape) 413 | 414 | class AdaptiveAvgPool2D(InferModule): 415 | def init(self, in_shape, out_shape, **kargs): 416 | self.prev = in_shape 417 | self.out_shape = list(out_shape) 418 | return [in_shape[0]] + self.out_shape 419 | 420 | def forward(self, x, **kargs): 421 | return x.adaptive_avg_pool2d(self.out_shape) 422 | 423 | def printNet(self, f): 424 | print("AdaptiveAvgPool2D out_Shape={} input_shape={}".format(list(self.out_shape), list(self.prev[1:]+self.prev[:1]) ), file = f) 425 | 426 | def neuronCount(self): 427 | return h.product(self.outShape) 428 | 429 | class Normalize(InferModule): 430 | def init(self, in_shape, mean, std, **kargs): 431 | self.mean_v = mean 432 | self.std_v = std 433 | self.mean = h.dten(mean) 434 | self.std = 1 / h.dten(std) 435 | return in_shape 436 | 437 | def forward(self, x, **kargs): 438 | mean_ex = self.mean.view(self.mean.shape[0],1,1).expand(*x.size()[1:]) 439 | std_ex = self.std.view(self.std.shape[0],1,1).expand(*x.size()[1:]) 440 | return (x - mean_ex) * std_ex 441 | 442 | def neuronCount(self): 443 | return 0 444 | 445 | def printNet(self, f): 446 | print("Normalize mean={} std={}".format(self.mean_v, self.std_v), file = f) 447 | 448 | def showNet(self, t = ""): 449 | print(t + "Normalize mean={} std={}".format(self.mean_v, self.std_v)) 450 | 451 | class Flatten(InferModule): 452 | def init(self, in_shape, **kargs): 453 | return h.product(in_shape) 454 | 455 | def forward(self, x, **kargs): 456 | s = x.size() 457 | return x.view(s[0], h.product(s[1:])) 458 | 459 | def neuronCount(self): 460 | return 0 461 | 462 | # class BatchNorm(InferModule): 463 | # def init(self, in_shape, track_running_stats = True, momentum = 0.1, eps=1e-5, **kargs): 464 | # self.gamma = torch.nn.Parameter(torch.Tensor(*in_shape)) 465 | # self.beta = torch.nn.Parameter(torch.Tensor(*in_shape)) 466 | # self.eps = eps 467 | # self.track_running_stats = track_running_stats 468 | # self.momentum = momentum 469 | 470 | # self.running_mean = None 471 | # self.running_var = None 472 | 473 | # self.num_batches_tracked = 0 474 | # return in_shape 475 | 476 | # def reset_parameters(self): 477 | # self.gamma.data.fill_(1) 478 | # self.beta.data.zero_() 479 | 480 | # def forward(self, x, **kargs): 481 | # exponential_average_factor = 0.0 482 | # if self.training and self.track_running_stats: 483 | # # TODO: if statement only here to tell the jit to skip emitting this when it is None 484 | # if self.num_batches_tracked is not None: 485 | # self.num_batches_tracked += 1 486 | # if self.momentum is None: # use cumulative moving average 487 | # exponential_average_factor = 1.0 / float(self.num_batches_tracked) 488 | # else: # use exponential moving average 489 | # exponential_average_factor = self.momentum 490 | 491 | # new_mean = x.vanillaTensorPart().detach().mean(dim=0) 492 | # new_var = x.vanillaTensorPart().detach().var(dim=0, unbiased=False) 493 | # if torch.isnan(new_var * 0).any(): 494 | # return x 495 | # if self.training: 496 | # self.running_mean = (1 - exponential_average_factor) * self.running_mean + exponential_average_factor * new_mean if self.running_mean is not None else new_mean 497 | # if self.running_var is None: 498 | # self.running_var = new_var 499 | # else: 500 | # q = (1 - exponential_average_factor) * self.running_var 501 | # r = exponential_average_factor * new_var 502 | # self.running_var = q + r 503 | 504 | # if self.track_running_stats and self.running_mean is not None and self.running_var is not None: 505 | # new_mean = self.running_mean 506 | # new_var = self.running_var 507 | 508 | # diver = 1 / (new_var + self.eps).sqrt() 509 | 510 | # if torch.isnan(diver).any(): 511 | # print("Really shouldn't happen ever") 512 | # return x 513 | # else: 514 | # out = (x - new_mean) * diver * self.gamma + self.beta 515 | # return out 516 | 517 | # def neuronCount(self): 518 | # return 0 519 | 520 | class BatchNorm(InferModule): 521 | def init(self, in_shape, track_running_stats = True, momentum = 0.1, eps=1e-5, **kargs): 522 | # self.gamma = torch.nn.Parameter(torch.Tensor(*in_shape)) 523 | # self.beta = torch.nn.Parameter(torch.Tensor(*in_shape)) 524 | self.weight = torch.nn.Parameter(torch.Tensor(in_shape[0])) 525 | self.bias = torch.nn.Parameter(torch.Tensor(in_shape[0])) 526 | self.eps = eps 527 | self.track_running_stats = track_running_stats 528 | self.momentum = momentum 529 | 530 | self.running_mean = torch.nn.Parameter(torch.Tensor(in_shape[0])) 531 | self.running_var = torch.nn.Parameter(torch.Tensor(in_shape[0])) 532 | self.num_batches_tracked = torch.nn.Parameter(torch.zeros([])) 533 | # self.running_mean = torch.zeros(in_shape[0]) 534 | # self.running_var = torch.zeros(in_shape[0]) 535 | 536 | # self.num_batches_tracked = 0 537 | return in_shape 538 | 539 | def reset_parameters(self): 540 | self.weight.data.fill_(1) 541 | self.bias.data.zero_() 542 | self.running_mean.data.fill_(0) 543 | self.running_var.data.fill_(1) 544 | self.num_batches_tracked.data.fill_(0) 545 | 546 | def forward(self, x, **kargs): 547 | exponential_average_factor = 0.0 548 | if self.training and self.track_running_stats: 549 | print('is training') 550 | # TODO: if statement only here to tell the jit to skip emitting this when it is None 551 | if self.num_batches_tracked is not None: 552 | self.num_batches_tracked += 1 553 | if self.momentum is None: # use cumulative moving average 554 | exponential_average_factor = 1.0 / float(self.num_batches_tracked) 555 | else: # use exponential moving average 556 | exponential_average_factor = self.momentum 557 | 558 | new_mean = x.vanillaTensorPart().detach().mean(dim=[0, 2, 3]) 559 | new_var = x.vanillaTensorPart().detach().var(dim=[0, 2, 3], unbiased=False) 560 | if torch.isnan(new_var * 0).any(): 561 | return x 562 | if self.training: 563 | if self.running_mean is not None: 564 | self.running_mean = (1 - exponential_average_factor) * self.running_mean + exponential_average_factor * new_mean 565 | else: 566 | self.running_mean = new_mean 567 | 568 | if self.running_var is not None: 569 | q = (1 - exponential_average_factor) * self.running_var 570 | r = exponential_average_factor * new_var 571 | self.running_var = q + r 572 | else: 573 | self.running_var = new_var 574 | 575 | if self.track_running_stats and self.running_mean is not None and self.running_var is not None: 576 | new_mean = self.running_mean 577 | new_var = self.running_var 578 | 579 | diver = 1 / (new_var + self.eps).sqrt() 580 | 581 | if torch.isnan(diver).any(): 582 | print("Really shouldn't happen ever") 583 | return x 584 | else: 585 | out = (x - new_mean[None, :, None, None]) * diver[None, :, None, None] \ 586 | * self.weight[None, :, None, None] + self.bias[None, :, None, None] 587 | return out 588 | 589 | def neuronCount(self): 590 | return 0 591 | 592 | class Unflatten2d(InferModule): 593 | def init(self, in_shape, w, **kargs): 594 | self.w = w 595 | self.outChan = int(h.product(in_shape) / (w * w)) 596 | 597 | return (self.outChan, self.w, self.w) 598 | 599 | def forward(self, x, **kargs): 600 | s = x.size() 601 | return x.view(s[0], self.outChan, self.w, self.w) 602 | 603 | def neuronCount(self): 604 | return 0 605 | 606 | 607 | class View(InferModule): 608 | def init(self, in_shape, out_shape, **kargs): 609 | assert(h.product(in_shape) == h.product(out_shape)) 610 | return out_shape 611 | 612 | def forward(self, x, **kargs): 613 | s = x.size() 614 | return x.view(s[0], *self.outShape) 615 | 616 | def neuronCount(self): 617 | return 0 618 | 619 | class CatTwo(InferModule): 620 | def init(self, in_shape, **kargs): 621 | [c, w, w] = in_shape 622 | return (c * 2, w, w) 623 | 624 | def forward(self, tp, **kargs): 625 | (x, y) = tp 626 | return x.cat(y, dim=1) 627 | 628 | def neuronCount(self): 629 | return 0 630 | 631 | class Seq(InferModule): 632 | def init(self, in_shape, *layers, **kargs): 633 | self.layers = layers 634 | self.net = nn.Sequential(*layers) 635 | self.prev = in_shape 636 | for s in layers: 637 | in_shape = s.infer(in_shape, **kargs).outShape 638 | return in_shape 639 | 640 | def forward(self, x, **kargs): 641 | 642 | for l in self.layers: 643 | x = l(x, **kargs) 644 | return x 645 | 646 | def clip_norm(self): 647 | for l in self.layers: 648 | l.clip_norm() 649 | 650 | def regularize(self, p): 651 | return sum(n.regularize(p) for n in self.layers) 652 | 653 | def remove_norm(self): 654 | for l in self.layers: 655 | l.remove_norm() 656 | 657 | def printNet(self, f): 658 | for l in self.layers: 659 | l.printNet(f) 660 | 661 | def showNet(self, *args, **kargs): 662 | for l in self.layers: 663 | l.showNet(*args, **kargs) 664 | 665 | def neuronCount(self): 666 | return sum([l.neuronCount() for l in self.layers ]) 667 | 668 | def depth(self): 669 | return sum([l.depth() for l in self.layers ]) 670 | 671 | def FFNN(layers, last_lin = False, last_zono = False, **kargs): 672 | starts = layers 673 | ends = [] 674 | if last_lin: 675 | ends = ([CorrelateAll(only_train=False)] if last_zono else []) + [PrintActivation(activation = "Affine"), Linear(layers[-1],**kargs)] 676 | starts = layers[:-1] 677 | 678 | return Seq(*([ Seq(PrintActivation(**kargs), Linear(s, **kargs), activation(**kargs)) for s in starts] + ends)) 679 | 680 | def Conv(*args, **kargs): 681 | return Seq(Conv2D(*args, **kargs), activation(**kargs)) 682 | 683 | def ConvTranspose(*args, **kargs): 684 | return Seq(ConvTranspose2D(*args, **kargs), activation(**kargs)) 685 | 686 | MP = MaxPool2D 687 | 688 | def LeNet(conv_layers, ly = [], bias = True, normal=False, **kargs): 689 | def transfer(tp): 690 | if isinstance(tp, InferModule): 691 | return tp 692 | if isinstance(tp[0], str): 693 | return MaxPool2D(*tp[1:]) 694 | return Conv(out_channels = tp[0], kernel_size = tp[1], stride = tp[-1] if len(tp) == 4 else 1, bias=bias, normal=normal, **kargs) 695 | conv = [transfer(s) for s in conv_layers] 696 | return Seq(*conv, FFNN(ly, **kargs, bias=bias)) if len(ly) > 0 else Seq(*conv) 697 | 698 | def InvLeNet(ly, w, conv_layers, bias = True, normal=False, **kargs): 699 | def transfer(tp, lin = False): 700 | return (ConvTranspose2D if lin else ConvTranspose)(out_channels = tp[0], kernel_size = tp[1], stride = tp[2], padding = tp[3], out_padding = tp[4], bias=False, normal=normal, **kargs) 701 | 702 | return Seq(FFNN(ly, bias=bias, **kargs), Unflatten2d(w), *[transfer(s) for s in conv_layers[:-1]], transfer(conv_layers[-1], lin=True) ) 703 | 704 | class FromByteImg(InferModule): 705 | def init(self, in_shape, **kargs): 706 | return in_shape 707 | 708 | def forward(self, x, **kargs): 709 | return x.to_dtype()/ 256. 710 | 711 | def neuronCount(self): 712 | return 0 713 | 714 | class Skip(InferModule): 715 | def init(self, in_shape, net1, net2, **kargs): 716 | self.net1 = net1.infer(in_shape, **kargs) 717 | self.net2 = net2.infer(in_shape, **kargs) 718 | assert(net1.outShape[1:] == net2.outShape[1:]) 719 | return [ net1.outShape[0] + net2.outShape[0] ] + net1.outShape[1:] 720 | 721 | def forward(self, x, **kargs): 722 | r1 = self.net1(x, **kargs) 723 | r2 = self.net2(x, **kargs) 724 | return r1.cat(r2, dim=1) 725 | 726 | def regularize(self, p): 727 | return self.net1.regularize(p) + self.net2.regularize(p) 728 | 729 | def clip_norm(self): 730 | self.net1.clip_norm() 731 | self.net2.clip_norm() 732 | 733 | def remove_norm(self): 734 | self.net1.remove_norm() 735 | self.net2.remove_norm() 736 | 737 | def neuronCount(self): 738 | return self.net1.neuronCount() + self.net2.neuronCount() 739 | 740 | def printNet(self, f): 741 | print("SkipNet1", file=f) 742 | self.net1.printNet(f) 743 | print("SkipNet2", file=f) 744 | self.net2.printNet(f) 745 | print("SkipCat dim=1", file=f) 746 | 747 | def showNet(self, t = ""): 748 | print(t+"SkipNet1") 749 | self.net1.showNet(" "+t) 750 | print(t+"SkipNet2") 751 | self.net2.showNet(" "+t) 752 | print(t+"SkipCat dim=1") 753 | 754 | class ParSum(InferModule): 755 | def init(self, in_shape, net1, net2, **kargs): 756 | self.net1 = net1.infer(in_shape, **kargs) 757 | self.net2 = net2.infer(in_shape, **kargs) 758 | assert(net1.outShape == net2.outShape) 759 | return net1.outShape 760 | 761 | 762 | 763 | def forward(self, x,just_left = False, **kargs): 764 | r1 = self.net1(x, **kargs) 765 | if just_left: 766 | return r1 767 | r2 = self.net2(x, **kargs) 768 | return x.addPar(r1,r2) 769 | 770 | def clip_norm(self): 771 | self.net1.clip_norm() 772 | self.net2.clip_norm() 773 | 774 | def remove_norm(self): 775 | self.net1.remove_norm() 776 | self.net2.remove_norm() 777 | 778 | def neuronCount(self): 779 | return self.net1.neuronCount() + self.net2.neuronCount() 780 | 781 | def depth(self): 782 | return max(self.net1.depth(), self.net2.depth()) 783 | 784 | def printNet(self, f): 785 | print("ParNet1", file=f) 786 | self.net1.printNet(f) 787 | print("ParNet2", file=f) 788 | self.net2.printNet(f) 789 | print("ParCat dim=1", file=f) 790 | 791 | def showNet(self, t = ""): 792 | print(t + "ParNet1") 793 | self.net1.showNet(" "+t) 794 | print(t + "ParNet2") 795 | self.net2.showNet(" "+t) 796 | print(t + "ParSum") 797 | 798 | class ToZono(Identity): 799 | def init(self, in_shape, customRelu = None, only_train = False, **kargs): 800 | self.customRelu = customRelu 801 | self.only_train = only_train 802 | return in_shape 803 | 804 | def forward(self, x, **kargs): 805 | return self.abstract_forward(x, **kargs) if self.training or not self.only_train else x 806 | 807 | def abstract_forward(self, x, **kargs): 808 | return x.abstractApplyLeaf('hybrid_to_zono', customRelu = self.customRelu) 809 | 810 | def showNet(self, t = ""): 811 | print(t + self.__class__.__name__ + " only_train=" + str(self.only_train)) 812 | 813 | class CorrelateAll(ToZono): 814 | def abstract_forward(self, x, **kargs): 815 | return x.abstractApplyLeaf('hybrid_to_zono',correlate=True, customRelu = self.customRelu) 816 | 817 | class ToHZono(ToZono): 818 | def abstract_forward(self, x, **kargs): 819 | return x.abstractApplyLeaf('zono_to_hybrid',customRelu = self.customRelu) 820 | 821 | class Concretize(ToZono): 822 | def init(self, in_shape, only_train = True, **kargs): 823 | self.only_train = only_train 824 | return in_shape 825 | 826 | def abstract_forward(self, x, **kargs): 827 | return x.abstractApplyLeaf('concretize') 828 | 829 | class NormalPermute(InferModule): 830 | def init(self, in_shape, prior_scale = None, **kargs): 831 | self.prior_scale = prior_scale 832 | return in_shape 833 | 834 | def forward(self, x, should_permute = True, **kargs): 835 | if should_permute: 836 | return x + torch.randn(x.shape).cudify() * self.prior_scale 837 | else: 838 | return x 839 | 840 | 841 | class CombineNormal(InferModule): 842 | def init(self, in_shape, latent_dims, prior_scale = None, **kargs): 843 | self.prior_scale = prior_scale 844 | return latent_dims 845 | 846 | def forward(self, x, is_training = True, **kargs): 847 | kl = 0 848 | while hasattr(x, "tag"): 849 | kl += x.tag 850 | x = x.a 851 | 852 | mu, logvar = x.view(x.shape[0], 2, *self.outShape).split(1, dim=1) 853 | 854 | if is_training: 855 | std = (logvar * 0.5).exp() 856 | eps = torch.randn(mu.shape).to(h.device) 857 | out = std * eps + mu 858 | else: 859 | out = mu 860 | 861 | if self.prior_scale is None: 862 | kl += 0.5 * (mu.pow(2) + logvar.exp() - logvar).view(x.shape[0], -1).sum(dim=1) 863 | else: 864 | prior = torch.distributions.normal.Normal(mu * 0 + 0.5, self.prior_scale + mu * 0) 865 | post = torch.distributions.normal.Normal(mu, logvar.exp().pow(0.5)) 866 | kl -= torch.distributions.kl.kl_divergence(post, prior).view(x.shape[0], -1).sum(dim=1) 867 | 868 | return ai.TaggedDomain(out.view(x.shape[0], *self.outShape), tag = kl) 869 | 870 | # stochastic correlation 871 | class CorrRand(Concretize): 872 | def init(self, in_shape, num_correlate, only_train = True, **kargs): 873 | self.only_train = only_train 874 | self.num_correlate = num_correlate 875 | return in_shape 876 | 877 | def abstract_forward(self, x): 878 | return x.abstractApplyLeaf("stochasticCorrelate", self.num_correlate) 879 | 880 | def showNet(self, t = ""): 881 | print(t + self.__class__.__name__ + " only_train=" + str(self.only_train) + " num_correlate="+ str(self.num_correlate)) 882 | 883 | class CorrMaxK(CorrRand): 884 | def abstract_forward(self, x): 885 | return x.abstractApplyLeaf("correlateMaxK", self.num_correlate) 886 | 887 | 888 | class CorrMaxPool2D(Concretize): 889 | def init(self,in_shape, kernel_size, only_train = True, max_type = ai.MaxTypes.head_beta, **kargs): 890 | self.only_train = only_train 891 | self.kernel_size = kernel_size 892 | self.max_type = max_type 893 | return in_shape 894 | 895 | def abstract_forward(self, x): 896 | return x.abstractApplyLeaf("correlateMaxPool", kernel_size = self.kernel_size, stride = self.kernel_size, max_type = self.max_type) 897 | 898 | def showNet(self, t = ""): 899 | print(t + self.__class__.__name__ + " only_train=" + str(self.only_train) + " kernel_size="+ str(self.kernel_size) + " max_type=" +str(self.max_type)) 900 | 901 | class CorrMaxPool3D(Concretize): 902 | def init(self,in_shape, kernel_size, only_train = True, max_type = ai.MaxTypes.only_beta, **kargs): 903 | self.only_train = only_train 904 | self.kernel_size = kernel_size 905 | self.max_type = max_type 906 | return in_shape 907 | 908 | def abstract_forward(self, x): 909 | return x.abstractApplyLeaf("correlateMaxPool", kernel_size = self.kernel_size, stride = self.kernel_size, max_type = self.max_type, max_pool = F.max_pool3d) 910 | 911 | def showNet(self, t = ""): 912 | print(t + self.__class__.__name__ + " only_train=" + str(self.only_train) + " kernel_size="+ str(self.kernel_size) + " max_type=" +self.max_type) 913 | 914 | class CorrFix(Concretize): 915 | def init(self,in_shape, k, only_train = True, **kargs): 916 | self.k = k 917 | self.only_train = only_train 918 | return in_shape 919 | 920 | def abstract_forward(self, x): 921 | sz = x.size() 922 | """ 923 | # for more control in the future 924 | indxs_1 = torch.arange(start = 0, end = sz[1], step = math.ceil(sz[1] / self.dims[1]) ) 925 | indxs_2 = torch.arange(start = 0, end = sz[2], step = math.ceil(sz[2] / self.dims[2]) ) 926 | indxs_3 = torch.arange(start = 0, end = sz[3], step = math.ceil(sz[3] / self.dims[3]) ) 927 | 928 | indxs = torch.stack(torch.meshgrid((indxs_1,indxs_2,indxs_3)), dim=3).view(-1,3) 929 | """ 930 | szm = h.product(sz[1:]) 931 | indxs = torch.arange(start = 0, end = szm, step = math.ceil(szm / self.k)) 932 | indxs = indxs.unsqueeze(0).expand(sz[0], indxs.size()[0]) 933 | 934 | 935 | return x.abstractApplyLeaf("correlate", indxs) 936 | 937 | def showNet(self, t = ""): 938 | print(t + self.__class__.__name__ + " only_train=" + str(self.only_train) + " k="+ str(self.k)) 939 | 940 | 941 | class DecorrRand(Concretize): 942 | def init(self, in_shape, num_decorrelate, only_train = True, **kargs): 943 | self.only_train = only_train 944 | self.num_decorrelate = num_decorrelate 945 | return in_shape 946 | 947 | def abstract_forward(self, x): 948 | return x.abstractApplyLeaf("stochasticDecorrelate", self.num_decorrelate) 949 | 950 | class DecorrMin(Concretize): 951 | def init(self, in_shape, num_decorrelate, only_train = True, num_to_keep = False, **kargs): 952 | self.only_train = only_train 953 | self.num_decorrelate = num_decorrelate 954 | self.num_to_keep = num_to_keep 955 | return in_shape 956 | 957 | def abstract_forward(self, x): 958 | return x.abstractApplyLeaf("decorrelateMin", self.num_decorrelate, num_to_keep = self.num_to_keep) 959 | 960 | 961 | def showNet(self, t = ""): 962 | print(t + self.__class__.__name__ + " only_train=" + str(self.only_train) + " k="+ str(self.num_decorrelate) + " num_to_keep=" + str(self.num_to_keep) ) 963 | 964 | class DeepLoss(ToZono): 965 | def init(self, in_shape, bw = 0.01, act = F.relu, **kargs): # weight must be between 0 and 1 966 | self.only_train = True 967 | self.bw = S.Const.initConst(bw) 968 | self.act = act 969 | return in_shape 970 | 971 | def abstract_forward(self, x, **kargs): 972 | if x.isPoint(): 973 | return x 974 | return ai.TaggedDomain(x, self.MLoss(self, x)) 975 | 976 | class MLoss(): 977 | def __init__(self, obj, x): 978 | self.obj = obj 979 | self.x = x 980 | 981 | def loss(self, a, *args, lr = 1, time = 0, **kargs): 982 | bw = self.obj.bw.getVal(time = time) 983 | pre_loss = a.loss(*args, time = time, **kargs, lr = lr * (1 - bw)) 984 | if bw <= 0.0: 985 | return pre_loss 986 | return (1 - bw) * pre_loss + bw * self.x.deep_loss(act = self.obj.act) 987 | 988 | def showNet(self, t = ""): 989 | print(t + self.__class__.__name__ + " only_train=" + str(self.only_train) + " bw="+ str(self.bw) + " act=" + str(self.act) ) 990 | 991 | class IdentLoss(DeepLoss): 992 | def abstract_forward(self, x, **kargs): 993 | return x 994 | 995 | def SkipNet(net1, net2, ffnn, **kargs): 996 | return Seq(Skip(net1,net2), FFNN(ffnn, **kargs)) 997 | 998 | def WideBlock(out_filters, downsample=False, k=3, bias=False, **kargs): 999 | if not downsample: 1000 | k_first = 3 1001 | skip_stride = 1 1002 | k_skip = 1 1003 | else: 1004 | k_first = 4 1005 | skip_stride = 2 1006 | k_skip = 2 1007 | 1008 | # conv2d280(input) 1009 | blockA = Conv2D(out_filters, kernel_size=k_skip, stride=skip_stride, padding=0, bias=bias, normal=True, **kargs) 1010 | 1011 | # conv2d282(relu(conv2d278(input))) 1012 | blockB = Seq( Conv(out_filters, kernel_size = k_first, stride = skip_stride, padding = 1, bias=bias, normal=True, **kargs) 1013 | , Conv2D(out_filters, kernel_size = k, stride = 1, padding = 1, bias=bias, normal=True, **kargs)) 1014 | return Seq(ParSum(blockA, blockB), activation(**kargs)) 1015 | 1016 | 1017 | 1018 | def BasicBlock(in_planes, planes, stride=1, bias = False, skip_net = False, **kargs): 1019 | block = Seq( Conv(planes, kernel_size = 3, stride = stride, padding = 1, bias=bias, normal=True, **kargs) 1020 | , Conv2D(planes, kernel_size = 3, stride = 1, padding = 1, bias=bias, normal=True, **kargs)) 1021 | 1022 | if stride != 1 or in_planes != planes: 1023 | block = ParSum(block, Conv2D(planes, kernel_size=1, stride=stride, bias=bias, normal=True, **kargs)) 1024 | elif not skip_net: 1025 | block = ParSum(block, Identity()) 1026 | return Seq(block, activation(**kargs)) 1027 | 1028 | # https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py 1029 | def ResNet(blocksList, extra = [], bias = False, **kargs): 1030 | 1031 | layers = [] 1032 | in_planes = 64 1033 | planes = 64 1034 | stride = 0 1035 | for num_blocks in blocksList: 1036 | if stride < 2: 1037 | stride += 1 1038 | 1039 | strides = [stride] + [1]*(num_blocks-1) 1040 | for stride in strides: 1041 | layers.append(BasicBlock(in_planes, planes, stride, bias = bias, **kargs)) 1042 | in_planes = planes 1043 | planes *= 2 1044 | 1045 | print("RESlayers: ", len(layers)) 1046 | for e,l in extra: 1047 | layers[l] = Seq(layers[l], e) 1048 | 1049 | return Seq(Conv(64, kernel_size=3, stride=1, padding = 1, bias=bias, normal=True, printShape=True), 1050 | *layers) 1051 | 1052 | 1053 | 1054 | def DenseNet(growthRate, depth, reduction, num_classes, bottleneck = True): 1055 | 1056 | def Bottleneck(growthRate): 1057 | interChannels = 4*growthRate 1058 | 1059 | n = Seq( ReLU(), 1060 | Conv2D(interChannels, kernel_size=1, bias=True, ibp_init = True), 1061 | ReLU(), 1062 | Conv2D(growthRate, kernel_size=3, padding=1, bias=True, ibp_init = True) 1063 | ) 1064 | 1065 | return Skip(Identity(), n) 1066 | 1067 | def SingleLayer(growthRate): 1068 | n = Seq( ReLU(), 1069 | Conv2D(growthRate, kernel_size=3, padding=1, bias=True, ibp_init = True)) 1070 | return Skip(Identity(), n) 1071 | 1072 | def Transition(nOutChannels): 1073 | return Seq( ReLU(), 1074 | Conv2D(nOutChannels, kernel_size = 1, bias = True, ibp_init = True), 1075 | AvgPool2D(kernel_size=2)) 1076 | 1077 | def make_dense(growthRate, nDenseBlocks, bottleneck): 1078 | return Seq(*[Bottleneck(growthRate) if bottleneck else SingleLayer(growthRate) for i in range(nDenseBlocks)]) 1079 | 1080 | nDenseBlocks = (depth-4) // 3 1081 | if bottleneck: 1082 | nDenseBlocks //= 2 1083 | 1084 | nChannels = 2*growthRate 1085 | conv1 = Conv2D(nChannels, kernel_size=3, padding=1, bias=True, ibp_init = True) 1086 | dense1 = make_dense(growthRate, nDenseBlocks, bottleneck) 1087 | nChannels += nDenseBlocks * growthRate 1088 | nOutChannels = int(math.floor(nChannels*reduction)) 1089 | trans1 = Transition(nOutChannels) 1090 | 1091 | nChannels = nOutChannels 1092 | dense2 = make_dense(growthRate, nDenseBlocks, bottleneck) 1093 | nChannels += nDenseBlocks*growthRate 1094 | nOutChannels = int(math.floor(nChannels*reduction)) 1095 | trans2 = Transition(nOutChannels) 1096 | 1097 | nChannels = nOutChannels 1098 | dense3 = make_dense(growthRate, nDenseBlocks, bottleneck) 1099 | 1100 | return Seq(conv1, dense1, trans1, dense2, trans2, dense3, 1101 | ReLU(), 1102 | AvgPool2D(kernel_size=8), 1103 | CorrelateAll(only_train=False, ignore_point = True), 1104 | Linear(num_classes, ibp_init = True)) 1105 | 1106 | --------------------------------------------------------------------------------