├── .clang-format ├── docs └── clink_logo.png ├── .gitmodules ├── .gitignore ├── third_party ├── json.BUILD └── BUILD ├── .github └── workflows │ ├── clink_ci_docker.yml │ └── clink_format.yml ├── mlir_test └── executor │ └── basic.mlir ├── lib ├── kernels │ ├── static_registration.cc │ ├── clink_kernels.cc │ └── opdefs │ │ └── clink_kernels.cc ├── linalg │ └── sparse_vector.cc ├── utils │ ├── clink_runner.cc │ └── clink_utils.cc ├── feature │ └── one_hot_encoder.cc ├── executor │ └── main.cc └── jna │ └── clink_jna.cc ├── include └── clink │ ├── feature │ ├── proto │ │ └── one_hot_encoder.proto │ └── one_hot_encoder.h │ ├── kernels │ ├── clink_kernels.h │ └── opdefs │ │ ├── types.h │ │ ├── clink_kernels.h │ │ └── clink_kernels.td │ ├── utils │ ├── clink_utils.h │ └── clink_runner.h │ ├── linalg │ ├── vector.h │ └── sparse_vector.h │ └── api │ └── model.h ├── java-lib ├── src │ ├── test │ │ └── java │ │ │ └── org │ │ │ └── flinkextended │ │ │ └── clink │ │ │ ├── example │ │ │ └── ExampleTest.java │ │ │ ├── util │ │ │ └── AllTestsRunner.java │ │ │ └── feature │ │ │ └── ClinkOneHotEncoderTest.java │ └── main │ │ └── java │ │ └── org │ │ └── flinkextended │ │ └── clink │ │ ├── util │ │ ├── ByteArrayEncoder.java │ │ ├── ByteArrayDecoder.java │ │ └── ClinkReadWriteUtils.java │ │ ├── jna │ │ ├── SparseVectorJna.java │ │ └── ClinkJna.java │ │ └── feature │ │ └── onehotencoder │ │ ├── ClinkOneHotEncoder.java │ │ └── ClinkOneHotEncoderModel.java └── pom.xml ├── .bazelrc ├── cpp_tests ├── linalg │ └── sparse_vector_test.cc ├── BUILD ├── include │ └── clink │ │ └── cpp_tests │ │ └── test_util.h └── feature │ └── one_hot_encoder_test.cc ├── docker ├── Dockerfile_centos_77 ├── Dockerfile_ubuntu_1604 └── README.md ├── tools └── format-code.sh ├── README.md ├── WORKSPACE └── BUILD /.clang-format: -------------------------------------------------------------------------------- 1 | BasedOnStyle: Google 2 | -------------------------------------------------------------------------------- /docs/clink_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flink-extended/clink/HEAD/docs/clink_logo.png -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "tfrt"] 2 | path = tfrt 3 | url = https://github.com/tensorflow/runtime.git 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | bazel-* 2 | .vscode 3 | .idea 4 | cache 5 | cscope* 6 | tags 7 | *.swp 8 | protos/*.cc 9 | protos/*.h 10 | *.jar 11 | *.iml 12 | .ijwb/ 13 | java-lib/target 14 | .DS_Store 15 | -------------------------------------------------------------------------------- /third_party/json.BUILD: -------------------------------------------------------------------------------- 1 | package(default_visibility = ["//visibility:public"]) 2 | 3 | load("@rules_cc//cc:defs.bzl", "cc_library") 4 | 5 | licenses(["notice"]) 6 | 7 | cc_library( 8 | name = "json", 9 | hdrs = ["single_include/nlohmann/json.hpp"], 10 | strip_include_prefix = "single_include/", 11 | ) 12 | -------------------------------------------------------------------------------- /third_party/BUILD: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | 13 | package( 14 | default_visibility = ["//:__subpackages__"], 15 | ) 16 | -------------------------------------------------------------------------------- /.github/workflows/clink_ci_docker.yml: -------------------------------------------------------------------------------- 1 | name: Clink CI Docker 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | bazel_test: 7 | runs-on: ubuntu-latest 8 | strategy: 9 | matrix: 10 | os: ["centos7.7.1908", "ubuntu16.04"] 11 | steps: 12 | - name: Check out the repo 13 | uses: actions/checkout@v2 14 | - name: Update submodules 15 | run: git submodule update --init --recursive 16 | - name: Pull Docker image 17 | run: docker pull docker.io/flinkextended/clink:${{matrix.os}} 18 | - name: Run tests in Docker image 19 | run: | 20 | docker run -t -v ${GITHUB_WORKSPACE}:/root/clink -w /root/clink \ 21 | docker.io/flinkextended/clink:${{matrix.os}} /bin/bash \ 22 | bazel test $(bazel query //...) -c dbg 23 | -------------------------------------------------------------------------------- /mlir_test/executor/basic.mlir: -------------------------------------------------------------------------------- 1 | // Licensed under the Apache License, Version 2.0 (the "License"); 2 | // you may not use this file except in compliance with the License. 3 | // You may obtain a copy of the License at 4 | // 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // 7 | // Unless required by applicable law or agreed to in writing, software 8 | // distributed under the License is distributed on an "AS IS" BASIS, 9 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | // See the License for the specific language governing permissions and 11 | // limitations under the License. 12 | 13 | func @main(%arg_0: f64) { 14 | %ch0 = tfrt.new.chain 15 | 16 | %value_1 = clink.square.f64 %arg_0 17 | %result = clink.square_add.f64 %value_1, %arg_0 18 | 19 | %ch1 = tfrt.print.f64 %result, %ch0 20 | tfrt.return 21 | } 22 | -------------------------------------------------------------------------------- /lib/kernels/static_registration.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed under the Apache License, Version 2.0 (the "License"); 3 | * you may not use this file except in compliance with the License. 4 | * You may obtain a copy of the License at 5 | * 6 | * http://www.apache.org/licenses/LICENSE-2.0 7 | * 8 | * Unless required by applicable law or agreed to in writing, software 9 | * distributed under the License is distributed on an "AS IS" BASIS, 10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | * See the License for the specific language governing permissions and 12 | * limitations under the License. 13 | */ 14 | 15 | // This file uses a static constructor to automatically register all of the 16 | // kernels in this directory. This can be used to simplify clients that don't 17 | // care about selective registration of kernels. 18 | 19 | #include "clink/kernels/clink_kernels.h" 20 | #include "tfrt/host_context/kernel_registry.h" 21 | 22 | namespace clink { 23 | 24 | TFRT_STATIC_KERNEL_REGISTRATION(RegisterClinkKernels); 25 | 26 | } // namespace clink 27 | -------------------------------------------------------------------------------- /.github/workflows/clink_format.yml: -------------------------------------------------------------------------------- 1 | name: Clink Format 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | format_check: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - name: Check out the repo 10 | uses: actions/checkout@v2 11 | - name: Update submodules 12 | run: git submodule update --init --recursive 13 | - name: Check md5sum equality 14 | run: | 15 | pre_built_md5sum=`find ./ -type f -exec md5sum {} \; | md5sum` 16 | 17 | echo "Running ./tools/format-code.sh" 18 | ./tools/format-code.sh 19 | 20 | post_built_md5sum=`find ./ -type f -exec md5sum {} \; | md5sum` 21 | 22 | echo "Pre-formatted md5sum: $pre_built_md5sum" 23 | echo "Post-formatted md5sum: $post_built_md5sum" 24 | if [ "${pre_built_md5sum}" != "${post_built_md5sum}" ]; then 25 | echo "Pre-formatted md5sum and Post-formatted md5sum do not equal." 26 | echo "Please format all Clink codes using ./tools/format-code.sh" \ 27 | "before submitting commits." 28 | exit 1 29 | fi 30 | -------------------------------------------------------------------------------- /include/clink/feature/proto/one_hot_encoder.proto: -------------------------------------------------------------------------------- 1 | 2 | // Copyright 2021 The Clink Authors 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | 16 | // [START declaration] 17 | syntax = "proto3"; 18 | package clink; 19 | // [END declaration] 20 | 21 | // [START java_declaration] 22 | option java_multiple_files = true; 23 | option java_package = "org.clink.feature.onehotencoder"; 24 | option java_outer_classname = "OneHotEncoderProto"; 25 | // [END java_declaration] 26 | 27 | // [START messages] 28 | message OneHotEncoderModelDataProto { 29 | repeated int32 featureSizes = 1; 30 | } 31 | // [END messages] 32 | -------------------------------------------------------------------------------- /include/clink/kernels/clink_kernels.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed under the Apache License, Version 2.0 (the "License"); 3 | * you may not use this file except in compliance with the License. 4 | * You may obtain a copy of the License at 5 | * 6 | * http://www.apache.org/licenses/LICENSE-2.0 7 | * 8 | * Unless required by applicable law or agreed to in writing, software 9 | * distributed under the License is distributed on an "AS IS" BASIS, 10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | * See the License for the specific language governing permissions and 12 | * limitations under the License. 13 | */ 14 | 15 | #ifndef CLINK_KERNELS_CLINK_KERNELS_H_ 16 | #define CLINK_KERNELS_CLINK_KERNELS_H_ 17 | 18 | #include "tfrt/host_context/kernel_utils.h" 19 | 20 | using namespace tfrt; 21 | 22 | namespace clink { 23 | 24 | AsyncValueRef SquareAdd(Argument x, Argument y, 25 | const ExecutionContext &exec_ctx); 26 | 27 | double Square(double x); 28 | 29 | void RegisterClinkKernels(tfrt::KernelRegistry *registry); 30 | 31 | } // namespace clink 32 | 33 | #endif // CLINK_KERNELS_CLINK_KERNELS_H_ 34 | -------------------------------------------------------------------------------- /java-lib/src/test/java/org/flinkextended/clink/example/ExampleTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Clink Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package org.flinkextended.clink.example; 18 | 19 | import org.flinkextended.clink.jna.ClinkJna; 20 | import org.junit.Test; 21 | 22 | import static org.junit.Assert.assertEquals; 23 | 24 | /** Simple test of native library declaration and usage. */ 25 | public class ExampleTest { 26 | @Test 27 | public void testExample() { 28 | assertEquals(9.0, ClinkJna.INSTANCE.Square(3.0), 1e-5); 29 | assertEquals(10.0, ClinkJna.INSTANCE.SquareAdd(1.0, 3.0), 1e-5); 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /include/clink/kernels/opdefs/types.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Clink Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #ifndef CLINK_FEATURE_OPDEFS_TYPES_H_ 18 | #define CLINK_FEATURE_OPDEFS_TYPES_H_ 19 | 20 | #include "mlir/IR/Types.h" 21 | 22 | namespace clink { 23 | 24 | class ModelType 25 | : public mlir::Type::TypeBase { 26 | public: 27 | using Base::Base; 28 | }; 29 | 30 | class VectorType 31 | : public mlir::Type::TypeBase { 32 | public: 33 | using Base::Base; 34 | }; 35 | 36 | } // namespace clink 37 | 38 | #endif // CLINK_FEATURE_OPDEFS_TYPES_H_ 39 | -------------------------------------------------------------------------------- /java-lib/src/main/java/org/flinkextended/clink/util/ByteArrayEncoder.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Clink Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package org.flinkextended.clink.util; 18 | 19 | import org.apache.flink.api.common.serialization.Encoder; 20 | import org.apache.flink.core.memory.DataOutputViewStreamWrapper; 21 | 22 | import java.io.IOException; 23 | import java.io.OutputStream; 24 | 25 | /** Data Encoder for byte array. */ 26 | public class ByteArrayEncoder implements Encoder { 27 | @Override 28 | public void encode(byte[] bytes, OutputStream outputStream) throws IOException { 29 | DataOutputViewStreamWrapper outputViewStreamWrapper = 30 | new DataOutputViewStreamWrapper(outputStream); 31 | outputViewStreamWrapper.writeInt(bytes.length); 32 | outputViewStreamWrapper.write(bytes); 33 | outputStream.flush(); 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /.bazelrc: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # https://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | 13 | build --enable_platform_specific_config 14 | 15 | build:linux --config=clang # On Linux, build with clang by default. 16 | 17 | # Build with clang. 18 | build:clang --repo_env=CC=clang 19 | # This should be 'build:clang' as well, but then could no longer be overriden by 20 | # --config=nvcc. See https://github.com/bazelbuild/bazel/issues/13603. 21 | build:clang --cxxopt=-std=c++14 --host_cxxopt=-std=c++14 22 | 23 | # Build with gcc (and nvcc if --config=cuda). 24 | build:gcc --repo_env=CC=gcc 25 | build:gcc --config=nvcc 26 | build:gcc --cxxopt=-std=c++14 --host_cxxopt=-std=c++14 27 | build:gcc --cxxopt=-Wno-maybe-uninitialized 28 | build:gcc --cxxopt=-Wno-sign-compare 29 | 30 | # Default to an optimized build. 31 | # Override via: "-c dbg" or --compilation_mode=dbg 32 | build --compilation_mode=opt 33 | 34 | # Disable RTTI and exceptions 35 | build:disable_rtti_and_exceptions --no//:rtti_and_exceptions 36 | 37 | -------------------------------------------------------------------------------- /cpp_tests/linalg/sparse_vector_test.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Clink Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include "clink/linalg/sparse_vector.h" 18 | 19 | #include "clink/cpp_tests/test_util.h" 20 | #include "gtest/gtest.h" 21 | 22 | namespace clink { 23 | 24 | namespace { 25 | 26 | TEST(SparseVectorTest, CreatesVector) { 27 | SparseVector vector(5); 28 | EXPECT_EQ(vector.size(), 5); 29 | } 30 | 31 | TEST(SparseVectorTest, SetGetValue) { 32 | SparseVector vector(5); 33 | vector.set(1, 1.0); 34 | vector.set(2, 3.0); 35 | vector.set(4, 2.5); 36 | EXPECT_EQ(vector.get(0).get(), 0.0); 37 | EXPECT_EQ(vector.get(1).get(), 1.0); 38 | EXPECT_EQ(vector.get(2).get(), 3.0); 39 | EXPECT_EQ(vector.get(3).get(), 0.0); 40 | EXPECT_EQ(vector.get(4).get(), 2.5); 41 | EXPECT_FALSE((bool)vector.get(4).takeError()); 42 | EXPECT_EQ(tfrt::StrCat(vector.get(5).takeError()), "Index out of range."); 43 | } 44 | 45 | } // namespace 46 | } // namespace clink 47 | -------------------------------------------------------------------------------- /include/clink/utils/clink_utils.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed under the Apache License, Version 2.0 (the "License"); 3 | * you may not use this file except in compliance with the License. 4 | * You may obtain a copy of the License at 5 | * 6 | * http://www.apache.org/licenses/LICENSE-2.0 7 | * 8 | * Unless required by applicable law or agreed to in writing, software 9 | * distributed under the License is distributed on an "AS IS" BASIS, 10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | * See the License for the specific language governing permissions and 12 | * limitations under the License. 13 | */ 14 | 15 | #ifndef CLINK_UTILS_CLINK_UTILS_H_ 16 | #define CLINK_UTILS_CLINK_UTILS_H_ 17 | 18 | #include "tfrt/bef_executor_driver/bef_executor_driver.h" 19 | 20 | using namespace tfrt; 21 | 22 | namespace clink { 23 | 24 | std::unique_ptr CreateHostContext( 25 | string_view work_queue_type, tfrt::HostAllocatorType host_allocator_type); 26 | 27 | // Given a directory path, gets the only file in the directory. 28 | // 29 | // In Flink ML operators that produces model data as a protobuf record, they 30 | // save model data in a directory with only one file. C++ knows the directory of 31 | // the file, but the file's name is unknown to C++. This function helps to 32 | // locate that file. 33 | // 34 | // This function returns empty string if the directory does not exist, or there 35 | // is zero or more than one file in the directory. 36 | std::string getOnlyFileInDirectory(std::string path); 37 | 38 | } // namespace clink 39 | 40 | #endif // CLINK_UTILS_CLINK_UTILS_H_ 41 | -------------------------------------------------------------------------------- /include/clink/kernels/opdefs/clink_kernels.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed under the Apache License, Version 2.0 (the "License"); 3 | * you may not use this file except in compliance with the License. 4 | * You may obtain a copy of the License at 5 | * 6 | * http://www.apache.org/licenses/LICENSE-2.0 7 | * 8 | * Unless required by applicable law or agreed to in writing, software 9 | * distributed under the License is distributed on an "AS IS" BASIS, 10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | * See the License for the specific language governing permissions and 12 | * limitations under the License. 13 | */ 14 | 15 | // MLIR op definitions for clink kernels. 16 | // This file declares the 'clink' dialect as well as the operators in 17 | // the clink library. 18 | 19 | #ifndef CLINK_KERNELS_OPDEFS_CLINK_KERNELS_H_ 20 | #define CLINK_KERNELS_OPDEFS_CLINK_KERNELS_H_ 21 | 22 | #include "mlir/IR/DialectImplementation.h" 23 | #include "mlir/Interfaces/InferTypeOpInterface.h" 24 | 25 | using namespace mlir; 26 | 27 | namespace clink { 28 | 29 | // Dialect for clink operations. 30 | class ClinkDialect : public Dialect { 31 | public: 32 | static StringRef getDialectNamespace() { return "clink"; } 33 | explicit ClinkDialect(MLIRContext *context); 34 | 35 | mlir::Type parseType(mlir::DialectAsmParser &parser) const override; 36 | void printType(mlir::Type type, 37 | mlir::DialectAsmPrinter &printer) const override; 38 | }; 39 | 40 | } // namespace clink 41 | 42 | #define GET_OP_CLASSES 43 | #include "clink/kernels/opdefs/clink_kernels.h.inc" 44 | 45 | #endif // CLINK_KERNELS_OPDEFS_CLINK_KERNELS_H_ 46 | -------------------------------------------------------------------------------- /java-lib/src/main/java/org/flinkextended/clink/jna/SparseVectorJna.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Clink Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package org.flinkextended.clink.jna; 18 | 19 | import org.apache.flink.ml.linalg.SparseVector; 20 | 21 | import com.sun.jna.Pointer; 22 | import com.sun.jna.Structure; 23 | import com.sun.jna.Structure.FieldOrder; 24 | 25 | /** 26 | * Class that corresponds to struct SparseVectorJNA in C++. It is only used by JNA to transmit data 27 | * between Java and C++. 28 | */ 29 | @FieldOrder({"n", "indices", "values", "length"}) 30 | public class SparseVectorJna extends Structure { 31 | public static class ByReference extends SparseVectorJna implements Structure.ByReference {} 32 | 33 | public int n; 34 | public Pointer indices; 35 | public Pointer values; 36 | public int length; 37 | 38 | /** Converts this class to {@link SparseVector}. */ 39 | public SparseVector toSparseVector() { 40 | return new SparseVector( 41 | n, indices.getIntArray(0, length), values.getDoubleArray(0, length)); 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /java-lib/src/test/java/org/flinkextended/clink/util/AllTestsRunner.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Clink Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package org.flinkextended.clink.util; 18 | 19 | import org.apache.flink.table.shaded.org.reflections.Reflections; 20 | import org.apache.flink.table.shaded.org.reflections.scanners.SubTypesScanner; 21 | 22 | import junit.framework.JUnit4TestAdapter; 23 | import junit.framework.TestSuite; 24 | import org.junit.runner.RunWith; 25 | 26 | /** Run all tests in org.clink. */ 27 | @RunWith(org.junit.runners.AllTests.class) 28 | public class AllTestsRunner { 29 | public static TestSuite suite() throws Exception { 30 | TestSuite suite = new TestSuite(); 31 | Reflections reflections = 32 | new Reflections("org/flinkextended/clink", new SubTypesScanner(false)); 33 | reflections.getSubTypesOf(Object.class).stream() 34 | .filter(clazz -> clazz.getName().endsWith("Test")) 35 | .map(JUnit4TestAdapter::new) 36 | .forEach(suite::addTest); 37 | return suite; 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /include/clink/linalg/vector.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Clink Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #ifndef CLINK_LINALG_VECTOR_H_ 18 | #define CLINK_LINALG_VECTOR_H_ 19 | 20 | #include "tfrt/host_context/chain.h" 21 | 22 | namespace clink { 23 | 24 | // A vector of double values. 25 | class Vector { 26 | public: 27 | // Sets the value at a certain index of the vector. 28 | virtual llvm::Error set(const int index, const double value) = 0; 29 | 30 | // Gets the value at a certain index of the vector. 31 | virtual llvm::Expected get(const int index) const = 0; 32 | 33 | // Gets the total number of dimensions of the vector. 34 | virtual int size() const = 0; 35 | 36 | protected: 37 | // Move operations are supported. 38 | Vector(Vector &&other) = default; 39 | Vector &operator=(Vector &&other) = default; 40 | 41 | // This class is not copyable or assignable. 42 | Vector(const Vector &other) = delete; 43 | Vector &operator=(const Vector &) = delete; 44 | 45 | Vector() = default; 46 | 47 | virtual ~Vector() {} 48 | }; 49 | 50 | } // namespace clink 51 | 52 | #endif // CLINK_LINALG_VECTOR_H_ 53 | -------------------------------------------------------------------------------- /cpp_tests/BUILD: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | 13 | load("@tf_runtime//:build_defs.bzl", "tfrt_cc_library", "tfrt_cc_test") 14 | 15 | licenses(["notice"]) 16 | 17 | package( 18 | default_visibility = [":__subpackages__"], 19 | ) 20 | 21 | tfrt_cc_library( 22 | name = "common", 23 | testonly = True, 24 | hdrs = [ 25 | "include/clink/cpp_tests/test_util.h", 26 | ], 27 | visibility = [ 28 | "//visibility:public", 29 | ], 30 | deps = [ 31 | "@clink//:clink_kernels", 32 | "@clink//:clink_kernels_alwayslink", 33 | "@clink//:clink_kernels_opdefs", 34 | "@clink//:clink_utils", 35 | "@com_google_googletest//:gtest_main", 36 | "@tf_runtime//:basic_kernels_alwayslink", 37 | "@tf_runtime//:hostcontext_alwayslink", 38 | ], 39 | ) 40 | 41 | tfrt_cc_test( 42 | name = "linalg/sparse_vector_test", 43 | srcs = ["linalg/sparse_vector_test.cc"], 44 | deps = [ 45 | ":common", 46 | ], 47 | ) 48 | 49 | tfrt_cc_test( 50 | name = "feature/one_hot_encoder_test", 51 | srcs = ["feature/one_hot_encoder_test.cc"], 52 | deps = [ 53 | ":common", 54 | ], 55 | ) 56 | -------------------------------------------------------------------------------- /lib/linalg/sparse_vector.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Clink Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include "clink/linalg/sparse_vector.h" 18 | 19 | namespace clink { 20 | 21 | llvm::Expected SparseVector::get(const int index) const { 22 | if (index >= n_ || index < 0) { 23 | return tfrt::MakeStringError("Index out of range."); 24 | } 25 | 26 | for (int i = 0; i < indices_.size(); i++) { 27 | if (indices_[i] == index) { 28 | return values_[i]; 29 | } 30 | } 31 | return 0.0; 32 | } 33 | 34 | llvm::Error SparseVector::set(const int index, const double value) { 35 | if (index >= n_) { 36 | return tfrt::MakeStringError("Index out of range."); 37 | } 38 | 39 | for (int i = 0; i < indices_.size(); i++) { 40 | if (indices_[i] == index) { 41 | values_[i] = value; 42 | return llvm::Error::success(); 43 | } 44 | } 45 | 46 | indices_.push_back(index); 47 | values_.push_back(value); 48 | return llvm::Error::success(); 49 | } 50 | 51 | int SparseVector::size() const { return n_; } 52 | 53 | bool SparseVector::operator==(const SparseVector &other) const { 54 | if (n_ != other.n_) { 55 | return false; 56 | } 57 | for (int i = 0; i < n_; i++) { 58 | if (this->get(i).get() != other.get(i).get()) { 59 | return false; 60 | } 61 | } 62 | return true; 63 | } 64 | 65 | } // namespace clink 66 | -------------------------------------------------------------------------------- /docker/Dockerfile_centos_77: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | 13 | # CentOS 7.7.1908 installed with the following libraries: 14 | # - Bazel 4.0.0 15 | # - Clang 11.1.0 16 | # - libstdc++8 or greater 17 | # - openjdk-8 18 | 19 | FROM centos:centos7.7.1908 as base 20 | 21 | RUN curl https://copr.fedorainfracloud.org/coprs/vbatts/bazel/repo/epel-7/vbatts-bazel-epel-7.repo > /etc/yum.repos.d/vbatts-bazel-epel-7.repo && \ 22 | echo "source /opt/rh/devtoolset-8/enable" >> ~/.bashrc 23 | 24 | RUN yum update -y && yum install -y bazel4 centos-release-scl epel-release wget gcc git 25 | 26 | RUN yum update -y && yum install -y cmake3 devtoolset-8 27 | 28 | RUN wget https://github.com/llvm/llvm-project/releases/download/llvmorg-11.1.0/llvm-project-11.1.0.src.tar.xz && tar xvf llvm-project-11.1.0.src.tar.xz 29 | 30 | RUN mkdir /llvm-project-11.1.0.src/build && \ 31 | cd /llvm-project-11.1.0.src/build && \ 32 | source ~/.bashrc && \ 33 | cmake3 -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_PROJECTS=clang -G "Unix Makefiles" ../llvm && \ 34 | make && \ 35 | make install -j 8 && \ 36 | rm -rf /llvm-project-11.1.0.src 37 | 38 | RUN git clone --depth 1 https://github.com/flink-extended/clink.git /tmp/clink && \ 39 | cd /tmp/clink && \ 40 | git submodule update --init --recursive && \ 41 | bazel build --disk_cache=~/.cache/bazel @tf_runtime//tools:bef_executor_lite && \ 42 | bazel build --disk_cache=~/.cache/bazel @tf_runtime//tools:tfrt_translate && \ 43 | rm -rf /tmp/clink 44 | -------------------------------------------------------------------------------- /include/clink/linalg/sparse_vector.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Clink Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #ifndef CLINK_LINALG_SPARSE_VECTOR_H_ 18 | #define CLINK_LINALG_SPARSE_VECTOR_H_ 19 | 20 | #include "clink/linalg/vector.h" 21 | 22 | namespace clink { 23 | 24 | // A sparse vector of double values. 25 | class SparseVector : public Vector { 26 | public: 27 | // Constructor for SparseVector. 28 | // `n` stands for the number of dimensions of the vector. 29 | explicit SparseVector(const int n) : n_(n) {} 30 | 31 | // Move operations are supported. 32 | SparseVector(SparseVector &&other) = default; 33 | SparseVector &operator=(SparseVector &&other) = default; 34 | 35 | // This class is not copyable or assignable. 36 | SparseVector(const SparseVector &other) = delete; 37 | SparseVector &operator=(const SparseVector &) = delete; 38 | 39 | // Sets the value at a certain index of the vector. 40 | llvm::Error set(const int index, const double value); 41 | 42 | // Gets the value at a certain index of the vector. 43 | llvm::Expected get(const int index) const; 44 | 45 | // Gets the total number of dimensions of the vector. 46 | int size() const; 47 | 48 | bool operator==(const SparseVector &other) const; 49 | 50 | private: 51 | const int n_; 52 | llvm::SmallVector indices_; 53 | llvm::SmallVector values_; 54 | }; 55 | 56 | } // namespace clink 57 | 58 | #endif // CLINK_LINALG_SPARSE_VECTOR_H_ 59 | -------------------------------------------------------------------------------- /docker/Dockerfile_ubuntu_1604: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | 13 | # Ubuntu 16.04 installed with the following libraries: 14 | # - Bazel 4.0.0 15 | # - Clang 11.1.0 16 | # - libstdc++8 or greater 17 | # - openjdk-8 18 | 19 | FROM ubuntu:16.04 as base 20 | 21 | RUN apt-get update && apt-get install -y apt-transport-https software-properties-common sudo 22 | 23 | RUN echo "deb http://apt.llvm.org/xenial/ llvm-toolchain-xenial main" >> /etc/apt/sources.list && \ 24 | echo "deb-src http://apt.llvm.org/xenial/ llvm-toolchain-xenial main" >> /etc/apt/sources.list && \ 25 | echo "deb http://apt.llvm.org/xenial/ llvm-toolchain-xenial-11 main" >> /etc/apt/sources.list && \ 26 | echo "deb [arch=amd64] https://storage.googleapis.com/bazel-apt stable jdk1.8" >> /etc/apt/sources.list.d/bazel.list 27 | 28 | RUN add-apt-repository -y ppa:ubuntu-toolchain-r/test && \ 29 | apt-get update && \ 30 | apt-get install -y --allow-unauthenticated --fix-missing bazel clang-11 gcc-8 g++-8 openjdk-8-jdk git && \ 31 | rm -rf /var/lib/apt/lists/* 32 | 33 | RUN update-alternatives --install /usr/bin/clang clang /usr/bin/clang-11 11 && \ 34 | update-alternatives --install /usr/bin/clang++ clang++ /usr/bin/clang++-11 11 35 | 36 | RUN git clone --depth 1 https://github.com/flink-extended/clink.git /tmp/clink && \ 37 | cd /tmp/clink && \ 38 | git submodule update --init --recursive && \ 39 | bazel build --disk_cache=~/.cache/bazel @tf_runtime//tools:bef_executor_lite && \ 40 | bazel build --disk_cache=~/.cache/bazel @tf_runtime//tools:tfrt_translate && \ 41 | rm -rf /tmp/clink 42 | -------------------------------------------------------------------------------- /docker/README.md: -------------------------------------------------------------------------------- 1 | Clink provides dockerfiles in this folder to create Docker images that provide 2 | the environment required to build and execute Clink programs. This README 3 | contains guidelines of how to build Clink Docker images and push them to Docker 4 | Hub. 5 | 6 | **When should Clink Docker images on Docker Hub be updated?** 7 | 8 | What Clink images provides is the environment to execute Clink programs, not 9 | Clink itself, so Clink images on Docker Hub need to be updated only when the 10 | environment required for Clink changes. It is worth noting that Clink 11 | dockerfiles also cache prebuilt TFRT in the images, thus changing the commit id 12 | of the TFRT submodule of this repository also means that Clink Docker images 13 | should be updated. 14 | 15 | ## Prerequisites 16 | 17 | - Make sure you have installed [Docker](https://docs.docker.com/engine/install/) 18 | on your system. 19 | - Make sure that you have write permission to `flinkextended/clink`, which means 20 | one of the following. 21 | - You are authenticated with your Docker ID, and that your Docker ID has 22 | access to `flinkextended/clink`, or 23 | - You have a `flinkextended/clink`'s access token, so that you can directly 24 | login to `flinkextended`. 25 | - If you do not have such permissions, please contact the reviewers of your PR 26 | and request access to `flinkextended` from them. They are supposed to have 27 | been familiar with Clink and know where to get such access. 28 | 29 | ## Docker Image Publication Guideline 30 | 31 | 1. Build the Docker image locally. From this directory: 32 | 33 | ```sh 34 | $ docker build -t clink: -f .. 35 | ``` 36 | 37 | 2. Create the Docker image to be published from locally built image. 38 | 39 | ```sh 40 | $ docker tag clink: flinkextended/clink: 41 | ``` 42 | 43 | 3. Log in to the Docker Hub account with write permission. 44 | 45 | ```sh 46 | $ docker login -u 47 | ``` 48 | 49 | 4. Push the Docker image to Docker Hub. 50 | 51 | ```sh 52 | $ docker push flinkextended/clink: 53 | ``` 54 | -------------------------------------------------------------------------------- /include/clink/api/model.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Clink Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #ifndef CLINK_API_MODEL_H_ 18 | #define CLINK_API_MODEL_H_ 19 | 20 | #include "tfrt/host_context/async_dispatch.h" 21 | #include "tfrt/host_context/async_value.h" 22 | #include "tfrt/host_context/host_allocator.h" 23 | #include "tfrt/support/ref_count.h" 24 | 25 | namespace clink { 26 | 27 | // Basic interface for Clink operators that provides feature processing 28 | // function. 29 | // 30 | // NOTE: every Model subclass should implement a static method with signature 31 | // `static llvm::Expected> load(const std::string &path, 32 | // tfrt::HostContext *host)`, where `T` refers to the concrete subclass. This 33 | // static method should instantiate a new Model instance based on the data read 34 | // from the given path. 35 | class Model : public tfrt::ReferenceCounted { 36 | public: 37 | virtual ~Model() {} 38 | 39 | // Applies the Model on the given ArrayRef of input AsyncValues and returns 40 | // a SmallVector of AsyncValues. 41 | virtual llvm::SmallVector, 4> transform( 42 | llvm::ArrayRef inputs, 43 | const tfrt::ExecutionContext &exec_ctx) = 0; 44 | 45 | protected: 46 | template 47 | static void DestroyImpl(SubClass *ptr, tfrt::HostAllocator *allocator) { 48 | ptr->~SubClass(); 49 | allocator->DeallocateBytes(ptr, sizeof(SubClass)); 50 | } 51 | 52 | private: 53 | // For access to Destroy(). 54 | friend class ReferenceCounted; 55 | 56 | virtual void Destroy() = 0; 57 | }; 58 | 59 | } // namespace clink 60 | 61 | #endif // CLINK_API_MODEL_H_ 62 | -------------------------------------------------------------------------------- /tools/format-code.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | ################################################################################ 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | ################################################################################ 15 | 16 | # This script formats all codes in the Clink repository. It uses clang-format to 17 | # format C++ code, diffplug/spotless to format Java code, and Buildifier to 18 | # format Bazel code. 19 | 20 | set -e 21 | 22 | version_array=( 23 | "clang-format" "11.1.0" "clang-format --version | cut -d\" \" -f3" 24 | "bazel" "4.0.0" "bazel --version | cut -d\" \" -f2" 25 | "mvn" "3.1.0" "mvn --version | head -n1 | cut -d\" \" -f3" 26 | ) 27 | 28 | # Checks whether required tools have been installed 29 | for ((i = 0; i < ${#version_array[@]}; i += 3)); do 30 | cmd=${version_array[$i]} 31 | if ! command -v $cmd &> /dev/null 32 | then 33 | echo "$cmd: command not found" 34 | exit 1 35 | fi 36 | expected_version=${version_array[$i+1]} 37 | actual_version=`eval "${version_array[$i+2]}"` 38 | unsorted_versions="${expected_version}\n${actual_version}\n" 39 | sorted_versions=`printf ${unsorted_versions} | sort -V` 40 | unsorted_versions=`printf ${unsorted_versions}` 41 | if [ "${unsorted_versions}" != "${sorted_versions}" ]; then 42 | echo "$cmd $expected_version or a higher version is required, but found $actual_version" 43 | exit 1 44 | fi 45 | done 46 | 47 | # Formats C++ codes 48 | find . \( -name "*.cc" -or -name "*.h" \) -not -path "./tfrt/*" -exec clang-format -i {} \; 49 | 50 | # Formats Java codes 51 | mvn -f java-lib spotless:apply 52 | 53 | # Formats Bazel codes 54 | bazel run //:buildifier 55 | -------------------------------------------------------------------------------- /java-lib/src/main/java/org/flinkextended/clink/util/ByteArrayDecoder.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Clink Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package org.flinkextended.clink.util; 18 | 19 | import org.apache.flink.api.common.typeinfo.TypeInformation; 20 | import org.apache.flink.api.common.typeinfo.Types; 21 | import org.apache.flink.configuration.Configuration; 22 | import org.apache.flink.connector.file.src.reader.SimpleStreamFormat; 23 | import org.apache.flink.core.fs.FSDataInputStream; 24 | import org.apache.flink.core.memory.DataInputViewStreamWrapper; 25 | import org.apache.flink.util.Preconditions; 26 | 27 | import java.io.EOFException; 28 | import java.io.IOException; 29 | 30 | /** Data Decoder for byte array. */ 31 | public class ByteArrayDecoder extends SimpleStreamFormat { 32 | @Override 33 | public Reader createReader(Configuration config, FSDataInputStream inputStream) { 34 | return new Reader() { 35 | final DataInputViewStreamWrapper inputViewStreamWrapper = 36 | new DataInputViewStreamWrapper(inputStream); 37 | 38 | @Override 39 | public byte[] read() throws IOException { 40 | try { 41 | int expectedLen = inputViewStreamWrapper.readInt(); 42 | byte[] bytes = new byte[expectedLen]; 43 | int actualLen = inputViewStreamWrapper.read(bytes); 44 | Preconditions.checkArgument(expectedLen == actualLen); 45 | return bytes; 46 | } catch (EOFException e) { 47 | return null; 48 | } 49 | } 50 | 51 | @Override 52 | public void close() throws IOException { 53 | inputStream.close(); 54 | } 55 | }; 56 | } 57 | 58 | @Override 59 | public TypeInformation getProducedType() { 60 | return (TypeInformation) Types.PRIMITIVE_ARRAY(Types.BYTE); 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /java-lib/src/main/java/org/flinkextended/clink/feature/onehotencoder/ClinkOneHotEncoder.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Clink Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package org.flinkextended.clink.feature.onehotencoder; 18 | 19 | import org.apache.flink.ml.api.Estimator; 20 | import org.apache.flink.ml.feature.onehotencoder.OneHotEncoder; 21 | import org.apache.flink.ml.feature.onehotencoder.OneHotEncoderModel; 22 | import org.apache.flink.ml.feature.onehotencoder.OneHotEncoderParams; 23 | import org.apache.flink.ml.param.Param; 24 | import org.apache.flink.ml.util.ReadWriteUtils; 25 | import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; 26 | import org.apache.flink.table.api.Table; 27 | 28 | import java.io.IOException; 29 | import java.util.Map; 30 | 31 | public class ClinkOneHotEncoder 32 | implements Estimator, 33 | OneHotEncoderParams { 34 | private final OneHotEncoder estimator; 35 | 36 | public ClinkOneHotEncoder() { 37 | this(new OneHotEncoder()); 38 | } 39 | 40 | private ClinkOneHotEncoder(OneHotEncoder estimator) { 41 | this.estimator = estimator; 42 | } 43 | 44 | @Override 45 | public ClinkOneHotEncoderModel fit(Table... inputs) { 46 | OneHotEncoderModel model = estimator.fit(inputs); 47 | ClinkOneHotEncoderModel clinkModel = new ClinkOneHotEncoderModel(); 48 | ReadWriteUtils.updateExistingParams(clinkModel, model.getParamMap()); 49 | clinkModel.setModelData(model.getModelData()); 50 | return clinkModel; 51 | } 52 | 53 | @Override 54 | public void save(String path) throws IOException { 55 | estimator.save(path); 56 | } 57 | 58 | public static ClinkOneHotEncoder load(StreamExecutionEnvironment env, String path) 59 | throws IOException { 60 | OneHotEncoder estimator = OneHotEncoder.load(env, path); 61 | return new ClinkOneHotEncoder(estimator); 62 | } 63 | 64 | @Override 65 | public Map, Object> getParamMap() { 66 | return estimator.getParamMap(); 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /include/clink/feature/one_hot_encoder.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Clink Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #ifndef CLINK_FEATURE_ONE_HOT_ENCODER_H_ 18 | #define CLINK_FEATURE_ONE_HOT_ENCODER_H_ 19 | 20 | #include "clink/api/model.h" 21 | #include "clink/feature/proto/one_hot_encoder.pb.h" 22 | #include "clink/linalg/sparse_vector.h" 23 | 24 | namespace clink { 25 | 26 | // A Model which encodes data into one-hot format. 27 | class OneHotEncoderModel : public Model { 28 | public: 29 | // Default constructor. 30 | OneHotEncoderModel(tfrt::HostContext *host) : allocator_(host->allocator()) {} 31 | 32 | // Move operations are supported. 33 | OneHotEncoderModel(OneHotEncoderModel &&other) = default; 34 | OneHotEncoderModel &operator=(OneHotEncoderModel &&other) = default; 35 | 36 | // This class is not copyable or assignable. 37 | OneHotEncoderModel(const OneHotEncoderModel &other) = delete; 38 | OneHotEncoderModel &operator=(const OneHotEncoderModel &) = delete; 39 | 40 | llvm::SmallVector, 4> transform( 41 | llvm::ArrayRef inputs, 42 | const tfrt::ExecutionContext &exec_ctx) override; 43 | 44 | // Loads a OneHotEncoderModel from given path. The path should be a directory 45 | // containing params and model data saved through 46 | // org.clink.feature.onehotencoder.ClinkOneHotEnoderModel::save(...). 47 | static llvm::Expected> load( 48 | const std::string &path, tfrt::HostContext *host); 49 | 50 | void setDropLast(const bool is_droplast); 51 | 52 | bool getDropLast() const; 53 | 54 | private: 55 | void Destroy() override { 56 | Model::DestroyImpl(this, allocator_); 57 | } 58 | 59 | // Params of OneHotEncoderModel. 60 | struct Params { 61 | // Whether to drop the last category. 62 | bool is_droplast; 63 | }; 64 | 65 | // Params of OneHotEncoderModel. 66 | Params params_; 67 | 68 | // Model data of OneHotEncoderModel. 69 | OneHotEncoderModelDataProto model_data_; 70 | 71 | tfrt::HostAllocator *allocator_; 72 | }; 73 | 74 | } // namespace clink 75 | 76 | #endif // CLINK_FEATURE_ONE_HOT_ENCODER_H_ 77 | -------------------------------------------------------------------------------- /java-lib/src/main/java/org/flinkextended/clink/jna/ClinkJna.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Clink Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package org.flinkextended.clink.jna; 18 | 19 | import com.sun.jna.LastErrorException; 20 | import com.sun.jna.Library; 21 | import com.sun.jna.Native; 22 | import com.sun.jna.Pointer; 23 | 24 | /** Utility methods that have implementations in C++. */ 25 | public interface ClinkJna extends Library { 26 | ClinkJna INSTANCE = Native.load("clink_jna", ClinkJna.class); 27 | 28 | double SquareAdd(double x, double y); 29 | 30 | double Square(double x); 31 | 32 | /** 33 | * Deletes a {@link SparseVectorJna} in order to avoid memory leak. 34 | * 35 | * @param vector A reference to the {@link SparseVectorJna} object. 36 | */ 37 | // TODO: Automatically free C++ objects to avoid memory leak and improve usability. 38 | void SparseVector_delete(SparseVectorJna.ByReference vector); 39 | 40 | /** 41 | * Loads a {@link org.apache.flink.ml.feature.onehotencoder.OneHotEncoderModel} C++ operator 42 | * from given path. The path should be a directory containing params saved in json format and 43 | * model data saved in protobuf format. 44 | * 45 | * @return Pointer to the loaded C++ Operator 46 | */ 47 | Pointer OneHotEncoderModel_load(String path) throws LastErrorException; 48 | 49 | /** 50 | * Converts an indexed integer to one-hot-encoded sparse vector, using the {@link 51 | * org.apache.flink.ml.feature.onehotencoder.OneHotEncoderModel} C++ operator. 52 | * 53 | * @param modelPointer Pointer to the OneHotEncoder C++ operator 54 | * @param value The indexed integer to be converted 55 | * @param columnIndex The column index which the indexed integer locates 56 | * @return A one-hot-encoded sparse vector 57 | */ 58 | // TODO: Compare the performance of using ByReference v.s. ByValue and optimize accordingly. 59 | SparseVectorJna.ByReference OneHotEncoderModel_transform( 60 | Pointer modelPointer, int value, int columnIndex) throws LastErrorException; 61 | 62 | /** 63 | * Deletes a {@link org.apache.flink.ml.feature.onehotencoder.OneHotEncoderModel} C++ operator 64 | * in order to avoid memory leak. 65 | * 66 | * @param modelPointer Pointer to the OneHotEncoder C++ operator 67 | */ 68 | void OneHotEncoderModel_delete(Pointer modelPointer); 69 | } 70 | -------------------------------------------------------------------------------- /lib/utils/clink_runner.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed under the Apache License, Version 2.0 (the "License"); 3 | * you may not use this file except in compliance with the License. 4 | * You may obtain a copy of the License at 5 | * 6 | * http://www.apache.org/licenses/LICENSE-2.0 7 | * 8 | * Unless required by applicable law or agreed to in writing, software 9 | * distributed under the License is distributed on an "AS IS" BASIS, 10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | * See the License for the specific language governing permissions and 12 | * limitations under the License. 13 | */ 14 | 15 | #include "clink/utils/clink_runner.h" 16 | 17 | #include "mlir/Parser.h" 18 | #include "tfrt/bef_converter/mlir_to_bef.h" 19 | #include "tfrt/host_context/function.h" 20 | #include "tfrt/host_context/host_context.h" 21 | 22 | namespace clink { 23 | 24 | ClinkRunner::Builder::Builder() {} 25 | 26 | ClinkRunner ClinkRunner::Builder::Compile() { 27 | assert(!mlir_input_.empty() && 28 | "mlir_input must be set before calling Compile."); 29 | assert(!fn_name_.empty() && "fn_name must be set before calling Compile."); 30 | assert(mlir_context_ && "MLIR context must be set before calling Compile."); 31 | 32 | mlir::OwningModuleRef module = 33 | mlir::parseSourceString(mlir_input_, mlir_context_); 34 | 35 | tfrt::BefBuffer bef_buffer = 36 | tfrt::ConvertMLIRToBEF(module.get(), /*disable_optional_sections=*/true); 37 | auto bef_file = 38 | BEFFile::Open(bef_buffer, host_context_->GetKernelRegistry(), 39 | host_context_->diag_handler(), host_context_->allocator()); 40 | return ClinkRunner(fn_name_, std::move(bef_buffer), host_context_); 41 | } 42 | 43 | ClinkRunner::ClinkRunner(const std::string &fn_name, BefBuffer bef_buffer, 44 | HostContext *host_context) 45 | : fn_name_(fn_name), 46 | bef_buffer_(bef_buffer), 47 | host_context_(host_context), 48 | execution_context_( 49 | *tfrt::RequestContextBuilder(host_context_, resource_context_.get()) 50 | .build()) { 51 | bef_file_ = 52 | BEFFile::Open(bef_buffer_, host_context_->GetKernelRegistry(), 53 | host_context_->diag_handler(), host_context_->allocator()); 54 | func_ = bef_file_->GetFunction(fn_name_); 55 | } 56 | 57 | llvm::SmallVector, 4> ClinkRunner::Run( 58 | ArrayRef> inputs) { 59 | assert((func_->num_arguments() == inputs.size()) && 60 | "Incorrect number of arguments set."); 61 | 62 | std::vector input_ptrs; 63 | input_ptrs.resize(inputs.size()); 64 | std::transform(inputs.begin(), inputs.end(), input_ptrs.begin(), 65 | [](auto &value) { return value.get(); }); 66 | 67 | llvm::SmallVector, 4> results; 68 | results.resize(func_->result_types().size()); 69 | func_->Execute(execution_context_, input_ptrs, results); 70 | return std::move(results); 71 | } 72 | 73 | } // namespace clink 74 | -------------------------------------------------------------------------------- /lib/utils/clink_utils.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed under the Apache License, Version 2.0 (the "License"); 3 | * you may not use this file except in compliance with the License. 4 | * You may obtain a copy of the License at 5 | * 6 | * http://www.apache.org/licenses/LICENSE-2.0 7 | * 8 | * Unless required by applicable law or agreed to in writing, software 9 | * distributed under the License is distributed on an "AS IS" BASIS, 10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | * See the License for the specific language governing permissions and 12 | * limitations under the License. 13 | */ 14 | 15 | #include "clink/utils/clink_utils.h" 16 | 17 | #include 18 | #include 19 | 20 | #include 21 | 22 | #include "tfrt/host_context/concurrent_work_queue.h" 23 | #include "tfrt/host_context/host_context.h" 24 | #include "tfrt/host_context/profiled_allocator.h" 25 | 26 | namespace clink { 27 | 28 | std::unique_ptr CreateHostContext( 29 | string_view work_queue_type, tfrt::HostAllocatorType host_allocator_type) { 30 | auto decoded_diagnostic_handler = [&](const DecodedDiagnostic &diag) { 31 | TFRT_LOG(FATAL) << "Encountered error while executing, aborting: " 32 | << diag.message; 33 | }; 34 | std::unique_ptr work_queue = 35 | CreateWorkQueue(work_queue_type); 36 | 37 | std::unique_ptr host_allocator; 38 | switch (host_allocator_type) { 39 | case HostAllocatorType::kMalloc: 40 | host_allocator = CreateMallocAllocator(); 41 | llvm::outs() << "Choosing malloc.\n"; 42 | break; 43 | case HostAllocatorType::kTestFixedSizeMalloc: 44 | host_allocator = tfrt::CreateFixedSizeAllocator(); 45 | llvm::outs() << "Choosing fixed size malloc.\n"; 46 | break; 47 | case HostAllocatorType::kProfiledMalloc: 48 | host_allocator = CreateMallocAllocator(); 49 | host_allocator = CreateProfiledAllocator(std::move(host_allocator)); 50 | llvm::outs() << "Choosing profiled allocator based on malloc.\n"; 51 | break; 52 | case HostAllocatorType::kLeakCheckMalloc: 53 | host_allocator = CreateMallocAllocator(); 54 | host_allocator = CreateLeakCheckAllocator(std::move(host_allocator)); 55 | llvm::outs() << "Choosing memory leak check allocator.\n"; 56 | } 57 | llvm::outs().flush(); 58 | 59 | auto host_ctx = std::make_unique(decoded_diagnostic_handler, 60 | std::move(host_allocator), 61 | std::move(work_queue)); 62 | RegisterStaticKernels(host_ctx->GetMutableRegistry()); 63 | return host_ctx; 64 | } 65 | 66 | std::string getOnlyFileInDirectory(std::string path) { 67 | std::string result = ""; 68 | struct dirent *entry; 69 | struct stat st; 70 | DIR *dir = opendir(path.c_str()); 71 | 72 | if (dir == NULL) { 73 | return ""; 74 | } 75 | while ((entry = readdir(dir)) != NULL) { 76 | const std::string full_file_name = tfrt::StrCat(path, "/", entry->d_name); 77 | if (stat(full_file_name.c_str(), &st) == -1) continue; 78 | bool is_directory = (st.st_mode & S_IFDIR) != 0; 79 | if (!is_directory) { 80 | if (result != "") { 81 | return ""; 82 | } 83 | result = std::string(entry->d_name); 84 | } 85 | } 86 | closedir(dir); 87 | return result; 88 | } 89 | 90 | } // namespace clink 91 | -------------------------------------------------------------------------------- /java-lib/src/main/java/org/flinkextended/clink/util/ClinkReadWriteUtils.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Clink Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package org.flinkextended.clink.util; 18 | 19 | import org.apache.flink.ml.param.Param; 20 | 21 | import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper; 22 | 23 | import java.io.BufferedWriter; 24 | import java.io.File; 25 | import java.io.FileWriter; 26 | import java.io.IOException; 27 | import java.util.HashMap; 28 | import java.util.Map; 29 | 30 | /** Utility methods for reading and writing clink operators. */ 31 | public class ClinkReadWriteUtils { 32 | public static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); 33 | 34 | /** 35 | * Saves the metadata of the given stage and the extra metadata to a file named `metadata` under 36 | * the given path. The metadata of a stage includes the stage class name, parameter values etc. 37 | * 38 | *

