└── relative_position.py /relative_position.py: -------------------------------------------------------------------------------- 1 | class RelativePosition(nn.Module): 2 | 3 | def __init__(self, num_units, max_relative_position): 4 | super().__init__() 5 | self.num_units = num_units 6 | self.max_relative_position = max_relative_position 7 | self.embeddings_table = Parameter(torch.Tensor(max_relative_position * 2 + 1, num_units) 8 | nn.init.xavier_uniform_(self.embeddings_table) 9 | 10 | def forward(self, length_q, length_k): 11 | range_vec_q = torch.arange(length_q) 12 | range_vec_k = torch.arange(length_k) 13 | distance_mat = range_vec_k[None, :] - range_vec_q[:, None] 14 | distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position) 15 | final_mat = distance_mat_clipped + self.max_relative_position 16 | final_mat = torch.LongTensor(final_mat).cuda() 17 | embeddings = self.embeddings_table[final_mat].cuda() 18 | 19 | return embeddings 20 | 21 | 22 | self.relative_position_k = RelativePosition(i, self.d_k, max_relative_position) 23 | self.relative_position_v = RelativePosition(i, self.d_v, max_relative_position) 24 | 25 | r_q = q.permute(2, 0, 1, 3).contiguous().view(len_q, sz_b*n_head, d_k) 26 | r_k = self.relative_position_k(len_q, len_k) 27 | attn_2 = torch.matmul(r_q, r_k.transpose(1, 2)).transpose(0, 1) 28 | attn_2 = attn_2.contiguous().view(sz_b, self.n_head, len_k, len_k) 29 | 30 | r_v = self.relative_position_v(len_q, len_v) 31 | weight = attn.permute(2, 0, 1, 3).contiguous().view(len_q, sz_b*n_head, len_k) 32 | weight = torch.matmul(weight, r_v) 33 | weight = weight.transpose(0, 1).contiguous().view(sz_b, self.n_head, len_q, d_v) 34 | --------------------------------------------------------------------------------