├── .gitignore
├── LICENSE
├── README.md
├── pom.xml
└── src
├── main
└── java
│ └── com
│ └── simiacryptus
│ └── text
│ ├── GraphModifier.java
│ ├── LanguageCodeModel.java
│ ├── MinEntropyWrapper.java
│ ├── ModelWrapper.java
│ ├── SimpleModel.java
│ ├── SumModel.java
│ ├── TemperatureWrapper.java
│ ├── TextGenerator.java
│ ├── TopNWrapper.java
│ └── gpt2
│ ├── GPT2Codec.java
│ ├── GPT2Edit_345M.java
│ ├── GPT2Model.java
│ └── GPT2Util.java
├── site
└── site.xml
└── test
├── java
└── com
│ └── simiacryptus
│ └── text
│ └── gpt2
│ ├── DevTests.java
│ ├── GraphComparer.java
│ ├── TestUtil.java
│ └── UserTests.java
└── resources
└── logback.xml
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | !/.mvn/wrapper/maven-wrapper.jar
3 | # Avoid ignoring Maven wrapper jar file (.jar files are usually ignored)
4 | *.hprof
5 | *.pb
6 | *.pem
7 | *.zip
8 | .idea/
9 | .mvn/timing.properties
10 | buildNumber.properties
11 | dependency-reduced-pom.xml
12 | encoder_345M.json
13 | pom.xml.next
14 | pom.xml.releaseBackup
15 | pom.xml.tag
16 | pom.xml.versionsBackup
17 | release.properties
18 | target/
19 | tf-gpt-2.iml
20 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the mask of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # tf-gpt-2
2 |
3 | Java library for the GPT-2 Text Model using Tensorflow
4 |
5 | Source:
6 | 1. [Original Python Release](https://github.com/openai/gpt-2)
7 | 1. [OpenAI Blog - GPT2 Details](https://openai.com/blog/better-language-models/)
8 |
9 | More Background:
10 | 1. [Transformers and Attention Models](http://jalammar.github.io/illustrated-transformer/)
11 |
12 | ## Basic Usage
13 |
14 | ### Import the Library
15 |
16 | ```xml
17 |
18 | com.simiacryptus
19 | tf-gpt-2
20 | 1.7.1
21 |
22 | ```
23 |
24 | ### Instantiate the text generator
25 |
26 | ```java
27 | import com.simiacryptus.text.TextGenerator;
28 | import com.simiacryptus.text.gpt2.GPT2Util;
29 | TextGenerator textGenerator = GPT2Util.get345M();
30 | ```
31 |
32 | ### Generate text
33 |
34 | ```java
35 | System.out.println(textGenerator.generateText(500));
36 | ```
37 |
38 | ### Generate text given prefix
39 |
40 | ```java
41 | System.out.println(textGenerator.generateText(500, "Once upon a time"));
42 | ```
43 |
44 |
--------------------------------------------------------------------------------
/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
20 |
22 | 4.0.0
23 |
24 |
25 | com.simiacryptus
26 | java-parent
27 | 2.1.0
28 | ../../mvn-parents/java-parent
29 |
30 |
31 | tf-gpt-2
32 | GPT-2 Text Prediction via Tensorflow Java API
33 |
34 |
35 |
36 |
37 |
38 | com.simiacryptus
39 | bom
40 | ${project.version}
41 | pom
42 | import
43 |
44 |
45 |
46 |
47 |
48 |
49 | com.simiacryptus
50 | java-util
51 |
52 |
53 | org.tensorflow
54 | tensorflow
55 |
56 |
57 | org.tensorflow
58 | libtensorflow_jni_gpu
59 |
60 |
61 | org.tensorflow
62 | proto
63 |
64 |
65 | com.simiacryptus
66 | tensorflow-model
67 |
68 |
69 |
70 | com.google.code.gson
71 | gson
72 |
73 |
74 | commons-io
75 | commons-io
76 |
77 |
78 | org.slf4j
79 | slf4j-api
80 |
81 |
82 |
83 | org.junit.jupiter
84 | junit-jupiter
85 | test
86 |
87 |
88 | ch.qos.logback
89 | logback-classic
90 | test
91 |
92 |
93 | org.slf4j
94 | jcl-over-slf4j
95 | test
96 |
97 |
98 | org.slf4j
99 | log4j-over-slf4j
100 | test
101 |
102 |
103 |
104 | http://code.simiacrypt.us/release/${project.version}/tf-gpt-2
105 |
106 |
107 | simiacryptus
108 | s3://code.simiacrypt.us/release/${project.version}/tf-gpt-2
109 |
110 |
111 |
112 |
--------------------------------------------------------------------------------
/src/main/java/com/simiacryptus/text/GraphModifier.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2019 by Andrew Charneski.
3 | *
4 | * The author licenses this file to you under the
5 | * Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance
7 | * with the License. You may obtain a copy
8 | * of the License at
9 | *
10 | * http://www.apache.org/licenses/LICENSE-2.0
11 | *
12 | * Unless required by applicable law or agreed to in writing,
13 | * software distributed under the License is distributed on an
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 | * KIND, either express or implied. See the License for the
16 | * specific language governing permissions and limitations
17 | * under the License.
18 | */
19 |
20 | package com.simiacryptus.text;
21 |
22 | import com.google.protobuf.ByteString;
23 | import com.google.protobuf.InvalidProtocolBufferException;
24 | import com.google.protobuf.ProtocolStringList;
25 | import com.simiacryptus.text.gpt2.GPT2Codec;
26 | import org.slf4j.Logger;
27 | import org.slf4j.LoggerFactory;
28 | import org.tensorflow.*;
29 | import org.tensorflow.framework.GraphDef;
30 | import org.tensorflow.framework.NodeDef;
31 | import org.tensorflow.framework.TensorProto;
32 | import org.tensorflow.framework.TensorShapeProto;
33 |
34 | import javax.annotation.Nonnull;
35 | import javax.annotation.Nullable;
36 | import java.nio.ByteBuffer;
37 | import java.nio.IntBuffer;
38 | import java.util.ArrayList;
39 | import java.util.Arrays;
40 | import java.util.HashSet;
41 | import java.util.List;
42 | import java.util.function.Consumer;
43 | import java.util.stream.Collectors;
44 |
45 | import static org.tensorflow.framework.DataType.DT_INT32;
46 |
47 | /**
48 | * The type Graph modifier.
49 | */
50 | public abstract class GraphModifier {
51 | /**
52 | * The constant logger.
53 | */
54 | protected static final Logger logger = LoggerFactory.getLogger(GPT2Codec.class);
55 |
56 | /**
57 | * Gets deletes init.
58 | *
59 | * @return the deletes init
60 | */
61 | @Nonnull
62 | public abstract HashSet getDeletes_Init();
63 |
64 | /**
65 | * Import graph def.
66 | *
67 | * @param graph the graph
68 | * @param graphdef the graphdef
69 | */
70 | public static void importGraphDef(@Nonnull Graph graph, @Nonnull GraphDef graphdef) {
71 | final HashSet opsPresent = new HashSet<>();
72 | graph.operations().forEachRemaining(op -> {
73 | opsPresent.add(op.name());
74 | });
75 | while (true) {
76 | final List nextToAdd = graphdef.getNodeList().stream()
77 | .filter(nodeDef -> !opsPresent.contains(nodeDef.getName()))
78 | .filter(nodeDef -> {
79 | final ProtocolStringList inputList = nodeDef.getInputList();
80 | return inputList.isEmpty() || inputList.stream().allMatch(input -> opsPresent.contains(input.split(":")[0]));
81 | })
82 | .collect(Collectors.toList());
83 | if (nextToAdd.isEmpty()) break;
84 | nextToAdd.forEach(nodeDef -> {
85 | opsPresent.add(nodeDef.getName());
86 | if (graph.operation(nodeDef.getName()) == null) {
87 | try {
88 | logger.debug("Adding node to graph: " + nodeDef.getName() + " <= [" + nodeDef.getInputList().stream().reduce((a, b) -> a + "," + b).orElse("") + "]");
89 | // Add new node
90 | final OperationBuilder operationBuilder = graph.opBuilder(nodeDef.getOp(), nodeDef.getName());
91 | operationBuilder.setDevice(nodeDef.getDevice());
92 | nodeDef.getAttrMap().forEach((k, v) -> {
93 | switch (v.getValueCase()) {
94 | case TENSOR: {
95 | final TensorProto tensorProto = v.getTensor();
96 | final long[] shape = tensorProto.getTensorShape().getDimList().stream().mapToLong(x -> x.getSize()).toArray();
97 | final Class> type;
98 | switch (tensorProto.getDtype()) {
99 | case DT_FLOAT:
100 | type = Float.class;
101 | break;
102 | case DT_INT32:
103 | type = Integer.class;
104 | break;
105 | default:
106 | throw new RuntimeException(tensorProto.getDtype().toString());
107 | }
108 | if (null != tensorProto.getTensorContent() && !tensorProto.getTensorContent().isEmpty()) {
109 | operationBuilder.setAttr(k, Tensor.create(type, shape, tensorProto.getTensorContent().asReadOnlyByteBuffer()));
110 | } else if (0 < tensorProto.getIntValCount()) {
111 | operationBuilder.setAttr(k, Tensor.create(shape, IntBuffer.wrap(tensorProto.getIntValList().stream().mapToInt(x -> x).toArray())));
112 | } else {
113 | throw new RuntimeException(tensorProto.toString());
114 | }
115 | break;
116 | }
117 | case SHAPE:
118 | final TensorShapeProto shapeProto = v.getShape();
119 | final long[] shape = shapeProto.getDimList().stream().mapToLong(x -> x.getSize()).toArray();
120 | operationBuilder.setAttr(k, Shape.make(shape[0], Arrays.copyOfRange(shape, 1, shape.length)));
121 | break;
122 | case TYPE:
123 | operationBuilder.setAttr(k, DataType.valueOf(v.getType().name().split("_")[1]));
124 | break;
125 | case I:
126 | operationBuilder.setAttr(k, v.getI());
127 | break;
128 | case B:
129 | operationBuilder.setAttr(k, v.getB());
130 | break;
131 | default:
132 | throw new RuntimeException(k + " = " + v.toString());
133 | }
134 | });
135 | final Output[] inputs = nodeDef.getInputList().stream().map(input -> {
136 | final String[] split = input.split(":");
137 | final int idx = 1 == split.length ? 0 : Integer.parseInt(split[1]);
138 | return graph.operation(split[0]).output(idx);
139 | }).toArray(i -> new Output[i]);
140 | if (nodeDef.getOp().equals("Pack")) {
141 | operationBuilder.addInputList(inputs);
142 | } else if (nodeDef.getOp().equals("ConcatV2")) {
143 | operationBuilder.addInputList(new Output[]{inputs[0], inputs[1]});
144 | operationBuilder.addInput(inputs[2]);
145 | operationBuilder.addControlInput(inputs[2].op());
146 | } else if (nodeDef.getOp().equals("StridedSlice")) {
147 | for (int i = 0; i < inputs.length; i++) {
148 | if (i == 0) {
149 | operationBuilder.addInput(inputs[i]);
150 | } else {
151 | operationBuilder.addInput(inputs[i]);
152 | operationBuilder.addControlInput(inputs[i].op());
153 | }
154 | }
155 | } else if (inputs.length > 1) {
156 | for (int i = 0; i < inputs.length; i++) {
157 | operationBuilder.addInput(inputs[i]);
158 | }
159 | } else if (inputs.length > 0) {
160 | operationBuilder.addInput(inputs[0]);
161 | }
162 | try {
163 | operationBuilder.build();
164 | } catch (Throwable e) {
165 | throw new RuntimeException("Error processing " + nodeDef.toString(), e);
166 | }
167 | } catch (RuntimeException e) {
168 | throw e;
169 | } catch (Throwable e) {
170 | throw new RuntimeException("Error processing " + nodeDef.toString(), e);
171 | }
172 | }
173 | });
174 | }
175 | graphdef.getNodeList().stream()
176 | .filter(nodeDef -> !opsPresent.contains(nodeDef.getName()))
177 | .forEach(nodeDef -> {
178 | logger.warn("Remaining Node: " + nodeDef.toString());
179 | });
180 | }
181 |
182 | /**
183 | * Edit byte buffer.
184 | *
185 | * @param srcBuffer the src buffer
186 | * @param fn the fn
187 | * @return the byte buffer
188 | */
189 | @Nonnull
190 | public static ByteBuffer edit(@Nonnull ByteBuffer srcBuffer, @Nonnull Consumer fn) {
191 | final ByteBuffer dstBuffer = copy(srcBuffer);
192 | final IntBuffer intBuffer = dstBuffer.asIntBuffer();
193 | fn.accept(intBuffer);
194 | return dstBuffer;
195 | }
196 |
197 | /**
198 | * Copy byte buffer.
199 | *
200 | * @param srcBuffer the src buffer
201 | * @return the byte buffer
202 | */
203 | @Nonnull
204 | public static ByteBuffer copy(@Nonnull ByteBuffer srcBuffer) {
205 | final ByteBuffer byteBuffer = ByteBuffer.allocate(srcBuffer.capacity());
206 | byteBuffer.put(srcBuffer);
207 | byteBuffer.clear();
208 | return byteBuffer;
209 | }
210 |
211 | /**
212 | * Tensor 1 tensor proto.
213 | *
214 | * @param shape the shape
215 | * @param vals the vals
216 | * @return the tensor proto
217 | */
218 | @Nonnull
219 | public static TensorProto tensor1(int[] shape, @Nonnull int... vals) {
220 | TensorProto.Builder builder = TensorProto.newBuilder().setTensorShape(shape(shape)).setDtype(DT_INT32);
221 | Arrays.stream(vals).forEach(x -> builder.addIntVal(x));
222 | return builder.build();
223 | }
224 |
225 | /**
226 | * Tensor 2 tensor proto.
227 | *
228 | * @param shape the shape
229 | * @param vals the vals
230 | * @return the tensor proto
231 | */
232 | @Nonnull
233 | public static TensorProto tensor2(int[] shape, @Nonnull int... vals) {
234 | TensorProto.Builder builder = TensorProto.newBuilder().setTensorShape(shape(shape));
235 | byte[] array = new byte[vals.length * 4];
236 | IntBuffer buffer = ByteBuffer.wrap(array).asIntBuffer();
237 | for (int val : vals) buffer.put(Integer.reverseBytes(val));
238 | builder.setTensorContent(ByteString.copyFrom(array)).setDtype(DT_INT32);
239 | return builder.build();
240 | }
241 |
242 | /**
243 | * Shape tensor shape proto.
244 | *
245 | * @param dims the dims
246 | * @return the tensor shape proto
247 | */
248 | @Nonnull
249 | public static TensorShapeProto shape(@Nonnull int... dims) {
250 | TensorShapeProto.Builder builder = TensorShapeProto.newBuilder();
251 | Arrays.stream(dims).mapToObj(v -> TensorShapeProto.Dim.newBuilder().setSize(v).build()).forEach(value -> builder.addDim(value));
252 | return builder.build();
253 | }
254 |
255 | /**
256 | * Edit graph def.
257 | *
258 | * @param src the src
259 | * @param prefix the prefix
260 | * @param includeOriginal the include original
261 | * @return the graph def
262 | * @throws InvalidProtocolBufferException the invalid protocol buffer exception
263 | */
264 | @Nonnull
265 | public GraphDef edit(@Nonnull GraphDef src, String prefix, boolean includeOriginal) throws InvalidProtocolBufferException {
266 | final GraphDef srcGraphDef = GraphDef.parseFrom(src.toByteArray());
267 | final GraphDef.Builder destGraphDef = GraphDef.newBuilder();
268 | final HashSet deletes = getDeletes_Init();
269 | final HashSet editedNodes = new HashSet<>();
270 | for (int index = 0; index < srcGraphDef.getNodeCount(); index++) {
271 | final NodeDef node = srcGraphDef.getNode(index);
272 | if (deletes.contains(node.getName())) {
273 | logger.debug("Omit Node: " + node.getName());
274 | } else {
275 | @Nullable NodeDef.Builder nodeBuilder = edit(node.toBuilder());
276 | if (null != nodeBuilder) {
277 | logger.debug("Edit Node: " + node.getName());
278 | destGraphDef.addNode(nodeBuilder.build());
279 | editedNodes.add(node.getName());
280 | } else {
281 | // logger.debug("Pass-thru Node: " + node.getName());
282 | destGraphDef.addNode(node);
283 | }
284 | }
285 | }
286 | addNodes(nodeDef -> {
287 | destGraphDef.addNode(nodeDef);
288 | editedNodes.add(nodeDef.getName());
289 | });
290 | // return destGraphDef.build();
291 | return prefixRewrite(destGraphDef.build(), editedNodes, prefix, includeOriginal);
292 | }
293 |
294 | /**
295 | * Edit node def . builder.
296 | *
297 | * @param node the node
298 | * @return the node def . builder
299 | */
300 | @Nullable
301 | public abstract NodeDef.Builder edit(NodeDef.Builder node);
302 |
303 | /**
304 | * Add nodes.
305 | *
306 | * @param add the add
307 | */
308 | protected abstract void addNodes(Consumer add);
309 |
310 | /**
311 | * Prefix rewrite graph def.
312 | *
313 | * @param graphDef the graph def
314 | * @param editedNodes the edited nodes
315 | * @param prefix the prefix
316 | * @param includeOriginal the include original
317 | * @return the graph def
318 | */
319 | @Nonnull
320 | protected GraphDef prefixRewrite(@Nonnull GraphDef graphDef, @Nonnull HashSet editedNodes, String prefix, boolean includeOriginal) {
321 | while (true) {
322 | final List newItems = graphDef.getNodeList().stream()
323 | .filter(nodeDef -> !editedNodes.contains(nodeDef.getName()))
324 | .filter(nodeDef -> nodeDef.getInputList().stream().filter(input -> editedNodes.contains(input.split(":")[0])).findAny().isPresent())
325 | .map(x -> x.getName()).collect(Collectors.toList());
326 | if (newItems.isEmpty()) break;
327 | for (String newItem : newItems) {
328 | logger.debug("Item touched by rename: " + newItem);
329 | }
330 | editedNodes.addAll(newItems);
331 | }
332 | final GraphDef.Builder destGraphDef = GraphDef.newBuilder();
333 | for (NodeDef nodeDef : graphDef.getNodeList()) {
334 | NodeDef.Builder builder;
335 | if (editedNodes.contains(nodeDef.getName())) {
336 | builder = nodeDef.toBuilder();
337 | builder.setName(prefix + nodeDef.getName());
338 | } else {
339 | builder = null;
340 | }
341 | final ArrayList inputs = new ArrayList<>(nodeDef.getInputList());
342 | if (inputs.stream().filter(o -> editedNodes.contains(o.split(":")[0])).findAny().isPresent()) {
343 | if (null == builder) builder = nodeDef.toBuilder();
344 | builder.clearInput();
345 | for (String input : inputs) {
346 | if (editedNodes.contains(input.split(":")[0])) {
347 | logger.debug(nodeDef.getName() + " [ " + input + " ] += " + prefix);
348 | builder.addInput(prefix + input);
349 | } else {
350 | builder.addInput(input);
351 | }
352 | }
353 | }
354 | if (null != builder) {
355 | logger.debug("Edit in renaming: " + builder.getName());
356 | destGraphDef.addNode(builder.build());
357 | } else {
358 | if (includeOriginal) destGraphDef.addNode(nodeDef);
359 | }
360 | }
361 | return destGraphDef.build();
362 | }
363 | }
364 |
--------------------------------------------------------------------------------
/src/main/java/com/simiacryptus/text/LanguageCodeModel.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2019 by Andrew Charneski.
3 | *
4 | * The author licenses this file to you under the
5 | * Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance
7 | * with the License. You may obtain a copy
8 | * of the License at
9 | *
10 | * http://www.apache.org/licenses/LICENSE-2.0
11 | *
12 | * Unless required by applicable law or agreed to in writing,
13 | * software distributed under the License is distributed on an
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 | * KIND, either express or implied. See the License for the
16 | * specific language governing permissions and limitations
17 | * under the License.
18 | */
19 |
20 | package com.simiacryptus.text;
21 |
22 | import org.tensorflow.Tensor;
23 |
24 | import javax.annotation.Nonnull;
25 | import javax.annotation.Nullable;
26 | import java.util.function.BiFunction;
27 |
28 | /**
29 | * The interface Language code model.
30 | */
31 | public interface LanguageCodeModel {
32 | /**
33 | * Gets filter fn.
34 | *
35 | * @return the filter fn
36 | */
37 | @Nullable
38 | BiFunction getFilterFn();
39 |
40 | /**
41 | * Sets filter fn.
42 | *
43 | * @param filterFn the filter fn
44 | * @return the filter fn
45 | */
46 | @Nonnull
47 | LanguageCodeModel setFilterFn(BiFunction filterFn);
48 |
49 | /**
50 | * Copy language code model.
51 | *
52 | * @return the language code model
53 | */
54 | @Nonnull
55 | LanguageCodeModel copy();
56 |
57 | /**
58 | * Clear language code model.
59 | *
60 | * @return the language code model
61 | */
62 | @Nonnull
63 | LanguageCodeModel clear();
64 |
65 | /**
66 | * Eval float [ ].
67 | *
68 | * @param data_X the data x
69 | * @return the float [ ]
70 | */
71 | float[] eval(int data_X);
72 |
73 | /**
74 | * State tensor.
75 | *
76 | * @return the tensor
77 | */
78 | @Nullable
79 | Tensor> state();
80 | }
81 |
--------------------------------------------------------------------------------
/src/main/java/com/simiacryptus/text/MinEntropyWrapper.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2019 by Andrew Charneski.
3 | *
4 | * The author licenses this file to you under the
5 | * Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance
7 | * with the License. You may obtain a copy
8 | * of the License at
9 | *
10 | * http://www.apache.org/licenses/LICENSE-2.0
11 | *
12 | * Unless required by applicable law or agreed to in writing,
13 | * software distributed under the License is distributed on an
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 | * KIND, either express or implied. See the License for the
16 | * specific language governing permissions and limitations
17 | * under the License.
18 | */
19 |
20 | package com.simiacryptus.text;
21 |
22 | import com.simiacryptus.ref.wrappers.RefString;
23 | import org.slf4j.Logger;
24 | import org.slf4j.LoggerFactory;
25 |
26 | import javax.annotation.Nonnull;
27 | import java.util.ArrayList;
28 | import java.util.Arrays;
29 | import java.util.stream.DoubleStream;
30 | import java.util.stream.IntStream;
31 |
32 | /**
33 | * The type Min entropy wrapper.
34 | */
35 | public class MinEntropyWrapper extends ModelWrapper {
36 | /**
37 | * The constant logger.
38 | */
39 | protected static final Logger logger = LoggerFactory.getLogger(MinEntropyWrapper.class);
40 | private final ArrayList entropyHistory = new ArrayList<>();
41 | private double value;
42 |
43 | /**
44 | * Instantiates a new Min entropy wrapper.
45 | *
46 | * @param value the value
47 | * @param child the child
48 | */
49 | public MinEntropyWrapper(double value, LanguageCodeModel child) {
50 | super(child);
51 | this.value = value;
52 | }
53 |
54 | /**
55 | * Gets value.
56 | *
57 | * @return the value
58 | */
59 | public double getValue() {
60 | return value;
61 | }
62 |
63 | /**
64 | * Sets value.
65 | *
66 | * @param value the value
67 | */
68 | public void setValue(double value) {
69 | this.value = value;
70 | }
71 |
72 | /**
73 | * Entropy double.
74 | *
75 | * @param floats the floats
76 | * @return the double
77 | */
78 | public static double entropy(@Nonnull float[] floats) {
79 | return IntStream.range(0, floats.length).mapToDouble(i -> {
80 | float p = floats[i];
81 | return p <= 0 ? 0 : -p * Math.log(p);
82 | }).sum() / Math.log(2);
83 | }
84 |
85 | /**
86 | * Pow copy float [ ].
87 | *
88 | * @param floats the floats
89 | * @param value the value
90 | * @return the float [ ]
91 | */
92 | @Nonnull
93 | public static float[] powCopy(@Nonnull float[] floats, double value) {
94 | float[] copy = Arrays.copyOf(floats, floats.length);
95 | pow(copy, value);
96 | return copy;
97 | }
98 |
99 | /**
100 | * Pow.
101 | *
102 | * @param floats the floats
103 | * @param value the value
104 | */
105 | public static void pow(@Nonnull float[] floats, double value) {
106 | for (int i = 0; i < floats.length; i++) {
107 | floats[i] = (float) Math.pow(floats[i], value);
108 | }
109 | }
110 |
111 | @Override
112 | public float[] eval(int data_X) {
113 | LanguageCodeModel child = children[0];
114 | float[] floats = child.eval(data_X);
115 | double entropy = entropy(floats);
116 | entropyHistory.add(entropy);
117 | double[] schedule = DoubleStream.iterate(1.0, x -> x * 0.9).limit(1000).toArray();
118 | for (int i = 0; i < schedule.length; i++) {
119 | float[] copy = powCopy(floats, schedule[i]);
120 | SumModel.normalize(copy);
121 | if (entropy(copy) > value) {
122 | floats = copy;
123 | break;
124 | }
125 | }
126 | logger.debug(RefString.format("Entropy = %s => %s", entropy, entropy(floats)));
127 | //SumModel.normalize(floats);
128 | return floats;
129 | }
130 | }
131 |
--------------------------------------------------------------------------------
/src/main/java/com/simiacryptus/text/ModelWrapper.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2019 by Andrew Charneski.
3 | *
4 | * The author licenses this file to you under the
5 | * Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance
7 | * with the License. You may obtain a copy
8 | * of the License at
9 | *
10 | * http://www.apache.org/licenses/LICENSE-2.0
11 | *
12 | * Unless required by applicable law or agreed to in writing,
13 | * software distributed under the License is distributed on an
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 | * KIND, either express or implied. See the License for the
16 | * specific language governing permissions and limitations
17 | * under the License.
18 | */
19 |
20 | package com.simiacryptus.text;
21 |
22 | import org.tensorflow.Tensor;
23 |
24 | import javax.annotation.Nonnull;
25 | import java.util.Arrays;
26 | import java.util.function.BiFunction;
27 |
28 | /**
29 | * The type Model wrapper.
30 | */
31 | public abstract class ModelWrapper implements LanguageCodeModel {
32 | /**
33 | * The Children.
34 | */
35 | protected final LanguageCodeModel[] children;
36 |
37 | /**
38 | * Instantiates a new Model wrapper.
39 | *
40 | * @param children the children
41 | */
42 | public ModelWrapper(LanguageCodeModel... children) {
43 | this.children = children;
44 | }
45 |
46 | @Override
47 | public BiFunction getFilterFn() {
48 | return children[0].getFilterFn();
49 | }
50 |
51 | @Nonnull
52 | @Override
53 | public LanguageCodeModel copy() {
54 | return new SumModel(Arrays.stream(children)
55 | .map(languageCodeModel -> languageCodeModel.copy())
56 | .toArray(i -> new LanguageCodeModel[i]));
57 | }
58 |
59 | @Nonnull
60 | @Override
61 | public LanguageCodeModel clear() {
62 | for (LanguageCodeModel child : children) {
63 | child.clear();
64 | }
65 | return this;
66 | }
67 |
68 | @Override
69 | public abstract float[] eval(int data_X);
70 |
71 | @Nonnull
72 | @Override
73 | public LanguageCodeModel setFilterFn(BiFunction filterFn) {
74 | for (LanguageCodeModel child : children) {
75 | child.setFilterFn(filterFn);
76 | }
77 | return this;
78 | }
79 |
80 | @Override
81 | public Tensor> state() {
82 | assert 1 == children.length;
83 | return children[0].state();
84 | }
85 | }
86 |
--------------------------------------------------------------------------------
/src/main/java/com/simiacryptus/text/SimpleModel.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2019 by Andrew Charneski.
3 | *
4 | * The author licenses this file to you under the
5 | * Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance
7 | * with the License. You may obtain a copy
8 | * of the License at
9 | *
10 | * http://www.apache.org/licenses/LICENSE-2.0
11 | *
12 | * Unless required by applicable law or agreed to in writing,
13 | * software distributed under the License is distributed on an
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 | * KIND, either express or implied. See the License for the
16 | * specific language governing permissions and limitations
17 | * under the License.
18 | */
19 |
20 | package com.simiacryptus.text;
21 |
22 | import com.simiacryptus.text.gpt2.GPT2Codec;
23 | import org.tensorflow.Tensor;
24 |
25 | import javax.annotation.Nonnull;
26 | import javax.annotation.Nullable;
27 | import java.util.Arrays;
28 | import java.util.List;
29 | import java.util.Map;
30 | import java.util.function.BiFunction;
31 | import java.util.stream.Collectors;
32 |
33 | /**
34 | * The type Simple model.
35 | */
36 | public class SimpleModel implements LanguageCodeModel {
37 |
38 | @Nonnull
39 | private final float[] result;
40 |
41 | /**
42 | * Instantiates a new Simple model.
43 | *
44 | * @param result the result
45 | */
46 | public SimpleModel(@Nonnull float... result) {
47 | this.result = Arrays.copyOf(result, result.length);
48 | }
49 |
50 | @Nullable
51 | @Override
52 | public BiFunction getFilterFn() {
53 | return null;
54 | }
55 |
56 | /**
57 | * Build simple model.
58 | *
59 | * @param codec the codec
60 | * @param text the text
61 | * @return the simple model
62 | */
63 | @Nonnull
64 | public static SimpleModel build(@Nonnull GPT2Codec codec, String text) {
65 | List encode = codec.encode(text);
66 | Map counts = encode.stream().collect(Collectors.groupingBy(x -> x, Collectors.counting()));
67 | float[] result = new float[codec.getVocabSize()];
68 | for (int i = 0; i < result.length; i++) {
69 | result[i] = (float) counts.getOrDefault(i, 0l) / encode.size();
70 | }
71 | return new SimpleModel(result);
72 | }
73 |
74 | @Nonnull
75 | @Override
76 | public LanguageCodeModel copy() {
77 | return new SimpleModel(result);
78 | }
79 |
80 | @Nonnull
81 | @Override
82 | public LanguageCodeModel clear() {
83 | return this;
84 | }
85 |
86 | @Nonnull
87 | @Override
88 | public float[] eval(int data_X) {
89 | return Arrays.copyOf(result, result.length);
90 | }
91 |
92 | @Nonnull
93 | @Override
94 | public LanguageCodeModel setFilterFn(BiFunction filterFn) {
95 | return this;
96 | }
97 |
98 | @Nullable
99 | @Override
100 | public Tensor> state() {
101 | return null;
102 | }
103 |
104 | /**
105 | * Sets temperature.
106 | *
107 | * @return the temperature
108 | */
109 | @Nonnull
110 | public LanguageCodeModel setTemperature() {
111 | return this;
112 | }
113 |
114 | }
115 |
--------------------------------------------------------------------------------
/src/main/java/com/simiacryptus/text/SumModel.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2019 by Andrew Charneski.
3 | *
4 | * The author licenses this file to you under the
5 | * Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance
7 | * with the License. You may obtain a copy
8 | * of the License at
9 | *
10 | * http://www.apache.org/licenses/LICENSE-2.0
11 | *
12 | * Unless required by applicable law or agreed to in writing,
13 | * software distributed under the License is distributed on an
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 | * KIND, either express or implied. See the License for the
16 | * specific language governing permissions and limitations
17 | * under the License.
18 | */
19 |
20 | package com.simiacryptus.text;
21 |
22 | import com.simiacryptus.ref.lang.RefUtil;
23 |
24 | import javax.annotation.Nonnull;
25 | import java.util.Arrays;
26 | import java.util.stream.IntStream;
27 |
28 | /**
29 | * The type Sum model.
30 | */
31 | public class SumModel extends ModelWrapper {
32 |
33 | /**
34 | * Instantiates a new Sum model.
35 | *
36 | * @param children the children
37 | */
38 | public SumModel(LanguageCodeModel... children) {
39 | super(children);
40 | }
41 |
42 | /**
43 | * Normalize.
44 | *
45 | * @param sums the sums
46 | */
47 | public static void normalize(@Nonnull float[] sums) {
48 | double sum = IntStream.range(0, sums.length).mapToDouble(x -> sums[x]).sum();
49 | for (int i = 0; i < sums.length; i++) sums[i] /= sum;
50 | }
51 |
52 | @Nonnull
53 | @Override
54 | public float[] eval(int data_X) {
55 | float[] sums = RefUtil.get(Arrays.stream(children).map(c -> c.eval(data_X)).reduce((a, b) -> {
56 | for (int i = 0; i < a.length; i++) a[i] *= b[i];
57 | return a;
58 | }));
59 | normalize(sums);
60 | return sums;
61 | }
62 |
63 | }
64 |
--------------------------------------------------------------------------------
/src/main/java/com/simiacryptus/text/TemperatureWrapper.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2019 by Andrew Charneski.
3 | *
4 | * The author licenses this file to you under the
5 | * Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance
7 | * with the License. You may obtain a copy
8 | * of the License at
9 | *
10 | * http://www.apache.org/licenses/LICENSE-2.0
11 | *
12 | * Unless required by applicable law or agreed to in writing,
13 | * software distributed under the License is distributed on an
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 | * KIND, either express or implied. See the License for the
16 | * specific language governing permissions and limitations
17 | * under the License.
18 | */
19 |
20 | package com.simiacryptus.text;
21 |
22 | /**
23 | * The type Temperature wrapper.
24 | */
25 | public class TemperatureWrapper extends ModelWrapper {
26 |
27 | private double value;
28 |
29 | /**
30 | * Instantiates a new Temperature wrapper.
31 | *
32 | * @param value the value
33 | * @param child the child
34 | */
35 | public TemperatureWrapper(double value, LanguageCodeModel child) {
36 | super(child);
37 | this.value = value;
38 | }
39 |
40 | /**
41 | * Gets value.
42 | *
43 | * @return the value
44 | */
45 | public double getValue() {
46 | return value;
47 | }
48 |
49 | /**
50 | * Sets value.
51 | *
52 | * @param value the value
53 | */
54 | public void setValue(double value) {
55 | this.value = value;
56 | }
57 |
58 | @Override
59 | public float[] eval(int data_X) {
60 | LanguageCodeModel child = children[0];
61 | float[] floats = child.eval(data_X);
62 | for (int i = 0; i < floats.length; i++) {
63 | floats[i] = (float) Math.pow(floats[i], getValue());
64 | }
65 | SumModel.normalize(floats);
66 | return floats;
67 | }
68 | }
69 |
--------------------------------------------------------------------------------
/src/main/java/com/simiacryptus/text/TextGenerator.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2019 by Andrew Charneski.
3 | *
4 | * The author licenses this file to you under the
5 | * Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance
7 | * with the License. You may obtain a copy
8 | * of the License at
9 | *
10 | * http://www.apache.org/licenses/LICENSE-2.0
11 | *
12 | * Unless required by applicable law or agreed to in writing,
13 | * software distributed under the License is distributed on an
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 | * KIND, either express or implied. See the License for the
16 | * specific language governing permissions and limitations
17 | * under the License.
18 | */
19 |
20 | package com.simiacryptus.text;
21 |
22 | import com.simiacryptus.ref.wrappers.RefString;
23 | import com.simiacryptus.text.gpt2.GPT2Codec;
24 | import org.slf4j.Logger;
25 | import org.slf4j.LoggerFactory;
26 |
27 | import javax.annotation.Nonnull;
28 | import javax.annotation.Nullable;
29 | import java.util.ArrayList;
30 | import java.util.Arrays;
31 | import java.util.Comparator;
32 | import java.util.List;
33 | import java.util.function.Predicate;
34 | import java.util.stream.IntStream;
35 |
36 | /**
37 | * The type Text generator.
38 | */
39 | public class TextGenerator {
40 | /**
41 | * The constant logger.
42 | */
43 | protected static final Logger logger = LoggerFactory.getLogger(TextGenerator.class);
44 | /**
45 | * The Vocabulary size.
46 | */
47 | protected final int vocabularySize;
48 | /**
49 | * The Codec.
50 | */
51 | protected final GPT2Codec codec;
52 | /**
53 | * The Verbose.
54 | */
55 | protected boolean verbose = false;
56 | /**
57 | * The Choices to log.
58 | */
59 | protected int choicesToLog = 10;
60 | /**
61 | * The Codes.
62 | */
63 | @Nonnull
64 | List codes = new ArrayList<>();
65 | /**
66 | * The Next selections.
67 | */
68 | @Nullable
69 | float[] nextSelections;
70 | private LanguageCodeModel model;
71 |
72 | /**
73 | * Instantiates a new Text generator.
74 | *
75 | * @param vocabularySize the vocabulary size
76 | * @param model the model
77 | * @param codec the codec
78 | */
79 | public TextGenerator(int vocabularySize, LanguageCodeModel model, GPT2Codec codec) {
80 | this.setModel(model);
81 | this.vocabularySize = vocabularySize;
82 | this.codec = codec;
83 | }
84 |
85 | /**
86 | * Gets choices to log.
87 | *
88 | * @return the choices to log
89 | */
90 | public int getChoicesToLog() {
91 | return choicesToLog;
92 | }
93 |
94 | /**
95 | * Sets choices to log.
96 | *
97 | * @param choicesToLog the choices to log
98 | * @return the choices to log
99 | */
100 | @Nonnull
101 | public TextGenerator setChoicesToLog(int choicesToLog) {
102 | this.choicesToLog = choicesToLog;
103 | return this;
104 | }
105 |
106 | /**
107 | * Gets model.
108 | *
109 | * @return the model
110 | */
111 | public LanguageCodeModel getModel() {
112 | return model;
113 | }
114 |
115 | /**
116 | * Sets model.
117 | *
118 | * @param model the model
119 | * @return the model
120 | */
121 | @Nonnull
122 | public TextGenerator setModel(LanguageCodeModel model) {
123 | if (this.model == model) return this;
124 | if (null != this.model) this.model.clear();
125 | this.model = model;
126 | return this;
127 | }
128 |
129 | /**
130 | * Gets text.
131 | *
132 | * @return the text
133 | */
134 | public String getText() {
135 | return codec.decode(codes.toArray(new Integer[]{}));
136 | }
137 |
138 | /**
139 | * Gets vocabulary size.
140 | *
141 | * @return the vocabulary size
142 | */
143 | public int getVocabularySize() {
144 | return vocabularySize;
145 | }
146 |
147 | /**
148 | * Is verbose boolean.
149 | *
150 | * @return the boolean
151 | */
152 | public boolean isVerbose() {
153 | return verbose;
154 | }
155 |
156 | /**
157 | * Sets verbose.
158 | *
159 | * @param verbose the verbose
160 | * @return the verbose
161 | */
162 | @Nonnull
163 | public TextGenerator setVerbose(boolean verbose) {
164 | this.verbose = verbose;
165 | return this;
166 | }
167 |
168 | /**
169 | * Sorted indices int [ ].
170 | *
171 | * @param chosen the chosen
172 | * @param limit the limit
173 | * @return the int [ ]
174 | */
175 | public static int[] sortedIndices(@Nonnull float[] chosen, int limit) {
176 | return IntStream.range(0, chosen.length)
177 | .mapToObj(x -> x)
178 | .sorted(Comparator.comparing(c -> -chosen[c]))
179 | .limit(limit)
180 | .mapToInt(x -> x)
181 | .toArray();
182 | }
183 |
184 | /**
185 | * Copy text generator.
186 | *
187 | * @return the text generator
188 | */
189 | @Nonnull
190 | public TextGenerator copy() {
191 | TextGenerator copy = new TextGenerator(vocabularySize, getModel().copy(), codec);
192 | copy.codes.addAll(this.codes);
193 | copy.verbose = this.verbose;
194 | copy.choicesToLog = this.choicesToLog;
195 | copy.nextSelections = null == this.nextSelections ? null : Arrays.copyOf(this.nextSelections, this.nextSelections.length);
196 | return copy;
197 | }
198 |
199 | /**
200 | * Generate text string.
201 | *
202 | * @param terminator the terminator
203 | * @return the string
204 | */
205 | @Nonnull
206 | public String generateText(@Nonnull Predicate terminator) {
207 | return generateText(terminator, null);
208 | }
209 |
210 | /**
211 | * Generate text string.
212 | *
213 | * @param numberOfWords the number of words
214 | * @return the string
215 | */
216 | @Nonnull
217 | public String generateText(int numberOfWords) {
218 | return generateText(numberOfWords, null);
219 | }
220 |
221 | /**
222 | * Generate text string.
223 | *
224 | * @param terminator the terminator
225 | * @param prefix the prefix
226 | * @return the string
227 | */
228 | @Nonnull
229 | public String generateText(@Nonnull Predicate terminator, String prefix) {
230 | reset();
231 | feed(prefix);
232 | generate(terminator);
233 | return getText();
234 | }
235 |
236 | /**
237 | * Generate text string.
238 | *
239 | * @param numberOfTokens the number of tokens
240 | * @param prefix the prefix
241 | * @return the string
242 | */
243 | @Nonnull
244 | public String generateText(int numberOfTokens, String prefix) {
245 | reset();
246 | feed(prefix);
247 | generate(numberOfTokens);
248 | return getText();
249 | }
250 |
251 | /**
252 | * Generate string.
253 | *
254 | * @param fn the fn
255 | * @return the string
256 | */
257 | public String generate(@Nonnull Predicate fn) {
258 | init();
259 | ArrayList theseCodes = new ArrayList<>();
260 | try {
261 | for (int wordIndex = 0; wordIndex == 0 || fn.test(codec.decode(theseCodes.toArray(new Integer[]{}))); wordIndex++) {
262 | assert nextSelections != null;
263 | int selected = select(nextSelections);
264 | if (isVerbose()) {
265 | if (wordIndex != 0) log(nextSelections, codec, getChoicesToLog());
266 | logger.info(RefString.format("Selected New Text: '%s'", codec.decode(selected)));
267 | }
268 | if (selected == getVocabularySize() - 1) break;
269 | codes.add(selected);
270 | theseCodes.add(selected);
271 | nextSelections = getModel().eval(selected);
272 | }
273 | } catch (Throwable e) {
274 | //logger.warn("Error generating text", e);
275 | throw new RuntimeException("Error generating text: " + codec.decode(theseCodes.toArray(new Integer[]{})), e);
276 | }
277 | return codec.decode(theseCodes.toArray(new Integer[]{}));
278 | }
279 |
280 | /**
281 | * Generate.
282 | *
283 | * @param numberOfWords the number of words
284 | */
285 | public void generate(int numberOfWords) {
286 | init();
287 | try {
288 | for (int wordIndex = 0; wordIndex < numberOfWords; wordIndex++) {
289 | assert nextSelections != null;
290 | int selected = select(nextSelections);
291 | if (isVerbose()) {
292 | if (wordIndex != 0) log(nextSelections, codec, getChoicesToLog());
293 | logger.info(RefString.format("Selected New Text: '%s'", codec.decode(selected)));
294 | }
295 | if (selected == getVocabularySize() - 1) break;
296 | codes.add(selected);
297 | nextSelections = getModel().eval(selected);
298 | }
299 | } catch (Throwable e) {
300 | logger.warn("Error generating text", e);
301 | }
302 | }
303 |
304 | /**
305 | * Init text generator.
306 | *
307 | * @return the text generator
308 | */
309 | @Nonnull
310 | public TextGenerator init() {
311 | if (nextSelections == null) feed("");
312 | return this;
313 | }
314 |
315 | /**
316 | * Feed double.
317 | *
318 | * @param text the text
319 | * @return the double
320 | */
321 | public double feed(String text) {
322 | double entropy = 0.0;
323 | List codeList = new ArrayList<>();
324 | codeList.addAll(codec.encode(text));
325 | if (codeList.isEmpty()) codeList.add(getVocabularySize() - 1);
326 | for (Integer code : codeList) {
327 | if (null != nextSelections) {
328 | float p = nextSelections[code];
329 | entropy += p != 0 ? -Math.log(p) : Math.log(getVocabularySize());
330 | }
331 | codes.add(code);
332 | nextSelections = getModel().eval(code);
333 | if (isVerbose()) {
334 | logger.info(RefString.format("Feed token: '%s'", codec.decode(code)));
335 | assert nextSelections != null;
336 | log(nextSelections, codec, getChoicesToLog());
337 | }
338 | }
339 | return entropy / Math.log(2);
340 | }
341 |
342 | /**
343 | * Reset text generator.
344 | *
345 | * @return the text generator
346 | */
347 | @Nonnull
348 | public TextGenerator reset() {
349 | codes.clear();
350 | getModel().clear();
351 | return this;
352 | }
353 |
354 | /**
355 | * Select int.
356 | *
357 | * @param chosen the chosen
358 | * @return the int
359 | */
360 | protected int select(@Nonnull float[] chosen) {
361 | double originalFate = Math.random() * 1;
362 | double fate = originalFate;
363 | int j = 0;
364 | int[] topCandidates = sortedIndices(chosen, chosen.length);
365 | while (j < topCandidates.length && fate > chosen[topCandidates[j]]) {
366 | int topCandidate = topCandidates[j++];
367 | fate -= chosen[topCandidate];
368 | }
369 | int topCandidate = topCandidates[j];
370 | logger.debug(RefString.format("Chose #%s with fate %s", topCandidate, originalFate));
371 | return topCandidate;
372 | }
373 |
374 | /**
375 | * Log.
376 | *
377 | * @param chosen the chosen
378 | * @param codec the codec
379 | * @param count the count
380 | */
381 | protected void log(@Nonnull float[] chosen, @Nonnull GPT2Codec codec, int count) {
382 | Arrays.stream(sortedIndices(chosen, count))
383 | .forEach(candidate -> logger.info(RefString.format("\t#%d %.4f%% '%s'", candidate, chosen[candidate] * 100, codec.decode(candidate))));
384 | }
385 | }
386 |
--------------------------------------------------------------------------------
/src/main/java/com/simiacryptus/text/TopNWrapper.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2019 by Andrew Charneski.
3 | *
4 | * The author licenses this file to you under the
5 | * Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance
7 | * with the License. You may obtain a copy
8 | * of the License at
9 | *
10 | * http://www.apache.org/licenses/LICENSE-2.0
11 | *
12 | * Unless required by applicable law or agreed to in writing,
13 | * software distributed under the License is distributed on an
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 | * KIND, either express or implied. See the License for the
16 | * specific language governing permissions and limitations
17 | * under the License.
18 | */
19 |
20 | package com.simiacryptus.text;
21 |
22 | import java.util.Comparator;
23 | import java.util.stream.IntStream;
24 |
25 | /**
26 | * The type Top n wrapper.
27 | */
28 | public class TopNWrapper extends ModelWrapper {
29 |
30 | private int value;
31 |
32 | /**
33 | * Instantiates a new Top n wrapper.
34 | *
35 | * @param value the value
36 | * @param child the child
37 | */
38 | public TopNWrapper(int value, LanguageCodeModel child) {
39 | super(child);
40 | this.value = value;
41 | }
42 |
43 | /**
44 | * Gets value.
45 | *
46 | * @return the value
47 | */
48 | public int getValue() {
49 | return value;
50 | }
51 |
52 | /**
53 | * Sets value.
54 | *
55 | * @param value the value
56 | */
57 | public void setValue(int value) {
58 | this.value = value;
59 | }
60 |
61 | @Override
62 | public float[] eval(int data_X) {
63 | LanguageCodeModel child = children[0];
64 | float[] floats = child.eval(data_X);
65 | int[] sortedIndices = IntStream.range(0, floats.length)
66 | .mapToObj(x -> x)
67 | .sorted(Comparator.comparing(x -> -floats[x]))
68 | .mapToInt(x -> x)
69 | .toArray();
70 | for (int i = value; i < sortedIndices.length; i++) {
71 | floats[sortedIndices[i]] = 0;
72 | }
73 | SumModel.normalize(floats);
74 | return floats;
75 | }
76 | }
77 |
--------------------------------------------------------------------------------
/src/main/java/com/simiacryptus/text/gpt2/GPT2Codec.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2019 by Andrew Charneski.
3 | *
4 | * The author licenses this file to you under the
5 | * Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance
7 | * with the License. You may obtain a copy
8 | * of the License at
9 | *
10 | * http://www.apache.org/licenses/LICENSE-2.0
11 | *
12 | * Unless required by applicable law or agreed to in writing,
13 | * software distributed under the License is distributed on an
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 | * KIND, either express or implied. See the License for the
16 | * specific language governing permissions and limitations
17 | * under the License.
18 | */
19 |
20 | package com.simiacryptus.text.gpt2;
21 |
22 | import com.google.gson.GsonBuilder;
23 | import com.google.gson.JsonObject;
24 | import com.simiacryptus.ref.lang.RefUtil;
25 | import com.simiacryptus.ref.wrappers.RefStringBuilder;
26 | import com.simiacryptus.util.Util;
27 | import org.apache.commons.io.FileUtils;
28 | import org.slf4j.Logger;
29 | import org.slf4j.LoggerFactory;
30 |
31 | import javax.annotation.Nonnull;
32 | import javax.annotation.Nullable;
33 | import java.io.File;
34 | import java.io.IOException;
35 | import java.util.*;
36 | import java.util.function.Function;
37 | import java.util.stream.Collectors;
38 | import java.util.stream.Stream;
39 |
40 | /**
41 | * The type Gpt 2 codec.
42 | */
43 | public class GPT2Codec {
44 | /**
45 | * The constant logger.
46 | */
47 | protected static final Logger logger = LoggerFactory.getLogger(GPT2Codec.class);
48 |
49 | /**
50 | * The Encoder.
51 | */
52 | protected final TreeMap encoder;
53 | /**
54 | * The Decoder.
55 | */
56 | @Nonnull
57 | protected final TreeMap decoder;
58 | private final int vocabSize;
59 |
60 | /**
61 | * Instantiates a new Gpt 2 codec.
62 | *
63 | * @param encoder the encoder
64 | * @param vocabSize the vocab size
65 | */
66 | public GPT2Codec(TreeMap encoder, int vocabSize) {
67 | this.encoder = encoder;
68 | this.vocabSize = vocabSize;
69 | this.decoder = buildDecoder(this.encoder);
70 | }
71 |
72 | /**
73 | * Instantiates a new Gpt 2 codec.
74 | *
75 | * @param file the file
76 | * @param vocabSize the vocab size
77 | */
78 | public GPT2Codec(@Nonnull File file, int vocabSize) {
79 | this(GPT2Codec.loadEncoder(file), vocabSize);
80 | }
81 |
82 | /**
83 | * Gets character transformer.
84 | *
85 | * @return the character transformer
86 | */
87 | @Nonnull
88 | public static Function getCharacterTransformer() {
89 | Map byteEncoder = byteEncoder();
90 | return x -> {
91 | char[] chars = x.toCharArray();
92 | for (int i = 0; i < chars.length; i++) {
93 | chars[i] = byteEncoder.getOrDefault(chars[i], chars[i]);
94 | }
95 | return new String(chars);
96 | };
97 | }
98 |
99 | /**
100 | * Gets vocab size.
101 | *
102 | * @return the vocab size
103 | */
104 | public int getVocabSize() {
105 | return vocabSize;
106 | }
107 |
108 | /**
109 | * Build decoder tree map.
110 | *
111 | * @param encoder the encoder
112 | * @return the tree map
113 | */
114 | @Nonnull
115 | public static TreeMap buildDecoder(@Nonnull TreeMap encoder) {
116 | Stream> stream = encoder.entrySet().stream();
117 | return new TreeMap<>(stream.collect(Collectors.toMap(
118 | (Map.Entry e) -> e.getValue(),
119 | (Map.Entry e) -> e.getKey()
120 | )));
121 | }
122 |
123 | /**
124 | * Load encoder tree map.
125 | *
126 | * @param file the file
127 | * @return the tree map
128 | */
129 | @Nonnull
130 | public static TreeMap loadEncoder(@Nonnull File file) {
131 | try {
132 | return toMap(FileUtils.readFileToString(file, "UTF-8"), getCharacterTransformer());
133 | } catch (IOException e) {
134 | throw Util.throwException(e);
135 | }
136 | }
137 |
138 | /**
139 | * To map tree map.
140 | *
141 | * @param jsonTxt the json txt
142 | * @param keyEncoder the key encoder
143 | * @return the tree map
144 | */
145 | @Nonnull
146 | public static TreeMap toMap(String jsonTxt, @Nonnull Function keyEncoder) {
147 | JsonObject json = new GsonBuilder().create().fromJson(jsonTxt, JsonObject.class);
148 | return new TreeMap<>(json.keySet().stream().collect(Collectors.toMap(keyEncoder, x -> json.get(x).getAsInt(), (a, b) -> a)));
149 | }
150 |
151 | /**
152 | * Byte encoder map.
153 | *
154 | * @return the map
155 | */
156 | @Nonnull
157 | public static Map byteEncoder() {
158 | try {
159 | HashMap characterMap = new HashMap<>();
160 | for (int c = 0; c < 256; c++) {
161 | characterMap.put((char) (c + 256), (char) c);
162 | }
163 | for (char i = '!'; i < '~'; i++) {
164 | characterMap.put(i, i);
165 | }
166 | for (char i = '¡'; i < '¬'; i++) {
167 | characterMap.put(i, i);
168 | }
169 | for (char i = '®'; i < 'ÿ'; i++) {
170 | characterMap.put(i, i);
171 | }
172 | return characterMap;
173 | } catch (Throwable e) {
174 | throw Util.throwException(e);
175 | }
176 | }
177 |
178 | /**
179 | * Decode string.
180 | *
181 | * @param msg the msg
182 | * @return the string
183 | */
184 | public String decode(@Nonnull Integer... msg) {
185 | return Arrays.stream(msg).map(i -> decoder.getOrDefault(i, "")).reduce((a, b) -> a + b).orElseGet(() -> "");
186 | }
187 |
188 | /**
189 | * Encode list.
190 | *
191 | * @param msg the msg
192 | * @return the list
193 | */
194 | @Nonnull
195 | public List encode(@Nullable String msg) {
196 | ArrayList list = new ArrayList<>();
197 | if (null != msg && !msg.isEmpty()) {
198 | RefStringBuilder stringBuffer = new RefStringBuilder(msg);
199 | while (stringBuffer.length() > 0) {
200 | Optional codeString = lookup(stringBuffer.toString());
201 | if (codeString.isPresent()) {
202 | String key = RefUtil.get(codeString);
203 | stringBuffer.delete(0, key.length());
204 | list.add(encoder.get(key));
205 | } else {
206 | stringBuffer.delete(0, 1);
207 | }
208 | }
209 | }
210 | return list;
211 | }
212 |
213 | /**
214 | * Lookup optional.
215 | *
216 | * @param searchStr the search str
217 | * @return the optional
218 | */
219 | protected Optional lookup(@Nullable String searchStr) {
220 | if (null == searchStr || searchStr.isEmpty()) return Optional.empty();
221 | String ceilingKey = encoder.ceilingKey(searchStr);
222 | String floorKey = encoder.floorKey(searchStr);
223 | if (null != ceilingKey && !searchStr.startsWith(ceilingKey)) ceilingKey = null;
224 | if (null != floorKey && !searchStr.startsWith(floorKey)) floorKey = null;
225 | Optional codeString;
226 | if (null != ceilingKey || null != floorKey) {
227 | if (null != ceilingKey && null != floorKey) {
228 | if (floorKey.length() < ceilingKey.length()) {
229 | codeString = Optional.of(ceilingKey);
230 | } else {
231 | codeString = Optional.of(floorKey);
232 | }
233 | } else if (null != ceilingKey) {
234 | codeString = Optional.of(ceilingKey);
235 | } else {
236 | codeString = Optional.of(floorKey);
237 | }
238 | } else {
239 | codeString = Optional.empty();
240 | }
241 | // codeString = encoder.keySet().stream()
242 | // .filter(x -> x.equals(searchStr.substring(0, Math.min(searchStr.length(), x.length()))))
243 | // .sorted(Comparator.comparing(x -> -x.length()))
244 | // .findFirst();
245 | return codeString;
246 | }
247 | }
248 |
--------------------------------------------------------------------------------
/src/main/java/com/simiacryptus/text/gpt2/GPT2Model.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2019 by Andrew Charneski.
3 | *
4 | * The author licenses this file to you under the
5 | * Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance
7 | * with the License. You may obtain a copy
8 | * of the License at
9 | *
10 | * http://www.apache.org/licenses/LICENSE-2.0
11 | *
12 | * Unless required by applicable law or agreed to in writing,
13 | * software distributed under the License is distributed on an
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 | * KIND, either express or implied. See the License for the
16 | * specific language governing permissions and limitations
17 | * under the License.
18 | */
19 |
20 | package com.simiacryptus.text.gpt2;
21 |
22 | import com.google.protobuf.InvalidProtocolBufferException;
23 | import com.simiacryptus.ref.wrappers.RefString;
24 | import com.simiacryptus.text.GraphModifier;
25 | import com.simiacryptus.text.LanguageCodeModel;
26 | import com.simiacryptus.text.TextGenerator;
27 | import com.simiacryptus.util.Util;
28 | import org.apache.commons.io.FileUtils;
29 | import org.slf4j.Logger;
30 | import org.slf4j.LoggerFactory;
31 | import org.tensorflow.Graph;
32 | import org.tensorflow.Session;
33 | import org.tensorflow.Tensor;
34 | import org.tensorflow.framework.*;
35 |
36 | import javax.annotation.Nonnull;
37 | import javax.annotation.Nullable;
38 | import java.io.File;
39 | import java.io.IOException;
40 | import java.nio.FloatBuffer;
41 | import java.nio.IntBuffer;
42 | import java.util.*;
43 | import java.util.function.BiFunction;
44 | import java.util.stream.DoubleStream;
45 | import java.util.stream.IntStream;
46 |
47 | /**
48 | * The type Gpt 2 model.
49 | */
50 | public class GPT2Model implements LanguageCodeModel {
51 | /**
52 | * The constant logger.
53 | */
54 | protected static final Logger logger = LoggerFactory.getLogger(GPT2Model.class);
55 |
56 | /**
57 | * The Name.
58 | */
59 | public final String name;
60 | /**
61 | * The Graph def.
62 | */
63 | protected final byte[] graphDef;
64 | /**
65 | * The Code history.
66 | */
67 | protected final ArrayList code_history = new ArrayList<>();
68 | /**
69 | * The Graph modifier.
70 | */
71 | protected final GraphModifier graphModifier;
72 | /**
73 | * The Codec.
74 | */
75 | protected final GPT2Codec codec;
76 | /**
77 | * The Loaded subnets.
78 | */
79 | public HashSet loadedSubnets;
80 | /**
81 | * The Graph.
82 | */
83 | public Graph graph;
84 | /**
85 | * The Session.
86 | */
87 | public Session session;
88 | /**
89 | * The History size.
90 | */
91 | protected int history_size = 0;
92 | /**
93 | * The Tensor state.
94 | */
95 | @Nullable
96 | protected Tensor tensor_state = null;
97 | private BiFunction filterFn = (a, b) -> true;
98 |
99 | /**
100 | * Instantiates a new Gpt 2 model.
101 | *
102 | * @param name the name
103 | * @param graphModifier the graph modifier
104 | * @param file the file
105 | * @param codec the codec
106 | */
107 | public GPT2Model(String name, GraphModifier graphModifier, @Nonnull File file, GPT2Codec codec) {
108 | this(name, loadModel(file), graphModifier, codec);
109 | }
110 |
111 | /**
112 | * Instantiates a new Gpt 2 model.
113 | *
114 | * @param name the name
115 | * @param graphDef the graph def
116 | * @param graphModifier the graph modifier
117 | * @param codec the codec
118 | */
119 | public GPT2Model(String name, byte[] graphDef, GraphModifier graphModifier, GPT2Codec codec) {
120 | this(name, graphDef, graphModifier, codec, new Graph());
121 | }
122 |
123 | /**
124 | * Instantiates a new Gpt 2 model.
125 | *
126 | * @param name the name
127 | * @param graphDef the graph def
128 | * @param graphModifier the graph modifier
129 | * @param codec the codec
130 | * @param graph the graph
131 | */
132 | public GPT2Model(String name, byte[] graphDef, GraphModifier graphModifier, GPT2Codec codec, @Nonnull Graph graph) {
133 | this(name, graphDef, graphModifier, codec, graph, new Session(graph, ConfigProto.newBuilder()
134 | //.setLogDevicePlacement(true)
135 | // .setUsePerSessionThreads(true)
136 | // .setInterOpParallelismThreads(8)
137 | // .setIntraOpParallelismThreads(8)
138 | // .setIsolateSessionState(false)
139 | .setGraphOptions(GraphOptions.newBuilder()
140 | .setOptimizerOptions(OptimizerOptions.newBuilder()
141 | .setDoConstantFolding(true)
142 | .setDoFunctionInlining(true)
143 | .setDoCommonSubexpressionElimination(true)
144 | .build())
145 | .build())
146 | .setGpuOptions(GPUOptions.newBuilder()
147 | .setForceGpuCompatible(true)
148 | .setAllowGrowth(true)
149 | .setPerProcessGpuMemoryFraction(0.5)
150 | .build())
151 | .build().toByteArray()));
152 | }
153 |
154 | /**
155 | * Instantiates a new Gpt 2 model.
156 | *
157 | * @param name the name
158 | * @param graphDef the graph def
159 | * @param graphModifier the graph modifier
160 | * @param codec the codec
161 | * @param graph the graph
162 | * @param session the session
163 | */
164 | public GPT2Model(String name, byte[] graphDef, GraphModifier graphModifier, GPT2Codec codec, Graph graph, Session session) {
165 | this.name = name;
166 | this.graphDef = graphDef;
167 | this.graphModifier = graphModifier;
168 | this.codec = codec;
169 | this.graph = graph;
170 | this.session = session;
171 | loadedSubnets = new HashSet<>();
172 | }
173 |
174 | @Override
175 | public BiFunction getFilterFn() {
176 | return filterFn;
177 | }
178 |
179 | /**
180 | * Load model byte [ ].
181 | *
182 | * @param file the file
183 | * @return the byte [ ]
184 | */
185 | public static byte[] loadModel(@Nonnull File file) {
186 | try {
187 | return FileUtils.readFileToByteArray(file);
188 | } catch (IOException e) {
189 | throw Util.throwException(e);
190 | }
191 | }
192 |
193 | /**
194 | * Copy tensor.
195 | *
196 | * @param toCopy the to copy
197 | * @return the tensor
198 | */
199 | @Nonnull
200 | public static Tensor copy(@Nonnull Tensor toCopy) {
201 | FloatBuffer floatBuffer = FloatBuffer.allocate(toCopy.numElements());
202 | toCopy.writeTo(floatBuffer);
203 | floatBuffer.flip();
204 | return Tensor.create(toCopy.shape(), floatBuffer);
205 | }
206 |
207 | @Nonnull
208 | @Override
209 | public LanguageCodeModel copy() {
210 | GPT2Model copy = new GPT2Model(name, graphDef, graphModifier, this.codec, this.graph, this.session);
211 | if (null == this.tensor_state) {
212 | copy.tensor_state = null;
213 | } else {
214 | copy.tensor_state = copy(this.tensor_state);
215 | }
216 | copy.history_size = this.history_size;
217 | copy.loadedSubnets = this.loadedSubnets;
218 | copy.code_history.addAll(this.code_history);
219 | copy.filterFn = this.filterFn;
220 | return copy;
221 | }
222 |
223 | /**
224 | * Logits to probabilities float [ ].
225 | *
226 | * @param logits the logits
227 | * @return the float [ ]
228 | */
229 | @Nonnull
230 | public float[] logitsToProbabilities(@Nonnull float[] logits) {
231 | String prefix = codec.decode(code_history.stream().toArray(i -> new Integer[i]));
232 | int[] sortedIndices = Arrays.stream(TextGenerator.sortedIndices(logits, Integer.MAX_VALUE))
233 | .filter(item -> {
234 | if (item == logits.length - 1) return true;
235 | String thisStr = codec.decode(item);
236 | assert getFilterFn() != null;
237 | return getFilterFn().apply(prefix, thisStr);
238 | })
239 | .toArray();
240 | double[] input = IntStream.range(0, sortedIndices.length).mapToDouble(c -> logits[sortedIndices[c]]).toArray();
241 | assert 1 < input.length : "input.length() = " + input.length;
242 |
243 | final DoubleSummaryStatistics summaryStatistics = DoubleStream.of(input).filter(x -> Double.isFinite(x)).summaryStatistics();
244 | final double max = summaryStatistics.getMax();
245 | @Nullable final double[] exp = Arrays.stream(input).map(x -> {
246 | double xx = Math.exp(x - max);
247 | return Double.isFinite(xx) ? xx : 0;
248 | }).toArray();
249 | final double sum = 0 < Arrays.stream(exp).sum() ? Arrays.stream(exp).sum() : 1;
250 | assert Double.isFinite(sum);
251 | @Nullable double[] chosen = Arrays.stream(exp).map(x -> x / sum).toArray();
252 |
253 | for (int i = 0; i < logits.length; i++) logits[i] = 0;
254 | assert chosen != null;
255 | IntStream.range(0, chosen.length).forEach(c -> {
256 | logits[sortedIndices[c]] = (float) chosen[c];
257 | });
258 | return logits;
259 | }
260 |
261 | @Nonnull
262 | @Override
263 | public synchronized LanguageCodeModel clear() {
264 | logger.debug("Reset Language Model State");
265 | if (null != this.tensor_state) this.tensor_state.close();
266 | this.tensor_state = null;
267 | history_size = 0;
268 | code_history.clear();
269 | return this;
270 | }
271 |
272 | @Nonnull
273 | @Override
274 | public synchronized float[] eval(int data_X) {
275 | logger.debug(RefString.format("Eval %d", data_X));
276 | try {
277 | String prefix;
278 | if (!loadedSubnets.contains("")) {
279 | loadedSubnets.add("");
280 | graph.importGraphDef(this.graphDef);
281 | }
282 | if (null == this.tensor_state) {
283 | prefix = "init/";
284 | if (!loadedSubnets.contains(prefix)) {
285 | GraphModifier.importGraphDef(graph, this.graphModifier.edit(GraphDef.parseFrom(this.graphDef), prefix, false));
286 | loadedSubnets.add(prefix);
287 | }
288 | } else {
289 | prefix = "";
290 | }
291 | this.code_history.add(data_X);
292 | final float[] eval;
293 | if (0 == history_size) {
294 | eval = eval(prefix, data_X);
295 | } else {
296 | final int[] activeCodes = this.code_history
297 | .subList(this.code_history.size() - 1, this.code_history.size())
298 | .stream().mapToInt(x -> x).toArray();
299 | eval = eval(prefix, activeCodes);
300 | }
301 | return eval;
302 | } catch (InvalidProtocolBufferException e) {
303 | throw Util.throwException(e);
304 | }
305 | }
306 |
307 | /**
308 | * Eval float [ ].
309 | *
310 | * @param prefix the prefix
311 | * @param data_X the data x
312 | * @return the float [ ]
313 | */
314 | @Nonnull
315 | public synchronized float[] eval(String prefix, @Nonnull int... data_X) {
316 | synchronized (session) {
317 | logger.debug(RefString.format("Eval(%s,%s)", session, Arrays.toString(data_X)));
318 | Tensor input_X = Tensor.create(new long[]{1, data_X.length}, IntBuffer.wrap(data_X));
319 | Session.Runner runner = session.runner().feed("input_X", input_X);
320 | if (null != this.tensor_state) runner = runner.feed(prefix + "input_past", this.tensor_state);
321 | logger.debug("Input Codes: " + Arrays.toString(data_X));
322 | logger.debug("Input State: " + (this.tensor_state == null ? "null" : Arrays.toString(this.tensor_state.shape())));
323 | final Tensor prevState = this.tensor_state;
324 | runner = runner
325 | .fetch(prefix + "output/strided_slice_1")
326 | .fetch(0 == history_size ? prefix + "model/stack" : prefix + "output/concat");
327 | List> run = runner.run();
328 | Tensor tensor_next = run.get(0).expect(Float.class);
329 | final Tensor outputState = run.get(1).expect(Float.class); // reshape(shape_state, run.get(1).expect(Float.class));
330 | logger.debug("Output Logits: " + Arrays.toString(tensor_next.shape()));
331 | logger.debug("Output State: " + Arrays.toString(outputState.shape()));
332 | if (null == this.tensor_state) {
333 | this.history_size = (int) outputState.shape()[4];
334 | this.tensor_state = outputState;
335 | } else {
336 | this.history_size = this.history_size + 1;
337 | this.tensor_state.close();
338 | this.tensor_state = outputState;
339 | }
340 | float[] logits = new float[tensor_next.numElements()];
341 | tensor_next.writeTo(FloatBuffer.wrap(logits));
342 | tensor_next.close();
343 | if (null != prevState) prevState.close();
344 | input_X.close();
345 | return logitsToProbabilities(logits);
346 | }
347 | }
348 |
349 | @Nonnull
350 | @Override
351 | public LanguageCodeModel setFilterFn(BiFunction filterFn) {
352 | this.filterFn = filterFn;
353 | return this;
354 | }
355 |
356 | @Nullable
357 | @Override
358 | public Tensor> state() {
359 | return this.tensor_state;
360 | }
361 |
362 | }
363 |
--------------------------------------------------------------------------------
/src/main/java/com/simiacryptus/text/gpt2/GPT2Util.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2019 by Andrew Charneski.
3 | *
4 | * The author licenses this file to you under the
5 | * Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance
7 | * with the License. You may obtain a copy
8 | * of the License at
9 | *
10 | * http://www.apache.org/licenses/LICENSE-2.0
11 | *
12 | * Unless required by applicable law or agreed to in writing,
13 | * software distributed under the License is distributed on an
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 | * KIND, either express or implied. See the License for the
16 | * specific language governing permissions and limitations
17 | * under the License.
18 | */
19 |
20 | package com.simiacryptus.text.gpt2;
21 |
22 | import com.simiacryptus.text.LanguageCodeModel;
23 | import com.simiacryptus.text.SumModel;
24 | import com.simiacryptus.text.TextGenerator;
25 | import com.simiacryptus.util.Util;
26 | import org.apache.commons.io.FileUtils;
27 | import org.apache.commons.io.IOUtils;
28 |
29 | import javax.annotation.Nonnull;
30 | import javax.annotation.Nullable;
31 | import java.io.File;
32 | import java.io.IOException;
33 | import java.net.URI;
34 | import java.security.KeyManagementException;
35 | import java.security.NoSuchAlgorithmException;
36 | import java.util.ArrayList;
37 | import java.util.Arrays;
38 | import java.util.List;
39 | import java.util.TreeSet;
40 | import java.util.stream.Collectors;
41 | import java.util.stream.Stream;
42 | import java.util.zip.ZipFile;
43 |
44 | /**
45 | * The type Gpt 2 util.
46 | */
47 | public class GPT2Util {
48 |
49 | private static final String MODEL_URL_BASE = System.getProperty(
50 | "GPT2_MODEL_URL",
51 | "https://s3-us-west-2.amazonaws.com/simiacryptus/gpt2/");
52 | private static final TextGenerator prototype = get345M().setVerbose(false);
53 |
54 | /**
55 | * Gets 345 m.
56 | *
57 | * @return the 345 m
58 | */
59 | @Nonnull
60 | public static TextGenerator get345M() {
61 | return new TextGenerator(50257, getModel_345M(), getCodec_345M());
62 | }
63 |
64 | /**
65 | * Gets codec 345 m.
66 | *
67 | * @return the codec 345 m
68 | */
69 | @Nonnull
70 | public static GPT2Codec getCodec_345M() {
71 | return new GPT2Codec(getEncoderFile_345M(), 50257);
72 | }
73 |
74 | /**
75 | * Gets encoder file 345 m.
76 | *
77 | * @return the encoder file 345 m
78 | */
79 | @Nonnull
80 | public static File getEncoderFile_345M() {
81 | return loadZippedInternetFile(MODEL_URL_BASE + "encoder_345M.zip", "encoder_345M.json");
82 | }
83 |
84 | /**
85 | * Gets graph file 345 m.
86 | *
87 | * @return the graph file 345 m
88 | */
89 | @Nonnull
90 | public static File getGraphFile_345M() {
91 | return loadRawInternetFile(MODEL_URL_BASE, "345M.pb");
92 | }
93 |
94 | /**
95 | * Gets model 345 m.
96 | *
97 | * @return the model 345 m
98 | */
99 | @Nonnull
100 | public static GPT2Model getModel_345M() {
101 | return getModel_345M(getGraphFile_345M());
102 | }
103 |
104 | /**
105 | * Gets text generator.
106 | *
107 | * @return the text generator
108 | */
109 | @Nonnull
110 | public static TextGenerator getTextGenerator() {
111 | return prototype.copy();
112 | }
113 |
114 | /**
115 | * Load zipped internet file file.
116 | *
117 | * @param zipUrl the zip url
118 | * @param pathname the pathname
119 | * @return the file
120 | */
121 | public @Nonnull
122 | static File loadZippedInternetFile(@Nonnull String zipUrl, @Nonnull String pathname) {
123 | File encoderFile = new File(pathname);
124 | if (new File(encoderFile.getName()).exists()) {
125 | encoderFile = new File(encoderFile.getName());
126 | } else {
127 | try {
128 | try (ZipFile zipFile = new ZipFile(Util.cacheFile(new URI(zipUrl)))) {
129 | byte[] graphDefBytes = IOUtils.toByteArray(zipFile.getInputStream(zipFile.getEntry(pathname)));
130 | encoderFile = new File(encoderFile.getName());
131 | FileUtils.writeByteArrayToFile(encoderFile, graphDefBytes);
132 | }
133 | } catch (Exception e) {
134 | throw Util.throwException(e);
135 | }
136 | }
137 | return encoderFile;
138 | }
139 |
140 | /**
141 | * Load raw internet file file.
142 | *
143 | * @param urlBase the url base
144 | * @param fileName the file name
145 | * @return the file
146 | */
147 | public @Nonnull
148 | static File loadRawInternetFile(String urlBase, @Nonnull String fileName) {
149 | File graphFile = new File(fileName);
150 | if (new File(graphFile.getName()).exists()) {
151 | graphFile = new File(graphFile.getName());
152 | } else {
153 | try {
154 | graphFile = Util.cacheFile(new URI(urlBase + fileName));
155 | } catch (Exception e) {
156 | throw Util.throwException(e);
157 | }
158 | }
159 | return graphFile;
160 | }
161 |
162 | /**
163 | * Gets model 345 m.
164 | *
165 | * @param file the file
166 | * @return the model 345 m
167 | */
168 | @Nonnull
169 | public static GPT2Model getModel_345M(@Nonnull File file) {
170 | return getModel_345M("345M", file);
171 | }
172 |
173 | /**
174 | * Gets model 345 m.
175 | *
176 | * @param name the name
177 | * @param file the file
178 | * @return the model 345 m
179 | */
180 | @Nonnull
181 | public static GPT2Model getModel_345M(String name, @Nonnull File file) {
182 | return new GPT2Model(name, new GPT2Edit_345M(), file, getCodec_345M());
183 | }
184 |
185 | /**
186 | * Gets text generator.
187 | *
188 | * @param textGenerator the text generator
189 | * @param characterWhitelist the character whitelist
190 | * @param wordlist the wordlist
191 | * @return the text generator
192 | * @throws IOException the io exception
193 | * @throws NoSuchAlgorithmException the no such algorithm exception
194 | * @throws KeyManagementException the key management exception
195 | */
196 | @Nonnull
197 | public static TextGenerator getTextGenerator(@Nonnull TextGenerator textGenerator, @Nullable String characterWhitelist, @Nullable URI wordlist) throws IOException, NoSuchAlgorithmException, KeyManagementException {
198 | TreeSet wordList = null == wordlist ? null : new TreeSet<>(
199 | Arrays.stream(FileUtils.readFileToString(Util.cacheFile(wordlist), "UTF-8").split("\\s+"))
200 | .map(x -> x.trim().toLowerCase()).collect(Collectors.toSet())
201 | );
202 | textGenerator.getModel().setFilterFn((prefix, txt) -> {
203 | if (null != characterWhitelist && !characterWhitelist.isEmpty() &&
204 | txt.matches(".*[^" + characterWhitelist + "].*")) return false;
205 | String[] words = txt.split("[^\\w]+");
206 | if (null != wordList && !wordList.isEmpty())
207 | for (int i = 0; i < words.length; i++) {
208 | String word = words[i].toLowerCase();
209 | if (word.isEmpty()) continue;
210 | if (i < words.length - 1 && !wordList.contains(word)) return false;
211 | else {
212 | if (wordList.contains(word)) continue;
213 | String floor = wordList.floor(word);
214 | if (null != floor && floor.startsWith(word)) continue;
215 | String ceiling = wordList.ceiling(word);
216 | if (null != ceiling && ceiling.startsWith(word)) continue;
217 | return false;
218 | }
219 | }
220 | return true;
221 | });
222 | return textGenerator;
223 | }
224 |
225 | /**
226 | * Gets text generator.
227 | *
228 | * @param seeds the seeds
229 | * @return the text generator
230 | */
231 | @Nonnull
232 | public static TextGenerator getTextGenerator(String... seeds) {
233 | return getTextGenerator(get345M().setVerbose(false), seeds);
234 | }
235 |
236 | /**
237 | * Gets text generator.
238 | *
239 | * @param base the base
240 | * @param seeds the seeds
241 | * @return the text generator
242 | */
243 | @Nonnull
244 | public static TextGenerator getTextGenerator(@Nonnull TextGenerator base, String... seeds) {
245 | ArrayList languageCodeModels = new ArrayList<>();
246 | return getTextGenerator(base, languageCodeModels, seeds);
247 | }
248 |
249 | /**
250 | * Gets text generator.
251 | *
252 | * @param base the base
253 | * @param languageCodeModels the language code models
254 | * @param seeds the seeds
255 | * @return the text generator
256 | */
257 | @Nonnull
258 | public static TextGenerator getTextGenerator(@Nonnull TextGenerator base, @Nonnull List languageCodeModels, @Nonnull String... seeds) {
259 | base.setModel(new SumModel(Stream.concat(
260 | Arrays.stream(seeds).map(seed -> {
261 | TextGenerator copy = base.copy();
262 | copy.feed(seed);
263 | return copy.getModel();
264 | }),
265 | languageCodeModels.stream()
266 | ).toArray(i -> new LanguageCodeModel[i])));
267 | return base;
268 | }
269 |
270 | /**
271 | * Gets text generator.
272 | *
273 | * @param characterWhitelist the character whitelist
274 | * @param wordlist the wordlist
275 | * @return the text generator
276 | * @throws IOException the io exception
277 | * @throws NoSuchAlgorithmException the no such algorithm exception
278 | * @throws KeyManagementException the key management exception
279 | */
280 | @Nonnull
281 | protected static TextGenerator getTextGenerator(String characterWhitelist, URI wordlist) throws IOException, NoSuchAlgorithmException, KeyManagementException {
282 | return getTextGenerator(getTextGenerator(), characterWhitelist, wordlist);
283 | }
284 | }
285 |
--------------------------------------------------------------------------------
/src/site/site.xml:
--------------------------------------------------------------------------------
1 |
19 |
20 |
21 |
22 | org.apache.maven.skins
23 | maven-fluido-skin
24 | 1.8
25 |
26 |
27 |
28 | true
29 | false
30 | true
31 |
32 | SimiaCryptus/tf-gpt-2
33 | left
34 | black
35 |
36 | release
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
--------------------------------------------------------------------------------
/src/test/java/com/simiacryptus/text/gpt2/DevTests.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2019 by Andrew Charneski.
3 | *
4 | * The author licenses this file to you under the
5 | * Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance
7 | * with the License. You may obtain a copy
8 | * of the License at
9 | *
10 | * http://www.apache.org/licenses/LICENSE-2.0
11 | *
12 | * Unless required by applicable law or agreed to in writing,
13 | * software distributed under the License is distributed on an
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 | * KIND, either express or implied. See the License for the
16 | * specific language governing permissions and limitations
17 | * under the License.
18 | */
19 |
20 | package com.simiacryptus.text.gpt2;
21 |
22 | import com.simiacryptus.ref.lang.RefUtil;
23 | import com.simiacryptus.tensorflow.GraphModel;
24 | import com.simiacryptus.tensorflow.TFUtil;
25 | import org.junit.jupiter.api.Assertions;
26 | import org.junit.jupiter.api.Test;
27 | import org.tensorflow.Output;
28 | import org.tensorflow.TensorFlowException;
29 |
30 | import javax.annotation.Nonnull;
31 | import java.io.File;
32 | import java.util.Arrays;
33 | import java.util.Map;
34 |
35 | import static com.simiacryptus.tensorflow.TFUtil.find;
36 |
37 | /**
38 | * The type Dev tests.
39 | */
40 | public class DevTests {
41 |
42 | /**
43 | * Model home file.
44 | *
45 | * @return the file
46 | */
47 | @Nonnull
48 | public static File modelHome() {
49 | return new File(System.getProperty("MODEL_HOME", "H:\\SimiaCryptus\\data-science-tools\\gpt-2\\models"));
50 | }
51 |
52 | /**
53 | * Summary json.
54 | *
55 | * @throws Exception the exception
56 | */
57 | @Test
58 | public void summaryJson() throws Exception {
59 | TestUtil.open("345M.", new GraphModel(GPT2Model.loadModel(new File(modelHome(), "345M" + ".pb"))));
60 | TestUtil.open("345M_Init.", new GraphModel(GPT2Model.loadModel(new File(modelHome(), "345M_Init" + ".pb"))));
61 | }
62 |
63 | /**
64 | * Compare init.
65 | */
66 | @Test
67 | public void compare_init() {
68 | final GraphModel a = new GraphModel(GPT2Model.loadModel(new File(modelHome(), "345M" + ".pb")));
69 | final GraphModel b = new GraphModel(GPT2Model.loadModel(new File(modelHome(), "345M_Init" + ".pb")));
70 | // TestUtil.open("345M.", a);
71 | // TestUtil.open("345M_Init.", b);
72 | // TestUtil.open("345M_cmp.", b.compare(a));
73 | new GraphComparer().compare(a, b);
74 | }
75 |
76 | /**
77 | * Test init.
78 | *
79 | * @throws Exception the exception
80 | */
81 | @Test
82 | public void test_init() throws Exception {
83 | final GraphModel a = new GraphModel(GPT2Model.loadModel(new File(modelHome(), "345M" + ".pb")));
84 | final GraphModel b = new GraphModel(GPT2Model.loadModel(new File(modelHome(), "345M_Init" + ".pb")));
85 | final GraphModel edited = new GraphModel(new GPT2Edit_345M().edit(a.graphDef, "", true).toByteArray());
86 | TestUtil.open("345M_edit.", edited);
87 | final Map compare = b.compare(edited);
88 | TestUtil.open("345M_test.", compare);
89 | compare.values().forEach(deltaRecord -> {
90 | System.out.println("left=" + (null == deltaRecord.left ? "null" : deltaRecord.left.getNodeDef()));
91 | System.out.println("right=" + (null == deltaRecord.right ? "null" : deltaRecord.right.getNodeDef()));
92 | });
93 | }
94 |
95 | /**
96 | * Add gradient.
97 | *
98 | * @throws Exception the exception
99 | */
100 | @Test
101 | public void addGradient() throws Exception {
102 | Assertions.assertThrows(TensorFlowException.class, () -> {
103 | try {
104 | byte[] originalGraphDef = GPT2Model.loadModel(new File(modelHome(), "345M" + ".pb"));
105 | byte[] newGraphDef = TFUtil.editGraph(originalGraphDef, graph -> {
106 | graph.addGradients("gradient_", new Output[]{
107 | find(graph, "model/Reshape_1").output(0),
108 | find(graph, "model/stack").output(0)
109 | }, new Output[]{
110 | find(graph, "input_past").output(0)
111 | }, null);
112 | });
113 | System.out.println("GPT2Model: " + TFUtil.toJson(new GraphModel(newGraphDef)));
114 | } catch (Throwable e) {
115 | e.printStackTrace();
116 | throw e;
117 | }
118 | });
119 | }
120 |
121 | /**
122 | * Encode.
123 | */
124 | @Test
125 | public void encode() {
126 | GPT2Codec encoder = new GPT2Codec(new File(modelHome(), "345M" + "\\encoder.json"), 50257);
127 | for (String text : Arrays.asList(
128 | "This is a test",
129 | "<|endoftext|>"
130 | )) {
131 | System.out.println(text + " => " + RefUtil.get(encoder.encode(text).stream().map(x -> x.toString()).reduce((a, b) -> a + ", " + b)));
132 | }
133 | }
134 |
135 | }
136 |
--------------------------------------------------------------------------------
/src/test/java/com/simiacryptus/text/gpt2/GraphComparer.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2019 by Andrew Charneski.
3 | *
4 | * The author licenses this file to you under the
5 | * Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance
7 | * with the License. You may obtain a copy
8 | * of the License at
9 | *
10 | * http://www.apache.org/licenses/LICENSE-2.0
11 | *
12 | * Unless required by applicable law or agreed to in writing,
13 | * software distributed under the License is distributed on an
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 | * KIND, either express or implied. See the License for the
16 | * specific language governing permissions and limitations
17 | * under the License.
18 | */
19 |
20 | package com.simiacryptus.text.gpt2;
21 |
22 | import com.simiacryptus.ref.wrappers.RefString;
23 | import com.simiacryptus.tensorflow.GraphModel;
24 | import org.tensorflow.framework.AttrValue;
25 | import org.tensorflow.framework.NodeDef;
26 | import org.tensorflow.framework.TensorProto;
27 | import org.tensorflow.framework.TensorShapeProto;
28 |
29 | import javax.annotation.Nonnull;
30 | import javax.annotation.Nullable;
31 | import java.nio.FloatBuffer;
32 | import java.nio.IntBuffer;
33 | import java.util.*;
34 | import java.util.function.Consumer;
35 | import java.util.stream.IntStream;
36 | import java.util.stream.Stream;
37 |
38 | /**
39 | * The type Graph comparer.
40 | */
41 | class GraphComparer implements Consumer {
42 | /**
43 | * The Node deletes.
44 | */
45 | public final ArrayList nodeDeletes = new ArrayList();
46 | /**
47 | * The New nodes.
48 | */
49 | public final ArrayList newNodes = new ArrayList();
50 | /**
51 | * The Node edits.
52 | */
53 | public final Map> nodeEdits = new HashMap<>();
54 |
55 | /**
56 | * To string string.
57 | *
58 | * @param value the value
59 | * @return the string
60 | */
61 | public static String toString(@Nonnull AttrValue value) {
62 | switch (value.getValueCase()) {
63 | case I:
64 | return RefString.format("AttrValue.newBuilder().setI(%s).build()", value.getI());
65 | case F:
66 | return RefString.format("AttrValue.newBuilder().setF(%s).build()", value.getF());
67 | case B:
68 | return RefString.format("AttrValue.newBuilder().setB(%s).build()", value.getB());
69 | case S:
70 | return RefString.format("AttrValue.newBuilder().setS(%s).build()", value.getS().toStringUtf8());
71 | case TYPE:
72 | return RefString.format("AttrValue.newBuilder().setType(DataType.forNumber(%s)).build()", value.getType().getNumber());
73 | case SHAPE:
74 | return RefString.format("AttrValue.newBuilder().setShape(shape(%s)).build()", toString(dims(value.getShape())));
75 | case TENSOR:
76 | TensorProto tensor = value.getTensor();
77 | TensorShapeProto shape = tensor.getTensorShape();
78 | switch (tensor.getDtype()) {
79 | case DT_INT32: {
80 | String shapeElements = shape.getDimList().stream().map(x -> Long.toString(x.getSize())).reduce((a, b) -> a + ", " + b).orElse("");
81 | return tensor.getIntValList().stream().map(x -> Integer.toString(x)).reduce((a, b) -> a + ", " + b).map(elements -> {
82 | return RefString.format("AttrValue.newBuilder().setTensor(tensor1(new int[]{ %s }, new int[] { %s })).build()", shapeElements, elements);
83 | }).orElseGet(() -> {
84 | IntBuffer intBuffer = tensor.getTensorContent().asReadOnlyByteBuffer().asIntBuffer();
85 | int[] dst = new int[intBuffer.remaining()];
86 | intBuffer.get(dst);
87 | String elements = Arrays.stream(dst).map(i1 -> Integer.reverseBytes(i1)).mapToObj(i -> Integer.toString(i)).reduce((a, b) -> a + ", " + b).orElse("");
88 | return RefString.format("AttrValue.newBuilder().setTensor(tensor2(new int[]{ %s }, new int[] { %s })).build()", shapeElements, elements);
89 | });
90 | }
91 | case DT_FLOAT:
92 | String shapeElements = shape.getDimList().stream().map(x -> Long.toString(x.getSize())).reduce((a, b) -> a + ", " + b).orElse("");
93 | return tensor.getFloatValList().stream().map(x -> Float.toString(x)).reduce((a, b) -> a + ", " + b).map(elements -> {
94 | return RefString.format("AttrValue.newBuilder().setTensor(tensor1(new int[]{ %s }, new int[] { %s })).build()", shapeElements, elements);
95 | }).orElseGet(() -> {
96 | FloatBuffer intBuffer = tensor.getTensorContent().asReadOnlyByteBuffer().asFloatBuffer();
97 | float[] dst = new float[intBuffer.remaining()];
98 | intBuffer.get(dst);
99 | String elements = IntStream.range(0, dst.length).mapToDouble(i -> dst[i]).mapToObj(d -> Double.toString(d)).reduce((a, b) -> a + ", " + b).orElse("");
100 | return RefString.format("AttrValue.newBuilder().setTensor(tensor2(new int[]{ %s }, new float[]{ %s })).build()", shapeElements, elements);
101 | });
102 | }
103 | default:
104 | return "/* " + value.getType() + " - " + value.toString().trim() + " */";
105 | }
106 | }
107 |
108 | /**
109 | * To string string.
110 | *
111 | * @param dims the dims
112 | * @return the string
113 | */
114 | @Nonnull
115 | public static String toString(@Nonnull long[] dims) {
116 | return Arrays.stream(dims).mapToObj(size -> Long.toString(size)).reduce((a, b) -> a + ", " + b).orElse("");
117 | }
118 |
119 | /**
120 | * Dims long [ ].
121 | *
122 | * @param shape the shape
123 | * @return the long [ ]
124 | */
125 | public static long[] dims(@Nonnull TensorShapeProto shape) {
126 | return shape.getDimList().stream().mapToLong(dim -> dim.getSize()).toArray();
127 | }
128 |
129 | /**
130 | * Compare.
131 | *
132 | * @param left the left
133 | * @param right the right
134 | */
135 | public void compare(@Nonnull GraphModel left, @Nonnull GraphModel right) {
136 | left.compare(right).values().stream().forEach(this);
137 | System.out.println("\n" +
138 | " @Override\n" +
139 | " public HashSet getDeletes_Init() {\n" +
140 | " final HashSet toDelete = new HashSet<>();\n" +
141 | "\n" + this.nodeDeletes.stream().map(s2 -> s2.trim()).reduce((a, b) -> a + "\n" + b).orElse("") +
142 | " return toDelete;\n" +
143 | " }\n" +
144 | "\n" +
145 | " protected void addNodes(Consumer add) {\n" +
146 | "\n" + this.newNodes.stream().map(s1 -> s1.trim()).reduce((a, b) -> a + "\n" + b).orElse("") +
147 | " }\n" +
148 | "\n" +
149 | " @Override\n" +
150 | " public NodeDef.Builder edit(NodeDef.Builder node) {\n" +
151 | "\n" + nodeEdits.entrySet().stream().sorted(Comparator.comparing(x -> x.getKey())).map(e ->
152 | RefString.format("if(node.getName().equals(\"%s\")) {\n%s\n}", e.getKey(),
153 | e.getValue().stream().map(x -> "\t" + x).reduce((a, b) -> a + "\n" + b).orElse("")
154 | )).map(s -> s.trim()).reduce((a, b) -> a + "\nelse " + b).orElse("") +
155 | " else {\n" +
156 | " return null;\n" +
157 | " }\n" +
158 | " return node;\n" +
159 | " }\n");
160 | }
161 |
162 | @Override
163 | public void accept(@Nonnull GraphModel.DeltaRecord delta) {
164 | if (delta.left == null || delta.right == null) {
165 | if (delta.left == null) {
166 | GraphModel.GraphNode node = delta.right;
167 | this.newNodes.add(RefString.format("add.accept(NodeDef.newBuilder().setName(\"%s\").setOp(\"%s\")%s%s.build());",
168 | node.name,
169 | node.getOp(),
170 | node.getInputKeys().stream().map(s -> '"' + s + '"')
171 | .reduce((a, b) -> a + ", " + b)
172 | .map(s -> ".addAllInput(Arrays.asList(" + s + "))").orElse(""),
173 | node.getNodeDef().getAttrMap().entrySet().stream()
174 | .map(e -> RefString.format(".putAttr(\"%s\",%s)", e.getKey(), toString(e.getValue())))
175 | .reduce((a, b) -> a + b).orElse("")
176 | ).replaceAll("\\)\\.", ")\n\t."));
177 | } else {
178 | nodeDeletes.add(RefString.format("toDelete.add(\"%s\");%n", delta.name));
179 | }
180 | } else {
181 | final NodeDef leftNode = delta.left.getNodeDef();
182 | final NodeDef rightNode = delta.right.getNodeDef();
183 | if (null != leftNode && null != rightNode) {
184 | if (!leftNode.getOp().equals(rightNode.getOp())) {
185 | getBuffer(delta.name).add(RefString.format("node.setOp(\"%s\");", rightNode.getOp()));
186 | }
187 | }
188 | assert rightNode != null;
189 | assert leftNode != null;
190 | compare(delta, leftNode.getAttrMap(), rightNode.getAttrMap());
191 | compareInputs(delta,
192 | delta.left.getInputKeys(),
193 | delta.right.getInputKeys());
194 | }
195 | }
196 |
197 | /**
198 | * Compare inputs.
199 | *
200 | * @param delta the delta
201 | * @param leftData the left data
202 | * @param rightData the right data
203 | */
204 | public void compareInputs(@Nonnull GraphModel.DeltaRecord delta, @Nullable List leftData, @Nonnull List rightData) {
205 | if (leftData == null || leftData.size() == 0) {
206 | getBuffer(delta.name).add(rightData.stream().map(input ->
207 | RefString.format("node.addInput(\"%s\");", input)
208 | ).reduce((a, b) -> a + "\n" + b).orElse(""));
209 | } else if (rightData.size() == 0) {
210 | getBuffer(delta.name).add("node.clearInput();");
211 | } else if (leftData.size() != rightData.size()) {
212 | if (leftData.size() + 1 == rightData.size() && leftData.equals(rightData.subList(0, leftData.size()))) {
213 | getBuffer(delta.name).add(RefString.format("node.addInput(\"%s\");", rightData.get(leftData.size())));
214 | } else {
215 | System.out.printf("// %s: Input %s vs %s%n", delta.name,
216 | leftData.stream().reduce((a, b) -> a + "," + b).orElse("-"),
217 | rightData.stream().reduce((a, b) -> a + "," + b).orElse("-"));
218 | }
219 | } else {
220 | final int[] mismatchedIndices = IntStream.range(0, leftData.size()).filter(i -> !leftData.get(i).equals(rightData.get(i))).toArray();
221 | if (1 == mismatchedIndices.length) {
222 | getBuffer(delta.name).add(RefString.format("node.setInput(%s, \"%s\");", mismatchedIndices[0], rightData.get(mismatchedIndices[0])));
223 | } else if (0 < mismatchedIndices.length) {
224 | getBuffer(delta.name).add(RefString.format("node.clearInput();node.addAllInput(Arrays.asList(%s));",
225 | rightData.stream().map(x -> '"' + x + '"').reduce((a, b) -> a + "," + b).orElseGet(() -> "")));
226 | }
227 | }
228 | }
229 |
230 | private void compare(@Nonnull GraphModel.DeltaRecord delta, @Nonnull Map left, @Nonnull Map right) {
231 | Stream.concat(
232 | left.keySet().stream(),
233 | right.keySet().stream()
234 | )
235 | .forEach(key -> {
236 | AttrValue leftValue = left.get(key);
237 | AttrValue rightValue = right.get(key);
238 | if (null == leftValue) {
239 | getBuffer(delta.name).add(RefString.format("node.putAttr(\"%s\", %s);", key, toString(rightValue).trim()));
240 | } else if (null == rightValue) {
241 | getBuffer(delta.name).add(RefString.format("node.removeAttr(\"%s\");", key));
242 | } else {
243 | if (!leftValue.toString().equals(rightValue.toString())) {
244 | getBuffer(delta.name).add(RefString.format("node.putAttr(\"%s\", %s);", key, toString(rightValue).trim()));
245 | }
246 | }
247 | });
248 | }
249 |
250 | @Nonnull
251 | private ArrayList getBuffer(String name) {
252 | return nodeEdits.computeIfAbsent(name, k -> new ArrayList());
253 | }
254 |
255 | }
256 |
--------------------------------------------------------------------------------
/src/test/java/com/simiacryptus/text/gpt2/TestUtil.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2019 by Andrew Charneski.
3 | *
4 | * The author licenses this file to you under the
5 | * Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance
7 | * with the License. You may obtain a copy
8 | * of the License at
9 | *
10 | * http://www.apache.org/licenses/LICENSE-2.0
11 | *
12 | * Unless required by applicable law or agreed to in writing,
13 | * software distributed under the License is distributed on an
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 | * KIND, either express or implied. See the License for the
16 | * specific language governing permissions and limitations
17 | * under the License.
18 | */
19 |
20 | package com.simiacryptus.text.gpt2;
21 |
22 | import com.fasterxml.jackson.databind.ObjectMapper;
23 | import com.fasterxml.jackson.databind.SerializationFeature;
24 | import com.simiacryptus.tensorflow.TFUtil;
25 | import com.simiacryptus.tensorflow.TensorboardEventWriter;
26 | import com.simiacryptus.util.Util;
27 | import org.tensorflow.framework.GraphDef;
28 |
29 | import javax.annotation.Nonnull;
30 | import javax.swing.*;
31 | import java.awt.*;
32 | import java.io.File;
33 | import java.io.FileOutputStream;
34 | import java.io.IOException;
35 | import java.net.URISyntaxException;
36 |
37 | /**
38 | * The type Test util.
39 | */
40 | public class TestUtil {
41 | /**
42 | * Launch tensorboard.
43 | *
44 | * @param tensorboardDir the tensorboard dir
45 | * @throws IOException the io exception
46 | * @throws URISyntaxException the uri syntax exception
47 | */
48 | public static void launchTensorboard(@Nonnull File tensorboardDir) throws IOException, URISyntaxException {
49 | TFUtil.launchTensorboard(tensorboardDir.getAbsolutePath(), tensorboard -> {
50 | try {
51 | JOptionPane.showConfirmDialog(null, "Press OK to close");
52 | tensorboard.destroyForcibly();
53 | } catch (Exception e) {
54 | throw Util.throwException(e);
55 | }
56 | });
57 | }
58 |
59 | /**
60 | * Write graph file.
61 | *
62 | * @param graphDef the graph def
63 | * @param location the location
64 | * @param name the name
65 | * @return the file
66 | * @throws IOException the io exception
67 | */
68 | public static File writeGraph(@Nonnull GraphDef graphDef, File location, @Nonnull String name) throws IOException {
69 | TensorboardEventWriter eventWriter = new TensorboardEventWriter(new File(location, name), graphDef);
70 | eventWriter.write(graphDef);
71 | eventWriter.close();
72 | return location;
73 | }
74 |
75 | /**
76 | * Open.
77 | *
78 | * @param prefix the prefix
79 | * @param model the model
80 | * @throws IOException the io exception
81 | */
82 | public static void open(@Nonnull String prefix, Object model) throws IOException {
83 | File tempFile = File.createTempFile(prefix, ".json");
84 | FileOutputStream fileOutputStream = new FileOutputStream(tempFile);
85 | new ObjectMapper().enable(SerializationFeature.INDENT_OUTPUT).writeValue(fileOutputStream, model);
86 | fileOutputStream.close();
87 | Desktop.getDesktop().open(tempFile);
88 | }
89 | }
90 |
--------------------------------------------------------------------------------
/src/test/java/com/simiacryptus/text/gpt2/UserTests.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2019 by Andrew Charneski.
3 | *
4 | * The author licenses this file to you under the
5 | * Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance
7 | * with the License. You may obtain a copy
8 | * of the License at
9 | *
10 | * http://www.apache.org/licenses/LICENSE-2.0
11 | *
12 | * Unless required by applicable law or agreed to in writing,
13 | * software distributed under the License is distributed on an
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 | * KIND, either express or implied. See the License for the
16 | * specific language governing permissions and limitations
17 | * under the License.
18 | */
19 |
20 | package com.simiacryptus.text.gpt2;
21 |
22 | import com.simiacryptus.text.TextGenerator;
23 | import org.junit.jupiter.api.Test;
24 |
25 | import java.util.Arrays;
26 |
27 | /**
28 | * The type User tests.
29 | */
30 | public class UserTests {
31 |
32 | // /**
33 | // * Tensorboard graph.
34 | // */
35 | // @Test
36 | // public void tensorboardGraph() {
37 | // String now = new SimpleDateFormat("yyyyMMddHHmm").format(new Date());
38 | // try {
39 | // String id = UUID.randomUUID().toString();
40 | // File location = new File("target/" + now + "/tensorboard/" + id);
41 | // File graphFile = GPT2Util.getGraphFile_345M();
42 | // TestUtil.launchTensorboard(TestUtil.writeGraph(GraphDef.parseFrom(GPT2Model.loadModel(graphFile)), location, id));
43 | // } catch (IOException e) {
44 | // throw Util.throwException(e);
45 | // } catch (URISyntaxException e) {
46 | // throw Util.throwException(e);
47 | // }
48 | // }
49 |
50 | /**
51 | * Generate unconditional text.
52 | */
53 | @Test
54 | public void generateUnconditionalText() {
55 | TextGenerator textGenerator = GPT2Util.get345M().setVerbose(false);
56 | for (double t = 1.0; t < 3; t *= 1.1) {
57 | textGenerator.getModel();
58 | System.out.println("Temperature=" + t);
59 | System.out.println(textGenerator.generateText(150));
60 | }
61 | }
62 |
63 | /**
64 | * Generate conditional text.
65 | */
66 | @Test
67 | public void generateConditionalText() {
68 | TextGenerator textGenerator = GPT2Util.get345M().setVerbose(false);
69 | for (double t = 1.0; t < 3; t *= 1.1) {
70 | textGenerator.getModel();
71 | System.out.println("Temperature=" + t);
72 | for (String seed : Arrays.asList(
73 | "Hello World",
74 | "The opposite of up is",
75 | "English: Thank You\nSpanish:",
76 | "public void main(",
77 | "The problem with the world is",
78 | "I love",
79 | "You people are"
80 | )) {
81 | System.out.println(textGenerator.generateText(150, seed));
82 | }
83 | }
84 | }
85 |
86 |
87 | }
88 |
89 |
90 |
--------------------------------------------------------------------------------
/src/test/resources/logback.xml:
--------------------------------------------------------------------------------
1 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 | %msg%n
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
--------------------------------------------------------------------------------