Required: the metadata file under the given path should not exist. 39 | * 40 | * @param paramMap The parameter map of the stage. 41 | * @param modelClass The detailed class of the stage. 42 | * @param path The parent directory to save the stage metadata. 43 | */ 44 | public static void saveMetadata( 45 | Map, Object> paramMap, Class modelClass, String path) throws IOException { 46 | Map metadata = new HashMap<>(); 47 | metadata.put("className", modelClass.getName()); 48 | metadata.put("timestamp", System.currentTimeMillis()); 49 | metadata.put("paramMap", jsonEncode(paramMap)); 50 | String metadataStr = OBJECT_MAPPER.writeValueAsString(metadata); 51 | 52 | // Creates parent directories if not already created. 53 | new File(path).mkdirs(); 54 | 55 | File metadataFile = new File(path, "metadata"); 56 | if (!metadataFile.createNewFile()) { 57 | throw new IOException("File " + metadataFile.toString() + " already exists."); 58 | } 59 | try (BufferedWriter writer = new BufferedWriter(new FileWriter(metadataFile))) { 60 | writer.write(metadataStr); 61 | } 62 | } 63 | 64 | /** Converts a parameter map to corresponding json string. */ 65 | public static Map jsonEncode(Map, Object> paramMap) 66 | throws IOException { 67 | Map result = new HashMap<>(paramMap.size()); 68 | for (Map.Entry, Object> entry : paramMap.entrySet()) { 69 | String json = jsonEncodeHelper(entry.getKey(), entry.getValue()); 70 | result.put(entry.getKey().name, json); 71 | } 72 | return result; 73 | } 74 | 75 | // A helper method that calls encodes the given parameter value to a json string. We can not 76 | // call param.jsonEncode(value) directly because Param::jsonEncode(...) needs the actual type 77 | // of the value. 78 | private static String jsonEncodeHelper(Param param, Object value) throws IOException { 79 | return param.jsonEncode((T) value); 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /include/clink/utils/clink_runner.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed under the Apache License, Version 2.0 (the "License"); 3 | * you may not use this file except in compliance with the License. 4 | * You may obtain a copy of the License at 5 | * 6 | * http://www.apache.org/licenses/LICENSE-2.0 7 | * 8 | * Unless required by applicable law or agreed to in writing, software 9 | * distributed under the License is distributed on an "AS IS" BASIS, 10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | * See the License for the specific language governing permissions and 12 | * limitations under the License. 13 | */ 14 | 15 | #ifndef CLINK_UTILS_CLINK_RUNNER_H_ 16 | #define CLINK_UTILS_CLINK_RUNNER_H_ 17 | 18 | #include "mlir/IR/MLIRContext.h" 19 | #include "tfrt/bef/bef_buffer.h" 20 | #include "tfrt/bef_executor/bef_file.h" 21 | #include "tfrt/host_context/execution_context.h" 22 | 23 | using namespace tfrt; 24 | 25 | namespace clink { 26 | 27 | // This class is a utility class that provides support for users to specify an 28 | // MLIR function, supply inputs and then have it compiled and run through TFRT. 29 | class ClinkRunner { 30 | public: 31 | class Builder { 32 | public: 33 | Builder(); 34 | 35 | // Sets the MLIR function string and returns the object to chain setters. 36 | // Does not perform validation, will be validated when Compile is called. 37 | Builder &set_mlir_input(string_view mlir_input) { 38 | assert(!mlir_input.empty() && "MLIR input must not be empty."); 39 | mlir_input_ = mlir_input.str(); 40 | return *this; 41 | } 42 | 43 | // Sets the MLIR function name that will be compiled and run, returns the 44 | // object to chain setters. 45 | Builder &set_mlir_fn_name(string_view fn_name) { 46 | assert(!fn_name.empty() && "Function name must not be empty."); 47 | fn_name_ = fn_name.str(); 48 | return *this; 49 | } 50 | 51 | // Sets the `host_context_` that should be used for opening the BefFile. 52 | // `host_context` must outlive ClinkRunner. 53 | Builder &set_host_context(HostContext *host_context) { 54 | assert(host_context && "HostContext must not be null."); 55 | host_context_ = host_context; 56 | return *this; 57 | } 58 | 59 | // Sets the `mlir_context` that should be used for compiling the MLIR code. 60 | // `mlir_context` must outlive ClinkRunner. 61 | Builder &set_mlir_context(mlir::MLIRContext *mlir_context) { 62 | assert(mlir_context && "MLIR context must not be null."); 63 | mlir_context_ = mlir_context; 64 | return *this; 65 | } 66 | 67 | // Compiles the MLIR function to BEF and returns a ClinkRunner 68 | // object that can be used to Run the MLIR function of interest on TFRT and 69 | // extract outputs. Assert fails if any of mlir_input, fn_name, mlir_context 70 | // are not set. 71 | ClinkRunner Compile(); 72 | 73 | private: 74 | std::string mlir_input_; 75 | std::string fn_name_; 76 | mlir::MLIRContext *mlir_context_ = nullptr; 77 | HostContext *host_context_ = nullptr; 78 | }; 79 | 80 | // Runs the MLIR function on TFRT and returns the outputs. 81 | llvm::SmallVector, 4> Run( 82 | llvm::ArrayRef> inputs); 83 | 84 | private: 85 | // Use ClinkRunner::Builder to get a ClinkRunner object. 86 | ClinkRunner(const std::string &fn_name, BefBuffer bef_buffer, 87 | HostContext *host_context); 88 | 89 | std::string fn_name_; 90 | BefBuffer bef_buffer_; 91 | HostContext *host_context_ = nullptr; 92 | RCReference bef_file_; 93 | const tfrt::Function *func_; 94 | std::unique_ptr resource_context_ = nullptr; 95 | ExecutionContext execution_context_; 96 | }; 97 | 98 | } // namespace clink 99 | 100 | #endif // CLINK_UTILS_CLINK_RUNNER_H_ 101 | -------------------------------------------------------------------------------- /lib/kernels/clink_kernels.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed under the Apache License, Version 2.0 (the "License"); 3 | * you may not use this file except in compliance with the License. 4 | * You may obtain a copy of the License at 5 | * 6 | * http://www.apache.org/licenses/LICENSE-2.0 7 | * 8 | * Unless required by applicable law or agreed to in writing, software 9 | * distributed under the License is distributed on an "AS IS" BASIS, 10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | * See the License for the specific language governing permissions and 12 | * limitations under the License. 13 | */ 14 | 15 | #include "clink/kernels/clink_kernels.h" 16 | 17 | #include "clink/feature/one_hot_encoder.h" 18 | #include "clink/linalg/sparse_vector.h" 19 | #include "tfrt/host_context/async_dispatch.h" 20 | #include "tfrt/host_context/kernel_utils.h" 21 | 22 | using namespace tfrt; 23 | 24 | namespace clink { 25 | 26 | AsyncValueRef SquareAdd(Argument x, Argument y, 27 | const ExecutionContext &exec_ctx) { 28 | HostContext *host = exec_ctx.host(); 29 | // Submit a subtask to compute x^2. 30 | AsyncValueRef x_square = 31 | EnqueueWork(exec_ctx, [x = x.ValueRef()] { return x.get() * x.get(); }); 32 | 33 | // Submit a subtask to compute y^2. 34 | AsyncValueRef y_square = 35 | EnqueueWork(exec_ctx, [y = y.ValueRef()] { return y.get() * y.get(); }); 36 | 37 | SmallVector async_values; 38 | async_values.push_back(x_square.GetAsyncValue()); 39 | async_values.push_back(y_square.GetAsyncValue()); 40 | 41 | // Submit a subtask to compute x^2 + y^2 once the previous two subtasks have 42 | // completed. 43 | auto output = MakeUnconstructedAsyncValueRef(host); 44 | RunWhenReady(async_values, 45 | [x_square = std::move(x_square), y_square = std::move(y_square), 46 | output = output.CopyRef(), exec_ctx]() { 47 | output.emplace(x_square.get() + y_square.get()); 48 | }); 49 | 50 | return output; 51 | } 52 | 53 | double Square(double x) { return x * x; } 54 | 55 | template 56 | void ModelLoad(Argument path, 57 | Result> result_model, 58 | const ExecutionContext &exec_ctx) { 59 | auto result = result_model.Allocate(); 60 | // Model Loading might invoke slow IO operation, which requires the usage of 61 | // EnqueueBlockingWork. 62 | bool work_enqueued = tfrt::EnqueueBlockingWork( 63 | exec_ctx.host(), [result = result.CopyRef(), path, exec_ctx]() { 64 | auto model = T::load(path.get(), exec_ctx.host()); 65 | if (auto err = model.takeError()) { 66 | result.SetError(err); 67 | } else { 68 | result.emplace(std::move(model.get())); 69 | } 70 | }); 71 | if (!work_enqueued) result.SetError("Failed to enqueue blocking work."); 72 | } 73 | 74 | void ModelTransform(RCReference model, RemainingArguments args, 75 | tfrt::RemainingResults results, 76 | const ExecutionContext &exec_ctx) { 77 | auto outputs = model->transform(args.values(), exec_ctx); 78 | for (int i = 0; i < outputs.size(); i++) { 79 | results.AllocateIndirectResultAt(i)->ForwardTo(std::move(outputs[i])); 80 | } 81 | } 82 | 83 | //===----------------------------------------------------------------------===// 84 | // Registration 85 | //===----------------------------------------------------------------------===// 86 | 87 | void RegisterClinkKernels(tfrt::KernelRegistry *registry) { 88 | registry->AddKernel("clink.square_add.f64", TFRT_KERNEL(SquareAdd)); 89 | registry->AddKernel("clink.square.f64", TFRT_KERNEL(Square)); 90 | registry->AddKernel("clink.load.onehotencoder", 91 | TFRT_KERNEL(ModelLoad)); 92 | registry->AddKernel("clink.transform", TFRT_KERNEL(ModelTransform)); 93 | } 94 | 95 | } // namespace clink 96 | -------------------------------------------------------------------------------- /lib/kernels/opdefs/clink_kernels.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed under the Apache License, Version 2.0 (the "License"); 3 | * you may not use this file except in compliance with the License. 4 | * You may obtain a copy of the License at 5 | * 6 | * http://www.apache.org/licenses/LICENSE-2.0 7 | * 8 | * Unless required by applicable law or agreed to in writing, software 9 | * distributed under the License is distributed on an "AS IS" BASIS, 10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | * See the License for the specific language governing permissions and 12 | * limitations under the License. 13 | */ 14 | 15 | // This file implements MLIR operation functions for the clink library. 16 | 17 | #include "clink/kernels/opdefs/clink_kernels.h" 18 | 19 | #include "clink/kernels/opdefs/types.h" 20 | #include "mlir/IR/BuiltinOps.h" 21 | #include "tfrt/basic_kernels/opdefs/types.h" 22 | 23 | namespace clink { 24 | 25 | //===----------------------------------------------------------------------===// 26 | // Clink Dialect 27 | //===----------------------------------------------------------------------===// 28 | 29 | ClinkDialect::ClinkDialect(MLIRContext *context) 30 | : Dialect(/*name=*/"clink", context, TypeID::get()) { 31 | allowUnknownTypes(); 32 | allowUnknownOperations(); 33 | 34 | addTypes(); 35 | 36 | addOperations< 37 | #define GET_OP_LIST 38 | #include "clink/kernels/opdefs/clink_kernels.cpp.inc" 39 | >(); 40 | } 41 | 42 | mlir::Type ClinkDialect::parseType(mlir::DialectAsmParser &parser) const { 43 | llvm::StringRef spec = parser.getFullSymbolSpec(); 44 | if (spec == "model") return ModelType::get(getContext()); 45 | if (spec == "vector") return VectorType::get(getContext()); 46 | 47 | if (auto type = mlir::Dialect::parseType(parser)) return type; 48 | 49 | mlir::Location loc = parser.getEncodedSourceLoc(parser.getNameLoc()); 50 | mlir::emitError(loc) << "unknown data type " << spec; 51 | return {}; 52 | } 53 | 54 | void ClinkDialect::printType(mlir::Type type, 55 | mlir::DialectAsmPrinter &printer) const { 56 | if (type.isa()) { 57 | printer << "model"; 58 | return; 59 | } 60 | 61 | if (type.isa()) { 62 | printer << "vector"; 63 | return; 64 | } 65 | 66 | llvm_unreachable("unknown data type"); 67 | } 68 | 69 | namespace { 70 | 71 | static Type GetModelType(Builder *builder) { 72 | return builder->getType(); 73 | } 74 | 75 | } // namespace 76 | 77 | //===----------------------------------------------------------------------===// 78 | // TransformOp 79 | //===----------------------------------------------------------------------===// 80 | 81 | static ParseResult parseTransformOp(OpAsmParser &parser, 82 | OperationState &result) { 83 | SmallVector operands; 84 | SmallVector operand_types; 85 | FunctionType calleeType; 86 | auto calleeLoc = parser.getNameLoc(); 87 | if (parser.parseOperandList(operands) || parser.parseColonType(calleeType) || 88 | parser.addTypesToList(calleeType.getResults(), result.types)) { 89 | return failure(); 90 | } 91 | operand_types.push_back(GetModelType(&parser.getBuilder())); 92 | operand_types.insert(operand_types.end(), calleeType.getInputs().begin(), 93 | calleeType.getInputs().end()); 94 | if (parser.resolveOperands(operands, operand_types, calleeLoc, 95 | result.operands)) { 96 | return failure(); 97 | } 98 | 99 | return success(); 100 | } 101 | 102 | } // namespace clink 103 | 104 | //===----------------------------------------------------------------------===// 105 | // TableGen'd op method definitions 106 | //===----------------------------------------------------------------------===// 107 | 108 | #define GET_OP_CLASSES 109 | #include "clink/kernels/opdefs/clink_kernels.cpp.inc" 110 | -------------------------------------------------------------------------------- /lib/feature/one_hot_encoder.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Clink Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include "clink/feature/one_hot_encoder.h" 18 | 19 | #include 20 | 21 | #include "clink/utils/clink_utils.h" 22 | #include "nlohmann/json.hpp" 23 | 24 | namespace clink { 25 | 26 | llvm::SmallVector, 4> 27 | OneHotEncoderModel::transform(llvm::ArrayRef inputs, 28 | const ExecutionContext &exec_ctx) { 29 | auto output = MakeUnconstructedAsyncValueRef(exec_ctx.host()); 30 | tfrt::EnqueueWork( 31 | exec_ctx, [model = tfrt::FormRef(this), value = inputs[0]->get(), 32 | column_index = inputs[1]->get(), 33 | output = output.CopyRef(), exec_ctx]() { 34 | if (column_index >= model->model_data_.featuresizes_size()) { 35 | output.SetError("Column index out of range."); 36 | return; 37 | } 38 | 39 | int len = model->model_data_.featuresizes(column_index); 40 | if (value >= len) { 41 | output.SetError("Value out of range."); 42 | return; 43 | } 44 | if (model->getDropLast()) { 45 | len -= 1; 46 | } 47 | 48 | SparseVector vector(len); 49 | if (value < len) { 50 | vector.set(value, 1.0); 51 | } 52 | output.emplace(std::move(vector)); 53 | }); 54 | 55 | llvm::SmallVector, 4> result; 56 | result.push_back(std::move(output)); 57 | return result; 58 | } 59 | 60 | void OneHotEncoderModel::setDropLast(const bool is_droplast) { 61 | params_.is_droplast = is_droplast; 62 | } 63 | 64 | bool OneHotEncoderModel::getDropLast() const { return params_.is_droplast; } 65 | 66 | llvm::Expected> OneHotEncoderModel::load( 67 | const std::string &path, tfrt::HostContext *host) { 68 | tfrt::RCReference model = 69 | TakeRef(host->Construct(host)); 70 | 71 | std::ifstream params_input(tfrt::StrCat(path, "/metadata")); 72 | nlohmann::json params; 73 | params << params_input; 74 | std::string is_droplast = params["paramMap"]["dropLast"].get(); 75 | model->setDropLast(is_droplast != "false"); 76 | params_input.close(); 77 | 78 | std::string model_data_filename = 79 | getOnlyFileInDirectory(tfrt::StrCat(path, "/data")); 80 | if (model_data_filename == "") { 81 | return tfrt::MakeStringError( 82 | "Failed to load OneHotEncoderModel: model data directory ", path, 83 | "/data does not exist, or it has zero or more than one file."); 84 | } 85 | 86 | std::ifstream model_data_input( 87 | tfrt::StrCat(path, "/data/", model_data_filename)); 88 | model_data_input.seekg(sizeof(int), model_data_input.beg); 89 | if (!model->model_data_.ParseFromIstream(&model_data_input)) { 90 | return tfrt::MakeStringError( 91 | "Failed to load OneHotEncoderModel: Invalid model data file ", path, 92 | "/data/", model_data_filename); 93 | } 94 | model_data_input.close(); 95 | 96 | for (int i = 0; i < model->model_data_.featuresizes_size(); i++) { 97 | if (model->model_data_.featuresizes(i) <= 0) { 98 | return tfrt::MakeStringError( 99 | "Failed to load OneHotEncoderModel: Model " 100 | "data feature size value must be positive."); 101 | } 102 | } 103 | 104 | return model; 105 | } 106 | 107 | } // namespace clink 108 | -------------------------------------------------------------------------------- /cpp_tests/include/clink/cpp_tests/test_util.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Clink Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | // This file defines utilities related to setting up unit tests. 18 | #ifndef CLINK_CPP_TESTS_TEST_UTIL_H_ 19 | #define CLINK_CPP_TESTS_TEST_UTIL_H_ 20 | 21 | #include 22 | #include 23 | 24 | #include 25 | 26 | #include "clink/kernels/opdefs/clink_kernels.h" 27 | #include "clink/utils/clink_runner.h" 28 | #include "clink/utils/clink_utils.h" 29 | #include "google/protobuf/message.h" 30 | #include "nlohmann/json.hpp" 31 | #include "tfrt/basic_kernels/opdefs/basic_kernels.h" 32 | 33 | namespace clink { 34 | namespace test { 35 | 36 | // This class represents a temporary folder used for unit tests. The folder will 37 | // be deleted automatically once this object is freed. 38 | class TemporaryFolder { 39 | public: 40 | TemporaryFolder() { 41 | char dir_template[] = "/tmp/clink-test-tmp.XXXXXX"; 42 | dir_name = std::string(mkdtemp(dir_template)); 43 | } 44 | 45 | ~TemporaryFolder() { deleteFolderRecursively(dir_name); } 46 | 47 | const std::string getAbsolutePath() { return dir_name; } 48 | 49 | private: 50 | void deleteFolderRecursively(std::string path) { 51 | struct dirent *entry; 52 | struct stat st; 53 | DIR *dir = opendir(path.c_str()); 54 | 55 | if (dir == NULL) { 56 | return; 57 | } 58 | while ((entry = readdir(dir)) != NULL) { 59 | const std::string full_file_name = tfrt::StrCat(path, "/", entry->d_name); 60 | if (stat(full_file_name.c_str(), &st) == -1) continue; 61 | bool is_directory = (st.st_mode & S_IFDIR) != 0; 62 | if (is_directory) { 63 | if (strcmp(entry->d_name, ".") == 0 || strcmp(entry->d_name, "..") == 0) 64 | continue; 65 | deleteFolderRecursively(full_file_name); 66 | } else { 67 | remove(full_file_name.c_str()); 68 | } 69 | } 70 | closedir(dir); 71 | remove(path.c_str()); 72 | } 73 | 74 | std::string dir_name; 75 | }; 76 | 77 | // Mocks Flink ML's `Stage.save(String path)` method to save metadata and model 78 | // data of an Clink operator to a given directory. 79 | void saveMetaDataModelData(std::string dir_name, nlohmann::json params, 80 | google::protobuf::Message &model_data) { 81 | std::ofstream params_output(tfrt::StrCat(dir_name, "/metadata")); 82 | params_output << params; 83 | params_output.close(); 84 | 85 | mkdir(tfrt::StrCat(dir_name, "/data").c_str(), S_IRUSR | S_IWUSR); 86 | std::ofstream model_data_output(tfrt::StrCat(dir_name, "/data/modelData")); 87 | 88 | int model_data_len = 0; 89 | model_data_output.write((char *)&model_data_len, sizeof(int)); 90 | 91 | model_data.SerializeToOstream(&model_data_output); 92 | model_data_output.close(); 93 | } 94 | 95 | // Creates a ClinkRunner, executes the provided mlir script and returns the 96 | // execution result. 97 | llvm::SmallVector, 4> runMlirScript( 98 | tfrt::HostContext *host_context, MLIRContext *mlir_context, 99 | string_view mlir_script, llvm::ArrayRef> inputs) { 100 | // Initializes ClinkRunner. 101 | clink::ClinkRunner::Builder builder; 102 | builder.set_mlir_fn_name("main") 103 | .set_mlir_input(mlir_script) 104 | .set_host_context(host_context) 105 | .set_mlir_context(mlir_context); 106 | auto runner = builder.Compile(); 107 | 108 | return runner.Run(inputs); 109 | } 110 | 111 | } // namespace test 112 | } // namespace clink 113 | 114 | #endif // CLINK_CPP_TESTS_TEST_UTIL_H_ 115 | -------------------------------------------------------------------------------- /lib/executor/main.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed under the Apache License, Version 2.0 (the "License"); 3 | * you may not use this file except in compliance with the License. 4 | * You may obtain a copy of the License at 5 | * 6 | * http://www.apache.org/licenses/LICENSE-2.0 7 | * 8 | * Unless required by applicable law or agreed to in writing, software 9 | * distributed under the License is distributed on an "AS IS" BASIS, 10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | * See the License for the specific language governing permissions and 12 | * limitations under the License. 13 | */ 14 | 15 | #include 16 | 17 | #include "clink/kernels/opdefs/clink_kernels.h" 18 | #include "clink/utils/clink_runner.h" 19 | #include "clink/utils/clink_utils.h" 20 | #include "llvm/Support/CommandLine.h" 21 | #include "llvm/Support/InitLLVM.h" 22 | #include "llvm/Support/SourceMgr.h" 23 | #include "mlir/Support/FileUtilities.h" 24 | #include "tfrt/basic_kernels/opdefs/basic_kernels.h" 25 | #include "tfrt/bef_executor_driver/bef_executor_driver.h" 26 | #include "tfrt/host_context/host_context.h" 27 | 28 | using namespace mlir; 29 | 30 | static llvm::cl::opt inputFilename(llvm::cl::Positional, 31 | llvm::cl::desc(""), 32 | llvm::cl::init("-")); 33 | 34 | static llvm::cl::list cl_functions( // NOLINT 35 | "functions", llvm::cl::desc("Specify MLIR functions to run"), 36 | llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated); 37 | 38 | // Enable ConcurrentWorkQueue types to be specified on the command line. 39 | static llvm::cl::opt cl_work_queue_type( // NOLINT 40 | "work_queue_type", 41 | llvm::cl::desc("Specify concurrent work queue type (s, mstd, ...):"), 42 | llvm::cl::init("s")); 43 | 44 | // Enable HostAllocator types to be specified on the command line. 45 | static llvm::cl::opt cl_host_allocator_type( // NOLINT 46 | "host_allocator_type", llvm::cl::desc("Specify host allocator type:"), 47 | llvm::cl::values( 48 | clEnumValN(tfrt::HostAllocatorType::kMalloc, "malloc", "Malloc."), 49 | clEnumValN(tfrt::HostAllocatorType::kTestFixedSizeMalloc, 50 | "test_fixed_size_1k", 51 | "Fixed size (1 kB) Malloc for testing."), 52 | clEnumValN(tfrt::HostAllocatorType::kProfiledMalloc, 53 | "profiled_allocator", "Malloc with metric profiling."), 54 | clEnumValN(tfrt::HostAllocatorType::kLeakCheckMalloc, 55 | "leak_check_allocator", "Malloc with memory leak check.")), 56 | llvm::cl::init(tfrt::HostAllocatorType::kLeakCheckMalloc)); 57 | 58 | int main(int argc, char **argv) { 59 | llvm::InitLLVM y(argc, argv); 60 | llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR translation driver\n"); 61 | 62 | // Reads mlir source program. 63 | std::string errorMessage; 64 | auto input = openInputFile(inputFilename, &errorMessage); 65 | if (!input) { 66 | llvm::errs() << errorMessage << "\n"; 67 | return 1; 68 | } 69 | llvm::SourceMgr source_mgr; 70 | source_mgr.AddNewSourceBuffer(std::move(input), llvm::SMLoc()); 71 | auto mlir_src = 72 | source_mgr.getMemoryBuffer(source_mgr.getMainFileID())->getBuffer(); 73 | 74 | // Initializes MLIR context. 75 | MLIRContext context; 76 | context.allowUnregisteredDialects(); 77 | context.printOpOnDiagnostic(true); 78 | mlir::DialectRegistry registry; 79 | registry.insert(); 80 | registry.insert(); 81 | context.appendDialectRegistry(registry); 82 | 83 | // Initializes HostContext. 84 | std::unique_ptr host_context = 85 | clink::CreateHostContext(cl_work_queue_type, cl_host_allocator_type); 86 | 87 | // Initializes ClinkRunner. 88 | clink::ClinkRunner::Builder builder; 89 | builder.set_mlir_fn_name("main") 90 | .set_mlir_input(mlir_src.data()) 91 | .set_host_context(host_context.get()) 92 | .set_mlir_context(&context); 93 | auto runner = builder.Compile(); 94 | 95 | // Executes ClinkRunner. 96 | llvm::SmallVector, 4> inputs; 97 | inputs.push_back(tfrt::MakeAvailableAsyncValueRef(2.0)); 98 | auto results = runner.Run(inputs); 99 | 100 | return 0; 101 | } 102 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

4 | 5 | # Clink 6 | 7 | Clink is a library that provides infrastructure to do the following: 8 | - Defines C++ functions that can be parallelized by TFRT thread pool. 9 | - Executes a graph (in the MLIR format) of these C++ functions in parallel. 10 | - Makes C++ functions executable as Java functions using 11 | [JNA](https://github.com/java-native-access/jna). 12 | 13 | Furthermore, Clink provides an off-the-shelf library of reusable Feature 14 | Processing functions that can be executed as Java and C++ functions. 15 | 16 | Clink is useful in the scenario where users want to do online feature processing 17 | with low latency (in sub-millisecond) in C++, apply the same logic to do offline 18 | feature processing in Java, and implement this logic only once (in C++). 19 | 20 | ## Getting Started 21 | 22 | ### Prerequisites 23 | 24 | Clink uses [TFRT](https://github.com/tensorflow/runtime) as the underlying 25 | execution engine and therefore follows TFRT's Operation System and installation 26 | requirements. 27 | 28 | Currently supported operating systems are as follows: 29 | 30 | - Ubuntu 16.04 31 | - CentOS 7.7.1908 32 | 33 | Here are the prerequisites to build and install Clink: 34 | - Bazel 4.0.0 35 | - Clang 11.1.0 36 | - libstdc++8 or greater 37 | - openjdk-8 38 | 39 | Clink provides dockerfiles and pre-built docker images that satisfy the 40 | installation requirements listed above. You can use one of the following 41 | commands to build the docker image, according to the operating system you expect 42 | to use. 43 | 44 | ```bash 45 | $ docker build -t ubuntu:16.04_clink -f docker/Dockerfile_ubuntu_1604 . 46 | ``` 47 | 48 | ```bash 49 | $ docker build -t centos:centos7.7.1908_clink -f docker/Dockerfile_centos_77 . 50 | ``` 51 | 52 | Or you can use one of the following commands to pull the pre-built Docker image 53 | from Docker Hub. 54 | 55 | ```bash 56 | $ docker pull docker.io/flinkextended/clink:ubuntu16.04 57 | ``` 58 | 59 | ```bash 60 | $ docker pull docker.io/flinkextended/clink:centos7.7.1908 61 | ``` 62 | 63 | If you plan to set up the Clink environment without the docker images provided 64 | above, please check the [TFRT](https://github.com/tensorflow/runtime) README for 65 | more detailed instructions to install, configure and verify Bazel, Clang, and 66 | libstdc++8. 67 | 68 | ### Initializing Submodules before building Clink from Source 69 | 70 | After setting up the environment according to the instructions above and pulling 71 | Clink repository, please use the following command to initialize submodules like 72 | TFRT before building any Clink target from source. 73 | 74 | ```bash 75 | $ git submodule update --init --recursive 76 | ``` 77 | 78 | ### Executing Examples 79 | 80 | Users can execute Clink C++ function example in parallel in C++ using one of the 81 | following commands. 82 | 83 | ```bash 84 | $ bazel run //:executor -- `pwd`/mlir_test/executor/basic.mlir --work_queue_type=mstd --host_allocator_type=malloc 85 | ``` 86 | 87 | 88 | 89 | ## Developer Guidelines 90 | 91 | ### Running All Tests 92 | 93 | Developers can run the following command to build all targets and to run all 94 | tests. 95 | 96 | ```bash 97 | $ bazel test $(bazel query //...) -c dbg 98 | ``` 99 | 100 | ### Code Formatting 101 | 102 | Changes to Clink C++ code should conform to [Google C++ Style 103 | Guide](https://google.github.io/styleguide/cppguide.html). 104 | 105 | Clink uses [ClangFormat](https://clang.llvm.org/docs/ClangFormat.html) to check 106 | C++ code, [diffplug/spotless](https://github.com/diffplug/spotless) to check 107 | java code, and [Buildifier](https://github.com/bazelbuild/buildtools) to check 108 | bazel code. 109 | 110 | Please run the following command to format codes before uploading PRs for 111 | review. 112 | 113 | ```bash 114 | $ ./tools/format-code.sh 115 | ``` 116 | 117 | ### View & Edit Java Code with IDE 118 | 119 | Clink provides maven configuration that allows users to view or edit java code 120 | with IDEs like IntelliJ IDEA. Before IDEs can correctly compile java project, 121 | users need to run the following commands after setting up Clink repo and build 122 | Clink. 123 | 124 | ```bash 125 | $ bazel build //:clink_java_proto 126 | $ cp bazel-bin/libclink_proto-speed.jar java-lib/lib/ 127 | ``` 128 | 129 | Then users can open `java-lib` directory with their IDEs. 130 | -------------------------------------------------------------------------------- /include/clink/kernels/opdefs/clink_kernels.td: -------------------------------------------------------------------------------- 1 | // Copyright 2021 The Clink Runtime Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | //===- clink_kernels.td ----------------------------------------------------===// 16 | // 17 | // Operation definitions for clink_kernels. 18 | // 19 | //===----------------------------------------------------------------------===// 20 | 21 | #ifdef CLINK_OPS 22 | #else 23 | #define CLINK_OPS 24 | 25 | include "tfrt/basic_kernels/opdefs/tfrt_base.td" 26 | 27 | // "clink" dialect 28 | def Clink_Dialect : Dialect { 29 | let name = "clink"; 30 | 31 | let description = [{ 32 | This dialect contains common clink operations. 33 | }]; 34 | 35 | let cppNamespace = "::clink"; 36 | } 37 | 38 | // Base class for Clink dialect ops. 39 | class Clink_Op traits = []> : 40 | Op { 41 | 42 | // Each registered op in the Clink namespace needs to provide a parser. 43 | let parser = [{ return clink::parse$cppClass(parser, result); }]; 44 | } 45 | 46 | //===----------------------------------------------------------------------===// 47 | // Clink types 48 | //===----------------------------------------------------------------------===// 49 | 50 | def Clink_ModelType : 51 | Type()">, "!clink.model type">, 52 | BuildableType<"$_builder.getType()">; 53 | 54 | def Clink_VectorType : 55 | Type()">, "!clink.vector type">, 56 | BuildableType<"$_builder.getType()">; 57 | 58 | //===----------------------------------------------------------------------===// 59 | // Clink ops 60 | //===----------------------------------------------------------------------===// 61 | 62 | def SquareAddF64Op: Clink_Op<"square_add.f64"> { 63 | let summary = "clink.square_add.f64 operation"; 64 | let description = [{ 65 | An operation that takes two inputs and returns their squared sum as the result. 66 | 67 | Example: 68 | %2 = clink.square_add.f64 %0, %1 69 | }]; 70 | let arguments = (ins F64, F64); 71 | let results = (outs F64); 72 | let assemblyFormat = "operands attr-dict"; 73 | let verifier = ?; 74 | } 75 | 76 | def SquareF64Op: Clink_Op<"square.f64"> { 77 | let summary = "clink.square.f64 operation"; 78 | let description = [{ 79 | An operation that returns the square of the input. 80 | 81 | Example: 82 | %1 = clink.square.f64 %0 83 | }]; 84 | let arguments = (ins F64); 85 | let results = (outs F64); 86 | let assemblyFormat = "operands attr-dict"; 87 | let verifier = ?; 88 | } 89 | 90 | class LoadOp : Clink_Op<"load." # suffix> { 91 | let summary = "clink.load operation"; 92 | let description = [{ 93 | An operation that creates a Model's subclass instance based on the 94 | data read from the given path. The path should be a directory 95 | containing params and model data saved by Clink's corresponding Java 96 | operator's void save(String path) method. 97 | 98 | Example: 99 | %1 = clink.load.onehotencoder %0 100 | }]; 101 | let arguments = (ins TFRT_StringType:$path); 102 | let results = (outs Clink_ModelType:$model); 103 | let assemblyFormat = "operands attr-dict"; 104 | let verifier = ?; 105 | } 106 | 107 | def OneHotEncoderLoadOp : LoadOp<"onehotencoder">; 108 | 109 | def TransformOp: Clink_Op<"transform"> { 110 | let summary = "transform operation"; 111 | let description = [{ 112 | An operation that transforms data based on a Model. It applies the Model 113 | on any given input AsyncValues and returns the transformation results as 114 | AsyncValues. 115 | 116 | The input (except Model, the first input argument) and result types must 117 | be specified after the colon. 118 | 119 | Example: 120 | %2 = clink.transform %model, %0, %1 : (i32, i32) -> !clink.vector 121 | }]; 122 | let arguments = (ins 123 | Clink_ModelType:$model, 124 | Variadic:$inputs 125 | ); 126 | let results = (outs Variadic:$outputs); 127 | let verifier = ?; 128 | } 129 | 130 | #endif // CLINK_OPS 131 | -------------------------------------------------------------------------------- /java-lib/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 20 | 23 | 4.0.0 24 | 25 | clink-java-lib 26 | org.clink 27 | Clink Java Lib 28 | 0.1-SNAPSHOT 29 | jar 30 | 31 | https://github.com/flink-extended/clink 32 | 2021 33 | 34 | 35 | 36 | The Apache Software License, Version 2.0 37 | https://www.apache.org/licenses/LICENSE-2.0.txt 38 | repo 39 | 40 | 41 | 42 | 43 | https://github.com/flink-extended/clink 44 | git@github.com:flink-extended/clink.git 45 | 46 | 47 | 48 | 1.14.0 49 | 2.0.0 50 | 3.14.0 51 | 2.12 52 | 2.4.2 53 | 5.6.0 54 | 4.12 55 | 56 | 57 | 58 | 59 | net.java.dev.jna 60 | jna 61 | ${jna.version} 62 | 63 | 64 | org.apache.flink 65 | flink-streaming-java_${scala.version} 66 | ${flink.version} 67 | 68 | 69 | org.apache.flink 70 | flink-table-api-java 71 | ${flink.version} 72 | 73 | 74 | org.apache.flink 75 | flink-core 76 | ${flink.version} 77 | 78 | 79 | org.apache.flink 80 | flink-table-api-java-bridge_${scala.version} 81 | ${flink.version} 82 | 83 | 84 | org.apache.flink 85 | flink-ml-core_${scala.version} 86 | ${flink.ml.version} 87 | 88 | 89 | org.apache.flink 90 | flink-ml-lib_${scala.version} 91 | ${flink.ml.version} 92 | 93 | 94 | org.clink 95 | clink_proto 96 | ${flink.ml.version} 97 | system 98 | ${basedir}/lib/libclink_proto-speed.jar 99 | 100 | 101 | com.google.protobuf 102 | protobuf-java 103 | ${protobuf.version} 104 | 105 | 106 | 107 | org.apache.flink 108 | flink-table-planner_${scala.version} 109 | ${flink.version} 110 | 111 | 112 | junit 113 | junit 114 | ${junit.version} 115 | test 116 | 117 | 118 | 119 | 120 | 121 | 122 | org.apache.maven.plugins 123 | maven-compiler-plugin 124 | 3.5.1 125 | 126 | 1.8 127 | 1.8 128 | 129 | 130 | 131 | 132 | com.diffplug.spotless 133 | spotless-maven-plugin 134 | 135 | 136 | 137 | 138 | 139 | 140 | com.diffplug.spotless 141 | spotless-maven-plugin 142 | ${spotless.version} 143 | 144 | 145 | 146 | 1.7 147 | 148 | 149 | 150 | 151 | 152 | org.apache.flink,org.apache.flink.shaded,,javax,java,scala,\# 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | spotless-check 161 | validate 162 | 163 | check 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | -------------------------------------------------------------------------------- /WORKSPACE: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | 13 | workspace(name = "clink") 14 | 15 | local_repository( 16 | name = "tf_runtime", 17 | path = "tfrt", 18 | ) 19 | 20 | load("@tf_runtime//:dependencies.bzl", "tfrt_dependencies") 21 | 22 | tfrt_dependencies() 23 | 24 | load("@bazel_skylib//lib:versions.bzl", "versions") 25 | 26 | versions.check(minimum_bazel_version = "4.0.0") 27 | 28 | load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe") 29 | load("@llvm-raw//utils/bazel:configure.bzl", "llvm_configure", "llvm_disable_optional_support_deps") 30 | 31 | maybe( 32 | llvm_configure, 33 | name = "llvm-project", 34 | ) 35 | 36 | llvm_disable_optional_support_deps() 37 | 38 | load("@com_google_protobuf//:protobuf_deps.bzl", "protobuf_deps") 39 | 40 | protobuf_deps() 41 | 42 | load("@rules_cc//cc:repositories.bzl", "rules_cc_toolchains") 43 | 44 | rules_cc_toolchains() 45 | 46 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") 47 | 48 | RULES_JVM_EXTERNAL_TAG = "4.2" 49 | 50 | RULES_JVM_EXTERNAL_SHA = "cd1a77b7b02e8e008439ca76fd34f5b07aecb8c752961f9640dea15e9e5ba1ca" 51 | 52 | http_archive( 53 | name = "rules_jvm_external", 54 | sha256 = RULES_JVM_EXTERNAL_SHA, 55 | strip_prefix = "rules_jvm_external-%s" % RULES_JVM_EXTERNAL_TAG, 56 | url = "https://github.com/bazelbuild/rules_jvm_external/archive/%s.zip" % RULES_JVM_EXTERNAL_TAG, 57 | ) 58 | 59 | http_archive( 60 | name = "com_github_nlohmann_json", 61 | build_file = "//third_party:json.BUILD", # see below 62 | sha256 = "4cf0df69731494668bdd6460ed8cb269b68de9c19ad8c27abc24cd72605b2d5b", 63 | strip_prefix = "json-3.9.1", 64 | urls = ["https://github.com/nlohmann/json/archive/v3.9.1.tar.gz"], 65 | ) 66 | 67 | # io_bazel_rules_scala defines scala version. 68 | rules_scala_version = "e7a948ad1948058a7a5ddfbd9d1629d6db839933" 69 | 70 | http_archive( 71 | name = "io_bazel_rules_scala", 72 | sha256 = "76e1abb8a54f61ada974e6e9af689c59fd9f0518b49be6be7a631ce9fa45f236", 73 | strip_prefix = "rules_scala-%s" % rules_scala_version, 74 | type = "zip", 75 | url = "https://github.com/bazelbuild/rules_scala/archive/%s.zip" % rules_scala_version, 76 | ) 77 | 78 | load("@rules_jvm_external//:defs.bzl", "maven_install") 79 | load("@io_bazel_rules_scala//:scala_config.bzl", "scala_config") 80 | 81 | scala_config(scala_version = "2.12.7") 82 | 83 | load("@io_bazel_rules_scala//scala:scala.bzl", "scala_repositories") 84 | 85 | scala_repositories() 86 | 87 | FLINK_VERSION = "1.14.0" 88 | 89 | FLINK_ML_VERSION = "2.0.0" 90 | 91 | SCALA_VERSION = "2.12" 92 | 93 | maven_install( 94 | artifacts = [ 95 | "net.java.dev.jna:jna:5.6.0", 96 | "net.java.dev.jna:jna-platform:5.6.0", 97 | "org.apache.flink:flink-connector-files:%s" % FLINK_VERSION, 98 | "org.apache.flink:flink-core:%s" % FLINK_VERSION, 99 | "org.apache.flink:flink-streaming-java_%s:%s" % (SCALA_VERSION, FLINK_VERSION), 100 | "org.apache.flink:flink-shaded-jackson:2.12.4-14.0", 101 | "org.apache.flink:flink-table-api-java:%s" % FLINK_VERSION, 102 | "org.apache.flink:flink-table-api-java-bridge_%s:%s" % (SCALA_VERSION, FLINK_VERSION), 103 | "org.apache.flink:flink-clients_%s:%s" % (SCALA_VERSION, FLINK_VERSION), 104 | "org.apache.flink:flink-table-planner_%s:%s" % (SCALA_VERSION, FLINK_VERSION), 105 | "org.apache.flink:flink-table-runtime_%s:%s" % (SCALA_VERSION, FLINK_VERSION), 106 | "org.apache.flink:flink-test-utils-junit:%s" % FLINK_VERSION, 107 | "org.apache.flink:flink-ml-core_%s:%s" % (SCALA_VERSION, FLINK_ML_VERSION), 108 | "org.apache.flink:flink-ml-iteration_%s:%s" % (SCALA_VERSION, FLINK_ML_VERSION), 109 | "org.apache.flink:flink-ml-lib_%s:%s" % (SCALA_VERSION, FLINK_ML_VERSION), 110 | "org.apache.commons:commons-compress:1.21", 111 | "commons-collections:commons-collections:3.2.2", 112 | "org.apache.commons:commons-lang3:3.3.2", 113 | "junit:junit:4.12", 114 | ], 115 | override_targets = { 116 | "org.scala-lang.scala-library": "@io_bazel_rules_scala_scala_library//:io_bazel_rules_scala_scala_library", 117 | "org.scala-lang.scala-reflect": "@io_bazel_rules_scala_scala_reflect//:io_bazel_rules_scala_scala_reflect", 118 | "org.scala-lang.scala-compiler": "@io_bazel_rules_scala_scala_compiler//:io_bazel_rules_scala_scala_compiler", 119 | "org.scala-lang.modules.scala-parser-combinators_2.11": "@io_bazel_rules_scala_scala_parser_combinators//:io_bazel_rules_scala_scala_parser_combinators", 120 | "org.scala-lang.modules.scala-xml_2.11": "@io_bazel_rules_scala_scala_xml//:io_bazel_rules_scala_scala_xml", 121 | }, 122 | repositories = [ 123 | "https://maven.google.com", 124 | "https://repo1.maven.org/maven2", 125 | "http://packages.confluent.io/maven", 126 | "http://mvnrepo.alibaba-inc.com/mvn/repository", 127 | ], 128 | ) 129 | 130 | http_archive( 131 | name = "io_bazel_rules_go", 132 | sha256 = "2b1641428dff9018f9e85c0384f03ec6c10660d935b750e3fa1492a281a53b0f", 133 | urls = [ 134 | "https://mirror.bazel.build/github.com/bazelbuild/rules_go/releases/download/v0.29.0/rules_go-v0.29.0.zip", 135 | "https://github.com/bazelbuild/rules_go/releases/download/v0.29.0/rules_go-v0.29.0.zip", 136 | ], 137 | ) 138 | 139 | load("@io_bazel_rules_go//go:deps.bzl", "go_register_toolchains", "go_rules_dependencies") 140 | 141 | go_rules_dependencies() 142 | 143 | go_register_toolchains(version = "1.17.2") 144 | 145 | http_archive( 146 | name = "com_github_bazelbuild_buildtools", 147 | sha256 = "0d3ca4ed434958dda241fb129f77bd5ef0ce246250feed2d5a5470c6f29a77fa", 148 | strip_prefix = "buildtools-4.0.0", 149 | urls = [ 150 | "https://github.com/bazelbuild/buildtools/archive/refs/tags/4.0.0.tar.gz", 151 | ], 152 | ) 153 | -------------------------------------------------------------------------------- /cpp_tests/feature/one_hot_encoder_test.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Clink Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include "clink/feature/one_hot_encoder.h" 18 | 19 | #include "clink/cpp_tests/test_util.h" 20 | #include "gtest/gtest.h" 21 | 22 | namespace clink { 23 | 24 | namespace { 25 | 26 | class OneHotEncoderTest : public testing::Test { 27 | protected: 28 | static void SetUpTestSuite() { 29 | assert(host_context == nullptr); 30 | host_context = 31 | CreateHostContext("mstd", tfrt::HostAllocatorType::kLeakCheckMalloc) 32 | .release(); 33 | 34 | assert(mlir_context == nullptr); 35 | mlir_context = new MLIRContext(); 36 | mlir_context->allowUnregisteredDialects(); 37 | mlir_context->printOpOnDiagnostic(true); 38 | mlir::DialectRegistry registry; 39 | registry.insert(); 40 | registry.insert(); 41 | mlir_context->appendDialectRegistry(registry); 42 | 43 | assert(exec_context == nullptr); 44 | exec_context = new ExecutionContext( 45 | *tfrt::RequestContextBuilder(host_context, nullptr).build()); 46 | } 47 | 48 | static void TearDownTestSuite() { 49 | delete host_context; 50 | delete mlir_context; 51 | delete exec_context; 52 | host_context = nullptr; 53 | mlir_context = nullptr; 54 | exec_context = nullptr; 55 | } 56 | 57 | static tfrt::HostContext *host_context; 58 | static MLIRContext *mlir_context; 59 | static ExecutionContext *exec_context; 60 | }; 61 | 62 | tfrt::HostContext *OneHotEncoderTest::host_context = nullptr; 63 | 64 | MLIRContext *OneHotEncoderTest::mlir_context = nullptr; 65 | 66 | ExecutionContext *OneHotEncoderTest::exec_context = nullptr; 67 | 68 | TEST_F(OneHotEncoderTest, Param) { 69 | RCReference model = 70 | tfrt::TakeRef(host_context->Construct(host_context)); 71 | model->setDropLast(false); 72 | EXPECT_FALSE(model->getDropLast()); 73 | model->setDropLast(true); 74 | EXPECT_TRUE(model->getDropLast()); 75 | } 76 | 77 | TEST_F(OneHotEncoderTest, Transform) { 78 | test::TemporaryFolder tmp_folder; 79 | 80 | nlohmann::json params; 81 | // TODO: Add helper function that converts between json data of structured 82 | // format and that of Flink ML, which wraps all values as strings 83 | params["paramMap"]["dropLast"] = "false"; 84 | 85 | OneHotEncoderModelDataProto model_data; 86 | model_data.add_featuresizes(2); 87 | model_data.add_featuresizes(3); 88 | 89 | test::saveMetaDataModelData(tmp_folder.getAbsolutePath(), params, model_data); 90 | 91 | auto model = 92 | OneHotEncoderModel::load(tmp_folder.getAbsolutePath(), host_context); 93 | EXPECT_FALSE((bool)model.takeError()); 94 | 95 | SparseVector expected_vector(2); 96 | expected_vector.set(1, 1.0); 97 | 98 | tfrt::AsyncValueRef value_ref = MakeAvailableAsyncValueRef(1); 99 | tfrt::AsyncValueRef colum_index_ref = MakeAvailableAsyncValueRef(0); 100 | llvm::SmallVector inputs{ 101 | value_ref.GetAsyncValue(), colum_index_ref.GetAsyncValue()}; 102 | 103 | auto outputs = model.get()->transform(inputs, *exec_context); 104 | host_context->Await(outputs); 105 | SparseVector &actual_vector = outputs[0]->get(); 106 | EXPECT_EQ(actual_vector, expected_vector); 107 | } 108 | 109 | TEST_F(OneHotEncoderTest, Mlir) { 110 | test::TemporaryFolder tmp_folder; 111 | 112 | nlohmann::json params; 113 | params["paramMap"]["dropLast"] = "false"; 114 | 115 | OneHotEncoderModelDataProto model_data; 116 | model_data.add_featuresizes(2); 117 | model_data.add_featuresizes(3); 118 | 119 | test::saveMetaDataModelData(tmp_folder.getAbsolutePath(), params, model_data); 120 | 121 | const std::string mlir_script = R"mlir( 122 | func @load_model(%path: !tfrt.string) -> !clink.model { 123 | %model = clink.load.onehotencoder %path 124 | tfrt.return %model : !clink.model 125 | } 126 | 127 | func @transform_inputs(%model: !clink.model, %value: i32, %column_index: i32) -> !clink.vector { 128 | %outputs = clink.transform %model, %value, %column_index : (i32, i32) -> !clink.vector 129 | tfrt.return %outputs : !clink.vector 130 | } 131 | )mlir"; 132 | 133 | clink::ClinkRunner::Builder builder; 134 | builder.set_mlir_fn_name("load_model") 135 | .set_mlir_input(mlir_script) 136 | .set_host_context(host_context) 137 | .set_mlir_context(mlir_context); 138 | auto model_load_runner = builder.Compile(); 139 | 140 | llvm::SmallVector> model_load_inputs{ 141 | tfrt::MakeAvailableAsyncValueRef( 142 | tmp_folder.getAbsolutePath())}; 143 | auto model_ref = model_load_runner.Run(model_load_inputs)[0]; 144 | 145 | builder.set_mlir_fn_name("transform_inputs"); 146 | auto model_transform_runner = builder.Compile(); 147 | 148 | { 149 | llvm::SmallVector, 4> inputs; 150 | inputs.push_back(model_ref); 151 | inputs.push_back(MakeAvailableAsyncValueRef(1)); 152 | inputs.push_back(MakeAvailableAsyncValueRef(0)); 153 | 154 | auto results = model_transform_runner.Run(inputs); 155 | host_context->Await(results); 156 | SparseVector &actual_vector = results[0]->get(); 157 | 158 | SparseVector expected_vector(2); 159 | expected_vector.set(1, 1.0); 160 | EXPECT_EQ(actual_vector, expected_vector); 161 | } 162 | 163 | { 164 | llvm::SmallVector, 4> inputs; 165 | inputs.push_back(model_ref); 166 | inputs.push_back(MakeAvailableAsyncValueRef(5)); 167 | inputs.push_back(MakeAvailableAsyncValueRef(5)); 168 | 169 | auto results = model_transform_runner.Run(inputs); 170 | host_context->Await(results); 171 | EXPECT_EQ(results[0]->GetError().message, "Column index out of range."); 172 | } 173 | } 174 | 175 | } // namespace 176 | } // namespace clink 177 | -------------------------------------------------------------------------------- /lib/jna/clink_jna.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed under the Apache License, Version 2.0 (the "License"); 3 | * you may not use this file except in compliance with the License. 4 | * You may obtain a copy of the License at 5 | * 6 | * http://www.apache.org/licenses/LICENSE-2.0 7 | * 8 | * Unless required by applicable law or agreed to in writing, software 9 | * distributed under the License is distributed on an "AS IS" BASIS, 10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | * See the License for the specific language governing permissions and 12 | * limitations under the License. 13 | */ 14 | 15 | #include 16 | 17 | #include "clink/feature/one_hot_encoder.h" 18 | #include "clink/kernels/clink_kernels.h" 19 | #include "clink/utils/clink_utils.h" 20 | #include "llvm/Support/Errc.h" 21 | #include "llvm/Support/Error.h" 22 | #include "tfrt/host_context/chain.h" 23 | #include "tfrt/support/logging.h" 24 | 25 | #ifdef __cplusplus 26 | extern "C" { 27 | #endif 28 | 29 | // Handles llvm::Error generated in Clink JNA methods. This function prints 30 | // corresponding error message to std::err and set errno to 1, which causes Java 31 | // codes to throw LastErrorException. 32 | #define CLINK_JNA_HANDLE_ERROR(ERR, RET_VALUE) \ 33 | do { \ 34 | if (auto err = ERR) { \ 35 | TFRT_LOG(ERROR) << tfrt::StrCat(err) << "\n"; \ 36 | errno = -1; \ 37 | return RET_VALUE; \ 38 | } \ 39 | } while (0); 40 | 41 | // Handles tfrt::RCReference generated in Clink JNA 42 | // methods. This function prints corresponding error message to std::err and set 43 | // errno to 1, which causes Java codes to throw LastErrorException. 44 | #define CLINK_JNA_HANDLE_ASYNC_ERROR(ERR, RET_VALUE) \ 45 | do { \ 46 | if (ERR->IsError()) { \ 47 | TFRT_LOG(ERROR) << tfrt::StrCat(ERR->GetError()) << "\n"; \ 48 | errno = -1; \ 49 | return RET_VALUE; \ 50 | } \ 51 | } while (0); 52 | 53 | namespace { 54 | inline tfrt::HostContext *getJnaHostContext() { 55 | #ifndef NDEBUG 56 | static tfrt::HostAllocatorType allocator_type = 57 | tfrt::HostAllocatorType::kLeakCheckMalloc; 58 | #else 59 | static tfrt::HostAllocatorType allocator_type = 60 | tfrt::HostAllocatorType::kMalloc; 61 | #endif 62 | static tfrt::HostContext *jna_host_context = 63 | clink::CreateHostContext("mstd", allocator_type).release(); 64 | return jna_host_context; 65 | } 66 | 67 | inline ExecutionContext &getJnaExecutionContext() { 68 | static ExecutionContext exec_ctx( 69 | *tfrt::RequestContextBuilder(getJnaHostContext(), nullptr).build()); 70 | return exec_ctx; 71 | } 72 | 73 | } // namespace 74 | 75 | double SquareAdd(double x, double y) { 76 | ExecutionContext &exec_ctx = getJnaExecutionContext(); 77 | 78 | AsyncValueRef x_async = MakeAvailableAsyncValueRef(x); 79 | Argument x_arg(x_async.GetAsyncValue()); 80 | 81 | AsyncValueRef y_async = MakeAvailableAsyncValueRef(y); 82 | Argument y_arg(y_async.GetAsyncValue()); 83 | 84 | AsyncValueRef result_async = clink::SquareAdd(x_arg, y_arg, exec_ctx); 85 | 86 | exec_ctx.host()->Await(result_async.CopyRCRef()); 87 | return result_async.get(); 88 | } 89 | 90 | double Square(double x) { return clink::Square(x); } 91 | 92 | // Struct representation of clink::SparseVector. It is only used for JNA to 93 | // transmit data between Java and C++. 94 | typedef struct SparseVectorJNA { 95 | SparseVectorJNA(const clink::SparseVector &vector, 96 | tfrt::HostContext *host_context); 97 | ~SparseVectorJNA(); 98 | 99 | // Total dimensions of the sparse vector. 100 | int n; 101 | int *indices; 102 | double *values; 103 | // Length of indices and values array. 104 | int length; 105 | 106 | tfrt::HostContext *host_; 107 | } SparseVectorJNA; 108 | 109 | SparseVectorJNA::SparseVectorJNA(const clink::SparseVector &sparse_vector, 110 | tfrt::HostContext *host_context) 111 | : host_(host_context) { 112 | this->n = sparse_vector.size(); 113 | this->length = 0; 114 | for (int i = 0; i < this->n; i++) { 115 | if (sparse_vector.get(i).get() != 0.0) { 116 | this->length++; 117 | } 118 | } 119 | 120 | this->indices = host_->Allocate(this->length); 121 | this->values = host_->Allocate(this->length); 122 | int offset = 0; 123 | for (int i = 0; i < this->n; i++) { 124 | if (sparse_vector.get(i).get() != 0.0) { 125 | this->indices[offset] = i; 126 | this->values[offset] = sparse_vector.get(i).get(); 127 | offset++; 128 | } 129 | } 130 | } 131 | 132 | SparseVectorJNA::~SparseVectorJNA() { 133 | host_->Deallocate(this->indices, this->length); 134 | host_->Deallocate(this->values, this->length); 135 | } 136 | 137 | void SparseVector_delete(SparseVectorJNA *vector) { 138 | getJnaHostContext()->Destruct(vector); 139 | } 140 | 141 | SparseVectorJNA *OneHotEncoderModel_transform(clink::OneHotEncoderModel *model, 142 | const int value, 143 | const int column_index) { 144 | tfrt::AsyncValueRef value_ref = MakeAvailableAsyncValueRef(value); 145 | tfrt::AsyncValueRef colum_index_ref = 146 | MakeAvailableAsyncValueRef(column_index); 147 | llvm::SmallVector inputs{ 148 | value_ref.GetAsyncValue(), colum_index_ref.GetAsyncValue()}; 149 | 150 | auto output = model->transform(inputs, getJnaExecutionContext())[0]; 151 | getJnaHostContext()->Await(output); 152 | CLINK_JNA_HANDLE_ASYNC_ERROR(output, NULL); 153 | clink::SparseVector &actual_vector = output->get(); 154 | 155 | return getJnaHostContext()->Construct(actual_vector, 156 | getJnaHostContext()); 157 | } 158 | 159 | clink::OneHotEncoderModel *OneHotEncoderModel_load(const char *path) { 160 | auto model = clink::OneHotEncoderModel::load(path, getJnaHostContext()); 161 | CLINK_JNA_HANDLE_ERROR(model.takeError(), NULL); 162 | return model->release(); 163 | } 164 | 165 | void OneHotEncoderModel_delete(clink::OneHotEncoderModel *model) { 166 | model->DropRef(); 167 | } 168 | 169 | #ifdef __cplusplus 170 | } 171 | #endif 172 | -------------------------------------------------------------------------------- /BUILD: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | 13 | load("@bazel_skylib//:bzl_library.bzl", "bzl_library") 14 | load("@tf_runtime//:build_defs.bzl", "tfrt_cc_binary", "tfrt_cc_library") 15 | load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") 16 | load("@rules_cc//cc:defs.bzl", "cc_proto_library") 17 | load("@rules_java//java:defs.bzl", "java_proto_library") 18 | load("@rules_proto//proto:defs.bzl", "proto_library") 19 | load("@com_github_bazelbuild_buildtools//buildifier:def.bzl", "buildifier") 20 | 21 | package( 22 | default_visibility = ["//:__subpackages__"], 23 | ) 24 | 25 | licenses(["notice"]) 26 | 27 | package_group( 28 | name = "friends", 29 | packages = [ 30 | "//...", 31 | ], 32 | ) 33 | 34 | buildifier( 35 | name = "buildifier", 36 | exclude_patterns = ["./tfrt/*"], 37 | ) 38 | 39 | tfrt_cc_binary( 40 | name = "executor", 41 | srcs = [ 42 | "lib/executor/main.cc", 43 | ], 44 | deps = [ 45 | ":clink_kernels_alwayslink", 46 | ":clink_kernels_opdefs", 47 | ":clink_utils", 48 | "@tf_runtime//:basic_kernels_alwayslink", 49 | "@tf_runtime//:hostcontext_alwayslink", 50 | ], 51 | ) 52 | 53 | tfrt_cc_library( 54 | name = "clink_utils", 55 | srcs = [ 56 | "lib/utils/clink_runner.cc", 57 | "lib/utils/clink_utils.cc", 58 | ], 59 | hdrs = [ 60 | "include/clink/utils/clink_runner.h", 61 | "include/clink/utils/clink_utils.h", 62 | ], 63 | visibility = [":friends"], 64 | deps = [ 65 | "@llvm-project//llvm:Support", 66 | "@tf_runtime//:bef_executor_driver", 67 | "@tf_runtime//:mlirtobef", 68 | ], 69 | ) 70 | 71 | tfrt_cc_library( 72 | name = "clink_kernels", 73 | srcs = [ 74 | "lib/feature/one_hot_encoder.cc", 75 | "lib/kernels/clink_kernels.cc", 76 | "lib/linalg/sparse_vector.cc", 77 | ], 78 | hdrs = [ 79 | "include/clink/api/model.h", 80 | "include/clink/feature/one_hot_encoder.h", 81 | "include/clink/kernels/clink_kernels.h", 82 | "include/clink/linalg/sparse_vector.h", 83 | "include/clink/linalg/vector.h", 84 | ], 85 | alwayslink_static_registration_src = "lib/kernels/static_registration.cc", 86 | visibility = [":friends"], 87 | deps = [ 88 | ":clink_cc_proto", 89 | ":clink_utils", 90 | "@com_github_nlohmann_json//:json", 91 | "@tf_runtime//:hostcontext", 92 | ], 93 | ) 94 | 95 | gentbl_cc_library( 96 | name = "clink_kernels_opdefs_inc_gen", 97 | includes = ["include"], 98 | tbl_outs = [ 99 | ( 100 | ["-gen-op-decls"], 101 | "include/clink/kernels/opdefs/clink_kernels.h.inc", 102 | ), 103 | ( 104 | ["-gen-op-defs"], 105 | "include/clink/kernels/opdefs/clink_kernels.cpp.inc", 106 | ), 107 | ], 108 | tblgen = "@llvm-project//mlir:mlir-tblgen", 109 | td_file = "include/clink/kernels/opdefs/clink_kernels.td", 110 | deps = [ 111 | "@tf_runtime//:OpBaseTdFiles", 112 | ], 113 | ) 114 | 115 | tfrt_cc_library( 116 | name = "clink_kernels_opdefs", 117 | srcs = [ 118 | "lib/kernels/opdefs/clink_kernels.cc", 119 | ], 120 | hdrs = [ 121 | "include/clink/kernels/opdefs/clink_kernels.h", 122 | "include/clink/kernels/opdefs/types.h", 123 | ], 124 | visibility = [":friends"], 125 | deps = [ 126 | ":clink_kernels_opdefs_inc_gen", 127 | "@tf_runtime//:basic_kernels_opdefs", 128 | ], 129 | ) 130 | 131 | tfrt_cc_binary( 132 | name = "clink_jna", 133 | srcs = [ 134 | "lib/jna/clink_jna.cc", 135 | ], 136 | linkshared = True, 137 | visibility = [":friends"], 138 | deps = [ 139 | ":clink_kernels", 140 | "@tf_runtime//:hostcontext_alwayslink", 141 | ], 142 | ) 143 | 144 | FLINK_VERSION = "1_14_0" 145 | 146 | SCALA_VERSION = "2_12" 147 | 148 | java_library( 149 | name = "clink_kernels_java_deps", 150 | exports = [ 151 | "@maven//:commons_collections_commons_collections_3_2_2", 152 | "@maven//:net_java_dev_jna_jna", 153 | "@maven//:org_apache_commons_commons_compress", 154 | "@maven//:org_apache_commons_commons_lang3_3_3_2", 155 | "@maven//:org_apache_flink_flink_clients_%s" % SCALA_VERSION, 156 | "@maven//:org_apache_flink_flink_connector_files", 157 | "@maven//:org_apache_flink_flink_core", 158 | "@maven//:org_apache_flink_flink_streaming_java_%s" % SCALA_VERSION, 159 | "@maven//:org_apache_flink_flink_shaded_jackson", 160 | "@maven//:org_apache_flink_flink_table_api_java", 161 | "@maven//:org_apache_flink_flink_table_api_java_bridge_%s" % SCALA_VERSION, 162 | "@maven//:org_apache_flink_flink_table_planner_%s" % SCALA_VERSION, 163 | "@maven//:org_apache_flink_flink_table_runtime_%s" % SCALA_VERSION, 164 | "@maven//:org_apache_flink_flink_ml_core_%s" % SCALA_VERSION, 165 | "@maven//:org_apache_flink_flink_ml_iteration_%s" % SCALA_VERSION, 166 | "@maven//:org_apache_flink_flink_ml_lib_%s" % SCALA_VERSION, 167 | ], 168 | ) 169 | 170 | java_library( 171 | name = "clink_kernels_java_test_deps", 172 | exports = [ 173 | "@maven//:org_apache_flink_flink_test_utils_junit_%s" % FLINK_VERSION, 174 | "@maven//:junit_junit", 175 | ], 176 | ) 177 | 178 | java_library( 179 | name = "clink_kernels_java", 180 | srcs = glob(["java-lib/src/main/**/*.java"]), 181 | visibility = [":friends"], 182 | deps = [ 183 | ":clink_java_proto", 184 | ":clink_jna", 185 | ":clink_kernels_java_deps", 186 | ], 187 | ) 188 | 189 | java_test( 190 | name = "clink_kernels_java_test", 191 | srcs = glob(["java-lib/src/test/**/*.java"]), 192 | jvm_flags = [ 193 | "-Djna.library.path=.", 194 | ], 195 | test_class = "org.flinkextended.clink.util.AllTestsRunner", 196 | visibility = [":friends"], 197 | deps = [ 198 | ":clink_java_proto", 199 | ":clink_jna", 200 | ":clink_kernels_java", 201 | ":clink_kernels_java_deps", 202 | ":clink_kernels_java_test_deps", 203 | ], 204 | ) 205 | 206 | proto_library( 207 | name = "clink_proto", 208 | srcs = ["include/clink/feature/proto/one_hot_encoder.proto"], 209 | ) 210 | 211 | cc_proto_library( 212 | name = "clink_cc_proto", 213 | deps = [":clink_proto"], 214 | ) 215 | 216 | java_proto_library( 217 | name = "clink_java_proto", 218 | deps = [":clink_proto"], 219 | ) 220 | -------------------------------------------------------------------------------- /java-lib/src/test/java/org/flinkextended/clink/feature/ClinkOneHotEncoderTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Clink Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package org.flinkextended.clink.feature; 18 | 19 | import org.apache.flink.api.common.restartstrategy.RestartStrategies; 20 | import org.apache.flink.api.java.tuple.Tuple2; 21 | import org.apache.flink.configuration.Configuration; 22 | import org.apache.flink.ml.feature.onehotencoder.OneHotEncoderModelData; 23 | import org.apache.flink.ml.linalg.Vectors; 24 | import org.apache.flink.ml.util.ReadWriteUtils; 25 | import org.apache.flink.streaming.api.datastream.DataStream; 26 | import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; 27 | import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; 28 | import org.apache.flink.table.api.Table; 29 | import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; 30 | import org.apache.flink.types.Row; 31 | 32 | import com.sun.jna.LastErrorException; 33 | import org.flinkextended.clink.feature.onehotencoder.ClinkOneHotEncoder; 34 | import org.flinkextended.clink.feature.onehotencoder.ClinkOneHotEncoderModel; 35 | import org.junit.Assert; 36 | import org.junit.Before; 37 | import org.junit.Rule; 38 | import org.junit.Test; 39 | import org.junit.rules.TemporaryFolder; 40 | 41 | import static org.junit.Assert.assertEquals; 42 | 43 | /** Tests Java wrapped C++ OneHotEncoder Estimator and Model Operator. */ 44 | public class ClinkOneHotEncoderTest { 45 | @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); 46 | 47 | private StreamExecutionEnvironment env; 48 | private StreamTableEnvironment tEnv; 49 | private ClinkOneHotEncoder estimator; 50 | private String savePath; 51 | private static final Row[] trainInput = new Row[] {Row.of(0, 1), Row.of(2, 3)}; 52 | private static final Row predictInput = Row.of(0, 1); 53 | private static final Row expectedOutput = 54 | Row.of( 55 | 0, 56 | 1, 57 | Vectors.sparse(2, new int[] {0}, new double[] {1.0}), 58 | Vectors.sparse(3, new int[] {1}, new double[] {1.0})); 59 | 60 | @Before 61 | public void before() throws Exception { 62 | Configuration config = new Configuration(); 63 | config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); 64 | env = StreamExecutionEnvironment.getExecutionEnvironment(config); 65 | env.setParallelism(4); 66 | env.enableCheckpointing(100); 67 | env.setRestartStrategy(RestartStrategies.noRestart()); 68 | tEnv = StreamTableEnvironment.create(env); 69 | 70 | estimator = 71 | new ClinkOneHotEncoder() 72 | .setInputCols("a", "b") 73 | .setOutputCols("c", "d") 74 | .setDropLast(false); 75 | 76 | savePath = tempFolder.newFolder().getAbsolutePath(); 77 | } 78 | 79 | @Test 80 | public void testFitAndPredict() throws Exception { 81 | DataStream trainStream = env.fromElements(trainInput); 82 | Table trainTable = tEnv.fromDataStream(trainStream).as("a", "b"); 83 | DataStream predictStream = env.fromElements(predictInput); 84 | Table predictTable = tEnv.fromDataStream(predictStream).as("a", "b"); 85 | 86 | estimator.fit(trainTable).save(savePath); 87 | env.execute(); 88 | ClinkOneHotEncoderModel model = ClinkOneHotEncoderModel.load(env, savePath); 89 | 90 | Table outputTable = model.transform(predictTable)[0]; 91 | 92 | Row actual = outputTable.execute().collect().next(); 93 | assertEquals(expectedOutput, actual); 94 | } 95 | 96 | @Test 97 | public void testInvalidInput() throws Exception { 98 | DataStream trainStream = env.fromElements(trainInput); 99 | Table trainTable = tEnv.fromDataStream(trainStream).as("a", "b"); 100 | DataStream predictStream = env.fromElements(Row.of(5, 6)); 101 | Table predictTable = tEnv.fromDataStream(predictStream).as("a", "b"); 102 | 103 | estimator.fit(trainTable).save(savePath); 104 | env.execute(); 105 | ClinkOneHotEncoderModel model = ClinkOneHotEncoderModel.load(env, savePath); 106 | 107 | Table outputTable = model.transform(predictTable)[0]; 108 | 109 | try { 110 | outputTable.execute().collect().next(); 111 | Assert.fail("Expected LastErrorException"); 112 | } catch (Exception e) { 113 | Throwable exception = e; 114 | while (exception.getCause() != null) { 115 | exception = exception.getCause(); 116 | } 117 | assertEquals(LastErrorException.class, exception.getClass()); 118 | } 119 | } 120 | 121 | @Test 122 | public void testGetModelData() throws Exception { 123 | estimator.setInputCols("a").setOutputCols("c"); 124 | DataStream trainStream = env.fromElements(Row.of(0), Row.of(1), Row.of(2)); 125 | Table trainTable = tEnv.fromDataStream(trainStream).as("a"); 126 | 127 | ClinkOneHotEncoderModel model = estimator.fit(trainTable); 128 | Tuple2 expected = new Tuple2<>(0, 2); 129 | Tuple2 actual = 130 | OneHotEncoderModelData.getModelDataStream(model.getModelData()[0]) 131 | .executeAndCollect() 132 | .next(); 133 | assertEquals(expected, actual); 134 | } 135 | 136 | @Test 137 | public void testSetModelData() throws Exception { 138 | DataStream trainStream = env.fromElements(trainInput); 139 | Table trainTable = tEnv.fromDataStream(trainStream).as("a", "b"); 140 | DataStream predictStream = env.fromElements(predictInput); 141 | Table predictTable = tEnv.fromDataStream(predictStream).as("a", "b"); 142 | 143 | ClinkOneHotEncoderModel modelA = estimator.fit(trainTable); 144 | 145 | Table modelData = modelA.getModelData()[0]; 146 | ClinkOneHotEncoderModel modelB = new ClinkOneHotEncoderModel().setModelData(modelData); 147 | ReadWriteUtils.updateExistingParams(modelB, modelA.getParamMap()); 148 | modelB.save(savePath); 149 | env.execute(); 150 | ClinkOneHotEncoderModel modelC = ClinkOneHotEncoderModel.load(env, savePath); 151 | 152 | Table outputTable = modelC.transform(predictTable)[0]; 153 | 154 | Row actual = outputTable.execute().collect().next(); 155 | Row expected = 156 | Row.of( 157 | 0, 158 | 1, 159 | Vectors.sparse(2, new int[] {0}, new double[] {1.0}), 160 | Vectors.sparse(3, new int[] {1}, new double[] {1.0})); 161 | assertEquals(expected, actual); 162 | } 163 | } 164 | -------------------------------------------------------------------------------- /java-lib/src/main/java/org/flinkextended/clink/feature/onehotencoder/ClinkOneHotEncoderModel.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Clink Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package org.flinkextended.clink.feature.onehotencoder; 18 | 19 | import org.apache.flink.api.common.functions.FlatMapFunction; 20 | import org.apache.flink.api.common.functions.MapPartitionFunction; 21 | import org.apache.flink.api.common.functions.RichMapFunction; 22 | import org.apache.flink.api.common.typeinfo.TypeInformation; 23 | import org.apache.flink.api.java.tuple.Tuple2; 24 | import org.apache.flink.api.java.typeutils.RowTypeInfo; 25 | import org.apache.flink.configuration.Configuration; 26 | import org.apache.flink.ml.api.Model; 27 | import org.apache.flink.ml.common.datastream.DataStreamUtils; 28 | import org.apache.flink.ml.common.datastream.TableUtils; 29 | import org.apache.flink.ml.common.param.HasHandleInvalid; 30 | import org.apache.flink.ml.feature.onehotencoder.OneHotEncoderModelData; 31 | import org.apache.flink.ml.feature.onehotencoder.OneHotEncoderParams; 32 | import org.apache.flink.ml.linalg.SparseVector; 33 | import org.apache.flink.ml.param.Param; 34 | import org.apache.flink.ml.util.ParamUtils; 35 | import org.apache.flink.ml.util.ReadWriteUtils; 36 | import org.apache.flink.streaming.api.datastream.DataStream; 37 | import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; 38 | import org.apache.flink.table.api.Table; 39 | import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; 40 | import org.apache.flink.table.api.internal.TableImpl; 41 | import org.apache.flink.table.runtime.typeutils.ExternalTypeInfo; 42 | import org.apache.flink.types.Row; 43 | import org.apache.flink.util.Collector; 44 | import org.apache.flink.util.Preconditions; 45 | 46 | import com.sun.jna.Pointer; 47 | import org.apache.commons.compress.utils.Lists; 48 | import org.apache.commons.lang3.ArrayUtils; 49 | import org.clink.feature.onehotencoder.OneHotEncoderModelDataProto; 50 | import org.flinkextended.clink.jna.ClinkJna; 51 | import org.flinkextended.clink.jna.SparseVectorJna; 52 | import org.flinkextended.clink.util.ByteArrayDecoder; 53 | import org.flinkextended.clink.util.ByteArrayEncoder; 54 | import org.flinkextended.clink.util.ClinkReadWriteUtils; 55 | 56 | import java.io.IOException; 57 | import java.util.*; 58 | import java.util.stream.Collectors; 59 | 60 | /** 61 | * Wrapper class for Flink ML OneHotEncoderModel which calls equivalent C++ operator to transform. 62 | */ 63 | public class ClinkOneHotEncoderModel 64 | implements Model, OneHotEncoderParams { 65 | private final Map, Object> paramMap = new HashMap<>(); 66 | private String modelDataPath; 67 | private Table modelDataTable; 68 | 69 | public ClinkOneHotEncoderModel() { 70 | ParamUtils.initializeMapWithDefaultValues(paramMap, this); 71 | } 72 | 73 | /** 74 | * Sets model data using the given list of tables. Each table could be an unbounded stream of 75 | * model data changes. 76 | * 77 | *

This method should not be invoked if the {@link ClinkOneHotEncoderModel} is instantiated 78 | * through {@link ClinkOneHotEncoderModel#load(StreamExecutionEnvironment, String)}, which have 79 | * provided model data from persistent storage. 80 | * 81 | * @throws IllegalStateException if users attempt to modify the model data of a Model loaded 82 | * from filesystem. 83 | * @param inputs a list of tables 84 | */ 85 | @Override 86 | public ClinkOneHotEncoderModel setModelData(Table... inputs) { 87 | Preconditions.checkArgument(inputs.length == 1); 88 | Preconditions.checkState( 89 | modelDataPath == null, 90 | "It is not allowed to modify the model data of a Model loaded from filesystem."); 91 | modelDataTable = inputs[0]; 92 | return this; 93 | } 94 | 95 | @Override 96 | public Table[] getModelData() { 97 | return new Table[] {modelDataTable}; 98 | } 99 | 100 | /** 101 | * Applies the one-hot encoding process on the given input tables and returns the result tables. 102 | * 103 | *

The {@link ClinkOneHotEncoderModel} object from which this method is invoked must have had 104 | * its model data saved in persistent storage, i.e., must be instantiated through {@link 105 | * ClinkOneHotEncoderModel#load(StreamExecutionEnvironment, String)} method. 106 | * 107 | * @throws IllegalStateException if this {@link ClinkOneHotEncoderModel} object is not created 108 | * through {@link ClinkOneHotEncoderModel#load(StreamExecutionEnvironment, String)} method. 109 | */ 110 | @Override 111 | public Table[] transform(Table... inputs) { 112 | final String[] inputCols = getInputCols(); 113 | final String[] outputCols = getOutputCols(); 114 | 115 | Preconditions.checkArgument(getHandleInvalid().equals(HasHandleInvalid.ERROR_INVALID)); 116 | Preconditions.checkArgument(inputs.length == 1); 117 | Preconditions.checkArgument(inputCols.length == outputCols.length); 118 | Preconditions.checkState( 119 | modelDataPath != null, 120 | "transform() can only be invoked after the model is loaded from a path " 121 | + "containing the saved model data. Please use ClinkOneHotEncoderModel::save() " 122 | + "to save the model data to a certain path in file system and use " 123 | + "ClinkOneHotEncoderModel::load() to instantiate the model from that path."); 124 | 125 | RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); 126 | RowTypeInfo outputTypeInfo = 127 | new RowTypeInfo( 128 | ArrayUtils.addAll( 129 | inputTypeInfo.getFieldTypes(), 130 | Collections.nCopies( 131 | outputCols.length, 132 | ExternalTypeInfo.of(Vector.class)) 133 | .toArray(new TypeInformation[0])), 134 | ArrayUtils.addAll(inputTypeInfo.getFieldNames(), outputCols)); 135 | 136 | StreamTableEnvironment tEnv = 137 | (StreamTableEnvironment) ((TableImpl) modelDataTable).getTableEnvironment(); 138 | DataStream input = tEnv.toDataStream(inputs[0]); 139 | DataStream output = 140 | input.map(new GenerateOutputsFunction(inputCols, modelDataPath), outputTypeInfo); 141 | Table outputTable = tEnv.fromDataStream(output); 142 | return new Table[] {outputTable}; 143 | } 144 | 145 | private static class GenerateOutputsFunction extends RichMapFunction { 146 | private final String[] inputCols; 147 | private final String modelDataPath; 148 | private Pointer modelPointer; 149 | 150 | private GenerateOutputsFunction(String[] inputCols, String modelDataPath) { 151 | this.inputCols = inputCols; 152 | this.modelDataPath = modelDataPath; 153 | } 154 | 155 | @Override 156 | public void open(Configuration parameters) throws Exception { 157 | super.open(parameters); 158 | modelPointer = ClinkJna.INSTANCE.OneHotEncoderModel_load(modelDataPath); 159 | } 160 | 161 | @Override 162 | public Row map(Row row) { 163 | Row resultRow = new Row(inputCols.length); 164 | for (int i = 0; i < inputCols.length; i++) { 165 | String inputCol = inputCols[i]; 166 | int inputValue = ((Number) row.getField(inputCol)).intValue(); 167 | SparseVectorJna.ByReference jnaVector = 168 | ClinkJna.INSTANCE.OneHotEncoderModel_transform(modelPointer, inputValue, i); 169 | SparseVector vector = jnaVector.toSparseVector(); 170 | ClinkJna.INSTANCE.SparseVector_delete(jnaVector); 171 | resultRow.setField(i, vector); 172 | } 173 | return Row.join(row, resultRow); 174 | } 175 | 176 | @Override 177 | public void close() throws Exception { 178 | super.close(); 179 | ClinkJna.INSTANCE.OneHotEncoderModel_delete(modelPointer); 180 | modelPointer = null; 181 | } 182 | } 183 | 184 | @Override 185 | public Map, Object> getParamMap() { 186 | return paramMap; 187 | } 188 | 189 | @Override 190 | public void save(String path) throws IOException { 191 | ClinkReadWriteUtils.saveMetadata(getParamMap(), getClass(), path); 192 | 193 | DataStream modelDataProtoBuf = 194 | DataStreamUtils.mapPartition( 195 | OneHotEncoderModelData.getModelDataStream(getModelData()[0]), 196 | new ModelDataAggregateFunction()); 197 | modelDataProtoBuf.getTransformation().setParallelism(1); 198 | 199 | ReadWriteUtils.saveModelData(modelDataProtoBuf, path, new ByteArrayEncoder()); 200 | } 201 | 202 | private static class ModelDataAggregateFunction 203 | implements MapPartitionFunction, byte[]> { 204 | @Override 205 | public void mapPartition( 206 | Iterable> iterable, Collector collector) { 207 | List> list = Lists.newArrayList(iterable.iterator()); 208 | list.sort(Comparator.comparingInt(o -> o.f0)); 209 | 210 | OneHotEncoderModelDataProto.Builder builder = OneHotEncoderModelDataProto.newBuilder(); 211 | builder.addAllFeatureSizes(list.stream().map(x -> x.f1).collect(Collectors.toList())); 212 | 213 | collector.collect(builder.build().toByteArray()); 214 | } 215 | } 216 | 217 | public static ClinkOneHotEncoderModel load(StreamExecutionEnvironment env, String path) 218 | throws IOException { 219 | ClinkOneHotEncoderModel clinkModel = ReadWriteUtils.loadStageParam(path); 220 | 221 | DataStream modelDataProtobuf = 222 | ReadWriteUtils.loadModelData(env, path, new ByteArrayDecoder()); 223 | DataStream> modelData = 224 | modelDataProtobuf.flatMap(new ModelDataFlatMapFunction()); 225 | 226 | StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); 227 | clinkModel.setModelData(tEnv.fromDataStream(modelData)); 228 | clinkModel.modelDataPath = path; 229 | 230 | return clinkModel; 231 | } 232 | 233 | private static class ModelDataFlatMapFunction 234 | implements FlatMapFunction> { 235 | @Override 236 | public void flatMap(byte[] bytes, Collector> collector) 237 | throws Exception { 238 | OneHotEncoderModelDataProto modelDataProto = 239 | OneHotEncoderModelDataProto.parseFrom(bytes); 240 | for (int i = 0; i < modelDataProto.getFeatureSizesCount(); i++) { 241 | collector.collect(new Tuple2<>(i, modelDataProto.getFeatureSizes(i))); 242 | } 243 | } 244 | } 245 | } 246 | --------------------------------------------------------------------------------