├── README.md ├── models.py └── qcnn.py /README.md: -------------------------------------------------------------------------------- 1 | ## Quaternion CNN 2 | 3 | This repository contains code for [Rotation-invariant gait identification with quaternion convolutional neural networks](https://arxiv.org/abs/2008.07393), arXiv 2020, by B Jing, V Prabhu, A Gu, and J Whaley. 4 | 5 | Due to privacy considerations, we are not able to release the datasets or the training codes used to train the models described in the paper. This repository contains only the implementations of the QCNN kernels and model architecture, which should be sufficient to use QCNN in other domains of interest. 6 | 7 | All quaternions should be represented as tensors of shape `(4, 1)`, with the real part in the zeroth index. A tensor quaternion thought of as having shape `dims`, for example, should therefore actually have shape `(*dims, 4, 1)`. 8 | 9 | The implementation of the quaternion kernels are in `qcnn.py`. The main exports of interest are: 10 | * `qcnn.QConv1d` Quaternion convolutional kernel, accepts arguments `inchannels, outchannels, filterlen, stride=1` and tensors of shape `(batch, in channels, in time, 4, 1)`. 11 | * `qcnn.QBatchNorm1d` Quaternion batch norm, accepts arguments `*dims, momentum=0.1` and tensors of shape `(batch, *dims, time, 4, 1)`. 12 | * `qcnn.cuda()` Call this once to prepare the kernel to run on GPU. 13 | 14 | Some utility exports of interest are 15 | * `qcnn.checkGrad()` Call this to check that the implementations of quaternion gradients are correct. 16 | * `qcnn.checkEquivariant()` Call this to check that the quaternion kernel is equivariant. 17 | * `qcnn.qconj(q)` Returns the conjugate of `q`. 18 | * `qcnn.qnormsq(q)` Returns the squared norm of `q`. 19 | * `qcnn.qnorm(q)` Returns the norm of `q`. 20 | * `qcnn.qinv(q)` Returns the inverse of `q`. 21 | * `qcnn.rotate(q, r)` Rotates quaternion `q` by rotation quaternion `r`. 22 | 23 | Example usages can be found in `models.py`, which defines a CNN and QCNN in parallel to illustrate the similar usages of the `nn.Conv1D`, `qcnn.QConv1d`, `nn.BatchNorm1D`, and `qcnn.QBatchNorm1d`. These were also the models used for the multi-user experiments in the paper. 24 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import qcnn 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class QCNN(nn.Module): 6 | def cuda(self): 7 | qcnn.cuda() 8 | return super(QCNN, self).cuda() 9 | 10 | def __init__(self, outclasses): 11 | super(QCNN, self).__init__() 12 | 13 | self.qnorm1 = qcnn.QBatchNorm1d(1) # New layer 14 | 15 | self.qconv1 = qcnn.QConv1d(1, 16, 5) # 16 * 96 16 | self.qnorm1 = qcnn.QBatchNorm1d(16) 17 | 18 | # time = 96 19 | self.qconv2 = qcnn.QConv1d(16, 32, 7, stride=2) 20 | self.qnorm2 = qcnn.QBatchNorm1d(32) 21 | 22 | # time = 45 23 | # invariant 24 | 25 | self.conv3 = nn.Conv1d(32, 32, 5) 26 | self.norm3 = nn.BatchNorm1d(32) 27 | 28 | # time = 41 29 | self.conv4 = nn.Conv1d(32, 32, 7) 30 | self.norm4 = nn.BatchNorm1d(32) 31 | 32 | # Now 32 channels x 35 time 33 | 34 | self.dense5 = nn.Linear(32*35, 768) 35 | self.dense6 = nn.Linear(768, outclasses) 36 | 37 | 38 | def forward(self, x): 39 | x = self.qconv1(x) 40 | x = self.qnorm1(x) 41 | 42 | x = self.qconv2(x) 43 | x = self.qnorm2(x) 44 | 45 | x = qcnn.qnorm(x) 46 | 47 | x = F.relu(self.conv3(x)) 48 | x = self.norm3(x) 49 | 50 | x = F.relu(self.conv4(x)) 51 | x = self.norm4(x) 52 | 53 | x = x.view(-1, 32*35) 54 | x = F.selu(self.dense5(x)) 55 | 56 | logits = self.dense6(x) 57 | return F.softmax(logits, dim=1) 58 | 59 | class CNN(nn.Module): 60 | 61 | def __init__(self, outclasses): 62 | super(CNN, self).__init__() 63 | 64 | self.conv1 = nn.Conv1d(3, 64, 5) # 16 * 96 65 | self.norm1 = nn.BatchNorm1d(64) 66 | 67 | # time = 96 68 | self.conv2 = nn.Conv1d(64, 32, 7, stride=2) 69 | self.norm2 = nn.BatchNorm1d(32) 70 | 71 | # time = 45 72 | # invariant 73 | 74 | self.conv3 = nn.Conv1d(32, 32, 5) 75 | self.norm3 = nn.BatchNorm1d(32) 76 | 77 | # time = 41 78 | self.conv4 = nn.Conv1d(32, 32, 7) 79 | self.norm4 = nn.BatchNorm1d(32) 80 | 81 | # Now 32 channels x 35 time 82 | 83 | self.dense5 = nn.Linear(32*35, 768) 84 | self.dense6 = nn.Linear(768, outclasses) 85 | 86 | 87 | def forward(self, x): 88 | x = F.relu(self.conv1(x)) 89 | x = self.norm1(x) 90 | 91 | x = F.relu(self.conv2(x)) 92 | x = self.norm2(x) 93 | 94 | x = F.relu(self.conv3(x)) 95 | x = self.norm3(x) 96 | 97 | x = F.relu(self.conv4(x)) 98 | x = self.norm4(x) 99 | 100 | x = x.view(-1, 32*35) 101 | x = F.selu(self.dense5(x)) 102 | 103 | logits = self.dense6(x) 104 | return F.softmax(logits, dim=1) 105 | -------------------------------------------------------------------------------- /qcnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Function 4 | import math 5 | 6 | 7 | Qmt = torch.Tensor([[ 8 | [1, 0, 0, 0], 9 | [0, 1, 0, 0], 10 | [0, 0, 1, 0], 11 | [0, 0, 0, 1] 12 | ],[ 13 | [0, -1, 0, 0], 14 | [1, 0, 0, 0], 15 | [0, 0, 0, -1], 16 | [0, 0, 1, 0] 17 | ],[ 18 | [0, 0, -1, 0], 19 | [0, 0, 0, 1], 20 | [1, 0, 0, 0], 21 | [0, -1, 0, 0] 22 | ],[ 23 | [0, 0, 0, -1], 24 | [0, 0, -1, 0], 25 | [0, 1, 0, 0], 26 | [1, 0, 0, 0] 27 | ]]).float() 28 | 29 | 30 | Qmt2 = torch.Tensor([[ 31 | [1, 0, 0, 0], 32 | [0, 1, 0, 0], 33 | [0, 0, 1, 0], 34 | [0, 0, 0, 1] 35 | ],[ 36 | [0, -1, 0, 0], 37 | [1, 0, 0, 0], 38 | [0, 0, 0, 1], 39 | [0, 0, -1, 0] 40 | ],[ 41 | [0, 0, -1, 0], 42 | [0, 0, 0, -1], 43 | [1, 0, 0, 0], 44 | [0, 1, 0, 0] 45 | ],[ 46 | [0, 0, 0, -1], 47 | [0, 0, 1, 0], 48 | [0, -1, 0, 0], 49 | [1, 0, 0, 0] 50 | ]]).float() 51 | 52 | Qmt3 = torch.Tensor([[ 53 | [1, 0, 0, 0], 54 | [0, -1, 0, 0], 55 | [0, 0, -1, 0], 56 | [0, 0, 0, -1] 57 | ],[ 58 | [0, 1, 0, 0], 59 | [1, 0, 0, 0], 60 | [0, 0, 0, 1], 61 | [0, 0, -1, 0] 62 | ],[ 63 | [0, 0, 1, 0], 64 | [0, 0, 0, -1], 65 | [1, 0, 0, 0], 66 | [0, 1, 0, 0] 67 | ],[ 68 | [0, 0, 0, 1], 69 | [0, 0, 1, 0], 70 | [0, -1, 0, 0], 71 | [1, 0, 0, 0] 72 | ]]).float() 73 | 74 | transposer = torch.Tensor([ 75 | [1, 0, 0, 0], 76 | [0, -1, 0, 0], 77 | [0, 0, -1, 0], 78 | [0, 0, 0, -1] 79 | ]).float() 80 | 81 | def cuda(): 82 | global Qmt 83 | Qmt = Qmt.cuda() 84 | global Qmt2 85 | Qmt2 = Qmt.cuda() 86 | global Qmt3 87 | Qmt3 = Qmt3.cuda() 88 | global transposer 89 | transposer = transposer.cuda() 90 | 91 | def q2m(qs, M): 92 | ''' 93 | |qs|: dims (*, 4, 1) 94 | |m|: dims (4, 4, 4) 95 | ''' 96 | qs = qs.unsqueeze(-1) # dims (*, 4, 1, 1) 97 | qs = M * qs # dims (*, 4, 4, 4) 98 | return qs.sum(-3) # dims (*, 4, 4) 99 | 100 | def qconj(q): 101 | return transposer.matmul(q) 102 | 103 | def qnormsq(q): 104 | return (q**2).sum(-2, keepdim=True) # 30x speed 105 | 106 | def qinv(q): 107 | return qconj(q)/qnormsq(q) # 20x speed 108 | 109 | class QMultiply(Function): 110 | @staticmethod 111 | def forward(ctx, p, q): 112 | ctx.save_for_backward(p, q) 113 | return q2m(p, Qmt).matmul(q) 114 | 115 | @staticmethod 116 | def backward(ctx, grad_output): 117 | p, q = ctx.saved_tensors 118 | dout_dp = q2m(q, Qmt2) 119 | dout_dq = q2m(p, Qmt) 120 | grad_p = dout_dp.transpose(-1, -2).matmul(grad_output) 121 | grad_q = dout_dq.transpose(-1, -2).matmul(grad_output) 122 | return grad_p, grad_q 123 | 124 | 125 | class QMultiplyConjugate(Function): 126 | @staticmethod 127 | def forward(ctx, p, q): 128 | ctx.save_for_backward(p, q) 129 | return q2m(p, Qmt).matmul(qconj(q)) 130 | 131 | @staticmethod 132 | def backward(ctx, grad_output): 133 | p, q = ctx.saved_tensors 134 | dout_dp = q2m(q, Qmt2) 135 | dout_dq = q2m(p, Qmt3) 136 | grad_p = dout_dp.matmul(grad_output) 137 | grad_q = dout_dq.transpose(-1, -2).matmul(grad_output) 138 | return grad_p, grad_q 139 | 140 | class QConjugate(Function): 141 | @staticmethod 142 | def forward(ctx, p, q): # qpq*/||q||^2 143 | ctx.save_for_backward(p, q) 144 | return QMultiply.apply(q, QMultiplyConjugate.apply(p, q))/qnormsq(q) 145 | 146 | @staticmethod 147 | def backward(ctx, grad_output): 148 | p, q = ctx.saved_tensors 149 | qp = QMultiply.apply(q, p) 150 | qpq = QMultiplyConjugate.apply(qp, q) 151 | qnsq = qnormsq(q) 152 | grad_qnsq = (-grad_output * qpq / qnsq**2).sum(-2, keepdim=True) 153 | grad_qpq = grad_output / qnsq # correct 154 | dqpq_dqp = q2m(q, Qmt2) 155 | grad_pq = dqpq_dqp.matmul(grad_qpq) # correct 156 | grad_p = q2m(q, Qmt).transpose(-1, -2).matmul(grad_pq) 157 | 158 | dqpq_dq = q2m(qp, Qmt3) 159 | dpq_dq = q2m(p, Qmt2) 160 | grad_q = dpq_dq.transpose(-1, -2).matmul(grad_pq) + \ 161 | dqpq_dq.transpose(-1, -2).matmul(grad_qpq) + \ 162 | grad_qnsq * 2 * q 163 | return grad_p, grad_q 164 | 165 | def checkGrad(): 166 | torch.manual_seed(42) 167 | p = torch.rand(1, 4, 1, requires_grad=True).double() 168 | q = torch.rand(1, 4, 1, requires_grad=True).double() 169 | from torch.autograd import gradcheck 170 | global transposer 171 | transposer = transposer.double() 172 | test = gradcheck(QConjugate.apply, (p, q), eps=1e-6, atol=1e-4) 173 | print(test) 174 | 175 | 176 | def qnorm(q): # (*, 4, 1) 177 | return (q.squeeze(-1)**2).sum(-1).sqrt() 178 | 179 | def rotate(q, r): 180 | ''' 181 | |q|: dims (*, 4, 1) 182 | |r|: dims (4, 1) 183 | ''' 184 | return QMultiply.apply(r, QMultiplyConjugate.apply(q, r))/qnormsq(r) 185 | 186 | class QConv1d(nn.Module): 187 | def __init__(self, inchannels, outchannels, filterlen, stride=1): 188 | ''' 189 | |inchannels|: number of input quaternion channels 190 | |outchannels|: number of output quaternion channels 191 | |filterlen|: length of convolutional filter, recommended odd 192 | ''' 193 | super(QConv1d, self).__init__() 194 | self.filterlen = filterlen 195 | self.inchannels = inchannels 196 | self.outchannels = outchannels 197 | self.stride = stride 198 | self.register_buffer('eye', torch.tensor([[1],[0],[0],[0]]).float()) 199 | 200 | # sum a(q + b)(qp + c)q(qp + c)^-1 201 | dims = (1, outchannels, inchannels, 1, filterlen, 1, 1) 202 | he = math.sqrt(2 / filterlen / inchannels) 203 | self.a = nn.Parameter(torch.randn(*dims) * he) 204 | self.b = nn.Parameter(torch.randn(*dims) / 2) 205 | self.c = nn.Parameter(torch.randn(*dims) * math.sqrt(1.3780**2 - 1/4)) 206 | # self.beta = nn.Parameter(torch.randn(1, outchannels, 1, 1, 1)) 207 | 208 | def forward(self, x): 209 | ''' 210 | |x|: dims (batch, in channels, in time, 4, 1) 211 | returns: dims (batch, out channels, out time, 4, 1) 212 | ''' 213 | #import pdb; pdb.set_trace() 214 | q = x.unsqueeze(1) 215 | # q: (batch, 1, in channels, in time, 4, 1) 216 | 217 | qp = q.transpose(-3, 0)[self.filterlen//2:-(self.filterlen//2):self.stride].transpose(0, -3).unsqueeze(-3) 218 | # qp: (batch, 1, in channels, out time, 1, 4, 1) 219 | 220 | qpc = qp + self.c*self.eye # (4, 1) 221 | # self.c*self.eye: (1, outchannels, inchannels, 1, filterlen, 4, 1) 222 | # qpc: (batch, outchannels, inchannels, out time, filterlen, 4, 1) 223 | 224 | q = q.unfold(-3, self.filterlen, self.stride).transpose(-3, -1).transpose(-1,-2) 225 | # q: (batch, 1, inchannels, out time, filterlen, 4, 1) 226 | 227 | res = QMultiply.apply(qpc, QMultiplyConjugate.apply(q, qpc))/qnormsq(qpc) 228 | #res = QConjugate.apply(q, qpc) 229 | res = QMultiply.apply(q, res) + self.b * res 230 | res = res * self.a 231 | res = res.sum(-3).sum(2) 232 | # res: dims (batch, out channels, out time, 4, 1) 233 | return res 234 | 235 | class QBatchNorm1d(nn.Module): 236 | def __init__(self, *dims, momentum=0.1): 237 | super(QBatchNorm1d, self).__init__() 238 | 239 | self.register_buffer('mean', torch.ones(1, *dims, 1, 1, 1)) 240 | self.momentum = momentum 241 | 242 | def forward(self, x): 243 | ''' 244 | |x|: dims (batch, *dims, time, 4, 1) 245 | returns: dims (batch, *dims, time, 4, 1) 246 | ''' 247 | if self.training: 248 | self.mean = self.mean * (1-self.momentum) + self.momentum * qnormsq(x.detach()).mean(dim=0, keepdim=True).mean(dim=-3, keepdim=True).sqrt() 249 | 250 | return x / self.mean 251 | 252 | def checkEquivariant(): 253 | x = torch.randn(1,1,7,4,1) 254 | kernel = QConv1d(1, 1, 7) 255 | r = torch.randn(4,1) 256 | x_rot = rotate(x, r) 257 | fx_rot = kernel(x_rot) 258 | fx = kernel(x) 259 | rot_fx = rotate(fx, r) 260 | frac = (fx_rot*rot_fx).sum()/fx_rot.norm()/rot_fx.norm() 261 | print(frac) 262 | 263 | #checkEquivariant() 264 | #checkGrad() --------------------------------------------------------------------------------