├── run_experiments.sh ├── utils.py ├── README.md ├── main.py ├── attack.py ├── evolve.py ├── classify.py ├── generator.py ├── facenet.py └── SAC.py /run_experiments.sh: -------------------------------------------------------------------------------- 1 | # Command for the simplified experiment with the reduced number of target classes and maximum episodes. 2 | # python main.py -model_name VGG16 -max_episodes 10000 -max_step 1 -alpha 0 -n_classes 1000 -z_dim 100 -n_target 100 3 | 4 | # If you want to experiment with the settings reported in the paper, please run the commands below. 5 | 6 | # VGG16 7 | python main.py -model_name VGG16 -max_episodes 40000 -max_step 1 -alpha 0 -n_classes 1000 -z_dim 100 -n_target 1000 8 | 9 | # ResNet-152 10 | # python main.py -model_name ResNet-152 -max_episodes 40000 -max_step 1 -alpha 0 -n_classes 1000 -z_dim 100 -n_target 1000 11 | 12 | # Face.evoLVe 13 | # python main.py -model_name Face.evoLVe -max_episodes 40000 -max_step 1 -alpha 0 -n_classes 1000 -z_dim 100 -n_target 1000 14 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import transforms 3 | 4 | def load_my_state_dict(self, state_dict): 5 | own_state = self.state_dict() 6 | for name, param in state_dict.items(): 7 | if name not in own_state: 8 | print(name) 9 | continue 10 | own_state[name].copy_(param.data) 11 | 12 | def get_deprocessor(): 13 | proc = [] 14 | proc.append(transforms.Resize((112, 112))) 15 | proc.append(transforms.ToTensor()) 16 | return transforms.Compose(proc) 17 | 18 | def low2high(img): 19 | bs = img.size(0) 20 | proc = get_deprocessor() 21 | img_tensor = img.detach().cpu().float() 22 | img = torch.zeros(bs, 3, 112, 112) 23 | for i in range(bs): 24 | img_i = transforms.ToPILImage()(img_tensor[i, :, :, :]).convert('RGB') 25 | img_i = proc(img_i) 26 | img[i, :, :, :] = img_i[:, :, :] 27 | 28 | img = img.cuda() 29 | return img 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Reinforcement Learning-Based Black-Box Model Inversion Attacks 2 | 3 | This is a PyTorch implementation of the paper "Reinforcement Learning-Based Black-Box Model Inversion Attacks" accepted by CVPR 2023. 4 | 5 | ## Dependencies 6 | 7 | This code has been tested with Python 3.8.8, PyTorch 1.8.0 and cuda 10.2.89. 8 | 9 | ## Weights 10 | 11 | Model weights for experiments can be downloaded from the link below. 12 | https://drive.google.com/drive/folders/15Xcqoz53TQVUUyZe9HNtchoCriLUeQ-O?usp=sharing 13 | 14 | ## Usage 15 | 16 | Please check the commands included in `run_experiments.sh`. 17 | There are commands for both the simplified experiment and the experiments reported in the paper. 18 | 19 | Please run 20 | 21 | `bash run_experiments.sh` 22 | 23 | to reproduce the results. 24 | 25 | ## Timeline 26 | 27 | [04.28] Mistakes made during code cleanup have been fixed. 28 | 29 | ## Acknowledgements 30 | 31 | This repository contains code snippets and some model weights from repositories mentioned below. 32 | 33 | https://github.com/MKariya1998/GMI-Attack 34 | 35 | https://github.com/SCccc21/Knowledge-Enriched-DMI 36 | 37 | https://github.com/BY571/Soft-Actor-Critic-and-Extensions 38 | 39 | 40 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import random 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import argparse 8 | from generator import Generator 9 | from classify import * 10 | from utils import * 11 | from SAC import Agent 12 | from attack import inversion 13 | 14 | parser = argparse.ArgumentParser(description="RLB-MI") 15 | parser.add_argument('-model_name', default='VGG16') 16 | parser.add_argument("-max_episodes", type=int, default=40000) 17 | parser.add_argument("-max_step", type=int, default=1) 18 | parser.add_argument("-seed", type=int, default=42) 19 | parser.add_argument("-alpha", type=float, default=0) 20 | parser.add_argument("-n_classes", type=int, default=1000) 21 | parser.add_argument("-z_dim", type=int, default=100) 22 | parser.add_argument("-n_target", type=int, default=100) 23 | args = parser.parse_args() 24 | 25 | if __name__ == "__main__": 26 | model_name = args.model_name 27 | max_episodes = args.max_episodes 28 | max_step = args.max_step 29 | seed = args.seed 30 | alpha = args.alpha 31 | n_classes = args.n_classes 32 | z_dim = args.z_dim 33 | n_target = args.n_target 34 | 35 | print("Target Model : " + model_name) 36 | G = Generator(z_dim) 37 | G = nn.DataParallel(G).cuda() 38 | G = G.cuda() 39 | ckp_G = torch.load('weights/CelebA.tar')['state_dict'] 40 | load_my_state_dict(G, ckp_G) 41 | G.eval() 42 | 43 | if model_name == "VGG16": 44 | T = VGG16(n_classes) 45 | path_T = './weights/VGG16.tar' 46 | elif model_name == 'ResNet-152': 47 | T = IR152(n_classes) 48 | path_T = './weights/ResNet-152.tar' 49 | elif model_name == "Face.evoLVe": 50 | T = FaceNet64(n_classes) 51 | path_T = './weights/Face.evoLVe.tar' 52 | 53 | T = torch.nn.DataParallel(T).cuda() 54 | ckp_T = torch.load(path_T) 55 | T.load_state_dict(ckp_T['state_dict'], strict=False) 56 | T.eval() 57 | 58 | E = FaceNet(n_classes) 59 | path_E = './weights/Eval.tar' 60 | E = torch.nn.DataParallel(E).cuda() 61 | ckp_E = torch.load(path_E) 62 | E.load_state_dict(ckp_E['state_dict'], strict=False) 63 | E.eval() 64 | 65 | def seed_everything(seed: int = 42): 66 | random.seed(seed) 67 | np.random.seed(seed) 68 | os.environ["PYTHONHASHSEED"] = str(seed) 69 | torch.manual_seed(seed) 70 | torch.cuda.manual_seed(seed) # type: ignore 71 | torch.backends.cudnn.deterministic = True # type: ignore 72 | torch.backends.cudnn.benchmark = True # type: ignore 73 | 74 | seed_everything(seed) 75 | 76 | total = 0 77 | cnt = 0 78 | cnt5 = 0 79 | 80 | identities = range(n_classes) 81 | targets = random.sample(identities, n_target) 82 | 83 | for i in targets: 84 | agent = Agent(state_size=z_dim, action_size=z_dim, random_seed=seed, hidden_size=256, action_prior="uniform") 85 | recon_image = inversion(agent, G, T, alpha, z_dim=z_dim, max_episodes=max_episodes, max_step=max_step, label=i, model_name=model_name) 86 | _, output= E(low2high(recon_image)) 87 | eval_prob = F.softmax(output[0], dim=-1) 88 | top_idx = torch.argmax(eval_prob) 89 | _, top5_idx = torch.topk(eval_prob, 5) 90 | 91 | total += 1 92 | if top_idx == i: 93 | cnt += 1 94 | if i in top5_idx: 95 | cnt5 += 1 96 | 97 | acc = cnt / total 98 | acc5 = cnt5 / total 99 | print("Classes {}/{}, Accuracy : {:.3f}, Top-5 Accuracy : {:.3f}".format(total, n_target, acc, acc5)) 100 | 101 | -------------------------------------------------------------------------------- /attack.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torchvision.utils import save_image 7 | from generator import Generator 8 | from classify import * 9 | from utils import * 10 | from copy import deepcopy 11 | 12 | def inversion(agent, G, T, alpha, z_dim = 100, max_episodes=40000, max_step=1, label=0, model_name="VGG16"): 13 | print("Target Label : " + str(label)) 14 | best_score = 0 15 | 16 | for i_episode in range(1, max_episodes + 1): 17 | y = torch.tensor([label]).cuda() 18 | 19 | # Initialize the state at the beginning of each episode. 20 | z = torch.randn(1, z_dim).cuda() 21 | state = deepcopy(z.cpu().numpy()) 22 | for t in range(max_step): 23 | 24 | # Update the state and create images from the updated state and action. 25 | action = agent.act(state) 26 | z = alpha * z + (1 - alpha) * action.clone().detach().reshape((1, len(action))).cuda() 27 | next_state = deepcopy(z.cpu().numpy()) 28 | state_image = G(z).detach() 29 | action_image = G(action.clone().detach().reshape((1, len(action))).cuda()).detach() 30 | 31 | # Calculate the reward. 32 | _, state_output = T(state_image) 33 | _, action_output = T(action_image) 34 | score1 = float(torch.mean(torch.diag(torch.index_select(torch.log(F.softmax(state_output, dim=-1)).data, 1, y)))) 35 | score2 = float(torch.mean(torch.diag(torch.index_select(torch.log(F.softmax(action_output, dim=-1)).data, 1, y)))) 36 | score3 = math.log(max(1e-7, float(torch.index_select(F.softmax(state_output, dim=-1).data, 1, y)) - float(torch.max(torch.cat((F.softmax(state_output, dim=-1)[0,:y],F.softmax(state_output, dim=-1)[0,y+1:])), dim=-1)[0]))) 37 | reward = 2 * score1 + 2 * score2 + 8 * score3 38 | 39 | # Update policy. 40 | if t == max_step - 1 : 41 | done = True 42 | else : 43 | done = False 44 | 45 | agent.step(state, action, reward, next_state, done, t) 46 | state = next_state 47 | 48 | # Save the image with the maximum confidence score. 49 | test_images = [] 50 | test_scores = [] 51 | for i in range(1): 52 | with torch.no_grad(): 53 | z_test = torch.randn(1, z_dim).cuda() 54 | for t in range(max_step): 55 | state_test = z_test.cpu().numpy() 56 | action_test = agent.act(state_test) 57 | z_test = alpha * z_test + (1 - alpha) * action_test.clone().detach().reshape((1, len(action_test))).cuda() 58 | test_image = G(z_test).detach() 59 | test_images.append(test_image.cpu()) 60 | _, test_output = T(test_image) 61 | test_score = float(torch.mean(torch.diag(torch.index_select(F.softmax(test_output, dim=-1).data, 1, y)))) 62 | test_scores.append(test_score) 63 | mean_score = sum(test_scores) / len(test_scores) 64 | if mean_score >= best_score: 65 | best_score = mean_score 66 | best_images = torch.vstack(test_images) 67 | os.makedirs("./result/images/{}".format(model_name), exist_ok=True) 68 | os.makedirs("./result/models/{}".format(model_name), exist_ok=True) 69 | save_image(best_images, "./result/images/{}/{}_{}.png".format(model_name, label, alpha), nrow=10) 70 | torch.save(agent.actor_local.state_dict(), "./result/models/{}/actor_{}_{}.pt".format(model_name, label, alpha)) 71 | if i_episode % 10000 == 0 or i_episode == max_episodes: 72 | print('Episodes {}/{}, Confidence score for the target model : {:.4f}'.format(i_episode, max_episodes,best_score)) 73 | return best_images -------------------------------------------------------------------------------- /evolve.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, ReLU, Sigmoid, Dropout, MaxPool2d, \ 4 | AdaptiveAvgPool2d, Sequential, Module 5 | from collections import namedtuple 6 | 7 | 8 | # Support: ['IR_50', 'IR_101', 'IR_152', 'IR_SE_50', 'IR_SE_101', 'IR_SE_152'] 9 | 10 | 11 | class Flatten(Module): 12 | def forward(self, input): 13 | return input.view(input.size(0), -1) 14 | 15 | 16 | def l2_norm(input, axis=1): 17 | norm = torch.norm(input, 2, axis, True) 18 | output = torch.div(input, norm) 19 | 20 | return output 21 | 22 | 23 | class SEModule(Module): 24 | def __init__(self, channels, reduction): 25 | super(SEModule, self).__init__() 26 | self.avg_pool = AdaptiveAvgPool2d(1) 27 | self.fc1 = Conv2d( 28 | channels, channels // reduction, kernel_size=1, padding=0, bias=False) 29 | 30 | nn.init.xavier_uniform_(self.fc1.weight.data) 31 | 32 | self.relu = ReLU(inplace=True) 33 | self.fc2 = Conv2d( 34 | channels // reduction, channels, kernel_size=1, padding=0, bias=False) 35 | 36 | self.sigmoid = Sigmoid() 37 | 38 | def forward(self, x): 39 | module_input = x 40 | x = self.avg_pool(x) 41 | x = self.fc1(x) 42 | x = self.relu(x) 43 | x = self.fc2(x) 44 | x = self.sigmoid(x) 45 | 46 | return module_input * x 47 | 48 | 49 | class bottleneck_IR(Module): 50 | def __init__(self, in_channel, depth, stride): 51 | super(bottleneck_IR, self).__init__() 52 | if in_channel == depth: 53 | self.shortcut_layer = MaxPool2d(1, stride) 54 | else: 55 | self.shortcut_layer = Sequential( 56 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), BatchNorm2d(depth)) 57 | self.res_layer = Sequential( 58 | BatchNorm2d(in_channel), 59 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), 60 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)) 61 | 62 | def forward(self, x): 63 | shortcut = self.shortcut_layer(x) 64 | res = self.res_layer(x) 65 | 66 | return res + shortcut 67 | 68 | 69 | class bottleneck_IR_SE(Module): 70 | def __init__(self, in_channel, depth, stride): 71 | super(bottleneck_IR_SE, self).__init__() 72 | if in_channel == depth: 73 | self.shortcut_layer = MaxPool2d(1, stride) 74 | else: 75 | self.shortcut_layer = Sequential( 76 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 77 | BatchNorm2d(depth)) 78 | self.res_layer = Sequential( 79 | BatchNorm2d(in_channel), 80 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), 81 | PReLU(depth), 82 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), 83 | BatchNorm2d(depth), 84 | SEModule(depth, 16) 85 | ) 86 | 87 | def forward(self, x): 88 | shortcut = self.shortcut_layer(x) 89 | res = self.res_layer(x) 90 | 91 | return res + shortcut 92 | 93 | 94 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): 95 | '''A named tuple describing a ResNet block.''' 96 | 97 | 98 | def get_block(in_channel, depth, num_units, stride=2): 99 | 100 | return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] 101 | 102 | 103 | def get_blocks(num_layers): 104 | if num_layers == 50: 105 | blocks = [ 106 | get_block(in_channel=64, depth=64, num_units=3), 107 | get_block(in_channel=64, depth=128, num_units=4), 108 | get_block(in_channel=128, depth=256, num_units=14), 109 | get_block(in_channel=256, depth=512, num_units=3) 110 | ] 111 | elif num_layers == 100: 112 | blocks = [ 113 | get_block(in_channel=64, depth=64, num_units=3), 114 | get_block(in_channel=64, depth=128, num_units=13), 115 | get_block(in_channel=128, depth=256, num_units=30), 116 | get_block(in_channel=256, depth=512, num_units=3) 117 | ] 118 | elif num_layers == 152: 119 | blocks = [ 120 | get_block(in_channel=64, depth=64, num_units=3), 121 | get_block(in_channel=64, depth=128, num_units=8), 122 | get_block(in_channel=128, depth=256, num_units=36), 123 | get_block(in_channel=256, depth=512, num_units=3) 124 | ] 125 | 126 | return blocks 127 | 128 | 129 | class Backbone64(Module): 130 | def __init__(self, input_size, num_layers, mode='ir'): 131 | super(Backbone64, self).__init__() 132 | assert input_size[0] in [64, 112, 224], "input_size should be [112, 112] or [224, 224]" 133 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" 134 | assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" 135 | blocks = get_blocks(num_layers) 136 | if mode == 'ir': 137 | unit_module = bottleneck_IR 138 | elif mode == 'ir_se': 139 | unit_module = bottleneck_IR_SE 140 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), 141 | BatchNorm2d(64), 142 | PReLU(64)) 143 | 144 | modules = [] 145 | for block in blocks: 146 | for bottleneck in block: 147 | modules.append( 148 | unit_module(bottleneck.in_channel, 149 | bottleneck.depth, 150 | bottleneck.stride)) 151 | self.body = Sequential(*modules) 152 | 153 | self._initialize_weights() 154 | 155 | def forward(self, x): 156 | x = self.input_layer(x) 157 | x = self.body(x) 158 | #x = self.output_layer(x) 159 | 160 | return x 161 | 162 | def _initialize_weights(self): 163 | for m in self.modules(): 164 | if isinstance(m, nn.Conv2d): 165 | nn.init.xavier_uniform_(m.weight.data) 166 | if m.bias is not None: 167 | m.bias.data.zero_() 168 | elif isinstance(m, nn.BatchNorm2d): 169 | m.weight.data.fill_(1) 170 | m.bias.data.zero_() 171 | elif isinstance(m, nn.BatchNorm1d): 172 | m.weight.data.fill_(1) 173 | m.bias.data.zero_() 174 | elif isinstance(m, nn.Linear): 175 | nn.init.xavier_uniform_(m.weight.data) 176 | if m.bias is not None: 177 | m.bias.data.zero_() 178 | 179 | class Backbone112(Module): 180 | def __init__(self, input_size, num_layers, mode='ir'): 181 | super(Backbone112, self).__init__() 182 | assert input_size[0] in [64, 112, 224], "input_size should be [112, 112] or [224, 224]" 183 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" 184 | assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" 185 | blocks = get_blocks(num_layers) 186 | if mode == 'ir': 187 | unit_module = bottleneck_IR 188 | elif mode == 'ir_se': 189 | unit_module = bottleneck_IR_SE 190 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), 191 | BatchNorm2d(64), 192 | PReLU(64)) 193 | 194 | if input_size[0] == 112: 195 | self.output_layer = Sequential(BatchNorm2d(512), 196 | Dropout(), 197 | Flatten(), 198 | Linear(512 * 7 * 7, 512), 199 | BatchNorm1d(512)) 200 | else: 201 | self.output_layer = Sequential(BatchNorm2d(512), 202 | Dropout(), 203 | Flatten(), 204 | Linear(512 * 14 * 14, 512), 205 | BatchNorm1d(512)) 206 | 207 | modules = [] 208 | for block in blocks: 209 | for bottleneck in block: 210 | modules.append( 211 | unit_module(bottleneck.in_channel, 212 | bottleneck.depth, 213 | bottleneck.stride)) 214 | self.body = Sequential(*modules) 215 | 216 | self._initialize_weights() 217 | 218 | def forward(self, x): 219 | x = self.input_layer(x) 220 | x = self.body(x) 221 | x = self.output_layer(x) 222 | 223 | return x 224 | 225 | def _initialize_weights(self): 226 | for m in self.modules(): 227 | if isinstance(m, nn.Conv2d): 228 | nn.init.xavier_uniform_(m.weight.data) 229 | if m.bias is not None: 230 | m.bias.data.zero_() 231 | elif isinstance(m, nn.BatchNorm2d): 232 | m.weight.data.fill_(1) 233 | m.bias.data.zero_() 234 | elif isinstance(m, nn.BatchNorm1d): 235 | m.weight.data.fill_(1) 236 | m.bias.data.zero_() 237 | elif isinstance(m, nn.Linear): 238 | nn.init.xavier_uniform_(m.weight.data) 239 | if m.bias is not None: 240 | m.bias.data.zero_() 241 | 242 | 243 | def IR_50_64(input_size): 244 | """Constructs a ir-50 model. 245 | """ 246 | model = Backbone64(input_size, 50, 'ir') 247 | 248 | return model 249 | 250 | def IR_50_112(input_size): 251 | """Constructs a ir-50 model. 252 | """ 253 | model = Backbone112(input_size, 50, 'ir') 254 | 255 | return model 256 | 257 | 258 | def IR_100(input_size): 259 | """Constructs a ir-100 model. 260 | """ 261 | model = Backbone(input_size, 100, 'ir') 262 | 263 | return model 264 | 265 | def IR_152_64(input_size): 266 | """Constructs a ir-152 model. 267 | """ 268 | model = Backbone64(input_size, 152, 'ir') 269 | 270 | return model 271 | 272 | 273 | def IR_152_112(input_size): 274 | """Constructs a ir-152 model. 275 | """ 276 | model = Backbone112(input_size, 152, 'ir') 277 | 278 | return model 279 | 280 | def IR_SE_50(input_size): 281 | """Constructs a ir_se-50 model. 282 | """ 283 | model = Backbone(input_size, 50, 'ir_se') 284 | 285 | return model 286 | 287 | 288 | def IR_SE_101(input_size): 289 | """Constructs a ir_se-101 model. 290 | """ 291 | model = Backbone(input_size, 100, 'ir_se') 292 | 293 | return model 294 | 295 | 296 | def IR_SE_152(input_size): 297 | """Constructs a ir_se-152 model. 298 | """ 299 | model = Backbone(input_size, 152, 'ir_se') 300 | 301 | return model -------------------------------------------------------------------------------- /classify.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import time 3 | import torch 4 | import numpy as np 5 | import torch.nn as nn 6 | import torchvision.models 7 | import torch.nn.functional as F 8 | from torch.nn.modules.loss import _Loss 9 | import math, evolve 10 | 11 | class Flatten(nn.Module): 12 | def forward(self, input): 13 | return input.view(input.size(0), -1) 14 | 15 | class Mnist_CNN(nn.Module): 16 | def __init__(self): 17 | super(Mnist_CNN, self).__init__() 18 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 19 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 20 | self.conv2_drop = nn.Dropout2d() 21 | self.fc1 = nn.Linear(500, 50) 22 | self.fc2 = nn.Linear(50, 5) 23 | 24 | def forward(self, x): 25 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 26 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 27 | x = x.view(x.size(0), -1) 28 | x = F.relu(self.fc1(x)) 29 | x = F.dropout(x, training=self.training) 30 | res = self.fc2(x) 31 | return [x, res] 32 | 33 | class VGG16(nn.Module): 34 | def __init__(self, n_classes): 35 | super(VGG16, self).__init__() 36 | model = torchvision.models.vgg16_bn(pretrained=False) 37 | self.feature = model.features 38 | self.feat_dim = 512 * 2 * 2 39 | self.n_classes = n_classes 40 | self.bn = nn.BatchNorm1d(self.feat_dim) 41 | self.bn.bias.requires_grad_(False) # no shift 42 | self.fc_layer = nn.Linear(self.feat_dim, self.n_classes) 43 | 44 | def forward(self, x): 45 | feature = self.feature(x) 46 | feature = feature.view(feature.size(0), -1) 47 | feature = self.bn(feature) 48 | res = self.fc_layer(feature) 49 | 50 | return [feature, res] 51 | 52 | def predict(self, x): 53 | feature = self.feature(x) 54 | feature = feature.view(feature.size(0), -1) 55 | feature = self.bn(feature) 56 | res = self.fc_layer(feature) 57 | out = F.softmax(res, dim=1) 58 | 59 | return out 60 | 61 | class VGG16_vib(nn.Module): 62 | def __init__(self, n_classes): 63 | super(VGG16_vib, self).__init__() 64 | model = torchvision.models.vgg16_bn(pretrained=True) 65 | self.feature = model.features 66 | self.feat_dim = 512 * 2 * 2 67 | self.k = self.feat_dim // 2 68 | self.n_classes = n_classes 69 | self.st_layer = nn.Linear(self.feat_dim, self.k * 2) 70 | self.fc_layer = nn.Linear(self.k, self.n_classes) 71 | 72 | def forward(self, x, mode="train"): 73 | feature = self.feature(x) 74 | feature = feature.view(feature.size(0), -1) 75 | statis = self.st_layer(feature) 76 | mu, std = statis[:, :self.k], statis[:, self.k:] 77 | 78 | std = F.softplus(std-5, beta=1) 79 | eps = torch.FloatTensor(std.size()).normal_().cuda() 80 | res = mu + std * eps 81 | out = self.fc_layer(res) 82 | 83 | return [feature, out, mu, std] 84 | 85 | def predict(self, x): 86 | feature = self.feature(x) 87 | feature = feature.view(feature.size(0), -1) 88 | statis = self.st_layer(feature) 89 | mu, std = statis[:, :self.k], statis[:, self.k:] 90 | 91 | std = F.softplus(std-5, beta=1) 92 | eps = torch.FloatTensor(std.size()).normal_().cuda() 93 | res = mu + std * eps 94 | out = self.fc_layer(res) 95 | 96 | return out 97 | 98 | class CrossEntropyLoss(_Loss): 99 | def forward(self, out, gt, mode="reg"): 100 | bs = out.size(0) 101 | loss = - torch.mul(gt.float(), torch.log(out.float() + 1e-7)) 102 | if mode == "dp": 103 | loss = torch.sum(loss, dim=1).view(-1) 104 | else: 105 | loss = torch.sum(loss) / bs 106 | return loss 107 | 108 | class BinaryLoss(_Loss): 109 | def forward(self, out, gt): 110 | bs = out.size(0) 111 | loss = - (gt * torch.log(out.float()+1e-7) + (1-gt) * torch.log(1-out.float()+1e-7)) 112 | loss = torch.mean(loss) 113 | return loss 114 | 115 | 116 | class FaceNet(nn.Module): 117 | def __init__(self, num_classes=1000): 118 | super(FaceNet, self).__init__() 119 | self.feature = evolve.IR_50_112((112, 112)) 120 | self.feat_dim = 512 121 | self.num_classes = num_classes 122 | self.fc_layer = nn.Linear(self.feat_dim, self.num_classes) 123 | 124 | def predict(self, x): 125 | feat = self.feature(x) 126 | feat = feat.view(feat.size(0), -1) 127 | out = self.fc_layer(feat) 128 | return out 129 | 130 | def forward(self, x): 131 | # print("input shape:", x.shape) 132 | # import pdb; pdb.set_trace() 133 | 134 | feat = self.feature(x) 135 | feat = feat.view(feat.size(0), -1) 136 | out = self.fc_layer(feat) 137 | return [feat, out] 138 | 139 | class FaceNet64(nn.Module): 140 | def __init__(self, num_classes = 1000): 141 | super(FaceNet64, self).__init__() 142 | self.feature = evolve.IR_50_64((64, 64)) 143 | self.feat_dim = 512 144 | self.num_classes = num_classes 145 | self.output_layer = nn.Sequential(nn.BatchNorm2d(512), 146 | nn.Dropout(), 147 | Flatten(), 148 | nn.Linear(512 * 4 * 4, 512), 149 | nn.BatchNorm1d(512)) 150 | 151 | self.fc_layer = nn.Linear(self.feat_dim, self.num_classes) 152 | 153 | def forward(self, x): 154 | feat = self.feature(x) 155 | feat = self.output_layer(feat) 156 | feat = feat.view(feat.size(0), -1) 157 | out = self.fc_layer(feat) 158 | __, iden = torch.max(out, dim=1) 159 | iden = iden.view(-1, 1) 160 | return feat, out 161 | 162 | class IR152(nn.Module): 163 | def __init__(self, num_classes=1000): 164 | super(IR152, self).__init__() 165 | self.feature = evolve.IR_152_64((64, 64)) 166 | self.feat_dim = 512 167 | self.num_classes = num_classes 168 | self.output_layer = nn.Sequential(nn.BatchNorm2d(512), 169 | nn.Dropout(), 170 | Flatten(), 171 | nn.Linear(512 * 4 * 4, 512), 172 | nn.BatchNorm1d(512)) 173 | 174 | self.fc_layer = nn.Linear(self.feat_dim, self.num_classes) 175 | 176 | def forward(self, x): 177 | feat = self.feature(x) 178 | feat = self.output_layer(feat) 179 | feat = feat.view(feat.size(0), -1) 180 | out = self.fc_layer(feat) 181 | return feat, out 182 | 183 | class IR152_vib(nn.Module): 184 | def __init__(self, num_classes=1000): 185 | super(IR152_vib, self).__init__() 186 | self.feature = evolve.IR_152_64((64, 64)) 187 | self.feat_dim = 512 188 | self.k = self.feat_dim // 2 189 | self.n_classes = num_classes 190 | self.output_layer = nn.Sequential(nn.BatchNorm2d(512), 191 | nn.Dropout(), 192 | Flatten(), 193 | nn.Linear(512 * 4 * 4, 512), 194 | nn.BatchNorm1d(512)) 195 | 196 | self.st_layer = nn.Linear(self.feat_dim, self.k * 2) 197 | self.fc_layer = nn.Sequential( 198 | nn.Linear(self.k, self.n_classes), 199 | nn.Softmax(dim = 1)) 200 | 201 | def forward(self, x): 202 | feature = self.output_layer(self.feature(x)) 203 | feature = feature.view(feature.size(0), -1) 204 | statis = self.st_layer(feature) 205 | mu, std = statis[:, :self.k], statis[:, self.k:] 206 | 207 | std = F.softplus(std-5, beta=1) 208 | eps = torch.FloatTensor(std.size()).normal_().cuda() 209 | res = mu + std * eps 210 | out = self.fc_layer(res) 211 | __, iden = torch.max(out, dim=1) 212 | iden = iden.view(-1, 1) 213 | 214 | return feature, out, iden, mu, st 215 | 216 | class IR50(nn.Module): 217 | def __init__(self, num_classes=1000): 218 | super(IR50, self).__init__() 219 | self.feature = evolve.IR_50_64((64, 64)) 220 | self.feat_dim = 512 221 | self.num_classes = num_classes 222 | self.output_layer = nn.Sequential(nn.BatchNorm2d(512), 223 | nn.Dropout(), 224 | Flatten(), 225 | nn.Linear(512 * 4 * 4, 512), 226 | nn.BatchNorm1d(512)) 227 | 228 | self.st_layer = nn.Linear(self.feat_dim, self.k * 2) 229 | self.fc_layer = nn.Sequential( 230 | nn.Linear(self.k, self.n_classes), 231 | nn.Softmax(dim = 1)) 232 | 233 | def forward(self, x): 234 | feature = self.output_layer(self.feature(x)) 235 | feature = feature.view(feature.size(0), -1) 236 | statis = self.st_layer(feature) 237 | mu, std = statis[:, :self.k], statis[:, self.k:] 238 | 239 | std = F.softplus(std-5, beta=1) 240 | eps = torch.FloatTensor(std.size()).normal_().cuda() 241 | res = mu + std * eps 242 | out = self.fc_layer(res) 243 | __, iden = torch.max(out, dim=1) 244 | iden = iden.view(-1, 1) 245 | 246 | return feature, out, iden, mu, std 247 | 248 | class IR50_vib(nn.Module): 249 | def __init__(self, num_classes=1000): 250 | super(IR50_vib, self).__init__() 251 | self.feature = evolve.IR_50_64((64, 64)) 252 | self.feat_dim = 512 253 | self.n_classes = num_classes 254 | self.k = self.feat_dim // 2 255 | self.output_layer = nn.Sequential(nn.BatchNorm2d(512), 256 | nn.Dropout(), 257 | Flatten(), 258 | nn.Linear(512 * 4 * 4, 512), 259 | nn.BatchNorm1d(512)) 260 | 261 | self.st_layer = nn.Linear(self.feat_dim, self.k * 2) 262 | self.fc_layer = nn.Sequential( 263 | nn.Linear(self.k, self.n_classes), 264 | nn.Softmax(dim=1)) 265 | 266 | def forward(self, x): 267 | feat = self.output_layer(self.feature(x)) 268 | feat = feat.view(feat.size(0), -1) 269 | statis = self.st_layer(feat) 270 | mu, std = statis[:, :self.k], statis[:, self.k:] 271 | 272 | std = F.softplus(std-5, beta=1) 273 | eps = torch.FloatTensor(std.size()).normal_().cuda() 274 | res = mu + std * eps 275 | out = self.fc_layer(res) 276 | __, iden = torch.max(out, dim=1) 277 | iden = iden.view(-1, 1) 278 | 279 | return feat, out, iden, mu, std 280 | 281 | 282 | -------------------------------------------------------------------------------- /generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Generator(nn.Module): 5 | def __init__(self, in_dim=100, dim=64): 6 | super(Generator, self).__init__() 7 | def dconv_bn_relu(in_dim, out_dim): 8 | return nn.Sequential( 9 | nn.ConvTranspose2d(in_dim, out_dim, 5, 2, 10 | padding=2, output_padding=1, bias=False), 11 | nn.BatchNorm2d(out_dim), 12 | nn.ReLU()) 13 | 14 | self.l1 = nn.Sequential( 15 | nn.Linear(in_dim, dim * 8 * 4 * 4, bias=False), 16 | nn.BatchNorm1d(dim * 8 * 4 * 4), 17 | nn.ReLU()) 18 | self.l2_5 = nn.Sequential( 19 | dconv_bn_relu(dim * 8, dim * 4), 20 | dconv_bn_relu(dim * 4, dim * 2), 21 | dconv_bn_relu(dim * 2, dim), 22 | nn.ConvTranspose2d(dim, 3, 5, 2, padding=2, output_padding=1), 23 | nn.Sigmoid()) 24 | 25 | def forward(self, x): 26 | y = self.l1(x) 27 | y = y.view(y.size(0), -1, 4, 4) 28 | y = self.l2_5(y) 29 | return y 30 | 31 | class GeneratorMNIST(nn.Module): 32 | def __init__(self, in_dim=100, dim=64): 33 | super(GeneratorMNIST, self).__init__() 34 | def dconv_bn_relu(in_dim, out_dim): 35 | return nn.Sequential( 36 | nn.ConvTranspose2d(in_dim, out_dim, 5, 2, 37 | padding=2, output_padding=1, bias=False), 38 | nn.BatchNorm2d(out_dim), 39 | nn.ReLU()) 40 | 41 | self.l1 = nn.Sequential( 42 | nn.Linear(in_dim, dim * 4 * 4 * 4, bias=False), 43 | nn.BatchNorm1d(dim * 4 * 4 * 4), 44 | nn.ReLU()) 45 | self.l2_5 = nn.Sequential( 46 | dconv_bn_relu(dim * 4, dim * 2), 47 | dconv_bn_relu(dim * 2, dim), 48 | nn.ConvTranspose2d(dim, 1, 5, 2, padding=2, output_padding=1), 49 | nn.Sigmoid()) 50 | 51 | def forward(self, x): 52 | y = self.l1(x) 53 | y = y.view(y.size(0), -1, 4, 4) 54 | y = self.l2_5(y) 55 | return y 56 | 57 | class CompletionNetwork(nn.Module): 58 | def __init__(self): 59 | super(CompletionNetwork, self).__init__() 60 | # input_shape: (None, 4, img_h, img_w) 61 | self.conv1 = nn.Conv2d(4, 32, kernel_size=5, stride=1, padding=2) 62 | self.bn1 = nn.BatchNorm2d(32) 63 | self.act1 = nn.ReLU() 64 | # input_shape: (None, 64, img_h, img_w) 65 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1) 66 | self.bn2 = nn.BatchNorm2d(64) 67 | self.act2 = nn.ReLU() 68 | # input_shape: (None, 128, img_h//2, img_w//2) 69 | self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 70 | self.bn3 = nn.BatchNorm2d(64) 71 | self.act3 = nn.ReLU() 72 | # input_shape: (None, 128, img_h//2, img_w//2) 73 | self.conv4 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1) 74 | self.bn4 = nn.BatchNorm2d(128) 75 | self.act4 = nn.ReLU() 76 | # input_shape: (None, 256, img_h//4, img_w//4) 77 | self.conv5 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) 78 | self.bn5 = nn.BatchNorm2d(128) 79 | self.act5 = nn.ReLU() 80 | # input_shape: (None, 256, img_h//4, img_w//4) 81 | self.conv6 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) 82 | self.bn6 = nn.BatchNorm2d(128) 83 | self.act6 = nn.ReLU() 84 | # input_shape: (None, 256, img_h//4, img_w//4) 85 | self.conv7 = nn.Conv2d(128, 128, kernel_size=3, stride=1, dilation=2, padding=2) 86 | self.bn7 = nn.BatchNorm2d(128) 87 | self.act7 = nn.ReLU() 88 | # input_shape: (None, 256, img_h//4, img_w//4) 89 | self.conv8 = nn.Conv2d(128, 128, kernel_size=3, stride=1, dilation=4, padding=4) 90 | self.bn8 = nn.BatchNorm2d(128) 91 | self.act8 = nn.ReLU() 92 | # input_shape: (None, 256, img_h//4, img_w//4) 93 | self.conv9 = nn.Conv2d(128, 128, kernel_size=3, stride=1, dilation=8, padding=8) 94 | self.bn9 = nn.BatchNorm2d(128) 95 | self.act9 = nn.ReLU() 96 | # input_shape: (None, 256, img_h//4, img_w//4) 97 | self.conv10 = nn.Conv2d(128, 128, kernel_size=3, stride=1, dilation=16, padding=16) 98 | self.bn10 = nn.BatchNorm2d(128) 99 | self.act10 = nn.ReLU() 100 | # input_shape: (None, 256, img_h//4, img_w//4) 101 | self.conv11 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) 102 | self.bn11 = nn.BatchNorm2d(128) 103 | self.act11 = nn.ReLU() 104 | # input_shape: (None, 256, img_h//4, img_w//4) 105 | self.conv12 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) 106 | self.bn12 = nn.BatchNorm2d(128) 107 | self.act12 = nn.ReLU() 108 | # input_shape: (None, 256, img_h//4, img_w//4) 109 | self.deconv13 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1) 110 | self.bn13 = nn.BatchNorm2d(64) 111 | self.act13 = nn.ReLU() 112 | # input_shape: (None, 128, img_h//2, img_w//2) 113 | self.conv14 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 114 | self.bn14 = nn.BatchNorm2d(64) 115 | self.act14 = nn.ReLU() 116 | # input_shape: (None, 128, img_h//2, img_w//2) 117 | self.deconv15 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1) 118 | self.bn15 = nn.BatchNorm2d(32) 119 | self.act15 = nn.ReLU() 120 | # input_shape: (None, 64, img_h, img_w) 121 | self.conv16 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1) 122 | self.bn16 = nn.BatchNorm2d(32) 123 | self.act16 = nn.ReLU() 124 | # input_shape: (None, 32, img_h, img_w) 125 | self.conv17 = nn.Conv2d(32, 3, kernel_size=3, stride=1, padding=1) 126 | self.act17 = nn.Sigmoid() 127 | # output_shape: (None, 3, img_h. img_w) 128 | 129 | def forward(self, x): 130 | x = self.bn1(self.act1(self.conv1(x))) 131 | x = self.bn2(self.act2(self.conv2(x))) 132 | x = self.bn3(self.act3(self.conv3(x))) 133 | x = self.bn4(self.act4(self.conv4(x))) 134 | x = self.bn5(self.act5(self.conv5(x))) 135 | x = self.bn6(self.act6(self.conv6(x))) 136 | x = self.bn7(self.act7(self.conv7(x))) 137 | x = self.bn8(self.act8(self.conv8(x))) 138 | x = self.bn9(self.act9(self.conv9(x))) 139 | x = self.bn10(self.act10(self.conv10(x))) 140 | x = self.bn11(self.act11(self.conv11(x))) 141 | x = self.bn12(self.act12(self.conv12(x))) 142 | x = self.bn13(self.act13(self.deconv13(x))) 143 | x = self.bn14(self.act14(self.conv14(x))) 144 | x = self.bn15(self.act15(self.deconv15(x))) 145 | x = self.bn16(self.act16(self.conv16(x))) 146 | x = self.act17(self.conv17(x)) 147 | return x 148 | 149 | def dconv_bn_relu(in_dim, out_dim): 150 | return nn.Sequential( 151 | nn.ConvTranspose2d(in_dim, out_dim, 5, 2, 152 | padding=2, output_padding=1, bias=False), 153 | nn.BatchNorm2d(out_dim), 154 | nn.ReLU()) 155 | 156 | class ContextNetwork(nn.Module): 157 | def __init__(self): 158 | super(ContextNetwork, self).__init__() 159 | # input_shape: (None, 4, img_h, img_w) 160 | self.conv1 = nn.Conv2d(4, 32, kernel_size=5, stride=1, padding=2) 161 | self.bn1 = nn.BatchNorm2d(32) 162 | self.act1 = nn.ReLU() 163 | # input_shape: (None, 32, img_h, img_w) 164 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1) 165 | self.bn2 = nn.BatchNorm2d(64) 166 | self.act2 = nn.ReLU() 167 | # input_shape: (None, 64, img_h//2, img_w//2) 168 | self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 169 | self.bn3 = nn.BatchNorm2d(64) 170 | self.act3 = nn.ReLU() 171 | # input_shape: (None, 128, img_h//2, img_w//2) 172 | self.conv4 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1) 173 | self.bn4 = nn.BatchNorm2d(128) 174 | self.act4 = nn.ReLU() 175 | # input_shape: (None, 128, img_h//4, img_w//4) 176 | self.conv5 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) 177 | self.bn5 = nn.BatchNorm2d(128) 178 | self.act5 = nn.ReLU() 179 | # input_shape: (None, 128, img_h//4, img_w//4) 180 | self.conv6 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) 181 | self.bn6 = nn.BatchNorm2d(128) 182 | self.act6 = nn.ReLU() 183 | # input_shape: (None, 128, img_h//4, img_w//4) 184 | self.conv7 = nn.Conv2d(128, 128, kernel_size=3, stride=1, dilation=2, padding=2) 185 | self.bn7 = nn.BatchNorm2d(128) 186 | self.act7 = nn.ReLU() 187 | # input_shape: (None, 128, img_h//4, img_w//4) 188 | self.conv8 = nn.Conv2d(128, 128, kernel_size=3, stride=1, dilation=4, padding=4) 189 | self.bn8 = nn.BatchNorm2d(128) 190 | self.act8 = nn.ReLU() 191 | # input_shape: (None, 128, img_h//4, img_w//4) 192 | self.conv9 = nn.Conv2d(128, 128, kernel_size=3, stride=1, dilation=8, padding=8) 193 | self.bn9 = nn.BatchNorm2d(128) 194 | self.act9 = nn.ReLU() 195 | # input_shape: (None, 128, img_h//4, img_w//4) 196 | self.conv10 = nn.Conv2d(128, 128, kernel_size=3, stride=1, dilation=16, padding=16) 197 | self.bn10 = nn.BatchNorm2d(128) 198 | self.act10 = nn.ReLU() 199 | 200 | 201 | 202 | def forward(self, x): 203 | x = self.bn1(self.act1(self.conv1(x))) 204 | x = self.bn2(self.act2(self.conv2(x))) 205 | x = self.bn3(self.act3(self.conv3(x))) 206 | x = self.bn4(self.act4(self.conv4(x))) 207 | x = self.bn5(self.act5(self.conv5(x))) 208 | x = self.bn6(self.act6(self.conv6(x))) 209 | x = self.bn7(self.act7(self.conv7(x))) 210 | x = self.bn8(self.act8(self.conv8(x))) 211 | x = self.bn9(self.act9(self.conv9(x))) 212 | x = self.bn10(self.act10(self.conv10(x))) 213 | return x 214 | 215 | class IdentityGenerator(nn.Module): 216 | 217 | def __init__(self, in_dim = 100, dim=64): 218 | super(IdentityGenerator, self).__init__() 219 | 220 | self.l1 = nn.Sequential( 221 | nn.Linear(in_dim, dim * 8 * 4 * 4, bias=False), 222 | nn.BatchNorm1d(dim * 8 * 4 * 4), 223 | nn.ReLU()) 224 | self.l2_5 = nn.Sequential( 225 | dconv_bn_relu(dim * 8, dim * 4), 226 | dconv_bn_relu(dim * 4, dim * 2)) 227 | 228 | def forward(self, x): 229 | y = self.l1(x) 230 | y = y.view(y.size(0), -1, 4, 4) 231 | y = self.l2_5(y) 232 | return y 233 | 234 | class InversionNet(nn.Module): 235 | def __init__(self, out_dim = 128): 236 | super(InversionNet, self).__init__() 237 | 238 | # input [4, h, w] output [256, h // 4, w // 4] 239 | self.ContextNetwork = ContextNetwork() 240 | # input [z_dim] output[128, 16, 16] 241 | self.IdentityGenerator = IdentityGenerator() 242 | 243 | self.dim = 128 + 128 244 | self.out_dim = out_dim 245 | 246 | self.Dconv = nn.Sequential( 247 | dconv_bn_relu(self.dim, self.out_dim), 248 | dconv_bn_relu(self.out_dim, self.out_dim // 2)) 249 | 250 | self.Conv = nn.Sequential( 251 | nn.Conv2d(self.out_dim // 2, self.out_dim // 4, kernel_size=3, stride=1, padding=1), 252 | nn.BatchNorm2d(self.out_dim // 4), 253 | nn.ReLU(), 254 | nn.Conv2d(self.out_dim // 4, 3, kernel_size=3, stride=1, padding=1), 255 | nn.Sigmoid()) 256 | 257 | 258 | def forward(self, inp): 259 | # x.shape [4, h, w] z.shape [100] 260 | x, z = inp 261 | context_info = self.ContextNetwork(x) 262 | identity_info = self.IdentityGenerator(z) 263 | y = torch.cat((context_info, identity_info), dim=1) 264 | y = self.Dconv(y) 265 | y = self.Conv(y) 266 | 267 | return y 268 | -------------------------------------------------------------------------------- /facenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, ReLU, Sigmoid, Dropout, MaxPool2d, \ 4 | AdaptiveAvgPool2d, Sequential, Module 5 | from collections import namedtuple 6 | 7 | 8 | # Support: ['IR_50', 'IR_101', 'IR_152', 'IR_SE_50', 'IR_SE_101', 'IR_SE_152'] 9 | class FaceNet(nn.Module): 10 | def __init__(self, num_classes = 1000): 11 | super(FaceNet, self).__init__() 12 | self.feature = IR_50_112((112, 112)) 13 | self.feat_dim = 512 14 | self.num_classes = num_classes 15 | self.fc_layer = nn.Sequential( 16 | nn.Linear(self.feat_dim, self.num_classes), 17 | nn.Softmax(dim = 1)) 18 | 19 | def forward(self, x): 20 | feat = self.feature(x) 21 | feat = feat.view(feat.size(0), -1) 22 | out = self.fc_layer(feat) 23 | __, iden = torch.max(out, dim = 1) 24 | iden = iden.view(-1, 1) 25 | return feat, out, iden 26 | 27 | class FaceNet64(nn.Module): 28 | def __init__(self, num_classes = 1000): 29 | super(FaceNet64, self).__init__() 30 | self.feature = IR_50_64((64, 64)) 31 | self.feat_dim = 512 32 | self.num_classes = num_classes 33 | self.output_layer = nn.Sequential(nn.BatchNorm2d(512), 34 | nn.Dropout(), 35 | Flatten(), 36 | nn.Linear(512 * 4 * 4, 512), 37 | nn.BatchNorm1d(512)) 38 | 39 | self.fc_layer = nn.Sequential( 40 | nn.Linear(self.feat_dim, self.num_classes), 41 | nn.Softmax(dim = 1)) 42 | 43 | def forward(self, x): 44 | feat = self.feature(x) 45 | feat = self.output_layer(feat) 46 | feat = feat.view(feat.size(0), -1) 47 | out = self.fc_layer(feat) 48 | __, iden = torch.max(out, dim = 1) 49 | iden = iden.view(-1, 1) 50 | return feat, out, iden 51 | 52 | class Flatten(Module): 53 | def forward(self, input): 54 | return input.view(input.size(0), -1) 55 | 56 | 57 | def l2_norm(input, axis=1): 58 | norm = torch.norm(input, 2, axis, True) 59 | output = torch.div(input, norm) 60 | 61 | return output 62 | 63 | 64 | class SEModule(Module): 65 | def __init__(self, channels, reduction): 66 | super(SEModule, self).__init__() 67 | self.avg_pool = AdaptiveAvgPool2d(1) 68 | self.fc1 = Conv2d( 69 | channels, channels // reduction, kernel_size=1, padding=0, bias=False) 70 | 71 | nn.init.xavier_uniform_(self.fc1.weight.data) 72 | 73 | self.relu = ReLU(inplace=True) 74 | self.fc2 = Conv2d( 75 | channels // reduction, channels, kernel_size=1, padding=0, bias=False) 76 | 77 | self.sigmoid = Sigmoid() 78 | 79 | def forward(self, x): 80 | module_input = x 81 | x = self.avg_pool(x) 82 | x = self.fc1(x) 83 | x = self.relu(x) 84 | x = self.fc2(x) 85 | x = self.sigmoid(x) 86 | 87 | return module_input * x 88 | 89 | 90 | class bottleneck_IR(Module): 91 | def __init__(self, in_channel, depth, stride): 92 | super(bottleneck_IR, self).__init__() 93 | if in_channel == depth: 94 | self.shortcut_layer = MaxPool2d(1, stride) 95 | else: 96 | self.shortcut_layer = Sequential( 97 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), BatchNorm2d(depth)) 98 | self.res_layer = Sequential( 99 | BatchNorm2d(in_channel), 100 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), 101 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)) 102 | 103 | def forward(self, x): 104 | shortcut = self.shortcut_layer(x) 105 | res = self.res_layer(x) 106 | 107 | return res + shortcut 108 | 109 | 110 | class bottleneck_IR_SE(Module): 111 | def __init__(self, in_channel, depth, stride): 112 | super(bottleneck_IR_SE, self).__init__() 113 | if in_channel == depth: 114 | self.shortcut_layer = MaxPool2d(1, stride) 115 | else: 116 | self.shortcut_layer = Sequential( 117 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 118 | BatchNorm2d(depth)) 119 | self.res_layer = Sequential( 120 | BatchNorm2d(in_channel), 121 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), 122 | PReLU(depth), 123 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), 124 | BatchNorm2d(depth), 125 | SEModule(depth, 16) 126 | ) 127 | 128 | def forward(self, x): 129 | shortcut = self.shortcut_layer(x) 130 | res = self.res_layer(x) 131 | 132 | return res + shortcut 133 | 134 | 135 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): 136 | '''A named tuple describing a ResNet block.''' 137 | 138 | 139 | def get_block(in_channel, depth, num_units, stride=2): 140 | 141 | return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] 142 | 143 | 144 | def get_blocks(num_layers): 145 | if num_layers == 50: 146 | blocks = [ 147 | get_block(in_channel=64, depth=64, num_units=3), 148 | get_block(in_channel=64, depth=128, num_units=4), 149 | get_block(in_channel=128, depth=256, num_units=14), 150 | get_block(in_channel=256, depth=512, num_units=3) 151 | ] 152 | elif num_layers == 100: 153 | blocks = [ 154 | get_block(in_channel=64, depth=64, num_units=3), 155 | get_block(in_channel=64, depth=128, num_units=13), 156 | get_block(in_channel=128, depth=256, num_units=30), 157 | get_block(in_channel=256, depth=512, num_units=3) 158 | ] 159 | elif num_layers == 152: 160 | blocks = [ 161 | get_block(in_channel=64, depth=64, num_units=3), 162 | get_block(in_channel=64, depth=128, num_units=8), 163 | get_block(in_channel=128, depth=256, num_units=36), 164 | get_block(in_channel=256, depth=512, num_units=3) 165 | ] 166 | 167 | return blocks 168 | 169 | 170 | class Backbone64(Module): 171 | def __init__(self, input_size, num_layers, mode='ir'): 172 | super(Backbone64, self).__init__() 173 | assert input_size[0] in [64], "input_size should be [112, 112] or [224, 224]" 174 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" 175 | assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" 176 | blocks = get_blocks(num_layers) 177 | if mode == 'ir': 178 | unit_module = bottleneck_IR 179 | elif mode == 'ir_se': 180 | unit_module = bottleneck_IR_SE 181 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), 182 | BatchNorm2d(64), 183 | PReLU(64)) 184 | 185 | self.output_layer = Sequential(BatchNorm2d(512), 186 | Dropout(), 187 | Flatten(), 188 | Linear(512 * 14 * 14, 512), 189 | BatchNorm1d(512)) 190 | 191 | modules = [] 192 | for block in blocks: 193 | for bottleneck in block: 194 | modules.append( 195 | unit_module(bottleneck.in_channel, 196 | bottleneck.depth, 197 | bottleneck.stride)) 198 | self.body = Sequential(*modules) 199 | 200 | self._initialize_weights() 201 | 202 | def forward(self, x): 203 | x = self.input_layer(x) 204 | x = self.body(x) 205 | 206 | return x 207 | 208 | def _initialize_weights(self): 209 | for m in self.modules(): 210 | if isinstance(m, nn.Conv2d): 211 | nn.init.xavier_uniform_(m.weight.data) 212 | if m.bias is not None: 213 | m.bias.data.zero_() 214 | elif isinstance(m, nn.BatchNorm2d): 215 | m.weight.data.fill_(1) 216 | m.bias.data.zero_() 217 | elif isinstance(m, nn.BatchNorm1d): 218 | m.weight.data.fill_(1) 219 | m.bias.data.zero_() 220 | elif isinstance(m, nn.Linear): 221 | nn.init.xavier_uniform_(m.weight.data) 222 | if m.bias is not None: 223 | m.bias.data.zero_() 224 | 225 | class Backbone112(Module): 226 | def __init__(self, input_size, num_layers, mode='ir'): 227 | super(Backbone112, self).__init__() 228 | assert input_size[0] in [112], "input_size should be [112, 112] or [224, 224]" 229 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" 230 | assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" 231 | blocks = get_blocks(num_layers) 232 | if mode == 'ir': 233 | unit_module = bottleneck_IR 234 | elif mode == 'ir_se': 235 | unit_module = bottleneck_IR_SE 236 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), 237 | BatchNorm2d(64), 238 | PReLU(64)) 239 | 240 | if input_size[0] == 112: 241 | self.output_layer = Sequential(BatchNorm2d(512), 242 | Dropout(), 243 | Flatten(), 244 | Linear(512 * 7 * 7, 512), 245 | BatchNorm1d(512)) 246 | 247 | modules = [] 248 | for block in blocks: 249 | for bottleneck in block: 250 | modules.append( 251 | unit_module(bottleneck.in_channel, 252 | bottleneck.depth, 253 | bottleneck.stride)) 254 | self.body = Sequential(*modules) 255 | 256 | self._initialize_weights() 257 | 258 | def forward(self, x): 259 | x = self.input_layer(x) 260 | x = self.body(x) 261 | x = self.output_layer(x) 262 | 263 | return x 264 | 265 | def _initialize_weights(self): 266 | for m in self.modules(): 267 | if isinstance(m, nn.Conv2d): 268 | nn.init.xavier_uniform_(m.weight.data) 269 | if m.bias is not None: 270 | m.bias.data.zero_() 271 | elif isinstance(m, nn.BatchNorm2d): 272 | m.weight.data.fill_(1) 273 | m.bias.data.zero_() 274 | elif isinstance(m, nn.BatchNorm1d): 275 | m.weight.data.fill_(1) 276 | m.bias.data.zero_() 277 | elif isinstance(m, nn.Linear): 278 | nn.init.xavier_uniform_(m.weight.data) 279 | if m.bias is not None: 280 | m.bias.data.zero_() 281 | 282 | 283 | def IR_50_64(input_size): 284 | """Constructs a ir-50 model. 285 | """ 286 | model = Backbone64(input_size, 50, 'ir') 287 | 288 | return model 289 | 290 | def IR_50_112(input_size): 291 | """Constructs a ir-50 model. 292 | """ 293 | model = Backbone112(input_size, 50, 'ir') 294 | 295 | return model 296 | 297 | 298 | def IR_101(input_size): 299 | """Constructs a ir-101 model. 300 | """ 301 | model = Backbone(input_size, 100, 'ir') 302 | 303 | return model 304 | 305 | 306 | def IR_152_64(input_size): 307 | """Constructs a ir-152 model. 308 | """ 309 | model = Backbone64(input_size, 152, 'ir') 310 | 311 | return model 312 | 313 | def IR_152_112(input_size): 314 | """Constructs a ir-152 model. 315 | """ 316 | model = Backbone112(input_size, 152, 'ir') 317 | 318 | return model 319 | 320 | 321 | def IR_SE_50(input_size): 322 | """Constructs a ir_se-50 model. 323 | """ 324 | model = Backbone(input_size, 50, 'ir_se') 325 | 326 | return model 327 | 328 | 329 | def IR_SE_101(input_size): 330 | """Constructs a ir_se-101 model. 331 | """ 332 | model = Backbone(input_size, 100, 'ir_se') 333 | 334 | return model 335 | 336 | 337 | def IR_SE_152(input_size): 338 | """Constructs a ir_se-152 model. 339 | """ 340 | model = Backbone(input_size, 152, 'ir_se') 341 | 342 | return model -------------------------------------------------------------------------------- /SAC.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | from collections import namedtuple, deque 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.distributions import Normal, MultivariateNormal 8 | import torch.optim as optim 9 | 10 | GAMMA = 0.99 11 | TAU = 1e-2 12 | HIDDEN_SIZE = 256 13 | BUFFER_SIZE = int(1e6) 14 | BATCH_SIZE = 256 15 | LR_ACTOR = 5e-4 16 | LR_CRITIC = 5e-4 17 | FIXED_ALPHA = None 18 | 19 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 20 | 21 | def hidden_init(layer): 22 | fan_in = layer.weight.data.size()[0] 23 | lim = 1. / np.sqrt(fan_in) 24 | return (-lim, lim) 25 | 26 | class Actor(nn.Module): 27 | """Actor (Policy) Model.""" 28 | 29 | def __init__(self, state_size, action_size, seed, hidden_size=32, init_w=3e-3, log_std_min=-20, log_std_max=2): 30 | """Initialize parameters and build model. 31 | Params 32 | ====== 33 | state_size (int): Dimension of each state 34 | action_size (int): Dimension of each action 35 | seed (int): Random seed 36 | fc1_units (int): Number of nodes in first hidden layer 37 | fc2_units (int): Number of nodes in second hidden layer 38 | """ 39 | super(Actor, self).__init__() 40 | self.seed = torch.manual_seed(seed) 41 | self.log_std_min = log_std_min 42 | self.log_std_max = log_std_max 43 | 44 | self.fc1 = nn.Linear(state_size, hidden_size) 45 | self.fc2 = nn.Linear(hidden_size, hidden_size) 46 | 47 | self.mu = nn.Linear(hidden_size, action_size) 48 | self.log_std_linear = nn.Linear(hidden_size, action_size) 49 | 50 | 51 | def reset_parameters(self): 52 | self.fc1.weight.data.uniform_(*hidden_init(self.fc1)) 53 | self.fc2.weight.data.uniform_(*hidden_init(self.fc2)) 54 | self.mu.weight.data.uniform_(-init_w, init_w) 55 | self.log_std_linear.weight.data.uniform_(-init_w, init_w) 56 | 57 | def forward(self, state): 58 | 59 | x = F.relu(self.fc1(state), inplace=True) 60 | x = F.relu(self.fc2(x), inplace=True) 61 | mu = self.mu(x) 62 | 63 | log_std = self.log_std_linear(x) 64 | log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max) 65 | return mu, log_std 66 | 67 | def evaluate(self, state, epsilon=1e-6): 68 | mu, log_std = self.forward(state) 69 | std = log_std.exp() 70 | dist = Normal(0, 1) 71 | e = dist.sample().to(device) 72 | action = torch.tanh(mu + e * std) 73 | log_prob = (Normal(mu, std).log_prob(mu + e * std) - torch.log(1 - action.pow(2) + epsilon)).mean(1, keepdim=True) 74 | 75 | return action, log_prob 76 | 77 | 78 | def get_action(self, state): 79 | """ 80 | returns the action based on a squashed gaussian policy. That means the samples are obtained according to: 81 | a(s,e)= tanh(mu(s)+sigma(s)+e) 82 | """ 83 | #state = torch.FloatTensor(state).to(device) #.unsqzeeze(0) 84 | mu, log_std = self.forward(state) 85 | std = log_std.exp() 86 | dist = Normal(0, 1) 87 | e = dist.sample().to(device) 88 | action = torch.tanh(mu + e * std).cpu() 89 | #action = torch.clamp(action*action_high, action_low, action_high) 90 | return action[0] 91 | 92 | 93 | class Critic(nn.Module): 94 | """Critic (Value) Model.""" 95 | 96 | def __init__(self, state_size, action_size, seed, hidden_size=32): 97 | """Initialize parameters and build model. 98 | Params 99 | ====== 100 | state_size (int): Dimension of each state 101 | action_size (int): Dimension of each action 102 | seed (int): Random seed 103 | hidden_size (int): Number of nodes in the network layers 104 | """ 105 | super(Critic, self).__init__() 106 | self.seed = torch.manual_seed(seed) 107 | self.fc1 = nn.Linear(state_size+action_size, hidden_size) 108 | self.fc2 = nn.Linear(hidden_size, hidden_size) 109 | self.fc3 = nn.Linear(hidden_size, 1) 110 | self.reset_parameters() 111 | 112 | def reset_parameters(self): 113 | self.fc1.weight.data.uniform_(*hidden_init(self.fc1)) 114 | self.fc2.weight.data.uniform_(*hidden_init(self.fc2)) 115 | self.fc3.weight.data.uniform_(-3e-3, 3e-3) 116 | 117 | def forward(self, state, action): 118 | """Build a critic (value) network that maps (state, action) pairs -> Q-values.""" 119 | x = torch.cat((state, action), dim=1) 120 | x = F.relu(self.fc1(x)) 121 | x = F.relu(self.fc2(x)) 122 | return self.fc3(x) 123 | 124 | class Agent(): 125 | """Interacts with and learns from the environment.""" 126 | 127 | def __init__(self, state_size, action_size, random_seed, hidden_size, action_prior="uniform"): 128 | """Initialize an Agent object. 129 | 130 | Params 131 | ====== 132 | state_size (int): dimension of each state 133 | action_size (int): dimension of each action 134 | random_seed (int): random seed 135 | """ 136 | self.state_size = state_size 137 | self.action_size = action_size 138 | self.seed = random.seed(random_seed) 139 | 140 | self.target_entropy = -action_size # -dim(A) 141 | self.alpha = 1 142 | self.log_alpha = torch.tensor([0.0], requires_grad=True) 143 | self.alpha_optimizer = optim.Adam(params=[self.log_alpha], lr=LR_ACTOR) 144 | self._action_prior = action_prior 145 | 146 | # print("Using: ", device) 147 | 148 | # Actor Network 149 | self.actor_local = Actor(state_size, action_size, random_seed, hidden_size).to(device) 150 | self.actor_optimizer = optim.Adam(self.actor_local.parameters(), lr=LR_ACTOR) 151 | 152 | # Critic Network (w/ Target Network) 153 | self.critic1 = Critic(state_size, action_size, random_seed, hidden_size).to(device) 154 | self.critic2 = Critic(state_size, action_size, random_seed, hidden_size).to(device) 155 | 156 | self.critic1_target = Critic(state_size, action_size, random_seed,hidden_size).to(device) 157 | self.critic1_target.load_state_dict(self.critic1.state_dict()) 158 | 159 | self.critic2_target = Critic(state_size, action_size, random_seed,hidden_size).to(device) 160 | self.critic2_target.load_state_dict(self.critic2.state_dict()) 161 | 162 | self.critic1_optimizer = optim.Adam(self.critic1.parameters(), lr=LR_CRITIC, weight_decay=0) 163 | self.critic2_optimizer = optim.Adam(self.critic2.parameters(), lr=LR_CRITIC, weight_decay=0) 164 | 165 | # Replay memory 166 | self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, random_seed) 167 | 168 | 169 | def step(self, state, action, reward, next_state, done, step): 170 | """Save experience in replay memory, and use random sample from buffer to learn.""" 171 | # Save experience / reward 172 | self.memory.add(state, action, reward, next_state, done) 173 | 174 | # Learn, if enough samples are available in memory 175 | if len(self.memory) > BATCH_SIZE: 176 | experiences = self.memory.sample() 177 | self.learn(step, experiences, GAMMA) 178 | 179 | 180 | def act(self, state): 181 | """Returns actions for given state as per current policy.""" 182 | state = torch.from_numpy(state).float().to(device) 183 | action = self.actor_local.get_action(state).detach() 184 | return action 185 | 186 | def learn(self, step, experiences, gamma, d=1): 187 | """Updates actor, critics and entropy_alpha parameters using given batch of experience tuples. 188 | Q_targets = r + γ * (min_critic_target(next_state, actor_target(next_state)) - α *log_pi(next_action|next_state)) 189 | Critic_loss = MSE(Q, Q_target) 190 | Actor_loss = α * log_pi(a|s) - Q(s,a) 191 | where: 192 | actor_target(state) -> action 193 | critic_target(state, action) -> Q-value 194 | Params 195 | ====== 196 | experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples 197 | gamma (float): discount factor 198 | """ 199 | states, actions, rewards, next_states, dones = experiences 200 | 201 | 202 | # ---------------------------- update critic ---------------------------- # 203 | # Get predicted next-state actions and Q values from target models 204 | next_action, log_pis_next = self.actor_local.evaluate(next_states) 205 | 206 | Q_target1_next = self.critic1_target(next_states.to(device), next_action.squeeze(0).to(device)) 207 | Q_target2_next = self.critic2_target(next_states.to(device), next_action.squeeze(0).to(device)) 208 | 209 | # take the mean of both critics for updating 210 | Q_target_next = torch.min(Q_target1_next, Q_target2_next) 211 | 212 | if FIXED_ALPHA == None: 213 | # Compute Q targets for current states (y_i) 214 | Q_targets = rewards.cpu() + (gamma * (1 - dones.cpu()) * (Q_target_next.cpu() - self.alpha * log_pis_next.squeeze(0).cpu())) 215 | else: 216 | Q_targets = rewards.cpu() + (gamma * (1 - dones.cpu()) * (Q_target_next.cpu() - FIXED_ALPHA * log_pis_next.squeeze(0).cpu())) 217 | 218 | # Compute critic loss 219 | Q_1 = self.critic1(states, actions).cpu() 220 | Q_2 = self.critic2(states, actions).cpu() 221 | 222 | critic1_loss = 0.5*F.mse_loss(Q_1, Q_targets.detach()) 223 | critic2_loss = 0.5*F.mse_loss(Q_2, Q_targets.detach()) 224 | # Update critics 225 | # critic 1 226 | self.critic1_optimizer.zero_grad() 227 | critic1_loss.backward() 228 | self.critic1_optimizer.step() 229 | # critic 2 230 | self.critic2_optimizer.zero_grad() 231 | critic2_loss.backward() 232 | self.critic2_optimizer.step() 233 | if step % d == 0: 234 | # ---------------------------- update actor ---------------------------- # 235 | if FIXED_ALPHA == None: 236 | alpha = torch.exp(self.log_alpha) 237 | # Compute alpha loss 238 | actions_pred, log_pis = self.actor_local.evaluate(states) 239 | alpha_loss = - (self.log_alpha.cpu() * (log_pis.cpu() + self.target_entropy).detach().cpu()).mean() 240 | self.alpha_optimizer.zero_grad() 241 | alpha_loss.backward() 242 | self.alpha_optimizer.step() 243 | 244 | self.alpha = alpha 245 | # Compute actor loss 246 | if self._action_prior == "normal": 247 | policy_prior = MultivariateNormal(loc=torch.zeros(self.action_size), scale_tril=torch.ones(self.action_size).unsqueeze(0)) 248 | policy_prior_log_probs = policy_prior.log_prob(actions_pred) 249 | elif self._action_prior == "uniform": 250 | policy_prior_log_probs = 0.0 251 | 252 | actor_loss = (alpha * log_pis.squeeze(0).cpu() - self.critic1(states, actions_pred.squeeze(0)).cpu() - policy_prior_log_probs ).mean() 253 | else: 254 | 255 | actions_pred, log_pis = self.actor_local.evaluate(states) 256 | if self._action_prior == "normal": 257 | policy_prior = MultivariateNormal(loc=torch.zeros(self.action_size), scale_tril=torch.ones(self.action_size).unsqueeze(0)) 258 | policy_prior_log_probs = policy_prior.log_prob(actions_pred) 259 | elif self._action_prior == "uniform": 260 | policy_prior_log_probs = 0.0 261 | 262 | actor_loss = (FIXED_ALPHA * log_pis.squeeze(0).cpu() - self.critic1(states, actions_pred.squeeze(0)).cpu()- policy_prior_log_probs ).mean() 263 | # Minimize the loss 264 | self.actor_optimizer.zero_grad() 265 | actor_loss.backward() 266 | self.actor_optimizer.step() 267 | 268 | # ----------------------- update target networks ----------------------- # 269 | self.soft_update(self.critic1, self.critic1_target, TAU) 270 | self.soft_update(self.critic2, self.critic2_target, TAU) 271 | 272 | 273 | 274 | def soft_update(self, local_model, target_model, tau): 275 | """Soft update model parameters. 276 | θ_target = τ*θ_local + (1 - τ)*θ_target 277 | Params 278 | ====== 279 | local_model: PyTorch model (weights will be copied from) 280 | target_model: PyTorch model (weights will be copied to) 281 | tau (float): interpolation parameter 282 | """ 283 | for target_param, local_param in zip(target_model.parameters(), local_model.parameters()): 284 | target_param.data.copy_(tau*local_param.data + (1.0-tau)*target_param.data) 285 | 286 | class ReplayBuffer: 287 | """Fixed-size buffer to store experience tuples.""" 288 | 289 | def __init__(self, action_size, buffer_size, batch_size, seed): 290 | """Initialize a ReplayBuffer object. 291 | Params 292 | ====== 293 | buffer_size (int): maximum size of buffer 294 | batch_size (int): size of each training batch 295 | """ 296 | self.action_size = action_size 297 | self.memory = deque(maxlen=buffer_size) # internal memory (deque) 298 | self.batch_size = batch_size 299 | self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"]) 300 | self.seed = random.seed(seed) 301 | 302 | def add(self, state, action, reward, next_state, done): 303 | """Add a new experience to memory.""" 304 | e = self.experience(state, action, reward, next_state, done) 305 | self.memory.append(e) 306 | 307 | def sample(self): 308 | """Randomly sample a batch of experiences from memory.""" 309 | experiences = random.sample(self.memory, k=self.batch_size) 310 | 311 | states = torch.from_numpy(np.vstack([e.state for e in experiences if e is not None])).float().to(device) 312 | actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).float().to(device) 313 | rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(device) 314 | next_states = torch.from_numpy(np.vstack([e.next_state for e in experiences if e is not None])).float().to(device) 315 | dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(device) 316 | 317 | return (states, actions, rewards, next_states, dones) 318 | 319 | def __len__(self): 320 | """Return the current size of internal memory.""" 321 | return len(self.memory) --------------------------------------------------------------------------------