├── README.md ├── RSR_Layers_PyTorch.ipynb ├── requirements.txt └── rsr_layer_pytorch.py /README.md: -------------------------------------------------------------------------------- 1 | # RSRLayer PyTorch 2 | Keywords: **RSRAE, RSR Layer, RSR Layers, RSR AutoEncoder** 3 | PyTorch implementation of Robust Subspace Recovery Layer for Unsupervised Anomaly Detection https://arxiv.org/abs/1904.00152 4 | 5 | **For more details, see my blogpost:** 6 | 7 | https://zablo.net/blog/post/robust-space-recovery-rsrlayer-pytorch/ 8 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | plotly==4.4.1 2 | pytorch-lightning==1.2.8 3 | scikit-learn~=0.22.2 4 | torch~=1.8.1 5 | torchvision 6 | numpy==1.19.5 7 | pandas==1.1.5 8 | -------------------------------------------------------------------------------- /rsr_layer_pytorch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class RSRLayer(nn.Module): 7 | def __init__(self, d: int, D: int): 8 | super().__init__() 9 | self.d = d 10 | self.D = D 11 | self.A = nn.Parameter(torch.nn.init.orthogonal_(torch.empty(d, D))) 12 | 13 | def forward(self, z): 14 | # z is the output from the encoder 15 | z_hat = self.A @ z.view(z.size(0), self.D, 1) 16 | return z_hat.squeeze(2) 17 | 18 | 19 | class RSRLoss(nn.Module): 20 | def __init__(self, lambda1, lambda2, d, D): 21 | super().__init__() 22 | self.lambda1 = lambda1 23 | self.lambda2 = lambda2 24 | self.d = d 25 | self.D = D 26 | self.register_buffer("Id", torch.eye(d)) 27 | 28 | def forward(self, z, A): 29 | z_hat = A @ z.view(z.size(0), self.D, 1) 30 | AtAz = (A.T @ z_hat).squeeze(2) 31 | term1 = torch.sum(torch.norm(z - AtAz, p=2)) 32 | 33 | term2 = torch.norm(A @ A.T - self.Id, p=2) ** 2 34 | 35 | return self.lambda1 * term1 + self.lambda2 * term2 36 | 37 | 38 | class L2p_Loss(nn.Module): 39 | def __init__(self, p=1.0): 40 | super().__init__() 41 | self.p = p 42 | 43 | def forward(self, y_hat, y): 44 | return torch.sum(torch.pow(torch.norm(y - y_hat, p=2), self.p)) 45 | 46 | 47 | class RSRAutoEncoder(nn.Module): 48 | def __init__(self, input_dim, d, D): 49 | super().__init__() 50 | # Put your encoder network here, remember about the output D-dimension 51 | self.encoder = nn.Sequential( 52 | nn.Linear(input_dim, input_dim // 2), 53 | nn.LeakyReLU(), 54 | nn.Linear(input_dim // 2, input_dim // 4), 55 | nn.LeakyReLU(), 56 | nn.Linear(input_dim // 4, D), 57 | ) 58 | 59 | self.rsr = RSRLayer(d, D) 60 | 61 | # Put your decoder network here, rembember about the input d-dimension 62 | self.decoder = nn.Sequential( 63 | nn.Linear(d, D), 64 | nn.LeakyReLU(), 65 | nn.Linear(D, input_dim // 2), 66 | nn.LeakyReLU(), 67 | nn.Linear(input_dim // 2, input_dim), 68 | ) 69 | 70 | def forward(self, x): 71 | enc = self.encoder(x) # obtain the embedding from the encoder 72 | latent = self.rsr(enc) # RSR manifold 73 | dec = self.decoder(latent) # obtain the representation in the input space 74 | return enc, dec, latent, self.rsr.A 75 | 76 | 77 | import pytorch_lightning as pl 78 | 79 | pl.seed_everything(666) 80 | 81 | 82 | from torchvision.datasets import MNIST 83 | from torchvision.transforms import ToTensor 84 | 85 | # https://github.com/pytorch/vision/issues/1938#issuecomment-594623431 86 | from six.moves import urllib 87 | 88 | opener = urllib.request.build_opener() 89 | opener.addheaders = [("User-agent", "Mozilla/5.0")] 90 | urllib.request.install_opener(opener) 91 | 92 | mnist = MNIST(".", download=True, transform=ToTensor()) 93 | 94 | 95 | class RSRDs(torch.utils.data.Dataset): 96 | def __init__(self, target_class, other_classes, n_examples_per_other): 97 | super().__init__() 98 | self.mnist = MNIST(".", download=True, transform=ToTensor()) 99 | self.target_indices = (self.mnist.targets == target_class).nonzero().flatten() 100 | 101 | other = [] 102 | for other_class in other_classes: 103 | other.extend( 104 | (self.mnist.targets == other_class) 105 | .nonzero() 106 | .flatten()[:n_examples_per_other] 107 | ) 108 | self.other_indices = torch.tensor(other) 109 | self.all_indices = torch.cat([self.other_indices, self.target_indices]) 110 | print(f"Targets: {self.target_indices.size(0)}") 111 | print(f"Others : {self.other_indices.size(0)}") 112 | 113 | def __getitem__(self, idx): 114 | actual_idx = self.all_indices[idx].item() 115 | return self.mnist[actual_idx] 116 | 117 | def __len__(self): 118 | return self.all_indices.size(0) 119 | 120 | 121 | ds = RSRDs(target_class=4, other_classes=(0, 1, 2, 8), n_examples_per_other=100) 122 | 123 | 124 | class RSRAE(pl.LightningModule): 125 | def __init__(self, hparams): 126 | super().__init__() 127 | self.hparams = hparams 128 | self.ae = RSRAutoEncoder(self.hparams.input_dim, self.hparams.d, self.hparams.D) 129 | self.reconstruction_loss = L2p_Loss(p=1.0) 130 | self.rsr_loss = RSRLoss( 131 | self.hparams.lambda1, self.hparams.lambda2, self.hparams.d, self.hparams.D 132 | ) 133 | 134 | def forward(self, x): 135 | return self.ae(x) 136 | 137 | def training_step(self, batch, batch_idx): 138 | X, _ = batch 139 | x = X.view(X.size(0), -1) 140 | enc, dec, latent, A = self.ae(x) 141 | 142 | rec_loss = self.reconstruction_loss(torch.sigmoid(dec), x) 143 | rsr_loss = self.rsr_loss(enc, A) 144 | loss = rec_loss + rsr_loss 145 | 146 | # log some usefull stuff 147 | self.log( 148 | "reconstruction_loss", 149 | rec_loss.item(), 150 | on_step=True, 151 | on_epoch=False, 152 | prog_bar=True, 153 | ) 154 | self.log( 155 | "rsr_loss", rsr_loss.item(), on_step=True, on_epoch=False, prog_bar=True 156 | ) 157 | return {"loss": loss} 158 | 159 | def configure_optimizers(self): 160 | opt = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr) 161 | # Fast.AI's best practices :) 162 | scheduler = torch.optim.lr_scheduler.OneCycleLR( 163 | opt, 164 | max_lr=self.hparams.lr, 165 | epochs=self.hparams.epochs, 166 | steps_per_epoch=self.hparams.steps_per_epoch, 167 | ) 168 | return [opt], [{"scheduler": scheduler, "interval": "step"}] 169 | 170 | 171 | from torch.utils.data import DataLoader 172 | 173 | dl = DataLoader(ds, batch_size=64, shuffle=True, drop_last=True) 174 | 175 | hparams = dict( 176 | d=16, 177 | D=128, 178 | input_dim=28 * 28, 179 | # Peak learning rate 180 | lr=0.01, 181 | # Configuration for the OneCycleLR scheduler 182 | epochs=150, 183 | steps_per_epoch=len(dl), 184 | # lambda coefficients from RSR Loss 185 | lambda1=1.0, 186 | lambda2=1.0, 187 | ) 188 | model = RSRAE(hparams) 189 | model 190 | 191 | trainer = pl.Trainer(max_epochs=model.hparams.epochs, gpus=1) 192 | 193 | trainer.fit(model, dl) 194 | 195 | 196 | model.freeze() 197 | from torchvision.transforms import functional as tvf 198 | from torchvision.utils import make_grid 199 | 200 | 201 | def reconstruct(x, model): 202 | enc, x_hat, latent, A = model(x.view(1, -1)) 203 | # x_img = tvf.to_pil_image( 204 | # x_hat.squeeze(0).view(1, 28, 28) 205 | # ) 206 | x_hat = torch.sigmoid(x_hat) 207 | return tvf.to_pil_image(make_grid([x_hat.squeeze(0).view(1, 28, 28), x])) 208 | 209 | 210 | tvf.to_pil_image( 211 | make_grid( 212 | [ 213 | tvf.to_tensor(reconstruct(ds[i][0], model)) 214 | for i in torch.randint(0, len(ds), (8,)) 215 | ], 216 | nrow=1, 217 | ) 218 | ) 219 | 220 | import pandas as pd 221 | 222 | rsr_embeddings = [] 223 | classes = [] 224 | errors = [] 225 | for batch in iter(DataLoader(ds, batch_size=64, shuffle=False)): 226 | X, cl = batch 227 | x = X.view(X.size(0), -1) 228 | enc, x_hat, latent, A = model(x) 229 | rsr_embeddings.append(latent) 230 | classes.extend(cl.numpy()) 231 | for i in range(X.size(0)): 232 | rec_error = L2p_Loss()(torch.sigmoid(x_hat[i]).unsqueeze(0), x[i].unsqueeze(0)) 233 | errors.append(float(rec_error.numpy())) 234 | 235 | all_embs = torch.vstack(rsr_embeddings) 236 | df = pd.DataFrame( 237 | all_embs.numpy(), 238 | columns=["x", "y", "z"] + [f"dim_{i}" for i in range(hparams["d"] - 3)], 239 | ) 240 | df.loc[:, "class"] = classes 241 | df.loc[:, "errors"] = errors 242 | 243 | from sklearn.decomposition import PCA 244 | 245 | pca = PCA(n_components=3) 246 | rsr_3d = pd.DataFrame(pca.fit_transform(all_embs), columns=["x", "y", "z"]) 247 | rsr_3d.loc[:, "class"] = classes 248 | 249 | import plotly.express as px 250 | import plotly 251 | 252 | df = df 253 | fig = px.scatter_3d( 254 | df, x="x", y="y", z="z", symbol="class", color="class", opacity=0.95 255 | ) 256 | fig.update_layout(margin=dict(l=0, r=0, b=0, t=0)) 257 | fig.show() 258 | 259 | df.errors.describe() 260 | 261 | df.groupby("class").errors.hist(legend=True, bins=60, figsize=(12, 12)) 262 | 263 | 264 | lowest_mistakes = ( 265 | df.sort_values(by="errors", ascending=True).head(60).loc[:, ["errors", "class"]] 266 | ) 267 | highest_mistakes = ( 268 | df.sort_values(by="errors", ascending=False).head(60).loc[:, ["errors", "class"]] 269 | ) 270 | highest_mistakes.head(10) 271 | 272 | 273 | print("Images with the highest reconsturction loss") 274 | tvf.to_pil_image( 275 | make_grid( 276 | [tvf.to_tensor(reconstruct(ds[i][0], model)) for i in highest_mistakes.index], 277 | nrow=6, 278 | pad_value=0.5, 279 | ) 280 | ) 281 | 282 | print("Images with the lowest reconsturction loss") 283 | tvf.to_pil_image( 284 | make_grid( 285 | [tvf.to_tensor(reconstruct(ds[i][0], model)) for i in lowest_mistakes.index], 286 | nrow=6, 287 | pad_value=0.5, 288 | ) 289 | ) 290 | --------------------------------------------------------------------------------