├── .bazelrc ├── BUILD ├── LICENSE ├── README.md ├── WORKSPACE ├── common ├── BUILD ├── page.h ├── page_producer.cc └── page_producer.h ├── dedup ├── BUILD ├── dedup_main.cc ├── hash_dumper.cc ├── hash_dumper.h ├── lsh_index.cc ├── lsh_index.h ├── lsh_main.cc ├── page_featurizer.cc └── page_featurizer.h ├── distill ├── BUILD ├── distill_page_main.cc ├── gen_distilled_page.py ├── gen_ex_main.cc ├── marketing_detection.py └── train_marketing_detection.cc ├── thirdparty ├── BUILD ├── concurrentqueue.BUILD └── minhashcuda.BUILD └── vocab ├── sp_tok.model └── tokme.model /.bazelrc: -------------------------------------------------------------------------------- 1 | ############################################################################ 2 | # All default build options below. 3 | 4 | # Sets the default Apple platform to macOS. 5 | build --apple_platform_type=macos 6 | build --macos_minimum_os=10.14 7 | 8 | # Make Bazel print out all options from rc files. 9 | build --announce_rc 10 | 11 | build --define open_source_build=true 12 | 13 | build --spawn_strategy=standalone 14 | 15 | build --enable_platform_specific_config 16 | 17 | build --experimental_cc_shared_library 18 | 19 | # Disable enabled-by-default TensorFlow features that we don't care about. 20 | build --define=no_aws_support=true 21 | build --define=no_gcp_support=true 22 | build --define=no_hdfs_support=true 23 | build --define=no_kafka_support=true 24 | build --define=no_ignite_support=true 25 | 26 | build --define=grpc_no_ares=true 27 | 28 | build --define=tsl_link_protobuf=true 29 | 30 | build -c opt 31 | 32 | build --config=short_logs 33 | 34 | build --copt=-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir. 35 | build --check_visibility=false 36 | 37 | # Later Bazel flag values override earlier values; if CUDA/ROCM/TPU are enabled, 38 | # these values are overridden. 39 | build --@xla//xla/python:enable_gpu=false 40 | # build --@org_tensorflow//tensorflow/compiler/xla/python:enable_tpu=false 41 | # build --@org_tensorflow//tensorflow/compiler/xla/python:enable_plugin_device=false 42 | 43 | ########################################################################### 44 | 45 | build:posix --copt=-fvisibility=hidden 46 | build:posix --copt=-Wno-sign-compare 47 | build:posix --cxxopt=-std=c++17 48 | build:posix --host_cxxopt=-std=c++17 49 | 50 | build:avx_posix --copt=-mavx 51 | build:avx_posix --host_copt=-mavx 52 | 53 | build:avx_windows --copt=/arch=AVX 54 | 55 | build:avx_linux --copt=-mavx 56 | build:avx_linux --host_copt=-mavx 57 | 58 | build:native_arch_posix --copt=-march=native 59 | build:native_arch_posix --host_copt=-march=native 60 | 61 | build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=1 62 | 63 | build:cuda --repo_env TF_NEED_CUDA=1 64 | # "sm" means we emit only cubin, which is forward compatible within a GPU generation. 65 | # "compute" means we emit both cubin and PTX, which is larger but also forward compatible to future GPU generations. 66 | build:cuda --action_env TF_CUDA_COMPUTE_CAPABILITIES="sm_35,sm_52,sm_60,sm_70,compute_80" 67 | build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain 68 | build:cuda --@local_config_cuda//:enable_cuda 69 | build:cuda --@xla//xla/python:enable_gpu=true 70 | build:cuda --@xla//xla/python:jax_cuda_pip_rpaths=true 71 | build:cuda --define=xla_python_enable_gpu=true 72 | build:cuda --jobs=8 73 | 74 | # Disable TFRT integration for now unless --config=tfrt is specified. 75 | build --deleted_packages=tensorflow/core/tfrt/stubs,tensorflow/compiler/mlir/tfrt,tensorflow/compiler/mlir/tfrt/benchmarks,tensorflow/compiler/mlir/tfrt/ir,tensorflow/compiler/mlir/tfrt/ir/mlrt,tensorflow/compiler/mlir/tfrt/jit/python_binding,tensorflow/compiler/mlir/tfrt/jit/transforms,tensorflow/compiler/mlir/tfrt/python_tests,tensorflow/compiler/mlir/tfrt/tests,tensorflow/compiler/mlir/tfrt/tests/mlrt,tensorflow/compiler/mlir/tfrt/tests/ir,tensorflow/compiler/mlir/tfrt/tests/analysis,tensorflow/compiler/mlir/tfrt/tests/jit,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_tfrt,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt,tensorflow/compiler/mlir/tfrt/tests/tf_to_corert,tensorflow/compiler/mlir/tfrt/tests/tf_to_tfrt_data,tensorflow/compiler/mlir/tfrt/tests/saved_model,tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu,tensorflow/compiler/mlir/tfrt/transforms/mlrt,tensorflow/core/runtime_fallback,tensorflow/core/runtime_fallback/conversion,tensorflow/core/runtime_fallback/kernel,tensorflow/core/runtime_fallback/opdefs,tensorflow/core/runtime_fallback/runtime,tensorflow/core/runtime_fallback/util,tensorflow/core/tfrt/mlrt,tensorflow/core/tfrt/mlrt/attribute,tensorflow/core/tfrt/mlrt/kernel,tensorflow/core/tfrt/mlrt/bytecode,tensorflow/core/tfrt/mlrt/interpreter,tensorflow/compiler/mlir/tfrt/translate/mlrt,tensorflow/compiler/mlir/tfrt/translate/mlrt/testdata,tensorflow/core/tfrt/gpu,tensorflow/core/tfrt/run_handler_thread_pool,tensorflow/core/tfrt/runtime,tensorflow/core/tfrt/saved_model,tensorflow/core/tfrt/graph_executor,tensorflow/core/tfrt/saved_model/tests,tensorflow/core/tfrt/tpu,tensorflow/core/tfrt/utils,tensorflow/core/tfrt/utils/debug,tensorflow/core/tfrt/saved_model/python,tensorflow/core/tfrt/graph_executor/python 76 | 77 | build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain 78 | build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true 79 | build:rocm --@org_tensorflow//tensorflow/compiler/xla/python:enable_gpu=true 80 | build:rocm --define=xla_python_enable_gpu=true 81 | build:rocm --repo_env TF_NEED_ROCM=1 82 | build:rocm --action_env TF_ROCM_AMDGPU_TARGETS="gfx900,gfx906,gfx908,gfx90a,gfx1030" 83 | 84 | build:nonccl --define=no_nccl_support=true 85 | 86 | # Tensorflow uses M_* math constants that only get defined by MSVC headers if 87 | # _USE_MATH_DEFINES is defined. 88 | build:windows --copt=/D_USE_MATH_DEFINES 89 | build:windows --host_copt=/D_USE_MATH_DEFINES 90 | # Make sure to include as little of windows.h as possible 91 | build:windows --copt=-DWIN32_LEAN_AND_MEAN 92 | build:windows --host_copt=-DWIN32_LEAN_AND_MEAN 93 | build:windows --copt=-DNOGDI 94 | build:windows --host_copt=-DNOGDI 95 | # https://devblogs.microsoft.com/cppblog/announcing-full-support-for-a-c-c-conformant-preprocessor-in-msvc/ 96 | # otherwise, there will be some compiling error due to preprocessing. 97 | build:windows --copt=/Zc:preprocessor 98 | build:windows --cxxopt=/std:c++17 99 | build:windows --host_cxxopt=/std:c++17 100 | # Generate PDB files, to generate useful PDBs, in opt compilation_mode 101 | # --copt /Z7 is needed. 102 | build:windows --linkopt=/DEBUG 103 | build:windows --host_linkopt=/DEBUG 104 | build:windows --linkopt=/OPT:REF 105 | build:windows --host_linkopt=/OPT:REF 106 | build:windows --linkopt=/OPT:ICF 107 | build:windows --host_linkopt=/OPT:ICF 108 | build:windows --incompatible_strict_action_env=true 109 | 110 | build:linux --config=posix 111 | build:linux --copt=-Wno-unknown-warning-option 112 | # Workaround for gcc 10+ warnings related to upb. 113 | # See https://github.com/tensorflow/tensorflow/issues/39467 114 | build:linux --copt=-Wno-stringop-truncation 115 | build:linux --copt=-Wno-array-parameter 116 | 117 | build:macos --config=posix 118 | 119 | # Suppress all warning messages. 120 | build:short_logs --output_filter=DONT_MATCH_ANYTHING 121 | 122 | build:tpu --@org_tensorflow//tensorflow/compiler/xla/python:enable_tpu=true 123 | build:tpu --define=with_tpu_support=true 124 | 125 | build:plugin_device --@org_tensorflow//tensorflow/compiler/xla/python:enable_plugin_device=true 126 | 127 | ######################################################################### 128 | # RBE config options below. 129 | # Flag to enable remote config 130 | common --experimental_repo_remote_exec 131 | 132 | build:rbe --repo_env=BAZEL_DO_NOT_DETECT_CPP_TOOLCHAIN=1 133 | build:rbe --google_default_credentials 134 | build:rbe --bes_backend=buildeventservice.googleapis.com 135 | build:rbe --bes_results_url="https://source.cloud.google.com/results/invocations" 136 | build:rbe --bes_timeout=600s 137 | build:rbe --define=EXECUTOR=remote 138 | build:rbe --distinct_host_configuration=false 139 | build:rbe --flaky_test_attempts=3 140 | build:rbe --jobs=200 141 | build:rbe --remote_executor=grpcs://remotebuildexecution.googleapis.com 142 | build:rbe --remote_timeout=3600 143 | build:rbe --spawn_strategy=remote,worker,standalone,local 144 | test:rbe --test_env=USER=anon 145 | # Attempt to minimize the amount of data transfer between bazel and the remote 146 | # workers: 147 | build:rbe --remote_download_toplevel 148 | 149 | build:rbe_linux --config=rbe 150 | build:rbe_linux --action_env=PATH="/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/local/go/bin" 151 | build:rbe_linux --host_javabase=@bazel_toolchains//configs/ubuntu16_04_clang/1.1:jdk8 152 | build:rbe_linux --javabase=@bazel_toolchains//configs/ubuntu16_04_clang/1.1:jdk8 153 | build:rbe_linux --host_java_toolchain=@bazel_tools//tools/jdk:toolchain_hostjdk8 154 | build:rbe_linux --java_toolchain=@bazel_tools//tools/jdk:toolchain_hostjdk8 155 | 156 | # Non-rbe settings we should include because we do not run configure 157 | build:rbe_linux --config=avx_linux 158 | build:rbe_linux --linkopt=-lrt 159 | build:rbe_linux --host_linkopt=-lrt 160 | build:rbe_linux --linkopt=-lm 161 | build:rbe_linux --host_linkopt=-lm 162 | 163 | # Use the GPU toolchain until the CPU one is ready. 164 | # https://github.com/bazelbuild/bazel/issues/13623 165 | build:rbe_cpu_linux_base --config=rbe_linux 166 | build:rbe_cpu_linux_base --host_crosstool_top="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_cuda//crosstool:toolchain" 167 | build:rbe_cpu_linux_base --crosstool_top="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_cuda//crosstool:toolchain" 168 | build:rbe_cpu_linux_base --extra_toolchains="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_cuda//crosstool:toolchain-linux-x86_64" 169 | build:rbe_cpu_linux_base --extra_execution_platforms="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_platform//:platform" 170 | build:rbe_cpu_linux_base --host_platform="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_platform//:platform" 171 | build:rbe_cpu_linux_base --platforms="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_platform//:platform" 172 | 173 | build:rbe_cpu_linux_py37 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_python3.7" 174 | build:rbe_cpu_linux_py37 --python_path="/usr/local/bin/python3.7" 175 | build:rbe_cpu_linux_py38 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_python3.8" 176 | build:rbe_cpu_linux_py38 --python_path="/usr/local/bin/python3.8" 177 | build:rbe_cpu_linux_py39 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_python3.9" 178 | build:rbe_cpu_linux_py39 --python_path="/usr/local/bin/python3.9" 179 | build:rbe_cpu_linux_py310 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_python3.10" 180 | build:rbe_cpu_linux_py310 --python_path="/usr/local/bin/python3.10" 181 | build:rbe_cpu_linux_py311 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_python3.11" 182 | build:rbe_cpu_linux_py311 --python_path="/usr/local/bin/python3.11" 183 | 184 | build:rbe_linux_cuda_base --config=rbe_linux 185 | build:rbe_linux_cuda_base --config=cuda 186 | build:rbe_linux_cuda_base --repo_env=REMOTE_GPU_TESTING=1 187 | 188 | build:rbe_linux_cuda11.1_nvcc_base --config=rbe_linux_cuda_base 189 | build:rbe_linux_cuda11.1_nvcc_base --action_env=TF_CUDA_VERSION=11 190 | build:rbe_linux_cuda11.1_nvcc_base --action_env=TF_CUDNN_VERSION=8 191 | build:rbe_linux_cuda11.1_nvcc_base --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-11.1" 192 | build:rbe_linux_cuda11.1_nvcc_base --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/tensorrt/lib" 193 | build:rbe_linux_cuda11.1_nvcc_base --action_env=GCC_HOST_COMPILER_PATH="/dt9/usr/bin/gcc" 194 | test:rbe_linux_cuda11.1_nvcc_base --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda-11.1/lib64" 195 | build:rbe_linux_cuda11.1_nvcc_base --host_crosstool_top="@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_cuda//crosstool:toolchain" 196 | build:rbe_linux_cuda11.1_nvcc_base --crosstool_top="@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_cuda//crosstool:toolchain" 197 | build:rbe_linux_cuda11.1_nvcc_base --extra_toolchains="@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_cuda//crosstool:toolchain-linux-x86_64" 198 | build:rbe_linux_cuda11.1_nvcc_base --extra_execution_platforms="@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_platform//:platform" 199 | build:rbe_linux_cuda11.1_nvcc_base --host_platform="@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_platform//:platform" 200 | build:rbe_linux_cuda11.1_nvcc_base --platforms="@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_platform//:platform" 201 | build:rbe_linux_cuda11.1_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_cuda" 202 | build:rbe_linux_cuda11.1_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_tensorrt" 203 | build:rbe_linux_cuda11.1_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_nccl" 204 | build:rbe_linux_cuda11.1_nvcc_py3.7 --config=rbe_linux_cuda11.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_python3.7" 205 | build:rbe_linux_cuda11.1_nvcc_py3.7 --python_path="/usr/local/bin/python3.7" 206 | build:rbe_linux_cuda11.1_nvcc_py3.8 --config=rbe_linux_cuda11.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_python3.8" 207 | build:rbe_linux_cuda11.1_nvcc_py3.8 --python_path="/usr/local/bin/python3.8" 208 | build:rbe_linux_cuda11.1_nvcc_py3.9 --config=rbe_linux_cuda11.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_python3.9" 209 | build:rbe_linux_cuda11.1_nvcc_py3.9 --python_path="/usr/local/bin/python3.9" 210 | build:rbe_linux_cuda11.1_nvcc_py3.10 --config=rbe_linux_cuda11.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_python3.10" 211 | build:rbe_linux_cuda11.1_nvcc_py3.10 --python_path="/usr/local/bin/python3.10" 212 | build:rbe_linux_cuda11.1_nvcc_py3.11 --config=rbe_linux_cuda11.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_python3.11" 213 | build:rbe_linux_cuda11.1_nvcc_py3.11 --python_path="/usr/local/bin/python3.11" 214 | 215 | build:rbe_linux_cuda11.4_nvcc_base --config=rbe_linux_cuda_base 216 | build:rbe_linux_cuda11.4_nvcc_base --action_env=TF_CUDA_VERSION=11 217 | build:rbe_linux_cuda11.4_nvcc_base --action_env=TF_CUDNN_VERSION=8 218 | build:rbe_linux_cuda11.4_nvcc_base --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-11.4" 219 | build:rbe_linux_cuda11.4_nvcc_base --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/tensorrt/lib" 220 | build:rbe_linux_cuda11.4_nvcc_base --action_env=GCC_HOST_COMPILER_PATH="/dt9/usr/bin/gcc" 221 | build:rbe_linux_cuda11.4_nvcc_base --host_crosstool_top="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_cuda//crosstool:toolchain" 222 | build:rbe_linux_cuda11.4_nvcc_base --crosstool_top="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_cuda//crosstool:toolchain" 223 | build:rbe_linux_cuda11.4_nvcc_base --extra_toolchains="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_cuda//crosstool:toolchain-linux-x86_64" 224 | build:rbe_linux_cuda11.4_nvcc_base --extra_execution_platforms="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_platform//:platform" 225 | build:rbe_linux_cuda11.4_nvcc_base --host_platform="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_platform//:platform" 226 | build:rbe_linux_cuda11.4_nvcc_base --platforms="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_platform//:platform" 227 | build:rbe_linux_cuda11.4_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_cuda" 228 | build:rbe_linux_cuda11.4_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_tensorrt" 229 | build:rbe_linux_cuda11.4_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_nccl" 230 | build:rbe_linux_cuda11.4_nvcc_py3.7 --config=rbe_linux_cuda11.4_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_python3.7" 231 | build:rbe_linux_cuda11.4_nvcc_py3.7 --python_path="/usr/local/bin/python3.7" 232 | build:rbe_linux_cuda11.4_nvcc_py3.8 --config=rbe_linux_cuda11.4_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_python3.8" 233 | build:rbe_linux_cuda11.4_nvcc_py3.8 --python_path="/usr/local/bin/python3.8" 234 | build:rbe_linux_cuda11.4_nvcc_py3.9 --config=rbe_linux_cuda11.4_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_python3.9" 235 | build:rbe_linux_cuda11.4_nvcc_py3.9 --python_path="/usr/local/bin/python3.9" 236 | build:rbe_linux_cuda11.4_nvcc_py3.10 --config=rbe_linux_cuda11.4_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_python3.10" 237 | build:rbe_linux_cuda11.4_nvcc_py3.10 --python_path="/usr/local/bin/python3.10" 238 | build:rbe_linux_cuda11.4_nvcc_py3.11 --config=rbe_linux_cuda11.4_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_python3.11" 239 | build:rbe_linux_cuda11.4_nvcc_py3.11 --python_path="/usr/local/bin/python3.11" 240 | 241 | # These you may need to change for your own GCP project. 242 | build:tensorflow_testing_rbe --project_id=tensorflow-testing 243 | common:tensorflow_testing_rbe_linux --remote_instance_name=projects/tensorflow-testing/instances/default_instance 244 | build:tensorflow_testing_rbe_linux --config=tensorflow_testing_rbe 245 | ############################################################################# 246 | 247 | # Load `.jax_configure.bazelrc` file written by build.py 248 | try-import %workspace%/.jax_configure.bazelrc 249 | 250 | # Load rc file with user-specific options. 251 | try-import %workspace%/.bazelrc.user 252 | -------------------------------------------------------------------------------- /BUILD: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EastTower16/LLMDataDistill/3ff93f9cddba0d91099e87281125c29e592937f8/BUILD -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LLM Data Distill 2 | 3 | ## 准备数据 4 | - 以悟道数据集为例,下载悟道开源的数据集 200G https://www.scidb.cn/en/detail?dataSetId=c6a3fe684227415a9db8e21bac4a15ab 5 | - 下好后解压到一个目录 6 | - 该数据集分类统计情况大致如下: 7 | {'经济': 1055142, '娱乐': 1538285, '文化': 609237, '军事': 411410, '游戏': 742239, '汽车': 1308636, '科技': 1219031, '农业': 1074709, '体育': 648958, '国际': 596095, '教育': 1051980, '社会': 433914, '旅行': 746855, '房产': 378339, '法律': 34910, '股票': 1134, '豆瓣话题': 169600, '博客': 11628495, '日报': 13571, '评论': 10757, '酒业': 230, '资讯': 1049332, '科普文章': 47066, '孕育常识': 39660, '百科': 9456851, '小红书攻略': 153601, '经验': 456112, '财经': 54040, '健康': 14992, '医学问答': 252159, '亲子': 35, '网页文本': 98745, '新闻': 828978, '生活': 22, '百家号文章': 333216, '黄金': 215, '时尚': 1113, '文旅': 2910, '观点': 1242, '党建': 1, '保险': 70, '期货': 328, '理论': 209, '快讯': 41, '国内': 14, '美容': 7, '国学': 603, '信托': 62, '公益': 14, '能源': 7, '创新': 6, '户外': 5, '海外': 4, '天气': 538, '水利资讯': 9} 8 | 9 | ## 去重 10 | - 数据集中最多分类是博客类,质量低也是该类别,程序默认针对博客类做去重 11 | - 安装bazel 5.1+版本 12 | - 安装设置好cuda 驱动 11.0以上版本 13 | - 编译去重程序 14 | bazel build --config=cuda dedup:dedup_main 15 | - 运行bazel-bin/dedup/dedup_main --wudao_dir --output_path 16 | - 内存占用需要16G左右 17 | - 程序输出重复的key到输出文件 18 | 19 | ## 过滤低质内容 20 | - 目前低质识别只处理营销比较严重的网页 21 | - 编译: 22 | bazel build --config=cuda distill:distill_page_main 23 | - 下载训练好的低质模型: 24 | 链接: https://pan.baidu.com/s/1RvSl1mfUGXJ3Z2WoF8vzdw?pwd=qsbi 提取码: qsbi 25 | - 运行: 26 | bazel-bin/distill/distill_page_main --wudao_dir --model_path --distilled_output_path 27 | - 默认输出的是有营销倾向的文档id 28 | 29 | ## 限制 30 | - 为了速度,目前程序运行需要在支持CUDA的GPU上 31 | - 针对悟道开源数据集,后面需要更灵活的配置,适配自定义的语料 32 | - 低质模型支持更多的类型 -------------------------------------------------------------------------------- /WORKSPACE: -------------------------------------------------------------------------------- 1 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") 2 | load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository") 3 | 4 | 5 | 6 | # Bazel Skylib. 7 | http_archive( 8 | name = "bazel_skylib", # 2022-11-16T18:29:32Z 9 | sha256 = "a22290c26d29d3ecca286466f7f295ac6cbe32c0a9da3a91176a90e0725e3649", 10 | strip_prefix = "bazel-skylib-5bfcb1a684550626ce138fe0fe8f5f702b3764c3", 11 | urls = ["https://github.com/bazelbuild/bazel-skylib/archive/5bfcb1a684550626ce138fe0fe8f5f702b3764c3.zip"], 12 | ) 13 | 14 | 15 | 16 | http_archive( 17 | name = "minhashcuda", 18 | # sha256 = "7c0101c68422aa038314e07842b2f83aa4c3d2b5520ebba350319277c4ad9c03", 19 | strip_prefix = "minhashcuda-a0d014aa31b6cdb26bb5cc2b11ccefe137b0193c", 20 | urls = ["https://github.com/koth/minhashcuda/archive/a0d014aa31b6cdb26bb5cc2b11ccefe137b0193c.tar.gz"], 21 | build_file ="@//thirdparty:minhashcuda.BUILD" 22 | ) 23 | 24 | http_archive( 25 | name="nlohmann_json", 26 | #sha256 = "", 27 | strip_prefix = "json-6af826d0bdb55e4b69e3ad817576745335f243ca", 28 | urls = ["https://github.com/nlohmann/json/archive/6af826d0bdb55e4b69e3ad817576745335f243ca.zip"], 29 | ) 30 | 31 | 32 | http_archive( 33 | name="tokme", 34 | sha256 = "b055f3a5b3db636277ea56ae2ed8c8313e95bd75897e95e7f0705588460543e1", 35 | strip_prefix = "tokme-41d6f49e2ac3bddd8116a0ff0c56b35f370a9fae", 36 | urls = ["https://github.com/koth/tokme/archive/41d6f49e2ac3bddd8116a0ff0c56b35f370a9fae.zip"], 37 | ) 38 | 39 | http_archive( 40 | name="concurrentqueue", 41 | #sha256 = "", 42 | strip_prefix = "concurrentqueue-810f6213a2ee3bbd0c2ff647c28996cfff84df06", 43 | urls = ["https://github.com/koth/concurrentqueue/archive/810f6213a2ee3bbd0c2ff647c28996cfff84df06.zip"], 44 | build_file ="@//thirdparty:concurrentqueue.BUILD" 45 | ) 46 | 47 | 48 | # http_archive( 49 | # name="abseil-cpp", 50 | # sha256 = "ea1d31db00eb37e607bfda17ffac09064670ddf05da067944c4766f517876390", 51 | # strip_prefix = "abseil-cpp-c2435f8342c2d0ed8101cb43adfd605fdc52dca2", 52 | # urls = ["https://github.com/abseil/abseil-cpp/archive/c2435f8342c2d0ed8101cb43adfd605fdc52dca2.zip"], 53 | # ) 54 | 55 | http_archive( 56 | name="australis", 57 | sha256 = "4627ebc92d8135de11172cd91cb45e408fccb8b27e78d4debd89d75921bf845f", 58 | strip_prefix = "australis-bea893c8cf08fde4e09b7b2dd893f86b935bfd9d", 59 | urls = ["https://github.com/EastTower16/australis/archive/bea893c8cf08fde4e09b7b2dd893f86b935bfd9d.zip"], 60 | ) 61 | 62 | 63 | git_repository( 64 | name = "jax", 65 | commit = "c3e242700872c2f7e098a07f3911ee6d2de8132c", 66 | remote = "https://github.com/google/jax.git", 67 | ) 68 | 69 | http_archive( 70 | name = "sentencepiece", 71 | sha256 = "0c28dfd2fad9f215ea276f60e62c41aca7f0ad48fd4bc072dd79180f59b44ec2", 72 | strip_prefix = "sentencepiece-cf093775361a08dbe8d2a5ec98f548b25d7d6e37", 73 | urls = [ 74 | "https://github.com/EastTower16/sentencepiece/archive/cf093775361a08dbe8d2a5ec98f548b25d7d6e37.tar.gz", 75 | ], 76 | ) 77 | # local_repository( 78 | # name = "sentencepiece", 79 | # path = "/f/workspace/sentencepiece", 80 | # ) 81 | 82 | http_archive( 83 | name = "xla", 84 | sha256 = "4ec16aff3862c5a243db956ce558d7a62eb79f5e20747b0e80802a3b0d12e419", 85 | strip_prefix = "xla-12de6ec958419b57be248d0acd2d9f757e71748c", 86 | urls = [ 87 | "https://github.com/openxla/xla/archive/12de6ec958419b57be248d0acd2d9f757e71748c.tar.gz", 88 | ], 89 | ) 90 | load("@xla//third_party/gpus:cuda_configure.bzl", "cuda_configure") 91 | 92 | cuda_configure(name = "local_config_cuda") 93 | # For development, one can use a local TF repository instead. 94 | # local_repository( 95 | # name = "org_tensorflow", 96 | # path = "tensorflow", 97 | # ) 98 | 99 | 100 | load("@xla//:workspace4.bzl", "xla_workspace4") 101 | xla_workspace4() 102 | 103 | load("@xla//:workspace3.bzl", "xla_workspace3") 104 | xla_workspace3() 105 | 106 | load("@xla//:workspace2.bzl", "xla_workspace2") 107 | xla_workspace2() 108 | 109 | load("@xla//:workspace1.bzl", "xla_workspace1") 110 | xla_workspace1() 111 | 112 | load("@xla//:workspace0.bzl", "xla_workspace0") 113 | xla_workspace0() 114 | 115 | load("@jax//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo") 116 | flatbuffers() 117 | -------------------------------------------------------------------------------- /common/BUILD: -------------------------------------------------------------------------------- 1 | package(default_visibility = ["//visibility:public"]) 2 | 3 | cc_library( 4 | name="page", 5 | hdrs=[ 6 | "page.h", 7 | ], 8 | visibility = ["//visibility:public"], 9 | ) 10 | 11 | cc_library( 12 | name = "page_producer", 13 | hdrs =[ 14 | "page_producer.h" 15 | ], 16 | srcs=[ 17 | "page_producer.cc", 18 | ], 19 | deps=[ 20 | ":page", 21 | "@concurrentqueue//:concurrentqueue", 22 | "@nlohmann_json//:json", 23 | "@com_google_absl//absl/log", 24 | ] 25 | ) 26 | -------------------------------------------------------------------------------- /common/page.h: -------------------------------------------------------------------------------- 1 | #ifndef PD_COMMON_PAGE_H_ 2 | #define PD_COMMON_PAGE_H_ 3 | #include 4 | #include 5 | #include 6 | 7 | namespace pd { 8 | 9 | struct Page { 10 | std::string title; 11 | std::string url; 12 | std::string content; 13 | std::string idkey; 14 | std::string category; 15 | 16 | // filled by minhashcuda 17 | std::vector> weighted_hash_values; 18 | // features filled by featurizer 19 | std::unordered_map features; 20 | 21 | }; 22 | 23 | } // namespace pd 24 | 25 | #endif // PD_COMMON_PAGE_H_ 26 | -------------------------------------------------------------------------------- /common/page_producer.cc: -------------------------------------------------------------------------------- 1 | #include "common/page_producer.h" 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "absl/log/log.h" 8 | #include "nlohmann/json.hpp" 9 | 10 | namespace pd { 11 | std::string PageProducer::EOFPageHashKey ="____"; 12 | PageProducer::~PageProducer() {} 13 | 14 | bool PageProducer::initByFileList(const std::vector& fileList) { 15 | producer_ = std::thread([this,fileList]() { 16 | LOG(INFO) << "Start process thread, num file:" << fileList.size(); 17 | for (auto path : fileList) { 18 | if (this->stop_.load()) { 19 | break; 20 | } 21 | std::ifstream file(path); 22 | nlohmann::json objList = nlohmann::json::parse(file); 23 | LOG(INFO) << "got :" << objList.size()<<" from path: " << path; 24 | for (nlohmann::json::iterator it = objList.begin(); it != objList.end(); ++it) { 25 | if (this->stop_.load()) { 26 | break; 27 | } 28 | nlohmann::json& obj = *it; 29 | struct Page p; 30 | p.title = obj["title"].get(); 31 | p.content = obj["content"].get(); 32 | p.idkey = obj["uniqueKey"].get(); 33 | p.category = obj["dataType"].get(); 34 | if(p.category != "博客"){ 35 | continue; 36 | } 37 | while(!this->page_queue_.try_enqueue(p)){ 38 | std::this_thread::sleep_for(std::chrono::milliseconds(100)); 39 | } 40 | // LOG(INFO)<<"processed :"<stop_.load()) { 43 | break; 44 | } 45 | } 46 | struct Page eof; 47 | eof.idkey = EOFPageHashKey; 48 | this->page_queue_.enqueue(eof); 49 | 50 | }); 51 | return true; 52 | } 53 | bool PageProducer::takeOnePage(struct Page& page, bool waitForValidPage) { 54 | if (!waitForValidPage){ 55 | return this->page_queue_.try_dequeue(page); 56 | } 57 | page_queue_.wait_dequeue(page); 58 | return true; 59 | } 60 | void PageProducer::shutdown() { 61 | stop_.store(true); 62 | producer_.join(); 63 | } 64 | 65 | } // namespace pd 66 | -------------------------------------------------------------------------------- /common/page_producer.h: -------------------------------------------------------------------------------- 1 | #ifndef PD_COMMON_PAGE_PRODUCER_H_ 2 | #define PD_COMMON_PAGE_PRODUCER_H_ 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include "blockingconcurrentqueue.h" 9 | #include "common/page.h" 10 | namespace pd { 11 | 12 | class PageProducer { 13 | public: 14 | virtual ~PageProducer(); 15 | virtual bool initByFileList(const std::vector& fileList); 16 | virtual bool takeOnePage(struct Page& page, bool waitForValidPage = false); 17 | virtual void shutdown(); 18 | 19 | private: 20 | moodycamel::BlockingConcurrentQueue page_queue_{1024}; 21 | std::thread producer_; 22 | std::atomic_bool stop_ = false; 23 | 24 | public: 25 | static std::string EOFPageHashKey; 26 | }; 27 | 28 | } // namespace pd 29 | 30 | #endif // PD_COMMON_PAGE_PRODUCER_H_ 31 | -------------------------------------------------------------------------------- /dedup/BUILD: -------------------------------------------------------------------------------- 1 | cc_library( 2 | name="page_featurizer", 3 | srcs=[ 4 | "page_featurizer.cc", 5 | ], 6 | hdrs=[ 7 | "page_featurizer.h", 8 | ], 9 | deps=[ 10 | "//common:page", 11 | "@tokme//:tokme", 12 | ], 13 | visibility = ["//visibility:public"], 14 | ) 15 | 16 | cc_library( 17 | name="hash_dumper", 18 | srcs=[ 19 | "hash_dumper.cc", 20 | ], 21 | hdrs=[ 22 | "hash_dumper.h", 23 | ], 24 | deps=[ 25 | ":page_featurizer", 26 | ":lsh_index", 27 | "@com_google_absl//absl/log", 28 | "@com_google_absl//absl/log:check", 29 | "@minhashcuda//:minhashcuda" 30 | ], 31 | visibility = ["//visibility:public"], 32 | ) 33 | 34 | cc_library( 35 | name="lsh_index", 36 | srcs=[ 37 | "lsh_index.cc", 38 | ], 39 | hdrs=[ 40 | "lsh_index.h", 41 | ], 42 | deps=[ 43 | "@com_google_absl//absl/log", 44 | "@com_google_absl//absl/container:flat_hash_set", 45 | "@com_google_absl//absl/container:flat_hash_map", 46 | ], 47 | visibility = ["//visibility:public"], 48 | ) 49 | 50 | 51 | cc_binary( 52 | name ="dedup_main", 53 | srcs=[ 54 | "dedup_main.cc", 55 | ], 56 | deps=[ 57 | "//common:page_producer", 58 | ":page_featurizer", 59 | ":hash_dumper", 60 | "@com_google_absl//absl/log", 61 | "@com_google_absl//absl/flags:flag", 62 | "@com_google_absl//absl/log:initialize", 63 | "@com_google_absl//absl/flags:parse", 64 | "@minhashcuda//:minhashcuda", 65 | ], 66 | linkopts = [ 67 | '-lm', 68 | ], 69 | ) 70 | 71 | 72 | cc_binary( 73 | name ="lsh_main", 74 | srcs=[ 75 | "lsh_main.cc", 76 | ], 77 | deps=[ 78 | ":lsh_index", 79 | "@com_google_absl//absl/log", 80 | "@com_google_absl//absl/log:check", 81 | "@com_google_absl//absl/flags:flag", 82 | "@com_google_absl//absl/log:initialize", 83 | "@com_google_absl//absl/flags:parse", 84 | "@com_google_absl//absl/strings", 85 | ], 86 | linkopts = [ 87 | '-lm', 88 | ], 89 | ) -------------------------------------------------------------------------------- /dedup/dedup_main.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "absl/log/log.h" 7 | #include "absl/log/initialize.h" 8 | #include "absl/log/globals.h" 9 | #include "absl/flags/flag.h" 10 | #include "absl/flags/parse.h" 11 | #include "absl/flags/usage.h" 12 | #include "common/page_producer.h" 13 | #include "dedup/hash_dumper.h" 14 | 15 | ABSL_FLAG(std::string, wudao_dir, "/g/wudao", "wudao dataset dir"); 16 | ABSL_FLAG(std::string, output_path, "/g/duplicate_keys.txt", 17 | "debup output path"); 18 | ABSL_FLAG(std::string, tokenizer_path, "vocab/tokme.model", "the tokenizer model path"); 19 | 20 | namespace fs = std::filesystem; 21 | 22 | void traverseDirectory(const std::string& path, 23 | std::vector& pathList) { 24 | for (const auto& entry : fs::directory_iterator(path)) { 25 | if (entry.is_directory()) { 26 | fs::path newdir = fs::path(path) / entry.path(); 27 | // 如果是子目录,则递归遍历 28 | traverseDirectory(newdir.string(), pathList); 29 | } else if (entry.is_regular_file() && entry.path().string().size() > 5 && 30 | entry.path().string().substr(entry.path().string().size() - 5) == 31 | std::string(".json")) { 32 | fs::path newpath = fs::path(path) / entry.path(); 33 | pathList.emplace_back(newpath.string()); 34 | } 35 | } 36 | } 37 | 38 | void doWork(pd::PageProducer* producer, pd::HashDumper* hasher){ 39 | pd::Page p; 40 | uint64_t count = 0; 41 | producer->takeOnePage(p, true); 42 | while(p.idkey != pd::PageProducer::EOFPageHashKey){ 43 | hasher->process(p); 44 | count +=1; 45 | if(count % 20000 == 0){ 46 | LOG(INFO) << count << " pages processed!!"; 47 | } 48 | producer->takeOnePage(p, true); 49 | } 50 | hasher->doBatch(); 51 | LOG(INFO) << count << " total pages processed!!"; 52 | } 53 | 54 | int main(int argc, char* argv[]) { 55 | absl::SetProgramUsageMessage("Dedup Main"); 56 | absl::ParseCommandLine(argc, argv); 57 | absl::InitializeLog(); 58 | absl::SetStderrThreshold(absl::LogSeverity::kInfo); 59 | pd::PageProducer pageProducer; 60 | pd::HashDumper hashDumper(absl::GetFlag(FLAGS_tokenizer_path),absl::GetFlag(FLAGS_output_path)); 61 | std::vector pathList; 62 | traverseDirectory(absl::GetFlag(FLAGS_wudao_dir), pathList); 63 | LOG(INFO) << "got " << pathList.size() << " paths!"; 64 | pageProducer.initByFileList(pathList); 65 | doWork(&pageProducer, &hashDumper); 66 | pageProducer.shutdown(); 67 | return 0; 68 | } 69 | 70 | -------------------------------------------------------------------------------- /dedup/hash_dumper.cc: -------------------------------------------------------------------------------- 1 | #include "dedup/hash_dumper.h" 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | #include "absl/log/check.h" 9 | #include "absl/log/log.h" 10 | #include "dedup/lsh_index.h" 11 | #include "dedup/page_featurizer.h" 12 | #include "minhashcuda.h" 13 | 14 | namespace pd { 15 | 16 | const static uint32_t kVocabDim = 51200; 17 | const static uint16_t kMinHashBins = 128; 18 | const static uint32_t kDevice = 0; 19 | const static int kBatchSize = 10240; 20 | const static int kBandNum = 8; 21 | const static int kNumHashSlot = 99829; 22 | const static float kDedupThreshold = 0.85; 23 | HashDumper::HashDumper(const std::string& tokenizerPath, 24 | const std::string& outFile) { 25 | out_ = fopen(outFile.c_str(), "w"); 26 | featurizer_.reset(new PageFeaturizer()); 27 | CHECK(featurizer_->Init(tokenizerPath)); 28 | indptr_.push_back(0); 29 | uint32_t seed = static_cast(time(NULL)); 30 | MHCUDAResult result = mhcudaSuccess; 31 | minhashptr_ = 32 | mhcuda_init(kVocabDim, kMinHashBins, seed, 0, kDevice, 2, &result); 33 | if (result != mhcudaSuccess) { 34 | LOG(FATAL) << "error mhcuda init :" << result; 35 | minhashptr_ = nullptr; 36 | } 37 | result_buffer_ = new uint32_t[kBatchSize * kMinHashBins * 2]; 38 | indexer_.reset(new pd::LshIndex(kBandNum, kNumHashSlot)); 39 | } 40 | HashDumper::~HashDumper() { 41 | if (result_buffer_ != nullptr) { 42 | delete[] result_buffer_; 43 | result_buffer_ = nullptr; 44 | } 45 | if (out_ != nullptr) { 46 | fclose(out_); 47 | out_ = nullptr; 48 | } 49 | if (minhashptr_ != nullptr) { 50 | MinhashCudaGenerator* gen = 51 | reinterpret_cast(minhashptr_); 52 | mhcuda_fini(gen); 53 | minhashptr_ = nullptr; 54 | } 55 | } 56 | 57 | bool HashDumper::doBatch() { 58 | if (idkeys_.empty()) { 59 | LOG(INFO) << "idkey is empty,..."; 60 | return true; 61 | } 62 | MinhashCudaGenerator* gen = 63 | reinterpret_cast(minhashptr_); 64 | MHCUDAResult result = mhcudaSuccess; 65 | auto startTime = std::chrono::high_resolution_clock::now(); 66 | result = mhcuda_calc(gen, data_.data(), indices_.data(), indptr_.data(), 67 | indptr_.size() - 1, result_buffer_); 68 | auto endTime = std::chrono::high_resolution_clock::now(); 69 | if (result != mhcudaSuccess) { 70 | LOG(FATAL) << "error mhcuda calc :" << result; 71 | return false; 72 | } 73 | auto duration = 74 | std::chrono::duration_cast(endTime - startTime) 75 | .count(); 76 | LOG(INFO) << "minhash calc time: " << duration << " ms"; 77 | for (int i = 0; i < static_cast(indptr_.size() - 1); i++) { 78 | size_t offset = 2 * kMinHashBins * i; 79 | std::string key = idkeys_[i]; 80 | pd::WeightedMinHash wmh; 81 | for (int k = 0; k < kMinHashBins; k++) { 82 | wmh.ks.push_back(result_buffer_[2 * k + offset]); 83 | wmh.ts.push_back(result_buffer_[2 * k + offset + 1]); 84 | } 85 | std::vector hashvals; 86 | std::vector cands; 87 | std::vector sims; 88 | bool gotDup = false; 89 | if (indexer_->query(wmh, hashvals, cands, sims)) { 90 | for (size_t i = 0; i < cands.size(); i++) { 91 | if (sims[i] >= kDedupThreshold) { 92 | gotDup = true; 93 | // LOG(INFO)<<"key:["<addWeightedMinHash(key, wmh, hashvals); 101 | } 102 | } 103 | } 104 | buffered_ = 0; 105 | indptr_.clear(); 106 | indptr_.push_back(0); 107 | indices_.clear(); 108 | data_.clear(); 109 | idkeys_.clear(); 110 | return true; 111 | } 112 | bool HashDumper::process(Page& page) { 113 | if (!featurizer_->Featurize(page)) { 114 | LOG(INFO) << "feature error....:" << page.title; 115 | return false; 116 | } 117 | 118 | int preCount = indptr_.back(); 119 | 120 | std::vector> features(page.features.begin(), 121 | page.features.end()); 122 | std::sort(features.begin(), features.end(), 123 | [](const std::pair& a, const std::pair& b) { 124 | return a.first < b.first; 125 | }); 126 | int num_feature = features.size(); 127 | indptr_.push_back(preCount + num_feature); 128 | idkeys_.push_back(page.idkey); 129 | for (int i = 0; i < num_feature; i++) { 130 | int col = features[i].first; 131 | indices_.emplace_back(col); 132 | data_.emplace_back(features[i].second); 133 | } 134 | // LOG(INFO)<<"process:"<= 10240) { 137 | LOG(INFO) << "will do one batch!"; 138 | if (!doBatch()) { 139 | return false; 140 | } 141 | } 142 | return true; 143 | } 144 | } // namespace pd 145 | -------------------------------------------------------------------------------- /dedup/hash_dumper.h: -------------------------------------------------------------------------------- 1 | #ifndef PD_DEDUP_FEATURE_DUMPER_H_ 2 | #define PD_DEDUP_FEATURE_DUMPER_H_ 3 | #include 4 | #include 5 | #include 6 | 7 | #include "common/page.h" 8 | 9 | namespace pd { 10 | class PageFeaturizer; 11 | class LshIndex; 12 | class HashDumper { 13 | public: 14 | HashDumper(const std::string& tokenizerPath,const std::string& outFile); 15 | virtual ~HashDumper(); 16 | virtual bool process(Page& page); 17 | virtual bool doBatch(); 18 | private: 19 | 20 | std::unique_ptr featurizer_; 21 | FILE* out_=nullptr; 22 | std::vector indptr_; 23 | std::vector indices_; 24 | std::vector data_; 25 | std::vector idkeys_; 26 | int buffered_=0; 27 | void* minhashptr_=nullptr; 28 | uint32_t *result_buffer_=nullptr; 29 | std::unique_ptr indexer_; 30 | }; 31 | } // namespace pd 32 | 33 | #endif // PD_DEDUP_FEATURE_DUMPER_H_ 34 | 35 | -------------------------------------------------------------------------------- /dedup/lsh_index.cc: -------------------------------------------------------------------------------- 1 | #include "dedup/lsh_index.h" 2 | 3 | namespace pd{ 4 | 5 | static int hash_func(const WeightedMinHash& wmh, int nband, int bidx, size_t mod){ 6 | size_t hash = 0; 7 | size_t numPerBand = wmh.ts.size()/nband; 8 | for(size_t i = 0; i < numPerBand; ++i){ 9 | hash = (hash * 60607)%mod + (wmh.ks[nband * bidx + i]*30637)%mod+ wmh.ts[nband * bidx + i]; 10 | hash = hash % mod; 11 | } 12 | return hash; 13 | } 14 | 15 | static void dohash(const WeightedMinHash& wmh, int nband, size_t mod, std::vector& hashvals){ 16 | for(int i=0;i& hashvls){ 28 | size_t nband=band_indexes_.size(); 29 | if(wmh.ks.size() % nband !=0 || nband ==0 ){ 30 | return false; 31 | } 32 | if(hashkeys_.count(key)>0){ 33 | return false; 34 | } 35 | int newIdx=all_keys_.size(); 36 | hashkeys_.insert(key); 37 | all_keys_.emplace_back(key); 38 | hash_values_.emplace_back(wmh); 39 | for(size_t i=0;i& hashvals, std::vector& cands,std::vector& sims){ 46 | size_t nband=band_indexes_.size(); 47 | if(wmh.ks.size() % nband !=0 || nband ==0 ){ 48 | return false; 49 | } 50 | dohash(wmh, nband, band_indexes_[0].size(),hashvals); 51 | absl::flat_hash_set candidates; 52 | for(size_t i=0;i& iset=band_indexes_[i][hashvals[i]]; 54 | for(size_t j=i+1;j& jset=band_indexes_[j][hashvals[j]]; 56 | for(auto key: jset){ 57 | if(iset.count(key)>0){ 58 | candidates.insert(key); 59 | } 60 | } 61 | } 62 | } 63 | for(auto key : candidates){ 64 | const WeightedMinHash& other = hash_values_[key]; 65 | cands.push_back(all_keys_[key]); 66 | sims.push_back(wmh.jaccard(other)); 67 | } 68 | return true; 69 | } 70 | } // namespace pd 71 | -------------------------------------------------------------------------------- /dedup/lsh_index.h: -------------------------------------------------------------------------------- 1 | #ifndef PD_DEDUP_LSH_INDEX_H_ 2 | #define PD_DEDUP_LSH_INDEX_H_ 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include "absl/container/flat_hash_set.h" 9 | #include "absl/container/flat_hash_map.h" 10 | 11 | namespace pd { 12 | struct WeightedMinHash{ 13 | std::vector ks; 14 | std::vector ts; 15 | float jaccard(const WeightedMinHash& other) const { 16 | int nn=ks.size(); 17 | int same=0; 18 | for(int i=0;i(same)/ static_cast(nn); 24 | } 25 | 26 | }; 27 | class LshIndex { 28 | public: 29 | LshIndex(int nband,int slotsPerBand); 30 | virtual ~LshIndex(); 31 | virtual bool addWeightedMinHash(const std::string& key, const WeightedMinHash& wmh,const std::vector& hashvls); 32 | virtual bool query(const WeightedMinHash& wmh,std::vector& hashvals,std::vector& cands, std::vector& sims); 33 | private: 34 | absl::flat_hash_set hashkeys_; 35 | std::vector all_keys_; 36 | std::vector hash_values_; 37 | std::vector >> band_indexes_; 38 | }; 39 | } // namespace pd 40 | 41 | #endif // PD_DEDUP_LSH_INDEX_H_ 42 | -------------------------------------------------------------------------------- /dedup/lsh_main.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "absl/flags/flag.h" 7 | #include "absl/flags/parse.h" 8 | #include "absl/flags/usage.h" 9 | #include "absl/strings/str_split.h" 10 | #include "absl/strings/ascii.h" 11 | #include "absl/strings/numbers.h" 12 | #include "absl/log/log.h" 13 | #include "absl/log/initialize.h" 14 | #include "absl/log/globals.h" 15 | #include "absl/log/check.h" 16 | #include "dedup/lsh_index.h" 17 | 18 | ABSL_FLAG(std::string, minhash_path, "/g/wudao_features.txt", 19 | "calced weighted min hash output file path"); 20 | ABSL_FLAG(int, nband, 8, "the band number of lsh"); 21 | ABSL_FLAG(int, band_slots, 99829, "hash number of same band"); 22 | ABSL_FLAG(float, dup_threshold, 0.85, "the sim hash threshold to judge as duplicate"); 23 | ABSL_FLAG(std::string, dup_key_path, "/g/dup_keys.txt", 24 | "the output duplicated keys file path"); 25 | int main(int argc, char* argv[]) { 26 | absl::SetProgramUsageMessage("Lsh Main"); 27 | absl::ParseCommandLine(argc, argv); 28 | absl::InitializeLog(); 29 | absl::SetStderrThreshold(absl::LogSeverity::kInfo); 30 | pd::LshIndex index(absl::GetFlag(FLAGS_nband), 31 | absl::GetFlag(FLAGS_band_slots)); 32 | std::ifstream inputFile(absl::GetFlag(FLAGS_minhash_path)); 33 | if (!inputFile) { 34 | LOG(WARNING) << "Failed to open the file:" 35 | << absl::GetFlag(FLAGS_minhash_path); 36 | return 1; 37 | } 38 | float dupThreshold=absl::GetFlag(FLAGS_dup_threshold); 39 | std::string line; 40 | uint32_t dupCount=0; 41 | uint32_t totalCount=0; 42 | std::ofstream outputFile(absl::GetFlag(FLAGS_dup_key_path)); 43 | while (std::getline(inputFile, line)) { 44 | absl::StripTrailingAsciiWhitespace(&line); 45 | std::vector ss = absl::StrSplit(line,' '); 46 | const std::string& key = ss[0]; 47 | CHECK_EQ(static_cast(ss.size()), 257)< hashvals; 57 | std::vector cands; 58 | std::vector sims; 59 | bool gotDup=false; 60 | totalCount +=1; 61 | if(index.query(wmh,hashvals,cands,sims)){ 62 | for(size_t i=0;i= dupThreshold){ 64 | gotDup=true; 65 | //LOG(INFO)<<"key:["<> ids; 21 | vkcom::Status status = encoder_->encode_as_ids({page.content}, &ids); 22 | if(!status.ok()){ 23 | return false; 24 | } 25 | auto pageIds = ids[0]; 26 | for(auto& id : pageIds){ 27 | float w = log(2.0+id/100.0); 28 | auto it = page.features.find(id); 29 | if(it== page.features.end()){ 30 | page.features.insert({id, w}); 31 | }else{ 32 | it->second += w; 33 | } 34 | } 35 | return true; 36 | } 37 | 38 | } // namespace pd -------------------------------------------------------------------------------- /dedup/page_featurizer.h: -------------------------------------------------------------------------------- 1 | #ifndef PD_DEDUP_PAGE_FEATURIZER_H_ 2 | #define PD_DEDUP_PAGE_FEATURIZER_H_ 3 | #include 4 | #include 5 | 6 | #include "common/page.h" 7 | namespace vkcom { 8 | class BaseEncoder; 9 | } // namespace vkcom 10 | namespace pd { 11 | class PageFeaturizer { 12 | public: 13 | PageFeaturizer(); 14 | virtual ~PageFeaturizer(); 15 | virtual bool Init(const std::string& vocabPath); 16 | virtual bool Featurize(Page& page); 17 | 18 | private: 19 | std::unique_ptr encoder_; 20 | }; 21 | } // namespace pd 22 | 23 | #endif // PD_DEDUP_PAGE_FEATURIZER_H_ 24 | -------------------------------------------------------------------------------- /distill/BUILD: -------------------------------------------------------------------------------- 1 | package(default_visibility = ["//visibility:public"]) 2 | 3 | load("@australis//:australis.bzl", "australis") 4 | 5 | licenses(["notice"]) 6 | 7 | australis( 8 | name = "marketing_detection", 9 | cc_namespace = "pd", 10 | py_deps = [], # Internal flax deps 11 | ) 12 | 13 | cc_binary( 14 | name ="gen_ex_main", 15 | srcs=[ 16 | "gen_ex_main.cc", 17 | ], 18 | deps=[ 19 | "@com_google_absl//absl/log", 20 | "@com_google_absl//absl/flags:flag", 21 | "@com_google_absl//absl/log:initialize", 22 | "@com_google_absl//absl/flags:parse", 23 | "@com_google_absl//absl/strings", 24 | "@tokme//:tokme", 25 | ], 26 | linkopts = [ 27 | '-lm', 28 | ], 29 | ) 30 | 31 | 32 | 33 | cc_binary( 34 | name ="train_marketing_detection", 35 | srcs=[ 36 | "train_marketing_detection.cc", 37 | ], 38 | linkstatic = 1, 39 | deps=[ 40 | ":marketing_detection_cc", 41 | "@com_google_absl//absl/log", 42 | "@com_google_absl//absl/log:check", 43 | "@com_google_absl//absl/flags:flag", 44 | "@com_google_absl//absl/log:initialize", 45 | "@com_google_absl//absl/flags:parse", 46 | "@com_google_absl//absl/strings", 47 | "@australis//australis:cpu_support", 48 | "@australis//australis:gpu_support", 49 | "@australis//australis:petri", 50 | "@australis//australis", 51 | "@sentencepiece//src:sentencepiece_processor", 52 | 53 | ], 54 | linkopts = [ 55 | '-lm', 56 | ], 57 | ) 58 | 59 | 60 | cc_binary( 61 | name ="distill_page_main", 62 | srcs=[ 63 | "distill_page_main.cc", 64 | ], 65 | linkstatic = 1, 66 | deps=[ 67 | ":marketing_detection_cc", 68 | "//common:page_producer", 69 | "@com_google_absl//absl/log", 70 | "@com_google_absl//absl/log:check", 71 | "@com_google_absl//absl/flags:flag", 72 | "@com_google_absl//absl/log:initialize", 73 | "@com_google_absl//absl/flags:parse", 74 | "@com_google_absl//absl/strings", 75 | "@australis//australis:cpu_support", 76 | "@australis//australis:gpu_support", 77 | "@australis//australis:petri", 78 | "@australis//australis", 79 | "@sentencepiece//src:sentencepiece_processor", 80 | 81 | ], 82 | linkopts = [ 83 | '-lm', 84 | ], 85 | ) 86 | -------------------------------------------------------------------------------- /distill/distill_page_main.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #include "absl/flags/flag.h" 8 | #include "absl/flags/parse.h" 9 | #include "absl/flags/usage.h" 10 | #include "absl/log/globals.h" 11 | #include "absl/log/initialize.h" 12 | #include "absl/log/log.h" 13 | 14 | #include "common/page_producer.h" 15 | #include "xla/pjrt/pjrt_client.h" 16 | #include "australis/australis.h" 17 | #include "australis/petri.h" 18 | 19 | #include "distill/marketing_detection.h" 20 | #include "sentencepiece_processor.h" 21 | 22 | ABSL_FLAG(std::string, wudao_dir, "/g/wudao", "wudao dataset dir"); 23 | ABSL_FLAG(std::string, distilled_output_path, "/g/distilled_pages.txt", 24 | "distilled output path"); 25 | ABSL_FLAG(std::string, tokenizer_path, "vocab/sp_tok.model", 26 | "the tokenizer model path"); 27 | ABSL_FLAG(std::string, model_path, "md_model.jax", 28 | "the marketing detection model path"); 29 | 30 | namespace fs = std::filesystem; 31 | const static int kMaxLen = 2048; 32 | void traverseDirectory(const std::string& path, 33 | std::vector& pathList) { 34 | for (const auto& entry : fs::directory_iterator(path)) { 35 | if (entry.is_directory()) { 36 | fs::path newdir = fs::path(path) / entry.path(); 37 | traverseDirectory(newdir.string(), pathList); 38 | } else if (entry.is_regular_file() && entry.path().string().size() > 5 && 39 | entry.path().string().substr(entry.path().string().size() - 5) == 40 | std::string(".json")) { 41 | fs::path newpath = fs::path(path) / entry.path(); 42 | pathList.emplace_back(newpath.string()); 43 | } 44 | } 45 | } 46 | 47 | struct DistillClient { 48 | pd::FlaxServing& serving; 49 | aux::PTree params; 50 | aux::Device device; 51 | FILE* output=nullptr; 52 | explicit DistillClient(aux::Device dev, pd::FlaxServing& serving_, const std::string& outputPath):serving(serving_),device(dev) { 53 | output = fopen(outputPath.c_str(), "w"); 54 | } 55 | ~DistillClient(){ 56 | if(output != nullptr){ 57 | fclose(output); 58 | output = nullptr; 59 | } 60 | 61 | } 62 | bool LoadModel(const std::string& model_path) { 63 | std::ifstream file(model_path, std::ios::binary); 64 | if (!file) { 65 | return false; 66 | } 67 | 68 | int numBuf = 0; 69 | file.read(reinterpret_cast(&numBuf), sizeof(int)); 70 | std::vector modelBuffers; 71 | for (int i = 0; i < numBuf; i++) { 72 | int ndim = 0; 73 | std::vector dims; 74 | int numData = 0; 75 | file.read(reinterpret_cast(&ndim), sizeof(int)); 76 | for (int k = 0; k < ndim; k++) { 77 | int dim = 0; 78 | file.read(reinterpret_cast(&dim), sizeof(int)); 79 | dims.push_back(dim); 80 | } 81 | file.read(reinterpret_cast(&numData), sizeof(int)); 82 | std::vector buffer(numData, 0); 83 | file.read(reinterpret_cast(buffer.data()), 84 | numData * sizeof(float)); 85 | 86 | modelBuffers.push_back( 87 | *aux::PTree::BufferRN(buffer,dims, device)); 88 | } 89 | params = aux::PTree::Tuple(std::move(modelBuffers)); 90 | file.close(); 91 | return true; 92 | } 93 | bool Predict(const sentencepiece::SentencePieceProcessor* encoder, int maxLength, 94 | const pd::Page& p, float& marketingScore) { 95 | 96 | std::vector tokens; 97 | encoder->Encode(p.content, &tokens).IgnoreError(); 98 | int textLen = tokens.size(); 99 | if (textLen > maxLength) { 100 | tokens = std::vector(tokens.begin(), tokens.begin() + maxLength); 101 | }else{ 102 | for(int i=textLen;i(absl::Span(tokens), 107 | {1, kMaxLen}, device); 108 | auto result = *(serving(params, x)); 109 | auto resultLiteral = *result.ToArray(); 110 | float rawy=resultLiteral.data()[0]; 111 | marketingScore = 1.0/(1.0+exp(0-rawy)); 112 | if(marketingScore>=0.6){ 113 | fprintf(output,"%s\t%.4f\n",p.idkey.c_str(),marketingScore); 114 | } 115 | return true; 116 | } 117 | }; 118 | 119 | void doWork(pd::PageProducer* producer,const sentencepiece::SentencePieceProcessor* encoder, DistillClient* distillClient) { 120 | pd::Page p; 121 | uint64_t count = 0; 122 | producer->takeOnePage(p, true); 123 | while (p.idkey != pd::PageProducer::EOFPageHashKey) { 124 | float score=0; 125 | CHECK(distillClient->Predict(encoder,kMaxLen,p, score)); 126 | count += 1; 127 | if (count % 20000 == 0) { 128 | LOG(INFO) << count << " pages processed!!"; 129 | } 130 | producer->takeOnePage(p, true); 131 | } 132 | LOG(INFO) << count << " total pages processed!!"; 133 | } 134 | 135 | int main(int argc, char* argv[]) { 136 | absl::SetProgramUsageMessage("Dedup Main"); 137 | absl::ParseCommandLine(argc, argv); 138 | absl::InitializeLog(); 139 | absl::SetStderrThreshold(absl::LogSeverity::kInfo); 140 | std::unique_ptr encoder; 141 | encoder.reset( 142 | new sentencepiece::SentencePieceProcessor()); 143 | if (!encoder->Load(absl::GetFlag(FLAGS_tokenizer_path)).ok()) { 144 | LOG(ERROR) << "init tokenizer error,path:" 145 | << absl::GetFlag(FLAGS_tokenizer_path); 146 | return -1; 147 | } 148 | pd::PageProducer pageProducer; 149 | std::vector pathList; 150 | traverseDirectory(absl::GetFlag(FLAGS_wudao_dir), pathList); 151 | LOG(INFO) << "got " << pathList.size() << " paths!"; 152 | auto client = *aux::Client::GetDefault(); 153 | aux::Device dev = client.LocalDevices()[0]; 154 | auto serving = *pd::FlaxServing::Load(client); 155 | DistillClient distillClient(dev,serving,absl::GetFlag(FLAGS_distilled_output_path)); 156 | pageProducer.initByFileList(pathList); 157 | CHECK(distillClient.LoadModel(absl::GetFlag(FLAGS_model_path))); 158 | doWork(&pageProducer,encoder.get(), &distillClient); 159 | pageProducer.shutdown(); 160 | return 0; 161 | } 162 | -------------------------------------------------------------------------------- /distill/gen_distilled_page.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | 5 | def loadBlacklist(paths): 6 | retSet = set([]) 7 | for p in paths: 8 | with open(p,"r",encoding="utf-8") as inp: 9 | for line in inp: 10 | ss = line.strip().split('\t') 11 | retSet.add(ss[0]) 12 | return retSet 13 | 14 | def processCorpus(inputdir, dataTypes, blacklistPath,outputPath): 15 | total =0 16 | valid = 0 17 | prefilter =0 18 | tooShort =0 19 | blackSet= loadBlacklist(blacklistPath) 20 | types = set(dataTypes) 21 | print(f"got {len(blackSet)} keys in blacklist") 22 | with open(outputPath,"w",encoding='utf8') as outp: 23 | #list inputdir 24 | for file in os.listdir(inputdir): 25 | if file.endswith(".json"): 26 | with open(os.path.join(inputdir, file),"r") as inp: 27 | data = inp.read() 28 | objarr = json.loads(data) 29 | for obj in objarr: 30 | title = obj['title'] 31 | text = obj['content'] 32 | type= obj['dataType'] 33 | ukey = obj['uniqueKey'] 34 | total +=1 35 | if total%100000==0: 36 | print(f"processed {total}, prefilter:{prefilter}, valid:{valid}, short :{tooShort}.....") 37 | 38 | 39 | if type in types: 40 | prefilter +=1 41 | if ukey not in blackSet: 42 | if len(title) < 5 or len(text) <10: 43 | tooShort +=1 44 | continue 45 | json.dump(obj,outp, ensure_ascii=False) 46 | outp.write("\n") 47 | valid +=1 48 | 49 | print(f"processed {total}, prefilter:{prefilter}, valid:{valid}, short :{tooShort}.....") 50 | 51 | 52 | 53 | if __name__ == '__main__': 54 | # processCorpus("/g/wudao/",'博客',["/g/duplicate_keys.txt","/g/distilled_pages.txt"],'/g/distilled_blog_pages.jsonl') 55 | # processCorpus("/g/wudao/",['百科'],[],'/g/baike_pages.jsonl') 56 | # processCorpus("/g/wudao/",['经验','小红书攻略','健康','亲子','医学问答','生活','理论','期货','观点','党建','信托','国学','评论','法律','百家号文章','科普文章', '孕育常识'],[],'/g/other2_pages.jsonl') 57 | processCorpus("/g/wudao/",["经济","娱乐","文化","军事","游戏","汽车","科技","农业","体育","国际","教育","社会","旅行","房产","股票","豆瓣话题","资讯","新闻"],["/g/duplicate_keys.txt"],'/g/other1_pages.jsonl') 58 | -------------------------------------------------------------------------------- /distill/gen_ex_main.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #include "absl/flags/flag.h" 8 | #include "absl/flags/parse.h" 9 | #include "absl/flags/usage.h" 10 | #include "absl/log/globals.h" 11 | #include "absl/log/initialize.h" 12 | #include "absl/log/log.h" 13 | #include "absl/strings/str_split.h" 14 | #include "absl/strings/str_join.h" 15 | #include "absl/strings/numbers.h" 16 | 17 | #include "youtokentome/cpp/bpe.h" 18 | 19 | ABSL_FLAG(std::string, tagged_corpus_path, "/g/chatgpt_output.txt", 20 | "the tagged corpuse path"); 21 | ABSL_FLAG(std::string, output_path, "/g/tagged_dataset.txt", 22 | "segmented corpuse output path"); 23 | ABSL_FLAG(std::string, tokenizer_path, "vocab/tokme.model", 24 | "the tokenizer model path"); 25 | ABSL_FLAG(int, max_length, 4096, "the maximum length of text"); 26 | 27 | int main(int argc, char* argv[]) { 28 | absl::SetProgramUsageMessage("Gen Ex Main"); 29 | absl::ParseCommandLine(argc, argv); 30 | absl::InitializeLog(); 31 | absl::SetStderrThreshold(absl::LogSeverity::kInfo); 32 | std::unique_ptr encoder; 33 | vkcom::Status status; 34 | encoder.reset(new vkcom::BaseEncoder(absl::GetFlag(FLAGS_tokenizer_path),1,&status)); 35 | if(!status.ok()){ 36 | LOG(ERROR)<<"init tokenizer error,path:"< vs = absl::StrSplit(line, absl::ByChar('\t')); 48 | int quality=0, saleGrade =0; 49 | absl::SimpleAtoi(vs[0],&quality); 50 | absl::SimpleAtoi(vs[1],&saleGrade); 51 | int g = quality/10; 52 | if(g>=10){ 53 | g = 9; 54 | } 55 | qualityGrades[g]+=1; 56 | g = saleGrade /10; 57 | if(g>=10){ 58 | g = 9; 59 | } 60 | saleGrades[g]+=1; 61 | std::vector> ids; 62 | status = encoder->encode_as_ids({vs[2]}, &ids); 63 | if(!status.ok() || ids.empty()){ 64 | LOG(ERROR)<<"encode_as_ids error!"; 65 | continue; 66 | } 67 | std::vector tokens=ids[0]; 68 | int textLen = tokens.size(); 69 | if(textLen > maxLength){ 70 | tokens = std::vector(tokens.begin(),tokens.begin()+maxLength); 71 | } 72 | std::string text = absl::StrJoin(tokens, " "); 73 | ofs< Lowered: 60 | model = MarketingDetectionModel(vocab_size=64000, hidden_size=256) 61 | 62 | tx = optax.adam(0.0003) 63 | 64 | @jax.jit 65 | def init(): 66 | init_rng = jax.random.PRNGKey(42) 67 | params = model.init(init_rng, jnp.ones((128, 2048), dtype=jnp.int32)) 68 | return params, tx.init(params) 69 | 70 | init_fn = init.lower() 71 | 72 | @jax.jit 73 | def serving(params,x): 74 | return model.apply(params, x, False) 75 | 76 | @jax.jit 77 | def batchServing(params,x): 78 | return model.apply(params, x, False) 79 | 80 | 81 | @partial(jax.jit, static_argnums=(1,)) 82 | def save(params, path="best_model.npz"): 83 | jax.numpy.save(path, jax.device_get(params)) 84 | 85 | 86 | @jax.jit 87 | def optimizer_step(params, opt_state, x,y): 88 | def fwd(params): 89 | logits = model.apply(params, x, True,rngs={'dropout': jax.random.PRNGKey(42)}) 90 | return optax.sigmoid_binary_cross_entropy(logits,y).mean() 91 | 92 | 93 | loss,grads = jax.value_and_grad(fwd)(params) 94 | updates, opt_state = tx.update(grads, opt_state) 95 | params = optax.apply_updates(params, updates) 96 | return params, opt_state, loss 97 | 98 | params, opt_state = jax.eval_shape(init) 99 | optimizer_step_lowered = optimizer_step.lower( 100 | params, opt_state, jax.ShapeDtypeStruct((128, 2048), jnp.int32),jax.ShapeDtypeStruct((128, 1), jnp.int32)) 101 | 102 | serving_lowered = serving.lower( 103 | params, jax.ShapeDtypeStruct((1, 2048), jnp.int32)) 104 | 105 | batch_serving_lowered = batchServing.lower( 106 | params, jax.ShapeDtypeStruct((128, 2048), jnp.int32)) 107 | save_lowered = save.lower(params) 108 | return [ 109 | ("flax_init", init_fn), 110 | ("flax_optimizer_step", optimizer_step_lowered), 111 | ("flax_serving",serving_lowered), 112 | ("flax_batch_serving",batch_serving_lowered), 113 | ("flax_save",save_lowered), 114 | ] 115 | 116 | 117 | if __name__ == "__main__": 118 | exporter.run(lower) 119 | -------------------------------------------------------------------------------- /distill/train_marketing_detection.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include "absl/flags/flag.h" 10 | #include "absl/flags/parse.h" 11 | #include "absl/flags/usage.h" 12 | #include "absl/log/check.h" 13 | #include "absl/log/globals.h" 14 | #include "absl/log/initialize.h" 15 | #include "absl/log/log.h" 16 | #include "absl/strings/numbers.h" 17 | #include "absl/strings/str_join.h" 18 | #include "absl/strings/str_split.h" 19 | 20 | #include "xla/pjrt/pjrt_client.h" 21 | #include "australis/australis.h" 22 | #include "australis/petri.h" 23 | 24 | #include "distill/marketing_detection.h" 25 | #include "sentencepiece_processor.h" 26 | // #include "youtokentome/cpp/bpe.h" 27 | 28 | ABSL_FLAG(std::string, train_data_path, "data/train.txt", "training data path"); 29 | ABSL_FLAG(std::string, test_data_path, "data/test.txt", "the test data path"); 30 | ABSL_FLAG(std::string, tokenizer_path, "vocab/sp_tok.model", 31 | "the tokenizer model path"); 32 | ABSL_FLAG(std::string, out_model_path, "md_model.jax", "model output path"); 33 | ABSL_FLAG(int, max_epoch, 50, "the maximum epoch number of training"); 34 | 35 | const static int kBatchSize = 128; 36 | const static int kMaxLen = 2048; 37 | 38 | static std::vector kInputBuffer(kBatchSize* kMaxLen, 0); 39 | static std::vector kTrainBuffer(kBatchSize* kMaxLen, 0); 40 | static std::vector kTargetBuffer(kBatchSize, 0); 41 | 42 | struct TrainEx { 43 | int marketing; 44 | std::vector data; 45 | }; 46 | absl::StatusOr> Unpack2Tuple( 47 | absl::StatusOr input) { 48 | auto tmp = *aux::PTree::DestructureTuple(std::move(input)); 49 | if (tmp.size() != 2) { 50 | return absl::InvalidArgumentError(absl::StrCat("Wrong size: ", tmp.size())); 51 | } 52 | return std::tuple(std::move(tmp[0]), 53 | std::move(tmp[1])); 54 | } 55 | 56 | absl::StatusOr> Unpack3Tuple( 57 | absl::StatusOr input) { 58 | auto tmp = *aux::PTree::DestructureTuple(std::move(input)); 59 | if (tmp.size() != 3) { 60 | return absl::InvalidArgumentError(absl::StrCat("Wrong size: ", tmp.size())); 61 | } 62 | return std::tuple( 63 | std::move(tmp[0]), std::move(tmp[1]), std::move(tmp[2])); 64 | } 65 | 66 | // vector ids; 67 | // sp.Encode("hello world.", &ids).IgnoreError(); 68 | static void convertLineToEx(const std::string& line, 69 | const sentencepiece::SentencePieceProcessor* encoder, int maxLength, 70 | TrainEx& ex) { 71 | std::vector vs = absl::StrSplit(line, absl::ByChar('\t')); 72 | int quality = 0, saleGrade = 0; 73 | absl::SimpleAtoi(vs[0], &quality); 74 | absl::SimpleAtoi(vs[1], &saleGrade); 75 | ex.marketing = saleGrade >= 30.0 ? 1 : 0; 76 | std::vector tokens; 77 | encoder->Encode(vs[2], &tokens).IgnoreError(); 78 | int textLen = tokens.size(); 79 | if (textLen > maxLength) { 80 | tokens = std::vector(tokens.begin(), tokens.begin() + maxLength); 81 | } 82 | ex.data = tokens; 83 | } 84 | void loadTestDataset(const std::string& path, const sentencepiece::SentencePieceProcessor* encoder, 85 | std::vector& testData) { 86 | std::ifstream file(path); 87 | std::string line; 88 | while (std::getline(file, line)) { 89 | TrainEx ex; 90 | convertLineToEx(line, encoder, kMaxLen, ex); 91 | testData.push_back(ex); 92 | } 93 | } 94 | 95 | void fillTrainData(const std::vector& inputData) { 96 | CHECK_EQ(inputData.size(), kBatchSize); 97 | for (int k = 0; k < kBatchSize; k++) { 98 | for (int j = 0; j < inputData[k].data.size(); j++) { 99 | kTrainBuffer[k * kMaxLen + j] = inputData[k].data[j]; 100 | } 101 | for (int j = inputData[k].data.size(); j < kMaxLen; j++) { 102 | kTrainBuffer[k * kMaxLen + j] = 0; 103 | } 104 | kTargetBuffer[k] = inputData[k].marketing; 105 | } 106 | } 107 | float evalTestDataset(const std::vector& testData, 108 | pd::FlaxBatchServing& batchServing, aux::Device dev, 109 | const aux::PTree& param, int maxLength) { 110 | int nn = (testData.size() - 1) / kBatchSize + 1; 111 | float sumDiff = 0; 112 | for (int i = 0; i < nn; i++) { 113 | int start = i * kBatchSize; 114 | int end = std::min((i + 1) * kBatchSize, static_cast(testData.size())); 115 | for (int k = start; k < end; k++) { 116 | int off = k - start; 117 | for (int j = 0; j < testData[k].data.size(); j++) { 118 | kInputBuffer[off * maxLength + j] = testData[k].data[j]; 119 | } 120 | for (int j = testData[k].data.size(); j < maxLength; j++) { 121 | kInputBuffer[off * maxLength + j] = 0; 122 | } 123 | } 124 | for (int k = end; k < (i + 1) * kBatchSize; k++) { 125 | int off = k - start; 126 | for (int j = 0; j < maxLength; j++) { 127 | kInputBuffer[off * maxLength + j] = 0; 128 | } 129 | } 130 | auto x = *aux::PTree::BufferRN(absl::Span(kInputBuffer), 131 | {kBatchSize, kMaxLen}, dev); 132 | auto result = *(batchServing(param, x)); 133 | auto resultLiteral = *result.ToArray(); 134 | for (int k = start; k < end; k++) { 135 | const TrainEx& ex = testData[k]; 136 | int predy = resultLiteral.data()[k - start] > 0 ? 1 : 0; 137 | sumDiff += predy == ex.marketing ? 1 : 0; 138 | } 139 | } 140 | LOG(INFO) << "test accuracy: " << sumDiff / testData.size(); 141 | return sumDiff / testData.size(); 142 | } 143 | bool saveModel(const std::string& path, 144 | const std::vector& params) { 145 | std::ofstream file(path, std::ios::binary); 146 | if (!file) { 147 | return false; 148 | } 149 | int numBuf = params.size(); 150 | file.write(reinterpret_cast(&numBuf), sizeof(int)); 151 | for (int i = 0; i < numBuf; i++) { 152 | const aux::DeviceArray* p = params[i]; 153 | auto buffers = p->buffers(); 154 | auto shape =buffers[0]->on_device_shape(); 155 | int ndims=shape.dimensions_size(); 156 | file.write(reinterpret_cast(&ndims), sizeof(int)); 157 | int total =1; 158 | for(int k=0;k(&dim), sizeof(int)); 161 | total *= dim; 162 | } 163 | auto ia = *p->ToArrays(); 164 | auto dataArr = ia.data()->data(); 165 | int numData = dataArr.size(); 166 | CHECK_EQ(total, numData); 167 | file.write(reinterpret_cast(&numData), sizeof(int)); 168 | for (int j = 0; j < numData; j++) { 169 | file.write(reinterpret_cast(&dataArr[j]), sizeof(float)); 170 | } 171 | } 172 | file.close(); 173 | return true; 174 | } 175 | int main(int argc, char* argv[]) { 176 | absl::SetProgramUsageMessage("Train MarketingDetection Main"); 177 | absl::ParseCommandLine(argc, argv); 178 | absl::InitializeLog(); 179 | absl::SetStderrThreshold(absl::LogSeverity::kInfo); 180 | std::unique_ptr encoder; 181 | encoder.reset( 182 | new sentencepiece::SentencePieceProcessor()); 183 | if (!encoder->Load(absl::GetFlag(FLAGS_tokenizer_path)).ok()) { 184 | LOG(ERROR) << "init tokenizer error,path:" 185 | << absl::GetFlag(FLAGS_tokenizer_path); 186 | return -1; 187 | } 188 | int maxEpoch = absl::GetFlag(FLAGS_max_epoch); 189 | std::vector testData; 190 | loadTestDataset(absl::GetFlag(FLAGS_test_data_path), encoder.get(), testData); 191 | LOG(INFO) << "load " << testData.size() << " test examples"; 192 | auto client = *aux::Client::GetDefault(); 193 | aux::Device dev = client.LocalDevices()[0]; 194 | auto init_fn = *pd::FlaxInit::Load(client); 195 | auto optimizer_step_fn = *pd::FlaxOptimizerStep::Load(client); 196 | auto batchServing = *pd::FlaxBatchServing::Load(client); 197 | 198 | auto [params, opt_state] = *Unpack2Tuple(init_fn()); 199 | LOG(INFO) << "inited model , num buffer in weights:" << params.num_buffers(); 200 | float bestAccuracy = 201 | evalTestDataset(testData, batchServing, dev, params, kMaxLen); 202 | std::string line; 203 | std::vector trainData; 204 | uint32_t step = 0; 205 | for (int epoch = 0; epoch < maxEpoch; epoch++) { 206 | LOG(INFO) << "begin epoch " << epoch; 207 | std::ifstream file(absl::GetFlag(FLAGS_train_data_path)); 208 | aux::PTree lossPT; 209 | while (std::getline(file, line)) { 210 | TrainEx ex; 211 | convertLineToEx(line, encoder.get(), kMaxLen, ex); 212 | trainData.push_back(ex); 213 | if (trainData.size() == kBatchSize) { 214 | fillTrainData(trainData); 215 | auto x = *aux::PTree::BufferRN(absl::Span(kTrainBuffer), 216 | {kBatchSize, kMaxLen}, dev); 217 | auto y = *aux::PTree::BufferRN(absl::Span(kTargetBuffer), 218 | {kBatchSize, 1}, dev); 219 | 220 | std::tie(params, opt_state, lossPT) = 221 | *Unpack3Tuple(optimizer_step_fn(params, opt_state, x, y)); 222 | step += 1; 223 | auto resultLiteral = *lossPT.ToArray(); 224 | if (step % 20 == 0) { 225 | LOG(INFO) << "Step:" << step 226 | << ",Loss: " << resultLiteral.data()[0]; 227 | } 228 | trainData.clear(); 229 | } 230 | } 231 | float accuracy = 232 | evalTestDataset(testData, batchServing, dev, params, kMaxLen); 233 | if (accuracy > bestAccuracy) { 234 | bestAccuracy = accuracy; 235 | std::vector flattened_params; 236 | CHECK(params.FlattenTo(&flattened_params).ok()); 237 | saveModel(absl::GetFlag(FLAGS_out_model_path), flattened_params); 238 | LOG(INFO) << "new best model:" << bestAccuracy; 239 | } 240 | } 241 | LOG(INFO) << "Finished, best accuracy:" << bestAccuracy; 242 | return 0; 243 | } 244 | -------------------------------------------------------------------------------- /thirdparty/BUILD: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EastTower16/LLMDataDistill/3ff93f9cddba0d91099e87281125c29e592937f8/thirdparty/BUILD -------------------------------------------------------------------------------- /thirdparty/concurrentqueue.BUILD: -------------------------------------------------------------------------------- 1 | 2 | package(default_visibility = ["//visibility:public"]) 3 | 4 | 5 | cc_library( 6 | name="concurrentqueue", 7 | srcs=[ 8 | 'lightweightsemaphore.h', 9 | ], 10 | defines=[ 11 | ], 12 | hdrs=[ 13 | "concurrentqueue.h", 14 | "blockingconcurrentqueue.h" 15 | ], 16 | ) 17 | -------------------------------------------------------------------------------- /thirdparty/minhashcuda.BUILD: -------------------------------------------------------------------------------- 1 | # load("@rules_cuda//cuda:defs.bzl", "cuda_library") 2 | load("@local_config_cuda//cuda:build_defs.bzl", "cuda_library") 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | cuda_library( 6 | name="kernel", 7 | srcs=[ 8 | 'wrappers.h', 9 | 'kernel.cu.cc', 10 | 11 | ], 12 | hdrs=[ 13 | 'private.h', 14 | 'minhashcuda.h', 15 | ], 16 | deps=[ 17 | "@local_config_cuda//cuda:cuda", 18 | ], 19 | ) 20 | 21 | cc_library( 22 | name="minhashcuda", 23 | srcs=[ 24 | 'minhashcuda.cc', 25 | ], 26 | defines=[ 27 | "CUDA_ARCH=86", 28 | ], 29 | hdrs=[ 30 | ], 31 | deps=[ 32 | ":kernel", 33 | ] 34 | ) 35 | -------------------------------------------------------------------------------- /vocab/sp_tok.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EastTower16/LLMDataDistill/3ff93f9cddba0d91099e87281125c29e592937f8/vocab/sp_tok.model --------------------------------------------------------------------------------