├── .env ├── requirements-apt.txt ├── requirements.txt ├── doc ├── architecture.png ├── pytorch_demo │ ├── requirements.txt │ ├── .gitignore │ ├── Makefile │ ├── README.md │ ├── LICENSE │ ├── custom_ops.cpp │ ├── Preprocessing.ipynb │ ├── kge_mapping.py │ └── kge_training.py ├── system_architecture.png ├── resources.md ├── development.md └── design.md ├── tests ├── main.cpp ├── testPoplar.cpp ├── python │ ├── test_poplar_kge_ensemble.py │ ├── test_poplar_kge_dataset.py │ ├── reference_model.py │ └── test_poplar_kge.py ├── testPoplarKge.cpp └── fructose │ ├── testSmallTraining.cpp │ └── testFructose.cpp ├── .clang-format ├── .gitignore ├── requirements-dev.txt ├── .gitmodules ├── setup.cfg ├── scripts ├── wandb_sweep_template.yaml ├── run_dataset_benchmark.py ├── run_training.py └── run_profile.py ├── src ├── fructose │ ├── frnn.hpp │ ├── frnn.cpp │ └── fructose.hpp ├── poplar_extensions │ ├── distance.hpp │ ├── l1distance.codelet.cpp │ ├── l2distance.codelet.cpp │ └── distance.cpp ├── lib.cpp ├── poplar_kge.hpp ├── python │ ├── poplar_kge_ensemble.py │ ├── poplar_kge_utility.py │ └── poplar_kge_dataset.py └── pag │ └── pag.hpp ├── LICENSE ├── .github └── workflows │ └── ci.yaml ├── README.md └── dev /.env: -------------------------------------------------------------------------------- 1 | PYTHONPATH=${PYTHONPATH}:build:src/python:tests/python 2 | -------------------------------------------------------------------------------- /requirements-apt.txt: -------------------------------------------------------------------------------- 1 | ninja-build 2 | clang 3 | clang-format 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ogb==1.3.4 2 | torch==1.10.2 3 | wandb==0.13.3 4 | -------------------------------------------------------------------------------- /doc/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore/distributed-kge-poplar/HEAD/doc/architecture.png -------------------------------------------------------------------------------- /doc/pytorch_demo/requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.6.2 2 | ogb==1.3.5 3 | pandas==1.5.2 4 | seaborn==0.12.1 5 | -------------------------------------------------------------------------------- /doc/system_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore/distributed-kge-poplar/HEAD/doc/system_architecture.png -------------------------------------------------------------------------------- /tests/main.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | 3 | #define CATCH_CONFIG_MAIN 4 | #include 5 | -------------------------------------------------------------------------------- /doc/pytorch_demo/.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | __pycache__ 3 | .venv 4 | .venvs 5 | .vscode 6 | 7 | /build 8 | /data 9 | /local 10 | /out 11 | -------------------------------------------------------------------------------- /.clang-format: -------------------------------------------------------------------------------- 1 | BasedOnStyle: Chromium 2 | IndentWidth: 4 3 | ColumnLimit: 100 4 | AllowShortFunctionsOnASingleLine: Inline 5 | AllowShortIfStatementsOnASingleLine: true 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .ninja_deps 2 | .ninja_log 3 | .ssub.yaml 4 | .venv 5 | .vscode 6 | __pycache__ 7 | 8 | /build 9 | /data 10 | /local 11 | /localdata 12 | /out 13 | /third_party/poplar 14 | /wandb 15 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | -r requirements.txt 2 | black==24.3.0 3 | flake8==5.0.4 4 | isort==5.10.1 5 | jupyterlab 6 | matplotlib 7 | mypy==0.971 8 | pandas 9 | pytest 10 | seaborn 11 | types-dataclasses==0.6.6 12 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/pybind11"] 2 | path = third_party/pybind11 3 | url = https://github.com/pybind/pybind11.git 4 | [submodule "third_party/catch2"] 5 | path = third_party/catch2 6 | url = https://github.com/catchorg/Catch2.git 7 | -------------------------------------------------------------------------------- /doc/pytorch_demo/Makefile: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | 3 | CXX ?= g++ 4 | CXXFLAGS = -Wall -Wextra -Werror -std=c++17 -O2 -g -fPIC -DONNX_NAMESPACE=onnx 5 | 6 | build/custom_ops.so: custom_ops.cpp 7 | mkdir -p build && $(CXX) $(CXXFLAGS) -shared $^ -o $@ -Wl,--no-undefined -lpoplar -lpopart -lgcl 8 | -------------------------------------------------------------------------------- /doc/resources.md: -------------------------------------------------------------------------------- 1 | # Code and supplementary resources for BESS 2 | 3 | - [Technical report](https://arxiv.org/abs/2211.12281) 4 | - [Code](https://github.com/graphcore/distributed-kge-poplar) 5 | - [Paperspace demo](https://ipu.dev/3QwfKJS) 6 | - [Raw results](https://github.com/graphcore/distributed-kge-poplar/blob/resources/2022-ogb-submission/configs_and_results.json) 7 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [mypy] 2 | strict = true 3 | pretty = true 4 | check_untyped_defs = true 5 | show_error_codes = true 6 | ignore_missing_imports = true 7 | allow_any_generics = true 8 | 9 | [flake8] 10 | # Required to match 'black' 11 | ignore = E203,W503 12 | max-line-length = 120 13 | 14 | [isort] 15 | # See https://black.readthedocs.io/en/stable/compatible_configs.html 16 | multi_line_output = 3 17 | include_trailing_comma = True 18 | force_grid_wrap = 0 19 | use_parentheses = True 20 | ensure_newline_before_comments = True 21 | line_length = 88 22 | 23 | [tool:pytest] 24 | filterwarnings = 25 | ignore::outdated.OutdatedCacheFailedWarning 26 | -------------------------------------------------------------------------------- /doc/pytorch_demo/README.md: -------------------------------------------------------------------------------- 1 | # Interactive KGE in PyTorch 2 | 3 | [![Gradient](https://assets.paperspace.io/img/gradient-badge.svg)](https://ipu.dev/3QwfKJS) 4 | 5 | An interactive training and evaluation demo of Knowledge Graph Embedding (KGE) model training on IPU in PyTorch. 6 | 7 | This demo is designed to be run interactively using an IPU notebook service. Please see [KgeModelling.ipynb](KgeModelling.ipynb) for a complete walkthrough. 8 | 9 | _Note: requires Poplar SDK >= 3.1._ 10 | 11 | --- 12 | 13 | Included code is licensed under the [MIT license](LICENSE). Supporting data files are licensed as detailed in the accompanying `LICENSE.md`. 14 | -------------------------------------------------------------------------------- /scripts/wandb_sweep_template.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | 3 | command: 4 | - ./dev 5 | - train 6 | method: bayes 7 | metric: 8 | goal: maximize 9 | name: valid_mrr 10 | # N.B to specify nested parameters, you must specify 11 | # that the parent parameter has parameters... 12 | parameters: 13 | training: 14 | parameters: 15 | learning_rate: 16 | distribution: log_uniform_values 17 | max: 0.001 18 | min: 1e-05 19 | feature_regularisation: 20 | parameters: 21 | weight: 22 | distribution: log_uniform_values 23 | max: 0.0001 24 | min: 1e-06 25 | learning_rate_modifiers: 26 | value: {} 27 | -------------------------------------------------------------------------------- /scripts/run_dataset_benchmark.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | 3 | import cProfile 4 | import itertools as it 5 | import pstats 6 | 7 | import poplar_kge as kge 8 | import poplar_kge_dataset as kge_data 9 | import poplar_kge_utility as kge_utility 10 | 11 | profile = cProfile.Profile() 12 | 13 | settings = kge.Settings.create_wikikg90mv2() 14 | settings.logging.wandb = False 15 | 16 | settings.prepare() 17 | logger = kge_utility.Logger(settings, __file__) 18 | 19 | data = kge_data.RawData.load(settings) 20 | logger.log("load_data", {}) 21 | 22 | dataset = kge_data.Dataset.load(data, settings) 23 | logger.log("build_index", {}) 24 | 25 | profile.enable() 26 | 27 | count = 10 28 | list(it.islice(dataset.batches(), count)) 29 | logger.log("sample_batch", dict(count=count)) 30 | 31 | profile.disable() 32 | pstats.Stats(profile).sort_stats("cumtime").print_stats() 33 | -------------------------------------------------------------------------------- /src/fructose/frnn.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | 3 | #ifndef FRNN_HPP 4 | #define FRNN_HPP 5 | 6 | #include "fructose.hpp" 7 | 8 | /** 9 | * Fructose-Neural-Networks (fr::nn), additional functions for implementing neural nets. 10 | */ 11 | namespace fr::nn { 12 | 13 | // Ops 14 | Tensor relu(const Tensor& tensor); 15 | Tensor softmaxCrossEntropy(const Tensor& logits, const Tensor& labels); 16 | Tensor dropout(const Tensor& a, float dropProbability); 17 | 18 | // Optimisers 19 | void sgd(const Tensor& tensor, const Tensor& learningRate); 20 | 21 | struct AdamParams { 22 | float betaM; 23 | float betaV; 24 | float epsilon; 25 | float weightDecay; 26 | }; 27 | Tensor adamStepSizeAutoIncrement(const Tensor& step, 28 | const Tensor& learningRate, 29 | const AdamParams& params); 30 | void adam(const Tensor& tensor, 31 | const Tensor& momentum, 32 | const Tensor& variance, 33 | const Tensor& stepSize, 34 | const AdamParams& params); 35 | 36 | } // namespace fr::nn 37 | 38 | #endif // FRNN_HPP 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Graphcore Ltd. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /doc/pytorch_demo/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Graphcore Ltd. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tests/testPoplar.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | /** 12 | * Not a 'real' test case, but somewhere to play around with Poplar directly. 13 | */ 14 | TEST_CASE("Manual poplar", "[poplar]") { 15 | auto device = poplar::Device::createCPUDevice(); 16 | 17 | poplar::Graph graph(device.getTarget()); 18 | popops::addCodelets(graph); 19 | poplin::addCodelets(graph); 20 | poplar::program::Sequence prog; 21 | 22 | auto a = 23 | graph.addConstant(poplar::FLOAT, {1, 2, 5}, std::vector(1 * 2 * 5, 1.0f)); 24 | graph.setTileMapping(a, 0); 25 | auto b = 26 | graph.addConstant(poplar::FLOAT, {1, 5, 7}, std::vector(1 * 5 * 7, 1.0f)); 27 | graph.setTileMapping(b, 0); 28 | auto c = poplin::matMulGrouped(graph, a, b, prog, poplar::FLOAT); 29 | 30 | // prog.add(poplar::program::PrintTensor("c", c)); 31 | 32 | poplar::Engine engine(graph, prog); 33 | engine.loadAndRun(device); 34 | } 35 | -------------------------------------------------------------------------------- /tests/python/test_poplar_kge_ensemble.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | 3 | from typing import Any, Callable, Dict 4 | 5 | import numpy as np 6 | import poplar_kge_ensemble as kge_ensemble 7 | import pytest 8 | 9 | 10 | @pytest.mark.parametrize( 11 | "ensemble_fn,ensemble_args", 12 | [ 13 | (kge_ensemble.mean_ensemble, dict(power=-1)), 14 | (kge_ensemble.mean_ensemble, dict(power=-0.5)), 15 | (kge_ensemble.mean_ensemble, dict(power=1)), 16 | (kge_ensemble.mean_ensemble, dict(power=0)), 17 | (kge_ensemble.mean_ensemble, dict(power=-1, model_weight=np.arange(27))), 18 | (kge_ensemble.median_ensemble, dict()), 19 | ], 20 | ids=str, 21 | ) 22 | def test_ensemble( 23 | ensemble_fn: Callable[..., np.ndarray], ensemble_args: Dict[str, Any] 24 | ) -> None: 25 | random = np.random.default_rng(100) 26 | 27 | n_model = 27 28 | n_example = 23 29 | n_prediction = 30 30 | n_entity = 100 31 | 32 | ground_truth = random.integers(n_entity, size=n_example) 33 | scores = random.random((n_model, n_example, n_entity)) 34 | scores[:, np.arange(n_example), ground_truth] += 0.25 35 | predictions = np.argsort(scores, axis=2)[..., ::-1][:, :, :n_prediction] 36 | 37 | ensemble_predictions = ensemble_fn(predictions, count=n_prediction, **ensemble_args) 38 | 39 | _, _, ranks = np.where(predictions == ground_truth[:, np.newaxis]) 40 | model_mrr = np.mean(1 / (1 + ranks)) 41 | 42 | _, ranks = np.where(ensemble_predictions == ground_truth[:, np.newaxis]) 43 | ensemble_mrr = np.mean(1 / (1 + ranks)) 44 | 45 | # print(ensemble_fn.__name__, ensemble_args, model_mrr, ensemble_mrr) 46 | assert model_mrr < ensemble_mrr 47 | -------------------------------------------------------------------------------- /doc/development.md: -------------------------------------------------------------------------------- 1 | # Development (Graphcore) 2 | 3 | ## Contribution process 4 | 5 | To contribute code, please: 6 | - Open a Pull Request 7 | - Add a reviewer 8 | - When you get an LGTM, submitter merges 9 | - Merge-squash by default 10 | - Merge when required (e.g. branches-upon-branches) 11 | 12 | 13 | ## Coding guidelines 14 | 15 | A few principles that we feel are important: 16 | - Unless refactoring, keep with the existing pattern of code 17 | - Keep code simple by: 18 | - Minimising the need to reason about state (e.g. functional style) 19 | - Using descriptive, but not overly long, variable names 20 | - Ensuring methods/classes/etc have a single, easy-to-describe responsibility 21 | - Let's not worry about basic formatting - use the autoformatters clang-format & black, as configured for everyone 22 | - Try to avoid dependencies that aren't strictly necessary (both internal & external) 23 | - Keep trying to improve test coverage, but with as little testing code as possible 24 | 25 | 26 | ## Using VSCode 27 | 28 | We strongly recommend taking the time to sort out indexing & autocomplete in your IDE of choice. It should be possible to get suggestions for C++ and Python (including poplar_kge but not libpoplar_kge) using the following. 29 | 30 | ``` 31 | ln -s $POPLAR_SDK_ENABLED third_party/poplar 32 | 33 | # Add the following to .vscode/settings.json 34 | { 35 | "editor.formatOnSave": true, 36 | "python.defaultInterpreterPath": ".venv/bin/python", 37 | "C_Cpp.default.includePath": [ 38 | "${workspaceFolder}/src", 39 | "${workspaceFolder}/third_party/poplar/include", 40 | "${workspaceFolder}/third_party/pybind11/include", 41 | "${workspaceFolder}/third_party/catch2/single_include" 42 | ] 43 | } 44 | ``` 45 | -------------------------------------------------------------------------------- /src/poplar_extensions/distance.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | 3 | #ifndef POPLAR_EXTENSIONS_DISTANCE_HPP 4 | #define POPLAR_EXTENSIONS_DISTANCE_HPP 5 | 6 | #include 7 | #include 8 | 9 | namespace poplar_extensions { 10 | 11 | poplar::Tensor l1distance(poplar::Graph& graph, 12 | const poplar::Tensor& a, 13 | const poplar::Tensor& b, 14 | poplar::program::Sequence& prog, 15 | const poplar::DebugContext& debugContext); 16 | 17 | poplar::Tensor l1distancegrad(poplar::Graph& graph, 18 | const poplar::Tensor& a, 19 | const poplar::Tensor& b, 20 | const poplar::Tensor& gradOutput, 21 | poplar::program::Sequence& prog, 22 | const poplar::DebugContext& debugContext); 23 | 24 | poplar::Tensor l2distance(poplar::Graph& graph, 25 | const poplar::Tensor& a, 26 | const poplar::Tensor& b, 27 | poplar::program::Sequence& prog, 28 | const poplar::DebugContext& debugContext); 29 | 30 | poplar::Tensor l2distancegrad(poplar::Graph& graph, 31 | const poplar::Tensor& a, 32 | const poplar::Tensor& b, 33 | const poplar::Tensor& dist, 34 | const poplar::Tensor& gradOutput, 35 | poplar::program::Sequence& prog, 36 | const poplar::DebugContext& debugContext); 37 | 38 | } // namespace poplar_extensions 39 | 40 | #endif // POPLAR_EXTENSIONS_DISTANCE_HPP 41 | -------------------------------------------------------------------------------- /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | name: CI 3 | on: 4 | workflow_dispatch: 5 | pull_request: 6 | schedule: 7 | - cron: "0 1 * * *" 8 | defaults: 9 | run: 10 | shell: bash 11 | jobs: 12 | # Workflow which cancels any previous jobs from the same PR (to conserve resources) 13 | cancel_previous: 14 | runs-on: ubuntu-latest 15 | steps: 16 | - name: Cancel Previous Workflows For Same PR 17 | uses: styfle/cancel-workflow-action@0.10.0 18 | with: 19 | access_token: ${{ github.token }} 20 | ci: 21 | runs-on: [self-hosted, applications, linux, pod] 22 | container: 23 | image: graphcore/pytorch:3.0.0-ubuntu-20.04 24 | options: --pull=always --ulimit memlock=-1:-1 --cap-add=IPC_LOCK --device=/dev/infiniband/ -e IPUOF_VIPU_API_HOST -e IPUOF_VIPU_API_PARTITION_ID --shm-size=128G 25 | steps: 26 | - name: Install Git 27 | run: apt-get update && apt-get install -y git 28 | - name: Checkout Repo 29 | uses: actions/checkout@v3 30 | with: 31 | submodules: true 32 | - name: Attach RDMA Network 33 | run: | 34 | python3 -m pip install docker 35 | python3 -c "import docker; client=docker.from_env(); client.networks.get('macvlan_rdma_swarm').connect(client.containers.get('${{ job.container.id }}'))" 36 | - name: Install requirements 37 | run: | 38 | export DEBIAN_FRONTEND=noninteractive 39 | apt-get install -yq $(cat requirements-apt.txt | tr "\n" " ") 40 | pip3 install -r requirements-dev.txt 41 | - name: Run CI 42 | run: ./dev ci 43 | slack_notify: 44 | name: Slack Notify 45 | needs: [ci] 46 | if: always() 47 | uses: graphcore/github-actions-internal/.github/workflows/slack_notify.yaml@main 48 | secrets: 49 | SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }} 50 | CI_SLACK_CHANNEL_ID: ${{ secrets.CI_SLACK_CHANNEL_ID }} 51 | with: 52 | ci-result: ${{ needs.ci.result }} 53 | -------------------------------------------------------------------------------- /src/lib.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | // 3 | // Python bindings 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | #include "poplar_kge.hpp" 10 | 11 | namespace py = pybind11; 12 | 13 | namespace { 14 | template 15 | void bindArray(py::module& m, const std::string& typeName) { 16 | using Src = py::array_t; 17 | using Dest = poplar_kge::ArrayView; 18 | std::ostringstream name; 19 | name << "ArrayView[" << typeName << "]"; 20 | py::class_(m, name.str().c_str()).def(py::init([](Src& data) { 21 | return Dest({data.shape(), data.shape() + data.ndim()}, data.mutable_data()); 22 | })); 23 | py::implicitly_convertible(); 24 | } 25 | void bindInt16ToFloat16Array(py::module& m) { 26 | // Hack - use the C++ type int16_t as a dummy placeholder for float16, then convert 27 | // to poplar_kge::float16 28 | using Src = py::array_t; 29 | using Dest = poplar_kge::ArrayView; 30 | py::class_(m, "ArrayView[float16]").def(py::init([](Src& data) { 31 | return Dest({data.shape(), data.shape() + data.ndim()}, 32 | reinterpret_cast(data.mutable_data())); 33 | })); 34 | py::implicitly_convertible(); 35 | } 36 | } // namespace 37 | 38 | PYBIND11_MODULE(libpoplar_kge, m) { 39 | m.doc() = 40 | "Implements the core model of Poplar-KGE, a very low-level interface (please use Python " 41 | "wrapper)"; 42 | bindArray(m, "float32"); 43 | bindArray(m, "uint32"); 44 | bindInt16ToFloat16Array(m); 45 | 46 | py::class_(m, "Engine", 47 | "Build graph, compile and load the executable onto device") 48 | .def(py::init(), py::arg("settings"), 49 | py::arg("gp_folder")) 50 | .def("run", &poplar_kge::Engine::run, "Execute a command (see Python for usage)", 51 | py::arg("command"), py::arg("data")); 52 | } 53 | -------------------------------------------------------------------------------- /scripts/run_training.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | 3 | from typing import Dict, Union 4 | 5 | import numpy as np 6 | import poplar_kge as kge 7 | import poplar_kge_dataset as kge_data 8 | import poplar_kge_utility as kge_utility 9 | import torch 10 | 11 | # settings = kge.Settings.create_demo() 12 | 13 | settings = kge.Settings.create_wikikg90mv2() 14 | # settings.model.score_fn = "RotatE" # modify settings directly here 15 | 16 | 17 | # Main script 18 | 19 | settings.prepare() 20 | logger = kge_utility.Logger(settings, __file__) 21 | # Grab settings from `logger` in case they've been changed for a sweep 22 | settings = logger.settings 23 | 24 | data = kge_data.RawData.load(settings) 25 | logger.log("load_data", {}) 26 | 27 | dataset = kge_data.Dataset.load(data, settings) 28 | logger.log("build_index", {}) 29 | 30 | engine = kge.Engine(settings, dataset.shard_to_count) 31 | logger.log("compile", {}) 32 | 33 | engine.initialise_all(dataset.entity_features(settings.model.entity_feature_size)) 34 | logger.log("initialise", {}) 35 | 36 | ds = kge_data.DatasetWrapper(dataset) 37 | dl = torch.utils.data.DataLoader( 38 | ds, batch_size=None, num_workers=10, worker_init_fn=ds.worker_init_fn 39 | ) 40 | dl_iter = iter(dl) 41 | 42 | 43 | def validate() -> None: 44 | for part in ["train", "valid"]: 45 | logger.log(f"eval_{part}", {f"{part}_mrr": dataset.mrr(part, engine.predict)}) 46 | 47 | 48 | def predict(name: str) -> None: 49 | results: Dict[str, Union[int, np.ndarray]] = dict(step=logger.step) 50 | for part in ["valid", "test-dev", "test-challenge"]: 51 | entity, score = dataset.predict(part, engine.predict) 52 | results[part] = entity.astype(np.uint32) 53 | results[f"{part}-score"] = score.astype(np.float16) 54 | logger.savez(f"predictions_{name}.npz", results) 55 | 56 | 57 | for n in range(settings.logs_per_training_run): 58 | if n % settings.logs_per_validation == 0: 59 | validate() 60 | if n in settings.predict_at_log: 61 | predict(str(logger.step)) 62 | loss, lr = np.mean( 63 | [ 64 | engine.train_step_loop(ds.tensors_to_batch(**next(dl_iter))) 65 | for _ in range(settings.program_runs_per_log) 66 | ], 67 | axis=0, 68 | ) 69 | logger.log("train_step_loop", dict(loss=loss, learning_rate=lr)) 70 | validate() 71 | predict("final") 72 | -------------------------------------------------------------------------------- /scripts/run_profile.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | 3 | """Profiling script.""" 4 | 5 | import argparse 6 | import json 7 | import os 8 | from pathlib import Path 9 | 10 | import numpy as np 11 | import poplar_kge as kge 12 | import poplar_kge_utility as kge_utility 13 | import test_poplar_kge 14 | 15 | 16 | def profile(settings: kge.Settings, run_predict: bool) -> None: 17 | if settings.logging.path: 18 | settings.logging.path.mkdir(parents=True, exist_ok=True) 19 | assert ( 20 | "POPLAR_ENGINE_OPTIONS" not in os.environ 21 | ), "POPLAR_ENGINE_OPTIONS should not be set outside the run_profile script" 22 | os.environ["POPLAR_ENGINE_OPTIONS"] = json.dumps( 23 | { 24 | "autoReport.directory": str(settings.logging.path), 25 | "autoReport.all": True, 26 | "autoReport.outputArchive": False, 27 | "autoReport.executionProfileProgramRunCount": 1, 28 | "profiler.replicaToProfile": 0, 29 | } 30 | ) 31 | 32 | logger = kge_utility.Logger(settings, __file__, truncate_train=False) 33 | engine = kge.Engine( 34 | settings, np.full(settings.model.n_shard, settings.model.n_entity - 1) 35 | ) 36 | logger.log("compile", {}) 37 | 38 | engine.initialise_variables() 39 | # hack to avoid having to run additional 'fill' programs 40 | engine.uninitialized_entity_data[:] = False 41 | logger.log("initialise", {}) 42 | 43 | train_batch = test_poplar_kge._random_batch(engine) 44 | predict_query = test_poplar_kge._random_prediction_query( 45 | engine, engine.settings.execution.predict_hr_batch_size 46 | ) 47 | logger.log("sample", {}) 48 | 49 | for _ in range(5): 50 | engine.train_step_loop(train_batch) 51 | logger.log("train_step_loop", {}) 52 | 53 | if run_predict: 54 | engine.predict(predict_query) 55 | logger.log("predict_step", {}) 56 | 57 | 58 | if __name__ == "__main__": 59 | parser = argparse.ArgumentParser(description=__doc__) 60 | parser.add_argument("output", nargs="?", type=Path) 61 | args = parser.parse_args() 62 | 63 | settings = kge.Settings.create_wikikg90mv2() 64 | 65 | settings.logging.wandb = False 66 | settings.logging.path = args.output 67 | settings.prepare() 68 | profile(settings, run_predict=False) 69 | -------------------------------------------------------------------------------- /tests/testPoplarKge.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include "fructose/fructose.hpp" 11 | #include "poplar_kge.hpp" 12 | 13 | namespace { 14 | struct HostTensor { 15 | std::vector shape; 16 | std::vector data; 17 | 18 | HostTensor() = default; 19 | HostTensor(const std::vector& shape) 20 | : shape(shape), data(fr::util::numElements(shape)) {} 21 | 22 | float scalar() const { 23 | assert(shape.size() == 0); 24 | return data.front(); 25 | } 26 | }; 27 | struct TestHelper { 28 | poplar::Device device; 29 | fr::RootFrame rootFrame; 30 | std::unique_ptr engine; 31 | 32 | explicit TestHelper(unsigned nIpus = 1) 33 | : device(poplar::Device::createCPUDevice(nIpus)), rootFrame(device.getTarget()) {} 34 | 35 | void load() { 36 | engine.reset(new poplar::Engine(rootFrame.graph.poplar(), rootFrame.tape.prog())); 37 | engine->load(device); 38 | } 39 | std::unordered_map run() { 40 | std::unordered_map result; 41 | for (auto& item : rootFrame.streams) { 42 | assert(item.second.spec().dtype == poplar::FLOAT); 43 | result.insert({item.first, HostTensor(item.second.spec().shape)}); 44 | engine->connectStream(item.first, result[item.first].data.data()); 45 | } 46 | engine->run(); 47 | return result; 48 | } 49 | std::unordered_map loadAndRun() { 50 | load(); 51 | return run(); 52 | } 53 | }; 54 | } // namespace 55 | 56 | TEST_CASE("poplar_kge::detachedSoftmax", "[poplar_kge]") { 57 | TestHelper test; 58 | popops::addCodelets(test.rootFrame.graph.poplar()); 59 | 60 | auto xent = 61 | poplar_kge::detachedSoftmax(fr::ops::constant({1, 1, 2, 2, 4, 6}).reshape({2, 3})); 62 | fr::ops::output("xent", xent); 63 | 64 | using Catch::Matchers::Approx; 65 | auto result = test.loadAndRun(); 66 | auto norm1 = std::exp(1.0f) + std::exp(1.0f) + std::exp(2.0f); 67 | auto norm2 = std::exp(2.0f) + std::exp(4.0f) + std::exp(6.0f); 68 | REQUIRE_THAT( 69 | result["xent"].data, 70 | Approx({std::exp(1.0f) / norm1, std::exp(1.0f) / norm1, std::exp(2.0f) / norm1, 71 | std::exp(2.0f) / norm2, std::exp(4.0f) / norm2, std::exp(6.0f) / norm2})); 72 | } 73 | -------------------------------------------------------------------------------- /doc/design.md: -------------------------------------------------------------------------------- 1 | # The design of Poplar KGE, PAG & Fructose 2 | 3 | This document covers code structure and design, for a detailed description of the models and execution scheme, please see our [technical report](https://arxiv.org/abs/2211.12281). 4 | 5 | ## System architecture 6 | 7 | The core model & training code is written in a combination of C++ and Python. The C++ code is build using ninja, and automated using a custom standalone Python script called `./dev`. We recommend consulting the [./dev source code](../dev) if the build process is unclear or problematic. 8 | 9 | At runtime, the Python code requires `libpoplar_kge.so` from `build/`, `src/python` and `tests/python` to be available on the Python path. Optionally, the commands `./dev train` or `./dev python ...` can be used; these will check the C++ library is up-to-date before adding the required paths. 10 | 11 | We use WandB to track experiments and manage sweeps, however this is optional and the code can be run locally as required (using `settings.logging.wandb = False`). 12 | 13 | ![system architecture diagram, showing Poplar KGE application code at the centre, depending upon Poplar SDK, OGB dataset and PyTorch data, with an optional dependency on the wandb server](system_architecture.png) 14 | 15 | ## Internal architecture 16 | 17 | Internally the C++ code is divided up into "mini-libraries" which each have a simple responsibility. These are _not intended for direct reuse_, but may be a useful starting point for further work or similar projects. 18 | 19 | In the diagram below, the most application-specific code is shown in green, generic mini-libraries are shown in brown and external dependencies in blue. Briefly: 20 | 21 | - Poplar KGE (poplar_kge*.py, poplar_kge.[ch]pp) is a monolithic training system for distributed KGE training and evaluation. It contains all model-specific and training-specific code. 22 | - Fructose is a syntactic sugar and ease-of-use layer on top of PAG, providing an API similar to machine learning frameworks for constructing Poplar progams. 23 | - Poplar AutoGrad (PAG) is a low-level automatic differentiation layer for PopLibs. It presents a PopLibs-like interface that allows generation of the backwards pass of composed PopLibs functions using the chain rule. 24 | 25 | ![architecture diagram, showing a chain of dependencies: src/python/poplar_kge.py (batching, initialisation, persistence, logging/wandb), then lib.cpp (generic binding), poplar_kge.cpp (define core programs: train, predict, write_entity, read_entity), Fructose (syntactic sugar via global state), Poplar AutoGrad (define gradients for poplibs ops) and finally the external dependency on Poplar/PopLibs](architecture.png) 26 | -------------------------------------------------------------------------------- /src/poplar_kge.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | 3 | #ifndef POPLAR_KGE_HPP 4 | #define POPLAR_KGE_HPP 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include "fructose/fructose.hpp" 12 | 13 | namespace poplar_kge { 14 | 15 | fr::Tensor detachedSoftmax(const fr::Tensor& tensor); 16 | 17 | /** 18 | * A C-contiguous N-dimensional array view (does not own 'data'). 19 | */ 20 | template 21 | struct ArrayView { 22 | ArrayView(const std::vector& shape, T* data); 23 | 24 | const std::vector& shape() const; 25 | const T* data() const; 26 | T* data(); 27 | 28 | private: 29 | std::vector m_shape; 30 | T* m_data; 31 | }; 32 | 33 | struct float16 { 34 | uint16_t value; 35 | }; 36 | static_assert(sizeof(float16) == 2); 37 | 38 | using Batch = 39 | std::unordered_map, 44 | ArrayView, 45 | ArrayView, 46 | bool, 47 | float, 48 | unsigned, 49 | std::vector>>, 50 | std::unordered_map>>; 51 | 52 | struct EngineImpl; 53 | 54 | /** 55 | * Interface to training engine. 56 | */ 57 | struct Engine { 58 | /** 59 | * Construct an engine to train & evaluate the KGE model. 60 | * 61 | * See lib.cpp:engineDocstring for full details. 62 | */ 63 | Engine(const Batch& settings, const std::string& gpFolder); 64 | ~Engine(); 65 | 66 | /** 67 | * Run a command, e.g. "train_step_loop", "read", etc. 68 | * 69 | * See lib.cpp:runDocstring for full details. 70 | */ 71 | Batch run(const std::string& command, Batch& data); 72 | 73 | private: 74 | std::unique_ptr m_impl; 75 | }; 76 | 77 | /////////////////////////////////////////////////////////////////////////////// 78 | // Implementation 79 | 80 | template 81 | ArrayView::ArrayView(const std::vector& shape, T* data) : m_shape(shape), m_data(data) {} 82 | template 83 | const std::vector& ArrayView::shape() const { 84 | return m_shape; 85 | } 86 | template 87 | const T* ArrayView::data() const { 88 | return m_data; 89 | } 90 | template 91 | T* ArrayView::data() { 92 | return m_data; 93 | } 94 | 95 | } // namespace poplar_kge 96 | 97 | #endif // POPLAR_KGE_HPP 98 | -------------------------------------------------------------------------------- /src/python/poplar_kge_ensemble.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | 3 | """Methods for ensembling ranked predictions.""" 4 | 5 | from typing import Optional 6 | 7 | import numpy as np 8 | 9 | 10 | def mean_ensemble( 11 | predictions: np.ndarray, 12 | count: int, 13 | power: float, 14 | default_rank: Optional[float] = None, 15 | model_weight: Optional[np.ndarray] = None, 16 | ) -> np.ndarray: 17 | """Sort entities by mean(score), where score is -sign(power) * rank ** power. 18 | 19 | predictions -- uint[n_model x n_example x n_prediction] 20 | 21 | count -- number of predictions to return, must be <= n_prediction 22 | 23 | default_rank -- rank to assume when a model is not found 24 | default (power < 0) = infinity 25 | default (power > 0) = 1 + n_prediction 26 | 27 | returns -- uint[n_model x n_example x count] 28 | """ 29 | n_model, n_example, n_prediction = predictions.shape 30 | 31 | if power == 0.0: 32 | default_score = -1.0 33 | elif default_rank is not None: 34 | default_score = -np.sign(power) * default_rank**power 35 | elif power < 0: 36 | default_score = 0.0 37 | elif power > 0: 38 | default_score = -((1 + n_prediction) ** power) 39 | 40 | if model_weight is not None and default_score != 0: 41 | raise NotImplementedError( 42 | "Model weighting only implemented for default_rank = inf." 43 | ) 44 | 45 | rank_scores = ( 46 | -np.sign(power) * (1 + np.arange(n_prediction, dtype=np.float32)) ** power 47 | ) 48 | rank_scores = np.tile(rank_scores[np.newaxis, :], (n_model, 1)) 49 | if model_weight is not None: 50 | rank_scores *= model_weight[:, np.newaxis] 51 | 52 | results = [] 53 | for idx in range(n_example): 54 | ids, indices, counts = np.unique( 55 | predictions[:, idx, :].reshape(-1), return_inverse=True, return_counts=True 56 | ) 57 | scores = (n_prediction - counts) * default_score 58 | np.add.at(scores, indices, rank_scores.reshape(-1)) 59 | results.append(ids[np.argsort(scores)[::-1][:count]]) 60 | return np.stack(results) 61 | 62 | 63 | def median_ensemble(predictions: np.ndarray, count: int) -> np.ndarray: 64 | """Sort entities by median(1/rank). 65 | 66 | predictions -- uint[n_model x n_example x n_prediction] 67 | 68 | count -- number of predictions to return, must be <= n_prediction 69 | 70 | returns -- uint[n_model x n_example x count] 71 | """ 72 | n_model, n_example, n_prediction = predictions.shape 73 | rrank = np.tile(1 + np.arange(n_prediction, dtype=np.float32), (n_model, 1)) ** -1 74 | 75 | results = [] 76 | for idx in range(n_example): 77 | ids, indices = np.unique( 78 | predictions[:, idx, :].reshape(-1), return_inverse=True 79 | ) 80 | padded_rrank = np.zeros((n_model, len(ids)), dtype=np.float32) 81 | padded_rrank[ 82 | np.arange(n_model)[:, np.newaxis], indices.reshape(n_model, n_prediction) 83 | ] = rrank 84 | scores = np.median(padded_rrank, axis=0) 85 | results.append(ids[np.argsort(scores)[::-1][:count]]) 86 | return np.stack(results) 87 | -------------------------------------------------------------------------------- /src/fructose/frnn.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | 3 | #include "frnn.hpp" 4 | 5 | #include 6 | #include 7 | 8 | namespace fr::nn { 9 | 10 | /////////////////////////////////////////////////////////////////////////////// 11 | // Ops 12 | 13 | Tensor relu(const Tensor& tensor) { 14 | Frame f("frnn::relu"); 15 | return ops::max(tensor, ops::constant(0.0f, tensor.dtype())); 16 | } 17 | 18 | Tensor softmaxCrossEntropy(const Tensor& logits, const Tensor& labels) { 19 | Frame f("frnn::softmaxCrossEntropy"); 20 | auto logp = ops::logSoftmax(logits); 21 | auto oneHotLabels = ops::oneHot(labels, logp.shape().back(), logp.dtype()); 22 | return -ops::sum(logp * oneHotLabels, {logp.rank() - 1}); 23 | } 24 | 25 | Tensor dropout(const Tensor& a, float dropProbability) { 26 | Frame f("frnn::dropout"); 27 | mapping::setDefault(mapping::Linear(), {a}); 28 | return Tensor::wrap(pag::ops::dropout(f.graph, a.pag(), dropProbability, f.tape, f.di)); 29 | } 30 | 31 | /////////////////////////////////////////////////////////////////////////////// 32 | // Optimisers 33 | 34 | void sgd(const Tensor& tensor, const Tensor& learningRate) { 35 | Frame f("frnn::sgd"); 36 | auto grad = tensor.grad(); 37 | popops::scaledSubtractFrom(f.graph.poplar(), f.graph.unwrap(tensor.pag()), 38 | f.graph.unwrap(grad.pag()), f.graph.unwrap(learningRate.pag()), 39 | f.tape.prog(), f.di); 40 | } 41 | 42 | Tensor adamStepSizeAutoIncrement(const Tensor& step, 43 | const Tensor& learningRate, 44 | const AdamParams& params) { 45 | Frame f("frnn::adamStepSizeAutoIncrement"); 46 | mapping::setDefault(mapping::OneTile(), {step, learningRate}); 47 | popops::addInPlace(f.graph.poplar(), f.graph.unwrap(step.pag()), 48 | f.graph.unwrap(ops::constant(1u, step.dtype()).pag()), f.tape.prog(), f.di); 49 | 50 | namespace pe = popops::expr; 51 | auto numerator = 52 | pe::Sqrt(1 - pe::Pow(pe::Const(params.betaV), pe::Cast(pe::_1, poplar::FLOAT))); 53 | auto denominator = 1 - pe::Pow(pe::Const(params.betaM), pe::Cast(pe::_1, poplar::FLOAT)); 54 | return Tensor::wrap( 55 | f.graph.wrap(popops::map(f.graph.poplar(), (numerator / denominator) * pe::_2, 56 | {f.graph.unwrap(step.pag()), f.graph.unwrap(learningRate.pag())}, 57 | f.tape.prog(), f.di), 58 | /*requiresGrad*/ false)); 59 | } 60 | 61 | void adam(const Tensor& tensor, 62 | const Tensor& momentum, 63 | const Tensor& variance, 64 | const Tensor& stepSize, 65 | const AdamParams& params) { 66 | Frame f("frnn::adam"); 67 | assert(tensor.valid() && "`tensor` should have a grad, so must have been mapped"); 68 | 69 | // {tensor, momentum, variance} are used elementwise, so give them the same default mapping 70 | mapping::setDefault(mapping::Copy(f.graph.unwrap(tensor.pag())), {momentum, variance}); 71 | 72 | namespace pe = popops::expr; 73 | 74 | auto grad = tensor.grad(); 75 | popops::mapInPlace( 76 | f.graph.poplar(), pe::Const(params.betaM) * pe::_1 + pe::Const(1 - params.betaM) * pe::_2, 77 | {f.graph.unwrap(momentum.pag()), f.graph.unwrap(grad.pag())}, f.tape.prog(), f.di); 78 | 79 | popops::mapInPlace( 80 | f.graph.poplar(), 81 | pe::Const(params.betaV) * pe::_1 + pe::Const(1 - params.betaV) * pe::_2 * pe::_2, 82 | {f.graph.unwrap(variance.pag()), f.graph.unwrap(grad.pag())}, f.tape.prog(), f.di); 83 | 84 | auto update = pe::_1 - pe::_2 * (pe::_3 / (pe::Sqrt(pe::_4) + pe::Const(params.epsilon)) + 85 | pe::Const(params.weightDecay) * pe::_1); 86 | popops::mapInPlace(f.graph.poplar(), update, 87 | {f.graph.unwrap(tensor.pag()), f.graph.unwrap(stepSize.pag()), 88 | f.graph.unwrap(momentum.pag()), f.graph.unwrap(variance.pag())}, 89 | f.tape.prog(), f.di); 90 | } 91 | 92 | } // namespace fr::nn 93 | -------------------------------------------------------------------------------- /tests/fructose/testSmallTraining.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | #include "fructose/frnn.hpp" 10 | #include "fructose/fructose.hpp" 11 | 12 | TEST_CASE("Small training example", "[fr]") { 13 | auto N = 10u; 14 | auto hidden_size = 64u; 15 | std::mt19937 rng(3248924); 16 | 17 | // Dataset - modulo addition 18 | std::vector dataA, dataB, dataY; 19 | for (auto i = 0u; i < N * N; ++i) { 20 | auto a = i % N; 21 | auto b = i / N; 22 | auto y = (a + b) % N; 23 | dataA.push_back(a); 24 | dataB.push_back(b); 25 | dataY.push_back(y); 26 | } 27 | 28 | // Build graph 29 | auto device = poplar::Device::createCPUDevice(); 30 | fr::RootFrame rootFrame(device.getTarget()); 31 | popops::addCodelets(rootFrame.graph.poplar()); 32 | poplin::addCodelets(rootFrame.graph.poplar()); 33 | 34 | auto a = fr::ops::input("a", {{N * N}, poplar::UNSIGNED_INT}); 35 | auto b = fr::ops::input("b", {{N * N}, poplar::UNSIGNED_INT}); 36 | auto y = fr::ops::input("y", {{N * N}, poplar::UNSIGNED_INT}); 37 | 38 | auto embedding = fr::ops::variable("embedding", {{N, hidden_size}, poplar::FLOAT}); 39 | auto z = fr::nn::relu(fr::ops::gather(embedding, a) + fr::ops::gather(embedding, b)); 40 | 41 | auto hiddenW = fr::ops::variable("hidden/W", {{hidden_size, hidden_size}, poplar::FLOAT}); 42 | auto hiddenB = fr::ops::variable("hidden/b", {{hidden_size}, poplar::FLOAT}); 43 | z = fr::nn::relu(fr::ops::matMul(z, hiddenW) + hiddenB); 44 | 45 | auto projectionW = fr::ops::variable("projection/W", {{hidden_size, N}, poplar::FLOAT}); 46 | auto logits = fr::ops::matMul(z, projectionW); 47 | auto loss = fr::ops::mean(fr::nn::softmaxCrossEntropy(logits, y)); 48 | 49 | loss.backward(); 50 | fr::ops::output("loss", loss); 51 | 52 | fr::nn::AdamParams adamParams{/*betaM*/ 0.9f, /*betaV*/ 0.999f, /*epsilon*/ 1e-8f, 53 | /*weightDecay*/ 0.0f}; 54 | auto step = fr::ops::variable("step", {{}, poplar::UNSIGNED_INT}); 55 | auto adamStepSize = 56 | fr::nn::adamStepSizeAutoIncrement(step, fr::ops::constant(0.01f), adamParams); 57 | for (auto& tensor : {embedding, hiddenW, hiddenB, projectionW}) { 58 | auto momentum = 59 | fr::ops::variable(tensor.name() + "/adam_m", tensor.spec(), /*requiresGrad*/ false); 60 | auto variance = 61 | fr::ops::variable(tensor.name() + "/adam_v", tensor.spec(), /*requiresGrad*/ false); 62 | fr::nn::adam(tensor, momentum, variance, adamStepSize, adamParams); 63 | tensor.hostAccess(); 64 | momentum.hostAccess(); 65 | variance.hostAccess(); 66 | } 67 | step.hostAccess(); 68 | 69 | // Initialise 70 | poplar::Engine engine(rootFrame.graph.poplar(), rootFrame.tape.prog()); 71 | engine.load(device); 72 | engine.writeTensor("step", {0u}); 73 | for (auto& tensor : {embedding, hiddenW, hiddenB, projectionW}) { 74 | std::vector init(tensor.numElements()); 75 | auto scale = std::unordered_map{ 76 | {"embedding", 1.0f}, 77 | {"hidden/W", 1.0f / std::sqrt(hidden_size)}, 78 | {"hidden/b", 0.0f}, 79 | {"projection/W", 1.0f / std::sqrt(hidden_size)}}[tensor.name()]; 80 | std::generate(init.begin(), init.end(), 81 | [&rng, scale] { return scale * std::normal_distribution()(rng); }); 82 | engine.writeTensor(tensor.name(), init); 83 | engine.writeTensor(tensor.name() + "/adam_m", std::vector(init.size())); 84 | engine.writeTensor(tensor.name() + "/adam_v", std::vector(init.size())); 85 | } 86 | 87 | // Run 88 | float hostLoss; 89 | engine.connectStream("loss", &hostLoss); 90 | engine.connectStream("a", dataA); 91 | engine.connectStream("b", dataB); 92 | engine.connectStream("y", dataY); 93 | for (auto i = 0u; i < 30; ++i) { 94 | engine.run(); 95 | } 96 | REQUIRE(hostLoss < 0.03f); 97 | } 98 | -------------------------------------------------------------------------------- /src/python/poplar_kge_utility.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | 3 | """General support utilities.""" 4 | 5 | import dataclasses 6 | import datetime 7 | import json 8 | import sys 9 | import tempfile 10 | import time 11 | from pathlib import Path 12 | from typing import Any, Dict, Union 13 | 14 | import numpy as np 15 | import poplar_kge as kge 16 | import wandb 17 | 18 | 19 | def _json_default(obj: Any) -> str: 20 | if isinstance(obj, Path): 21 | return str(obj) 22 | raise ValueError(f"Value {obj} of type {type(obj)} is not JSON-seralisable") 23 | 24 | 25 | def recursive_replace(d: kge.Settings, u: Dict[str, Any]) -> kge.Settings: 26 | for k, v in u.items(): 27 | if dataclasses.is_dataclass(getattr(d, k)): 28 | setattr(d, k, recursive_replace(getattr(d, k), v)) 29 | else: 30 | setattr(d, k, v) 31 | return d 32 | 33 | 34 | class Logger: 35 | """Log to wandb, file and stderr.""" 36 | 37 | def __init__( 38 | self, settings: kge.Settings, code_path: str, truncate_train: bool = True 39 | ): 40 | self.settings = settings 41 | self.truncate_train = truncate_train 42 | self.last_t = time.time() 43 | self.last_event = "start" 44 | self.step = 0 45 | self.log_file = None 46 | 47 | if settings.logging.path: 48 | settings.logging.path.mkdir(parents=True, exist_ok=True) 49 | (settings.logging.path / "app.json").write_text( 50 | json.dumps(settings.flatten(), default=_json_default) 51 | ) 52 | self.log_file = (settings.logging.path / "log.jsonl").open("w") 53 | self._write_log( 54 | dict( 55 | event="start", 56 | start_time=datetime.datetime.now().isoformat(), 57 | settings=settings.flatten(), 58 | ) 59 | ) 60 | if settings.logging.wandb: 61 | wandb.init( 62 | entity="ogb-wikiwiki", 63 | project="poplar-kge-v2", 64 | config=dataclasses.asdict(settings), 65 | dir=tempfile.gettempdir(), 66 | ) 67 | # If being run from a sweep agent, this will update the settings 68 | # to those specified in the sweep. If being run from a regular 69 | # training run, this is a no-op 70 | self.settings = recursive_replace(settings, wandb.config) 71 | wandb.run.log_code(code_path) # type:ignore[union-attr] 72 | print(f"[start] {self.settings.flatten()}", file=sys.stderr, flush=True) 73 | 74 | self.steps_per_log = ( 75 | self.settings.execution.train_steps_per_program_run 76 | * self.settings.program_runs_per_log 77 | ) 78 | self.samples_per_step = settings.model.n_shard * settings.data.batch_size 79 | 80 | def _write_log(self, data: Dict[str, Any]) -> None: 81 | assert self.log_file 82 | self.log_file.write(json.dumps(data, default=_json_default) + "\n") 83 | self.log_file.flush() 84 | 85 | def log(self, event: str, data: Dict[str, Any]) -> None: 86 | data = data.copy() 87 | elapsed = time.time() - self.last_t 88 | self.last_t += elapsed 89 | if event == "train_step_loop": 90 | self.step += self.steps_per_log 91 | sample = self.samples_per_step * self.step 92 | 93 | if self.settings.logging.wandb: 94 | wandb.log( 95 | {**data, f"{event}_time": elapsed, "step": self.step, "sample": sample}, 96 | step=self.step, 97 | ) 98 | 99 | if self.log_file: 100 | self._write_log( 101 | dict( 102 | event=event, 103 | step=self.step, 104 | time=elapsed, 105 | sample=sample, 106 | data=data, 107 | ) 108 | ) 109 | 110 | if not self.truncate_train or not ( 111 | event == self.last_event == "train_step_loop" 112 | ): 113 | print( 114 | f"[#{self.step // 1000:>06d}k {event} : {elapsed:.3f} s] {json.dumps(data)}", 115 | file=sys.stderr, 116 | flush=True, 117 | ) 118 | 119 | self.last_event = event 120 | 121 | def savez(self, name: str, data: Dict[str, Union[int, np.ndarray]]) -> None: 122 | if self.settings.logging.wandb: 123 | np.savez(Path(wandb.run.dir) / name, **data) # type:ignore[union-attr] 124 | if self.settings.logging.path: 125 | np.savez(self.settings.logging.path / name, **data) 126 | -------------------------------------------------------------------------------- /doc/pytorch_demo/custom_ops.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include 8 | 9 | #pragma GCC diagnostic push 10 | #pragma GCC diagnostic ignored "-Wsign-compare" 11 | #pragma GCC diagnostic ignored "-Wunused-parameter" 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #pragma GCC diagnostic pop 23 | 24 | // AllToAll 25 | 26 | namespace Onnx::CustomOperators { 27 | const popart::OperatorIdentifier AllToAll = {"ai.graphcore", "AllToAll", 1}; 28 | } // namespace Onnx::CustomOperators 29 | 30 | namespace { 31 | namespace all_to_all { 32 | struct CustomOp : popart::Op { 33 | CustomOp(const popart::OperatorIdentifier& _opid, const popart::Op::Settings& settings_) 34 | : popart::Op(_opid, settings_) {} 35 | std::unique_ptr clone() const final { return std::make_unique(*this); } 36 | float getSubgraphValue() const final { return getLowSubgraphValue(); } 37 | void setup() { outInfo(0) = inInfo(0); } // shape inference 38 | std::vector> getGradOps() { 39 | std::vector> result; 40 | result.emplace_back(new CustomOp(*this)); // grad(allToAll) == allToAll 41 | return result; 42 | } 43 | const std::vector& gradInputInfo() const { 44 | static const std::vector inInfo = { 45 | {0, 0, popart::GradOpInType::GradOut}}; 46 | return inInfo; 47 | } 48 | const std::map& gradOutToNonGradIn() const { 49 | static const std::map outInfo = {{0, 0}}; 50 | return outInfo; 51 | } 52 | }; 53 | 54 | struct CustomOpx : popart::popx::Opx { 55 | CustomOpx(popart::Op* op, popart::popx::Devicex* devicex) : popart::popx::Opx(op, devicex) { 56 | verifyOp(op, Onnx::CustomOperators::AllToAll); 57 | } 58 | void grow(poplar::program::Sequence& prog) const final { 59 | auto input = get(inId(0)); 60 | auto output = gcl::allToAllCrossReplica(graph(), input, prog, {}, debugContext("allToAll")); 61 | insert(outId(0), output); 62 | } 63 | }; 64 | 65 | popart::OpDefinition::DataTypes T = {popart::DataType::FLOAT16, popart::DataType::FLOAT}; 66 | popart::OpCreator opCreator( 67 | {{Onnx::CustomOperators::AllToAll, 68 | {popart::OpDefinition::Inputs({{"input", T}}), popart::OpDefinition::Outputs({{"output", T}}), 69 | popart::OpDefinition::Attributes({})}}}, 70 | [](const popart::OpCreatorInfo& info) { 71 | return std::make_unique(info.opid, info.settings); 72 | }, 73 | true); 74 | popart::popx::OpxCreator opxCreator(Onnx::CustomOperators::AllToAll); 75 | } // namespace all_to_all 76 | } // namespace 77 | 78 | // RemoveAllReducePattern 79 | 80 | namespace { 81 | struct RemoveAllReducePattern : popart::PreAliasPattern { 82 | bool matches(popart::Op* op) const override { 83 | return op->isConvertibleTo(); 84 | } 85 | 86 | std::vector touches(popart::Op*) const override { return {}; } 87 | 88 | bool apply(popart::Op* op) const override { 89 | auto rar_op = static_cast(op); 90 | if (rar_op->getReplicaGrouping().getGroupSize() == 1) { 91 | popart::Tensor* in_rar = rar_op->inTensor(popart::ReplicatedAllReduceOp::getInIndex()); 92 | popart::Tensor* out_rar = 93 | rar_op->outTensor(popart::ReplicatedAllReduceOp::getOutIndex()); 94 | // std::cerr << "Removing ReplicatedAllReduceOp with groupSize=1: " << in_rar->id 95 | // << std::endl; 96 | for (auto cons : out_rar->consumers.getOps()) { 97 | for (auto in_index : cons->input->indices(out_rar)) { 98 | cons->disconnectInTensor(out_rar); 99 | cons->connectInTensor(in_index, in_rar->id); 100 | } 101 | } 102 | op->disconnectAllInputs(); 103 | op->disconnectAllOutputs(); 104 | op->getGraph().eraseOp(rar_op->id); 105 | return true; 106 | } 107 | return false; 108 | } 109 | }; 110 | 111 | static popart::PatternCreator RemoveAllReducePatternCreator( 112 | "RemoveAllReducePattern", 113 | false); 114 | } // namespace 115 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Distributed KGE (C++) 2 | 3 | IPU implementation of a sharded knowledge graph embedding (KGE) model, implemented in Poplar for execution using DRAM on an IPU-POD16. 4 | 5 | Note that this is a low-level implementation for advanced IPU usage. 6 | 7 | _See also: [PyTorch KGE demo notebook](doc/pytorch_demo)._ 8 | 9 | 10 | ## Usage 11 | 12 | ### First-time setup 13 | 14 | 0. Ensure `clang++` and `ninja` are installed. 15 | 1. Clone this repository with `--recurse-submodules`. 16 | 2. Install Poplar SDK and activate with `source $POPLAR_SDK_DIR/enable`. 17 | 3. Create and activate a Python virtual environment. 18 | 4. Install Python requirements `pip install -r requirements-dev.txt` 19 | 5. Check everything is working by running `./dev` (see also `./dev --help`). 20 | 21 | For example: 22 | 23 | ```sh 24 | sudo apt-get install clang++ ninja 25 | git clone --recurse-submodules REPO 26 | source $POPLAR_SDK_DIR/enable 27 | virtualenv -p python3 .venv 28 | source .venv/bin/activate 29 | pip install -r requirements-dev.txt 30 | ./dev --help 31 | ./dev 32 | ``` 33 | 34 | ### Training 35 | 36 | Our standard training script is in [scripts/run_training.py](scripts/run_training.py). To build the core C++ code, add it to the path and run training, 37 | 38 | ```sh 39 | ./dev train 40 | ``` 41 | 42 | This trains a TransE model with embedding size 256. 43 | 44 | Note: 45 | - Build and develelopment automation is provided by the `./dev` script, which generates a ninja build file (`build/build.ninja`). 46 | - You may wish to change C++ compiler, e.g. `env CXX=g++ ./dev ...` 47 | - The training script expects the OGB WikiKG90Mv2 dataset to be downloaded to `$OGBWIKIKG_PATH`; see the [OGB WikiKG90Mv2 page](https://ogb.stanford.edu/docs/lsc/wikikg90mv2/) for instructions. 48 | 49 | 50 | ## About 51 | 52 | The application is a self-contained research platform for KGE models, using Poplar/PopLibs directly for execution on IPU, PyTorch for data loading and numpy for batching and interchange. Since model checkpoints would be very large, all training, evaluation and prediction tasks are run in a single job via `run_training.py`. 53 | 54 | The main components are: 55 | 56 | - [scripts/{run_training.py, run_profile.py}](scripts/) - top-level entry points, note that we use Python configuration in place of a command line interface 57 | - Core model & training 58 | - [src/poplar_kge.cpp](src/poplar_kge.cpp) - core model and training step definition 59 | - [src/python/poplar_kge.py](src/python/poplar_kge.py) - Python glue code & experiment settings 60 | - [src/python/poplar_kge_dataset.py](src/python/poplar_kge_dataset.py) - data sampling & batching 61 | - Library-like components 62 | - [src/pag/](src/pag/) - Poplar AutoGrad (PAG), a self-contained mini-library for adding automatic differentiation to PopLibs programs 63 | - [src/fructose/](src/fructose/) - Fructose, a self-contained mini-library for a friendly, noise-free interface to PAG 64 | - [src/poplar_extensions/](src/poplar_extensions/) - custom device codelets, with a PopLibs-like interface, for efficient L1/L2 distance 65 | 66 | See also [doc/design.md](doc/design.md) for a more detailed description of the design of the application. 67 | 68 | ### Poplar remote buffers 69 | 70 | We rely on Poplar's access to streaming memory in this code (see [IPU memory architecture](https://docs.graphcore.ai/projects/ipu-programmers-guide/en/latest/about_ipu.html#memory-architecture)), which enables sparse access to a much larger memory store. This is accessed via the [remote memory buffers](https://docs.graphcore.ai/projects/poplar-user-guide/en/latest/poplar_programs.html#remote-memory-buffers) API. 71 | 72 | One implementation detail of interest is that we stack all remote embedding state (consisting of entity features, embeddings and optimiser state) into a single remote buffer, which helps to minimise memory overhead due to padding. 73 | 74 | ### References & license 75 | 76 | The included code is released under a MIT license (see [LICENSE](LICENSE)). 77 | 78 | Copyright (c) 2022 Graphcore Ltd. Licensed under the MIT License 79 | 80 | Our dependencies are: 81 | 82 | | Component | Type | About | License | 83 | | --- | --- | --- | --- | 84 | | pybind11 | submodule | C++/Python interop library ([github](https://github.com/pybind/pybind11)) | BSD 3-Clause | 85 | | Catch2 | submodule | C++ unit testing framework ([github](https://github.com/catchorg/Catch2)) | Boost | 86 | | OGB | `requirements.txt` | Open Graph Benchmark dataset and task definition ([paper](https://arxiv.org/abs/2103.09430), [website](https://ogb.stanford.edu/)) | MIT | 87 | | PyTorch | `requirements.txt` | Machine learning framework ([website](https://pytorch.org/)) | BSD 3-Clause | 88 | | WandB | `requirements.txt` | Weights and Biases client library ([website](https://wandb.ai/)), for optional logging to wandb servers | MIT | 89 | 90 | We also use ninja ([website](https://ninja-build.org/)) with clang++ from LLVM ([website](https://clang.llvm.org/)) to build C++ code and additional Python dependencies for development/testing (see [requirements-dev.txt](requirements-dev.txt)). 91 | 92 | The OGB WikiKG90Mv2 dataset is licenced under CC-0. 93 | -------------------------------------------------------------------------------- /tests/fructose/testFructose.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include "fructose/frnn.hpp" 11 | #include "fructose/fructose.hpp" 12 | 13 | namespace { 14 | struct HostTensor { 15 | std::vector shape; 16 | std::vector data; 17 | 18 | HostTensor() = default; 19 | HostTensor(const std::vector& shape) 20 | : shape(shape), data(fr::util::numElements(shape)) {} 21 | 22 | float scalar() const { 23 | assert(shape.size() == 0); 24 | return data.front(); 25 | } 26 | }; 27 | 28 | struct TestHelper { 29 | poplar::Device device; 30 | fr::RootFrame rootFrame; 31 | std::unique_ptr engine; 32 | 33 | explicit TestHelper(unsigned nIpus = 1) 34 | : device(poplar::Device::createCPUDevice(nIpus)), rootFrame(device.getTarget()) {} 35 | 36 | void load() { 37 | engine.reset(new poplar::Engine(rootFrame.graph.poplar(), rootFrame.tape.prog())); 38 | engine->load(device); 39 | } 40 | std::unordered_map run() { 41 | std::unordered_map result; 42 | for (auto& item : rootFrame.streams) { 43 | assert(item.second.spec().dtype == poplar::FLOAT); 44 | result.insert({item.first, HostTensor(item.second.spec().shape)}); 45 | engine->connectStream(item.first, result[item.first].data.data()); 46 | } 47 | engine->run(); 48 | return result; 49 | } 50 | std::unordered_map loadAndRun() { 51 | load(); 52 | return run(); 53 | } 54 | }; 55 | } // namespace 56 | 57 | TEST_CASE("fr::nn::softmaxCrossEntropy", "[fr]") { 58 | TestHelper test; 59 | popops::addCodelets(test.rootFrame.graph.poplar()); 60 | 61 | auto xent = fr::nn::softmaxCrossEntropy(fr::ops::constant({-100, 2, 2, 2, 2}), 62 | fr::ops::constant(1u)); 63 | fr::ops::output("xent", xent); 64 | 65 | auto result = test.loadAndRun(); 66 | REQUIRE(result["xent"].scalar() == Approx(std::log(4))); 67 | } 68 | 69 | TEST_CASE("fructose basics", "[fr]") { 70 | TestHelper test; 71 | 72 | auto a = fr::ops::variable("a", {{}, poplar::FLOAT}); 73 | auto b = a + a; 74 | auto c = b + b; 75 | c.backward(); 76 | fr::ops::output("c", c); 77 | fr::ops::output("grad_a", a.grad()); 78 | fr::ops::output("grad_b", b.grad()); 79 | a.hostAccess(); 80 | 81 | test.load(); 82 | test.engine->writeTensor(a.name(), {10.0f}); 83 | auto result = test.run(); 84 | REQUIRE(result["c"].scalar() == Approx(40.0f)); 85 | REQUIRE(result["grad_a"].scalar() == Approx(4.0f)); 86 | REQUIRE(result["grad_b"].scalar() == Approx(2.0f)); 87 | } 88 | 89 | namespace { 90 | std::vector getTensorToTileMapping(const fr::Tensor& tensor) { 91 | auto& frame = fr::Environment::frame(); 92 | auto tileMapping = frame.graph.poplar().getTileMapping(frame.graph.unwrap(tensor.pag())); 93 | std::vector tensorMapping(tensor.numElements()); 94 | for (auto tile = 0u; tile < tileMapping.size(); ++tile) { 95 | for (auto interval : tileMapping[tile]) { 96 | for (auto i = interval.begin(); i < interval.end(); ++i) { 97 | tensorMapping[i] = tile; 98 | } 99 | } 100 | } 101 | return tensorMapping; 102 | } 103 | } // namespace 104 | 105 | TEST_CASE("fr::mapping::OneTile", "[fr]") { 106 | fr::RootFrame frame(poplar::Target::createIPUTarget(1, 6, "IPU-POD16")); 107 | 108 | auto a = fr::ops::variable("a", {{2}, poplar::FLOAT}); 109 | fr::mapping::setDefault(fr::mapping::OneTile(), {a}); // default = -1 110 | REQUIRE(getTensorToTileMapping(a) == std::vector{5, 5}); 111 | 112 | auto b = fr::ops::variable("b", {{2}, poplar::FLOAT}); 113 | fr::mapping::setDefault(fr::mapping::OneTile(9), {b}); // wraparound 114 | REQUIRE(getTensorToTileMapping(b) == std::vector{3, 3}); 115 | 116 | auto c = fr::ops::variable("c", {{2}, poplar::FLOAT}); 117 | fr::mapping::setDefault(fr::mapping::OneTile(-8), {c}); // negative wraparound 118 | REQUIRE(getTensorToTileMapping(c) == std::vector{4, 4}); 119 | } 120 | 121 | TEST_CASE("fr::Buffer", "[fr]") { 122 | TestHelper test; 123 | 124 | fr::Buffer buf("buf", {{3, 1, 4}, poplar::FLOAT}); 125 | buf.write(fr::ops::constant({100, 101, 102, 103, // 126 | 200, 201, 202, 203, // 127 | 300, 301, 302, 303}, 128 | {{3, 1, 4}}), 129 | fr::ops::constant({2, 0, 1})); 130 | 131 | auto out = buf.read(fr::ops::constant({0u, 2u})); 132 | REQUIRE(out.shape() == std::vector({2, 1, 4})); 133 | fr::ops::output("out", out); 134 | 135 | using Catch::Matchers::Approx; 136 | auto result = test.loadAndRun(); 137 | REQUIRE_THAT(result["out"].data, Approx({200, 201, 202, 203, // 138 | 100, 101, 102, 103})); 139 | } 140 | -------------------------------------------------------------------------------- /tests/python/test_poplar_kge_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | 3 | import itertools as it 4 | from typing import Tuple 5 | 6 | import numpy as np 7 | import poplar_kge as kge 8 | import poplar_kge_dataset as kge_data 9 | import pytest 10 | 11 | 12 | def test_entity_features() -> None: 13 | data = kge_data.RawData.generate( 14 | n_entity=70, 15 | n_relation_type=100, 16 | feature_size=8, 17 | n_train=100, 18 | n_eval=50, 19 | seed=29483, 20 | ) 21 | 22 | def check_features(mapping: str, feature_size: int) -> None: 23 | dataset = kge_data.Dataset( 24 | data, 25 | n_shard=4, 26 | train_steps_per_program_run=10, 27 | settings=kge.DataSettings( 28 | seed=9842, 29 | batch_size=16, 30 | a2a_size=10, 31 | entity_feature_mapping=mapping, 32 | dataset=None, # type:ignore[arg-type] 33 | sampling_strategy=None, 34 | ), 35 | ) 36 | features = dataset.entity_features(feature_size=feature_size) 37 | assert features.shape == (4, int(1 + np.ceil(70 / 4)), feature_size) 38 | assert features.dtype == np.float16 39 | 40 | check_features("zero", feature_size=6) 41 | check_features("full", feature_size=8) 42 | check_features("random_projection", feature_size=6) 43 | 44 | with pytest.raises(ValueError): 45 | check_features("full", feature_size=6) 46 | 47 | 48 | def _get_triples( 49 | dataset: kge_data.Dataset, batch: kge.Batch 50 | ) -> Tuple[np.ndarray, np.ndarray]: 51 | """Get the positive and negative (h, r, t) triples from a batch. 52 | 53 | returns -- (positive[:, {h,r,t}], negatives[:]) 54 | """ 55 | _ = np.newaxis # a lot of fancy indexing 56 | 57 | idx_to_entity = np.full((dataset.n_shard, dataset.n_entity_per_shard), -1) 58 | idx_to_entity[dataset.entity_to_shard, dataset.entity_to_idx] = np.arange( 59 | dataset.data.n_entity 60 | ) 61 | shards = np.arange(dataset.n_shard) 62 | n_batch = dataset.train_steps_per_program_run 63 | batches = np.arange(n_batch) 64 | 65 | remote_entity = idx_to_entity[shards[:, _, _], batch.remote] 66 | head_entity = remote_entity[shards[:, _, _], batches[_, :, _], batch.head] 67 | a2a_entity = ( 68 | remote_entity[shards[:, _, _, _], batches[_, :, _, _], batch.a2a] 69 | .transpose(2, 1, 0, 3) 70 | .reshape(dataset.n_shard, n_batch, -1) 71 | ) 72 | assert not len(a2a_entity[(a2a_entity < 0) | (a2a_entity >= dataset.data.n_entity)]) 73 | tail_entity = a2a_entity[shards[:, _, _], batches[_, :, _], batch.tail] 74 | mask = np.full((dataset.n_shard, n_batch, a2a_entity.shape[-1]), True) 75 | mask[shards[:, _, _], batches[_, :, _], batch.tail] = False 76 | negative_tail_entity = a2a_entity[mask] 77 | return ( 78 | np.stack([head_entity, batch.relation, tail_entity], axis=-1).reshape(-1, 3), 79 | negative_tail_entity.flatten(), 80 | ) 81 | 82 | 83 | def test_sample_batch() -> None: 84 | data = kge_data.RawData.generate( 85 | n_entity=70, 86 | n_relation_type=1024, 87 | feature_size=0, 88 | n_train=100, 89 | n_eval=50, 90 | seed=2835, 91 | ) 92 | n_shard = 4 93 | n_batch = 11 94 | s = kge.DataSettings( 95 | seed=12039, 96 | batch_size=8, 97 | a2a_size=5, 98 | entity_feature_mapping="full", 99 | dataset=None, # type:ignore[arg-type] 100 | sampling_strategy=kge.CubicRootRelationSampling(), 101 | ) 102 | dataset = kge_data.Dataset( 103 | data, n_shard=n_shard, train_steps_per_program_run=n_batch, settings=s 104 | ) 105 | 106 | # Check basic shapes 107 | batch = dataset.sample_batch() 108 | assert batch.remote.shape == (n_shard, n_batch, s.batch_size + n_shard * s.a2a_size) 109 | assert batch.a2a.shape == (n_shard, n_batch, n_shard, s.a2a_size) 110 | for section in [batch.head, batch.relation, batch.tail]: 111 | assert section.shape == (n_shard, n_batch, s.batch_size) 112 | assert {k: v.dtype.name for k, v in batch.__dict__.items()} == dict( 113 | remote="uint32", 114 | a2a="uint32", 115 | head="uint32", 116 | relation="uint32", 117 | tail="uint32", 118 | ) 119 | 120 | # Check contiguous 121 | for k, v in batch.__dict__.items(): 122 | assert v.flags.c_contiguous, f"batch.{k} is not C-contiguous" 123 | 124 | # Check invariants 125 | all_hrt = set(map(tuple, data.train_hrt)) 126 | missed_hrt = all_hrt.copy() 127 | all_negative = set(np.arange(data.n_entity)) 128 | missed_negative = all_negative.copy() 129 | for batch in it.islice(dataset.batches(), 100): 130 | positive, negative = _get_triples(dataset, batch) 131 | positive_hrt = set(map(tuple, positive)) 132 | assert not positive_hrt - all_hrt 133 | missed_hrt -= positive_hrt 134 | assert not set(negative) - all_negative 135 | missed_negative -= set(negative) 136 | # Enough samples make these tests almost certain 137 | assert not missed_hrt, "unlikely to omit any positive (h,r,t)" 138 | assert not missed_negative, "unlikely to omit any negative tail entities" 139 | -------------------------------------------------------------------------------- /src/poplar_extensions/l1distance.codelet.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | 3 | #include 4 | 5 | #ifdef __IPU__ 6 | #include 7 | #endif 8 | #include 9 | #include 10 | 11 | static constexpr auto ONE_PTR = poplar::VectorLayout::ONE_PTR; 12 | static constexpr auto SPAN = poplar::VectorLayout::SPAN; 13 | 14 | #ifdef __IPU__ 15 | 16 | float l1dist(const float* a, const float* b, size_t size) { 17 | auto a2 = reinterpret_cast(a); 18 | auto b2 = reinterpret_cast(b); 19 | 20 | float2 sum = {0.0, 0.0}; 21 | for (size_t i = 0; i < size / 2; ++i) { 22 | sum += ipu::fabs(a2[i] - b2[i]); 23 | } 24 | float res = sum[0] + sum[1]; 25 | if (size % 2) res += ipu::fabs(a[size - 1] - b[size - 1]); 26 | return res; 27 | } 28 | 29 | float l1dist(const half* a, const half* b, size_t size) { 30 | auto a4 = reinterpret_cast(a); 31 | auto b4 = reinterpret_cast(b); 32 | 33 | half4 sum = {0.0, 0.0, 0.0, 0.0}; 34 | for (size_t i = 0; i < size / 4; ++i) { 35 | sum += ipu::fabs(a4[i] - b4[i]); 36 | } 37 | float res = float(sum[0]) + float(sum[1]) + float(sum[2]) + float(sum[3]); 38 | size_t rem = size % 4; 39 | if (rem) { 40 | for (size_t i = size - rem; i < size; ++i) { 41 | res += ipu::fabs(float(a[i] - b[i])); 42 | } 43 | } 44 | return res; 45 | } 46 | 47 | #else // !__IPU__ 48 | 49 | template 50 | float l1dist(const T* a, const T* b, size_t size) { 51 | float sum = 0.0; 52 | for (size_t i = 0; i < size; ++i) { 53 | sum += std::fabs(float(a[i] - b[i])); 54 | } 55 | return sum; 56 | } 57 | 58 | #endif // __IPU__ 59 | 60 | template 61 | class L1DistanceSingleVertex : public poplar::Vertex { 62 | public: 63 | poplar::Input> a; 64 | poplar::Input> b; 65 | // Note: Output generates a slower 16-bit write, so we always output float 66 | poplar::Output out; 67 | 68 | bool compute() { 69 | *out = l1dist(&a[0], &b[0], a.size()); 70 | return true; 71 | } 72 | }; 73 | template class L1DistanceSingleVertex; 74 | template class L1DistanceSingleVertex; 75 | 76 | template 77 | static inline T signum(T x) { 78 | return T((T(0.0) < x) - (x < T(0.0))); 79 | } 80 | 81 | #ifdef __IPU__ 82 | 83 | static inline half4 signum(half4 x) { 84 | constexpr half4 one = {1.0f, 1.0f, 1.0f, 1.0f}; 85 | const auto pOne = reinterpret_cast(&one); 86 | const auto xNonzero = x != half4{0.0f, 0.0f, 0.0f, 0.0f}; 87 | const auto pXNonzero = reinterpret_cast(&xNonzero); 88 | const uint32_t pZeroOrOne[2] = {pXNonzero[0] & pOne[0], pXNonzero[1] & pOne[1]}; 89 | return ipu::copysign(*reinterpret_cast(pZeroOrOne), x); 90 | } 91 | 92 | static inline float2 signum(float2 x) { 93 | constexpr float2 one = {1.0f, 1.0f}; 94 | const auto pOne = reinterpret_cast(&one); 95 | const auto xNonzero = x != float2{0.0f, 0.0f}; 96 | const auto pXNonzero = reinterpret_cast(&xNonzero); 97 | const uint32_t pZeroOrOne[2] = {pXNonzero[0] & pOne[0], pXNonzero[1] & pOne[1]}; 98 | return ipu::copysign(*reinterpret_cast(pZeroOrOne), x); 99 | } 100 | 101 | float l1distgrad(const float& a, const float* b, const float* grad, size_t size) { 102 | float2 sum = {0.0, 0.0}; 103 | auto grad2 = reinterpret_cast(grad); 104 | auto b2 = reinterpret_cast(b); 105 | float2 a2 = {a, a}; 106 | 107 | for (size_t i = 0; i < size / 2; ++i) { 108 | sum += grad2[i] * signum(a2 - b2[i]); 109 | } 110 | float res = sum[0] + sum[1]; 111 | if (size % 2) { 112 | res += grad[size - 1] * signum(a - b[size - 1]); 113 | } 114 | return res; 115 | } 116 | 117 | float l1distgrad(const half& a, const half* b, const half* grad, size_t size) { 118 | half4 sum = {0.0, 0.0}; 119 | auto grad4 = reinterpret_cast(grad); 120 | auto b4 = reinterpret_cast(b); 121 | half4 a4 = {a, a, a, a}; 122 | 123 | for (size_t i = 0; i < size / 4; ++i) { 124 | sum += grad4[i] * signum(a4 - b4[i]); 125 | } 126 | float res = float(sum[0]) + float(sum[1]) + float(sum[2]) + float(sum[3]); 127 | size_t rem = size % 4; 128 | if (rem) { 129 | for (size_t i = size - rem; i < size; ++i) { 130 | res += float(grad[i] * signum(a - b[i])); 131 | } 132 | } 133 | return res; 134 | } 135 | 136 | #else // !__IPU__ 137 | 138 | template 139 | float l1distgrad(const T& a, const T* b, const T* grad, size_t size) { 140 | float sum = 0.0; 141 | for (size_t i = 0; i < size; ++i) { 142 | sum += float(grad[i]) * float(signum(a - b[i])); 143 | } 144 | return sum; 145 | } 146 | 147 | #endif // __IPU__ 148 | 149 | template 150 | class L1DistanceGradSingleVertex : public poplar::Vertex { 151 | public: 152 | poplar::Input a; 153 | poplar::Input> b; 154 | poplar::Input> gradOutput; 155 | // Note: Output generates a slower 16-bit write, so we always output float 156 | poplar::Output grad; 157 | 158 | bool compute() { 159 | *grad = l1distgrad(*a, &b[0], &gradOutput[0], b.size()); 160 | return true; 161 | } 162 | }; 163 | 164 | template class L1DistanceGradSingleVertex; 165 | template class L1DistanceGradSingleVertex; 166 | -------------------------------------------------------------------------------- /doc/pytorch_demo/Preprocessing.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "3d02d392-75f2-402e-be93-5032fd879b3c", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "# Copyright (c) 2022 Graphcore Ltd. All rights reserved." 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "id": "1ebc7833-d4f4-44c8-8b64-007e1fa99573", 16 | "metadata": {}, 17 | "source": [ 18 | "# Build entity mapping database\n", 19 | "\n", 20 | "_Note: this notebook is included to document the process and so will not run out of the box._\n", 21 | "\n", 22 | "In order to provide an interface for querying the model, we build a full text searchable mapping from string to entity and relation ID. To do this, we:\n", 23 | " - Download a full list of entity and relation labels from wikidata.\n", 24 | " - Use OGB's `data/ogbl_wikikg2/mapping/` metadata to filter entities of interest, and map them to contiguous OGB dataset IDs.\n", 25 | " - Build a SQLite database with FTS3 indicies for efficient local retrieval.\n", 26 | " \n", 27 | "Also contains the command to build a faster-loading `.npz` file containing ogbl-wikikg2." 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "id": "f3e2dcd4-da30-4703-9315-73f5f62b4413", 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "import bz2\n", 38 | "import json\n", 39 | "import tqdm\n", 40 | "import itertools as it\n", 41 | "from pathlib import Path\n", 42 | "import csv\n", 43 | "import gzip\n", 44 | "import sys\n", 45 | "\n", 46 | "import kge_mapping\n", 47 | "import kge_training" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": null, 53 | "id": "33addfcb-8a72-4150-8c8b-37c269c00fd2", 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "%%time\n", 58 | "\n", 59 | "# Source:\n", 60 | "# https://dumps.wikimedia.org/wikidatawiki/entities/latest-all.json.bz2\n", 61 | "\n", 62 | "labels = {}\n", 63 | "path = Path(\"/localdata/research/scratch/douglaso/latest-all.json.bz2\")\n", 64 | "\n", 65 | "f0 = open(path, \"rb\")\n", 66 | "f = bz2.BZ2File(f0)\n", 67 | "f.readline() # opening \"[\"\n", 68 | "tq = tqdm.tqdm(it.count())\n", 69 | "for n in tq:\n", 70 | " line = f.readline().decode().rstrip(\"\\n ,\")\n", 71 | " if line == \"]\":\n", 72 | " break\n", 73 | " e = json.loads()\n", 74 | " labels[e[\"id\"]] = e[\"labels\"].get(\"en\", dict(value=\"\"))[\"value\"]\n", 75 | " if n % int(1e3) == 0:\n", 76 | " tq.set_description(f\"{f0.tell() / path.stat().st_size:.0%}, {f0.tell() / 2**30:.1f} GiB\")" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "id": "e2d35fe3-fe10-4232-867f-93c135797a2b", 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "%%time\n", 87 | "\n", 88 | "# Source:\n", 89 | "# https://hay.toolforge.org/propbrowse/props.json\n", 90 | "\n", 91 | "with Path(\"/localdata/research/scratch/douglaso/props.json\").open() as f:\n", 92 | " props = {item[\"id\"]: item[\"label\"] for item in json.load(f)}" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": null, 98 | "id": "edab16f6-72ef-4f4b-9b2c-17688e1fb190", 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "%%time\n", 103 | "\n", 104 | "records = []\n", 105 | "\n", 106 | "with gzip.open(\"data/ogbl_wikikg2/mapping/nodeidx2entityid.csv.gz\", \"rt\") as f:\n", 107 | " for item in csv.DictReader(f):\n", 108 | " records.append(dict(\n", 109 | " type=\"entity\",\n", 110 | " idx=int(item[\"node idx\"]),\n", 111 | " wikidata_id=item[\"entity id\"],\n", 112 | " wikidata_label=labels.get(item[\"entity id\"], \"\"),\n", 113 | " ))\n", 114 | "\n", 115 | "with gzip.open(\"data/ogbl_wikikg2/mapping/reltype2relid.csv.gz\", \"rt\") as f:\n", 116 | " for item in csv.DictReader(f):\n", 117 | " records.append(dict(\n", 118 | " type=\"relation\",\n", 119 | " idx=int(item[\"reltype\"]),\n", 120 | " wikidata_id=item[\"rel id\"],\n", 121 | " wikidata_label=props.get(item[\"rel id\"], \"\"),\n", 122 | " ))\n", 123 | "\n", 124 | "print(f\"Missing entity labels: {sum(not r['wikidata_label'] for r in records if r['type'] == 'entity') / (sum(1 for r in records if r['type'] == 'entity')):.1%}\")\n", 125 | "print(f\"Missing relation labels: {sum(not r['wikidata_label'] for r in records if r['type'] == 'relation') / (sum(1 for r in records if r['type'] == 'relation')):.1%}\")\n", 126 | "\n", 127 | "with gzip.open(\"data/ogbl_wikikg2_mapping.jsonl.gz\", \"wt\") as f:\n", 128 | " for record in records:\n", 129 | " print(json.dumps(record), file=f)" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": null, 135 | "id": "392c9edd-d861-4efb-8fc9-f84a27763c31", 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [ 139 | "%%time\n", 140 | "\n", 141 | "kge_mapping.Database.build(\n", 142 | " Path(\"data/ogbl_wikikg2_mapping.sqlite\"),\n", 143 | " kge_mapping.RawData.load(Path(\"data/ogbl_wikikg2_mapping.jsonl.gz\"), Path(\"data\")),\n", 144 | ")" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "id": "e1b3d8ea-fb6f-40ea-9694-9f44f9b13483", 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [ 154 | "%%time\n", 155 | "\n", 156 | "kge_training.Dataset.build_wikikg2(\n", 157 | " Path(\"data\"),\n", 158 | " Path(\"data/ogbl_wikikg2.npz\"),\n", 159 | " seed=1000,\n", 160 | ")" 161 | ] 162 | } 163 | ], 164 | "metadata": { 165 | "kernelspec": { 166 | "display_name": "Python 3 (ipykernel)", 167 | "language": "python", 168 | "name": "python3" 169 | }, 170 | "language_info": { 171 | "codemirror_mode": { 172 | "name": "ipython", 173 | "version": 3 174 | }, 175 | "file_extension": ".py", 176 | "mimetype": "text/x-python", 177 | "name": "python", 178 | "nbconvert_exporter": "python", 179 | "pygments_lexer": "ipython3", 180 | "version": "3.8.10" 181 | } 182 | }, 183 | "nbformat": 4, 184 | "nbformat_minor": 5 185 | } 186 | -------------------------------------------------------------------------------- /src/poplar_extensions/l2distance.codelet.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | 3 | #include 4 | 5 | #ifdef __IPU__ 6 | #include 7 | #include 8 | #endif 9 | #include 10 | #include 11 | 12 | static constexpr auto ONE_PTR = poplar::VectorLayout::ONE_PTR; 13 | static constexpr auto SPAN = poplar::VectorLayout::SPAN; 14 | 15 | #ifdef __IPU__ 16 | 17 | float l2dist(const float* a, const float* b, size_t size) { 18 | auto a2 = reinterpret_cast(a); 19 | auto b2 = reinterpret_cast(b); 20 | 21 | float2 sum = {0.0, 0.0}; 22 | for (size_t i = 0; i < size / 2; ++i) { 23 | float2 diff = a2[i] - b2[i]; 24 | sum += diff * diff; 25 | } 26 | float res = sum[0] + sum[1]; 27 | if (size % 2) { 28 | float diff = a[size - 1] - b[size - 1]; 29 | res += diff * diff; 30 | } 31 | return ipu::sqrt(res); 32 | } 33 | 34 | float l2dist(const half* a, const half* b, size_t size) { 35 | auto a2 = reinterpret_cast(a); 36 | auto b2 = reinterpret_cast(b); 37 | 38 | float2 sum = {0.0, 0.0}; 39 | for (size_t i = 0; i < size / 2; ++i) { 40 | auto diff = a2[i] - b2[i]; 41 | float2 diff32 = {float(diff[0]), float(diff[1])}; 42 | sum += diff32 * diff32; 43 | } 44 | auto res = sum[0] + sum[1]; 45 | if (size % 2) { 46 | auto diff = float(a[size - 1] - b[size - 1]); 47 | res += diff * diff; 48 | } 49 | return ipu::sqrt(res); 50 | } 51 | 52 | #else // !__IPU__ 53 | 54 | template 55 | float l2dist(const T* a, const T* b, size_t size) { 56 | float sum = 0.0; 57 | for (size_t i = 0; i < size; ++i) { 58 | float diff = float(a[i] - b[i]); 59 | sum += diff * diff; 60 | } 61 | return std::sqrt(sum); 62 | } 63 | 64 | #endif // __IPU__ 65 | 66 | template 67 | class L2DistanceSingleVertex : public poplar::Vertex { 68 | public: 69 | poplar::Input> a; 70 | poplar::Input> b; 71 | // Note: Output generates a slower 16-bit write, so we always output float 72 | poplar::Output out; 73 | 74 | bool compute() { 75 | *out = l2dist(&a[0], &b[0], a.size()); 76 | return true; 77 | } 78 | }; 79 | template class L2DistanceSingleVertex; 80 | template class L2DistanceSingleVertex; 81 | 82 | #ifdef __IPU__ 83 | 84 | float l2distgrad(const float& a, 85 | const float* b, 86 | const float* dist, 87 | const float* grad, 88 | size_t size) { 89 | float2 a2 = {a, a}; 90 | auto b2 = reinterpret_cast(b); 91 | auto dist2 = reinterpret_cast(dist); 92 | auto grad2 = reinterpret_cast(grad); 93 | 94 | float2 sum = {0.0, 0.0}; 95 | for (size_t i = 0; i < size / 2; ++i) { 96 | const float2 diff = a2 - b2[i]; 97 | const float2 dist2_i = dist2[i]; 98 | const auto comp = dist2_i == float2{0.0f, 0.0f}; 99 | const uint32_t dist_mask[2] = {0x3f800000, 0x3f800000}; // 1.0f 100 | const uint32_t dist_delta[2] = {comp[0] & dist_mask[0], comp[1] & dist_mask[1]}; 101 | // if dist == 0 -> safe_dist=1.0. 102 | const float2 safe_dist = dist2_i + *reinterpret_cast(dist_delta); 103 | const float2 masked_diff = ipu::andc(diff, *reinterpret_cast(&comp)); 104 | sum += grad2[i] * masked_diff / safe_dist; 105 | } 106 | float res = sum[0] + sum[1]; 107 | if (size % 2) { 108 | float diff = a - b[size - 1]; 109 | float dist_i = float(dist[size - 1]); 110 | res += dist_i == 0.0f ? 0.0f : float(grad[size - 1]) * diff / dist_i; 111 | } 112 | return res; 113 | } 114 | 115 | float l2distgrad(const half& a, const half* b, const half* dist, const half* grad, size_t size) { 116 | half4 a4 = {a, a, a, a}; 117 | auto b4 = reinterpret_cast(b); 118 | auto dist4 = reinterpret_cast(dist); 119 | auto grad4 = reinterpret_cast(grad); 120 | 121 | half4 sum = {0.0, 0.0, 0.0, 0.0}; 122 | for (size_t i = 0; i < size / 4; ++i) { 123 | const half4 diff = a4 - b4[i]; 124 | const half4 dist4_i = dist4[i]; 125 | const auto comp = dist4_i == half4{0.0f, 0.0f, 0.0f, 0.0f}; 126 | const uint16_t dist_mask[4] = {0x3f80, 0x3f80, 0x3f80, 0x3f80}; // 1.0f 127 | const auto p_comp = reinterpret_cast(&comp); 128 | const auto p_mask = reinterpret_cast(dist_mask); 129 | const uint32_t dist_delta[2] = {p_comp[0] & p_mask[0], p_comp[1] & p_mask[1]}; 130 | // if dist == 0 -> safe_dist=1.0. 131 | const half4 safe_dist = dist4_i + *reinterpret_cast(dist_delta); 132 | const float2 masked_diff = ipu::andc(*reinterpret_cast(&diff), 133 | *reinterpret_cast(&comp)); 134 | sum += grad4[i] * *reinterpret_cast(&masked_diff) / safe_dist; 135 | } 136 | float2 res2 = ipu::sum(sum); 137 | float res = res2[0] + res2[1]; 138 | size_t rem = size % 4; 139 | for (size_t i = size - rem; i < size; ++i) { 140 | float diff = float(a) - float(b[i]); 141 | float dist_i = float(dist[i]); 142 | res += dist_i == 0.0f ? 0.0f : float(grad[i]) * diff / dist_i; 143 | } 144 | return res; 145 | } 146 | 147 | #else // !__IPU__ 148 | 149 | template 150 | float l2distgrad(const T& a, const T* b, const T* dist, const T* grad, size_t size) { 151 | float sum = 0.0; 152 | for (size_t i = 0; i < size; ++i) { 153 | float diff = float(a - b[i]); 154 | float dist_i = float(dist[i]); 155 | sum += dist_i == 0.0f ? 0.0f : float(grad[i]) * diff / dist_i; 156 | } 157 | return sum; 158 | } 159 | 160 | #endif 161 | 162 | template 163 | class L2DistanceGradSingleVertex : public poplar::Vertex { 164 | public: 165 | poplar::Input a; 166 | poplar::Input> b; 167 | poplar::Input> dist; 168 | poplar::Input> gradOutput; 169 | // Note: Output generates a slower 16-bit write, so we always output float 170 | poplar::Output grad; 171 | 172 | bool compute() { 173 | *grad = l2distgrad(*a, &b[0], &dist[0], &gradOutput[0], b.size()); 174 | return true; 175 | } 176 | }; 177 | 178 | template class L2DistanceGradSingleVertex; 179 | template class L2DistanceGradSingleVertex; 180 | -------------------------------------------------------------------------------- /tests/python/reference_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | 3 | """A reference implementation in plain PyTorch, to match poplar_kge.""" 4 | 5 | import typing 6 | from typing import Dict 7 | 8 | import numpy as np 9 | import poplar_kge as kge 10 | import torch as T 11 | from poplar_kge import Predictions 12 | 13 | 14 | class Model(T.nn.Module): 15 | """A reference implementation that should match `poplar_kge.Engine`.""" 16 | 17 | @typing.no_type_check 18 | def __init__(self, settings: kge.Settings): 19 | super().__init__() 20 | self.settings: kge.Settings = settings 21 | m = self.settings.model 22 | self.register_parameter( 23 | "entity_embedding", 24 | T.nn.Parameter(T.FloatTensor(m.n_shard, m.n_entity, m.embedding_size)), 25 | ) 26 | self.entity_features = T.zeros(m.n_shard, m.n_entity, m.entity_feature_size) 27 | self.register_parameter( 28 | "feature_projection", 29 | T.nn.Parameter(T.FloatTensor(m.entity_feature_size, m.embedding_size)), 30 | ) 31 | self.register_parameter( 32 | "relation_embedding", 33 | T.nn.Parameter(T.FloatTensor(m.n_relation_type, m.embedding_size)), 34 | ) 35 | T.nn.init.normal_(self.entity_embedding) 36 | T.nn.init.xavier_normal_(self.feature_projection) 37 | T.nn.init.normal_(self.relation_embedding) 38 | s = self.settings.training 39 | self.opt: T.optim.Optimizer = T.optim.Adam( 40 | self.parameters(), 41 | lr=s.learning_rate, 42 | betas=(s.adam_beta_m, s.adam_beta_v), 43 | eps=s.adam_epsilon, 44 | weight_decay=s.weight_decay, 45 | ) 46 | 47 | @typing.no_type_check 48 | def set_state(self, state: Dict[str, np.ndarray]) -> None: 49 | with T.no_grad(): 50 | (step,) = set(state["step"]) 51 | for key, param in self.named_parameters(): 52 | if key != "entity_embedding": 53 | param[...] = T.tensor(state[key]) 54 | self.opt.state[param] = dict( 55 | step=step, 56 | exp_avg=T.tensor(state[f"{key}/adam_m"]), 57 | exp_avg_sq=T.tensor(state[f"{key}/adam_v"]), 58 | ) 59 | size = self.settings.model.embedding_size 60 | data = T.tensor(state["entity_data"]).to(T.float) 61 | self.entity_embedding[...] = data[:, :, 0:size] 62 | self.opt.state[self.entity_embedding] = dict( 63 | step=step, 64 | exp_avg=data[:, :, size : 2 * size], 65 | exp_avg_sq=data[:, :, 2 * size : 3 * size] ** 2, 66 | ) 67 | self.entity_features[...] = data[:, :, 3 * size :] 68 | 69 | @typing.no_type_check 70 | def get_state(self) -> Dict[str, np.ndarray]: 71 | state = {} 72 | for key, param in self.named_parameters(): 73 | if key != "entity_embedding": 74 | state[key] = param.detach().numpy() 75 | state[f"{key}/adam_m"] = ( 76 | self.opt.state[param]["exp_avg"].detach().numpy() 77 | ) 78 | state[f"{key}/adam_v"] = ( 79 | self.opt.state[param]["exp_avg_sq"].detach().numpy() 80 | ) 81 | state["entity_data"] = ( 82 | T.cat( 83 | [ 84 | self.entity_embedding, 85 | self.opt.state[self.entity_embedding]["exp_avg"], 86 | T.sqrt(self.opt.state[self.entity_embedding]["exp_avg_sq"]), 87 | self.entity_features, 88 | ], 89 | dim=2, 90 | ) 91 | .detach() 92 | .numpy() 93 | ) 94 | return state 95 | 96 | @staticmethod 97 | def all_to_all(tensor: T.Tensor) -> T.Tensor: 98 | return tensor.transpose(0, 1) 99 | 100 | @staticmethod 101 | def one_hot_sign(indices: T.Tensor, num_classes: int) -> T.Tensor: 102 | return ( # type:ignore[no-any-return] 103 | 2 * T.nn.functional.one_hot(indices, num_classes) - 1 104 | ) 105 | 106 | @staticmethod 107 | def to_tensor(array: np.ndarray) -> T.Tensor: 108 | if array.dtype == np.uint32: 109 | array = array.astype(np.int64) 110 | return T.tensor(array) 111 | 112 | @typing.no_type_check 113 | def compute_embedding(self, shard: T.Tensor, idx: T.Tensor) -> T.Tensor: 114 | return ( 115 | self.entity_embedding[shard, idx] 116 | + self.entity_features[shard, idx] @ self.feature_projection 117 | ) 118 | 119 | @typing.no_type_check 120 | def compute_broadcast_score( 121 | self, pred_tails: T.Tensor, tails: T.Tensor 122 | ) -> T.Tensor: 123 | return self.settings.model.gamma - T.cdist(pred_tails, tails, p=1) 124 | 125 | @typing.no_type_check 126 | def loss(self, batch: kge.Batch) -> T.Tensor: 127 | batch = kge.Batch(**{k: self.to_tensor(v) for k, v in batch.__dict__.items()}) 128 | shards = T.arange(self.settings.model.n_shard) 129 | entity_embeddings = self.compute_embedding(shards[:, None], batch.remote) 130 | heads = entity_embeddings[shards[:, None], batch.head] 131 | predicted_tails = heads + self.relation_embedding[batch.relation, :] 132 | tails = self.all_to_all( 133 | entity_embeddings[shards[:, None, None], batch.a2a] 134 | ).reshape( 135 | self.settings.model.n_shard, 136 | self.settings.model.n_shard * self.settings.data.a2a_size, 137 | self.settings.model.embedding_size, 138 | ) 139 | scores = self.compute_broadcast_score(predicted_tails, tails) 140 | weight_positive = 0.5 * scores.shape[2] 141 | weight_negative = 0.5 * scores.shape[2] / (scores.shape[2] - 1) 142 | weight = ( 143 | T.nn.functional.one_hot(batch.tail, scores.shape[2]) 144 | * (weight_positive - weight_negative) 145 | + weight_negative 146 | ) 147 | return -T.mean( 148 | weight 149 | * T.nn.functional.logsigmoid( 150 | self.one_hot_sign(batch.tail, scores.shape[2]) * scores 151 | ) 152 | ) 153 | 154 | @typing.no_type_check 155 | def train_step_loop(self, batch: kge.Batch) -> float: 156 | losses = [] 157 | batches = [ 158 | kge.Batch(**{k: v[:, i] for k, v in batch.__dict__.items()}) 159 | for i in range(self.settings.execution.train_steps_per_program_run) 160 | ] 161 | for batch in batches: 162 | self.opt.zero_grad() 163 | loss = self.loss(batch) 164 | loss.backward(T.tensor(self.settings.training.loss_scale)) 165 | self.opt.step() 166 | losses.append(float(loss)) 167 | return np.mean(losses) 168 | 169 | @typing.no_type_check 170 | def predict( 171 | self, shard_idx_relation: np.ndarray, shard_to_count: np.ndarray 172 | ) -> Predictions: 173 | with T.no_grad(): 174 | mask = T.zeros((self.settings.model.n_shard, self.settings.model.n_entity)) 175 | for i, l in enumerate(shard_to_count): 176 | mask[i, 0] = -1e4 177 | mask[i, 1 + l :] = -1e4 178 | mask = mask.reshape(-1) 179 | 180 | shard_idx_relation = self.to_tensor(shard_idx_relation) 181 | all_entity_embs = self.compute_embedding( 182 | np.arange(self.settings.model.n_shard)[:, None], 183 | np.arange(self.settings.model.n_entity), 184 | ) 185 | predicted_tails = ( 186 | all_entity_embs[shard_idx_relation[:, 0], shard_idx_relation[:, 1]] 187 | + self.relation_embedding[shard_idx_relation[:, 2]] 188 | ) 189 | 190 | query_scores = self.compute_broadcast_score( 191 | predicted_tails, 192 | all_entity_embs.reshape(-1, self.settings.model.embedding_size), 193 | ) 194 | query_scores += mask 195 | 196 | predictions_flat = T.topk( 197 | query_scores, k=self.settings.execution.predict_n_best 198 | ) 199 | 200 | return Predictions( 201 | shard_idx=np.stack( 202 | np.divmod(predictions_flat.indices, self.settings.model.n_entity), 203 | axis=-1, 204 | ), 205 | score=predictions_flat.values, 206 | ) 207 | -------------------------------------------------------------------------------- /tests/python/test_poplar_kge.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | 3 | import itertools as it 4 | from functools import partial 5 | 6 | import numpy as np 7 | import poplar_kge as kge 8 | import poplar_kge_dataset as kge_data 9 | import pytest 10 | import reference_model 11 | 12 | # Utilities 13 | 14 | 15 | def test_predictions_topk() -> None: 16 | pre = kge.Predictions( 17 | shard_idx=np.array( 18 | [ 19 | [[10, 100], [20, 200], [30, 300], [40, 400]], 20 | [[50, 500], [60, 600], [70, 700], [80, 800]], 21 | ] 22 | ), 23 | score=np.array([[2, 1, 3, 0], [-4, -3, -2, -1]]), 24 | ) 25 | post = pre.topk(k=3).sort() 26 | assert post.shard_idx.shape == (2, 3, 2) 27 | assert post.score.shape == (2, 3) 28 | np.testing.assert_equal( 29 | post.shard_idx, 30 | [ 31 | [[30, 300], [10, 100], [20, 200]], 32 | [[80, 800], [70, 700], [60, 600]], 33 | ], 34 | ) 35 | np.testing.assert_equal(post.score, [[3, 2, 1], [-1, -2, -3]]) 36 | 37 | multidim = ( 38 | kge.Predictions(pre.shard_idx[:, np.newaxis], pre.score[:, np.newaxis]) 39 | .topk(k=3) 40 | .sort() 41 | ) 42 | assert multidim.shard_idx.shape == (2, 1, 3, 2) 43 | assert multidim.score.shape == (2, 1, 3) 44 | np.testing.assert_equal(multidim.shard_idx[:, 0], post.shard_idx) 45 | np.testing.assert_equal(multidim.score[:, 0], post.score) 46 | 47 | 48 | # Core 49 | 50 | 51 | def _random_batch(engine: kge.Engine) -> kge.Batch: 52 | s = engine.settings 53 | n_step = s.execution.train_steps_per_program_run 54 | n_shard = s.model.n_shard 55 | bshape = (n_shard, n_step) 56 | 57 | # Note: the zero entity index is reserved and remote indices must not repeat 58 | # (within a shard) 59 | def sample_remote() -> np.ndarray: 60 | # We'd like to use random.choice(..., replace=False), but it is very slow 61 | # when choosing from a large set, so we use with-replacement then unique 62 | sample: np.ndarray = np.unique( 63 | engine.random.choice( 64 | np.arange(1, s.model.n_entity), 65 | size=s.n_remote * 3, 66 | replace=True, 67 | ) 68 | ) 69 | engine.random.shuffle(sample) 70 | sample = np.sort(sample[: s.n_remote]) 71 | assert sample.size == s.n_remote 72 | return sample 73 | 74 | return kge.Batch( 75 | remote=( 76 | np.stack([sample_remote() for _ in range(n_shard * n_step)]) 77 | .reshape(bshape + (s.n_remote,)) 78 | .astype(np.uint32) 79 | ), 80 | a2a=engine.random_integers(s.n_remote, bshape + (n_shard, s.data.a2a_size)), 81 | head=engine.random_integers(s.n_remote, bshape + (s.data.batch_size,)), 82 | relation=engine.random_integers( 83 | s.model.n_relation_type, bshape + (s.data.batch_size,) 84 | ), 85 | tail=engine.random_integers(s.n_tail, bshape + (s.data.batch_size,)), 86 | ) 87 | 88 | 89 | def _random_prediction_query(engine: kge.Engine, n: int) -> np.ndarray: 90 | s = engine.settings 91 | return np.stack( 92 | [ 93 | engine.random_integers(s.model.n_shard, (n,)), 94 | 1 + engine.random_integers(s.model.n_entity - 1, (n,)), 95 | engine.random_integers(s.model.n_relation_type, (n,)), 96 | ], 97 | axis=-1, 98 | ) 99 | 100 | 101 | def _random_features(engine: kge.Engine) -> np.ndarray: 102 | s = engine.settings 103 | features = engine.random_normal( 104 | (s.model.n_shard, s.model.n_entity, s.model.entity_feature_size), 105 | np.float16, 106 | ) 107 | features[:, 0, :] = 0 # this will be zeroed in any case 108 | return features 109 | 110 | 111 | def test_engine_ops() -> None: 112 | # Note - we write a lot of stuff as a single test, to make it faster to run 113 | # (avoid recompiling) 114 | settings = kge.Settings.create_demo() 115 | settings.seed = 100 116 | settings.training.learning_rate = 0.01 117 | settings.prepare() 118 | 119 | engine = kge.Engine( 120 | settings, np.full(settings.model.n_shard, settings.model.n_entity - 1) 121 | ) 122 | 123 | # ## Train 124 | # Check we can initialise & take training steps 125 | random_features = _random_features(engine) 126 | engine.initialise_all(random_features) 127 | 128 | random_batch = _random_batch(engine) 129 | loss_0, _ = engine.train_step_loop(random_batch) 130 | loss_1, _ = engine.train_step_loop(random_batch) 131 | assert loss_1 < loss_0 132 | 133 | # ## Read/write SRAM 134 | # Check we can write an SRAM parameter & get it back 135 | new_relation_embedding = engine.random_normal( 136 | engine.variable_to_shape["relation_embedding"] 137 | ) 138 | assert not np.allclose( 139 | engine.read_variable("relation_embedding"), new_relation_embedding 140 | ) 141 | engine.write_variable("relation_embedding", new_relation_embedding) 142 | np.testing.assert_allclose( 143 | engine.read_variable("relation_embedding"), new_relation_embedding 144 | ) 145 | 146 | # ## Read DRAM 147 | # Check we can get our fake features back out of entity_data 148 | np.testing.assert_allclose( 149 | engine.read_entity_all()[:, :, -settings.model.entity_feature_size :], 150 | random_features, 151 | ) 152 | 153 | # ## Predict 154 | # Check we can make some predictions 155 | n_predict = engine.settings.execution.predict_hr_batch_size + 5 156 | predictions = engine.predict(_random_prediction_query(engine, n_predict)) 157 | assert predictions.shard_idx.shape == ( 158 | n_predict, 159 | settings.execution.predict_n_best, 160 | 2, 161 | ) 162 | assert np.all( 163 | np.isin(predictions.shard_idx[:, :, 0], np.arange(settings.model.n_shard)) 164 | ) 165 | assert np.all( 166 | np.isin(predictions.shard_idx[:, :, 1], np.arange(1, settings.model.n_entity)) 167 | ) 168 | assert predictions.score.shape == (n_predict, settings.execution.predict_n_best) 169 | assert not np.any(np.isnan(predictions.score)) 170 | np.testing.assert_equal( 171 | predictions.score, np.sort(predictions.score, axis=-1)[..., ::-1] 172 | ) 173 | 174 | 175 | @pytest.mark.parametrize("dtype", ["float32", "float16"]) 176 | def test_training(dtype: str) -> None: 177 | settings = kge.Settings.create_demo() 178 | settings.seed = 100 179 | settings.execution.dtype = dtype 180 | if dtype == "float16": 181 | settings.training.n_step = 20 # since numerical errors compound 182 | settings.execution.device = "ipu" 183 | settings.prepare() 184 | 185 | data = kge_data.RawData.load(settings) 186 | dataset = kge_data.Dataset.load(data, settings) 187 | try: 188 | engine = kge.Engine(settings, dataset.shard_to_count) 189 | except RuntimeError as e: 190 | if "Could not attach" in str(e): 191 | pytest.skip("IPU not available") 192 | raise 193 | engine.initialise_all(dataset.entity_features(settings.model.entity_feature_size)) 194 | initial_state = engine.read_all() 195 | reference = reference_model.Model(settings) 196 | reference.set_state(initial_state) 197 | 198 | # A difference between reference & poplar_kge is that 199 | # entity embeddings in poplar_kge are only updated when they are used by 200 | # the model, so for this test we keep 'remote' indices the same for every batch 201 | batches = [_random_batch(engine), _random_batch(engine), _random_batch(engine)] 202 | for batch in batches: 203 | batch.remote[...] = batches[0].remote[:, 0, np.newaxis, :] 204 | 205 | n_loop = settings.program_runs_per_log * settings.logs_per_training_run 206 | 207 | loss = [] 208 | reference_loss = [] 209 | mrr = [] 210 | reference_mrr = [] 211 | for batch in it.islice(it.cycle(batches), n_loop): 212 | loss.append(engine.train_step_loop(batch)[0]) 213 | reference_loss.append(reference.train_step_loop(batch)) 214 | mrr.append(dataset.mrr("valid", engine.predict)) 215 | reference_mrr.append( 216 | dataset.mrr( 217 | "valid", 218 | partial(reference.predict, shard_to_count=dataset.shard_to_count), 219 | ) 220 | ) 221 | 222 | # Loss check 223 | np.testing.assert_allclose( 224 | loss, reference_loss, rtol=dict(float16=2e-3, float32=1e-4)[dtype] 225 | ) 226 | 227 | # MRR check 228 | np.testing.assert_allclose( 229 | mrr, reference_mrr, rtol=dict(float16=2e-2, float32=1e-4)[dtype] 230 | ) 231 | 232 | # Rough final value check 233 | final_state = engine.read_all() 234 | reference_final_state = reference.get_state() 235 | final_state.pop("step") 236 | assert set(final_state.keys()) == set(reference_final_state.keys()) 237 | 238 | for key in final_state: 239 | # These optimiser states have particularly high error in float16 (cause unknown) 240 | if dtype == "float16" and key.endswith("/adam_m"): 241 | continue 242 | 243 | delta = final_state[key] - initial_state[key] 244 | reference_delta = reference_final_state[key] - initial_state[key] 245 | if key == "entity_data": 246 | delta = delta[:, 1:] 247 | reference_delta = reference_delta[:, 1:] 248 | # Euclidean distance, normalised by the reference 249 | distance = np.sqrt( 250 | np.sum((delta - reference_delta) ** 2) / np.sum(reference_delta**2) 251 | ) 252 | tolerance = dict(float16=1e-1, float32=1e-4)[dtype] 253 | assert distance <= tolerance, ( 254 | "Distance between (final - initial) and reference_(final - initial)" 255 | f" for parameter '{key}'" 256 | f" is {distance:.2g} > tolerance {tolerance:.2g}" 257 | f"\n\n(final - initial) = {np.array2string(delta, threshold=20)}" 258 | f"\n\nreference_(final - initial) = {np.array2string(reference_delta, threshold=20)}" 259 | ) 260 | -------------------------------------------------------------------------------- /src/pag/pag.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | 3 | #ifndef PAG_HPP 4 | #define PAG_HPP 5 | 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | /** 17 | * PAG = Poplar AutoGrad 18 | * 19 | * A rough-and-ready implementation of autograd for plain Poplar. 20 | */ 21 | namespace pag { 22 | 23 | struct GraphImpl; 24 | 25 | /** 26 | * A forward pass activation or weight, that can have a backward pass (gradient) associated 27 | * with it (see Graph). 28 | */ 29 | struct Tensor { 30 | using ID = unsigned; 31 | static constexpr ID Invalid = 0u; 32 | 33 | Tensor(); 34 | explicit Tensor(ID id); 35 | ID id() const; 36 | bool valid() const; 37 | 38 | private: 39 | ID m_id; 40 | }; 41 | 42 | /** 43 | * A stash of forward pass (activation) and backward pass (gradient) `poplar::Tensor`s. 44 | */ 45 | struct Graph { 46 | explicit Graph(poplar::Graph&); 47 | Graph(const Graph&) = delete; 48 | Graph& operator=(const Graph&) = delete; 49 | ~Graph(); 50 | 51 | const poplar::Graph& poplar() const; 52 | poplar::Graph& poplar(); 53 | poplar::Tensor unwrap(const Tensor& tensor) const; 54 | poplar::Tensor grad(const Tensor& tensor, bool checkValid = true) const; 55 | bool requiresGrad(const Tensor& tensor) const; 56 | 57 | /** 58 | * Add a `poplar::Tensor` to the graph. 59 | */ 60 | Tensor wrap(const poplar::Tensor& tensor, bool requiresGrad); 61 | 62 | /** 63 | * Sets or accumulates a gradient tensor. 64 | * 65 | * If there is already a gradient tensor set (for example when the tensor was used by 66 | * multiple forward pass operations), accumulate the gradient. 67 | */ 68 | void addGrad(const Tensor& tensor, 69 | const poplar::Tensor& grad, 70 | poplar::program::Sequence& prog, 71 | const poplar::DebugContext& debugContext); 72 | 73 | private: 74 | std::unique_ptr m_impl; 75 | }; 76 | 77 | /** 78 | * A sequential program that records forward pass operations, allowing generation of a 79 | * backward pass. 80 | */ 81 | struct Tape { 82 | using BackwardOp = std::function; 83 | 84 | poplar::program::Sequence& prog(); 85 | 86 | void addBackwardOp(const BackwardOp& op); 87 | void backward(Graph& graph, const Tensor& root = {}, const poplar::Tensor& rootGrad = {}); 88 | 89 | private: 90 | poplar::program::Sequence m_prog; 91 | std::vector m_backwardOps; 92 | }; 93 | 94 | namespace util { 95 | 96 | poplar::Tensor broadcastGrad(Graph& graph, 97 | const poplar::Tensor& grad, 98 | const Tensor& tensor, 99 | poplar::program::Sequence& prog, 100 | const poplar::DebugContext& debugContext = {}); 101 | 102 | } // namespace util 103 | 104 | /** 105 | * Extensible library of differentiable ops 106 | */ 107 | namespace ops { 108 | 109 | /** 110 | * Can be used as a "start grad". 111 | */ 112 | Tensor identity(Graph& graph, const Tensor& tensor, bool requiresGrad, Tape& tape); 113 | 114 | Tensor transpose(Graph& graph, const Tensor& tensor, Tape& tape); 115 | 116 | Tensor reshape(Graph& graph, const Tensor& tensor, const std::vector& shape, Tape& tape); 117 | 118 | Tensor slice(Graph& graph, const Tensor& tensor, size_t dim, poplar::Interval region, Tape& tape); 119 | 120 | Tensor concat(Graph& graph, const std::vector& tensors, size_t dim, Tape& tape); 121 | 122 | /** 123 | * Splits a tensor in dimension `dim`, where each part has specified size. 124 | * 125 | * Requires: sum(sizes) == tensor.dim(dim) 126 | */ 127 | std::vector split(Graph& graph, 128 | const Tensor& tensor, 129 | size_t dim, 130 | const std::vector& sizes, 131 | Tape& tape); 132 | 133 | Tensor add(Graph& graph, 134 | const Tensor& A, 135 | const Tensor& B, 136 | Tape& tape, 137 | const poplar::DebugContext& debugContext = {}, 138 | const poplar::OptionFlags& options = {}); 139 | 140 | Tensor sub(Graph& graph, 141 | const Tensor& A, 142 | const Tensor& B, 143 | Tape& tape, 144 | const poplar::DebugContext& debugContext = {}, 145 | const poplar::OptionFlags& options = {}); 146 | 147 | Tensor mul(Graph& graph, 148 | const Tensor& A, 149 | const Tensor& B, 150 | Tape& tape, 151 | const poplar::DebugContext& debugContext = {}, 152 | const poplar::OptionFlags& options = {}); 153 | 154 | Tensor div(Graph& graph, 155 | const Tensor& A, 156 | const Tensor& B, 157 | Tape& tape, 158 | const poplar::DebugContext& debugContext = {}, 159 | const poplar::OptionFlags& options = {}); 160 | 161 | Tensor neg(Graph& graph, 162 | const Tensor& A, 163 | Tape& tape, 164 | const poplar::DebugContext& debugContext = {}, 165 | const poplar::OptionFlags& options = {}); 166 | 167 | Tensor abs(Graph& graph, 168 | const Tensor& A, 169 | Tape& tape, 170 | const poplar::DebugContext& debugContext = {}, 171 | const poplar::OptionFlags& options = {}); 172 | 173 | Tensor square(Graph& graph, 174 | const Tensor& A, 175 | Tape& tape, 176 | const poplar::DebugContext& debugContext = {}, 177 | const poplar::OptionFlags& options = {}); 178 | 179 | Tensor pow(Graph& graph, 180 | const Tensor& A, 181 | float exponent, 182 | Tape& tape, 183 | const poplar::DebugContext& debugContext = {}, 184 | const poplar::OptionFlags& options = {}); 185 | 186 | Tensor sqrt(Graph& graph, 187 | const Tensor& A, 188 | Tape& tape, 189 | const poplar::DebugContext& debugContext = {}, 190 | const poplar::OptionFlags& options = {}); 191 | 192 | Tensor cbrt(Graph& graph, 193 | const Tensor& A, 194 | Tape& tape, 195 | const poplar::DebugContext& debugContext = {}, 196 | const poplar::OptionFlags& options = {}); 197 | 198 | Tensor sin(Graph& graph, 199 | const Tensor& A, 200 | Tape& tape, 201 | const poplar::DebugContext& debugContext = {}, 202 | const poplar::OptionFlags& options = {}); 203 | 204 | Tensor cos(Graph& graph, 205 | const Tensor& A, 206 | Tape& tape, 207 | const poplar::DebugContext& debugContext = {}, 208 | const poplar::OptionFlags& options = {}); 209 | 210 | Tensor dropout(Graph& graph, 211 | const Tensor& A, 212 | float p, 213 | Tape& tape, 214 | const poplar::DebugContext& debugContext = {}); 215 | 216 | Tensor cast(Graph& graph, 217 | const Tensor& A, 218 | poplar::Type type, 219 | Tape& tape, 220 | const poplar::DebugContext& debugContext = {}); 221 | 222 | /** 223 | * Note: if A == B, both tensors will receive gradient in the bwd pass. 224 | */ 225 | Tensor max(Graph& graph, 226 | const Tensor& A, 227 | const Tensor& B, 228 | Tape& tape, 229 | const poplar::DebugContext& debugContext = {}, 230 | const poplar::OptionFlags& options = {}); 231 | 232 | Tensor matMul(Graph& graph, 233 | const Tensor& A, 234 | const Tensor& B, 235 | Tape& tape, 236 | const poplar::DebugContext& debugContext = {}, 237 | const poplar::OptionFlags& options = {}, 238 | poplin::PlanningCache* cache = nullptr); 239 | 240 | /** 241 | * WARNING - tested only for very limited cases (2D-table embedding lookups). 242 | */ 243 | Tensor multiSlice(Graph& graph, 244 | const Tensor& t, 245 | const Tensor& offsets, 246 | const std::vector& dims, 247 | const std::vector& sizes, 248 | Tape& tape, 249 | const popops::SlicePlan& plan, 250 | const poplar::OptionFlags& options, 251 | const poplar::DebugContext& debugContext = {}); 252 | 253 | Tensor reduce(Graph& graph, 254 | const Tensor& in, 255 | const std::vector& dims, 256 | popops::ReduceParams params, 257 | Tape& tape, 258 | const poplar::DebugContext& debugContext = {}, 259 | const poplar::OptionFlags& options = {}); 260 | 261 | Tensor l1distance(Graph& graph, 262 | const Tensor& A, 263 | const Tensor& B, 264 | Tape& tape, 265 | const poplar::DebugContext& debugContext = {}); 266 | 267 | Tensor l2distance(Graph& graph, 268 | const Tensor& A, 269 | const Tensor& B, 270 | Tape& tape, 271 | const poplar::DebugContext& debugContext = {}); 272 | 273 | /////////////////////////////////////////////////////////////////////////////// 274 | // Neural Networks 275 | 276 | Tensor logSoftmax(Graph& graph, 277 | const Tensor& t, 278 | Tape& tape, 279 | const poplar::DebugContext& debugContext = {}); 280 | 281 | Tensor sigmoid(Graph& graph, 282 | const Tensor& A, 283 | Tape& tape, 284 | const poplar::DebugContext& debugContext = {}, 285 | const poplar::OptionFlags& options = {}); 286 | 287 | Tensor logSigmoid(Graph& graph, 288 | const Tensor& t, 289 | Tape& tape, 290 | const poplar::DebugContext& debugContext = {}); 291 | 292 | /////////////////////////////////////////////////////////////////////////////// 293 | // Collectives 294 | 295 | Tensor allToAllCrossReplica(Graph& graph, 296 | const Tensor& data, 297 | Tape& tape, 298 | const gcl::CommGroup& group, 299 | const poplar::DebugContext& debugContext = {}, 300 | const poplar::OptionFlags& options = {}); 301 | 302 | Tensor reduceScatterCrossReplica(Graph& graph, 303 | const Tensor& data, 304 | gcl::CollectiveOperator op, 305 | Tape& tape, 306 | const gcl::CommGroup& group, 307 | const poplar::DebugContext& debugContext = {}, 308 | const poplar::OptionFlags& options = {}); 309 | 310 | Tensor allGatherCrossReplica(Graph& graph, 311 | const Tensor& data, 312 | Tape& tape, 313 | const gcl::CommGroup& group, 314 | const poplar::DebugContext& debugContext = {}, 315 | const poplar::OptionFlags& options = {}); 316 | 317 | } // namespace ops 318 | 319 | } // namespace pag 320 | 321 | #endif // PAG_HPP 322 | -------------------------------------------------------------------------------- /dev: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 3 | 4 | """Dev task launcher for LscWikiKG.""" 5 | 6 | import argparse 7 | import collections 8 | import json 9 | import os 10 | import subprocess 11 | import sys 12 | import sysconfig 13 | from pathlib import Path 14 | from types import TracebackType 15 | from typing import Any, Callable, Dict, Iterable, List, Optional, Type, TypeVar 16 | 17 | # Utilities 18 | 19 | 20 | def run(command: Iterable[Any], env: Dict[str, str] = {}) -> None: 21 | """Run a command, terminating on failure.""" 22 | cmd = [str(arg) for arg in command if arg is not None] 23 | print("$ " + " ".join(cmd), file=sys.stderr) 24 | environ = os.environ.copy() 25 | environ.update(env) 26 | cwd = os.getcwd() 27 | environ["PYTHONPATH"] = ":".join( 28 | [ 29 | environ.get("PYTHONPATH", ""), 30 | f"{cwd}/build", 31 | f"{cwd}/src/python", 32 | f"{cwd}/tests/python", 33 | ] 34 | ) 35 | exit_code = subprocess.call(cmd, env=environ) 36 | if exit_code: 37 | sys.exit(exit_code) 38 | 39 | 40 | T = TypeVar("T") 41 | 42 | 43 | def cli(*args: Any, **kwargs: Any) -> Callable[[T], T]: 44 | """Declare a CLI command / arguments for that command.""" 45 | 46 | def wrap(func: T) -> T: 47 | if not hasattr(func, "cli_args"): 48 | setattr(func, "cli_args", []) 49 | if args or kwargs: 50 | getattr(func, "cli_args").append((args, kwargs)) 51 | return func 52 | 53 | return wrap 54 | 55 | 56 | class _NinjaFile: 57 | """Basic builder for .ninja files.""" 58 | 59 | def __init__(self, path: Path): 60 | self.path = path 61 | path.parent.mkdir(exist_ok=True, parents=True) 62 | self.file = open(path, "w") 63 | 64 | def __enter__(self) -> "_NinjaFile": 65 | return self 66 | 67 | def __exit__( 68 | self, 69 | exc_type: Optional[Type[BaseException]], 70 | exc_val: Optional[BaseException], 71 | exc_tb: Optional[TracebackType], 72 | ) -> None: 73 | self.file.close() 74 | 75 | def write(self, *block: str) -> None: 76 | """Write a declaration or rule to the file.""" 77 | print("\n ".join(block), file=self.file) 78 | 79 | def blank(self) -> None: 80 | """Insert a blank line (for readability).""" 81 | self.file.write("\n") 82 | 83 | 84 | # Commands 85 | 86 | PY_FOLDERS = ["dev", "src/python", "tests/python", "scripts"] 87 | 88 | 89 | @cli("targets", nargs="*", help="targets to build") 90 | def build(targets: List[str]) -> None: 91 | """build C++ shared libraries and tests""" 92 | build_root = Path("build") 93 | with _NinjaFile(build_root / "build.ninja") as ninja: 94 | poplar = os.environ["POPLAR_SDK_ENABLED"] 95 | cxx = os.environ.get("CXX", "clang++") 96 | ninja.write( 97 | f"cpppath = -Isrc -isystem {poplar}/include -Ithird_party/pybind11/include" 98 | f" -Ithird_party/catch2/single_include -I{sysconfig.get_path('include')}" 99 | ) 100 | ninja.write( 101 | "cppflags = -Wall -Wextra -Werror -Wno-unused-function -std=c++17 -O2 -g -fPIC" 102 | + (" -fcolor-diagnostics" if cxx == "clang++" else "") 103 | ) 104 | ninja.write("linkflags = $cppflags -Wl,--no-undefined") 105 | ninja.write( 106 | f"libs = {sysconfig.get_config_var('BLDLIBRARY')}" 107 | " -lpoplar -lpopops -lpoputil -lpoplin -lpopnn -lgcl -lpoprand" 108 | ) 109 | ninja.write() 110 | 111 | ninja.write( 112 | "rule compile", 113 | f"command = {cxx} -MD -MF$out.d $cppflags $cpppath -c $in -o $out", 114 | "deps = gcc", 115 | "depfile = $out.d", 116 | ) 117 | ninja.write() 118 | ninja.write( 119 | "rule linkso", 120 | f"command = {cxx} $linkflags -shared $in -o $out $libs", 121 | ) 122 | ninja.write() 123 | ninja.write( 124 | "rule linkexe", 125 | f"command = {cxx} $linkflags $in -o $out $libs", 126 | ) 127 | ninja.write() 128 | 129 | ninja.write( 130 | "rule popc", 131 | "command = popc -Wall -Wextra -Werror -Wold-style-cast -O2 -g --target=cpu,ipu1,ipu2 $in -o $out", 132 | ) 133 | ninja.write() 134 | 135 | # Compile 136 | objs = collections.defaultdict(list) 137 | for src_root in [Path("src"), Path("tests")]: 138 | for cpp_file in src_root.glob("**/*.cpp"): 139 | if ".codelet" not in cpp_file.suffixes: 140 | obj_file = build_root / "obj" / cpp_file.with_suffix(".obj") 141 | ninja.write(f"build {obj_file}: compile {cpp_file}") 142 | objs[src_root.name].append(obj_file) 143 | ninja.write() 144 | 145 | # Compile codelets 146 | codelets_src = list(Path("src/poplar_extensions").glob("*.codelet.cpp")) 147 | codelets_gp = build_root / "poplar_extensions.gp" 148 | ninja.write(f"build {codelets_gp}: popc {' '.join(map(str, codelets_src))}") 149 | 150 | # Link 151 | ninja.write( 152 | f"build {build_root}/libpoplar_kge.so:" 153 | f" linkso {' '.join(map(str, objs['src']))} | {codelets_gp}" 154 | ) 155 | ninja.write( 156 | f"build {build_root}/tests:" 157 | f" linkexe {' '.join(map(str, objs['src'] + objs['tests']))} | {codelets_gp}" 158 | ) 159 | 160 | run(["ninja", "-f", str(ninja.path)] + targets) 161 | 162 | 163 | @cli("-k", "--filter") 164 | @cli("--gdb", action="store_true") 165 | @cli("--profile", type=Path, help="run profiling, to this path") 166 | def tests_cpp(filter: Optional[str], gdb: bool, profile: Optional[Path]) -> None: 167 | """run C++ tests""" 168 | build(["build/tests"]) 169 | prefix, suffix = [], [] 170 | if gdb: 171 | prefix = ["gdb", "-ex", "catch throw", "-ex", "run", "--args"] 172 | suffix = ["--abort", "--break"] 173 | env = {} 174 | if profile: 175 | profile.mkdir(parents=True, exist_ok=True) 176 | env["POPLAR_ENGINE_OPTIONS"] = json.dumps( 177 | { 178 | "autoReport.all": "true", 179 | "autoReport.directory": str(profile), 180 | "autoReport.outputArchive": False, 181 | } 182 | ) 183 | if not filter: 184 | filter = "~[benchmark]" 185 | run(prefix + ["./build/tests", filter] + suffix, env=env) 186 | 187 | 188 | @cli("-k", "--filter") 189 | @cli("--gdb", action="store_true") 190 | def tests_py(filter: Optional[str], gdb: bool) -> None: 191 | """run Python tests""" 192 | build(["build/libpoplar_kge.so"]) 193 | prefix = [] 194 | if gdb: 195 | prefix = ["gdb", "-ex", "catch throw", "-ex", "run", "--args"] 196 | run( 197 | prefix 198 | + [ 199 | "python", 200 | "-m", 201 | "pytest", 202 | "-rA", 203 | "tests/python", 204 | *(["-k", filter] if filter else []), 205 | ] 206 | ) 207 | 208 | 209 | @cli() 210 | def tests() -> None: 211 | """run all tests""" 212 | tests_cpp(filter=None, gdb=False, profile=None) 213 | tests_py(filter=None, gdb=False) 214 | 215 | 216 | @cli("command", nargs="*") 217 | @cli("-w", "--wrap", choices=("gdb", "cprofile")) 218 | def python(command: List[Any], wrap: str) -> None: 219 | build([]) 220 | prefix: List[Any] = [] 221 | if wrap == "gdb": 222 | prefix = ["gdb", "-ex", "catch throw", "-ex", "run", "--args", "python"] 223 | elif wrap == "cprofile": 224 | prefix = ["python", "-m", "cProfile", "-s", "cumtime"] 225 | else: 226 | prefix = ["python"] 227 | run(prefix + command) 228 | 229 | 230 | @cli("output", nargs="?", type=Path) 231 | @cli("-w", "--wrap", choices=("gdb", "cprofile")) 232 | def profile(output: Optional[Path], wrap: str) -> None: 233 | """run a profile script for a single training step""" 234 | python(["scripts/run_profile.py", output], wrap=wrap) 235 | 236 | 237 | @cli("-w", "--wrap", choices=("gdb", "cprofile")) 238 | def train(wrap: str) -> None: 239 | """run a profile script for a single training step""" 240 | python(["scripts/run_training.py"], wrap=wrap) 241 | 242 | 243 | @cli() 244 | def lint() -> None: 245 | """run static analysis""" 246 | run(["flake8", *PY_FOLDERS]) 247 | run(["mypy", *PY_FOLDERS]) 248 | 249 | 250 | @cli("--check", action="store_true") 251 | def format(check: bool, isort: bool = True) -> None: 252 | """autoformat all sources""" 253 | cpp_files = [*Path("src").glob("**/*.[ch]pp"), *Path("tests").glob("**/*.[ch]pp")] 254 | if check: 255 | output = subprocess.check_output( 256 | ["clang-format", "-output-replacements-xml", *map(str, cpp_files)] 257 | ).decode() 258 | if "" in output: 259 | print("Some C++ files need formatting, please run ./dev format") 260 | sys.exit(1) 261 | else: 262 | run(["clang-format", "-i", *cpp_files]) 263 | run(["black", "--check" if check else None, *PY_FOLDERS]) 264 | if isort: 265 | run(["isort", "--check" if check else None, *PY_FOLDERS]) 266 | 267 | 268 | @cli("--port", type=int) 269 | def lab(port: Optional[int] = None) -> None: 270 | """start a jupyter lab server""" 271 | run( 272 | [ 273 | "python", 274 | "-m", 275 | "jupyter", 276 | "lab", 277 | "--ip", 278 | "*", 279 | *(["--port", port] if port else []), 280 | ] 281 | ) 282 | 283 | 284 | @cli() 285 | def check_copyright_headers() -> None: 286 | """check for Graphcore copyright headers on relevant files""" 287 | command = ( 288 | "find dev scripts/ src/ tests/ -type f -not -name *.pyc" 289 | " | xargs grep -L 'Copyright (c) 202. Graphcore Ltd[.] All rights reserved[.]'" 290 | ) 291 | print(f"$ {command}", file=sys.stderr) 292 | # Note: grep exit codes are not consistent between versions, so we don't use check=True 293 | output = ( 294 | subprocess.run( 295 | command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT 296 | ) 297 | .stdout.decode() 298 | .strip() 299 | ) 300 | if output: 301 | print(output, file=sys.stderr) 302 | sys.exit(1) 303 | 304 | 305 | @cli() 306 | def ci() -> None: 307 | """run all continuous integration tests & checks""" 308 | tests() 309 | lint() 310 | # Because of https://github.com/PyCQA/isort/issues/1889 311 | format(check=True, isort=False) 312 | check_copyright_headers() 313 | 314 | 315 | # Script 316 | 317 | 318 | def _main() -> None: 319 | parser = argparse.ArgumentParser(description=__doc__) 320 | parser.set_defaults(action=lambda: ci()) 321 | 322 | subs = parser.add_subparsers() 323 | for key, value in globals().items(): 324 | if hasattr(value, "cli_args"): 325 | sub = subs.add_parser(key.replace("_", "-"), help=value.__doc__) 326 | for args, kwargs in value.cli_args: 327 | sub.add_argument(*args, **kwargs) 328 | sub.set_defaults(action=value) 329 | 330 | cli_args = vars(parser.parse_args()) 331 | action = cli_args.pop("action") 332 | action(**cli_args) 333 | 334 | 335 | if __name__ == "__main__": 336 | _main() 337 | -------------------------------------------------------------------------------- /src/poplar_extensions/distance.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | 3 | #include "distance.hpp" 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | 12 | namespace { 13 | 14 | void mapTensor2Dblocks(poplar::Graph& graph, poplar::Tensor& t) { 15 | assert(t.rank() == 2 && "only 2D tensors can use mapTensor2Dblocks"); 16 | auto nTiles = graph.getTarget().getNumTiles(); 17 | auto blockSize0 = 18 | std::max(std::ceil(t.dim(0) / nTiles), 19 | std::ceil(std::sqrt(static_cast(t.numElements()) / nTiles))); 20 | auto nBlocks0 = (t.dim(0) + blockSize0 - 1) / blockSize0; 21 | auto nBlocks1 = std::max(1u, nTiles / nBlocks0); 22 | auto blockSize1 = (t.dim(1) + nBlocks1 - 1) / nBlocks1; 23 | for (auto i = 0u; i < nBlocks0; ++i) { 24 | for (auto j = 0u; j < nBlocks1; ++j) { 25 | auto tile = nBlocks1 * i + j; 26 | graph.setTileMapping(t.slice({std::min(i * blockSize0, t.dim(0)), 27 | std::min(j * blockSize1, t.dim(1))}, 28 | {std::min((i + 1) * blockSize0, t.dim(0)), 29 | std::min((j + 1) * blockSize1, t.dim(1))}), 30 | tile); 31 | } 32 | } 33 | } 34 | 35 | poplar::Tensor getCachedCopy(std::map, poplar::Tensor>& cache, 36 | poplar::Graph& graph, 37 | size_t tile, 38 | size_t index, 39 | const poplar::Tensor& t, 40 | poplar::program::Sequence& prog) { 41 | std::pair key = {tile, index}; 42 | auto iter = cache.find(key); 43 | if (iter != cache.end()) { 44 | return iter->second; 45 | } 46 | poplar::Tensor copy = graph.addVariable(t.elementType(), t.shape()); 47 | graph.setTileMapping(copy, tile); 48 | cache[key] = copy; 49 | prog.add(poplar::program::Copy(t, copy)); 50 | return copy; 51 | } 52 | 53 | // Like popops::cast, but alias rather than copying when the type already matches 54 | poplar::Tensor castMaybe(poplar::Graph& graph, 55 | const poplar::Tensor& tensor, 56 | poplar::Type type, 57 | poplar::program::Sequence& prog) { 58 | if (type == tensor.elementType()) { 59 | return tensor; 60 | } 61 | auto casted = graph.clone(type, tensor); 62 | prog.add(popops::cast(graph, tensor, casted)); 63 | return casted; 64 | } 65 | 66 | } // namespace 67 | 68 | namespace poplar_extensions { 69 | 70 | poplar::Tensor l1distance(poplar::Graph& graph, 71 | const poplar::Tensor& a, 72 | const poplar::Tensor& b, 73 | poplar::program::Sequence& prog, 74 | const poplar::DebugContext& debugContext) { 75 | if (a.rank() != 2 || b.rank() != 2 || a.dim(1) != b.dim(1)) { 76 | std::ostringstream msg; 77 | msg << "Bad arguments to l1distance, expected a.shape (M, K), b.shape (N, K), actual" 78 | << " a.shape = " << a.shapeToString() << ", b.shape = " << b.shapeToString() << "."; 79 | throw std::invalid_argument(msg.str()); 80 | } 81 | const size_t n = b.dim(0); 82 | poplar::Tensor out = 83 | graph.addVariable(poplar::FLOAT, {a.dim(0), b.dim(0)}, {debugContext, "l1dist_out"}); 84 | mapTensor2Dblocks(graph, out); 85 | const auto& mapping = graph.getTileMapping(out); 86 | poplar::ComputeSet cs = graph.addComputeSet({debugContext, "l1dist"}); 87 | const auto vertexName = poputil::templateVertex("L1DistanceSingleVertex", a.elementType()); 88 | for (size_t i = 0; i < mapping.size(); ++i) { 89 | for (const auto& interval : mapping[i]) { 90 | for (size_t j = interval.begin(); j != interval.end(); ++j) { 91 | size_t a_index = j / n; 92 | size_t b_index = j % n; 93 | auto v = graph.addVertex(cs, vertexName); 94 | graph.connect(v["a"], a[a_index]); 95 | graph.connect(v["b"], b[b_index]); 96 | graph.connect(v["out"], out[a_index][b_index]); 97 | graph.setTileMapping(v, i); 98 | graph.setPerfEstimate(v, 0); // placeholder 99 | } 100 | } 101 | } 102 | prog.add(poplar::program::Execute(cs)); 103 | return castMaybe(graph, out, a.elementType(), prog); 104 | } 105 | 106 | poplar::Tensor l1distancegrad(poplar::Graph& graph, 107 | const poplar::Tensor& a, 108 | const poplar::Tensor& b, 109 | const poplar::Tensor& gradOutput, 110 | poplar::program::Sequence& prog, 111 | const poplar::DebugContext& debugContext) { 112 | if (a.rank() != 2 || b.rank() != 2 || gradOutput.rank() != 2 || a.dim(1) != b.dim(1) || 113 | gradOutput.dim(0) != a.dim(0) || gradOutput.dim(1) != b.dim(0)) { 114 | std::ostringstream msg; 115 | msg << "Bad arguments to l1distancegrad, expected" 116 | << " a.shape (M, K), b.shape (N, K), gradOutput.shape (M, N), actual" 117 | << " a.shape = " << a.shapeToString() << ", b.shape = " << b.shapeToString() 118 | << ", gradOutput.shape = " << gradOutput.shapeToString() << "."; 119 | throw std::invalid_argument(msg.str()); 120 | } 121 | const size_t k = a.dim(1); 122 | poplar::Tensor grad = 123 | graph.addVariable(poplar::FLOAT, a.shape(), {debugContext, "l1dist_grad"}); 124 | mapTensor2Dblocks(graph, grad); 125 | const auto& mapping = graph.getTileMapping(grad); 126 | poplar::ComputeSet cs = graph.addComputeSet({debugContext, "l1dist_grad"}); 127 | const auto vertexName = poputil::templateVertex("L1DistanceGradSingleVertex", a.elementType()); 128 | std::map, poplar::Tensor> bCache, gradCache; 129 | for (size_t i = 0; i < mapping.size(); ++i) { 130 | for (const auto& interval : mapping[i]) { 131 | for (size_t j = interval.begin(); j != interval.end(); ++j) { 132 | size_t a1_index = j / k; 133 | size_t a2_index = j % k; 134 | auto v = graph.addVertex(cs, vertexName); 135 | graph.connect(v["a"], a[a1_index][a2_index]); 136 | graph.connect( 137 | v["b"], getCachedCopy(bCache, graph, i, a2_index, 138 | b.slice({a2_index, a2_index + 1}, 1).squeeze({1}), prog)); 139 | graph.connect(v["gradOutput"], getCachedCopy(gradCache, graph, i, a1_index, 140 | gradOutput[a1_index], prog)); 141 | graph.connect(v["grad"], grad[a1_index][a2_index]); 142 | graph.setTileMapping(v, i); 143 | graph.setPerfEstimate(v, 0); // placeholder 144 | } 145 | } 146 | } 147 | prog.add(poplar::program::Execute(cs)); 148 | return castMaybe(graph, grad, a.elementType(), prog); 149 | } 150 | 151 | poplar::Tensor l2distance(poplar::Graph& graph, 152 | const poplar::Tensor& a, 153 | const poplar::Tensor& b, 154 | poplar::program::Sequence& prog, 155 | const poplar::DebugContext& debugContext) { 156 | if (a.rank() != 2 || b.rank() != 2 || a.dim(1) != b.dim(1)) { 157 | std::ostringstream msg; 158 | msg << "Bad arguments to l2distance, expected a.shape (M, K), b.shape (N, K), actual" 159 | << " a.shape = " << a.shapeToString() << ", b.shape = " << b.shapeToString() << "."; 160 | throw std::invalid_argument(msg.str()); 161 | } 162 | const size_t n = b.dim(0); 163 | poplar::Tensor out = 164 | graph.addVariable(poplar::FLOAT, {a.dim(0), b.dim(0)}, {debugContext, "l2dist_out"}); 165 | mapTensor2Dblocks(graph, out); 166 | const auto& mapping = graph.getTileMapping(out); 167 | poplar::ComputeSet cs = graph.addComputeSet({debugContext, "l2dist"}); 168 | const auto vertexName = poputil::templateVertex("L2DistanceSingleVertex", a.elementType()); 169 | for (size_t i = 0; i < mapping.size(); ++i) { 170 | for (const auto& interval : mapping[i]) { 171 | for (size_t j = interval.begin(); j != interval.end(); ++j) { 172 | size_t a_index = j / n; 173 | size_t b_index = j % n; 174 | auto v = graph.addVertex(cs, vertexName); 175 | graph.connect(v["a"], a[a_index]); 176 | graph.connect(v["b"], b[b_index]); 177 | graph.connect(v["out"], out[a_index][b_index]); 178 | graph.setTileMapping(v, i); 179 | graph.setPerfEstimate(v, 0); // placeholder 180 | } 181 | } 182 | } 183 | prog.add(poplar::program::Execute(cs)); 184 | return castMaybe(graph, out, a.elementType(), prog); 185 | } 186 | 187 | poplar::Tensor l2distancegrad(poplar::Graph& graph, 188 | const poplar::Tensor& a, 189 | const poplar::Tensor& b, 190 | const poplar::Tensor& dist, 191 | const poplar::Tensor& gradOutput, 192 | poplar::program::Sequence& prog, 193 | const poplar::DebugContext& debugContext) { 194 | if (a.rank() != 2 || b.rank() != 2 || dist.rank() != 2 || gradOutput.rank() != 2 || 195 | a.dim(1) != b.dim(1) || gradOutput.dim(0) != a.dim(0) || gradOutput.dim(1) != b.dim(0) || 196 | dist.dim(0) != a.dim(0) || dist.dim(1) != b.dim(0)) { 197 | std::ostringstream msg; 198 | msg << "Bad arguments to l2distancegrad, expected" 199 | << " a.shape (M, K), b.shape (N, K), gradOutput.shape (M, N), actual" 200 | << " a.shape = " << a.shapeToString() << ", b.shape = " << b.shapeToString() 201 | << ", gradOutput.shape = " << gradOutput.shapeToString() << "."; 202 | throw std::invalid_argument(msg.str()); 203 | } 204 | const size_t k = a.dim(1); 205 | poplar::Tensor grad = 206 | graph.addVariable(poplar::FLOAT, a.shape(), {debugContext, "l2dist_grad"}); 207 | mapTensor2Dblocks(graph, grad); 208 | const auto& mapping = graph.getTileMapping(grad); 209 | poplar::ComputeSet cs = graph.addComputeSet({debugContext, "l2dist_grad"}); 210 | const auto vertexName = poputil::templateVertex("L2DistanceGradSingleVertex", a.elementType()); 211 | for (size_t i = 0; i < mapping.size(); ++i) { 212 | for (const auto& interval : mapping[i]) { 213 | for (size_t j = interval.begin(); j != interval.end(); ++j) { 214 | size_t a1_index = j / k; 215 | size_t a2_index = j % k; 216 | auto v = graph.addVertex(cs, vertexName); 217 | graph.connect(v["a"], a[a1_index][a2_index]); 218 | graph.connect(v["b"], b.slice({a2_index, a2_index + 1}, 1).squeeze({1})); 219 | graph.connect(v["dist"], dist[a1_index]); 220 | graph.connect(v["gradOutput"], gradOutput[a1_index]); 221 | graph.connect(v["grad"], grad[a1_index][a2_index]); 222 | graph.setTileMapping(v, i); 223 | graph.setPerfEstimate(v, 0); // placeholder 224 | } 225 | } 226 | } 227 | prog.add(poplar::program::Execute(cs)); 228 | return castMaybe(graph, grad, a.elementType(), prog); 229 | } 230 | 231 | } // namespace poplar_extensions 232 | -------------------------------------------------------------------------------- /doc/pytorch_demo/kge_mapping.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | 3 | """Provide search and autocomplete for entity IDs.""" 4 | 5 | import dataclasses 6 | import gzip 7 | import json 8 | import re 9 | import sqlite3 10 | from pathlib import Path 11 | from typing import Callable, List 12 | 13 | import numpy as np 14 | import ogb.linkproppred 15 | 16 | 17 | @dataclasses.dataclass 18 | class Entry: 19 | """A database entry, corresponding to an entity or relation.""" 20 | 21 | type: str # "entity" | "relation" 22 | idx: int 23 | wikidata_id: str 24 | wikidata_label: str 25 | 26 | 27 | @dataclasses.dataclass 28 | class RawData: 29 | """Utility class for creating a mapping `Database` from a mapping dump and OGB dataset.""" 30 | 31 | mapping: List[Entry] 32 | relation_count: np.ndarray 33 | head_entity_count: np.ndarray 34 | 35 | @classmethod 36 | def load(cls, mapping_path: Path, data_path: Path) -> "RawData": 37 | """Load raw data from raw inputs. 38 | 39 | E.g. `load("ogbl_wikikg2_mapping.jsonl.gz", "data/")` 40 | """ 41 | with gzip.open(mapping_path, "rt") as f: 42 | mapping = [Entry(**json.loads(line)) for line in f] 43 | ogb_data = ogb.linkproppred.LinkPropPredDataset("ogbl-wikikg2", root=data_path) 44 | return cls( 45 | mapping=mapping, 46 | relation_count=np.bincount(ogb_data.graph["edge_reltype"].flatten()), 47 | head_entity_count=np.bincount( 48 | ogb_data.graph["edge_index"][0], minlength=ogb_data.graph["num_nodes"] 49 | ), 50 | ) 51 | 52 | 53 | class Database: 54 | """Uses SQLite FTS3 to provide text search on entity & relation names.""" 55 | 56 | def __init__(self, path: Path): 57 | self.connection = sqlite3.connect(str(path)) 58 | 59 | @staticmethod 60 | def build(path: Path, data: RawData) -> None: 61 | connection = sqlite3.connect(str(path)) 62 | with connection: 63 | connection.execute("BEGIN TRANSACTION") 64 | 65 | # Drop any previous contents 66 | connection.execute("DROP TABLE IF EXISTS mapping") 67 | connection.execute("DROP TABLE IF EXISTS entity_label") 68 | connection.execute("DROP TABLE IF EXISTS relation_label") 69 | 70 | # Create main table & FTS index 71 | # Note that we create separate FTS tables for entity & relation for performance reasons 72 | connection.execute( 73 | "CREATE TABLE mapping (" 74 | "type TEXT NOT NULL" 75 | ", idx INTEGER NOT NULL" 76 | ", wikidata_id TEXT NOT NULL" 77 | ", wikidata_label TEXT NOT NULL" 78 | ", score REAL NOT NULL" 79 | ")" 80 | ) 81 | for type in ["entity", "relation"]: 82 | connection.execute( 83 | f"CREATE VIRTUAL TABLE {type}_label USING fts3 (label, id)" 84 | ) 85 | connection.execute( 86 | f"CREATE TRIGGER mapping_insert_{type} AFTER INSERT ON mapping" 87 | f" WHEN new.type = '{type}'" 88 | " BEGIN" 89 | f" INSERT INTO {type}_label" 90 | " (docid, label, id) VALUES (new.rowid, new.wikidata_label, new.wikidata_id);" 91 | " END;" 92 | ) 93 | connection.executemany( 94 | "INSERT INTO mapping VALUES (:type, :idx, :wikidata_id, :wikidata_label, :score)", 95 | [ 96 | dict( 97 | **entry.__dict__, 98 | score=float( 99 | dict( 100 | relation=data.relation_count, 101 | entity=data.head_entity_count, 102 | )[entry.type][entry.idx] 103 | ), 104 | ) 105 | for entry in data.mapping 106 | ], 107 | ) 108 | connection.execute( 109 | "CREATE UNIQUE INDEX mapping_index_typeidx ON mapping(type, idx)" 110 | ) 111 | connection.close() 112 | 113 | @staticmethod 114 | def _entries(cursor: sqlite3.Cursor) -> List[Entry]: 115 | columns = [c[0] for c in cursor.description] 116 | return [Entry(**dict(zip(columns, x))) for x in cursor.fetchall()] 117 | 118 | def search(self, query: str, type: str, limit: int) -> List[Entry]: 119 | cursor = self.connection.cursor() 120 | cursor.execute( 121 | "SELECT m.type, m.idx, m.wikidata_id, m.wikidata_label" 122 | " FROM mapping as m" 123 | f" INNER JOIN {type}_label as w" 124 | " WHERE m.rowid = w.docid" 125 | " AND m.type = :type" 126 | + (f" AND {type}_label MATCH :query" if query else "") 127 | + " ORDER BY m.score DESC" 128 | " LIMIT :limit", 129 | dict(type=type, query=query, limit=limit), 130 | ) 131 | return self._entries(cursor) 132 | 133 | def get_entry(self, idx: int, type: str) -> Entry: 134 | cursor = self.connection.cursor() 135 | cursor.execute( 136 | "SELECT m.type, m.idx, m.wikidata_id, m.wikidata_label" 137 | " FROM mapping as m" 138 | " WHERE m.type = ? AND m.idx = ?", 139 | [type, int(idx)], 140 | ) 141 | (result,) = self._entries(cursor) 142 | return result 143 | 144 | @property 145 | def all_relations(self) -> List[Entry]: 146 | cursor = self.connection.cursor() 147 | cursor.execute( 148 | "SELECT m.type, m.idx, m.wikidata_id, m.wikidata_label" 149 | " FROM mapping as m" 150 | " WHERE m.type = 'relation'" 151 | " ORDER BY m.score DESC" 152 | ) 153 | return self._entries(cursor) 154 | 155 | 156 | @dataclasses.dataclass 157 | class Prediction(Entry): 158 | """Additional metadata for a tail entity prediction.""" 159 | 160 | in_training: bool 161 | 162 | 163 | @dataclasses.dataclass 164 | class Predictions: 165 | """Render head & relation from a search query, and tail predictions from a model.""" 166 | 167 | heads: List[Entry] 168 | relations: List[Entry] 169 | tail_predictions: List[Prediction] 170 | 171 | STYLE = """ 172 | 203 | """ 204 | 205 | @staticmethod 206 | def _shorten_text(text: str, limit: int) -> str: 207 | if len(text) < limit: 208 | return text 209 | return f"{text[:limit]}…" 210 | 211 | @classmethod 212 | def _render_entry(cls, entry: Entry) -> str: 213 | url = dict( 214 | entity=f"https://www.wikidata.org/wiki/{entry.wikidata_id}", 215 | relation=f"https://www.wikidata.org/wiki/Property:{entry.wikidata_id}", 216 | )[entry.type] 217 | html = ( 218 | f"{cls._shorten_text(entry.wikidata_label, 40)}" 219 | f' ({entry.wikidata_id})' 220 | ) 221 | return html 222 | 223 | @classmethod 224 | def _render_part(cls, entries: List[Entry]) -> str: 225 | html = "
    " 226 | for entry in entries: 227 | class_ = "" 228 | if isinstance(entry, Prediction) and entry.in_training: 229 | class_ = "kge-tail-present" 230 | html += f'
  1. {cls._render_entry(entry)}
  2. ' 231 | html += "
