├── README.md
├── src
├── main
│ ├── resources
│ │ └── logback.xml
│ └── java
│ │ └── tech
│ │ └── dubs
│ │ └── dl4j
│ │ └── contrib
│ │ └── attention
│ │ ├── activations
│ │ └── ActivationMaskedSoftmax.java
│ │ ├── conf
│ │ ├── RecurrentAttentionLayer.java
│ │ ├── TimestepAttentionLayer.java
│ │ └── SelfAttentionLayer.java
│ │ └── nn
│ │ ├── TimestepAttentionLayer.java
│ │ ├── params
│ │ ├── QueryAttentionParamInitializer.java
│ │ ├── SelfAttentionParamInitializer.java
│ │ └── RecurrentQueryAttentionParamInitializer.java
│ │ ├── SelfAttentionLayer.java
│ │ ├── RecurrentAttentionLayer.java
│ │ └── AdditiveAttentionMechanism.java
└── test
│ └── java
│ └── tech
│ └── dubs
│ └── dl4j
│ └── contrib
│ └── attention
│ ├── Serialization.java
│ └── GradientChecks.java
├── .gitignore
└── pom.xml
/README.md:
--------------------------------------------------------------------------------
1 | # Work in Progress: Attention Layers for DL4J
--------------------------------------------------------------------------------
/src/main/resources/logback.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
6 |
7 | %msg%n
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Created by .ignore support plugin (hsz.mobi)
2 | ### Maven template
3 | target/
4 | pom.xml.tag
5 | pom.xml.releaseBackup
6 | pom.xml.versionsBackup
7 | pom.xml.next
8 | release.properties
9 | dependency-reduced-pom.xml
10 | buildNumber.properties
11 | .mvn/timing.properties
12 | .mvn/wrapper/maven-wrapper.jar
13 | ### JetBrains template
14 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm
15 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
16 |
17 | *.iml
18 | .idea
19 |
20 | # Gradle and Maven with auto-import
21 | # When using Gradle or Maven with auto-import, you should exclude module files,
22 | # since they will be recreated, and may cause churn. Uncomment if using
23 | # auto-import.
24 | # .idea/modules.xml
25 | # .idea/*.iml
26 | # .idea/modules
27 |
28 | # CMake
29 | cmake-build-*/
30 |
31 | # Mongo Explorer plugin
32 | .idea/**/mongoSettings.xml
33 |
34 | # File-based project format
35 | *.iws
36 |
37 | # IntelliJ
38 | out/
39 |
40 | # mpeltonen/sbt-idea plugin
41 | .idea_modules/
42 |
43 | # JIRA plugin
44 | atlassian-ide-plugin.xml
45 |
46 | # Cursive Clojure plugin
47 | .idea/replstate.xml
48 |
49 | # Crashlytics plugin (for Android Studio and IntelliJ)
50 | com_crashlytics_export_strings.xml
51 | crashlytics.properties
52 | crashlytics-build.properties
53 | fabric.properties
54 |
55 | # Editor-based Rest Client
56 | .idea/httpRequests
57 | ### Java template
58 | # Compiled class file
59 | *.class
60 |
61 | # Log file
62 | *.log
63 |
64 | # BlueJ files
65 | *.ctxt
66 |
67 | # Mobile Tools for Java (J2ME)
68 | .mtj.tmp/
69 |
70 | # Package Files #
71 | *.jar
72 | *.war
73 | *.nar
74 | *.ear
75 | *.zip
76 | *.tar.gz
77 | *.rar
78 |
79 | # virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml
80 | hs_err_pid*
81 |
82 |
--------------------------------------------------------------------------------
/src/main/java/tech/dubs/dl4j/contrib/attention/activations/ActivationMaskedSoftmax.java:
--------------------------------------------------------------------------------
1 | /*******************************************************************************
2 | * Copyright (c) 2015-2018 Skymind, Inc.
3 | *
4 | * This program and the accompanying materials are made available under the
5 | * terms of the Apache License, Version 2.0 which is available at
6 | * https://www.apache.org/licenses/LICENSE-2.0.
7 | *
8 | * Unless required by applicable law or agreed to in writing, software
9 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
10 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
11 | * License for the specific language governing permissions and limitations
12 | * under the License.
13 | *
14 | * SPDX-License-Identifier: Apache-2.0
15 | ******************************************************************************/
16 |
17 | package tech.dubs.dl4j.contrib.attention.activations;
18 |
19 | import org.nd4j.linalg.api.ndarray.INDArray;
20 | import org.nd4j.linalg.api.ops.impl.transforms.OldSoftMax;
21 | import org.nd4j.linalg.factory.Nd4j;
22 | import org.nd4j.linalg.ops.transforms.Transforms;
23 | import org.nd4j.linalg.primitives.Pair;
24 |
25 | import java.util.Arrays;
26 |
27 | /**
28 | * f_i(x, m) = m_i*exp(x_i - shift) / sum_j m_j*exp(x_j - shift)
29 | * where shift = max_i(x_i), m = mask
30 | *
31 | * @author Paul Dubs
32 | */
33 | public class ActivationMaskedSoftmax {
34 |
35 | public INDArray getActivation(INDArray in, INDArray mask) {
36 | if(mask == null){
37 | return Nd4j.getExecutioner().execAndReturn(new OldSoftMax(in));
38 | }else {
39 | assertShape(in, mask, null);
40 |
41 | final INDArray shift = in.max(-1);
42 | final INDArray exp = Transforms.exp(in.subiColumnVector(shift), false);
43 |
44 | final INDArray masked = exp.muli(mask);
45 |
46 | final INDArray sum = masked.sum(-1);
47 | masked.diviColumnVector(sum);
48 | return masked;
49 | }
50 | }
51 |
52 | public Pair backprop(INDArray postSoftmax, INDArray mask, INDArray epsilon) {
53 | INDArray x = postSoftmax.mul(epsilon).sum(1);
54 | INDArray dLdz = postSoftmax.mul(epsilon.subColumnVector(x));
55 | return new Pair<>(dLdz, null);
56 | }
57 |
58 | @Override
59 | public String toString() {
60 | return "maskedSoftmax";
61 | }
62 |
63 | private void assertShape(INDArray in, INDArray mask, INDArray epsilon) {
64 | if (mask != null && !in.equalShapes(mask)) {
65 | throw new IllegalStateException("Shapes must be equal: in.shape{} = " + Arrays.toString(in.shape())
66 | + ", mask.shape() = " + Arrays.toString(mask.shape()));
67 | }
68 | if (epsilon != null && !in.equalShapes(epsilon)) {
69 | throw new IllegalStateException("Shapes must be equal during backprop: in.shape{} = " + Arrays.toString(in.shape())
70 | + ", epsilon.shape() = " + Arrays.toString(epsilon.shape()));
71 | }
72 | }
73 | }
74 |
--------------------------------------------------------------------------------
/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
5 | 4.0.0
6 |
7 | tech.dubs.dl4j.contrib
8 | attention
9 | 1.0-SNAPSHOT
10 |
11 |
12 | 1.0.0-beta2
13 | 1.2.3
14 | 1.8
15 | 2.4.3
16 |
17 |
18 |
19 |
20 | snapshots-repo
21 | https://oss.sonatype.org/content/repositories/snapshots
22 |
23 | false
24 |
25 |
26 | true
27 | daily
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 | org.deeplearning4j
36 | deeplearning4j-nn
37 | ${dl4j.version}
38 |
39 |
40 |
44 |
45 | org.nd4j
46 | nd4j-native-platform
47 | ${dl4j.version}
48 | test
49 |
50 |
51 |
52 |
53 | ch.qos.logback
54 | logback-classic
55 | ${logback.version}
56 | test
57 |
58 |
59 |
60 | junit
61 | junit
62 | 4.12
63 | test
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 | org.apache.maven.plugins
72 | maven-compiler-plugin
73 | 3.5.1
74 |
75 | ${java.version}
76 | ${java.version}
77 |
78 |
79 |
80 |
81 |
--------------------------------------------------------------------------------
/src/main/java/tech/dubs/dl4j/contrib/attention/conf/RecurrentAttentionLayer.java:
--------------------------------------------------------------------------------
1 | package tech.dubs.dl4j.contrib.attention.conf;
2 |
3 | import org.deeplearning4j.nn.api.Layer;
4 | import org.deeplearning4j.nn.api.ParamInitializer;
5 | import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
6 | import org.deeplearning4j.nn.conf.inputs.InputType;
7 | import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer;
8 | import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
9 | import org.deeplearning4j.optimize.api.TrainingListener;
10 | import org.nd4j.linalg.api.ndarray.INDArray;
11 | import tech.dubs.dl4j.contrib.attention.nn.params.RecurrentQueryAttentionParamInitializer;
12 |
13 | import java.util.Collection;
14 | import java.util.Map;
15 |
16 | /**
17 | * TODO: Memory Report, Configurable Activation for Attention
18 | *
19 | * @author Paul Dubs
20 | */
21 | public class RecurrentAttentionLayer extends BaseRecurrentLayer {
22 | // No-Op Constructor for Deserialization
23 | public RecurrentAttentionLayer() { }
24 |
25 | private RecurrentAttentionLayer(Builder builder) {
26 | super(builder);
27 | }
28 |
29 | @Override
30 | public Layer instantiate(NeuralNetConfiguration conf, Collection iterationListeners,
31 | int layerIndex, INDArray layerParamsView, boolean initializeParams) {
32 |
33 | tech.dubs.dl4j.contrib.attention.nn.RecurrentAttentionLayer layer = new tech.dubs.dl4j.contrib.attention.nn.RecurrentAttentionLayer(conf);
34 | layer.setListeners(iterationListeners); //Set the iteration listeners, if any
35 | layer.setIndex(layerIndex); //Integer index of the layer
36 |
37 | layer.setParamsViewArray(layerParamsView);
38 |
39 | Map paramTable = initializer().init(conf, layerParamsView, initializeParams);
40 | layer.setParamTable(paramTable);
41 | layer.setConf(conf);
42 | return layer;
43 | }
44 |
45 | @Override
46 | public ParamInitializer initializer() {
47 | return RecurrentQueryAttentionParamInitializer.getInstance();
48 | }
49 |
50 | @Override
51 | public double getL1ByParam(String paramName) {
52 | if(initializer().isWeightParam(this, paramName)){
53 | return l1;
54 | }else if(initializer().isBiasParam(this, paramName)){
55 | return l1Bias;
56 | }
57 |
58 | throw new IllegalArgumentException("Unknown parameter name: \"" + paramName + "\"");
59 | }
60 |
61 | @Override
62 | public double getL2ByParam(String paramName) {
63 | if(initializer().isWeightParam(this, paramName)){
64 | return l2;
65 | }else if(initializer().isBiasParam(this, paramName)){
66 | return l2Bias;
67 | }
68 |
69 | throw new IllegalArgumentException("Unknown parameter name: \"" + paramName + "\"");
70 | }
71 |
72 |
73 | @Override
74 | public LayerMemoryReport getMemoryReport(InputType inputType) {
75 | return null;
76 | }
77 |
78 |
79 | public static class Builder extends BaseRecurrentLayer.Builder {
80 |
81 | @Override
82 | @SuppressWarnings("unchecked") //To stop warnings about unchecked cast. Not required.
83 | public RecurrentAttentionLayer build() {
84 | return new RecurrentAttentionLayer(this);
85 | }
86 |
87 | }
88 | }
89 |
--------------------------------------------------------------------------------
/src/test/java/tech/dubs/dl4j/contrib/attention/Serialization.java:
--------------------------------------------------------------------------------
1 | package tech.dubs.dl4j.contrib.attention;
2 |
3 | import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
4 | import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
5 | import org.deeplearning4j.nn.conf.inputs.InputType;
6 | import org.deeplearning4j.nn.conf.layers.LSTM;
7 | import org.deeplearning4j.nn.conf.layers.OutputLayer;
8 | import org.deeplearning4j.nn.weights.WeightInit;
9 | import org.junit.Assert;
10 | import org.junit.Test;
11 | import org.nd4j.linalg.activations.Activation;
12 | import org.nd4j.linalg.learning.config.NoOp;
13 | import org.nd4j.linalg.lossfunctions.LossFunctions;
14 | import tech.dubs.dl4j.contrib.attention.conf.RecurrentAttentionLayer;
15 | import tech.dubs.dl4j.contrib.attention.conf.SelfAttentionLayer;
16 | import tech.dubs.dl4j.contrib.attention.conf.TimestepAttentionLayer;
17 |
18 | public class Serialization {
19 | @Test
20 | public void testSelfAttentionSerialization(){
21 | int nIn = 3;
22 | int nOut = 5;
23 | int layerSize = 8;
24 |
25 | MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
26 | .activation(Activation.TANH)
27 | .updater(new NoOp())
28 | .weightInit(WeightInit.XAVIER)
29 | .list()
30 | .layer(new LSTM.Builder().nOut(layerSize).build())
31 | .layer(new SelfAttentionLayer.Builder().build())
32 | .layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX)
33 | .lossFunction(LossFunctions.LossFunction.MCXENT).build())
34 | .setInputType(InputType.recurrent(nIn))
35 | .build();
36 |
37 | final String json = conf.toJson();
38 | final String yaml = conf.toYaml();
39 |
40 | final MultiLayerConfiguration fromJson = MultiLayerConfiguration.fromJson(json);
41 | final MultiLayerConfiguration fromYaml = MultiLayerConfiguration.fromYaml(yaml);
42 |
43 | Assert.assertEquals(conf, fromJson);
44 | Assert.assertEquals(conf, fromYaml);
45 | }
46 |
47 | @Test
48 | public void testTimestepAttentionSerialization(){
49 | int nIn = 3;
50 | int nOut = 5;
51 | int layerSize = 8;
52 |
53 | MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
54 | .activation(Activation.TANH)
55 | .updater(new NoOp())
56 | .weightInit(WeightInit.XAVIER)
57 | .list()
58 | .layer(new LSTM.Builder().nOut(layerSize).build())
59 | .layer(new TimestepAttentionLayer.Builder().build())
60 | .layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX)
61 | .lossFunction(LossFunctions.LossFunction.MCXENT).build())
62 | .setInputType(InputType.recurrent(nIn))
63 | .build();
64 |
65 | final String json = conf.toJson();
66 | final String yaml = conf.toYaml();
67 |
68 | final MultiLayerConfiguration fromJson = MultiLayerConfiguration.fromJson(json);
69 | final MultiLayerConfiguration fromYaml = MultiLayerConfiguration.fromYaml(yaml);
70 |
71 | Assert.assertEquals(conf, fromJson);
72 | Assert.assertEquals(conf, fromYaml);
73 | }
74 |
75 | @Test
76 | public void testRecurrentAttentionSerialization(){
77 | int nIn = 3;
78 | int nOut = 5;
79 | int layerSize = 8;
80 |
81 | MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
82 | .activation(Activation.TANH)
83 | .updater(new NoOp())
84 | .weightInit(WeightInit.XAVIER)
85 | .list()
86 | .layer(new LSTM.Builder().nOut(layerSize).build())
87 | .layer(new RecurrentAttentionLayer.Builder().build())
88 | .layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX)
89 | .lossFunction(LossFunctions.LossFunction.MCXENT).build())
90 | .setInputType(InputType.recurrent(nIn))
91 | .build();
92 |
93 | final String json = conf.toJson();
94 | final String yaml = conf.toYaml();
95 |
96 | final MultiLayerConfiguration fromJson = MultiLayerConfiguration.fromJson(json);
97 | final MultiLayerConfiguration fromYaml = MultiLayerConfiguration.fromYaml(yaml);
98 |
99 | Assert.assertEquals(conf, fromJson);
100 | Assert.assertEquals(conf, fromYaml);
101 | }
102 | }
103 |
--------------------------------------------------------------------------------
/src/main/java/tech/dubs/dl4j/contrib/attention/nn/TimestepAttentionLayer.java:
--------------------------------------------------------------------------------
1 | package tech.dubs.dl4j.contrib.attention.nn;
2 |
3 | import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
4 | import org.deeplearning4j.nn.gradient.DefaultGradient;
5 | import org.deeplearning4j.nn.gradient.Gradient;
6 | import org.deeplearning4j.nn.layers.BaseLayer;
7 | import org.deeplearning4j.nn.workspace.ArrayType;
8 | import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
9 | import org.nd4j.base.Preconditions;
10 | import org.nd4j.linalg.activations.IActivation;
11 | import org.nd4j.linalg.api.ndarray.INDArray;
12 | import org.nd4j.linalg.api.shape.Shape;
13 | import org.nd4j.linalg.primitives.Pair;
14 | import tech.dubs.dl4j.contrib.attention.activations.ActivationMaskedSoftmax;
15 | import tech.dubs.dl4j.contrib.attention.nn.params.QueryAttentionParamInitializer;
16 |
17 | /**
18 | * Timestep Attention Layer Implementation
19 | *
20 | *
21 | * TODO:
22 | * - Optionally keep attention weights around for inspection
23 | * - Handle Masking
24 | *
25 | * @author Paul Dubs
26 | */
27 | public class TimestepAttentionLayer extends BaseLayer {
28 | private ActivationMaskedSoftmax softmax = new ActivationMaskedSoftmax();
29 |
30 | public TimestepAttentionLayer(NeuralNetConfiguration conf) {
31 | super(conf);
32 | }
33 |
34 | @Override
35 | public boolean isPretrainLayer() {
36 | return false;
37 | }
38 |
39 | @Override
40 | public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) {
41 | assertInputSet(false);
42 | Preconditions.checkState(input.rank() == 3,
43 | "3D input expected to RNN layer expected, got " + input.rank());
44 |
45 | applyDropOutIfNecessary(training, workspaceMgr);
46 |
47 | INDArray W = getParamWithNoise(QueryAttentionParamInitializer.WEIGHT_KEY, training, workspaceMgr);
48 | INDArray Q = getParamWithNoise(QueryAttentionParamInitializer.QUERY_WEIGHT_KEY, training, workspaceMgr);
49 | INDArray b = getParamWithNoise(QueryAttentionParamInitializer.BIAS_KEY, training, workspaceMgr);
50 |
51 |
52 | long nIn = layerConf().getNIn();
53 | long nOut = layerConf().getNOut();
54 | IActivation a = layerConf().getActivationFn();
55 | long examples = input.shape()[0] == nIn ? input.shape()[2] : input.shape()[0];
56 | long tsLength = input.shape()[0] == nIn ? input.shape()[1] : input.shape()[2];
57 |
58 | INDArray activations = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, new long[]{examples, nIn*nOut, tsLength}, 'f');
59 |
60 | if(input.shape()[0] != nIn)
61 | input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input.permute(1, 2, 0), 'f');
62 |
63 | final AdditiveAttentionMechanism attentionMechanism = new AdditiveAttentionMechanism(Q, W, b, a, workspaceMgr, training);
64 | final INDArray attention = attentionMechanism.query(input, input, input, maskArray);
65 | activations.assign(attention);
66 |
67 | return activations;
68 | }
69 |
70 |
71 |
72 | @Override
73 | public Pair backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
74 | assertInputSet(true);
75 | if(epsilon.ordering() != 'f' || !Shape.hasDefaultStridesForShape(epsilon))
76 | epsilon = epsilon.dup('f');
77 |
78 | INDArray W = getParamWithNoise(QueryAttentionParamInitializer.WEIGHT_KEY, true, workspaceMgr);
79 | INDArray Q = getParamWithNoise(QueryAttentionParamInitializer.QUERY_WEIGHT_KEY, true, workspaceMgr);
80 | INDArray b = getParamWithNoise(QueryAttentionParamInitializer.BIAS_KEY, true, workspaceMgr);
81 |
82 | INDArray Wg = gradientViews.get(QueryAttentionParamInitializer.WEIGHT_KEY);
83 | INDArray Qg = gradientViews.get(QueryAttentionParamInitializer.QUERY_WEIGHT_KEY);
84 | INDArray bg = gradientViews.get(QueryAttentionParamInitializer.BIAS_KEY);
85 | gradientsFlattened.assign(0);
86 |
87 | applyDropOutIfNecessary(true, workspaceMgr);
88 |
89 | IActivation a = layerConf().getActivationFn();
90 |
91 | long nIn = layerConf().getNIn();
92 | if(input.shape()[0] != nIn)
93 | input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input.permute(1, 2, 0), 'f');
94 |
95 | INDArray epsOut = workspaceMgr.create(ArrayType.ACTIVATION_GRAD, input.shape(), 'f');
96 |
97 | final AdditiveAttentionMechanism attentionMechanism = new AdditiveAttentionMechanism(Q, W, b, a, workspaceMgr, true);
98 | attentionMechanism
99 | .withGradientViews(Wg, Qg, bg, epsOut, epsOut, epsOut)
100 | .backprop(epsilon, input, input, input, maskArray);
101 |
102 | epsOut = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsOut.permute(2, 0, 1), 'f');
103 |
104 | weightNoiseParams.clear();
105 |
106 | Gradient g = new DefaultGradient(gradientsFlattened);
107 | g.gradientForVariable().put(QueryAttentionParamInitializer.WEIGHT_KEY, Wg);
108 | g.gradientForVariable().put(QueryAttentionParamInitializer.QUERY_WEIGHT_KEY, Qg);
109 | g.gradientForVariable().put(QueryAttentionParamInitializer.BIAS_KEY, bg);
110 |
111 | epsOut = backpropDropOutIfPresent(epsOut);
112 | return new Pair<>(g, epsOut);
113 | }
114 | }
115 |
--------------------------------------------------------------------------------
/src/main/java/tech/dubs/dl4j/contrib/attention/conf/TimestepAttentionLayer.java:
--------------------------------------------------------------------------------
1 | package tech.dubs.dl4j.contrib.attention.conf;
2 |
3 | import org.deeplearning4j.nn.api.Layer;
4 | import org.deeplearning4j.nn.api.ParamInitializer;
5 | import org.deeplearning4j.nn.conf.InputPreProcessor;
6 | import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
7 | import org.deeplearning4j.nn.conf.inputs.InputType;
8 | import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer;
9 | import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
10 | import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
11 | import org.deeplearning4j.nn.conf.memory.MemoryReport;
12 | import org.deeplearning4j.optimize.api.TrainingListener;
13 | import org.nd4j.linalg.api.ndarray.INDArray;
14 | import tech.dubs.dl4j.contrib.attention.nn.params.QueryAttentionParamInitializer;
15 |
16 | import java.util.Collection;
17 | import java.util.Map;
18 |
19 | /**
20 | *
21 | *
22 | * @author Paul Dubs
23 | */
24 | public class TimestepAttentionLayer extends BaseRecurrentLayer {
25 | // No-Op Constructor for Deserialization
26 | public TimestepAttentionLayer() { }
27 |
28 | private TimestepAttentionLayer(Builder builder) {
29 | super(builder);
30 | }
31 |
32 | @Override
33 | public Layer instantiate(NeuralNetConfiguration conf, Collection iterationListeners,
34 | int layerIndex, INDArray layerParamsView, boolean initializeParams) {
35 |
36 | tech.dubs.dl4j.contrib.attention.nn.TimestepAttentionLayer layer = new tech.dubs.dl4j.contrib.attention.nn.TimestepAttentionLayer(conf);
37 | layer.setListeners(iterationListeners); //Set the iteration listeners, if any
38 | layer.setIndex(layerIndex); //Integer index of the layer
39 |
40 | layer.setParamsViewArray(layerParamsView);
41 |
42 | Map paramTable = initializer().init(conf, layerParamsView, initializeParams);
43 | layer.setParamTable(paramTable);
44 | layer.setConf(conf);
45 | return layer;
46 | }
47 |
48 | @Override
49 | public ParamInitializer initializer() {
50 | return QueryAttentionParamInitializer.getInstance();
51 | }
52 |
53 | @Override
54 | public double getL1ByParam(String paramName) {
55 | if(initializer().isWeightParam(this, paramName)){
56 | return l1;
57 | }else if(initializer().isBiasParam(this, paramName)){
58 | return l1Bias;
59 | }
60 |
61 | throw new IllegalArgumentException("Unknown parameter name: \"" + paramName + "\"");
62 | }
63 |
64 | @Override
65 | public double getL2ByParam(String paramName) {
66 | if(initializer().isWeightParam(this, paramName)){
67 | return l2;
68 | }else if(initializer().isBiasParam(this, paramName)){
69 | return l2Bias;
70 | }
71 |
72 | throw new IllegalArgumentException("Unknown parameter name: \"" + paramName + "\"");
73 | }
74 |
75 | @Override
76 | public InputType getOutputType(int layerIndex, InputType inputType) {
77 | if (inputType == null || inputType.getType() != InputType.Type.RNN) {
78 | throw new IllegalStateException("Invalid input for RNN layer (layer index = " + layerIndex
79 | + ", layer name = \"" + getLayerName() + "\"): expect RNN input type with size > 0. Got: "
80 | + inputType);
81 | }
82 | InputType.InputTypeRecurrent itr = (InputType.InputTypeRecurrent) inputType;
83 |
84 | return InputType.recurrent(nIn * nOut, itr.getTimeSeriesLength());
85 | }
86 |
87 | @Override
88 | public void setNIn(InputType inputType, boolean override) {
89 | if (inputType == null || inputType.getType() != InputType.Type.RNN) {
90 | throw new IllegalStateException("Invalid input for RNN layer (layer name = \"" + getLayerName()
91 | + "\"): expect RNN input type with size > 0. Got: " + inputType);
92 | }
93 |
94 | if (nIn <= 0 || override) {
95 | InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType;
96 | this.nIn = r.getSize();
97 | }
98 | }
99 |
100 | @Override
101 | public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
102 | return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, getLayerName());
103 | }
104 |
105 | @Override
106 | public LayerMemoryReport getMemoryReport(InputType inputType) {
107 | InputType.InputTypeRecurrent itr = (InputType.InputTypeRecurrent) inputType;
108 | final long tsLength = itr.getTimeSeriesLength();
109 |
110 | InputType outputType = getOutputType(-1, inputType);
111 |
112 | long numParams = initializer().numParams(this);
113 | int updaterStateSize = (int)getIUpdater().stateSize(numParams);
114 |
115 | int trainSizeFixed = 0;
116 | int trainSizeVariable = 0;
117 | if(getIDropout() != null){
118 | //Assume we dup the input for dropout
119 | trainSizeVariable += inputType.arrayElementsPerExample();
120 | }
121 | trainSizeVariable += outputType.arrayElementsPerExample() * tsLength;
122 | trainSizeVariable += itr.getSize() * outputType.arrayElementsPerExample();
123 |
124 | return new LayerMemoryReport.Builder(layerName, TimestepAttentionLayer.class, inputType, outputType)
125 | .standardMemory(numParams, updaterStateSize)
126 | .workingMemory(0, 0, trainSizeFixed, trainSizeVariable) //No additional memory (beyond activations) for inference
127 | .cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS) //No caching in DenseLayer
128 | .build();
129 | }
130 |
131 |
132 | public static class Builder extends BaseRecurrentLayer.Builder {
133 |
134 | @Override
135 | @SuppressWarnings("unchecked") //To stop warnings about unchecked cast. Not required.
136 | public TimestepAttentionLayer build() {
137 | return new TimestepAttentionLayer(this);
138 | }
139 |
140 | }
141 | }
142 |
--------------------------------------------------------------------------------
/src/main/java/tech/dubs/dl4j/contrib/attention/conf/SelfAttentionLayer.java:
--------------------------------------------------------------------------------
1 | package tech.dubs.dl4j.contrib.attention.conf;
2 |
3 | import org.deeplearning4j.nn.api.Layer;
4 | import org.deeplearning4j.nn.api.ParamInitializer;
5 | import org.deeplearning4j.nn.conf.InputPreProcessor;
6 | import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
7 | import org.deeplearning4j.nn.conf.inputs.InputType;
8 | import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
9 | import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
10 | import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
11 | import org.deeplearning4j.nn.conf.memory.MemoryReport;
12 | import org.deeplearning4j.optimize.api.TrainingListener;
13 | import org.nd4j.linalg.api.ndarray.INDArray;
14 | import tech.dubs.dl4j.contrib.attention.nn.params.SelfAttentionParamInitializer;
15 |
16 | import java.util.Collection;
17 | import java.util.Map;
18 |
19 | /**
20 | * nOut = Number of Attention Heads
21 | *
22 | * The self attention layer is the simplest of attention layers. It takes a recurrent input, and returns a fixed size
23 | * dense output. For this it requires just the same parameters as a dense layer, since for the most learning part of it
24 | * it actually is a dense layer.
25 | *
26 | * @author Paul Dubs
27 | */
28 | public class SelfAttentionLayer extends FeedForwardLayer {
29 | // No-Op Constructor for Deserialization
30 | public SelfAttentionLayer() { }
31 |
32 | private SelfAttentionLayer(Builder builder) {
33 | super(builder);
34 | }
35 |
36 | @Override
37 | public Layer instantiate(NeuralNetConfiguration conf, Collection iterationListeners,
38 | int layerIndex, INDArray layerParamsView, boolean initializeParams) {
39 |
40 | tech.dubs.dl4j.contrib.attention.nn.SelfAttentionLayer layer = new tech.dubs.dl4j.contrib.attention.nn.SelfAttentionLayer(conf);
41 | layer.setListeners(iterationListeners); //Set the iteration listeners, if any
42 | layer.setIndex(layerIndex); //Integer index of the layer
43 |
44 | layer.setParamsViewArray(layerParamsView);
45 |
46 | Map paramTable = initializer().init(conf, layerParamsView, initializeParams);
47 | layer.setParamTable(paramTable);
48 | layer.setConf(conf);
49 | return layer;
50 | }
51 |
52 | @Override
53 | public ParamInitializer initializer() {
54 | return SelfAttentionParamInitializer.getInstance();
55 | }
56 |
57 |
58 | @Override
59 | public double getL1ByParam(String paramName) {
60 | if(initializer().isWeightParam(this, paramName)){
61 | return l1;
62 | }else if(initializer().isBiasParam(this, paramName)){
63 | return l1Bias;
64 | }
65 |
66 | throw new IllegalArgumentException("Unknown parameter name: \"" + paramName + "\"");
67 | }
68 |
69 | @Override
70 | public double getL2ByParam(String paramName) {
71 | if(initializer().isWeightParam(this, paramName)){
72 | return l2;
73 | }else if(initializer().isBiasParam(this, paramName)){
74 | return l2Bias;
75 | }
76 |
77 | throw new IllegalArgumentException("Unknown parameter name: \"" + paramName + "\"");
78 | }
79 |
80 | @Override
81 | public InputType getOutputType(int layerIndex, InputType inputType) {
82 | if (inputType == null || inputType.getType() != InputType.Type.RNN) {
83 | throw new IllegalStateException("Invalid input for RNN layer (layer index = " + layerIndex
84 | + ", layer name = \"" + getLayerName() + "\"): expect RNN input type with size > 0. Got: "
85 | + inputType);
86 | }
87 |
88 | return InputType.feedForward(nIn * nOut);
89 | }
90 |
91 | @Override
92 | public void setNIn(InputType inputType, boolean override) {
93 | if (inputType == null || inputType.getType() != InputType.Type.RNN) {
94 | throw new IllegalStateException("Invalid input for RNN layer (layer name = \"" + getLayerName()
95 | + "\"): expect RNN input type with size > 0. Got: " + inputType);
96 | }
97 |
98 | if (nIn <= 0 || override) {
99 | InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType;
100 | this.nIn = r.getSize();
101 | }
102 | }
103 |
104 | @Override
105 | public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
106 | return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, getLayerName());
107 | }
108 |
109 | @Override
110 | public LayerMemoryReport getMemoryReport(InputType inputType) {
111 | InputType.InputTypeRecurrent itr = (InputType.InputTypeRecurrent) inputType;
112 | final long tsLength = itr.getTimeSeriesLength();
113 |
114 | InputType outputType = getOutputType(-1, inputType);
115 |
116 | long numParams = initializer().numParams(this);
117 | int updaterStateSize = (int)getIUpdater().stateSize(numParams);
118 |
119 | int trainSizeFixed = 0;
120 | int trainSizeVariable = 0;
121 | if(getIDropout() != null){
122 | //Assume we dup the input for dropout
123 | trainSizeVariable += inputType.arrayElementsPerExample();
124 | }
125 | trainSizeVariable += outputType.arrayElementsPerExample() * tsLength;
126 | trainSizeVariable += itr.getSize() * outputType.arrayElementsPerExample();
127 |
128 | return new LayerMemoryReport.Builder(layerName, SelfAttentionLayer.class, inputType, outputType)
129 | .standardMemory(numParams, updaterStateSize)
130 | .workingMemory(0, 0, trainSizeFixed, trainSizeVariable) //No additional memory (beyond activations) for inference
131 | .cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS) //No caching in DenseLayer
132 | .build();
133 | }
134 |
135 |
136 | public static class Builder extends FeedForwardLayer.Builder {
137 | @Override
138 | @SuppressWarnings("unchecked") //To stop warnings about unchecked cast. Not required.
139 | public SelfAttentionLayer build() {
140 | return new SelfAttentionLayer(this);
141 | }
142 | }
143 | }
144 |
--------------------------------------------------------------------------------
/src/main/java/tech/dubs/dl4j/contrib/attention/nn/params/QueryAttentionParamInitializer.java:
--------------------------------------------------------------------------------
1 | /*******************************************************************************
2 | * Copyright (c) 2015-2018 Skymind, Inc.
3 | *
4 | * This program and the accompanying materials are made available under the
5 | * terms of the Apache License, Version 2.0 which is available at
6 | * https://www.apache.org/licenses/LICENSE-2.0.
7 | *
8 | * Unless required by applicable law or agreed to in writing, software
9 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
10 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
11 | * License for the specific language governing permissions and limitations
12 | * under the License.
13 | *
14 | * SPDX-License-Identifier: Apache-2.0
15 | ******************************************************************************/
16 |
17 | package tech.dubs.dl4j.contrib.attention.nn.params;
18 |
19 | import org.deeplearning4j.nn.api.ParamInitializer;
20 | import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
21 | import org.deeplearning4j.nn.conf.distribution.Distributions;
22 | import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer;
23 | import org.deeplearning4j.nn.conf.layers.Layer;
24 | import org.deeplearning4j.nn.params.DefaultParamInitializer;
25 | import org.deeplearning4j.nn.weights.WeightInit;
26 | import org.deeplearning4j.nn.weights.WeightInitUtil;
27 | import org.nd4j.linalg.api.ndarray.INDArray;
28 | import org.nd4j.linalg.api.rng.distribution.Distribution;
29 |
30 | import java.util.*;
31 |
32 | import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
33 | import static org.nd4j.linalg.indexing.NDArrayIndex.point;
34 |
35 | /**
36 | * @author Paul Dubs
37 | */
38 | public class QueryAttentionParamInitializer implements ParamInitializer {
39 |
40 | private static final QueryAttentionParamInitializer INSTANCE = new QueryAttentionParamInitializer();
41 |
42 | public static QueryAttentionParamInitializer getInstance(){
43 | return INSTANCE;
44 | }
45 |
46 | public static final String WEIGHT_KEY = DefaultParamInitializer.WEIGHT_KEY;
47 | public static final String QUERY_WEIGHT_KEY = "Q";
48 | public static final String BIAS_KEY = DefaultParamInitializer.BIAS_KEY;
49 |
50 | private static final List PARAM_KEYS = Collections.unmodifiableList(Arrays.asList(WEIGHT_KEY, QUERY_WEIGHT_KEY, BIAS_KEY));
51 | private static final List WEIGHT_KEYS = Collections.unmodifiableList(Arrays.asList(WEIGHT_KEY, QUERY_WEIGHT_KEY));
52 | private static final List BIAS_KEYS = Collections.singletonList(BIAS_KEY);
53 |
54 |
55 | @Override
56 | public long numParams(NeuralNetConfiguration conf) {
57 | return numParams(conf.getLayer());
58 | }
59 |
60 | @Override
61 | public long numParams(Layer layer) {
62 | BaseRecurrentLayer c = (BaseRecurrentLayer) layer;
63 | final long nIn = c.getNIn();
64 | final long nOut = c.getNOut();
65 |
66 | final long paramsW = nIn * nOut;
67 | final long paramsWq = nIn * nOut;
68 | final long paramsB = nOut;
69 | return paramsW + paramsWq + paramsB;
70 | }
71 |
72 | @Override
73 | public List paramKeys(Layer layer) {
74 | return PARAM_KEYS;
75 | }
76 |
77 | @Override
78 | public List weightKeys(Layer layer) {
79 | return WEIGHT_KEYS;
80 | }
81 |
82 | @Override
83 | public List biasKeys(Layer layer) {
84 | return BIAS_KEYS;
85 | }
86 |
87 | @Override
88 | public boolean isWeightParam(Layer layer, String key) {
89 | return WEIGHT_KEYS.contains(key);
90 | }
91 |
92 | @Override
93 | public boolean isBiasParam(Layer layer, String key) {
94 | return BIAS_KEYS.contains(key);
95 | }
96 |
97 | @Override
98 | public Map init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) {
99 | BaseRecurrentLayer c = (BaseRecurrentLayer) conf.getLayer();
100 | final long nIn = c.getNIn();
101 | final long nOut = c.getNOut();
102 |
103 | Map m;
104 |
105 | if (initializeParams) {
106 | Distribution dist = Distributions.createDistribution(c.getDist());
107 |
108 | m = getSubsets(paramsView, nIn, nOut, false);
109 | INDArray w = WeightInitUtil.initWeights(nIn, nOut, new long[]{nIn, nOut}, c.getWeightInit(), dist, 'f', m.get(WEIGHT_KEY));
110 | m.put(WEIGHT_KEY, w);
111 |
112 | WeightInit rqInit;
113 | Distribution rqDist = dist;
114 | if (c.getWeightInitRecurrent() != null) {
115 | rqInit = c.getWeightInitRecurrent();
116 | if(c.getDistRecurrent() != null) {
117 | rqDist = Distributions.createDistribution(c.getDistRecurrent());
118 | }
119 | } else {
120 | rqInit = c.getWeightInit();
121 | }
122 |
123 | INDArray rq = WeightInitUtil.initWeights(nIn, nOut, new long[]{nIn, nOut}, rqInit, rqDist, 'f', m.get(QUERY_WEIGHT_KEY));
124 | m.put(QUERY_WEIGHT_KEY, rq);
125 | } else {
126 | m = getSubsets(paramsView, nIn, nOut, true);
127 | }
128 |
129 | conf.addVariable(WEIGHT_KEY);
130 | conf.addVariable(QUERY_WEIGHT_KEY);
131 | conf.addVariable(BIAS_KEY);
132 |
133 | return m;
134 | }
135 |
136 | @Override
137 | public Map getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) {
138 | BaseRecurrentLayer c = (BaseRecurrentLayer) conf.getLayer();
139 | final long nIn = c.getNIn();
140 | final long nOut = c.getNOut();
141 |
142 | return getSubsets(gradientView, nIn, nOut, true);
143 | }
144 |
145 | private static Map getSubsets(INDArray in, long nIn, long nOut, boolean reshape){
146 | long pos = nIn * nOut;
147 | INDArray w = in.get(point(0), interval(0, pos));
148 | INDArray rq = in.get(point(0), interval(pos, pos + nIn * nOut));
149 | pos += nIn * nOut;
150 | INDArray b = in.get(point(0), interval(pos, pos + nOut));
151 |
152 | if(reshape){
153 | w = w.reshape('f', nIn, nOut);
154 | rq = rq.reshape('f', nIn, nOut);
155 | }
156 |
157 | Map m = new LinkedHashMap<>();
158 | m.put(WEIGHT_KEY, w);
159 | m.put(QUERY_WEIGHT_KEY, rq);
160 | m.put(BIAS_KEY, b);
161 | return m;
162 | }
163 | }
164 |
--------------------------------------------------------------------------------
/src/main/java/tech/dubs/dl4j/contrib/attention/nn/SelfAttentionLayer.java:
--------------------------------------------------------------------------------
1 | package tech.dubs.dl4j.contrib.attention.nn;
2 |
3 | import org.deeplearning4j.nn.api.MaskState;
4 | import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
5 | import org.deeplearning4j.nn.gradient.DefaultGradient;
6 | import org.deeplearning4j.nn.gradient.Gradient;
7 | import org.deeplearning4j.nn.layers.BaseLayer;
8 | import org.deeplearning4j.nn.workspace.ArrayType;
9 | import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
10 | import org.nd4j.base.Preconditions;
11 | import org.nd4j.linalg.activations.IActivation;
12 | import org.nd4j.linalg.activations.impl.ActivationSoftmax;
13 | import org.nd4j.linalg.api.ndarray.INDArray;
14 | import org.nd4j.linalg.primitives.Pair;
15 | import tech.dubs.dl4j.contrib.attention.nn.params.SelfAttentionParamInitializer;
16 |
17 | /**
18 | * Self Attention Layer Implementation
19 | *
20 | * The implementation of mmul across time isn't the most efficient thing possible in nd4j, since the reshapes require
21 | * a copy, but it is the easiest to follow for now.
22 | *
23 | * TODO:
24 | * - Optionally keep attention weights around for inspection
25 | * - Handle Masking
26 | *
27 | * @author Paul Dubs
28 | */
29 | public class SelfAttentionLayer extends BaseLayer {
30 | private IActivation softmax = new ActivationSoftmax();
31 |
32 | public SelfAttentionLayer(NeuralNetConfiguration conf) {
33 | super(conf);
34 | }
35 |
36 | @Override
37 | public boolean isPretrainLayer() {
38 | return false;
39 | }
40 |
41 | @Override
42 | public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) {
43 | assertInputSet(false);
44 | Preconditions.checkState(input.rank() == 3,
45 | "3D input expected to RNN layer expected, got " + input.rank());
46 |
47 | applyDropOutIfNecessary(training, workspaceMgr);
48 |
49 | INDArray W = getParamWithNoise(SelfAttentionParamInitializer.WEIGHT_KEY, training, workspaceMgr);
50 | INDArray Q = getParamWithNoise(SelfAttentionParamInitializer.QUERY_WEIGHT_KEY, training, workspaceMgr);
51 | INDArray b = getParamWithNoise(SelfAttentionParamInitializer.BIAS_KEY, training, workspaceMgr);
52 | INDArray q = getParamWithNoise(SelfAttentionParamInitializer.QUERY_KEY, training, workspaceMgr);
53 |
54 | long nOut = layerConf().getNOut();
55 | long nIn = layerConf().getNIn();
56 | long examples = input.shape()[0] == nIn ? input.shape()[2] : input.shape()[0];
57 | IActivation a = layerConf().getActivationFn();
58 |
59 | INDArray activations = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, new long[]{examples, nIn * nOut}, 'f');
60 |
61 | if(input.shape()[0] != nIn)
62 | input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input.permute(1, 2, 0), 'f');
63 |
64 | final INDArray queries = q.reshape(nIn, 1, 1).broadcast(nIn, 1, examples);
65 |
66 | final AdditiveAttentionMechanism attentionMechanism = new AdditiveAttentionMechanism(Q, W, b, a, workspaceMgr, training);
67 | final INDArray attention = attentionMechanism.query(queries, input, input, maskArray);
68 | activations.assign(attention.reshape(activations.shape()));
69 |
70 | return activations;
71 | }
72 |
73 |
74 |
75 | @Override
76 | public Pair backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
77 | assertInputSet(true);
78 |
79 | INDArray W = getParamWithNoise(SelfAttentionParamInitializer.WEIGHT_KEY, true, workspaceMgr);
80 | INDArray Q = getParamWithNoise(SelfAttentionParamInitializer.QUERY_WEIGHT_KEY, true, workspaceMgr);
81 | INDArray b = getParamWithNoise(SelfAttentionParamInitializer.BIAS_KEY, true, workspaceMgr);
82 | INDArray q = getParamWithNoise(SelfAttentionParamInitializer.QUERY_KEY, true, workspaceMgr);
83 |
84 | INDArray Wg = gradientViews.get(SelfAttentionParamInitializer.WEIGHT_KEY);
85 | INDArray Qg = gradientViews.get(SelfAttentionParamInitializer.QUERY_WEIGHT_KEY);
86 | INDArray bg = gradientViews.get(SelfAttentionParamInitializer.BIAS_KEY);
87 | INDArray qg = gradientViews.get(SelfAttentionParamInitializer.QUERY_KEY);
88 | gradientsFlattened.assign(0);
89 |
90 | applyDropOutIfNecessary(true, workspaceMgr);
91 |
92 | long nIn = layerConf().getNIn();
93 | long examples = input.shape()[0] == nIn ? input.shape()[2] : input.shape()[0];
94 | IActivation a = layerConf().getActivationFn();
95 |
96 | if(input.shape()[0] != nIn)
97 | input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input.permute(1, 2, 0), 'f');
98 |
99 | INDArray epsOut = workspaceMgr.create(ArrayType.ACTIVATION_GRAD, input.shape(), 'f');
100 |
101 | final AdditiveAttentionMechanism attentionMechanism = new AdditiveAttentionMechanism(Q, W, b, a, workspaceMgr, true);
102 |
103 | final INDArray queries = q.reshape(nIn, 1, 1).broadcast(nIn, 1, examples);
104 | final INDArray queryG = workspaceMgr.create(ArrayType.BP_WORKING_MEM, queries.shape(), 'f');
105 |
106 | attentionMechanism.withGradientViews(Wg, Qg, bg, epsOut, epsOut, queryG)
107 | .backprop(epsilon, queries, input, input, maskArray);
108 |
109 | qg.assign(queryG.sum(2).transposei());
110 |
111 | epsOut = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsOut.permute(2, 0, 1), 'f');
112 |
113 | weightNoiseParams.clear();
114 |
115 | Gradient g = new DefaultGradient(gradientsFlattened);
116 | g.gradientForVariable().put(SelfAttentionParamInitializer.WEIGHT_KEY, Wg);
117 | g.gradientForVariable().put(SelfAttentionParamInitializer.QUERY_WEIGHT_KEY, Qg);
118 | g.gradientForVariable().put(SelfAttentionParamInitializer.BIAS_KEY, bg);
119 | g.gradientForVariable().put(SelfAttentionParamInitializer.QUERY_KEY, qg);
120 |
121 | epsOut = backpropDropOutIfPresent(epsOut);
122 | return new Pair<>(g, epsOut);
123 | }
124 |
125 | @Override
126 | public Pair feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState,
127 | int minibatchSize) {
128 | // no masking is possible after this point... i.e., masks have been taken into account
129 | // as part of the selfattention
130 | this.maskArray = maskArray;
131 | this.maskState = null; //Not used in global pooling - always applied
132 |
133 | return null;
134 | }
135 | }
136 |
--------------------------------------------------------------------------------
/src/main/java/tech/dubs/dl4j/contrib/attention/nn/params/SelfAttentionParamInitializer.java:
--------------------------------------------------------------------------------
1 | /*******************************************************************************
2 | * Copyright (c) 2015-2018 Skymind, Inc.
3 | *
4 | * This program and the accompanying materials are made available under the
5 | * terms of the Apache License, Version 2.0 which is available at
6 | * https://www.apache.org/licenses/LICENSE-2.0.
7 | *
8 | * Unless required by applicable law or agreed to in writing, software
9 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
10 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
11 | * License for the specific language governing permissions and limitations
12 | * under the License.
13 | *
14 | * SPDX-License-Identifier: Apache-2.0
15 | ******************************************************************************/
16 |
17 | package tech.dubs.dl4j.contrib.attention.nn.params;
18 |
19 | import org.deeplearning4j.nn.api.ParamInitializer;
20 | import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
21 | import org.deeplearning4j.nn.conf.distribution.Distributions;
22 | import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
23 | import org.deeplearning4j.nn.conf.layers.Layer;
24 | import org.deeplearning4j.nn.params.DefaultParamInitializer;
25 | import org.deeplearning4j.nn.weights.WeightInitUtil;
26 | import org.nd4j.linalg.api.ndarray.INDArray;
27 | import org.nd4j.linalg.api.rng.distribution.Distribution;
28 |
29 | import java.util.*;
30 |
31 | import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
32 | import static org.nd4j.linalg.indexing.NDArrayIndex.point;
33 |
34 | /**
35 | * @author Paul Dubs
36 | */
37 | public class SelfAttentionParamInitializer implements ParamInitializer {
38 |
39 | private static final SelfAttentionParamInitializer INSTANCE = new SelfAttentionParamInitializer();
40 |
41 | public static SelfAttentionParamInitializer getInstance(){
42 | return INSTANCE;
43 | }
44 |
45 | public static final String WEIGHT_KEY = DefaultParamInitializer.WEIGHT_KEY;
46 | public static final String QUERY_WEIGHT_KEY = "Q";
47 | public static final String QUERY_KEY = "q";
48 | public static final String BIAS_KEY = DefaultParamInitializer.BIAS_KEY;
49 |
50 | private static final List PARAM_KEYS = Collections.unmodifiableList(Arrays.asList(WEIGHT_KEY, QUERY_WEIGHT_KEY, BIAS_KEY, QUERY_KEY));
51 | private static final List WEIGHT_KEYS = Collections.unmodifiableList(Arrays.asList(WEIGHT_KEY, QUERY_WEIGHT_KEY));
52 | private static final List BIAS_KEYS = Collections.unmodifiableList(Arrays.asList(BIAS_KEY, QUERY_KEY));
53 |
54 |
55 | @Override
56 | public long numParams(NeuralNetConfiguration conf) {
57 | return numParams(conf.getLayer());
58 | }
59 |
60 | @Override
61 | public long numParams(Layer layer) {
62 | FeedForwardLayer c = (FeedForwardLayer) layer;
63 | final long nIn = c.getNIn();
64 | final long nOut = c.getNOut();
65 |
66 | final long paramsW = nIn * nOut;
67 | final long paramsWq = nIn * nOut;
68 | final long paramsB = nOut;
69 | final long paramsQ = nIn;
70 | return paramsW + paramsWq + paramsB + paramsQ;
71 | }
72 |
73 | @Override
74 | public List paramKeys(Layer layer) {
75 | return PARAM_KEYS;
76 | }
77 |
78 | @Override
79 | public List weightKeys(Layer layer) {
80 | return WEIGHT_KEYS;
81 | }
82 |
83 | @Override
84 | public List biasKeys(Layer layer) {
85 | return BIAS_KEYS;
86 | }
87 |
88 | @Override
89 | public boolean isWeightParam(Layer layer, String key) {
90 | return WEIGHT_KEYS.contains(key);
91 | }
92 |
93 | @Override
94 | public boolean isBiasParam(Layer layer, String key) {
95 | return BIAS_KEYS.contains(key);
96 | }
97 |
98 | @Override
99 | public Map init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) {
100 | FeedForwardLayer c = (FeedForwardLayer) conf.getLayer();
101 | final long nIn = c.getNIn();
102 | final long nOut = c.getNOut();
103 |
104 | Map m;
105 |
106 | if (initializeParams) {
107 | Distribution dist = Distributions.createDistribution(c.getDist());
108 |
109 | m = getSubsets(paramsView, nIn, nOut, false);
110 | INDArray w = WeightInitUtil.initWeights(nIn, nOut, new long[]{nIn, nOut}, c.getWeightInit(), dist, 'f', m.get(WEIGHT_KEY));
111 | m.put(WEIGHT_KEY, w);
112 |
113 | INDArray rq = WeightInitUtil.initWeights(nIn, nOut, new long[]{nIn, nOut}, c.getWeightInit(), dist,'f', m.get(QUERY_WEIGHT_KEY));
114 | m.put(QUERY_WEIGHT_KEY, rq);
115 |
116 | INDArray q = WeightInitUtil.initWeights(nIn, 1, new long[]{nIn, 1}, c.getWeightInit(), dist,'f', m.get(QUERY_KEY));
117 | m.put(QUERY_KEY, q);
118 | } else {
119 | m = getSubsets(paramsView, nIn, nOut, true);
120 | }
121 |
122 | conf.addVariable(WEIGHT_KEY);
123 | conf.addVariable(QUERY_WEIGHT_KEY);
124 | conf.addVariable(BIAS_KEY);
125 | conf.addVariable(QUERY_KEY);
126 |
127 | return m;
128 | }
129 |
130 | @Override
131 | public Map getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) {
132 | FeedForwardLayer c = (FeedForwardLayer) conf.getLayer();
133 | final long nIn = c.getNIn();
134 | final long nOut = c.getNOut();
135 |
136 | return getSubsets(gradientView, nIn, nOut, true);
137 | }
138 |
139 | private static Map getSubsets(INDArray in, long nIn, long nOut, boolean reshape){
140 | final long endW = nIn * nOut;
141 | final long endWq = endW + nIn * nOut;
142 | final long endB = endWq + nOut;
143 | final long endQ = endB + nIn;
144 |
145 | INDArray w = in.get(point(0), interval(0, endW));
146 | INDArray wq = in.get(point(0), interval(endW, endWq));
147 | INDArray b = in.get(point(0), interval(endWq, endB));
148 | INDArray q = in.get(point(0), interval(endB, endQ));
149 |
150 | if (reshape) {
151 | w = w.reshape('f', nIn, nOut);
152 | wq = wq.reshape('f', nIn, nOut);
153 | b = b.reshape('f', 1, nOut);
154 | q = q.reshape('f', 1, nIn);
155 | }
156 |
157 | Map m = new LinkedHashMap<>();
158 | m.put(WEIGHT_KEY, w);
159 | m.put(QUERY_WEIGHT_KEY, wq);
160 | m.put(BIAS_KEY, b);
161 | m.put(QUERY_KEY, q);
162 | return m;
163 | }
164 | }
165 |
--------------------------------------------------------------------------------
/src/main/java/tech/dubs/dl4j/contrib/attention/nn/params/RecurrentQueryAttentionParamInitializer.java:
--------------------------------------------------------------------------------
1 | /*******************************************************************************
2 | * Copyright (c) 2015-2018 Skymind, Inc.
3 | *
4 | * This program and the accompanying materials are made available under the
5 | * terms of the Apache License, Version 2.0 which is available at
6 | * https://www.apache.org/licenses/LICENSE-2.0.
7 | *
8 | * Unless required by applicable law or agreed to in writing, software
9 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
10 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
11 | * License for the specific language governing permissions and limitations
12 | * under the License.
13 | *
14 | * SPDX-License-Identifier: Apache-2.0
15 | ******************************************************************************/
16 |
17 | package tech.dubs.dl4j.contrib.attention.nn.params;
18 |
19 | import org.deeplearning4j.nn.api.ParamInitializer;
20 | import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
21 | import org.deeplearning4j.nn.conf.distribution.Distributions;
22 | import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer;
23 | import org.deeplearning4j.nn.conf.layers.Layer;
24 | import org.deeplearning4j.nn.params.DefaultParamInitializer;
25 | import org.deeplearning4j.nn.weights.WeightInit;
26 | import org.deeplearning4j.nn.weights.WeightInitUtil;
27 | import org.nd4j.linalg.api.ndarray.INDArray;
28 | import org.nd4j.linalg.api.rng.distribution.Distribution;
29 |
30 | import java.util.*;
31 |
32 | import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
33 | import static org.nd4j.linalg.indexing.NDArrayIndex.point;
34 |
35 | /**
36 | *
37 | * TODO: Allow configuring different weight init distribution for every weight and bias type
38 | *
39 | * @author Paul Dubs
40 | */
41 | public class RecurrentQueryAttentionParamInitializer implements ParamInitializer {
42 |
43 | private static final RecurrentQueryAttentionParamInitializer INSTANCE = new RecurrentQueryAttentionParamInitializer();
44 |
45 | public static RecurrentQueryAttentionParamInitializer getInstance() {
46 | return INSTANCE;
47 | }
48 |
49 | public static final String WEIGHT_KEY = DefaultParamInitializer.WEIGHT_KEY;
50 | public static final String RECURRENT_WEIGHT_KEY = "WR";
51 | public static final String BIAS_KEY = DefaultParamInitializer.BIAS_KEY;
52 | public static final String QUERY_WEIGHT_KEY = "WQ";
53 | public static final String RECURRENT_QUERY_WEIGHT_KEY = "WQR";
54 | public static final String QUERY_BIAS_KEY = "bQ";
55 |
56 | private static final List PARAM_KEYS = Collections.unmodifiableList(Arrays.asList(WEIGHT_KEY, QUERY_WEIGHT_KEY, BIAS_KEY, RECURRENT_WEIGHT_KEY, RECURRENT_QUERY_WEIGHT_KEY, QUERY_BIAS_KEY));
57 | private static final List WEIGHT_KEYS = Collections.unmodifiableList(Arrays.asList(WEIGHT_KEY, QUERY_WEIGHT_KEY, RECURRENT_WEIGHT_KEY, RECURRENT_QUERY_WEIGHT_KEY));
58 | private static final List BIAS_KEYS = Collections.unmodifiableList(Arrays.asList(BIAS_KEY, QUERY_BIAS_KEY));
59 |
60 |
61 | @Override
62 | public long numParams(NeuralNetConfiguration conf) {
63 | return numParams(conf.getLayer());
64 | }
65 |
66 | @Override
67 | public long numParams(Layer layer) {
68 | BaseRecurrentLayer c = (BaseRecurrentLayer) layer;
69 | final long nIn = c.getNIn();
70 | final long nOut = c.getNOut();
71 |
72 | final long paramsW = nIn * nOut;
73 | final long paramsWR = nIn * nOut;
74 | final long paramsWq = nIn;
75 | final long paramsWqR = nOut;
76 | final long paramsB = nOut;
77 | final long paramsBq = 1;
78 | return paramsW + paramsWR + paramsWq + paramsWqR + paramsB + paramsBq;
79 | }
80 |
81 | @Override
82 | public List paramKeys(Layer layer) {
83 | return PARAM_KEYS;
84 | }
85 |
86 | @Override
87 | public List weightKeys(Layer layer) {
88 | return WEIGHT_KEYS;
89 | }
90 |
91 | @Override
92 | public List biasKeys(Layer layer) {
93 | return BIAS_KEYS;
94 | }
95 |
96 | @Override
97 | public boolean isWeightParam(Layer layer, String key) {
98 | return WEIGHT_KEYS.contains(key);
99 | }
100 |
101 | @Override
102 | public boolean isBiasParam(Layer layer, String key) {
103 | return BIAS_KEYS.contains(key);
104 | }
105 |
106 | @Override
107 | public Map init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) {
108 | BaseRecurrentLayer c = (BaseRecurrentLayer) conf.getLayer();
109 | final long nIn = c.getNIn();
110 | final long nOut = c.getNOut();
111 |
112 | Map m;
113 |
114 | if (initializeParams) {
115 | Distribution dist = Distributions.createDistribution(c.getDist());
116 |
117 | m = getSubsets(paramsView, nIn, nOut, false);
118 | INDArray w = WeightInitUtil.initWeights(nIn, nOut, new long[]{nIn, nOut}, c.getWeightInit(), dist, 'f', m.get(WEIGHT_KEY));
119 | m.put(WEIGHT_KEY, w);
120 | INDArray wq = WeightInitUtil.initWeights(nIn, 1, new long[]{nIn, 1}, c.getWeightInit(), dist, 'f', m.get(QUERY_WEIGHT_KEY));
121 | m.put(QUERY_WEIGHT_KEY, wq);
122 |
123 |
124 | WeightInit rwInit;
125 | Distribution rwDist = dist;
126 | if (c.getWeightInitRecurrent() == null) {
127 | rwInit = c.getWeightInit();
128 | } else {
129 | rwInit = c.getWeightInitRecurrent();
130 | if (c.getDistRecurrent() != null) {
131 | rwDist = Distributions.createDistribution(c.getDistRecurrent());
132 | }
133 | }
134 |
135 | INDArray rw = WeightInitUtil.initWeights(nIn, nOut, new long[]{nIn, nOut}, rwInit, rwDist, 'f', m.get(RECURRENT_WEIGHT_KEY));
136 | m.put(RECURRENT_WEIGHT_KEY, rw);
137 | INDArray wqr = WeightInitUtil.initWeights(nOut, 1, new long[]{nOut, 1}, rwInit, rwDist, 'f', m.get(RECURRENT_QUERY_WEIGHT_KEY));
138 | m.put(RECURRENT_QUERY_WEIGHT_KEY, wqr);
139 | } else {
140 | m = getSubsets(paramsView, nIn, nOut, true);
141 | }
142 |
143 | for (String paramKey : PARAM_KEYS) {
144 | conf.addVariable(paramKey);
145 | }
146 |
147 | return m;
148 | }
149 |
150 | @Override
151 | public Map getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) {
152 | BaseRecurrentLayer c = (BaseRecurrentLayer) conf.getLayer();
153 | final long nIn = c.getNIn();
154 | final long nOut = c.getNOut();
155 |
156 | return getSubsets(gradientView, nIn, nOut, true);
157 | }
158 |
159 | private static Map getSubsets(INDArray in, long nIn, long nOut, boolean reshape) {
160 | final long endW = nIn * nOut;
161 | final long endWq = endW + nIn;
162 | final long endWR = endWq + nIn * nOut;
163 | final long endWqR = endWR + nOut;
164 | final long endB = endWqR + nOut;
165 | final long endBq = endB + 1;
166 |
167 | INDArray w = in.get(point(0), interval(0, endW));
168 | INDArray wq = in.get(point(0), interval(endW, endWq));
169 | INDArray wr = in.get(point(0), interval(endWq, endWR));
170 | INDArray wqr = in.get(point(0), interval(endWR, endWqR));
171 | INDArray b = in.get(point(0), interval(endWqR, endB));
172 | INDArray bq = in.get(point(0), interval(endB, endBq));
173 |
174 | if (reshape) {
175 | w = w.reshape('f', nIn, nOut);
176 | wr = wr.reshape('f', nIn, nOut);
177 | wq = wq.reshape('f', nIn, 1);
178 | wqr = wqr.reshape('f', nOut, 1);
179 | b = b.reshape('f', 1, nOut);
180 | bq = bq.reshape('f', 1, 1);
181 | }
182 |
183 | Map m = new LinkedHashMap<>();
184 | m.put(WEIGHT_KEY, w);
185 | m.put(QUERY_WEIGHT_KEY, wq);
186 | m.put(RECURRENT_WEIGHT_KEY, wr);
187 | m.put(RECURRENT_QUERY_WEIGHT_KEY, wqr);
188 | m.put(BIAS_KEY, b);
189 | m.put(QUERY_BIAS_KEY, bq);
190 | return m;
191 | }
192 | }
193 |
--------------------------------------------------------------------------------
/src/test/java/tech/dubs/dl4j/contrib/attention/GradientChecks.java:
--------------------------------------------------------------------------------
1 | package tech.dubs.dl4j.contrib.attention;
2 |
3 | import org.deeplearning4j.gradientcheck.GradientCheckUtil;
4 | import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
5 | import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
6 | import org.deeplearning4j.nn.conf.inputs.InputType;
7 | import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer;
8 | import org.deeplearning4j.nn.conf.layers.LSTM;
9 | import org.deeplearning4j.nn.conf.layers.OutputLayer;
10 | import org.deeplearning4j.nn.conf.layers.PoolingType;
11 | import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
12 | import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
13 | import org.deeplearning4j.nn.weights.WeightInit;
14 | import org.junit.Test;
15 | import org.nd4j.linalg.activations.Activation;
16 | import org.nd4j.linalg.api.buffer.DataBuffer;
17 | import org.nd4j.linalg.api.ndarray.INDArray;
18 | import org.nd4j.linalg.factory.Nd4j;
19 | import org.nd4j.linalg.learning.config.NoOp;
20 | import org.nd4j.linalg.lossfunctions.LossFunctions;
21 | import org.nd4j.linalg.profiler.OpProfiler;
22 | import tech.dubs.dl4j.contrib.attention.conf.RecurrentAttentionLayer;
23 | import tech.dubs.dl4j.contrib.attention.conf.SelfAttentionLayer;
24 | import tech.dubs.dl4j.contrib.attention.conf.TimestepAttentionLayer;
25 |
26 | import java.util.Random;
27 |
28 | import static org.junit.Assert.assertTrue;
29 |
30 | public class GradientChecks {
31 | private static final boolean PRINT_RESULTS = false;
32 | private static final boolean RETURN_ON_FIRST_FAILURE = false;
33 | private static final double DEFAULT_EPS = 1e-6;
34 | private static final double DEFAULT_MAX_REL_ERROR = 1e-3;
35 | private static final double DEFAULT_MIN_ABS_ERROR = 1e-8;
36 |
37 | static {
38 | Nd4j.setDataType(DataBuffer.Type.DOUBLE);
39 | }
40 |
41 | @Test
42 | public void testSelfAttentionLayer() {
43 | int nIn = 3;
44 | int nOut = 5;
45 | int tsLength = 4;
46 | int layerSize = 8;
47 | int attentionHeads = 2;
48 |
49 | Random r = new Random(12345);
50 | for (int mb : new int[]{1, 2, 3}) {
51 | for (boolean inputMask : new boolean[]{false, true}) {
52 | INDArray in = Nd4j.rand(new int[]{mb, nIn, tsLength});
53 | INDArray labels = Nd4j.create(mb, nOut);
54 | for (int i = 0; i < mb; i++) {
55 | labels.putScalar(i, r.nextInt(nOut), 1.0);
56 | }
57 | String maskType = (inputMask ? "inputMask" : "none");
58 |
59 | INDArray inMask = null;
60 | if (inputMask) {
61 | inMask = Nd4j.ones(mb, tsLength);
62 | for (int i = 0; i < mb; i++) {
63 | int firstMaskedStep = tsLength - 1 - i;
64 | if (firstMaskedStep == 0) {
65 | firstMaskedStep = tsLength;
66 | }
67 | for (int j = firstMaskedStep; j < tsLength; j++) {
68 | inMask.putScalar(i, j, 0.0);
69 | }
70 | }
71 | }
72 |
73 | String name = "testSelfAttentionLayer() - mb=" + mb + ", tsLength = " + tsLength + ", maskType=" + maskType;
74 | System.out.println("Starting test: " + name);
75 |
76 |
77 | MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
78 | .activation(Activation.TANH)
79 | .updater(new NoOp())
80 | .weightInit(WeightInit.XAVIER)
81 | .list()
82 | .layer(new LSTM.Builder().nOut(layerSize).build())
83 | .layer(new SelfAttentionLayer.Builder().nOut(attentionHeads).build())
84 | .layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX)
85 | .lossFunction(LossFunctions.LossFunction.MCXENT).build())
86 | .setInputType(InputType.recurrent(nIn))
87 | .build();
88 |
89 | MultiLayerNetwork net = new MultiLayerNetwork(conf);
90 | net.init();
91 |
92 | boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
93 | DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE,in, labels, inMask, null, false, -1,
94 | null //Sets.newHashSet( /*"1_b", "1_W",* "1_WR", "1_WQR", "1_WQ", "1_bQ",*/ "2_b", "2_W" ,"0_W", "0_RW", "0_b"/**/)
95 | );
96 | assertTrue(name, gradOK);
97 | }
98 | }
99 |
100 | OpProfiler.getInstance().printOutDashboard();
101 | }
102 |
103 | @Test
104 | public void testTimestepAttentionLayer() {
105 | int nIn = 3;
106 | int nOut = 5;
107 | int tsLength = 4;
108 | int layerSize = 8;
109 | int attentionHeads = 3;
110 |
111 |
112 | Random r = new Random(12345);
113 | for (int mb : new int[]{1, 2, 3}) {
114 | for (boolean inputMask : new boolean[]{false, true}) {
115 | INDArray in = Nd4j.rand(new int[]{mb, nIn, tsLength});
116 | INDArray labels = Nd4j.create(mb, nOut);
117 | for (int i = 0; i < mb; i++) {
118 | labels.putScalar(i, r.nextInt(nOut), 1.0);
119 | }
120 | String maskType = (inputMask ? "inputMask" : "none");
121 |
122 | INDArray inMask = null;
123 | if (inputMask) {
124 | inMask = Nd4j.ones(mb, tsLength);
125 | for (int i = 0; i < mb; i++) {
126 | int firstMaskedStep = tsLength - 1 - i;
127 | if (firstMaskedStep == 0) {
128 | firstMaskedStep = tsLength;
129 | }
130 | for (int j = firstMaskedStep; j < tsLength; j++) {
131 | inMask.putScalar(i, j, 0.0);
132 | }
133 | }
134 | }
135 |
136 | String name = "testTimestepAttentionLayer() - mb=" + mb + ", tsLength = " + tsLength + ", maskType=" + maskType;
137 | System.out.println("Starting test: " + name);
138 |
139 |
140 | MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
141 | .activation(Activation.TANH)
142 | .updater(new NoOp())
143 | .weightInit(WeightInit.XAVIER)
144 | .list()
145 | .layer(new LSTM.Builder().nOut(layerSize).build())
146 | .layer(new TimestepAttentionLayer.Builder().nOut(attentionHeads).build())
147 | .layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.MAX).build())
148 | .layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX)
149 | .lossFunction(LossFunctions.LossFunction.MCXENT).build())
150 | .setInputType(InputType.recurrent(nIn))
151 | .build();
152 |
153 | MultiLayerNetwork net = new MultiLayerNetwork(conf);
154 | net.init();
155 |
156 | boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
157 | DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, in, labels, inMask, null);
158 | assertTrue(name, gradOK);
159 | }
160 | }
161 | }
162 |
163 | @Test
164 | public void testRecurrentAttentionLayer() {
165 | int nIn = 3;
166 | int nOut = 5;
167 | int tsLength = 4;
168 | int layerSize = 8;
169 | int attentionHeads = 7;
170 |
171 |
172 | Random r = new Random(12345);
173 | for (int mb : new int[]{3, 2, 1}) {
174 | for (boolean inputMask : new boolean[]{true, false}) {
175 | INDArray in = Nd4j.rand(new int[]{mb, nIn, tsLength});
176 | INDArray labels = Nd4j.create(mb, nOut);
177 | for (int i = 0; i < mb; i++) {
178 | labels.putScalar(i, r.nextInt(nOut), 1.0);
179 | }
180 | String maskType = (inputMask ? "inputMask" : "none");
181 |
182 | INDArray inMask = null;
183 | if (inputMask) {
184 | inMask = Nd4j.ones(mb, tsLength);
185 | for (int i = 0; i < mb; i++) {
186 | int firstMaskedStep = tsLength - 1 - i;
187 | if (firstMaskedStep == 0) {
188 | firstMaskedStep = tsLength;
189 | }
190 | for (int j = firstMaskedStep; j < tsLength; j++) {
191 | inMask.putScalar(i, j, 0.0);
192 | }
193 | }
194 | }
195 |
196 | String name = "testRecurrentAttentionLayer() - mb=" + mb + ", tsLength = " + tsLength + ", maskType=" + maskType;
197 | System.out.println("Starting test: " + name);
198 |
199 |
200 | MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
201 | .activation(Activation.TANH)
202 | .updater(new NoOp())
203 | .weightInit(WeightInit.XAVIER)
204 | .list()
205 | .layer(new LSTM.Builder().nOut(layerSize).build())
206 | .layer(new LastTimeStep(new RecurrentAttentionLayer.Builder().nOut(7).build()))
207 | .layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX)
208 | .lossFunction(LossFunctions.LossFunction.MCXENT).build())
209 | .setInputType(InputType.recurrent(nIn))
210 | .build();
211 |
212 | MultiLayerNetwork net = new MultiLayerNetwork(conf);
213 | net.init();
214 |
215 | //System.out.println("Original");
216 | boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
217 | DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, in, labels, inMask, null, false, -1, null
218 | //Sets.newHashSet( /*"1_b", "1_W",* "1_WR", "1_WQR", "1_WQ", "1_bQ",*/ "2_b", "2_W" ,"0_W", "0_RW", "0_b"/**/)
219 | );
220 | assertTrue(name, gradOK);
221 | }
222 | }
223 | }
224 | }
--------------------------------------------------------------------------------
/src/main/java/tech/dubs/dl4j/contrib/attention/nn/RecurrentAttentionLayer.java:
--------------------------------------------------------------------------------
1 | package tech.dubs.dl4j.contrib.attention.nn;
2 |
3 | import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
4 | import org.deeplearning4j.nn.gradient.DefaultGradient;
5 | import org.deeplearning4j.nn.gradient.Gradient;
6 | import org.deeplearning4j.nn.layers.BaseLayer;
7 | import org.deeplearning4j.nn.workspace.ArrayType;
8 | import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
9 | import org.nd4j.base.Preconditions;
10 | import org.nd4j.linalg.activations.IActivation;
11 | import org.nd4j.linalg.activations.impl.ActivationSoftmax;
12 | import org.nd4j.linalg.api.ndarray.INDArray;
13 | import org.nd4j.linalg.api.shape.Shape;
14 | import org.nd4j.linalg.factory.Nd4j;
15 | import org.nd4j.linalg.primitives.Pair;
16 | import tech.dubs.dl4j.contrib.attention.nn.params.RecurrentQueryAttentionParamInitializer;
17 |
18 | import static org.nd4j.linalg.indexing.NDArrayIndex.all;
19 | import static org.nd4j.linalg.indexing.NDArrayIndex.point;
20 |
21 | /**
22 | * Recurrent Attention Layer Implementation
23 | *
24 | *
25 | *
26 | * TODO:
27 | * - Optionally keep attention weights around for inspection
28 | * - Handle Masking
29 | *
30 | * @author Paul Dubs
31 | */
32 | public class RecurrentAttentionLayer extends BaseLayer {
33 | private IActivation softmax = new ActivationSoftmax();
34 |
35 | public RecurrentAttentionLayer(NeuralNetConfiguration conf) {
36 | super(conf);
37 | }
38 |
39 | @Override
40 | public boolean isPretrainLayer() {
41 | return false;
42 | }
43 |
44 | @Override
45 | public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) {
46 | assertInputSet(false);
47 | Preconditions.checkState(input.rank() == 3,
48 | "3D input expected to RNN layer expected, got " + input.rank());
49 |
50 | applyDropOutIfNecessary(training, workspaceMgr);
51 |
52 | INDArray W = getParamWithNoise(RecurrentQueryAttentionParamInitializer.WEIGHT_KEY, training, workspaceMgr);
53 | INDArray Wr = getParamWithNoise(RecurrentQueryAttentionParamInitializer.RECURRENT_WEIGHT_KEY, training, workspaceMgr);
54 | INDArray Wq = getParamWithNoise(RecurrentQueryAttentionParamInitializer.QUERY_WEIGHT_KEY, training, workspaceMgr);
55 | INDArray Wqr = getParamWithNoise(RecurrentQueryAttentionParamInitializer.RECURRENT_QUERY_WEIGHT_KEY, training, workspaceMgr);
56 | INDArray b = getParamWithNoise(RecurrentQueryAttentionParamInitializer.BIAS_KEY, training, workspaceMgr);
57 | INDArray bq = getParamWithNoise(RecurrentQueryAttentionParamInitializer.QUERY_BIAS_KEY, training, workspaceMgr);
58 |
59 | long examples = input.size(0);
60 | long tsLength = input.size(2);
61 | long nIn = layerConf().getNIn();
62 | long nOut = layerConf().getNOut();
63 | IActivation a = layerConf().getActivationFn();
64 |
65 | INDArray activations = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, new long[]{examples, nOut, tsLength}, 'f');
66 |
67 | if(input.shape()[0] != nIn)
68 | input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input.permute(1, 2, 0), 'f');
69 |
70 | final AdditiveAttentionMechanism attentionMechanism = new AdditiveAttentionMechanism(Wqr, Wq, bq, a, workspaceMgr, training).useCaching();
71 |
72 |
73 | // pre-compute non-recurrent part
74 | activations.assign(
75 | Nd4j.gemm(W, input.reshape('f', nIn, tsLength * examples), true, false)
76 | .addiColumnVector(b.transpose())
77 | .reshape('f', nOut, tsLength, examples).permute(2, 0, 1)
78 | );
79 |
80 |
81 | for (long timestep = 0; timestep < tsLength; timestep++) {
82 | final INDArray curOut = timestepArray(activations, timestep);
83 | if(timestep > 0){
84 | final INDArray prevActivation = timestepArray(activations, timestep - 1);
85 | final INDArray queries = Nd4j.expandDims(prevActivation, 2).permute(1,2,0);
86 | final INDArray attention = attentionMechanism.query(queries, input, input, maskArray);
87 | curOut.addi(Nd4j.squeeze(attention, 2).mmul(Wr));
88 | }
89 |
90 | a.getActivation(curOut, true);
91 | }
92 |
93 |
94 | return activations;
95 | }
96 |
97 |
98 |
99 | /*
100 | * Notice that the epsilon given here does not contain the recurrent component, which will have to be calculated
101 | * manually.
102 | */
103 | @Override
104 | public Pair backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
105 | assertInputSet(true);
106 | if(epsilon.ordering() != 'f' || !Shape.hasDefaultStridesForShape(epsilon))
107 | epsilon = epsilon.dup('f');
108 |
109 | INDArray W = getParamWithNoise(RecurrentQueryAttentionParamInitializer.WEIGHT_KEY, true, workspaceMgr);
110 | INDArray Wr = getParamWithNoise(RecurrentQueryAttentionParamInitializer.RECURRENT_WEIGHT_KEY, true, workspaceMgr);
111 | INDArray Wq = getParamWithNoise(RecurrentQueryAttentionParamInitializer.QUERY_WEIGHT_KEY, true, workspaceMgr);
112 | INDArray Wqr = getParamWithNoise(RecurrentQueryAttentionParamInitializer.RECURRENT_QUERY_WEIGHT_KEY, true, workspaceMgr);
113 | INDArray b = getParamWithNoise(RecurrentQueryAttentionParamInitializer.BIAS_KEY, true, workspaceMgr);
114 | INDArray bq = getParamWithNoise(RecurrentQueryAttentionParamInitializer.QUERY_BIAS_KEY, true, workspaceMgr);
115 |
116 | INDArray Wg = gradientViews.get(RecurrentQueryAttentionParamInitializer.WEIGHT_KEY);
117 | INDArray Wrg = gradientViews.get(RecurrentQueryAttentionParamInitializer.RECURRENT_WEIGHT_KEY);
118 | INDArray Wqg = gradientViews.get(RecurrentQueryAttentionParamInitializer.QUERY_WEIGHT_KEY);
119 | INDArray Wqrg = gradientViews.get(RecurrentQueryAttentionParamInitializer.RECURRENT_QUERY_WEIGHT_KEY);
120 | INDArray bg = gradientViews.get(RecurrentQueryAttentionParamInitializer.BIAS_KEY);
121 | INDArray bqg = gradientViews.get(RecurrentQueryAttentionParamInitializer.QUERY_BIAS_KEY);
122 | gradientsFlattened.assign(0);
123 |
124 | applyDropOutIfNecessary(true, workspaceMgr);
125 |
126 |
127 | long nIn = layerConf().getNIn();
128 | long nOut = layerConf().getNOut();
129 | long examples = input.shape()[0] == nIn ? input.shape()[2] : input.shape()[0];
130 | long tsLength = input.shape()[0] == nIn ? input.shape()[1] : input.shape()[2];
131 | IActivation a = layerConf().getActivationFn();
132 |
133 | INDArray activations = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, new long[]{examples, nOut, tsLength}, 'f');
134 | INDArray preOut = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, new long[]{examples, nOut, tsLength}, 'f');
135 | INDArray attentions = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, new long[]{examples, nIn, tsLength}, 'f');
136 | INDArray queryG = workspaceMgr.create(ArrayType.BP_WORKING_MEM, new long[]{nOut, 1, examples}, 'f');
137 |
138 | if(input.shape()[0] != nIn)
139 | input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input.permute(1, 2, 0), 'f');
140 |
141 | INDArray epsOut = workspaceMgr.create(ArrayType.ACTIVATION_GRAD, input.shape(), 'f');
142 | epsOut.assign(0);
143 |
144 |
145 | final AdditiveAttentionMechanism attentionMechanism = new AdditiveAttentionMechanism(Wqr, Wq, bq, a, workspaceMgr, true).useCaching();
146 |
147 | // pre-compute non-recurrent part
148 | activations.assign(
149 | Nd4j.gemm(W, input.reshape('f', nIn, tsLength * examples), true, false)
150 | .addiColumnVector(b.transpose())
151 | .reshape('f', nOut, tsLength, examples).permute(2, 0, 1)
152 | );
153 |
154 |
155 | for (long timestep = 0; timestep < tsLength; timestep++) {
156 | final INDArray curOut = timestepArray(activations, timestep);
157 |
158 | if(timestep > 0){
159 | final INDArray prevActivation = timestepArray(activations, timestep - 1);
160 | final INDArray query = Nd4j.expandDims(prevActivation, 2).permute(1, 2, 0);
161 | final INDArray attention = Nd4j.squeeze(attentionMechanism.query(query, input, input, maskArray), 2);
162 | timestepArray(attentions, timestep).assign(attention);
163 |
164 | curOut.addi(attention.mmul(Wr));
165 | }
166 | timestepArray(preOut, timestep).assign(curOut);
167 | a.getActivation(curOut, true);
168 | }
169 |
170 |
171 | for (long timestep = tsLength - 1; timestep >= 0; timestep--) {
172 | final INDArray curEps = timestepArray(epsilon, timestep);
173 | final INDArray curPreOut = timestepArray(preOut, timestep);
174 | final INDArray curIn = input.get(all(), point(timestep), all());
175 |
176 | final INDArray dldz = a.backprop(curPreOut, curEps).getFirst();
177 | Wg.addi(Nd4j.gemm(curIn, dldz, false, false));
178 | bg.addi(dldz.sum(0));
179 | epsOut.tensorAlongDimension((int)timestep, 0, 2).addi(Nd4j.gemm(dldz, W, false, true).transposei());
180 |
181 | if(timestep > 0){
182 | final INDArray curAttn = timestepArray(attentions, timestep);
183 |
184 | Wrg.addi(Nd4j.gemm(curAttn, dldz, true, false));
185 |
186 | final INDArray prevEps = timestepArray(epsilon, timestep - 1);
187 | final INDArray prevActivation = timestepArray(activations, timestep - 1);
188 | final INDArray query = Nd4j.expandDims(prevActivation, 2).permute(1,2,0);
189 | queryG.assign(0);
190 |
191 | final INDArray dldAtt = Nd4j.gemm(dldz, Wr, false, true);
192 | attentionMechanism
193 | .withGradientViews(Wqg, Wqrg, bqg, epsOut, epsOut, queryG)
194 | .backprop(dldAtt, query, input, input, maskArray);
195 |
196 | prevEps.addi(Nd4j.squeeze(queryG, 1).transpose());
197 | }
198 | }
199 |
200 | weightNoiseParams.clear();
201 |
202 | Gradient g = new DefaultGradient(gradientsFlattened);
203 | g.gradientForVariable().put(RecurrentQueryAttentionParamInitializer.WEIGHT_KEY, Wg);
204 | g.gradientForVariable().put(RecurrentQueryAttentionParamInitializer.QUERY_WEIGHT_KEY, Wqg);
205 | g.gradientForVariable().put(RecurrentQueryAttentionParamInitializer.RECURRENT_QUERY_WEIGHT_KEY, Wqrg);
206 | g.gradientForVariable().put(RecurrentQueryAttentionParamInitializer.RECURRENT_WEIGHT_KEY, Wrg);
207 | g.gradientForVariable().put(RecurrentQueryAttentionParamInitializer.BIAS_KEY, bg);
208 | g.gradientForVariable().put(RecurrentQueryAttentionParamInitializer.QUERY_BIAS_KEY, bqg);
209 |
210 | epsOut = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsOut.permute(2, 0, 1), 'f');
211 | epsOut = backpropDropOutIfPresent(epsOut);
212 |
213 | return new Pair<>(g, epsOut);
214 | }
215 |
216 | private INDArray subArray(INDArray in, int example, int timestep){
217 | return in.tensorAlongDimension(example, 1, 2).tensorAlongDimension(timestep, 0);
218 | }
219 |
220 | private INDArray timestepArray(INDArray in, long timestep){
221 | return in.tensorAlongDimension((int) timestep, 0, 1);
222 | }
223 | }
224 |
--------------------------------------------------------------------------------
/src/main/java/tech/dubs/dl4j/contrib/attention/nn/AdditiveAttentionMechanism.java:
--------------------------------------------------------------------------------
1 | package tech.dubs.dl4j.contrib.attention.nn;
2 |
3 | import org.deeplearning4j.nn.workspace.ArrayType;
4 | import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
5 | import org.nd4j.linalg.activations.IActivation;
6 | import org.nd4j.linalg.api.memory.MemoryWorkspace;
7 | import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
8 | import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
9 | import org.nd4j.linalg.api.memory.enums.LearningPolicy;
10 | import org.nd4j.linalg.api.ndarray.INDArray;
11 | import org.nd4j.linalg.api.shape.Shape;
12 | import org.nd4j.linalg.factory.Nd4j;
13 | import tech.dubs.dl4j.contrib.attention.activations.ActivationMaskedSoftmax;
14 |
15 | import java.util.Arrays;
16 |
17 | import static org.nd4j.linalg.indexing.NDArrayIndex.all;
18 | import static org.nd4j.linalg.indexing.NDArrayIndex.point;
19 |
20 | /*
21 | * Attention: Shapes for keys, values and queries should be in [features, timesteps, examples] order!
22 | * @author Paul Dubs
23 | */
24 | public class AdditiveAttentionMechanism {
25 | private final INDArray W;
26 | private final INDArray Q;
27 | private final INDArray b;
28 | private final IActivation activation;
29 | private final ActivationMaskedSoftmax softmax;
30 | private final LayerWorkspaceMgr mgr;
31 | private final boolean training;
32 | private boolean caching;
33 | private INDArray WkCache;
34 |
35 | // Required to be set for backprop
36 | private INDArray Wg;
37 | private INDArray Qg;
38 | private INDArray bg;
39 | private INDArray keyG;
40 | private INDArray valueG;
41 | private INDArray queryG;
42 |
43 | public AdditiveAttentionMechanism(INDArray queryWeight, INDArray keyWeight, INDArray bias, IActivation activation, LayerWorkspaceMgr mgr, boolean training) {
44 | assertWeightShapes(queryWeight, keyWeight, bias);
45 | Q = queryWeight;
46 | W = keyWeight;
47 | b = bias;
48 | this.activation = activation;
49 | softmax = new ActivationMaskedSoftmax();
50 | this.mgr = mgr;
51 | this.training = training;
52 |
53 | this.caching = false;
54 | }
55 |
56 | public AdditiveAttentionMechanism useCaching() {
57 | this.caching = true;
58 | return this;
59 | }
60 |
61 | public INDArray query(INDArray queries, INDArray keys, INDArray values, INDArray mask) {
62 | assertShapes(queries, keys, values);
63 |
64 | final long examples = queries.shape()[2];
65 | final long queryCount = queries.shape()[1];
66 | final long queryWidth = queries.shape()[0];
67 | final long attentionHeads = W.shape()[1];
68 | final long memoryWidth = W.shape()[0];
69 | final long tsLength = keys.shape()[1];
70 |
71 | final INDArray result = mgr.createUninitialized(ArrayType.FF_WORKING_MEM, new long[]{examples, memoryWidth * attentionHeads, queryCount}, 'f');
72 |
73 | WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder()
74 | .policyAllocation(AllocationPolicy.STRICT)
75 | .policyLearning(LearningPolicy.FIRST_LOOP)
76 | .build();
77 |
78 | if (this.caching && this.WkCache == null) {
79 | final INDArray target = mgr.createUninitialized(ArrayType.FF_WORKING_MEM, new long[]{attentionHeads, tsLength * examples}, 'f');
80 |
81 | this.WkCache = Nd4j.gemm(W, keys.reshape('f', memoryWidth, tsLength * examples), target, true, false, 1.0, 0.0)
82 | .addiColumnVector(b.transpose())
83 | .reshape('f', attentionHeads, tsLength, examples);
84 | }
85 |
86 | final INDArray queryRes = Nd4j.gemm(Q, queries.reshape('f', queryWidth, queryCount * examples), true, false)
87 | .reshape('f', attentionHeads, queryCount, examples);
88 |
89 | for (long example = 0; example < examples; example++) {
90 | try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(initialConfig, "ATTENTION_FF")) {
91 | final INDArray curValues = values.get(all(), all(), point(example));
92 | final INDArray curKeys = keys.get(all(), all(), point(example));
93 |
94 | final INDArray preResult;
95 | if (this.caching) {
96 | preResult = this.WkCache.get(all(), all(), point(example));
97 | } else {
98 | preResult = Nd4j.gemm(W, curKeys, true, false);
99 | preResult.addiColumnVector(b.transpose());
100 | }
101 |
102 |
103 | final INDArray curMask = subMask(mask, example);
104 | final INDArray attentionHeadMask = attentionHeadMask(curMask, preResult.shape());
105 |
106 | for (long queryIdx = 0; queryIdx < queryCount; queryIdx++) {
107 | final INDArray query = queries.get(all(), point(queryIdx), point(example));
108 | final INDArray curResult = subArray(result, example, queryIdx);
109 |
110 | final INDArray queryResult = queryRes.get(all(), point(queryIdx), point(example));
111 |
112 | final INDArray preA = preResult.addColumnVector(queryResult);
113 | final INDArray preS = this.activation.getActivation(preA, training);
114 | final INDArray attW = softmax.getActivation(preS, attentionHeadMask);
115 |
116 | final INDArray att = Nd4j.gemm(curValues, attW, false, true);
117 | curResult.assign(att.reshape('f', 1, memoryWidth * attentionHeads));
118 | }
119 | }
120 | }
121 | return result;
122 | }
123 |
124 | public AdditiveAttentionMechanism withGradientViews(INDArray W, INDArray Q, INDArray b, INDArray keys, INDArray values, INDArray queries) {
125 | Wg = W;
126 | Qg = Q;
127 | bg = b;
128 | keyG = keys;
129 | valueG = values;
130 | queryG = queries;
131 |
132 | return this;
133 | }
134 |
135 | public void backprop(INDArray epsilon, INDArray queries, INDArray keys, INDArray values, INDArray mask) {
136 | if (Wg == null || Qg == null || bg == null || keyG == null || valueG == null || queryG == null) {
137 | throw new IllegalStateException("You MUST use attnMech.withGradientViews(...).backprop(...).");
138 | }
139 |
140 | assertShapes(queries, keys, values);
141 |
142 | final long examples = queries.shape()[2];
143 | final long queryCount = queries.shape()[1];
144 | final long queryWidth = queries.shape()[0];
145 | final long attentionHeads = W.shape()[1];
146 | final long memoryWidth = W.shape()[0];
147 | final long tsLength = keys.shape()[1];
148 |
149 |
150 | if (epsilon.ordering() != 'c' || !Shape.hasDefaultStridesForShape(epsilon))
151 | epsilon = epsilon.dup('c');
152 |
153 | final long[] epsilonShape = epsilon.shape();
154 | if (epsilonShape[0] != examples || epsilonShape[1] != (attentionHeads * memoryWidth) || (epsilonShape.length == 2 && queryCount != 1) || (epsilonShape.length == 3 && epsilonShape[2] != queryCount)) {
155 | throw new IllegalStateException("Epsilon shape must match result shape. Got epsilon.shape() = " + Arrays.toString(epsilonShape)
156 | + "; result shape = [" + examples + ", " + attentionHeads * memoryWidth + ", " + queryCount + "]");
157 | }
158 |
159 | WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder()
160 | .policyAllocation(AllocationPolicy.STRICT)
161 | .policyLearning(LearningPolicy.FIRST_LOOP)
162 | .build();
163 |
164 | final INDArray dldAtt = epsilon.reshape('c', examples, attentionHeads, memoryWidth, queryCount);
165 |
166 | if (this.caching && this.WkCache == null) {
167 | final INDArray target = mgr.createUninitialized(ArrayType.BP_WORKING_MEM, new long[]{attentionHeads, tsLength * examples}, 'f');
168 |
169 | this.WkCache = Nd4j.gemm(W, keys.reshape('f', memoryWidth, tsLength * examples), target, true, false, 1.0, 0.0)
170 | .addiColumnVector(b.transpose())
171 | .reshape('f', attentionHeads, tsLength, examples);
172 | }
173 |
174 | final INDArray queryRes = Nd4j.gemm(Q, queries.reshape('f', queryWidth, queryCount * examples), true, false)
175 | .reshape('f', attentionHeads, queryCount, examples);
176 |
177 | for (long example = 0; example < examples; example++) {
178 | try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(initialConfig, "ATTENTION_BP")) {
179 | final INDArray curValues = values.get(all(), all(), point(example));
180 | final INDArray curKeys = keys.get(all(), all(), point(example));
181 |
182 | final INDArray exEps = dldAtt.tensorAlongDimension((int) example, 1, 2, 3);
183 |
184 | final INDArray preResult;
185 | if (this.caching) {
186 | preResult = this.WkCache.get(all(), all(), point(example));
187 | } else {
188 | preResult = Nd4j.gemm(W, curKeys, true, false);
189 | preResult.addiColumnVector(b.transpose());
190 | }
191 |
192 | final INDArray curMask = subMask(mask, example);
193 | final INDArray attentionHeadMask = attentionHeadMask(curMask, preResult.shape());
194 |
195 | for (long queryIdx = 0; queryIdx < queryCount; queryIdx++) {
196 | final INDArray curEps = exEps.tensorAlongDimension((int) queryIdx, 0, 1);
197 | final INDArray query = queries.get(all(), point(queryIdx), point(example));
198 |
199 | final INDArray queryResult = queryRes.get(all(), point(queryIdx), point(example));
200 |
201 | final INDArray preA = preResult.addColumnVector(queryResult);
202 | final INDArray preS = this.activation.getActivation(preA.dup(), training);
203 | final INDArray attW = softmax.getActivation(preS, attentionHeadMask);
204 |
205 | valueG.get(all(), all(), point(example)).addi(Nd4j.gemm(curEps, attW, true, false));
206 |
207 | final INDArray dldAttW = Nd4j.gemm(curEps, curValues, false, false);
208 | final INDArray dldPreS = softmax.backprop(attW, attentionHeadMask, dldAttW).getFirst();
209 | final INDArray dldPreA = activation.backprop(preA, dldPreS).getFirst();
210 |
211 | final INDArray dldPreASum = dldPreA.sum(1);
212 |
213 | Nd4j.gemm(query, dldPreASum, Qg, false, true, 1.0, 1.0);
214 | Nd4j.gemm(curKeys, dldPreA, Wg, false, true, 1.0, 1.0);
215 |
216 | bg.addi(dldPreASum.transpose());
217 |
218 | keyG.get(all(), all(), point(example)).addi(Nd4j.gemm(W, dldPreA, false, false));
219 | queryG.get(all(), point(queryIdx), point(example)).addi(Nd4j.gemm(Q, dldPreASum, false, false));
220 | }
221 | }
222 | }
223 | }
224 |
225 | private void assertWeightShapes(INDArray queryWeight, INDArray keyWeight, INDArray bias) {
226 | final long qOut = queryWeight.shape()[1];
227 | final long kOut = keyWeight.shape()[1];
228 | final long bOut = bias.shape()[1];
229 | if (qOut != kOut || qOut != bOut) {
230 | throw new IllegalStateException("Shapes must be compatible: queryWeight.shape() = " + Arrays.toString(queryWeight.shape())
231 | + ", keyWeight.shape() = " + Arrays.toString(keyWeight.shape())
232 | + ", bias.shape() = " + Arrays.toString(bias.shape())
233 | + "\n Compatible shapes should have the same second dimension, but got: [" + qOut + ", " + kOut + ", " + bOut + "]"
234 | );
235 | }
236 | }
237 |
238 | private void assertShapes(INDArray query, INDArray keys, INDArray values) {
239 | final long kIn = W.shape()[0];
240 | final long qIn = Q.shape()[0];
241 |
242 | if (query.shape()[0] != qIn || keys.shape()[0] != kIn) {
243 | throw new IllegalStateException("Shapes of query and keys must be compatible to weights, but got: queryWeight.shape() = " + Arrays.toString(Q.shape())
244 | + ", queries.shape() = " + Arrays.toString(query.shape())
245 | + "; keyWeight.shape() = " + Arrays.toString(W.shape())
246 | + ", keys.shape() = " + Arrays.toString(keys.shape())
247 | );
248 | }
249 |
250 | if (keys.shape()[1] != values.shape()[1]) {
251 | throw new IllegalStateException("Keys must be the same length as values! But got keys.shape() = " + Arrays.toString(keys.shape())
252 | + ", values.shape = " + Arrays.toString(values.shape()));
253 | }
254 |
255 | if (keys.shape()[2] != values.shape()[2] || query.shape()[2] != keys.shape()[2]) {
256 | throw new IllegalStateException("Queries, Keys and Values must have same mini-batch size! But got keys.shape() = " + Arrays.toString(keys.shape())
257 | + ", values.shape = " + Arrays.toString(values.shape())
258 | + ", queries.shape = " + Arrays.toString(query.shape())
259 | );
260 | }
261 | }
262 |
263 |
264 | private INDArray subArray(INDArray in, long example) {
265 | return in.tensorAlongDimension((int) example, 1, 2);
266 | }
267 |
268 | private INDArray subArray(INDArray in, long example, long timestep) {
269 | return subArray(in, example).tensorAlongDimension((int) timestep, 0);
270 | }
271 |
272 | private INDArray subMask(INDArray mask, long example) {
273 | if (mask == null) {
274 | return null;
275 | } else {
276 | return mask.tensorAlongDimension((int) example, 1);
277 | }
278 | }
279 |
280 | private INDArray attentionHeadMask(INDArray mask, long[] shape) {
281 | if (mask == null) {
282 | return null;
283 | } else {
284 | return mask.broadcast(shape);
285 | }
286 | }
287 | }
288 |
--------------------------------------------------------------------------------