├── .ci └── view_component_trigger │ ├── Jenkinsfile │ └── jobs.groovy ├── .clang-format ├── .clang-tidy ├── .github ├── CODEOWNERS └── workflows │ └── apply_linters.yml ├── .gitignore ├── .linters └── activate_buildenv.sh ├── .pre-commit-config.yaml ├── .pylintrc ├── .style.yapf ├── CMakeLists.txt ├── License.txt ├── MANIFEST.in ├── README.md ├── config.buildenv.py ├── docs ├── common │ ├── _static │ │ └── css │ │ │ └── custom_rtd.css │ ├── conf.py │ ├── custom_dic │ ├── graphcorelogo-html.png │ └── graphcorelogo-pdf.png ├── poptorch_geometric │ ├── common │ │ └── conf.py │ └── user_guide │ │ ├── index.rst │ │ ├── installation.rst │ │ ├── intro.rst │ │ ├── legal.rst │ │ ├── performance.rst │ │ ├── reference.rst │ │ ├── supported_operations.rst │ │ └── tutorials.rst └── user_guide │ ├── CMakeLists.txt │ ├── IPU-pipeline.jpg │ ├── api.py │ ├── batching.rst │ ├── buffers.py │ ├── comm-group-types.png │ ├── debugging.py │ ├── debugging.rst │ ├── device_iterations.py │ ├── error_handling.py │ ├── example.rst │ ├── experimental.rst │ ├── hostio_optimisation.rst │ ├── index.rst │ ├── inferenceModel.py │ ├── installation.rst │ ├── intro.rst │ ├── legal.rst │ ├── mnist.py │ ├── no-buffering-profile.png │ ├── overview.rst │ ├── phased_execution.py │ ├── pipeline_simple.py │ ├── pipelined_execution.png │ ├── poptorch.conf │ ├── poptorch_training_simple.py │ ├── precompilation.py │ ├── pytorch-software-stack.png │ ├── pytorch_to_poptorch.rst │ ├── reference.rst │ ├── replica_grouped_weights.py │ ├── sharded_execution.png │ ├── stages_summary.png │ ├── sumAnchorReturnType.py │ ├── supported_ops.rst │ ├── trainingModel.py │ ├── with-buffering-overlap-profile.png │ └── with-buffering-profile.png ├── examples ├── CMakeLists.txt ├── bert_ipu.py ├── lstm.py ├── mnist.py ├── simple_adder.py └── zeus.jpg ├── popart_compiler ├── CMakeLists.txt ├── include │ └── popart_compiler │ │ ├── CodeletsCompilation.hpp │ │ ├── Compiler.hpp │ │ ├── CompilerOperationMacros.inc.hpp │ │ ├── ManuallyAddedOperations.inc.hpp │ │ ├── SupportedOperations.inc.hpp │ │ └── Utils.hpp ├── source │ ├── CodeletsCompilation.cpp │ ├── Compiler.cpp │ ├── CompilerImpl.cpp │ ├── SessionOptions.cpp │ ├── Utils.cpp │ ├── custom_operations │ │ ├── Embedding.cpp │ │ ├── FastGatherLastDim.cpp │ │ ├── FastGatherLastDim.hpp │ │ ├── FastGatherLastDimBwdCodelets.inc.cpp │ │ ├── FastGatherLastDimFwdCodelets.inc.cpp │ │ ├── HostOp.cpp │ │ ├── TorchSoftplus.cpp │ │ ├── TorchSoftplus.hpp │ │ ├── UpsampleBilinear2d.cpp │ │ └── UpsampleBilinear2dCodelets.inc.cpp │ └── include │ │ └── popart_compiler │ │ ├── CompilerImpl.hpp │ │ ├── CompilerOptions.hpp │ │ ├── CustomOps.hpp │ │ ├── MultiConvBuilder.hpp │ │ └── SessionOptionsImpl.hpp └── types │ └── include │ └── popart_compiler │ ├── CompilerTypes.hpp │ └── PopartEnums.hpp ├── poptorch ├── CMakeLists.txt ├── include │ └── poptorch │ │ ├── DispatchTracer.hpp │ │ ├── InplaceOps.hpp │ │ ├── LowerToPopart.hpp │ │ ├── LowerToPopartFactories.hpp │ │ ├── PoplarExecutable.hpp │ │ ├── SessionOptionsParser.hpp │ │ └── Utils.hpp └── source │ ├── AddDetachOperations.cpp │ ├── AddSubgraphConnectionNodes.cpp │ ├── AliasProcessing.cpp │ ├── CPUOffloadingCleanUp.cpp │ ├── CompilerOps.cpp.inc │ ├── ErrorOnUnsupportedAten.cpp │ ├── FixupSetAvailableMemory.cpp │ ├── GNNOptimizations.cpp │ ├── GatherWithExpandedIndicesOptimization.cpp │ ├── ImplicitCasting.cpp │ ├── InplaceOps.cpp │ ├── LowerToPopart.cpp │ ├── LowerToPopartFactories.cpp │ ├── OpBuilder.cpp │ ├── OverlappedIO.cpp │ ├── PopartCanonicalization.cpp │ ├── PopartLateCanonicalization.cpp │ ├── PoplarExecutable.cpp │ ├── PoptorchStaticInit.hpp │ ├── PoptorchSymbols.cpp │ ├── PoptorchSymbols.hpp │ ├── RemoveSurplusIdentityLosses.cpp │ ├── RequiresGrad.cpp │ ├── SessionOptionsParser.cpp │ ├── Utils.cpp │ ├── dispatch_tracer │ ├── CMakeLists.txt │ ├── CommonHelperFunctions.cpp │ ├── CommonHelperFunctions.hpp │ ├── InplaceAliasMapper.cpp │ ├── InplaceAliasMapper.hpp │ ├── README.md │ ├── RegisterAtenOverloads.cpp │ ├── RegisterMetaOps.cpp.inc │ ├── RegisterOptionalAtenOps.cpp.inc │ ├── Tensor.cpp │ ├── Tensor.hpp │ ├── TypeInferenceHandler.cpp │ ├── TypeInferenceHandler.hpp │ ├── ValueMapper.cpp │ ├── ValueMapper.hpp │ └── dispatchers │ │ ├── IDispatch.cpp │ │ ├── IDispatch.hpp │ │ ├── JitDispatch.cpp │ │ └── JitDispatch.hpp │ ├── include │ └── poptorch │ │ ├── AliasProcessing.hpp │ │ ├── CompilerOps.inc.hpp │ │ ├── ImplicitCasting.hpp │ │ ├── InplaceOpsPyTorch.hpp_nolint │ │ ├── OpBuilder.hpp │ │ ├── OverlappedIO.hpp │ │ ├── PopartCanonicalization.hpp │ │ ├── RequiresGrad.hpp │ │ └── TypeAndConstantCanonicalization.hpp │ ├── popart_canonicalization │ ├── ActivationOps.cpp │ ├── ArithmeticOps.cpp │ ├── AtenHandlers.gen.cpp │ ├── BilinearOps.cpp │ ├── BitwiseOps.cpp │ ├── BlasOps.cpp │ ├── ConstantOps.cpp │ ├── ConvolutionOps.cpp │ ├── CustomOps.cpp │ ├── DistanceOps.cpp │ ├── DropoutOps.cpp │ ├── EinsumOp.cpp │ ├── EinsumOp.hpp │ ├── EmbeddingOps.cpp │ ├── IndexOps.cpp │ ├── LossOps.cpp │ ├── NormalizationOps.cpp │ ├── OtherOps.cpp │ ├── PoolingOps.cpp │ ├── PopartCanonicalizationUtils.cpp │ ├── PopartCanonicalizationUtils.hpp │ ├── PoptorchHandlers.gen.cpp │ ├── PyGTorchScatterOps.cpp │ ├── PyGTorchSplineConvOps.cpp │ ├── RNNOps.cpp │ ├── RandomSamplingOps.cpp │ ├── ReduceOps.cpp │ ├── ReshapeOps.cpp │ ├── ScatterReduction.cpp │ ├── ScatterReduction.hpp │ ├── SliceOps.cpp │ ├── SoftmaxOps.cpp │ ├── TensorOps.cpp │ └── pyg_torch_cluster │ │ ├── FpsOp.cpp │ │ ├── GridOp.cpp │ │ └── NearestOp.cpp │ └── type_and_constant_canonicalization │ ├── AddListNumElements.cpp │ ├── CanonicaliseConstants.cpp │ ├── CastUnsupportedInputs.cpp │ ├── CheckAndChangeOutputTypes.cpp │ ├── EvaluateConstexprs.cpp │ └── MakeConstantIntParams.cpp ├── poptorch_compiler └── pytorch_bridge │ ├── CMakeLists.txt │ ├── IpuSession.cpp │ └── include │ └── pytorch_bridge │ ├── CompilerOptions.hpp │ ├── CompilerTypes.hpp │ ├── DebugInfo.hpp │ └── IpuSession.hpp ├── poptorch_err ├── CMakeLists.txt ├── exception_info │ └── poptorch_err │ │ └── ExceptionInfo.hpp ├── include │ └── poptorch_err │ │ └── ExceptionHandling.hpp └── source │ └── ExceptionHandling.cpp ├── poptorch_geometric ├── CMakeLists.txt ├── License.txt ├── MANIFEST.in ├── README.md ├── config.buildenv.py ├── poptorch_geometric_third_party_licenses.txt ├── pyproject.toml ├── python │ ├── CMakeLists.txt │ ├── __init__.py │ ├── cluster_loader.py │ ├── collate.py │ ├── common.py │ ├── dataloader.py │ ├── fixed_size_options.py │ ├── masker.py │ ├── neighbor_loader.py │ ├── ops │ │ ├── __init__.py │ │ ├── aggregation_base.py │ │ ├── cluster_gcn_conv.py │ │ ├── hetero_linear.py │ │ ├── instance_norm.py │ │ ├── knn.py │ │ ├── knn_graph.py │ │ ├── knn_interpolate.py │ │ ├── mf_conv.py │ │ └── radius.py │ ├── override.py │ ├── py.typed │ ├── pyg_cluster_loader.py │ ├── pyg_collate.py │ ├── pyg_dataloader.py │ ├── stream_packing_sampler.py │ ├── types.py │ └── utils.py ├── requirements.txt ├── setup.cfg └── setup.py ├── poptorch_logging ├── CMakeLists.txt ├── include │ └── poptorch_logging │ │ ├── Error.hpp │ │ ├── Logging.hpp │ │ ├── LoggingLight.hpp │ │ └── Tracepoint.hpp └── source │ ├── Error.cpp │ ├── Logging.cpp │ └── Tracepoint.cpp ├── poptorch_third_party_licenses.txt ├── pyproject.toml ├── python ├── CMakeLists.txt ├── __init__.py ├── _args_parser.py ├── _dataloader.py ├── _impl.py ├── _logging.py ├── _optimizer_attributes.py ├── _options_config.py ├── _options_impl.py ├── _poplar_executor.py ├── _poptorch_data.py ├── _printing.py ├── _utils.py ├── enums.py ├── ops.py ├── optim.py ├── options.py ├── poptorch.cpp ├── profiling.py ├── py.typed └── testing.py ├── requirements.txt ├── scripts ├── PopAtenHandlers.py ├── PopParse.py ├── PopTorchHandlers.py ├── __init__.py ├── apply_linters.py ├── check_spelling.py ├── create_buildenv.py ├── docs_build.py ├── download_external_datasets.py ├── enable.sh.in ├── generate_poppyg_package.py ├── generate_python_package.py ├── popgen │ ├── __init__.py │ ├── api.py │ ├── generator.py │ ├── helpers.py │ ├── onnx.py │ ├── operatorfactory.py │ ├── poptorch.py │ ├── registry.py │ ├── transform.py │ └── values.py ├── set_version.py └── utils │ └── _utils.py ├── setup.cfg ├── setup.py ├── tests ├── .gitignore ├── CMakeLists.txt ├── activations_test.py ├── attach_detach_test.py ├── attach_detach_wait_for_ipu_test.py ├── batching_test.py ├── bert_small_and_medium_test.py ├── blas_test.py ├── bool_support_test.py ├── buffers_test.py ├── conftest.py ├── convs_test.py ├── cpp │ ├── CMakeLists.txt │ └── GNNOptimizationsTest.cpp ├── cpu_op_test.py ├── ctc_decoder_test.py ├── custom_loss_test.py ├── custom_ops │ ├── CMakeLists.txt │ ├── custom_add_scalar_op.cpp │ ├── custom_add_scalar_vec_op.cpp │ ├── custom_add_vec_scalar_mul_op.cpp │ ├── custom_cube_op.cpp │ ├── custom_leaky_relu_op.cpp │ ├── custom_many_attribute_op.cpp │ ├── custom_reduce_op.cpp │ └── custom_three_input_reduce_op.cpp ├── custom_ops_attributes_test.py ├── custom_ops_test.py ├── dataloader_test.py ├── debug_tensors_test.py ├── distance_ops_test.py ├── exception_test.py ├── fine_tuning_test.py ├── functional_test.py ├── generate_test_file.py ├── gnn │ ├── .gitignore │ ├── benchgnn │ │ ├── README.md │ │ ├── benchgnn.py │ │ ├── datasets.py │ │ ├── models.py │ │ ├── requirements.txt │ │ └── utils.py │ ├── benchgnn_ops │ │ ├── README.md │ │ ├── benchgnn_ops.py │ │ ├── builder.py │ │ ├── example_configs │ │ │ ├── common.yaml │ │ │ ├── scatter_testcase1.yaml │ │ │ └── scatter_testcase2.yaml │ │ ├── metrics.py │ │ ├── ops.py │ │ └── requirements.txt │ ├── conftest.py │ ├── nn │ │ ├── aggr │ │ │ ├── aggr_utils.py │ │ │ ├── conftest.py │ │ │ ├── test_attention.py │ │ │ ├── test_basic.py │ │ │ ├── test_deep_sets.py │ │ │ ├── test_equilibrium.py │ │ │ ├── test_fused.py │ │ │ ├── test_gmt.py │ │ │ ├── test_gru.py │ │ │ ├── test_lstm.py │ │ │ ├── test_mlp_aggr.py │ │ │ ├── test_multi.py │ │ │ ├── test_quantile.py │ │ │ ├── test_scaler.py │ │ │ ├── test_set2set.py │ │ │ ├── test_set_transformer.py │ │ │ └── test_sort.py │ │ ├── conftest.py │ │ ├── conv │ │ │ ├── conv_utils.py │ │ │ ├── test_agnn_conv.py │ │ │ ├── test_antisymmetric_conv.py │ │ │ ├── test_appnp.py │ │ │ ├── test_arma_conv.py │ │ │ ├── test_cg_conv.py │ │ │ ├── test_cheb_conv.py │ │ │ ├── test_cluster_gcn_conv.py │ │ │ ├── test_dna_conv.py │ │ │ ├── test_edge_conv.py │ │ │ ├── test_eg_conv.py │ │ │ ├── test_fa_conv.py │ │ │ ├── test_feast_conv.py │ │ │ ├── test_film_conv.py │ │ │ ├── test_gat_conv.py │ │ │ ├── test_gated_graph_conv.py │ │ │ ├── test_gatv2_conv.py │ │ │ ├── test_gcn2_conv.py │ │ │ ├── test_gcn_conv.py │ │ │ ├── test_gen_conv.py │ │ │ ├── test_general_conv.py │ │ │ ├── test_gin_conv.py │ │ │ ├── test_gmm_conv.py │ │ │ ├── test_gps_conv.py │ │ │ ├── test_graph_conv.py │ │ │ ├── test_gravnet_conv.py │ │ │ ├── test_han_conv.py │ │ │ ├── test_heat_conv.py │ │ │ ├── test_hetero_conv.py │ │ │ ├── test_hgt_conv.py │ │ │ ├── test_hypergraph_conv.py │ │ │ ├── test_le_conv.py │ │ │ ├── test_lg_conv.py │ │ │ ├── test_mf_conv.py │ │ │ ├── test_nn_conv.py │ │ │ ├── test_pan_conv.py │ │ │ ├── test_pdn_conv.py │ │ │ ├── test_pna_conv.py │ │ │ ├── test_point_conv.py │ │ │ ├── test_point_gnn_conv.py │ │ │ ├── test_point_transformer_conv.py │ │ │ ├── test_ppf_conv.py │ │ │ ├── test_res_gated_graph_conv.py │ │ │ ├── test_rgat_conv.py │ │ │ ├── test_rgcn_conv.py │ │ │ ├── test_sage_conv.py │ │ │ ├── test_sg_conv.py │ │ │ ├── test_signed_conv.py │ │ │ ├── test_simple_conv.py │ │ │ ├── test_spline_conv.py │ │ │ ├── test_ssg_conv.py │ │ │ ├── test_supergat_conv.py │ │ │ ├── test_tag_conv.py │ │ │ ├── test_transformer_conv.py │ │ │ ├── test_wl_conv.py │ │ │ ├── test_wl_conv_continuous.py │ │ │ └── test_x_conv.py │ │ ├── dense │ │ │ ├── dense_utils.py │ │ │ └── test_convs.py │ │ ├── functional │ │ │ ├── test_bro.py │ │ │ └── test_gini.py │ │ ├── kge │ │ │ ├── kge_utils.py │ │ │ ├── test_complex.py │ │ │ ├── test_distmult.py │ │ │ ├── test_rotate.py │ │ │ └── test_transe.py │ │ ├── nn_utils.py │ │ ├── norm │ │ │ ├── norm_utils.py │ │ │ ├── test_batch_norm.py │ │ │ ├── test_diff_group_norm.py │ │ │ ├── test_graph_norm.py │ │ │ ├── test_graph_size_norm.py │ │ │ ├── test_instance_norm.py │ │ │ ├── test_layer_norm.py │ │ │ ├── test_mean_subtraction_norm.py │ │ │ ├── test_msg_norm.py │ │ │ └── test_pair_norm.py │ │ ├── pool │ │ │ ├── pool_utils.py │ │ │ ├── test_asap.py │ │ │ ├── test_avg_pool.py │ │ │ ├── test_consecutive.py │ │ │ ├── test_decimation.py │ │ │ ├── test_edge_pool.py │ │ │ ├── test_fps.py │ │ │ ├── test_glob.py │ │ │ ├── test_graclus.py │ │ │ ├── test_max_pool.py │ │ │ ├── test_mem_pool.py │ │ │ ├── test_pan_pool.py │ │ │ ├── test_pool_knn.py │ │ │ ├── test_radius.py │ │ │ ├── test_sag_pool.py │ │ │ ├── test_select_topk.py │ │ │ ├── test_topk_pool.py │ │ │ └── test_voxel_grid.py │ │ ├── test_linear.py │ │ ├── test_loss.py │ │ ├── test_mish.py │ │ ├── test_sequential.py │ │ └── unpool │ │ │ └── test_interpolate.py │ ├── ops │ │ ├── test_knn.py │ │ ├── test_knn_graph.py │ │ ├── test_knn_interpolate.py │ │ ├── test_nearest.py │ │ ├── test_radius_op.py │ │ ├── test_spline_conv_ops.py │ │ └── test_to_dense_batch.py │ ├── test_basic_gnn.py │ ├── test_cluster_loader.py │ ├── test_collate.py │ ├── test_dataloader.py │ ├── test_encoding.py │ ├── test_fixed_size_options.py │ ├── test_masker.py │ ├── test_model_args.py │ ├── test_neighbor_loader.py │ ├── test_register_custom_args.py │ ├── test_stream_packing_sampler.py │ └── utils.py ├── grouping_scatters_gathers_test.py ├── gru_test.py ├── half_float_test.py ├── half_test.py ├── helpers.py ├── hooks_test.py ├── if_test.py ├── index_ops_test.py ├── inplace_test.py ├── inputs_test.py ├── io_performance_test.py ├── ipu_print_tensor_test.py ├── loop_test.py ├── losses_test.py ├── lstm_test.py ├── math_ops_test.py ├── misc_nn_layers_test.py ├── misc_test.py ├── multiconv_test.py ├── non_contiguous_tensors_test.py ├── norms_test.py ├── ops_test.py ├── optimizers_test.py ├── options_test.py ├── other_ops_test.py ├── outputs_test.py ├── overlapped_io_test.py ├── phased_execution_test.py ├── pipelining_test.py ├── pooling_and_padding_test.py ├── popdist_test.py ├── poplar_executor_test.py ├── precompilation_test.py ├── pyg_torch_scatter_test.py ├── random_sampling_test.py ├── reduce_ops_test.py ├── replicated_graph_test.py ├── requires_grad_test.py ├── rnn_test.py ├── sharding_test.py ├── slice_test.py ├── tensor_ops_test.py ├── test_doc_urls.py ├── test_perf_counters.py ├── timeout_handler.py ├── torch_nn_test.py ├── torchvision_inference_test.py ├── type_support_test.py └── weights_writing_test.py └── version.json /.ci/view_component_trigger/Jenkinsfile: -------------------------------------------------------------------------------- 1 | @Library('sw-jenkins-library@view-component-trigger') _ 2 | 3 | viewComponentTrigger(jobsFilepath: '.ci/view_component_trigger/jobs.groovy') 4 | -------------------------------------------------------------------------------- /.ci/view_component_trigger/jobs.groovy: -------------------------------------------------------------------------------- 1 | [ 2 | [ 3 | job: '/poptorch/poptorch_pr', 4 | parameters: [ 5 | string(name: 'GCCI_BRANCH', value: 'mk2-main') 6 | ] 7 | ], 8 | ] 9 | -------------------------------------------------------------------------------- /.clang-format: -------------------------------------------------------------------------------- 1 | Language: Cpp 2 | BasedOnStyle: llvm 3 | -------------------------------------------------------------------------------- /.clang-tidy: -------------------------------------------------------------------------------- 1 | Checks: '*, -abseil*, -altera*, -android*, -cppcoreguidelines*, -cert*, -modernize*, -boost*, -google*, -fuchsia*, -hicpp*, -objc*, -llvm*, -bugprone-exception-escape, -readability-uppercase-literal-suffix, -misc-non-private-member-variables-in-classes, -fuchsia-default-arguments-declarations, -fuchsia-default-arguments-calls, -readability-magic-numbers, -fuchsia-overloaded-operator, -performance-noexcept-move-constructor, -concurrency-mt-unsafe, -readability-function-cognitive-complexity, -misc-throw-by-value-catch-by-reference, -misc-no-recursion, -bugprone-narrowing-conversions, -bugprone-easily-swappable-parameters, -readability-make-member-function-const, -readability-use-anyofallof, -readability-identifier-length,-misc-confusable-identifiers,-bugprone-reserved-identifier,-misc-unused-using-decls' 2 | WarningsAsErrors: '*' 3 | HeaderFilterRegex: '' 4 | AnalyzeTemporaryDtors: false 5 | CheckOptions: 6 | - key: readability-identifier-naming.NamespaceCase 7 | value: lower_case 8 | - key: readability-identifier-naming.ClassCase 9 | value: CamelCase 10 | - key: readability-identifier-naming.StructCase 11 | value: CamelCase 12 | - key: readability-identifier-naming.PrivateMemberPrefix 13 | value: _ 14 | - key: readability-identifier-naming.ProtectedMemberPrefix 15 | value: _ 16 | - key: readability-identifier-naming.MemberCase 17 | value: lower_case 18 | - key: readability-identifier-naming.StructCase 19 | value: CamelCase 20 | - key: readability-identifier-naming.MethodCase 21 | value: camelBack 22 | - key: readability-identifier-naming.FunctionCase 23 | value: camelBack 24 | - key: readability-identifier-naming.VariableCase 25 | value: lower_case 26 | - key: misc-throw-by-value-catch-by-reference.MaxSize 27 | value: '8' 28 | -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @Software-GCAI/poptorch 2 | -------------------------------------------------------------------------------- /.github/workflows/apply_linters.yml: -------------------------------------------------------------------------------- 1 | name: apply_linters.py git trailer check 2 | 3 | on: 4 | push: 5 | branches: [mk2-main] 6 | pull_request: 7 | branches: [mk2-main] 8 | 9 | jobs: 10 | apply_linters: 11 | timeout-minutes: 10 12 | name: apply_linters.py git trailer check 13 | runs-on: [self-hosted, linux] 14 | steps: 15 | - uses: actions/checkout@v3 16 | with: 17 | # 0 indicates fetch history for all branches and tags. 18 | # By default the checkout action only checks out the PR 19 | # ref. However apply_linters.py needs run git commands 20 | # that reference origin/mk2-main. 21 | fetch-depth: 0 22 | # Checkout the head instead of the merge commit 23 | ref: ${{ github.event.pull_request.head.sha }} 24 | - name: Verify most recent commit's git trailer 25 | run: python scripts/apply_linters.py --check-trailer 26 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | __pycache__ 3 | .linters 4 | .cache 5 | .vscode 6 | test_data 7 | -------------------------------------------------------------------------------- /.linters/activate_buildenv.sh: -------------------------------------------------------------------------------- 1 | ../../build/activate_buildenv.sh -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v2.3.0 4 | hooks: 5 | - id: check-yaml 6 | - id: end-of-file-fixer 7 | - id: check-merge-conflict 8 | - id: trailing-whitespace 9 | - repo: local 10 | hooks: 11 | - id: apply_linters 12 | name: apply_linters 13 | entry: scripts/apply_linters.py 14 | language: python 15 | args: [-a, --add-trailer-on-success, --debug, --git-strategy=pre-commit] 16 | additional_dependencies: [pyyaml==6.0.0, packaging==23.0.0, colorama==0.4.6] 17 | # For the git trailer to be correct apply_linters.py must be applied to all the files. 18 | -------------------------------------------------------------------------------- /License.txt: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2020 Graphcore Limited 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 19 | WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include include *.hpp 2 | include poptorch/lib/* 3 | include poptorch/lib/poplar_rt/* 4 | include poptorch/lib/graphcore/lib/*.a 5 | include *.py 6 | include *.toml 7 | include License.txt 8 | include poptorch_third_party_licenses.txt 9 | -------------------------------------------------------------------------------- /docs/common/custom_dic: -------------------------------------------------------------------------------- 1 | accessor 2 | AdamW 3 | AMSGrad 4 | AsyncRebatched 5 | autograd 6 | backend 7 | booleans 8 | bwd 9 | checkpointed 10 | checkpointing 11 | codepaths 12 | config 13 | connectionist 14 | const 15 | constness 16 | CTC 17 | dict 18 | EOF 19 | float16 20 | float32 21 | FP16 22 | InputChannels 23 | ints 24 | IO 25 | ipu 26 | IPU 27 | IPUs 28 | iterable 29 | L2 30 | libpvti 31 | matmul 32 | Mk1 33 | Mk2 34 | mpirun 35 | Nesterov 36 | num 37 | OpenMPI 38 | OutputChannels 39 | PopART 40 | PopART's 41 | PopDist 42 | PopLibs 43 | PopRun 44 | PopTorch 45 | precompile 46 | pvti 47 | PyTorch 48 | PyTorch's 49 | rebatch 50 | rebatched 51 | rebatching 52 | recomputation 53 | ReducingDim 54 | replan 55 | RMSprop 56 | RTS 57 | serializable 58 | SGD 59 | sharded 60 | sharding 61 | stdout 62 | str 63 | submodules 64 | TODO 65 | tracepoint 66 | tracepoints 67 | tracepointsints 68 | unrounded 69 | unroundedPopRunsubmodules 70 | bool 71 | -------------------------------------------------------------------------------- /docs/common/graphcorelogo-html.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore/poptorch/c2a8b17762f1d7106f7bf048ff011ab4fee7685e/docs/common/graphcorelogo-html.png -------------------------------------------------------------------------------- /docs/common/graphcorelogo-pdf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore/poptorch/c2a8b17762f1d7106f7bf048ff011ab4fee7685e/docs/common/graphcorelogo-pdf.png -------------------------------------------------------------------------------- /docs/poptorch_geometric/user_guide/index.rst: -------------------------------------------------------------------------------- 1 | PyTorch Geometric for the IPU: User Guide 2 | ========================================= 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | :numbered: 3 7 | 8 | intro 9 | installation 10 | performance 11 | tutorials 12 | supported_operations 13 | reference 14 | legal 15 | -------------------------------------------------------------------------------- /docs/poptorch_geometric/user_guide/intro.rst: -------------------------------------------------------------------------------- 1 | ============ 2 | Introduction 3 | ============ 4 | 5 | .. admonition:: Experimental Release 6 | 7 | This is an experimental release of PopTorch Geometric. Not all features of PyTorch Geometric are supported, and some functions may not work as expected. The implementation may change without warning in future releases in ways that are not backwards compatible. 8 | 9 | PopTorch Geometric is a set of extensions for PyTorch Geometric, enabling Graph 10 | Neural Network models to be trained, evaluated and used on Graphcore IPU 11 | hardware. 12 | 13 | PopTorch Geometric has been designed to require as few changes as possible to 14 | your models to run on the IPU. 15 | However, it does have some differences from native PyTorch Geometric execution, 16 | in order to get the most out of IPU hardware. 17 | 18 | PopTorch Geometric depends on the functionality provided by PopTorch. 19 | 20 | PopTorch and PopTorch Geometric are included in the `Poplar SDK `__. See the `Getting Started guide `_ for your system for how to 21 | install the Poplar SDK. Refer to :numref:`installation` for how to install the PopTorch and PopTorch Geometric wheels. 22 | -------------------------------------------------------------------------------- /docs/poptorch_geometric/user_guide/legal.rst: -------------------------------------------------------------------------------- 1 | Legal notices 2 | ============= 3 | 4 | |LEGAL:TRADEMARKS| 5 | 6 | |LEGAL:EULA| 7 | 8 | © Copyright 2023 Graphcore Ltd. All rights reserved. 9 | -------------------------------------------------------------------------------- /docs/poptorch_geometric/user_guide/reference.rst: -------------------------------------------------------------------------------- 1 | .. _reference: 2 | 3 | ============= 4 | API reference 5 | ============= 6 | 7 | .. _api_options: 8 | 9 | Data loaders 10 | ============ 11 | 12 | .. autoclass:: poptorch_geometric.dataloader.DataLoader 13 | 14 | .. autoclass:: poptorch_geometric.dataloader.FixedSizeDataLoader 15 | 16 | .. autoclass:: poptorch_geometric.pyg_dataloader.FixedSizeStrategy 17 | 18 | .. autoclass:: poptorch_geometric.pyg_dataloader.OverSizeStrategy 19 | 20 | Cluster data loaders 21 | ==================== 22 | 23 | .. autoclass:: poptorch_geometric.cluster_loader.FixedSizeClusterLoader 24 | 25 | Collators 26 | ========= 27 | 28 | .. autoclass:: poptorch_geometric.collate.FixedSizeCollater 29 | 30 | Batch samplers 31 | ============== 32 | 33 | .. autoclass:: poptorch_geometric.stream_packing_sampler.StreamPackingSampler 34 | 35 | Fixed size options 36 | ================== 37 | 38 | .. autoclass:: poptorch_geometric.fixed_size_options.FixedSizeOptions 39 | -------------------------------------------------------------------------------- /docs/poptorch_geometric/user_guide/tutorials.rst: -------------------------------------------------------------------------------- 1 | .. _examples_and_tutorials: 2 | 3 | ====================== 4 | Examples and tutorials 5 | ====================== 6 | 7 | Examples demonstrating different use scenarios for PopTorch Geometric are 8 | available in the 9 | `Graphcore examples repository on GitHub `_. 10 | 11 | Tutorials in the form of Jupyter notebooks are available in the `PyTorch Geometric tutorials directory `__. These tutorials show how to get the maximum benefit from IPU systems with PopTorch Geometric. 12 | -------------------------------------------------------------------------------- /docs/user_guide/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set(LONG_TESTS mnist inferenceModel) 2 | 3 | function(add_poptorch_py_user_guide_example name path) 4 | message(STATUS "Adding python example '${name}'") 5 | set(extra_labels "") 6 | if("${name}" STREQUAL "pipeline_simple") 7 | set(extra_labels ";external_data") 8 | else() 9 | if("${name}" IN_LIST LONG_TESTS) 10 | set(extra_labels "") 11 | else() 12 | set(extra_labels ";short") 13 | endif() 14 | endif() 15 | 16 | add_test(NAME "${name}_user_guide_example" 17 | COMMAND python3 ${path}/${name}.py 18 | WORKING_DIRECTORY ${PROJECT_BINARY_DIR}) 19 | set_tests_properties("${name}_user_guide_example" PROPERTIES LABELS "user_guide_examples${extra_labels}") 20 | endfunction() 21 | 22 | install(FILES "poptorch.conf" DESTINATION "${PROJECT_BINARY_DIR}/tmp") 23 | 24 | file(GLOB EXAMPLES "${CMAKE_CURRENT_SOURCE_DIR}/*.py") 25 | if(COPY_TESTS) 26 | install(FILES ${EXAMPLES} DESTINATION "${CMAKE_CURRENT_BINARY_DIR}") 27 | set(DOC_EXAMPLES_PATH "${CMAKE_CURRENT_BINARY_DIR}") 28 | else() 29 | set(DOC_EXAMPLES_PATH "${CMAKE_CURRENT_SOURCE_DIR}") 30 | endif() 31 | 32 | foreach(EXAMPLE ${EXAMPLES}) 33 | get_filename_component(NAME ${EXAMPLE} NAME_WE) 34 | add_poptorch_py_user_guide_example(${NAME} ${DOC_EXAMPLES_PATH}) 35 | endforeach() 36 | 37 | if(BUILD_DOCS) 38 | run_poptorch_install_command( 39 | "python3 ${PROJECT_SOURCE_DIR}/scripts/docs_build.py --install-dir ${CMAKE_INSTALL_PREFIX} --add-to-sys-path ${CMAKE_INSTALL_PREFIX}" 40 | "${PROJECT_BINARY_DIR}" 41 | "docs_build.py") 42 | endif() 43 | -------------------------------------------------------------------------------- /docs/user_guide/IPU-pipeline.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore/poptorch/c2a8b17762f1d7106f7bf048ff011ab4fee7685e/docs/user_guide/IPU-pipeline.jpg -------------------------------------------------------------------------------- /docs/user_guide/buffers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Graphcore Ltd. All rights reserved. 2 | import torch 3 | import poptorch 4 | 5 | 6 | # counter_model_wrong_start 7 | class CounterModel(torch.nn.Module): 8 | def __init__(self): 9 | super().__init__() 10 | self.i = torch.tensor([0.], dtype=torch.float) 11 | 12 | def forward(self): 13 | self.i += 1 14 | return self.i 15 | 16 | 17 | model = CounterModel() 18 | poptorch_model = poptorch.inferenceModel(model) 19 | print(poptorch_model()) # tensor([1.]) 20 | print(poptorch_model()) # tensor([1.]) 21 | # counter_model_wrong_end 22 | 23 | torch.testing.assert_close(model.i, torch.tensor([1.], dtype=torch.float)) 24 | 25 | 26 | # pragma pylint: disable=function-redefined,no-member 27 | # counter_model_correct_start 28 | class CounterModel(torch.nn.Module): 29 | def __init__(self): 30 | super().__init__() 31 | self.register_buffer("i", torch.tensor([0.], dtype=torch.float)) 32 | 33 | def forward(self): 34 | self.i += 1 35 | return self.i 36 | 37 | 38 | model = CounterModel() 39 | poptorch_model = poptorch.inferenceModel(model) 40 | 41 | print(poptorch_model()) # tensor([1.]) 42 | print(poptorch_model()) # tensor([2.]) 43 | # counter_model_correct_end 44 | 45 | # Because the model is running in inference mode, we will need to manually 46 | # call copyWeightsToHost 47 | poptorch_model.copyWeightsToHost() 48 | torch.testing.assert_close(model.i, torch.tensor([2.], dtype=torch.float)) 49 | -------------------------------------------------------------------------------- /docs/user_guide/comm-group-types.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore/poptorch/c2a8b17762f1d7106f7bf048ff011ab4fee7685e/docs/user_guide/comm-group-types.png -------------------------------------------------------------------------------- /docs/user_guide/debugging.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2021 Graphcore Ltd. All rights reserved. 3 | import torch 4 | import poptorch 5 | 6 | 7 | class Model(torch.nn.Module): 8 | def __init__(self): 9 | super().__init__() 10 | self.fc1 = torch.nn.Linear(10, 10) 11 | self.relu = torch.nn.ReLU() 12 | self.fc2 = torch.nn.Linear(10, 10) 13 | self.loss = torch.nn.MSELoss(reduction="mean") 14 | 15 | def forward(self, x, labels=None): 16 | out = self.fc2(self.relu(self.fc1(x))) 17 | if self.training: 18 | return self.loss(out, labels) 19 | return out 20 | 21 | 22 | # tensor_names_start 23 | input = torch.rand(10, 10) 24 | label = torch.rand(10, 10) 25 | 26 | model = Model() 27 | poptorch_model = poptorch.trainingModel(model) 28 | poptorch_model(input, label) 29 | 30 | tensor_names = poptorch_model.getTensorNames() 31 | # tensor_names_end 32 | 33 | # tensor_anchor_start 34 | opts = poptorch.Options() 35 | opts.anchorTensor('grad_bias', 'Gradient___fc2.bias') 36 | opts.anchorTensor('update_weight', 'UpdatedVar___fc2.weight') 37 | # tensor_anchor_end 38 | 39 | poptorch_model.destroy() 40 | 41 | # tensor_retrieve_start 42 | poptorch_model = poptorch.trainingModel(model, opts) 43 | poptorch_model(input, label) 44 | 45 | grad = poptorch_model.getAnchoredTensor('grad_bias') 46 | update = poptorch_model.getAnchoredTensor('update_weight') 47 | # tensor_retrieve_end 48 | 49 | poptorch_model.destroy() 50 | 51 | # optim_state_dict_start 52 | optim = poptorch.optim.SGD(model.parameters(), lr=0.01) 53 | poptorch_model = poptorch.trainingModel(model, opts, optim) 54 | poptorch_model(input, label) 55 | 56 | state = optim.state_dict() 57 | # optim_state_dict_end 58 | -------------------------------------------------------------------------------- /docs/user_guide/example.rst: -------------------------------------------------------------------------------- 1 | Examples 2 | ======== 3 | 4 | You can find PyTorch examples and tutorials in the Graphcore GitHub `examples repository `__. 5 | This contains: 6 | 7 | * Examples of popular machine learning models for training and inference 8 | * :tutorials-repo:`Tutorials ` 9 | * :tutorials-repo:`Examples of PopTorch and IPU features ` 10 | * :tutorials-repo:`Examples of simple models ` 11 | * Source code from videos, blogs and other documents 12 | 13 | MNIST example 14 | _____________ 15 | 16 | The example in :numref:`mnist-example-code` shows how an MNIST model can be run on the IPU. The highlighted lines show the PopTorch-specific code required to run the example on multiple IPUs. 17 | 18 | You can download the full source code from GitHub: :github-poptorch:`mnist.py `. 19 | 20 | To run this example you will need to install the Poplar SDK (see the `Getting Started Guide `_ for your IPU system) and the appropriate version of ``torchvision``: 21 | 22 | .. code-block:: console 23 | 24 | $ python3 -m pip install torchvision==0.11.1 25 | 26 | .. literalinclude:: ../../examples/mnist.py 27 | :caption: MNIST example 28 | :name: mnist-example-code 29 | :start-after: mnist_start 30 | :end-before: mnist_end 31 | :emphasize-lines: 12, 15, 17, 20, 35, 96, 99 32 | :language: python 33 | :dedent: 3 34 | :linenos: 35 | :lineno-match: 36 | -------------------------------------------------------------------------------- /docs/user_guide/experimental.rst: -------------------------------------------------------------------------------- 1 | ===================== 2 | Experimental features 3 | ===================== 4 | 5 | Distributed execution without PopRun 6 | ==================================== 7 | 8 | PopTorch supports distributed execution on a Pod using the IPU over Fabric 9 | (IPUoF). 10 | 11 | If you run a program using your own distributed processing tool instead of PopRun, the only change you need to make to your code is to set the ID of the current process and 12 | the total number of processes the execution is distributed across, using 13 | :py:meth:`~poptorch.options._DistributedOptions.configureProcessId`. 14 | 15 | Note that :py:meth:`~poptorch.Options.replicationFactor` should 16 | be used to set the number of local replicas (per host) not the total (global) 17 | number of replicas. 18 | 19 | .. literalinclude:: device_iterations.py 20 | :caption: Changes required for distributed execution 21 | :start-after: distributed_execution_start 22 | :end-before: distributed_execution_end 23 | :emphasize-lines: 9, 12, 18 24 | :linenos: 25 | 26 | .. note:: ``DataLoader`` will automatically select a different subset of the 27 | dataset based on the process ID. 28 | 29 | .. warning:: All the processes must use the same seed if ``shuffle=True`` is used 30 | for the ``DataLoader``. 31 | 32 | torch.nn.CTCLoss 33 | ================ 34 | 35 | The CTCLoss operator is supported, with some limitations: 36 | 37 | #. The ``reduction`` parameter must be set to either ``sum`` or ``mean`` 38 | #. The ``targets`` tensor must be 2D, corresponding to stacked, padded layout 39 | -------------------------------------------------------------------------------- /docs/user_guide/index.rst: -------------------------------------------------------------------------------- 1 | PyTorch for the IPU: User Guide 2 | =============================== 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | :numbered: 3 7 | 8 | intro 9 | installation 10 | pytorch_to_poptorch 11 | overview 12 | batching 13 | supported_ops 14 | debugging 15 | hostio_optimisation 16 | example 17 | experimental 18 | reference 19 | legal 20 | -------------------------------------------------------------------------------- /docs/user_guide/inferenceModel.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Graphcore Ltd. All rights reserved. 2 | import os 3 | import poptorch 4 | # If running on the model then make sure to run on the full size model to 5 | # avoid running out of memory. 6 | if not poptorch.ipuHardwareIsAvailable(): 7 | os.environ["POPTORCH_IPU_MODEL"] = "1" 8 | 9 | # pylint: disable=reimported 10 | # pylint: disable=ungrouped-imports 11 | # pylint: disable=wrong-import-order 12 | # pylint: disable=wrong-import-position 13 | 14 | # inference_model_start 15 | import torch 16 | import torchvision 17 | import poptorch 18 | 19 | # Some dummy imagenet sized input. 20 | picture_of_a_cat_here = torch.randn([1, 3, 224, 224]) 21 | 22 | # The model, in this case a MobileNet model with pretrained weights that comes 23 | # canned with PyTorch. 24 | model = torchvision.models.mobilenet_v2(pretrained=True) 25 | model.train(False) 26 | 27 | # Wrap in the PopTorch inference wrapper 28 | inference_model = poptorch.inferenceModel(model) 29 | 30 | # Execute on IPU. 31 | out_tensor = inference_model(picture_of_a_cat_here) 32 | 33 | # Get the top 5 ImageNet classes. 34 | top_five_classes = torch.topk(torch.softmax(out_tensor, 1), 5) 35 | print(top_five_classes) 36 | 37 | # Try the same on native PyTorch 38 | native_out = model(picture_of_a_cat_here) 39 | 40 | native_top_five_classes = torch.topk(torch.softmax(native_out, 1), 5) 41 | 42 | # Models should be very close to native output although some operations are 43 | # numerically different and floating point differences can accumulate. 44 | assert any(top_five_classes[1][0] == native_top_five_classes[1][0]) 45 | # inference_half_start 46 | model = torch.nn.Linear(1, 10) 47 | 48 | # Cast the parameters (weights) to half. 49 | model.half() 50 | 51 | t1 = torch.tensor([1.]).half() 52 | 53 | opts = poptorch.Options() 54 | 55 | inference_model = poptorch.inferenceModel(model, opts) 56 | out = inference_model(t1) 57 | 58 | assert out.dtype == torch.half 59 | # inference_half_end 60 | -------------------------------------------------------------------------------- /docs/user_guide/legal.rst: -------------------------------------------------------------------------------- 1 | Trademarks & copyright 2 | ====================== 3 | 4 | |LEGAL:TRADEMARKS| 5 | 6 | |LEGAL:EULA| 7 | 8 | Copyright © 2020-|YEAR| Graphcore Ltd. All rights reserved. 9 | -------------------------------------------------------------------------------- /docs/user_guide/no-buffering-profile.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore/poptorch/c2a8b17762f1d7106f7bf048ff011ab4fee7685e/docs/user_guide/no-buffering-profile.png -------------------------------------------------------------------------------- /docs/user_guide/pipelined_execution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore/poptorch/c2a8b17762f1d7106f7bf048ff011ab4fee7685e/docs/user_guide/pipelined_execution.png -------------------------------------------------------------------------------- /docs/user_guide/poptorch.conf: -------------------------------------------------------------------------------- 1 | deviceIterations(1) 2 | setExecutionStrategy(poptorch.ShardedExecution()) 3 | replicationFactor(1) 4 | enableSyntheticData(True) 5 | -------------------------------------------------------------------------------- /docs/user_guide/pytorch-software-stack.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore/poptorch/c2a8b17762f1d7106f7bf048ff011ab4fee7685e/docs/user_guide/pytorch-software-stack.png -------------------------------------------------------------------------------- /docs/user_guide/sharded_execution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore/poptorch/c2a8b17762f1d7106f7bf048ff011ab4fee7685e/docs/user_guide/sharded_execution.png -------------------------------------------------------------------------------- /docs/user_guide/stages_summary.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore/poptorch/c2a8b17762f1d7106f7bf048ff011ab4fee7685e/docs/user_guide/stages_summary.png -------------------------------------------------------------------------------- /docs/user_guide/with-buffering-overlap-profile.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore/poptorch/c2a8b17762f1d7106f7bf048ff011ab4fee7685e/docs/user_guide/with-buffering-overlap-profile.png -------------------------------------------------------------------------------- /docs/user_guide/with-buffering-profile.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore/poptorch/c2a8b17762f1d7106f7bf048ff011ab4fee7685e/docs/user_guide/with-buffering-profile.png -------------------------------------------------------------------------------- /examples/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | function(add_poptorch_py_example name path) 2 | message(STATUS "Adding python example '${name}'") 3 | 4 | set(extra_labels "") 5 | if("${name}" STREQUAL "bert_ipu") 6 | set(extra_labels ";external_data") 7 | else() 8 | set(extra_labels ";short") 9 | endif() 10 | add_test(NAME "${name}_example" 11 | COMMAND python3 ${path}/${name}.py 12 | WORKING_DIRECTORY ${PROJECT_BINARY_DIR}) 13 | set_tests_properties("${name}_example" PROPERTIES LABELS "examples${extra_labels}") 14 | endfunction() 15 | 16 | file(GLOB EXAMPLES "${CMAKE_CURRENT_SOURCE_DIR}/*.py") 17 | if(COPY_TESTS) 18 | install(FILES ${EXAMPLES} DESTINATION "${CMAKE_CURRENT_BINARY_DIR}") 19 | set(EXAMPLES_PATH "${CMAKE_CURRENT_BINARY_DIR}") 20 | else() 21 | set(EXAMPLES_PATH "${CMAKE_CURRENT_SOURCE_DIR}") 22 | endif() 23 | 24 | foreach(EXAMPLE ${EXAMPLES}) 25 | get_filename_component(NAME ${EXAMPLE} NAME_WE) 26 | add_poptorch_py_example(${NAME} ${EXAMPLES_PATH}) 27 | endforeach() 28 | -------------------------------------------------------------------------------- /examples/lstm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2020 Graphcore Ltd. All rights reserved. 3 | 4 | import torch 5 | import torch.nn as nn 6 | import poptorch 7 | 8 | 9 | class SimpleLSTM(nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | self.lstm = nn.LSTM(3, 3) 13 | 14 | def forward(self, input_tensors, hidden): 15 | Y, (Y_h, Y_c) = self.lstm(input_tensors, hidden) 16 | return Y, (Y_h, Y_c) 17 | 18 | 19 | inputs = [torch.randn(1, 3) for _ in range(5)] 20 | # Add the extra 2nd dimension 21 | inputs = torch.cat(inputs).view(len(inputs), 1, -1) 22 | hidden = (torch.randn(1, 1, 3), torch.randn(1, 1, 3)) # clean out hidden state 23 | 24 | inference_lstm = poptorch.inferenceModel(SimpleLSTM()) 25 | out, hidden = inference_lstm(inputs, hidden) 26 | 27 | print(out) 28 | print(hidden) 29 | -------------------------------------------------------------------------------- /examples/simple_adder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2020 Graphcore Ltd. All rights reserved. 3 | 4 | import torch 5 | import torch.nn as nn 6 | import poptorch 7 | 8 | # This simple example demonstrates compiling a model to add 9 | # two tensors together using the IPU. 10 | 11 | 12 | class SimpleAdder(nn.Module): 13 | def forward(self, x, y): 14 | return x + y 15 | 16 | 17 | model = SimpleAdder() 18 | inference_model = poptorch.inferenceModel(model) 19 | 20 | t1 = torch.tensor([1.]) 21 | t2 = torch.tensor([2.]) 22 | 23 | assert inference_model(t1, t2) == 3.0 24 | print("Success") 25 | -------------------------------------------------------------------------------- /examples/zeus.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore/poptorch/c2a8b17762f1d7106f7bf048ff011ab4fee7685e/examples/zeus.jpg -------------------------------------------------------------------------------- /popart_compiler/include/popart_compiler/CodeletsCompilation.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | #ifndef POPART_COMPILER_CODELETS_COMPILATION_HPP 3 | #define POPART_COMPILER_CODELETS_COMPILATION_HPP 4 | 5 | #include 6 | 7 | namespace poptorch { 8 | namespace popart_compiler { 9 | 10 | // Called from python on each 'import poptorch'. Cache path is expected to be 11 | // a true filesystem path of the installed python package where codelet sources 12 | // are stored. 13 | void setCustomCodeletsPath(const char *cache_path); 14 | 15 | // Compile a custom codelet (if not already compiled) and store the output 16 | // file to the path specified with 'setCustomCodeletsPath' above. This can 17 | // safely be called from multiple threads/processes. 18 | std::unique_ptr compileCustomCodeletIfNeeded(const char *src_file_name, 19 | bool hw_only_codelet); 20 | 21 | } // namespace popart_compiler 22 | } // namespace poptorch 23 | 24 | #endif // POPART_COMPILER_CODELETS_COMPILATION_HPP 25 | -------------------------------------------------------------------------------- /popart_compiler/include/popart_compiler/SupportedOperations.inc.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2020 Graphcore Ltd. All rights reserved. 2 | /* 3 | OP_DECLS are in the following form: 4 | OP_DECL(namespace, funcName, function, onnx implementation, arguments, body argument) 5 | - namespace is the op's namespace 6 | - funcName is the op name 7 | - function is the actual op part of the : pair and will be 8 | used to name/call the given function. 9 | - Onnx implementation is the underlaying onnx function which will be 10 | called. 11 | - Arguments are the arguments to the op which will be parsed by different 12 | macros depending on which file this is in. 13 | - Body arguments are just the names of the arguments so they can be used in 14 | the cpp file. 15 | */ 16 | #include "CompilerOperationMacros.inc.hpp" 17 | #include "ManuallyAddedOperations.inc.hpp" 18 | -------------------------------------------------------------------------------- /popart_compiler/include/popart_compiler/Utils.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2020 Graphcore Ltd. All rights reserved. 2 | #ifndef POPART_COMPILER_UTILS_HPP 3 | #define POPART_COMPILER_UTILS_HPP 4 | 5 | #include 6 | #include 7 | 8 | namespace poptorch { 9 | namespace popart_compiler { 10 | 11 | bool ipuModelEnvironmentVariableIsEnabled(); 12 | 13 | bool ipuSmallModelEnvironmentVariableIsEnabled(); 14 | 15 | std::string getIpuModelVersion(); 16 | 17 | int getNumTilesPerIpu(const std::string &ipu_model_version); 18 | 19 | std::uint64_t roundUpNumIPUs(std::uint64_t num_ipus); 20 | 21 | bool waitIfIpuIsUnavailable(); 22 | 23 | bool waitForAWhile(); 24 | 25 | /** Returns the IPU version of the device if the system contains a device with 26 | * num_ipus -1 if there is a device but the architecture is unknown. 0 if there 27 | * is no device with num_ipus. 28 | * 29 | * Note: This function doesn't check if the devices are currently in use. 30 | */ 31 | std::int64_t ipuHardwareVersion(std::uint64_t num_ipus = 1); 32 | 33 | // Converts a C++ string to a unique pointer of the string array; the purpose 34 | // is to return a "string" without using the non ABI-compatible std::string 35 | std::unique_ptr stringToUniquePtr(const std::string &str); 36 | 37 | // Returns the dtype int corresponding to the onnx type string 38 | int64_t dtypeIntFromOnnxStr(const char *onnx_type); 39 | 40 | // Returns the Onnx datatype as string corresponding the dtype int used in Onnx 41 | // and Popart ops which take an int64_t dtype argument, a.g. "randomnormal" 42 | const char *onnxStrFromDtypeInt(int64_t dtype); 43 | 44 | } // namespace popart_compiler 45 | } // namespace poptorch 46 | 47 | #endif // POPART_COMPILER_UTILS_HPP 48 | -------------------------------------------------------------------------------- /popart_compiler/source/include/popart_compiler/SessionOptionsImpl.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2020 Graphcore Ltd. All rights reserved. 2 | #pragma once 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include "popart/sessionoptions.hpp" 11 | #include "popart_compiler/CompilerOptions.hpp" 12 | 13 | namespace poptorch { 14 | namespace popart_compiler { 15 | namespace detail { 16 | 17 | struct SessionOptionsImpl { 18 | SessionOptionsImpl(); 19 | 20 | std::map> bool_options; 21 | std::map> uint64_options; 22 | std::map> string_options; 23 | std::map> double_options; 24 | 25 | std::map)>> 27 | container_options; 28 | std::set options_set; 29 | 30 | popart::SessionOptions popart_options; 31 | CompilerOptions poptorch_options; 32 | 33 | void setMemoryProportion(std::uint32_t ipu, float memory) { 34 | poptorch_options.available_memory_proportion[ipu] = memory; 35 | } 36 | 37 | template 38 | void set(const std::string &key, ValueType value, 39 | std::map> &options, 40 | const std::string &typeStr) { 41 | const auto it = options.find(key); 42 | ERROR_ON_MSG(it == options.end(), 43 | "Unknown " << typeStr << " option " << key); 44 | 45 | it->second(value); 46 | options_set.insert(key); 47 | } 48 | }; 49 | 50 | } // namespace detail 51 | } // namespace popart_compiler 52 | } // namespace poptorch 53 | -------------------------------------------------------------------------------- /poptorch/include/poptorch/LowerToPopartFactories.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | #ifndef INCLUDE_POPTORCH_LOWER_TO_POPART_FACTORIES_H 3 | #define INCLUDE_POPTORCH_LOWER_TO_POPART_FACTORIES_H 4 | 5 | #include 6 | 7 | #include 8 | #include 9 | #include 10 | 11 | #include "poptorch/LowerToPopart.hpp" 12 | #include "poptorch/SessionOptionsParser.hpp" 13 | 14 | namespace poptorch { 15 | 16 | poptorch::LowerToPopart lowerToPopartFromDispatch( 17 | SessionOptionsParser &parser, bool training, AnchorList &&anchors_list, 18 | const std::function &initCallbackBuffers, 19 | std::vector &&optimizers, 20 | const AttributeAccessor &attribute_accessor, CPUCallbackMap &callbacks); 21 | } // namespace poptorch 22 | 23 | #endif // INCLUDE_POPTORCH_LOWER_TO_POPART_FACTORIES_H 24 | -------------------------------------------------------------------------------- /poptorch/source/AliasProcessing.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021 Graphcore Ltd. All rights reserved. 2 | 3 | #include 4 | #include 5 | 6 | #include "poptorch/AliasProcessing.hpp" 7 | 8 | namespace poptorch { 9 | 10 | void resolveAliases(torch::jit::Graph *graph) { 11 | std::vector to_delete; 12 | 13 | for (auto *node : graph->nodes()) { 14 | if (node->kind() != c10::aten::alias) { 15 | continue; 16 | } 17 | 18 | // Replace the alias output with the direct input 19 | node->output()->replaceAllUsesWith(node->input()); 20 | to_delete.push_back(node); 21 | } 22 | 23 | for (auto *node : to_delete) { 24 | node->destroy(); 25 | } 26 | } 27 | } // namespace poptorch 28 | -------------------------------------------------------------------------------- /poptorch/source/ErrorOnUnsupportedAten.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2020 Graphcore Ltd. All rights reserved. 2 | 3 | #include 4 | 5 | #include "poptorch/PopartCanonicalization.hpp" 6 | #include "poptorch_logging/Error.hpp" 7 | #include "poptorch_logging/Logging.hpp" 8 | 9 | namespace poptorch { 10 | 11 | void errorOnUnsupportedAten(torch::jit::Graph *graph) { 12 | // Check that all of the "aten::" ops have been eliminated. 13 | std::unordered_set unsupported_ops; 14 | 15 | for (torch::jit::Node *node : graph->nodes()) { 16 | if (node->kind().is_aten()) { 17 | unsupported_ops.insert(node->kind()); 18 | } 19 | } 20 | 21 | // Terminate compilation via error. 22 | if (!unsupported_ops.empty()) { 23 | std::stringstream ss; 24 | std::string sep; 25 | for (const auto &op : unsupported_ops) { 26 | ss << sep << op.toQualString(); 27 | sep = ", "; 28 | } 29 | 30 | ERROR("Unsupported ops found in compiled model: [" 31 | << ss.str() 32 | << "]. Not all operations are supported yet by Graphcore's PyTorch " 33 | "compiler. If you believe any of these should be, please report " 34 | "this message to support@graphcore.ai."); 35 | } 36 | } 37 | 38 | } // namespace poptorch 39 | -------------------------------------------------------------------------------- /poptorch/source/PoptorchStaticInit.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2020 Graphcore Ltd. All rights reserved. 2 | #ifndef SOURCE_POPTORCH_STATIC_INIT_H 3 | #define SOURCE_POPTORCH_STATIC_INIT_H 4 | 5 | // The constants below set priorities for constructor functions used to 6 | // initialize static data. Functions with lower numbers run first. 7 | 8 | // Priority value for symbol initialisation functions 9 | #define SYMBOL_INIT_PRIORITY 101 10 | 11 | // Priority value for shape inference registration functions 12 | #define SHAPE_INFERENCE_INIT_PRIORITY 102 13 | 14 | // Priority value for handler registration functions 15 | #define HANDLER_INIT_PRIORITY 103 16 | 17 | #endif 18 | -------------------------------------------------------------------------------- /poptorch/source/RequiresGrad.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | 3 | #include 4 | #include 5 | 6 | #include "poptorch/DispatchTracer.hpp" 7 | #include "poptorch/RequiresGrad.hpp" 8 | #include "poptorch/Utils.hpp" 9 | #include "poptorch_logging/Logging.hpp" 10 | 11 | namespace poptorch { 12 | 13 | void fixRequiresGradFromDispatch(torch::jit::Graph *graph) { 14 | // For each output of each node in the graph. 15 | for (auto *node : graph->nodes()) { 16 | for (auto *output : node->outputs()) { 17 | auto tensor_type = output->type()->cast(); 18 | if (!tensor_type) { 19 | continue; 20 | } 21 | auto device = tensor_type->device(); 22 | if (!device) { 23 | continue; 24 | } 25 | if (device->type() != at::DeviceType::IPU) { 26 | continue; 27 | } 28 | // If the output is an IPU floating-point tensor, check if any 29 | // of the inputs has requires_grad set, and update the Value if 30 | // needed. 31 | bool requires_grad = false; 32 | if (tensor_type->scalarType().has_value() && 33 | c10::isFloatingType(tensor_type->scalarType().value())) { 34 | for (auto *input : node->inputs()) { 35 | if (input->requires_grad()) { 36 | requires_grad = true; 37 | break; 38 | } 39 | } 40 | } 41 | if (requires_grad != output->requires_grad()) { 42 | logging::trace("[requires_grad] Set requires_grad={} on node {}", 43 | requires_grad, nodeToString(node)); 44 | output->setType(tensor_type->withRequiresGrad(requires_grad)); 45 | } 46 | } 47 | } 48 | } 49 | 50 | } // namespace poptorch 51 | -------------------------------------------------------------------------------- /poptorch/source/dispatch_tracer/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | 2 | set(CMAKE_POSITION_INDEPENDENT_CODE ON) 3 | 4 | add_library(dispatch_tracer STATIC 5 | RegisterAtenOverloads.cpp 6 | CommonHelperFunctions.cpp 7 | dispatchers/IDispatch.cpp 8 | dispatchers/JitDispatch.cpp 9 | InplaceAliasMapper.cpp 10 | ValueMapper.cpp 11 | Tensor.cpp 12 | TypeInferenceHandler.cpp 13 | ) 14 | 15 | target_link_libraries(dispatch_tracer 16 | PUBLIC 17 | torch 18 | PRIVATE 19 | poptorch_internal_headers 20 | poptorch_logging 21 | poptorch_compiler 22 | popart_compiler 23 | poptorch_err 24 | ) 25 | 26 | set_property(TARGET dispatch_tracer PROPERTY CXX_STANDARD 17) 27 | -------------------------------------------------------------------------------- /poptorch/source/dispatch_tracer/InplaceAliasMapper.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | #include 3 | 4 | #include "InplaceAliasMapper.hpp" 5 | 6 | namespace poptorch { 7 | 8 | InplaceArgAliasMapper &InplaceArgAliasMapper::getInstance() { 9 | static InplaceArgAliasMapper instance; 10 | return instance; 11 | } 12 | 13 | void InplaceArgAliasMapper::registerInplaceArgId( 14 | const std::string &operator_name, std::size_t alias_arg_id) { 15 | 16 | std::string key = 17 | _namespace ? fmt::format("{}::{}", _namespace.value(), operator_name) 18 | : operator_name; 19 | _operator_name_to_arg_id.emplace(key, alias_arg_id); 20 | } 21 | 22 | std::optional 23 | InplaceArgAliasMapper::getInplaceArg(const std::string &operator_name) { 24 | auto &operator_name_to_arg_id = getInstance()._operator_name_to_arg_id; 25 | const auto it = operator_name_to_arg_id.find(operator_name); 26 | if (it != operator_name_to_arg_id.end()) { 27 | return it->second; 28 | } 29 | return std::nullopt; 30 | } 31 | 32 | void InplaceArgAliasMapper::setNamespace(const std::string &p_namespace) { 33 | _namespace = p_namespace; 34 | } 35 | 36 | void InplaceArgAliasMapper::unsetNamespace() { _namespace = std::nullopt; } 37 | 38 | InplaceArgAliasMapperInit::InplaceArgAliasMapperInit( 39 | void (*init_mapper)(InplaceArgAliasMapper &), 40 | const std::string &p_namespace) { 41 | auto &alias_mapper = InplaceArgAliasMapper::getInstance(); 42 | alias_mapper.setNamespace(p_namespace); 43 | init_mapper(alias_mapper); 44 | alias_mapper.unsetNamespace(); 45 | } 46 | 47 | INPLACE_ARG_MAPPER_IMPL(torch_scatter, mapper) { 48 | mapper.registerInplaceArgId("scatter_mul", 3); 49 | mapper.registerInplaceArgId("scatter_max", 3); 50 | mapper.registerInplaceArgId("scatter_min", 3); 51 | } 52 | 53 | } // namespace poptorch 54 | -------------------------------------------------------------------------------- /poptorch/source/include/poptorch/AliasProcessing.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021 Graphcore Ltd. All rights reserved. 2 | #ifndef INCLUDE_POPTORCH_ALIAS_PROCESSING_H 3 | #define INCLUDE_POPTORCH_ALIAS_PROCESSING_H 4 | 5 | namespace torch { 6 | namespace jit { 7 | struct Graph; 8 | } // namespace jit 9 | } // namespace torch 10 | 11 | namespace poptorch { 12 | 13 | // Remove instances of aten::alias in the graph by replacing the outputs with 14 | // the original (aliased) output. The known source of aliases is when an 15 | // operation takes place on a wrapped buffer, for which the return value tensor 16 | // is aliased and then set to be a member of the original (wrapper) subclass. 17 | void resolveAliases(torch::jit::Graph *graph); 18 | 19 | } // namespace poptorch 20 | 21 | #endif // INCLUDE_POPTORCH_ALIAS_PROCESSING_H 22 | -------------------------------------------------------------------------------- /poptorch/source/include/poptorch/ImplicitCasting.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2020 Graphcore Ltd. All rights reserved. 2 | #ifndef INCLUDE_POPTORCH_IMPLICIT_CASTING_HPP 3 | #define INCLUDE_POPTORCH_IMPLICIT_CASTING_HPP 4 | #include 5 | 6 | namespace c10 { 7 | template class ArrayRef; 8 | } // namespace c10 9 | 10 | namespace torch { 11 | namespace jit { 12 | template using ArrayRef = c10::ArrayRef; 13 | struct Graph; 14 | struct Value; 15 | } // namespace jit 16 | } // namespace torch 17 | 18 | namespace poptorch { 19 | 20 | enum class ImplicitCast { 21 | None, 22 | All, 23 | ExceptFirst, 24 | ExceptSecond, 25 | ExceptThird, 26 | ExceptFourthFifth 27 | }; 28 | 29 | enum class ImplicitCastOutput { None, AsPromoted, AlwaysBool, AlwaysFloat }; 30 | 31 | std::vector 32 | implicitCastInputs(torch::jit::ArrayRef *inputs, 33 | ImplicitCast implicit_cast); 34 | 35 | // TODO(T55228): remove after we use our own dispatch key. 36 | // With the dispatcher we catch implicit torch casts (intercepted with 37 | // JitDispatch::toCopyInplace) but it seems that in the case of CPU tensors, 38 | // the returned (casted) aten tensors are not reflected in the later ops, i.e. 39 | // we might end up with dead implicit casts in the ir which we clean with this 40 | // pass. The actual poptorch casting is done in our canonicalization handlers 41 | // anyway. 42 | void removeDeadImplicitCasts(torch::jit::Graph *graph); 43 | 44 | } // namespace poptorch 45 | 46 | #endif // INCLUDE_POPTORCH_IMPLICIT_CASTING_HPP 47 | -------------------------------------------------------------------------------- /poptorch/source/include/poptorch/OverlappedIO.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021 Graphcore Ltd. All rights reserved. 2 | #ifndef INCLUDE_POPTORCH_OVERLAPPED_IO_H 3 | #define INCLUDE_POPTORCH_OVERLAPPED_IO_H 4 | 5 | namespace torch { 6 | namespace jit { 7 | struct Graph; 8 | 9 | } // namespace jit 10 | } // namespace torch 11 | 12 | namespace poptorch { 13 | 14 | // Turns any set_overlap_for_input nodes applied to inputs into attributes of 15 | // the parameter node. These attributes specify any host IO Overlapped for the 16 | // input 17 | void attributiseOverlappedIO(torch::jit::Graph *graph); 18 | 19 | } // namespace poptorch 20 | 21 | #endif // INCLUDE_POPTORCH_OVERLAPPED_IO_H 22 | -------------------------------------------------------------------------------- /poptorch/source/include/poptorch/RequiresGrad.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | #ifndef INCLUDE_POPTORCH_REQUIRES_GRAD_H 3 | #define INCLUDE_POPTORCH_REQUIRES_GRAD_H 4 | 5 | namespace torch { 6 | namespace jit { 7 | struct Graph; 8 | } // namespace jit 9 | } // namespace torch 10 | 11 | namespace poptorch { 12 | 13 | // Autograd sets the requires_grad flag on the ATen tensors 14 | // after we've instantiated the corresponding ATen node in the dispatcher. 15 | // This pass goes through all the nodes in the ATen graph and sets the 16 | // requires_graph flag on a node's outputs if any of its inputs has 17 | // requires_grad set. 18 | void fixRequiresGradFromDispatch(torch::jit::Graph *graph); 19 | 20 | } // namespace poptorch 21 | 22 | #endif // INCLUDE_POPTORCH_REQUIRES_GRAD_H 23 | -------------------------------------------------------------------------------- /poptorch/source/popart_canonicalization/CustomOps.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2020 Graphcore Ltd. All rights reserved. 2 | #include 3 | 4 | #include "../PoptorchStaticInit.hpp" 5 | #include "../PoptorchSymbols.hpp" 6 | #include "PopartCanonicalizationUtils.hpp" 7 | #include "poptorch/OpBuilder.hpp" 8 | #include "poptorch/Utils.hpp" 9 | #include "poptorch_logging/Error.hpp" 10 | 11 | namespace poptorch { 12 | namespace { 13 | 14 | torch::jit::Node *customOpHandler(torch::jit::Graph *graph, 15 | torch::jit::Node *node) { 16 | std::vector inputs = 17 | handleTensorList(node->input(0)->node()); 18 | std::string name = constantToString(node->input(1)->node()); 19 | std::string domain = constantToString(node->input(2)->node()); 20 | 21 | // Get the domain version. 22 | std::int64_t domain_version = constantToLong(node->input(3)->node()); 23 | 24 | // Get the number of outputs. 25 | std::int64_t num_outputs = constantToLong(node->input(4)->node()); 26 | 27 | // The attributes are in the Python dict represented by an id within a string 28 | auto attributes_id_str = constantToString(node->input(6)->node()); 29 | 30 | // Add the custom op with a variadic number of outputs. 31 | torch::jit::Node *custom_op = 32 | createCustomOperation(graph, inputs, name, domain, domain_version, 33 | num_outputs, attributes_id_str); 34 | 35 | // It is replacing an operation which returned a list so add a list 36 | // construct to keep the IR legal. 37 | return createAndInsertNode(graph, at::prim::ListConstruct, 38 | custom_op->outputs()); 39 | } 40 | 41 | } // namespace 42 | 43 | __attribute__((constructor(HANDLER_INIT_PRIORITY))) static void registration() { 44 | registerHandler(symbols::poptorch::custom_operation, customOpHandler); 45 | } 46 | 47 | } // namespace poptorch 48 | -------------------------------------------------------------------------------- /poptorch/source/popart_canonicalization/PyGTorchSplineConvOps.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | #include "../PoptorchStaticInit.hpp" 3 | #include "../PoptorchSymbols.hpp" 4 | #include "PopartCanonicalizationUtils.hpp" 5 | 6 | #include "poptorch/OpBuilder.hpp" 7 | #include "poptorch/Utils.hpp" 8 | #include "poptorch_logging/Error.hpp" 9 | #include "poptorch_logging/Logging.hpp" 10 | 11 | namespace poptorch { 12 | namespace { 13 | 14 | torch::jit::Node *torchSplineBasisHandler(torch::jit::Graph *graph, 15 | torch::jit::Node *node) { 16 | // Signatures for spline_basis 17 | // (Tensor pseudo, Tensor kernelSize, Tensor isOpenSpline, int degree) 18 | 19 | const std::vector args{node->input(0), node->input(1), 20 | node->input(2)}; 21 | const std::int32_t degree = constantToInt(node->input(3)->node()); 22 | 23 | auto *result = createSplinebasis(graph, args, degree); 24 | 25 | return result; 26 | } 27 | 28 | torch::jit::Node *torchSplineWeightingHandler(torch::jit::Graph *graph, 29 | torch::jit::Node *node) { 30 | // Signatures for spline_weighting 31 | // (Tensor input, Tensor weight, Tensor basis, Tensor weightIndex) 32 | 33 | const std::vector args{node->input(0), node->input(1), 34 | node->input(2), node->input(3)}; 35 | 36 | auto *result = createSplineweighting(graph, args); 37 | 38 | return result; 39 | } 40 | 41 | } // namespace 42 | 43 | __attribute__((constructor(HANDLER_INIT_PRIORITY))) static void registration() { 44 | registerHandler(torch_spline_conv::spline_basis, torchSplineBasisHandler); 45 | registerHandler(torch_spline_conv::spline_weighting, 46 | torchSplineWeightingHandler); 47 | } 48 | 49 | } // namespace poptorch 50 | -------------------------------------------------------------------------------- /poptorch/source/popart_canonicalization/ScatterReduction.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | #include "ScatterReduction.hpp" 4 | #include "PopartCanonicalizationUtils.hpp" 5 | 6 | namespace poptorch { 7 | 8 | std::int32_t getReductionMethod(torch::jit::Node *node) { 9 | const auto reduce = constantToString(node); 10 | if (reduce == "sum" || reduce == "add") { 11 | return static_cast(ScatterReduction::Sum); 12 | } 13 | if (reduce == "amax") { 14 | return static_cast(ScatterReduction::Max); 15 | } 16 | if (reduce == "amin") { 17 | return static_cast(ScatterReduction::Min); 18 | } 19 | if (reduce == "mean") { 20 | return static_cast(ScatterReduction::Mean); 21 | } 22 | if (reduce == "prod" || reduce == "multiply") { 23 | return static_cast(ScatterReduction::Mul); 24 | } 25 | 26 | ERROR("Unsupported reduction type for scatter_reduce: " << reduce); 27 | } 28 | 29 | float getReductionInitValue(std::int32_t reduce) { 30 | float init_val; 31 | switch (reduce) { 32 | case static_cast(ScatterReduction::Sum): 33 | case static_cast(ScatterReduction::Mean): 34 | init_val = 0.0; 35 | break; 36 | case static_cast(ScatterReduction::Mul): 37 | init_val = 1.0; 38 | break; 39 | case static_cast(ScatterReduction::Max): 40 | init_val = -std::numeric_limits::infinity(); 41 | break; 42 | case static_cast(ScatterReduction::Min): 43 | init_val = std::numeric_limits::infinity(); 44 | break; 45 | default: 46 | ERROR("Unsupported reduction type for scatter_reduce: " << reduce); 47 | break; 48 | } 49 | return init_val; 50 | } 51 | 52 | } // namespace poptorch 53 | -------------------------------------------------------------------------------- /poptorch/source/popart_canonicalization/ScatterReduction.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | #ifndef SCATTER_REDUCTION_H 3 | #define SCATTER_REDUCTION_H 4 | 5 | #include 6 | 7 | namespace torch { 8 | namespace jit { 9 | class Node; 10 | } // namespace jit 11 | } // namespace torch 12 | 13 | namespace poptorch { 14 | 15 | enum class ScatterReduction { Sum = 0, Max, Min, Mul, None, Mean }; 16 | 17 | std::int32_t getReductionMethod(torch::jit::Node *node); 18 | float getReductionInitValue(std::int32_t reduce); 19 | 20 | } // namespace poptorch 21 | 22 | #endif 23 | -------------------------------------------------------------------------------- /poptorch/source/type_and_constant_canonicalization/CastUnsupportedInputs.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2020 Graphcore Ltd. All rights reserved. 2 | 3 | #include "torch/csrc/jit/ir/ir.h" 4 | 5 | #include "poptorch_logging/Error.hpp" 6 | #include "poptorch_logging/Logging.hpp" 7 | 8 | #include "poptorch/OpBuilder.hpp" 9 | #include "poptorch/TypeAndConstantCanonicalization.hpp" 10 | #include "poptorch/Utils.hpp" 11 | 12 | #include "../PoptorchSymbols.hpp" 13 | 14 | namespace poptorch { 15 | namespace type_and_constant_canonicalization { 16 | namespace { 17 | void processInputTensor(torch::jit::Graph *graph, torch::jit::Value *input) { 18 | auto tensor_type = input->type()->expect(); 19 | auto current_type = tensor_type->scalarType().value(); 20 | 21 | at::ScalarType new_type = coerceToSupportedType(current_type); 22 | 23 | if (current_type == at::ScalarType::BFloat16) { 24 | new_type = at::ScalarType::Half; 25 | } else if (new_type == current_type) { 26 | // No need for a host side cast 27 | return; 28 | } 29 | 30 | auto *earliest_user = findEarliestUser(input); 31 | if (earliest_user == nullptr) { 32 | logging::warn("Graph contains an unused input %{} : {}", input->debugName(), 33 | *tensor_type); 34 | return; 35 | } 36 | 37 | // This is an identity op but used just to make sure the implicit cast 38 | // does not end up promoting to a Double/Long 39 | auto *new_node = graph->create(symbols::poptorch::host_side_cast); 40 | 41 | insertNodeBeforeNode(new_node, earliest_user); 42 | input->replaceAllUsesWith(new_node->output()); 43 | new_node->addInput(input); 44 | 45 | new_node->output()->setType(tensor_type->withScalarType(new_type)); 46 | } 47 | } // namespace 48 | 49 | void castUnsupportedInputs(torch::jit::Graph *graph) { 50 | auto collapsed_inputs = collapsedGraphInputHierachy(graph); 51 | 52 | for (auto *input : collapsed_inputs) { 53 | if (input != nullptr) { 54 | processInputTensor(graph, input); 55 | } 56 | } 57 | } 58 | 59 | } // namespace type_and_constant_canonicalization 60 | } // namespace poptorch 61 | -------------------------------------------------------------------------------- /poptorch_compiler/pytorch_bridge/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | file(GLOB_RECURSE poptorch_compiler_public_headers "${CMAKE_CURRENT_SOURCE_DIR}/include/*.hpp*") 2 | 3 | add_library(poptorch_compiler SHARED 4 | IpuSession.cpp 5 | ) 6 | 7 | target_link_libraries(poptorch_compiler 8 | PRIVATE 9 | poptorch_logging 10 | ) 11 | 12 | set_property(TARGET poptorch_compiler PROPERTY CXX_STANDARD 17) 13 | 14 | set_target_properties(poptorch_compiler PROPERTIES 15 | PUBLIC_HEADER "${poptorch_compiler_public_headers}") 16 | 17 | target_include_directories(poptorch_compiler 18 | PUBLIC 19 | $ 20 | $ 21 | ) 22 | install(TARGETS poptorch_compiler 23 | LIBRARY 24 | DESTINATION ${CMAKE_INSTALL_LIBDIR} 25 | PUBLIC_HEADER 26 | DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/pytorch_bridge 27 | ) 28 | -------------------------------------------------------------------------------- /poptorch_compiler/pytorch_bridge/include/pytorch_bridge/CompilerOptions.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | #ifndef POPTORCH_COMPILER_PYTORCH_BRIDGE_COMPILER_OPTIONS_HPP_ 3 | #define POPTORCH_COMPILER_PYTORCH_BRIDGE_COMPILER_OPTIONS_HPP_ 4 | 5 | #include 6 | 7 | namespace poptorch { 8 | 9 | struct CompilerOptions { 10 | struct Dispatcher { 11 | // NOTE: std::string-s are avoided here due to ABI issues 12 | std::vector> source_location_excludes; 13 | bool check_added_ops = true; 14 | }; 15 | Dispatcher dispatcher; 16 | }; 17 | 18 | } // namespace poptorch 19 | 20 | #endif 21 | -------------------------------------------------------------------------------- /poptorch_compiler/pytorch_bridge/include/pytorch_bridge/DebugInfo.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | #ifndef POPTORCH_COMPILER_PYTORCH_BRIDGE_DEBUG_INFO_HPP_ 3 | #define POPTORCH_COMPILER_PYTORCH_BRIDGE_DEBUG_INFO_HPP_ 4 | 5 | #include 6 | #include 7 | 8 | namespace poptorch_ir { 9 | 10 | struct GraphDebugInfo { 11 | // Note these are shared with the tensor details 12 | std::shared_ptr> initial_graph; 13 | std::shared_ptr> cached_graph; 14 | }; 15 | 16 | struct TensorDebugInfo { 17 | GraphDebugInfo debug_info; 18 | std::size_t output_idx = 0; 19 | }; 20 | 21 | } // namespace poptorch_ir 22 | 23 | #endif // POPTORCH_COMPILER_PYTORCH_BRIDGE_DEBUG_INFO_HPP_ 24 | -------------------------------------------------------------------------------- /poptorch_compiler/pytorch_bridge/include/pytorch_bridge/IpuSession.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | #ifndef POPTORCH_COMPILER_PYTORCH_BRIDGE_IPU_SESSION_HPP_ 3 | #define POPTORCH_COMPILER_PYTORCH_BRIDGE_IPU_SESSION_HPP_ 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include "pytorch_bridge/CompilerTypes.hpp" 11 | #include "pytorch_bridge/DebugInfo.hpp" 12 | #include 13 | 14 | namespace poptorch_ir { 15 | 16 | struct FunctionIO { 17 | std::vector inputs; 18 | std::vector outputs; 19 | }; 20 | 21 | class Buffer { 22 | // TODO(T70841): since Buffer is stored as a shared pointer it should be 23 | // possible at least stop CpuBuffer being a shared pointer. 24 | std::variant _store = std::monostate{}; 25 | 26 | public: 27 | Buffer() = default; 28 | explicit Buffer(CpuBuffer buf) noexcept; 29 | 30 | Buffer &operator=(CpuBuffer buf) noexcept; 31 | 32 | const CpuBuffer &getCpuData(); 33 | const CpuBuffer &getCpuData() const; 34 | 35 | bool hasData() const; 36 | }; 37 | 38 | class IIpuSession { 39 | public: 40 | virtual ~IIpuSession() = default; 41 | 42 | virtual Buffer allocate(const TensorType &type) = 0; 43 | virtual void copyDataFromCpuSource(Buffer &ipu_dest, const char *cpu_src) = 0; 44 | virtual void copyDataToCpu(char *cpu_dest, Buffer &ipu_src) = 0; 45 | virtual void copyDataOnDevice(Buffer &dest, const Buffer &src) = 0; 46 | }; 47 | 48 | std::shared_ptr createStaticSession(); 49 | 50 | } // namespace poptorch_ir 51 | 52 | #endif // POPTORCH_COMPILER_PYTORCH_BRIDGE_IPU_SESSION_HPP_ 53 | -------------------------------------------------------------------------------- /poptorch_err/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.14 FATAL_ERROR) 2 | project(poptorch_err) 3 | 4 | set(CMAKE_POSITION_INDEPENDENT_CODE ON) 5 | 6 | add_library(poptorch_exception_info INTERFACE) 7 | 8 | target_include_directories(poptorch_exception_info 9 | INTERFACE 10 | exception_info) 11 | 12 | add_library(poptorch_err STATIC 13 | "source/ExceptionHandling.cpp") 14 | 15 | target_include_directories(poptorch_err SYSTEM PUBLIC 16 | $ 17 | $ 18 | ) 19 | file(GLOB_RECURSE poptorch_err_public_headers "${CMAKE_CURRENT_SOURCE_DIR}/include/*.hpp*" "exception_info/*.hpp*") 20 | 21 | set_target_properties(poptorch_err PROPERTIES 22 | PUBLIC_HEADER "${poptorch_err_public_headers}") 23 | target_link_libraries(poptorch_err 24 | PUBLIC 25 | torch 26 | poptorch_exception_info 27 | PRIVATE 28 | popart_compiler 29 | poptorch_logging) 30 | 31 | install(TARGETS poptorch_err 32 | LIBRARY 33 | DESTINATION ${CMAKE_INSTALL_LIBDIR} 34 | PUBLIC_HEADER 35 | DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/poptorch_err 36 | ) 37 | -------------------------------------------------------------------------------- /poptorch_err/exception_info/poptorch_err/ExceptionInfo.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | #pragma once 3 | 4 | #include 5 | 6 | namespace poptorch { 7 | 8 | enum class ErrorCategory { RuntimeRecoverable, RuntimeUnrecoverable, Other }; 9 | 10 | /* 11 | * A subclass of this class is used to pass exception information across the ABI 12 | * boundary between popart_compiler and the pybind11 interface. It has to use 13 | * POD data types to cross the boundary successfully. We then unpack it into a 14 | * PoptorchError on the pybind11 side and rethrow it. 15 | */ 16 | class ExceptionInfo { 17 | public: 18 | virtual ~ExceptionInfo(); 19 | const virtual char *what() const = 0; 20 | const virtual char *type() const = 0; 21 | virtual int64_t stackDepth() const = 0; 22 | const virtual char *stack(int64_t level) const = 0; 23 | const virtual char *filename() const = 0; 24 | virtual uint64_t line() const = 0; 25 | const virtual char *recoveryAction() const = 0; 26 | virtual ErrorCategory category() const = 0; 27 | }; 28 | 29 | } // namespace poptorch 30 | -------------------------------------------------------------------------------- /poptorch_geometric/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.14 FATAL_ERROR) 2 | project(poptorch-geometric) 3 | 4 | set(INSTALL_POPPYG_PYDIR ${CMAKE_INSTALL_PREFIX}/poptorch_geometric) 5 | 6 | add_subdirectory(python) 7 | 8 | add_custom_target(poptorch_geometric_wheel 9 | WORKING_DIRECTORY ${CMAKE_INSTALL_PREFIX} 10 | COMMAND python3 ${PROJECT_SOURCE_DIR}/../scripts/generate_poppyg_package.py bdist_wheel --output-dir ${CMAKE_INSTALL_PREFIX}/dist --python-dir ${INSTALL_POPPYG_PYDIR} 11 | ) 12 | 13 | add_custom_target(poptorch_geometric_sdist 14 | WORKING_DIRECTORY ${CMAKE_INSTALL_PREFIX} 15 | COMMAND python3 ${PROJECT_SOURCE_DIR}/../scripts/generate_poppyg_package.py sdist --output-dir ${CMAKE_INSTALL_PREFIX}/dist --python-dir ${INSTALL_POPPYG_PYDIR} 16 | ) 17 | 18 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/README.md 19 | DESTINATION .) 20 | -------------------------------------------------------------------------------- /poptorch_geometric/License.txt: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2023 Graphcore Limited 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 19 | WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /poptorch_geometric/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include *.py 2 | include *.toml 3 | include License.txt 4 | -------------------------------------------------------------------------------- /poptorch_geometric/README.md: -------------------------------------------------------------------------------- 1 | # poptorch-geometric 2 | Set of extensions for PyTorch Geometric, enabling GNN models to be trained, evaluated and used on the Graphcore IPU. 3 | 4 | :warning: This project is under active development. All APIs should be considered volatile and any feedback is welcome. 5 | -------------------------------------------------------------------------------- /poptorch_geometric/config.buildenv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | 3 | installers.add(PipRequirements("requirements.txt")) 4 | -------------------------------------------------------------------------------- /poptorch_geometric/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=42", 4 | "wheel" 5 | ] 6 | build-backend = "setuptools.build_meta" 7 | -------------------------------------------------------------------------------- /poptorch_geometric/python/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | include(GNUInstallDirs) 2 | file(GLOB poppyg_python_files "${CMAKE_CURRENT_SOURCE_DIR}/*.py") 3 | 4 | # __init__.py needs to be edited by set_version.py so don't copy it over. 5 | list(REMOVE_ITEM poppyg_python_files "${CMAKE_CURRENT_SOURCE_DIR}/__init__.py") 6 | 7 | install(CODE 8 | " execute_process( 9 | COMMAND python3 ${PROJECT_SOURCE_DIR}/../scripts/set_version.py --input-file ${CMAKE_CURRENT_SOURCE_DIR}/__init__.py ${CMAKE_CURRENT_BINARY_DIR}/__init__.py 10 | WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} 11 | RESULT_VARIABLE RETVAL OUTPUT_VARIABLE OUTPUT ERROR_VARIABLE OUTPUT) 12 | if(RETVAL AND NOT RETVAL EQUAL 0) 13 | message(FATAL_ERROR \"set_version.py FAILED: \${OUTPUT}\") 14 | endif()") 15 | 16 | install(FILES ${CMAKE_CURRENT_BINARY_DIR}/__init__.py DESTINATION "${INSTALL_POPPYG_PYDIR}") 17 | install(FILES ${poppyg_python_files} py.typed DESTINATION "${INSTALL_POPPYG_PYDIR}") 18 | install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/ops DESTINATION "${INSTALL_POPPYG_PYDIR}") 19 | 20 | install(CODE 21 | " execute_process( 22 | COMMAND python3 ${PROJECT_SOURCE_DIR}/../scripts/generate_poppyg_package.py install --output-dir ${CMAKE_INSTALL_PREFIX} --python-dir ${INSTALL_POPPYG_PYDIR} 23 | WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} 24 | RESULT_VARIABLE RETVAL OUTPUT_VARIABLE OUTPUT ERROR_VARIABLE OUTPUT) 25 | if(RETVAL AND NOT RETVAL EQUAL 0) 26 | message(FATAL_ERROR \"generate_poppyg_package.py FAILED: \${OUTPUT}\") 27 | endif()") 28 | -------------------------------------------------------------------------------- /poptorch_geometric/python/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | import importlib 3 | 4 | from .collate import make_exclude_keys 5 | from .common import call_once 6 | from .dataloader import (FixedSizeDataLoader, FixedSizeStrategy, 7 | OverSizeStrategy) 8 | from .fixed_size_options import FixedSizeOptions 9 | from .types import PyGArgsParser, registerCustomArgParsers 10 | from .utils import TrainingStepper, set_aggregation_dim_size 11 | from .override import _TorchGeometricOpsSubstitutionManager 12 | 13 | __version__ = "@VERSION@-@SNAPSHOT@" 14 | 15 | __all__ = [ 16 | '__version__', 'FixedSizeDataLoader', 'FixedSizeOptions', 17 | 'FixedSizeStrategy', 'set_aggregation_dim_size', 'TrainingStepper', 18 | 'make_exclude_keys', 'OverSizeStrategy', 'PyGArgsParser' 19 | ] 20 | 21 | 22 | @call_once 23 | def registerOverrideManager(): 24 | poplar_executor_spec = importlib.util.find_spec( 25 | "poptorch._poplar_executor") 26 | if poplar_executor_spec is not None: 27 | loader = poplar_executor_spec.loader 28 | if loader is not None: 29 | poplar_executor = loader.load_module() 30 | poplar_executor._OverwriteContextManager.registerSubsitutionManager( # pylint: disable=protected-access 31 | _TorchGeometricOpsSubstitutionManager) 32 | 33 | 34 | registerOverrideManager() 35 | registerCustomArgParsers() 36 | -------------------------------------------------------------------------------- /poptorch_geometric/python/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | from torch_geometric.data import Batch, Data, HeteroData 4 | 5 | DataBatch = type(Batch(_base_cls=Data)) 6 | HeteroDataBatch = type(Batch(_base_cls=HeteroData)) 7 | 8 | 9 | def call_once(f): 10 | def wrapper(*args, **kwargs): 11 | if not wrapper.has_run: 12 | wrapper.has_run = True 13 | return f(*args, **kwargs) 14 | return None 15 | 16 | wrapper.has_run = False 17 | return wrapper 18 | -------------------------------------------------------------------------------- /poptorch_geometric/python/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | from .aggregation_base import Aggregation 4 | from .cluster_gcn_conv import ClusterGCNConv 5 | from .hetero_linear import HeteroLinear 6 | from .instance_norm import InstanceNorm 7 | from .knn import knn 8 | from .knn_graph import knn_graph 9 | from .knn_interpolate import knn_interpolate 10 | from .mf_conv import MFConv 11 | from .radius import radius, radius_graph 12 | 13 | __all__ = [ 14 | 'Aggregation', 15 | 'ClusterGCNConv', 16 | 'HeteroLinear', 17 | 'InstanceNorm', 18 | 'knn', 19 | 'knn_graph', 20 | 'knn_interpolate', 21 | 'MFConv', 22 | 'radius', 23 | 'radius_graph', 24 | ] 25 | -------------------------------------------------------------------------------- /poptorch_geometric/python/ops/aggregation_base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | from typing import Optional 3 | from torch import Tensor 4 | import torch_geometric 5 | 6 | 7 | class Aggregation(torch_geometric.nn.aggr.Aggregation): 8 | def assert_sorted_index(self, index: Optional[Tensor]): 9 | pass 10 | -------------------------------------------------------------------------------- /poptorch_geometric/python/py.typed: -------------------------------------------------------------------------------- 1 | # Marker file for PEP 561. 2 | -------------------------------------------------------------------------------- /poptorch_geometric/python/pyg_collate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | from torch_geometric.loader.dataloader import Collater as PyGCollater 4 | 5 | 6 | # TODO: Upstream that change (default arguments) to PyG when upstreaming 7 | # DataLoaders. 8 | class Collater(PyGCollater): 9 | def __init__(self, follow_batch=None, exclude_keys=None): 10 | follow_batch = follow_batch or [] 11 | exclude_keys = exclude_keys or [] 12 | super().__init__(follow_batch, exclude_keys) 13 | -------------------------------------------------------------------------------- /poptorch_geometric/requirements.txt: -------------------------------------------------------------------------------- 1 | # Install pre-built wheels for PyTorch Geometric that are compatible with 2 | # poptorch which is currently pinned to torch 2.0.1 3 | --find-links https://data.pyg.org/whl/torch-2.0.1+cpu.html 4 | 5 | pyg-nightly==2.4.0.dev20230613 6 | 7 | torch-scatter==2.1.1+pt20cpu 8 | torch-sparse==0.6.17+pt20cpu 9 | torch-cluster==1.6.1+pt20cpu 10 | torch-spline-conv==1.2.2+pt20cpu 11 | 12 | pytest-benchmark==4.0.0 13 | pytest-cov==4.0.0 14 | nbconvert==7.2.9 15 | nbformat==5.7.3 16 | pandas==2.0.1 17 | 18 | singledispatchmethod==1.0; python_version < '3.8' 19 | -------------------------------------------------------------------------------- /poptorch_geometric/setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | license_files = 3 | License.txt 4 | poptorch_geometric_third_party_licenses.txt 5 | -------------------------------------------------------------------------------- /poptorch_geometric/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | import sys 3 | 4 | from setuptools import setup, find_packages 5 | 6 | REQUIRES = [ 7 | '@PYG_DEPENDENCY@', 8 | '@POPTORCH_DEPENDENCY@', 9 | '@TORCH_SCATTER_DEPENDENCY@', 10 | '@TORCH_SPARSE_DEPENDENCY@', 11 | ] 12 | 13 | python_version = f'{sys.version_info.major}.{sys.version_info.minor}' 14 | 15 | if python_version == '3.7': 16 | REQUIRES.append('singledispatchmethod==1.0') 17 | 18 | VERSION = '@VERSION@' 19 | 20 | LONG_DESCRIPTION = ( 21 | 'PopTorch Geometric is a set of extensions for PyTorch Geometric, enabling ' 22 | 'GNN models to be trained, evaluated and used on the Graphcore IPU.') 23 | 24 | setup(name='poptorch_geometric', 25 | version=VERSION, 26 | description=LONG_DESCRIPTION, 27 | long_description=LONG_DESCRIPTION, 28 | long_description_content_type='text/markdown', 29 | license='MIT License', 30 | license_files=('License.txt', 31 | 'poptorch_geometric_third_party_licenses.txt'), 32 | author='Graphcore Ltd.', 33 | author_email='contact@graphcore.ai', 34 | url='http://graphcore.ai', 35 | classifiers=[ 36 | 'Development Status :: 3 - Alpha', 37 | 'Intended Audience :: Developers', 38 | 'Intended Audience :: Science/Research', 39 | 'Topic :: Scientific/Engineering', 40 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 41 | 'License :: OSI Approved :: MIT License', 42 | 'Programming Language :: Python :: 3', 43 | ], 44 | platforms='@PLATFORM@', 45 | install_requires=REQUIRES, 46 | python_requires=f'=={python_version}.*', 47 | packages=find_packages()) 48 | -------------------------------------------------------------------------------- /poptorch_logging/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.14 FATAL_ERROR) 2 | project(poptorch_logging) 3 | 4 | set(CMAKE_POSITION_INDEPENDENT_CODE ON) 5 | 6 | find_package(spdlog 1.8.0 EXACT REQUIRED) 7 | 8 | # Packages provided by Poplar 9 | find_package(libpvti REQUIRED) 10 | find_package(gccs REQUIRED) 11 | 12 | add_library(poptorch_logging STATIC 13 | "source/Error.cpp" 14 | "source/Logging.cpp" 15 | "source/Tracepoint.cpp") 16 | 17 | file(GLOB_RECURSE poptorch_logging_public_headers "${CMAKE_CURRENT_SOURCE_DIR}/include/*.hpp*") 18 | 19 | set_target_properties(poptorch_logging PROPERTIES 20 | CXX_STANDARD 14 21 | PUBLIC_HEADER "${poptorch_logging_public_headers}") 22 | 23 | target_include_directories(poptorch_logging SYSTEM 24 | PUBLIC 25 | $ 26 | $) 27 | 28 | # Unfortunately, there seems to be an issue with using the `spdlog::*` targets 29 | # directly with `target_link_libraries()`, which breaks dependencies of 30 | # `poptorch_logging` adding any other include directories. Instead, we'll 31 | # manually add spdlog's include directories and compile definitions here. 32 | target_include_directories(poptorch_logging SYSTEM 33 | PUBLIC 34 | $) 35 | target_compile_definitions(poptorch_logging 36 | PUBLIC 37 | $) 38 | 39 | target_link_libraries(poptorch_logging 40 | PRIVATE 41 | libpvti 42 | gccs_stacktrace) 43 | 44 | install(TARGETS poptorch_logging 45 | LIBRARY 46 | DESTINATION ${CMAKE_INSTALL_LIBDIR} 47 | PUBLIC_HEADER 48 | DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/poptorch_logging) 49 | -------------------------------------------------------------------------------- /poptorch_logging/include/poptorch_logging/LoggingLight.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | #ifndef INCLUDE_POPTORCH_LOGGING_LIGHT_H 3 | #define INCLUDE_POPTORCH_LOGGING_LIGHT_H 4 | 5 | #include 6 | #include 7 | 8 | // This header is a lighter version of poptorch_logging which doesn't require 9 | // spdlog and therefore doesn't support formatting. 10 | // 11 | // For the full version of the logging API use 12 | // poptorch_logging/Logging.hpp instead. 13 | namespace poptorch { 14 | namespace logging { 15 | 16 | enum class Level { 17 | Trace = 0, 18 | Debug = 1, 19 | Info = 2, 20 | Warn = 3, 21 | Err = 4, 22 | // level 5 is "critical" in spdlog, which we don't use so isn't exposed here. 23 | Off = 6, 24 | }; 25 | 26 | // Set the current log level to one of the above levels. The default 27 | // log level is set by the POPTORCH_LOG_LEVEL environment variable 28 | // and is off by default. 29 | void setLogLevel(Level l); 30 | 31 | // Return true if the passed log level is currently enabled. 32 | bool shouldLog(Level l); 33 | 34 | // Return true if the Popart IR should be dumped. 35 | bool outputPopartIR(); 36 | 37 | // Return number of times logs should be allowed to repeat 38 | std::uint64_t repeatLimit(); 39 | 40 | void setRepeatLimit(std::uint64_t limit); 41 | 42 | // Flush the log. By default it is only flushed when the underlying libc 43 | // decides to. 44 | void flush(); 45 | 46 | // Log a message. You should probably use the MAKE_LOG_TEMPLATE macros 47 | // instead, e.g. logging::debug("A debug message"). 48 | void log(Level l, const char *msg); 49 | 50 | } // namespace logging 51 | } // namespace poptorch 52 | 53 | #endif // INCLUDE_POPTORCH_LOGGING_LIGHT_H 54 | -------------------------------------------------------------------------------- /poptorch_logging/source/Tracepoint.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | #include "poptorch_logging/Tracepoint.hpp" 3 | 4 | #include 5 | 6 | #include "poptorch_logging/Error.hpp" 7 | 8 | namespace poptorch { 9 | 10 | namespace logging { 11 | 12 | namespace detail { 13 | 14 | class TracepointImpl : public pvti::Tracepoint { 15 | public: 16 | explicit TracepointImpl(const std::string &label_) 17 | : pvti::Tracepoint(&TracepointImpl::channel, label_), ctx(label_) {} 18 | ~TracepointImpl() = default; 19 | static pvti::TraceChannel channel; 20 | LogContext ctx; 21 | }; 22 | 23 | pvti::TraceChannel TracepointImpl::channel = {"poptorch"}; 24 | } // namespace detail 25 | 26 | Tracepoint::Tracepoint(const char *label) 27 | : _impl(std::make_unique(std::string(label))) {} 28 | 29 | void Tracepoint::begin(const char *label) { 30 | pvti::Tracepoint::begin(&detail::TracepointImpl::channel, label); 31 | } 32 | 33 | void Tracepoint::end(const char *label) { 34 | pvti::Tracepoint::end(&detail::TracepointImpl::channel, label); 35 | } 36 | 37 | Tracepoint::~Tracepoint() = default; 38 | 39 | } // namespace logging 40 | } // namespace poptorch 41 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "python_version>=3.7", 4 | "setuptools>=42", 5 | "wheel", 6 | "pybind11>=2.8.0", 7 | "@TORCH_DEPENDENCY@", 8 | ] 9 | build-backend = "setuptools.build_meta" 10 | 11 | [tool.pytest.ini_options] 12 | # Required to supress a warning from the package `ruamel` using a deprecated pkg_resources function. 13 | filterwarnings = [ 14 | "ignore::DeprecationWarning:pkg_resources.*", 15 | # Deprecation warnings from pillow in torchvision. 16 | "ignore:.*Pillow.*:DeprecationWarning:torchvision", 17 | ] 18 | -------------------------------------------------------------------------------- /python/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | include(GNUInstallDirs) 2 | file(GLOB python_files "${CMAKE_CURRENT_SOURCE_DIR}/*.py") 3 | 4 | # __init__.py needs to be edited by set_version.py so don't copy it over. 5 | list(REMOVE_ITEM python_files "${CMAKE_CURRENT_SOURCE_DIR}/__init__.py") 6 | 7 | run_poptorch_install_command("python3 ${PROJECT_SOURCE_DIR}/scripts/set_version.py --torch-version ${TORCH_VERSION} ${CMAKE_CURRENT_BINARY_DIR}/__init__.py" "${PROJECT_SOURCE_DIR}" "Generate __init__.py") 8 | install(FILES ${CMAKE_CURRENT_BINARY_DIR}/__init__.py DESTINATION "${INSTALL_PYDIR}") 9 | install(FILES ${python_files} py.typed DESTINATION "${INSTALL_PYDIR}") 10 | 11 | # Compile the Pybind11 module using setup.py (Called by generate_python_package.py 12 | run_poptorch_install_command( 13 | "python3 ${PROJECT_SOURCE_DIR}/scripts/generate_python_package.py install --include-dir ${CMAKE_INSTALL_PREFIX}/include --lib-dir ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR} --output-dir ${CMAKE_INSTALL_PREFIX} --python-dir ${INSTALL_PYDIR}" "${PROJECT_SOURCE_DIR}" "poptorch_core.so module compilation") 14 | -------------------------------------------------------------------------------- /python/_options_config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2021 Graphcore Ltd. All rights reserved. 3 | 4 | import os 5 | import poptorch 6 | 7 | 8 | def parseAndSetOptions(options, filepath): 9 | cmds = [] 10 | with open(filepath) as f: 11 | filename = os.path.basename(f.name) 12 | prefix = "options." 13 | for line in f: 14 | # Remove whitespace 15 | stripped = line.strip() 16 | # Skip empty lines and comments 17 | if not stripped or stripped.startswith("#"): 18 | continue 19 | cmd = prefix + stripped 20 | cmds.append(cmd) 21 | 22 | code = "\n".join(cmds) 23 | try: 24 | # pylint: disable=exec-used 25 | exec(code, {}, {"poptorch": poptorch, "options": options}) 26 | except SyntaxError as err: 27 | err_class = err.__class__.__name__ 28 | detail = err.args[0] 29 | lineno = err.lineno 30 | line = err.text 31 | # pylint: disable=no-member 32 | raise poptorch.ConfigFileError("{} at line {} of {}: {}\n> {}".format( 33 | err_class, lineno, filename, detail, line)) 34 | -------------------------------------------------------------------------------- /python/_printing.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Graphcore Ltd. All rights reserved. 2 | import torch 3 | 4 | 5 | # Override torches repr function to provide information on the pre hooks as 6 | # well. The pre hooks is where BeginBlock is added 7 | def module_repr(m: torch.nn.Module): 8 | """ 9 | Provide a string representation of a torch.nn.Module along with the 10 | corresponding pre-hooks. 11 | 12 | This will show any BeginBlocks that have been added to the model which 13 | otherwise wouldn't be displayed. 14 | """ 15 | 16 | def _add_indent(s_, numSpaces): 17 | return f'\n{numSpaces}'.join(s_.split('\n')) 18 | 19 | # pylint: disable=protected-access 20 | 21 | # We treat the extra repr like the sub-module, one item per line 22 | extra_lines = [] 23 | extra_repr = m.extra_repr() 24 | # empty string will be split into list [''] 25 | if extra_repr: 26 | extra_lines = extra_repr.split('\n') 27 | child_lines = [] 28 | for key, module in m._modules.items(): 29 | mod_str = module_repr(module) 30 | mod_str = _add_indent(mod_str, 2) 31 | child_lines.append('(' + key + '): ' + mod_str) 32 | lines = extra_lines + child_lines 33 | 34 | pre_hooks = ''.join( 35 | map(lambda x: repr(x) + ' ', m._forward_pre_hooks.values())) 36 | 37 | main_str = pre_hooks + m._get_name() + '(' 38 | if lines: 39 | # simple one-liner info, which most builtin Modules will use 40 | if len(extra_lines) == 1 and not child_lines: 41 | main_str += extra_lines[0] 42 | else: 43 | main_str += '\n ' + '\n '.join(lines) + '\n' 44 | 45 | main_str += ')' 46 | return main_str 47 | 48 | 49 | _global_print = print 50 | 51 | 52 | def print(m): 53 | """ 54 | Prints a torch.nn.Module along with the corresponding pre-hooks. 55 | 56 | This will print any BeginBlocks that have been added to the model which 57 | otherwise wouldn't be displayed. 58 | """ 59 | if isinstance(m, torch.nn.Module): 60 | _global_print(module_repr(m)) 61 | _global_print(m) 62 | -------------------------------------------------------------------------------- /python/py.typed: -------------------------------------------------------------------------------- 1 | # Marker file for PEP 561. 2 | -------------------------------------------------------------------------------- /python/testing.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Graphcore Ltd. All rights reserved. 2 | 3 | import torch 4 | 5 | 6 | # Return true if both the structure and the content of ref and other match 7 | def allclose(ref, other): 8 | if isinstance(ref, torch.Tensor): 9 | return torch.allclose(other, ref) 10 | if isinstance(ref, tuple): 11 | if not isinstance(other, tuple) or len(ref) != len(other): 12 | return False 13 | elif isinstance(ref, list): 14 | if not isinstance(other, list) or len(ref) != len(other): 15 | return False 16 | else: 17 | assert "%s not supported" % type(ref) 18 | return all([allclose(r, other[i]) for i, r in enumerate(ref)]) 19 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # IMPORTANT: Keep requirements in sync with ./config.buildenv.py 2 | 3 | --extra-index-url https://download.pytorch.org/whl/cpu 4 | 5 | torch==2.0.1 6 | torchaudio==2.0.2 7 | torchvision==0.15.2 8 | 9 | expecttest==0.1.3 10 | lit==0.11.1 11 | pytest==6.2.5 12 | setuptools==58.0.4 13 | tqdm==4.46.1 14 | transformers==4.12.2 15 | typing-extensions==4.1.1 16 | # Use old version for wheel.pep425tags support (new versions removed it). 17 | wheel<0.35 18 | 19 | -r poptorch_geometric/requirements.txt 20 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore/poptorch/c2a8b17762f1d7106f7bf048ff011ab4fee7685e/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/download_external_datasets.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 3 | 4 | import argparse 5 | import os.path as osp 6 | import torch_geometric as pyg 7 | 8 | parser = argparse.ArgumentParser(description="Download external datasets") 9 | parser.add_argument( 10 | "external_datasets_dir", 11 | help="The directory where the external datasets will be downloaded.") 12 | 13 | args = parser.parse_args() 14 | 15 | pyg.datasets.QM9(root=osp.join(args.external_datasets_dir, "qm9")) 16 | pyg.datasets.Planetoid(osp.join(args.external_datasets_dir, "planetoid"), 17 | "Cora") 18 | -------------------------------------------------------------------------------- /scripts/enable.sh.in: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export PYTHONPATH=@CMAKE_INSTALL_PREFIX@:$PYTHONPATH 3 | @ENABLE_POPLAR_CMD@ 4 | @ENABLE_POPART_CMD@ 5 | -------------------------------------------------------------------------------- /scripts/popgen/poptorch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Graphcore Ltd. All rights reserved 2 | 3 | # signatures for manually added operators 4 | signatures = { 5 | 'beginIpuBlock': [['clong'], ['clong'], ['clong']], 6 | 'cast': ['Args', ['scalar_type']], 7 | 'internalCast': ['Args', ['cstr']], 8 | 'constantPad': ['Args', ['clong_list'], ['cfloat']], 9 | 'edgePad': ['Args', ['clong_list']], 10 | 'optimizerGroup': [['clong'], ['tensor_list']], 11 | 'printIpuTensor': ['Args', ['cstr']], 12 | 'callCpuOp': [['tensor_list'], ['cstr'], ['node']], 13 | 'randomNormal': [ 14 | 'Args', ['tensor_shape'], ['cfloat'], ['cfloat'], 15 | ['scalar_type', 'None'] 16 | ], 17 | 'randomUniform': [ 18 | 'Args', ['tensor_shape'], ['cfloat'], ['cfloat'], 19 | ['scalar_type', 'None'] 20 | ], 21 | 'recomputationCheckpoint': ['Args'], 22 | 'reflectionPad': ['Args', ['clong_list']], 23 | 'setAvailableMemory': ['Args', ['cfloat']], 24 | 'setMatMulSerialization': ['Args', ['cstr'], ['clong'], ['cint']], 25 | 'startForLoop': ['Args'], 26 | 'endForLoop': ['Args', ['clong']], 27 | 'startIfBlock': ['Args'], 28 | 'startElseBlock': ['Args'], 29 | 'endIfBlock': ['Args'], 30 | } 31 | -------------------------------------------------------------------------------- /scripts/set_version.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2020 Graphcore Ltd. All rights reserved. 3 | import argparse 4 | import logging 5 | import os 6 | 7 | from utils import _utils 8 | 9 | logger = logging.getLogger(os.path.basename(__file__)) 10 | _utils.set_logger(logger) 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--debug", 15 | "-d", 16 | action="store_true", 17 | help="Print debug messages") 18 | parser.add_argument("--torch-version", type=str) 19 | parser.add_argument("--input-file", type=str) 20 | parser.add_argument("output", help="File to create") 21 | 22 | args = parser.parse_args() 23 | 24 | logging_level = logging.DEBUG if args.debug else logging.INFO 25 | logging.basicConfig(level=logging_level) 26 | logger.debug("Args: %s", str(args)) 27 | 28 | pkg_info = _utils.PkgInfo.load_from_file(must_exist=False) 29 | 30 | # Copy the content of python/__init__.py and replace the occurrences of 31 | # @VERSION@ / @SNAPSHOT@ with the actual version / snapshot 32 | with open(args.output, "w") as f: 33 | if args.input_file is None: 34 | args.input_file = os.path.join(_utils.sources_dir(), "python", 35 | "__init__.py") 36 | for line in open(args.input_file): 37 | line = line.replace("@VERSION@", pkg_info.version_long) 38 | line = line.replace("@SNAPSHOT@", pkg_info.snapshot) 39 | if args.torch_version: 40 | line = line.replace("@TORCH_VERSION@", args.torch_version) 41 | f.write(line) 42 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | license_files = 3 | License.txt 4 | poptorch_third_party_licenses.txt 5 | -------------------------------------------------------------------------------- /tests/.gitignore: -------------------------------------------------------------------------------- 1 | .datasets 2 | -------------------------------------------------------------------------------- /tests/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(custom_ops) 2 | add_subdirectory(cpp) 3 | 4 | # Copy tests to the build folder if requested. 5 | if(COPY_TESTS) 6 | # NOTE: Collapsing the hierarchy like this may cause conflicts. 7 | file(GLOB_RECURSE TEST_FILES "${CMAKE_CURRENT_SOURCE_DIR}/*.py") 8 | install(FILES ${TEST_FILES} DESTINATION "${CMAKE_CURRENT_BINARY_DIR}") 9 | set(TESTS_PATH "${CMAKE_CURRENT_BINARY_DIR}") 10 | else() 11 | set(TESTS_PATH "${CMAKE_CURRENT_SOURCE_DIR}") 12 | endif() 13 | 14 | set(EXTERNAL_DATASETS_DIR "${CMAKE_BINARY_DIR}/buildenv/external_datasets") 15 | 16 | # Generate tests. 17 | run_poptorch_install_command( 18 | "python3 ${CMAKE_CURRENT_SOURCE_DIR}/generate_test_file.py \ 19 | ${TESTS_PATH} \ 20 | ${CMAKE_CURRENT_BINARY_DIR}/CTestTestfile.cmake \ 21 | --add-to-sys-path ${CMAKE_INSTALL_PREFIX} \ 22 | --external-datasets-dir ${EXTERNAL_DATASETS_DIR} \ 23 | --extra-pytest-args=\"${EXTRA_PYTEST_ARGS}\" " 24 | "${PROJECT_BINARY_DIR}" 25 | "generate_test_file.py") 26 | -------------------------------------------------------------------------------- /tests/bool_support_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2021 Graphcore Ltd. All rights reserved. 3 | 4 | import torch 5 | import pytest 6 | import helpers 7 | import poptorch 8 | 9 | # Not need for mean or logsumexp 10 | reduce_ops = [torch.sum, torch.prod] 11 | test_tensors = [ 12 | torch.tensor([1.0, 2.0, 3.1]), 13 | torch.tensor([1.1, 2.0, 3.0]), 14 | torch.tensor([0.0, 0.0, 0.0]) 15 | ] 16 | 17 | 18 | @pytest.mark.parametrize("op", reduce_ops) 19 | @pytest.mark.parametrize("t_1", test_tensors) 20 | @pytest.mark.parametrize("t_2", test_tensors) 21 | def test_reduce_two_bool_types(op, t_1, t_2): 22 | class Model(torch.nn.Module): 23 | def forward(self, x, y): 24 | return op(x == y) 25 | 26 | model = Model() 27 | 28 | poptorch_model = poptorch.inferenceModel(model) 29 | native_out = model(t_1, t_2) 30 | poptorch_out = poptorch_model(t_1, t_2) 31 | #expected = no dims (scalar) 32 | helpers.assert_allclose(actual=poptorch_out, expected=native_out) 33 | 34 | assert native_out.dtype == torch.int64 35 | assert poptorch_out.dtype == torch.int32 36 | 37 | 38 | def test_logits(): 39 | class Model(torch.nn.Module): 40 | def forward(self, logits, y): 41 | acc = torch.sum(torch.argmax(logits, -1) == y) / float(y.size(0)) 42 | return acc 43 | 44 | model = Model() 45 | 46 | logits = torch.tensor([[1.0, 2.0, 3.0], [3.0, 1.0, 2.0], [2.0, 3.0, 1.0]]) 47 | y = torch.tensor([[0], [2], [1]]) 48 | 49 | poptorch_model = poptorch.inferenceModel(model) 50 | native_out = model(logits, y) 51 | poptorch_out = poptorch_model(logits, y) 52 | 53 | helpers.assert_allclose(actual=poptorch_out, expected=native_out) 54 | -------------------------------------------------------------------------------- /tests/cpp/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | set(Boost_USE_STATIC_LIBS OFF) 4 | set(Boost_USE_MULTITHREADED ON) 5 | set(Boost_USE_STATIC_RUNTIME OFF) 6 | find_package(Boost 1.70 REQUIRED COMPONENTS unit_test_framework) 7 | 8 | # Ensure ABI matches that of PyTorch 9 | add_definitions(${TORCH_CXX_FLAGS}) 10 | 11 | function(add_unit_test name) 12 | add_executable(${name} ${ARGN}) 13 | 14 | target_link_libraries(${name} Boost::unit_test_framework torch poptorch poptorch_logging pthread) 15 | 16 | target_include_directories(${name} PRIVATE 17 | ${CMAKE_SOURCE_DIR}/poptorch/source/include/) 18 | 19 | add_test(${name} ${name}) 20 | 21 | endfunction() 22 | 23 | add_unit_test(GNNOptimizationsTest GNNOptimizationsTest.cpp) 24 | -------------------------------------------------------------------------------- /tests/ctc_decoder_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2021 Graphcore Ltd. All rights reserved. 3 | 4 | import torch 5 | import poptorch 6 | 7 | 8 | class SimpleModel(torch.nn.Module): 9 | def forward(self, log_probs, lengths): 10 | return poptorch.ctc_beam_search_decoder(log_probs, lengths) 11 | 12 | 13 | def test_ctc_decoder(): 14 | input_size = 9 15 | batch_size = 3 16 | num_classes = 10 17 | 18 | torch.manual_seed(42) 19 | log_probs = torch.randn(input_size, batch_size, num_classes) 20 | lengths = torch.randint(5, input_size, (batch_size, ), dtype=torch.int) 21 | 22 | model = SimpleModel() 23 | poptorch_model = poptorch.inferenceModel(model) 24 | 25 | result = poptorch_model(log_probs, lengths) 26 | 27 | # note we have no reference implementation so only the most basic 28 | # test is possible - relying on popart/poplibs which are validated 29 | # against tensorflow's implementation 30 | assert result[0].shape == (batch_size, 1) 31 | assert result[1].shape == (batch_size, 1) 32 | assert result[2].shape == (batch_size, 1, input_size) 33 | -------------------------------------------------------------------------------- /tests/gnn/.gitignore: -------------------------------------------------------------------------------- 1 | .datasets 2 | -------------------------------------------------------------------------------- /tests/gnn/benchgnn/README.md: -------------------------------------------------------------------------------- 1 | # benchgnn 2 | 3 | Benchmark tool for testing performance of GNN models 4 | 5 | ## Usage example 6 | 7 | ``benchgnn --dataset FakeDataset --model GAT --bs 1 100 --cpu --output outfile`` 8 | 9 | Type ``benchgnn --help`` to print detailed information about supported options. 10 | -------------------------------------------------------------------------------- /tests/gnn/benchgnn/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | 3 | import os.path as osp 4 | 5 | from torch_geometric import seed_everything 6 | from torch_geometric.datasets import Entities 7 | from torch_geometric.datasets import FakeDataset as FDS 8 | from torch_geometric.datasets import Planetoid 9 | from torch_geometric.transforms import Compose, GCNNorm, NormalizeFeatures 10 | 11 | 12 | class DataSets: 13 | def __init__(self, root): 14 | self.root = root 15 | 16 | def Cora(self): 17 | return Planetoid(osp.join(self.root, 'Cora'), 'Cora') 18 | 19 | def CiteSeer(self): 20 | return Planetoid(osp.join(self.root, 'CiteSeer'), 'CiteSeer') 21 | 22 | def PubMed(self): 23 | return Planetoid(osp.join(self.root, 'PubMed'), 'PubMed') 24 | 25 | def mutag(self): 26 | return Entities(osp.join(self.root, 'EntitiesMUTAG'), 'mutag') 27 | 28 | def FakeDataset(self): 29 | seed_everything(0) 30 | 31 | transform = Compose([GCNNorm(), NormalizeFeatures()]) 32 | 33 | dataset = FDS( 34 | num_graphs=1000, 35 | avg_num_nodes=16, 36 | avg_degree=5, 37 | transform=transform, 38 | num_channels=64, 39 | ) 40 | setattr(dataset, 'name', 'FakeDataset') 41 | return dataset 42 | -------------------------------------------------------------------------------- /tests/gnn/benchgnn/requirements.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | pytest-benchmark 3 | pytest-cov 4 | nbformat 5 | nbconvert 6 | pandas 7 | rdflib 8 | tabulate 9 | -------------------------------------------------------------------------------- /tests/gnn/benchgnn/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | 3 | from tabulate import _table_formats, tabulate 4 | 5 | all_formats = sorted(list(_table_formats.keys())) 6 | 7 | 8 | def merge_results(results, prev_results): 9 | if prev_results: 10 | keys = prev_results[0].keys() 11 | time_keys = {'ipu_time', 'gpu_time', 'cpu_time'}.intersection(keys) 12 | for curr, prev in zip(results, prev_results): 13 | for key in time_keys: 14 | curr['prev_' + key.split('_')[0]] = prev[key] 15 | return results 16 | 17 | 18 | def include_speedups_ratio(results): 19 | keys = list(results[0].keys()) 20 | 21 | # Calculate speedup over other times 22 | if 'ipu_time' in keys: 23 | other = filter(lambda x: x in keys, 24 | ('cpu_time', 'prev_cpu', 'prev_gpu', 'prev_ipu')) 25 | for t in other: 26 | for res in results: 27 | res['ipu/' + t] = res[t] / res["ipu_time"] 28 | 29 | return results 30 | 31 | 32 | def print_results(results, format): 33 | results = include_speedups_ratio(results) 34 | 35 | content = [list(results[0].keys())] 36 | prev_model = None 37 | for res in results: 38 | curr_model = res['model'] 39 | if prev_model != curr_model: 40 | if prev_model is not None: 41 | content.append([]) 42 | prev_model = curr_model 43 | else: 44 | res['model'] = '' 45 | 46 | row = [f'{x:.2f}' if isinstance(x, float) else x for x in res.values()] 47 | 48 | content.append(row) 49 | 50 | body = tabulate(content, headers='firstrow', tablefmt=format) 51 | print('\n', body, sep='') 52 | -------------------------------------------------------------------------------- /tests/gnn/benchgnn_ops/README.md: -------------------------------------------------------------------------------- 1 | # benchgnn 2 | 3 | Benchmark tool for testing performance of GNN operators 4 | 5 | ## Usage example 6 | 7 | Running single benchmark test case scenario from command line: 8 | ``python3 benchgnn_ops.py --num_sample_rounds 10 scatter --src_shape [1,12] --input_shape [1,12] --index_shape [1,12] --dim 0`` 9 | 10 | Running multiple benchmark test case scenarios from yaml configuration files from given directory: 11 | ``python3 benchgnn_ops.py --common_config=example_configs/common.yaml --config_dir=example_configs`` 12 | 13 | Running multiple benchmark test case scenarios from given yaml configuration files: 14 | ``python3 benchgnn_ops.py --common_config=example_configs/common.yaml --config_files=[example_configs/scatter_testcase1.yaml,example_configs/scatter_testcase2.yaml]`` 15 | 16 | Running multiple benchmark test case scenarios - combining all available options: 17 | ``python3 benchgnn_ops.py --common_config=example_configs/common.yaml --config_dir=example_configs --config_files=[example_configs/scatter_testcase1.yaml,example_configs/scatter_testcase2.yaml] scatter --src_shape [1,12] --input_shape [1,12] --index_shape [1,12] --dim 0`` 18 | 19 | Type ``python3 benchgnn_ops.py --help`` to print detailed information about supported options. 20 | -------------------------------------------------------------------------------- /tests/gnn/benchgnn_ops/example_configs/common.yaml: -------------------------------------------------------------------------------- 1 | num_sample_rounds: 25 2 | compile_options: 3 | num_repeats: 100 4 | -------------------------------------------------------------------------------- /tests/gnn/benchgnn_ops/example_configs/scatter_testcase1.yaml: -------------------------------------------------------------------------------- 1 | scatter: 2 | src_shape: [1,120] 3 | input_shape: [1,120] 4 | index_shape: [1,120] 5 | dim: 0 6 | -------------------------------------------------------------------------------- /tests/gnn/benchgnn_ops/example_configs/scatter_testcase2.yaml: -------------------------------------------------------------------------------- 1 | scatter: 2 | src_shape: [1,12] 3 | input_shape: [1,12] 4 | index_shape: [1,12] 5 | dim: 0 6 | -------------------------------------------------------------------------------- /tests/gnn/benchgnn_ops/requirements.txt: -------------------------------------------------------------------------------- 1 | jsonargparse==4.19.0 2 | docstring-parser==0.15 3 | tqdm==4.64.1 4 | -------------------------------------------------------------------------------- /tests/gnn/nn/aggr/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | import pytest 4 | from torch_geometric import seed_everything 5 | from torch_geometric.datasets import FakeDataset 6 | from torch_geometric.transforms import NormalizeFeatures 7 | 8 | from poptorch_geometric.dataloader import FixedSizeDataLoader 9 | from poptorch_geometric.fixed_size_options import FixedSizeOptions 10 | from poptorch_geometric.pyg_dataloader import FixedSizeStrategy 11 | 12 | 13 | @pytest.fixture 14 | def dataloader(): 15 | seed_everything(42) 16 | 17 | dataset = FakeDataset(num_graphs=4, 18 | avg_num_nodes=8, 19 | avg_degree=3, 20 | transform=NormalizeFeatures(), 21 | num_channels=8) 22 | 23 | dataloader = FixedSizeDataLoader( 24 | dataset, 25 | fixed_size_options=FixedSizeOptions(num_nodes=12, num_edges=32), 26 | fixed_size_strategy=FixedSizeStrategy.StreamPack, 27 | add_pad_masks=True) 28 | 29 | return dataloader 30 | -------------------------------------------------------------------------------- /tests/gnn/nn/aggr/test_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | import torch 4 | from torch_geometric.nn import MLP 5 | from torch_geometric.nn.aggr import AttentionalAggregation 6 | 7 | from aggr_utils import aggr_harness 8 | 9 | 10 | def test_attentional_aggregation(dataloader): 11 | first_sample = next(iter(dataloader)) 12 | in_channels = first_sample.num_node_features 13 | out_channels = in_channels * 2 14 | 15 | gate_nn = MLP([in_channels, 1], act='relu') 16 | nn = MLP([in_channels, in_channels], act='relu') 17 | aggr = AttentionalAggregation(gate_nn, nn) 18 | post_proc = torch.nn.Linear(in_channels, out_channels) 19 | 20 | aggr_harness(aggr, 21 | first_sample.num_nodes, 22 | dataloader, 23 | post_proc, 24 | atol=1e-3, 25 | rtol=5e-3) 26 | -------------------------------------------------------------------------------- /tests/gnn/nn/aggr/test_deep_sets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | from torch_geometric.nn import DeepSetsAggregation, Linear 4 | 5 | from aggr_utils import aggr_harness 6 | 7 | 8 | def test_deep_sets_aggregation(dataloader): 9 | first_sample = next(iter(dataloader)) 10 | channels = first_sample.num_node_features 11 | 12 | aggr = DeepSetsAggregation( 13 | local_nn=Linear(channels, channels * 2), 14 | global_nn=Linear(channels * 2, channels * 4), 15 | ) 16 | aggr.reset_parameters() 17 | 18 | aggr_harness(aggr, first_sample.num_nodes, dataloader) 19 | -------------------------------------------------------------------------------- /tests/gnn/nn/aggr/test_equilibrium.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | import pytest 4 | from torch_geometric.nn import EquilibriumAggregation 5 | 6 | from aggr_utils import aggr_harness 7 | 8 | 9 | @pytest.mark.skip(reason="TODO(AFS-354)") 10 | @pytest.mark.parametrize('grad_iter', [0, 1, 5]) 11 | def test_equilibrium(dataloader, grad_iter): 12 | first_sample = next(iter(dataloader)) 13 | channels = first_sample.num_node_features 14 | 15 | aggr = EquilibriumAggregation(channels, 16 | channels // 2, 17 | num_layers=[10, 10], 18 | grad_iter=grad_iter) 19 | 20 | aggr_harness(aggr, first_sample.num_nodes, dataloader) 21 | -------------------------------------------------------------------------------- /tests/gnn/nn/aggr/test_fused.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | import pytest 4 | import torch 5 | from torch_geometric.nn.aggr.fused import FusedAggregation 6 | 7 | from aggr_utils import aggr_harness 8 | 9 | 10 | @pytest.mark.parametrize('aggrs', [ 11 | ['sum', 'mean', 'min', 'max', 'mul', 'var', 'std'], 12 | ['sum', 'min', 'max', 'mul', 'var', 'std'], 13 | ['min', 'max', 'mul', 'var', 'std'], 14 | ['mean', 'min', 'max', 'mul', 'var', 'std'], 15 | ['sum', 'min', 'max', 'mul', 'std'], 16 | ['mean', 'min', 'max', 'mul', 'std'], 17 | ['min', 'max', 'mul', 'std'], 18 | ]) 19 | def test_fused_aggregation(dataloader, aggrs): 20 | first_sample = next(iter(dataloader)) 21 | in_channels = first_sample.num_node_features 22 | out_channels = in_channels * 2 23 | 24 | aggr = FusedAggregation(aggrs) 25 | post_proc = torch.nn.Linear(in_channels, out_channels) 26 | 27 | aggr_harness(aggr, first_sample.num_nodes, dataloader, post_proc) 28 | -------------------------------------------------------------------------------- /tests/gnn/nn/aggr/test_gmt.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | import pytest 4 | from torch_geometric.nn.aggr import GraphMultisetTransformer 5 | 6 | from aggr_utils import aggr_harness 7 | 8 | 9 | @pytest.mark.skip(reason="TODO(AFS-351)") 10 | def test_graph_multiset_transformer(dataloader): 11 | first_sample = next(iter(dataloader)) 12 | print(first_sample) 13 | print(first_sample.num_nodes) 14 | channels = first_sample.num_node_features 15 | aggr = GraphMultisetTransformer(channels, k=2, heads=2) 16 | aggr.reset_parameters() 17 | 18 | aggr_harness(aggr, 19 | first_sample.num_nodes, 20 | dataloader, 21 | sorted_index=True, 22 | enable_fp_exception=False, 23 | equal_nan=True) 24 | -------------------------------------------------------------------------------- /tests/gnn/nn/aggr/test_gru.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | from torch_geometric.nn import GRUAggregation 4 | 5 | from aggr_utils import aggr_harness 6 | 7 | 8 | def test_gru_aggregation(dataloader): 9 | first_sample = next(iter(dataloader)) 10 | channels = first_sample.num_node_features 11 | 12 | aggr = GRUAggregation(channels, channels * 2) 13 | aggr.reset_parameters() 14 | 15 | aggr_harness(aggr, first_sample.num_nodes, dataloader, sorted_index=True) 16 | -------------------------------------------------------------------------------- /tests/gnn/nn/aggr/test_lstm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | from torch_geometric.nn import LSTMAggregation 4 | 5 | from aggr_utils import aggr_harness 6 | 7 | 8 | def test_lstm_aggregation(dataloader): 9 | first_sample = next(iter(dataloader)) 10 | channels = first_sample.num_node_features 11 | 12 | aggr = LSTMAggregation(channels, channels * 2) 13 | 14 | aggr_harness(aggr, first_sample.num_nodes, dataloader, sorted_index=True) 15 | -------------------------------------------------------------------------------- /tests/gnn/nn/aggr/test_mlp_aggr.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | from torch_geometric.nn import MLPAggregation 4 | 5 | from aggr_utils import aggr_harness 6 | 7 | 8 | def test_mlp_aggregation(dataloader): 9 | first_sample = next(iter(dataloader)) 10 | channels = first_sample.num_node_features 11 | 12 | aggr = MLPAggregation( 13 | in_channels=channels, 14 | out_channels=channels * 2, 15 | max_num_elements=first_sample.num_nodes, 16 | num_layers=1, 17 | ) 18 | 19 | aggr_harness(aggr, first_sample.num_nodes, dataloader, sorted_index=True) 20 | -------------------------------------------------------------------------------- /tests/gnn/nn/aggr/test_multi.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | import pytest 4 | import torch 5 | from torch_geometric.nn import MultiAggregation 6 | 7 | from aggr_utils import aggr_harness 8 | 9 | 10 | @pytest.mark.parametrize('mode', [ 11 | 'cat', 'proj', 'attn', 'sum', 'mean', 'max', 'min', 'logsumexp', 'std', 12 | 'var' 13 | ]) 14 | def test_multi_aggr(dataloader, mode): 15 | first_sample = next(iter(dataloader)) 16 | in_channels = first_sample.num_node_features 17 | out_channels = in_channels * 2 18 | 19 | mode_kwargs = None 20 | if mode == 'proj': 21 | mode_kwargs = dict(in_channels=in_channels, out_channels=in_channels) 22 | elif mode == 'attn': 23 | mode_kwargs = dict(in_channels=in_channels, 24 | out_channels=in_channels, 25 | num_heads=in_channels // 4) 26 | 27 | aggrs = ['mean', 'sum', 'max'] 28 | aggr = MultiAggregation(aggrs, mode=mode, mode_kwargs=mode_kwargs) 29 | aggr.reset_parameters() 30 | 31 | if mode == 'cat': 32 | # The 'cat' combine mode will expand the output dimensions 33 | # the number of aggregators. 34 | in_channels = in_channels * len(aggrs) 35 | out_channels = out_channels * len(aggrs) 36 | 37 | post_proc = torch.nn.Linear(in_channels, out_channels) 38 | 39 | aggr_harness(aggr, 40 | first_sample.num_nodes, 41 | dataloader, 42 | post_proc, 43 | atol=1e-3) 44 | -------------------------------------------------------------------------------- /tests/gnn/nn/aggr/test_quantile.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | import pytest 4 | import torch 5 | from torch_geometric.nn import MedianAggregation, QuantileAggregation 6 | 7 | from aggr_utils import aggr_harness 8 | 9 | 10 | @pytest.mark.parametrize('q', [0., .1, .2, .3, .4, .5, .6, .7, .8, .9, 1.]) 11 | @pytest.mark.parametrize('interpolation', QuantileAggregation.interpolations) 12 | def test_quantile_aggregation(dataloader, q, interpolation): 13 | torch.manual_seed(42) 14 | first_sample = next(iter(dataloader)) 15 | in_channels = first_sample.num_node_features 16 | out_channels = in_channels * 2 17 | 18 | aggr = QuantileAggregation(q=q, interpolation=interpolation) 19 | post_proc = torch.nn.Linear(in_channels, out_channels) 20 | 21 | aggr_harness(aggr, 22 | first_sample.num_nodes, 23 | dataloader, 24 | post_proc, 25 | sorted_index=True) 26 | 27 | 28 | def test_median_aggregation(dataloader): 29 | first_sample = next(iter(dataloader)) 30 | in_channels = first_sample.num_node_features 31 | out_channels = in_channels * 2 32 | 33 | aggr = MedianAggregation() 34 | post_proc = torch.nn.Linear(in_channels, out_channels) 35 | 36 | aggr_harness(aggr, first_sample.num_nodes, dataloader, post_proc) 37 | -------------------------------------------------------------------------------- /tests/gnn/nn/aggr/test_scaler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | import pytest 4 | import torch 5 | from torch_geometric.nn import DegreeScalerAggregation 6 | 7 | from aggr_utils import aggr_harness 8 | 9 | 10 | @pytest.mark.parametrize('scaler', 11 | [['identity'], ['amplification'], ['attenuation'], 12 | ['linear'], ['inverse_linear']]) 13 | @pytest.mark.parametrize('train_norm', [True, False]) 14 | def test_degree_scaler_aggregation(dataloader, scaler, train_norm): 15 | 16 | first_sample = next(iter(dataloader)) 17 | in_channels = first_sample.num_node_features 18 | out_channels = in_channels * 2 19 | 20 | deg = torch.tensor([2, 5, 3, 1, 2, 3, 4, 1, 2, 0]) 21 | 22 | basic_aggrs = ['mean', 'sum', 'max'] 23 | aggr = DegreeScalerAggregation(basic_aggrs, 24 | scaler, 25 | deg, 26 | train_norm=train_norm) 27 | post_proc = torch.nn.Linear(in_channels * len(basic_aggrs), 28 | out_channels * len(basic_aggrs)) 29 | 30 | aggr_harness(aggr, first_sample.num_nodes, dataloader, post_proc) 31 | -------------------------------------------------------------------------------- /tests/gnn/nn/aggr/test_set2set.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | from torch_geometric.nn.aggr import Set2Set 4 | 5 | from aggr_utils import aggr_harness 6 | 7 | 8 | def test_set2set(dataloader): 9 | first_sample = next(iter(dataloader)) 10 | channels = first_sample.num_node_features 11 | 12 | aggr = Set2Set(in_channels=channels, processing_steps=1) 13 | 14 | aggr_harness(aggr, first_sample.num_nodes, dataloader) 15 | -------------------------------------------------------------------------------- /tests/gnn/nn/aggr/test_set_transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | import pytest 4 | from torch_geometric.nn.aggr import SetTransformerAggregation 5 | 6 | from aggr_utils import aggr_harness 7 | 8 | 9 | @pytest.mark.skip(reason="TODO(AFS-351)") 10 | def test_set_transformer_aggregation(dataloader): 11 | first_sample = next(iter(dataloader)) 12 | channels = first_sample.num_node_features 13 | 14 | aggr = SetTransformerAggregation(channels, num_seed_points=2, heads=2) 15 | aggr.reset_parameters() 16 | 17 | aggr_harness(aggr, 18 | first_sample.num_nodes, 19 | dataloader, 20 | sorted_index=True, 21 | enable_fp_exception=False, 22 | equal_nan=True) 23 | -------------------------------------------------------------------------------- /tests/gnn/nn/aggr/test_sort.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | import torch 4 | from torch_geometric.nn.aggr import SortAggregation 5 | 6 | from aggr_utils import aggr_harness 7 | 8 | 9 | def test_sort_aggregation(dataloader): 10 | first_sample = next(iter(dataloader)) 11 | in_channels = first_sample.num_node_features 12 | out_channels = in_channels * 2 13 | 14 | k = 5 15 | aggr = SortAggregation(k=k) 16 | post_proc = torch.nn.Linear(k * in_channels, k * out_channels) 17 | 18 | aggr_harness(aggr, 19 | first_sample.num_nodes, 20 | dataloader, 21 | post_proc, 22 | sorted_index=True) 23 | -------------------------------------------------------------------------------- /tests/gnn/nn/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | import pytest 4 | 5 | from torch_geometric import seed_everything 6 | from torch_geometric.datasets import FakeDataset 7 | from torch_geometric.transforms import Compose, GCNNorm, NormalizeFeatures 8 | 9 | from poptorch_geometric.dataloader import FixedSizeDataLoader, DataLoader 10 | from poptorch_geometric.fixed_size_options import FixedSizeOptions 11 | from poptorch_geometric.pyg_dataloader import FixedSizeStrategy 12 | 13 | 14 | def get_dataset(num_channels=16): 15 | seed_everything(0) 16 | transform = Compose([GCNNorm(), NormalizeFeatures()]) 17 | 18 | dataset = FakeDataset(avg_num_nodes=32, 19 | avg_degree=5, 20 | transform=transform, 21 | num_channels=num_channels) 22 | data = dataset[0] 23 | data.num_classes = dataset.num_classes 24 | 25 | return data 26 | 27 | 28 | @pytest.fixture 29 | def dataset(): 30 | return get_dataset() 31 | 32 | 33 | @pytest.fixture 34 | def fake_dataset(): 35 | seed_everything(0) 36 | 37 | dataset = FakeDataset(num_graphs=4, 38 | avg_num_nodes=8, 39 | avg_degree=3, 40 | transform=NormalizeFeatures(), 41 | num_channels=10) 42 | return dataset 43 | 44 | 45 | @pytest.fixture 46 | def fixed_size_dataloader(fake_dataset): 47 | dataloader = FixedSizeDataLoader( 48 | fake_dataset, 49 | fixed_size_options=FixedSizeOptions(num_nodes=12), 50 | fixed_size_strategy=FixedSizeStrategy.StreamPack, 51 | add_pad_masks=True) 52 | return dataloader 53 | 54 | 55 | @pytest.fixture 56 | def dataloader(fake_dataset): 57 | dataloader = DataLoader(fake_dataset, shuffle=False) 58 | return dataloader 59 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_agnn_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | from torch_geometric.nn import AGNNConv 3 | from conv_utils import conv_harness 4 | 5 | conv_kwargs = {"add_self_loops": False} 6 | 7 | 8 | def test_agnn_conv(dataset): 9 | conv = AGNNConv(**conv_kwargs) 10 | 11 | conv_harness(conv, dataset) 12 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_antisymmetric_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | from torch_geometric.nn import AntiSymmetricConv 3 | from torch_geometric.nn.conv import GCNConv 4 | 5 | from conv_utils import conv_harness 6 | 7 | 8 | def test_antisymmetric_conv(dataset): 9 | in_channels = dataset.num_node_features 10 | phi = GCNConv(in_channels, in_channels, bias=False, add_self_loops=False) 11 | conv = AntiSymmetricConv(in_channels, phi=phi) 12 | 13 | conv_harness(conv, dataset) 14 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_appnp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | import torch 3 | from torch_geometric.nn import APPNP 4 | from conv_utils import conv_harness 5 | 6 | out_channels = 16 7 | conv_kwargs = {"add_self_loops": False} 8 | 9 | 10 | def test_appnp(dataset): 11 | in_channels = dataset.num_node_features 12 | lin = torch.nn.Linear(in_channels, out_channels) 13 | conv = APPNP(K=10, alpha=0.1, dropout=0.0, **conv_kwargs) 14 | 15 | conv_harness(conv, dataset, post_proc=lin) 16 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_arma_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | from torch_geometric.nn import ARMAConv 3 | from conv_utils import conv_harness 4 | 5 | out_channels = 32 6 | 7 | 8 | def test_arma_conv(dataset): 9 | in_channels = dataset.num_node_features 10 | conv = ARMAConv(in_channels, out_channels, num_stacks=8, num_layers=4) 11 | 12 | conv_harness(conv, dataset) 13 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_cg_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | import pytest 3 | from torch_geometric.nn import CGConv 4 | from conv_utils import conv_harness 5 | 6 | 7 | @pytest.mark.parametrize('batch_norm', [False]) 8 | def test_cg_conv(dataset, batch_norm): 9 | in_channels = dataset.num_node_features 10 | conv = CGConv(in_channels, batch_norm=batch_norm) 11 | 12 | conv_harness(conv, dataset) 13 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_cheb_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import torch 3 | import pytest 4 | from torch_geometric.nn import ChebConv 5 | 6 | from conv_utils import conv_harness 7 | 8 | 9 | @pytest.mark.skip( 10 | reason="ChebConv won't work, because algorithm requires removing " 11 | "self loops and we are adding self loops to ensure that " 12 | "tensors have fixed size.") 13 | def test_cheb_conv(dataset): 14 | in_channels = dataset.num_node_features 15 | out_channels = 32 16 | conv = ChebConv(in_channels, out_channels, K=3, add_self_loops=False) 17 | conv_harness(conv, dataset) 18 | 19 | batch = (dataset.x, dataset.edge_index, dataset.edge_weight) 20 | conv_harness(conv, batch=batch) 21 | 22 | batch = (dataset.x, dataset.edge_index, dataset.edge_weight, None, 3.0) 23 | conv_harness(conv, batch=batch) 24 | 25 | num_nodes = dataset.num_nodes 26 | batch_mask = [int(i > num_nodes // 2) for i in range(num_nodes)] 27 | batch_mask = torch.tensor(batch_mask) 28 | lambda_max = torch.tensor([2.0, 3.0]) 29 | batch = (dataset.x, dataset.edge_index, dataset.edge_weight, batch_mask, 30 | lambda_max) 31 | conv_harness(conv, batch=batch) 32 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_cluster_gcn_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | from torch_geometric.nn import ClusterGCNConv 3 | from conv_utils import conv_harness 4 | 5 | 6 | def test_cluster_gcn_conv(dataset): 7 | in_channels = dataset.num_node_features 8 | out_channels = 32 9 | conv = ClusterGCNConv(in_channels, 10 | out_channels, 11 | diag_lambda=1., 12 | add_self_loops=False) 13 | conv_harness(conv, dataset) 14 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_dna_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import pytest 3 | import torch 4 | from torch_geometric.nn import DNAConv 5 | 6 | from conv_utils import conv_harness 7 | 8 | conv_kwargs_list = [{ 9 | 'heads': 4, 10 | 'groups': 8, 11 | }, { 12 | 'heads': 4, 13 | 'groups': 8, 14 | }, { 15 | 'heads': 4, 16 | 'groups': 8, 17 | 'cached': True 18 | }] 19 | 20 | 21 | @pytest.mark.parametrize('conv_kwargs', conv_kwargs_list) 22 | def test_dna_conv(conv_kwargs): 23 | channels = 32 24 | num_layers = 3 25 | edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) 26 | num_nodes = edge_index.max().item() + 1 27 | x = torch.randn((num_nodes, num_layers, channels)) 28 | 29 | conv = DNAConv(channels, dropout=0.0, add_self_loops=False, **conv_kwargs) 30 | conv_harness(conv, batch=(x, edge_index)) 31 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_edge_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | from torch.nn import Linear as Lin 3 | from torch.nn import ReLU 4 | from torch.nn import Sequential as Seq 5 | from torch_geometric.nn import DynamicEdgeConv, EdgeConv 6 | from conv_utils import conv_harness 7 | 8 | out_channels = 32 9 | 10 | 11 | def test_edge_conv(dataset): 12 | in_channels = dataset.num_node_features 13 | nn = Seq(Lin(in_channels * 2, in_channels), ReLU(), 14 | Lin(in_channels, out_channels)) 15 | conv = EdgeConv(nn) 16 | 17 | conv_harness(conv, dataset) 18 | 19 | 20 | def test_dynamic_edge_conv(dataset): 21 | in_channels = dataset.num_node_features 22 | nn = Seq(Lin(in_channels * 2, in_channels), ReLU(), 23 | Lin(in_channels, out_channels)) 24 | conv = DynamicEdgeConv(nn, k=2) 25 | 26 | conv_harness(conv, dataset, batch=(dataset.x, )) 27 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_eg_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | import pytest 3 | from torch_geometric.nn import EGConv 4 | from conv_utils import conv_harness 5 | 6 | conv_kwargs_list = [ 7 | { 8 | "add_self_loops": False 9 | }, 10 | { 11 | "add_self_loops": False, 12 | "aggregators": ["max", "min"] 13 | }, 14 | ] 15 | 16 | 17 | @pytest.mark.parametrize('conv_kwargs', conv_kwargs_list) 18 | def test_eg_conv(dataset, conv_kwargs): 19 | in_channels = dataset.num_node_features 20 | conv = EGConv(in_channels, 32, **conv_kwargs) 21 | 22 | conv_harness(conv, dataset) 23 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_fa_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | from torch_geometric.nn import FAConv 3 | from conv_utils import conv_harness 4 | 5 | conv_kwargs = {"add_self_loops": False} 6 | 7 | 8 | def test_fa_conv(dataset): 9 | in_channels = dataset.num_node_features 10 | conv = FAConv(in_channels, eps=1.0, **conv_kwargs) 11 | batch = (dataset.x, dataset.x, dataset.edge_index) 12 | 13 | conv_harness(conv, dataset, batch=batch) 14 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_feast_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | from torch_geometric.nn import FeaStConv 3 | from conv_utils import conv_harness 4 | 5 | out_channels = 32 6 | conv_kwargs = {"add_self_loops": False} 7 | 8 | 9 | def test_feast_conv(dataset): 10 | in_channels = dataset.num_node_features 11 | conv = FeaStConv(in_channels, out_channels, heads=2, **conv_kwargs) 12 | 13 | conv_harness(conv, dataset) 14 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_film_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | import pytest 3 | from torch_geometric.nn import FiLMConv 4 | from conv_utils import conv_harness 5 | 6 | out_channels = 32 7 | 8 | 9 | @pytest.mark.parametrize('num_relations', [1]) 10 | def test_film_conv(dataset, num_relations): 11 | in_channels = dataset.num_node_features 12 | conv = FiLMConv(in_channels, out_channels, num_relations=num_relations) 13 | 14 | conv_harness(conv, dataset) 15 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_gat_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | import pytest 3 | from torch_geometric.nn import GATConv 4 | from conv_utils import conv_harness 5 | 6 | out_channels = 32 7 | conv_kwargs_list = [ 8 | { 9 | 'edge_dim': None 10 | }, 11 | { 12 | 'edge_dim': 1, 13 | 'fill_value': 0.5 14 | }, 15 | { 16 | 'edge_dim': 1, 17 | 'fill_value': 'mean' 18 | }, 19 | { 20 | 'edge_dim': 4, 21 | 'fill_value': 0.5 22 | }, 23 | { 24 | 'edge_dim': 4, 25 | 'fill_value': 'mean' 26 | }, 27 | ] 28 | 29 | 30 | @pytest.mark.parametrize('conv_kwargs', conv_kwargs_list) 31 | def test_gat_conv(dataset, conv_kwargs): 32 | in_channels = dataset.num_node_features 33 | conv_kwargs["add_self_loops"] = False 34 | 35 | conv = GATConv(in_channels, out_channels, heads=2, **conv_kwargs) 36 | conv_harness(conv, dataset) 37 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_gated_graph_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | from torch_geometric.nn import GatedGraphConv 3 | from conv_utils import conv_harness 4 | 5 | out_channels = 32 6 | 7 | 8 | def test_gated_graph_conv(dataset): 9 | in_channels = dataset.num_node_features 10 | conv = GatedGraphConv(in_channels, num_layers=3) 11 | 12 | conv_harness(conv, dataset) 13 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_gatv2_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | import pytest 3 | from torch_geometric.nn import GATv2Conv 4 | from conv_utils import conv_harness 5 | 6 | out_channels = 32 7 | conv_kwargs_list = [ 8 | { 9 | 'edge_dim': None 10 | }, 11 | { 12 | 'edge_dim': 1, 13 | 'fill_value': 0.5 14 | }, 15 | { 16 | 'edge_dim': 1, 17 | 'fill_value': 'mean' 18 | }, 19 | { 20 | 'edge_dim': 4, 21 | 'fill_value': 0.5 22 | }, 23 | { 24 | 'edge_dim': 4, 25 | 'fill_value': 'mean' 26 | }, 27 | ] 28 | 29 | 30 | @pytest.mark.parametrize('conv_kwargs', conv_kwargs_list) 31 | def test_gatv2_conv(dataset, conv_kwargs): 32 | in_channels = dataset.num_node_features 33 | conv_kwargs["add_self_loops"] = False 34 | conv = GATv2Conv(in_channels, out_channels, heads=2, **conv_kwargs) 35 | 36 | conv_harness(conv, dataset, atol=1e-4, rtol=1e-3) 37 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_gcn2_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | import torch 3 | from torch_geometric.nn import GCN2Conv 4 | from conv_utils import conv_harness 5 | 6 | out_channels = 16 7 | 8 | 9 | def test_gcn2_conv(dataset): 10 | print(dataset) 11 | in_channels = dataset.num_node_features 12 | conv = GCN2Conv(in_channels, alpha=float(0.2), add_self_loops=False) 13 | x2 = torch.randn_like(dataset.x) 14 | batch = (dataset.x, x2, dataset.edge_index) 15 | conv_harness(conv, dataset, batch=batch, num_steps=1) 16 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_gcn_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | import pytest 3 | from torch_geometric.nn import GCNConv 4 | from conv_utils import conv_harness 5 | 6 | out_channels = 32 7 | conv_kwargs = {'add_self_loops': False} 8 | 9 | 10 | @pytest.mark.parametrize('flow', ['source_to_target', 'target_to_source']) 11 | def test_gcn_conv(dataset, flow): 12 | in_channels = dataset.num_node_features 13 | conv = GCNConv(in_channels, out_channels, flow, **conv_kwargs) 14 | 15 | conv_harness(conv, dataset) 16 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_gen_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import pytest 3 | import torch 4 | from torch_geometric.nn import GENConv 5 | 6 | from conv_utils import conv_harness 7 | 8 | 9 | @pytest.mark.parametrize('aggr', ['softmax', 'powermean']) 10 | def test_gen_conv(aggr, dataset): 11 | in_channels = dataset.num_node_features 12 | 13 | conv = GENConv(in_channels, 14 | 32, 15 | aggr, 16 | edge_dim=16, 17 | add_self_loops=False, 18 | norm='layer') 19 | conv_harness(conv, dataset) 20 | 21 | x2 = torch.randn(dataset.x.shape) 22 | batch = ((dataset.x, x2), dataset.edge_index) 23 | conv_harness(conv, dataset, batch=batch) 24 | 25 | conv = GENConv((in_channels, in_channels), 26 | 32, 27 | aggr, 28 | add_self_loops=False, 29 | norm='layer') 30 | conv_harness(conv, dataset, batch=batch) 31 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_general_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | import pytest 3 | import torch 4 | from torch_geometric.nn import GeneralConv 5 | from conv_utils import conv_harness 6 | 7 | out_channels = 32 8 | num_edge_attr = 16 9 | 10 | conv_kwargs_list = [{ 11 | 'skip_linear': True 12 | }, { 13 | 'directed_msg': False 14 | }, { 15 | 'heads': 3 16 | }, { 17 | 'attention': True 18 | }, { 19 | 'heads': 3, 20 | 'attention': True 21 | }, { 22 | 'heads': 3, 23 | 'attention': True, 24 | 'attention_type': 'dot_product' 25 | }, { 26 | 'l2_normalize': True 27 | }] 28 | 29 | 30 | @pytest.mark.parametrize('conv_kwargs', conv_kwargs_list) 31 | def test_general_conv(dataset, conv_kwargs): 32 | in_channels = dataset.num_node_features 33 | conv = GeneralConv(in_channels, out_channels, num_edge_attr, **conv_kwargs) 34 | 35 | e1 = torch.randn(dataset.num_edges, num_edge_attr) 36 | 37 | batch = (dataset.x, dataset.edge_index, e1) 38 | conv_harness(conv, dataset, batch=batch) 39 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_gin_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | import torch 3 | from torch.nn import Linear as Lin 4 | from torch.nn import ReLU 5 | from torch.nn import Sequential as Seq 6 | from torch_geometric.nn import GINConv, GINEConv 7 | from conv_utils import conv_harness 8 | 9 | out_channels = 32 10 | 11 | 12 | def test_gin_conv(dataset): 13 | in_channels = dataset.num_node_features 14 | nn = Seq(Lin(in_channels, 32), ReLU(), Lin(32, 32)) 15 | conv = GINConv(nn, train_eps=True) 16 | 17 | conv_harness(conv, dataset) 18 | 19 | 20 | def test_gine_conv(dataset): 21 | in_channels = dataset.num_node_features 22 | nn = Seq(Lin(in_channels, 32), ReLU(), Lin(32, 32)) 23 | 24 | conv = GINEConv(nn, train_eps=True) 25 | 26 | value = torch.randn(dataset.num_edges, 16) 27 | batch = (dataset.x, dataset.edge_index, value) 28 | 29 | conv_harness(conv, dataset, batch=batch) 30 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_gmm_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import pytest 3 | import torch 4 | from torch_geometric.nn import GMMConv 5 | 6 | from conv_utils import conv_harness 7 | 8 | 9 | @pytest.mark.parametrize('separate_gaussians', [True, False]) 10 | def test_gmm_conv(separate_gaussians, dataset): 11 | in_channels = dataset.num_node_features 12 | conv = GMMConv(in_channels, 13 | 32, 14 | dim=3, 15 | kernel_size=25, 16 | separate_gaussians=separate_gaussians, 17 | add_self_loops=False) 18 | value = torch.rand(dataset.num_edges, 3) 19 | batch = (dataset.x, dataset.edge_index, value) 20 | conv_harness(conv, batch=batch) 21 | 22 | 23 | @pytest.mark.parametrize('separate_gaussians', [True, False]) 24 | def test_gmm_conv_bipartite(separate_gaussians, dataset): 25 | 26 | in_channels = dataset.num_node_features 27 | conv = GMMConv((in_channels, in_channels), 28 | 32, 29 | dim=3, 30 | kernel_size=5, 31 | separate_gaussians=separate_gaussians, 32 | add_self_loops=False) 33 | value = torch.rand(dataset.num_edges, 3) 34 | x2 = torch.randn(dataset.x.shape) 35 | batch = ((dataset.x, x2), dataset.edge_index, value) 36 | conv_harness(conv, batch=batch) 37 | 38 | batch = ((dataset.x, None), dataset.edge_index, value) 39 | conv_harness(conv, batch=batch) 40 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_gps_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import pytest 3 | import torch 4 | from torch_geometric.nn import GPSConv, SAGEConv 5 | 6 | from conv_utils import conv_harness 7 | 8 | 9 | @pytest.mark.skip(reason="TODO(AFS-279, AFS-162)") 10 | @pytest.mark.parametrize('norm', [None, 'batch_norm', 'layer_norm']) 11 | def test_gps_conv(norm, dataset): 12 | in_channels = dataset.num_node_features 13 | conv = GPSConv(in_channels, 14 | conv=SAGEConv(16, 16, add_self_loops=False), 15 | heads=4, 16 | norm=norm) 17 | conv.reset_parameters() 18 | 19 | conv_harness(conv, dataset) 20 | 21 | 22 | @pytest.mark.skip(reason="TODO(AFS-279, AFS-162)") 23 | @pytest.mark.parametrize('norm', [None, 'batch_norm', 'layer_norm']) 24 | def test_gps_conv_with_batch_index_tensor(norm, dataset): 25 | in_channels = dataset.num_node_features 26 | conv = GPSConv(in_channels, 27 | conv=SAGEConv(16, 16, add_self_loops=False), 28 | heads=4, 29 | norm=norm) 30 | conv.reset_parameters() 31 | 32 | batch_index = [ 33 | i > dataset.num_nodes // 2 for i in range(dataset.num_nodes) 34 | ] 35 | batch_index = torch.tensor(batch_index, dtype=torch.int64) 36 | 37 | batch = (dataset.x, dataset.edge_index, batch_index) 38 | conv_harness(conv, batch=batch) 39 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_graph_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | from torch_geometric.nn import GraphConv 3 | from conv_utils import conv_harness 4 | 5 | out_channels = 16 6 | 7 | 8 | def test_graph_conv(dataset): 9 | in_channels = dataset.num_node_features 10 | conv = GraphConv(in_channels, out_channels) 11 | 12 | conv_harness(conv, dataset) 13 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_gravnet_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import torch 3 | from torch_geometric.nn import GravNetConv 4 | from torch_geometric.testing import withPackage 5 | 6 | from conv_utils import conv_harness 7 | 8 | 9 | @withPackage('torch_cluster') 10 | def test_gravnet_conv(dataset): 11 | in_channels = dataset.num_node_features 12 | out_channels = 32 13 | conv = GravNetConv(in_channels, 14 | out_channels, 15 | space_dimensions=4, 16 | propagate_dimensions=8, 17 | k=2, 18 | add_self_loops=False) 19 | conv_harness(conv, batch=(dataset.x, )) 20 | 21 | num_nodes = dataset.num_nodes 22 | batch_index = [1 if i > num_nodes // 2 else 0 for i in range(num_nodes)] 23 | conv_harness(conv, batch=(dataset.x, batch_index)) 24 | 25 | x2 = torch.randn_like(dataset.x) 26 | conv_harness(conv, batch=((dataset.x, x2), ), atol=5e-05, rtol=0.001) 27 | conv_harness(conv, 28 | batch=((dataset.x, x2), (torch.Tensor(batch_index), 29 | torch.Tensor(batch_index))), 30 | atol=5e-03, 31 | rtol=0.1) 32 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_han_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | from torch_geometric.nn import HANConv 3 | 4 | from conv_utils import hetero_conv_harness, random_heterodata 5 | 6 | 7 | def test_han_conv(): 8 | data, in_channels = random_heterodata() 9 | metadata = data.metadata() 10 | 11 | conv = HANConv(in_channels, 16, metadata, heads=2, add_self_loops=False) 12 | hetero_conv_harness(conv, data, 'author') 13 | 14 | 15 | def test_han_conv_lazy(): 16 | data, _ = random_heterodata() 17 | metadata = data.metadata() 18 | 19 | conv = HANConv(-1, 16, metadata, heads=2, add_self_loops=False) 20 | _ = conv(data.x_dict, data.edge_index_dict) 21 | hetero_conv_harness(conv, data, 'author') 22 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_heat_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import pytest 3 | import torch 4 | from torch_geometric.nn import HEATConv 5 | 6 | from conv_utils import conv_harness 7 | 8 | 9 | @pytest.mark.parametrize('concat', [True, False]) 10 | def test_heat_conv(concat): 11 | x = torch.randn(4, 8) 12 | edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) 13 | edge_attr = torch.randn((4, 2)) 14 | node_type = torch.tensor([0, 0, 1, 2]) 15 | edge_type = torch.tensor([0, 2, 1, 2]) 16 | 17 | conv = HEATConv(in_channels=8, 18 | out_channels=16, 19 | num_node_types=3, 20 | num_edge_types=3, 21 | edge_type_emb_dim=5, 22 | edge_dim=2, 23 | edge_attr_emb_dim=6, 24 | heads=2, 25 | concat=concat, 26 | add_self_loops=False) 27 | 28 | conv_harness(conv, 29 | batch=(x, edge_index, node_type, edge_type, edge_attr), 30 | atol=5e-4, 31 | rtol=0.3) 32 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_hgt_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | from collections import defaultdict 4 | 5 | from torch_geometric.nn import HGTConv 6 | 7 | from conv_utils import hetero_conv_harness, random_heterodata 8 | 9 | 10 | def test_hgt_conv_same_dimensions(): 11 | in_channels = defaultdict(lambda: 16) 12 | 13 | data, _ = random_heterodata(in_channels) 14 | 15 | conv = HGTConv(in_channels['author'], 16 | in_channels['paper'], 17 | metadata=data.metadata(), 18 | heads=2) 19 | hetero_conv_harness(conv, data, 'author') 20 | 21 | 22 | def test_hgt_conv_different_dimensions(): 23 | in_channels = defaultdict(lambda: 16) 24 | in_channels['paper'] = 32 25 | 26 | data, _ = random_heterodata(in_channels) 27 | 28 | conv = HGTConv(in_channels=in_channels, 29 | out_channels=32, 30 | metadata=data.metadata(), 31 | heads=2) 32 | 33 | hetero_conv_harness(conv, data, 'author') 34 | 35 | 36 | def test_hgt_conv_lazy(): 37 | in_channels = defaultdict(lambda: 16) 38 | in_channels['paper'] = 32 39 | 40 | data, _ = random_heterodata(in_channels) 41 | 42 | conv = HGTConv(-1, 32, metadata=data.metadata(), heads=2) 43 | 44 | _ = conv(data.x_dict, data.edge_index_dict) 45 | hetero_conv_harness(conv, data, 'author') 46 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_hypergraph_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import torch 3 | from torch_geometric.nn import HypergraphConv 4 | 5 | from conv_utils import conv_harness 6 | 7 | 8 | def test_hypergraph_conv_with_more_nodes_than_edges(): 9 | torch.manual_seed(42) 10 | in_channels, out_channels = (16, 32) 11 | hyperedge_index = torch.tensor([[0, 0, 1, 1, 2, 3], [0, 1, 0, 1, 0, 1]]) 12 | hyperedge_weight = torch.tensor([1.0, 0.5]) 13 | num_nodes = hyperedge_index[0].max().item() + 1 14 | num_edges = hyperedge_index[1].max().item() + 1 15 | x = torch.randn((num_nodes, in_channels)) 16 | hyperedge_attr = torch.randn((num_edges, in_channels)) 17 | 18 | conv = HypergraphConv(in_channels, out_channels, add_self_loops=False) 19 | 20 | conv_harness(conv, batch=(x, hyperedge_index, None, None, num_edges)) 21 | 22 | conv = HypergraphConv(in_channels, 23 | out_channels, 24 | use_attention=True, 25 | heads=2, 26 | add_self_loops=False) 27 | 28 | conv_harness(conv, 29 | batch=(x, hyperedge_index, hyperedge_weight, hyperedge_attr, 30 | num_edges)) 31 | 32 | 33 | def test_hypergraph_conv_with_more_edges_than_nodes(): 34 | torch.manual_seed(42) 35 | in_channels, out_channels = (16, 32) 36 | hyperedge_index = torch.tensor([[0, 0, 1, 1, 2, 3, 3, 3, 2, 1, 2], 37 | [0, 1, 2, 1, 2, 1, 0, 3, 3, 4, 4]]) 38 | hyperedge_weight = torch.tensor([1.0, 0.5, 0.8, 0.2, 0.7]) 39 | num_nodes = hyperedge_index[0].max().item() + 1 40 | num_edges = hyperedge_index[1].max().item() + 1 41 | x = torch.randn((num_nodes, in_channels)) 42 | 43 | conv = HypergraphConv(in_channels, out_channels) 44 | 45 | conv_harness(conv, batch=(x, hyperedge_index, None, None, num_edges)) 46 | conv_harness(conv, 47 | batch=(x, hyperedge_index, hyperedge_weight, None, num_edges)) 48 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_le_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | from torch_geometric.nn import LEConv 3 | from conv_utils import conv_harness 4 | 5 | out_channels = 16 6 | 7 | 8 | def test_le_conv(dataset): 9 | in_channels = dataset.num_node_features 10 | conv = LEConv(in_channels, out_channels) 11 | 12 | conv_harness(conv, dataset) 13 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_lg_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | import torch 3 | from torch_geometric.nn import LGConv 4 | from conv_utils import conv_harness 5 | 6 | out_channels = 16 7 | 8 | 9 | def test_lg_conv(dataset): 10 | in_channels = dataset.num_node_features 11 | conv = LGConv() 12 | lin = torch.nn.Linear(in_channels, out_channels) 13 | 14 | conv_harness(conv, dataset, post_proc=lin) 15 | 16 | 17 | def test_lg_edge_weights_conv(dataset): 18 | in_channels = dataset.num_node_features 19 | conv = LGConv() 20 | lin = torch.nn.Linear(in_channels, out_channels) 21 | 22 | batch = (dataset.x, dataset.edge_index, dataset.edge_weight) 23 | conv_harness(conv, dataset, batch=batch, post_proc=lin) 24 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_mf_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import torch 3 | from torch_geometric.nn import MFConv 4 | 5 | from conv_utils import conv_harness 6 | 7 | 8 | def test_mf_conv(dataset): 9 | in_channels = dataset.num_node_features 10 | out_channels = 32 11 | 12 | conv = MFConv(in_channels, out_channels, add_self_loops=False) 13 | 14 | conv_harness(conv, dataset) 15 | 16 | conv = MFConv((in_channels, in_channels), 17 | out_channels, 18 | add_self_loops=False) 19 | 20 | x2 = torch.randn(dataset.x.shape) 21 | batch = ((dataset.x, x2), dataset.edge_index) 22 | conv_harness(conv, batch=batch) 23 | 24 | batch = ((dataset.x, None), dataset.edge_index) 25 | conv_harness(conv, batch=batch) 26 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_nn_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | import torch 3 | from torch.nn import Linear as Lin 4 | from torch.nn import ReLU 5 | from torch.nn import Sequential as Seq 6 | from torch_geometric.nn import NNConv 7 | from conv_utils import conv_harness 8 | 9 | out_channels = 16 10 | 11 | 12 | def test_nn_conv(dataset): 13 | in_channels = dataset.num_node_features 14 | nn = Seq(Lin(3, 32), ReLU(), Lin(32, 8 * 32)) 15 | conv = NNConv(in_channels, out_channels, nn=nn) 16 | 17 | value = torch.rand(dataset.num_edges, 3) 18 | batch = (dataset.x, dataset.edge_index, value) 19 | 20 | conv_harness(conv, dataset, batch=batch) 21 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_pan_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import pytest 3 | from torch_geometric.nn import PANConv 4 | 5 | from conv_utils import conv_harness 6 | 7 | 8 | @pytest.mark.skip(reason="TODO(AFS-262)") 9 | def test_pan_conv(dataset): 10 | in_channels = dataset.num_node_features 11 | conv = PANConv(in_channels, 32, filter_size=2, add_self_loops=False) 12 | 13 | conv_harness(conv, dataset) 14 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_pdn_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | import torch 3 | from torch_geometric.nn import PDNConv 4 | from conv_utils import conv_harness 5 | 6 | out_channels = 16 7 | 8 | 9 | def test_pdn_conv(dataset): 10 | in_channels = dataset.num_node_features 11 | conv = PDNConv(in_channels, 12 | out_channels, 13 | edge_dim=8, 14 | hidden_channels=128, 15 | add_self_loops=False) 16 | 17 | edge_attr = torch.randn(dataset.num_edges, 8) 18 | batch = (dataset.x, dataset.edge_index, edge_attr) 19 | conv_harness(conv, dataset, batch=batch) 20 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_pna_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | import torch 3 | from torch_geometric.nn import PNAConv 4 | from conv_utils import conv_harness 5 | 6 | out_channels = 16 7 | 8 | aggregators = ['sum', 'mean', 'min', 'max', 'var', 'std'] 9 | scalers = [ 10 | 'identity', 'amplification', 'attenuation', 'linear', 'inverse_linear' 11 | ] 12 | 13 | 14 | def test_pna_conv(dataset): 15 | in_channels = dataset.num_node_features 16 | deg = PNAConv.get_degree_histogram([dataset]) 17 | 18 | conv = PNAConv(in_channels, 19 | out_channels, 20 | aggregators, 21 | scalers, 22 | deg=deg, 23 | edge_dim=3, 24 | towers=4) 25 | 26 | value = torch.rand(dataset.num_edges, 3) 27 | batch = (dataset.x, dataset.edge_index, value) 28 | conv_harness(conv, dataset, batch=batch) 29 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_point_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | import torch 3 | from torch.nn import Linear as Lin 4 | from torch.nn import ReLU 5 | from torch.nn import Sequential as Seq 6 | from torch_geometric.nn import PointNetConv 7 | from conv_utils import conv_harness 8 | 9 | out_channels = 16 10 | 11 | 12 | def test_point_net_conv(dataset): 13 | 14 | local_nn = Seq(Lin(16 + 3, 32), ReLU(), Lin(32, 32)) 15 | global_nn = Seq(Lin(32, 32)) 16 | conv = PointNetConv(local_nn, global_nn, add_self_loops=False) 17 | 18 | pos = torch.rand(dataset.num_nodes, 3) 19 | batch = (dataset.x, pos, dataset.edge_index) 20 | conv_harness(conv, dataset, batch=batch) 21 | 22 | 23 | def test_point2_net_conv(dataset): 24 | 25 | local_nn = Seq(Lin(16 + 3, 32), ReLU(), Lin(32, 32)) 26 | global_nn = Seq(Lin(32, 32)) 27 | conv = PointNetConv(local_nn, global_nn, add_self_loops=False) 28 | 29 | pos1 = torch.rand(dataset.num_nodes, 3) 30 | pos2 = torch.rand(dataset.num_nodes, 3) 31 | 32 | batch = (dataset.x, (pos1, pos2), dataset.edge_index) 33 | conv_harness(conv, dataset, batch=batch) 34 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_point_gnn_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import torch 3 | from torch_geometric import seed_everything 4 | from torch_geometric.nn import MLP, PointGNNConv 5 | 6 | from conv_utils import conv_harness 7 | 8 | 9 | def test_pointgnn_conv(): 10 | seed_everything(42) 11 | x = torch.rand(6, 8) 12 | pos = torch.rand(6, 3) 13 | edge_index = torch.tensor([[0, 1, 1, 1, 2, 5], [1, 2, 3, 4, 3, 4]]) 14 | 15 | conv = PointGNNConv( 16 | mlp_h=MLP([8, 16, 3], norm='LayerNorm'), 17 | mlp_f=MLP([3 + 8, 16, 8], norm='LayerNorm'), 18 | mlp_g=MLP([8, 16, 8], norm='LayerNorm'), 19 | ) 20 | 21 | batch = (x, pos, edge_index) 22 | conv_harness(conv, batch=batch) 23 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_point_transformer_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | import torch 3 | from torch.nn import Linear as Lin 4 | from torch.nn import ReLU 5 | from torch.nn import Sequential as Seq 6 | from torch_geometric.nn import PointTransformerConv 7 | from conv_utils import conv_harness 8 | 9 | out_channels = 32 10 | 11 | 12 | def test_point_transformer_conv(dataset): 13 | in_channels = dataset.num_node_features 14 | conv = PointTransformerConv(in_channels, 15 | out_channels, 16 | add_self_loops=False) 17 | 18 | pos = torch.rand(dataset.num_nodes, 3) 19 | 20 | batch = (dataset.x, pos, dataset.edge_index) 21 | conv_harness(conv, dataset, batch=batch, atol=1e-4, rtol=1e-3) 22 | 23 | 24 | def test_point_transformer_nn_conv(dataset): 25 | in_channels = dataset.num_node_features 26 | pos_nn = Seq(Lin(3, 16), ReLU(), Lin(16, 32)) 27 | attn_nn = Seq(Lin(32, 32), ReLU(), Lin(32, 32)) 28 | conv = PointTransformerConv(in_channels, 29 | out_channels, 30 | pos_nn, 31 | attn_nn, 32 | add_self_loops=False) 33 | 34 | pos = torch.rand(dataset.num_nodes, 3) 35 | 36 | batch = (dataset.x, pos, dataset.edge_index) 37 | conv_harness(conv, dataset, batch=batch, atol=1e-3, rtol=1e-2) 38 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_ppf_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.nn import Linear as Lin 5 | from torch.nn import ReLU 6 | from torch.nn import Sequential as Seq 7 | from torch_geometric.nn import PPFConv 8 | from conv_utils import conv_harness 9 | 10 | 11 | def test_ppf_conv(dataset): 12 | 13 | local_nn = Seq(Lin(16 + 4, 32), ReLU(), Lin(32, 32)) 14 | global_nn = Seq(Lin(32, 32)) 15 | conv = PPFConv(local_nn, global_nn, add_self_loops=False) 16 | 17 | pos = torch.rand(dataset.num_nodes, 3) 18 | n = F.normalize(torch.rand(dataset.num_nodes, 3), dim=-1) 19 | 20 | batch = (dataset.x, pos, n, dataset.edge_index) 21 | conv_harness(conv, dataset, batch=batch) 22 | 23 | 24 | def test_ppf2_conv(dataset): 25 | 26 | local_nn = Seq(Lin(16 + 4, 32), ReLU(), Lin(32, 32)) 27 | global_nn = Seq(Lin(32, 32)) 28 | conv = PPFConv(local_nn, global_nn, add_self_loops=False) 29 | 30 | pos1 = torch.rand(dataset.num_nodes, 3) 31 | pos2 = torch.rand(dataset.num_nodes, 3) 32 | n1 = F.normalize(torch.rand(dataset.num_nodes, 3), dim=-1) 33 | n2 = F.normalize(torch.rand(dataset.num_nodes, 3), dim=-1) 34 | 35 | batch = (dataset.x, (pos1, pos2), (n1, n2), dataset.edge_index) 36 | conv_harness(conv, dataset, batch=batch) 37 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_res_gated_graph_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | from torch_geometric.nn import ResGatedGraphConv 3 | from conv_utils import conv_harness 4 | 5 | out_channels = 16 6 | 7 | 8 | def test_res_gated_graph_conv(dataset): 9 | in_channels = dataset.num_node_features 10 | 11 | conv = ResGatedGraphConv(in_channels, out_channels) 12 | conv_harness(conv, dataset) 13 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_rgat_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import pytest 3 | import torch 4 | from torch_geometric import seed_everything 5 | from torch_geometric.nn import RGATConv 6 | 7 | from conv_utils import conv_harness 8 | 9 | 10 | @pytest.mark.parametrize('mod', [ 11 | 'additive', 12 | 'scaled', 13 | 'f-additive', 14 | 'f-scaled', 15 | ]) 16 | @pytest.mark.parametrize('attention_mechanism', [ 17 | 'within-relation', 18 | 'across-relation', 19 | ]) 20 | @pytest.mark.parametrize('attention_mode', [ 21 | 'additive-self-attention', 22 | 'multiplicative-self-attention', 23 | ]) 24 | def test_rgat_conv(mod, attention_mechanism, attention_mode): 25 | seed_everything(0) 26 | 27 | if attention_mechanism == 'within-relation': 28 | pytest.skip("Condition from torch.nonzero is used to compute softmax. " 29 | "Fixed size tensor can change softmax result.") 30 | 31 | if mod != 'additive' or attention_mode != 'additive-self-attention': 32 | pytest.skip("TODO(AFS-200)") 33 | 34 | x = torch.randn(4, 8) 35 | edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) 36 | edge_type = torch.tensor([0, 2, 1, 2]) 37 | edge_attr = torch.randn((4, 8)) 38 | 39 | conv = RGATConv(8, 40 | 20, 41 | num_relations=4, 42 | num_bases=4, 43 | mod=mod, 44 | attention_mechanism=attention_mechanism, 45 | attention_mode=attention_mode, 46 | heads=2, 47 | dim=1, 48 | edge_dim=8, 49 | add_self_loops=False) 50 | 51 | batch = (x, edge_index, edge_type, edge_attr) 52 | conv_harness(conv, batch=batch) 53 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_rgcn_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import pytest 3 | import torch 4 | from torch_geometric.nn import FastRGCNConv, RGCNConv 5 | from conv_utils import conv_harness 6 | 7 | out_channels = 16 8 | 9 | 10 | @pytest.mark.parametrize('rgcn', [FastRGCNConv, RGCNConv]) 11 | def test_rgcn_conv(rgcn): 12 | if rgcn == RGCNConv: 13 | pytest.skip("RGCNConv uses dynamic shapes") 14 | 15 | in_channels = 4 16 | out_channels = 32 17 | num_relations = 4 18 | edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [0, 0, 1, 0, 1, 1]]) 19 | edge_type = torch.tensor([0, 1, 1, 0, 0, 1]) 20 | conv = rgcn(in_channels, 21 | out_channels, 22 | num_relations, 23 | num_bases=15, 24 | add_self_loops=False) 25 | 26 | batch = (None, edge_index, edge_type) 27 | conv_harness(conv, batch=batch) 28 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_sage_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | from torch_geometric.nn import SAGEConv 3 | from conv_utils import conv_harness 4 | 5 | out_channels = 16 6 | 7 | aggregators = ['sum', 'mean', 'min', 'max', 'var', 'std'] 8 | 9 | 10 | def test_sage_conv(dataset): 11 | in_channels = dataset.num_node_features 12 | 13 | conv = SAGEConv(in_channels, 14 | out_channels, 15 | aggr=aggregators, 16 | normalize=True, 17 | root_weight=True, 18 | project=True, 19 | bias=True) 20 | 21 | conv_harness(conv, dataset) 22 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_sg_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | from torch_geometric.nn import SGConv 3 | from conv_utils import conv_harness 4 | 5 | out_channels = 16 6 | 7 | 8 | def test_sg_conv(dataset): 9 | in_channels = dataset.num_node_features 10 | conv = SGConv(in_channels, out_channels, K=10, add_self_loops=False) 11 | 12 | conv_harness(conv, dataset) 13 | 14 | 15 | def test_sg_weights_conv(dataset): 16 | in_channels = dataset.num_node_features 17 | conv = SGConv(in_channels, out_channels, K=10, add_self_loops=False) 18 | 19 | batch = (dataset.x, dataset.edge_index, dataset.edge_weight) 20 | conv_harness(conv, dataset, batch=batch) 21 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_signed_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | import torch 3 | from torch_geometric.nn import SignedConv 4 | from conv_utils import conv_harness 5 | 6 | out_channels = 16 7 | 8 | 9 | def test_signed_conv(dataset): 10 | 11 | in_channels = dataset.num_node_features 12 | 13 | class Convs(torch.nn.Module): 14 | def __init__(self): 15 | super().__init__() 16 | self.conv1 = SignedConv(in_channels, 17 | out_channels, 18 | first_aggr=True, 19 | add_self_loops=False) 20 | 21 | self.conv2 = SignedConv(out_channels, 22 | 32, 23 | first_aggr=False, 24 | add_self_loops=False) 25 | 26 | def forward(self, x, pos_edge_index, neg_edge_index): 27 | x = self.conv1(x, pos_edge_index, neg_edge_index) 28 | x = self.conv2(x, pos_edge_index, neg_edge_index) 29 | return x 30 | 31 | conv = Convs() 32 | 33 | batch = (dataset.x, dataset.edge_index, dataset.edge_index) 34 | conv_harness(conv, dataset, batch=batch) 35 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_simple_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import pytest 3 | 4 | import torch 5 | from torch_geometric.nn import SimpleConv 6 | 7 | from conv_utils import conv_harness 8 | 9 | 10 | @pytest.mark.parametrize('combine_root', ['sum', 'cat', 'self_loop', None]) 11 | def test_simple_conv(dataset, combine_root): 12 | in_channels = dataset.num_node_features 13 | out_channels = 64 14 | 15 | if combine_root == 'cat': 16 | in_channels = in_channels * 2 17 | 18 | lin = torch.nn.Linear(in_channels, out_channels) 19 | conv = SimpleConv(combine_root=combine_root) 20 | 21 | conv_harness(conv, dataset, post_proc=lin) 22 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_spline_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import pytest 3 | import torch 4 | from torch_geometric.nn import SplineConv 5 | from torch_geometric.testing import withPackage 6 | 7 | from conv_utils import conv_harness 8 | 9 | 10 | @pytest.mark.parametrize("training", [True, False]) 11 | @withPackage('torch_spline_conv') 12 | def test_spline_conv(training): 13 | if training: 14 | pytest.skip('reason="TODO(AFS-216, AFS-218)') 15 | x1 = torch.randn(4, 4) 16 | x2 = torch.randn(2, 8) 17 | edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) 18 | value = torch.rand(edge_index[0].size(0), 3) 19 | conv = SplineConv(4, 32, dim=3, kernel_size=5) 20 | 21 | conv_harness(conv, batch=(x1, edge_index, value), training=training) 22 | 23 | conv = SplineConv((4, 8), 32, dim=3, kernel_size=5) 24 | batch = ((x1, x2), edge_index, value) 25 | conv_harness(conv, batch=batch, training=training) 26 | 27 | batch = ((x1, None), edge_index, value, (4, 2)) 28 | conv_harness(conv, batch=batch, training=training) 29 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_ssg_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import torch 3 | from torch_geometric.nn import SSGConv 4 | 5 | from conv_utils import conv_harness 6 | 7 | 8 | def test_ssg_conv(dataset): 9 | in_channels = dataset.num_node_features 10 | out_channels = 32 11 | 12 | conv = SSGConv(in_channels, 13 | out_channels, 14 | alpha=0.1, 15 | K=10, 16 | add_self_loops=False) 17 | conv_harness(conv, dataset) 18 | 19 | value = torch.rand(dataset.num_edges) 20 | conv_harness(conv, batch=(dataset.x, dataset.edge_index, value)) 21 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_supergat_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | import pytest 3 | from torch_geometric.nn import SuperGATConv 4 | from conv_utils import conv_harness 5 | 6 | out_channels = 16 7 | 8 | 9 | @pytest.mark.skip(reason="TODO(AFS-36)") 10 | @pytest.mark.parametrize('att_type', ['MX', 'SD']) 11 | def test_supergat_conv(dataset, att_type): 12 | in_channels = dataset.num_node_features 13 | conv = SuperGATConv(in_channels, 14 | out_channels, 15 | heads=2, 16 | attention_type=att_type, 17 | neg_sample_ratio=1.0, 18 | edge_sample_ratio=1.0, 19 | add_self_loops=False) 20 | 21 | conv_harness(conv, dataset) 22 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_tag_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | from torch_geometric.nn import TAGConv 3 | from conv_utils import conv_harness 4 | 5 | out_channels = 16 6 | 7 | 8 | def test_tag_conv(dataset): 9 | in_channels = dataset.num_node_features 10 | conv = TAGConv(in_channels, out_channels) 11 | 12 | conv_harness(conv, dataset) 13 | 14 | 15 | def test_tag_weights_conv(dataset): 16 | in_channels = dataset.num_node_features 17 | conv = TAGConv(in_channels, out_channels) 18 | 19 | batch = (dataset.x, dataset.edge_index, dataset.edge_weight) 20 | conv_harness(conv, dataset, batch=batch) 21 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_transformer_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | from torch_geometric.nn import TransformerConv 3 | from conv_utils import conv_harness 4 | 5 | out_channels = 16 6 | 7 | 8 | def test_transformer_conv(dataset): 9 | in_channels = dataset.num_node_features 10 | conv = TransformerConv(in_channels, out_channels, heads=2, beta=True) 11 | 12 | conv_harness(conv, dataset) 13 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_wl_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import pytest 3 | import torch 4 | from torch_geometric.nn import WLConv 5 | 6 | from conv_utils import conv_harness 7 | 8 | 9 | @pytest.mark.skip(reason="Algorithm requires reading tensors which " 10 | "are placed on the IPU.") 11 | def test_wl_conv(): 12 | x = torch.tensor([1, 0, 0, 1]) 13 | edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]) 14 | conv = WLConv() 15 | _ = conv(x, edge_index) 16 | conv_harness(conv, batch=(x, edge_index), training=False) 17 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_wl_conv_continuous.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import torch 3 | from torch_geometric.nn import WLConvContinuous 4 | 5 | from conv_utils import conv_harness 6 | 7 | 8 | def test_wl_conv_cont(dataset): 9 | in_channels = dataset.num_node_features 10 | conv = WLConvContinuous() 11 | 12 | lin = torch.nn.Linear(in_channels, 8) 13 | conv_harness(conv, dataset, post_proc=lin) 14 | 15 | batch = ((dataset.x, None), dataset.edge_index, dataset.edge_weight) 16 | conv_harness(conv, batch=batch, post_proc=lin) 17 | 18 | x2 = torch.randn(dataset.x.shape) 19 | batch = ((dataset.x, x2), dataset.edge_index, dataset.edge_weight) 20 | conv_harness(conv, batch=batch, post_proc=lin) 21 | -------------------------------------------------------------------------------- /tests/gnn/nn/conv/test_x_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import torch 3 | from torch_geometric.nn import XConv 4 | from torch_geometric.testing import withPackage 5 | 6 | from conv_utils import conv_harness 7 | 8 | 9 | @withPackage('torch_cluster') 10 | def test_x_conv(): 11 | x = torch.randn(8, 16) 12 | pos = torch.rand(8, 5) 13 | batch = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1]) 14 | 15 | conv = XConv(16, 32, dim=5, kernel_size=2, dilation=2) 16 | 17 | torch.manual_seed(0) 18 | # We need to pass very loose atol and rtol here due to TODO(AFS-276) 19 | conv_harness(conv, batch=(x, pos), atol=0.1, rtol=0.1) 20 | conv_harness(conv, batch=(x, pos, batch), atol=0.1, rtol=0.1) 21 | -------------------------------------------------------------------------------- /tests/gnn/nn/dense/dense_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | import torch 4 | 5 | from poptorch_geometric import TrainingStepper 6 | 7 | 8 | def dense_harness(dense, 9 | batch=None, 10 | post_proc=None, 11 | loss_fn=torch.nn.MSELoss(), 12 | num_steps=4, 13 | atol=1e-5, 14 | rtol=1e-4): 15 | class DenseWrapper(torch.nn.Module): 16 | def __init__(self, dense, loss_fn, post_proc=None): 17 | super().__init__() 18 | self.dense = dense 19 | self.loss_fn = loss_fn 20 | self.post_proc = post_proc 21 | 22 | def forward(self, *args): 23 | x = self.dense(*args) 24 | if self.post_proc is not None: 25 | x = self.post_proc(x) 26 | if self.training: 27 | target = torch.ones_like(x) 28 | loss = self.loss_fn(x, target) 29 | return x, loss 30 | 31 | return x 32 | 33 | model = DenseWrapper(dense, loss_fn=loss_fn, post_proc=post_proc) 34 | 35 | stepper = TrainingStepper(model, atol=atol, rtol=rtol) 36 | 37 | stepper.run(num_steps, batch) 38 | -------------------------------------------------------------------------------- /tests/gnn/nn/functional/test_bro.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | import pytest 4 | import torch 5 | from torch_geometric.nn.functional import bro 6 | import poptorch 7 | 8 | 9 | @pytest.mark.skip(reason="TODO(AFS-269)") 10 | def test_bro(): 11 | batch = torch.tensor([0, 0, 0, 0, 1, 1, 1, 2, 2]) 12 | 13 | g1 = torch.tensor([ 14 | [0.2, 0.2, 0.2, 0.2], 15 | [0.0, 0.2, 0.2, 0.2], 16 | [0.2, 0.0, 0.2, 0.2], 17 | [0.2, 0.2, 0.0, 0.2], 18 | ]) 19 | 20 | g2 = torch.tensor([ 21 | [0.2, 0.2, 0.2, 0.2], 22 | [0.0, 0.2, 0.2, 0.2], 23 | [0.2, 0.0, 0.2, 0.2], 24 | ]) 25 | 26 | g3 = torch.tensor([ 27 | [0.2, 0.2, 0.2, 0.2], 28 | [0.2, 0.0, 0.2, 0.2], 29 | ]) 30 | 31 | class Model(torch.nn.Module): 32 | def forward(self, g1, g2, g3, batch): 33 | return bro(torch.cat([g1, g2, g3], dim=0), batch) 34 | 35 | model = Model() 36 | poptorch_model = poptorch.inferenceModel(model) 37 | 38 | ipu_out = poptorch_model(g1, g2, g3, batch) 39 | 40 | s = 0. 41 | for g in [torch.cat([g1, g2, g3]) / 3]: 42 | s += torch.norm(g @ g.t() - torch.eye(g.shape[0]), p=2) 43 | 44 | assert torch.isclose(s / 3., ipu_out) 45 | -------------------------------------------------------------------------------- /tests/gnn/nn/functional/test_gini.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | import torch 4 | from torch_geometric.nn.functional import gini 5 | 6 | import poptorch 7 | 8 | 9 | def test_gini(): 10 | 11 | w = torch.tensor([[0., 0., 0., 0.], [0., 0., 0., 1000.0]]) 12 | 13 | class Model(torch.nn.Module): 14 | def forward(self, w): 15 | return gini(w) 16 | 17 | model = Model() 18 | poptorch_model = poptorch.inferenceModel(model) 19 | 20 | ipu_out = poptorch_model(w) 21 | 22 | assert torch.isclose(ipu_out, torch.tensor(0.5)) 23 | -------------------------------------------------------------------------------- /tests/gnn/nn/kge/kge_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | from typing import List 4 | import torch 5 | 6 | from poptorch_geometric import TrainingStepper 7 | 8 | 9 | def kge_harness(kge, 10 | dataloader, 11 | post_proc=None, 12 | loss_fn=torch.nn.MSELoss(), 13 | num_steps=4, 14 | atol=5e-3, 15 | rtol=5e-3, 16 | equal_nan=False, 17 | enable_fp_exception=True): 18 | class KgeWrapper(torch.nn.Module): 19 | def __init__(self, kge, loss_fn, post_proc=None): 20 | super().__init__() 21 | self.model = kge 22 | self.loss_fn = loss_fn 23 | self.post_proc = post_proc 24 | 25 | def forward(self, *args): 26 | result = self.model(*args) 27 | 28 | if self.post_proc is not None: 29 | if isinstance(result, List): 30 | result = torch.cat(result) 31 | result = self.post_proc(result) 32 | 33 | if self.training: 34 | if isinstance(result, List): 35 | result = torch.cat(result) 36 | target = torch.ones_like(result) 37 | 38 | loss = self.loss_fn(result, target) 39 | return result, loss 40 | 41 | return result 42 | 43 | model = KgeWrapper(kge, loss_fn=loss_fn, post_proc=post_proc) 44 | 45 | stepper = TrainingStepper(model, 46 | atol=atol, 47 | rtol=rtol, 48 | equal_nan=equal_nan, 49 | enable_fp_exception=enable_fp_exception) 50 | 51 | if dataloader is not None: 52 | for step, batch in enumerate(dataloader): 53 | if step == num_steps: 54 | break 55 | stepper.run(1, batch) 56 | -------------------------------------------------------------------------------- /tests/gnn/nn/kge/test_complex.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | import torch 4 | 5 | from torch_geometric.nn import ComplEx 6 | from kge_utils import kge_harness 7 | 8 | 9 | def test_complex_scoring(): 10 | model = ComplEx(num_nodes=5, num_relations=2, hidden_channels=1) 11 | 12 | model.node_emb.weight.data = torch.tensor([ 13 | [2.], 14 | [3.], 15 | [5.], 16 | [1.], 17 | [2.], 18 | ]) 19 | model.node_emb_im.weight.data = torch.tensor([ 20 | [4.], 21 | [1.], 22 | [3.], 23 | [1.], 24 | [2.], 25 | ]) 26 | model.rel_emb.weight.data = torch.tensor([ 27 | [2.], 28 | [3.], 29 | ]) 30 | model.rel_emb_im.weight.data = torch.tensor([ 31 | [3.], 32 | [1.], 33 | ]) 34 | 35 | head_index = torch.tensor([1, 3]) 36 | rel_type = torch.tensor([1, 0]) 37 | tail_index = torch.tensor([2, 4]) 38 | 39 | loader = model.loader(head_index, rel_type, tail_index, batch_size=5) 40 | kge_harness(model, loader) 41 | 42 | 43 | def test_complex(): 44 | model = ComplEx(num_nodes=10, num_relations=5, hidden_channels=32) 45 | assert str(model) == 'ComplEx(10, num_relations=5, hidden_channels=32)' 46 | 47 | head_index = torch.tensor([0, 2, 4, 6, 8]) 48 | rel_type = torch.tensor([0, 1, 2, 3, 4]) 49 | tail_index = torch.tensor([1, 3, 5, 7, 9]) 50 | 51 | loader = model.loader(head_index, rel_type, tail_index, batch_size=5) 52 | kge_harness(model, loader) 53 | -------------------------------------------------------------------------------- /tests/gnn/nn/kge/test_distmult.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | import torch 4 | 5 | from torch_geometric.nn import DistMult 6 | from kge_utils import kge_harness 7 | 8 | 9 | def test_distmult(): 10 | model = DistMult(num_nodes=10, num_relations=5, hidden_channels=32) 11 | assert str(model) == 'DistMult(10, num_relations=5, hidden_channels=32)' 12 | 13 | head_index = torch.tensor([0, 2, 4, 6, 8]) 14 | rel_type = torch.tensor([0, 1, 2, 3, 4]) 15 | tail_index = torch.tensor([1, 3, 5, 7, 9]) 16 | 17 | loader = model.loader(head_index, rel_type, tail_index, batch_size=5) 18 | kge_harness(model, loader) 19 | -------------------------------------------------------------------------------- /tests/gnn/nn/kge/test_rotate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | import torch 4 | 5 | from torch_geometric.nn import RotatE 6 | from kge_utils import kge_harness 7 | 8 | 9 | def test_rotate(): 10 | model = RotatE(num_nodes=10, num_relations=5, hidden_channels=32) 11 | assert str(model) == 'RotatE(10, num_relations=5, hidden_channels=32)' 12 | 13 | head_index = torch.tensor([0, 2, 4, 6, 8]) 14 | rel_type = torch.tensor([0, 1, 2, 3, 4]) 15 | tail_index = torch.tensor([1, 3, 5, 7, 9]) 16 | 17 | loader = model.loader(head_index, rel_type, tail_index, batch_size=5) 18 | kge_harness(model, loader) 19 | -------------------------------------------------------------------------------- /tests/gnn/nn/kge/test_transe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | import torch 4 | 5 | from torch_geometric.nn import TransE 6 | from kge_utils import kge_harness 7 | 8 | 9 | def test_transe(): 10 | model = TransE(num_nodes=10, num_relations=5, hidden_channels=32) 11 | assert str(model) == 'TransE(10, num_relations=5, hidden_channels=32)' 12 | 13 | head_index = torch.tensor([0, 2, 4, 6, 8]) 14 | rel_type = torch.tensor([0, 1, 2, 3, 4]) 15 | tail_index = torch.tensor([1, 3, 5, 7, 9]) 16 | 17 | loader = model.loader(head_index, rel_type, tail_index, batch_size=5) 18 | kge_harness(model, loader) 19 | -------------------------------------------------------------------------------- /tests/gnn/nn/norm/norm_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | import torch 4 | 5 | import helpers 6 | from torch_geometric.data import Batch, Data 7 | 8 | from gnn.nn.nn_utils import op_harness 9 | 10 | 11 | def assert_(native_out, poptorch_out): 12 | def check_inner_field(x, y): 13 | assert isinstance(x, type(y)), \ 14 | f"x type={type(x)} is different than y type={type(y)}" 15 | if isinstance(x, torch.Tensor): 16 | helpers.assert_allclose(actual=x, 17 | expected=y, 18 | atol=1e-04, 19 | rtol=1e-04, 20 | equal_nan=True) 21 | elif isinstance(x, (list, tuple)): 22 | for t, ct in zip(x, y): 23 | check_inner_field(t, ct) 24 | elif isinstance(x, (Batch, Data)): 25 | assert x.keys == y.keys, "Objects have different keys." 26 | for k in x.keys: 27 | check_inner_field(x[k], y[k]) 28 | elif x is not None: 29 | assert False, f"Unsupported types: x type={type(x)}, y type=" \ 30 | f"{type(y)}" 31 | 32 | check_inner_field(native_out, poptorch_out) 33 | 34 | 35 | def norm_harness(op, inputs, assert_func=None, inference=False): 36 | 37 | if assert_func is None: 38 | assert_func = assert_ 39 | poptorch_out = op_harness(op, inputs, assert_func, inference) 40 | 41 | return poptorch_out 42 | -------------------------------------------------------------------------------- /tests/gnn/nn/norm/test_batch_norm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import pytest 3 | import torch 4 | 5 | from torch_geometric.nn import BatchNorm 6 | 7 | from norm_utils import norm_harness 8 | 9 | 10 | @pytest.mark.parametrize('conf', [True, False]) 11 | def test_batch_norm(conf): 12 | x = torch.randn(100, 16) 13 | 14 | norm = BatchNorm(16, affine=conf, track_running_stats=conf) 15 | assert str(norm) == 'BatchNorm(16)' 16 | 17 | out = norm_harness(norm, [x]) 18 | assert out.size() == (100, 16) 19 | 20 | 21 | def test_batch_norm_single_element(): 22 | x = torch.randn(1, 16) 23 | 24 | norm = BatchNorm(16, track_running_stats=True, allow_single_element=True) 25 | assert str(norm) == 'BatchNorm(16)' 26 | 27 | out = norm_harness(norm, [x], inference=True) 28 | assert torch.allclose(out, x) 29 | -------------------------------------------------------------------------------- /tests/gnn/nn/norm/test_diff_group_norm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import torch 3 | 4 | from norm_utils import norm_harness 5 | 6 | from torch_geometric.nn import DiffGroupNorm 7 | 8 | 9 | def test_diff_group_norm(): 10 | x = torch.randn(6, 16) 11 | 12 | norm = DiffGroupNorm(16, groups=4, lamda=0.01) 13 | assert str(norm) == 'DiffGroupNorm(16, groups=4)' 14 | 15 | out = norm_harness(norm, [x]) 16 | assert out.size() == x.size() 17 | -------------------------------------------------------------------------------- /tests/gnn/nn/norm/test_graph_norm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import torch 3 | 4 | from norm_utils import norm_harness 5 | 6 | from torch_geometric.nn import GraphNorm 7 | 8 | 9 | def test_graph_norm(): 10 | torch.manual_seed(42) 11 | x = torch.randn(200, 16) 12 | batch = torch.arange(4).view(-1, 1).repeat(1, 50).view(-1) 13 | batch_size = int(batch.max() + 1) 14 | 15 | norm = GraphNorm(16) 16 | 17 | norm_harness(norm, [x]) 18 | norm_harness(norm, [x, batch, batch_size]) 19 | -------------------------------------------------------------------------------- /tests/gnn/nn/norm/test_graph_size_norm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import torch 3 | 4 | from norm_utils import norm_harness 5 | 6 | from torch_geometric.nn import GraphSizeNorm 7 | 8 | 9 | def test_graph_size_norm(): 10 | x = torch.randn(100, 16) 11 | batch = torch.repeat_interleave(torch.full((10, ), 10, dtype=torch.long)) 12 | batch_size = int(batch.max()) + 1 13 | 14 | norm = GraphSizeNorm() 15 | assert str(norm) == 'GraphSizeNorm()' 16 | 17 | out = norm_harness(norm, [x, batch, batch_size]) 18 | assert out.size() == (100, 16) 19 | -------------------------------------------------------------------------------- /tests/gnn/nn/norm/test_instance_norm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import pytest 3 | import torch 4 | from torch_geometric.nn import InstanceNorm 5 | 6 | import helpers 7 | from gnn.nn.nn_utils import ModelWW 8 | import poptorch 9 | 10 | 11 | @pytest.mark.parametrize('conf', [True, False]) 12 | def test_instance_norm(conf): 13 | atol = None 14 | rtol = None 15 | if conf is True: 16 | # These values are based on torch_nn_test.py file 17 | # where InstanceNorm is tested from torch package. 18 | atol = 1e-3 19 | rtol = 0.05 20 | 21 | nodes_list = torch.randn(5, 100, 16) 22 | 23 | def test_body(inputs): 24 | 25 | norm = InstanceNorm(16, affine=conf, track_running_stats=conf) 26 | 27 | cpu_model = ModelWW(norm, inputs[0][0].shape) 28 | ipu_model = poptorch.trainingModel(ModelWW(norm, inputs[0][0].shape)) 29 | 30 | for x in inputs[0]: 31 | cpu_out = None 32 | ipu_out = None 33 | if len(inputs) > 1: 34 | model_inputs = [x] + inputs[1:] 35 | cpu_out = cpu_model(model_inputs) 36 | ipu_out = ipu_model(model_inputs) 37 | else: 38 | cpu_out = cpu_model([x]) 39 | ipu_out = ipu_model([x]) 40 | helpers.assert_allclose(actual=ipu_out[0], 41 | expected=cpu_out[0], 42 | atol=atol, 43 | rtol=rtol) 44 | 45 | test_body([nodes_list]) 46 | 47 | batch = torch.zeros(100, dtype=torch.long) 48 | batch_size = 1 49 | test_body([nodes_list, batch, batch_size]) 50 | 51 | batch[:50] = torch.ones(50, dtype=torch.long) 52 | batch_size = 2 53 | test_body([nodes_list, batch, batch_size]) 54 | -------------------------------------------------------------------------------- /tests/gnn/nn/norm/test_layer_norm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import pytest 3 | import torch 4 | 5 | from norm_utils import norm_harness 6 | 7 | from torch_geometric.nn import LayerNorm 8 | 9 | 10 | @pytest.mark.parametrize('affine', [True, False]) 11 | @pytest.mark.parametrize('mode', ['graph', 'node']) 12 | def test_layer_norm(affine, mode): 13 | x = torch.randn(100, 16) 14 | 15 | norm = LayerNorm(16, affine=affine, mode=mode) 16 | 17 | norm_harness(norm, [x]) 18 | 19 | batch = torch.zeros(100, dtype=torch.int64) 20 | batch_size = 1 21 | norm_harness(norm, [x, batch, batch_size]) 22 | 23 | batch_size = 2 24 | norm_harness(norm, [ 25 | torch.cat([x, x], dim=0), 26 | torch.cat([batch, batch + 1], dim=0), batch_size 27 | ]) 28 | -------------------------------------------------------------------------------- /tests/gnn/nn/norm/test_mean_subtraction_norm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import torch 3 | 4 | from norm_utils import norm_harness 5 | 6 | from torch_geometric.nn import MeanSubtractionNorm 7 | 8 | 9 | def test_mean_subtraction_norm_no_batch(): 10 | x = torch.randn(6, 16) 11 | 12 | norm = MeanSubtractionNorm() 13 | assert str(norm) == 'MeanSubtractionNorm()' 14 | 15 | out = norm_harness(norm, [x]) 16 | assert out.size() == (6, 16) 17 | assert torch.allclose(out.mean(), torch.tensor(0.), atol=1e-04) 18 | 19 | 20 | def test_mean_subtraction_norm(): 21 | x = torch.randn(6, 16) 22 | batch = torch.tensor([0, 0, 1, 1, 1, 2]) 23 | 24 | norm = MeanSubtractionNorm() 25 | assert str(norm) == 'MeanSubtractionNorm()' 26 | 27 | out = norm_harness(norm, [x, batch, 3]) 28 | assert out.size() == (6, 16) 29 | assert torch.allclose(out[0:2].mean(), torch.tensor(0.), atol=1e-04) 30 | assert torch.allclose(out[0:2].mean(), torch.tensor(0.), atol=1e-04) 31 | -------------------------------------------------------------------------------- /tests/gnn/nn/norm/test_msg_norm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import torch 3 | 4 | from norm_utils import norm_harness 5 | 6 | from torch_geometric.nn import MessageNorm 7 | 8 | 9 | def test_message_norm(): 10 | norm = MessageNorm(learn_scale=True) 11 | assert str(norm) == 'MessageNorm(learn_scale=True)' 12 | x = torch.randn(100, 16) 13 | msg = torch.randn(100, 16) 14 | 15 | out = norm_harness(norm, [x, msg]) 16 | assert out.size() == (100, 16) 17 | 18 | norm = MessageNorm(learn_scale=False) 19 | assert str(norm) == 'MessageNorm(learn_scale=False)' 20 | out = norm_harness(norm, [x, msg]) 21 | assert out.size() == (100, 16) 22 | -------------------------------------------------------------------------------- /tests/gnn/nn/norm/test_pair_norm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import pytest 3 | import torch 4 | 5 | from norm_utils import norm_harness 6 | 7 | from torch_geometric.nn import PairNorm 8 | 9 | 10 | @pytest.mark.parametrize('scale_individually', [False, True]) 11 | def test_pair_norm_no_batch(scale_individually): 12 | x = torch.randn(100, 16) 13 | 14 | norm = PairNorm(scale_individually=scale_individually) 15 | assert str(norm) == 'PairNorm()' 16 | 17 | out1 = norm_harness(norm, [x]) 18 | assert out1.size() == (100, 16) 19 | 20 | 21 | @pytest.mark.parametrize('scale_individually', [False, True]) 22 | def test_pair_norm(scale_individually): 23 | x = torch.randn(100, 16) 24 | batch = torch.zeros(100, dtype=torch.long) 25 | 26 | norm = PairNorm(scale_individually=scale_individually) 27 | assert str(norm) == 'PairNorm()' 28 | 29 | out1 = norm_harness(norm, [x]) 30 | 31 | batch_size = 2 32 | out2 = norm_harness(norm, [ 33 | torch.cat([x, x], dim=0), 34 | torch.cat([batch, batch + 1], dim=0), batch_size 35 | ]) 36 | assert torch.allclose(out1, out2[:100], atol=1e-04) 37 | assert torch.allclose(out1, out2[100:], atol=1e-04) 38 | -------------------------------------------------------------------------------- /tests/gnn/nn/pool/pool_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import dataclasses 3 | import torch 4 | 5 | import helpers 6 | from torch_geometric.data import Batch, Data 7 | 8 | from gnn.nn.nn_utils import op_harness 9 | 10 | 11 | def assert_(native_out, poptorch_out): 12 | def check_inner_field(x, y): 13 | assert isinstance(x, type(y)), \ 14 | f"x type={type(x)} is different than y type={type(y)}" 15 | if isinstance(x, torch.Tensor): 16 | helpers.assert_allclose(actual=x, 17 | expected=y, 18 | atol=1e-04, 19 | rtol=1e-04, 20 | equal_nan=True) 21 | elif isinstance(x, (list, tuple)): 22 | for t, ct in zip(x, y): 23 | check_inner_field(t, ct) 24 | elif isinstance(x, (Batch, Data)): 25 | assert x.keys == y.keys, "Objects have different keys." 26 | for k in x.keys: 27 | check_inner_field(x[k], y[k]) 28 | elif dataclasses.is_dataclass(x): 29 | for att in dir(x): 30 | x_field = getattr(x, att, None) 31 | if not callable(x_field) and isinstance(x_field, torch.Tensor): 32 | check_inner_field(x_field, getattr(y, att, None)) 33 | elif x is not None: 34 | assert False, f"Unsupported types: x type={type(x)}, y type=" \ 35 | f"{type(y)}" 36 | 37 | check_inner_field(native_out, poptorch_out) 38 | 39 | 40 | def pool_harness(op, inputs, assert_func=None): 41 | 42 | if assert_func is None: 43 | assert_func = assert_ 44 | poptorch_out = op_harness(op, inputs, assert_func) 45 | 46 | return poptorch_out 47 | -------------------------------------------------------------------------------- /tests/gnn/nn/pool/test_asap.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | import pytest 4 | import torch 5 | 6 | from torch_geometric.nn import ASAPooling, GCNConv, GraphConv 7 | 8 | from pool_utils import pool_harness 9 | 10 | 11 | @pytest.mark.skip(reason="TODO(AFS-229, AFS-230, AFS-232, AFS-262)") 12 | def test_asap(): 13 | in_channels = 16 14 | edge_index = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3], 15 | [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2]]) 16 | num_nodes = edge_index.max().item() + 1 17 | x = torch.randn((num_nodes, in_channels)) 18 | 19 | for GNN in [GraphConv, GCNConv]: 20 | pool = ASAPooling(in_channels, 21 | ratio=0.5, 22 | GNN=GNN, 23 | add_self_loops=False) 24 | assert pool.__repr__() == ('ASAPooling(16, ratio=0.5)') 25 | out = pool_harness(pool, [x, edge_index]) 26 | assert out[0].size() == (num_nodes // 2, in_channels) 27 | assert out[1].size() == (2, 2) 28 | 29 | pool = ASAPooling(in_channels, ratio=0.5, GNN=GNN, add_self_loops=True) 30 | assert pool.__repr__() == ('ASAPooling(16, ratio=0.5)') 31 | out = pool_harness(pool, [x, edge_index]) 32 | assert out[0].size() == (num_nodes // 2, in_channels) 33 | assert out[1].size() == (2, 4) 34 | 35 | pool = ASAPooling(in_channels, ratio=2, GNN=GNN, add_self_loops=False) 36 | assert pool.__repr__() == ('ASAPooling(16, ratio=2)') 37 | out = pool_harness(pool, [x, edge_index]) 38 | assert out[0].size() == (2, in_channels) 39 | assert out[1].size() == (2, 2) 40 | -------------------------------------------------------------------------------- /tests/gnn/nn/pool/test_consecutive.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import pytest 3 | import torch 4 | 5 | from torch_geometric.nn.pool.consecutive import consecutive_cluster 6 | 7 | from pool_utils import pool_harness 8 | 9 | 10 | @pytest.mark.skip( 11 | reason="consecutive_cluster uses torch.unique instruction which produces " 12 | "tensor with dynamic shape. This is not supported for Mk2.") 13 | def test_consecutive_cluster(): 14 | src = torch.tensor([8, 2, 10, 15, 100, 1, 100]) 15 | 16 | out, perm = pool_harness(consecutive_cluster, [src]) 17 | assert out.tolist() == [2, 1, 3, 4, 5, 0, 5] 18 | assert perm.tolist() == [5, 1, 0, 2, 3, 6] 19 | -------------------------------------------------------------------------------- /tests/gnn/nn/pool/test_graclus.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | import pytest 4 | import torch 5 | 6 | from torch_geometric.nn import graclus 7 | from torch_geometric.testing import withPackage 8 | 9 | from pool_utils import pool_harness 10 | 11 | 12 | @pytest.mark.skip(reason="TODO(AFS-245)") 13 | @withPackage('torch_cluster') 14 | def test_graclus(): 15 | edge_index = torch.tensor([[0, 1], [1, 0]]) 16 | weight = torch.tensor([1., 1.]) 17 | out = pool_harness(graclus, [edge_index, weight, 2]) 18 | assert out.tolist() == [0, 0] 19 | -------------------------------------------------------------------------------- /tests/gnn/nn/pool/test_pan_pool.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | import pytest 4 | import torch 5 | 6 | from torch_geometric.nn import PANConv, PANPooling 7 | 8 | from pool_utils import pool_harness 9 | 10 | 11 | @pytest.mark.skip(reason="The class is using filter_adj which produces " 12 | "tensors with dynamic shapes. It is not supported " 13 | "on Mk2.") 14 | def test_pan_pooling(): 15 | edge_index = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3], 16 | [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2]]) 17 | num_nodes = edge_index.max().item() + 1 18 | x = torch.randn((num_nodes, 16)) 19 | 20 | conv = PANConv(16, 32, filter_size=2) 21 | pool = PANPooling(32, ratio=0.5) 22 | assert str(pool) == 'PANPooling(32, ratio=0.5, multiplier=1.0)' 23 | 24 | x, M = conv(x, edge_index) 25 | row, col, edge_weight = M.coo() 26 | h, edge_index, edge_weight, _, perm, score = pool_harness( 27 | pool, [x, row, col, edge_weight]) 28 | 29 | assert h.size() == (2, 32) 30 | assert edge_index.size() == (2, 4) 31 | assert edge_weight.size() == (4, ) 32 | assert perm.size() == (2, ) 33 | assert score.size() == (2, ) 34 | -------------------------------------------------------------------------------- /tests/gnn/nn/pool/test_pool_knn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | import torch 4 | from torch_geometric.nn import knn, knn_graph 5 | 6 | import helpers 7 | import poptorch 8 | 9 | 10 | class KnnModel(torch.nn.Module): 11 | def __init__(self, op) -> None: 12 | super().__init__() 13 | self.op = op 14 | 15 | def forward(self, *args, **kwargs): 16 | return self.op(*args, **kwargs) 17 | 18 | 19 | def test_knn(): 20 | x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]]) 21 | batch_x = torch.tensor([0, 0, 0, 0]) 22 | y = torch.Tensor([[-1, 0], [1, 0]]) 23 | batch_y = torch.tensor([0, 0]) 24 | 25 | assign_index_cpu = knn(x, y, 2, batch_x, batch_y) 26 | 27 | model = poptorch.inferenceModel(KnnModel(knn)) 28 | assign_index_ipu = model(x, y, 2, batch_x, batch_y) 29 | 30 | # There is no guarantee that indexes that knn returns must be in any 31 | # particualr order if there are multiple identical elements so we can't 32 | # compare results directly as one can be permutation of the other. 33 | helpers.assert_allequal(actual=assign_index_ipu.sort()[0], 34 | expected=assign_index_cpu.sort()[0]) 35 | 36 | 37 | def test_knn_graph(): 38 | x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]]) 39 | batch = torch.tensor([0, 0, 0, 0]) 40 | 41 | edge_index_cpu = knn_graph(x, k=2, batch=batch, loop=True) 42 | model = poptorch.inferenceModel(KnnModel(knn_graph)) 43 | edge_index_ipu = model(x, k=2, batch=batch, loop=True) 44 | 45 | # There is no guarantee that indexes that knn returns must be in any 46 | # particualr order if there are multiple identical elements so we can't 47 | # compare results directly as one can be permutation of the other. 48 | helpers.assert_allequal(actual=edge_index_cpu.sort()[0], 49 | expected=edge_index_ipu.sort()[0]) 50 | -------------------------------------------------------------------------------- /tests/gnn/nn/pool/test_radius.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | from typing import Optional 4 | 5 | import torch 6 | import torch_geometric 7 | from torch import Tensor 8 | import poptorch 9 | 10 | 11 | def to_set(edge_index): 12 | # pylint: disable=R1721 13 | return {(i, j) for i, j in edge_index.t().tolist()} 14 | 15 | 16 | def assert_fn(native_out, poptorch_out): 17 | poptorch_out = poptorch_out[poptorch_out != -1] 18 | dim = poptorch_out.size(0) // 2 19 | poptorch_out = poptorch_out.reshape((2, dim)) 20 | 21 | native_out = native_out[native_out != -1] 22 | dim = native_out.size(0) // 2 23 | native_out = native_out.reshape((2, dim)) 24 | 25 | assert to_set(poptorch_out) == to_set(native_out) 26 | 27 | 28 | def op_harness(*args, **kwargs): 29 | class Model(torch.nn.Module): 30 | def forward(self, x: Tensor, batch: Optional[Tensor] = None) -> Tensor: 31 | return torch_geometric.nn.radius_graph(x, 32 | r=2.5, 33 | batch=batch, 34 | loop=True) 35 | 36 | native_out = Model()(*args, **kwargs) 37 | model = poptorch.inferenceModel(Model()) 38 | poptorch_out = model(*args, **kwargs) 39 | assert_fn(native_out, poptorch_out) 40 | 41 | 42 | def test_radius_graph(): 43 | 44 | x = torch.tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]], dtype=torch.float) 45 | batch = torch.tensor([0, 0, 0, 0]) 46 | 47 | op_harness(x, batch) 48 | -------------------------------------------------------------------------------- /tests/gnn/nn/pool/test_sag_pool.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | import pytest 4 | import torch 5 | 6 | from torch_geometric.nn import ( 7 | GATConv, 8 | GCNConv, 9 | GraphConv, 10 | SAGEConv, 11 | SAGPooling, 12 | ) 13 | 14 | from pool_utils import pool_harness 15 | 16 | 17 | @pytest.mark.skip(reason="The class is using filter_adj which produces " 18 | "tensors with dynamic shapes. It is not supported " 19 | "on Mk2.") 20 | @pytest.mark.parametrize('GNN', [GraphConv, GCNConv, GATConv, SAGEConv]) 21 | def test_sag_pooling(GNN): 22 | conv_kwargs = {'add_self_loops': False} 23 | 24 | in_channels = 16 25 | edge_index = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3], 26 | [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2]]) 27 | num_nodes = edge_index.max().item() + 1 28 | x = torch.randn((num_nodes, in_channels)) 29 | 30 | pool1 = SAGPooling(in_channels, ratio=0.5, GNN=GNN, **conv_kwargs) 31 | out1 = pool_harness(pool1, [x, edge_index]) 32 | assert out1[0].size() == (num_nodes // 2, in_channels) 33 | assert out1[1].size() == (2, 2) 34 | 35 | pool2 = SAGPooling(in_channels, 36 | ratio=None, 37 | GNN=GNN, 38 | min_score=0.1, 39 | **conv_kwargs) 40 | out2 = pool_harness(pool2, [x, edge_index]) 41 | assert out2[0].size(0) <= x.size(0) and out2[0].size(1) == (16) 42 | assert out2[1].size(0) == 2 and out2[1].size(1) <= edge_index.size(1) 43 | 44 | pool3 = SAGPooling(in_channels, ratio=2, GNN=GNN, **conv_kwargs) 45 | out3 = pool_harness(pool3, [x, edge_index]) 46 | assert out3[0].size() == (2, in_channels) 47 | assert out3[1].size() == (2, 2) 48 | -------------------------------------------------------------------------------- /tests/gnn/nn/pool/test_voxel_grid.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | import torch 4 | 5 | from torch_geometric.data import Batch 6 | from torch_geometric.nn import avg_pool, voxel_grid 7 | from torch_geometric.testing import withPackage 8 | 9 | from pool_utils import pool_harness 10 | 11 | 12 | @withPackage('torch_cluster') 13 | def test_voxel_grid(): 14 | pos = torch.Tensor([[0, 0], [11, 9], [2, 8], [2, 2], [8, 3]]) 15 | batch = torch.tensor([0, 0, 0, 1, 1]) 16 | 17 | out = pool_harness(voxel_grid, [pos, 5, batch]) 18 | assert out.tolist() == [0, 5, 3, 6, 7] 19 | out = pool_harness(voxel_grid, [pos, 5]) 20 | assert out.tolist() == [0, 5, 3, 0, 1] 21 | 22 | 23 | @withPackage('torch_cluster') 24 | def test_voxel_grid_with_optional_args(): 25 | pos = torch.Tensor([[0, 0], [11, 9], [2, 8], [2, 2], [8, 3]]) 26 | batch = torch.tensor([0, 0, 0, 1, 1]) 27 | 28 | cluster = pool_harness(voxel_grid, [pos, 5, batch, -1, [18, 14]]) 29 | assert cluster.tolist() == [0, 10, 4, 16, 17] 30 | 31 | cluster_no_batch = pool_harness(voxel_grid, [pos, 5, None, -1, [18, 14]]) 32 | assert cluster_no_batch.tolist() == [0, 10, 4, 0, 1] 33 | 34 | 35 | @withPackage('torch_cluster') 36 | def test_single_voxel_grid(): 37 | pos = torch.Tensor([[0, 0], [1, 1], [2, 2], [3, 3], [4, 4]]) 38 | edge_index = torch.tensor([[0, 0, 3], [1, 2, 4]]) 39 | batch = torch.tensor([0, 0, 0, 1, 1]) 40 | x = torch.randn(5, 16) 41 | 42 | cluster = pool_harness(voxel_grid, [pos, 5, batch]) 43 | assert cluster.tolist() == [0, 0, 0, 1, 1] 44 | 45 | data = Batch(x=x, edge_index=edge_index, pos=pos, batch=batch) 46 | data = avg_pool(cluster, data) 47 | 48 | cluster_no_batch = pool_harness(voxel_grid, [pos, 5]) 49 | assert cluster_no_batch.tolist() == [0, 0, 0, 0, 0] 50 | -------------------------------------------------------------------------------- /tests/gnn/nn/test_linear.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | from itertools import product 4 | 5 | import pytest 6 | import torch 7 | from torch_geometric.nn import HeteroLinear, Linear 8 | 9 | from dense.dense_utils import dense_harness 10 | 11 | weight_inits = ['glorot', "uniform", 'kaiming_uniform', None] 12 | bias_inits = ['zeros', None] 13 | 14 | 15 | @pytest.mark.parametrize('weight,bias', product(weight_inits, bias_inits)) 16 | def test_linear(weight, bias): 17 | lin = Linear(16, 32, weight_initializer=weight, bias_initializer=bias) 18 | x = torch.randn(1, 4, 16) 19 | 20 | dense_harness(lin, x) 21 | 22 | 23 | @pytest.mark.parametrize('with_bias', [True, False]) 24 | def test_hetero_linear(with_bias): 25 | x = torch.randn(10, 16) 26 | type_vec = torch.tensor([0, 0, 2, 1, 0, 2, 2, 2, 1, 2]) 27 | 28 | lin = HeteroLinear(16, 32, num_types=3, bias=with_bias) 29 | 30 | dense_harness(lin, (x, type_vec)) 31 | -------------------------------------------------------------------------------- /tests/gnn/nn/test_sequential.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | from collections import OrderedDict 4 | from torch.nn import ReLU 5 | from torch_geometric.nn import Sequential, GCNConv, Linear 6 | 7 | from conv.conv_utils import conv_harness 8 | 9 | conv_kwargs = {"add_self_loops": False} 10 | 11 | 12 | def test_sequential(dataset): 13 | out_channels = in_channels = dataset.num_node_features 14 | 15 | model = Sequential('x, edge_index', [ 16 | (GCNConv(in_channels, 64, **conv_kwargs), 'x, edge_index -> x'), 17 | ReLU(inplace=True), 18 | (GCNConv(64, 64, **conv_kwargs), 'x, edge_index -> x'), 19 | ReLU(inplace=True), 20 | Linear(64, out_channels), 21 | ]) 22 | 23 | conv_harness(model, dataset) 24 | 25 | 26 | def test_sequential_with_ordered_dict(dataset): 27 | in_channels = dataset.num_node_features 28 | 29 | model = Sequential('x, edge_index', 30 | modules=OrderedDict([ 31 | ('conv1', (GCNConv(in_channels, 32, **conv_kwargs), 32 | 'x, edge_index -> x')), 33 | ('conv2', (GCNConv(32, 64, **conv_kwargs), 34 | 'x, edge_index -> x')), 35 | ])) 36 | 37 | conv_harness(model, dataset) 38 | -------------------------------------------------------------------------------- /tests/gnn/nn/unpool/test_interpolate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import helpers 3 | 4 | import torch 5 | import torch_geometric 6 | 7 | import poptorch 8 | import poptorch_geometric # pylint: disable=unused-import 9 | 10 | 11 | def test_knn_interpolate(): 12 | x = torch.Tensor([[1], [10], [100], [-1], [-10], [-100]]) 13 | pos_x = torch.Tensor([[-1, 0], [0, 0], [1, 0], [-2, 0], [0, 0], [2, 0]]) 14 | pos_y = torch.Tensor([[-1, -1], [1, 1], [-2, -2], [2, 2]]) 15 | batch_x = torch.tensor([0, 0, 0, 1, 1, 1]) 16 | batch_y = torch.tensor([0, 0, 1, 1]) 17 | k = 2 18 | 19 | class Model(torch.nn.Module): 20 | def forward(self, *args, **kwargs): 21 | return torch_geometric.nn.knn_interpolate(*args, **kwargs) 22 | 23 | model = poptorch.inferenceModel(Model()) 24 | 25 | poptorch_out = model(x, pos_x, pos_y, batch_x, batch_y, k) 26 | torch_geometric_out = torch_geometric.nn.knn_interpolate( 27 | x, pos_x, pos_y, batch_x, batch_y, k) 28 | 29 | helpers.assert_allclose(actual=poptorch_out, expected=torch_geometric_out) 30 | -------------------------------------------------------------------------------- /tests/gnn/ops/test_knn_graph.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | import pytest 4 | import torch 5 | import helpers 6 | 7 | from torch_geometric.nn import knn_graph 8 | from poptorch_geometric.ops.knn_graph import knn_graph as pyg_knn_graph 9 | 10 | import poptorch 11 | 12 | 13 | @pytest.mark.parametrize('flow', ['source_to_target', 'target_to_source']) 14 | def test_knn_graph(flow): 15 | x = torch.Tensor([[1], [10], [100], [-1], [-10], [-100]]) 16 | batch = torch.tensor([0, 0, 0, 1, 1, 1]) 17 | k = 2 18 | 19 | class Model(torch.nn.Module): 20 | def forward(self, *args, **kwargs): 21 | return pyg_knn_graph(*args, **kwargs) 22 | 23 | model = poptorch.inferenceModel(Model()) 24 | 25 | poptorch_out = model(x, k, batch, True, flow) 26 | torch_geometric_out = knn_graph(x, k, batch, True, flow) 27 | pyg_cpu_out = pyg_knn_graph(x, k, batch, True, flow) 28 | 29 | helpers.assert_allclose(actual=poptorch_out, expected=pyg_cpu_out) 30 | helpers.assert_allclose(actual=poptorch_out, expected=torch_geometric_out) 31 | -------------------------------------------------------------------------------- /tests/gnn/ops/test_knn_interpolate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | import helpers 4 | 5 | import torch 6 | from torch_geometric.nn import knn_interpolate 7 | 8 | import poptorch 9 | 10 | from poptorch_geometric.ops.knn_interpolate import knn_interpolate as pyg_knn_interpolate 11 | 12 | 13 | def test_knn_interpolate(): 14 | x = torch.Tensor([[1], [10], [100], [-1], [-10], [-100]]) 15 | pos_x = torch.Tensor([[-1, 0], [0, 0], [1, 0], [-2, 0], [0, 0], [2, 0]]) 16 | pos_y = torch.Tensor([[-1, -1], [1, 1], [-2, -2], [2, 2]]) 17 | batch_x = torch.tensor([0, 0, 0, 1, 1, 1]) 18 | batch_y = torch.tensor([0, 0, 1, 1]) 19 | k = 2 20 | 21 | class Model(torch.nn.Module): 22 | def forward(self, *args, **kwargs): 23 | return pyg_knn_interpolate(*args, **kwargs) 24 | 25 | model = poptorch.inferenceModel(Model()) 26 | 27 | poptorch_out = model(x, pos_x, pos_y, batch_x, batch_y, k) 28 | torch_geometric_out = knn_interpolate(x, pos_x, pos_y, batch_x, batch_y, k) 29 | pyg_cpu_out = pyg_knn_interpolate(x, pos_x, pos_y, batch_x, batch_y, k) 30 | 31 | helpers.assert_allclose(actual=poptorch_out, expected=pyg_cpu_out) 32 | helpers.assert_allclose(actual=poptorch_out, expected=torch_geometric_out) 33 | -------------------------------------------------------------------------------- /tests/gnn/ops/test_to_dense_batch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | import pytest 4 | import torch 5 | import torch_geometric 6 | from torch_geometric.utils import to_dense_batch 7 | 8 | import helpers 9 | import poptorch 10 | 11 | 12 | def op_harness(reference_op, *args, **kwargs): 13 | class Model(torch.nn.Module): 14 | def forward(self, *args, **kwargs): 15 | return torch_geometric.utils.to_dense_batch(*args, **kwargs) 16 | 17 | model = poptorch.inferenceModel(Model()) 18 | 19 | poptorch_out = model(*args, **kwargs) 20 | 21 | native_out = reference_op(*args, **kwargs) 22 | 23 | helpers.assert_allclose(actual=poptorch_out, expected=native_out) 24 | 25 | 26 | def test_basic(): 27 | x = torch.arange(12).view(6, 2) 28 | 29 | op_harness(to_dense_batch, x, batch_size=1, max_num_nodes=11) 30 | 31 | 32 | def test_batch_size_not_set(): 33 | x = torch.arange(12).view(6, 2) 34 | batch = torch.tensor([0, 0, 1, 2, 2, 2]) 35 | 36 | with pytest.raises( 37 | ValueError, 38 | match= 39 | "Dynamic shapes disabled. Argument 'batch_size' needs to be set"): 40 | op_harness(to_dense_batch, x, batch) 41 | 42 | 43 | def test_batch_size_set(): 44 | x = torch.arange(12).view(6, 2) 45 | batch = torch.tensor([0, 0, 1, 2, 2, 2]) 46 | 47 | with pytest.raises( 48 | ValueError, 49 | match= 50 | "Dynamic shapes disabled. Argument 'max_num_nodes' needs to be set" 51 | ): 52 | op_harness(to_dense_batch, x, batch, batch_size=3) 53 | 54 | 55 | def test_batch_size_and_max_num_nodes_set(): 56 | x = torch.arange(12).view(6, 2) 57 | batch = torch.tensor([0, 0, 1, 2, 2, 2]) 58 | batch_size = int(batch.max()) + 1 59 | max_num_nodes = 11 60 | 61 | op_harness(to_dense_batch, 62 | x, 63 | batch, 64 | max_num_nodes=max_num_nodes, 65 | batch_size=batch_size) 66 | -------------------------------------------------------------------------------- /tests/gnn/test_encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | import torch 4 | from torch_geometric.nn import PositionalEncoding, TemporalEncoding 5 | from gnn.nn.nn_utils import op_harness 6 | 7 | 8 | def test_positional_encoding(): 9 | encoder = PositionalEncoding(64) 10 | 11 | x = torch.tensor([1.0, 2.0, 3.0]) 12 | 13 | op_harness(encoder, [x]) 14 | 15 | 16 | def test_temporal_encoding(): 17 | encoder = TemporalEncoding(64) 18 | 19 | x = torch.tensor([1.0, 2.0, 3.0]) 20 | 21 | op_harness(encoder, [x]) 22 | -------------------------------------------------------------------------------- /tests/gnn/test_masker.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | import pytest 3 | import torch 4 | import torch_geometric as pyg 5 | 6 | from poptorch_geometric import masker 7 | 8 | 9 | @pytest.fixture(params=[True, False]) 10 | def entries(request) -> masker.Entries: 11 | """Returns something which looks like an entry""" 12 | pyg.seed_everything(1) 13 | is_tuple = request.param 14 | entry = torch.rand([2, 3, 4]) 15 | return (entry, entry) if is_tuple else entry 16 | 17 | 18 | class TestNoOpMasker: 19 | """Tests the No Op masker, makes sure it does nothing.""" 20 | 21 | @pytest.mark.parametrize("masker_name", ["node", "graph", "edge"]) 22 | def test_masker_does_not_change_the_object(self, masker_name: str, 23 | entries: masker.Entries): 24 | mask = masker.NoMasker() 25 | output_entries = getattr(mask, f"{masker_name}_masker")(entries) 26 | assert entries is output_entries 27 | 28 | 29 | class TestNoOpLayerMasker: 30 | @pytest.fixture 31 | def layer(self): 32 | def layer_function(*args): 33 | total = 0 34 | for arg in args: 35 | total += torch.sum(arg) 36 | return total 37 | 38 | return layer_function 39 | 40 | @pytest.mark.parametrize("masker_name", ["node", "graph", "edge"]) 41 | def test_masker_does_not_change_the_layer_result( 42 | self, 43 | masker_name: str, 44 | entries: masker.Entries, 45 | layer: masker.Layer, 46 | ): 47 | mask = masker.PreLayerMasker(masker=masker.NoMasker()) 48 | masked_layer = getattr(mask, f"{masker_name}_masker")(layer) 49 | if not isinstance(entries, (tuple, list)): 50 | entries = (entries, ) 51 | reference_output = layer(*entries) 52 | masked_output = masked_layer(*entries) 53 | assert reference_output == masked_output, ( 54 | "For the No-op layer masker," + 55 | " the result of a layer should be unchanged") 56 | -------------------------------------------------------------------------------- /tests/gnn/test_register_custom_args.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import unittest 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch_geometric.nn import GCNConv 7 | 8 | import helpers 9 | import poptorch 10 | 11 | 12 | class GCN(torch.nn.Module): 13 | def __init__(self, in_channels: int, out_channels: int): 14 | super().__init__() 15 | self.conv1 = GCNConv(in_channels, 16, add_self_loops=False) 16 | self.conv2 = GCNConv(16, out_channels, add_self_loops=False) 17 | 18 | def forward(self, data): 19 | x = data.x 20 | edge_index = data.edge_index 21 | 22 | x = self.conv1(x, edge_index).relu() 23 | x = F.dropout(x, training=self.training) 24 | x = self.conv2(x, edge_index).relu() 25 | x = F.log_softmax(x, dim=1) 26 | 27 | return x 28 | 29 | 30 | @unittest.mock.patch.dict("os.environ", helpers.disableSmallModel()) 31 | def test_register_custom_parsers(planetoid_cora): 32 | data = planetoid_cora[0] 33 | model = GCN(planetoid_cora.num_node_features, planetoid_cora.num_classes) 34 | model.eval() 35 | poptorch_model = poptorch.inferenceModel(model) 36 | result = poptorch_model(data) 37 | assert result is not None 38 | -------------------------------------------------------------------------------- /tests/gru_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2020 Graphcore Ltd. All rights reserved. 3 | 4 | import pytest 5 | import torch 6 | import helpers 7 | import poptorch 8 | 9 | 10 | @pytest.mark.parametrize("bias", [True, False]) 11 | @pytest.mark.parametrize("batch_first", [True, False]) 12 | def test_gru(bias, batch_first): 13 | length = 1 14 | batches = 3 15 | input_size = 5 16 | hidden_size = 7 17 | 18 | layers = 1 19 | directions = 1 20 | 21 | torch.manual_seed(42) 22 | if batch_first: 23 | inp = torch.randn(batches, length, input_size) 24 | else: 25 | inp = torch.randn(length, batches, input_size) 26 | h0 = torch.randn(layers * directions, batches, hidden_size) 27 | 28 | op = torch.nn.GRU(input_size, 29 | hidden_size, 30 | bias=bias, 31 | batch_first=batch_first) 32 | 33 | out_fn = lambda x: x[0] 34 | model = helpers.ModelWithWeights(op, inp.shape, out_fn) 35 | 36 | poptorch_model = poptorch.trainingModel(model) 37 | 38 | (native_out, native_hn), _ = model((inp, h0)) 39 | (poptorch_out, poptorch_hn), _ = poptorch_model((inp, h0)) 40 | 41 | # Inference test - check outputs 42 | helpers.assert_allclose(actual=poptorch_out, expected=native_out) 43 | helpers.assert_allclose(actual=poptorch_hn, expected=native_hn) 44 | 45 | # Training test - check weights changed 46 | poptorch_model.assert_weights_changed() 47 | -------------------------------------------------------------------------------- /tests/non_contiguous_tensors_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2020 Graphcore Ltd. All rights reserved. 3 | 4 | import torch 5 | import helpers 6 | import poptorch 7 | 8 | 9 | class FiveAdder(torch.nn.Module): 10 | def forward(self, in_1, in_2, in_3, in_4, in_5): 11 | return in_1 + in_2 + in_3 + in_4 + in_5 12 | 13 | 14 | def test_non_contiguous(): 15 | torch.manual_seed(23148) 16 | 17 | model = FiveAdder() 18 | poptorch_model = poptorch.inferenceModel(model) 19 | 20 | OUTER_DIM = 1000 21 | INNER_DIM = 40 22 | 23 | nc1 = torch.randn([OUTER_DIM, INNER_DIM + 1])[:, 0:INNER_DIM] 24 | nc2 = torch.transpose(torch.randn([INNER_DIM, OUTER_DIM]), 0, 1) 25 | nc3 = torch.tensor([1.0]).expand([OUTER_DIM, INNER_DIM]) 26 | 27 | c1 = torch.randn([OUTER_DIM, INNER_DIM]) 28 | c2 = torch.randn([2, OUTER_DIM, INNER_DIM])[0, :, :] 29 | 30 | assert not nc1.is_contiguous() 31 | assert not nc2.is_contiguous() 32 | assert not nc3.is_contiguous() 33 | 34 | assert c1.is_contiguous() 35 | assert c2.is_contiguous() 36 | 37 | native_out = model(nc1, c1, nc2, c2, nc3) 38 | poptorch_out = poptorch_model(nc1, c1, nc2, c2, nc3) 39 | 40 | assert native_out.shape == (OUTER_DIM, INNER_DIM) 41 | 42 | print(native_out) 43 | print(poptorch_out) 44 | 45 | helpers.assert_allclose(actual=poptorch_out, expected=native_out) 46 | -------------------------------------------------------------------------------- /tests/rnn_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 3 | import pytest 4 | import torch 5 | import torch.nn as nn 6 | import helpers 7 | import poptorch 8 | 9 | 10 | @pytest.mark.parametrize("nonlinearity", ['tanh', 'relu']) 11 | @pytest.mark.parametrize("batch_first", [True, False]) 12 | def test_rnn(nonlinearity, batch_first): 13 | torch.manual_seed(42) 14 | num_batches = 10 15 | sequence_length = 5 16 | batch_size = 8 17 | input_size = 4 18 | hidden_size = 3 19 | num_layers = 1 20 | 21 | if batch_first: 22 | input_shape = (batch_size, sequence_length, input_size) 23 | else: 24 | input_shape = (sequence_length, batch_size, input_size) 25 | 26 | inputs = [torch.randn(input_shape) for _ in range(num_batches)] 27 | h = torch.randn((num_layers, batch_size, hidden_size)) 28 | 29 | rnn = nn.RNN( 30 | input_size, 31 | hidden_size, 32 | num_layers, 33 | nonlinearity=nonlinearity, 34 | batch_first=batch_first, 35 | ) 36 | model = helpers.ModelWithWeights(rnn, inputs[0].shape, lambda x: x[0]) 37 | ipu_model = poptorch.trainingModel(model) 38 | 39 | for input in inputs: 40 | (out_cpu, h_cpu), _ = model((input, h)) 41 | (out_ipu, h_ipu), _ = ipu_model((input, h)) 42 | helpers.assert_allclose(actual=out_ipu, expected=out_cpu) 43 | helpers.assert_allclose(actual=h_ipu, expected=h_cpu) 44 | ipu_model.assert_weights_changed() 45 | h = h_cpu 46 | -------------------------------------------------------------------------------- /tests/sharding_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2021 Graphcore Ltd. All rights reserved. 3 | import torch 4 | import helpers 5 | import poptorch 6 | 7 | 8 | def test_sharded_execution(): 9 | class Model(torch.nn.Module): 10 | def forward(self, x): 11 | with poptorch.Block("0", ipu_id=0): 12 | x = x * 2 13 | with poptorch.Block("1", ipu_id=1): 14 | x = x * 3 15 | with poptorch.Block("2", ipu_id=2): 16 | x = x * 4 17 | with poptorch.Block("3", ipu_id=3): 18 | x = x * 5 19 | return x 20 | 21 | native = Model() 22 | stages = [poptorch.Stage(f"{k}") for k in range(0, 4)] 23 | strategy = poptorch.ShardedExecution(*stages) 24 | 25 | opts = poptorch.Options() 26 | opts.setExecutionStrategy(strategy) 27 | ipu = poptorch.inferenceModel(native, opts) 28 | 29 | torch.manual_seed(42) 30 | inp = torch.randn(3, 7) 31 | 32 | native_out = native(inp) 33 | ipu_out = ipu(inp) 34 | helpers.assert_allclose(actual=ipu_out, expected=native_out) 35 | -------------------------------------------------------------------------------- /version.json: -------------------------------------------------------------------------------- 1 | {"major": "3", "minor": "4", "point": "0"} 2 | --------------------------------------------------------------------------------