" 232 | return html 233 | 234 | def _repr_html_(self) -> str: 235 | html = self.STYLE 236 | for kind, entries in dict( 237 | head=self.heads, 238 | relation=self.relations, 239 | tail_predictions=self.tail_predictions, 240 | ).items(): 241 | html += ( 242 | f'
' 243 | f'

{kind.capitalize().replace("_", " ")}

' 244 | f"{self._render_part(entries)}
" 245 | ) 246 | return html 247 | 248 | 249 | class Predictor: 250 | """Provide a user-friendly query interface for use in a Jupyter notebook. 251 | 252 | Constructing a predictor takes non-negligible time, so we recommend constructing once 253 | rather than inline for each query. 254 | """ 255 | 256 | def __init__( 257 | self, 258 | database: Database, 259 | predict: Callable[[int, int], List[int]], 260 | train_hrt: np.ndarray, 261 | n_entity: int, 262 | n_relation_type: int, 263 | n_suggestions: int = 10, 264 | ): 265 | self.database = database 266 | self.predict = predict 267 | self.train_hrt_hash = { 268 | int(h) 269 | for h in np.sum( 270 | [ 271 | n_entity * n_relation_type, 272 | n_entity, 273 | 1, 274 | ] 275 | * train_hrt.astype(np.int64), 276 | -1, 277 | ) 278 | } 279 | self.n_entity = n_entity 280 | self.n_relation_type = n_relation_type 281 | self.n_suggestions = n_suggestions 282 | 283 | def search(self, query: str, type: str) -> List[Entry]: 284 | if query and not query.endswith(" ") and not re.match(r"^(P|Q)\d+$", query): 285 | query = query + "*" 286 | return self.database.search(query, type=type, limit=self.n_suggestions) 287 | 288 | def __call__(self, head_query: str, relation_query: str) -> Predictions: 289 | """Run text search on {head, relation} and a model prediction query for {tail}. 290 | 291 | head_query -- str 292 | -- e.g. "england richa" -> "Richard I of England" 293 | e.g. "Q62378" -> "Tower of London" 294 | (note that "*" is appended for word completion unless 295 | the query ends with " " or looks like a WikiData ID) 296 | 297 | relation_query -- str -- same syntax (and ID e.g. "P770") 298 | 299 | returns -- Predictions -- designed for rendering in Jupyter via `display()` 300 | """ 301 | heads = self.search(head_query, type="entity") 302 | relations = self.search(relation_query, type="relation") 303 | predictions = [] 304 | if heads and relations: 305 | head = heads[0].idx 306 | relation = relations[0].idx 307 | predictions = [ 308 | Prediction( 309 | **self.database.get_entry(tail, type="entity").__dict__, 310 | in_training=( 311 | ( 312 | self.n_entity * self.n_relation_type * head 313 | + self.n_entity * relation 314 | + tail 315 | ) 316 | in self.train_hrt_hash 317 | ), 318 | ) 319 | for tail in self.predict(heads[0].idx, relations[0].idx) 320 | ] 321 | return Predictions(heads, relations, predictions) 322 | -------------------------------------------------------------------------------- /src/fructose/fructose.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | 3 | #ifndef FRUCTOSE_HPP 4 | #define FRUCTOSE_HPP 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | 16 | /** 17 | * FR = Fructose 18 | * 19 | * A simple "sugar" interface to Poplar/PAG. Sugar means: 20 | * - No need to pass `pag::Graph` or `pag::Tape` around, they're stashed in globals. 21 | * - Lazy Tensor layout - variables get tile mapping upon first use. 22 | * - Tensor operator overloads, friendly op definitions. 23 | */ 24 | namespace fr { 25 | 26 | /////////////////////////////////////////////////////////////////////////////// 27 | // Graph construction 28 | 29 | /** 30 | * A fr::Tensor wraps either a grad-able pag::Tensor, or metadata, such that 31 | * it supports 'lazy layout' by the first op that uses it. 32 | */ 33 | struct Tensor { 34 | using ID = unsigned; 35 | using Shape = std::vector; 36 | static constexpr ID Invalid = 0u; 37 | struct Spec { 38 | Shape shape; 39 | poplar::Type dtype; 40 | }; 41 | 42 | Tensor(); 43 | static Tensor declare(const Spec&, bool requiresGrad, const std::string& name); 44 | static Tensor wrap(const pag::Tensor&); 45 | 46 | void set(const poplar::Tensor&) const; 47 | void hostAccess(bool read = true, bool write = true) const; 48 | void backward(const Tensor& rootGrad = {}) const; 49 | 50 | Tensor transpose() const; 51 | Tensor reshape(const Shape&) const; 52 | Tensor slice(size_t dim, poplar::Interval region) const; 53 | std::vector split(size_t dim, const std::vector& sizes) const; 54 | 55 | ID id() const; 56 | const Spec& spec() const; 57 | const Shape& shape() const; 58 | size_t rank() const; 59 | size_t numElements() const; 60 | poplar::Type dtype() const; 61 | const std::string& name() const; 62 | bool valid() const; 63 | pag::Tensor pag() const; 64 | Tensor grad() const; 65 | Tensor astype(poplar::Type) const; 66 | Tensor operator[](const Tensor& index) const; 67 | 68 | private: 69 | explicit Tensor(ID id); 70 | ID m_id; 71 | }; 72 | 73 | bool operator==(const Tensor::Spec&, const Tensor::Spec&); 74 | bool operator!=(const Tensor::Spec&, const Tensor::Spec&); 75 | std::ostream& operator<<(std::ostream&, const Tensor::Spec&); 76 | 77 | Tensor operator+(const Tensor& lhs, const Tensor& rhs); 78 | Tensor operator-(const Tensor& lhs, const Tensor& rhs); 79 | Tensor operator*(const Tensor& lhs, const Tensor& rhs); 80 | Tensor operator/(const Tensor& lhs, const Tensor& rhs); 81 | Tensor operator-(const Tensor& tensor); 82 | Tensor operator~(const Tensor& tensor); 83 | Tensor operator<(const Tensor& lhs, const Tensor& rhs); 84 | 85 | namespace ops { 86 | 87 | // Sources 88 | template 89 | Tensor constant(T value, poplar::Type type = poplar::equivalent_device_type().value); 90 | template 91 | Tensor constant(const std::vector& value, 92 | const std::optional& shape = std::nullopt, 93 | poplar::Type type = poplar::equivalent_device_type().value); 94 | template 95 | Tensor full(const Tensor::Shape& shape, 96 | const T value, 97 | poplar::Type type = poplar::equivalent_device_type().value); 98 | Tensor variable(const std::string& name, 99 | const Tensor::Spec& spec, 100 | std::optional requiresGrad = std::nullopt); 101 | Tensor input(const std::string& handle, const Tensor::Spec& spec); 102 | Tensor randomNormal(float mean, 103 | float stdDev, 104 | const Tensor::Shape& shape, 105 | unsigned seed = 0, 106 | poplar::Type type = poplar::FLOAT); 107 | 108 | // Sinks 109 | void output(const std::string& handle, const Tensor& tensor); 110 | void print(const std::string& message, const Tensor& tensor); 111 | 112 | // Transformations 113 | Tensor abs(const Tensor& a); 114 | Tensor max(const Tensor& a, const Tensor& b); 115 | Tensor square(const Tensor& a); 116 | Tensor pow(const Tensor& a, float exponent); 117 | Tensor sqrt(const Tensor& a); 118 | Tensor cbrt(const Tensor& a); 119 | Tensor sin(const Tensor& a); 120 | Tensor cos(const Tensor& a); 121 | Tensor l1distance(const Tensor& a, const Tensor& b); 122 | Tensor l2distance(const Tensor& a, const Tensor& b); 123 | Tensor gather(const Tensor& tensor, const Tensor& indices); 124 | Tensor sum(const Tensor& tensor, const std::vector& dims = {}); 125 | Tensor mean(const Tensor& tensor, const std::vector& dims = {}); 126 | Tensor matMul(const Tensor& lhs, const Tensor& rhs); 127 | Tensor logSoftmax(const Tensor& tensor); 128 | Tensor sigmoid(const Tensor& tensor); 129 | Tensor logSigmoid(const Tensor& tensor); 130 | Tensor oneHot(const Tensor& tensor, size_t N, poplar::Type type); 131 | 132 | Tensor startGrad(const Tensor& tensor); 133 | Tensor concat(const std::vector& tensors, size_t dim); 134 | // Only in the fwd pass 135 | Tensor copyToLinearTensor(const Tensor& tensor, 136 | std::optional minElementsPerTile = std::nullopt, 137 | std::optional grainSize = std::nullopt); 138 | 139 | // Collectives 140 | Tensor allGather(const Tensor& tensor); 141 | Tensor allToAll(const Tensor& tensor); 142 | 143 | // Other 144 | void forN(unsigned n, const std::function& body); 145 | 146 | } // namespace ops 147 | 148 | /** 149 | * Reading from or writing to host memory. 150 | */ 151 | struct Stream { 152 | Stream() = default; 153 | Stream(const std::string& handle, const Tensor::Spec&, poplar::DataStreamType); 154 | 155 | std::string handle() const; 156 | Tensor::Spec spec() const; 157 | 158 | Tensor read() const; 159 | void write(const Tensor&); 160 | 161 | private: 162 | poplar::DataStream m_stream; 163 | std::vector m_shape; 164 | }; 165 | 166 | /** 167 | * Reading from and writing to remote buffers. 168 | */ 169 | struct Buffer { 170 | Buffer() = default; 171 | Buffer(const std::string& name, const Tensor::Spec&); 172 | 173 | Tensor read(const Tensor& indices) const; 174 | void write(const Tensor& data, const Tensor& indices); 175 | 176 | size_t totalBytes(const poplar::Target&) const; 177 | 178 | private: 179 | std::vector rwShape(const Tensor& indices) const; 180 | poplar::RemoteBuffer m_buffer; 181 | std::vector m_rowShape; 182 | }; 183 | 184 | /////////////////////////////////////////////////////////////////////////////// 185 | // Helpers & utilities 186 | 187 | namespace util { 188 | 189 | template 190 | struct Seq { 191 | const T& sequence; 192 | }; 193 | 194 | template 195 | Seq seq(const T& sequence); 196 | 197 | template 198 | std::ostream& operator<<(std::ostream&, const Seq&); 199 | 200 | template 201 | bool operator==(const Seq&, const Seq&); 202 | template 203 | bool operator!=(const Seq&, const Seq&); 204 | 205 | void checkArgument(const Tensor&, 206 | const std::string& message, 207 | const Tensor::Shape& shape, 208 | const std::vector& types = {}); 209 | 210 | template 211 | std::vector arange(T start, T end); 212 | 213 | template 214 | std::vector arange(T end); 215 | 216 | unsigned numElements(const fr::Tensor::Shape& shape); 217 | 218 | template 219 | auto mapVector(const Collection& items, Func&& func) -> std::vector; 220 | 221 | } // namespace util 222 | 223 | namespace mapping { 224 | 225 | struct Method { 226 | virtual void apply(poplar::Graph&, const poplar::Tensor&) const = 0; 227 | }; 228 | 229 | struct Linear : Method { 230 | void apply(poplar::Graph&, const poplar::Tensor&) const; 231 | }; 232 | 233 | struct OneTile : Method { 234 | /** 235 | * Tile number, can be negative, counting back from the last tile. 236 | * 237 | * Note that the tile number will wrap around modulo target.getNumTiles(). 238 | */ 239 | int tile; 240 | OneTile(); 241 | explicit OneTile(int tile); 242 | void apply(poplar::Graph&, const poplar::Tensor&) const; 243 | }; 244 | 245 | /** 246 | * Copy a tile mapping from another tensor. 247 | */ 248 | struct Copy : Method { 249 | poplar::Tensor base; 250 | Copy(const poplar::Tensor&); 251 | void apply(poplar::Graph&, const poplar::Tensor&) const; 252 | }; 253 | 254 | /** 255 | * For each tensor, set a mapping (if not already set). 256 | */ 257 | void setDefault(const Method& method, const std::vector& tensors); 258 | 259 | } // namespace mapping 260 | 261 | /////////////////////////////////////////////////////////////////////////////// 262 | // State management 263 | 264 | struct TensorPool; 265 | 266 | /** 267 | * A frame adds a new frame to the stack (e.g. debugContext), and gives access to the global 268 | * `pag::Graph`, `pag::Tape` etc. 269 | * 270 | * Usage: 271 | * Frame f("myOps::foo"); 272 | * foo(f.graph, f.tape, f.di); 273 | */ 274 | struct Frame { 275 | pag::Graph& graph; 276 | poplin::PlanningCache& matMulCache; 277 | pag::Tape& tape; 278 | std::unordered_map& streams; 279 | poplar::DebugInfo di; 280 | 281 | explicit Frame(const std::string& name = "", 282 | poplar::SourceLocation loc = poplar::SourceLocation::Current()); 283 | Frame(const Frame&) = delete; 284 | Frame& operator=(const Frame&) = delete; 285 | ~Frame(); 286 | 287 | unsigned replicationFactor() const; 288 | 289 | protected: 290 | Frame(pag::Graph& graph, 291 | poplin::PlanningCache& matMulCache, 292 | pag::Tape& tape, 293 | std::unordered_map& streams, 294 | const poplar::DebugInfo& di, 295 | const std::string& name, 296 | const poplar::SourceLocation& loc); 297 | }; 298 | 299 | struct SubProgramFrame : Frame { 300 | explicit SubProgramFrame(const std::string& name = "", 301 | poplar::SourceLocation loc = poplar::SourceLocation::Current()); 302 | 303 | private: 304 | pag::Tape m_tape; 305 | std::unordered_map m_streams; 306 | }; 307 | 308 | struct RootFrame : Frame { 309 | std::unordered_map variables; 310 | std::unique_ptr pool; 311 | 312 | explicit RootFrame(const poplar::Target&, 313 | poplar::SourceLocation loc = poplar::SourceLocation::Current()); 314 | ~RootFrame(); 315 | 316 | private: 317 | poplar::Graph m_poplarGraph; 318 | pag::Graph m_pagGraph; 319 | poplin::PlanningCache m_matMulCache; 320 | pag::Tape m_tape; 321 | std::unordered_map m_streams; 322 | }; 323 | 324 | /** 325 | * Global state that defines the current `poplar::Graph` and `poplar::program::Sequence` 326 | * being built. 327 | */ 328 | struct Environment { 329 | static Frame& frame(); 330 | static RootFrame& rootFrame(); 331 | 332 | Environment(const Environment&) = delete; 333 | Environment& operator=(const Environment&) = delete; 334 | 335 | private: 336 | friend struct Frame; 337 | friend struct RootFrame; 338 | Environment(); 339 | static Environment& instance(); 340 | RootFrame* m_root; 341 | std::vector m_stack; 342 | }; 343 | 344 | } // namespace fr 345 | 346 | /////////////////////////////////////////////////////////////////////////////// 347 | // Template implementations 348 | 349 | namespace fr { 350 | 351 | namespace ops { 352 | 353 | template 354 | Tensor constant(T value, poplar::Type type) { 355 | Frame f("fr::ops::constant"); 356 | auto poplarTensor = f.graph.poplar().addConstant(type, {}, value, f.di); 357 | // Most ops fill from zero, so constants on N-1 seems somewhat reasonable. 358 | mapping::OneTile().apply(f.graph.poplar(), poplarTensor); 359 | return Tensor::wrap(f.graph.wrap(poplarTensor, /*requiresGrad*/ false)); 360 | } 361 | 362 | template 363 | Tensor constant(const std::vector& value, 364 | const std::optional& shape, 365 | poplar::Type type) { 366 | Frame f("fr::ops::constant"); 367 | auto actualShape = shape.value_or(Tensor::Shape{value.size()}); 368 | auto poplarTensor = f.graph.poplar().addConstant(type, actualShape, value); 369 | mapping::Linear().apply(f.graph.poplar(), poplarTensor); 370 | return Tensor::wrap(f.graph.wrap(poplarTensor, /*requiresGrad*/ false)); 371 | } 372 | 373 | template 374 | Tensor full(const Tensor::Shape& shape, const T value, poplar::Type type) { 375 | Frame f("fr::ops::full"); 376 | auto poplarTensor = f.graph.poplar().addConstant(type, shape, value, f.di); 377 | mapping::Linear().apply(f.graph.poplar(), poplarTensor); 378 | return Tensor::wrap(f.graph.wrap(poplarTensor, /*requiresGrad*/ false)); 379 | } 380 | 381 | } // namespace ops 382 | 383 | namespace util { 384 | 385 | template 386 | Seq seq(const T& sequence) { 387 | return Seq{sequence}; 388 | } 389 | 390 | template 391 | std::ostream& operator<<(std::ostream& out, const Seq& seq) { 392 | bool separator = false; 393 | out << "{"; 394 | for (auto& element : seq.sequence) { 395 | if (separator) out << ", "; 396 | out << element; 397 | separator = true; 398 | } 399 | return out << "}"; 400 | } 401 | 402 | template 403 | bool operator==(const Seq& lhs, const Seq& rhs) { 404 | return std::equal(lhs.sequence.begin(), lhs.sequence.end(), rhs.sequence.begin(), 405 | rhs.sequence.end()); 406 | } 407 | template 408 | bool operator!=(const Seq& lhs, const Seq& rhs) { 409 | return !(lhs == rhs); 410 | } 411 | 412 | template 413 | std::vector arange(T start, T end) { 414 | std::vector result(static_cast(end - start)); 415 | std::iota(result.begin(), result.end(), start); 416 | return result; 417 | } 418 | 419 | template 420 | std::vector arange(T end) { 421 | return arange(T(0), end); 422 | } 423 | 424 | template 425 | auto mapVector(const Collection& items, Func&& func) 426 | -> std::vector { 427 | std::vector result; 428 | result.reserve(items.size()); 429 | std::transform(items.begin(), items.end(), std::back_inserter(result), 430 | std::forward(func)); 431 | return result; 432 | } 433 | 434 | } // namespace util 435 | 436 | } // namespace fr 437 | 438 | #endif // FRUCTOSE_HPP 439 | -------------------------------------------------------------------------------- /src/python/poplar_kge_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | 3 | """Dataset loading, preparation, batching for poplar_kge.""" 4 | 5 | import dataclasses 6 | import os 7 | from pathlib import Path 8 | from typing import Callable, Dict, Iterable, Tuple 9 | 10 | import numpy as np 11 | import ogb.lsc 12 | import poplar_kge as kge 13 | import torch 14 | 15 | 16 | @dataclasses.dataclass 17 | class RawData: 18 | """Wraps ogb.lsc.WikiKG90Mv2Dataset to allow testing on fake data.""" 19 | 20 | n_entity: int 21 | n_relation_type: int 22 | entity_features: np.ndarray # float16[n_entity x feature_size] 23 | train_hrt: np.ndarray # uint32[n_train x 3] 24 | eval_hr_: Dict[str, np.ndarray] # {str: uint[n_train x (2 or 3)]} 25 | 26 | @classmethod 27 | def generate( 28 | cls, 29 | n_entity: int, 30 | n_relation_type: int, 31 | feature_size: int, 32 | n_train: int, 33 | n_eval: int, 34 | seed: int, 35 | ) -> "RawData": 36 | random = np.random.RandomState(seed) 37 | # There is no generalisation is this random world 38 | hrt = { 39 | key: np.stack( 40 | [ 41 | random.randint(n_entity, size=n), 42 | random.randint(n_relation_type, size=n), 43 | random.randint(n_entity, size=n), 44 | ], 45 | axis=-1, 46 | ).astype(np.uint32) 47 | for key, n in { 48 | "train": n_train, 49 | "valid": n_eval, 50 | "test-dev": n_eval, 51 | "test-challenge": n_eval, 52 | }.items() 53 | } 54 | return cls( 55 | n_entity=n_entity, 56 | n_relation_type=n_relation_type, 57 | entity_features=random.randn(n_entity, feature_size).astype(np.float16), 58 | train_hrt=hrt["train"], 59 | eval_hr_=dict( 60 | train=hrt["train"][random.choice(n_train, size=n_eval, replace=False)], 61 | **{k: hrt[k] for k in ["valid", "test-dev", "test-challenge"]}, 62 | ), 63 | ) 64 | 65 | @classmethod 66 | def load_wikikg90mv2( 67 | cls, path: Path, seed: int, entity_limit: int = 1 << 63 68 | ) -> "RawData": 69 | data = ogb.lsc.WikiKG90Mv2Dataset(path) 70 | n_entity = data.num_entities 71 | train_hrt = data.train_hrt.astype(np.uint32) 72 | eval_hr_ = {} 73 | eval_hr_["valid"] = np.concatenate( 74 | [ 75 | data.valid_dict["h,r->t"]["hr"], 76 | data.valid_dict["h,r->t"]["t"][:, np.newaxis], 77 | ], 78 | axis=1, 79 | ).astype(np.uint32) 80 | entity_features = data.entity_feat 81 | 82 | if entity_limit < n_entity: 83 | # Select a subset of entities (total <= entity_limit) 84 | entity_mask = np.full(data.num_entities, False) 85 | # Most common heads 86 | entity_mask[ 87 | np.argsort(np.bincount(data.train_hrt[:, 0]))[-entity_limit // 2 :] 88 | ] = True 89 | # Most common tails 90 | entity_mask[ 91 | np.argsort(np.bincount(data.train_hrt[:, 2]))[-entity_limit // 2 :] 92 | ] = True 93 | 94 | # Truncate & re-map dataset 95 | n_entity = np.sum(entity_mask) 96 | old_to_new_entity = np.full(data.num_entities, -1) 97 | old_to_new_entity[np.where(entity_mask)] = np.arange(n_entity) 98 | train_hrt = train_hrt[ 99 | entity_mask[train_hrt[:, 0]] & entity_mask[train_hrt[:, 2]] 100 | ] 101 | train_hrt[:, 0] = old_to_new_entity[train_hrt[:, 0]] 102 | train_hrt[:, 2] = old_to_new_entity[train_hrt[:, 2]] 103 | eval_hr_["valid"] = eval_hr_["valid"][ 104 | entity_mask[eval_hr_["valid"][:, 0]] 105 | & entity_mask[eval_hr_["valid"][:, 2]] 106 | ] 107 | eval_hr_["valid"][:, 0] = old_to_new_entity[eval_hr_["valid"][:, 0]] 108 | eval_hr_["valid"][:, 2] = old_to_new_entity[eval_hr_["valid"][:, 2]] 109 | entity_features = entity_features[np.where(entity_mask)] 110 | else: 111 | # Only add 'test' sets when entities haven't been truncated/remapped 112 | for name in ["test-dev", "test-challenge"]: 113 | eval_hr_[name] = data.test_dict(name)["h,r->t"]["hr"].astype(np.uint32) 114 | 115 | # After train_hrt has been truncated/remapped 116 | eval_hr_["train"] = train_hrt[ 117 | np.random.RandomState(seed).choice(train_hrt.shape[0], size=15000) 118 | ] 119 | 120 | # Note: test first so that .astype() doesn't collapse the memmap 121 | if entity_features.dtype != np.float16: 122 | entity_features = entity_features.astype(np.float16) 123 | return cls( 124 | n_entity=n_entity, 125 | n_relation_type=data.num_relations, 126 | entity_features=entity_features, 127 | train_hrt=train_hrt, 128 | eval_hr_=eval_hr_, 129 | ) 130 | 131 | @classmethod 132 | def load(cls, settings: kge.Settings) -> "RawData": 133 | if isinstance(settings.data.dataset, kge.WikiKg90Mv2Settings): 134 | path = Path( 135 | os.environ.get("OGBWIKIKG_PATH", "/localdata/research/datasets/ogb/lsc") 136 | ) 137 | if not (path / "wikikg90m-v2").exists(): 138 | raise ValueError( 139 | f"Dataset 'wikikg90mv2' was not found at {path}.\n" 140 | f" On the farm, try: mkdir -p {path}" 141 | " && rsync -a --info=progress2 --chmod=D0770,F660" 142 | f" /home/research-datasets/ogb/lsc/wikikg90m-v2/ {path}/wikikg90m-v2/" 143 | ) 144 | return cls.load_wikikg90mv2( 145 | path, 146 | seed=settings.data.seed, 147 | entity_limit=settings.model.n_shard * (settings.model.n_entity - 1), 148 | ) 149 | if isinstance(settings.data.dataset, kge.GeneratedDataSettings): 150 | return cls.generate( 151 | n_entity=(settings.model.n_entity - 1) * settings.model.n_shard - 1, 152 | n_relation_type=settings.model.n_relation_type, 153 | feature_size=settings.model.entity_feature_size, 154 | n_train=settings.data.dataset.n_train, 155 | n_eval=settings.data.dataset.n_eval, 156 | seed=settings.data.dataset.seed, 157 | ) 158 | raise ValueError(f"Unknown dataset '{settings.data.dataset}'") 159 | 160 | 161 | def unique_pad(idx: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: 162 | """A zero-padded version of `np.unique(idx, return_inverse=True)`.""" 163 | unique, inverse = np.unique(idx, return_inverse=True) 164 | return np.pad(unique, ((0, len(idx) - len(unique)),)), inverse.astype(np.uint32) 165 | 166 | 167 | # Dataset 168 | 169 | 170 | @dataclasses.dataclass 171 | class Dataset: 172 | def __init__( 173 | self, 174 | data: RawData, 175 | n_shard: int, 176 | train_steps_per_program_run: int, 177 | settings: kge.DataSettings, 178 | ): 179 | self.data = data 180 | self.n_shard = n_shard 181 | self.train_steps_per_program_run = train_steps_per_program_run 182 | self.settings = settings 183 | 184 | # Partition entities into shards 185 | self.random = np.random.RandomState(settings.seed) 186 | self.entity_to_idx, self.entity_to_shard = np.divmod( 187 | self.random.permutation(data.n_entity), n_shard 188 | ) 189 | self.entity_to_idx += 1 # index zero is for padding 190 | self.shard_to_count = np.bincount(self.entity_to_shard, minlength=n_shard) 191 | self.n_entity_per_shard = 1 + np.max(self.shard_to_count) 192 | self.shard_idx_to_entity = np.full( 193 | (n_shard, self.n_entity_per_shard), 1 << 31, dtype=np.uint32 194 | ) 195 | self.shard_idx_to_entity[self.entity_to_shard, self.entity_to_idx] = np.arange( 196 | data.n_entity 197 | ) 198 | 199 | # Indexed by shardpair=(head_shard, tail_shard) for sampling 200 | triple_to_head_shard = self.entity_to_shard[data.train_hrt[:, 0]] 201 | triple_to_tail_shard = self.entity_to_shard[data.train_hrt[:, 2]] 202 | 203 | if isinstance(self.settings.sampling_strategy, kge.CubicRootRelationSampling): 204 | triple_to_shardpair = ( 205 | n_shard * data.n_relation_type * triple_to_head_shard 206 | + data.n_relation_type * triple_to_tail_shard 207 | + data.train_hrt[:, 1] 208 | ) 209 | shardpair_shape = (n_shard * n_shard, data.n_relation_type) 210 | sort_idx = np.argsort(triple_to_shardpair) 211 | else: 212 | triple_to_shardpair = n_shard * triple_to_head_shard + triple_to_tail_shard 213 | shardpair_shape = (n_shard, n_shard) 214 | assert ( 215 | n_shard <= 16 216 | ), f"cannot use uint8 to speed up sorting when n_shard ({n_shard}) > 16" 217 | sort_idx = np.argsort(triple_to_shardpair.astype(np.uint8)) 218 | 219 | self.shardpair_to_count = np.bincount( 220 | triple_to_shardpair, 221 | minlength=np.prod(shardpair_shape), 222 | ).reshape(shardpair_shape) 223 | self.shardpair_to_offset = np.concatenate( 224 | [[0], np.cumsum(self.shardpair_to_count)[:-1]] 225 | ).reshape(shardpair_shape) 226 | 227 | if isinstance(self.settings.sampling_strategy, kge.CubicRootRelationSampling): 228 | count_root = np.cbrt(self.shardpair_to_count) 229 | self.shardpair_sample_prob = count_root / np.sum( 230 | count_root, axis=-1, keepdims=True 231 | ) 232 | 233 | train_hrt_sorted = data.train_hrt[sort_idx] 234 | self.shardpair_to_flat_hrt = np.stack( 235 | [ 236 | self.entity_to_idx[train_hrt_sorted[:, 0]], 237 | train_hrt_sorted[:, 1], 238 | self.entity_to_idx[train_hrt_sorted[:, 2]], 239 | ], 240 | axis=0, 241 | ).astype(np.uint32) 242 | 243 | # Derived hyperparameters 244 | self.entity_projection_seed = self.random.randint(1 << 32) 245 | if settings.batch_size % n_shard != 0: 246 | raise ValueError( 247 | f"Expected batch_size ({settings.batch_size}) to be a multiple of n_shard ({n_shard})" 248 | ) 249 | self.positives_per_shardpair = settings.batch_size // n_shard 250 | if settings.a2a_size <= self.positives_per_shardpair: 251 | raise ValueError( 252 | f"Expected a2a_size ({settings.a2a_size}) to be >= batch_size/n_shard ({self.positives_per_shardpair})" 253 | ) 254 | self.negatives_per_shardpair = settings.a2a_size - self.positives_per_shardpair 255 | 256 | # Precomute 'tail', possible since batch positives are always in the same place 257 | tail = ( 258 | np.tile(np.arange(self.positives_per_shardpair), (n_shard, n_shard, 1)) 259 | + (settings.a2a_size * np.arange(n_shard))[np.newaxis, :, np.newaxis] 260 | ).reshape(n_shard, settings.batch_size) 261 | self.pos_tail_idx = np.tile( 262 | tail[:, np.newaxis, :], (1, train_steps_per_program_run, 1) 263 | ).astype(np.uint32) 264 | 265 | def _sharded_entity_features(self, features: np.ndarray) -> np.ndarray: 266 | result = np.zeros( 267 | (self.n_shard, self.n_entity_per_shard, features.shape[-1]), 268 | dtype=features.dtype, 269 | ) 270 | result[self.entity_to_shard, self.entity_to_idx, :] = features 271 | return result 272 | 273 | @classmethod 274 | def load(cls, data: RawData, settings: kge.Settings) -> "Dataset": 275 | return cls( 276 | data, 277 | n_shard=settings.model.n_shard, 278 | train_steps_per_program_run=settings.execution.train_steps_per_program_run, 279 | settings=settings.data, 280 | ) 281 | 282 | def entity_features(self, feature_size: int) -> np.ndarray: 283 | mapping = self.settings.entity_feature_mapping 284 | if mapping == "zero": 285 | return np.zeros( 286 | (self.n_shard, self.n_entity_per_shard, feature_size), dtype=np.float16 287 | ) 288 | 289 | if mapping == "full" or mapping == "random_projection": 290 | data_feature_size = self.data.entity_features.shape[-1] 291 | if mapping == "full" and feature_size != data_feature_size: 292 | raise ValueError( 293 | "Entity_feature_mapping 'full' requires dataset " 294 | f"feature_size ({data_feature_size}) == model feature_size ({feature_size})" 295 | ) 296 | entity_features = self.data.entity_features 297 | if mapping == "random_projection": 298 | # Note - CPU float16 is slow, it might be faster to do this in float32 299 | projection = np.random.RandomState(self.entity_projection_seed).randn( 300 | data_feature_size, feature_size 301 | ).astype(np.float16) / np.sqrt(data_feature_size) 302 | entity_features = entity_features @ projection 303 | 304 | return self._sharded_entity_features(entity_features) 305 | 306 | raise ValueError(f"Unknown entity_feature_mapping: {mapping}") 307 | 308 | # Training 309 | 310 | def sample_batch(self) -> kge.Batch: 311 | n_shard = self.n_shard 312 | n_batch = self.train_steps_per_program_run 313 | batch_size = self.settings.batch_size 314 | 315 | # Sample positive & negative entity indices 316 | if isinstance(self.settings.sampling_strategy, kge.CubicRootRelationSampling): 317 | num_samples = self.positives_per_shardpair * n_batch 318 | shardpair_relat = np.stack( 319 | [ 320 | self.random.choice( 321 | self.data.n_relation_type, size=num_samples, p=prob 322 | ) 323 | for prob in self.shardpair_sample_prob 324 | ] 325 | ) 326 | shardpair_offset = self.shardpair_to_offset[ 327 | np.arange(n_shard * n_shard)[:, np.newaxis], shardpair_relat 328 | ].reshape(n_shard, n_shard, -1) 329 | shardpair_count = self.shardpair_to_count[ 330 | np.arange(n_shard * n_shard)[:, np.newaxis], shardpair_relat 331 | ].reshape(n_shard, n_shard, -1) 332 | sample_idx = ( 333 | ( 334 | shardpair_offset 335 | + self.random.randint( 336 | 1 << 63, 337 | size=(n_shard, n_shard, n_batch * self.positives_per_shardpair), 338 | ) 339 | % shardpair_count 340 | ) 341 | .reshape(n_shard, n_shard, n_batch, -1) 342 | .transpose(0, 2, 1, 3) 343 | ) 344 | else: 345 | sample_idx = ( 346 | self.shardpair_to_offset[:, np.newaxis, :, np.newaxis] 347 | + self.random.randint( 348 | 1 << 63, 349 | size=(n_shard, n_batch, n_shard, self.positives_per_shardpair), 350 | ) 351 | % self.shardpair_to_count[:, np.newaxis, :, np.newaxis] 352 | ) 353 | 354 | h, r, t = self.shardpair_to_flat_hrt[:, sample_idx] 355 | t_negative = 1 + ( 356 | self.random.randint( 357 | 1 << 63, size=(n_shard, n_batch, n_shard, self.negatives_per_shardpair) 358 | ) 359 | % self.shard_to_count[np.newaxis, np.newaxis, :, np.newaxis] 360 | ).astype(np.uint32) 361 | 362 | # Remote indices, deduplication & routing indices 363 | remote_duplicated = np.concatenate( 364 | [ 365 | h.reshape(n_shard * n_batch, -1), 366 | t.transpose(2, 1, 0, 3).reshape(n_shard * n_batch, -1), 367 | t_negative.transpose(2, 1, 0, 3).reshape(n_shard * n_batch, -1), 368 | ], 369 | axis=1, 370 | ) 371 | remote, gather = map(np.stack, zip(*[unique_pad(a) for a in remote_duplicated])) 372 | remote = remote.reshape(n_shard, n_batch, -1) 373 | head = gather[:, :batch_size].reshape(n_shard, n_batch, batch_size) 374 | a2a = np.concatenate( 375 | [ 376 | gather[:, batch_size : 2 * batch_size].reshape( 377 | n_shard, n_batch, n_shard, self.positives_per_shardpair 378 | ), 379 | gather[:, 2 * batch_size :].reshape( 380 | n_shard, n_batch, n_shard, self.negatives_per_shardpair 381 | ), 382 | ], 383 | axis=3, 384 | ) 385 | relation = r.reshape(n_shard, n_batch, batch_size) 386 | 387 | return kge.Batch( 388 | remote=remote, 389 | head=np.ascontiguousarray(head), 390 | relation=np.ascontiguousarray(relation), 391 | a2a=a2a, 392 | tail=self.pos_tail_idx, 393 | ) 394 | 395 | def batches(self) -> Iterable[kge.Batch]: 396 | while True: 397 | yield self.sample_batch() 398 | 399 | # Evaluation 400 | 401 | def predict( 402 | self, part: str, predict_fn: Callable[[np.ndarray], kge.Predictions] 403 | ) -> Tuple[np.ndarray, np.ndarray]: 404 | """Compute predictions (in plain entities) for an evaluation set. 405 | 406 | predictions = predict_fn(shard_idx_relation) 407 | shard_idx_relation: uint32[n x 3] -- (head.shard, head.index, head.relation) 408 | predictions: kge.Predictions[n x n_best] 409 | 410 | Returns: (tail, score) over all (h,r,?) in `eval_hr_[part]`. 411 | tail: uint32[n x n_best] 412 | score: float[n x n_best] 413 | """ 414 | head = self.data.eval_hr_[part][:, 0] 415 | relation = self.data.eval_hr_[part][:, 1] 416 | predictions = predict_fn( 417 | np.stack( 418 | [self.entity_to_shard[head], self.entity_to_idx[head], relation], axis=1 419 | ) 420 | ) 421 | tail = self.shard_idx_to_entity[ 422 | predictions.shard_idx[..., 0], predictions.shard_idx[..., 1] 423 | ] 424 | return (tail, predictions.score) 425 | 426 | def mrr( 427 | self, part: str, predict_fn: Callable[[np.ndarray], kge.Predictions] 428 | ) -> float: 429 | """Compute Mean Reciprocal Rank (mrr) for a labelled evaluation set.""" 430 | true_tails = self.data.eval_hr_[part][:, 2].astype(np.int32) 431 | predicted_tails = self.predict(part, predict_fn)[0][:, :10].astype(np.int32) 432 | return ogb.lsc.WikiKG90Mv2Evaluator().eval( # type:ignore[no-any-return] 433 | {"h,r->t": dict(t=true_tails, t_pred_top10=predicted_tails)} 434 | )["mrr"] 435 | 436 | 437 | class DatasetWrapper(torch.utils.data.Dataset[Dict[str, np.ndarray]]): 438 | def __init__(self, dataset: Dataset) -> None: 439 | self.ds = dataset 440 | 441 | def __len__(self) -> int: 442 | return (2**31) - 1 443 | 444 | def __getitem__(self, item: int) -> Dict[str, np.ndarray]: 445 | sample_batch = self.ds.sample_batch() 446 | return { 447 | "remote": sample_batch.remote.astype(np.int32), 448 | "a2a": sample_batch.a2a.astype(np.int32), 449 | "head": sample_batch.head.astype(np.int32), 450 | "relation": sample_batch.relation.astype(np.int32), 451 | "tail": sample_batch.tail.astype(np.int32), 452 | } 453 | 454 | @staticmethod 455 | def tensors_to_batch( 456 | remote: torch.Tensor, 457 | a2a: torch.Tensor, 458 | head: torch.Tensor, 459 | relation: torch.Tensor, 460 | tail: torch.Tensor, 461 | ) -> kge.Batch: 462 | return kge.Batch( 463 | remote=remote.numpy().astype(np.uint32), 464 | head=head.numpy().astype(np.uint32), 465 | relation=relation.numpy().astype(np.uint32), 466 | a2a=a2a.numpy().astype(np.uint32), 467 | tail=tail.numpy().astype(np.uint32), 468 | ) 469 | 470 | @staticmethod 471 | def worker_init_fn(worker_id: int) -> None: 472 | worker_info = torch.utils.data.get_worker_info() # type:ignore[no-untyped-call] 473 | dataset_unwrapped = worker_info.dataset.ds 474 | worked_seed = dataset_unwrapped.settings.seed + worker_id 475 | dataset_unwrapped.random = np.random.RandomState(worked_seed) 476 | -------------------------------------------------------------------------------- /doc/pytorch_demo/kge_training.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | 3 | """Modelling and distributed training code for ogbl-wikikg2 on IPU. 4 | 5 | Implements BESS, see https://arxiv.org/abs/2211.12281 for more information. 6 | """ 7 | 8 | import ctypes 9 | import os 10 | import time 11 | from dataclasses import dataclass 12 | from pathlib import Path 13 | from typing import Any, Callable, Dict, Iterable, List, Optional 14 | 15 | import numpy as np 16 | import ogb.linkproppred 17 | import poptorch 18 | import torch as T 19 | 20 | # Data, sharding, batching 21 | 22 | 23 | @dataclass 24 | class Dataset: 25 | """Represents a complete knowledge graph dataset of (head, relation, tail) triples.""" 26 | 27 | n_entity: int 28 | n_relation_type: int 29 | triples: Dict[str, np.ndarray] # {"train|valid": int64[n_triple, {h,r,t}]} 30 | 31 | @classmethod 32 | def build_wikikg2(cls, root: Path, out: Path, seed: int) -> None: 33 | """Build the OGB dataset into a simple .npz file, for faster loading.""" 34 | data = ogb.linkproppred.LinkPropPredDataset("ogbl-wikikg2", root=root) 35 | split = data.get_edge_split() 36 | parts = {} 37 | random = np.random.default_rng(seed) 38 | for part in ["train", "valid"]: 39 | hrt = split[part] 40 | parts[part] = np.stack([hrt["head"], hrt["relation"], hrt["tail"]], axis=-1) 41 | random.shuffle(parts[part]) 42 | np.savez( 43 | out, 44 | n_entity=int(data[0]["num_nodes"]), 45 | n_relation_type=int(1 + np.max(data.graph["edge_reltype"])), 46 | **{f"part_{k}": v for k, v in parts.items()}, 47 | ) 48 | 49 | @classmethod 50 | def load(cls, path: Path) -> "Dataset": 51 | """Load a dataset from an .npz file saved by `Dataset.build_wikikg2`.""" 52 | data = np.load(path) 53 | return cls( 54 | n_entity=int(data["n_entity"]), 55 | n_relation_type=int(data["n_relation_type"]), 56 | triples={ 57 | k.replace("part_", ""): data[k] for k in data if k.startswith("part_") 58 | }, 59 | ) 60 | 61 | def sample(self, part: str, n: int, seed: Optional[int]) -> np.ndarray: 62 | """Draw a random sample of triples, without replacement.""" 63 | triples = self.triples[part] 64 | idx = np.random.default_rng(seed).choice( 65 | triples.shape[0], size=n, replace=False 66 | ) 67 | return triples[idx] 68 | 69 | 70 | @dataclass 71 | class Sharding: 72 | """A mapping of entities to shards (and back again). 73 | 74 | entity_to_shard -- int64[n_entity] -- maps entity ID to shard index 75 | 76 | entity_to_idx -- int64[n_entity] -- maps entity ID to index within its shard 77 | 78 | shard_and_idx_to_entity -- int64[n_shard, max_entity_per_shard] 79 | -- maps [shard, idx] to entity ID 80 | """ 81 | 82 | n_shard: int 83 | entity_to_shard: np.ndarray 84 | entity_to_idx: np.ndarray 85 | shard_and_idx_to_entity: np.ndarray 86 | 87 | @property 88 | def n_entity(self) -> int: 89 | return len(self.entity_to_shard) 90 | 91 | @property 92 | def max_entity_per_shard(self) -> int: 93 | return self.shard_and_idx_to_entity.shape[1] 94 | 95 | @classmethod 96 | def create(cls, n_entity: int, n_shard: int, seed: Optional[int]) -> "Sharding": 97 | """Construct a random balanced assignment of entities to shards.""" 98 | # Randomly shard entities 99 | entity_to_idx, entity_to_shard = np.divmod( 100 | np.random.default_rng(seed).permutation(n_entity), n_shard 101 | ) 102 | # Build a reverse mapping back to entities 103 | shard_and_idx_to_entity = np.zeros( 104 | (n_shard, int(np.ceil(n_entity / n_shard))), dtype=np.int64 105 | ) 106 | shard_and_idx_to_entity[entity_to_shard, entity_to_idx] = np.arange( 107 | len(entity_to_idx) 108 | ) 109 | return cls( 110 | n_shard=n_shard, 111 | entity_to_shard=entity_to_shard, 112 | entity_to_idx=entity_to_idx, 113 | shard_and_idx_to_entity=shard_and_idx_to_entity, 114 | ) 115 | 116 | 117 | class BatchSampler: 118 | """Sample training batches from a dataset of triples. 119 | 120 | Generates batches of numpy arrays containing: 121 | 122 | head : int64[n_batch_per_call, n_shard, n_shard, n_positive] 123 | 124 | relation : int64[n_batch_per_call, n_shard, n_shard, n_positive] 125 | 126 | src_tails : int64[n_batch_per_call, n_shard, n_shard, n_positive + n_negative] 127 | """ 128 | 129 | def __init__( 130 | self, 131 | triples: np.ndarray, 132 | sharding: Sharding, 133 | n_positive: int, 134 | n_negative: int, 135 | n_batch_per_call: int, 136 | seed: Optional[int], 137 | ): 138 | self.n_shard = sharding.n_shard 139 | self.n_positive = n_positive 140 | self.n_negative = n_negative 141 | self.n_batch_per_call = n_batch_per_call 142 | self.rng = np.random.default_rng(seed=seed) 143 | 144 | # Entity count per shard, useful for sampling negatives 145 | self.shard_to_count = np.bincount( 146 | sharding.entity_to_shard, minlength=self.n_shard 147 | ) 148 | 149 | # Build a mapping (shard(head), shard(tail)) -> [(idx(h), r, idx(t))] 150 | triple_to_shardpair = ( 151 | self.n_shard * sharding.entity_to_shard[triples[:, 0]] 152 | + sharding.entity_to_shard[triples[:, 2]] 153 | ) 154 | self.shardpair_to_count = np.bincount( 155 | triple_to_shardpair, minlength=self.n_shard * self.n_shard 156 | ).reshape(self.n_shard, self.n_shard) 157 | self.shardpair_to_offset = np.concatenate( 158 | [[0], np.cumsum(self.shardpair_to_count)[:-1]] 159 | ).reshape(self.n_shard, self.n_shard) 160 | train_hrt_sorted = triples[np.argsort(triple_to_shardpair)] 161 | self.flat_shardpair_to_hrt = np.stack( 162 | [ 163 | sharding.entity_to_idx[train_hrt_sorted[:, 0]], 164 | train_hrt_sorted[:, 1], 165 | sharding.entity_to_idx[train_hrt_sorted[:, 2]], 166 | ], 167 | axis=0, 168 | ) 169 | 170 | def __iter__(self) -> "BatchSampler": 171 | return self 172 | 173 | def __next__(self) -> Dict[str, np.ndarray]: 174 | # Use a flattened sampling trick to draw a uniform random sample of 175 | # training triples, stratified by (shard(head), shard(tail)) 176 | sample_idx = ( 177 | self.shardpair_to_offset[None, :, :, None] 178 | + self.rng.integers( 179 | 1 << 63, 180 | size=( 181 | self.n_batch_per_call, 182 | self.n_shard, 183 | self.n_shard, 184 | self.n_positive, 185 | ), 186 | ) 187 | % self.shardpair_to_count[None, :, :, None] 188 | ) 189 | head, relation, tail = self.flat_shardpair_to_hrt[:, sample_idx] 190 | # Draw negative samples, uniformly within each shard 191 | tail_negative = ( 192 | self.rng.integers( 193 | 1 << 63, 194 | size=( 195 | self.n_batch_per_call, 196 | self.n_shard, 197 | self.n_shard, 198 | self.n_negative, 199 | ), 200 | ) 201 | % self.shard_to_count[None, None, :, None] 202 | ) 203 | # Concatenate positive and negative tails and transpose so that the local shard index is first 204 | src_tails = np.concatenate([tail, tail_negative], axis=3).transpose(0, 2, 1, 3) 205 | return dict(head=head, relation=relation, src_tails=src_tails) 206 | 207 | 208 | # Model 209 | 210 | 211 | def l1_distance(a: T.Tensor, b: T.Tensor) -> T.Tensor: 212 | """Compute batched L1 distance between (at least 2D) tensors. 213 | 214 | a -- float[*group_shape, n_a, embedding_size] 215 | 216 | b -- float[*group_shape, n_b, embedding_size] 217 | 218 | returns -- float[*group_shape, n_a, n_b] 219 | """ 220 | if poptorch.isRunningOnIpu(): 221 | return T.sum(T.abs(a[..., :, None, :] - b[..., None, :, :]), dim=-1) 222 | return T.cdist(a, b, p=1.0) 223 | 224 | 225 | def transe_score(head: T.Tensor, relation: T.Tensor, tail: T.Tensor) -> T.Tensor: 226 | """Compute a batch of scores using TransE with L1 distance. 227 | 228 | See: Translating Embeddings for Modeling Multi-relational Data 229 | https://proceedings.neurips.cc/paper/2013/file/1cecc7a77928ca8133fa24680a88d2f9-Paper.pdf 230 | 231 | head -- float[n_positive, embedding_size] -- head entity embeddings 232 | 233 | relation -- float[n_positive, embedding_size] -- relation type embeddings 234 | 235 | tail -- float[n_tail, embedding_size] -- tail candidate entity embeddings 236 | 237 | returns -- float[n_positive, n_tail] -- scores (negative L1 distance) for each positive 238 | & candidate tail 239 | """ 240 | return -l1_distance(head + relation, tail) 241 | 242 | 243 | def all_to_all(x: T.Tensor) -> T.Tensor: 244 | """Cross-replica all-to-all permutation of data (IPU only). 245 | 246 | Each replica sends a fixed-size tensor to every other IPU. For example: 247 | 248 | Step | IPU0 | IPU1 249 | ------ | ------ | ------ 250 | Before | [A, B] | [C, D] 251 | Op | all_to_all 252 | After | [A, C] | [B, D] 253 | 254 | See `gcl::allToAllCrossReplica` in the Poplar SDK documentation for more information. 255 | 256 | x -- float[n_replica, ...] 257 | 258 | returns -- float[n_replica, ...] 259 | """ 260 | assert poptorch.isRunningOnIpu() 261 | ctypes.cdll.LoadLibrary(Path(__file__).parent / Path("build/custom_ops.so")) 262 | (y,) = poptorch.custom_op( 263 | [x], 264 | "AllToAll", 265 | "ai.graphcore", 266 | 1, 267 | example_outputs=[x], 268 | ) 269 | return y 270 | 271 | 272 | def sampled_softmax_cross_entropy( 273 | score: T.Tensor, label: T.Tensor, total_classes: int 274 | ) -> T.Tensor: 275 | """A minor modification of softmax cross-entropy with an adjustment for negative sampling. 276 | 277 | The adjustment increases the score of negative classes to account for the fact that 278 | they are not updated on every step. 279 | 280 | This method assumes that negative classes are drawn with a flat distribution, probability 281 | (1/total_classes). 282 | 283 | score -- float[batch_size, candidate] 284 | 285 | label -- int[batch_size] 286 | 287 | total_classes -- int -- total number of classes that negative samples are drawn from 288 | 289 | returns -- float[] -- sum softmax cross entropy loss 290 | """ 291 | # The adjustment for `class` is `log(1 / E(count(candidate == class)))`, which is constant 292 | # over all negative classes and zero for the target class. 293 | adjustment = T.tensor( 294 | np.log(total_classes) - np.log(score.shape[1] - 1), 295 | device=score.device, 296 | dtype=score.dtype, 297 | ) 298 | nonlabel_mask = ( 299 | T.arange(score.shape[1], device=label.device, dtype=label.dtype)[None, :] 300 | != label[:, None] 301 | ) 302 | return T.nn.functional.cross_entropy( 303 | score + adjustment * nonlabel_mask, label, reduction="sum" 304 | ) 305 | 306 | 307 | class Model(T.nn.Module): 308 | """A basic knowledge graph embedding (KGE) model using TransE and BESS. 309 | 310 | Parameters: 311 | entity_embedding -- float[n_shard, max_entity_per_shard, embedding_size] 312 | relation_embedding -- float[n_relation_type, embedding_size] 313 | 314 | Forward (IPU): 315 | head -- int[1, n_shard, n_positive] 316 | relation -- int[1, n_shard, n_positive] 317 | src_tails -- int[1, n_shard, n_tail] 318 | 319 | Note that this corresponds to the shape on each replica. There should be 320 | `n_shard` replicas. 321 | 322 | Forward (CPU): 323 | head -- int[n_shard, n_shard, n_positive] 324 | relation -- int[n_shard, n_shard, n_positive] 325 | src_tails -- int[n_shard, n_shard, n_tail] 326 | """ 327 | 328 | def __init__(self, sharding: Sharding, n_relation_type: int, embedding_size: int): 329 | super().__init__() 330 | self.sharding = sharding 331 | self.embedding_size = embedding_size 332 | self.score = transe_score 333 | self.entity_embedding = T.nn.Parameter( 334 | T.FloatTensor( 335 | sharding.n_shard, 336 | sharding.max_entity_per_shard, 337 | embedding_size, 338 | ) 339 | ) 340 | T.nn.init.normal_(self.entity_embedding, std=1 / embedding_size) 341 | self.relation_embedding = T.nn.Parameter( 342 | T.FloatTensor(n_relation_type, embedding_size) 343 | ) 344 | T.nn.init.normal_(self.relation_embedding, std=1 / embedding_size) 345 | 346 | def forward( 347 | self, head: T.Tensor, relation: T.Tensor, src_tails: T.Tensor 348 | ) -> T.Tensor: 349 | if poptorch.isRunningOnIpu(): 350 | return self.forward_ipu(head, relation, src_tails) 351 | return self.forward_cpu(head, relation, src_tails) 352 | 353 | def forward_cpu( 354 | self, head: T.Tensor, relation: T.Tensor, src_tails: T.Tensor 355 | ) -> T.Tensor: 356 | n_shard, _, n_positive = head.shape 357 | _, _, n_tails = src_tails.shape 358 | shards = T.arange(n_shard)[:, None, None] 359 | score = self.score( 360 | head=self.entity_embedding[shards, head] 361 | .float() 362 | .view(n_shard, -1, self.embedding_size), 363 | relation=self.relation_embedding[relation, :] 364 | .float() 365 | .view(n_shard, -1, self.embedding_size), 366 | tail=self.entity_embedding[shards, src_tails] 367 | .float() 368 | .permute(1, 0, 2, 3) 369 | .reshape(n_shard, -1, self.embedding_size), 370 | ) 371 | true_tail_idx = ( 372 | (T.arange(n_shard)[:, None] * n_tails + T.arange(n_positive)[None, :]) 373 | .view(-1) 374 | .repeat(n_shard) 375 | ) 376 | return sampled_softmax_cross_entropy( 377 | score.view(-1, n_shard * n_tails), 378 | true_tail_idx, 379 | total_classes=self.sharding.n_entity, 380 | ) 381 | 382 | def forward_ipu( 383 | self, head: T.Tensor, relation: T.Tensor, src_tails: T.Tensor 384 | ) -> T.Tensor: 385 | head = head.squeeze(0) 386 | relation = relation.squeeze(0) 387 | src_tails = src_tails.squeeze(0) 388 | head_embedding, src_tail_embedding = T.split( 389 | self.entity_embedding[T.concat([head, src_tails], dim=1)], 390 | [head.shape[1], src_tails.shape[1]], 391 | dim=1, 392 | ) 393 | relation_embedding = self.relation_embedding[relation] 394 | score = self.score( 395 | head=head_embedding.float().view(-1, self.embedding_size), 396 | relation=relation_embedding.float().view(-1, self.embedding_size), 397 | tail=all_to_all(src_tail_embedding).float().view(-1, self.embedding_size), 398 | ) 399 | true_tail_idx = ( 400 | T.arange(self.sharding.n_shard, device=score.device, dtype=T.int)[:, None] 401 | * src_tails.shape[1] 402 | + T.arange(head.shape[1], device=score.device, dtype=T.int)[None, :] 403 | ).view(-1) 404 | return sampled_softmax_cross_entropy( 405 | score, true_tail_idx, total_classes=self.sharding.n_entity 406 | ) 407 | 408 | # Persistence 409 | 410 | def save(self, path: Path) -> None: 411 | """Save model parameters and entity sharding metadata to disk.""" 412 | T.save(dict(**self.state_dict(), sharding=self.sharding.__dict__), path) 413 | 414 | def load(self, path: Path) -> None: 415 | """Load model parameters and entity sharding metadata from disk. 416 | 417 | Note that any `BatchSampler` must be updated to use `model.sharding` after calling 418 | this method. 419 | """ 420 | state = T.load(path) 421 | self.sharding = Sharding(**state.pop("sharding")) 422 | self.load_state_dict(state) 423 | 424 | 425 | # Evaluation 426 | 427 | 428 | def mrr(predictions: np.ndarray, target: np.ndarray) -> float: 429 | """Compute the mean reciprocal rank (MRR) of targets in a ranked list. 430 | 431 | Missing targets have rank infinity (no contribution). 432 | 433 | predictions -- int[batch_size, n_predictions] 434 | 435 | target -- int[batch_size] 436 | 437 | returns -- float -- mean over the batch 438 | """ 439 | rows, cols = np.nonzero(predictions == target[:, None]) 440 | assert len(rows) == len( 441 | np.unique(rows) 442 | ), "target should never appear twice in predictions" 443 | return np.sum(1 / (1 + cols)) / predictions.shape[0] 444 | 445 | 446 | def evaluate_mrr(model: Model, triples: np.ndarray, n_best: int) -> float: 447 | """Compute MRR for a batch of evaluation triples. 448 | 449 | Note: runs on CPU, for sake of simplicity. 450 | 451 | triples -- int[batch_size, {h, r, t}] 452 | 453 | n_best -- int -- number of tail predictions per query (fix this for a fair comparison) 454 | 455 | returns -- float -- MRR over all triples 456 | """ 457 | head, relation, true_tail = triples.T 458 | score = model.score( 459 | head=model.entity_embedding[ 460 | model.sharding.entity_to_shard[head], model.sharding.entity_to_idx[head] 461 | ].float(), 462 | relation=model.relation_embedding[relation].float(), 463 | tail=model.entity_embedding.view(-1, model.embedding_size).float(), 464 | ) 465 | _, idx = T.topk(score, k=n_best) 466 | shards, indices = np.divmod(idx, model.entity_embedding.shape[1]) 467 | return mrr(model.sharding.shard_and_idx_to_entity[shards, indices], true_tail) 468 | 469 | 470 | def predict(head: int, relation: int, model: Model, n_best: int) -> List[int]: 471 | """Make the n_best most likely tail predictions for a single `(head, relation, ?)` query. 472 | 473 | Note: runs on CPU, for sake of simplicity. 474 | """ 475 | score = model.score( 476 | head=model.entity_embedding[ 477 | None, 478 | model.sharding.entity_to_shard[head], 479 | model.sharding.entity_to_idx[head], 480 | ].float(), 481 | relation=model.relation_embedding[None, relation].float(), 482 | tail=model.entity_embedding.view(-1, model.embedding_size).float(), 483 | ) 484 | _, idx = T.topk(score, k=n_best) 485 | shards, indices = np.divmod(idx[0], model.entity_embedding.shape[1]) 486 | return list(model.sharding.shard_and_idx_to_entity[shards, indices]) 487 | 488 | 489 | # Training 490 | 491 | 492 | def create_train_step( 493 | model: Model, 494 | optimiser: str, 495 | lr: float, 496 | sgdm_momentum: Optional[float], 497 | weight_decay: float, 498 | device: str, 499 | device_iterations: int, 500 | ) -> Callable[..., T.Tensor]: 501 | """Create a 'stepper function' for training, with the same interface across {CPU, IPU}. 502 | 503 | Note: returns the final loss (not summed over device_iterations). 504 | 505 | Usage: 506 | 507 | stepper = create_train_step(...) 508 | for batch in batches: 509 | stepper(batch) 510 | 511 | optimiser -- {"adamw" | "sgdm"} 512 | 513 | device -- {"ipu" | "cpu"} -- note that the CPU implementation is slow, only included for testing 514 | 515 | device_iterations -- int -- the number of optimiser steps to take for each call 516 | (this must match the batch shape passed to `step()`) 517 | 518 | returns -- fn(head, relation, src_tails) -- training `step()` function 519 | -- head, relation, src_tails are numpy arrays, as generated by `BatchSampler` 520 | """ 521 | if device == "cpu": 522 | if optimiser == "adamw": 523 | opt = T.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) 524 | elif optimiser == "sgdm": 525 | opt = T.optim.SGD( 526 | model.parameters(), 527 | lr=lr, 528 | momentum=sgdm_momentum, 529 | weight_decay=weight_decay, 530 | ) 531 | else: 532 | assert False, f"unexpected optimiser {optimiser!r}" 533 | 534 | def step(batch: Dict[str, np.ndarray]) -> T.Tensor: 535 | opt.zero_grad() 536 | for i in range(device_iterations): 537 | loss = model(**{k: T.tensor(v[i]) for k, v in batch.items()}) 538 | loss.backward() 539 | opt.step() 540 | return loss 541 | 542 | return step 543 | 544 | if device == "ipu": 545 | options = poptorch.Options() 546 | options.replication_factor = model.sharding.n_shard 547 | options.deviceIterations(device_iterations) 548 | options.Precision.enableStochasticRounding(True) 549 | if "POPLAR_EXECUTABLE_CACHE_DIR" in os.environ: 550 | options.enableExecutableCaching( 551 | str(Path(os.environ["POPLAR_EXECUTABLE_CACHE_DIR"]) / "kge_training") 552 | ) 553 | 554 | # Add a memory saving optimisation pattern. This removes an unnecessary 555 | # entity_embedding gradient all-reduce, which is a no-op since it is fully 556 | # sharded across replicas. 557 | ctypes.cdll.LoadLibrary(Path(__file__).parent / "build/custom_ops.so") 558 | options._popart.setPatterns(dict(RemoveAllReducePattern=True)) 559 | 560 | (dtype,) = {p.dtype for p in model.parameters()} 561 | if optimiser == "adamw": 562 | opt = poptorch.optim.AdamW( 563 | model.parameters(), 564 | lr=lr, 565 | weight_decay=weight_decay, 566 | accum_type=T.float32, 567 | first_order_momentum_accum_type=dtype, 568 | second_order_momentum_accum_type=T.float32, 569 | ) 570 | elif optimiser == "sgdm": 571 | opt = poptorch.optim.SGD( 572 | model.parameters(), 573 | lr=lr, 574 | momentum=sgdm_momentum, 575 | weight_decay=weight_decay, 576 | accum_type=T.float32, 577 | velocity_accum_type=dtype, 578 | ) 579 | else: 580 | assert False, f"unexpected optimiser {optimiser!r}" 581 | 582 | ipu_model = poptorch.trainingModel(model, options=options, optimizer=opt) 583 | 584 | # Set `entity_embedding` as a fully sharded parameter across replicas. This 585 | # disables the gradient all-reduce, and allows the parameters to be initialised 586 | # and read separately, so that they are treated as separate parameters. 587 | # 588 | # (Compare with `relation_embedding`, which keeps default settings, a replicated 589 | # parameter across replicas. This means that the gradients across replica are 590 | # all-reduced together, such that the parameter value is kept synchronised.) 591 | ipu_model.entity_embedding.replicaGrouping( 592 | poptorch.CommGroupType.NoGrouping, 593 | 0, 594 | poptorch.VariableRetrievalMode.OnePerGroup, 595 | ) 596 | 597 | def step(batch: Dict[str, np.ndarray]) -> T.Tensor: 598 | # PopTorch expects device_iterations & shard to be flattened 599 | # into a single dimension 600 | return ipu_model( 601 | **{ 602 | k: T.tensor(np.ascontiguousarray(v), dtype=T.int32).flatten( 603 | end_dim=1 604 | ) 605 | for k, v in batch.items() 606 | } 607 | ) 608 | 609 | return step 610 | 611 | assert False, f"device '{device}' unexpected, expected {{'cpu', 'ipu'}}" 612 | 613 | 614 | def train( 615 | model: Model, 616 | batch_sampler: BatchSampler, 617 | n_step: int, 618 | optimiser: str, 619 | lr: float, 620 | sgdm_momentum: Optional[float], 621 | weight_decay: float, 622 | valid_triples: np.ndarray, 623 | valid_interval: int, 624 | device: str, 625 | ) -> Iterable[Dict[str, Any]]: 626 | """Wraps `create_train_step` into a full training loop, with interleaved validation. 627 | 628 | n_step -- int -- total number of optimiser update steps; an upper bound if `n_step` is not a 629 | multiple of `batch_sampler.n_batch_per_call` 630 | 631 | valid_triples -- int[n_valid, {h, r, t}] -- triples for interleaved validation 632 | (we recommend limiting this in the range of 1000s, as validation runs on CPU) 633 | 634 | valid_interval -- int -- number of optimiser update steps before re-running validation 635 | 636 | See `create_train_step()` for the description of remaining parameters. 637 | 638 | yields dict( 639 | step -- int -- number of optimiser update steps taken 640 | example -- int -- number of positive examples consumed so far 641 | loss -- Optional[float] -- mean loss for this step (missing in final validation dict) 642 | mrr -- Optional[float] -- validation MRR (periodically, based on `valid_interval`) 643 | elapsed -- float -- elapsed seconds since last dict (includes training and validaiton time) 644 | ) 645 | """ 646 | n_best = 10 647 | t0 = time.time() 648 | step = create_train_step( 649 | model=model, 650 | lr=lr, 651 | optimiser=optimiser, 652 | sgdm_momentum=sgdm_momentum, 653 | weight_decay=weight_decay, 654 | device=device, 655 | device_iterations=batch_sampler.n_batch_per_call, 656 | ) 657 | n_loop = n_step // batch_sampler.n_batch_per_call 658 | valid_interval_loop = valid_interval // batch_sampler.n_batch_per_call 659 | examples_per_step = batch_sampler.n_positive * batch_sampler.n_shard**2 660 | for n in range(n_loop + 1): 661 | record = dict(step=n * batch_sampler.n_batch_per_call) 662 | record["example"] = record["step"] * examples_per_step 663 | if n and (n % valid_interval_loop == 0) or n == n_loop: 664 | record["mrr"] = evaluate_mrr(model, valid_triples, n_best=n_best) 665 | if n < n_loop: 666 | record["loss"] = float(T.sum(step(next(batch_sampler)))) / examples_per_step 667 | record["elapsed"] = time.time() - t0 668 | t0 += record["elapsed"] 669 | yield record 670 | --------------------------------------------------------------------------------