├── src ├── main │ ├── python │ │ └── pointer-generator │ │ │ ├── __init__.py │ │ │ ├── run_train.sh │ │ │ ├── run_inference.sh │ │ │ ├── inspect_checkpoint.py │ │ │ ├── flink_writer.py │ │ │ ├── util.py │ │ │ ├── train.py │ │ │ ├── beam_search.py │ │ │ ├── README.md │ │ │ └── LICENSE.txt │ ├── java │ │ ├── me │ │ │ └── littlebo │ │ │ │ ├── SysUtils.java │ │ │ │ ├── MessageSerializationSchema.java │ │ │ │ ├── MessageDeserializationSchema.java │ │ │ │ ├── Message.java │ │ │ │ ├── Summarization.java │ │ │ │ └── App.java │ │ └── org │ │ │ └── apache │ │ │ └── flink │ │ │ └── table │ │ │ └── ml │ │ │ └── lib │ │ │ └── tensorflow │ │ │ ├── param │ │ │ ├── HasTrainOutputCols.java │ │ │ ├── HasTrainSelectedCols.java │ │ │ ├── HasInferenceOutputCols.java │ │ │ ├── HasInferenceSelectedCols.java │ │ │ ├── HasTrainOutputTypes.java │ │ │ ├── HasInferenceOutputTypes.java │ │ │ ├── HasClusterConfig.java │ │ │ ├── HasTrainPythonConfig.java │ │ │ └── HasInferencePythonConfig.java │ │ │ ├── TFModel.java │ │ │ ├── TFEstimator.java │ │ │ └── util │ │ │ └── CodingUtils.java │ └── resources │ │ └── log4j2.xml └── test │ ├── java │ └── org │ │ └── apache │ │ └── flink │ │ └── table │ │ └── ml │ │ └── lib │ │ └── tensorflow │ │ ├── KafkaSourceSinkTest.java │ │ ├── SourceSinkTest.java │ │ └── InputOutputTest.java │ └── python │ └── test.py ├── doc ├── design.bmp ├── design.pptx ├── github │ ├── Issue6 │ │ ├── Issue6.bmp │ │ ├── Issue6.pptx │ │ └── Issue6_PR.bmp │ ├── [Issue] Simplify the ExampleCoding configuration process.md │ ├── [Issue] Exception when there is only InputTfExampleConfig but no OutputTfExampleConfig.md │ ├── Comment to [Issue 2] flink-ai-extended adapter flink ml pipeline.md │ └── [Issue] The streaming inference result needs to wait until the next query to write to sink.md └── deprecated │ ├── Flink-AI-Extended Integration design.md │ └── About StreamExeEnv in AI-Extended Issue.md ├── TextSummarization-On-Flink.iml ├── .gitignore ├── data └── cnn-dailymail │ ├── download_data.sh │ ├── LICENSE.md │ ├── README.md │ └── make_datafiles.py ├── log └── download_model.sh ├── pom.xml └── README.md /src/main/python/pointer-generator/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /doc/design.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittleBBBo/TextSummarization-On-Flink/HEAD/doc/design.bmp -------------------------------------------------------------------------------- /doc/design.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittleBBBo/TextSummarization-On-Flink/HEAD/doc/design.pptx -------------------------------------------------------------------------------- /doc/github/Issue6/Issue6.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittleBBBo/TextSummarization-On-Flink/HEAD/doc/github/Issue6/Issue6.bmp -------------------------------------------------------------------------------- /doc/github/Issue6/Issue6.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittleBBBo/TextSummarization-On-Flink/HEAD/doc/github/Issue6/Issue6.pptx -------------------------------------------------------------------------------- /doc/github/Issue6/Issue6_PR.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittleBBBo/TextSummarization-On-Flink/HEAD/doc/github/Issue6/Issue6_PR.bmp -------------------------------------------------------------------------------- /src/main/java/me/littlebo/SysUtils.java: -------------------------------------------------------------------------------- 1 | package me.littlebo; 2 | 3 | public class SysUtils { 4 | public static String getProjectRootDir() { 5 | return System.getProperty("user.dir"); 6 | } 7 | } 8 | -------------------------------------------------------------------------------- /TextSummarization-On-Flink.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/.DS_Store 2 | 3 | /data/cnn-dailymail/** 4 | /data/raw 5 | !/data/cnn-dailymail/LICENSE.md 6 | !/data/cnn-dailymail/make_datafiles.py 7 | !/data/cnn-dailymail/README.md 8 | !/data/cnn-dailymail/url_lists 9 | !/data/cnn-dailymail/download_data.sh 10 | 11 | /pointer-generator/*.pyc 12 | 13 | /log/** 14 | !/log/download_model.sh 15 | 16 | /temp/** 17 | /venv/** 18 | **/.idea/** -------------------------------------------------------------------------------- /src/main/python/pointer-generator/run_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python run_summarization.py --mode=train --data_path=/Users/bodeng/TextSummarization-On-Flink/data/cnn-dailymail/finished_files/chunked/train_* --vocab_path=/Users/bodeng/TextSummarization-On-Flink/data/cnn-dailymail/finished_files/vocab --log_root=/Users/bodeng/TextSummarization-On-Flink/log --exp_name=pretrained_model_tf1.2.1 --max_enc_steps=400 --max_dec_steps=100 --coverage=1 3 | -------------------------------------------------------------------------------- /src/main/python/pointer-generator/run_inference.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python run_summarization.py --mode=decode --data_path=/Users/bodeng/TextSummarization-On-Flink/data/cnn-dailymail/cnn_stories_test/0* --vocab_path=/Users/bodeng/TextSummarization-On-Flink/data/cnn-dailymail/finished_files/vocab --log_root=/Users/bodeng/TextSummarization-On-Flink/log --exp_name=pretrained_model_tf1.2.1 --max_enc_steps=400 --max_dec_steps=100 --coverage=1 --single_pass=1 --inference=1 3 | -------------------------------------------------------------------------------- /data/cnn-dailymail/download_data.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # script to download processed data from Google Drive 4 | # 5 | # not guaranteed to work indefinitely 6 | # taken from Stack Overflow answer: 7 | # http://stackoverflow.com/a/38937732/7002068 8 | 9 | #gURL='https://drive.google.com/uc?id=0BzQ6rtO2VN95a0c3TlZCWkl3aU0&export=download' 10 | 11 | # match id, more than 26 word characters 12 | #ggID=$(echo "$gURL" | egrep -o '(\w|-){26,}') 13 | 14 | ggID='0BzQ6rtO2VN95a0c3TlZCWkl3aU0' 15 | 16 | ggURL='https://drive.google.com/uc?export=download' 17 | gURL="${ggURL}&id=${ggID}" 18 | 19 | curl -sc /tmp/gcokie "${ggURL}&id=${ggID}" >/dev/null 20 | getcode="$(awk '/_warning_/ {print $NF}' /tmp/gcokie)" 21 | 22 | cmd='curl --insecure -C - -LOJb /tmp/gcokie "${ggURL}&confirm=${getcode}&id=${ggID}"' 23 | echo -e "Downloading from "$gURL"...\n" 24 | eval $cmd 25 | 26 | # unzip data file 27 | unzip finished_files.zip 28 | rm finished_files.zip -------------------------------------------------------------------------------- /log/download_model.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # script to download pretrained model from Google Drive 4 | # 5 | # not guaranteed to work indefinitely 6 | # taken from Stack Overflow answer: 7 | # http://stackoverflow.com/a/38937732/7002068 8 | 9 | #gURL='https://drive.google.com/uc?id=0B7pQmm-OfDv7ZUhHZm9ZWEZidDg&export=download' 10 | 11 | # match id, more than 26 word characters 12 | #ggID=$(echo "$gURL" | egrep -o '(\w|-){26,}') 13 | 14 | ggID='0B7pQmm-OfDv7ZUhHZm9ZWEZidDg' 15 | 16 | ggURL='https://drive.google.com/uc?export=download' 17 | gURL="${ggURL}&id=${ggID}" 18 | 19 | curl -sc /tmp/gcokie "${ggURL}&id=${ggID}" >/dev/null 20 | getcode="$(awk '/_warning_/ {print $NF}' /tmp/gcokie)" 21 | 22 | cmd='curl --insecure -C - -LOJb /tmp/gcokie "${ggURL}&confirm=${getcode}&id=${ggID}"' 23 | echo -e "Downloading from "$gURL"...\n" 24 | eval $cmd 25 | 26 | # unzip data file 27 | unzip pretrained_model_tf1.2.1.zip 28 | rm pretrained_model_tf1.2.1.zip -------------------------------------------------------------------------------- /src/main/resources/log4j2.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | -------------------------------------------------------------------------------- /src/main/java/org/apache/flink/table/ml/lib/tensorflow/param/HasTrainOutputCols.java: -------------------------------------------------------------------------------- 1 | package org.apache.flink.table.ml.lib.tensorflow.param; 2 | 3 | import org.apache.flink.ml.api.misc.param.ParamInfo; 4 | import org.apache.flink.ml.api.misc.param.ParamInfoFactory; 5 | import org.apache.flink.ml.api.misc.param.WithParams; 6 | 7 | /** 8 | * An interface for classes with a parameter specifying the names of multiple output columns. 9 | * @param the actual type of this WithParams, as the return type of setter 10 | */ 11 | public interface HasTrainOutputCols extends WithParams { 12 | ParamInfo TRAIN_OUTPUT_COLS = ParamInfoFactory 13 | .createParamInfo("trainOutputCols", String[].class) 14 | .setDescription("Names of the output columns for train processing") 15 | .setRequired() 16 | .build(); 17 | 18 | default String[] getTrainOutputCols() { 19 | return get(TRAIN_OUTPUT_COLS); 20 | } 21 | 22 | default T setTrainOutputCols(String... value) { 23 | return set(TRAIN_OUTPUT_COLS, value); 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /src/main/java/org/apache/flink/table/ml/lib/tensorflow/param/HasTrainSelectedCols.java: -------------------------------------------------------------------------------- 1 | package org.apache.flink.table.ml.lib.tensorflow.param; 2 | 3 | import org.apache.flink.ml.api.misc.param.ParamInfo; 4 | import org.apache.flink.ml.api.misc.param.ParamInfoFactory; 5 | import org.apache.flink.ml.api.misc.param.WithParams; 6 | 7 | /** 8 | * An interface for classes with a parameter specifying the name of multiple selected input columns. 9 | * @param the actual type of this WithParams, as the return type of setter 10 | */ 11 | public interface HasTrainSelectedCols extends WithParams { 12 | ParamInfo TRAIN_SELECTED_COLS = ParamInfoFactory 13 | .createParamInfo("trainSelectedCols", String[].class) 14 | .setDescription("Names of the columns used for train processing") 15 | .setRequired() 16 | .build(); 17 | 18 | default String[] getTrainSelectedCols() { 19 | return get(TRAIN_SELECTED_COLS); 20 | } 21 | 22 | default T setTrainSelectedCols(String... value) { 23 | return set(TRAIN_SELECTED_COLS, value); 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /src/main/java/org/apache/flink/table/ml/lib/tensorflow/param/HasInferenceOutputCols.java: -------------------------------------------------------------------------------- 1 | package org.apache.flink.table.ml.lib.tensorflow.param; 2 | 3 | import org.apache.flink.ml.api.misc.param.ParamInfo; 4 | import org.apache.flink.ml.api.misc.param.ParamInfoFactory; 5 | import org.apache.flink.ml.api.misc.param.WithParams; 6 | 7 | /** 8 | * An interface for classes with a parameter specifying the names of multiple output columns. 9 | * @param the actual type of this WithParams, as the return type of setter 10 | */ 11 | public interface HasInferenceOutputCols extends WithParams { 12 | ParamInfo INFERENCE_OUTPUT_COLS = ParamInfoFactory 13 | .createParamInfo("inferenceOutputCols", String[].class) 14 | .setDescription("Names of the output columns for inference processing") 15 | .setRequired() 16 | .build(); 17 | 18 | default String[] getInferenceOutputCols() { 19 | return get(INFERENCE_OUTPUT_COLS); 20 | } 21 | 22 | default T setInferenceOutputCols(String... value) { 23 | return set(INFERENCE_OUTPUT_COLS, value); 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /src/main/java/org/apache/flink/table/ml/lib/tensorflow/param/HasInferenceSelectedCols.java: -------------------------------------------------------------------------------- 1 | package org.apache.flink.table.ml.lib.tensorflow.param; 2 | 3 | import org.apache.flink.ml.api.misc.param.ParamInfo; 4 | import org.apache.flink.ml.api.misc.param.ParamInfoFactory; 5 | import org.apache.flink.ml.api.misc.param.WithParams; 6 | 7 | /** 8 | * An interface for classes with a parameter specifying the name of multiple selected input columns. 9 | * @param the actual type of this WithParams, as the return type of setter 10 | */ 11 | public interface HasInferenceSelectedCols extends WithParams { 12 | ParamInfo INFERENCE_SELECTED_COLS = ParamInfoFactory 13 | .createParamInfo("inferenceSelectedCols", String[].class) 14 | .setDescription("Names of the columns used for inference processing") 15 | .setRequired() 16 | .build(); 17 | 18 | default String[] getInferenceSelectedCols() { 19 | return get(INFERENCE_SELECTED_COLS); 20 | } 21 | 22 | default T setInferenceSelectedCols(String... value) { 23 | return set(INFERENCE_SELECTED_COLS, value); 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /data/cnn-dailymail/LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Abi See 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/main/java/org/apache/flink/table/ml/lib/tensorflow/param/HasTrainOutputTypes.java: -------------------------------------------------------------------------------- 1 | package org.apache.flink.table.ml.lib.tensorflow.param; 2 | 3 | import com.alibaba.flink.ml.operator.util.DataTypes; 4 | import org.apache.flink.ml.api.misc.param.ParamInfo; 5 | import org.apache.flink.ml.api.misc.param.ParamInfoFactory; 6 | import org.apache.flink.ml.api.misc.param.WithParams; 7 | 8 | /** 9 | * An interface for classes with a parameter specifying the types of multiple output columns. 10 | * @param the actual type of this WithParams, as the return type of setter 11 | */ 12 | public interface HasTrainOutputTypes extends WithParams { 13 | ParamInfo TRAIN_OUTPUT_TYPES = ParamInfoFactory 14 | .createParamInfo("trainOutputTypes", DataTypes[].class) 15 | .setDescription("TypeInformation of output columns for train processing") 16 | .setRequired() 17 | .build(); 18 | 19 | default DataTypes[] getTrainOutputTypes() { 20 | return get(TRAIN_OUTPUT_TYPES); 21 | } 22 | 23 | default T setTrainOutputTypes(DataTypes[] outputTypes) { 24 | return set(TRAIN_OUTPUT_TYPES, outputTypes); 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /src/main/java/org/apache/flink/table/ml/lib/tensorflow/param/HasInferenceOutputTypes.java: -------------------------------------------------------------------------------- 1 | package org.apache.flink.table.ml.lib.tensorflow.param; 2 | 3 | import com.alibaba.flink.ml.operator.util.DataTypes; 4 | import org.apache.flink.ml.api.misc.param.ParamInfo; 5 | import org.apache.flink.ml.api.misc.param.ParamInfoFactory; 6 | import org.apache.flink.ml.api.misc.param.WithParams; 7 | 8 | /** 9 | * An interface for classes with a parameter specifying the types of multiple output columns. 10 | * @param the actual type of this WithParams, as the return type of setter 11 | */ 12 | public interface HasInferenceOutputTypes extends WithParams { 13 | ParamInfo INFERENCE_OUTPUT_TYPES = ParamInfoFactory 14 | .createParamInfo("inferenceOutputTypes", DataTypes[].class) 15 | .setDescription("TypeInformation of output columns for inference processing") 16 | .setRequired() 17 | .build(); 18 | 19 | default DataTypes[] getInferenceOutputTypes() { 20 | return get(INFERENCE_OUTPUT_TYPES); 21 | } 22 | 23 | default T setInferenceOutputTypes(DataTypes[] outputTypes) { 24 | return set(INFERENCE_OUTPUT_TYPES, outputTypes); 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /src/main/java/me/littlebo/MessageSerializationSchema.java: -------------------------------------------------------------------------------- 1 | package me.littlebo; 2 | 3 | import org.apache.flink.api.common.serialization.SerializationSchema; 4 | import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonProcessingException; 5 | import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper; 6 | import org.apache.flink.types.Row; 7 | import org.slf4j.Logger; 8 | import org.slf4j.LoggerFactory; 9 | 10 | public class MessageSerializationSchema implements SerializationSchema { 11 | private ObjectMapper objectMapper; 12 | private Logger logger = LoggerFactory.getLogger(MessageSerializationSchema.class); 13 | 14 | @Override 15 | public byte[] serialize(Row row) { 16 | if(objectMapper == null) { 17 | objectMapper = new ObjectMapper(); 18 | } 19 | try { 20 | Message message = new Message((String)row.getField(0), (String)row.getField(1), 21 | (String)row.getField(2), (String)row.getField(3)); 22 | return objectMapper.writeValueAsString(message).getBytes(); 23 | } catch (JsonProcessingException e) { 24 | logger.error("Failed to parse JSON", e); 25 | } 26 | return new byte[0]; 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /src/main/python/pointer-generator/inspect_checkpoint.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simple script that checks if a checkpoint is corrupted with any inf/NaN values. Run like this: 3 | python inspect_checkpoint.py model.12345 4 | """ 5 | 6 | import tensorflow as tf 7 | import sys 8 | import numpy as np 9 | 10 | 11 | if __name__ == '__main__': 12 | if len(sys.argv) != 2: 13 | raise Exception("Usage: python inspect_checkpoint.py \nNote: Do not include the .data .index or .meta part of the model checkpoint in file_name.") 14 | file_name = sys.argv[1] 15 | reader = tf.train.NewCheckpointReader(file_name) 16 | var_to_shape_map = reader.get_variable_to_shape_map() 17 | 18 | finite = [] 19 | all_infnan = [] 20 | some_infnan = [] 21 | 22 | for key in sorted(var_to_shape_map.keys()): 23 | tensor = reader.get_tensor(key) 24 | if np.all(np.isfinite(tensor)): 25 | finite.append(key) 26 | else: 27 | if not np.any(np.isfinite(tensor)): 28 | all_infnan.append(key) 29 | else: 30 | some_infnan.append(key) 31 | 32 | print "\nFINITE VARIABLES:" 33 | for key in finite: print key 34 | 35 | print "\nVARIABLES THAT ARE ALL INF/NAN:" 36 | for key in all_infnan: print key 37 | 38 | print "\nVARIABLES THAT CONTAIN SOME FINITE, SOME INF/NAN VALUES:" 39 | for key in some_infnan: print key 40 | 41 | print "" 42 | if not all_infnan and not some_infnan: 43 | print "CHECK PASSED: checkpoint contains no inf/NaN values" 44 | else: 45 | print "CHECK FAILED: checkpoint contains some inf/NaN values" 46 | -------------------------------------------------------------------------------- /src/main/java/me/littlebo/MessageDeserializationSchema.java: -------------------------------------------------------------------------------- 1 | package me.littlebo; 2 | 3 | import org.apache.flink.api.common.serialization.DeserializationSchema; 4 | import org.apache.flink.api.common.typeinfo.TypeInformation; 5 | import org.apache.flink.api.common.typeinfo.Types; 6 | import org.apache.flink.api.java.typeutils.RowTypeInfo; 7 | import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper; 8 | import org.apache.flink.types.Row; 9 | 10 | import java.io.IOException; 11 | import java.util.concurrent.atomic.AtomicInteger; 12 | 13 | public class MessageDeserializationSchema implements DeserializationSchema { 14 | private static ObjectMapper objectMapper = new ObjectMapper(); 15 | private AtomicInteger counter = new AtomicInteger(); 16 | private Integer maxCount; 17 | 18 | public MessageDeserializationSchema(int maxCount) { 19 | this.maxCount = maxCount; 20 | } 21 | 22 | @Override 23 | public Row deserialize(byte[] bytes) throws IOException { 24 | Message message = objectMapper.readValue(bytes, Message.class); 25 | Row row = new Row(4); 26 | row.setField(0, message.getUuid()); 27 | row.setField(1, message.getArticle()); 28 | row.setField(2, message.getSummary()); 29 | row.setField(3, message.getReference()); 30 | counter.incrementAndGet(); 31 | return row; 32 | } 33 | 34 | @Override 35 | public boolean isEndOfStream(Row row) { 36 | if (counter.get() > maxCount) { 37 | return true; 38 | } 39 | return false; 40 | } 41 | 42 | @Override 43 | public TypeInformation getProducedType() { 44 | return new RowTypeInfo(Types.STRING, Types.STRING, Types.STRING,Types.STRING); 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /src/main/java/me/littlebo/Message.java: -------------------------------------------------------------------------------- 1 | package me.littlebo; 2 | 3 | import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty; 4 | import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.annotation.JsonSerialize; 5 | import org.apache.flink.types.Row; 6 | 7 | 8 | import java.io.Serializable; 9 | 10 | @JsonSerialize 11 | public class Message implements Serializable { 12 | @JsonProperty("uuid") 13 | private String uuid; 14 | @JsonProperty("article") 15 | private String article; 16 | @JsonProperty("summary") 17 | private String summary; 18 | @JsonProperty("reference") 19 | private String reference; 20 | 21 | public Message() { 22 | } 23 | 24 | public Message(String uuid, String article, String summary, String reference) { 25 | this.uuid = uuid; 26 | this.article = article; 27 | this.summary = summary; 28 | this.reference = reference; 29 | } 30 | 31 | public String getUuid() { 32 | return uuid; 33 | } 34 | 35 | public String getArticle() { 36 | return article; 37 | } 38 | 39 | public String getSummary() { 40 | return summary; 41 | } 42 | 43 | public String getReference() { 44 | return reference; 45 | } 46 | 47 | public void setUuid(String uuid) { 48 | this.uuid = uuid; 49 | } 50 | 51 | public void setArticle(String article) { 52 | this.article = article; 53 | } 54 | 55 | public void setSummary(String summary) { 56 | this.summary = summary; 57 | } 58 | 59 | public void setReference(String reference) { 60 | this.reference = reference; 61 | } 62 | 63 | public static Row toRow(Message message) { 64 | Row row = new Row(4); 65 | row.setField(0, message.uuid); 66 | row.setField(1, message.article); 67 | row.setField(2, message.summary); 68 | row.setField(3, message.reference); 69 | return row; 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /src/main/java/org/apache/flink/table/ml/lib/tensorflow/param/HasClusterConfig.java: -------------------------------------------------------------------------------- 1 | package org.apache.flink.table.ml.lib.tensorflow.param; 2 | 3 | import org.apache.flink.ml.api.misc.param.ParamInfo; 4 | import org.apache.flink.ml.api.misc.param.ParamInfoFactory; 5 | import org.apache.flink.ml.api.misc.param.WithParams; 6 | 7 | /** 8 | * Parameters for cluster configuration, including: 9 | * 1. zookeeper address 10 | * 2. worker number 11 | * 3. ps number 12 | * @param the actual type of this WithParams, as the return type of setter 13 | */ 14 | public interface HasClusterConfig extends WithParams { 15 | ParamInfo ZOOKEEPER_CONNECT_STR = ParamInfoFactory 16 | .createParamInfo("zookeeper_connect_str", String.class) 17 | .setDescription("zookeeper address to connect") 18 | .setRequired() 19 | .setHasDefaultValue("127.0.0.1:2181").build(); 20 | ParamInfo WORKER_NUM = ParamInfoFactory 21 | .createParamInfo("worker_num", Integer.class) 22 | .setDescription("worker number") 23 | .setRequired() 24 | .setHasDefaultValue(1).build(); 25 | ParamInfo PS_NUM = ParamInfoFactory 26 | .createParamInfo("ps_num", Integer.class) 27 | .setDescription("ps number") 28 | .setRequired() 29 | .setHasDefaultValue(0).build(); 30 | 31 | default String getZookeeperConnStr() { 32 | return get(ZOOKEEPER_CONNECT_STR); 33 | } 34 | 35 | default T setZookeeperConnStr(String zookeeperConnStr) { 36 | return set(ZOOKEEPER_CONNECT_STR, zookeeperConnStr); 37 | } 38 | 39 | default int getWorkerNum() { 40 | return get(WORKER_NUM); 41 | } 42 | 43 | default T setWorkerNum(int workerNum) { 44 | return set(WORKER_NUM, workerNum); 45 | } 46 | 47 | default int getPsNum() { 48 | return get(PS_NUM); 49 | } 50 | 51 | default T setPsNum(int psNum) { 52 | return set(PS_NUM, psNum); 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /doc/github/[Issue] Simplify the ExampleCoding configuration process.md: -------------------------------------------------------------------------------- 1 | ## [[Issue](https://github.com/alibaba/flink-ai-extended/issues/17)] Simplify the ExampleCoding configuration process 2 | 3 | Now the process of configuring ExampleCoding is cumbersome. As a common configuration, I think we can add some tool interfaces to help users simply configure. 4 | The following is the general configuration process under the current version: 5 | 6 | ```java 7 | // configure encode example coding 8 | String strInput = ExampleCodingConfig.createExampleConfigStr(encodeNames, encodeTypes, 9 | entryType, entryClass); 10 | config.getProperties().put(TFConstants.INPUT_TF_EXAMPLE_CONFIG, strInput); 11 | config.getProperties().put(MLConstants.ENCODING_CLASS, 12 | ExampleCoding.class.getCanonicalName()); 13 | 14 | // configure decode example coding 15 | String strOutput = ExampleCodingConfig.createExampleConfigStr(decodeNames, decodeTypes, 16 | entryType, entryClass); 17 | config.getProperties().put(TFConstants.OUTPUT_TF_EXAMPLE_CONFIG, strOutput); 18 | config.getProperties().put(MLConstants.DECODING_CLASS, 19 | ExampleCoding.class.getCanonicalName()); 20 | ``` 21 | 22 | It can be seen that the user needs to know the column names, types, various constants, etc. of the field when configuring example coding. In fact, it can be encapsulated, and the user only needs to provide the input and output table schema to complete the configuration. For example: 23 | 24 | ```java 25 | ExampleCodingConfigUtil.configureExampleCoding(tfConfig, inputSchema, outputSchema, 26 | ExampleCodingConfig.ObjectType.ROW, Row.class); 27 | ``` 28 | 29 | In the current version, the data type that TF can accept is defined in **DataTypes** in Flink-AI-Extended project, and the data type of Flink Table field is defined in **TypeInformation** in Flink project and some basic types such as *BasicTypeInfo* and *BasicArrayTypeInfo* are implemented. But the problem is that the basic types of **DataTypes** and **TypeInformation** are not one-to-one correspondence. 30 | 31 | Therefore, if we want to encapsulate the ExampleCoding configuration process, we need to solve the problem that DataTypes is not compatible with TypeInformation. There are two options: 32 | 33 | 1. Provide a method for converting DataTypes and TypeInformation. Although most of the commonly used types can be matched, they are not completely one-to-one correspondence, so there are some problems that cannot be converted. 34 | 2. Discard DataTypes and use TypeInformation directly in Flink-AI-Extended. DataTypes is just a simple enumeration type that only participates in the identification of data types. TypeInformation can also achieve the same functionality. 35 | 36 | Solution 1 is relatively simple to implement and easy to be compatible, but solution 2 is better in the long run. 37 | 38 | 39 | 40 | -------------------------------------------------------------------------------- /src/main/python/pointer-generator/flink_writer.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | import tensorflow as tf 4 | 5 | 6 | class AbstractWriter(object): 7 | def build_graph(self): 8 | return 9 | 10 | def write_result(self, sess, uuid, article, abstract, reference): 11 | return 12 | 13 | def close(self, sess): 14 | return 15 | 16 | 17 | class FlinkWriter(AbstractWriter): 18 | def __init__(self, tf_context): 19 | self._context = tf_context 20 | 21 | def build_graph(self): 22 | self._write_feed = tf.placeholder(dtype=tf.string) 23 | self.write_op, self._close_op = self._context.output_writer_op([self._write_feed]) 24 | 25 | def write_result(self, sess, uuid, article, abstract, reference): 26 | example = tf.train.Example(features=tf.train.Features( 27 | feature={ 28 | 'uuid': tf.train.Feature(bytes_list=tf.train.BytesList(value=[uuid])), 29 | 'article': tf.train.Feature(bytes_list=tf.train.BytesList(value=[article])), 30 | 'summary': tf.train.Feature(bytes_list=tf.train.BytesList(value=[abstract])), 31 | 'reference': tf.train.Feature(bytes_list=tf.train.BytesList(value=[reference])), 32 | } 33 | )) 34 | sess.run(self.write_op, feed_dict={self._write_feed: example.SerializeToString()}) 35 | 36 | def close(self, sess): 37 | sess.run(self._close_op) 38 | 39 | 40 | class AbstractFlinkWriter(object): 41 | __metaclass__ = ABCMeta 42 | 43 | def __init__(self, context): 44 | """ 45 | Initialize the writer 46 | :param context: TFContext 47 | """ 48 | self._context = context 49 | self._build_graph() 50 | 51 | def _build_graph(self): 52 | self._write_feed = tf.placeholder(dtype=tf.string) 53 | self.write_op, self._close_op = self._context.output_writer_op([self._write_feed]) 54 | 55 | @abstractmethod 56 | def _example(self, results): 57 | # TODO: Implement this function using context automatically 58 | """ 59 | Encode the results tensor to `tf.train.Example` 60 | 61 | Examples: 62 | ``` 63 | example = tf.train.Example(features=tf.train.Features( 64 | feature={ 65 | 'predict_label': tf.train.Feature(int64_list=tf.train.Int64List(value=[results[0]])), 66 | 'label_org': tf.train.Feature(int64_list=tf.train.Int64List(value=[results[1]])), 67 | 'abstract': tf.train.Feature(bytes_list=tf.train.BytesList(value=[results[2]])), 68 | } 69 | )) 70 | return example 71 | ``` 72 | 73 | :param results: the result list of `Tensor` to write into Flink 74 | :return: An Example. 75 | """ 76 | pass 77 | 78 | def write_result(self, sess, results): 79 | """ 80 | Encode the results tensor and write to Flink. 81 | """ 82 | sess.run(self.write_op, feed_dict={self._write_feed: self._example(results).SerializeToString()}) 83 | 84 | def close(self, sess): 85 | """ 86 | close writer 87 | :param sess: the session to execute operator 88 | """ 89 | sess.run(self._close_op) 90 | -------------------------------------------------------------------------------- /doc/deprecated/Flink-AI-Extended Integration design.md: -------------------------------------------------------------------------------- 1 | ## Flink-AI-Extended Integration design 2 | 3 | ### TFEstimator: 4 | 5 | ```java 6 | /** 7 | * A general TensorFlow Estimator who need general TFConfig, 8 | * the fit() function is actually executed by Flink-AI-Extended 9 | */ 10 | public interface TFEstimator, M extends TFModel> 11 | implements WithTFConfigParams extends Estimator { 12 | } 13 | ``` 14 | 15 | ### TFModel: 16 | 17 | ```java 18 | /** 19 | * A general TensorFlow Model who need general TFConfig, 20 | * the transform() function is actually executed by Flink-AI-Extended 21 | */ 22 | public interface TFModel> 23 | implements WithTFConfigParams extends Model { 24 | } 25 | ``` 26 | 27 | ### WithTFConfigParams: 28 | 29 | ```java 30 | /** 31 | * Calling python scripts via Flink-AI-Extended generally requires the following parameters: 32 | * 1. Zookeeper address: String 33 | * 2. Python scripts path: String[], the first file will be the entry 34 | * 3. Worker number: Int 35 | * 4. Ps number: Int 36 | * 5. Map function: String, the function in entry file to be called 37 | * 6. Virtual environment path: String, could be null 38 | */ 39 | public interface WithTFConfigParams> extends WithZookeeperAddressParams, 40 | WithPythonScriptsParams, 41 | WithWorkerNumParams, 42 | WithPsNumParams, 43 | WithMapFunctionParams, 44 | WithEnvironmentPathParams { 45 | 46 | } 47 | ``` 48 | 49 | ### TFSummaryEstimator: 50 | 51 | ```java 52 | /** 53 | * A specific document summarization estimator, 54 | * the training process needs to specify the input article column 55 | * and the input abstract column, 56 | * in addition to the hyperparameter for TensorFlow model 57 | */ 58 | public class TFSummaryEstimator implements TFEstimator, 59 | WithInputArticleCol, 60 | WithInputAbstractCol, 61 | WithHyperParams { 62 | private Params params = new Params(); 63 | 64 | @Override 65 | public TFSummaryModel fit(TableEnviroment tEnv, Table input) { 66 | //TODO: call training process through Flink-AI-Extended 67 | } 68 | 69 | @Override 70 | public Params getParams() { 71 | return params; 72 | } 73 | } 74 | ``` 75 | 76 | ### TFSummaryModel: 77 | 78 | ```java 79 | /** 80 | * A specific document summarization estimator, 81 | * the inference process needs to specify the input article column 82 | * and the output abstract column, 83 | * in addition to the model path and hyperparameter for TensorFlow model 84 | */ 85 | public class TFSummaryModel implements TFModel, 86 | WithInputArticleCol, 87 | WithOutputAbstractCol, 88 | WithModelPathParams, 89 | WithHyperParams { 90 | private Params params = new Params(); 91 | 92 | @Override 93 | public Table transform(TableEnvironment tEnv, Table input) { 94 | //TODO: call inference process through Flink-AI-Extended 95 | } 96 | 97 | @Override 98 | public Params getParams() { 99 | return params; 100 | } 101 | } 102 | ``` 103 | 104 | 105 | 106 | -------------------------------------------------------------------------------- /doc/deprecated/About StreamExeEnv in AI-Extended Issue.md: -------------------------------------------------------------------------------- 1 | ## About StreamExeEnv in AI-Extended Issue 2 | 3 | 我在把Flink-AI-Extended兼容整合进Pipeline框架时存在缺少**StreamExecutionEnvironment**的问题。 4 | 5 | 现版本ML pipeline框架的接口定义如下,fit()与transform()方法均只接收input table所绑定的**TableEnv**,该**TableEnv**是由上层**ExecutionEnvironment**创建的(或**StreamExeEnv**或**BatchExeEnv**)。但问题是现在的Flink-AI-Extended里所有的train或inference方法都是基于**StreamExeEnv**,需同时传入**StreamExeEnv**和**TableEnv**,而**StreamExeEnv**在ML pipeline框架下是不一定存在的,这就会导致无法兼容现有的框架。 6 | 7 | ```java 8 | public interface Transformer> extends PipelineStage { 9 | Table transform(TableEnvironment tEnv, Table input); 10 | } 11 | 12 | public interface Model> extends Transformer { 13 | } 14 | 15 | public interface Estimator, M extends Model> 16 | extends PipelineStage { 17 | M fit(TableEnvironment tEnv, Table input); 18 | } 19 | ``` 20 | 21 | 我查了一下**StreamExeEnv**在Flink-AI-Extended里的引用,如下所示,主要用在两个地方: 22 | 23 | 1. 注册python缓存文件。 24 | 25 | ```java 26 | package com.alibaba.flink.ml.operator.util; 27 | public class PythonFileUtil { 28 | //... 29 | private static List registerPythonLibFiles( 30 | StreamExecutionEnvironment env, 31 | String... userPyLibs) throws IOException { 32 | Tuple2, List> tuple2 = convertFiles(userPyLibs); 33 | Map files = (Map)tuple2.f0; 34 | files.forEach((name, uri) -> { 35 | env.registerCachedFile(uri.toString(), name); 36 | }); 37 | return (List)tuple2.f1; 38 | } 39 | //... 40 | } 41 | ``` 42 | 43 | 2. 为不同角色的job添加source。 44 | 45 | ```java 46 | package com.alibaba.flink.ml.operator.client; 47 | public class RoleUtils { 48 | //... 49 | public static DataStream addRole( 50 | StreamExecutionEnvironment streamEnv, 51 | ExecutionMode mode, 52 | DataStream input, 53 | MLConfig mlConfig, 54 | TypeInformation outTI, 55 | BaseRole role) { 56 | if (null != input) { 57 | mlConfig.addProperty("job_has_input", "true"); 58 | } 59 | 60 | TypeInformation workerTI = outTI == null ? DUMMY_TI : outTI; 61 | DataStream worker = null; 62 | int workerParallelism = (Integer)mlConfig.getRoleParallelismMap().get(role.name()); 63 | if (input == null) { 64 | worker = streamEnv 65 | .addSource(NodeSource.createSource(mode, role, mlConfig, workerTI)) 66 | .setParallelism(workerParallelism) 67 | .name(role.name()); 68 | } else { 69 | FlatMapFunction flatMapper = new MLFlatMapOp(mode, role, mlConfig, 70 | input.getType(), workerTI); 71 | worker = input 72 | .flatMap(flatMapper) 73 | .setParallelism(workerParallelism) 74 | .name(role.name()); 75 | } 76 | 77 | if (outTI == null && worker != null) { 78 | worker 79 | .addSink(new DummySink()) 80 | .setParallelism(workerParallelism) 81 | .name("Dummy sink"); 82 | } 83 | 84 | return worker; 85 | } 86 | //... 87 | } 88 | ``` 89 | 90 | 我对ExecutionEnvironment不太熟悉,不知道这两部分的设置是否有其他方式实现?或者是否有其他方法能解决现在无法兼容的问题。特意请教一下! -------------------------------------------------------------------------------- /src/main/java/org/apache/flink/table/ml/lib/tensorflow/param/HasTrainPythonConfig.java: -------------------------------------------------------------------------------- 1 | package org.apache.flink.table.ml.lib.tensorflow.param; 2 | 3 | import org.apache.flink.ml.api.misc.param.ParamInfo; 4 | import org.apache.flink.ml.api.misc.param.ParamInfoFactory; 5 | import org.apache.flink.ml.api.misc.param.WithParams; 6 | 7 | /** 8 | * Parameters for python configuration in training process, including: 9 | * 1. paths of python scripts 10 | * 2. entry function in main python file 11 | * 3. key to get hyper parameter in python 12 | * 4. hyper parameter for python 13 | * 5. virtual environment path 14 | * @param the actual type of this WithParams, as the return type of setter 15 | */ 16 | public interface HasTrainPythonConfig extends WithParams { 17 | ParamInfo TRAIN_SCRIPTS = ParamInfoFactory 18 | .createParamInfo("train_scripts", String[].class) 19 | .setDescription("python scripts path, the first file entry, for train processing") 20 | .setRequired().build(); 21 | ParamInfo TRAIN_MAP_FUNC = ParamInfoFactory 22 | .createParamInfo("train_map_func", String.class) 23 | .setDescription("the entry function in entry file to be called, for train processing") 24 | .setRequired().build(); 25 | ParamInfo TRAIN_HYPER_PARAMS_KEY = ParamInfoFactory 26 | .createParamInfo("train_hyper_params_key", String.class) 27 | .setDescription("the key name to get hyper params from context inf TensorFlow, for train processing") 28 | .setRequired() 29 | .build(); 30 | ParamInfo TRAIN_HYPER_PARAMS = ParamInfoFactory 31 | .createParamInfo("train_hyper_params", String[].class) 32 | .setDescription("hyper params for TensorFlow, each param format is '--param1=value1', for train processing") 33 | .setRequired() 34 | .setHasDefaultValue(new String[]{}).build(); 35 | ParamInfo TRAIN_ENV_PATH = ParamInfoFactory 36 | .createParamInfo("train_env_path", String.class) 37 | .setDescription("virtual environment path, for train processing") 38 | .setOptional() 39 | .setHasDefaultValue(null).build(); 40 | 41 | default String[] getTrainScripts() { 42 | return get(TRAIN_SCRIPTS); 43 | } 44 | 45 | default T setTrainScripts(String[] scripts) { 46 | return set(TRAIN_SCRIPTS, scripts); 47 | } 48 | 49 | default String getTrainMapFunc() { 50 | return get(TRAIN_MAP_FUNC); 51 | } 52 | 53 | default T setTrainMapFunc(String mapFunc) { 54 | return set(TRAIN_MAP_FUNC, mapFunc); 55 | } 56 | 57 | default String getTrainHyperParamsKey() { 58 | return get(TRAIN_HYPER_PARAMS_KEY); 59 | } 60 | 61 | default T setTrainHyperParamsKey(String key) { 62 | return set(TRAIN_HYPER_PARAMS_KEY, key); 63 | } 64 | 65 | default String[] getTrainHyperParams() { 66 | return get(TRAIN_HYPER_PARAMS); 67 | } 68 | 69 | default T setTrainHyperParams(String[] hyperParams) { 70 | return set(TRAIN_HYPER_PARAMS, hyperParams); 71 | } 72 | 73 | default String getTrainEnvPath() { 74 | return get(TRAIN_ENV_PATH); 75 | } 76 | 77 | default T setTrainEnvPath(String envPath) { 78 | return set(TRAIN_ENV_PATH, envPath); 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /src/main/java/org/apache/flink/table/ml/lib/tensorflow/param/HasInferencePythonConfig.java: -------------------------------------------------------------------------------- 1 | package org.apache.flink.table.ml.lib.tensorflow.param; 2 | 3 | import org.apache.flink.ml.api.misc.param.ParamInfo; 4 | import org.apache.flink.ml.api.misc.param.ParamInfoFactory; 5 | import org.apache.flink.ml.api.misc.param.WithParams; 6 | 7 | /** 8 | * Parameters for python configuration in inference process, including: 9 | * 1. paths of python scripts 10 | * 2. entry function in main python file 11 | * 3. key to get hyper parameter in python 12 | * 4. hyper parameter for python 13 | * 5. virtual environment path 14 | * @param the actual type of this WithParams, as the return type of setter 15 | */ 16 | public interface HasInferencePythonConfig extends WithParams { 17 | ParamInfo INFERENCE_SCRIPTS = ParamInfoFactory 18 | .createParamInfo("inference_scripts", String[].class) 19 | .setDescription("python scripts path, the first file entry, for inference processing") 20 | .setRequired().build(); 21 | ParamInfo INFERENCE_MAP_FUNC = ParamInfoFactory 22 | .createParamInfo("inference_map_func", String.class) 23 | .setDescription("the entry function in entry file to be called, for inference processing") 24 | .setRequired().build(); 25 | ParamInfo INFERENCE_HYPER_PARAMS_KEY = ParamInfoFactory 26 | .createParamInfo("inference_hyper_params_key", String.class) 27 | .setDescription("the key name to get hyper params from context inf TensorFlow, for inference processing") 28 | .setRequired() 29 | .build(); 30 | ParamInfo INFERENCE_HYPER_PARAMS = ParamInfoFactory 31 | .createParamInfo("inference_hyper_params", String[].class) 32 | .setDescription("hyper params for TensorFlow, each param format is '--param1=value1', for inference processing") 33 | .setRequired() 34 | .setHasDefaultValue(new String[]{}).build(); 35 | ParamInfo INFERENCE_ENV_PATH = ParamInfoFactory 36 | .createParamInfo("inference_env_path", String.class) 37 | .setDescription("virtual environment path, for inference processing") 38 | .setOptional() 39 | .setHasDefaultValue(null).build(); 40 | 41 | default String[] getInferenceScripts() { 42 | return get(INFERENCE_SCRIPTS); 43 | } 44 | 45 | default T setInferenceScripts(String[] scripts) { 46 | return set(INFERENCE_SCRIPTS, scripts); 47 | } 48 | 49 | default String getInferenceMapFunc() { 50 | return get(INFERENCE_MAP_FUNC); 51 | } 52 | 53 | default T setInferenceMapFunc(String mapFunc) { 54 | return set(INFERENCE_MAP_FUNC, mapFunc); 55 | } 56 | 57 | default String getInferenceHyperParamsKey() { 58 | return get(INFERENCE_HYPER_PARAMS_KEY); 59 | } 60 | 61 | default T setInferenceHyperParamsKey(String key) { 62 | return set(INFERENCE_HYPER_PARAMS_KEY, key); 63 | } 64 | 65 | default String[] getInferenceHyperParams() { 66 | return get(INFERENCE_HYPER_PARAMS); 67 | } 68 | 69 | default T setInferenceHyperParams(String[] hyperParams) { 70 | return set(INFERENCE_HYPER_PARAMS, hyperParams); 71 | } 72 | 73 | default String getInferenceEnvPath() { 74 | return get(INFERENCE_ENV_PATH); 75 | } 76 | 77 | default T setInferenceEnvPath(String envPath) { 78 | return set(INFERENCE_ENV_PATH, envPath); 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /data/cnn-dailymail/README.md: -------------------------------------------------------------------------------- 1 | This code produces the non-anonymized version of the CNN / Daily Mail summarization dataset, as used in the ACL 2017 paper *[Get To The Point: Summarization with Pointer-Generator Networks](https://arxiv.org/pdf/1704.04368.pdf)*. It processes the dataset into the binary format expected by the [code](https://github.com/abisee/pointer-generator) for the Tensorflow model. 2 | 3 | **Python 3 version**: This code is in Python 2. If you want a Python 3 version, see [@becxer's fork](https://github.com/becxer/cnn-dailymail/). 4 | 5 | # Option 1: download the processed data 6 | User @JafferWilson has provided the processed data, which you can download [here](https://github.com/JafferWilson/Process-Data-of-CNN-DailyMail). (See discussion [here](https://github.com/abisee/cnn-dailymail/issues/9) about why we do not provide it ourselves). 7 | 8 | # Option 2: process the data yourself 9 | 10 | ## 1. Download data 11 | Download and unzip the `stories` directories from [here](http://cs.nyu.edu/~kcho/DMQA/) for both CNN and Daily Mail. 12 | 13 | **Warning:** These files contain a few (114, in a dataset of over 300,000) examples for which the article text is missing - see for example `cnn/stories/72aba2f58178f2d19d3fae89d5f3e9a4686bc4bb.story`. The [Tensorflow code](https://github.com/abisee/pointer-generator) has been updated to discard these examples. 14 | 15 | ## 2. Download Stanford CoreNLP 16 | We will need Stanford CoreNLP to tokenize the data. Download it [here](https://stanfordnlp.github.io/CoreNLP/) and unzip it. Then add the following command to your bash_profile: 17 | ``` 18 | export CLASSPATH=/path/to/stanford-corenlp-full-2016-10-31/stanford-corenlp-3.7.0.jar 19 | ``` 20 | replacing `/path/to/` with the path to where you saved the `stanford-corenlp-full-2016-10-31` directory. You can check if it's working by running 21 | ``` 22 | echo "Please tokenize this text." | java edu.stanford.nlp.process.PTBTokenizer 23 | ``` 24 | You should see something like: 25 | ``` 26 | Please 27 | tokenize 28 | this 29 | text 30 | . 31 | PTBTokenizer tokenized 5 tokens at 68.97 tokens per second. 32 | ``` 33 | ## 3. Process into .bin and vocab files 34 | Run 35 | ``` 36 | python make_datafiles.py /path/to/cnn/stories /path/to/dailymail/stories 37 | ``` 38 | replacing `/path/to/cnn/stories` with the path to where you saved the `cnn/stories` directory that you downloaded; similarly for `dailymail/stories`. 39 | 40 | This script will do several things: 41 | * The directories `cnn_stories_tokenized` and `dm_stories_tokenized` will be created and filled with tokenized versions of `cnn/stories` and `dailymail/stories`. This may take some time. ***Note**: you may see several `Untokenizable:` warnings from Stanford Tokenizer. These seem to be related to Unicode characters in the data; so far it seems OK to ignore them.* 42 | * For each of the url lists `all_train.txt`, `all_val.txt` and `all_test.txt`, the corresponding tokenized stories are read from file, lowercased and written to serialized binary files `train.bin`, `val.bin` and `test.bin`. These will be placed in the newly-created `finished_files` directory. This may take some time. 43 | * Additionally, a `vocab` file is created from the training data. This is also placed in `finished_files`. 44 | * Lastly, `train.bin`, `val.bin` and `test.bin` will be split into chunks of 1000 examples per chunk. These chunked files will be saved in `finished_files/chunked` as e.g. `train_000.bin`, `train_001.bin`, ..., `train_287.bin`. This should take a few seconds. You can use either the single files or the chunked files as input to the Tensorflow code (see considerations [here](https://github.com/abisee/cnn-dailymail/issues/3)). 45 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 4.0.0 6 | 7 | bodeng 8 | abstract-on-flink 9 | 1.0-SNAPSHOT 10 | 11 | 12 | 20.0 13 | 2.7.1 14 | 1.8 15 | 4.12 16 | 1.8.0 17 | 1.9-SNAPSHOT 18 | 0.1.0 19 | 2.3.0 20 | 21 | 22 | 23 | 24 | 25 | org.apache.flink 26 | flink-ml-lib 27 | ${flink-ml.version} 28 | 29 | 30 | org.apache.flink 31 | flink-table-api-java 32 | 33 | 34 | 35 | 36 | com.alibaba.flink.ml 37 | flink-ml-tensorflow 38 | ${flink-ml-tensorflow.version} 39 | 40 | 41 | org.apache.flink 42 | flink-connector-kafka_2.11 43 | ${flink.version} 44 | 45 | 46 | 47 | org.apache.kafka 48 | kafka_2.11 49 | ${kafka.version} 50 | 51 | 52 | org.apache.kafka 53 | kafka-clients 54 | ${kafka.version} 55 | 56 | 57 | org.apache.curator 58 | curator-framework 59 | ${curator.version} 60 | 61 | 62 | org.apache.curator 63 | curator-test 64 | ${curator.version} 65 | 66 | 67 | com.google.guava 68 | guava 69 | 70 | 71 | 72 | 73 | com.google.guava 74 | guava 75 | ${guava.version} 76 | 77 | 78 | 79 | junit 80 | junit 81 | ${junit.version} 82 | 83 | 84 | 85 | 86 | 87 | 88 | org.apache.maven.plugins 89 | maven-compiler-plugin 90 | 3.1 91 | 92 | ${java.version} 93 | ${java.version} 94 | 95 | 96 | 97 | 98 | 99 | -------------------------------------------------------------------------------- /src/main/java/org/apache/flink/table/ml/lib/tensorflow/TFModel.java: -------------------------------------------------------------------------------- 1 | package org.apache.flink.table.ml.lib.tensorflow; 2 | 3 | import com.alibaba.flink.ml.tensorflow.client.TFConfig; 4 | import com.alibaba.flink.ml.tensorflow.client.TFUtils; 5 | import com.alibaba.flink.ml.tensorflow.coding.ExampleCodingConfig; 6 | import com.alibaba.flink.ml.util.MLConstants; 7 | import org.apache.flink.ml.api.core.Model; 8 | import org.apache.flink.ml.api.misc.param.Params; 9 | 10 | import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; 11 | import org.apache.flink.table.api.Table; 12 | import org.apache.flink.table.api.TableEnvironment; 13 | import org.apache.flink.table.api.TableSchema; 14 | import org.apache.flink.table.api.java.StreamTableEnvironment; 15 | import org.apache.flink.table.ml.lib.tensorflow.param.*; 16 | import org.apache.flink.table.ml.lib.tensorflow.util.CodingUtils; 17 | import org.apache.flink.types.Row; 18 | import org.slf4j.Logger; 19 | import org.slf4j.LoggerFactory; 20 | 21 | import java.util.HashMap; 22 | import java.util.Map; 23 | 24 | /** 25 | * A general TensorFlow model implemented by Flink-AI-Extended, 26 | * is usually generated by an {@link TFEstimator} 27 | * when {@link TFEstimator#fit(TableEnvironment, Table)} is invoked. 28 | */ 29 | public class TFModel implements Model, HasClusterConfig, HasInferencePythonConfig, 30 | HasInferenceSelectedCols, HasInferenceOutputCols, HasInferenceOutputTypes { 31 | private static final Logger LOG = LoggerFactory.getLogger(TFModel.class); 32 | private Params params = new Params(); 33 | 34 | protected Table configureInputTable(Table rawTable) { 35 | return rawTable.select(String.join(",", getInferenceSelectedCols())); 36 | } 37 | 38 | protected TableSchema configureOutputSchema() { 39 | return new TableSchema(getInferenceOutputCols(), 40 | CodingUtils.dataTypesListToTypeInformation(getInferenceOutputTypes())); 41 | } 42 | 43 | protected TFConfig configureTFConfig() { 44 | Map prop = new HashMap<>(); 45 | prop.put(MLConstants.CONFIG_STORAGE_TYPE, MLConstants.STORAGE_ZOOKEEPER); 46 | prop.put(MLConstants.CONFIG_ZOOKEEPER_CONNECT_STR, getZookeeperConnStr()); 47 | prop.put(getInferenceHyperParamsKey(), String.join(" ", getInferenceHyperParams())); 48 | return new TFConfig(getWorkerNum(), getPsNum(), prop, getInferenceScripts(), getInferenceMapFunc(), getInferenceEnvPath()); 49 | } 50 | 51 | protected void configureExampleCoding(TFConfig config, TableSchema inputSchema, TableSchema outputSchema) { 52 | CodingUtils.configureExampleCoding(config, inputSchema, outputSchema, ExampleCodingConfig.ObjectType.ROW, Row.class); 53 | } 54 | 55 | @Override 56 | public Table transform(TableEnvironment tableEnvironment, Table table) { 57 | StreamExecutionEnvironment streamEnv; 58 | try { 59 | // TODO: [hack] transform table to dataStream to get StreamExecutionEnvironment 60 | if (tableEnvironment instanceof StreamTableEnvironment) { 61 | StreamTableEnvironment streamTableEnvironment = (StreamTableEnvironment)tableEnvironment; 62 | streamEnv = streamTableEnvironment.toAppendStream(table, Row.class).getExecutionEnvironment(); 63 | } else { 64 | throw new RuntimeException("Unsupported TableEnvironment, please use StreamTableEnvironment"); 65 | } 66 | 67 | // Select the necessary columns according to "SelectedCols" 68 | Table inputTable = configureInputTable(table); 69 | // Construct the output schema according on the "OutputCols" and "OutputTypes" 70 | TableSchema outputSchema = configureOutputSchema(); 71 | // Create a basic TFConfig according to "ClusterConfig" and "PythonConfig" 72 | TFConfig config = configureTFConfig(); 73 | // Configure the row encoding and decoding base on input & output schema 74 | configureExampleCoding(config, inputTable.getSchema(), outputSchema); 75 | // transform the table by TF which implemented by AI-Extended 76 | Table outputTable = TFUtils.inference(streamEnv, tableEnvironment, inputTable, config, outputSchema); 77 | return outputTable; 78 | } catch (Exception e) { 79 | throw new RuntimeException(e); 80 | } 81 | } 82 | 83 | @Override 84 | public Params getParams() { 85 | return params; 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /src/test/java/org/apache/flink/table/ml/lib/tensorflow/KafkaSourceSinkTest.java: -------------------------------------------------------------------------------- 1 | //package org.apache.flink.table.ml.lib.tensorflow; 2 | // 3 | //import kafka.server.KafkaConfig; 4 | //import kafka.server.KafkaServerStartable; 5 | //import org.apache.curator.test.TestingServer; 6 | //import org.apache.kafka.clients.consumer.Consumer; 7 | //import org.apache.kafka.clients.consumer.ConsumerConfig; 8 | //import org.apache.kafka.clients.producer.KafkaProducer; 9 | //import org.apache.kafka.clients.producer.Producer; 10 | //import org.apache.kafka.clients.producer.ProducerRecord; 11 | //import org.junit.After; 12 | //import org.junit.Before; 13 | //import org.junit.Test; 14 | // 15 | //import java.io.IOException; 16 | //import java.util.HashMap; 17 | //import java.util.Map; 18 | //import java.util.Properties; 19 | // 20 | //import static com.sun.org.apache.xerces.internal.util.PropertyState.is; 21 | // 22 | //public class KafkaSourceSinkTest { 23 | // public static final String topic = "topic1-" + System.currentTimeMillis(); 24 | // 25 | // private KafkaTestFixture server; 26 | // private Producer producer; 27 | // private Consumer consumer; 28 | // private ConsumerConnector consumerConnector; 29 | // 30 | // @Before 31 | // public void setup() throws Exception { 32 | // server = new KafkaTestFixture(); 33 | // server.start(serverProperties()); 34 | // } 35 | // 36 | // @After 37 | // public void teardown() throws Exception { 38 | // producer.close(); 39 | // consumerConnector.shutdown(); 40 | // server.stop(); 41 | // } 42 | // 43 | // @Test 44 | // public void shouldWriteThenRead() throws Exception { 45 | // 46 | // //Create a consumer 47 | // ConsumerIterator it = buildConsumer(KafkaSourceSinkTest.topic); 48 | // 49 | // //Create a producer 50 | // producer = new KafkaProducer(producerProps()); 51 | // 52 | // //send a message 53 | // producer.send(new ProducerRecord(KafkaSourceSinkTest.topic, "message")).get(); 54 | // 55 | // //read it back 56 | // MessageAndMetadata messageAndMetadata = it.next(); 57 | // String value = messageAndMetadata.message(); 58 | // assertThat(value, is("message")); 59 | // } 60 | // 61 | // private ConsumerIterator buildConsumer(String topic) { 62 | // Properties props = consumerProperties(); 63 | // 64 | // Map topicCountMap = new HashMap(); 65 | // topicCountMap.put(topic, 1); 66 | // ConsumerConfig consumerConfig = new ConsumerConfig(props); 67 | // consumerConnector = Consumer.createJavaConsumerConnector(consumerConfig); 68 | // Map>> consumers = consumerConnector.createMessageStreams(topicCountMap, new StringDecoder(null), new StringDecoder(null)); 69 | // KafkaStream stream = consumers.get(topic).get(0); 70 | // return stream.iterator(); 71 | // } 72 | // 73 | // private Properties consumerProperties() { 74 | // Properties props = new Properties(); 75 | // props.put("zookeeper.connect", serverProperties().get("zookeeper.connect")); 76 | // props.put("group.id", "group1"); 77 | // props.put("auto.offset.reset", "smallest"); 78 | // return props; 79 | // } 80 | // 81 | // private Properties producerProps() { 82 | // Properties props = new Properties(); 83 | // props.put("bootstrap.servers", "localhost:9092"); 84 | // props.put("key.serializer", "org.apache.kafka.common.serialization.StringSerializer"); 85 | // props.put("value.serializer", "org.apache.kafka.common.serialization.StringSerializer"); 86 | // props.put("request.required.acks", "1"); 87 | // return props; 88 | // } 89 | // 90 | // private Properties serverProperties() { 91 | // Properties props = new Properties(); 92 | // props.put("zookeeper.connect", "localhost:2181"); 93 | // props.put("broker.id", "1"); 94 | // return props; 95 | // } 96 | // 97 | // private static class KafkaTestFixture { 98 | // private TestingServer zk; 99 | // private KafkaServerStartable kafka; 100 | // 101 | // public void start(Properties properties) throws Exception { 102 | // Integer port = getZkPort(properties); 103 | // zk = new TestingServer(port); 104 | // zk.start(); 105 | // 106 | // KafkaConfig kafkaConfig = new KafkaConfig(properties); 107 | // kafka = new KafkaServerStartable(kafkaConfig); 108 | // kafka.startup(); 109 | // } 110 | // 111 | // public void stop() throws IOException { 112 | // kafka.shutdown(); 113 | // zk.stop(); 114 | // zk.close(); 115 | // } 116 | // 117 | // private int getZkPort(Properties properties) { 118 | // String url = (String) properties.get("zookeeper.connect"); 119 | // String port = url.split(":")[1]; 120 | // return Integer.valueOf(port); 121 | // } 122 | // } 123 | //} 124 | -------------------------------------------------------------------------------- /src/main/python/pointer-generator/util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # Modifications Copyright 2017 Abigail See 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """This file contains some utility functions""" 18 | import tensorflow as tf 19 | import time 20 | import os 21 | FLAGS = tf.app.flags.FLAGS 22 | 23 | def get_config(): 24 | """Returns config for tf.session""" 25 | config = tf.ConfigProto(allow_soft_placement=True) 26 | config.gpu_options.allow_growth=True 27 | return config 28 | 29 | def load_ckpt(saver, sess, ckpt_dir="train"): 30 | """Load checkpoint from the ckpt_dir (if unspecified, this is train dir) and restore it to saver and sess, waiting 10 secs in the case of failure. Also returns checkpoint name.""" 31 | while True: 32 | try: 33 | latest_filename = "checkpoint_best" if ckpt_dir=="eval" else None 34 | ckpt_dir = os.path.join(FLAGS.log_root, ckpt_dir) 35 | ckpt_state = tf.train.get_checkpoint_state(ckpt_dir, latest_filename=latest_filename) 36 | tf.logging.info('Loading checkpoint %s', ckpt_state.model_checkpoint_path) 37 | saver.restore(sess, ckpt_state.model_checkpoint_path) 38 | return ckpt_state.model_checkpoint_path 39 | except: 40 | tf.logging.info("Failed to load checkpoint from %s. Sleeping for %i secs...", ckpt_dir, 10) 41 | time.sleep(10) 42 | 43 | 44 | def bin2txt(data_path, finished_dir): 45 | import glob 46 | import json 47 | import struct 48 | import nltk 49 | import data 50 | from tensorflow.core.example import example_pb2 51 | from collections import OrderedDict 52 | 53 | def example_generator(file_path): 54 | with open(file_path, 'rb') as reader: 55 | while True: 56 | len_bytes = reader.read(8) 57 | if not len_bytes: 58 | break # finished reading this file 59 | str_len = struct.unpack('q', len_bytes)[0] 60 | example_str = struct.unpack('%ds' % str_len, reader.read(str_len))[0] 61 | yield example_pb2.Example.FromString(example_str) 62 | 63 | def text_generator(example_generator): 64 | while True: 65 | e = example_generator.next() # e is a tf.Example 66 | try: 67 | article_text = e.features.feature['article'].bytes_list.value[0] # the article text was saved under the key 'article' in the data files 68 | abstract_text = e.features.feature['abstract'].bytes_list.value[0] # the abstract text was saved under the key 'abstract' in the data files 69 | except ValueError: 70 | tf.logging.error('Failed to get article or abstract from example') 71 | continue 72 | if len(article_text)==0: # See https://github.com/abisee/pointer-generator/issues/1 73 | tf.logging.warning('Found an example with empty article text. Skipping it.') 74 | else: 75 | yield (article_text, abstract_text) 76 | counter = 0 77 | filelist = glob.glob(data_path) # get the list of datafiles 78 | assert filelist, ('Error: Empty filelist at %s' % data_path) # check filelist isn't empty 79 | filelist = sorted(filelist) 80 | for f in filelist: 81 | input_gen = text_generator(example_generator(f)) 82 | with open(finished_dir + '/' + f.split('/')[-1].replace('.bin', '.txt'), 'w') as writer: 83 | while True: 84 | try: 85 | (article, abstract) = input_gen.next() # read the next example from file. article and abstract are both strings. 86 | abstract_sentences = [sent.strip() for sent in data.abstract2sents(abstract)] # Use the and tags in abstract to get a list of sentences. 87 | abstract = ' '.join(abstract_sentences) 88 | abstract_sentences = [' '.join(nltk.word_tokenize(sent)) for sent in nltk.sent_tokenize(abstract)] 89 | 90 | json_format = json.dumps(OrderedDict([('uuid', 'uuid-%i' % counter), ('article', article), ('summary', ''), ('reference', abstract)])) 91 | counter += 1 92 | writer.write(json_format) 93 | writer.write('\n') 94 | except StopIteration: # if there are no more examples: 95 | tf.logging.info("The example generator for this example queue filling thread has exhausted data.") 96 | break 97 | except UnicodeDecodeError: 98 | continue 99 | print "finished " + f 100 | 101 | # data_path = '/Users/bodeng/TextSummarization-On-Flink/data/cnn-dailymail/finished_files/chunked/train_*' 102 | # finished_dir = '/Users/bodeng/TextSummarization-On-Flink/data/cnn-dailymail/finished_files/json' 103 | # bin2txt(data_path, finished_dir) 104 | -------------------------------------------------------------------------------- /src/main/java/org/apache/flink/table/ml/lib/tensorflow/TFEstimator.java: -------------------------------------------------------------------------------- 1 | package org.apache.flink.table.ml.lib.tensorflow; 2 | 3 | import com.alibaba.flink.ml.tensorflow.client.TFConfig; 4 | import com.alibaba.flink.ml.tensorflow.client.TFUtils; 5 | import com.alibaba.flink.ml.tensorflow.coding.ExampleCodingConfig; 6 | import com.alibaba.flink.ml.util.MLConstants; 7 | import org.apache.flink.ml.api.core.Estimator; 8 | import org.apache.flink.ml.api.misc.param.Params; 9 | import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; 10 | import org.apache.flink.table.api.Table; 11 | import org.apache.flink.table.api.TableEnvironment; 12 | import org.apache.flink.table.api.TableSchema; 13 | import org.apache.flink.table.api.java.StreamTableEnvironment; 14 | import org.apache.flink.table.ml.lib.tensorflow.param.*; 15 | import org.apache.flink.table.ml.lib.tensorflow.util.CodingUtils; 16 | import org.apache.flink.types.Row; 17 | 18 | 19 | import java.util.HashMap; 20 | import java.util.Map; 21 | 22 | /** 23 | * A general TensorFlow estimator implemented by Flink-AI-Extended, 24 | * responsible for training and generating TensorFlow models. 25 | */ 26 | public class TFEstimator implements Estimator, HasClusterConfig, 27 | HasTrainPythonConfig, HasInferencePythonConfig, 28 | HasTrainSelectedCols, HasTrainOutputCols, HasTrainOutputTypes, 29 | HasInferenceSelectedCols, HasInferenceOutputCols, HasInferenceOutputTypes { 30 | private Params params = new Params(); 31 | 32 | protected Table configureInputTable(Table rawTable) { 33 | if (getTrainSelectedCols().length == 0) { 34 | return null; 35 | } else { 36 | return rawTable.select(String.join(",", getTrainSelectedCols())); 37 | } 38 | } 39 | protected TableSchema configureOutputSchema() { 40 | if (getTrainOutputCols().length == 0) { 41 | return null; 42 | } else { 43 | return new TableSchema(getTrainOutputCols(), 44 | CodingUtils.dataTypesListToTypeInformation(getTrainOutputTypes())); 45 | } 46 | } 47 | 48 | protected TFConfig configureTFConfig() { 49 | Map prop = new HashMap<>(); 50 | prop.put(MLConstants.CONFIG_STORAGE_TYPE, MLConstants.STORAGE_ZOOKEEPER); 51 | prop.put(MLConstants.CONFIG_ZOOKEEPER_CONNECT_STR, getZookeeperConnStr()); 52 | prop.put(getTrainHyperParamsKey(), String.join(" ", getTrainHyperParams())); 53 | return new TFConfig(getWorkerNum(), getPsNum(), prop, getTrainScripts(), getTrainMapFunc(), getTrainEnvPath()); 54 | } 55 | 56 | protected void configureExampleCoding(TFConfig config, TableSchema inputSchema, TableSchema outputSchema) { 57 | CodingUtils.configureExampleCoding(config, inputSchema, outputSchema, ExampleCodingConfig.ObjectType.ROW, Row.class); 58 | } 59 | 60 | @Override 61 | public TFModel fit(TableEnvironment tableEnvironment, Table table) { 62 | StreamExecutionEnvironment streamEnv; 63 | try { 64 | // TODO: [hack] transform table to dataStream to get StreamExecutionEnvironment 65 | if (tableEnvironment instanceof StreamTableEnvironment) { 66 | StreamTableEnvironment streamTableEnvironment = (StreamTableEnvironment)tableEnvironment; 67 | streamEnv = streamTableEnvironment.toAppendStream(table, Row.class).getExecutionEnvironment(); 68 | } else { 69 | throw new RuntimeException("Unsupported TableEnvironment, please use StreamTableEnvironment"); 70 | } 71 | // Select the necessary columns according to "SelectedCols" 72 | Table inputTable = configureInputTable(table); 73 | TableSchema inputSchema = null; 74 | if (inputTable != null) { 75 | inputSchema = inputTable.getSchema(); 76 | } 77 | // Construct the output schema according on the "OutputCols" and "OutputTypes" 78 | TableSchema outputSchema = configureOutputSchema(); 79 | // Create a basic TFConfig according to "ClusterConfig" and "PythonConfig" 80 | TFConfig config = configureTFConfig(); 81 | // Configure the row encoding and decoding base on input & output schema 82 | configureExampleCoding(config, inputSchema, outputSchema); 83 | // transform the table by TF which implemented by AI-Extended 84 | Table outputTable = TFUtils.train(streamEnv, tableEnvironment, inputTable, config, outputSchema); 85 | // Construct the trained model by inference related config 86 | TFModel model = new TFModel() 87 | .setZookeeperConnStr(getZookeeperConnStr()) 88 | .setWorkerNum(getWorkerNum()) 89 | .setPsNum(getPsNum()) 90 | .setInferenceScripts(getInferenceScripts()) 91 | .setInferenceMapFunc(getInferenceMapFunc()) 92 | .setInferenceHyperParams(getInferenceHyperParams()) 93 | .setInferenceEnvPath(getInferenceEnvPath()) 94 | .setInferenceSelectedCols(getInferenceSelectedCols()) 95 | .setInferenceOutputCols(getInferenceOutputCols()) 96 | .setInferenceOutputTypes(getInferenceOutputTypes()); 97 | 98 | return model; 99 | } catch (Exception e) { 100 | throw new RuntimeException(e); 101 | } 102 | } 103 | 104 | @Override 105 | public Params getParams() { 106 | return params; 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /src/main/python/pointer-generator/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import tensorflow as tf 5 | from tensorflow.python.training.summary_io import SummaryWriterCache 6 | import numpy as np 7 | 8 | 9 | FLAGS = tf.app.flags.FLAGS 10 | 11 | 12 | class FlinkTestTrainer(object): 13 | def __init__(self, hps, batcher, sess_config, server_target): 14 | self._hps = hps 15 | # self._model = model 16 | self._batcher = batcher 17 | 18 | train_dir = os.path.join("temp", "test") 19 | if not os.path.exists(train_dir): 20 | os.makedirs(train_dir) 21 | global_step = tf.contrib.framework.get_or_create_global_step() 22 | 23 | # self._model.build_graph() 24 | # scaffold = tf.train.Scaffold(saver=tf.train.Saver(max_to_keep=3)) 25 | # hooks = None 26 | # if hps.num_steps > 0: 27 | # hooks = [tf.train.StopAtStepHook(num_steps=hps.num_steps)] 28 | 29 | self._sess = tf.train.MonitoredTrainingSession(master=server_target, 30 | is_chief=True, 31 | config=sess_config, 32 | checkpoint_dir=train_dir,) 33 | # save_checkpoint_secs=60, 34 | # save_summaries_secs=60) 35 | # hooks=hooks, 36 | # scaffold=scaffold) 37 | 38 | # self._summary_writer = SummaryWriterCache.get(train_dir) 39 | tf.logging.info("Created session.") 40 | 41 | def stop(self): 42 | self._sess.close() 43 | 44 | def train(self): 45 | tf.logging.info("starting run_training") 46 | try: 47 | while not self._sess.should_stop(): 48 | tf.logging.info('getting next batch...') 49 | batch = self._batcher.next_batch() 50 | tf.logging.info(batch.dec_padding_mask.shape) 51 | tf.logging.info(batch.target_batch.shape) 52 | except KeyboardInterrupt: 53 | tf.logging.info("Caught keyboard interrupt on worker. Stopping supervisor...") 54 | finally: 55 | self.stop() 56 | 57 | 58 | class FlinkTrainer(object): 59 | def __init__(self, hps, model, batcher, sess_config, server_target): 60 | self._hps = hps 61 | self._model = model 62 | self._batcher = batcher 63 | 64 | train_dir = os.path.join(FLAGS.log_root, "train") 65 | if not os.path.exists(train_dir): os.makedirs(train_dir) 66 | 67 | self._model.build_graph() 68 | # self._batcher.build_graph() 69 | scaffold = tf.train.Scaffold(saver=tf.train.Saver(max_to_keep=3)) 70 | hooks = None 71 | if hps.num_steps > 0: 72 | hooks = [tf.train.StopAtStepHook(num_steps=hps.num_steps)] 73 | 74 | self._sess = tf.train.MonitoredTrainingSession(master=server_target, 75 | is_chief=True, 76 | config=sess_config, 77 | checkpoint_dir=train_dir, 78 | # save_checkpoint_secs=60, 79 | # save_summaries_secs=60, 80 | scaffold=scaffold, 81 | hooks=hooks) 82 | 83 | self._summary_writer = SummaryWriterCache.get(train_dir) 84 | tf.logging.info("Created session.") 85 | 86 | def stop(self): 87 | self._sess.close() 88 | 89 | def train(self): 90 | tf.logging.info("starting run_training") 91 | try: 92 | while not self._sess.should_stop(): 93 | tf.logging.info('getting next batch...') 94 | # batch = self._batcher.next_batch() 95 | batch = self._batcher.next_batch() 96 | tf.logging.info(batch.target_batch.shape) 97 | # print batch 98 | tf.logging.info('running training step...') 99 | t0 = time.time() 100 | results = self._model.run_train_step(self._sess, batch) 101 | t1 = time.time() 102 | tf.logging.info('seconds for training step: %.3f', t1 - t0) 103 | 104 | loss = results['loss'] 105 | tf.logging.info('loss: %f', loss) # print the loss to screen 106 | 107 | if not np.isfinite(loss): 108 | raise Exception("Loss is not finite. Stopping.") 109 | 110 | if FLAGS.coverage: 111 | coverage_loss = results['coverage_loss'] 112 | tf.logging.info("coverage_loss: %f", coverage_loss) # print the coverage loss to screen 113 | 114 | # get the summaries and iteration number so we can write summaries to tensorboard 115 | summaries = results['summaries'] # we will write these summaries to tensorboard using summary_writer 116 | train_step = results['global_step'] # we need this to update our running average loss 117 | 118 | self._summary_writer.add_summary(summaries, train_step) # write the summaries 119 | # if train_step % 100 == 0: # flush the summary writer every so often 120 | # self._summary_writer.flush() 121 | except KeyboardInterrupt: 122 | tf.logging.info("Caught keyboard interrupt on worker. Stopping supervisor...") 123 | # self.stop() 124 | finally: 125 | self.stop() 126 | -------------------------------------------------------------------------------- /src/test/java/org/apache/flink/table/ml/lib/tensorflow/SourceSinkTest.java: -------------------------------------------------------------------------------- 1 | package org.apache.flink.table.ml.lib.tensorflow; 2 | 3 | import com.alibaba.flink.ml.tensorflow.client.TFConfig; 4 | import com.alibaba.flink.ml.tensorflow.client.TFUtils; 5 | import com.alibaba.flink.ml.tensorflow.coding.ExampleCodingConfig; 6 | import com.alibaba.flink.ml.util.MLConstants; 7 | import org.apache.curator.test.TestingServer; 8 | import org.apache.flink.api.common.restartstrategy.RestartStrategies; 9 | import org.apache.flink.api.common.state.ListState; 10 | import org.apache.flink.api.common.state.ListStateDescriptor; 11 | import org.apache.flink.api.common.typeinfo.BasicTypeInfo; 12 | import org.apache.flink.api.common.typeinfo.TypeInformation; 13 | import org.apache.flink.api.common.typeinfo.Types; 14 | import org.apache.flink.api.java.typeutils.RowTypeInfo; 15 | import org.apache.flink.runtime.state.FunctionInitializationContext; 16 | import org.apache.flink.runtime.state.FunctionSnapshotContext; 17 | import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; 18 | import org.apache.flink.streaming.api.datastream.DataStream; 19 | import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; 20 | import org.apache.flink.streaming.api.functions.source.SourceFunction; 21 | import org.apache.flink.table.api.Table; 22 | import org.apache.flink.table.api.TableSchema; 23 | import org.apache.flink.table.api.java.StreamTableEnvironment; 24 | import org.apache.flink.table.ml.lib.tensorflow.util.CodingUtils; 25 | import org.apache.flink.types.Row; 26 | import org.junit.Test; 27 | import org.slf4j.Logger; 28 | import org.slf4j.LoggerFactory; 29 | 30 | import java.sql.Timestamp; 31 | import java.util.HashMap; 32 | import java.util.Map; 33 | 34 | public class SourceSinkTest { 35 | private static final String projectDir = System.getProperty("user.dir"); 36 | private static final String ZookeeperConn = "127.0.0.1:2181"; 37 | private static final String[] Scripts = {projectDir + "/src/test/python/test.py"}; 38 | private static final int WorkerNum = 1; 39 | private static final int PsNum = 0; 40 | 41 | @Test 42 | public void testSourceSink() throws Exception { 43 | TestingServer server = new TestingServer(2181, true); 44 | StreamExecutionEnvironment streamEnv = StreamExecutionEnvironment.createLocalEnvironment(1); 45 | streamEnv.setRestartStrategy(RestartStrategies.noRestart()); 46 | 47 | DataStream sourceStream = streamEnv.addSource( 48 | new DummyTimedSource(20, 5), new RowTypeInfo(Types.STRING)).setParallelism(1); 49 | StreamTableEnvironment tableEnv = StreamTableEnvironment.create(streamEnv); 50 | Table input = tableEnv.fromDataStream(sourceStream, "input"); 51 | TFConfig config = createTFConfig("test_source_sink"); 52 | TableSchema inputSchema = new TableSchema(new String[]{"input"}, new TypeInformation[]{BasicTypeInfo.STRING_TYPE_INFO}); 53 | TableSchema outputSchema = new TableSchema(new String[]{"output"}, new TypeInformation[]{BasicTypeInfo.STRING_TYPE_INFO}); 54 | CodingUtils.configureExampleCoding(config, inputSchema, outputSchema, ExampleCodingConfig.ObjectType.ROW, Row.class); 55 | Table output = TFUtils.inference(streamEnv, tableEnv, input, config, outputSchema); 56 | tableEnv.toAppendStream(output, Row.class) 57 | .map(r -> "[Sink][" + new Timestamp(System.currentTimeMillis()) + "]finish " + r.getField(0) + "\n") 58 | .print().setParallelism(1); 59 | 60 | streamEnv.execute(); 61 | server.stop(); 62 | } 63 | 64 | private TFConfig createTFConfig(String mapFunc) { 65 | Map prop = new HashMap<>(); 66 | prop.put(MLConstants.CONFIG_STORAGE_TYPE, MLConstants.STORAGE_ZOOKEEPER); 67 | prop.put(MLConstants.CONFIG_ZOOKEEPER_CONNECT_STR, ZookeeperConn); 68 | return new TFConfig(WorkerNum, PsNum, prop, Scripts, mapFunc, null); 69 | } 70 | 71 | private static class DummyTimedSource implements SourceFunction, CheckpointedFunction { 72 | public static final Logger LOG = LoggerFactory.getLogger(DummyTimedSource.class); 73 | private long count = 0L; 74 | private long MAX_COUNT; 75 | private long INTERVAL; 76 | private volatile boolean isRunning = true; 77 | 78 | private transient ListState checkpointedCount; 79 | 80 | public DummyTimedSource(long maxCount, long interval) { 81 | this.MAX_COUNT = maxCount; 82 | this.INTERVAL = interval; 83 | } 84 | 85 | @Override 86 | public void run(SourceContext ctx) throws Exception { 87 | while (isRunning && count < MAX_COUNT) { 88 | // this synchronized block ensures that state checkpointing, 89 | // internal state updates and emission of elements are an atomic operation 90 | synchronized (ctx.getCheckpointLock()) { 91 | Row row = new Row(1); 92 | row.setField(0, String.format("data-%d", count)); 93 | System.out.println("[Source][" + new Timestamp(System.currentTimeMillis()) + "]produce " + row.getField(0)); 94 | ctx.collect(row); 95 | count++; 96 | Thread.sleep(INTERVAL * 1000); 97 | } 98 | } 99 | } 100 | 101 | @Override 102 | public void cancel() { 103 | isRunning = false; 104 | } 105 | 106 | @Override 107 | public void snapshotState(FunctionSnapshotContext context) throws Exception { 108 | this.checkpointedCount.clear(); 109 | this.checkpointedCount.add(count); 110 | } 111 | 112 | @Override 113 | public void initializeState(FunctionInitializationContext context) throws Exception { 114 | this.checkpointedCount = context 115 | .getOperatorStateStore() 116 | .getListState(new ListStateDescriptor<>("count", Long.class)); 117 | 118 | if (context.isRestored()) { 119 | for (Long count : this.checkpointedCount.get()) { 120 | this.count = count; 121 | } 122 | } 123 | } 124 | } 125 | } 126 | -------------------------------------------------------------------------------- /src/test/java/org/apache/flink/table/ml/lib/tensorflow/InputOutputTest.java: -------------------------------------------------------------------------------- 1 | package org.apache.flink.table.ml.lib.tensorflow; 2 | 3 | import com.alibaba.flink.ml.tensorflow.client.TFConfig; 4 | import com.alibaba.flink.ml.tensorflow.client.TFUtils; 5 | import com.alibaba.flink.ml.tensorflow.coding.ExampleCodingConfig; 6 | import com.alibaba.flink.ml.util.MLConstants; 7 | import org.apache.curator.test.TestingServer; 8 | import org.apache.flink.api.common.restartstrategy.RestartStrategies; 9 | import org.apache.flink.api.common.typeinfo.BasicTypeInfo; 10 | import org.apache.flink.api.common.typeinfo.TypeInformation; 11 | import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; 12 | import org.apache.flink.table.api.Table; 13 | import org.apache.flink.table.api.TableSchema; 14 | import org.apache.flink.table.api.java.StreamTableEnvironment; 15 | import org.apache.flink.table.ml.lib.tensorflow.util.CodingUtils; 16 | import org.apache.flink.types.Row; 17 | import org.junit.Test; 18 | 19 | import java.util.ArrayList; 20 | import java.util.HashMap; 21 | import java.util.List; 22 | import java.util.Map; 23 | 24 | public class InputOutputTest { 25 | private static final String projectDir = System.getProperty("user.dir"); 26 | private static final String ZookeeperConn = "127.0.0.1:2181"; 27 | private static final String[] Scripts = {projectDir + "/src/test/python/test.py"}; 28 | private static final int WorkerNum = 1; 29 | private static final int PsNum = 0; 30 | 31 | @Test 32 | public void testExampleCoding() throws Exception { 33 | TestingServer server = new TestingServer(2181, true); 34 | StreamExecutionEnvironment streamEnv = StreamExecutionEnvironment.createLocalEnvironment(1); 35 | streamEnv.setRestartStrategy(restartStrategy()); 36 | StreamTableEnvironment tableEnv = StreamTableEnvironment.create(streamEnv); 37 | Table input = tableEnv.fromDataStream(streamEnv.fromCollection(createDummyData()).setParallelism(1), "input"); 38 | TFConfig config = createTFConfig("test_example_coding"); 39 | TableSchema inputSchema = new TableSchema(new String[]{"input"}, new TypeInformation[]{BasicTypeInfo.STRING_TYPE_INFO}); 40 | TableSchema outputSchema = new TableSchema(new String[]{"output"}, new TypeInformation[]{BasicTypeInfo.STRING_TYPE_INFO}); 41 | CodingUtils.configureExampleCoding(config, inputSchema, outputSchema, ExampleCodingConfig.ObjectType.ROW, Row.class); 42 | Table output = TFUtils.inference(streamEnv, tableEnv, input, config, outputSchema); 43 | tableEnv.toAppendStream(output, Row.class).print().setParallelism(1); 44 | streamEnv.execute(); 45 | server.stop(); 46 | } 47 | 48 | @Test 49 | public void testExampleCodingWithoutEncode() throws Exception { 50 | TestingServer server = new TestingServer(2181, true); 51 | StreamExecutionEnvironment streamEnv = StreamExecutionEnvironment.createLocalEnvironment(1); 52 | streamEnv.setRestartStrategy(restartStrategy()); 53 | StreamTableEnvironment tableEnv = StreamTableEnvironment.create(streamEnv); 54 | Table input = tableEnv.fromDataStream(streamEnv.fromCollection(createDummyData()).setParallelism(1), "input"); 55 | TFConfig config = createTFConfig("test_example_coding_without_encode"); 56 | // TableSchema inputSchema = new TableSchema(new String[]{"input"}, new TypeInformation[]{BasicTypeInfo.STRING_TYPE_INFO}); 57 | TableSchema outputSchema = new TableSchema(new String[]{"output"}, new TypeInformation[]{BasicTypeInfo.STRING_TYPE_INFO}); 58 | TableSchema inputSchema = null; 59 | CodingUtils.configureExampleCoding(config, inputSchema, outputSchema, ExampleCodingConfig.ObjectType.ROW, Row.class); 60 | Table output = TFUtils.inference(streamEnv, tableEnv, null, config, outputSchema); 61 | tableEnv.toAppendStream(output, Row.class).print().setParallelism(1); 62 | streamEnv.execute(); 63 | server.stop(); 64 | } 65 | 66 | @Test 67 | public void testExampleCodingWithoutDecode() throws Exception { 68 | TestingServer server = new TestingServer(2181, true); 69 | StreamExecutionEnvironment streamEnv = StreamExecutionEnvironment.createLocalEnvironment(1); 70 | streamEnv.setRestartStrategy(restartStrategy()); 71 | StreamTableEnvironment tableEnv = StreamTableEnvironment.create(streamEnv); 72 | Table input = tableEnv.fromDataStream(streamEnv.fromCollection(createDummyData()).setParallelism(1), "input"); 73 | TFConfig config = createTFConfig("test_example_coding_without_decode"); 74 | TableSchema inputSchema = new TableSchema(new String[]{"input"}, new TypeInformation[]{BasicTypeInfo.STRING_TYPE_INFO}); 75 | // TableSchema outputSchema = new TableSchema(new String[]{"output"}, new TypeInformation[]{BasicTypeInfo.STRING_TYPE_INFO}); 76 | TableSchema outputSchema = null; 77 | CodingUtils.configureExampleCoding(config, inputSchema, outputSchema, ExampleCodingConfig.ObjectType.ROW, Row.class); 78 | Table output = TFUtils.inference(streamEnv, tableEnv, input, config, outputSchema); 79 | // tableEnv.toAppendStream(output, Row.class).print().setParallelism(1); 80 | streamEnv.execute(); 81 | server.stop(); 82 | } 83 | 84 | @Test 85 | public void testExampleCodingWithNothing() throws Exception { 86 | TestingServer server = new TestingServer(2181, true); 87 | StreamExecutionEnvironment streamEnv = StreamExecutionEnvironment.createLocalEnvironment(1); 88 | streamEnv.setRestartStrategy(restartStrategy()); 89 | StreamTableEnvironment tableEnv = StreamTableEnvironment.create(streamEnv); 90 | Table input = tableEnv.fromDataStream(streamEnv.fromCollection(createDummyData()).setParallelism(1), "input"); 91 | TFConfig config = createTFConfig("test_example_coding_with_nothing"); 92 | // TableSchema inputSchema = new TableSchema(new String[]{"input"}, new TypeInformation[]{BasicTypeInfo.STRING_TYPE_INFO}); 93 | // TableSchema outputSchema = new TableSchema(new String[]{"output"}, new TypeInformation[]{BasicTypeInfo.STRING_TYPE_INFO}); 94 | TableSchema inputSchema = null; 95 | TableSchema outputSchema = null; 96 | // CodingUtils.configureExampleCoding(config, inputSchema, outputSchema, ExampleCodingConfig.ObjectType.ROW, Row.class); 97 | Table output = TFUtils.inference(streamEnv, tableEnv, null, config, outputSchema); 98 | // tableEnv.toAppendStream(output, Row.class).print().setParallelism(1); 99 | streamEnv.execute(); 100 | server.stop(); 101 | } 102 | 103 | private List createDummyData() { 104 | List rows = new ArrayList<>(); 105 | for (int i = 0; i < 10; i++) { 106 | Row row = new Row(1); 107 | row.setField(0, String.format("data-%d", i)); 108 | rows.add(row); 109 | } 110 | return rows; 111 | } 112 | 113 | private TFConfig createTFConfig(String mapFunc) { 114 | Map prop = new HashMap<>(); 115 | prop.put(MLConstants.CONFIG_STORAGE_TYPE, MLConstants.STORAGE_ZOOKEEPER); 116 | prop.put(MLConstants.CONFIG_ZOOKEEPER_CONNECT_STR, ZookeeperConn); 117 | return new TFConfig(WorkerNum, PsNum, prop, Scripts, mapFunc, null); 118 | } 119 | 120 | private RestartStrategies.RestartStrategyConfiguration restartStrategy() { 121 | // return RestartStrategies.fixedDelayRestart(2, Time.seconds(5)); 122 | return RestartStrategies.noRestart(); 123 | } 124 | } 125 | -------------------------------------------------------------------------------- /doc/github/[Issue] Exception when there is only InputTfExampleConfig but no OutputTfExampleConfig.md: -------------------------------------------------------------------------------- 1 | ## [[Issue](https://github.com/alibaba/flink-ai-extended/issues/10)] Exception when there is only InputTfExampleConfig but no OutputTfExampleConfig 2 | 3 | As follows, this is the normal configuration of ExampleCoding (that is, the format of Flink and Python for data transmission), but if only the encode part is configured without configuring the decode part, it will not be executed, and an exception will be thrown. vice versa. 4 | 5 | ```java 6 | // configure encode example coding 7 | String strInput = ExampleCodingConfig.createExampleConfigStr(encodeNames, encodeTypes, 8 | entryType, entryClass); 9 | config.getProperties().put(TFConstants.INPUT_TF_EXAMPLE_CONFIG, strInput); 10 | config.getProperties().put(MLConstants.ENCODING_CLASS, 11 | ExampleCoding.class.getCanonicalName()); 12 | 13 | // configure decode example coding 14 | String strOutput = ExampleCodingConfig.createExampleConfigStr(decodeNames, decodeTypes, 15 | entryType, entryClass); 16 | config.getProperties().put(TFConstants.OUTPUT_TF_EXAMPLE_CONFIG, strOutput); 17 | config.getProperties().put(MLConstants.DECODING_CLASS, 18 | ExampleCoding.class.getCanonicalName()); 19 | ``` 20 | 21 | Such a usage scenario is relatively common. For example, during the training process, the user only needs to transfer data to the TF without returning the table, so there will only be an encode phase of flink-to-tf without tf-to-flink. For the user, it is also customary to set only the encoding-related configuration. 22 | **Therefore**, I want to be able to configure **only the encode** part **without** configuring the **decode** part, and vice versa. 23 | After review, the main reason is the method of **ReflectUtil.createInstance(className, classes, objects)** in **CodingFactory.java**. This method will create an **ExampleCoding** instance according to **ENCODING_CLASS**. According to the definition of **ExampleCoding.java**, both inputConfig and outputConfig(even if not) will be configured in the constructor, which will result in a NullPointerException. 24 | The following is the exception information: 25 | 26 | ```java 27 | java.lang.reflect.InvocationTargetException 28 | at sun.reflect.NativeConstructorAccessorImpl.newInstance0(Native Method) 29 | at sun.reflect.NativeConstructorAccessorImpl.newInstance(NativeConstructorAccessorImpl.java:62) 30 | at sun.reflect.DelegatingConstructorAccessorImpl.newInstance(DelegatingConstructorAccessorImpl.java:45) 31 | at java.lang.reflect.Constructor.newInstance(Constructor.java:423) 32 | at com.alibaba.flink.ml.util.ReflectUtil.createInstance(ReflectUtil.java:36) 33 | at com.alibaba.flink.ml.coding.CodingFactory.getEncoding(CodingFactory.java:49) 34 | at com.alibaba.flink.ml.data.DataExchange.(DataExchange.java:58) 35 | at com.alibaba.flink.ml.operator.ops.MLMapFunction.open(MLMapFunction.java:80) 36 | at com.alibaba.flink.ml.operator.ops.MLFlatMapOp.open(MLFlatMapOp.java:51) 37 | at org.apache.flink.api.common.functions.util.FunctionUtils.openFunction(FunctionUtils.java:36) 38 | at org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator.open(AbstractUdfStreamOperator.java:102) 39 | at org.apache.flink.streaming.api.operators.StreamFlatMap.open(StreamFlatMap.java:43) 40 | at org.apache.flink.streaming.runtime.tasks.StreamTask.openAllOperators(StreamTask.java:424) 41 | at org.apache.flink.streaming.runtime.tasks.StreamTask.invoke(StreamTask.java:290) 42 | at org.apache.flink.runtime.taskmanager.Task.run(Task.java:711) 43 | at java.lang.Thread.run(Thread.java:748) 44 | Caused by: java.lang.NullPointerException 45 | at com.alibaba.flink.ml.tensorflow.coding.ExampleCodingConfig.fromJsonObject(ExampleCodingConfig.java:100) 46 | at com.alibaba.flink.ml.tensorflow.coding.ExampleCoding.(ExampleCoding.java:57) 47 | ... 16 more 48 | ``` 49 | 50 | The following is a detailed test case, Java code: 51 | 52 | ```java 53 | private static final String ZookeeperConn = "127.0.0.1:2181"; 54 | private static final String[] Scripts = {"test.py"}; 55 | private static final int WorkerNum = 1; 56 | private static final int PsNum = 0; 57 | 58 | @Test 59 | public void testExampleCodingWithoutDecode() throws Exception { 60 | TestingServer server = new TestingServer(2181, true); 61 | StreamExecutionEnvironment streamEnv = 62 | StreamExecutionEnvironment.createLocalEnvironment(1); 63 | streamEnv.setRestartStrategy(RestartStrategies.noRestart()); 64 | StreamTableEnvironment tableEnv = StreamTableEnvironment.create(streamEnv); 65 | 66 | Table input = tableEnv 67 | .fromDataStream(streamEnv.fromCollection(createDummyData()), "input"); 68 | TableSchema inputSchema = 69 | new TableSchema(new String[]{"input"}, 70 | new TypeInformation[]{BasicTypeInfo.STRING_TYPE_INFO}); 71 | TableSchema outputSchema = null; 72 | 73 | TFConfig config = createTFConfig("test_example_coding_without_decode"); 74 | // configure encode coding 75 | String strInput = ExampleCodingConfig.createExampleConfigStr( 76 | new String[]{"input"}, new DataTypes[]{DataTypes.STRING}, 77 | ExampleCodingConfig.ObjectType.ROW, Row.class); 78 | config.getProperties().put(TFConstants.INPUT_TF_EXAMPLE_CONFIG, strInput); 79 | config.getProperties().put(MLConstants.ENCODING_CLASS, 80 | ExampleCoding.class.getCanonicalName()); 81 | 82 | // run in python 83 | Table output = TFUtils.inference(streamEnv, tableEnv, input, config, outputSchema); 84 | 85 | streamEnv.execute(); 86 | server.stop(); 87 | } 88 | 89 | private List createDummyData() { 90 | List rows = new ArrayList<>(); 91 | for (int i = 0; i < 10; i++) { 92 | Row row = new Row(1); 93 | row.setField(0, String.format("data-%d", i)); 94 | rows.add(row); 95 | } 96 | return rows; 97 | } 98 | 99 | private TFConfig createTFConfig(String mapFunc) { 100 | Map prop = new HashMap<>(); 101 | prop.put(MLConstants.CONFIG_STORAGE_TYPE, MLConstants.STORAGE_ZOOKEEPER); 102 | prop.put(MLConstants.CONFIG_ZOOKEEPER_CONNECT_STR, ZookeeperConn); 103 | return new TFConfig(WorkerNum, PsNum, prop, Scripts, mapFunc, null); 104 | } 105 | ``` 106 | 107 | Python code: 108 | 109 | ```python 110 | import tensorflow as tf 111 | from flink_ml_tensorflow.tensorflow_context import TFContext 112 | 113 | 114 | class FlinkReader(object): 115 | def __init__(self, context, batch_size=1, features={'input': tf.FixedLenFeature([], tf.string)}): 116 | self._context = context 117 | self._batch_size = batch_size 118 | self._features = features 119 | self._build_graph() 120 | 121 | def _decode(self, features): 122 | return features['input'] 123 | 124 | def _build_graph(self): 125 | dataset = self._context.flink_stream_dataset() 126 | dataset = dataset.map(lambda record: tf.parse_single_example(record, features=self._features)) 127 | dataset = dataset.map(self._decode) 128 | dataset = dataset.batch(self._batch_size) 129 | iterator = dataset.make_one_shot_iterator() 130 | self._next_batch = iterator.get_next() 131 | 132 | def next_batch(self, sess): 133 | try: 134 | batch = sess.run(self._next_batch) 135 | return batch 136 | except tf.errors.OutOfRangeError: 137 | return None 138 | 139 | 140 | def test_example_coding_without_decode(context): 141 | tf_context = TFContext(context) 142 | if 'ps' == tf_context.get_role_name(): 143 | from time import sleep 144 | while True: 145 | sleep(1) 146 | else: 147 | index = tf_context.get_index() 148 | job_name = tf_context.get_role_name() 149 | cluster_json = tf_context.get_tf_cluster() 150 | cluster = tf.train.ClusterSpec(cluster=cluster_json) 151 | 152 | server = tf.train.Server(cluster, job_name=job_name, task_index=index) 153 | sess_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False, 154 | device_filters=["/job:ps", "/job:worker/task:%d" % index]) 155 | with tf.device(tf.train.replica_device_setter(worker_device='/job:worker/task:' + str(index), cluster=cluster)): 156 | reader = FlinkReader(tf_context) 157 | 158 | with tf.train.ChiefSessionCreator(master=server.target, config=sess_config).create_session() as sess: 159 | while True: 160 | batch = reader.next_batch(sess) 161 | tf.logging.info(str(batch)) 162 | if batch is None: 163 | break 164 | sys.stdout.flush() 165 | ``` 166 | 167 | -------------------------------------------------------------------------------- /src/main/python/pointer-generator/beam_search.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # Modifications Copyright 2017 Abigail See 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """This file contains code to run beam search decoding""" 18 | 19 | import tensorflow as tf 20 | import numpy as np 21 | import data 22 | 23 | FLAGS = tf.app.flags.FLAGS 24 | 25 | 26 | class Hypothesis(object): 27 | """Class to represent a hypothesis during beam search. Holds all the information needed for the hypothesis.""" 28 | 29 | def __init__(self, tokens, log_probs, state, attn_dists, p_gens, coverage): 30 | """Hypothesis constructor. 31 | 32 | Args: 33 | tokens: List of integers. The ids of the tokens that form the summary so far. 34 | log_probs: List, same length as tokens, of floats, giving the log probabilities of the tokens so far. 35 | state: Current state of the decoder, a LSTMStateTuple. 36 | attn_dists: List, same length as tokens, of numpy arrays with shape (attn_length). These are the attention distributions so far. 37 | p_gens: List, same length as tokens, of floats, or None if not using pointer-generator model. The values of the generation probability so far. 38 | coverage: Numpy array of shape (attn_length), or None if not using coverage. The current coverage vector. 39 | """ 40 | self.tokens = tokens 41 | self.log_probs = log_probs 42 | self.state = state 43 | self.attn_dists = attn_dists 44 | self.p_gens = p_gens 45 | self.coverage = coverage 46 | 47 | def extend(self, token, log_prob, state, attn_dist, p_gen, coverage): 48 | """Return a NEW hypothesis, extended with the information from the latest step of beam search. 49 | 50 | Args: 51 | token: Integer. Latest token produced by beam search. 52 | log_prob: Float. Log prob of the latest token. 53 | state: Current decoder state, a LSTMStateTuple. 54 | attn_dist: Attention distribution from latest step. Numpy array shape (attn_length). 55 | p_gen: Generation probability on latest step. Float. 56 | coverage: Latest coverage vector. Numpy array shape (attn_length), or None if not using coverage. 57 | Returns: 58 | New Hypothesis for next step. 59 | """ 60 | return Hypothesis(tokens=self.tokens + [token], 61 | log_probs=self.log_probs + [log_prob], 62 | state=state, 63 | attn_dists=self.attn_dists + [attn_dist], 64 | p_gens=self.p_gens + [p_gen], 65 | coverage=coverage) 66 | 67 | @property 68 | def latest_token(self): 69 | return self.tokens[-1] 70 | 71 | @property 72 | def log_prob(self): 73 | # the log probability of the hypothesis so far is the sum of the log probabilities of the tokens so far 74 | return sum(self.log_probs) 75 | 76 | @property 77 | def avg_log_prob(self): 78 | # normalize log probability by number of tokens (otherwise longer sequences always have lower probability) 79 | return self.log_prob / len(self.tokens) 80 | 81 | 82 | def run_beam_search(sess, model, vocab, batch): 83 | """Performs beam search decoding on the given example. 84 | 85 | Args: 86 | sess: a tf.Session 87 | model: a seq2seq model 88 | vocab: Vocabulary object 89 | batch: Batch object that is the same example repeated across the batch 90 | 91 | Returns: 92 | best_hyp: Hypothesis object; the best hypothesis found by beam search. 93 | """ 94 | # Run the encoder to get the encoder hidden states and decoder initial state 95 | enc_states, dec_in_state = model.run_encoder(sess, batch) 96 | # dec_in_state is a LSTMStateTuple 97 | # enc_states has shape [batch_size, <=max_enc_steps, 2*hidden_dim]. 98 | 99 | # Initialize beam_size-many hyptheses 100 | hyps = [Hypothesis(tokens=[vocab.word2id(data.START_DECODING)], 101 | log_probs=[0.0], 102 | state=dec_in_state, 103 | attn_dists=[], 104 | p_gens=[], 105 | coverage=np.zeros([batch.enc_batch.shape[1]]) # zero vector of length attention_length 106 | ) for _ in xrange(FLAGS.beam_size)] 107 | results = [] # this will contain finished hypotheses (those that have emitted the [STOP] token) 108 | 109 | steps = 0 110 | while steps < FLAGS.max_dec_steps and len(results) < FLAGS.beam_size: 111 | latest_tokens = [h.latest_token for h in hyps] # latest token produced by each hypothesis 112 | latest_tokens = [t if t in xrange(vocab.size()) else vocab.word2id(data.UNKNOWN_TOKEN) for t in 113 | latest_tokens] # change any in-article temporary OOV ids to [UNK] id, so that we can lookup word embeddings 114 | states = [h.state for h in hyps] # list of current decoder states of the hypotheses 115 | prev_coverage = [h.coverage for h in hyps] # list of coverage vectors (or None) 116 | 117 | # Run one step of the decoder to get the new info 118 | (topk_ids, topk_log_probs, new_states, attn_dists, p_gens, new_coverage) = model.decode_onestep(sess=sess, 119 | batch=batch, 120 | latest_tokens=latest_tokens, 121 | enc_states=enc_states, 122 | dec_init_states=states, 123 | prev_coverage=prev_coverage) 124 | 125 | # Extend each hypothesis and collect them all in all_hyps 126 | all_hyps = [] 127 | num_orig_hyps = 1 if steps == 0 else len( 128 | hyps) # On the first step, we only had one original hypothesis (the initial hypothesis). On subsequent steps, all original hypotheses are distinct. 129 | for i in xrange(num_orig_hyps): 130 | h, new_state, attn_dist, p_gen, new_coverage_i = hyps[i], new_states[i], attn_dists[i], p_gens[i], \ 131 | new_coverage[ 132 | i] # take the ith hypothesis and new decoder state info 133 | for j in xrange(FLAGS.beam_size * 2): # for each of the top 2*beam_size hyps: 134 | # Extend the ith hypothesis with the jth option 135 | new_hyp = h.extend(token=topk_ids[i, j], 136 | log_prob=topk_log_probs[i, j], 137 | state=new_state, 138 | attn_dist=attn_dist, 139 | p_gen=p_gen, 140 | coverage=new_coverage_i) 141 | all_hyps.append(new_hyp) 142 | 143 | # Filter and collect any hypotheses that have produced the end token. 144 | hyps = [] # will contain hypotheses for the next step 145 | for h in sort_hyps(all_hyps): # in order of most likely h 146 | if h.latest_token == vocab.word2id(data.STOP_DECODING): # if stop token is reached... 147 | # If this hypothesis is sufficiently long, put in results. Otherwise discard. 148 | if steps >= FLAGS.min_dec_steps: 149 | results.append(h) 150 | else: # hasn't reached stop token, so continue to extend this hypothesis 151 | hyps.append(h) 152 | if len(hyps) == FLAGS.beam_size or len(results) == FLAGS.beam_size: 153 | # Once we've collected beam_size-many hypotheses for the next step, or beam_size-many complete hypotheses, stop. 154 | break 155 | 156 | steps += 1 157 | 158 | # At this point, either we've got beam_size results, or we've reached maximum decoder steps 159 | 160 | if len( 161 | results) == 0: # if we don't have any complete results, add all current hypotheses (incomplete summaries) to results 162 | results = hyps 163 | 164 | # Sort hypotheses by average log probability 165 | hyps_sorted = sort_hyps(results) 166 | 167 | # Return the hypothesis with highest average log prob 168 | return hyps_sorted[0] 169 | 170 | 171 | def sort_hyps(hyps): 172 | """Return a list of Hypothesis objects, sorted by descending average log probability""" 173 | return sorted(hyps, key=lambda h: h.avg_log_prob, reverse=True) 174 | -------------------------------------------------------------------------------- /src/main/java/me/littlebo/Summarization.java: -------------------------------------------------------------------------------- 1 | package me.littlebo; 2 | 3 | import com.alibaba.flink.ml.operator.coding.RowCSVCoding; 4 | import com.alibaba.flink.ml.tensorflow.client.TFConfig; 5 | import com.alibaba.flink.ml.tensorflow.client.TFUtils; 6 | import com.alibaba.flink.ml.tensorflow.coding.ExampleCoding; 7 | import com.alibaba.flink.ml.tensorflow.coding.ExampleCodingConfig; 8 | import com.alibaba.flink.ml.tensorflow.coding.ExampleCodingConfig.ObjectType; 9 | import com.alibaba.flink.ml.tensorflow.util.TFConstants; 10 | import com.alibaba.flink.ml.util.MLConstants; 11 | import org.apache.curator.test.TestingServer; 12 | import org.apache.flink.api.common.typeinfo.BasicTypeInfo; 13 | import org.apache.flink.api.common.typeinfo.TypeInformation; 14 | import org.apache.flink.api.java.typeutils.RowTypeInfo; 15 | import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; 16 | import org.apache.flink.table.api.Table; 17 | import org.apache.flink.table.api.TableSchema; 18 | import org.apache.flink.table.api.java.StreamTableEnvironment; 19 | import org.apache.flink.table.ml.lib.tensorflow.util.CodingUtils; 20 | import org.slf4j.Logger; 21 | import org.slf4j.LoggerFactory; 22 | 23 | import java.util.Arrays; 24 | import java.util.HashMap; 25 | import java.util.List; 26 | import java.util.Map; 27 | 28 | @Deprecated 29 | public class Summarization { 30 | private static Logger LOG = LoggerFactory.getLogger(Summarization.class); 31 | public static final String CONFIG_HYPERPARAMETER = "TF_Hyperparameter"; 32 | public static final String[] scripts = { 33 | "/Users/bodeng/TextSummarization-On-Flink/src/main/python/pointer-generator/run_summarization.py", 34 | "/Users/bodeng/TextSummarization-On-Flink/src/main/python/pointer-generator/__init__.py", 35 | "/Users/bodeng/TextSummarization-On-Flink/src/main/python/pointer-generator/attention_decoder.py", 36 | "/Users/bodeng/TextSummarization-On-Flink/src/main/python/pointer-generator/batcher.py", 37 | "/Users/bodeng/TextSummarization-On-Flink/src/main/python/pointer-generator/beam_search.py", 38 | "/Users/bodeng/TextSummarization-On-Flink/src/main/python/pointer-generator/data.py", 39 | "/Users/bodeng/TextSummarization-On-Flink/src/main/python/pointer-generator/decode.py", 40 | "/Users/bodeng/TextSummarization-On-Flink/src/main/python/pointer-generator/inspect_checkpoint.py", 41 | "/Users/bodeng/TextSummarization-On-Flink/src/main/python/pointer-generator/model.py", 42 | "/Users/bodeng/TextSummarization-On-Flink/src/main/python/pointer-generator/util.py", 43 | "/Users/bodeng/TextSummarization-On-Flink/src/main/python/pointer-generator/flink_writer.py", 44 | "/Users/bodeng/TextSummarization-On-Flink/src/main/python/pointer-generator/train.py", 45 | }; 46 | 47 | private static void setExampleCodingTypeRow(TFConfig config) { 48 | String[] names = {"article"}; 49 | com.alibaba.flink.ml.operator.util.DataTypes[] types = {com.alibaba.flink.ml.operator.util.DataTypes.STRING}; 50 | String strInput = ExampleCodingConfig.createExampleConfigStr(names, types, 51 | ExampleCodingConfig.ObjectType.ROW, String.class); 52 | config.getProperties().put(TFConstants.INPUT_TF_EXAMPLE_CONFIG, strInput); 53 | LOG.info("input tf example config: " + strInput); 54 | 55 | String[] namesOutput = {"abstract", "reference"}; 56 | com.alibaba.flink.ml.operator.util.DataTypes[] typesOutput = {com.alibaba.flink.ml.operator.util.DataTypes.STRING, 57 | com.alibaba.flink.ml.operator.util.DataTypes.STRING}; 58 | String strOutput = ExampleCodingConfig.createExampleConfigStr(namesOutput, typesOutput, 59 | ExampleCodingConfig.ObjectType.ROW, String.class); 60 | config.getProperties().put(TFConstants.OUTPUT_TF_EXAMPLE_CONFIG, strOutput); 61 | LOG.info("output tf example config: " + strOutput); 62 | 63 | config.getProperties().put(MLConstants.ENCODING_CLASS, ExampleCoding.class.getCanonicalName()); 64 | config.getProperties().put(MLConstants.DECODING_CLASS, ExampleCoding.class.getCanonicalName()); 65 | } 66 | 67 | private static void setCsvCodingTypeRow(TFConfig config) { 68 | config.getProperties().put(RowCSVCoding.DELIM_CONFIG, "#"); 69 | config.getProperties().put(MLConstants.ENCODING_CLASS, RowCSVCoding.class.getCanonicalName()); 70 | config.getProperties().put(MLConstants.DECODING_CLASS, RowCSVCoding.class.getCanonicalName()); 71 | 72 | StringBuilder inputSb = new StringBuilder(); 73 | inputSb.append(com.alibaba.flink.ml.operator.util.DataTypes.STRING.name()); 74 | config.getProperties().put(RowCSVCoding.ENCODE_TYPES, inputSb.toString()); 75 | inputSb.append(",").append(com.alibaba.flink.ml.operator.util.DataTypes.STRING.name()); 76 | config.getProperties().put(RowCSVCoding.DECODE_TYPES, inputSb.toString()); 77 | } 78 | 79 | public static void inference() throws Exception { 80 | // local zookeeper server 81 | TestingServer server = new TestingServer(2181, true); 82 | String[] hyperparameter = { 83 | "run_summarization.py", // first param is uesless but required 84 | "--mode=decode", 85 | "--data_path=/Users/bodeng/TextSummarization-On-Flink/data/cnn-dailymail/cnn_stories_test/0*", 86 | "--vocab_path=/Users/bodeng/TextSummarization-On-Flink/data/cnn-dailymail/finished_files/vocab", 87 | "--log_root=/Users/bodeng/TextSummarization-On-Flink/log", 88 | "--exp_name=pretrained_model_tf1.2.1", 89 | "--max_enc_steps=400", 90 | "--max_dec_steps=100", 91 | "--coverage=1", 92 | "--single_pass=1", 93 | "--inference=1", 94 | }; 95 | 96 | StreamExecutionEnvironment streamEnv = StreamExecutionEnvironment.getExecutionEnvironment(); 97 | StreamTableEnvironment tableEnv = StreamTableEnvironment.create(streamEnv); 98 | // tableEnv.registerFunction("LEN", new LEN()); 99 | Table input = tableEnv.fromDataStream(streamEnv.fromCollection(createArticleData()), "article"); 100 | 101 | // if zookeeper has other address 102 | Map prop = new HashMap<>(); 103 | prop.put(MLConstants.CONFIG_STORAGE_TYPE, MLConstants.STORAGE_ZOOKEEPER); 104 | prop.put(MLConstants.CONFIG_ZOOKEEPER_CONNECT_STR, "127.0.0.1:2181"); 105 | prop.put(CONFIG_HYPERPARAMETER, String.join(" ", hyperparameter)); 106 | TFConfig config = new TFConfig(1, 0, prop, scripts, "main_on_flink", null); 107 | // setCsvCodingTypeRow(config); 108 | // setExampleCodingTypeRow(config); 109 | 110 | TableSchema outSchema = new TableSchema(new String[]{"abstract", "reference"}, 111 | new TypeInformation[]{BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO}); 112 | CodingUtils.configureExampleCoding(config, input.getSchema(), outSchema, ObjectType.ROW, String.class); 113 | // input = input.select("LEN(article) as len, article"); 114 | input.printSchema(); 115 | tableEnv.toRetractStream(input, new RowTypeInfo(BasicTypeInfo.STRING_TYPE_INFO)).print(); 116 | 117 | Table resultTable = TFUtils.inference(streamEnv, tableEnv, input, config, outSchema); 118 | tableEnv.toRetractStream(resultTable, 119 | new RowTypeInfo(BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO)).print(); 120 | streamEnv.execute(); 121 | 122 | server.stop(); 123 | } 124 | 125 | public static void training() throws Exception { 126 | // local zookeeper server 127 | TestingServer server = new TestingServer(2181, true); 128 | String[] hyperparameter = { 129 | "run_summarization.py", // first param is uesless but required 130 | "--mode=train", 131 | "--data_path=/Users/bodeng/TextSummarization-On-Flink/data/cnn-dailymail/finished_files/chunked/train_*", 132 | "--vocab_path=/Users/bodeng/TextSummarization-On-Flink/data/cnn-dailymail/finished_files/vocab", 133 | "--log_root=/Users/bodeng/TextSummarization-On-Flink/log", 134 | "--exp_name=pretrained_model_tf1.2.1", 135 | "--max_enc_steps=400", 136 | "--max_dec_steps=100", 137 | "--coverage=1", 138 | "--num_steps=3", 139 | }; 140 | 141 | 142 | StreamExecutionEnvironment streamEnv = StreamExecutionEnvironment.createLocalEnvironment(); 143 | StreamTableEnvironment tableEnv = StreamTableEnvironment.create(streamEnv); 144 | 145 | // if zookeeper has other address 146 | Map prop = new HashMap<>(); 147 | prop.put(MLConstants.CONFIG_STORAGE_TYPE, MLConstants.STORAGE_ZOOKEEPER); 148 | prop.put(MLConstants.CONFIG_ZOOKEEPER_CONNECT_STR, "127.0.0.1:2181"); 149 | prop.put(CONFIG_HYPERPARAMETER, String.join(" ", hyperparameter)); 150 | TFConfig config = new TFConfig(1, 0, prop, scripts, "main_on_flink", null); 151 | 152 | TFUtils.train(streamEnv, tableEnv, null, config, null); 153 | streamEnv.execute(); 154 | server.stop(); 155 | } 156 | 157 | private static List createArticleData() { 158 | return Arrays.asList("article 1.", "article 2.", "article 3.", "article 4.", "article 5.", 159 | "article 6.", "article 7.", "article 8.", "article 9.", "article 10."); 160 | } 161 | 162 | public static void main(String[] args) throws Exception { 163 | training(); 164 | // inference(); 165 | } 166 | } 167 | -------------------------------------------------------------------------------- /src/main/python/pointer-generator/README.md: -------------------------------------------------------------------------------- 1 | *Note: this code is no longer actively maintained. However, feel free to use the Issues section to discuss the code with other users. Some users have updated this code for newer versions of Tensorflow and Python - see information below and Issues section.* 2 | 3 | --- 4 | 5 | This repository contains code for the ACL 2017 paper *[Get To The Point: Summarization with Pointer-Generator Networks](https://arxiv.org/abs/1704.04368)*. For an intuitive overview of the paper, read the [blog post](http://www.abigailsee.com/2017/04/16/taming-rnns-for-better-summarization.html). 6 | 7 | ## Looking for test set output? 8 | The test set output of the models described in the paper can be found [here](https://drive.google.com/file/d/0B7pQmm-OfDv7MEtMVU5sOHc5LTg/view?usp=sharing). 9 | 10 | ## Looking for pretrained model? 11 | A pretrained model is available here: 12 | * [Version for Tensorflow 1.0](https://drive.google.com/open?id=0B7pQmm-OfDv7SHFadHR4RllfR1E) 13 | * [Version for Tensorflow 1.2.1](https://drive.google.com/open?id=0B7pQmm-OfDv7ZUhHZm9ZWEZidDg) 14 | 15 | (The only difference between these two is the naming of some of the variables in the checkpoint. Tensorflow 1.0 uses `lstm_cell/biases` and `lstm_cell/weights` whereas Tensorflow 1.2.1 uses `lstm_cell/bias` and `lstm_cell/kernel`). 16 | 17 | **Note**: This pretrained model is *not* the exact same model that is reported in the paper. That is, it is the same architecture, trained with the same settings, but resulting from a different training run. Consequently this pretrained model has slightly lower ROUGE scores than those reported in the paper. This is probably due to us slightly overfitting to the randomness in our original experiments (in the original experiments we tried various hyperparameter settings and selected the model that performed best). Repeating the experiment once with the same settings did not perform quite as well. Better results might be obtained from further hyperparameter tuning. 18 | 19 | **Why can't you release the trained model reported in the paper?** Due to changes to the code between the original experiments and the time of releasing the code (e.g. TensorFlow version changes, lots of code cleanup), it is not possible to release the original trained model files. 20 | 21 | ## Looking for CNN / Daily Mail data? 22 | Instructions are [here](https://github.com/abisee/cnn-dailymail). 23 | 24 | ## About this code 25 | This code is based on the [TextSum code](https://github.com/tensorflow/models/tree/master/textsum) from Google Brain. 26 | 27 | This code was developed for Tensorflow 0.12, but has been updated to run with Tensorflow 1.0. 28 | In particular, the code in attention_decoder.py is based on [tf.contrib.legacy_seq2seq_attention_decoder](https://www.tensorflow.org/api_docs/python/tf/contrib/legacy_seq2seq/attention_decoder), which is now outdated. 29 | Tensorflow 1.0's [new seq2seq library](https://www.tensorflow.org/api_guides/python/contrib.seq2seq#Attention) probably provides a way to do this (as well as beam search) more elegantly and efficiently in the future. 30 | 31 | **Python 3 version**: This code is in Python 2. If you want a Python 3 version, see [@becxer's fork](https://github.com/becxer/pointer-generator/). 32 | 33 | ## How to run 34 | 35 | ### Get the dataset 36 | To obtain the CNN / Daily Mail dataset, follow the instructions [here](https://github.com/abisee/cnn-dailymail). Once finished, you should have [chunked](https://github.com/abisee/cnn-dailymail/issues/3) datafiles `train_000.bin`, ..., `train_287.bin`, `val_000.bin`, ..., `val_013.bin`, `test_000.bin`, ..., `test_011.bin` (each contains 1000 examples) and a vocabulary file `vocab`. 37 | 38 | **Note**: If you did this before 7th May 2017, follow the instructions [here](https://github.com/abisee/cnn-dailymail/issues/2) to correct a bug in the process. 39 | 40 | ### Run training 41 | To train your model, run: 42 | 43 | ``` 44 | python run_summarization.py --mode=train --data_path=/path/to/chunked/train_* --vocab_path=/path/to/vocab --log_root=/path/to/a/log/directory --exp_name=myexperiment 45 | ``` 46 | 47 | This will create a subdirectory of your specified `log_root` called `myexperiment` where all checkpoints and other data will be saved. Then the model will start training using the `train_*.bin` files as training data. 48 | 49 | **Warning**: Using default settings as in the above command, both initializing the model and running training iterations will probably be quite slow. To make things faster, try setting the following flags (especially `max_enc_steps` and `max_dec_steps`) to something smaller than the defaults specified in `run_summarization.py`: `hidden_dim`, `emb_dim`, `batch_size`, `max_enc_steps`, `max_dec_steps`, `vocab_size`. 50 | 51 | **Increasing sequence length during training**: Note that to obtain the results described in the paper, we increase the values of `max_enc_steps` and `max_dec_steps` in stages throughout training (mostly so we can perform quicker iterations during early stages of training). If you wish to do the same, start with small values of `max_enc_steps` and `max_dec_steps`, then interrupt and restart the job with larger values when you want to increase them. 52 | 53 | ### Run (concurrent) eval 54 | You may want to run a concurrent evaluation job, that runs your model on the validation set and logs the loss. To do this, run: 55 | 56 | ``` 57 | python run_summarization.py --mode=eval --data_path=/path/to/chunked/val_* --vocab_path=/path/to/vocab --log_root=/path/to/a/log/directory --exp_name=myexperiment 58 | ``` 59 | 60 | Note: you want to run the above command using the same settings you entered for your training job. 61 | 62 | **Restoring snapshots**: The eval job saves a snapshot of the model that scored the lowest loss on the validation data so far. You may want to restore one of these "best models", e.g. if your training job has overfit, or if the training checkpoint has become corrupted by NaN values. To do this, run your train command plus the `--restore_best_model=1` flag. This will copy the best model in the eval directory to the train directory. Then run the usual train command again. 63 | 64 | ### Run beam search decoding 65 | To run beam search decoding: 66 | 67 | ``` 68 | python run_summarization.py --mode=decode --data_path=/path/to/chunked/val_* --vocab_path=/path/to/vocab --log_root=/path/to/a/log/directory --exp_name=myexperiment 69 | ``` 70 | 71 | Note: you want to run the above command using the same settings you entered for your training job (plus any decode mode specific flags like `beam_size`). 72 | 73 | This will repeatedly load random examples from your specified datafile and generate a summary using beam search. The results will be printed to screen. 74 | 75 | **Visualize your output**: Additionally, the decode job produces a file called `attn_vis_data.json`. This file provides the data necessary for an in-browser visualization tool that allows you to view the attention distributions projected onto the text. To use the visualizer, follow the instructions [here](https://github.com/abisee/attn_vis). 76 | 77 | If you want to run evaluation on the entire validation or test set and get ROUGE scores, set the flag `single_pass=1`. This will go through the entire dataset in order, writing the generated summaries to file, and then run evaluation using [pyrouge](https://pypi.python.org/pypi/pyrouge). (Note this will *not* produce the `attn_vis_data.json` files for the attention visualizer). 78 | 79 | ### Evaluate with ROUGE 80 | `decode.py` uses the Python package [`pyrouge`](https://pypi.python.org/pypi/pyrouge) to run ROUGE evaluation. `pyrouge` provides an easier-to-use interface for the official Perl ROUGE package, which you must install for `pyrouge` to work. Here are some useful instructions on how to do this: 81 | * [How to setup Perl ROUGE](http://kavita-ganesan.com/rouge-howto) 82 | * [More details about plugins for Perl ROUGE](http://www.summarizerman.com/post/42675198985/figuring-out-rouge) 83 | 84 | **Note:** As of 18th May 2017 the [website](http://berouge.com/) for the official Perl package appears to be down. Unfortunately you need to download a directory called `ROUGE-1.5.5` from there. As an alternative, it seems that you can get that directory from [here](https://github.com/andersjo/pyrouge) (however, the version of `pyrouge` in that repo appears to be outdated, so best to install `pyrouge` from the [official source](https://pypi.python.org/pypi/pyrouge)). 85 | 86 | ### Tensorboard 87 | Run Tensorboard from the experiment directory (in the example above, `myexperiment`). You should be able to see data from the train and eval runs. If you select "embeddings", you should also see your word embeddings visualized. 88 | 89 | ### Help, I've got NaNs! 90 | For reasons that are [difficult to diagnose](https://github.com/abisee/pointer-generator/issues/4), NaNs sometimes occur during training, making the loss=NaN and sometimes also corrupting the model checkpoint with NaN values, making it unusable. Here are some suggestions: 91 | 92 | * If training stopped with the `Loss is not finite. Stopping.` exception, you can just try restarting. It may be that the checkpoint is not corrupted. 93 | * You can check if your checkpoint is corrupted by using the `inspect_checkpoint.py` script. If it says that all values are finite, then your checkpoint is OK and you can try resuming training with it. 94 | * The training job is set to keep 3 checkpoints at any one time (see the `max_to_keep` variable in `run_summarization.py`). If your newer checkpoint is corrupted, it may be that one of the older ones is not. You can switch to that checkpoint by editing the `checkpoint` file inside the `train` directory. 95 | * Alternatively, you can restore a "best model" from the `eval` directory. See the note **Restoring snapshots** above. 96 | * If you want to try to diagnose the cause of the NaNs, you can run with the `--debug=1` flag turned on. This will run [Tensorflow Debugger](https://www.tensorflow.org/versions/master/programmers_guide/debugger), which checks for NaNs and diagnoses their causes during training. 97 | -------------------------------------------------------------------------------- /src/test/python/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import datetime 4 | 5 | import tensorflow as tf 6 | from flink_ml_tensorflow.tensorflow_context import TFContext 7 | 8 | 9 | class FlinkReader(object): 10 | def __init__(self, context, batch_size=1, features={'input': tf.FixedLenFeature([], tf.string)}): 11 | self._context = context 12 | self._batch_size = batch_size 13 | self._features = features 14 | self._build_graph() 15 | 16 | def _decode(self, features): 17 | return features['input'] 18 | 19 | def _build_graph(self): 20 | dataset = self._context.flink_stream_dataset() 21 | dataset = dataset.map(lambda record: tf.parse_single_example(record, features=self._features)) 22 | dataset = dataset.map(self._decode) 23 | dataset = dataset.batch(self._batch_size) 24 | iterator = dataset.make_one_shot_iterator() 25 | self._next_batch = iterator.get_next() 26 | 27 | def next_batch(self, sess): 28 | try: 29 | batch = sess.run(self._next_batch) 30 | return batch 31 | except tf.errors.OutOfRangeError: 32 | return None 33 | 34 | 35 | class FlinkWriter(object): 36 | def __init__(self, context): 37 | self._context = context 38 | self._build_graph() 39 | 40 | def _build_graph(self): 41 | self._write_feed = tf.placeholder(dtype=tf.string) 42 | self.write_op, self._close_op = self._context.output_writer_op([self._write_feed]) 43 | 44 | def _example(self, results): 45 | example = tf.train.Example(features=tf.train.Features( 46 | feature={ 47 | 'output': tf.train.Feature(bytes_list=tf.train.BytesList(value=[results[0]])), 48 | } 49 | )) 50 | return example 51 | 52 | def write_result(self, sess, results): 53 | sess.run(self.write_op, feed_dict={self._write_feed: self._example(results).SerializeToString()}) 54 | 55 | def close(self, sess): 56 | sess.run(self._close_op) 57 | 58 | 59 | def flink_server_device(tf_context): 60 | index = tf_context.get_index() 61 | job_name = tf_context.get_role_name() 62 | cluster_json = tf_context.get_tf_cluster() 63 | cluster = tf.train.ClusterSpec(cluster=cluster_json) 64 | 65 | server = tf.train.Server(cluster, job_name=job_name, task_index=index) 66 | sess_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False, 67 | device_filters=["/job:ps", "/job:worker/task:%d" % index]) 68 | device = tf.device(tf.train.replica_device_setter(worker_device='/job:worker/task:' + str(index), cluster=cluster)) 69 | return (device, server, sess_config) 70 | 71 | 72 | def test_example_coding(context): 73 | tf_context = TFContext(context) 74 | if 'ps' == tf_context.get_role_name(): 75 | from time import sleep 76 | while True: 77 | sleep(1) 78 | else: 79 | index = tf_context.get_index() 80 | job_name = tf_context.get_role_name() 81 | cluster_json = tf_context.get_tf_cluster() 82 | cluster = tf.train.ClusterSpec(cluster=cluster_json) 83 | 84 | server = tf.train.Server(cluster, job_name=job_name, task_index=index) 85 | sess_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False, 86 | device_filters=["/job:ps", "/job:worker/task:%d" % index]) 87 | with tf.device(tf.train.replica_device_setter(worker_device='/job:worker/task:' + str(index), cluster=cluster)): 88 | # with flink_server_device(tf_context) as (device, server, sess_config): 89 | reader = FlinkReader(tf_context) 90 | writer = FlinkWriter(tf_context) 91 | 92 | with tf.train.ChiefSessionCreator(master=server.target, config=sess_config).create_session() as sess: 93 | while True: 94 | batch = reader.next_batch(sess) 95 | tf.logging.info(str(batch)) 96 | if batch is None: 97 | break 98 | writer.write_result(sess, batch) 99 | writer.close(sess) 100 | sys.stdout.flush() 101 | 102 | 103 | def test_example_coding_without_encode(context): 104 | tf_context = TFContext(context) 105 | if 'ps' == tf_context.get_role_name(): 106 | from time import sleep 107 | while True: 108 | sleep(1) 109 | else: 110 | index = tf_context.get_index() 111 | job_name = tf_context.get_role_name() 112 | cluster_json = tf_context.get_tf_cluster() 113 | cluster = tf.train.ClusterSpec(cluster=cluster_json) 114 | 115 | server = tf.train.Server(cluster, job_name=job_name, task_index=index) 116 | sess_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False, 117 | device_filters=["/job:ps", "/job:worker/task:%d" % index]) 118 | with tf.device(tf.train.replica_device_setter(worker_device='/job:worker/task:' + str(index), cluster=cluster)): 119 | # with flink_server_device(tf_context) as (device, server, sess_config): 120 | # reader = FlinkReader(tf_context) 121 | writer = FlinkWriter(tf_context) 122 | 123 | with tf.train.ChiefSessionCreator(master=server.target, config=sess_config).create_session() as sess: 124 | # while True: 125 | # batch = reader.next_batch(sess) 126 | # if batch is None: 127 | # break 128 | # writer.write_result(sess, batch) 129 | for i in range(10): 130 | writer.write_result(sess, ['output-%d' % i]) 131 | writer.close(sess) 132 | sys.stdout.flush() 133 | 134 | 135 | def test_example_coding_without_decode(context): 136 | tf_context = TFContext(context) 137 | if 'ps' == tf_context.get_role_name(): 138 | from time import sleep 139 | while True: 140 | sleep(1) 141 | else: 142 | index = tf_context.get_index() 143 | job_name = tf_context.get_role_name() 144 | cluster_json = tf_context.get_tf_cluster() 145 | cluster = tf.train.ClusterSpec(cluster=cluster_json) 146 | 147 | server = tf.train.Server(cluster, job_name=job_name, task_index=index) 148 | sess_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False, 149 | device_filters=["/job:ps", "/job:worker/task:%d" % index]) 150 | with tf.device(tf.train.replica_device_setter(worker_device='/job:worker/task:' + str(index), cluster=cluster)): 151 | # with flink_server_device(tf_context) as (device, server, sess_config): 152 | reader = FlinkReader(tf_context) 153 | # writer = FlinkWriter(tf_context) 154 | 155 | with tf.train.ChiefSessionCreator(master=server.target, config=sess_config).create_session() as sess: 156 | while True: 157 | batch = reader.next_batch(sess) 158 | tf.logging.info(str(batch)) 159 | if batch is None: 160 | break 161 | # writer.write_result(sess, batch) 162 | # writer.close(sess) 163 | sys.stdout.flush() 164 | 165 | 166 | def test_example_coding_with_nothing(context): 167 | tf_context = TFContext(context) 168 | if 'ps' == tf_context.get_role_name(): 169 | from time import sleep 170 | while True: 171 | sleep(1) 172 | else: 173 | index = tf_context.get_index() 174 | job_name = tf_context.get_role_name() 175 | cluster_json = tf_context.get_tf_cluster() 176 | cluster = tf.train.ClusterSpec(cluster=cluster_json) 177 | 178 | server = tf.train.Server(cluster, job_name=job_name, task_index=index) 179 | sess_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False, 180 | device_filters=["/job:ps", "/job:worker/task:%d" % index]) 181 | with tf.device(tf.train.replica_device_setter(worker_device='/job:worker/task:' + str(index), cluster=cluster)): 182 | # with flink_server_device(tf_context) as (device, server, sess_config): 183 | # reader = FlinkReader(tf_context) 184 | # writer = FlinkWriter(tf_context) 185 | 186 | with tf.train.ChiefSessionCreator(master=server.target, config=sess_config).create_session() as sess: 187 | tf.logging.info('do nothing') 188 | # while True: 189 | # batch = reader.next_batch(sess) 190 | # if batch is None: 191 | # break 192 | # writer.write_result(sess, batch) 193 | # writer.close(sess) 194 | sys.stdout.flush() 195 | 196 | 197 | def test_source_sink(context): 198 | tf_context = TFContext(context) 199 | if 'ps' == tf_context.get_role_name(): 200 | from time import sleep 201 | while True: 202 | sleep(1) 203 | else: 204 | index = tf_context.get_index() 205 | job_name = tf_context.get_role_name() 206 | cluster_json = tf_context.get_tf_cluster() 207 | cluster = tf.train.ClusterSpec(cluster=cluster_json) 208 | 209 | server = tf.train.Server(cluster, job_name=job_name, task_index=index) 210 | sess_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False, 211 | device_filters=["/job:ps", "/job:worker/task:%d" % index]) 212 | with tf.device(tf.train.replica_device_setter(worker_device='/job:worker/task:' + str(index), cluster=cluster)): 213 | # with flink_server_device(tf_context) as (device, server, sess_config): 214 | reader = FlinkReader(tf_context) 215 | writer = FlinkWriter(tf_context) 216 | 217 | with tf.train.ChiefSessionCreator(master=server.target, config=sess_config).create_session() as sess: 218 | while True: 219 | batch = reader.next_batch(sess) 220 | if batch is None: 221 | break 222 | # tf.logging.info("[TF][%s]process %s" % (str(datetime.datetime.now()), str(batch))) 223 | 224 | writer.write_result(sess, batch) 225 | writer.close(sess) 226 | sys.stdout.flush() 227 | -------------------------------------------------------------------------------- /src/main/java/me/littlebo/App.java: -------------------------------------------------------------------------------- 1 | package me.littlebo; 2 | 3 | import com.alibaba.flink.ml.operator.util.DataTypes; 4 | import org.apache.flink.api.java.ExecutionEnvironment; 5 | import org.apache.flink.streaming.api.datastream.DataStream; 6 | import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; 7 | import org.apache.flink.streaming.connectors.kafka.FlinkKafkaConsumer; 8 | import org.apache.flink.streaming.connectors.kafka.FlinkKafkaProducer; 9 | import org.apache.flink.table.api.Table; 10 | import org.apache.flink.table.api.java.StreamTableEnvironment; 11 | import org.apache.flink.table.ml.lib.tensorflow.TFEstimator; 12 | import org.apache.flink.table.ml.lib.tensorflow.TFModel; 13 | import org.apache.flink.types.Row; 14 | import org.slf4j.Logger; 15 | import org.slf4j.LoggerFactory; 16 | 17 | import java.util.ArrayList; 18 | import java.util.List; 19 | import java.util.Properties; 20 | 21 | /** 22 | * An end-to-end document summarization application. 23 | * In training mode, the TFEstimator uses data from kafka source to train 24 | * a TensorFlor model implemented by Flink-AI-Extended and return a TFModel. 25 | * 26 | * In reference mode, the TFModel process the summarization request from 27 | * kafka source and write the result to a kafka sink. 28 | */ 29 | public class App { 30 | public static final int MAX_ROW_COUNT = 8; 31 | 32 | public static final String trainTopic = "flink_train"; 33 | public static final String inputTopic = "flink_input"; 34 | public static final String outputTopic = "flink_output"; 35 | public static final String consumerGroup = "bode"; 36 | public static final String kafkaAddress = "127.0.0.1:9092"; 37 | 38 | public static final Logger LOG = LoggerFactory.getLogger(App.class); 39 | private static final String projectDir = SysUtils.getProjectRootDir(); 40 | public static final String[] scripts = { 41 | projectDir + "/src/main/python/pointer-generator/run_summarization.py", 42 | projectDir + "/src/main/python/pointer-generator/__init__.py", 43 | projectDir + "/src/main/python/pointer-generator/attention_decoder.py", 44 | projectDir + "/src/main/python/pointer-generator/batcher.py", 45 | projectDir + "/src/main/python/pointer-generator/beam_search.py", 46 | projectDir + "/src/main/python/pointer-generator/data.py", 47 | projectDir + "/src/main/python/pointer-generator/decode.py", 48 | projectDir + "/src/main/python/pointer-generator/inspect_checkpoint.py", 49 | projectDir + "/src/main/python/pointer-generator/model.py", 50 | projectDir + "/src/main/python/pointer-generator/util.py", 51 | projectDir + "/src/main/python/pointer-generator/flink_writer.py", 52 | projectDir + "/src/main/python/pointer-generator/train.py", 53 | }; 54 | private static final String hyperparameter_key = "TF_Hyperparameter"; 55 | public static final String[] inference_hyperparameter = { 56 | "run_summarization.py", // first param is uesless but placeholder 57 | "--mode=decode", 58 | "--data_path=" + projectDir + "/data/cnn-dailymail/cnn_stories_test/0*", 59 | "--vocab_path=" + projectDir + "/data/cnn-dailymail/finished_files/vocab", 60 | "--log_root=" + projectDir + "/log", 61 | "--exp_name=pretrained_model_tf1.2.1", 62 | "--batch_size=2", // default to 16 63 | "--max_enc_steps=400", 64 | "--max_dec_steps=100", 65 | "--coverage=1", 66 | "--single_pass=1", 67 | "--inference=1", 68 | }; 69 | public static final String[] train_hyperparameter = { 70 | "run_summarization.py", // first param is uesless but placeholder 71 | "--mode=train", 72 | "--data_path=" + projectDir + "/data/cnn-dailymail/finished_files/chunked/train_*", 73 | "--vocab_path=" + projectDir + "/data/cnn-dailymail/finished_files/vocab", 74 | "--log_root=" + projectDir + "/log", 75 | "--exp_name=pretrained_model_tf1.2.1", 76 | "--batch_size=2", // default to 16 77 | "--max_enc_steps=400", 78 | "--max_dec_steps=100", 79 | "--coverage=1", 80 | "--num_steps=1", // if 0, never stop 81 | }; 82 | 83 | public static String startTraining() throws Exception { 84 | ExecutionEnvironment executionEnvironment = ExecutionEnvironment.createLocalEnvironment(); 85 | 86 | StreamExecutionEnvironment streamEnv = StreamExecutionEnvironment.createLocalEnvironment(1); 87 | StreamTableEnvironment tableEnv = StreamTableEnvironment.create(streamEnv); 88 | FlinkKafkaConsumer kafkaConsumer = createMessageConsumer(trainTopic, kafkaAddress, consumerGroup); 89 | kafkaConsumer.setStartFromEarliest(); 90 | 91 | DataStream dataStream = streamEnv.addSource(kafkaConsumer).setParallelism(1); 92 | Table input = tableEnv.fromDataStream(dataStream, "uuid,article,summary,reference"); 93 | 94 | // Table input = tableEnv.fromDataStream(streamEnv.fromCollection(createArticleData()).setParallelism(1), 95 | // "uuid,article,summary,reference"); 96 | 97 | input.printSchema(); 98 | 99 | // input = input.select("uuid,article,reference"); 100 | tableEnv.toAppendStream(input, Row.class).print().setParallelism(1); 101 | TFEstimator estimator = createEstimator(); 102 | TFModel model = estimator.fit(tableEnv, input); 103 | streamEnv.execute(); 104 | LOG.info("trained model: " + model.toJson()); 105 | return model.toJson(); 106 | } 107 | 108 | public static void startInference(String modelJson) throws Exception { 109 | StreamExecutionEnvironment streamEnv = StreamExecutionEnvironment.createLocalEnvironment(1); 110 | StreamTableEnvironment tableEnv = StreamTableEnvironment.create(streamEnv); 111 | 112 | FlinkKafkaConsumer kafkaConsumer = createMessageConsumer(inputTopic, kafkaAddress, consumerGroup); 113 | kafkaConsumer.setStartFromEarliest(); 114 | FlinkKafkaProducer kafkaProducer = createStringProducer(outputTopic, kafkaAddress); 115 | 116 | Table input = tableEnv.fromDataStream(streamEnv 117 | .addSource(kafkaConsumer, "Kafaka Source") 118 | .setParallelism(1), "uuid,article,summary,reference"); 119 | input.printSchema(); 120 | 121 | tableEnv.toAppendStream(input, Row.class).print().setParallelism(1); 122 | 123 | input = input.select("uuid,article,reference"); 124 | TFModel model = createModel(); 125 | if (modelJson != null) { 126 | model.loadJson(modelJson); 127 | } 128 | Table output = model.transform(tableEnv, input); 129 | tableEnv.toAppendStream(output, Row.class).print().setParallelism(1); 130 | tableEnv.toAppendStream(output, Row.class).addSink(kafkaProducer).setParallelism(1); 131 | streamEnv.execute(); 132 | } 133 | 134 | public static FlinkKafkaConsumer createMessageConsumer(String topic, String kafkaAddress, String kafkaGroup) { 135 | Properties props = new Properties(); 136 | props.setProperty("bootstrap.servers", kafkaAddress); 137 | props.setProperty("group.id", kafkaGroup); 138 | return new FlinkKafkaConsumer<>(topic, new MessageDeserializationSchema(MAX_ROW_COUNT), props); 139 | } 140 | 141 | public static FlinkKafkaProducer createStringProducer(String topic, String kafkaAddress) { 142 | return new FlinkKafkaProducer<>(kafkaAddress, topic, new MessageSerializationSchema()); 143 | } 144 | 145 | public static TFModel createModel() { 146 | return new TFModel() 147 | .setZookeeperConnStr("127.0.0.1:2181") 148 | .setWorkerNum(1) 149 | .setPsNum(0) 150 | 151 | .setInferenceScripts(scripts) 152 | .setInferenceMapFunc("main_on_flink") 153 | .setInferenceHyperParamsKey(hyperparameter_key) 154 | .setInferenceHyperParams(inference_hyperparameter) 155 | .setInferenceEnvPath(null) 156 | 157 | .setInferenceSelectedCols(new String[]{ "uuid", "article", "reference" }) 158 | .setInferenceOutputCols(new String[]{ "uuid", "article", "summary", "reference" }) 159 | .setInferenceOutputTypes(new DataTypes[] {DataTypes.STRING, DataTypes.STRING, DataTypes.STRING, DataTypes.STRING}); 160 | } 161 | 162 | public static TFEstimator createEstimator() { 163 | return new TFEstimator() 164 | .setZookeeperConnStr("127.0.0.1:2181") 165 | .setWorkerNum(1) 166 | .setPsNum(0) 167 | 168 | .setTrainScripts(scripts) 169 | .setTrainMapFunc("main_on_flink") 170 | .setTrainHyperParamsKey(hyperparameter_key) 171 | .setTrainHyperParams(train_hyperparameter) 172 | .setTrainEnvPath(null) 173 | 174 | .setTrainSelectedCols(new String[]{ "uuid", "article", "reference" }) 175 | .setTrainOutputCols(new String[]{ "uuid"}) 176 | .setTrainOutputTypes(new DataTypes[]{ DataTypes.STRING }) 177 | 178 | .setInferenceScripts(scripts) 179 | .setInferenceMapFunc("main_on_flink") 180 | .setInferenceHyperParamsKey(hyperparameter_key) 181 | .setInferenceHyperParams(inference_hyperparameter) 182 | .setInferenceEnvPath(null) 183 | 184 | .setInferenceSelectedCols(new String[]{ "uuid", "article", "reference" }) 185 | .setInferenceOutputCols(new String[]{ "uuid", "article", "summary", "reference" }) 186 | .setInferenceOutputTypes(new DataTypes[] {DataTypes.STRING, DataTypes.STRING, DataTypes.STRING, DataTypes.STRING}); 187 | } 188 | 189 | private static List createArticleData() { 190 | List rows = new ArrayList<>(); 191 | for (int i = 0; i < 8; i++) { 192 | Row row = new Row(4); 193 | row.setField(0, String.format("uuid-%d", i)); 194 | row.setField(1, String.format("article %d.", i)); 195 | row.setField(2, ""); 196 | row.setField(3, String.format("reference %d.", i)); 197 | rows.add(row); 198 | } 199 | return rows; 200 | } 201 | 202 | public static void main(String[] args) throws Exception { 203 | // App.startTraining(); 204 | // App.startInference(null); 205 | String json = App.startTraining(); 206 | App.startInference(json); 207 | } 208 | } 209 | -------------------------------------------------------------------------------- /data/cnn-dailymail/make_datafiles.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import hashlib 4 | import struct 5 | import subprocess 6 | import collections 7 | import tensorflow as tf 8 | from tensorflow.core.example import example_pb2 9 | 10 | 11 | dm_single_close_quote = u'\u2019' # unicode 12 | dm_double_close_quote = u'\u201d' 13 | END_TOKENS = ['.', '!', '?', '...', "'", "`", '"', dm_single_close_quote, dm_double_close_quote, ")"] # acceptable ways to end a sentence 14 | 15 | # We use these to separate the summary sentences in the .bin datafiles 16 | SENTENCE_START = '' 17 | SENTENCE_END = '' 18 | 19 | all_train_urls = "url_lists/all_train.txt" 20 | all_val_urls = "url_lists/all_val.txt" 21 | all_test_urls = "url_lists/all_test.txt" 22 | 23 | cnn_tokenized_stories_dir = "cnn_stories_tokenized" 24 | dm_tokenized_stories_dir = "dm_stories_tokenized" 25 | finished_files_dir = "finished_files" 26 | chunks_dir = os.path.join(finished_files_dir, "chunked") 27 | 28 | # These are the number of .story files we expect there to be in cnn_stories_dir and dm_stories_dir 29 | num_expected_cnn_stories = 92579 30 | num_expected_dm_stories = 219506 31 | 32 | VOCAB_SIZE = 200000 33 | CHUNK_SIZE = 1000 # num examples per chunk, for the chunked data 34 | 35 | 36 | def chunk_file(set_name): 37 | in_file = 'finished_files/%s.bin' % set_name 38 | reader = open(in_file, "rb") 39 | chunk = 0 40 | finished = False 41 | while not finished: 42 | chunk_fname = os.path.join(chunks_dir, '%s_%03d.bin' % (set_name, chunk)) # new chunk 43 | with open(chunk_fname, 'wb') as writer: 44 | for _ in range(CHUNK_SIZE): 45 | len_bytes = reader.read(8) 46 | if not len_bytes: 47 | finished = True 48 | break 49 | str_len = struct.unpack('q', len_bytes)[0] 50 | example_str = struct.unpack('%ds' % str_len, reader.read(str_len))[0] 51 | writer.write(struct.pack('q', str_len)) 52 | writer.write(struct.pack('%ds' % str_len, example_str)) 53 | chunk += 1 54 | 55 | 56 | def chunk_all(): 57 | # Make a dir to hold the chunks 58 | if not os.path.isdir(chunks_dir): 59 | os.mkdir(chunks_dir) 60 | # Chunk the data 61 | for set_name in ['train', 'val', 'test']: 62 | print "Splitting %s data into chunks..." % set_name 63 | chunk_file(set_name) 64 | print "Saved chunked data in %s" % chunks_dir 65 | 66 | 67 | def tokenize_stories(stories_dir, tokenized_stories_dir): 68 | """Maps a whole directory of .story files to a tokenized version using Stanford CoreNLP Tokenizer""" 69 | print "Preparing to tokenize %s to %s..." % (stories_dir, tokenized_stories_dir) 70 | stories = os.listdir(stories_dir) 71 | # make IO list file 72 | print "Making list of files to tokenize..." 73 | with open("mapping.txt", "w") as f: 74 | for s in stories: 75 | f.write("%s \t %s\n" % (os.path.join(stories_dir, s), os.path.join(tokenized_stories_dir, s))) 76 | command = ['java', 'edu.stanford.nlp.process.PTBTokenizer', '-ioFileList', '-preserveLines', 'mapping.txt'] 77 | print "Tokenizing %i files in %s and saving in %s..." % (len(stories), stories_dir, tokenized_stories_dir) 78 | subprocess.call(command) 79 | print "Stanford CoreNLP Tokenizer has finished." 80 | os.remove("mapping.txt") 81 | 82 | # Check that the tokenized stories directory contains the same number of files as the original directory 83 | num_orig = len(os.listdir(stories_dir)) 84 | num_tokenized = len(os.listdir(tokenized_stories_dir)) 85 | if num_orig != num_tokenized: 86 | raise Exception("The tokenized stories directory %s contains %i files, but it should contain the same number as %s (which has %i files). Was there an error during tokenization?" % (tokenized_stories_dir, num_tokenized, stories_dir, num_orig)) 87 | print "Successfully finished tokenizing %s to %s.\n" % (stories_dir, tokenized_stories_dir) 88 | 89 | 90 | def read_text_file(text_file): 91 | lines = [] 92 | with open(text_file, "r") as f: 93 | for line in f: 94 | lines.append(line.strip()) 95 | return lines 96 | 97 | 98 | def hashhex(s): 99 | """Returns a heximal formated SHA1 hash of the input string.""" 100 | h = hashlib.sha1() 101 | h.update(s) 102 | return h.hexdigest() 103 | 104 | 105 | def get_url_hashes(url_list): 106 | return [hashhex(url) for url in url_list] 107 | 108 | 109 | def fix_missing_period(line): 110 | """Adds a period to a line that is missing a period""" 111 | if "@highlight" in line: return line 112 | if line=="": return line 113 | if line[-1] in END_TOKENS: return line 114 | # print line[-1] 115 | return line + " ." 116 | 117 | 118 | def get_art_abs(story_file): 119 | lines = read_text_file(story_file) 120 | 121 | # Lowercase everything 122 | lines = [line.lower() for line in lines] 123 | 124 | # Put periods on the ends of lines that are missing them (this is a problem in the dataset because many image captions don't end in periods; consequently they end up in the body of the article as run-on sentences) 125 | lines = [fix_missing_period(line) for line in lines] 126 | 127 | # Separate out article and abstract sentences 128 | article_lines = [] 129 | highlights = [] 130 | next_is_highlight = False 131 | for idx,line in enumerate(lines): 132 | if line == "": 133 | continue # empty line 134 | elif line.startswith("@highlight"): 135 | next_is_highlight = True 136 | elif next_is_highlight: 137 | highlights.append(line) 138 | else: 139 | article_lines.append(line) 140 | 141 | # Make article into a single string 142 | article = ' '.join(article_lines) 143 | 144 | # Make abstract into a signle string, putting and tags around the sentences 145 | abstract = ' '.join(["%s %s %s" % (SENTENCE_START, sent, SENTENCE_END) for sent in highlights]) 146 | 147 | return article, abstract 148 | 149 | 150 | def write_to_bin(url_file, out_file, makevocab=False): 151 | """Reads the tokenized .story files corresponding to the urls listed in the url_file and writes them to a out_file.""" 152 | print "Making bin file for URLs listed in %s..." % url_file 153 | url_list = read_text_file(url_file) 154 | url_hashes = get_url_hashes(url_list) 155 | story_fnames = [s+".story" for s in url_hashes] 156 | num_stories = len(story_fnames) 157 | 158 | if makevocab: 159 | vocab_counter = collections.Counter() 160 | 161 | with open(out_file, 'wb') as writer: 162 | for idx,s in enumerate(story_fnames): 163 | if idx % 1000 == 0: 164 | print "Writing story %i of %i; %.2f percent done" % (idx, num_stories, float(idx)*100.0/float(num_stories)) 165 | 166 | # Look in the tokenized story dirs to find the .story file corresponding to this url 167 | if os.path.isfile(os.path.join(cnn_tokenized_stories_dir, s)): 168 | story_file = os.path.join(cnn_tokenized_stories_dir, s) 169 | elif os.path.isfile(os.path.join(dm_tokenized_stories_dir, s)): 170 | story_file = os.path.join(dm_tokenized_stories_dir, s) 171 | else: 172 | print "Error: Couldn't find tokenized story file %s in either tokenized story directories %s and %s. Was there an error during tokenization?" % (s, cnn_tokenized_stories_dir, dm_tokenized_stories_dir) 173 | # Check again if tokenized stories directories contain correct number of files 174 | print "Checking that the tokenized stories directories %s and %s contain correct number of files..." % (cnn_tokenized_stories_dir, dm_tokenized_stories_dir) 175 | check_num_stories(cnn_tokenized_stories_dir, num_expected_cnn_stories) 176 | check_num_stories(dm_tokenized_stories_dir, num_expected_dm_stories) 177 | raise Exception("Tokenized stories directories %s and %s contain correct number of files but story file %s found in neither." % (cnn_tokenized_stories_dir, dm_tokenized_stories_dir, s)) 178 | 179 | # Get the strings to write to .bin file 180 | article, abstract = get_art_abs(story_file) 181 | 182 | # Write to tf.Example 183 | tf_example = example_pb2.Example() 184 | tf_example.features.feature['article'].bytes_list.value.extend([article]) 185 | tf_example.features.feature['abstract'].bytes_list.value.extend([abstract]) 186 | tf_example_str = tf_example.SerializeToString() 187 | str_len = len(tf_example_str) 188 | writer.write(struct.pack('q', str_len)) 189 | writer.write(struct.pack('%ds' % str_len, tf_example_str)) 190 | 191 | # Write the vocab to file, if applicable 192 | if makevocab: 193 | art_tokens = article.split(' ') 194 | abs_tokens = abstract.split(' ') 195 | abs_tokens = [t for t in abs_tokens if t not in [SENTENCE_START, SENTENCE_END]] # remove these tags from vocab 196 | tokens = art_tokens + abs_tokens 197 | tokens = [t.strip() for t in tokens] # strip 198 | tokens = [t for t in tokens if t!=""] # remove empty 199 | vocab_counter.update(tokens) 200 | 201 | print "Finished writing file %s\n" % out_file 202 | 203 | # write vocab to file 204 | if makevocab: 205 | print "Writing vocab file..." 206 | with open(os.path.join(finished_files_dir, "vocab"), 'w') as writer: 207 | for word, count in vocab_counter.most_common(VOCAB_SIZE): 208 | writer.write(word + ' ' + str(count) + '\n') 209 | print "Finished writing vocab file" 210 | 211 | 212 | def check_num_stories(stories_dir, num_expected): 213 | num_stories = len(os.listdir(stories_dir)) 214 | if num_stories != num_expected: 215 | raise Exception("stories directory %s contains %i files but should contain %i" % (stories_dir, num_stories, num_expected)) 216 | 217 | 218 | if __name__ == '__main__': 219 | if len(sys.argv) != 3: 220 | print "USAGE: python make_datafiles.py " 221 | sys.exit() 222 | cnn_stories_dir = sys.argv[1] 223 | dm_stories_dir = sys.argv[2] 224 | 225 | # Check the stories directories contain the correct number of .story files 226 | check_num_stories(cnn_stories_dir, num_expected_cnn_stories) 227 | check_num_stories(dm_stories_dir, num_expected_dm_stories) 228 | 229 | # Create some new directories 230 | if not os.path.exists(cnn_tokenized_stories_dir): os.makedirs(cnn_tokenized_stories_dir) 231 | if not os.path.exists(dm_tokenized_stories_dir): os.makedirs(dm_tokenized_stories_dir) 232 | if not os.path.exists(finished_files_dir): os.makedirs(finished_files_dir) 233 | 234 | # Run stanford tokenizer on both stories dirs, outputting to tokenized stories directories 235 | tokenize_stories(cnn_stories_dir, cnn_tokenized_stories_dir) 236 | tokenize_stories(dm_stories_dir, dm_tokenized_stories_dir) 237 | 238 | # Read the tokenized stories, do a little postprocessing then write to bin files 239 | write_to_bin(all_test_urls, os.path.join(finished_files_dir, "test.bin")) 240 | write_to_bin(all_val_urls, os.path.join(finished_files_dir, "val.bin")) 241 | write_to_bin(all_train_urls, os.path.join(finished_files_dir, "train.bin"), makevocab=True) 242 | 243 | # Chunk the data. This splits each of train.bin, val.bin and test.bin into smaller chunks, each containing e.g. 1000 examples, and saves them in finished_files/chunks 244 | chunk_all() 245 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## 2019 Summer Intern Mission 2 | 3 | ### Goal: 4 | 5 | Build an end-to-end application for abstractive document summarization on top of TensorFlow, Flink-AI-Extended [1] and Flink ML pipeline framework [2]. 6 | 7 | ### Introduction 8 | 9 | Document summarization is the process of shortening a text document in order to create a summary with the major points of the original document. 10 | 11 | In general, a native Flink application will be built to serve an external document summarization application. In the training phase, a corpus with plenty of article-summary tuple will be fed as input to an estimator pipeline for training to produce a model pipeline. In the reference phase, the Flink application will use the trained model pipeline to serve the summarization request from external application, which take raw article as input and response the summary. 12 | 13 | Inside the estimator pipeline, an abstract TF estimator will be created on top of Flink ML pipeline framework. The TF estimator is actually an untrained tensorflow model running in python, which use Flink-AI-Extended to connect to tensorflow. After fitting the corpus(or training), the estimator pipeline will be converted to a model pipeline. Similarly, a abstract TF Model will be created inside the model pipeline, which actually use trained model on tensorflow to execute transform function. 14 | 15 | The design of the entire system is shown below: 16 | 17 | ![design](doc/design.bmp) 18 | 19 | ### Schedule (JUL 1st - AUG 2nd) 20 | 21 | ​ **Week1**: Build an abstractive document summarization application on top of pure TensorFlow. Implements the ability to train models from the original corpus and generate summaries for new articles. 22 | 23 | ​ **Week2**: Encapsulate and call model’s training & inference function through Flink-AI-Extended 24 | 25 | ​ **Week3**: Integrate into the ML pipeline framework and train/inference/persist the model through pipeline 26 | 27 | ​ **Week4**: Build a simple WebUI, summarize & demo, make suggestions and improvements 28 | 29 | ​ **Week5**: buffer period 30 | 31 | ### Expected contributions 32 | 33 | 1. Build a complete application, connect TensorFlow <—> Flink-AI-Extended <—> Flink ML pipeline framework. Lessons learned from the perspective of MLlib users and developers. 34 | 2. Found the disadvantages of Flink-AI-Extended and ML pipeline in use, design and implementation 35 | 3. Suggestions and practical improvements for the disadvantages 36 | 37 | ### References 38 | 39 | \[1] [Flink-AI-Extended: Extend deep learning framework on the Flink project](https://github.com/alibaba/flink-ai-extended) 40 | 41 | \[2] [Flink ML pipeline framework: A new set of ML core interface on top of Flink TableAPI](https://github.com/apache/flink/tree/release-1.9/flink-ml-parent) 42 | 43 | 44 | 45 | ## Usage 46 | 47 | ### 1. Clone this project 48 | 49 | ```bash 50 | git clone git@github.com:LittleBBBo/TextSummarization-On-Flink.git 51 | ``` 52 | 53 | ### 2. Download processed data & pretrained model 54 | 55 | ```bash 56 | # Change dir form project root 57 | cd data/cnn-dailymail 58 | # This step will download 'finished_files.zip' from Google Drive, 59 | # then unzip to 'data/cnn-dailymail/finished_files', 60 | # finally remove 'finished_files.zip' 61 | bash download_data.sh 62 | 63 | # Change dir form project root 64 | cd log 65 | # This step will download 'pretrained_model_tf1.2.1.zip' from Google Drive, 66 | # then unzip to 'log/pretrained_model_tf1.2.1', 67 | # finally remove 'pretrained_model_tf1.2.1.zip' 68 | bash download_model.sh 69 | ``` 70 | 71 | ### 3. Setup 72 | 73 | #### 3.1 Install required environment 74 | 75 | Python 2.7 76 | 77 | Java 1.8 78 | 79 | cmake >= 3.6 80 | 81 | Maven >= 3.3.0 82 | 83 | #### 3.2 Build Flink-AI-Extended from source 84 | 85 | Follow the README from https://github.com/alibaba/flink-ai-extended 86 | 87 | **Compiling commands will automatically install tensorflow 1.11.0** 88 | 89 | ```bash 90 | git clone git@github.com:alibaba/flink-ai-extended.git 91 | mvn -DskipTests=true clean install 92 | ``` 93 | 94 | #### 3.3 Build Flink 1.9-SNASHOT from source 95 | 96 | Follow the README from https://github.com/apache/flink/tree/release-1.9: 97 | 98 | ```bash 99 | git clone git@github.com:apache/flink.git 100 | # switch to release-1.9 brach 101 | mvn -DskipTests=true clean install 102 | ``` 103 | 104 | #### 3.4 Install nltk package 105 | 106 | ```bash 107 | pip install nltk 108 | 109 | # To install 'punkt' tokenizers, 110 | # run follows once in Python: 111 | import nltk 112 | nltk.download('punkt') 113 | ``` 114 | 115 | ### 4. Run training or inference 116 | 117 | Running examples in **TensorFlowTest.java**: 118 | 119 | ```java 120 | public class TensorFlowTest { 121 | public static final Logger LOG = LoggerFactory.getLogger(TensorFlowTest.class); 122 | private static final String projectDir = System.getProperty("user.dir"); 123 | public static final String[] scripts = { 124 | projectDir + "/src/main/python/pointer-generator/run_summarization.py", 125 | projectDir + "/src/main/python/pointer-generator/__init__.py", 126 | projectDir + "/src/main/python/pointer-generator/attention_decoder.py", 127 | projectDir + "/src/main/python/pointer-generator/batcher.py", 128 | projectDir + "/src/main/python/pointer-generator/beam_search.py", 129 | projectDir + "/src/main/python/pointer-generator/data.py", 130 | projectDir + "/src/main/python/pointer-generator/decode.py", 131 | projectDir + "/src/main/python/pointer-generator/inspect_checkpoint.py", 132 | projectDir + "/src/main/python/pointer-generator/model.py", 133 | projectDir + "/src/main/python/pointer-generator/util.py", 134 | projectDir + "/src/main/python/pointer-generator/flink_writer.py", 135 | projectDir + "/src/main/python/pointer-generator/train.py", 136 | }; 137 | private static final String hyperparameter_key = "TF_Hyperparameter"; 138 | public static final String[] inference_hyperparameter = { 139 | "run_summarization.py", // first param is uesless but placeholder 140 | "--mode=decode", 141 | "--data_path=" + projectDir + "/data/cnn-dailymail/cnn_stories_test/0*", 142 | "--vocab_path=" + projectDir + "/data/cnn-dailymail/finished_files/vocab", 143 | "--log_root=" + projectDir + "/log", 144 | "--exp_name=pretrained_model_tf1.2.1", 145 | "--batch_size=4", // default to 16 146 | "--max_enc_steps=400", 147 | "--max_dec_steps=100", 148 | "--coverage=1", 149 | "--single_pass=1", 150 | "--inference=1", 151 | }; 152 | public static final String[] train_hyperparameter = { 153 | "run_summarization.py", // first param is uesless but placeholder 154 | "--mode=train", 155 | "--data_path=" + projectDir + "/data/cnn-dailymail/finished_files/chunked/train_*", 156 | "--vocab_path=" + projectDir + "/data/cnn-dailymail/finished_files/vocab", 157 | "--log_root=" + projectDir + "/log", 158 | "--exp_name=pretrained_model_tf1.2.1", 159 | "--batch_size=4", // default to 16 160 | "--max_enc_steps=400", 161 | "--max_dec_steps=100", 162 | "--coverage=1", 163 | "--num_steps=10", // if 0, never stop 164 | }; 165 | 166 | @Test 167 | public void testModelInference() throws Exception { 168 | TestingServer server = new TestingServer(2181, true); 169 | StreamExecutionEnvironment streamEnv = StreamExecutionEnvironment.getExecutionEnvironment(); 170 | StreamTableEnvironment tableEnv = StreamTableEnvironment.create(streamEnv); 171 | Table input = tableEnv.fromDataStream(streamEnv.fromCollection(createArticleData()), 172 | "uuid,article,summary,reference"); 173 | 174 | TFModel model = createModel(); 175 | Table output = model.transform(tableEnv, input); 176 | 177 | tableEnv.toAppendStream(output, Row.class).print().setParallelism(1); 178 | streamEnv.execute(); 179 | server.stop(); 180 | } 181 | 182 | @Test 183 | public void testModelTraining() throws Exception { 184 | TestingServer server = new TestingServer(2181, true); 185 | StreamExecutionEnvironment streamEnv = StreamExecutionEnvironment.getExecutionEnvironment(); 186 | 187 | StreamTableEnvironment tableEnv = StreamTableEnvironment.create(streamEnv); 188 | Table input = tableEnv.fromDataStream(streamEnv.fromCollection(createArticleData()), 189 | "uuid,article,summary,reference"); 190 | TFEstimator estimator = createEstimator(); 191 | estimator.fit(tableEnv, input); 192 | streamEnv.execute(); 193 | server.stop(); 194 | } 195 | 196 | private List createArticleData() { 197 | List rows = new ArrayList<>(); 198 | for (int i = 0; i < 8; i++) { 199 | Row row = new Row(4); 200 | row.setField(0, String.format("uuid-%d", i)); 201 | row.setField(1, String.format("article %d.", i)); 202 | row.setField(2, ""); 203 | row.setField(3, String.format("reference %d.", i)); 204 | rows.add(row); 205 | } 206 | return rows; 207 | } 208 | 209 | public static TFModel createModel() { 210 | return new TFModel() 211 | .setZookeeperConnStr("127.0.0.1:2181") 212 | .setWorkerNum(1) 213 | .setPsNum(0) 214 | 215 | .setInferenceScripts(scripts) 216 | .setInferenceMapFunc("main_on_flink") 217 | .setInferenceHyperParamsKey(hyperparameter_key) 218 | .setInferenceHyperParams(inference_hyperparameter) 219 | .setInferenceEnvPath(null) 220 | 221 | .setInferenceSelectedCols(new String[]{ "uuid", "article", "reference" }) 222 | .setInferenceOutputCols(new String[]{ "uuid", "article", "summary", "reference" }) 223 | .setInferenceOutputTypes(new DataTypes[] {DataTypes.STRING, DataTypes.STRING, DataTypes.STRING, DataTypes.STRING}); 224 | } 225 | 226 | public static TFEstimator createEstimator() { 227 | return new TFEstimator() 228 | .setZookeeperConnStr("127.0.0.1:2181") 229 | .setWorkerNum(1) 230 | .setPsNum(0) 231 | 232 | .setTrainScripts(scripts) 233 | .setTrainMapFunc("main_on_flink") 234 | .setTrainHyperParamsKey(hyperparameter_key) 235 | .setTrainHyperParams(train_hyperparameter) 236 | .setTrainEnvPath(null) 237 | 238 | .setTrainSelectedCols(new String[]{ "uuid", "article", "reference" }) 239 | .setTrainOutputCols(new String[]{ "uuid"}) 240 | .setTrainOutputTypes(new DataTypes[]{ DataTypes.STRING }) 241 | 242 | .setInferenceScripts(scripts) 243 | .setInferenceMapFunc("main_on_flink") 244 | .setInferenceHyperParamsKey(hyperparameter_key) 245 | .setInferenceHyperParams(inference_hyperparameter) 246 | .setInferenceEnvPath(null) 247 | 248 | .setInferenceSelectedCols(new String[]{ "uuid", "article", "reference" }) 249 | .setInferenceOutputCols(new String[]{ "uuid", "article", "summary", "reference" }) 250 | .setInferenceOutputTypes(new DataTypes[] {DataTypes.STRING, DataTypes.STRING, DataTypes.STRING, DataTypes.STRING}); 251 | } 252 | } 253 | ``` 254 | 255 | -------------------------------------------------------------------------------- /doc/github/Comment to [Issue 2] flink-ai-extended adapter flink ml pipeline.md: -------------------------------------------------------------------------------- 1 | ## Comment to [[Issue 2](https://github.com/alibaba/flink-ai-extended/issues/2)] flink-ai-extended adapter flink ml pipeline 2 | 3 | I implemented two general classes——**TFEstimator** and **TFModel** based on the Flink ML pipeline framework(see [flink-ml-parent](https://github.com/apache/flink/tree/release-1.9/flink-ml-parent)). These two classes encapsulate the Flink-AI-Extended **train** and **inference** procedures in the **fit**() and **transform**() methods respectively. The WithParams interface implements the configuration and delivery of common parameters. 4 | 5 | **TFEstimator** and **TFModel** are specific implementation classes. In theory, **any** Flink-AI-Extended extension-based TensorFlow algorithm can be run. The train/inference process is completely encapsulated by some generalized parameter configuration, so that the user can simply construct an estimator or model of a TensorFlow algorithm. 6 | 7 | ### Params 8 | 9 | In general, the following types of parameters need to be configured when using Flink-AI-Extended: 10 | 11 | - **cluster information**: including zookeeper address, number of workers, number of ps. 12 | - **Input and output information**: including the column name of the input table that needs to be passed to the TensorFlow, and the column name and corresponding type of the output table that TensorFlow returns to flink. 13 | - **python information**: including all python file paths, main function entry, hyper parameters passed to python, virtual environment path. 14 | 15 | Therefore, some unified interfaces are designed for each type of parameter, and **TFEstimator** and **TFModel** implement these interfaces. It should be noted that for the input and output and python related parameters, **two sets** of the same interface are designed for the **training** and **Inference** processes respectively. This design is because although most parameters of TFEstimator and TFModel should be the same in general applications, it is impossible to force users to develop TensorFlow algorithm according to such specifications, so the ability to independently configure a set of parameters for TFEstimator and TFModel is retained. . For example, the user's entry function during training is "train_on_flink" and the reference process can be "inference_on_flink". 16 | 17 | The core design of parameter interfaces is as follows: 18 | 19 | ```java 20 | package org.apache.flink.table.ml.lib.tensorflow.param; 21 | 22 | import org.apache.flink.ml.api.misc.param.ParamInfo; 23 | import org.apache.flink.ml.api.misc.param.ParamInfoFactory; 24 | import org.apache.flink.ml.api.misc.param.WithParams; 25 | 26 | /** 27 | * Parameters for cluster configuration, including: 28 | * 1. zookeeper address 29 | * 2. worker number 30 | * 3. ps number 31 | */ 32 | public interface HasClusterConfig extends WithParams { 33 | ParamInfo ZOOKEEPER_CONNECT_STR; 34 | ParamInfo WORKER_NUM; 35 | ParamInfo PS_NUM; 36 | } 37 | 38 | 39 | /** 40 | * Parameters for python configuration in training process, including: 41 | * 1. paths of python scripts 42 | * 2. entry function in main python file 43 | * 3. key to get hyper parameter in python 44 | * 4. hyper parameter for python 45 | * 5. virtual environment path 46 | */ 47 | public interface HasTrainPythonConfig extends WithParams { 48 | ParamInfo TRAIN_SCRIPTS; 49 | ParamInfo TRAIN_MAP_FUNC; 50 | ParamInfo TRAIN_HYPER_PARAMS_KEY; 51 | ParamInfo TRAIN_HYPER_PARAMS; 52 | ParamInfo TRAIN_ENV_PATH; 53 | } 54 | 55 | 56 | /** 57 | * An interface for classes with a parameter specifying 58 | * the name of multiple selected input columns. 59 | */ 60 | public interface HasTrainSelectedCols extends WithParams { 61 | ParamInfo TRAIN_SELECTED_COLS; 62 | } 63 | 64 | 65 | /** 66 | * An interface for classes with a parameter specifying 67 | * the names of multiple output columns. 68 | */ 69 | public interface HasTrainOutputCols extends WithParams { 70 | ParamInfo TRAIN_OUTPUT_COLS; 71 | } 72 | 73 | 74 | /** 75 | * An interface for classes with a parameter specifying 76 | * the types of multiple output columns. 77 | */ 78 | public interface HasTrainOutputTypes extends WithParams { 79 | ParamInfo TRAIN_OUTPUT_TYPES; 80 | } 81 | 82 | 83 | /** 84 | * Mirrored interfaces for configuration in inference process 85 | */ 86 | public interface HasInferencePythonConfig extends WithParams; 87 | public interface HasInferenceSelectedCols extends WithParams; 88 | public interface HasInferenceOutputCols extends WithParams; 89 | public interface HasInferenceOutputTypes extends WithParams; 90 | ``` 91 | 92 | 93 | 94 | ###TFModel 95 | 96 | The general **TFModel** configures cluster information and Python related information (such as file path, super parameter, etc.) through the **HasClusterConfig** and **HasInferencePythonConfig** interfaces. The encoding and decoding formats related to data transmission by the TF are configured through the **HasInferenceSelectedCols**, **HasInferenceOutputCols**, and **HasInferenceOutputTypes** interfaces. 97 | 98 | The core design of **TFModel** is as follows: 99 | 100 | ```java 101 | /** 102 | * A general TensorFlow model implemented by Flink-AI-Extended, 103 | * is usually generated by an {@link TFEstimator} 104 | * when {@link TFEstimator#fit(TableEnvironment, Table)} is invoked. 105 | */ 106 | public class TFModel implements Model, 107 | HasClusterConfig, 108 | HasInferencePythonConfig, 109 | HasInferenceSelectedCols, 110 | HasInferenceOutputCols, 111 | HasInferenceOutputTypes { 112 | 113 | private static final Logger LOG = LoggerFactory.getLogger(TFModel.class); 114 | private Params params = new Params(); 115 | 116 | @Override 117 | public Table transform(TableEnvironment tableEnvironment, Table table) { 118 | StreamExecutionEnvironment streamEnv; 119 | try { 120 | // TODO: [hack] transform table to dataStream to get StreamExecutionEnvironment 121 | if (tableEnvironment instanceof StreamTableEnvironment) { 122 | StreamTableEnvironment streamTableEnvironment = 123 | (StreamTableEnvironment)tableEnvironment; 124 | streamEnv = streamTableEnvironment 125 | .toAppendStream(table, Row.class).getExecutionEnvironment(); 126 | } else { 127 | throw new RuntimeException("Unsupported TableEnvironment, please use StreamTableEnvironment"); 128 | } 129 | 130 | // Select the necessary columns according to "SelectedCols" 131 | Table inputTable = configureInputTable(table); 132 | // Construct the output schema according on the "OutputCols" and "OutputTypes" 133 | TableSchema outputSchema = configureOutputSchema(); 134 | // Create a basic TFConfig according to "ClusterConfig" and "PythonConfig" 135 | TFConfig config = configureTFConfig(); 136 | // Configure the row encoding and decoding base on input & output schema 137 | configureExampleCoding(config, inputTable.getSchema(), outputSchema); 138 | // transform the table by TF which implemented by AI-Extended 139 | Table outputTable = TFUtils 140 | .inference(streamEnv, tableEnvironment, inputTable, config, outputSchema); 141 | return outputTable; 142 | } catch (Exception e) { 143 | throw new RuntimeException(e); 144 | } 145 | } 146 | 147 | @Override 148 | public Params getParams() { 149 | return params; 150 | } 151 | } 152 | ``` 153 | 154 | 155 | 156 | ### TFEstimator 157 | 158 | The general **TFEstimator** parameter configuration process is similar to **TFModel**, but the parameters of the train process and the Inference process need to be configured at the same time. Because the **TFEstimator** needs to return an **instantiated TFModel**, most of the parameters should be the same, but the user should not be required to do that. So the ability to independently configure a set of parameters for TFEstimator and TFModel is retained. 159 | 160 | The core design of **TFEstimator** is as follows: 161 | 162 | ```java 163 | /** 164 | * A general TensorFlow estimator implemented by Flink-AI-Extended, 165 | * responsible for training and generating TensorFlow models. 166 | */ 167 | public class TFEstimator implements Estimator, 168 | HasClusterConfig, 169 | 170 | HasTrainPythonConfig, 171 | HasInferencePythonConfig, 172 | 173 | HasTrainSelectedCols, 174 | HasTrainOutputCols, 175 | HasTrainOutputTypes, 176 | 177 | HasInferenceSelectedCols, 178 | HasInferenceOutputCols, 179 | HasInferenceOutputTypes { 180 | 181 | private Params params = new Params(); 182 | 183 | @Override 184 | public TFModel fit(TableEnvironment tableEnvironment, Table table) { 185 | StreamExecutionEnvironment streamEnv; 186 | try { 187 | // TODO: [hack] transform table to dataStream to get StreamExecutionEnvironment 188 | if (tableEnvironment instanceof StreamTableEnvironment) { 189 | StreamTableEnvironment streamTableEnvironment = 190 | (StreamTableEnvironment)tableEnvironment; 191 | streamEnv = streamTableEnvironment 192 | .toAppendStream(table, Row.class).getExecutionEnvironment(); 193 | } else { 194 | throw new RuntimeException("Unsupported TableEnvironment, please use StreamTableEnvironment"); 195 | } 196 | // Select the necessary columns according to "SelectedCols" 197 | Table inputTable = configureInputTable(table); 198 | // Construct the output schema according on the "OutputCols" and "OutputTypes" 199 | TableSchema outputSchema = configureOutputSchema(); 200 | // Create a basic TFConfig according to "ClusterConfig" and "PythonConfig" 201 | TFConfig config = configureTFConfig(); 202 | // Configure the row encoding and decoding base on input & output schema 203 | configureExampleCoding(config, inputTable.getSchema(), outputSchema); 204 | // transform the table by TF which implemented by AI-Extended 205 | Table outputTable = TFUtils.train(streamEnv, tableEnvironment, inputTable, config, outputSchema); 206 | // Construct the trained model by inference related config 207 | TFModel model = new TFModel() 208 | .setZookeeperConnStr(getZookeeperConnStr()) 209 | .setWorkerNum(getWorkerNum()) 210 | .setPsNum(getPsNum()) 211 | .setInferenceScripts(getInferenceScripts()) 212 | .setInferenceMapFunc(getInferenceMapFunc()) 213 | .setInferenceHyperParams(getInferenceHyperParams()) 214 | .setInferenceEnvPath(getInferenceEnvPath()) 215 | .setInferenceSelectedCols(getInferenceSelectedCols()) 216 | .setInferenceOutputCols(getInferenceOutputCols()) 217 | .setInferenceOutputTypes(getInferenceOutputTypes()); 218 | return model; 219 | } catch (Exception e) { 220 | throw new RuntimeException(e); 221 | } 222 | } 223 | 224 | @Override 225 | public Params getParams() { 226 | return params; 227 | } 228 | } 229 | ``` 230 | 231 | Please let me know If you think there is any problem with these designs. If not, I'd like to create a **pull request** with detailed code, documentation and testing. By the way, Flink ML pipeline is on top of Flink-1.9 and I think you may split a branch for Flink-1.9 because there are some incompatibilities between 1.8 and 1.9. -------------------------------------------------------------------------------- /src/main/python/pointer-generator/LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright 2017 The TensorFlow Authors. All rights reserved. 2 | Modifications Copyright 2017 Abigail See 3 | 4 | 5 | Apache License 6 | Version 2.0, January 2004 7 | http://www.apache.org/licenses/ 8 | 9 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 10 | 11 | 1. Definitions. 12 | 13 | "License" shall mean the terms and conditions for use, reproduction, 14 | and distribution as defined by Sections 1 through 9 of this document. 15 | 16 | "Licensor" shall mean the copyright owner or entity authorized by 17 | the copyright owner that is granting the License. 18 | 19 | "Legal Entity" shall mean the union of the acting entity and all 20 | other entities that control, are controlled by, or are under common 21 | control with that entity. For the purposes of this definition, 22 | "control" means (i) the power, direct or indirect, to cause the 23 | direction or management of such entity, whether by contract or 24 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 25 | outstanding shares, or (iii) beneficial ownership of such entity. 26 | 27 | "You" (or "Your") shall mean an individual or Legal Entity 28 | exercising permissions granted by this License. 29 | 30 | "Source" form shall mean the preferred form for making modifications, 31 | including but not limited to software source code, documentation 32 | source, and configuration files. 33 | 34 | "Object" form shall mean any form resulting from mechanical 35 | transformation or translation of a Source form, including but 36 | not limited to compiled object code, generated documentation, 37 | and conversions to other media types. 38 | 39 | "Work" shall mean the work of authorship, whether in Source or 40 | Object form, made available under the License, as indicated by a 41 | copyright notice that is included in or attached to the work 42 | (an example is provided in the Appendix below). 43 | 44 | "Derivative Works" shall mean any work, whether in Source or Object 45 | form, that is based on (or derived from) the Work and for which the 46 | editorial revisions, annotations, elaborations, or other modifications 47 | represent, as a whole, an original work of authorship. For the purposes 48 | of this License, Derivative Works shall not include works that remain 49 | separable from, or merely link (or bind by name) to the interfaces of, 50 | the Work and Derivative Works thereof. 51 | 52 | "Contribution" shall mean any work of authorship, including 53 | the original version of the Work and any modifications or additions 54 | to that Work or Derivative Works thereof, that is intentionally 55 | submitted to Licensor for inclusion in the Work by the copyright owner 56 | or by an individual or Legal Entity authorized to submit on behalf of 57 | the copyright owner. For the purposes of this definition, "submitted" 58 | means any form of electronic, verbal, or written communication sent 59 | to the Licensor or its representatives, including but not limited to 60 | communication on electronic mailing lists, source code control systems, 61 | and issue tracking systems that are managed by, or on behalf of, the 62 | Licensor for the purpose of discussing and improving the Work, but 63 | excluding communication that is conspicuously marked or otherwise 64 | designated in writing by the copyright owner as "Not a Contribution." 65 | 66 | "Contributor" shall mean Licensor and any individual or Legal Entity 67 | on behalf of whom a Contribution has been received by Licensor and 68 | subsequently incorporated within the Work. 69 | 70 | 2. Grant of Copyright License. Subject to the terms and conditions of 71 | this License, each Contributor hereby grants to You a perpetual, 72 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 73 | copyright license to reproduce, prepare Derivative Works of, 74 | publicly display, publicly perform, sublicense, and distribute the 75 | Work and such Derivative Works in Source or Object form. 76 | 77 | 3. Grant of Patent License. Subject to the terms and conditions of 78 | this License, each Contributor hereby grants to You a perpetual, 79 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 80 | (except as stated in this section) patent license to make, have made, 81 | use, offer to sell, sell, import, and otherwise transfer the Work, 82 | where such license applies only to those patent claims licensable 83 | by such Contributor that are necessarily infringed by their 84 | Contribution(s) alone or by combination of their Contribution(s) 85 | with the Work to which such Contribution(s) was submitted. If You 86 | institute patent litigation against any entity (including a 87 | cross-claim or counterclaim in a lawsuit) alleging that the Work 88 | or a Contribution incorporated within the Work constitutes direct 89 | or contributory patent infringement, then any patent licenses 90 | granted to You under this License for that Work shall terminate 91 | as of the date such litigation is filed. 92 | 93 | 4. Redistribution. You may reproduce and distribute copies of the 94 | Work or Derivative Works thereof in any medium, with or without 95 | modifications, and in Source or Object form, provided that You 96 | meet the following conditions: 97 | 98 | (a) You must give any other recipients of the Work or 99 | Derivative Works a copy of this License; and 100 | 101 | (b) You must cause any modified files to carry prominent notices 102 | stating that You changed the files; and 103 | 104 | (c) You must retain, in the Source form of any Derivative Works 105 | that You distribute, all copyright, patent, trademark, and 106 | attribution notices from the Source form of the Work, 107 | excluding those notices that do not pertain to any part of 108 | the Derivative Works; and 109 | 110 | (d) If the Work includes a "NOTICE" text file as part of its 111 | distribution, then any Derivative Works that You distribute must 112 | include a readable copy of the attribution notices contained 113 | within such NOTICE file, excluding those notices that do not 114 | pertain to any part of the Derivative Works, in at least one 115 | of the following places: within a NOTICE text file distributed 116 | as part of the Derivative Works; within the Source form or 117 | documentation, if provided along with the Derivative Works; or, 118 | within a display generated by the Derivative Works, if and 119 | wherever such third-party notices normally appear. The contents 120 | of the NOTICE file are for informational purposes only and 121 | do not modify the License. You may add Your own attribution 122 | notices within Derivative Works that You distribute, alongside 123 | or as an addendum to the NOTICE text from the Work, provided 124 | that such additional attribution notices cannot be construed 125 | as modifying the License. 126 | 127 | You may add Your own copyright statement to Your modifications and 128 | may provide additional or different license terms and conditions 129 | for use, reproduction, or distribution of Your modifications, or 130 | for any such Derivative Works as a whole, provided Your use, 131 | reproduction, and distribution of the Work otherwise complies with 132 | the conditions stated in this License. 133 | 134 | 5. Submission of Contributions. Unless You explicitly state otherwise, 135 | any Contribution intentionally submitted for inclusion in the Work 136 | by You to the Licensor shall be under the terms and conditions of 137 | this License, without any additional terms or conditions. 138 | Notwithstanding the above, nothing herein shall supersede or modify 139 | the terms of any separate license agreement you may have executed 140 | with Licensor regarding such Contributions. 141 | 142 | 6. Trademarks. This License does not grant permission to use the trade 143 | names, trademarks, service marks, or product names of the Licensor, 144 | except as required for reasonable and customary use in describing the 145 | origin of the Work and reproducing the content of the NOTICE file. 146 | 147 | 7. Disclaimer of Warranty. Unless required by applicable law or 148 | agreed to in writing, Licensor provides the Work (and each 149 | Contributor provides its Contributions) on an "AS IS" BASIS, 150 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 151 | implied, including, without limitation, any warranties or conditions 152 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 153 | PARTICULAR PURPOSE. You are solely responsible for determining the 154 | appropriateness of using or redistributing the Work and assume any 155 | risks associated with Your exercise of permissions under this License. 156 | 157 | 8. Limitation of Liability. In no event and under no legal theory, 158 | whether in tort (including negligence), contract, or otherwise, 159 | unless required by applicable law (such as deliberate and grossly 160 | negligent acts) or agreed to in writing, shall any Contributor be 161 | liable to You for damages, including any direct, indirect, special, 162 | incidental, or consequential damages of any character arising as a 163 | result of this License or out of the use or inability to use the 164 | Work (including but not limited to damages for loss of goodwill, 165 | work stoppage, computer failure or malfunction, or any and all 166 | other commercial damages or losses), even if such Contributor 167 | has been advised of the possibility of such damages. 168 | 169 | 9. Accepting Warranty or Additional Liability. While redistributing 170 | the Work or Derivative Works thereof, You may choose to offer, 171 | and charge a fee for, acceptance of support, warranty, indemnity, 172 | or other liability obligations and/or rights consistent with this 173 | License. However, in accepting such obligations, You may act only 174 | on Your own behalf and on Your sole responsibility, not on behalf 175 | of any other Contributor, and only if You agree to indemnify, 176 | defend, and hold each Contributor harmless for any liability 177 | incurred by, or claims asserted against, such Contributor by reason 178 | of your accepting any such warranty or additional liability. 179 | 180 | END OF TERMS AND CONDITIONS 181 | 182 | APPENDIX: How to apply the Apache License to your work. 183 | 184 | To apply the Apache License to your work, attach the following 185 | boilerplate notice, with the fields enclosed by brackets "[]" 186 | replaced with your own identifying information. (Don't include 187 | the brackets!) The text should be enclosed in the appropriate 188 | comment syntax for the file format. We also recommend that a 189 | file or class name and description of purpose be included on the 190 | same "printed page" as the copyright notice for easier 191 | identification within third-party archives. 192 | 193 | Copyright 2017, The TensorFlow Authors. 194 | Modifications Copyright 2017 Abigail See 195 | 196 | Licensed under the Apache License, Version 2.0 (the "License"); 197 | you may not use this file except in compliance with the License. 198 | You may obtain a copy of the License at 199 | 200 | http://www.apache.org/licenses/LICENSE-2.0 201 | 202 | Unless required by applicable law or agreed to in writing, software 203 | distributed under the License is distributed on an "AS IS" BASIS, 204 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 205 | See the License for the specific language governing permissions and 206 | limitations under the License. 207 | -------------------------------------------------------------------------------- /doc/github/[Issue] The streaming inference result needs to wait until the next query to write to sink.md: -------------------------------------------------------------------------------- 1 | ## [[Issue](https://github.com/alibaba/flink-ai-extended/issues/11)] The streaming inference result needs to wait until the next query to write to sink 2 | 3 | In the stream processing environment, after reading data from the source and processing it through Flink-AI-Extended, the result is **not immediately written to the sink**, but is not written **until the next data** of the source arrives. 4 | 5 | I built a simple demo for stream processing. Source injects a message every 5 seconds, a total of 25. The python part is immediately written back to the sink after reading. 6 | When I inject a message into the source, the log shows that the python process has received it and executed "context.output_writer_op" in python, but the sink did not receive any messages. When I continue to inject a message into the source, the last result is written to the sink. 7 | 8 | The following is the log: 9 | 10 | ```java 11 | ... 12 | [Source][2019-07-31 11:45:56.76]produce data-10 13 | [Sink][2019-07-31 11:45:56.76]finish data-9 14 | 15 | [Source][2019-07-31 11:46:01.765]produce data-11 16 | [Sink][2019-07-31 11:46:01.765]finish data-10 17 | ... 18 | ``` 19 | 20 | But I want to write back to sink immediately after executing "output_writer_op": 21 | 22 | ```java 23 | ... 24 | [Source][2019-07-31 11:45:56.76]produce data-10 25 | [Sink][2019-07-31 11:45:56.76]finish data-10 26 | 27 | [Source][2019-07-31 11:46:01.765]produce data-11 28 | [Sink][2019-07-31 11:46:01.765]finish data-11 29 | ... 30 | ``` 31 | 32 | For the time being, it is not clear why it is the cause of this situation. 33 | 34 | The following is my demo code: 35 | 36 | ```java 37 | package org.apache.flink.table.ml.lib.tensorflow; 38 | 39 | import com.alibaba.flink.ml.operator.util.DataTypes; 40 | import com.alibaba.flink.ml.tensorflow.client.TFConfig; 41 | import com.alibaba.flink.ml.tensorflow.client.TFUtils; 42 | import com.alibaba.flink.ml.tensorflow.coding.ExampleCoding; 43 | import com.alibaba.flink.ml.tensorflow.coding.ExampleCodingConfig; 44 | import com.alibaba.flink.ml.tensorflow.util.TFConstants; 45 | import com.alibaba.flink.ml.util.MLConstants; 46 | import org.apache.curator.test.TestingServer; 47 | import org.apache.flink.api.common.restartstrategy.RestartStrategies; 48 | import org.apache.flink.api.common.state.ListState; 49 | import org.apache.flink.api.common.state.ListStateDescriptor; 50 | import org.apache.flink.api.common.typeinfo.BasicTypeInfo; 51 | import org.apache.flink.api.common.typeinfo.TypeInformation; 52 | import org.apache.flink.api.common.typeinfo.Types; 53 | import org.apache.flink.api.java.typeutils.RowTypeInfo; 54 | import org.apache.flink.runtime.state.FunctionInitializationContext; 55 | import org.apache.flink.runtime.state.FunctionSnapshotContext; 56 | import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; 57 | import org.apache.flink.streaming.api.datastream.DataStream; 58 | import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; 59 | import org.apache.flink.streaming.api.functions.source.SourceFunction; 60 | import org.apache.flink.table.api.Table; 61 | import org.apache.flink.table.api.TableSchema; 62 | import org.apache.flink.table.api.java.StreamTableEnvironment; 63 | import org.apache.flink.table.ml.lib.tensorflow.util.Utils; 64 | import org.apache.flink.types.Row; 65 | import org.junit.Test; 66 | import org.slf4j.Logger; 67 | import org.slf4j.LoggerFactory; 68 | 69 | import java.sql.Timestamp; 70 | import java.util.HashMap; 71 | import java.util.Map; 72 | 73 | public class SourceSinkTest { 74 | private static final String ZookeeperConn = "127.0.0.1:2181"; 75 | private static final String[] Scripts = {"/Users/bodeng/TextSummarization-On-Flink/src/main/python/pointer-generator/test.py"}; 76 | private static final int WorkerNum = 1; 77 | private static final int PsNum = 0; 78 | 79 | @Test 80 | public void testSourceSink() throws Exception { 81 | TestingServer server = new TestingServer(2181, true); 82 | StreamExecutionEnvironment streamEnv = StreamExecutionEnvironment.createLocalEnvironment(1); 83 | streamEnv.setRestartStrategy(RestartStrategies.noRestart()); 84 | 85 | DataStream sourceStream = streamEnv.addSource( 86 | new DummyTimedSource(20, 5), new RowTypeInfo(Types.STRING)).setParallelism(1); 87 | StreamTableEnvironment tableEnv = StreamTableEnvironment.create(streamEnv); 88 | Table input = tableEnv.fromDataStream(sourceStream, "input"); 89 | TFConfig config = createTFConfig("test_source_sink"); 90 | 91 | TableSchema outputSchema = new TableSchema(new String[]{"output"}, new TypeInformation[]{BasicTypeInfo.STRING_TYPE_INFO}); 92 | 93 | // configure encode coding 94 | String strInput = ExampleCodingConfig.createExampleConfigStr( 95 | new String[]{"input"}, new DataTypes[]{DataTypes.STRING}, 96 | ExampleCodingConfig.ObjectType.ROW, Row.class); 97 | config.getProperties().put(TFConstants.INPUT_TF_EXAMPLE_CONFIG, strInput); 98 | config.getProperties().put(MLConstants.ENCODING_CLASS, 99 | ExampleCoding.class.getCanonicalName()); 100 | 101 | // configure decode coding 102 | String strOutput = ExampleCodingConfig.createExampleConfigStr( 103 | new String[]{"output"}, new DataTypes[]{DataTypes.STRING}, 104 | ExampleCodingConfig.ObjectType.ROW, Row.class); 105 | config.getProperties().put(TFConstants.OUTPUT_TF_EXAMPLE_CONFIG, strOutput); 106 | config.getProperties().put(MLConstants.DECODING_CLASS, 107 | ExampleCoding.class.getCanonicalName()); 108 | 109 | Table output = TFUtils.inference(streamEnv, tableEnv, input, config, outputSchema); 110 | tableEnv.toAppendStream(output, Row.class) 111 | .map(r -> "[Sink][" + new Timestamp(System.currentTimeMillis()) + "]finish " + r.getField(0) + "\n") 112 | .print().setParallelism(1); 113 | 114 | streamEnv.execute(); 115 | server.stop(); 116 | } 117 | 118 | private TFConfig createTFConfig(String mapFunc) { 119 | Map prop = new HashMap<>(); 120 | prop.put(MLConstants.CONFIG_STORAGE_TYPE, MLConstants.STORAGE_ZOOKEEPER); 121 | prop.put(MLConstants.CONFIG_ZOOKEEPER_CONNECT_STR, ZookeeperConn); 122 | return new TFConfig(WorkerNum, PsNum, prop, Scripts, mapFunc, null); 123 | } 124 | 125 | private static class DummyTimedSource implements SourceFunction, CheckpointedFunction { 126 | public static final Logger LOG = LoggerFactory.getLogger(DummyTimedSource.class); 127 | private long count = 0L; 128 | private long MAX_COUNT; 129 | private long INTERVAL; 130 | private volatile boolean isRunning = true; 131 | 132 | private transient ListState checkpointedCount; 133 | 134 | public DummyTimedSource(long maxCount, long interval) { 135 | this.MAX_COUNT = maxCount; 136 | this.INTERVAL = interval; 137 | } 138 | 139 | @Override 140 | public void run(SourceContext ctx) throws Exception { 141 | while (isRunning && count < MAX_COUNT) { 142 | // this synchronized block ensures that state checkpointing, 143 | // internal state updates and emission of elements are an atomic operation 144 | synchronized (ctx.getCheckpointLock()) { 145 | Row row = new Row(1); 146 | row.setField(0, String.format("data-%d", count)); 147 | System.out.println("[Source][" + new Timestamp(System.currentTimeMillis()) + "]produce " + row.getField(0)); 148 | ctx.collect(row); 149 | count++; 150 | Thread.sleep(INTERVAL * 1000); 151 | } 152 | } 153 | } 154 | 155 | @Override 156 | public void cancel() { 157 | isRunning = false; 158 | } 159 | 160 | @Override 161 | public void snapshotState(FunctionSnapshotContext context) throws Exception { 162 | this.checkpointedCount.clear(); 163 | this.checkpointedCount.add(count); 164 | } 165 | 166 | @Override 167 | public void initializeState(FunctionInitializationContext context) throws Exception { 168 | this.checkpointedCount = context 169 | .getOperatorStateStore() 170 | .getListState(new ListStateDescriptor<>("count", Long.class)); 171 | 172 | if (context.isRestored()) { 173 | for (Long count : this.checkpointedCount.get()) { 174 | this.count = count; 175 | } 176 | } 177 | } 178 | } 179 | } 180 | 181 | ``` 182 | 183 | and python code: 184 | 185 | ```python 186 | import sys 187 | import datetime 188 | 189 | import tensorflow as tf 190 | from flink_ml_tensorflow.tensorflow_context import TFContext 191 | 192 | 193 | class FlinkReader(object): 194 | def __init__(self, context, batch_size=1, features={'input': tf.FixedLenFeature([], tf.string)}): 195 | self._context = context 196 | self._batch_size = batch_size 197 | self._features = features 198 | self._build_graph() 199 | 200 | def _decode(self, features): 201 | return features['input'] 202 | 203 | def _build_graph(self): 204 | dataset = self._context.flink_stream_dataset() 205 | dataset = dataset.map(lambda record: tf.parse_single_example(record, features=self._features)) 206 | dataset = dataset.map(self._decode) 207 | dataset = dataset.batch(self._batch_size) 208 | iterator = dataset.make_one_shot_iterator() 209 | self._next_batch = iterator.get_next() 210 | 211 | def next_batch(self, sess): 212 | try: 213 | batch = sess.run(self._next_batch) 214 | return batch 215 | except tf.errors.OutOfRangeError: 216 | return None 217 | 218 | 219 | class FlinkWriter(object): 220 | def __init__(self, context): 221 | self._context = context 222 | self._build_graph() 223 | 224 | def _build_graph(self): 225 | self._write_feed = tf.placeholder(dtype=tf.string) 226 | self.write_op, self._close_op = self._context.output_writer_op([self._write_feed]) 227 | 228 | def _example(self, results): 229 | example = tf.train.Example(features=tf.train.Features( 230 | feature={ 231 | 'output': tf.train.Feature(bytes_list=tf.train.BytesList(value=[results[0]])), 232 | } 233 | )) 234 | return example 235 | 236 | def write_result(self, sess, results): 237 | sess.run(self.write_op, feed_dict={self._write_feed: self._example(results).SerializeToString()}) 238 | 239 | def close(self, sess): 240 | sess.run(self._close_op) 241 | 242 | 243 | 244 | def test_source_sink(context): 245 | tf_context = TFContext(context) 246 | if 'ps' == tf_context.get_role_name(): 247 | from time import sleep 248 | while True: 249 | sleep(1) 250 | else: 251 | index = tf_context.get_index() 252 | job_name = tf_context.get_role_name() 253 | cluster_json = tf_context.get_tf_cluster() 254 | cluster = tf.train.ClusterSpec(cluster=cluster_json) 255 | 256 | server = tf.train.Server(cluster, job_name=job_name, task_index=index) 257 | sess_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False, 258 | device_filters=["/job:ps", "/job:worker/task:%d" % index]) 259 | with tf.device(tf.train.replica_device_setter(worker_device='/job:worker/task:' + str(index), cluster=cluster)): 260 | reader = FlinkReader(tf_context) 261 | writer = FlinkWriter(tf_context) 262 | 263 | with tf.train.ChiefSessionCreator(master=server.target, config=sess_config).create_session() as sess: 264 | while True: 265 | batch = reader.next_batch(sess) 266 | if batch is None: 267 | break 268 | # tf.logging.info("[TF][%s]process %s" % (str(datetime.datetime.now()), str(batch))) 269 | 270 | writer.write_result(sess, batch) 271 | writer.close(sess) 272 | sys.stdout.flush() 273 | ``` 274 | 275 | -------------------------------------------------------------------------------- /src/main/java/org/apache/flink/table/ml/lib/tensorflow/util/CodingUtils.java: -------------------------------------------------------------------------------- 1 | package org.apache.flink.table.ml.lib.tensorflow.util; 2 | 3 | import com.alibaba.flink.ml.operator.util.DataTypes; 4 | import com.alibaba.flink.ml.tensorflow.client.TFConfig; 5 | import com.alibaba.flink.ml.tensorflow.coding.ExampleCoding; 6 | import com.alibaba.flink.ml.tensorflow.coding.ExampleCodingConfig; 7 | import com.alibaba.flink.ml.tensorflow.util.TFConstants; 8 | import com.alibaba.flink.ml.util.MLConstants; 9 | import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo; 10 | import org.apache.flink.api.common.typeinfo.BasicTypeInfo; 11 | import org.apache.flink.api.common.typeinfo.TypeInformation; 12 | import org.apache.flink.table.api.TableSchema; 13 | import org.slf4j.Logger; 14 | import org.slf4j.LoggerFactory; 15 | 16 | import java.util.Arrays; 17 | 18 | public class CodingUtils { 19 | private static Logger LOG = LoggerFactory.getLogger(CodingUtils.class); 20 | 21 | /** 22 | * Map DataTypes list to TypeInformation list 23 | * @throws RuntimeException when meet unsupported type of DataTypes 24 | */ 25 | public static TypeInformation[] dataTypesListToTypeInformation(DataTypes[] dataTypes) throws RuntimeException { 26 | TypeInformation[] ret = new TypeInformation[dataTypes.length]; 27 | for (int i = 0; i < dataTypes.length; i++) { 28 | ret[i] = dataTypesToTypeInformation(dataTypes[i]); 29 | } 30 | return ret; 31 | } 32 | 33 | /** 34 | * Map DataTypes class to TypeInformation 35 | * @throws RuntimeException when meet unsupported type of DataTypes 36 | */ 37 | public static TypeInformation dataTypesToTypeInformation(DataTypes dataTypes) throws RuntimeException { 38 | if (dataTypes == DataTypes.STRING) { 39 | return BasicTypeInfo.STRING_TYPE_INFO; 40 | } else if (dataTypes == DataTypes.BOOL) { 41 | return BasicTypeInfo.BOOLEAN_TYPE_INFO; 42 | } else if (dataTypes == DataTypes.INT_8) { 43 | return BasicTypeInfo.BYTE_TYPE_INFO; 44 | } else if (dataTypes == DataTypes.INT_16) { 45 | return BasicTypeInfo.SHORT_TYPE_INFO; 46 | } else if (dataTypes == DataTypes.INT_32) { 47 | return BasicTypeInfo.INT_TYPE_INFO; 48 | } else if (dataTypes == DataTypes.INT_64) { 49 | return BasicTypeInfo.LONG_TYPE_INFO; 50 | } else if (dataTypes == DataTypes.FLOAT_32) { 51 | return BasicTypeInfo.FLOAT_TYPE_INFO; 52 | } else if (dataTypes == DataTypes.FLOAT_64) { 53 | return BasicTypeInfo.DOUBLE_TYPE_INFO; 54 | } else if (dataTypes == DataTypes.UINT_16) { 55 | return BasicTypeInfo.CHAR_TYPE_INFO; 56 | } else if (dataTypes == DataTypes.FLOAT_32_ARRAY) { 57 | return BasicArrayTypeInfo.FLOAT_ARRAY_TYPE_INFO; 58 | } else { 59 | throw new RuntimeException("Unsupported data type of " + dataTypes.toString()); 60 | } 61 | } 62 | 63 | /** 64 | * Map TypeInformation list to DataTypes list 65 | * @throws RuntimeException when meet unsupported type of TypeInformation 66 | */ 67 | public static DataTypes[] typeInormationListToDataTypes(TypeInformation[] typeInformation) throws RuntimeException { 68 | DataTypes[] ret = new DataTypes[typeInformation.length]; 69 | for (int i = 0; i < typeInformation.length; i++) { 70 | ret[i] = typeInformationToDataTypes(typeInformation[i]); 71 | } 72 | return ret; 73 | } 74 | 75 | /** 76 | * Map TypeInformation class to DataTypes 77 | * @throws RuntimeException when meet unsupported type of TypeInformation 78 | */ 79 | public static DataTypes typeInformationToDataTypes(TypeInformation typeInformation) throws RuntimeException { 80 | if (typeInformation == BasicTypeInfo.STRING_TYPE_INFO) { 81 | return DataTypes.STRING; 82 | } else if (typeInformation == BasicTypeInfo.BOOLEAN_TYPE_INFO) { 83 | return DataTypes.BOOL; 84 | } else if (typeInformation == BasicTypeInfo.BYTE_TYPE_INFO) { 85 | return DataTypes.INT_8; 86 | } else if (typeInformation == BasicTypeInfo.SHORT_TYPE_INFO) { 87 | return DataTypes.INT_16; 88 | } else if (typeInformation == BasicTypeInfo.INT_TYPE_INFO) { 89 | return DataTypes.INT_32; 90 | } else if (typeInformation == BasicTypeInfo.LONG_TYPE_INFO) { 91 | return DataTypes.INT_64; 92 | } else if (typeInformation == BasicTypeInfo.FLOAT_TYPE_INFO) { 93 | return DataTypes.FLOAT_32; 94 | } else if (typeInformation == BasicTypeInfo.DOUBLE_TYPE_INFO) { 95 | return DataTypes.FLOAT_64; 96 | } else if (typeInformation == BasicTypeInfo.CHAR_TYPE_INFO) { 97 | return DataTypes.UINT_16; 98 | } else if (typeInformation == BasicTypeInfo.DATE_TYPE_INFO) { 99 | throw new RuntimeException("Unsupported data type of " + typeInformation.toString()); 100 | } else if (typeInformation == BasicTypeInfo.VOID_TYPE_INFO) { 101 | throw new RuntimeException("Unsupported data type of " + typeInformation.toString()); 102 | } else if (typeInformation == BasicTypeInfo.BIG_INT_TYPE_INFO) { 103 | throw new RuntimeException("Unsupported data type of " + typeInformation.toString()); 104 | } else if (typeInformation == BasicTypeInfo.BIG_DEC_TYPE_INFO) { 105 | throw new RuntimeException("Unsupported data type of " + typeInformation.toString()); 106 | } else if (typeInformation == BasicTypeInfo.INSTANT_TYPE_INFO) { 107 | throw new RuntimeException("Unsupported data type of " + typeInformation.toString()); 108 | } else if (typeInformation == BasicArrayTypeInfo.STRING_ARRAY_TYPE_INFO) { 109 | throw new RuntimeException("Unsupported data type of " + typeInformation.toString()); 110 | } else if (typeInformation == BasicArrayTypeInfo.BOOLEAN_ARRAY_TYPE_INFO) { 111 | throw new RuntimeException("Unsupported data type of " + typeInformation.toString()); 112 | } else if (typeInformation == BasicArrayTypeInfo.BYTE_ARRAY_TYPE_INFO) { 113 | throw new RuntimeException("Unsupported data type of " + typeInformation.toString()); 114 | } else if (typeInformation == BasicArrayTypeInfo.SHORT_ARRAY_TYPE_INFO) { 115 | throw new RuntimeException("Unsupported data type of " + typeInformation.toString()); 116 | } else if (typeInformation == BasicArrayTypeInfo.INT_ARRAY_TYPE_INFO) { 117 | throw new RuntimeException("Unsupported data type of " + typeInformation.toString()); 118 | } else if (typeInformation == BasicArrayTypeInfo.LONG_ARRAY_TYPE_INFO) { 119 | throw new RuntimeException("Unsupported data type of " + typeInformation.toString()); 120 | } else if (typeInformation == BasicArrayTypeInfo.FLOAT_ARRAY_TYPE_INFO) { 121 | return DataTypes.FLOAT_32_ARRAY; 122 | } else if (typeInformation == BasicArrayTypeInfo.DOUBLE_ARRAY_TYPE_INFO) { 123 | throw new RuntimeException("Unsupported data type of " + typeInformation.toString()); 124 | } else if (typeInformation == BasicArrayTypeInfo.CHAR_ARRAY_TYPE_INFO) { 125 | throw new RuntimeException("Unsupported data type of " + typeInformation.toString()); 126 | } else { 127 | throw new RuntimeException("Unsupported data type of " + typeInformation.toString()); 128 | } 129 | } 130 | 131 | public static void configureEncodeExampleCoding(TFConfig config, String[] encodeNames, DataTypes[] encodeTypes, 132 | ExampleCodingConfig.ObjectType entryType, Class entryClass) throws RuntimeException { 133 | String strInput = ExampleCodingConfig.createExampleConfigStr(encodeNames, encodeTypes, entryType, entryClass); 134 | LOG.info("input tf example config: " + strInput); 135 | config.getProperties().put(TFConstants.INPUT_TF_EXAMPLE_CONFIG, strInput); 136 | config.getProperties().put(MLConstants.ENCODING_CLASS, ExampleCoding.class.getCanonicalName()); 137 | } 138 | 139 | public static void configureDecodeExampleCoding(TFConfig config, String[] decodeNames, DataTypes[] decodeTypes, 140 | ExampleCodingConfig.ObjectType entryType, Class entryClass) throws RuntimeException { 141 | String strOutput = ExampleCodingConfig.createExampleConfigStr(decodeNames, decodeTypes, entryType, entryClass); 142 | LOG.info("output tf example config: " + strOutput); 143 | config.getProperties().put(TFConstants.OUTPUT_TF_EXAMPLE_CONFIG, strOutput); 144 | config.getProperties().put(MLConstants.DECODING_CLASS, ExampleCoding.class.getCanonicalName()); 145 | } 146 | 147 | // public static void configureExampleCoding(TFConfig config, String[] encodeNames, DataTypes[] encodeTypes, 148 | // String[] decodeNames, DataTypes[] decodeTypes, 149 | // ExampleCodingConfig.ObjectType entryType, Class entryClass) throws RuntimeException { 150 | // configureEncodeExampleCoding(config, encodeNames, encodeTypes, entryType, entryClass); 151 | // configureDecodeExampleCoding(config, decodeNames, decodeTypes, entryType, entryClass); 152 | // } 153 | 154 | // public static void configureExampleCoding(TFConfig config, TableSchema encodeSchema, 155 | // String[] decodeNames, DataTypes[] decodeTypes, 156 | // ExampleCodingConfig.ObjectType entryType, Class entryClass) throws RuntimeException { 157 | // String[] encodeNames = encodeSchema.getFieldNames(); 158 | // TypeInformation[] encodeTypes = encodeSchema.getFieldTypes(); 159 | // DataTypes[] encodeDataTypes = Arrays 160 | // .stream(encodeTypes) 161 | // .map(CodingUtils::typeInformationToDataTypes) 162 | // .toArray(DataTypes[]::new); 163 | // configureExampleCoding(config, encodeNames, encodeDataTypes, decodeNames, decodeTypes, entryType, entryClass); 164 | // } 165 | // 166 | // public static void configureExampleCoding(TFConfig config, String[] encodeNames, DataTypes[] encodeTypes, 167 | // TableSchema decodeSchema, 168 | // ExampleCodingConfig.ObjectType entryType, Class entryClass) throws RuntimeException { 169 | // String[] decodeNames = decodeSchema.getFieldNames(); 170 | // TypeInformation[] decodeTypes = decodeSchema.getFieldTypes(); 171 | // DataTypes[] decodeDataTypes = Arrays 172 | // .stream(decodeTypes) 173 | // .map(CodingUtils::typeInformationToDataTypes) 174 | // .toArray(DataTypes[]::new); 175 | // configureExampleCoding(config, encodeNames, encodeTypes, decodeNames, decodeDataTypes, entryType, entryClass); 176 | // } 177 | 178 | public static void configureEncodeExampleCoding(TFConfig config, String[] encodeNames, TypeInformation[] encodeTypes, 179 | ExampleCodingConfig.ObjectType entryType, Class entryClass) throws RuntimeException { 180 | DataTypes[] encodeDataTypes = Arrays 181 | .stream(encodeTypes) 182 | .map(CodingUtils::typeInformationToDataTypes) 183 | .toArray(DataTypes[]::new); 184 | configureEncodeExampleCoding(config, encodeNames, encodeDataTypes, entryType, entryClass); 185 | } 186 | 187 | public static void configureDecodeExampleCoding(TFConfig config, String[] decodeNames, TypeInformation[] decodeTypes, 188 | ExampleCodingConfig.ObjectType entryType, Class entryClass) throws RuntimeException { 189 | DataTypes[] decodeDataTypes = Arrays 190 | .stream(decodeTypes) 191 | .map(CodingUtils::typeInformationToDataTypes) 192 | .toArray(DataTypes[]::new); 193 | configureDecodeExampleCoding(config, decodeNames, decodeDataTypes, entryType, entryClass); 194 | } 195 | 196 | public static void configureExampleCoding(TFConfig config, TableSchema encodeSchema, TableSchema decodeSchema, 197 | ExampleCodingConfig.ObjectType entryType, Class entryClass) throws RuntimeException { 198 | if (encodeSchema != null) { 199 | configureEncodeExampleCoding(config, encodeSchema.getFieldNames(), encodeSchema.getFieldTypes(), 200 | entryType, entryClass); 201 | } 202 | if (decodeSchema != null) { 203 | configureDecodeExampleCoding(config, decodeSchema.getFieldNames(), decodeSchema.getFieldTypes(), 204 | entryType, entryClass); 205 | } 206 | } 207 | } 208 | --------------------------------------------------------------------------------