├── README.md ├── src ├── main │ ├── resources │ │ └── logback.xml │ └── java │ │ └── tech │ │ └── dubs │ │ └── dl4j │ │ └── contrib │ │ └── attention │ │ ├── activations │ │ └── ActivationMaskedSoftmax.java │ │ ├── conf │ │ ├── RecurrentAttentionLayer.java │ │ ├── TimestepAttentionLayer.java │ │ └── SelfAttentionLayer.java │ │ └── nn │ │ ├── TimestepAttentionLayer.java │ │ ├── params │ │ ├── QueryAttentionParamInitializer.java │ │ ├── SelfAttentionParamInitializer.java │ │ └── RecurrentQueryAttentionParamInitializer.java │ │ ├── SelfAttentionLayer.java │ │ ├── RecurrentAttentionLayer.java │ │ └── AdditiveAttentionMechanism.java └── test │ └── java │ └── tech │ └── dubs │ └── dl4j │ └── contrib │ └── attention │ ├── Serialization.java │ └── GradientChecks.java ├── .gitignore └── pom.xml /README.md: -------------------------------------------------------------------------------- 1 | # Work in Progress: Attention Layers for DL4J -------------------------------------------------------------------------------- /src/main/resources/logback.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | %msg%n 8 | 9 | 10 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### Maven template 3 | target/ 4 | pom.xml.tag 5 | pom.xml.releaseBackup 6 | pom.xml.versionsBackup 7 | pom.xml.next 8 | release.properties 9 | dependency-reduced-pom.xml 10 | buildNumber.properties 11 | .mvn/timing.properties 12 | .mvn/wrapper/maven-wrapper.jar 13 | ### JetBrains template 14 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 15 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 16 | 17 | *.iml 18 | .idea 19 | 20 | # Gradle and Maven with auto-import 21 | # When using Gradle or Maven with auto-import, you should exclude module files, 22 | # since they will be recreated, and may cause churn. Uncomment if using 23 | # auto-import. 24 | # .idea/modules.xml 25 | # .idea/*.iml 26 | # .idea/modules 27 | 28 | # CMake 29 | cmake-build-*/ 30 | 31 | # Mongo Explorer plugin 32 | .idea/**/mongoSettings.xml 33 | 34 | # File-based project format 35 | *.iws 36 | 37 | # IntelliJ 38 | out/ 39 | 40 | # mpeltonen/sbt-idea plugin 41 | .idea_modules/ 42 | 43 | # JIRA plugin 44 | atlassian-ide-plugin.xml 45 | 46 | # Cursive Clojure plugin 47 | .idea/replstate.xml 48 | 49 | # Crashlytics plugin (for Android Studio and IntelliJ) 50 | com_crashlytics_export_strings.xml 51 | crashlytics.properties 52 | crashlytics-build.properties 53 | fabric.properties 54 | 55 | # Editor-based Rest Client 56 | .idea/httpRequests 57 | ### Java template 58 | # Compiled class file 59 | *.class 60 | 61 | # Log file 62 | *.log 63 | 64 | # BlueJ files 65 | *.ctxt 66 | 67 | # Mobile Tools for Java (J2ME) 68 | .mtj.tmp/ 69 | 70 | # Package Files # 71 | *.jar 72 | *.war 73 | *.nar 74 | *.ear 75 | *.zip 76 | *.tar.gz 77 | *.rar 78 | 79 | # virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml 80 | hs_err_pid* 81 | 82 | -------------------------------------------------------------------------------- /src/main/java/tech/dubs/dl4j/contrib/attention/activations/ActivationMaskedSoftmax.java: -------------------------------------------------------------------------------- 1 | /******************************************************************************* 2 | * Copyright (c) 2015-2018 Skymind, Inc. 3 | * 4 | * This program and the accompanying materials are made available under the 5 | * terms of the Apache License, Version 2.0 which is available at 6 | * https://www.apache.org/licenses/LICENSE-2.0. 7 | * 8 | * Unless required by applicable law or agreed to in writing, software 9 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 10 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 11 | * License for the specific language governing permissions and limitations 12 | * under the License. 13 | * 14 | * SPDX-License-Identifier: Apache-2.0 15 | ******************************************************************************/ 16 | 17 | package tech.dubs.dl4j.contrib.attention.activations; 18 | 19 | import org.nd4j.linalg.api.ndarray.INDArray; 20 | import org.nd4j.linalg.api.ops.impl.transforms.OldSoftMax; 21 | import org.nd4j.linalg.factory.Nd4j; 22 | import org.nd4j.linalg.ops.transforms.Transforms; 23 | import org.nd4j.linalg.primitives.Pair; 24 | 25 | import java.util.Arrays; 26 | 27 | /** 28 | * f_i(x, m) = m_i*exp(x_i - shift) / sum_j m_j*exp(x_j - shift) 29 | * where shift = max_i(x_i), m = mask 30 | * 31 | * @author Paul Dubs 32 | */ 33 | public class ActivationMaskedSoftmax { 34 | 35 | public INDArray getActivation(INDArray in, INDArray mask) { 36 | if(mask == null){ 37 | return Nd4j.getExecutioner().execAndReturn(new OldSoftMax(in)); 38 | }else { 39 | assertShape(in, mask, null); 40 | 41 | final INDArray shift = in.max(-1); 42 | final INDArray exp = Transforms.exp(in.subiColumnVector(shift), false); 43 | 44 | final INDArray masked = exp.muli(mask); 45 | 46 | final INDArray sum = masked.sum(-1); 47 | masked.diviColumnVector(sum); 48 | return masked; 49 | } 50 | } 51 | 52 | public Pair backprop(INDArray postSoftmax, INDArray mask, INDArray epsilon) { 53 | INDArray x = postSoftmax.mul(epsilon).sum(1); 54 | INDArray dLdz = postSoftmax.mul(epsilon.subColumnVector(x)); 55 | return new Pair<>(dLdz, null); 56 | } 57 | 58 | @Override 59 | public String toString() { 60 | return "maskedSoftmax"; 61 | } 62 | 63 | private void assertShape(INDArray in, INDArray mask, INDArray epsilon) { 64 | if (mask != null && !in.equalShapes(mask)) { 65 | throw new IllegalStateException("Shapes must be equal: in.shape{} = " + Arrays.toString(in.shape()) 66 | + ", mask.shape() = " + Arrays.toString(mask.shape())); 67 | } 68 | if (epsilon != null && !in.equalShapes(epsilon)) { 69 | throw new IllegalStateException("Shapes must be equal during backprop: in.shape{} = " + Arrays.toString(in.shape()) 70 | + ", epsilon.shape() = " + Arrays.toString(epsilon.shape())); 71 | } 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 4.0.0 6 | 7 | tech.dubs.dl4j.contrib 8 | attention 9 | 1.0-SNAPSHOT 10 | 11 | 12 | 1.0.0-beta2 13 | 1.2.3 14 | 1.8 15 | 2.4.3 16 | 17 | 18 | 19 | 20 | snapshots-repo 21 | https://oss.sonatype.org/content/repositories/snapshots 22 | 23 | false 24 | 25 | 26 | true 27 | daily 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | org.deeplearning4j 36 | deeplearning4j-nn 37 | ${dl4j.version} 38 | 39 | 40 | 44 | 45 | org.nd4j 46 | nd4j-native-platform 47 | ${dl4j.version} 48 | test 49 | 50 | 51 | 52 | 53 | ch.qos.logback 54 | logback-classic 55 | ${logback.version} 56 | test 57 | 58 | 59 | 60 | junit 61 | junit 62 | 4.12 63 | test 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | org.apache.maven.plugins 72 | maven-compiler-plugin 73 | 3.5.1 74 | 75 | ${java.version} 76 | ${java.version} 77 | 78 | 79 | 80 | 81 | -------------------------------------------------------------------------------- /src/main/java/tech/dubs/dl4j/contrib/attention/conf/RecurrentAttentionLayer.java: -------------------------------------------------------------------------------- 1 | package tech.dubs.dl4j.contrib.attention.conf; 2 | 3 | import org.deeplearning4j.nn.api.Layer; 4 | import org.deeplearning4j.nn.api.ParamInitializer; 5 | import org.deeplearning4j.nn.conf.NeuralNetConfiguration; 6 | import org.deeplearning4j.nn.conf.inputs.InputType; 7 | import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer; 8 | import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; 9 | import org.deeplearning4j.optimize.api.TrainingListener; 10 | import org.nd4j.linalg.api.ndarray.INDArray; 11 | import tech.dubs.dl4j.contrib.attention.nn.params.RecurrentQueryAttentionParamInitializer; 12 | 13 | import java.util.Collection; 14 | import java.util.Map; 15 | 16 | /** 17 | * TODO: Memory Report, Configurable Activation for Attention 18 | * 19 | * @author Paul Dubs 20 | */ 21 | public class RecurrentAttentionLayer extends BaseRecurrentLayer { 22 | // No-Op Constructor for Deserialization 23 | public RecurrentAttentionLayer() { } 24 | 25 | private RecurrentAttentionLayer(Builder builder) { 26 | super(builder); 27 | } 28 | 29 | @Override 30 | public Layer instantiate(NeuralNetConfiguration conf, Collection iterationListeners, 31 | int layerIndex, INDArray layerParamsView, boolean initializeParams) { 32 | 33 | tech.dubs.dl4j.contrib.attention.nn.RecurrentAttentionLayer layer = new tech.dubs.dl4j.contrib.attention.nn.RecurrentAttentionLayer(conf); 34 | layer.setListeners(iterationListeners); //Set the iteration listeners, if any 35 | layer.setIndex(layerIndex); //Integer index of the layer 36 | 37 | layer.setParamsViewArray(layerParamsView); 38 | 39 | Map paramTable = initializer().init(conf, layerParamsView, initializeParams); 40 | layer.setParamTable(paramTable); 41 | layer.setConf(conf); 42 | return layer; 43 | } 44 | 45 | @Override 46 | public ParamInitializer initializer() { 47 | return RecurrentQueryAttentionParamInitializer.getInstance(); 48 | } 49 | 50 | @Override 51 | public double getL1ByParam(String paramName) { 52 | if(initializer().isWeightParam(this, paramName)){ 53 | return l1; 54 | }else if(initializer().isBiasParam(this, paramName)){ 55 | return l1Bias; 56 | } 57 | 58 | throw new IllegalArgumentException("Unknown parameter name: \"" + paramName + "\""); 59 | } 60 | 61 | @Override 62 | public double getL2ByParam(String paramName) { 63 | if(initializer().isWeightParam(this, paramName)){ 64 | return l2; 65 | }else if(initializer().isBiasParam(this, paramName)){ 66 | return l2Bias; 67 | } 68 | 69 | throw new IllegalArgumentException("Unknown parameter name: \"" + paramName + "\""); 70 | } 71 | 72 | 73 | @Override 74 | public LayerMemoryReport getMemoryReport(InputType inputType) { 75 | return null; 76 | } 77 | 78 | 79 | public static class Builder extends BaseRecurrentLayer.Builder { 80 | 81 | @Override 82 | @SuppressWarnings("unchecked") //To stop warnings about unchecked cast. Not required. 83 | public RecurrentAttentionLayer build() { 84 | return new RecurrentAttentionLayer(this); 85 | } 86 | 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /src/test/java/tech/dubs/dl4j/contrib/attention/Serialization.java: -------------------------------------------------------------------------------- 1 | package tech.dubs.dl4j.contrib.attention; 2 | 3 | import org.deeplearning4j.nn.conf.MultiLayerConfiguration; 4 | import org.deeplearning4j.nn.conf.NeuralNetConfiguration; 5 | import org.deeplearning4j.nn.conf.inputs.InputType; 6 | import org.deeplearning4j.nn.conf.layers.LSTM; 7 | import org.deeplearning4j.nn.conf.layers.OutputLayer; 8 | import org.deeplearning4j.nn.weights.WeightInit; 9 | import org.junit.Assert; 10 | import org.junit.Test; 11 | import org.nd4j.linalg.activations.Activation; 12 | import org.nd4j.linalg.learning.config.NoOp; 13 | import org.nd4j.linalg.lossfunctions.LossFunctions; 14 | import tech.dubs.dl4j.contrib.attention.conf.RecurrentAttentionLayer; 15 | import tech.dubs.dl4j.contrib.attention.conf.SelfAttentionLayer; 16 | import tech.dubs.dl4j.contrib.attention.conf.TimestepAttentionLayer; 17 | 18 | public class Serialization { 19 | @Test 20 | public void testSelfAttentionSerialization(){ 21 | int nIn = 3; 22 | int nOut = 5; 23 | int layerSize = 8; 24 | 25 | MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() 26 | .activation(Activation.TANH) 27 | .updater(new NoOp()) 28 | .weightInit(WeightInit.XAVIER) 29 | .list() 30 | .layer(new LSTM.Builder().nOut(layerSize).build()) 31 | .layer(new SelfAttentionLayer.Builder().build()) 32 | .layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX) 33 | .lossFunction(LossFunctions.LossFunction.MCXENT).build()) 34 | .setInputType(InputType.recurrent(nIn)) 35 | .build(); 36 | 37 | final String json = conf.toJson(); 38 | final String yaml = conf.toYaml(); 39 | 40 | final MultiLayerConfiguration fromJson = MultiLayerConfiguration.fromJson(json); 41 | final MultiLayerConfiguration fromYaml = MultiLayerConfiguration.fromYaml(yaml); 42 | 43 | Assert.assertEquals(conf, fromJson); 44 | Assert.assertEquals(conf, fromYaml); 45 | } 46 | 47 | @Test 48 | public void testTimestepAttentionSerialization(){ 49 | int nIn = 3; 50 | int nOut = 5; 51 | int layerSize = 8; 52 | 53 | MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() 54 | .activation(Activation.TANH) 55 | .updater(new NoOp()) 56 | .weightInit(WeightInit.XAVIER) 57 | .list() 58 | .layer(new LSTM.Builder().nOut(layerSize).build()) 59 | .layer(new TimestepAttentionLayer.Builder().build()) 60 | .layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX) 61 | .lossFunction(LossFunctions.LossFunction.MCXENT).build()) 62 | .setInputType(InputType.recurrent(nIn)) 63 | .build(); 64 | 65 | final String json = conf.toJson(); 66 | final String yaml = conf.toYaml(); 67 | 68 | final MultiLayerConfiguration fromJson = MultiLayerConfiguration.fromJson(json); 69 | final MultiLayerConfiguration fromYaml = MultiLayerConfiguration.fromYaml(yaml); 70 | 71 | Assert.assertEquals(conf, fromJson); 72 | Assert.assertEquals(conf, fromYaml); 73 | } 74 | 75 | @Test 76 | public void testRecurrentAttentionSerialization(){ 77 | int nIn = 3; 78 | int nOut = 5; 79 | int layerSize = 8; 80 | 81 | MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() 82 | .activation(Activation.TANH) 83 | .updater(new NoOp()) 84 | .weightInit(WeightInit.XAVIER) 85 | .list() 86 | .layer(new LSTM.Builder().nOut(layerSize).build()) 87 | .layer(new RecurrentAttentionLayer.Builder().build()) 88 | .layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX) 89 | .lossFunction(LossFunctions.LossFunction.MCXENT).build()) 90 | .setInputType(InputType.recurrent(nIn)) 91 | .build(); 92 | 93 | final String json = conf.toJson(); 94 | final String yaml = conf.toYaml(); 95 | 96 | final MultiLayerConfiguration fromJson = MultiLayerConfiguration.fromJson(json); 97 | final MultiLayerConfiguration fromYaml = MultiLayerConfiguration.fromYaml(yaml); 98 | 99 | Assert.assertEquals(conf, fromJson); 100 | Assert.assertEquals(conf, fromYaml); 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /src/main/java/tech/dubs/dl4j/contrib/attention/nn/TimestepAttentionLayer.java: -------------------------------------------------------------------------------- 1 | package tech.dubs.dl4j.contrib.attention.nn; 2 | 3 | import org.deeplearning4j.nn.conf.NeuralNetConfiguration; 4 | import org.deeplearning4j.nn.gradient.DefaultGradient; 5 | import org.deeplearning4j.nn.gradient.Gradient; 6 | import org.deeplearning4j.nn.layers.BaseLayer; 7 | import org.deeplearning4j.nn.workspace.ArrayType; 8 | import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; 9 | import org.nd4j.base.Preconditions; 10 | import org.nd4j.linalg.activations.IActivation; 11 | import org.nd4j.linalg.api.ndarray.INDArray; 12 | import org.nd4j.linalg.api.shape.Shape; 13 | import org.nd4j.linalg.primitives.Pair; 14 | import tech.dubs.dl4j.contrib.attention.activations.ActivationMaskedSoftmax; 15 | import tech.dubs.dl4j.contrib.attention.nn.params.QueryAttentionParamInitializer; 16 | 17 | /** 18 | * Timestep Attention Layer Implementation 19 | * 20 | * 21 | * TODO: 22 | * - Optionally keep attention weights around for inspection 23 | * - Handle Masking 24 | * 25 | * @author Paul Dubs 26 | */ 27 | public class TimestepAttentionLayer extends BaseLayer { 28 | private ActivationMaskedSoftmax softmax = new ActivationMaskedSoftmax(); 29 | 30 | public TimestepAttentionLayer(NeuralNetConfiguration conf) { 31 | super(conf); 32 | } 33 | 34 | @Override 35 | public boolean isPretrainLayer() { 36 | return false; 37 | } 38 | 39 | @Override 40 | public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) { 41 | assertInputSet(false); 42 | Preconditions.checkState(input.rank() == 3, 43 | "3D input expected to RNN layer expected, got " + input.rank()); 44 | 45 | applyDropOutIfNecessary(training, workspaceMgr); 46 | 47 | INDArray W = getParamWithNoise(QueryAttentionParamInitializer.WEIGHT_KEY, training, workspaceMgr); 48 | INDArray Q = getParamWithNoise(QueryAttentionParamInitializer.QUERY_WEIGHT_KEY, training, workspaceMgr); 49 | INDArray b = getParamWithNoise(QueryAttentionParamInitializer.BIAS_KEY, training, workspaceMgr); 50 | 51 | 52 | long nIn = layerConf().getNIn(); 53 | long nOut = layerConf().getNOut(); 54 | IActivation a = layerConf().getActivationFn(); 55 | long examples = input.shape()[0] == nIn ? input.shape()[2] : input.shape()[0]; 56 | long tsLength = input.shape()[0] == nIn ? input.shape()[1] : input.shape()[2]; 57 | 58 | INDArray activations = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, new long[]{examples, nIn*nOut, tsLength}, 'f'); 59 | 60 | if(input.shape()[0] != nIn) 61 | input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input.permute(1, 2, 0), 'f'); 62 | 63 | final AdditiveAttentionMechanism attentionMechanism = new AdditiveAttentionMechanism(Q, W, b, a, workspaceMgr, training); 64 | final INDArray attention = attentionMechanism.query(input, input, input, maskArray); 65 | activations.assign(attention); 66 | 67 | return activations; 68 | } 69 | 70 | 71 | 72 | @Override 73 | public Pair backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { 74 | assertInputSet(true); 75 | if(epsilon.ordering() != 'f' || !Shape.hasDefaultStridesForShape(epsilon)) 76 | epsilon = epsilon.dup('f'); 77 | 78 | INDArray W = getParamWithNoise(QueryAttentionParamInitializer.WEIGHT_KEY, true, workspaceMgr); 79 | INDArray Q = getParamWithNoise(QueryAttentionParamInitializer.QUERY_WEIGHT_KEY, true, workspaceMgr); 80 | INDArray b = getParamWithNoise(QueryAttentionParamInitializer.BIAS_KEY, true, workspaceMgr); 81 | 82 | INDArray Wg = gradientViews.get(QueryAttentionParamInitializer.WEIGHT_KEY); 83 | INDArray Qg = gradientViews.get(QueryAttentionParamInitializer.QUERY_WEIGHT_KEY); 84 | INDArray bg = gradientViews.get(QueryAttentionParamInitializer.BIAS_KEY); 85 | gradientsFlattened.assign(0); 86 | 87 | applyDropOutIfNecessary(true, workspaceMgr); 88 | 89 | IActivation a = layerConf().getActivationFn(); 90 | 91 | long nIn = layerConf().getNIn(); 92 | if(input.shape()[0] != nIn) 93 | input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input.permute(1, 2, 0), 'f'); 94 | 95 | INDArray epsOut = workspaceMgr.create(ArrayType.ACTIVATION_GRAD, input.shape(), 'f'); 96 | 97 | final AdditiveAttentionMechanism attentionMechanism = new AdditiveAttentionMechanism(Q, W, b, a, workspaceMgr, true); 98 | attentionMechanism 99 | .withGradientViews(Wg, Qg, bg, epsOut, epsOut, epsOut) 100 | .backprop(epsilon, input, input, input, maskArray); 101 | 102 | epsOut = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsOut.permute(2, 0, 1), 'f'); 103 | 104 | weightNoiseParams.clear(); 105 | 106 | Gradient g = new DefaultGradient(gradientsFlattened); 107 | g.gradientForVariable().put(QueryAttentionParamInitializer.WEIGHT_KEY, Wg); 108 | g.gradientForVariable().put(QueryAttentionParamInitializer.QUERY_WEIGHT_KEY, Qg); 109 | g.gradientForVariable().put(QueryAttentionParamInitializer.BIAS_KEY, bg); 110 | 111 | epsOut = backpropDropOutIfPresent(epsOut); 112 | return new Pair<>(g, epsOut); 113 | } 114 | } 115 | -------------------------------------------------------------------------------- /src/main/java/tech/dubs/dl4j/contrib/attention/conf/TimestepAttentionLayer.java: -------------------------------------------------------------------------------- 1 | package tech.dubs.dl4j.contrib.attention.conf; 2 | 3 | import org.deeplearning4j.nn.api.Layer; 4 | import org.deeplearning4j.nn.api.ParamInitializer; 5 | import org.deeplearning4j.nn.conf.InputPreProcessor; 6 | import org.deeplearning4j.nn.conf.NeuralNetConfiguration; 7 | import org.deeplearning4j.nn.conf.inputs.InputType; 8 | import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer; 9 | import org.deeplearning4j.nn.conf.layers.InputTypeUtil; 10 | import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; 11 | import org.deeplearning4j.nn.conf.memory.MemoryReport; 12 | import org.deeplearning4j.optimize.api.TrainingListener; 13 | import org.nd4j.linalg.api.ndarray.INDArray; 14 | import tech.dubs.dl4j.contrib.attention.nn.params.QueryAttentionParamInitializer; 15 | 16 | import java.util.Collection; 17 | import java.util.Map; 18 | 19 | /** 20 | * 21 | * 22 | * @author Paul Dubs 23 | */ 24 | public class TimestepAttentionLayer extends BaseRecurrentLayer { 25 | // No-Op Constructor for Deserialization 26 | public TimestepAttentionLayer() { } 27 | 28 | private TimestepAttentionLayer(Builder builder) { 29 | super(builder); 30 | } 31 | 32 | @Override 33 | public Layer instantiate(NeuralNetConfiguration conf, Collection iterationListeners, 34 | int layerIndex, INDArray layerParamsView, boolean initializeParams) { 35 | 36 | tech.dubs.dl4j.contrib.attention.nn.TimestepAttentionLayer layer = new tech.dubs.dl4j.contrib.attention.nn.TimestepAttentionLayer(conf); 37 | layer.setListeners(iterationListeners); //Set the iteration listeners, if any 38 | layer.setIndex(layerIndex); //Integer index of the layer 39 | 40 | layer.setParamsViewArray(layerParamsView); 41 | 42 | Map paramTable = initializer().init(conf, layerParamsView, initializeParams); 43 | layer.setParamTable(paramTable); 44 | layer.setConf(conf); 45 | return layer; 46 | } 47 | 48 | @Override 49 | public ParamInitializer initializer() { 50 | return QueryAttentionParamInitializer.getInstance(); 51 | } 52 | 53 | @Override 54 | public double getL1ByParam(String paramName) { 55 | if(initializer().isWeightParam(this, paramName)){ 56 | return l1; 57 | }else if(initializer().isBiasParam(this, paramName)){ 58 | return l1Bias; 59 | } 60 | 61 | throw new IllegalArgumentException("Unknown parameter name: \"" + paramName + "\""); 62 | } 63 | 64 | @Override 65 | public double getL2ByParam(String paramName) { 66 | if(initializer().isWeightParam(this, paramName)){ 67 | return l2; 68 | }else if(initializer().isBiasParam(this, paramName)){ 69 | return l2Bias; 70 | } 71 | 72 | throw new IllegalArgumentException("Unknown parameter name: \"" + paramName + "\""); 73 | } 74 | 75 | @Override 76 | public InputType getOutputType(int layerIndex, InputType inputType) { 77 | if (inputType == null || inputType.getType() != InputType.Type.RNN) { 78 | throw new IllegalStateException("Invalid input for RNN layer (layer index = " + layerIndex 79 | + ", layer name = \"" + getLayerName() + "\"): expect RNN input type with size > 0. Got: " 80 | + inputType); 81 | } 82 | InputType.InputTypeRecurrent itr = (InputType.InputTypeRecurrent) inputType; 83 | 84 | return InputType.recurrent(nIn * nOut, itr.getTimeSeriesLength()); 85 | } 86 | 87 | @Override 88 | public void setNIn(InputType inputType, boolean override) { 89 | if (inputType == null || inputType.getType() != InputType.Type.RNN) { 90 | throw new IllegalStateException("Invalid input for RNN layer (layer name = \"" + getLayerName() 91 | + "\"): expect RNN input type with size > 0. Got: " + inputType); 92 | } 93 | 94 | if (nIn <= 0 || override) { 95 | InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType; 96 | this.nIn = r.getSize(); 97 | } 98 | } 99 | 100 | @Override 101 | public InputPreProcessor getPreProcessorForInputType(InputType inputType) { 102 | return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, getLayerName()); 103 | } 104 | 105 | @Override 106 | public LayerMemoryReport getMemoryReport(InputType inputType) { 107 | InputType.InputTypeRecurrent itr = (InputType.InputTypeRecurrent) inputType; 108 | final long tsLength = itr.getTimeSeriesLength(); 109 | 110 | InputType outputType = getOutputType(-1, inputType); 111 | 112 | long numParams = initializer().numParams(this); 113 | int updaterStateSize = (int)getIUpdater().stateSize(numParams); 114 | 115 | int trainSizeFixed = 0; 116 | int trainSizeVariable = 0; 117 | if(getIDropout() != null){ 118 | //Assume we dup the input for dropout 119 | trainSizeVariable += inputType.arrayElementsPerExample(); 120 | } 121 | trainSizeVariable += outputType.arrayElementsPerExample() * tsLength; 122 | trainSizeVariable += itr.getSize() * outputType.arrayElementsPerExample(); 123 | 124 | return new LayerMemoryReport.Builder(layerName, TimestepAttentionLayer.class, inputType, outputType) 125 | .standardMemory(numParams, updaterStateSize) 126 | .workingMemory(0, 0, trainSizeFixed, trainSizeVariable) //No additional memory (beyond activations) for inference 127 | .cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS) //No caching in DenseLayer 128 | .build(); 129 | } 130 | 131 | 132 | public static class Builder extends BaseRecurrentLayer.Builder { 133 | 134 | @Override 135 | @SuppressWarnings("unchecked") //To stop warnings about unchecked cast. Not required. 136 | public TimestepAttentionLayer build() { 137 | return new TimestepAttentionLayer(this); 138 | } 139 | 140 | } 141 | } 142 | -------------------------------------------------------------------------------- /src/main/java/tech/dubs/dl4j/contrib/attention/conf/SelfAttentionLayer.java: -------------------------------------------------------------------------------- 1 | package tech.dubs.dl4j.contrib.attention.conf; 2 | 3 | import org.deeplearning4j.nn.api.Layer; 4 | import org.deeplearning4j.nn.api.ParamInitializer; 5 | import org.deeplearning4j.nn.conf.InputPreProcessor; 6 | import org.deeplearning4j.nn.conf.NeuralNetConfiguration; 7 | import org.deeplearning4j.nn.conf.inputs.InputType; 8 | import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; 9 | import org.deeplearning4j.nn.conf.layers.InputTypeUtil; 10 | import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; 11 | import org.deeplearning4j.nn.conf.memory.MemoryReport; 12 | import org.deeplearning4j.optimize.api.TrainingListener; 13 | import org.nd4j.linalg.api.ndarray.INDArray; 14 | import tech.dubs.dl4j.contrib.attention.nn.params.SelfAttentionParamInitializer; 15 | 16 | import java.util.Collection; 17 | import java.util.Map; 18 | 19 | /** 20 | * nOut = Number of Attention Heads 21 | * 22 | * The self attention layer is the simplest of attention layers. It takes a recurrent input, and returns a fixed size 23 | * dense output. For this it requires just the same parameters as a dense layer, since for the most learning part of it 24 | * it actually is a dense layer. 25 | * 26 | * @author Paul Dubs 27 | */ 28 | public class SelfAttentionLayer extends FeedForwardLayer { 29 | // No-Op Constructor for Deserialization 30 | public SelfAttentionLayer() { } 31 | 32 | private SelfAttentionLayer(Builder builder) { 33 | super(builder); 34 | } 35 | 36 | @Override 37 | public Layer instantiate(NeuralNetConfiguration conf, Collection iterationListeners, 38 | int layerIndex, INDArray layerParamsView, boolean initializeParams) { 39 | 40 | tech.dubs.dl4j.contrib.attention.nn.SelfAttentionLayer layer = new tech.dubs.dl4j.contrib.attention.nn.SelfAttentionLayer(conf); 41 | layer.setListeners(iterationListeners); //Set the iteration listeners, if any 42 | layer.setIndex(layerIndex); //Integer index of the layer 43 | 44 | layer.setParamsViewArray(layerParamsView); 45 | 46 | Map paramTable = initializer().init(conf, layerParamsView, initializeParams); 47 | layer.setParamTable(paramTable); 48 | layer.setConf(conf); 49 | return layer; 50 | } 51 | 52 | @Override 53 | public ParamInitializer initializer() { 54 | return SelfAttentionParamInitializer.getInstance(); 55 | } 56 | 57 | 58 | @Override 59 | public double getL1ByParam(String paramName) { 60 | if(initializer().isWeightParam(this, paramName)){ 61 | return l1; 62 | }else if(initializer().isBiasParam(this, paramName)){ 63 | return l1Bias; 64 | } 65 | 66 | throw new IllegalArgumentException("Unknown parameter name: \"" + paramName + "\""); 67 | } 68 | 69 | @Override 70 | public double getL2ByParam(String paramName) { 71 | if(initializer().isWeightParam(this, paramName)){ 72 | return l2; 73 | }else if(initializer().isBiasParam(this, paramName)){ 74 | return l2Bias; 75 | } 76 | 77 | throw new IllegalArgumentException("Unknown parameter name: \"" + paramName + "\""); 78 | } 79 | 80 | @Override 81 | public InputType getOutputType(int layerIndex, InputType inputType) { 82 | if (inputType == null || inputType.getType() != InputType.Type.RNN) { 83 | throw new IllegalStateException("Invalid input for RNN layer (layer index = " + layerIndex 84 | + ", layer name = \"" + getLayerName() + "\"): expect RNN input type with size > 0. Got: " 85 | + inputType); 86 | } 87 | 88 | return InputType.feedForward(nIn * nOut); 89 | } 90 | 91 | @Override 92 | public void setNIn(InputType inputType, boolean override) { 93 | if (inputType == null || inputType.getType() != InputType.Type.RNN) { 94 | throw new IllegalStateException("Invalid input for RNN layer (layer name = \"" + getLayerName() 95 | + "\"): expect RNN input type with size > 0. Got: " + inputType); 96 | } 97 | 98 | if (nIn <= 0 || override) { 99 | InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType; 100 | this.nIn = r.getSize(); 101 | } 102 | } 103 | 104 | @Override 105 | public InputPreProcessor getPreProcessorForInputType(InputType inputType) { 106 | return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, getLayerName()); 107 | } 108 | 109 | @Override 110 | public LayerMemoryReport getMemoryReport(InputType inputType) { 111 | InputType.InputTypeRecurrent itr = (InputType.InputTypeRecurrent) inputType; 112 | final long tsLength = itr.getTimeSeriesLength(); 113 | 114 | InputType outputType = getOutputType(-1, inputType); 115 | 116 | long numParams = initializer().numParams(this); 117 | int updaterStateSize = (int)getIUpdater().stateSize(numParams); 118 | 119 | int trainSizeFixed = 0; 120 | int trainSizeVariable = 0; 121 | if(getIDropout() != null){ 122 | //Assume we dup the input for dropout 123 | trainSizeVariable += inputType.arrayElementsPerExample(); 124 | } 125 | trainSizeVariable += outputType.arrayElementsPerExample() * tsLength; 126 | trainSizeVariable += itr.getSize() * outputType.arrayElementsPerExample(); 127 | 128 | return new LayerMemoryReport.Builder(layerName, SelfAttentionLayer.class, inputType, outputType) 129 | .standardMemory(numParams, updaterStateSize) 130 | .workingMemory(0, 0, trainSizeFixed, trainSizeVariable) //No additional memory (beyond activations) for inference 131 | .cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS) //No caching in DenseLayer 132 | .build(); 133 | } 134 | 135 | 136 | public static class Builder extends FeedForwardLayer.Builder { 137 | @Override 138 | @SuppressWarnings("unchecked") //To stop warnings about unchecked cast. Not required. 139 | public SelfAttentionLayer build() { 140 | return new SelfAttentionLayer(this); 141 | } 142 | } 143 | } 144 | -------------------------------------------------------------------------------- /src/main/java/tech/dubs/dl4j/contrib/attention/nn/params/QueryAttentionParamInitializer.java: -------------------------------------------------------------------------------- 1 | /******************************************************************************* 2 | * Copyright (c) 2015-2018 Skymind, Inc. 3 | * 4 | * This program and the accompanying materials are made available under the 5 | * terms of the Apache License, Version 2.0 which is available at 6 | * https://www.apache.org/licenses/LICENSE-2.0. 7 | * 8 | * Unless required by applicable law or agreed to in writing, software 9 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 10 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 11 | * License for the specific language governing permissions and limitations 12 | * under the License. 13 | * 14 | * SPDX-License-Identifier: Apache-2.0 15 | ******************************************************************************/ 16 | 17 | package tech.dubs.dl4j.contrib.attention.nn.params; 18 | 19 | import org.deeplearning4j.nn.api.ParamInitializer; 20 | import org.deeplearning4j.nn.conf.NeuralNetConfiguration; 21 | import org.deeplearning4j.nn.conf.distribution.Distributions; 22 | import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer; 23 | import org.deeplearning4j.nn.conf.layers.Layer; 24 | import org.deeplearning4j.nn.params.DefaultParamInitializer; 25 | import org.deeplearning4j.nn.weights.WeightInit; 26 | import org.deeplearning4j.nn.weights.WeightInitUtil; 27 | import org.nd4j.linalg.api.ndarray.INDArray; 28 | import org.nd4j.linalg.api.rng.distribution.Distribution; 29 | 30 | import java.util.*; 31 | 32 | import static org.nd4j.linalg.indexing.NDArrayIndex.interval; 33 | import static org.nd4j.linalg.indexing.NDArrayIndex.point; 34 | 35 | /** 36 | * @author Paul Dubs 37 | */ 38 | public class QueryAttentionParamInitializer implements ParamInitializer { 39 | 40 | private static final QueryAttentionParamInitializer INSTANCE = new QueryAttentionParamInitializer(); 41 | 42 | public static QueryAttentionParamInitializer getInstance(){ 43 | return INSTANCE; 44 | } 45 | 46 | public static final String WEIGHT_KEY = DefaultParamInitializer.WEIGHT_KEY; 47 | public static final String QUERY_WEIGHT_KEY = "Q"; 48 | public static final String BIAS_KEY = DefaultParamInitializer.BIAS_KEY; 49 | 50 | private static final List PARAM_KEYS = Collections.unmodifiableList(Arrays.asList(WEIGHT_KEY, QUERY_WEIGHT_KEY, BIAS_KEY)); 51 | private static final List WEIGHT_KEYS = Collections.unmodifiableList(Arrays.asList(WEIGHT_KEY, QUERY_WEIGHT_KEY)); 52 | private static final List BIAS_KEYS = Collections.singletonList(BIAS_KEY); 53 | 54 | 55 | @Override 56 | public long numParams(NeuralNetConfiguration conf) { 57 | return numParams(conf.getLayer()); 58 | } 59 | 60 | @Override 61 | public long numParams(Layer layer) { 62 | BaseRecurrentLayer c = (BaseRecurrentLayer) layer; 63 | final long nIn = c.getNIn(); 64 | final long nOut = c.getNOut(); 65 | 66 | final long paramsW = nIn * nOut; 67 | final long paramsWq = nIn * nOut; 68 | final long paramsB = nOut; 69 | return paramsW + paramsWq + paramsB; 70 | } 71 | 72 | @Override 73 | public List paramKeys(Layer layer) { 74 | return PARAM_KEYS; 75 | } 76 | 77 | @Override 78 | public List weightKeys(Layer layer) { 79 | return WEIGHT_KEYS; 80 | } 81 | 82 | @Override 83 | public List biasKeys(Layer layer) { 84 | return BIAS_KEYS; 85 | } 86 | 87 | @Override 88 | public boolean isWeightParam(Layer layer, String key) { 89 | return WEIGHT_KEYS.contains(key); 90 | } 91 | 92 | @Override 93 | public boolean isBiasParam(Layer layer, String key) { 94 | return BIAS_KEYS.contains(key); 95 | } 96 | 97 | @Override 98 | public Map init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) { 99 | BaseRecurrentLayer c = (BaseRecurrentLayer) conf.getLayer(); 100 | final long nIn = c.getNIn(); 101 | final long nOut = c.getNOut(); 102 | 103 | Map m; 104 | 105 | if (initializeParams) { 106 | Distribution dist = Distributions.createDistribution(c.getDist()); 107 | 108 | m = getSubsets(paramsView, nIn, nOut, false); 109 | INDArray w = WeightInitUtil.initWeights(nIn, nOut, new long[]{nIn, nOut}, c.getWeightInit(), dist, 'f', m.get(WEIGHT_KEY)); 110 | m.put(WEIGHT_KEY, w); 111 | 112 | WeightInit rqInit; 113 | Distribution rqDist = dist; 114 | if (c.getWeightInitRecurrent() != null) { 115 | rqInit = c.getWeightInitRecurrent(); 116 | if(c.getDistRecurrent() != null) { 117 | rqDist = Distributions.createDistribution(c.getDistRecurrent()); 118 | } 119 | } else { 120 | rqInit = c.getWeightInit(); 121 | } 122 | 123 | INDArray rq = WeightInitUtil.initWeights(nIn, nOut, new long[]{nIn, nOut}, rqInit, rqDist, 'f', m.get(QUERY_WEIGHT_KEY)); 124 | m.put(QUERY_WEIGHT_KEY, rq); 125 | } else { 126 | m = getSubsets(paramsView, nIn, nOut, true); 127 | } 128 | 129 | conf.addVariable(WEIGHT_KEY); 130 | conf.addVariable(QUERY_WEIGHT_KEY); 131 | conf.addVariable(BIAS_KEY); 132 | 133 | return m; 134 | } 135 | 136 | @Override 137 | public Map getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) { 138 | BaseRecurrentLayer c = (BaseRecurrentLayer) conf.getLayer(); 139 | final long nIn = c.getNIn(); 140 | final long nOut = c.getNOut(); 141 | 142 | return getSubsets(gradientView, nIn, nOut, true); 143 | } 144 | 145 | private static Map getSubsets(INDArray in, long nIn, long nOut, boolean reshape){ 146 | long pos = nIn * nOut; 147 | INDArray w = in.get(point(0), interval(0, pos)); 148 | INDArray rq = in.get(point(0), interval(pos, pos + nIn * nOut)); 149 | pos += nIn * nOut; 150 | INDArray b = in.get(point(0), interval(pos, pos + nOut)); 151 | 152 | if(reshape){ 153 | w = w.reshape('f', nIn, nOut); 154 | rq = rq.reshape('f', nIn, nOut); 155 | } 156 | 157 | Map m = new LinkedHashMap<>(); 158 | m.put(WEIGHT_KEY, w); 159 | m.put(QUERY_WEIGHT_KEY, rq); 160 | m.put(BIAS_KEY, b); 161 | return m; 162 | } 163 | } 164 | -------------------------------------------------------------------------------- /src/main/java/tech/dubs/dl4j/contrib/attention/nn/SelfAttentionLayer.java: -------------------------------------------------------------------------------- 1 | package tech.dubs.dl4j.contrib.attention.nn; 2 | 3 | import org.deeplearning4j.nn.api.MaskState; 4 | import org.deeplearning4j.nn.conf.NeuralNetConfiguration; 5 | import org.deeplearning4j.nn.gradient.DefaultGradient; 6 | import org.deeplearning4j.nn.gradient.Gradient; 7 | import org.deeplearning4j.nn.layers.BaseLayer; 8 | import org.deeplearning4j.nn.workspace.ArrayType; 9 | import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; 10 | import org.nd4j.base.Preconditions; 11 | import org.nd4j.linalg.activations.IActivation; 12 | import org.nd4j.linalg.activations.impl.ActivationSoftmax; 13 | import org.nd4j.linalg.api.ndarray.INDArray; 14 | import org.nd4j.linalg.primitives.Pair; 15 | import tech.dubs.dl4j.contrib.attention.nn.params.SelfAttentionParamInitializer; 16 | 17 | /** 18 | * Self Attention Layer Implementation 19 | * 20 | * The implementation of mmul across time isn't the most efficient thing possible in nd4j, since the reshapes require 21 | * a copy, but it is the easiest to follow for now. 22 | * 23 | * TODO: 24 | * - Optionally keep attention weights around for inspection 25 | * - Handle Masking 26 | * 27 | * @author Paul Dubs 28 | */ 29 | public class SelfAttentionLayer extends BaseLayer { 30 | private IActivation softmax = new ActivationSoftmax(); 31 | 32 | public SelfAttentionLayer(NeuralNetConfiguration conf) { 33 | super(conf); 34 | } 35 | 36 | @Override 37 | public boolean isPretrainLayer() { 38 | return false; 39 | } 40 | 41 | @Override 42 | public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) { 43 | assertInputSet(false); 44 | Preconditions.checkState(input.rank() == 3, 45 | "3D input expected to RNN layer expected, got " + input.rank()); 46 | 47 | applyDropOutIfNecessary(training, workspaceMgr); 48 | 49 | INDArray W = getParamWithNoise(SelfAttentionParamInitializer.WEIGHT_KEY, training, workspaceMgr); 50 | INDArray Q = getParamWithNoise(SelfAttentionParamInitializer.QUERY_WEIGHT_KEY, training, workspaceMgr); 51 | INDArray b = getParamWithNoise(SelfAttentionParamInitializer.BIAS_KEY, training, workspaceMgr); 52 | INDArray q = getParamWithNoise(SelfAttentionParamInitializer.QUERY_KEY, training, workspaceMgr); 53 | 54 | long nOut = layerConf().getNOut(); 55 | long nIn = layerConf().getNIn(); 56 | long examples = input.shape()[0] == nIn ? input.shape()[2] : input.shape()[0]; 57 | IActivation a = layerConf().getActivationFn(); 58 | 59 | INDArray activations = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, new long[]{examples, nIn * nOut}, 'f'); 60 | 61 | if(input.shape()[0] != nIn) 62 | input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input.permute(1, 2, 0), 'f'); 63 | 64 | final INDArray queries = q.reshape(nIn, 1, 1).broadcast(nIn, 1, examples); 65 | 66 | final AdditiveAttentionMechanism attentionMechanism = new AdditiveAttentionMechanism(Q, W, b, a, workspaceMgr, training); 67 | final INDArray attention = attentionMechanism.query(queries, input, input, maskArray); 68 | activations.assign(attention.reshape(activations.shape())); 69 | 70 | return activations; 71 | } 72 | 73 | 74 | 75 | @Override 76 | public Pair backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { 77 | assertInputSet(true); 78 | 79 | INDArray W = getParamWithNoise(SelfAttentionParamInitializer.WEIGHT_KEY, true, workspaceMgr); 80 | INDArray Q = getParamWithNoise(SelfAttentionParamInitializer.QUERY_WEIGHT_KEY, true, workspaceMgr); 81 | INDArray b = getParamWithNoise(SelfAttentionParamInitializer.BIAS_KEY, true, workspaceMgr); 82 | INDArray q = getParamWithNoise(SelfAttentionParamInitializer.QUERY_KEY, true, workspaceMgr); 83 | 84 | INDArray Wg = gradientViews.get(SelfAttentionParamInitializer.WEIGHT_KEY); 85 | INDArray Qg = gradientViews.get(SelfAttentionParamInitializer.QUERY_WEIGHT_KEY); 86 | INDArray bg = gradientViews.get(SelfAttentionParamInitializer.BIAS_KEY); 87 | INDArray qg = gradientViews.get(SelfAttentionParamInitializer.QUERY_KEY); 88 | gradientsFlattened.assign(0); 89 | 90 | applyDropOutIfNecessary(true, workspaceMgr); 91 | 92 | long nIn = layerConf().getNIn(); 93 | long examples = input.shape()[0] == nIn ? input.shape()[2] : input.shape()[0]; 94 | IActivation a = layerConf().getActivationFn(); 95 | 96 | if(input.shape()[0] != nIn) 97 | input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input.permute(1, 2, 0), 'f'); 98 | 99 | INDArray epsOut = workspaceMgr.create(ArrayType.ACTIVATION_GRAD, input.shape(), 'f'); 100 | 101 | final AdditiveAttentionMechanism attentionMechanism = new AdditiveAttentionMechanism(Q, W, b, a, workspaceMgr, true); 102 | 103 | final INDArray queries = q.reshape(nIn, 1, 1).broadcast(nIn, 1, examples); 104 | final INDArray queryG = workspaceMgr.create(ArrayType.BP_WORKING_MEM, queries.shape(), 'f'); 105 | 106 | attentionMechanism.withGradientViews(Wg, Qg, bg, epsOut, epsOut, queryG) 107 | .backprop(epsilon, queries, input, input, maskArray); 108 | 109 | qg.assign(queryG.sum(2).transposei()); 110 | 111 | epsOut = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsOut.permute(2, 0, 1), 'f'); 112 | 113 | weightNoiseParams.clear(); 114 | 115 | Gradient g = new DefaultGradient(gradientsFlattened); 116 | g.gradientForVariable().put(SelfAttentionParamInitializer.WEIGHT_KEY, Wg); 117 | g.gradientForVariable().put(SelfAttentionParamInitializer.QUERY_WEIGHT_KEY, Qg); 118 | g.gradientForVariable().put(SelfAttentionParamInitializer.BIAS_KEY, bg); 119 | g.gradientForVariable().put(SelfAttentionParamInitializer.QUERY_KEY, qg); 120 | 121 | epsOut = backpropDropOutIfPresent(epsOut); 122 | return new Pair<>(g, epsOut); 123 | } 124 | 125 | @Override 126 | public Pair feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, 127 | int minibatchSize) { 128 | // no masking is possible after this point... i.e., masks have been taken into account 129 | // as part of the selfattention 130 | this.maskArray = maskArray; 131 | this.maskState = null; //Not used in global pooling - always applied 132 | 133 | return null; 134 | } 135 | } 136 | -------------------------------------------------------------------------------- /src/main/java/tech/dubs/dl4j/contrib/attention/nn/params/SelfAttentionParamInitializer.java: -------------------------------------------------------------------------------- 1 | /******************************************************************************* 2 | * Copyright (c) 2015-2018 Skymind, Inc. 3 | * 4 | * This program and the accompanying materials are made available under the 5 | * terms of the Apache License, Version 2.0 which is available at 6 | * https://www.apache.org/licenses/LICENSE-2.0. 7 | * 8 | * Unless required by applicable law or agreed to in writing, software 9 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 10 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 11 | * License for the specific language governing permissions and limitations 12 | * under the License. 13 | * 14 | * SPDX-License-Identifier: Apache-2.0 15 | ******************************************************************************/ 16 | 17 | package tech.dubs.dl4j.contrib.attention.nn.params; 18 | 19 | import org.deeplearning4j.nn.api.ParamInitializer; 20 | import org.deeplearning4j.nn.conf.NeuralNetConfiguration; 21 | import org.deeplearning4j.nn.conf.distribution.Distributions; 22 | import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; 23 | import org.deeplearning4j.nn.conf.layers.Layer; 24 | import org.deeplearning4j.nn.params.DefaultParamInitializer; 25 | import org.deeplearning4j.nn.weights.WeightInitUtil; 26 | import org.nd4j.linalg.api.ndarray.INDArray; 27 | import org.nd4j.linalg.api.rng.distribution.Distribution; 28 | 29 | import java.util.*; 30 | 31 | import static org.nd4j.linalg.indexing.NDArrayIndex.interval; 32 | import static org.nd4j.linalg.indexing.NDArrayIndex.point; 33 | 34 | /** 35 | * @author Paul Dubs 36 | */ 37 | public class SelfAttentionParamInitializer implements ParamInitializer { 38 | 39 | private static final SelfAttentionParamInitializer INSTANCE = new SelfAttentionParamInitializer(); 40 | 41 | public static SelfAttentionParamInitializer getInstance(){ 42 | return INSTANCE; 43 | } 44 | 45 | public static final String WEIGHT_KEY = DefaultParamInitializer.WEIGHT_KEY; 46 | public static final String QUERY_WEIGHT_KEY = "Q"; 47 | public static final String QUERY_KEY = "q"; 48 | public static final String BIAS_KEY = DefaultParamInitializer.BIAS_KEY; 49 | 50 | private static final List PARAM_KEYS = Collections.unmodifiableList(Arrays.asList(WEIGHT_KEY, QUERY_WEIGHT_KEY, BIAS_KEY, QUERY_KEY)); 51 | private static final List WEIGHT_KEYS = Collections.unmodifiableList(Arrays.asList(WEIGHT_KEY, QUERY_WEIGHT_KEY)); 52 | private static final List BIAS_KEYS = Collections.unmodifiableList(Arrays.asList(BIAS_KEY, QUERY_KEY)); 53 | 54 | 55 | @Override 56 | public long numParams(NeuralNetConfiguration conf) { 57 | return numParams(conf.getLayer()); 58 | } 59 | 60 | @Override 61 | public long numParams(Layer layer) { 62 | FeedForwardLayer c = (FeedForwardLayer) layer; 63 | final long nIn = c.getNIn(); 64 | final long nOut = c.getNOut(); 65 | 66 | final long paramsW = nIn * nOut; 67 | final long paramsWq = nIn * nOut; 68 | final long paramsB = nOut; 69 | final long paramsQ = nIn; 70 | return paramsW + paramsWq + paramsB + paramsQ; 71 | } 72 | 73 | @Override 74 | public List paramKeys(Layer layer) { 75 | return PARAM_KEYS; 76 | } 77 | 78 | @Override 79 | public List weightKeys(Layer layer) { 80 | return WEIGHT_KEYS; 81 | } 82 | 83 | @Override 84 | public List biasKeys(Layer layer) { 85 | return BIAS_KEYS; 86 | } 87 | 88 | @Override 89 | public boolean isWeightParam(Layer layer, String key) { 90 | return WEIGHT_KEYS.contains(key); 91 | } 92 | 93 | @Override 94 | public boolean isBiasParam(Layer layer, String key) { 95 | return BIAS_KEYS.contains(key); 96 | } 97 | 98 | @Override 99 | public Map init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) { 100 | FeedForwardLayer c = (FeedForwardLayer) conf.getLayer(); 101 | final long nIn = c.getNIn(); 102 | final long nOut = c.getNOut(); 103 | 104 | Map m; 105 | 106 | if (initializeParams) { 107 | Distribution dist = Distributions.createDistribution(c.getDist()); 108 | 109 | m = getSubsets(paramsView, nIn, nOut, false); 110 | INDArray w = WeightInitUtil.initWeights(nIn, nOut, new long[]{nIn, nOut}, c.getWeightInit(), dist, 'f', m.get(WEIGHT_KEY)); 111 | m.put(WEIGHT_KEY, w); 112 | 113 | INDArray rq = WeightInitUtil.initWeights(nIn, nOut, new long[]{nIn, nOut}, c.getWeightInit(), dist,'f', m.get(QUERY_WEIGHT_KEY)); 114 | m.put(QUERY_WEIGHT_KEY, rq); 115 | 116 | INDArray q = WeightInitUtil.initWeights(nIn, 1, new long[]{nIn, 1}, c.getWeightInit(), dist,'f', m.get(QUERY_KEY)); 117 | m.put(QUERY_KEY, q); 118 | } else { 119 | m = getSubsets(paramsView, nIn, nOut, true); 120 | } 121 | 122 | conf.addVariable(WEIGHT_KEY); 123 | conf.addVariable(QUERY_WEIGHT_KEY); 124 | conf.addVariable(BIAS_KEY); 125 | conf.addVariable(QUERY_KEY); 126 | 127 | return m; 128 | } 129 | 130 | @Override 131 | public Map getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) { 132 | FeedForwardLayer c = (FeedForwardLayer) conf.getLayer(); 133 | final long nIn = c.getNIn(); 134 | final long nOut = c.getNOut(); 135 | 136 | return getSubsets(gradientView, nIn, nOut, true); 137 | } 138 | 139 | private static Map getSubsets(INDArray in, long nIn, long nOut, boolean reshape){ 140 | final long endW = nIn * nOut; 141 | final long endWq = endW + nIn * nOut; 142 | final long endB = endWq + nOut; 143 | final long endQ = endB + nIn; 144 | 145 | INDArray w = in.get(point(0), interval(0, endW)); 146 | INDArray wq = in.get(point(0), interval(endW, endWq)); 147 | INDArray b = in.get(point(0), interval(endWq, endB)); 148 | INDArray q = in.get(point(0), interval(endB, endQ)); 149 | 150 | if (reshape) { 151 | w = w.reshape('f', nIn, nOut); 152 | wq = wq.reshape('f', nIn, nOut); 153 | b = b.reshape('f', 1, nOut); 154 | q = q.reshape('f', 1, nIn); 155 | } 156 | 157 | Map m = new LinkedHashMap<>(); 158 | m.put(WEIGHT_KEY, w); 159 | m.put(QUERY_WEIGHT_KEY, wq); 160 | m.put(BIAS_KEY, b); 161 | m.put(QUERY_KEY, q); 162 | return m; 163 | } 164 | } 165 | -------------------------------------------------------------------------------- /src/main/java/tech/dubs/dl4j/contrib/attention/nn/params/RecurrentQueryAttentionParamInitializer.java: -------------------------------------------------------------------------------- 1 | /******************************************************************************* 2 | * Copyright (c) 2015-2018 Skymind, Inc. 3 | * 4 | * This program and the accompanying materials are made available under the 5 | * terms of the Apache License, Version 2.0 which is available at 6 | * https://www.apache.org/licenses/LICENSE-2.0. 7 | * 8 | * Unless required by applicable law or agreed to in writing, software 9 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 10 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 11 | * License for the specific language governing permissions and limitations 12 | * under the License. 13 | * 14 | * SPDX-License-Identifier: Apache-2.0 15 | ******************************************************************************/ 16 | 17 | package tech.dubs.dl4j.contrib.attention.nn.params; 18 | 19 | import org.deeplearning4j.nn.api.ParamInitializer; 20 | import org.deeplearning4j.nn.conf.NeuralNetConfiguration; 21 | import org.deeplearning4j.nn.conf.distribution.Distributions; 22 | import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer; 23 | import org.deeplearning4j.nn.conf.layers.Layer; 24 | import org.deeplearning4j.nn.params.DefaultParamInitializer; 25 | import org.deeplearning4j.nn.weights.WeightInit; 26 | import org.deeplearning4j.nn.weights.WeightInitUtil; 27 | import org.nd4j.linalg.api.ndarray.INDArray; 28 | import org.nd4j.linalg.api.rng.distribution.Distribution; 29 | 30 | import java.util.*; 31 | 32 | import static org.nd4j.linalg.indexing.NDArrayIndex.interval; 33 | import static org.nd4j.linalg.indexing.NDArrayIndex.point; 34 | 35 | /** 36 | * 37 | * TODO: Allow configuring different weight init distribution for every weight and bias type 38 | * 39 | * @author Paul Dubs 40 | */ 41 | public class RecurrentQueryAttentionParamInitializer implements ParamInitializer { 42 | 43 | private static final RecurrentQueryAttentionParamInitializer INSTANCE = new RecurrentQueryAttentionParamInitializer(); 44 | 45 | public static RecurrentQueryAttentionParamInitializer getInstance() { 46 | return INSTANCE; 47 | } 48 | 49 | public static final String WEIGHT_KEY = DefaultParamInitializer.WEIGHT_KEY; 50 | public static final String RECURRENT_WEIGHT_KEY = "WR"; 51 | public static final String BIAS_KEY = DefaultParamInitializer.BIAS_KEY; 52 | public static final String QUERY_WEIGHT_KEY = "WQ"; 53 | public static final String RECURRENT_QUERY_WEIGHT_KEY = "WQR"; 54 | public static final String QUERY_BIAS_KEY = "bQ"; 55 | 56 | private static final List PARAM_KEYS = Collections.unmodifiableList(Arrays.asList(WEIGHT_KEY, QUERY_WEIGHT_KEY, BIAS_KEY, RECURRENT_WEIGHT_KEY, RECURRENT_QUERY_WEIGHT_KEY, QUERY_BIAS_KEY)); 57 | private static final List WEIGHT_KEYS = Collections.unmodifiableList(Arrays.asList(WEIGHT_KEY, QUERY_WEIGHT_KEY, RECURRENT_WEIGHT_KEY, RECURRENT_QUERY_WEIGHT_KEY)); 58 | private static final List BIAS_KEYS = Collections.unmodifiableList(Arrays.asList(BIAS_KEY, QUERY_BIAS_KEY)); 59 | 60 | 61 | @Override 62 | public long numParams(NeuralNetConfiguration conf) { 63 | return numParams(conf.getLayer()); 64 | } 65 | 66 | @Override 67 | public long numParams(Layer layer) { 68 | BaseRecurrentLayer c = (BaseRecurrentLayer) layer; 69 | final long nIn = c.getNIn(); 70 | final long nOut = c.getNOut(); 71 | 72 | final long paramsW = nIn * nOut; 73 | final long paramsWR = nIn * nOut; 74 | final long paramsWq = nIn; 75 | final long paramsWqR = nOut; 76 | final long paramsB = nOut; 77 | final long paramsBq = 1; 78 | return paramsW + paramsWR + paramsWq + paramsWqR + paramsB + paramsBq; 79 | } 80 | 81 | @Override 82 | public List paramKeys(Layer layer) { 83 | return PARAM_KEYS; 84 | } 85 | 86 | @Override 87 | public List weightKeys(Layer layer) { 88 | return WEIGHT_KEYS; 89 | } 90 | 91 | @Override 92 | public List biasKeys(Layer layer) { 93 | return BIAS_KEYS; 94 | } 95 | 96 | @Override 97 | public boolean isWeightParam(Layer layer, String key) { 98 | return WEIGHT_KEYS.contains(key); 99 | } 100 | 101 | @Override 102 | public boolean isBiasParam(Layer layer, String key) { 103 | return BIAS_KEYS.contains(key); 104 | } 105 | 106 | @Override 107 | public Map init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) { 108 | BaseRecurrentLayer c = (BaseRecurrentLayer) conf.getLayer(); 109 | final long nIn = c.getNIn(); 110 | final long nOut = c.getNOut(); 111 | 112 | Map m; 113 | 114 | if (initializeParams) { 115 | Distribution dist = Distributions.createDistribution(c.getDist()); 116 | 117 | m = getSubsets(paramsView, nIn, nOut, false); 118 | INDArray w = WeightInitUtil.initWeights(nIn, nOut, new long[]{nIn, nOut}, c.getWeightInit(), dist, 'f', m.get(WEIGHT_KEY)); 119 | m.put(WEIGHT_KEY, w); 120 | INDArray wq = WeightInitUtil.initWeights(nIn, 1, new long[]{nIn, 1}, c.getWeightInit(), dist, 'f', m.get(QUERY_WEIGHT_KEY)); 121 | m.put(QUERY_WEIGHT_KEY, wq); 122 | 123 | 124 | WeightInit rwInit; 125 | Distribution rwDist = dist; 126 | if (c.getWeightInitRecurrent() == null) { 127 | rwInit = c.getWeightInit(); 128 | } else { 129 | rwInit = c.getWeightInitRecurrent(); 130 | if (c.getDistRecurrent() != null) { 131 | rwDist = Distributions.createDistribution(c.getDistRecurrent()); 132 | } 133 | } 134 | 135 | INDArray rw = WeightInitUtil.initWeights(nIn, nOut, new long[]{nIn, nOut}, rwInit, rwDist, 'f', m.get(RECURRENT_WEIGHT_KEY)); 136 | m.put(RECURRENT_WEIGHT_KEY, rw); 137 | INDArray wqr = WeightInitUtil.initWeights(nOut, 1, new long[]{nOut, 1}, rwInit, rwDist, 'f', m.get(RECURRENT_QUERY_WEIGHT_KEY)); 138 | m.put(RECURRENT_QUERY_WEIGHT_KEY, wqr); 139 | } else { 140 | m = getSubsets(paramsView, nIn, nOut, true); 141 | } 142 | 143 | for (String paramKey : PARAM_KEYS) { 144 | conf.addVariable(paramKey); 145 | } 146 | 147 | return m; 148 | } 149 | 150 | @Override 151 | public Map getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) { 152 | BaseRecurrentLayer c = (BaseRecurrentLayer) conf.getLayer(); 153 | final long nIn = c.getNIn(); 154 | final long nOut = c.getNOut(); 155 | 156 | return getSubsets(gradientView, nIn, nOut, true); 157 | } 158 | 159 | private static Map getSubsets(INDArray in, long nIn, long nOut, boolean reshape) { 160 | final long endW = nIn * nOut; 161 | final long endWq = endW + nIn; 162 | final long endWR = endWq + nIn * nOut; 163 | final long endWqR = endWR + nOut; 164 | final long endB = endWqR + nOut; 165 | final long endBq = endB + 1; 166 | 167 | INDArray w = in.get(point(0), interval(0, endW)); 168 | INDArray wq = in.get(point(0), interval(endW, endWq)); 169 | INDArray wr = in.get(point(0), interval(endWq, endWR)); 170 | INDArray wqr = in.get(point(0), interval(endWR, endWqR)); 171 | INDArray b = in.get(point(0), interval(endWqR, endB)); 172 | INDArray bq = in.get(point(0), interval(endB, endBq)); 173 | 174 | if (reshape) { 175 | w = w.reshape('f', nIn, nOut); 176 | wr = wr.reshape('f', nIn, nOut); 177 | wq = wq.reshape('f', nIn, 1); 178 | wqr = wqr.reshape('f', nOut, 1); 179 | b = b.reshape('f', 1, nOut); 180 | bq = bq.reshape('f', 1, 1); 181 | } 182 | 183 | Map m = new LinkedHashMap<>(); 184 | m.put(WEIGHT_KEY, w); 185 | m.put(QUERY_WEIGHT_KEY, wq); 186 | m.put(RECURRENT_WEIGHT_KEY, wr); 187 | m.put(RECURRENT_QUERY_WEIGHT_KEY, wqr); 188 | m.put(BIAS_KEY, b); 189 | m.put(QUERY_BIAS_KEY, bq); 190 | return m; 191 | } 192 | } 193 | -------------------------------------------------------------------------------- /src/test/java/tech/dubs/dl4j/contrib/attention/GradientChecks.java: -------------------------------------------------------------------------------- 1 | package tech.dubs.dl4j.contrib.attention; 2 | 3 | import org.deeplearning4j.gradientcheck.GradientCheckUtil; 4 | import org.deeplearning4j.nn.conf.MultiLayerConfiguration; 5 | import org.deeplearning4j.nn.conf.NeuralNetConfiguration; 6 | import org.deeplearning4j.nn.conf.inputs.InputType; 7 | import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer; 8 | import org.deeplearning4j.nn.conf.layers.LSTM; 9 | import org.deeplearning4j.nn.conf.layers.OutputLayer; 10 | import org.deeplearning4j.nn.conf.layers.PoolingType; 11 | import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep; 12 | import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; 13 | import org.deeplearning4j.nn.weights.WeightInit; 14 | import org.junit.Test; 15 | import org.nd4j.linalg.activations.Activation; 16 | import org.nd4j.linalg.api.buffer.DataBuffer; 17 | import org.nd4j.linalg.api.ndarray.INDArray; 18 | import org.nd4j.linalg.factory.Nd4j; 19 | import org.nd4j.linalg.learning.config.NoOp; 20 | import org.nd4j.linalg.lossfunctions.LossFunctions; 21 | import org.nd4j.linalg.profiler.OpProfiler; 22 | import tech.dubs.dl4j.contrib.attention.conf.RecurrentAttentionLayer; 23 | import tech.dubs.dl4j.contrib.attention.conf.SelfAttentionLayer; 24 | import tech.dubs.dl4j.contrib.attention.conf.TimestepAttentionLayer; 25 | 26 | import java.util.Random; 27 | 28 | import static org.junit.Assert.assertTrue; 29 | 30 | public class GradientChecks { 31 | private static final boolean PRINT_RESULTS = false; 32 | private static final boolean RETURN_ON_FIRST_FAILURE = false; 33 | private static final double DEFAULT_EPS = 1e-6; 34 | private static final double DEFAULT_MAX_REL_ERROR = 1e-3; 35 | private static final double DEFAULT_MIN_ABS_ERROR = 1e-8; 36 | 37 | static { 38 | Nd4j.setDataType(DataBuffer.Type.DOUBLE); 39 | } 40 | 41 | @Test 42 | public void testSelfAttentionLayer() { 43 | int nIn = 3; 44 | int nOut = 5; 45 | int tsLength = 4; 46 | int layerSize = 8; 47 | int attentionHeads = 2; 48 | 49 | Random r = new Random(12345); 50 | for (int mb : new int[]{1, 2, 3}) { 51 | for (boolean inputMask : new boolean[]{false, true}) { 52 | INDArray in = Nd4j.rand(new int[]{mb, nIn, tsLength}); 53 | INDArray labels = Nd4j.create(mb, nOut); 54 | for (int i = 0; i < mb; i++) { 55 | labels.putScalar(i, r.nextInt(nOut), 1.0); 56 | } 57 | String maskType = (inputMask ? "inputMask" : "none"); 58 | 59 | INDArray inMask = null; 60 | if (inputMask) { 61 | inMask = Nd4j.ones(mb, tsLength); 62 | for (int i = 0; i < mb; i++) { 63 | int firstMaskedStep = tsLength - 1 - i; 64 | if (firstMaskedStep == 0) { 65 | firstMaskedStep = tsLength; 66 | } 67 | for (int j = firstMaskedStep; j < tsLength; j++) { 68 | inMask.putScalar(i, j, 0.0); 69 | } 70 | } 71 | } 72 | 73 | String name = "testSelfAttentionLayer() - mb=" + mb + ", tsLength = " + tsLength + ", maskType=" + maskType; 74 | System.out.println("Starting test: " + name); 75 | 76 | 77 | MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() 78 | .activation(Activation.TANH) 79 | .updater(new NoOp()) 80 | .weightInit(WeightInit.XAVIER) 81 | .list() 82 | .layer(new LSTM.Builder().nOut(layerSize).build()) 83 | .layer(new SelfAttentionLayer.Builder().nOut(attentionHeads).build()) 84 | .layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX) 85 | .lossFunction(LossFunctions.LossFunction.MCXENT).build()) 86 | .setInputType(InputType.recurrent(nIn)) 87 | .build(); 88 | 89 | MultiLayerNetwork net = new MultiLayerNetwork(conf); 90 | net.init(); 91 | 92 | boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, 93 | DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE,in, labels, inMask, null, false, -1, 94 | null //Sets.newHashSet( /*"1_b", "1_W",* "1_WR", "1_WQR", "1_WQ", "1_bQ",*/ "2_b", "2_W" ,"0_W", "0_RW", "0_b"/**/) 95 | ); 96 | assertTrue(name, gradOK); 97 | } 98 | } 99 | 100 | OpProfiler.getInstance().printOutDashboard(); 101 | } 102 | 103 | @Test 104 | public void testTimestepAttentionLayer() { 105 | int nIn = 3; 106 | int nOut = 5; 107 | int tsLength = 4; 108 | int layerSize = 8; 109 | int attentionHeads = 3; 110 | 111 | 112 | Random r = new Random(12345); 113 | for (int mb : new int[]{1, 2, 3}) { 114 | for (boolean inputMask : new boolean[]{false, true}) { 115 | INDArray in = Nd4j.rand(new int[]{mb, nIn, tsLength}); 116 | INDArray labels = Nd4j.create(mb, nOut); 117 | for (int i = 0; i < mb; i++) { 118 | labels.putScalar(i, r.nextInt(nOut), 1.0); 119 | } 120 | String maskType = (inputMask ? "inputMask" : "none"); 121 | 122 | INDArray inMask = null; 123 | if (inputMask) { 124 | inMask = Nd4j.ones(mb, tsLength); 125 | for (int i = 0; i < mb; i++) { 126 | int firstMaskedStep = tsLength - 1 - i; 127 | if (firstMaskedStep == 0) { 128 | firstMaskedStep = tsLength; 129 | } 130 | for (int j = firstMaskedStep; j < tsLength; j++) { 131 | inMask.putScalar(i, j, 0.0); 132 | } 133 | } 134 | } 135 | 136 | String name = "testTimestepAttentionLayer() - mb=" + mb + ", tsLength = " + tsLength + ", maskType=" + maskType; 137 | System.out.println("Starting test: " + name); 138 | 139 | 140 | MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() 141 | .activation(Activation.TANH) 142 | .updater(new NoOp()) 143 | .weightInit(WeightInit.XAVIER) 144 | .list() 145 | .layer(new LSTM.Builder().nOut(layerSize).build()) 146 | .layer(new TimestepAttentionLayer.Builder().nOut(attentionHeads).build()) 147 | .layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.MAX).build()) 148 | .layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX) 149 | .lossFunction(LossFunctions.LossFunction.MCXENT).build()) 150 | .setInputType(InputType.recurrent(nIn)) 151 | .build(); 152 | 153 | MultiLayerNetwork net = new MultiLayerNetwork(conf); 154 | net.init(); 155 | 156 | boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, 157 | DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, in, labels, inMask, null); 158 | assertTrue(name, gradOK); 159 | } 160 | } 161 | } 162 | 163 | @Test 164 | public void testRecurrentAttentionLayer() { 165 | int nIn = 3; 166 | int nOut = 5; 167 | int tsLength = 4; 168 | int layerSize = 8; 169 | int attentionHeads = 7; 170 | 171 | 172 | Random r = new Random(12345); 173 | for (int mb : new int[]{3, 2, 1}) { 174 | for (boolean inputMask : new boolean[]{true, false}) { 175 | INDArray in = Nd4j.rand(new int[]{mb, nIn, tsLength}); 176 | INDArray labels = Nd4j.create(mb, nOut); 177 | for (int i = 0; i < mb; i++) { 178 | labels.putScalar(i, r.nextInt(nOut), 1.0); 179 | } 180 | String maskType = (inputMask ? "inputMask" : "none"); 181 | 182 | INDArray inMask = null; 183 | if (inputMask) { 184 | inMask = Nd4j.ones(mb, tsLength); 185 | for (int i = 0; i < mb; i++) { 186 | int firstMaskedStep = tsLength - 1 - i; 187 | if (firstMaskedStep == 0) { 188 | firstMaskedStep = tsLength; 189 | } 190 | for (int j = firstMaskedStep; j < tsLength; j++) { 191 | inMask.putScalar(i, j, 0.0); 192 | } 193 | } 194 | } 195 | 196 | String name = "testRecurrentAttentionLayer() - mb=" + mb + ", tsLength = " + tsLength + ", maskType=" + maskType; 197 | System.out.println("Starting test: " + name); 198 | 199 | 200 | MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() 201 | .activation(Activation.TANH) 202 | .updater(new NoOp()) 203 | .weightInit(WeightInit.XAVIER) 204 | .list() 205 | .layer(new LSTM.Builder().nOut(layerSize).build()) 206 | .layer(new LastTimeStep(new RecurrentAttentionLayer.Builder().nOut(7).build())) 207 | .layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX) 208 | .lossFunction(LossFunctions.LossFunction.MCXENT).build()) 209 | .setInputType(InputType.recurrent(nIn)) 210 | .build(); 211 | 212 | MultiLayerNetwork net = new MultiLayerNetwork(conf); 213 | net.init(); 214 | 215 | //System.out.println("Original"); 216 | boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, 217 | DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, in, labels, inMask, null, false, -1, null 218 | //Sets.newHashSet( /*"1_b", "1_W",* "1_WR", "1_WQR", "1_WQ", "1_bQ",*/ "2_b", "2_W" ,"0_W", "0_RW", "0_b"/**/) 219 | ); 220 | assertTrue(name, gradOK); 221 | } 222 | } 223 | } 224 | } -------------------------------------------------------------------------------- /src/main/java/tech/dubs/dl4j/contrib/attention/nn/RecurrentAttentionLayer.java: -------------------------------------------------------------------------------- 1 | package tech.dubs.dl4j.contrib.attention.nn; 2 | 3 | import org.deeplearning4j.nn.conf.NeuralNetConfiguration; 4 | import org.deeplearning4j.nn.gradient.DefaultGradient; 5 | import org.deeplearning4j.nn.gradient.Gradient; 6 | import org.deeplearning4j.nn.layers.BaseLayer; 7 | import org.deeplearning4j.nn.workspace.ArrayType; 8 | import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; 9 | import org.nd4j.base.Preconditions; 10 | import org.nd4j.linalg.activations.IActivation; 11 | import org.nd4j.linalg.activations.impl.ActivationSoftmax; 12 | import org.nd4j.linalg.api.ndarray.INDArray; 13 | import org.nd4j.linalg.api.shape.Shape; 14 | import org.nd4j.linalg.factory.Nd4j; 15 | import org.nd4j.linalg.primitives.Pair; 16 | import tech.dubs.dl4j.contrib.attention.nn.params.RecurrentQueryAttentionParamInitializer; 17 | 18 | import static org.nd4j.linalg.indexing.NDArrayIndex.all; 19 | import static org.nd4j.linalg.indexing.NDArrayIndex.point; 20 | 21 | /** 22 | * Recurrent Attention Layer Implementation 23 | * 24 | * 25 | * 26 | * TODO: 27 | * - Optionally keep attention weights around for inspection 28 | * - Handle Masking 29 | * 30 | * @author Paul Dubs 31 | */ 32 | public class RecurrentAttentionLayer extends BaseLayer { 33 | private IActivation softmax = new ActivationSoftmax(); 34 | 35 | public RecurrentAttentionLayer(NeuralNetConfiguration conf) { 36 | super(conf); 37 | } 38 | 39 | @Override 40 | public boolean isPretrainLayer() { 41 | return false; 42 | } 43 | 44 | @Override 45 | public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) { 46 | assertInputSet(false); 47 | Preconditions.checkState(input.rank() == 3, 48 | "3D input expected to RNN layer expected, got " + input.rank()); 49 | 50 | applyDropOutIfNecessary(training, workspaceMgr); 51 | 52 | INDArray W = getParamWithNoise(RecurrentQueryAttentionParamInitializer.WEIGHT_KEY, training, workspaceMgr); 53 | INDArray Wr = getParamWithNoise(RecurrentQueryAttentionParamInitializer.RECURRENT_WEIGHT_KEY, training, workspaceMgr); 54 | INDArray Wq = getParamWithNoise(RecurrentQueryAttentionParamInitializer.QUERY_WEIGHT_KEY, training, workspaceMgr); 55 | INDArray Wqr = getParamWithNoise(RecurrentQueryAttentionParamInitializer.RECURRENT_QUERY_WEIGHT_KEY, training, workspaceMgr); 56 | INDArray b = getParamWithNoise(RecurrentQueryAttentionParamInitializer.BIAS_KEY, training, workspaceMgr); 57 | INDArray bq = getParamWithNoise(RecurrentQueryAttentionParamInitializer.QUERY_BIAS_KEY, training, workspaceMgr); 58 | 59 | long examples = input.size(0); 60 | long tsLength = input.size(2); 61 | long nIn = layerConf().getNIn(); 62 | long nOut = layerConf().getNOut(); 63 | IActivation a = layerConf().getActivationFn(); 64 | 65 | INDArray activations = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, new long[]{examples, nOut, tsLength}, 'f'); 66 | 67 | if(input.shape()[0] != nIn) 68 | input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input.permute(1, 2, 0), 'f'); 69 | 70 | final AdditiveAttentionMechanism attentionMechanism = new AdditiveAttentionMechanism(Wqr, Wq, bq, a, workspaceMgr, training).useCaching(); 71 | 72 | 73 | // pre-compute non-recurrent part 74 | activations.assign( 75 | Nd4j.gemm(W, input.reshape('f', nIn, tsLength * examples), true, false) 76 | .addiColumnVector(b.transpose()) 77 | .reshape('f', nOut, tsLength, examples).permute(2, 0, 1) 78 | ); 79 | 80 | 81 | for (long timestep = 0; timestep < tsLength; timestep++) { 82 | final INDArray curOut = timestepArray(activations, timestep); 83 | if(timestep > 0){ 84 | final INDArray prevActivation = timestepArray(activations, timestep - 1); 85 | final INDArray queries = Nd4j.expandDims(prevActivation, 2).permute(1,2,0); 86 | final INDArray attention = attentionMechanism.query(queries, input, input, maskArray); 87 | curOut.addi(Nd4j.squeeze(attention, 2).mmul(Wr)); 88 | } 89 | 90 | a.getActivation(curOut, true); 91 | } 92 | 93 | 94 | return activations; 95 | } 96 | 97 | 98 | 99 | /* 100 | * Notice that the epsilon given here does not contain the recurrent component, which will have to be calculated 101 | * manually. 102 | */ 103 | @Override 104 | public Pair backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { 105 | assertInputSet(true); 106 | if(epsilon.ordering() != 'f' || !Shape.hasDefaultStridesForShape(epsilon)) 107 | epsilon = epsilon.dup('f'); 108 | 109 | INDArray W = getParamWithNoise(RecurrentQueryAttentionParamInitializer.WEIGHT_KEY, true, workspaceMgr); 110 | INDArray Wr = getParamWithNoise(RecurrentQueryAttentionParamInitializer.RECURRENT_WEIGHT_KEY, true, workspaceMgr); 111 | INDArray Wq = getParamWithNoise(RecurrentQueryAttentionParamInitializer.QUERY_WEIGHT_KEY, true, workspaceMgr); 112 | INDArray Wqr = getParamWithNoise(RecurrentQueryAttentionParamInitializer.RECURRENT_QUERY_WEIGHT_KEY, true, workspaceMgr); 113 | INDArray b = getParamWithNoise(RecurrentQueryAttentionParamInitializer.BIAS_KEY, true, workspaceMgr); 114 | INDArray bq = getParamWithNoise(RecurrentQueryAttentionParamInitializer.QUERY_BIAS_KEY, true, workspaceMgr); 115 | 116 | INDArray Wg = gradientViews.get(RecurrentQueryAttentionParamInitializer.WEIGHT_KEY); 117 | INDArray Wrg = gradientViews.get(RecurrentQueryAttentionParamInitializer.RECURRENT_WEIGHT_KEY); 118 | INDArray Wqg = gradientViews.get(RecurrentQueryAttentionParamInitializer.QUERY_WEIGHT_KEY); 119 | INDArray Wqrg = gradientViews.get(RecurrentQueryAttentionParamInitializer.RECURRENT_QUERY_WEIGHT_KEY); 120 | INDArray bg = gradientViews.get(RecurrentQueryAttentionParamInitializer.BIAS_KEY); 121 | INDArray bqg = gradientViews.get(RecurrentQueryAttentionParamInitializer.QUERY_BIAS_KEY); 122 | gradientsFlattened.assign(0); 123 | 124 | applyDropOutIfNecessary(true, workspaceMgr); 125 | 126 | 127 | long nIn = layerConf().getNIn(); 128 | long nOut = layerConf().getNOut(); 129 | long examples = input.shape()[0] == nIn ? input.shape()[2] : input.shape()[0]; 130 | long tsLength = input.shape()[0] == nIn ? input.shape()[1] : input.shape()[2]; 131 | IActivation a = layerConf().getActivationFn(); 132 | 133 | INDArray activations = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, new long[]{examples, nOut, tsLength}, 'f'); 134 | INDArray preOut = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, new long[]{examples, nOut, tsLength}, 'f'); 135 | INDArray attentions = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, new long[]{examples, nIn, tsLength}, 'f'); 136 | INDArray queryG = workspaceMgr.create(ArrayType.BP_WORKING_MEM, new long[]{nOut, 1, examples}, 'f'); 137 | 138 | if(input.shape()[0] != nIn) 139 | input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input.permute(1, 2, 0), 'f'); 140 | 141 | INDArray epsOut = workspaceMgr.create(ArrayType.ACTIVATION_GRAD, input.shape(), 'f'); 142 | epsOut.assign(0); 143 | 144 | 145 | final AdditiveAttentionMechanism attentionMechanism = new AdditiveAttentionMechanism(Wqr, Wq, bq, a, workspaceMgr, true).useCaching(); 146 | 147 | // pre-compute non-recurrent part 148 | activations.assign( 149 | Nd4j.gemm(W, input.reshape('f', nIn, tsLength * examples), true, false) 150 | .addiColumnVector(b.transpose()) 151 | .reshape('f', nOut, tsLength, examples).permute(2, 0, 1) 152 | ); 153 | 154 | 155 | for (long timestep = 0; timestep < tsLength; timestep++) { 156 | final INDArray curOut = timestepArray(activations, timestep); 157 | 158 | if(timestep > 0){ 159 | final INDArray prevActivation = timestepArray(activations, timestep - 1); 160 | final INDArray query = Nd4j.expandDims(prevActivation, 2).permute(1, 2, 0); 161 | final INDArray attention = Nd4j.squeeze(attentionMechanism.query(query, input, input, maskArray), 2); 162 | timestepArray(attentions, timestep).assign(attention); 163 | 164 | curOut.addi(attention.mmul(Wr)); 165 | } 166 | timestepArray(preOut, timestep).assign(curOut); 167 | a.getActivation(curOut, true); 168 | } 169 | 170 | 171 | for (long timestep = tsLength - 1; timestep >= 0; timestep--) { 172 | final INDArray curEps = timestepArray(epsilon, timestep); 173 | final INDArray curPreOut = timestepArray(preOut, timestep); 174 | final INDArray curIn = input.get(all(), point(timestep), all()); 175 | 176 | final INDArray dldz = a.backprop(curPreOut, curEps).getFirst(); 177 | Wg.addi(Nd4j.gemm(curIn, dldz, false, false)); 178 | bg.addi(dldz.sum(0)); 179 | epsOut.tensorAlongDimension((int)timestep, 0, 2).addi(Nd4j.gemm(dldz, W, false, true).transposei()); 180 | 181 | if(timestep > 0){ 182 | final INDArray curAttn = timestepArray(attentions, timestep); 183 | 184 | Wrg.addi(Nd4j.gemm(curAttn, dldz, true, false)); 185 | 186 | final INDArray prevEps = timestepArray(epsilon, timestep - 1); 187 | final INDArray prevActivation = timestepArray(activations, timestep - 1); 188 | final INDArray query = Nd4j.expandDims(prevActivation, 2).permute(1,2,0); 189 | queryG.assign(0); 190 | 191 | final INDArray dldAtt = Nd4j.gemm(dldz, Wr, false, true); 192 | attentionMechanism 193 | .withGradientViews(Wqg, Wqrg, bqg, epsOut, epsOut, queryG) 194 | .backprop(dldAtt, query, input, input, maskArray); 195 | 196 | prevEps.addi(Nd4j.squeeze(queryG, 1).transpose()); 197 | } 198 | } 199 | 200 | weightNoiseParams.clear(); 201 | 202 | Gradient g = new DefaultGradient(gradientsFlattened); 203 | g.gradientForVariable().put(RecurrentQueryAttentionParamInitializer.WEIGHT_KEY, Wg); 204 | g.gradientForVariable().put(RecurrentQueryAttentionParamInitializer.QUERY_WEIGHT_KEY, Wqg); 205 | g.gradientForVariable().put(RecurrentQueryAttentionParamInitializer.RECURRENT_QUERY_WEIGHT_KEY, Wqrg); 206 | g.gradientForVariable().put(RecurrentQueryAttentionParamInitializer.RECURRENT_WEIGHT_KEY, Wrg); 207 | g.gradientForVariable().put(RecurrentQueryAttentionParamInitializer.BIAS_KEY, bg); 208 | g.gradientForVariable().put(RecurrentQueryAttentionParamInitializer.QUERY_BIAS_KEY, bqg); 209 | 210 | epsOut = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsOut.permute(2, 0, 1), 'f'); 211 | epsOut = backpropDropOutIfPresent(epsOut); 212 | 213 | return new Pair<>(g, epsOut); 214 | } 215 | 216 | private INDArray subArray(INDArray in, int example, int timestep){ 217 | return in.tensorAlongDimension(example, 1, 2).tensorAlongDimension(timestep, 0); 218 | } 219 | 220 | private INDArray timestepArray(INDArray in, long timestep){ 221 | return in.tensorAlongDimension((int) timestep, 0, 1); 222 | } 223 | } 224 | -------------------------------------------------------------------------------- /src/main/java/tech/dubs/dl4j/contrib/attention/nn/AdditiveAttentionMechanism.java: -------------------------------------------------------------------------------- 1 | package tech.dubs.dl4j.contrib.attention.nn; 2 | 3 | import org.deeplearning4j.nn.workspace.ArrayType; 4 | import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; 5 | import org.nd4j.linalg.activations.IActivation; 6 | import org.nd4j.linalg.api.memory.MemoryWorkspace; 7 | import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; 8 | import org.nd4j.linalg.api.memory.enums.AllocationPolicy; 9 | import org.nd4j.linalg.api.memory.enums.LearningPolicy; 10 | import org.nd4j.linalg.api.ndarray.INDArray; 11 | import org.nd4j.linalg.api.shape.Shape; 12 | import org.nd4j.linalg.factory.Nd4j; 13 | import tech.dubs.dl4j.contrib.attention.activations.ActivationMaskedSoftmax; 14 | 15 | import java.util.Arrays; 16 | 17 | import static org.nd4j.linalg.indexing.NDArrayIndex.all; 18 | import static org.nd4j.linalg.indexing.NDArrayIndex.point; 19 | 20 | /* 21 | * Attention: Shapes for keys, values and queries should be in [features, timesteps, examples] order! 22 | * @author Paul Dubs 23 | */ 24 | public class AdditiveAttentionMechanism { 25 | private final INDArray W; 26 | private final INDArray Q; 27 | private final INDArray b; 28 | private final IActivation activation; 29 | private final ActivationMaskedSoftmax softmax; 30 | private final LayerWorkspaceMgr mgr; 31 | private final boolean training; 32 | private boolean caching; 33 | private INDArray WkCache; 34 | 35 | // Required to be set for backprop 36 | private INDArray Wg; 37 | private INDArray Qg; 38 | private INDArray bg; 39 | private INDArray keyG; 40 | private INDArray valueG; 41 | private INDArray queryG; 42 | 43 | public AdditiveAttentionMechanism(INDArray queryWeight, INDArray keyWeight, INDArray bias, IActivation activation, LayerWorkspaceMgr mgr, boolean training) { 44 | assertWeightShapes(queryWeight, keyWeight, bias); 45 | Q = queryWeight; 46 | W = keyWeight; 47 | b = bias; 48 | this.activation = activation; 49 | softmax = new ActivationMaskedSoftmax(); 50 | this.mgr = mgr; 51 | this.training = training; 52 | 53 | this.caching = false; 54 | } 55 | 56 | public AdditiveAttentionMechanism useCaching() { 57 | this.caching = true; 58 | return this; 59 | } 60 | 61 | public INDArray query(INDArray queries, INDArray keys, INDArray values, INDArray mask) { 62 | assertShapes(queries, keys, values); 63 | 64 | final long examples = queries.shape()[2]; 65 | final long queryCount = queries.shape()[1]; 66 | final long queryWidth = queries.shape()[0]; 67 | final long attentionHeads = W.shape()[1]; 68 | final long memoryWidth = W.shape()[0]; 69 | final long tsLength = keys.shape()[1]; 70 | 71 | final INDArray result = mgr.createUninitialized(ArrayType.FF_WORKING_MEM, new long[]{examples, memoryWidth * attentionHeads, queryCount}, 'f'); 72 | 73 | WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder() 74 | .policyAllocation(AllocationPolicy.STRICT) 75 | .policyLearning(LearningPolicy.FIRST_LOOP) 76 | .build(); 77 | 78 | if (this.caching && this.WkCache == null) { 79 | final INDArray target = mgr.createUninitialized(ArrayType.FF_WORKING_MEM, new long[]{attentionHeads, tsLength * examples}, 'f'); 80 | 81 | this.WkCache = Nd4j.gemm(W, keys.reshape('f', memoryWidth, tsLength * examples), target, true, false, 1.0, 0.0) 82 | .addiColumnVector(b.transpose()) 83 | .reshape('f', attentionHeads, tsLength, examples); 84 | } 85 | 86 | final INDArray queryRes = Nd4j.gemm(Q, queries.reshape('f', queryWidth, queryCount * examples), true, false) 87 | .reshape('f', attentionHeads, queryCount, examples); 88 | 89 | for (long example = 0; example < examples; example++) { 90 | try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(initialConfig, "ATTENTION_FF")) { 91 | final INDArray curValues = values.get(all(), all(), point(example)); 92 | final INDArray curKeys = keys.get(all(), all(), point(example)); 93 | 94 | final INDArray preResult; 95 | if (this.caching) { 96 | preResult = this.WkCache.get(all(), all(), point(example)); 97 | } else { 98 | preResult = Nd4j.gemm(W, curKeys, true, false); 99 | preResult.addiColumnVector(b.transpose()); 100 | } 101 | 102 | 103 | final INDArray curMask = subMask(mask, example); 104 | final INDArray attentionHeadMask = attentionHeadMask(curMask, preResult.shape()); 105 | 106 | for (long queryIdx = 0; queryIdx < queryCount; queryIdx++) { 107 | final INDArray query = queries.get(all(), point(queryIdx), point(example)); 108 | final INDArray curResult = subArray(result, example, queryIdx); 109 | 110 | final INDArray queryResult = queryRes.get(all(), point(queryIdx), point(example)); 111 | 112 | final INDArray preA = preResult.addColumnVector(queryResult); 113 | final INDArray preS = this.activation.getActivation(preA, training); 114 | final INDArray attW = softmax.getActivation(preS, attentionHeadMask); 115 | 116 | final INDArray att = Nd4j.gemm(curValues, attW, false, true); 117 | curResult.assign(att.reshape('f', 1, memoryWidth * attentionHeads)); 118 | } 119 | } 120 | } 121 | return result; 122 | } 123 | 124 | public AdditiveAttentionMechanism withGradientViews(INDArray W, INDArray Q, INDArray b, INDArray keys, INDArray values, INDArray queries) { 125 | Wg = W; 126 | Qg = Q; 127 | bg = b; 128 | keyG = keys; 129 | valueG = values; 130 | queryG = queries; 131 | 132 | return this; 133 | } 134 | 135 | public void backprop(INDArray epsilon, INDArray queries, INDArray keys, INDArray values, INDArray mask) { 136 | if (Wg == null || Qg == null || bg == null || keyG == null || valueG == null || queryG == null) { 137 | throw new IllegalStateException("You MUST use attnMech.withGradientViews(...).backprop(...)."); 138 | } 139 | 140 | assertShapes(queries, keys, values); 141 | 142 | final long examples = queries.shape()[2]; 143 | final long queryCount = queries.shape()[1]; 144 | final long queryWidth = queries.shape()[0]; 145 | final long attentionHeads = W.shape()[1]; 146 | final long memoryWidth = W.shape()[0]; 147 | final long tsLength = keys.shape()[1]; 148 | 149 | 150 | if (epsilon.ordering() != 'c' || !Shape.hasDefaultStridesForShape(epsilon)) 151 | epsilon = epsilon.dup('c'); 152 | 153 | final long[] epsilonShape = epsilon.shape(); 154 | if (epsilonShape[0] != examples || epsilonShape[1] != (attentionHeads * memoryWidth) || (epsilonShape.length == 2 && queryCount != 1) || (epsilonShape.length == 3 && epsilonShape[2] != queryCount)) { 155 | throw new IllegalStateException("Epsilon shape must match result shape. Got epsilon.shape() = " + Arrays.toString(epsilonShape) 156 | + "; result shape = [" + examples + ", " + attentionHeads * memoryWidth + ", " + queryCount + "]"); 157 | } 158 | 159 | WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder() 160 | .policyAllocation(AllocationPolicy.STRICT) 161 | .policyLearning(LearningPolicy.FIRST_LOOP) 162 | .build(); 163 | 164 | final INDArray dldAtt = epsilon.reshape('c', examples, attentionHeads, memoryWidth, queryCount); 165 | 166 | if (this.caching && this.WkCache == null) { 167 | final INDArray target = mgr.createUninitialized(ArrayType.BP_WORKING_MEM, new long[]{attentionHeads, tsLength * examples}, 'f'); 168 | 169 | this.WkCache = Nd4j.gemm(W, keys.reshape('f', memoryWidth, tsLength * examples), target, true, false, 1.0, 0.0) 170 | .addiColumnVector(b.transpose()) 171 | .reshape('f', attentionHeads, tsLength, examples); 172 | } 173 | 174 | final INDArray queryRes = Nd4j.gemm(Q, queries.reshape('f', queryWidth, queryCount * examples), true, false) 175 | .reshape('f', attentionHeads, queryCount, examples); 176 | 177 | for (long example = 0; example < examples; example++) { 178 | try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(initialConfig, "ATTENTION_BP")) { 179 | final INDArray curValues = values.get(all(), all(), point(example)); 180 | final INDArray curKeys = keys.get(all(), all(), point(example)); 181 | 182 | final INDArray exEps = dldAtt.tensorAlongDimension((int) example, 1, 2, 3); 183 | 184 | final INDArray preResult; 185 | if (this.caching) { 186 | preResult = this.WkCache.get(all(), all(), point(example)); 187 | } else { 188 | preResult = Nd4j.gemm(W, curKeys, true, false); 189 | preResult.addiColumnVector(b.transpose()); 190 | } 191 | 192 | final INDArray curMask = subMask(mask, example); 193 | final INDArray attentionHeadMask = attentionHeadMask(curMask, preResult.shape()); 194 | 195 | for (long queryIdx = 0; queryIdx < queryCount; queryIdx++) { 196 | final INDArray curEps = exEps.tensorAlongDimension((int) queryIdx, 0, 1); 197 | final INDArray query = queries.get(all(), point(queryIdx), point(example)); 198 | 199 | final INDArray queryResult = queryRes.get(all(), point(queryIdx), point(example)); 200 | 201 | final INDArray preA = preResult.addColumnVector(queryResult); 202 | final INDArray preS = this.activation.getActivation(preA.dup(), training); 203 | final INDArray attW = softmax.getActivation(preS, attentionHeadMask); 204 | 205 | valueG.get(all(), all(), point(example)).addi(Nd4j.gemm(curEps, attW, true, false)); 206 | 207 | final INDArray dldAttW = Nd4j.gemm(curEps, curValues, false, false); 208 | final INDArray dldPreS = softmax.backprop(attW, attentionHeadMask, dldAttW).getFirst(); 209 | final INDArray dldPreA = activation.backprop(preA, dldPreS).getFirst(); 210 | 211 | final INDArray dldPreASum = dldPreA.sum(1); 212 | 213 | Nd4j.gemm(query, dldPreASum, Qg, false, true, 1.0, 1.0); 214 | Nd4j.gemm(curKeys, dldPreA, Wg, false, true, 1.0, 1.0); 215 | 216 | bg.addi(dldPreASum.transpose()); 217 | 218 | keyG.get(all(), all(), point(example)).addi(Nd4j.gemm(W, dldPreA, false, false)); 219 | queryG.get(all(), point(queryIdx), point(example)).addi(Nd4j.gemm(Q, dldPreASum, false, false)); 220 | } 221 | } 222 | } 223 | } 224 | 225 | private void assertWeightShapes(INDArray queryWeight, INDArray keyWeight, INDArray bias) { 226 | final long qOut = queryWeight.shape()[1]; 227 | final long kOut = keyWeight.shape()[1]; 228 | final long bOut = bias.shape()[1]; 229 | if (qOut != kOut || qOut != bOut) { 230 | throw new IllegalStateException("Shapes must be compatible: queryWeight.shape() = " + Arrays.toString(queryWeight.shape()) 231 | + ", keyWeight.shape() = " + Arrays.toString(keyWeight.shape()) 232 | + ", bias.shape() = " + Arrays.toString(bias.shape()) 233 | + "\n Compatible shapes should have the same second dimension, but got: [" + qOut + ", " + kOut + ", " + bOut + "]" 234 | ); 235 | } 236 | } 237 | 238 | private void assertShapes(INDArray query, INDArray keys, INDArray values) { 239 | final long kIn = W.shape()[0]; 240 | final long qIn = Q.shape()[0]; 241 | 242 | if (query.shape()[0] != qIn || keys.shape()[0] != kIn) { 243 | throw new IllegalStateException("Shapes of query and keys must be compatible to weights, but got: queryWeight.shape() = " + Arrays.toString(Q.shape()) 244 | + ", queries.shape() = " + Arrays.toString(query.shape()) 245 | + "; keyWeight.shape() = " + Arrays.toString(W.shape()) 246 | + ", keys.shape() = " + Arrays.toString(keys.shape()) 247 | ); 248 | } 249 | 250 | if (keys.shape()[1] != values.shape()[1]) { 251 | throw new IllegalStateException("Keys must be the same length as values! But got keys.shape() = " + Arrays.toString(keys.shape()) 252 | + ", values.shape = " + Arrays.toString(values.shape())); 253 | } 254 | 255 | if (keys.shape()[2] != values.shape()[2] || query.shape()[2] != keys.shape()[2]) { 256 | throw new IllegalStateException("Queries, Keys and Values must have same mini-batch size! But got keys.shape() = " + Arrays.toString(keys.shape()) 257 | + ", values.shape = " + Arrays.toString(values.shape()) 258 | + ", queries.shape = " + Arrays.toString(query.shape()) 259 | ); 260 | } 261 | } 262 | 263 | 264 | private INDArray subArray(INDArray in, long example) { 265 | return in.tensorAlongDimension((int) example, 1, 2); 266 | } 267 | 268 | private INDArray subArray(INDArray in, long example, long timestep) { 269 | return subArray(in, example).tensorAlongDimension((int) timestep, 0); 270 | } 271 | 272 | private INDArray subMask(INDArray mask, long example) { 273 | if (mask == null) { 274 | return null; 275 | } else { 276 | return mask.tensorAlongDimension((int) example, 1); 277 | } 278 | } 279 | 280 | private INDArray attentionHeadMask(INDArray mask, long[] shape) { 281 | if (mask == null) { 282 | return null; 283 | } else { 284 | return mask.broadcast(shape); 285 | } 286 | } 287 | } 288 | --------------------------------------------------------------------------------