├── .gitignore ├── LICENSE ├── README.md ├── pom.xml └── src ├── main └── java │ └── com │ └── simiacryptus │ └── text │ ├── GraphModifier.java │ ├── LanguageCodeModel.java │ ├── MinEntropyWrapper.java │ ├── ModelWrapper.java │ ├── SimpleModel.java │ ├── SumModel.java │ ├── TemperatureWrapper.java │ ├── TextGenerator.java │ ├── TopNWrapper.java │ └── gpt2 │ ├── GPT2Codec.java │ ├── GPT2Edit_345M.java │ ├── GPT2Model.java │ └── GPT2Util.java ├── site └── site.xml └── test ├── java └── com │ └── simiacryptus │ └── text │ └── gpt2 │ ├── DevTests.java │ ├── GraphComparer.java │ ├── TestUtil.java │ └── UserTests.java └── resources └── logback.xml /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | !/.mvn/wrapper/maven-wrapper.jar 3 | # Avoid ignoring Maven wrapper jar file (.jar files are usually ignored) 4 | *.hprof 5 | *.pb 6 | *.pem 7 | *.zip 8 | .idea/ 9 | .mvn/timing.properties 10 | buildNumber.properties 11 | dependency-reduced-pom.xml 12 | encoder_345M.json 13 | pom.xml.next 14 | pom.xml.releaseBackup 15 | pom.xml.tag 16 | pom.xml.versionsBackup 17 | release.properties 18 | target/ 19 | tf-gpt-2.iml 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the mask of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tf-gpt-2 2 | 3 | Java library for the GPT-2 Text Model using Tensorflow 4 | 5 | Source: 6 | 1. [Original Python Release](https://github.com/openai/gpt-2) 7 | 1. [OpenAI Blog - GPT2 Details](https://openai.com/blog/better-language-models/) 8 | 9 | More Background: 10 | 1. [Transformers and Attention Models](http://jalammar.github.io/illustrated-transformer/) 11 | 12 | ## Basic Usage 13 | 14 | ### Import the Library 15 | 16 | ```xml 17 | 18 | com.simiacryptus 19 | tf-gpt-2 20 | 1.7.1 21 | 22 | ``` 23 | 24 | ### Instantiate the text generator 25 | 26 | ```java 27 | import com.simiacryptus.text.TextGenerator; 28 | import com.simiacryptus.text.gpt2.GPT2Util; 29 | TextGenerator textGenerator = GPT2Util.get345M(); 30 | ``` 31 | 32 | ### Generate text 33 | 34 | ```java 35 | System.out.println(textGenerator.generateText(500)); 36 | ``` 37 | 38 | ### Generate text given prefix 39 | 40 | ```java 41 | System.out.println(textGenerator.generateText(500, "Once upon a time")); 42 | ``` 43 | 44 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 20 | 22 | 4.0.0 23 | 24 | 25 | com.simiacryptus 26 | java-parent 27 | 2.1.0 28 | ../../mvn-parents/java-parent 29 | 30 | 31 | tf-gpt-2 32 | GPT-2 Text Prediction via Tensorflow Java API 33 | 34 | 35 | 36 | 37 | 38 | com.simiacryptus 39 | bom 40 | ${project.version} 41 | pom 42 | import 43 | 44 | 45 | 46 | 47 | 48 | 49 | com.simiacryptus 50 | java-util 51 | 52 | 53 | org.tensorflow 54 | tensorflow 55 | 56 | 57 | org.tensorflow 58 | libtensorflow_jni_gpu 59 | 60 | 61 | org.tensorflow 62 | proto 63 | 64 | 65 | com.simiacryptus 66 | tensorflow-model 67 | 68 | 69 | 70 | com.google.code.gson 71 | gson 72 | 73 | 74 | commons-io 75 | commons-io 76 | 77 | 78 | org.slf4j 79 | slf4j-api 80 | 81 | 82 | 83 | org.junit.jupiter 84 | junit-jupiter 85 | test 86 | 87 | 88 | ch.qos.logback 89 | logback-classic 90 | test 91 | 92 | 93 | org.slf4j 94 | jcl-over-slf4j 95 | test 96 | 97 | 98 | org.slf4j 99 | log4j-over-slf4j 100 | test 101 | 102 | 103 | 104 | http://code.simiacrypt.us/release/${project.version}/tf-gpt-2 105 | 106 | 107 | simiacryptus 108 | s3://code.simiacrypt.us/release/${project.version}/tf-gpt-2 109 | 110 | 111 | 112 | -------------------------------------------------------------------------------- /src/main/java/com/simiacryptus/text/GraphModifier.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2019 by Andrew Charneski. 3 | * 4 | * The author licenses this file to you under the 5 | * Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance 7 | * with the License. You may obtain a copy 8 | * of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | */ 19 | 20 | package com.simiacryptus.text; 21 | 22 | import com.google.protobuf.ByteString; 23 | import com.google.protobuf.InvalidProtocolBufferException; 24 | import com.google.protobuf.ProtocolStringList; 25 | import com.simiacryptus.text.gpt2.GPT2Codec; 26 | import org.slf4j.Logger; 27 | import org.slf4j.LoggerFactory; 28 | import org.tensorflow.*; 29 | import org.tensorflow.framework.GraphDef; 30 | import org.tensorflow.framework.NodeDef; 31 | import org.tensorflow.framework.TensorProto; 32 | import org.tensorflow.framework.TensorShapeProto; 33 | 34 | import javax.annotation.Nonnull; 35 | import javax.annotation.Nullable; 36 | import java.nio.ByteBuffer; 37 | import java.nio.IntBuffer; 38 | import java.util.ArrayList; 39 | import java.util.Arrays; 40 | import java.util.HashSet; 41 | import java.util.List; 42 | import java.util.function.Consumer; 43 | import java.util.stream.Collectors; 44 | 45 | import static org.tensorflow.framework.DataType.DT_INT32; 46 | 47 | /** 48 | * The type Graph modifier. 49 | */ 50 | public abstract class GraphModifier { 51 | /** 52 | * The constant logger. 53 | */ 54 | protected static final Logger logger = LoggerFactory.getLogger(GPT2Codec.class); 55 | 56 | /** 57 | * Gets deletes init. 58 | * 59 | * @return the deletes init 60 | */ 61 | @Nonnull 62 | public abstract HashSet getDeletes_Init(); 63 | 64 | /** 65 | * Import graph def. 66 | * 67 | * @param graph the graph 68 | * @param graphdef the graphdef 69 | */ 70 | public static void importGraphDef(@Nonnull Graph graph, @Nonnull GraphDef graphdef) { 71 | final HashSet opsPresent = new HashSet<>(); 72 | graph.operations().forEachRemaining(op -> { 73 | opsPresent.add(op.name()); 74 | }); 75 | while (true) { 76 | final List nextToAdd = graphdef.getNodeList().stream() 77 | .filter(nodeDef -> !opsPresent.contains(nodeDef.getName())) 78 | .filter(nodeDef -> { 79 | final ProtocolStringList inputList = nodeDef.getInputList(); 80 | return inputList.isEmpty() || inputList.stream().allMatch(input -> opsPresent.contains(input.split(":")[0])); 81 | }) 82 | .collect(Collectors.toList()); 83 | if (nextToAdd.isEmpty()) break; 84 | nextToAdd.forEach(nodeDef -> { 85 | opsPresent.add(nodeDef.getName()); 86 | if (graph.operation(nodeDef.getName()) == null) { 87 | try { 88 | logger.debug("Adding node to graph: " + nodeDef.getName() + " <= [" + nodeDef.getInputList().stream().reduce((a, b) -> a + "," + b).orElse("") + "]"); 89 | // Add new node 90 | final OperationBuilder operationBuilder = graph.opBuilder(nodeDef.getOp(), nodeDef.getName()); 91 | operationBuilder.setDevice(nodeDef.getDevice()); 92 | nodeDef.getAttrMap().forEach((k, v) -> { 93 | switch (v.getValueCase()) { 94 | case TENSOR: { 95 | final TensorProto tensorProto = v.getTensor(); 96 | final long[] shape = tensorProto.getTensorShape().getDimList().stream().mapToLong(x -> x.getSize()).toArray(); 97 | final Class type; 98 | switch (tensorProto.getDtype()) { 99 | case DT_FLOAT: 100 | type = Float.class; 101 | break; 102 | case DT_INT32: 103 | type = Integer.class; 104 | break; 105 | default: 106 | throw new RuntimeException(tensorProto.getDtype().toString()); 107 | } 108 | if (null != tensorProto.getTensorContent() && !tensorProto.getTensorContent().isEmpty()) { 109 | operationBuilder.setAttr(k, Tensor.create(type, shape, tensorProto.getTensorContent().asReadOnlyByteBuffer())); 110 | } else if (0 < tensorProto.getIntValCount()) { 111 | operationBuilder.setAttr(k, Tensor.create(shape, IntBuffer.wrap(tensorProto.getIntValList().stream().mapToInt(x -> x).toArray()))); 112 | } else { 113 | throw new RuntimeException(tensorProto.toString()); 114 | } 115 | break; 116 | } 117 | case SHAPE: 118 | final TensorShapeProto shapeProto = v.getShape(); 119 | final long[] shape = shapeProto.getDimList().stream().mapToLong(x -> x.getSize()).toArray(); 120 | operationBuilder.setAttr(k, Shape.make(shape[0], Arrays.copyOfRange(shape, 1, shape.length))); 121 | break; 122 | case TYPE: 123 | operationBuilder.setAttr(k, DataType.valueOf(v.getType().name().split("_")[1])); 124 | break; 125 | case I: 126 | operationBuilder.setAttr(k, v.getI()); 127 | break; 128 | case B: 129 | operationBuilder.setAttr(k, v.getB()); 130 | break; 131 | default: 132 | throw new RuntimeException(k + " = " + v.toString()); 133 | } 134 | }); 135 | final Output[] inputs = nodeDef.getInputList().stream().map(input -> { 136 | final String[] split = input.split(":"); 137 | final int idx = 1 == split.length ? 0 : Integer.parseInt(split[1]); 138 | return graph.operation(split[0]).output(idx); 139 | }).toArray(i -> new Output[i]); 140 | if (nodeDef.getOp().equals("Pack")) { 141 | operationBuilder.addInputList(inputs); 142 | } else if (nodeDef.getOp().equals("ConcatV2")) { 143 | operationBuilder.addInputList(new Output[]{inputs[0], inputs[1]}); 144 | operationBuilder.addInput(inputs[2]); 145 | operationBuilder.addControlInput(inputs[2].op()); 146 | } else if (nodeDef.getOp().equals("StridedSlice")) { 147 | for (int i = 0; i < inputs.length; i++) { 148 | if (i == 0) { 149 | operationBuilder.addInput(inputs[i]); 150 | } else { 151 | operationBuilder.addInput(inputs[i]); 152 | operationBuilder.addControlInput(inputs[i].op()); 153 | } 154 | } 155 | } else if (inputs.length > 1) { 156 | for (int i = 0; i < inputs.length; i++) { 157 | operationBuilder.addInput(inputs[i]); 158 | } 159 | } else if (inputs.length > 0) { 160 | operationBuilder.addInput(inputs[0]); 161 | } 162 | try { 163 | operationBuilder.build(); 164 | } catch (Throwable e) { 165 | throw new RuntimeException("Error processing " + nodeDef.toString(), e); 166 | } 167 | } catch (RuntimeException e) { 168 | throw e; 169 | } catch (Throwable e) { 170 | throw new RuntimeException("Error processing " + nodeDef.toString(), e); 171 | } 172 | } 173 | }); 174 | } 175 | graphdef.getNodeList().stream() 176 | .filter(nodeDef -> !opsPresent.contains(nodeDef.getName())) 177 | .forEach(nodeDef -> { 178 | logger.warn("Remaining Node: " + nodeDef.toString()); 179 | }); 180 | } 181 | 182 | /** 183 | * Edit byte buffer. 184 | * 185 | * @param srcBuffer the src buffer 186 | * @param fn the fn 187 | * @return the byte buffer 188 | */ 189 | @Nonnull 190 | public static ByteBuffer edit(@Nonnull ByteBuffer srcBuffer, @Nonnull Consumer fn) { 191 | final ByteBuffer dstBuffer = copy(srcBuffer); 192 | final IntBuffer intBuffer = dstBuffer.asIntBuffer(); 193 | fn.accept(intBuffer); 194 | return dstBuffer; 195 | } 196 | 197 | /** 198 | * Copy byte buffer. 199 | * 200 | * @param srcBuffer the src buffer 201 | * @return the byte buffer 202 | */ 203 | @Nonnull 204 | public static ByteBuffer copy(@Nonnull ByteBuffer srcBuffer) { 205 | final ByteBuffer byteBuffer = ByteBuffer.allocate(srcBuffer.capacity()); 206 | byteBuffer.put(srcBuffer); 207 | byteBuffer.clear(); 208 | return byteBuffer; 209 | } 210 | 211 | /** 212 | * Tensor 1 tensor proto. 213 | * 214 | * @param shape the shape 215 | * @param vals the vals 216 | * @return the tensor proto 217 | */ 218 | @Nonnull 219 | public static TensorProto tensor1(int[] shape, @Nonnull int... vals) { 220 | TensorProto.Builder builder = TensorProto.newBuilder().setTensorShape(shape(shape)).setDtype(DT_INT32); 221 | Arrays.stream(vals).forEach(x -> builder.addIntVal(x)); 222 | return builder.build(); 223 | } 224 | 225 | /** 226 | * Tensor 2 tensor proto. 227 | * 228 | * @param shape the shape 229 | * @param vals the vals 230 | * @return the tensor proto 231 | */ 232 | @Nonnull 233 | public static TensorProto tensor2(int[] shape, @Nonnull int... vals) { 234 | TensorProto.Builder builder = TensorProto.newBuilder().setTensorShape(shape(shape)); 235 | byte[] array = new byte[vals.length * 4]; 236 | IntBuffer buffer = ByteBuffer.wrap(array).asIntBuffer(); 237 | for (int val : vals) buffer.put(Integer.reverseBytes(val)); 238 | builder.setTensorContent(ByteString.copyFrom(array)).setDtype(DT_INT32); 239 | return builder.build(); 240 | } 241 | 242 | /** 243 | * Shape tensor shape proto. 244 | * 245 | * @param dims the dims 246 | * @return the tensor shape proto 247 | */ 248 | @Nonnull 249 | public static TensorShapeProto shape(@Nonnull int... dims) { 250 | TensorShapeProto.Builder builder = TensorShapeProto.newBuilder(); 251 | Arrays.stream(dims).mapToObj(v -> TensorShapeProto.Dim.newBuilder().setSize(v).build()).forEach(value -> builder.addDim(value)); 252 | return builder.build(); 253 | } 254 | 255 | /** 256 | * Edit graph def. 257 | * 258 | * @param src the src 259 | * @param prefix the prefix 260 | * @param includeOriginal the include original 261 | * @return the graph def 262 | * @throws InvalidProtocolBufferException the invalid protocol buffer exception 263 | */ 264 | @Nonnull 265 | public GraphDef edit(@Nonnull GraphDef src, String prefix, boolean includeOriginal) throws InvalidProtocolBufferException { 266 | final GraphDef srcGraphDef = GraphDef.parseFrom(src.toByteArray()); 267 | final GraphDef.Builder destGraphDef = GraphDef.newBuilder(); 268 | final HashSet deletes = getDeletes_Init(); 269 | final HashSet editedNodes = new HashSet<>(); 270 | for (int index = 0; index < srcGraphDef.getNodeCount(); index++) { 271 | final NodeDef node = srcGraphDef.getNode(index); 272 | if (deletes.contains(node.getName())) { 273 | logger.debug("Omit Node: " + node.getName()); 274 | } else { 275 | @Nullable NodeDef.Builder nodeBuilder = edit(node.toBuilder()); 276 | if (null != nodeBuilder) { 277 | logger.debug("Edit Node: " + node.getName()); 278 | destGraphDef.addNode(nodeBuilder.build()); 279 | editedNodes.add(node.getName()); 280 | } else { 281 | // logger.debug("Pass-thru Node: " + node.getName()); 282 | destGraphDef.addNode(node); 283 | } 284 | } 285 | } 286 | addNodes(nodeDef -> { 287 | destGraphDef.addNode(nodeDef); 288 | editedNodes.add(nodeDef.getName()); 289 | }); 290 | // return destGraphDef.build(); 291 | return prefixRewrite(destGraphDef.build(), editedNodes, prefix, includeOriginal); 292 | } 293 | 294 | /** 295 | * Edit node def . builder. 296 | * 297 | * @param node the node 298 | * @return the node def . builder 299 | */ 300 | @Nullable 301 | public abstract NodeDef.Builder edit(NodeDef.Builder node); 302 | 303 | /** 304 | * Add nodes. 305 | * 306 | * @param add the add 307 | */ 308 | protected abstract void addNodes(Consumer add); 309 | 310 | /** 311 | * Prefix rewrite graph def. 312 | * 313 | * @param graphDef the graph def 314 | * @param editedNodes the edited nodes 315 | * @param prefix the prefix 316 | * @param includeOriginal the include original 317 | * @return the graph def 318 | */ 319 | @Nonnull 320 | protected GraphDef prefixRewrite(@Nonnull GraphDef graphDef, @Nonnull HashSet editedNodes, String prefix, boolean includeOriginal) { 321 | while (true) { 322 | final List newItems = graphDef.getNodeList().stream() 323 | .filter(nodeDef -> !editedNodes.contains(nodeDef.getName())) 324 | .filter(nodeDef -> nodeDef.getInputList().stream().filter(input -> editedNodes.contains(input.split(":")[0])).findAny().isPresent()) 325 | .map(x -> x.getName()).collect(Collectors.toList()); 326 | if (newItems.isEmpty()) break; 327 | for (String newItem : newItems) { 328 | logger.debug("Item touched by rename: " + newItem); 329 | } 330 | editedNodes.addAll(newItems); 331 | } 332 | final GraphDef.Builder destGraphDef = GraphDef.newBuilder(); 333 | for (NodeDef nodeDef : graphDef.getNodeList()) { 334 | NodeDef.Builder builder; 335 | if (editedNodes.contains(nodeDef.getName())) { 336 | builder = nodeDef.toBuilder(); 337 | builder.setName(prefix + nodeDef.getName()); 338 | } else { 339 | builder = null; 340 | } 341 | final ArrayList inputs = new ArrayList<>(nodeDef.getInputList()); 342 | if (inputs.stream().filter(o -> editedNodes.contains(o.split(":")[0])).findAny().isPresent()) { 343 | if (null == builder) builder = nodeDef.toBuilder(); 344 | builder.clearInput(); 345 | for (String input : inputs) { 346 | if (editedNodes.contains(input.split(":")[0])) { 347 | logger.debug(nodeDef.getName() + " [ " + input + " ] += " + prefix); 348 | builder.addInput(prefix + input); 349 | } else { 350 | builder.addInput(input); 351 | } 352 | } 353 | } 354 | if (null != builder) { 355 | logger.debug("Edit in renaming: " + builder.getName()); 356 | destGraphDef.addNode(builder.build()); 357 | } else { 358 | if (includeOriginal) destGraphDef.addNode(nodeDef); 359 | } 360 | } 361 | return destGraphDef.build(); 362 | } 363 | } 364 | -------------------------------------------------------------------------------- /src/main/java/com/simiacryptus/text/LanguageCodeModel.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2019 by Andrew Charneski. 3 | * 4 | * The author licenses this file to you under the 5 | * Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance 7 | * with the License. You may obtain a copy 8 | * of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | */ 19 | 20 | package com.simiacryptus.text; 21 | 22 | import org.tensorflow.Tensor; 23 | 24 | import javax.annotation.Nonnull; 25 | import javax.annotation.Nullable; 26 | import java.util.function.BiFunction; 27 | 28 | /** 29 | * The interface Language code model. 30 | */ 31 | public interface LanguageCodeModel { 32 | /** 33 | * Gets filter fn. 34 | * 35 | * @return the filter fn 36 | */ 37 | @Nullable 38 | BiFunction getFilterFn(); 39 | 40 | /** 41 | * Sets filter fn. 42 | * 43 | * @param filterFn the filter fn 44 | * @return the filter fn 45 | */ 46 | @Nonnull 47 | LanguageCodeModel setFilterFn(BiFunction filterFn); 48 | 49 | /** 50 | * Copy language code model. 51 | * 52 | * @return the language code model 53 | */ 54 | @Nonnull 55 | LanguageCodeModel copy(); 56 | 57 | /** 58 | * Clear language code model. 59 | * 60 | * @return the language code model 61 | */ 62 | @Nonnull 63 | LanguageCodeModel clear(); 64 | 65 | /** 66 | * Eval float [ ]. 67 | * 68 | * @param data_X the data x 69 | * @return the float [ ] 70 | */ 71 | float[] eval(int data_X); 72 | 73 | /** 74 | * State tensor. 75 | * 76 | * @return the tensor 77 | */ 78 | @Nullable 79 | Tensor state(); 80 | } 81 | -------------------------------------------------------------------------------- /src/main/java/com/simiacryptus/text/MinEntropyWrapper.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2019 by Andrew Charneski. 3 | * 4 | * The author licenses this file to you under the 5 | * Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance 7 | * with the License. You may obtain a copy 8 | * of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | */ 19 | 20 | package com.simiacryptus.text; 21 | 22 | import com.simiacryptus.ref.wrappers.RefString; 23 | import org.slf4j.Logger; 24 | import org.slf4j.LoggerFactory; 25 | 26 | import javax.annotation.Nonnull; 27 | import java.util.ArrayList; 28 | import java.util.Arrays; 29 | import java.util.stream.DoubleStream; 30 | import java.util.stream.IntStream; 31 | 32 | /** 33 | * The type Min entropy wrapper. 34 | */ 35 | public class MinEntropyWrapper extends ModelWrapper { 36 | /** 37 | * The constant logger. 38 | */ 39 | protected static final Logger logger = LoggerFactory.getLogger(MinEntropyWrapper.class); 40 | private final ArrayList entropyHistory = new ArrayList<>(); 41 | private double value; 42 | 43 | /** 44 | * Instantiates a new Min entropy wrapper. 45 | * 46 | * @param value the value 47 | * @param child the child 48 | */ 49 | public MinEntropyWrapper(double value, LanguageCodeModel child) { 50 | super(child); 51 | this.value = value; 52 | } 53 | 54 | /** 55 | * Gets value. 56 | * 57 | * @return the value 58 | */ 59 | public double getValue() { 60 | return value; 61 | } 62 | 63 | /** 64 | * Sets value. 65 | * 66 | * @param value the value 67 | */ 68 | public void setValue(double value) { 69 | this.value = value; 70 | } 71 | 72 | /** 73 | * Entropy double. 74 | * 75 | * @param floats the floats 76 | * @return the double 77 | */ 78 | public static double entropy(@Nonnull float[] floats) { 79 | return IntStream.range(0, floats.length).mapToDouble(i -> { 80 | float p = floats[i]; 81 | return p <= 0 ? 0 : -p * Math.log(p); 82 | }).sum() / Math.log(2); 83 | } 84 | 85 | /** 86 | * Pow copy float [ ]. 87 | * 88 | * @param floats the floats 89 | * @param value the value 90 | * @return the float [ ] 91 | */ 92 | @Nonnull 93 | public static float[] powCopy(@Nonnull float[] floats, double value) { 94 | float[] copy = Arrays.copyOf(floats, floats.length); 95 | pow(copy, value); 96 | return copy; 97 | } 98 | 99 | /** 100 | * Pow. 101 | * 102 | * @param floats the floats 103 | * @param value the value 104 | */ 105 | public static void pow(@Nonnull float[] floats, double value) { 106 | for (int i = 0; i < floats.length; i++) { 107 | floats[i] = (float) Math.pow(floats[i], value); 108 | } 109 | } 110 | 111 | @Override 112 | public float[] eval(int data_X) { 113 | LanguageCodeModel child = children[0]; 114 | float[] floats = child.eval(data_X); 115 | double entropy = entropy(floats); 116 | entropyHistory.add(entropy); 117 | double[] schedule = DoubleStream.iterate(1.0, x -> x * 0.9).limit(1000).toArray(); 118 | for (int i = 0; i < schedule.length; i++) { 119 | float[] copy = powCopy(floats, schedule[i]); 120 | SumModel.normalize(copy); 121 | if (entropy(copy) > value) { 122 | floats = copy; 123 | break; 124 | } 125 | } 126 | logger.debug(RefString.format("Entropy = %s => %s", entropy, entropy(floats))); 127 | //SumModel.normalize(floats); 128 | return floats; 129 | } 130 | } 131 | -------------------------------------------------------------------------------- /src/main/java/com/simiacryptus/text/ModelWrapper.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2019 by Andrew Charneski. 3 | * 4 | * The author licenses this file to you under the 5 | * Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance 7 | * with the License. You may obtain a copy 8 | * of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | */ 19 | 20 | package com.simiacryptus.text; 21 | 22 | import org.tensorflow.Tensor; 23 | 24 | import javax.annotation.Nonnull; 25 | import java.util.Arrays; 26 | import java.util.function.BiFunction; 27 | 28 | /** 29 | * The type Model wrapper. 30 | */ 31 | public abstract class ModelWrapper implements LanguageCodeModel { 32 | /** 33 | * The Children. 34 | */ 35 | protected final LanguageCodeModel[] children; 36 | 37 | /** 38 | * Instantiates a new Model wrapper. 39 | * 40 | * @param children the children 41 | */ 42 | public ModelWrapper(LanguageCodeModel... children) { 43 | this.children = children; 44 | } 45 | 46 | @Override 47 | public BiFunction getFilterFn() { 48 | return children[0].getFilterFn(); 49 | } 50 | 51 | @Nonnull 52 | @Override 53 | public LanguageCodeModel copy() { 54 | return new SumModel(Arrays.stream(children) 55 | .map(languageCodeModel -> languageCodeModel.copy()) 56 | .toArray(i -> new LanguageCodeModel[i])); 57 | } 58 | 59 | @Nonnull 60 | @Override 61 | public LanguageCodeModel clear() { 62 | for (LanguageCodeModel child : children) { 63 | child.clear(); 64 | } 65 | return this; 66 | } 67 | 68 | @Override 69 | public abstract float[] eval(int data_X); 70 | 71 | @Nonnull 72 | @Override 73 | public LanguageCodeModel setFilterFn(BiFunction filterFn) { 74 | for (LanguageCodeModel child : children) { 75 | child.setFilterFn(filterFn); 76 | } 77 | return this; 78 | } 79 | 80 | @Override 81 | public Tensor state() { 82 | assert 1 == children.length; 83 | return children[0].state(); 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /src/main/java/com/simiacryptus/text/SimpleModel.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2019 by Andrew Charneski. 3 | * 4 | * The author licenses this file to you under the 5 | * Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance 7 | * with the License. You may obtain a copy 8 | * of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | */ 19 | 20 | package com.simiacryptus.text; 21 | 22 | import com.simiacryptus.text.gpt2.GPT2Codec; 23 | import org.tensorflow.Tensor; 24 | 25 | import javax.annotation.Nonnull; 26 | import javax.annotation.Nullable; 27 | import java.util.Arrays; 28 | import java.util.List; 29 | import java.util.Map; 30 | import java.util.function.BiFunction; 31 | import java.util.stream.Collectors; 32 | 33 | /** 34 | * The type Simple model. 35 | */ 36 | public class SimpleModel implements LanguageCodeModel { 37 | 38 | @Nonnull 39 | private final float[] result; 40 | 41 | /** 42 | * Instantiates a new Simple model. 43 | * 44 | * @param result the result 45 | */ 46 | public SimpleModel(@Nonnull float... result) { 47 | this.result = Arrays.copyOf(result, result.length); 48 | } 49 | 50 | @Nullable 51 | @Override 52 | public BiFunction getFilterFn() { 53 | return null; 54 | } 55 | 56 | /** 57 | * Build simple model. 58 | * 59 | * @param codec the codec 60 | * @param text the text 61 | * @return the simple model 62 | */ 63 | @Nonnull 64 | public static SimpleModel build(@Nonnull GPT2Codec codec, String text) { 65 | List encode = codec.encode(text); 66 | Map counts = encode.stream().collect(Collectors.groupingBy(x -> x, Collectors.counting())); 67 | float[] result = new float[codec.getVocabSize()]; 68 | for (int i = 0; i < result.length; i++) { 69 | result[i] = (float) counts.getOrDefault(i, 0l) / encode.size(); 70 | } 71 | return new SimpleModel(result); 72 | } 73 | 74 | @Nonnull 75 | @Override 76 | public LanguageCodeModel copy() { 77 | return new SimpleModel(result); 78 | } 79 | 80 | @Nonnull 81 | @Override 82 | public LanguageCodeModel clear() { 83 | return this; 84 | } 85 | 86 | @Nonnull 87 | @Override 88 | public float[] eval(int data_X) { 89 | return Arrays.copyOf(result, result.length); 90 | } 91 | 92 | @Nonnull 93 | @Override 94 | public LanguageCodeModel setFilterFn(BiFunction filterFn) { 95 | return this; 96 | } 97 | 98 | @Nullable 99 | @Override 100 | public Tensor state() { 101 | return null; 102 | } 103 | 104 | /** 105 | * Sets temperature. 106 | * 107 | * @return the temperature 108 | */ 109 | @Nonnull 110 | public LanguageCodeModel setTemperature() { 111 | return this; 112 | } 113 | 114 | } 115 | -------------------------------------------------------------------------------- /src/main/java/com/simiacryptus/text/SumModel.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2019 by Andrew Charneski. 3 | * 4 | * The author licenses this file to you under the 5 | * Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance 7 | * with the License. You may obtain a copy 8 | * of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | */ 19 | 20 | package com.simiacryptus.text; 21 | 22 | import com.simiacryptus.ref.lang.RefUtil; 23 | 24 | import javax.annotation.Nonnull; 25 | import java.util.Arrays; 26 | import java.util.stream.IntStream; 27 | 28 | /** 29 | * The type Sum model. 30 | */ 31 | public class SumModel extends ModelWrapper { 32 | 33 | /** 34 | * Instantiates a new Sum model. 35 | * 36 | * @param children the children 37 | */ 38 | public SumModel(LanguageCodeModel... children) { 39 | super(children); 40 | } 41 | 42 | /** 43 | * Normalize. 44 | * 45 | * @param sums the sums 46 | */ 47 | public static void normalize(@Nonnull float[] sums) { 48 | double sum = IntStream.range(0, sums.length).mapToDouble(x -> sums[x]).sum(); 49 | for (int i = 0; i < sums.length; i++) sums[i] /= sum; 50 | } 51 | 52 | @Nonnull 53 | @Override 54 | public float[] eval(int data_X) { 55 | float[] sums = RefUtil.get(Arrays.stream(children).map(c -> c.eval(data_X)).reduce((a, b) -> { 56 | for (int i = 0; i < a.length; i++) a[i] *= b[i]; 57 | return a; 58 | })); 59 | normalize(sums); 60 | return sums; 61 | } 62 | 63 | } 64 | -------------------------------------------------------------------------------- /src/main/java/com/simiacryptus/text/TemperatureWrapper.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2019 by Andrew Charneski. 3 | * 4 | * The author licenses this file to you under the 5 | * Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance 7 | * with the License. You may obtain a copy 8 | * of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | */ 19 | 20 | package com.simiacryptus.text; 21 | 22 | /** 23 | * The type Temperature wrapper. 24 | */ 25 | public class TemperatureWrapper extends ModelWrapper { 26 | 27 | private double value; 28 | 29 | /** 30 | * Instantiates a new Temperature wrapper. 31 | * 32 | * @param value the value 33 | * @param child the child 34 | */ 35 | public TemperatureWrapper(double value, LanguageCodeModel child) { 36 | super(child); 37 | this.value = value; 38 | } 39 | 40 | /** 41 | * Gets value. 42 | * 43 | * @return the value 44 | */ 45 | public double getValue() { 46 | return value; 47 | } 48 | 49 | /** 50 | * Sets value. 51 | * 52 | * @param value the value 53 | */ 54 | public void setValue(double value) { 55 | this.value = value; 56 | } 57 | 58 | @Override 59 | public float[] eval(int data_X) { 60 | LanguageCodeModel child = children[0]; 61 | float[] floats = child.eval(data_X); 62 | for (int i = 0; i < floats.length; i++) { 63 | floats[i] = (float) Math.pow(floats[i], getValue()); 64 | } 65 | SumModel.normalize(floats); 66 | return floats; 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /src/main/java/com/simiacryptus/text/TextGenerator.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2019 by Andrew Charneski. 3 | * 4 | * The author licenses this file to you under the 5 | * Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance 7 | * with the License. You may obtain a copy 8 | * of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | */ 19 | 20 | package com.simiacryptus.text; 21 | 22 | import com.simiacryptus.ref.wrappers.RefString; 23 | import com.simiacryptus.text.gpt2.GPT2Codec; 24 | import org.slf4j.Logger; 25 | import org.slf4j.LoggerFactory; 26 | 27 | import javax.annotation.Nonnull; 28 | import javax.annotation.Nullable; 29 | import java.util.ArrayList; 30 | import java.util.Arrays; 31 | import java.util.Comparator; 32 | import java.util.List; 33 | import java.util.function.Predicate; 34 | import java.util.stream.IntStream; 35 | 36 | /** 37 | * The type Text generator. 38 | */ 39 | public class TextGenerator { 40 | /** 41 | * The constant logger. 42 | */ 43 | protected static final Logger logger = LoggerFactory.getLogger(TextGenerator.class); 44 | /** 45 | * The Vocabulary size. 46 | */ 47 | protected final int vocabularySize; 48 | /** 49 | * The Codec. 50 | */ 51 | protected final GPT2Codec codec; 52 | /** 53 | * The Verbose. 54 | */ 55 | protected boolean verbose = false; 56 | /** 57 | * The Choices to log. 58 | */ 59 | protected int choicesToLog = 10; 60 | /** 61 | * The Codes. 62 | */ 63 | @Nonnull 64 | List codes = new ArrayList<>(); 65 | /** 66 | * The Next selections. 67 | */ 68 | @Nullable 69 | float[] nextSelections; 70 | private LanguageCodeModel model; 71 | 72 | /** 73 | * Instantiates a new Text generator. 74 | * 75 | * @param vocabularySize the vocabulary size 76 | * @param model the model 77 | * @param codec the codec 78 | */ 79 | public TextGenerator(int vocabularySize, LanguageCodeModel model, GPT2Codec codec) { 80 | this.setModel(model); 81 | this.vocabularySize = vocabularySize; 82 | this.codec = codec; 83 | } 84 | 85 | /** 86 | * Gets choices to log. 87 | * 88 | * @return the choices to log 89 | */ 90 | public int getChoicesToLog() { 91 | return choicesToLog; 92 | } 93 | 94 | /** 95 | * Sets choices to log. 96 | * 97 | * @param choicesToLog the choices to log 98 | * @return the choices to log 99 | */ 100 | @Nonnull 101 | public TextGenerator setChoicesToLog(int choicesToLog) { 102 | this.choicesToLog = choicesToLog; 103 | return this; 104 | } 105 | 106 | /** 107 | * Gets model. 108 | * 109 | * @return the model 110 | */ 111 | public LanguageCodeModel getModel() { 112 | return model; 113 | } 114 | 115 | /** 116 | * Sets model. 117 | * 118 | * @param model the model 119 | * @return the model 120 | */ 121 | @Nonnull 122 | public TextGenerator setModel(LanguageCodeModel model) { 123 | if (this.model == model) return this; 124 | if (null != this.model) this.model.clear(); 125 | this.model = model; 126 | return this; 127 | } 128 | 129 | /** 130 | * Gets text. 131 | * 132 | * @return the text 133 | */ 134 | public String getText() { 135 | return codec.decode(codes.toArray(new Integer[]{})); 136 | } 137 | 138 | /** 139 | * Gets vocabulary size. 140 | * 141 | * @return the vocabulary size 142 | */ 143 | public int getVocabularySize() { 144 | return vocabularySize; 145 | } 146 | 147 | /** 148 | * Is verbose boolean. 149 | * 150 | * @return the boolean 151 | */ 152 | public boolean isVerbose() { 153 | return verbose; 154 | } 155 | 156 | /** 157 | * Sets verbose. 158 | * 159 | * @param verbose the verbose 160 | * @return the verbose 161 | */ 162 | @Nonnull 163 | public TextGenerator setVerbose(boolean verbose) { 164 | this.verbose = verbose; 165 | return this; 166 | } 167 | 168 | /** 169 | * Sorted indices int [ ]. 170 | * 171 | * @param chosen the chosen 172 | * @param limit the limit 173 | * @return the int [ ] 174 | */ 175 | public static int[] sortedIndices(@Nonnull float[] chosen, int limit) { 176 | return IntStream.range(0, chosen.length) 177 | .mapToObj(x -> x) 178 | .sorted(Comparator.comparing(c -> -chosen[c])) 179 | .limit(limit) 180 | .mapToInt(x -> x) 181 | .toArray(); 182 | } 183 | 184 | /** 185 | * Copy text generator. 186 | * 187 | * @return the text generator 188 | */ 189 | @Nonnull 190 | public TextGenerator copy() { 191 | TextGenerator copy = new TextGenerator(vocabularySize, getModel().copy(), codec); 192 | copy.codes.addAll(this.codes); 193 | copy.verbose = this.verbose; 194 | copy.choicesToLog = this.choicesToLog; 195 | copy.nextSelections = null == this.nextSelections ? null : Arrays.copyOf(this.nextSelections, this.nextSelections.length); 196 | return copy; 197 | } 198 | 199 | /** 200 | * Generate text string. 201 | * 202 | * @param terminator the terminator 203 | * @return the string 204 | */ 205 | @Nonnull 206 | public String generateText(@Nonnull Predicate terminator) { 207 | return generateText(terminator, null); 208 | } 209 | 210 | /** 211 | * Generate text string. 212 | * 213 | * @param numberOfWords the number of words 214 | * @return the string 215 | */ 216 | @Nonnull 217 | public String generateText(int numberOfWords) { 218 | return generateText(numberOfWords, null); 219 | } 220 | 221 | /** 222 | * Generate text string. 223 | * 224 | * @param terminator the terminator 225 | * @param prefix the prefix 226 | * @return the string 227 | */ 228 | @Nonnull 229 | public String generateText(@Nonnull Predicate terminator, String prefix) { 230 | reset(); 231 | feed(prefix); 232 | generate(terminator); 233 | return getText(); 234 | } 235 | 236 | /** 237 | * Generate text string. 238 | * 239 | * @param numberOfTokens the number of tokens 240 | * @param prefix the prefix 241 | * @return the string 242 | */ 243 | @Nonnull 244 | public String generateText(int numberOfTokens, String prefix) { 245 | reset(); 246 | feed(prefix); 247 | generate(numberOfTokens); 248 | return getText(); 249 | } 250 | 251 | /** 252 | * Generate string. 253 | * 254 | * @param fn the fn 255 | * @return the string 256 | */ 257 | public String generate(@Nonnull Predicate fn) { 258 | init(); 259 | ArrayList theseCodes = new ArrayList<>(); 260 | try { 261 | for (int wordIndex = 0; wordIndex == 0 || fn.test(codec.decode(theseCodes.toArray(new Integer[]{}))); wordIndex++) { 262 | assert nextSelections != null; 263 | int selected = select(nextSelections); 264 | if (isVerbose()) { 265 | if (wordIndex != 0) log(nextSelections, codec, getChoicesToLog()); 266 | logger.info(RefString.format("Selected New Text: '%s'", codec.decode(selected))); 267 | } 268 | if (selected == getVocabularySize() - 1) break; 269 | codes.add(selected); 270 | theseCodes.add(selected); 271 | nextSelections = getModel().eval(selected); 272 | } 273 | } catch (Throwable e) { 274 | //logger.warn("Error generating text", e); 275 | throw new RuntimeException("Error generating text: " + codec.decode(theseCodes.toArray(new Integer[]{})), e); 276 | } 277 | return codec.decode(theseCodes.toArray(new Integer[]{})); 278 | } 279 | 280 | /** 281 | * Generate. 282 | * 283 | * @param numberOfWords the number of words 284 | */ 285 | public void generate(int numberOfWords) { 286 | init(); 287 | try { 288 | for (int wordIndex = 0; wordIndex < numberOfWords; wordIndex++) { 289 | assert nextSelections != null; 290 | int selected = select(nextSelections); 291 | if (isVerbose()) { 292 | if (wordIndex != 0) log(nextSelections, codec, getChoicesToLog()); 293 | logger.info(RefString.format("Selected New Text: '%s'", codec.decode(selected))); 294 | } 295 | if (selected == getVocabularySize() - 1) break; 296 | codes.add(selected); 297 | nextSelections = getModel().eval(selected); 298 | } 299 | } catch (Throwable e) { 300 | logger.warn("Error generating text", e); 301 | } 302 | } 303 | 304 | /** 305 | * Init text generator. 306 | * 307 | * @return the text generator 308 | */ 309 | @Nonnull 310 | public TextGenerator init() { 311 | if (nextSelections == null) feed(""); 312 | return this; 313 | } 314 | 315 | /** 316 | * Feed double. 317 | * 318 | * @param text the text 319 | * @return the double 320 | */ 321 | public double feed(String text) { 322 | double entropy = 0.0; 323 | List codeList = new ArrayList<>(); 324 | codeList.addAll(codec.encode(text)); 325 | if (codeList.isEmpty()) codeList.add(getVocabularySize() - 1); 326 | for (Integer code : codeList) { 327 | if (null != nextSelections) { 328 | float p = nextSelections[code]; 329 | entropy += p != 0 ? -Math.log(p) : Math.log(getVocabularySize()); 330 | } 331 | codes.add(code); 332 | nextSelections = getModel().eval(code); 333 | if (isVerbose()) { 334 | logger.info(RefString.format("Feed token: '%s'", codec.decode(code))); 335 | assert nextSelections != null; 336 | log(nextSelections, codec, getChoicesToLog()); 337 | } 338 | } 339 | return entropy / Math.log(2); 340 | } 341 | 342 | /** 343 | * Reset text generator. 344 | * 345 | * @return the text generator 346 | */ 347 | @Nonnull 348 | public TextGenerator reset() { 349 | codes.clear(); 350 | getModel().clear(); 351 | return this; 352 | } 353 | 354 | /** 355 | * Select int. 356 | * 357 | * @param chosen the chosen 358 | * @return the int 359 | */ 360 | protected int select(@Nonnull float[] chosen) { 361 | double originalFate = Math.random() * 1; 362 | double fate = originalFate; 363 | int j = 0; 364 | int[] topCandidates = sortedIndices(chosen, chosen.length); 365 | while (j < topCandidates.length && fate > chosen[topCandidates[j]]) { 366 | int topCandidate = topCandidates[j++]; 367 | fate -= chosen[topCandidate]; 368 | } 369 | int topCandidate = topCandidates[j]; 370 | logger.debug(RefString.format("Chose #%s with fate %s", topCandidate, originalFate)); 371 | return topCandidate; 372 | } 373 | 374 | /** 375 | * Log. 376 | * 377 | * @param chosen the chosen 378 | * @param codec the codec 379 | * @param count the count 380 | */ 381 | protected void log(@Nonnull float[] chosen, @Nonnull GPT2Codec codec, int count) { 382 | Arrays.stream(sortedIndices(chosen, count)) 383 | .forEach(candidate -> logger.info(RefString.format("\t#%d %.4f%% '%s'", candidate, chosen[candidate] * 100, codec.decode(candidate)))); 384 | } 385 | } 386 | -------------------------------------------------------------------------------- /src/main/java/com/simiacryptus/text/TopNWrapper.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2019 by Andrew Charneski. 3 | * 4 | * The author licenses this file to you under the 5 | * Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance 7 | * with the License. You may obtain a copy 8 | * of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | */ 19 | 20 | package com.simiacryptus.text; 21 | 22 | import java.util.Comparator; 23 | import java.util.stream.IntStream; 24 | 25 | /** 26 | * The type Top n wrapper. 27 | */ 28 | public class TopNWrapper extends ModelWrapper { 29 | 30 | private int value; 31 | 32 | /** 33 | * Instantiates a new Top n wrapper. 34 | * 35 | * @param value the value 36 | * @param child the child 37 | */ 38 | public TopNWrapper(int value, LanguageCodeModel child) { 39 | super(child); 40 | this.value = value; 41 | } 42 | 43 | /** 44 | * Gets value. 45 | * 46 | * @return the value 47 | */ 48 | public int getValue() { 49 | return value; 50 | } 51 | 52 | /** 53 | * Sets value. 54 | * 55 | * @param value the value 56 | */ 57 | public void setValue(int value) { 58 | this.value = value; 59 | } 60 | 61 | @Override 62 | public float[] eval(int data_X) { 63 | LanguageCodeModel child = children[0]; 64 | float[] floats = child.eval(data_X); 65 | int[] sortedIndices = IntStream.range(0, floats.length) 66 | .mapToObj(x -> x) 67 | .sorted(Comparator.comparing(x -> -floats[x])) 68 | .mapToInt(x -> x) 69 | .toArray(); 70 | for (int i = value; i < sortedIndices.length; i++) { 71 | floats[sortedIndices[i]] = 0; 72 | } 73 | SumModel.normalize(floats); 74 | return floats; 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /src/main/java/com/simiacryptus/text/gpt2/GPT2Codec.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2019 by Andrew Charneski. 3 | * 4 | * The author licenses this file to you under the 5 | * Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance 7 | * with the License. You may obtain a copy 8 | * of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | */ 19 | 20 | package com.simiacryptus.text.gpt2; 21 | 22 | import com.google.gson.GsonBuilder; 23 | import com.google.gson.JsonObject; 24 | import com.simiacryptus.ref.lang.RefUtil; 25 | import com.simiacryptus.ref.wrappers.RefStringBuilder; 26 | import com.simiacryptus.util.Util; 27 | import org.apache.commons.io.FileUtils; 28 | import org.slf4j.Logger; 29 | import org.slf4j.LoggerFactory; 30 | 31 | import javax.annotation.Nonnull; 32 | import javax.annotation.Nullable; 33 | import java.io.File; 34 | import java.io.IOException; 35 | import java.util.*; 36 | import java.util.function.Function; 37 | import java.util.stream.Collectors; 38 | import java.util.stream.Stream; 39 | 40 | /** 41 | * The type Gpt 2 codec. 42 | */ 43 | public class GPT2Codec { 44 | /** 45 | * The constant logger. 46 | */ 47 | protected static final Logger logger = LoggerFactory.getLogger(GPT2Codec.class); 48 | 49 | /** 50 | * The Encoder. 51 | */ 52 | protected final TreeMap encoder; 53 | /** 54 | * The Decoder. 55 | */ 56 | @Nonnull 57 | protected final TreeMap decoder; 58 | private final int vocabSize; 59 | 60 | /** 61 | * Instantiates a new Gpt 2 codec. 62 | * 63 | * @param encoder the encoder 64 | * @param vocabSize the vocab size 65 | */ 66 | public GPT2Codec(TreeMap encoder, int vocabSize) { 67 | this.encoder = encoder; 68 | this.vocabSize = vocabSize; 69 | this.decoder = buildDecoder(this.encoder); 70 | } 71 | 72 | /** 73 | * Instantiates a new Gpt 2 codec. 74 | * 75 | * @param file the file 76 | * @param vocabSize the vocab size 77 | */ 78 | public GPT2Codec(@Nonnull File file, int vocabSize) { 79 | this(GPT2Codec.loadEncoder(file), vocabSize); 80 | } 81 | 82 | /** 83 | * Gets character transformer. 84 | * 85 | * @return the character transformer 86 | */ 87 | @Nonnull 88 | public static Function getCharacterTransformer() { 89 | Map byteEncoder = byteEncoder(); 90 | return x -> { 91 | char[] chars = x.toCharArray(); 92 | for (int i = 0; i < chars.length; i++) { 93 | chars[i] = byteEncoder.getOrDefault(chars[i], chars[i]); 94 | } 95 | return new String(chars); 96 | }; 97 | } 98 | 99 | /** 100 | * Gets vocab size. 101 | * 102 | * @return the vocab size 103 | */ 104 | public int getVocabSize() { 105 | return vocabSize; 106 | } 107 | 108 | /** 109 | * Build decoder tree map. 110 | * 111 | * @param encoder the encoder 112 | * @return the tree map 113 | */ 114 | @Nonnull 115 | public static TreeMap buildDecoder(@Nonnull TreeMap encoder) { 116 | Stream> stream = encoder.entrySet().stream(); 117 | return new TreeMap<>(stream.collect(Collectors.toMap( 118 | (Map.Entry e) -> e.getValue(), 119 | (Map.Entry e) -> e.getKey() 120 | ))); 121 | } 122 | 123 | /** 124 | * Load encoder tree map. 125 | * 126 | * @param file the file 127 | * @return the tree map 128 | */ 129 | @Nonnull 130 | public static TreeMap loadEncoder(@Nonnull File file) { 131 | try { 132 | return toMap(FileUtils.readFileToString(file, "UTF-8"), getCharacterTransformer()); 133 | } catch (IOException e) { 134 | throw Util.throwException(e); 135 | } 136 | } 137 | 138 | /** 139 | * To map tree map. 140 | * 141 | * @param jsonTxt the json txt 142 | * @param keyEncoder the key encoder 143 | * @return the tree map 144 | */ 145 | @Nonnull 146 | public static TreeMap toMap(String jsonTxt, @Nonnull Function keyEncoder) { 147 | JsonObject json = new GsonBuilder().create().fromJson(jsonTxt, JsonObject.class); 148 | return new TreeMap<>(json.keySet().stream().collect(Collectors.toMap(keyEncoder, x -> json.get(x).getAsInt(), (a, b) -> a))); 149 | } 150 | 151 | /** 152 | * Byte encoder map. 153 | * 154 | * @return the map 155 | */ 156 | @Nonnull 157 | public static Map byteEncoder() { 158 | try { 159 | HashMap characterMap = new HashMap<>(); 160 | for (int c = 0; c < 256; c++) { 161 | characterMap.put((char) (c + 256), (char) c); 162 | } 163 | for (char i = '!'; i < '~'; i++) { 164 | characterMap.put(i, i); 165 | } 166 | for (char i = '¡'; i < '¬'; i++) { 167 | characterMap.put(i, i); 168 | } 169 | for (char i = '®'; i < 'ÿ'; i++) { 170 | characterMap.put(i, i); 171 | } 172 | return characterMap; 173 | } catch (Throwable e) { 174 | throw Util.throwException(e); 175 | } 176 | } 177 | 178 | /** 179 | * Decode string. 180 | * 181 | * @param msg the msg 182 | * @return the string 183 | */ 184 | public String decode(@Nonnull Integer... msg) { 185 | return Arrays.stream(msg).map(i -> decoder.getOrDefault(i, "")).reduce((a, b) -> a + b).orElseGet(() -> ""); 186 | } 187 | 188 | /** 189 | * Encode list. 190 | * 191 | * @param msg the msg 192 | * @return the list 193 | */ 194 | @Nonnull 195 | public List encode(@Nullable String msg) { 196 | ArrayList list = new ArrayList<>(); 197 | if (null != msg && !msg.isEmpty()) { 198 | RefStringBuilder stringBuffer = new RefStringBuilder(msg); 199 | while (stringBuffer.length() > 0) { 200 | Optional codeString = lookup(stringBuffer.toString()); 201 | if (codeString.isPresent()) { 202 | String key = RefUtil.get(codeString); 203 | stringBuffer.delete(0, key.length()); 204 | list.add(encoder.get(key)); 205 | } else { 206 | stringBuffer.delete(0, 1); 207 | } 208 | } 209 | } 210 | return list; 211 | } 212 | 213 | /** 214 | * Lookup optional. 215 | * 216 | * @param searchStr the search str 217 | * @return the optional 218 | */ 219 | protected Optional lookup(@Nullable String searchStr) { 220 | if (null == searchStr || searchStr.isEmpty()) return Optional.empty(); 221 | String ceilingKey = encoder.ceilingKey(searchStr); 222 | String floorKey = encoder.floorKey(searchStr); 223 | if (null != ceilingKey && !searchStr.startsWith(ceilingKey)) ceilingKey = null; 224 | if (null != floorKey && !searchStr.startsWith(floorKey)) floorKey = null; 225 | Optional codeString; 226 | if (null != ceilingKey || null != floorKey) { 227 | if (null != ceilingKey && null != floorKey) { 228 | if (floorKey.length() < ceilingKey.length()) { 229 | codeString = Optional.of(ceilingKey); 230 | } else { 231 | codeString = Optional.of(floorKey); 232 | } 233 | } else if (null != ceilingKey) { 234 | codeString = Optional.of(ceilingKey); 235 | } else { 236 | codeString = Optional.of(floorKey); 237 | } 238 | } else { 239 | codeString = Optional.empty(); 240 | } 241 | // codeString = encoder.keySet().stream() 242 | // .filter(x -> x.equals(searchStr.substring(0, Math.min(searchStr.length(), x.length())))) 243 | // .sorted(Comparator.comparing(x -> -x.length())) 244 | // .findFirst(); 245 | return codeString; 246 | } 247 | } 248 | -------------------------------------------------------------------------------- /src/main/java/com/simiacryptus/text/gpt2/GPT2Model.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2019 by Andrew Charneski. 3 | * 4 | * The author licenses this file to you under the 5 | * Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance 7 | * with the License. You may obtain a copy 8 | * of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | */ 19 | 20 | package com.simiacryptus.text.gpt2; 21 | 22 | import com.google.protobuf.InvalidProtocolBufferException; 23 | import com.simiacryptus.ref.wrappers.RefString; 24 | import com.simiacryptus.text.GraphModifier; 25 | import com.simiacryptus.text.LanguageCodeModel; 26 | import com.simiacryptus.text.TextGenerator; 27 | import com.simiacryptus.util.Util; 28 | import org.apache.commons.io.FileUtils; 29 | import org.slf4j.Logger; 30 | import org.slf4j.LoggerFactory; 31 | import org.tensorflow.Graph; 32 | import org.tensorflow.Session; 33 | import org.tensorflow.Tensor; 34 | import org.tensorflow.framework.*; 35 | 36 | import javax.annotation.Nonnull; 37 | import javax.annotation.Nullable; 38 | import java.io.File; 39 | import java.io.IOException; 40 | import java.nio.FloatBuffer; 41 | import java.nio.IntBuffer; 42 | import java.util.*; 43 | import java.util.function.BiFunction; 44 | import java.util.stream.DoubleStream; 45 | import java.util.stream.IntStream; 46 | 47 | /** 48 | * The type Gpt 2 model. 49 | */ 50 | public class GPT2Model implements LanguageCodeModel { 51 | /** 52 | * The constant logger. 53 | */ 54 | protected static final Logger logger = LoggerFactory.getLogger(GPT2Model.class); 55 | 56 | /** 57 | * The Name. 58 | */ 59 | public final String name; 60 | /** 61 | * The Graph def. 62 | */ 63 | protected final byte[] graphDef; 64 | /** 65 | * The Code history. 66 | */ 67 | protected final ArrayList code_history = new ArrayList<>(); 68 | /** 69 | * The Graph modifier. 70 | */ 71 | protected final GraphModifier graphModifier; 72 | /** 73 | * The Codec. 74 | */ 75 | protected final GPT2Codec codec; 76 | /** 77 | * The Loaded subnets. 78 | */ 79 | public HashSet loadedSubnets; 80 | /** 81 | * The Graph. 82 | */ 83 | public Graph graph; 84 | /** 85 | * The Session. 86 | */ 87 | public Session session; 88 | /** 89 | * The History size. 90 | */ 91 | protected int history_size = 0; 92 | /** 93 | * The Tensor state. 94 | */ 95 | @Nullable 96 | protected Tensor tensor_state = null; 97 | private BiFunction filterFn = (a, b) -> true; 98 | 99 | /** 100 | * Instantiates a new Gpt 2 model. 101 | * 102 | * @param name the name 103 | * @param graphModifier the graph modifier 104 | * @param file the file 105 | * @param codec the codec 106 | */ 107 | public GPT2Model(String name, GraphModifier graphModifier, @Nonnull File file, GPT2Codec codec) { 108 | this(name, loadModel(file), graphModifier, codec); 109 | } 110 | 111 | /** 112 | * Instantiates a new Gpt 2 model. 113 | * 114 | * @param name the name 115 | * @param graphDef the graph def 116 | * @param graphModifier the graph modifier 117 | * @param codec the codec 118 | */ 119 | public GPT2Model(String name, byte[] graphDef, GraphModifier graphModifier, GPT2Codec codec) { 120 | this(name, graphDef, graphModifier, codec, new Graph()); 121 | } 122 | 123 | /** 124 | * Instantiates a new Gpt 2 model. 125 | * 126 | * @param name the name 127 | * @param graphDef the graph def 128 | * @param graphModifier the graph modifier 129 | * @param codec the codec 130 | * @param graph the graph 131 | */ 132 | public GPT2Model(String name, byte[] graphDef, GraphModifier graphModifier, GPT2Codec codec, @Nonnull Graph graph) { 133 | this(name, graphDef, graphModifier, codec, graph, new Session(graph, ConfigProto.newBuilder() 134 | //.setLogDevicePlacement(true) 135 | // .setUsePerSessionThreads(true) 136 | // .setInterOpParallelismThreads(8) 137 | // .setIntraOpParallelismThreads(8) 138 | // .setIsolateSessionState(false) 139 | .setGraphOptions(GraphOptions.newBuilder() 140 | .setOptimizerOptions(OptimizerOptions.newBuilder() 141 | .setDoConstantFolding(true) 142 | .setDoFunctionInlining(true) 143 | .setDoCommonSubexpressionElimination(true) 144 | .build()) 145 | .build()) 146 | .setGpuOptions(GPUOptions.newBuilder() 147 | .setForceGpuCompatible(true) 148 | .setAllowGrowth(true) 149 | .setPerProcessGpuMemoryFraction(0.5) 150 | .build()) 151 | .build().toByteArray())); 152 | } 153 | 154 | /** 155 | * Instantiates a new Gpt 2 model. 156 | * 157 | * @param name the name 158 | * @param graphDef the graph def 159 | * @param graphModifier the graph modifier 160 | * @param codec the codec 161 | * @param graph the graph 162 | * @param session the session 163 | */ 164 | public GPT2Model(String name, byte[] graphDef, GraphModifier graphModifier, GPT2Codec codec, Graph graph, Session session) { 165 | this.name = name; 166 | this.graphDef = graphDef; 167 | this.graphModifier = graphModifier; 168 | this.codec = codec; 169 | this.graph = graph; 170 | this.session = session; 171 | loadedSubnets = new HashSet<>(); 172 | } 173 | 174 | @Override 175 | public BiFunction getFilterFn() { 176 | return filterFn; 177 | } 178 | 179 | /** 180 | * Load model byte [ ]. 181 | * 182 | * @param file the file 183 | * @return the byte [ ] 184 | */ 185 | public static byte[] loadModel(@Nonnull File file) { 186 | try { 187 | return FileUtils.readFileToByteArray(file); 188 | } catch (IOException e) { 189 | throw Util.throwException(e); 190 | } 191 | } 192 | 193 | /** 194 | * Copy tensor. 195 | * 196 | * @param toCopy the to copy 197 | * @return the tensor 198 | */ 199 | @Nonnull 200 | public static Tensor copy(@Nonnull Tensor toCopy) { 201 | FloatBuffer floatBuffer = FloatBuffer.allocate(toCopy.numElements()); 202 | toCopy.writeTo(floatBuffer); 203 | floatBuffer.flip(); 204 | return Tensor.create(toCopy.shape(), floatBuffer); 205 | } 206 | 207 | @Nonnull 208 | @Override 209 | public LanguageCodeModel copy() { 210 | GPT2Model copy = new GPT2Model(name, graphDef, graphModifier, this.codec, this.graph, this.session); 211 | if (null == this.tensor_state) { 212 | copy.tensor_state = null; 213 | } else { 214 | copy.tensor_state = copy(this.tensor_state); 215 | } 216 | copy.history_size = this.history_size; 217 | copy.loadedSubnets = this.loadedSubnets; 218 | copy.code_history.addAll(this.code_history); 219 | copy.filterFn = this.filterFn; 220 | return copy; 221 | } 222 | 223 | /** 224 | * Logits to probabilities float [ ]. 225 | * 226 | * @param logits the logits 227 | * @return the float [ ] 228 | */ 229 | @Nonnull 230 | public float[] logitsToProbabilities(@Nonnull float[] logits) { 231 | String prefix = codec.decode(code_history.stream().toArray(i -> new Integer[i])); 232 | int[] sortedIndices = Arrays.stream(TextGenerator.sortedIndices(logits, Integer.MAX_VALUE)) 233 | .filter(item -> { 234 | if (item == logits.length - 1) return true; 235 | String thisStr = codec.decode(item); 236 | assert getFilterFn() != null; 237 | return getFilterFn().apply(prefix, thisStr); 238 | }) 239 | .toArray(); 240 | double[] input = IntStream.range(0, sortedIndices.length).mapToDouble(c -> logits[sortedIndices[c]]).toArray(); 241 | assert 1 < input.length : "input.length() = " + input.length; 242 | 243 | final DoubleSummaryStatistics summaryStatistics = DoubleStream.of(input).filter(x -> Double.isFinite(x)).summaryStatistics(); 244 | final double max = summaryStatistics.getMax(); 245 | @Nullable final double[] exp = Arrays.stream(input).map(x -> { 246 | double xx = Math.exp(x - max); 247 | return Double.isFinite(xx) ? xx : 0; 248 | }).toArray(); 249 | final double sum = 0 < Arrays.stream(exp).sum() ? Arrays.stream(exp).sum() : 1; 250 | assert Double.isFinite(sum); 251 | @Nullable double[] chosen = Arrays.stream(exp).map(x -> x / sum).toArray(); 252 | 253 | for (int i = 0; i < logits.length; i++) logits[i] = 0; 254 | assert chosen != null; 255 | IntStream.range(0, chosen.length).forEach(c -> { 256 | logits[sortedIndices[c]] = (float) chosen[c]; 257 | }); 258 | return logits; 259 | } 260 | 261 | @Nonnull 262 | @Override 263 | public synchronized LanguageCodeModel clear() { 264 | logger.debug("Reset Language Model State"); 265 | if (null != this.tensor_state) this.tensor_state.close(); 266 | this.tensor_state = null; 267 | history_size = 0; 268 | code_history.clear(); 269 | return this; 270 | } 271 | 272 | @Nonnull 273 | @Override 274 | public synchronized float[] eval(int data_X) { 275 | logger.debug(RefString.format("Eval %d", data_X)); 276 | try { 277 | String prefix; 278 | if (!loadedSubnets.contains("")) { 279 | loadedSubnets.add(""); 280 | graph.importGraphDef(this.graphDef); 281 | } 282 | if (null == this.tensor_state) { 283 | prefix = "init/"; 284 | if (!loadedSubnets.contains(prefix)) { 285 | GraphModifier.importGraphDef(graph, this.graphModifier.edit(GraphDef.parseFrom(this.graphDef), prefix, false)); 286 | loadedSubnets.add(prefix); 287 | } 288 | } else { 289 | prefix = ""; 290 | } 291 | this.code_history.add(data_X); 292 | final float[] eval; 293 | if (0 == history_size) { 294 | eval = eval(prefix, data_X); 295 | } else { 296 | final int[] activeCodes = this.code_history 297 | .subList(this.code_history.size() - 1, this.code_history.size()) 298 | .stream().mapToInt(x -> x).toArray(); 299 | eval = eval(prefix, activeCodes); 300 | } 301 | return eval; 302 | } catch (InvalidProtocolBufferException e) { 303 | throw Util.throwException(e); 304 | } 305 | } 306 | 307 | /** 308 | * Eval float [ ]. 309 | * 310 | * @param prefix the prefix 311 | * @param data_X the data x 312 | * @return the float [ ] 313 | */ 314 | @Nonnull 315 | public synchronized float[] eval(String prefix, @Nonnull int... data_X) { 316 | synchronized (session) { 317 | logger.debug(RefString.format("Eval(%s,%s)", session, Arrays.toString(data_X))); 318 | Tensor input_X = Tensor.create(new long[]{1, data_X.length}, IntBuffer.wrap(data_X)); 319 | Session.Runner runner = session.runner().feed("input_X", input_X); 320 | if (null != this.tensor_state) runner = runner.feed(prefix + "input_past", this.tensor_state); 321 | logger.debug("Input Codes: " + Arrays.toString(data_X)); 322 | logger.debug("Input State: " + (this.tensor_state == null ? "null" : Arrays.toString(this.tensor_state.shape()))); 323 | final Tensor prevState = this.tensor_state; 324 | runner = runner 325 | .fetch(prefix + "output/strided_slice_1") 326 | .fetch(0 == history_size ? prefix + "model/stack" : prefix + "output/concat"); 327 | List> run = runner.run(); 328 | Tensor tensor_next = run.get(0).expect(Float.class); 329 | final Tensor outputState = run.get(1).expect(Float.class); // reshape(shape_state, run.get(1).expect(Float.class)); 330 | logger.debug("Output Logits: " + Arrays.toString(tensor_next.shape())); 331 | logger.debug("Output State: " + Arrays.toString(outputState.shape())); 332 | if (null == this.tensor_state) { 333 | this.history_size = (int) outputState.shape()[4]; 334 | this.tensor_state = outputState; 335 | } else { 336 | this.history_size = this.history_size + 1; 337 | this.tensor_state.close(); 338 | this.tensor_state = outputState; 339 | } 340 | float[] logits = new float[tensor_next.numElements()]; 341 | tensor_next.writeTo(FloatBuffer.wrap(logits)); 342 | tensor_next.close(); 343 | if (null != prevState) prevState.close(); 344 | input_X.close(); 345 | return logitsToProbabilities(logits); 346 | } 347 | } 348 | 349 | @Nonnull 350 | @Override 351 | public LanguageCodeModel setFilterFn(BiFunction filterFn) { 352 | this.filterFn = filterFn; 353 | return this; 354 | } 355 | 356 | @Nullable 357 | @Override 358 | public Tensor state() { 359 | return this.tensor_state; 360 | } 361 | 362 | } 363 | -------------------------------------------------------------------------------- /src/main/java/com/simiacryptus/text/gpt2/GPT2Util.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2019 by Andrew Charneski. 3 | * 4 | * The author licenses this file to you under the 5 | * Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance 7 | * with the License. You may obtain a copy 8 | * of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | */ 19 | 20 | package com.simiacryptus.text.gpt2; 21 | 22 | import com.simiacryptus.text.LanguageCodeModel; 23 | import com.simiacryptus.text.SumModel; 24 | import com.simiacryptus.text.TextGenerator; 25 | import com.simiacryptus.util.Util; 26 | import org.apache.commons.io.FileUtils; 27 | import org.apache.commons.io.IOUtils; 28 | 29 | import javax.annotation.Nonnull; 30 | import javax.annotation.Nullable; 31 | import java.io.File; 32 | import java.io.IOException; 33 | import java.net.URI; 34 | import java.security.KeyManagementException; 35 | import java.security.NoSuchAlgorithmException; 36 | import java.util.ArrayList; 37 | import java.util.Arrays; 38 | import java.util.List; 39 | import java.util.TreeSet; 40 | import java.util.stream.Collectors; 41 | import java.util.stream.Stream; 42 | import java.util.zip.ZipFile; 43 | 44 | /** 45 | * The type Gpt 2 util. 46 | */ 47 | public class GPT2Util { 48 | 49 | private static final String MODEL_URL_BASE = System.getProperty( 50 | "GPT2_MODEL_URL", 51 | "https://s3-us-west-2.amazonaws.com/simiacryptus/gpt2/"); 52 | private static final TextGenerator prototype = get345M().setVerbose(false); 53 | 54 | /** 55 | * Gets 345 m. 56 | * 57 | * @return the 345 m 58 | */ 59 | @Nonnull 60 | public static TextGenerator get345M() { 61 | return new TextGenerator(50257, getModel_345M(), getCodec_345M()); 62 | } 63 | 64 | /** 65 | * Gets codec 345 m. 66 | * 67 | * @return the codec 345 m 68 | */ 69 | @Nonnull 70 | public static GPT2Codec getCodec_345M() { 71 | return new GPT2Codec(getEncoderFile_345M(), 50257); 72 | } 73 | 74 | /** 75 | * Gets encoder file 345 m. 76 | * 77 | * @return the encoder file 345 m 78 | */ 79 | @Nonnull 80 | public static File getEncoderFile_345M() { 81 | return loadZippedInternetFile(MODEL_URL_BASE + "encoder_345M.zip", "encoder_345M.json"); 82 | } 83 | 84 | /** 85 | * Gets graph file 345 m. 86 | * 87 | * @return the graph file 345 m 88 | */ 89 | @Nonnull 90 | public static File getGraphFile_345M() { 91 | return loadRawInternetFile(MODEL_URL_BASE, "345M.pb"); 92 | } 93 | 94 | /** 95 | * Gets model 345 m. 96 | * 97 | * @return the model 345 m 98 | */ 99 | @Nonnull 100 | public static GPT2Model getModel_345M() { 101 | return getModel_345M(getGraphFile_345M()); 102 | } 103 | 104 | /** 105 | * Gets text generator. 106 | * 107 | * @return the text generator 108 | */ 109 | @Nonnull 110 | public static TextGenerator getTextGenerator() { 111 | return prototype.copy(); 112 | } 113 | 114 | /** 115 | * Load zipped internet file file. 116 | * 117 | * @param zipUrl the zip url 118 | * @param pathname the pathname 119 | * @return the file 120 | */ 121 | public @Nonnull 122 | static File loadZippedInternetFile(@Nonnull String zipUrl, @Nonnull String pathname) { 123 | File encoderFile = new File(pathname); 124 | if (new File(encoderFile.getName()).exists()) { 125 | encoderFile = new File(encoderFile.getName()); 126 | } else { 127 | try { 128 | try (ZipFile zipFile = new ZipFile(Util.cacheFile(new URI(zipUrl)))) { 129 | byte[] graphDefBytes = IOUtils.toByteArray(zipFile.getInputStream(zipFile.getEntry(pathname))); 130 | encoderFile = new File(encoderFile.getName()); 131 | FileUtils.writeByteArrayToFile(encoderFile, graphDefBytes); 132 | } 133 | } catch (Exception e) { 134 | throw Util.throwException(e); 135 | } 136 | } 137 | return encoderFile; 138 | } 139 | 140 | /** 141 | * Load raw internet file file. 142 | * 143 | * @param urlBase the url base 144 | * @param fileName the file name 145 | * @return the file 146 | */ 147 | public @Nonnull 148 | static File loadRawInternetFile(String urlBase, @Nonnull String fileName) { 149 | File graphFile = new File(fileName); 150 | if (new File(graphFile.getName()).exists()) { 151 | graphFile = new File(graphFile.getName()); 152 | } else { 153 | try { 154 | graphFile = Util.cacheFile(new URI(urlBase + fileName)); 155 | } catch (Exception e) { 156 | throw Util.throwException(e); 157 | } 158 | } 159 | return graphFile; 160 | } 161 | 162 | /** 163 | * Gets model 345 m. 164 | * 165 | * @param file the file 166 | * @return the model 345 m 167 | */ 168 | @Nonnull 169 | public static GPT2Model getModel_345M(@Nonnull File file) { 170 | return getModel_345M("345M", file); 171 | } 172 | 173 | /** 174 | * Gets model 345 m. 175 | * 176 | * @param name the name 177 | * @param file the file 178 | * @return the model 345 m 179 | */ 180 | @Nonnull 181 | public static GPT2Model getModel_345M(String name, @Nonnull File file) { 182 | return new GPT2Model(name, new GPT2Edit_345M(), file, getCodec_345M()); 183 | } 184 | 185 | /** 186 | * Gets text generator. 187 | * 188 | * @param textGenerator the text generator 189 | * @param characterWhitelist the character whitelist 190 | * @param wordlist the wordlist 191 | * @return the text generator 192 | * @throws IOException the io exception 193 | * @throws NoSuchAlgorithmException the no such algorithm exception 194 | * @throws KeyManagementException the key management exception 195 | */ 196 | @Nonnull 197 | public static TextGenerator getTextGenerator(@Nonnull TextGenerator textGenerator, @Nullable String characterWhitelist, @Nullable URI wordlist) throws IOException, NoSuchAlgorithmException, KeyManagementException { 198 | TreeSet wordList = null == wordlist ? null : new TreeSet<>( 199 | Arrays.stream(FileUtils.readFileToString(Util.cacheFile(wordlist), "UTF-8").split("\\s+")) 200 | .map(x -> x.trim().toLowerCase()).collect(Collectors.toSet()) 201 | ); 202 | textGenerator.getModel().setFilterFn((prefix, txt) -> { 203 | if (null != characterWhitelist && !characterWhitelist.isEmpty() && 204 | txt.matches(".*[^" + characterWhitelist + "].*")) return false; 205 | String[] words = txt.split("[^\\w]+"); 206 | if (null != wordList && !wordList.isEmpty()) 207 | for (int i = 0; i < words.length; i++) { 208 | String word = words[i].toLowerCase(); 209 | if (word.isEmpty()) continue; 210 | if (i < words.length - 1 && !wordList.contains(word)) return false; 211 | else { 212 | if (wordList.contains(word)) continue; 213 | String floor = wordList.floor(word); 214 | if (null != floor && floor.startsWith(word)) continue; 215 | String ceiling = wordList.ceiling(word); 216 | if (null != ceiling && ceiling.startsWith(word)) continue; 217 | return false; 218 | } 219 | } 220 | return true; 221 | }); 222 | return textGenerator; 223 | } 224 | 225 | /** 226 | * Gets text generator. 227 | * 228 | * @param seeds the seeds 229 | * @return the text generator 230 | */ 231 | @Nonnull 232 | public static TextGenerator getTextGenerator(String... seeds) { 233 | return getTextGenerator(get345M().setVerbose(false), seeds); 234 | } 235 | 236 | /** 237 | * Gets text generator. 238 | * 239 | * @param base the base 240 | * @param seeds the seeds 241 | * @return the text generator 242 | */ 243 | @Nonnull 244 | public static TextGenerator getTextGenerator(@Nonnull TextGenerator base, String... seeds) { 245 | ArrayList languageCodeModels = new ArrayList<>(); 246 | return getTextGenerator(base, languageCodeModels, seeds); 247 | } 248 | 249 | /** 250 | * Gets text generator. 251 | * 252 | * @param base the base 253 | * @param languageCodeModels the language code models 254 | * @param seeds the seeds 255 | * @return the text generator 256 | */ 257 | @Nonnull 258 | public static TextGenerator getTextGenerator(@Nonnull TextGenerator base, @Nonnull List languageCodeModels, @Nonnull String... seeds) { 259 | base.setModel(new SumModel(Stream.concat( 260 | Arrays.stream(seeds).map(seed -> { 261 | TextGenerator copy = base.copy(); 262 | copy.feed(seed); 263 | return copy.getModel(); 264 | }), 265 | languageCodeModels.stream() 266 | ).toArray(i -> new LanguageCodeModel[i]))); 267 | return base; 268 | } 269 | 270 | /** 271 | * Gets text generator. 272 | * 273 | * @param characterWhitelist the character whitelist 274 | * @param wordlist the wordlist 275 | * @return the text generator 276 | * @throws IOException the io exception 277 | * @throws NoSuchAlgorithmException the no such algorithm exception 278 | * @throws KeyManagementException the key management exception 279 | */ 280 | @Nonnull 281 | protected static TextGenerator getTextGenerator(String characterWhitelist, URI wordlist) throws IOException, NoSuchAlgorithmException, KeyManagementException { 282 | return getTextGenerator(getTextGenerator(), characterWhitelist, wordlist); 283 | } 284 | } 285 | -------------------------------------------------------------------------------- /src/site/site.xml: -------------------------------------------------------------------------------- 1 | 19 | 20 | 21 | 22 | org.apache.maven.skins 23 | maven-fluido-skin 24 | 1.8 25 | 26 | 27 | 28 | true 29 | false 30 | true 31 | 32 | SimiaCryptus/tf-gpt-2 33 | left 34 | black 35 | 36 | release 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 |
50 | 51 | 52 | 53 | -------------------------------------------------------------------------------- /src/test/java/com/simiacryptus/text/gpt2/DevTests.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2019 by Andrew Charneski. 3 | * 4 | * The author licenses this file to you under the 5 | * Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance 7 | * with the License. You may obtain a copy 8 | * of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | */ 19 | 20 | package com.simiacryptus.text.gpt2; 21 | 22 | import com.simiacryptus.ref.lang.RefUtil; 23 | import com.simiacryptus.tensorflow.GraphModel; 24 | import com.simiacryptus.tensorflow.TFUtil; 25 | import org.junit.jupiter.api.Assertions; 26 | import org.junit.jupiter.api.Test; 27 | import org.tensorflow.Output; 28 | import org.tensorflow.TensorFlowException; 29 | 30 | import javax.annotation.Nonnull; 31 | import java.io.File; 32 | import java.util.Arrays; 33 | import java.util.Map; 34 | 35 | import static com.simiacryptus.tensorflow.TFUtil.find; 36 | 37 | /** 38 | * The type Dev tests. 39 | */ 40 | public class DevTests { 41 | 42 | /** 43 | * Model home file. 44 | * 45 | * @return the file 46 | */ 47 | @Nonnull 48 | public static File modelHome() { 49 | return new File(System.getProperty("MODEL_HOME", "H:\\SimiaCryptus\\data-science-tools\\gpt-2\\models")); 50 | } 51 | 52 | /** 53 | * Summary json. 54 | * 55 | * @throws Exception the exception 56 | */ 57 | @Test 58 | public void summaryJson() throws Exception { 59 | TestUtil.open("345M.", new GraphModel(GPT2Model.loadModel(new File(modelHome(), "345M" + ".pb")))); 60 | TestUtil.open("345M_Init.", new GraphModel(GPT2Model.loadModel(new File(modelHome(), "345M_Init" + ".pb")))); 61 | } 62 | 63 | /** 64 | * Compare init. 65 | */ 66 | @Test 67 | public void compare_init() { 68 | final GraphModel a = new GraphModel(GPT2Model.loadModel(new File(modelHome(), "345M" + ".pb"))); 69 | final GraphModel b = new GraphModel(GPT2Model.loadModel(new File(modelHome(), "345M_Init" + ".pb"))); 70 | // TestUtil.open("345M.", a); 71 | // TestUtil.open("345M_Init.", b); 72 | // TestUtil.open("345M_cmp.", b.compare(a)); 73 | new GraphComparer().compare(a, b); 74 | } 75 | 76 | /** 77 | * Test init. 78 | * 79 | * @throws Exception the exception 80 | */ 81 | @Test 82 | public void test_init() throws Exception { 83 | final GraphModel a = new GraphModel(GPT2Model.loadModel(new File(modelHome(), "345M" + ".pb"))); 84 | final GraphModel b = new GraphModel(GPT2Model.loadModel(new File(modelHome(), "345M_Init" + ".pb"))); 85 | final GraphModel edited = new GraphModel(new GPT2Edit_345M().edit(a.graphDef, "", true).toByteArray()); 86 | TestUtil.open("345M_edit.", edited); 87 | final Map compare = b.compare(edited); 88 | TestUtil.open("345M_test.", compare); 89 | compare.values().forEach(deltaRecord -> { 90 | System.out.println("left=" + (null == deltaRecord.left ? "null" : deltaRecord.left.getNodeDef())); 91 | System.out.println("right=" + (null == deltaRecord.right ? "null" : deltaRecord.right.getNodeDef())); 92 | }); 93 | } 94 | 95 | /** 96 | * Add gradient. 97 | * 98 | * @throws Exception the exception 99 | */ 100 | @Test 101 | public void addGradient() throws Exception { 102 | Assertions.assertThrows(TensorFlowException.class, () -> { 103 | try { 104 | byte[] originalGraphDef = GPT2Model.loadModel(new File(modelHome(), "345M" + ".pb")); 105 | byte[] newGraphDef = TFUtil.editGraph(originalGraphDef, graph -> { 106 | graph.addGradients("gradient_", new Output[]{ 107 | find(graph, "model/Reshape_1").output(0), 108 | find(graph, "model/stack").output(0) 109 | }, new Output[]{ 110 | find(graph, "input_past").output(0) 111 | }, null); 112 | }); 113 | System.out.println("GPT2Model: " + TFUtil.toJson(new GraphModel(newGraphDef))); 114 | } catch (Throwable e) { 115 | e.printStackTrace(); 116 | throw e; 117 | } 118 | }); 119 | } 120 | 121 | /** 122 | * Encode. 123 | */ 124 | @Test 125 | public void encode() { 126 | GPT2Codec encoder = new GPT2Codec(new File(modelHome(), "345M" + "\\encoder.json"), 50257); 127 | for (String text : Arrays.asList( 128 | "This is a test", 129 | "<|endoftext|>" 130 | )) { 131 | System.out.println(text + " => " + RefUtil.get(encoder.encode(text).stream().map(x -> x.toString()).reduce((a, b) -> a + ", " + b))); 132 | } 133 | } 134 | 135 | } 136 | -------------------------------------------------------------------------------- /src/test/java/com/simiacryptus/text/gpt2/GraphComparer.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2019 by Andrew Charneski. 3 | * 4 | * The author licenses this file to you under the 5 | * Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance 7 | * with the License. You may obtain a copy 8 | * of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | */ 19 | 20 | package com.simiacryptus.text.gpt2; 21 | 22 | import com.simiacryptus.ref.wrappers.RefString; 23 | import com.simiacryptus.tensorflow.GraphModel; 24 | import org.tensorflow.framework.AttrValue; 25 | import org.tensorflow.framework.NodeDef; 26 | import org.tensorflow.framework.TensorProto; 27 | import org.tensorflow.framework.TensorShapeProto; 28 | 29 | import javax.annotation.Nonnull; 30 | import javax.annotation.Nullable; 31 | import java.nio.FloatBuffer; 32 | import java.nio.IntBuffer; 33 | import java.util.*; 34 | import java.util.function.Consumer; 35 | import java.util.stream.IntStream; 36 | import java.util.stream.Stream; 37 | 38 | /** 39 | * The type Graph comparer. 40 | */ 41 | class GraphComparer implements Consumer { 42 | /** 43 | * The Node deletes. 44 | */ 45 | public final ArrayList nodeDeletes = new ArrayList(); 46 | /** 47 | * The New nodes. 48 | */ 49 | public final ArrayList newNodes = new ArrayList(); 50 | /** 51 | * The Node edits. 52 | */ 53 | public final Map> nodeEdits = new HashMap<>(); 54 | 55 | /** 56 | * To string string. 57 | * 58 | * @param value the value 59 | * @return the string 60 | */ 61 | public static String toString(@Nonnull AttrValue value) { 62 | switch (value.getValueCase()) { 63 | case I: 64 | return RefString.format("AttrValue.newBuilder().setI(%s).build()", value.getI()); 65 | case F: 66 | return RefString.format("AttrValue.newBuilder().setF(%s).build()", value.getF()); 67 | case B: 68 | return RefString.format("AttrValue.newBuilder().setB(%s).build()", value.getB()); 69 | case S: 70 | return RefString.format("AttrValue.newBuilder().setS(%s).build()", value.getS().toStringUtf8()); 71 | case TYPE: 72 | return RefString.format("AttrValue.newBuilder().setType(DataType.forNumber(%s)).build()", value.getType().getNumber()); 73 | case SHAPE: 74 | return RefString.format("AttrValue.newBuilder().setShape(shape(%s)).build()", toString(dims(value.getShape()))); 75 | case TENSOR: 76 | TensorProto tensor = value.getTensor(); 77 | TensorShapeProto shape = tensor.getTensorShape(); 78 | switch (tensor.getDtype()) { 79 | case DT_INT32: { 80 | String shapeElements = shape.getDimList().stream().map(x -> Long.toString(x.getSize())).reduce((a, b) -> a + ", " + b).orElse(""); 81 | return tensor.getIntValList().stream().map(x -> Integer.toString(x)).reduce((a, b) -> a + ", " + b).map(elements -> { 82 | return RefString.format("AttrValue.newBuilder().setTensor(tensor1(new int[]{ %s }, new int[] { %s })).build()", shapeElements, elements); 83 | }).orElseGet(() -> { 84 | IntBuffer intBuffer = tensor.getTensorContent().asReadOnlyByteBuffer().asIntBuffer(); 85 | int[] dst = new int[intBuffer.remaining()]; 86 | intBuffer.get(dst); 87 | String elements = Arrays.stream(dst).map(i1 -> Integer.reverseBytes(i1)).mapToObj(i -> Integer.toString(i)).reduce((a, b) -> a + ", " + b).orElse(""); 88 | return RefString.format("AttrValue.newBuilder().setTensor(tensor2(new int[]{ %s }, new int[] { %s })).build()", shapeElements, elements); 89 | }); 90 | } 91 | case DT_FLOAT: 92 | String shapeElements = shape.getDimList().stream().map(x -> Long.toString(x.getSize())).reduce((a, b) -> a + ", " + b).orElse(""); 93 | return tensor.getFloatValList().stream().map(x -> Float.toString(x)).reduce((a, b) -> a + ", " + b).map(elements -> { 94 | return RefString.format("AttrValue.newBuilder().setTensor(tensor1(new int[]{ %s }, new int[] { %s })).build()", shapeElements, elements); 95 | }).orElseGet(() -> { 96 | FloatBuffer intBuffer = tensor.getTensorContent().asReadOnlyByteBuffer().asFloatBuffer(); 97 | float[] dst = new float[intBuffer.remaining()]; 98 | intBuffer.get(dst); 99 | String elements = IntStream.range(0, dst.length).mapToDouble(i -> dst[i]).mapToObj(d -> Double.toString(d)).reduce((a, b) -> a + ", " + b).orElse(""); 100 | return RefString.format("AttrValue.newBuilder().setTensor(tensor2(new int[]{ %s }, new float[]{ %s })).build()", shapeElements, elements); 101 | }); 102 | } 103 | default: 104 | return "/* " + value.getType() + " - " + value.toString().trim() + " */"; 105 | } 106 | } 107 | 108 | /** 109 | * To string string. 110 | * 111 | * @param dims the dims 112 | * @return the string 113 | */ 114 | @Nonnull 115 | public static String toString(@Nonnull long[] dims) { 116 | return Arrays.stream(dims).mapToObj(size -> Long.toString(size)).reduce((a, b) -> a + ", " + b).orElse(""); 117 | } 118 | 119 | /** 120 | * Dims long [ ]. 121 | * 122 | * @param shape the shape 123 | * @return the long [ ] 124 | */ 125 | public static long[] dims(@Nonnull TensorShapeProto shape) { 126 | return shape.getDimList().stream().mapToLong(dim -> dim.getSize()).toArray(); 127 | } 128 | 129 | /** 130 | * Compare. 131 | * 132 | * @param left the left 133 | * @param right the right 134 | */ 135 | public void compare(@Nonnull GraphModel left, @Nonnull GraphModel right) { 136 | left.compare(right).values().stream().forEach(this); 137 | System.out.println("\n" + 138 | " @Override\n" + 139 | " public HashSet getDeletes_Init() {\n" + 140 | " final HashSet toDelete = new HashSet<>();\n" + 141 | "\n" + this.nodeDeletes.stream().map(s2 -> s2.trim()).reduce((a, b) -> a + "\n" + b).orElse("") + 142 | " return toDelete;\n" + 143 | " }\n" + 144 | "\n" + 145 | " protected void addNodes(Consumer add) {\n" + 146 | "\n" + this.newNodes.stream().map(s1 -> s1.trim()).reduce((a, b) -> a + "\n" + b).orElse("") + 147 | " }\n" + 148 | "\n" + 149 | " @Override\n" + 150 | " public NodeDef.Builder edit(NodeDef.Builder node) {\n" + 151 | "\n" + nodeEdits.entrySet().stream().sorted(Comparator.comparing(x -> x.getKey())).map(e -> 152 | RefString.format("if(node.getName().equals(\"%s\")) {\n%s\n}", e.getKey(), 153 | e.getValue().stream().map(x -> "\t" + x).reduce((a, b) -> a + "\n" + b).orElse("") 154 | )).map(s -> s.trim()).reduce((a, b) -> a + "\nelse " + b).orElse("") + 155 | " else {\n" + 156 | " return null;\n" + 157 | " }\n" + 158 | " return node;\n" + 159 | " }\n"); 160 | } 161 | 162 | @Override 163 | public void accept(@Nonnull GraphModel.DeltaRecord delta) { 164 | if (delta.left == null || delta.right == null) { 165 | if (delta.left == null) { 166 | GraphModel.GraphNode node = delta.right; 167 | this.newNodes.add(RefString.format("add.accept(NodeDef.newBuilder().setName(\"%s\").setOp(\"%s\")%s%s.build());", 168 | node.name, 169 | node.getOp(), 170 | node.getInputKeys().stream().map(s -> '"' + s + '"') 171 | .reduce((a, b) -> a + ", " + b) 172 | .map(s -> ".addAllInput(Arrays.asList(" + s + "))").orElse(""), 173 | node.getNodeDef().getAttrMap().entrySet().stream() 174 | .map(e -> RefString.format(".putAttr(\"%s\",%s)", e.getKey(), toString(e.getValue()))) 175 | .reduce((a, b) -> a + b).orElse("") 176 | ).replaceAll("\\)\\.", ")\n\t.")); 177 | } else { 178 | nodeDeletes.add(RefString.format("toDelete.add(\"%s\");%n", delta.name)); 179 | } 180 | } else { 181 | final NodeDef leftNode = delta.left.getNodeDef(); 182 | final NodeDef rightNode = delta.right.getNodeDef(); 183 | if (null != leftNode && null != rightNode) { 184 | if (!leftNode.getOp().equals(rightNode.getOp())) { 185 | getBuffer(delta.name).add(RefString.format("node.setOp(\"%s\");", rightNode.getOp())); 186 | } 187 | } 188 | assert rightNode != null; 189 | assert leftNode != null; 190 | compare(delta, leftNode.getAttrMap(), rightNode.getAttrMap()); 191 | compareInputs(delta, 192 | delta.left.getInputKeys(), 193 | delta.right.getInputKeys()); 194 | } 195 | } 196 | 197 | /** 198 | * Compare inputs. 199 | * 200 | * @param delta the delta 201 | * @param leftData the left data 202 | * @param rightData the right data 203 | */ 204 | public void compareInputs(@Nonnull GraphModel.DeltaRecord delta, @Nullable List leftData, @Nonnull List rightData) { 205 | if (leftData == null || leftData.size() == 0) { 206 | getBuffer(delta.name).add(rightData.stream().map(input -> 207 | RefString.format("node.addInput(\"%s\");", input) 208 | ).reduce((a, b) -> a + "\n" + b).orElse("")); 209 | } else if (rightData.size() == 0) { 210 | getBuffer(delta.name).add("node.clearInput();"); 211 | } else if (leftData.size() != rightData.size()) { 212 | if (leftData.size() + 1 == rightData.size() && leftData.equals(rightData.subList(0, leftData.size()))) { 213 | getBuffer(delta.name).add(RefString.format("node.addInput(\"%s\");", rightData.get(leftData.size()))); 214 | } else { 215 | System.out.printf("// %s: Input %s vs %s%n", delta.name, 216 | leftData.stream().reduce((a, b) -> a + "," + b).orElse("-"), 217 | rightData.stream().reduce((a, b) -> a + "," + b).orElse("-")); 218 | } 219 | } else { 220 | final int[] mismatchedIndices = IntStream.range(0, leftData.size()).filter(i -> !leftData.get(i).equals(rightData.get(i))).toArray(); 221 | if (1 == mismatchedIndices.length) { 222 | getBuffer(delta.name).add(RefString.format("node.setInput(%s, \"%s\");", mismatchedIndices[0], rightData.get(mismatchedIndices[0]))); 223 | } else if (0 < mismatchedIndices.length) { 224 | getBuffer(delta.name).add(RefString.format("node.clearInput();node.addAllInput(Arrays.asList(%s));", 225 | rightData.stream().map(x -> '"' + x + '"').reduce((a, b) -> a + "," + b).orElseGet(() -> ""))); 226 | } 227 | } 228 | } 229 | 230 | private void compare(@Nonnull GraphModel.DeltaRecord delta, @Nonnull Map left, @Nonnull Map right) { 231 | Stream.concat( 232 | left.keySet().stream(), 233 | right.keySet().stream() 234 | ) 235 | .forEach(key -> { 236 | AttrValue leftValue = left.get(key); 237 | AttrValue rightValue = right.get(key); 238 | if (null == leftValue) { 239 | getBuffer(delta.name).add(RefString.format("node.putAttr(\"%s\", %s);", key, toString(rightValue).trim())); 240 | } else if (null == rightValue) { 241 | getBuffer(delta.name).add(RefString.format("node.removeAttr(\"%s\");", key)); 242 | } else { 243 | if (!leftValue.toString().equals(rightValue.toString())) { 244 | getBuffer(delta.name).add(RefString.format("node.putAttr(\"%s\", %s);", key, toString(rightValue).trim())); 245 | } 246 | } 247 | }); 248 | } 249 | 250 | @Nonnull 251 | private ArrayList getBuffer(String name) { 252 | return nodeEdits.computeIfAbsent(name, k -> new ArrayList()); 253 | } 254 | 255 | } 256 | -------------------------------------------------------------------------------- /src/test/java/com/simiacryptus/text/gpt2/TestUtil.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2019 by Andrew Charneski. 3 | * 4 | * The author licenses this file to you under the 5 | * Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance 7 | * with the License. You may obtain a copy 8 | * of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | */ 19 | 20 | package com.simiacryptus.text.gpt2; 21 | 22 | import com.fasterxml.jackson.databind.ObjectMapper; 23 | import com.fasterxml.jackson.databind.SerializationFeature; 24 | import com.simiacryptus.tensorflow.TFUtil; 25 | import com.simiacryptus.tensorflow.TensorboardEventWriter; 26 | import com.simiacryptus.util.Util; 27 | import org.tensorflow.framework.GraphDef; 28 | 29 | import javax.annotation.Nonnull; 30 | import javax.swing.*; 31 | import java.awt.*; 32 | import java.io.File; 33 | import java.io.FileOutputStream; 34 | import java.io.IOException; 35 | import java.net.URISyntaxException; 36 | 37 | /** 38 | * The type Test util. 39 | */ 40 | public class TestUtil { 41 | /** 42 | * Launch tensorboard. 43 | * 44 | * @param tensorboardDir the tensorboard dir 45 | * @throws IOException the io exception 46 | * @throws URISyntaxException the uri syntax exception 47 | */ 48 | public static void launchTensorboard(@Nonnull File tensorboardDir) throws IOException, URISyntaxException { 49 | TFUtil.launchTensorboard(tensorboardDir.getAbsolutePath(), tensorboard -> { 50 | try { 51 | JOptionPane.showConfirmDialog(null, "Press OK to close"); 52 | tensorboard.destroyForcibly(); 53 | } catch (Exception e) { 54 | throw Util.throwException(e); 55 | } 56 | }); 57 | } 58 | 59 | /** 60 | * Write graph file. 61 | * 62 | * @param graphDef the graph def 63 | * @param location the location 64 | * @param name the name 65 | * @return the file 66 | * @throws IOException the io exception 67 | */ 68 | public static File writeGraph(@Nonnull GraphDef graphDef, File location, @Nonnull String name) throws IOException { 69 | TensorboardEventWriter eventWriter = new TensorboardEventWriter(new File(location, name), graphDef); 70 | eventWriter.write(graphDef); 71 | eventWriter.close(); 72 | return location; 73 | } 74 | 75 | /** 76 | * Open. 77 | * 78 | * @param prefix the prefix 79 | * @param model the model 80 | * @throws IOException the io exception 81 | */ 82 | public static void open(@Nonnull String prefix, Object model) throws IOException { 83 | File tempFile = File.createTempFile(prefix, ".json"); 84 | FileOutputStream fileOutputStream = new FileOutputStream(tempFile); 85 | new ObjectMapper().enable(SerializationFeature.INDENT_OUTPUT).writeValue(fileOutputStream, model); 86 | fileOutputStream.close(); 87 | Desktop.getDesktop().open(tempFile); 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /src/test/java/com/simiacryptus/text/gpt2/UserTests.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2019 by Andrew Charneski. 3 | * 4 | * The author licenses this file to you under the 5 | * Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance 7 | * with the License. You may obtain a copy 8 | * of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | */ 19 | 20 | package com.simiacryptus.text.gpt2; 21 | 22 | import com.simiacryptus.text.TextGenerator; 23 | import org.junit.jupiter.api.Test; 24 | 25 | import java.util.Arrays; 26 | 27 | /** 28 | * The type User tests. 29 | */ 30 | public class UserTests { 31 | 32 | // /** 33 | // * Tensorboard graph. 34 | // */ 35 | // @Test 36 | // public void tensorboardGraph() { 37 | // String now = new SimpleDateFormat("yyyyMMddHHmm").format(new Date()); 38 | // try { 39 | // String id = UUID.randomUUID().toString(); 40 | // File location = new File("target/" + now + "/tensorboard/" + id); 41 | // File graphFile = GPT2Util.getGraphFile_345M(); 42 | // TestUtil.launchTensorboard(TestUtil.writeGraph(GraphDef.parseFrom(GPT2Model.loadModel(graphFile)), location, id)); 43 | // } catch (IOException e) { 44 | // throw Util.throwException(e); 45 | // } catch (URISyntaxException e) { 46 | // throw Util.throwException(e); 47 | // } 48 | // } 49 | 50 | /** 51 | * Generate unconditional text. 52 | */ 53 | @Test 54 | public void generateUnconditionalText() { 55 | TextGenerator textGenerator = GPT2Util.get345M().setVerbose(false); 56 | for (double t = 1.0; t < 3; t *= 1.1) { 57 | textGenerator.getModel(); 58 | System.out.println("Temperature=" + t); 59 | System.out.println(textGenerator.generateText(150)); 60 | } 61 | } 62 | 63 | /** 64 | * Generate conditional text. 65 | */ 66 | @Test 67 | public void generateConditionalText() { 68 | TextGenerator textGenerator = GPT2Util.get345M().setVerbose(false); 69 | for (double t = 1.0; t < 3; t *= 1.1) { 70 | textGenerator.getModel(); 71 | System.out.println("Temperature=" + t); 72 | for (String seed : Arrays.asList( 73 | "Hello World", 74 | "The opposite of up is", 75 | "English: Thank You\nSpanish:", 76 | "public void main(", 77 | "The problem with the world is", 78 | "I love", 79 | "You people are" 80 | )) { 81 | System.out.println(textGenerator.generateText(150, seed)); 82 | } 83 | } 84 | } 85 | 86 | 87 | } 88 | 89 | 90 | -------------------------------------------------------------------------------- /src/test/resources/logback.xml: -------------------------------------------------------------------------------- 1 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | %msg%n 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | --------------------------------------------------------------------------------