├── src
├── main
│ ├── python
│ │ └── pointer-generator
│ │ │ ├── __init__.py
│ │ │ ├── run_train.sh
│ │ │ ├── run_inference.sh
│ │ │ ├── inspect_checkpoint.py
│ │ │ ├── flink_writer.py
│ │ │ ├── util.py
│ │ │ ├── train.py
│ │ │ ├── beam_search.py
│ │ │ ├── README.md
│ │ │ └── LICENSE.txt
│ ├── java
│ │ ├── me
│ │ │ └── littlebo
│ │ │ │ ├── SysUtils.java
│ │ │ │ ├── MessageSerializationSchema.java
│ │ │ │ ├── MessageDeserializationSchema.java
│ │ │ │ ├── Message.java
│ │ │ │ ├── Summarization.java
│ │ │ │ └── App.java
│ │ └── org
│ │ │ └── apache
│ │ │ └── flink
│ │ │ └── table
│ │ │ └── ml
│ │ │ └── lib
│ │ │ └── tensorflow
│ │ │ ├── param
│ │ │ ├── HasTrainOutputCols.java
│ │ │ ├── HasTrainSelectedCols.java
│ │ │ ├── HasInferenceOutputCols.java
│ │ │ ├── HasInferenceSelectedCols.java
│ │ │ ├── HasTrainOutputTypes.java
│ │ │ ├── HasInferenceOutputTypes.java
│ │ │ ├── HasClusterConfig.java
│ │ │ ├── HasTrainPythonConfig.java
│ │ │ └── HasInferencePythonConfig.java
│ │ │ ├── TFModel.java
│ │ │ ├── TFEstimator.java
│ │ │ └── util
│ │ │ └── CodingUtils.java
│ └── resources
│ │ └── log4j2.xml
└── test
│ ├── java
│ └── org
│ │ └── apache
│ │ └── flink
│ │ └── table
│ │ └── ml
│ │ └── lib
│ │ └── tensorflow
│ │ ├── KafkaSourceSinkTest.java
│ │ ├── SourceSinkTest.java
│ │ └── InputOutputTest.java
│ └── python
│ └── test.py
├── doc
├── design.bmp
├── design.pptx
├── github
│ ├── Issue6
│ │ ├── Issue6.bmp
│ │ ├── Issue6.pptx
│ │ └── Issue6_PR.bmp
│ ├── [Issue] Simplify the ExampleCoding configuration process.md
│ ├── [Issue] Exception when there is only InputTfExampleConfig but no OutputTfExampleConfig.md
│ ├── Comment to [Issue 2] flink-ai-extended adapter flink ml pipeline.md
│ └── [Issue] The streaming inference result needs to wait until the next query to write to sink.md
└── deprecated
│ ├── Flink-AI-Extended Integration design.md
│ └── About StreamExeEnv in AI-Extended Issue.md
├── TextSummarization-On-Flink.iml
├── .gitignore
├── data
└── cnn-dailymail
│ ├── download_data.sh
│ ├── LICENSE.md
│ ├── README.md
│ └── make_datafiles.py
├── log
└── download_model.sh
├── pom.xml
└── README.md
/src/main/python/pointer-generator/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/doc/design.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LittleBBBo/TextSummarization-On-Flink/HEAD/doc/design.bmp
--------------------------------------------------------------------------------
/doc/design.pptx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LittleBBBo/TextSummarization-On-Flink/HEAD/doc/design.pptx
--------------------------------------------------------------------------------
/doc/github/Issue6/Issue6.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LittleBBBo/TextSummarization-On-Flink/HEAD/doc/github/Issue6/Issue6.bmp
--------------------------------------------------------------------------------
/doc/github/Issue6/Issue6.pptx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LittleBBBo/TextSummarization-On-Flink/HEAD/doc/github/Issue6/Issue6.pptx
--------------------------------------------------------------------------------
/doc/github/Issue6/Issue6_PR.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LittleBBBo/TextSummarization-On-Flink/HEAD/doc/github/Issue6/Issue6_PR.bmp
--------------------------------------------------------------------------------
/src/main/java/me/littlebo/SysUtils.java:
--------------------------------------------------------------------------------
1 | package me.littlebo;
2 |
3 | public class SysUtils {
4 | public static String getProjectRootDir() {
5 | return System.getProperty("user.dir");
6 | }
7 | }
8 |
--------------------------------------------------------------------------------
/TextSummarization-On-Flink.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | **/.DS_Store
2 |
3 | /data/cnn-dailymail/**
4 | /data/raw
5 | !/data/cnn-dailymail/LICENSE.md
6 | !/data/cnn-dailymail/make_datafiles.py
7 | !/data/cnn-dailymail/README.md
8 | !/data/cnn-dailymail/url_lists
9 | !/data/cnn-dailymail/download_data.sh
10 |
11 | /pointer-generator/*.pyc
12 |
13 | /log/**
14 | !/log/download_model.sh
15 |
16 | /temp/**
17 | /venv/**
18 | **/.idea/**
--------------------------------------------------------------------------------
/src/main/python/pointer-generator/run_train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python run_summarization.py --mode=train --data_path=/Users/bodeng/TextSummarization-On-Flink/data/cnn-dailymail/finished_files/chunked/train_* --vocab_path=/Users/bodeng/TextSummarization-On-Flink/data/cnn-dailymail/finished_files/vocab --log_root=/Users/bodeng/TextSummarization-On-Flink/log --exp_name=pretrained_model_tf1.2.1 --max_enc_steps=400 --max_dec_steps=100 --coverage=1
3 |
--------------------------------------------------------------------------------
/src/main/python/pointer-generator/run_inference.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python run_summarization.py --mode=decode --data_path=/Users/bodeng/TextSummarization-On-Flink/data/cnn-dailymail/cnn_stories_test/0* --vocab_path=/Users/bodeng/TextSummarization-On-Flink/data/cnn-dailymail/finished_files/vocab --log_root=/Users/bodeng/TextSummarization-On-Flink/log --exp_name=pretrained_model_tf1.2.1 --max_enc_steps=400 --max_dec_steps=100 --coverage=1 --single_pass=1 --inference=1
3 |
--------------------------------------------------------------------------------
/data/cnn-dailymail/download_data.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # script to download processed data from Google Drive
4 | #
5 | # not guaranteed to work indefinitely
6 | # taken from Stack Overflow answer:
7 | # http://stackoverflow.com/a/38937732/7002068
8 |
9 | #gURL='https://drive.google.com/uc?id=0BzQ6rtO2VN95a0c3TlZCWkl3aU0&export=download'
10 |
11 | # match id, more than 26 word characters
12 | #ggID=$(echo "$gURL" | egrep -o '(\w|-){26,}')
13 |
14 | ggID='0BzQ6rtO2VN95a0c3TlZCWkl3aU0'
15 |
16 | ggURL='https://drive.google.com/uc?export=download'
17 | gURL="${ggURL}&id=${ggID}"
18 |
19 | curl -sc /tmp/gcokie "${ggURL}&id=${ggID}" >/dev/null
20 | getcode="$(awk '/_warning_/ {print $NF}' /tmp/gcokie)"
21 |
22 | cmd='curl --insecure -C - -LOJb /tmp/gcokie "${ggURL}&confirm=${getcode}&id=${ggID}"'
23 | echo -e "Downloading from "$gURL"...\n"
24 | eval $cmd
25 |
26 | # unzip data file
27 | unzip finished_files.zip
28 | rm finished_files.zip
--------------------------------------------------------------------------------
/log/download_model.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # script to download pretrained model from Google Drive
4 | #
5 | # not guaranteed to work indefinitely
6 | # taken from Stack Overflow answer:
7 | # http://stackoverflow.com/a/38937732/7002068
8 |
9 | #gURL='https://drive.google.com/uc?id=0B7pQmm-OfDv7ZUhHZm9ZWEZidDg&export=download'
10 |
11 | # match id, more than 26 word characters
12 | #ggID=$(echo "$gURL" | egrep -o '(\w|-){26,}')
13 |
14 | ggID='0B7pQmm-OfDv7ZUhHZm9ZWEZidDg'
15 |
16 | ggURL='https://drive.google.com/uc?export=download'
17 | gURL="${ggURL}&id=${ggID}"
18 |
19 | curl -sc /tmp/gcokie "${ggURL}&id=${ggID}" >/dev/null
20 | getcode="$(awk '/_warning_/ {print $NF}' /tmp/gcokie)"
21 |
22 | cmd='curl --insecure -C - -LOJb /tmp/gcokie "${ggURL}&confirm=${getcode}&id=${ggID}"'
23 | echo -e "Downloading from "$gURL"...\n"
24 | eval $cmd
25 |
26 | # unzip data file
27 | unzip pretrained_model_tf1.2.1.zip
28 | rm pretrained_model_tf1.2.1.zip
--------------------------------------------------------------------------------
/src/main/resources/log4j2.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
--------------------------------------------------------------------------------
/src/main/java/org/apache/flink/table/ml/lib/tensorflow/param/HasTrainOutputCols.java:
--------------------------------------------------------------------------------
1 | package org.apache.flink.table.ml.lib.tensorflow.param;
2 |
3 | import org.apache.flink.ml.api.misc.param.ParamInfo;
4 | import org.apache.flink.ml.api.misc.param.ParamInfoFactory;
5 | import org.apache.flink.ml.api.misc.param.WithParams;
6 |
7 | /**
8 | * An interface for classes with a parameter specifying the names of multiple output columns.
9 | * @param the actual type of this WithParams, as the return type of setter
10 | */
11 | public interface HasTrainOutputCols extends WithParams {
12 | ParamInfo TRAIN_OUTPUT_COLS = ParamInfoFactory
13 | .createParamInfo("trainOutputCols", String[].class)
14 | .setDescription("Names of the output columns for train processing")
15 | .setRequired()
16 | .build();
17 |
18 | default String[] getTrainOutputCols() {
19 | return get(TRAIN_OUTPUT_COLS);
20 | }
21 |
22 | default T setTrainOutputCols(String... value) {
23 | return set(TRAIN_OUTPUT_COLS, value);
24 | }
25 | }
26 |
--------------------------------------------------------------------------------
/src/main/java/org/apache/flink/table/ml/lib/tensorflow/param/HasTrainSelectedCols.java:
--------------------------------------------------------------------------------
1 | package org.apache.flink.table.ml.lib.tensorflow.param;
2 |
3 | import org.apache.flink.ml.api.misc.param.ParamInfo;
4 | import org.apache.flink.ml.api.misc.param.ParamInfoFactory;
5 | import org.apache.flink.ml.api.misc.param.WithParams;
6 |
7 | /**
8 | * An interface for classes with a parameter specifying the name of multiple selected input columns.
9 | * @param the actual type of this WithParams, as the return type of setter
10 | */
11 | public interface HasTrainSelectedCols extends WithParams {
12 | ParamInfo TRAIN_SELECTED_COLS = ParamInfoFactory
13 | .createParamInfo("trainSelectedCols", String[].class)
14 | .setDescription("Names of the columns used for train processing")
15 | .setRequired()
16 | .build();
17 |
18 | default String[] getTrainSelectedCols() {
19 | return get(TRAIN_SELECTED_COLS);
20 | }
21 |
22 | default T setTrainSelectedCols(String... value) {
23 | return set(TRAIN_SELECTED_COLS, value);
24 | }
25 | }
26 |
--------------------------------------------------------------------------------
/src/main/java/org/apache/flink/table/ml/lib/tensorflow/param/HasInferenceOutputCols.java:
--------------------------------------------------------------------------------
1 | package org.apache.flink.table.ml.lib.tensorflow.param;
2 |
3 | import org.apache.flink.ml.api.misc.param.ParamInfo;
4 | import org.apache.flink.ml.api.misc.param.ParamInfoFactory;
5 | import org.apache.flink.ml.api.misc.param.WithParams;
6 |
7 | /**
8 | * An interface for classes with a parameter specifying the names of multiple output columns.
9 | * @param the actual type of this WithParams, as the return type of setter
10 | */
11 | public interface HasInferenceOutputCols extends WithParams {
12 | ParamInfo INFERENCE_OUTPUT_COLS = ParamInfoFactory
13 | .createParamInfo("inferenceOutputCols", String[].class)
14 | .setDescription("Names of the output columns for inference processing")
15 | .setRequired()
16 | .build();
17 |
18 | default String[] getInferenceOutputCols() {
19 | return get(INFERENCE_OUTPUT_COLS);
20 | }
21 |
22 | default T setInferenceOutputCols(String... value) {
23 | return set(INFERENCE_OUTPUT_COLS, value);
24 | }
25 | }
26 |
--------------------------------------------------------------------------------
/src/main/java/org/apache/flink/table/ml/lib/tensorflow/param/HasInferenceSelectedCols.java:
--------------------------------------------------------------------------------
1 | package org.apache.flink.table.ml.lib.tensorflow.param;
2 |
3 | import org.apache.flink.ml.api.misc.param.ParamInfo;
4 | import org.apache.flink.ml.api.misc.param.ParamInfoFactory;
5 | import org.apache.flink.ml.api.misc.param.WithParams;
6 |
7 | /**
8 | * An interface for classes with a parameter specifying the name of multiple selected input columns.
9 | * @param the actual type of this WithParams, as the return type of setter
10 | */
11 | public interface HasInferenceSelectedCols extends WithParams {
12 | ParamInfo INFERENCE_SELECTED_COLS = ParamInfoFactory
13 | .createParamInfo("inferenceSelectedCols", String[].class)
14 | .setDescription("Names of the columns used for inference processing")
15 | .setRequired()
16 | .build();
17 |
18 | default String[] getInferenceSelectedCols() {
19 | return get(INFERENCE_SELECTED_COLS);
20 | }
21 |
22 | default T setInferenceSelectedCols(String... value) {
23 | return set(INFERENCE_SELECTED_COLS, value);
24 | }
25 | }
26 |
--------------------------------------------------------------------------------
/data/cnn-dailymail/LICENSE.md:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Abi See
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/src/main/java/org/apache/flink/table/ml/lib/tensorflow/param/HasTrainOutputTypes.java:
--------------------------------------------------------------------------------
1 | package org.apache.flink.table.ml.lib.tensorflow.param;
2 |
3 | import com.alibaba.flink.ml.operator.util.DataTypes;
4 | import org.apache.flink.ml.api.misc.param.ParamInfo;
5 | import org.apache.flink.ml.api.misc.param.ParamInfoFactory;
6 | import org.apache.flink.ml.api.misc.param.WithParams;
7 |
8 | /**
9 | * An interface for classes with a parameter specifying the types of multiple output columns.
10 | * @param the actual type of this WithParams, as the return type of setter
11 | */
12 | public interface HasTrainOutputTypes extends WithParams {
13 | ParamInfo TRAIN_OUTPUT_TYPES = ParamInfoFactory
14 | .createParamInfo("trainOutputTypes", DataTypes[].class)
15 | .setDescription("TypeInformation of output columns for train processing")
16 | .setRequired()
17 | .build();
18 |
19 | default DataTypes[] getTrainOutputTypes() {
20 | return get(TRAIN_OUTPUT_TYPES);
21 | }
22 |
23 | default T setTrainOutputTypes(DataTypes[] outputTypes) {
24 | return set(TRAIN_OUTPUT_TYPES, outputTypes);
25 | }
26 | }
27 |
--------------------------------------------------------------------------------
/src/main/java/org/apache/flink/table/ml/lib/tensorflow/param/HasInferenceOutputTypes.java:
--------------------------------------------------------------------------------
1 | package org.apache.flink.table.ml.lib.tensorflow.param;
2 |
3 | import com.alibaba.flink.ml.operator.util.DataTypes;
4 | import org.apache.flink.ml.api.misc.param.ParamInfo;
5 | import org.apache.flink.ml.api.misc.param.ParamInfoFactory;
6 | import org.apache.flink.ml.api.misc.param.WithParams;
7 |
8 | /**
9 | * An interface for classes with a parameter specifying the types of multiple output columns.
10 | * @param the actual type of this WithParams, as the return type of setter
11 | */
12 | public interface HasInferenceOutputTypes extends WithParams {
13 | ParamInfo INFERENCE_OUTPUT_TYPES = ParamInfoFactory
14 | .createParamInfo("inferenceOutputTypes", DataTypes[].class)
15 | .setDescription("TypeInformation of output columns for inference processing")
16 | .setRequired()
17 | .build();
18 |
19 | default DataTypes[] getInferenceOutputTypes() {
20 | return get(INFERENCE_OUTPUT_TYPES);
21 | }
22 |
23 | default T setInferenceOutputTypes(DataTypes[] outputTypes) {
24 | return set(INFERENCE_OUTPUT_TYPES, outputTypes);
25 | }
26 | }
27 |
--------------------------------------------------------------------------------
/src/main/java/me/littlebo/MessageSerializationSchema.java:
--------------------------------------------------------------------------------
1 | package me.littlebo;
2 |
3 | import org.apache.flink.api.common.serialization.SerializationSchema;
4 | import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonProcessingException;
5 | import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper;
6 | import org.apache.flink.types.Row;
7 | import org.slf4j.Logger;
8 | import org.slf4j.LoggerFactory;
9 |
10 | public class MessageSerializationSchema implements SerializationSchema {
11 | private ObjectMapper objectMapper;
12 | private Logger logger = LoggerFactory.getLogger(MessageSerializationSchema.class);
13 |
14 | @Override
15 | public byte[] serialize(Row row) {
16 | if(objectMapper == null) {
17 | objectMapper = new ObjectMapper();
18 | }
19 | try {
20 | Message message = new Message((String)row.getField(0), (String)row.getField(1),
21 | (String)row.getField(2), (String)row.getField(3));
22 | return objectMapper.writeValueAsString(message).getBytes();
23 | } catch (JsonProcessingException e) {
24 | logger.error("Failed to parse JSON", e);
25 | }
26 | return new byte[0];
27 | }
28 | }
29 |
--------------------------------------------------------------------------------
/src/main/python/pointer-generator/inspect_checkpoint.py:
--------------------------------------------------------------------------------
1 | """
2 | Simple script that checks if a checkpoint is corrupted with any inf/NaN values. Run like this:
3 | python inspect_checkpoint.py model.12345
4 | """
5 |
6 | import tensorflow as tf
7 | import sys
8 | import numpy as np
9 |
10 |
11 | if __name__ == '__main__':
12 | if len(sys.argv) != 2:
13 | raise Exception("Usage: python inspect_checkpoint.py \nNote: Do not include the .data .index or .meta part of the model checkpoint in file_name.")
14 | file_name = sys.argv[1]
15 | reader = tf.train.NewCheckpointReader(file_name)
16 | var_to_shape_map = reader.get_variable_to_shape_map()
17 |
18 | finite = []
19 | all_infnan = []
20 | some_infnan = []
21 |
22 | for key in sorted(var_to_shape_map.keys()):
23 | tensor = reader.get_tensor(key)
24 | if np.all(np.isfinite(tensor)):
25 | finite.append(key)
26 | else:
27 | if not np.any(np.isfinite(tensor)):
28 | all_infnan.append(key)
29 | else:
30 | some_infnan.append(key)
31 |
32 | print "\nFINITE VARIABLES:"
33 | for key in finite: print key
34 |
35 | print "\nVARIABLES THAT ARE ALL INF/NAN:"
36 | for key in all_infnan: print key
37 |
38 | print "\nVARIABLES THAT CONTAIN SOME FINITE, SOME INF/NAN VALUES:"
39 | for key in some_infnan: print key
40 |
41 | print ""
42 | if not all_infnan and not some_infnan:
43 | print "CHECK PASSED: checkpoint contains no inf/NaN values"
44 | else:
45 | print "CHECK FAILED: checkpoint contains some inf/NaN values"
46 |
--------------------------------------------------------------------------------
/src/main/java/me/littlebo/MessageDeserializationSchema.java:
--------------------------------------------------------------------------------
1 | package me.littlebo;
2 |
3 | import org.apache.flink.api.common.serialization.DeserializationSchema;
4 | import org.apache.flink.api.common.typeinfo.TypeInformation;
5 | import org.apache.flink.api.common.typeinfo.Types;
6 | import org.apache.flink.api.java.typeutils.RowTypeInfo;
7 | import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper;
8 | import org.apache.flink.types.Row;
9 |
10 | import java.io.IOException;
11 | import java.util.concurrent.atomic.AtomicInteger;
12 |
13 | public class MessageDeserializationSchema implements DeserializationSchema {
14 | private static ObjectMapper objectMapper = new ObjectMapper();
15 | private AtomicInteger counter = new AtomicInteger();
16 | private Integer maxCount;
17 |
18 | public MessageDeserializationSchema(int maxCount) {
19 | this.maxCount = maxCount;
20 | }
21 |
22 | @Override
23 | public Row deserialize(byte[] bytes) throws IOException {
24 | Message message = objectMapper.readValue(bytes, Message.class);
25 | Row row = new Row(4);
26 | row.setField(0, message.getUuid());
27 | row.setField(1, message.getArticle());
28 | row.setField(2, message.getSummary());
29 | row.setField(3, message.getReference());
30 | counter.incrementAndGet();
31 | return row;
32 | }
33 |
34 | @Override
35 | public boolean isEndOfStream(Row row) {
36 | if (counter.get() > maxCount) {
37 | return true;
38 | }
39 | return false;
40 | }
41 |
42 | @Override
43 | public TypeInformation getProducedType() {
44 | return new RowTypeInfo(Types.STRING, Types.STRING, Types.STRING,Types.STRING);
45 | }
46 | }
47 |
--------------------------------------------------------------------------------
/src/main/java/me/littlebo/Message.java:
--------------------------------------------------------------------------------
1 | package me.littlebo;
2 |
3 | import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty;
4 | import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.annotation.JsonSerialize;
5 | import org.apache.flink.types.Row;
6 |
7 |
8 | import java.io.Serializable;
9 |
10 | @JsonSerialize
11 | public class Message implements Serializable {
12 | @JsonProperty("uuid")
13 | private String uuid;
14 | @JsonProperty("article")
15 | private String article;
16 | @JsonProperty("summary")
17 | private String summary;
18 | @JsonProperty("reference")
19 | private String reference;
20 |
21 | public Message() {
22 | }
23 |
24 | public Message(String uuid, String article, String summary, String reference) {
25 | this.uuid = uuid;
26 | this.article = article;
27 | this.summary = summary;
28 | this.reference = reference;
29 | }
30 |
31 | public String getUuid() {
32 | return uuid;
33 | }
34 |
35 | public String getArticle() {
36 | return article;
37 | }
38 |
39 | public String getSummary() {
40 | return summary;
41 | }
42 |
43 | public String getReference() {
44 | return reference;
45 | }
46 |
47 | public void setUuid(String uuid) {
48 | this.uuid = uuid;
49 | }
50 |
51 | public void setArticle(String article) {
52 | this.article = article;
53 | }
54 |
55 | public void setSummary(String summary) {
56 | this.summary = summary;
57 | }
58 |
59 | public void setReference(String reference) {
60 | this.reference = reference;
61 | }
62 |
63 | public static Row toRow(Message message) {
64 | Row row = new Row(4);
65 | row.setField(0, message.uuid);
66 | row.setField(1, message.article);
67 | row.setField(2, message.summary);
68 | row.setField(3, message.reference);
69 | return row;
70 | }
71 | }
72 |
--------------------------------------------------------------------------------
/src/main/java/org/apache/flink/table/ml/lib/tensorflow/param/HasClusterConfig.java:
--------------------------------------------------------------------------------
1 | package org.apache.flink.table.ml.lib.tensorflow.param;
2 |
3 | import org.apache.flink.ml.api.misc.param.ParamInfo;
4 | import org.apache.flink.ml.api.misc.param.ParamInfoFactory;
5 | import org.apache.flink.ml.api.misc.param.WithParams;
6 |
7 | /**
8 | * Parameters for cluster configuration, including:
9 | * 1. zookeeper address
10 | * 2. worker number
11 | * 3. ps number
12 | * @param the actual type of this WithParams, as the return type of setter
13 | */
14 | public interface HasClusterConfig extends WithParams {
15 | ParamInfo ZOOKEEPER_CONNECT_STR = ParamInfoFactory
16 | .createParamInfo("zookeeper_connect_str", String.class)
17 | .setDescription("zookeeper address to connect")
18 | .setRequired()
19 | .setHasDefaultValue("127.0.0.1:2181").build();
20 | ParamInfo WORKER_NUM = ParamInfoFactory
21 | .createParamInfo("worker_num", Integer.class)
22 | .setDescription("worker number")
23 | .setRequired()
24 | .setHasDefaultValue(1).build();
25 | ParamInfo PS_NUM = ParamInfoFactory
26 | .createParamInfo("ps_num", Integer.class)
27 | .setDescription("ps number")
28 | .setRequired()
29 | .setHasDefaultValue(0).build();
30 |
31 | default String getZookeeperConnStr() {
32 | return get(ZOOKEEPER_CONNECT_STR);
33 | }
34 |
35 | default T setZookeeperConnStr(String zookeeperConnStr) {
36 | return set(ZOOKEEPER_CONNECT_STR, zookeeperConnStr);
37 | }
38 |
39 | default int getWorkerNum() {
40 | return get(WORKER_NUM);
41 | }
42 |
43 | default T setWorkerNum(int workerNum) {
44 | return set(WORKER_NUM, workerNum);
45 | }
46 |
47 | default int getPsNum() {
48 | return get(PS_NUM);
49 | }
50 |
51 | default T setPsNum(int psNum) {
52 | return set(PS_NUM, psNum);
53 | }
54 | }
55 |
--------------------------------------------------------------------------------
/doc/github/[Issue] Simplify the ExampleCoding configuration process.md:
--------------------------------------------------------------------------------
1 | ## [[Issue](https://github.com/alibaba/flink-ai-extended/issues/17)] Simplify the ExampleCoding configuration process
2 |
3 | Now the process of configuring ExampleCoding is cumbersome. As a common configuration, I think we can add some tool interfaces to help users simply configure.
4 | The following is the general configuration process under the current version:
5 |
6 | ```java
7 | // configure encode example coding
8 | String strInput = ExampleCodingConfig.createExampleConfigStr(encodeNames, encodeTypes,
9 | entryType, entryClass);
10 | config.getProperties().put(TFConstants.INPUT_TF_EXAMPLE_CONFIG, strInput);
11 | config.getProperties().put(MLConstants.ENCODING_CLASS,
12 | ExampleCoding.class.getCanonicalName());
13 |
14 | // configure decode example coding
15 | String strOutput = ExampleCodingConfig.createExampleConfigStr(decodeNames, decodeTypes,
16 | entryType, entryClass);
17 | config.getProperties().put(TFConstants.OUTPUT_TF_EXAMPLE_CONFIG, strOutput);
18 | config.getProperties().put(MLConstants.DECODING_CLASS,
19 | ExampleCoding.class.getCanonicalName());
20 | ```
21 |
22 | It can be seen that the user needs to know the column names, types, various constants, etc. of the field when configuring example coding. In fact, it can be encapsulated, and the user only needs to provide the input and output table schema to complete the configuration. For example:
23 |
24 | ```java
25 | ExampleCodingConfigUtil.configureExampleCoding(tfConfig, inputSchema, outputSchema,
26 | ExampleCodingConfig.ObjectType.ROW, Row.class);
27 | ```
28 |
29 | In the current version, the data type that TF can accept is defined in **DataTypes** in Flink-AI-Extended project, and the data type of Flink Table field is defined in **TypeInformation** in Flink project and some basic types such as *BasicTypeInfo* and *BasicArrayTypeInfo* are implemented. But the problem is that the basic types of **DataTypes** and **TypeInformation** are not one-to-one correspondence.
30 |
31 | Therefore, if we want to encapsulate the ExampleCoding configuration process, we need to solve the problem that DataTypes is not compatible with TypeInformation. There are two options:
32 |
33 | 1. Provide a method for converting DataTypes and TypeInformation. Although most of the commonly used types can be matched, they are not completely one-to-one correspondence, so there are some problems that cannot be converted.
34 | 2. Discard DataTypes and use TypeInformation directly in Flink-AI-Extended. DataTypes is just a simple enumeration type that only participates in the identification of data types. TypeInformation can also achieve the same functionality.
35 |
36 | Solution 1 is relatively simple to implement and easy to be compatible, but solution 2 is better in the long run.
37 |
38 |
39 |
40 |
--------------------------------------------------------------------------------
/src/main/python/pointer-generator/flink_writer.py:
--------------------------------------------------------------------------------
1 | from abc import ABCMeta, abstractmethod
2 |
3 | import tensorflow as tf
4 |
5 |
6 | class AbstractWriter(object):
7 | def build_graph(self):
8 | return
9 |
10 | def write_result(self, sess, uuid, article, abstract, reference):
11 | return
12 |
13 | def close(self, sess):
14 | return
15 |
16 |
17 | class FlinkWriter(AbstractWriter):
18 | def __init__(self, tf_context):
19 | self._context = tf_context
20 |
21 | def build_graph(self):
22 | self._write_feed = tf.placeholder(dtype=tf.string)
23 | self.write_op, self._close_op = self._context.output_writer_op([self._write_feed])
24 |
25 | def write_result(self, sess, uuid, article, abstract, reference):
26 | example = tf.train.Example(features=tf.train.Features(
27 | feature={
28 | 'uuid': tf.train.Feature(bytes_list=tf.train.BytesList(value=[uuid])),
29 | 'article': tf.train.Feature(bytes_list=tf.train.BytesList(value=[article])),
30 | 'summary': tf.train.Feature(bytes_list=tf.train.BytesList(value=[abstract])),
31 | 'reference': tf.train.Feature(bytes_list=tf.train.BytesList(value=[reference])),
32 | }
33 | ))
34 | sess.run(self.write_op, feed_dict={self._write_feed: example.SerializeToString()})
35 |
36 | def close(self, sess):
37 | sess.run(self._close_op)
38 |
39 |
40 | class AbstractFlinkWriter(object):
41 | __metaclass__ = ABCMeta
42 |
43 | def __init__(self, context):
44 | """
45 | Initialize the writer
46 | :param context: TFContext
47 | """
48 | self._context = context
49 | self._build_graph()
50 |
51 | def _build_graph(self):
52 | self._write_feed = tf.placeholder(dtype=tf.string)
53 | self.write_op, self._close_op = self._context.output_writer_op([self._write_feed])
54 |
55 | @abstractmethod
56 | def _example(self, results):
57 | # TODO: Implement this function using context automatically
58 | """
59 | Encode the results tensor to `tf.train.Example`
60 |
61 | Examples:
62 | ```
63 | example = tf.train.Example(features=tf.train.Features(
64 | feature={
65 | 'predict_label': tf.train.Feature(int64_list=tf.train.Int64List(value=[results[0]])),
66 | 'label_org': tf.train.Feature(int64_list=tf.train.Int64List(value=[results[1]])),
67 | 'abstract': tf.train.Feature(bytes_list=tf.train.BytesList(value=[results[2]])),
68 | }
69 | ))
70 | return example
71 | ```
72 |
73 | :param results: the result list of `Tensor` to write into Flink
74 | :return: An Example.
75 | """
76 | pass
77 |
78 | def write_result(self, sess, results):
79 | """
80 | Encode the results tensor and write to Flink.
81 | """
82 | sess.run(self.write_op, feed_dict={self._write_feed: self._example(results).SerializeToString()})
83 |
84 | def close(self, sess):
85 | """
86 | close writer
87 | :param sess: the session to execute operator
88 | """
89 | sess.run(self._close_op)
90 |
--------------------------------------------------------------------------------
/doc/deprecated/Flink-AI-Extended Integration design.md:
--------------------------------------------------------------------------------
1 | ## Flink-AI-Extended Integration design
2 |
3 | ### TFEstimator:
4 |
5 | ```java
6 | /**
7 | * A general TensorFlow Estimator who need general TFConfig,
8 | * the fit() function is actually executed by Flink-AI-Extended
9 | */
10 | public interface TFEstimator, M extends TFModel>
11 | implements WithTFConfigParams extends Estimator {
12 | }
13 | ```
14 |
15 | ### TFModel:
16 |
17 | ```java
18 | /**
19 | * A general TensorFlow Model who need general TFConfig,
20 | * the transform() function is actually executed by Flink-AI-Extended
21 | */
22 | public interface TFModel>
23 | implements WithTFConfigParams extends Model {
24 | }
25 | ```
26 |
27 | ### WithTFConfigParams:
28 |
29 | ```java
30 | /**
31 | * Calling python scripts via Flink-AI-Extended generally requires the following parameters:
32 | * 1. Zookeeper address: String
33 | * 2. Python scripts path: String[], the first file will be the entry
34 | * 3. Worker number: Int
35 | * 4. Ps number: Int
36 | * 5. Map function: String, the function in entry file to be called
37 | * 6. Virtual environment path: String, could be null
38 | */
39 | public interface WithTFConfigParams> extends WithZookeeperAddressParams,
40 | WithPythonScriptsParams,
41 | WithWorkerNumParams,
42 | WithPsNumParams,
43 | WithMapFunctionParams,
44 | WithEnvironmentPathParams {
45 |
46 | }
47 | ```
48 |
49 | ### TFSummaryEstimator:
50 |
51 | ```java
52 | /**
53 | * A specific document summarization estimator,
54 | * the training process needs to specify the input article column
55 | * and the input abstract column,
56 | * in addition to the hyperparameter for TensorFlow model
57 | */
58 | public class TFSummaryEstimator implements TFEstimator,
59 | WithInputArticleCol,
60 | WithInputAbstractCol,
61 | WithHyperParams {
62 | private Params params = new Params();
63 |
64 | @Override
65 | public TFSummaryModel fit(TableEnviroment tEnv, Table input) {
66 | //TODO: call training process through Flink-AI-Extended
67 | }
68 |
69 | @Override
70 | public Params getParams() {
71 | return params;
72 | }
73 | }
74 | ```
75 |
76 | ### TFSummaryModel:
77 |
78 | ```java
79 | /**
80 | * A specific document summarization estimator,
81 | * the inference process needs to specify the input article column
82 | * and the output abstract column,
83 | * in addition to the model path and hyperparameter for TensorFlow model
84 | */
85 | public class TFSummaryModel implements TFModel,
86 | WithInputArticleCol,
87 | WithOutputAbstractCol,
88 | WithModelPathParams,
89 | WithHyperParams {
90 | private Params params = new Params();
91 |
92 | @Override
93 | public Table transform(TableEnvironment tEnv, Table input) {
94 | //TODO: call inference process through Flink-AI-Extended
95 | }
96 |
97 | @Override
98 | public Params getParams() {
99 | return params;
100 | }
101 | }
102 | ```
103 |
104 |
105 |
106 |
--------------------------------------------------------------------------------
/doc/deprecated/About StreamExeEnv in AI-Extended Issue.md:
--------------------------------------------------------------------------------
1 | ## About StreamExeEnv in AI-Extended Issue
2 |
3 | 我在把Flink-AI-Extended兼容整合进Pipeline框架时存在缺少**StreamExecutionEnvironment**的问题。
4 |
5 | 现版本ML pipeline框架的接口定义如下,fit()与transform()方法均只接收input table所绑定的**TableEnv**,该**TableEnv**是由上层**ExecutionEnvironment**创建的(或**StreamExeEnv**或**BatchExeEnv**)。但问题是现在的Flink-AI-Extended里所有的train或inference方法都是基于**StreamExeEnv**,需同时传入**StreamExeEnv**和**TableEnv**,而**StreamExeEnv**在ML pipeline框架下是不一定存在的,这就会导致无法兼容现有的框架。
6 |
7 | ```java
8 | public interface Transformer> extends PipelineStage {
9 | Table transform(TableEnvironment tEnv, Table input);
10 | }
11 |
12 | public interface Model> extends Transformer {
13 | }
14 |
15 | public interface Estimator, M extends Model>
16 | extends PipelineStage {
17 | M fit(TableEnvironment tEnv, Table input);
18 | }
19 | ```
20 |
21 | 我查了一下**StreamExeEnv**在Flink-AI-Extended里的引用,如下所示,主要用在两个地方:
22 |
23 | 1. 注册python缓存文件。
24 |
25 | ```java
26 | package com.alibaba.flink.ml.operator.util;
27 | public class PythonFileUtil {
28 | //...
29 | private static List registerPythonLibFiles(
30 | StreamExecutionEnvironment env,
31 | String... userPyLibs) throws IOException {
32 | Tuple2