├── .Rbuildignore ├── .editorconfig ├── .gitattributes ├── .github ├── .gitignore ├── dependabot.yaml └── workflows │ ├── dev-cmd-check.yml │ ├── pkgdown.yml │ └── r-cmd-check.yml ├── .gitignore ├── .ignore ├── .lintr ├── DESCRIPTION ├── LICENSE.md ├── NAMESPACE ├── NEWS.md ├── R ├── CallbackSet.R ├── CallbackSetCheckpoint.R ├── CallbackSetEarlyStopping.R ├── CallbackSetHistory.R ├── CallbackSetLRScheduler.R ├── CallbackSetProgress.R ├── CallbackSetTB.R ├── CallbackSetUnfreeze.R ├── ContextTorch.R ├── DataBackendLazy.R ├── DataDescriptor.R ├── LearnerTorch.R ├── LearnerTorchFeatureless.R ├── LearnerTorchImage.R ├── LearnerTorchMLP.R ├── LearnerTorchModel.R ├── LearnerTorchModule.R ├── LearnerTorchTabResNet.R ├── LearnerTorchVision.R ├── ModelDescriptor.R ├── PipeOpModule.R ├── PipeOpTaskPreprocTorch.R ├── PipeOpTorch.R ├── PipeOpTorchActivation.R ├── PipeOpTorchAdaptiveAvgPool.R ├── PipeOpTorchAvgPool.R ├── PipeOpTorchBatchNorm.R ├── PipeOpTorchBlock.R ├── PipeOpTorchCallbacks.R ├── PipeOpTorchConv.R ├── PipeOpTorchConvTranspose.R ├── PipeOpTorchDropout.R ├── PipeOpTorchFTCLS.R ├── PipeOpTorchFn.R ├── PipeOpTorchHead.R ├── PipeOpTorchIdentity.R ├── PipeOpTorchIngress.R ├── PipeOpTorchLayerNorm.R ├── PipeOpTorchLinear.R ├── PipeOpTorchLoss.R ├── PipeOpTorchMaxPool.R ├── PipeOpTorchMerge.R ├── PipeOpTorchModel.R ├── PipeOpTorchOptimizer.R ├── PipeOpTorchReshape.R ├── PipeOpTorchSoftmax.R ├── PipeOpTorchTokenizer.R ├── Select.R ├── TaskClassif_cifar.R ├── TaskClassif_lazy_iris.R ├── TaskClassif_melanoma.R ├── TaskClassif_mnist.R ├── TaskClassif_tiny_imagenet.R ├── TorchCallback.R ├── TorchDescriptor.R ├── TorchLoss.R ├── TorchOptimizer.R ├── aaa.R ├── bibentries.R ├── cache.R ├── lazy_tensor.R ├── learner_torch_methods.R ├── materialize.R ├── merge_graphs.R ├── multi_tensor_dataset.R ├── nn.R ├── nn_graph.R ├── paramset_torchlearner.R ├── preprocess.R ├── rd_info.R ├── shape.R ├── task_dataset.R ├── utils.R ├── with_torch_settings.R └── zzz.R ├── README.Rmd ├── README.md ├── _pkgdown.yml ├── attic ├── .gitignore ├── LearnerTorch.R ├── PipeOpImageTrafo.R ├── TorchOp.R ├── TorchOpRepeat.R ├── TorchOpTabTokenizer.R ├── Trafo.R ├── dataset │ ├── DataBackendDataset.R │ ├── TaskTorch.R │ ├── datasubset.R │ └── test_DataBackendDataset.R ├── debug-issue-12.R ├── debug-issue-8.R ├── dev-reset-layer.R ├── devel-img-task.R ├── experiments.R ├── image-task.R ├── imageuri.R ├── luz-reprex.R ├── make_image_dataset.R ├── models │ ├── LearnerClassifTabNet.R │ ├── LearnerClassifTabResNet.R │ ├── LearnerRegrTabNet.R │ └── TorchOpTabResNetBlocks.R ├── nn_cls.R ├── nn_rtdl_attention.R ├── operators.R ├── paramset_tabnet.R ├── paramsets_image_trafo.R ├── pot.R ├── reset.R ├── serialize.R ├── task-img.R ├── test-lrn-alexnet.R ├── test_GraphLearnerTorch.R ├── test_LearnerClassifTabResNet.R ├── test_LearnerClassifTorch.R ├── test_PipeOpImageTrafo.R ├── test_TorchOpFlatten.R ├── test_TorchOpInput.R ├── test_TorchOpModel.R ├── test_TorchOpOutput.R ├── test_TorchOpRepeat.R ├── test_TorchOpSelect.R ├── test_TorchOpTabResNetBlocks.R ├── test_TorchOpTabTokenizer.R ├── test_TorchOptimizer.R ├── test_as_dataloader.R ├── test_as_dataset.R ├── test_dataloader.R ├── test_nn_attention.R ├── test_nn_cls.R ├── test_paramset_activation.R ├── test_paramset_torchlearner.R ├── test_paramsets_image_trafo.R ├── test_pot.R ├── test_tabnet.R ├── test_tuning.R ├── testdrive.R ├── threading-repro.R ├── torch_reflections.R ├── torchvision-example-alexnet.R ├── tuning-test.R ├── util-torch.R ├── utils.R ├── utils_task.R └── vignettes │ ├── torchops.qmd │ └── torchops_neo.md ├── benchmarks ├── image_loaders │ └── benchmark_image_loaders.R └── tensor_dataset.R ├── cran-comments.md ├── data-raw ├── cifar.R ├── melanoma.R ├── mnist.R └── tiny_imagenet.R ├── info └── testdata.md ├── inst ├── COPYRIGHTS ├── WORDLIST └── col_info │ ├── cifar10.rds │ ├── cifar100.rds │ ├── melanoma.rds │ ├── mnist.rds │ └── tiny_imagenet.rds ├── man-roxygen ├── field_id.R ├── field_label.R ├── field_man.R ├── field_packages.R ├── field_param_set.R ├── field_phash.R ├── field_task_type.R ├── field_task_types.R ├── learner.R ├── learner_example.R ├── param_callbacks.R ├── param_dataset_shapes.R ├── param_feature_types.R ├── param_id.R ├── param_label.R ├── param_loss.R ├── param_man.R ├── param_module_generator.R ├── param_optimizer.R ├── param_packages.R ├── param_param_set.R ├── param_param_vals.R ├── param_predict_types.R ├── param_properties.R ├── param_task_type.R ├── params_learner.R ├── params_pipelines.R ├── paramset_torchlearner.R ├── pipeop_torch.R ├── pipeop_torch.html ├── pipeop_torch_channels_default.R ├── pipeop_torch_example.R ├── pipeop_torch_state_default.R ├── preprocess_torchvision.R ├── pretrained.R └── task_download.R ├── man ├── DataDescriptor.Rd ├── ModelDescriptor.Rd ├── PipeOpPreprocTorchTrafoNop.Rd ├── PipeOpPreprocTorchTrafoReshape.Rd ├── Select.Rd ├── TorchCallback.Rd ├── TorchDescriptor.Rd ├── TorchIngressToken.Rd ├── TorchLoss.Rd ├── TorchOptimizer.Rd ├── as_data_descriptor.Rd ├── as_lazy_tensor.Rd ├── as_lr_scheduler.Rd ├── as_torch_callback.Rd ├── as_torch_callbacks.Rd ├── as_torch_loss.Rd ├── as_torch_optimizer.Rd ├── assert_lazy_tensor.Rd ├── auto_device.Rd ├── batchgetter_categ.Rd ├── batchgetter_num.Rd ├── callback_set.Rd ├── cross_entropy.Rd ├── equals-.lazy_tensor.Rd ├── figures │ ├── lifecycle-archived.svg │ ├── lifecycle-defunct.svg │ ├── lifecycle-deprecated.svg │ ├── lifecycle-experimental.svg │ ├── lifecycle-maturing.svg │ ├── lifecycle-questioning.svg │ ├── lifecycle-stable.svg │ ├── lifecycle-superseded.svg │ ├── logo.svg │ ├── logo_old.png │ ├── logo_svg_old.svg │ └── mlr3torch.svg ├── infer_shapes.Rd ├── ingress_categ.Rd ├── ingress_ltnsr.Rd ├── ingress_num.Rd ├── is_lazy_tensor.Rd ├── lazy_shape.Rd ├── lazy_tensor.Rd ├── materialize.Rd ├── materialize_internal.Rd ├── mlr3torch-package.Rd ├── mlr3torch_callbacks.Rd ├── mlr3torch_losses.Rd ├── mlr3torch_optimizers.Rd ├── mlr_backends_lazy.Rd ├── mlr_callback_set.Rd ├── mlr_callback_set.checkpoint.Rd ├── mlr_callback_set.history.Rd ├── mlr_callback_set.lr_scheduler.Rd ├── mlr_callback_set.progress.Rd ├── mlr_callback_set.tb.Rd ├── mlr_callback_set.unfreeze.Rd ├── mlr_context_torch.Rd ├── mlr_learners.mlp.Rd ├── mlr_learners.module.Rd ├── mlr_learners.tab_resnet.Rd ├── mlr_learners.torch_featureless.Rd ├── mlr_learners.torchvision.Rd ├── mlr_learners_torch.Rd ├── mlr_learners_torch_image.Rd ├── mlr_learners_torch_model.Rd ├── mlr_pipeops_augment_center_crop.Rd ├── mlr_pipeops_augment_color_jitter.Rd ├── mlr_pipeops_augment_crop.Rd ├── mlr_pipeops_augment_hflip.Rd ├── mlr_pipeops_augment_random_affine.Rd ├── mlr_pipeops_augment_random_choice.Rd ├── mlr_pipeops_augment_random_crop.Rd ├── mlr_pipeops_augment_random_horizontal_flip.Rd ├── mlr_pipeops_augment_random_order.Rd ├── mlr_pipeops_augment_random_resized_crop.Rd ├── mlr_pipeops_augment_random_vertical_flip.Rd ├── mlr_pipeops_augment_resized_crop.Rd ├── mlr_pipeops_augment_rotate.Rd ├── mlr_pipeops_augment_vflip.Rd ├── mlr_pipeops_module.Rd ├── mlr_pipeops_nn_adaptive_avg_pool1d.Rd ├── mlr_pipeops_nn_adaptive_avg_pool2d.Rd ├── mlr_pipeops_nn_adaptive_avg_pool3d.Rd ├── mlr_pipeops_nn_avg_pool1d.Rd ├── mlr_pipeops_nn_avg_pool2d.Rd ├── mlr_pipeops_nn_avg_pool3d.Rd ├── mlr_pipeops_nn_batch_norm1d.Rd ├── mlr_pipeops_nn_batch_norm2d.Rd ├── mlr_pipeops_nn_batch_norm3d.Rd ├── mlr_pipeops_nn_block.Rd ├── mlr_pipeops_nn_celu.Rd ├── mlr_pipeops_nn_conv1d.Rd ├── mlr_pipeops_nn_conv2d.Rd ├── mlr_pipeops_nn_conv3d.Rd ├── mlr_pipeops_nn_conv_transpose1d.Rd ├── mlr_pipeops_nn_conv_transpose2d.Rd ├── mlr_pipeops_nn_conv_transpose3d.Rd ├── mlr_pipeops_nn_dropout.Rd ├── mlr_pipeops_nn_elu.Rd ├── mlr_pipeops_nn_flatten.Rd ├── mlr_pipeops_nn_fn.Rd ├── mlr_pipeops_nn_ft_cls.Rd ├── mlr_pipeops_nn_geglu.Rd ├── mlr_pipeops_nn_gelu.Rd ├── mlr_pipeops_nn_glu.Rd ├── mlr_pipeops_nn_hardshrink.Rd ├── mlr_pipeops_nn_hardsigmoid.Rd ├── mlr_pipeops_nn_hardtanh.Rd ├── mlr_pipeops_nn_head.Rd ├── mlr_pipeops_nn_identity.Rd ├── mlr_pipeops_nn_layer_norm.Rd ├── mlr_pipeops_nn_leaky_relu.Rd ├── mlr_pipeops_nn_linear.Rd ├── mlr_pipeops_nn_log_sigmoid.Rd ├── mlr_pipeops_nn_max_pool1d.Rd ├── mlr_pipeops_nn_max_pool2d.Rd ├── mlr_pipeops_nn_max_pool3d.Rd ├── mlr_pipeops_nn_merge.Rd ├── mlr_pipeops_nn_merge_cat.Rd ├── mlr_pipeops_nn_merge_prod.Rd ├── mlr_pipeops_nn_merge_sum.Rd ├── mlr_pipeops_nn_prelu.Rd ├── mlr_pipeops_nn_reglu.Rd ├── mlr_pipeops_nn_relu.Rd ├── mlr_pipeops_nn_relu6.Rd ├── mlr_pipeops_nn_reshape.Rd ├── mlr_pipeops_nn_rrelu.Rd ├── mlr_pipeops_nn_selu.Rd ├── mlr_pipeops_nn_sigmoid.Rd ├── mlr_pipeops_nn_softmax.Rd ├── mlr_pipeops_nn_softplus.Rd ├── mlr_pipeops_nn_softshrink.Rd ├── mlr_pipeops_nn_softsign.Rd ├── mlr_pipeops_nn_squeeze.Rd ├── mlr_pipeops_nn_tanh.Rd ├── mlr_pipeops_nn_tanhshrink.Rd ├── mlr_pipeops_nn_threshold.Rd ├── mlr_pipeops_nn_tokenizer_categ.Rd ├── mlr_pipeops_nn_tokenizer_num.Rd ├── mlr_pipeops_nn_unsqueeze.Rd ├── mlr_pipeops_preproc_torch.Rd ├── mlr_pipeops_torch.Rd ├── mlr_pipeops_torch_callbacks.Rd ├── mlr_pipeops_torch_ingress.Rd ├── mlr_pipeops_torch_ingress_categ.Rd ├── mlr_pipeops_torch_ingress_ltnsr.Rd ├── mlr_pipeops_torch_ingress_num.Rd ├── mlr_pipeops_torch_loss.Rd ├── mlr_pipeops_torch_model.Rd ├── mlr_pipeops_torch_model_classif.Rd ├── mlr_pipeops_torch_model_regr.Rd ├── mlr_pipeops_torch_optimizer.Rd ├── mlr_pipeops_trafo_adjust_brightness.Rd ├── mlr_pipeops_trafo_adjust_gamma.Rd ├── mlr_pipeops_trafo_adjust_hue.Rd ├── mlr_pipeops_trafo_adjust_saturation.Rd ├── mlr_pipeops_trafo_grayscale.Rd ├── mlr_pipeops_trafo_normalize.Rd ├── mlr_pipeops_trafo_pad.Rd ├── mlr_pipeops_trafo_resize.Rd ├── mlr_pipeops_trafo_rgb_to_grayscale.Rd ├── mlr_tasks_cifar.Rd ├── mlr_tasks_lazy_iris.Rd ├── mlr_tasks_melanoma.Rd ├── mlr_tasks_mnist.Rd ├── mlr_tasks_tiny_imagenet.Rd ├── model_descriptor_to_learner.Rd ├── model_descriptor_to_module.Rd ├── model_descriptor_union.Rd ├── nn.Rd ├── nn_ft_cls.Rd ├── nn_geglu.Rd ├── nn_graph.Rd ├── nn_merge_cat.Rd ├── nn_merge_prod.Rd ├── nn_merge_sum.Rd ├── nn_reglu.Rd ├── nn_reshape.Rd ├── nn_squeeze.Rd ├── nn_tokenizer_categ.Rd ├── nn_tokenizer_num.Rd ├── nn_unsqueeze.Rd ├── output_dim_for.Rd ├── pipeop_preproc_torch.Rd ├── replace_head.Rd ├── t_clbk.Rd ├── t_loss.Rd ├── t_opt.Rd ├── task_dataset.Rd └── torch_callback.Rd ├── mlr3torch.Rproj ├── paper └── finetuning.R ├── pkgdown └── favicon │ ├── apple-touch-icon-120x120.png │ ├── apple-touch-icon-152x152.png │ ├── apple-touch-icon-180x180.png │ ├── apple-touch-icon-60x60.png │ ├── apple-touch-icon-76x76.png │ ├── apple-touch-icon.png │ ├── favicon-16x16.png │ ├── favicon-32x32.png │ └── favicon.ico ├── tests ├── testthat.R └── testthat │ ├── assets │ ├── nano_dogs_vs_cats │ │ ├── cat.1.jpg │ │ ├── cat.2.jpg │ │ ├── cat.3.jpg │ │ ├── cat.4.jpg │ │ ├── cat.5.jpg │ │ ├── dog.1.jpg │ │ ├── dog.2.jpg │ │ ├── dog.3.jpg │ │ ├── dog.4.jpg │ │ └── dog.5.jpg │ ├── nano_imagenet │ │ ├── torch_n04456115_0.JPEG │ │ ├── torch_n04456115_1.JPEG │ │ ├── torch_n04456115_10.JPEG │ │ ├── torch_n04456115_100.JPEG │ │ ├── torch_n04456115_101.JPEG │ │ ├── wok_n04596742_0.JPEG │ │ ├── wok_n04596742_1.JPEG │ │ ├── wok_n04596742_10.JPEG │ │ ├── wok_n04596742_100.JPEG │ │ └── wok_n04596742_101.JPEG │ └── nano_mnist │ │ ├── 0-1.jpg │ │ ├── 0-14481.jpg │ │ ├── 1-10006.jpg │ │ ├── 1-10007.jpg │ │ ├── 2-10009.jpg │ │ ├── 2-10016.jpg │ │ └── data.rds │ ├── helper_autotest.R │ ├── helper_functions.R │ ├── helper_learner.R │ ├── helper_mlr3.R │ ├── helper_mlr3pipelines.R │ ├── helper_module.R │ ├── helper_pipeop.R │ ├── helper_tasks.R │ ├── setup.R │ ├── teardown.R │ ├── test_CallbackSet.R │ ├── test_CallbackSetCheckpoint.R │ ├── test_CallbackSetHistory.R │ ├── test_CallbackSetLRScheduler.R │ ├── test_CallbackSetProgress.R │ ├── test_CallbackSetTB.R │ ├── test_CallbackSetUnfreeze.R │ ├── test_DataBackendLazy.R │ ├── test_DataDescriptor.R │ ├── test_LearnerTorch.R │ ├── test_LearnerTorchFeatureless.R │ ├── test_LearnerTorchImage.R │ ├── test_LearnerTorchMLP.R │ ├── test_LearnerTorchModel.R │ ├── test_LearnerTorchModule.R │ ├── test_LearnerTorchTabResNet.R │ ├── test_LearnerTorchVision.R │ ├── test_ModelDescriptor.R │ ├── test_PipeOpModule.R │ ├── test_PipeOpTaskPreprocTorch.R │ ├── test_PipeOpTorch.R │ ├── test_PipeOpTorchActivation.R │ ├── test_PipeOpTorchAdaptiveAvgPool.R │ ├── test_PipeOpTorchAvgPool.R │ ├── test_PipeOpTorchBatchNorm.R │ ├── test_PipeOpTorchBlock.R │ ├── test_PipeOpTorchCallbacks.R │ ├── test_PipeOpTorchConv.R │ ├── test_PipeOpTorchConvTranspose.R │ ├── test_PipeOpTorchDropout.R │ ├── test_PipeOpTorchFTCLS.R │ ├── test_PipeOpTorchFn.R │ ├── test_PipeOpTorchHead.R │ ├── test_PipeOpTorchIdentity.R │ ├── test_PipeOpTorchIngress.R │ ├── test_PipeOpTorchLayerNorm.R │ ├── test_PipeOpTorchLinear.R │ ├── test_PipeOpTorchLoss.R │ ├── test_PipeOpTorchMaxPool.R │ ├── test_PipeOpTorchMerge.R │ ├── test_PipeOpTorchModel.R │ ├── test_PipeOpTorchOptimizer.R │ ├── test_PipeOpTorchReshape.R │ ├── test_PipeOpTorchTokenizer.R │ ├── test_Select.R │ ├── test_TaskClassif_cifar.R │ ├── test_TaskClassif_lazy_iris.R │ ├── test_TaskClassif_melanoma.R │ ├── test_TaskClassif_mnist.R │ ├── test_TaskClassif_tiny_imagenet.R │ ├── test_TorchCallback.R │ ├── test_TorchDescriptor.R │ ├── test_TorchLoss.R │ ├── test_TorchOptimizer.R │ ├── test_autotests.R │ ├── test_cache.R │ ├── test_dictionaries.R │ ├── test_helper_tasks.R │ ├── test_lazy_tensor.R │ ├── test_learner_torch_methods.R │ ├── test_materialize.R │ ├── test_multi_tensor_dataset.R │ ├── test_nn.R │ ├── test_nn_graph.R │ ├── test_paramset_torchlearner.R │ ├── test_preprocess.R │ ├── test_shape.R │ ├── test_task_dataset.R │ ├── test_utils.R │ └── test_with_torch_settings.R └── vignettes └── articles ├── callback_list.Rmd ├── callbacks.Rmd ├── get_started.Rmd ├── internals_pipeop_torch.Rmd ├── layer_list.Rmd ├── lazy_tensor.Rmd ├── loss_list.Rmd ├── optimizer_list.Rmd ├── pipeop_torch.Rmd ├── preprocessing_list.Rmd ├── tabular_learner_list.Rmd ├── task_list.Rmd └── vision_learner_list.Rmd /.Rbuildignore: -------------------------------------------------------------------------------- 1 | ^renv$ 2 | ^renv\.lock$ 3 | ^mlr3torch\.Rproj$ 4 | ^\.Rproj\.user$ 5 | ^README\.Rmd$ 6 | ^README.html$ 7 | ^\.github$ 8 | ^LICENSE\.md$ 9 | ^LICENSE$ 10 | ^attic$ 11 | ^man-roxygen$ 12 | ^mlr3torch\.Rcheck$ 13 | ^mlr3torch.*\.tar\.gz$ 14 | ^mlr3torch.*\.tgz$ 15 | ^.pre-commit-config.yaml 16 | ^\.pre-commit-config\.yaml$ 17 | ^bugs.md$ 18 | ^\.vscode$ 19 | ^benchmarks$ 20 | ^info$ 21 | ^_pkgdown\.yml$ 22 | ^docs$ 23 | ^pkgdown$ 24 | ^\.git$ 25 | ^\.editorconfig$ 26 | ^\.lintr$ 27 | ^info$ 28 | ^README_files$ 29 | ^TODO\.md$ 30 | ^\.ignore$ 31 | ^inst/data_scripts$ 32 | ^vignettes/articles$ 33 | data-raw 34 | ^doc$ 35 | ^Meta$ 36 | ^CRAN-SUBMISSION$ 37 | ^paper$ 38 | ^cran-comments\.md$ 39 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | # See http://editorconfig.org 2 | root = true 3 | 4 | [*] 5 | charset = utf-8 6 | end_of_line = lf 7 | insert_final_newline = true 8 | indent_style = space 9 | trim_trailing_whitespace = true 10 | 11 | [*.{r,R,md,Rmd}] 12 | indent_size = 2 13 | 14 | [*.{c,h}] 15 | indent_size = 4 16 | 17 | [*.{cpp,hpp}] 18 | indent_size = 4 19 | 20 | [{NEWS.md,DESCRIPTION,LICENSE}] 21 | max_line_length = 80 22 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.html linguist-detectable=false 2 | -------------------------------------------------------------------------------- /.github/.gitignore: -------------------------------------------------------------------------------- 1 | *.html 2 | -------------------------------------------------------------------------------- /.github/dependabot.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "github-actions" 4 | directory: "/" 5 | schedule: 6 | interval: "weekly" 7 | -------------------------------------------------------------------------------- /.github/workflows/pkgdown.yml: -------------------------------------------------------------------------------- 1 | # pkgdown workflow of the mlr3 ecosystem v0.1.0 2 | # https://github.com/mlr-org/actions 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | release: 11 | types: 12 | - published 13 | workflow_dispatch: 14 | 15 | name: pkgdown 16 | 17 | jobs: 18 | pkgdown: 19 | runs-on: ubuntu-latest 20 | 21 | concurrency: 22 | group: pkgdown-${{ github.event_name != 'pull_request' || github.run_id }} 23 | env: 24 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 25 | TORCH_INSTALL: 1 26 | steps: 27 | - uses: actions/checkout@v4 28 | 29 | - uses: r-lib/actions/setup-pandoc@v2 30 | 31 | - uses: r-lib/actions/setup-r@v2 32 | 33 | - uses: r-lib/actions/setup-r-dependencies@v2 34 | with: 35 | extra-packages: any::pkgdown, local::. 36 | needs: website 37 | 38 | - name: Install template 39 | run: pak::pkg_install("mlr-org/mlr3pkgdowntemplate") 40 | shell: Rscript {0} 41 | 42 | - name: Build site 43 | run: pkgdown::build_site_github_pages(new_process = FALSE, install = FALSE) 44 | shell: Rscript {0} 45 | 46 | - name: Deploy 47 | if: github.event_name != 'pull_request' 48 | uses: JamesIves/github-pages-deploy-action@v4.7.3 49 | with: 50 | clean: false 51 | branch: gh-pages 52 | folder: docs 53 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .Rproj.user 2 | .Rprofile 3 | .Rhistory 4 | mlr3torch.Rcheck/ 5 | mlr3torch*.tar.gz 6 | mlr3torch*.tgz 7 | .vscode/tasks.json 8 | ^README\.html$ 9 | *~ 10 | docs 11 | inst/doc 12 | *.html 13 | **/.DS_Store 14 | /doc/ 15 | /Meta/ 16 | CRAN-SUBMISSION 17 | paper/data 18 | .idea/ 19 | .vsc/ 20 | paper/data -------------------------------------------------------------------------------- /.ignore: -------------------------------------------------------------------------------- 1 | man/ 2 | docs/ 3 | inst/doc/ 4 | attic/ 5 | vignettes/*.html 6 | info/ 7 | -------------------------------------------------------------------------------- /.lintr: -------------------------------------------------------------------------------- 1 | linters: linters_with_defaults( 2 | assignment_linter=NULL, 3 | object_name_linter=NULL, 4 | cyclocomp_linter=NULL, 5 | commented_code_linter=NULL, 6 | line_length_linter=line_length_linter(180), 7 | indentation_linter=NULL, 8 | object_length_linter=object_length_linter(40)) 9 | 10 | -------------------------------------------------------------------------------- /R/CallbackSetEarlyStopping.R: -------------------------------------------------------------------------------- 1 | CallbackSetEarlyStopping = R6Class("CallbackSetEarlyStopping", 2 | inherit = CallbackSet, 3 | lock_objects = FALSE, 4 | public = list( 5 | initialize = function(patience, min_delta) { 6 | self$patience = assert_int(patience, lower = 1L) 7 | self$min_delta = assert_double(min_delta, lower = 0, len = 1L, any.missing = FALSE) 8 | self$stagnation = 0L 9 | self$best_score = NULL 10 | self$epoch_at_best_score = NULL 11 | }, 12 | on_valid_end = function() { 13 | if (is.null(self$ctx$last_scores_valid)) { 14 | return(NULL) 15 | } 16 | if (is.null(self$best_score)) { 17 | self$best_score = self$ctx$last_scores_valid[[1L]] 18 | self$epoch_at_best_score = self$ctx$epoch 19 | return(NULL) 20 | } 21 | multiplier = if (self$ctx$measures_valid[[1L]]$minimize) -1 else 1 22 | improvement = multiplier * (self$ctx$last_scores_valid[[1L]] - self$best_score) 23 | 24 | if (is.na(improvement)) { 25 | lg$warn("Learner %s in epoch %s: Difference between subsequent validation performances is NA", 26 | self$ctx$learner$id, self$ctx$epoch) 27 | return(NULL) 28 | } 29 | 30 | if (improvement <= self$min_delta) { 31 | self$stagnation = self$stagnation + 1L 32 | if (self$stagnation == self$patience) { 33 | self$ctx$terminate = TRUE 34 | } 35 | } else { 36 | self$stagnation = 0 37 | self$best_score = self$ctx$last_scores_valid[[1L]] 38 | self$epoch_at_best_score = self$ctx$epoch 39 | } 40 | }, 41 | state_dict = function() { 42 | list(best_epochs = self$epoch_at_best_score) 43 | } 44 | ) 45 | ) 46 | -------------------------------------------------------------------------------- /R/PipeOpTorchDropout.R: -------------------------------------------------------------------------------- 1 | #' @title Dropout 2 | #' @inherit torch::nnf_dropout description 3 | #' @section nn_module: 4 | #' Calls [`torch::nn_dropout()`] when trained. 5 | #' @section Parameters: 6 | #' * `p` :: `numeric(1)`\cr 7 | #' Probability of an element to be zeroed. Default: 0.5. 8 | #' * `inplace` :: `logical(1)`\cr 9 | #' If set to `TRUE`, will do this operation in-place. Default: `FALSE`. 10 | #' 11 | #' @templateVar id nn_dropout 12 | #' @template pipeop_torch_channels_default 13 | #' @template pipeop_torch 14 | #' @template pipeop_torch_example 15 | #' 16 | #' 17 | #' @export 18 | PipeOpTorchDropout = R6Class("PipeOpTorchDropout", 19 | inherit = PipeOpTorch, 20 | public = list( 21 | #' @description 22 | #' Creates a new instance of this [R6][R6::R6Class] class. 23 | #' @template params_pipelines 24 | initialize = function(id = "nn_dropout", param_vals = list()) { 25 | param_set = ps( 26 | p = p_dbl(default = 0.5, lower = 0, upper = 1, tags = "train"), 27 | inplace = p_lgl(default = FALSE, tags = "train") 28 | ) 29 | super$initialize( 30 | id = id, 31 | param_set = param_set, 32 | param_vals = param_vals, 33 | module_generator = nn_dropout 34 | ) 35 | } 36 | ) 37 | ) 38 | 39 | #' @include aaa.R 40 | register_po("nn_dropout", PipeOpTorchDropout) 41 | -------------------------------------------------------------------------------- /R/PipeOpTorchIdentity.R: -------------------------------------------------------------------------------- 1 | #' @title Identity Layer 2 | #' @inherit torch::nn_identity description 3 | #' @section nn_module: 4 | #' Calls [`torch::nn_identity()`] when trained, which passes the input unchanged to the output. 5 | #' 6 | #' @templateVar id nn_identity 7 | #' @template pipeop_torch_channels_default 8 | #' @template pipeop_torch 9 | #' @template pipeop_torch_example 10 | #' 11 | #' @export 12 | PipeOpTorchIdentity = R6Class("PipeOpTorchIdentity", 13 | inherit = PipeOpTorch, 14 | public = list( 15 | #' @description Creates a new instance of this [R6][R6::R6Class] class. 16 | #' @template params_pipelines 17 | initialize = function(id = "nn_identity", param_vals = list()) { 18 | param_set = ps() 19 | super$initialize( 20 | id = id, 21 | param_set = param_set, 22 | param_vals = param_vals, 23 | module_generator = nn_identity 24 | ) 25 | } 26 | ), 27 | private = list( 28 | .shape_dependent_params = function(shapes_in, param_vals, task) { 29 | param_vals 30 | }, 31 | .shapes_out = function(shapes_in, param_vals, task) { 32 | shapes_in 33 | } 34 | ) 35 | ) 36 | 37 | #' @include aaa.R 38 | register_po("nn_identity", PipeOpTorchIdentity) -------------------------------------------------------------------------------- /R/PipeOpTorchSoftmax.R: -------------------------------------------------------------------------------- 1 | #' @title Softmax 2 | #' @inherit torch::nnf_softmax description 3 | #' @section nn_module: 4 | #' Calls [`torch::nn_softmax()`] when trained. 5 | #' @section Parameters: 6 | #' * `dim` :: `integer(1)`\cr 7 | #' A dimension along which Softmax will be computed (so every slice along dim will sum to 1). 8 | #' 9 | #' @templateVar id nn_softmax 10 | #' @template pipeop_torch_channels_default 11 | #' @template pipeop_torch 12 | #' @template pipeop_torch_example 13 | #' 14 | #' 15 | #' @export 16 | PipeOpTorchSoftmax = R6::R6Class("PipeOpTorchSoftmax", 17 | inherit = PipeOpTorch, 18 | public = list( 19 | #' @description 20 | #' Creates a new instance of this [R6][R6::R6Class] class. 21 | #' @template params_pipelines 22 | initialize = function(id = "nn_softmax", param_vals = list()) { 23 | param_set = ps( 24 | dim = p_int(1L, Inf, tags = c("train", "required")) 25 | ) 26 | super$initialize( 27 | id = id, 28 | module_generator = nn_softmax, 29 | param_set = param_set, 30 | param_vals = param_vals 31 | ) 32 | } 33 | ) 34 | ) 35 | 36 | #' @include aaa.R 37 | register_po("nn_softmax", PipeOpTorchSoftmax) 38 | -------------------------------------------------------------------------------- /R/TaskClassif_lazy_iris.R: -------------------------------------------------------------------------------- 1 | #' @title Iris Classification Task 2 | #' 3 | #' @name mlr_tasks_lazy_iris 4 | #' 5 | #' @format [R6::R6Class] inheriting from [mlr3::TaskClassif]. 6 | #' @include aaa.R 7 | #' 8 | #' @description 9 | #' A classification task for the popular [datasets::iris] data set. 10 | #' Just like the iris task, but the features are represented as one lazy tensor column. 11 | #' 12 | #' @section Construction: 13 | #' ``` 14 | #' tsk("lazy_iris") 15 | #' ``` 16 | #' @source 17 | #' \url{https://en.wikipedia.org/wiki/Iris_flower_data_set} 18 | #' 19 | #' @section Properties: 20 | #' `r rd_info_task_torch("lazy_iris", missings = FALSE)` 21 | #' 22 | #' @references 23 | #' `r format_bib("anderson_1936")` 24 | #' @examplesIf torch::torch_is_installed() 25 | #' task = tsk("lazy_iris") 26 | #' task 27 | #' df = task$data() 28 | #' materialize(df$x[1:6], rbind = TRUE) 29 | NULL 30 | 31 | load_task_lazy_iris = function(id = "lazy_iris") { 32 | d = load_dataset("iris", "datasets") 33 | target = d[, 5] 34 | features = as.matrix(d[, -5]) 35 | 36 | d = data.table( 37 | Species = target, 38 | x = as_lazy_tensor(features) 39 | ) 40 | 41 | b = as_data_backend(d) 42 | task = TaskClassif$new(id, b, target = "Species", label = "Iris Flowers") 43 | b$hash = task$man = "mlr3torch::mlr_tasks_lazy_iris" 44 | task 45 | } 46 | 47 | #' @include aaa.R 48 | mlr3torch_tasks[["lazy_iris"]] = load_task_lazy_iris 49 | -------------------------------------------------------------------------------- /R/aaa.R: -------------------------------------------------------------------------------- 1 | register_po = function(name, constructor, metainf = NULL) { 2 | if (name %in% names(mlr3torch_pipeops)) stopf("pipeop %s registered twice", name) 3 | mlr3torch_pipeops[[name]] = list(constructor = constructor, metainf = substitute(metainf)) 4 | } 5 | 6 | register_learner = function(.name, .constructor, ...) { 7 | assert_multi_class(.constructor, c("function", "R6ClassGenerator")) 8 | if (is.function(.constructor)) { 9 | mlr3torch_learners[[.name]] = list(fn = .constructor, prototype_args = list(...)) 10 | return(NULL) 11 | } 12 | task_type = if (startsWith(.name, "classif")) "classif" else "regr" 13 | # What I am doing here: 14 | # The problem is that we wan't to set the task_type when creating the learner from the dictionary 15 | # The initial idea was to add functions function(...) LearnerClass$new(..., task_type = "") 16 | # This did not work because mlr3misc does not work with ... arguments (... arguments are not 17 | # passed further to the initialize method) 18 | # For this reason, we need this hacky solution here, might change in the future in mlr3misc 19 | fn = crate(function() { 20 | invoke(.constructor$new, task_type = task_type, .args = as.list(match.call()[-1])) 21 | }, .constructor, task_type, .parent = topenv()) 22 | fmls = formals(.constructor$public_methods$initialize) 23 | fmls$task_type = NULL 24 | formals(fn) = fmls 25 | if (.name %in% names(mlr3torch_learners)) stopf("learner %s registered twice", .name) 26 | mlr3torch_learners[[.name]] = list(fn = fn, prototype_args = list(...)) 27 | } 28 | 29 | register_task = function(name, constructor) { 30 | if (name %in% names(mlr3torch_tasks)) stopf("task %s registered twice", name) 31 | mlr3torch_tasks[[name]] = constructor 32 | } 33 | 34 | mlr3torch_pipeops = new.env() 35 | mlr3torch_learners = new.env() 36 | mlr3torch_tasks = new.env() 37 | -------------------------------------------------------------------------------- /R/merge_graphs.R: -------------------------------------------------------------------------------- 1 | merge_graphs = function(g1, g2) { 2 | graph = g1$clone(deep = FALSE) 3 | 4 | # if graphs are identical, we don't need to worry about copying stuff 5 | if (!identical(g1, g2)) { 6 | # PipeOps that have the same ID that occur in both graphs must be identical. 7 | common_names = intersect(names(graph$pipeops), names(g2$pipeops)) 8 | if (!identical(graph$pipeops[common_names], g2$pipeops[common_names])) { 9 | not_identical = map_lgl(common_names, function(name) { 10 | !identical(graph$pipeops[[name]], g2$pipeops[[name]]) 11 | }) 12 | stopf("Both graphs have PipeOps with ID(s) %s but they are not identical.", 13 | paste0("'", common_names[not_identical], "'", collapse = ", ") 14 | ) 15 | } 16 | 17 | # copy all PipeOps that are in g2 but not in g1 18 | graph$pipeops = c(graph$pipeops, g2$pipeops[setdiff(names(g2$pipeops), common_names)]) 19 | 20 | # clear param_set cache 21 | graph$.__enclos_env__$private$.param_set = NULL 22 | 23 | # edges that are in md2's graph that were not in md1's graph 24 | new_edges = g2$edges[!graph$edges, on = c("src_id", "src_channel", "dst_id", "dst_channel")] 25 | 26 | # IDs and channel names that get new input edges. These channels must not already have incoming edges in md1. 27 | new_input_edges = unique(new_edges[, c("dst_id", "dst_channel"), with = FALSE]) 28 | 29 | forbidden_edges = graph$edges[new_input_edges, on = c("dst_id", "dst_channel"), nomatch = NULL] 30 | 31 | if (nrow(forbidden_edges)) { 32 | stopf("PipeOp(s) %s have differing incoming edges in g1 and g2", 33 | paste(forbidden_edges$dst_id, collapse = ", ")) 34 | 35 | } 36 | graph$edges = rbind(graph$edges, new_edges) 37 | } 38 | 39 | return(graph) 40 | } 41 | -------------------------------------------------------------------------------- /R/multi_tensor_dataset.R: -------------------------------------------------------------------------------- 1 | multi_tensor_dataset = dataset("multi_tensor_dataset", 2 | initialize = function(dataset, device = "cpu") { 3 | assert_class(dataset, "dataset") 4 | # the return of dataset is list(x = list, y = torch_float, .index = torch_long/integer) 5 | self$data = if (!is.null(dataset$.getbatch)) { 6 | dataset$.getbatch(seq_len(length(dataset))) 7 | } else { 8 | batches = lapply(seq_len(length(dataset)), dataset$.getitem) 9 | list( 10 | x = set_names( 11 | map(names(batches[[1]]$x), function(nm) torch_stack(map(batches, function(batch) batch$x[[nm]]))), 12 | nm = names(batches[[1]]$x) 13 | ), 14 | y = torch_stack(map(batches, "y")), 15 | .index = do.call(c, map(batches, ".index")) 16 | ) 17 | } 18 | self$data$x = lapply(self$data$x, function(x) x$to(device = device)) 19 | self$data$y = self$data$y$to(device = device) 20 | }, 21 | .getbatch = function(i) { 22 | list( 23 | x = map(self$data$x, function(x) x[i, drop = FALSE]), 24 | y = self$data$y[i, drop = FALSE], 25 | .index = self$data$.index[i, drop = FALSE] 26 | ) 27 | }, 28 | .length = function() { 29 | nrow(self$data$x[[1L]]) 30 | } 31 | ) 32 | -------------------------------------------------------------------------------- /R/nn.R: -------------------------------------------------------------------------------- 1 | #' @title Create a Neural Network Layer 2 | #' @description 3 | #' Retrieve a neural network layer from the 4 | #' [`mlr_pipeops`][mlr3pipelines::mlr_pipeops] dictionary. 5 | #' @param .key (`character(1)`)\cr 6 | #' @param ... (any)\cr 7 | #' Additional parameters, constructor arguments or fields. 8 | #' @export 9 | #' @examples 10 | #' po1 = po("nn_linear", id = "linear") 11 | #' # is the same as: 12 | #' po2 = nn("linear") 13 | nn = function(.key, ...) { 14 | args = list(...) 15 | if (is.null(args$id)) { 16 | args$id = .key 17 | } 18 | invoke(po, .obj = paste0("nn_", .key), .args = args) 19 | } 20 | -------------------------------------------------------------------------------- /R/rd_info.R: -------------------------------------------------------------------------------- 1 | rd_info_learner_torch = function(name, task_types = "classif, regr") { 2 | task_types = gsub(" ", "", task_types) 3 | task_types = strsplit(task_types, split = ",")[[1L]] 4 | predict_types = c() 5 | if ("classif" %in% task_types) { 6 | learner = mlr_learners$get(paste0("classif.", name), .prototype = TRUE) 7 | predict_types = c(predict_types, 8 | sprintf(" * classif: %s", paste0("'", learner$predict_types, "'", collapse = ", ")) 9 | ) 10 | } 11 | if ("regr" %in% task_types) { 12 | lrn_regr = mlr_learners$get(paste0("regr.", name), .prototype = TRUE) 13 | predict_types = c(predict_types, 14 | sprintf(" * regr: %s", paste0("'", lrn_regr$predict_types, "'", collapse = ", ")) 15 | ) 16 | } 17 | x = c("", 18 | sprintf("* Supported task types: %s", paste0("'", task_types, "'", collapse = ", ")), 19 | sprintf("* Predict Types:"), 20 | predict_types, 21 | sprintf("* Feature Types: %s", rd_format_string(learner$feature_types)), 22 | sprintf("* Required Packages: %s", rd_format_packages(learner$packages)) 23 | ) 24 | paste(x, collapse = "\n") 25 | } 26 | 27 | rd_info_task_torch = function(name, missings) { 28 | obj = tsk(name) 29 | x = c("", 30 | sprintf("* Task type: %s", rd_format_string(obj$task_type)), 31 | sprintf("* Properties: %s", rd_format_string(obj$properties)), 32 | sprintf("* Has Missings: %s", if (missings) "yes" else "no"), 33 | sprintf("* Target: %s", rd_format_string(obj$target_names)), 34 | sprintf("* Features: %s", rd_format_string(obj$feature_names)), 35 | sprintf("* Data Dimension: %ix%i", obj$backend$nrow, obj$backend$ncol) 36 | ) 37 | paste(x, collapse = "\n") 38 | } 39 | -------------------------------------------------------------------------------- /R/with_torch_settings.R: -------------------------------------------------------------------------------- 1 | with_torch_settings = function(seed, num_threads = 1, num_interop_threads = 1, expr) { 2 | old_num_threads = torch_get_num_threads() 3 | if (running_on_mac()) { 4 | if (!isTRUE(all.equal(num_threads, 1L))) { 5 | lg$warn("Cannot set number of threads on macOS.") 6 | } 7 | } else { 8 | on.exit({torch_set_num_threads(old_num_threads)}, 9 | add = TRUE 10 | ) 11 | torch_set_num_threads(num_threads) 12 | } 13 | 14 | if (num_interop_threads != torch_get_num_interop_threads()) { 15 | result = try(torch::torch_set_num_interop_threads(num_interop_threads), silent = TRUE) 16 | if (inherits(result, "try-error")) { 17 | lg$warn(sprintf("Can only set the interop threads once, keeping the previous value %s", torch_get_num_interop_threads())) 18 | } 19 | } 20 | # sets the seed back when exiting the function 21 | if (!is.null(seed)) { 22 | local_torch_manual_seed(seed) 23 | } 24 | force(expr) 25 | } 26 | -------------------------------------------------------------------------------- /attic/.gitignore: -------------------------------------------------------------------------------- 1 | *.rds 2 | -------------------------------------------------------------------------------- /attic/LearnerTorch.R: -------------------------------------------------------------------------------- 1 | #' @title Classification Torch Learner 2 | #' 3 | #' @name mlr_learners_classif.torch 4 | #' 5 | #' @description 6 | #' Custom torch classification network. 7 | #' 8 | #' 9 | #' @section Parameters: 10 | #' The union of: 11 | #' * The construction `param_set` (is inferred if it is not s) 12 | #' the construction `param_set` and those from [`LearnerClassifTorch`]. 13 | #' @section Internals: 14 | #' TODO 15 | #' 16 | #' @export 17 | #' @include LearnerTorch.R 18 | #' @examples 19 | #' # TODO: 20 | LearnerClassifTorchModule = R6Class("LearnerClassifTorchModule", 21 | inherit = LearnerClassifTorch, 22 | public = list( 23 | #' @description 24 | initialize = function(module, param_set = NULL, optimizer = t_opt("adam"), loss = t_loss("cross_entropy"), 25 | param_vals = list(), feature_types = NULL, dataset) { 26 | private$.module = module 27 | private$.dataset = assert_function(dataset, args = c("task", "param_vals")) 28 | if (is.null(param_set)) { 29 | param_set = inferps(module) 30 | param_set$set_id = "net" 31 | } else { 32 | } 33 | assert_true(FALSE) 34 | 35 | super$initialize( 36 | id = "classif.torch_module", 37 | properties = c("twoclass", "multiclass"), 38 | label = "Torch Module Classifier", 39 | feature_types = feature_types %??% mlr_reflections$task_feature_types, 40 | optimizer = optimizer, 41 | loss = loss, 42 | man = "mlr3torch::mlr_learners_classif.torch", 43 | param_set = param_set 44 | ) 45 | } 46 | ), 47 | private = list( 48 | .network = function(task, param_vals) { 49 | invoke( 50 | private$.module, 51 | task = task, 52 | .args = param_vals 53 | ) 54 | }, 55 | .module = NULL, 56 | .dataset = NULL 57 | ) 58 | ) 59 | 60 | #' @include zzz.R 61 | register_learner("classif.torch_module", LearnerClassifTorchModule) 62 | -------------------------------------------------------------------------------- /attic/PipeOpImageTrafo.R: -------------------------------------------------------------------------------- 1 | #' @title Image Transformation 2 | #' @description 3 | #' image transformation 4 | #' @export 5 | PipeOpImageTrafo = R6Class("TorchOpImageTrafo", 6 | inherit = mlr3pipelines::PipeOpTaskPreprocSimple, 7 | public = list( 8 | initialize = function(id = .trafo, param_vals = list(), .trafo) { 9 | assert_choice(.trafo, torch_reflections$image_trafos) 10 | private$.trafo = .trafo 11 | param_set = paramsets_image_trafo$get(.trafo) 12 | super$initialize( 13 | id = id, 14 | param_set = param_set, 15 | param_vals = param_vals 16 | ) 17 | } 18 | ), 19 | private = list( 20 | .train_task = function(task) { 21 | pars = self$param_set$get_values(tags = "train") 22 | .data = task$backend$.__enclos_env__$private$.data 23 | image_cols = colnames(.data)[map_lgl(.data, function(x) inherits(x, "imageuri"))] 24 | torch_trafo = get_image_trafo(private$.trafo) 25 | trafo = function(img) { 26 | invoke(torch_trafo, img = img, .args = pars) 27 | } 28 | for (image_col in image_cols) { 29 | .data[, (image_col) := transform_imageuri(get(..image_col), ..trafo)] 30 | } 31 | task$backend$.__enclos_env__$private$.data = .data 32 | task 33 | }, 34 | .predict_task = function(task) { 35 | pars = self$param_set$get_values(tags = "predict") 36 | task 37 | }, 38 | .trafo = NULL 39 | ) 40 | ) 41 | -------------------------------------------------------------------------------- /attic/dataset/datasubset.R: -------------------------------------------------------------------------------- 1 | # FIXME: This is adapted from torch I don't fully understand it 2 | datasubset <- dataset("datasubset", 3 | initialize = function(dataset, rows, cols, device = "cpu", structure = NULL) { 4 | if (!is.null(structure)) { 5 | assert_list(structure, types = "character", names = "unique") 6 | assert_permutation(names(structure), c("x", "y")) 7 | } 8 | self$structure = structure 9 | self$x_names = structure$x 10 | self$y_names = structure$y 11 | 12 | self$rows = rows 13 | self$cols = cols 14 | self$dataset = dataset 15 | self$device = assert_choice(device, mlr_reflections$torch$devices) 16 | if (!is.null(dataset$.getbatch)) { 17 | self$.getbatch <- self$.getitem 18 | } 19 | classes = class(dataset) 20 | classes_to_append = classes[classes != "R6"] 21 | class(self) = c(paste0(classes_to_append, "_subset_toway"), class(self)) 22 | }, 23 | .getitem = function(idx) { 24 | # FIXME: disallow .index name for DataBackendDataset 25 | out = self$dataset[self$rows[idx]][self$cols] 26 | out = lapply(out, function(x) x$to(device = self$device)) 27 | if (!is.null(self$structure)) { 28 | list(x = out[self$x_names], y = out[self$y_names], .index = idx) 29 | } else { 30 | c(out, list(.index = idx)) 31 | } 32 | }, 33 | .length = function() { 34 | return(length(self$rows)) 35 | }, 36 | ncol = function() { 37 | return(length(self$cols)) 38 | }, 39 | nrow = function() { 40 | return(length(self$rows)) 41 | } 42 | ) 43 | 44 | -------------------------------------------------------------------------------- /attic/dataset/test_DataBackendDataset.R: -------------------------------------------------------------------------------- 1 | test_that("DataBackendDataset works", { 2 | test_dataset = torch::dataset( 3 | initialize = function(n = 20, shapes = list(feature1 = 3, target = 1)) { 4 | self$n = n 5 | self$tensors = lapply(shapes, function(shape) { 6 | torch_randn(n, shape) 7 | }) 8 | }, 9 | .getitem = function(index) { 10 | lapply(self$tensors, function(feature) feature[index,]) 11 | }, 12 | .length = function() { 13 | self$n 14 | } 15 | ) 16 | 17 | ds = test_dataset() 18 | indices = 1:2 19 | 20 | ds_sub = datasubset( 21 | test_dataset(), 22 | rows = indices, 23 | cols = "target", 24 | device = "cpu" 25 | ) 26 | ds_sub$.getitem(1) 27 | expect_equal(length(indices), length(ds_sub)) 28 | length(ds_sub) 29 | 30 | expect_equal(names(ds_sub$.getitem(1)), "target") 31 | 32 | backend = DataBackendDataset$new( 33 | data = ds, 34 | colnames = c("feature1", "target"), 35 | primary_key = "..my_row_id" 36 | ) 37 | backend 38 | 39 | task = TaskClassifTorch$new(backend, "test_task", "target") 40 | dataset = task$data() 41 | expect_class(dataset, "datasubset") 42 | 43 | learner = lrn("classif.mlp", batch_size = 10, device = "cpu", epochs = 1) 44 | 45 | learner$train(task) 46 | 47 | 48 | 49 | }) 50 | -------------------------------------------------------------------------------- /attic/debug-issue-8.R: -------------------------------------------------------------------------------- 1 | # Test without helper.debugging.R ----------------------------------------- 2 | library(mlr3) 3 | library(mlr3torch) 4 | # Load test scaffolding without helper_debugging.R 5 | lapply(list.files(system.file("testthat", package = "mlr3"), pattern = "^helper.*\\.[rR]$", full.names = TRUE)[-2], source) 6 | learner <- lrn("classif.tabnet") 7 | 8 | # Extracted from run_experiment() 9 | tasks <- generate_tasks(learner, N = 30L) 10 | 11 | # Pick the test which fails 12 | task <- tasks[["feat_all_binary"]] 13 | 14 | # Training works fine 15 | learner$train(task) 16 | 17 | # No errors 18 | learner$errors 19 | 20 | 21 | # With helper_debugging.R -------------------------------------------------- 22 | library(mlr3) 23 | library(mlr3torch) 24 | 25 | # Load mlr3 testthat scaffolding 26 | lapply(list.files(system.file("testthat", package = "mlr3"), pattern = "^helper.*\\.[rR]$", full.names = TRUE), source) 27 | 28 | learner <- lrn("classif.tabnet") 29 | 30 | # Extracted from run_experiment() 31 | tasks <- generate_tasks(learner, N = 30L) 32 | 33 | # Pick the test which fails 34 | task <- tasks[["feat_all_binary"]] 35 | 36 | # Error is triggered on training 37 | learner$train(task) 38 | 39 | -------------------------------------------------------------------------------- /attic/dev-reset-layer.R: -------------------------------------------------------------------------------- 1 | library(mlr3torch) 2 | 3 | # AlexNet 4 | model <- torchvision::model_alexnet(pretrained = TRUE) 5 | model$classifier[[7]]$out_feature 6 | 7 | model <- reset_last_layer(model, num_classes = 10) 8 | model$classifier[[7]]$out_feature 9 | 10 | model <- torchvision::model_resnet18(pretrained = TRUE) 11 | model$fc$out_feature 12 | 13 | model <- reset_last_layer(model, num_classes = 10) 14 | model$fc$out_feature 15 | 16 | 17 | # Reprex maybe for torch(vision) 18 | 19 | 20 | model_alexnet <- torchvision::model_alexnet(pretrained = TRUE) 21 | 22 | model_alexnet 23 | 24 | model_alexnet$classifier 25 | 26 | model_alexnet$classifier[[7]] 27 | 28 | model_alexnet$classifier$`6` <- torch::nn_linear(4096, 10) 29 | 30 | model_alexnet$classifier[[7]] 31 | 32 | model_alexnet$classifier 33 | 34 | # Works as intended 35 | model_alexnet$classifier$`6` <- torch::nn_linear(4096, 10) 36 | -------------------------------------------------------------------------------- /attic/imageuri.R: -------------------------------------------------------------------------------- 1 | #' @title Create an object of class "imageuri" 2 | #' @description 3 | #' Creates a class "images", that contains the uris of images. 4 | #' @param x (`character()`)\cr 5 | #' Character vector containig the paths to the images. 6 | #' @export 7 | imageuri = function(x) { 8 | assert_character(x) 9 | # input = torchvision::transform_to_tensor(input)$to(device = "cpu") 10 | structure( 11 | x, 12 | class = c("imageuri", "character") 13 | ) 14 | } 15 | 16 | #' @export 17 | `[.imageuri` = function(obj, ...) { 18 | imageuri(unclass(obj)[...]) 19 | } 20 | 21 | #' @title Transforms an Imageuri 22 | #' 23 | #' @description 24 | #' Transforms an imageuri. 25 | #' 26 | #' @param x (`imageuri()`)\cr 27 | #' Imageuri vector. 28 | #' @param trafo (`function`)\cr 29 | #' The image transformation. Must be a function with one argument. 30 | #' 31 | #' @export 32 | transform_imageuri = function(x, trafo) { 33 | assert_true(class(x)[[1L]] == "imageuri") 34 | assert_function(trafo, nargs = 1L) 35 | transformed = attr(x, "transformed") 36 | transformed = trafo(transformed) 37 | # transformed = try(trafo(input), silent = TRUE) 38 | if (inherits(transformed, "try-error")) { 39 | stopf("Transformation invalid for example input.") 40 | } 41 | if (!inherits(transformed, "torch_tensor")) { 42 | stopf("Transformation did not produce an object of class 'torch_tensor' for the example input.") 43 | } 44 | attr(x, "trafos") = c(attr(x, "trafos"), trafo) 45 | attr(x, "transformed") = transformed 46 | return(x) 47 | } 48 | 49 | #' @export 50 | print.imageuri = function(x, ...) { 51 | n_trafos = length(attr(x, "trafos")) 52 | catf("") 53 | catf(" * N(Trafos): %d", n_trafos) 54 | catf(" * Data Type: %s", class(attr(x, "transformed"))[[1L]]) 55 | catf(" * Example uri: %s", x[[1L]]) 56 | invisible(x) 57 | } 58 | -------------------------------------------------------------------------------- /attic/make_image_dataset.R: -------------------------------------------------------------------------------- 1 | make_image_dataset = function(task, augmentation = NULL) { 2 | y = task$data(cols = task$target_names)[[1L]] 3 | y = as.integer(y) 4 | y = torch_tensor(y, dtype = torch_long()) 5 | uri = task$data(cols = task$feature_names)[[1L]] 6 | trafos = c(attr(uri, "trafos"), augmentation) 7 | image = attr(uri, "transformed") 8 | assert_true(inherits(image, "magick-image") || inherits(image, "torch_tensor")) 9 | if (inherits(image, "magick-image")) { 10 | trafos = c(trafos, torchvision::transform_to_tensor) 11 | } 12 | 13 | dataset( 14 | name = task$id, 15 | initialize = function(y, uri, trafos) { 16 | self$trafos = trafos 17 | self$y = y 18 | self$uri = uri 19 | }, 20 | # Here we implement .getitem because this probably makes parallelization easier (?) 21 | .getitem = function(index) { 22 | y = self$y[index, drop = TRUE] 23 | uri = self$uri[index] 24 | image = magick::image_read(uri) 25 | for (trafo in self$trafos) { 26 | image = trafo(image) 27 | } 28 | list( 29 | y = y, 30 | x = image 31 | ) 32 | }, 33 | .length = function() { 34 | length(self$y) 35 | } 36 | )(y, uri, trafos) 37 | } 38 | -------------------------------------------------------------------------------- /attic/nn_cls.R: -------------------------------------------------------------------------------- 1 | #' \[CLS\] Token 2 | #' 3 | #' Appends a trainable \[CLS\] Token along the second dimension. 4 | #' 5 | #' @param d_token (`integer(1)`)\cr 6 | #' The token dimensionality. 7 | #' 8 | #' `r format_bib("devlin2018bert", "gorishniy2021revisiting")` 9 | nn_cls = nn_module("nn_cls", 10 | initialize = function(d_token) { 11 | self$weight = nn_parameter(torch_empty(d_token)) 12 | d_sqrt_inv = 1 / sqrt(d_token) 13 | nn_init_uniform_(self$weight, a = -d_sqrt_inv, b = d_sqrt_inv) 14 | }, 15 | forward = function(x) { 16 | cls = self$weight$view(c(1L, -1L))$expand(c(nrow(x), 1L, -1L)) 17 | x = torch_cat(list(x, cls), dim = 2L) 18 | return(x) 19 | } 20 | ) 21 | -------------------------------------------------------------------------------- /attic/pot.R: -------------------------------------------------------------------------------- 1 | #' @title Convenience Function to Retrieve a PipeOp without the "nn" or "torch" Prefix 2 | #' 3 | #' @description 4 | #' Sugar function that allows to construct all [PipeOp]s relating to torch without needing to 5 | #' write out the the prefixes `"torch_"` and `"nn_"`. 6 | #' Note that the assigned id has the prefix. 7 | #' 8 | #' @name pot 9 | #' @export 10 | #' @examples 11 | #' po("nn_linear", id = "linear") 12 | #' # is the same as 13 | #' pot("linear") 14 | #' 15 | #' po("torch_optimizer", "adam", id = "adam") 16 | #' # is the same as 17 | #' pot("optimizer", "adam") 18 | #' 19 | pot = function(.key, ...) { 20 | assert_string(.key, min.chars = 1L) 21 | if (grepl("^nn_", .key) || grepl("^torch_", .key)) { 22 | stopf("You probably wanted po(\"%s\", ...).", .key) 23 | } 24 | hits = 0 25 | torch_key = paste0("torch_", .key) 26 | nn_key = paste0("nn_", .key) 27 | if (torch_key %in% mlr_pipeops$keys()) { 28 | hits = hits + 1 29 | key = torch_key 30 | } 31 | if (nn_key %in% mlr_pipeops$keys()) { 32 | hits = hits + 1 33 | key = nn_key 34 | } 35 | if (hits == 0) { 36 | stopf("PipeOp with neither id %s or %s exists.", torch_key, nn_key) 37 | } else if (hits == 2) { 38 | stopf("PipeOp with both id %s or %s exists, which is ambiguous, use `po(...)` instead.", torch_key, nn_key) 39 | } 40 | po(key, ...) 41 | } 42 | 43 | #' @rdname pot 44 | #' @export 45 | pots = function(.keys, ...) { 46 | objs = map(.keys, pot, ...) 47 | } 48 | -------------------------------------------------------------------------------- /attic/reset.R: -------------------------------------------------------------------------------- 1 | reset_parameters = function(nn) { 2 | f = function(x) { 3 | if (!is.null(x$reset_parameters)) { 4 | x$reset_parameters() 5 | } 6 | } 7 | nn$apply(f) 8 | } 9 | 10 | reset_running_stats = function(nn) { 11 | f = function(x) { 12 | if (!is.null(x$reset_running_stats)) { 13 | x$reset_running_stats() 14 | } 15 | } 16 | nn$apply(f) 17 | } 18 | -------------------------------------------------------------------------------- /attic/serialize.R: -------------------------------------------------------------------------------- 1 | make_serialization = function(network, optimizer, loss_fn) { 2 | serialize_tensors <- function(x, f) { 3 | lapply(x, function(x) { 4 | if (inherits(x, "torch_tensor")) { 5 | torch:::tensor_to_raw_vector_with_class(x) 6 | } 7 | else if (is.list(x)) { 8 | serialize_tensors(x) 9 | } 10 | else { 11 | x 12 | } 13 | }) 14 | } 15 | f = function(x) { 16 | serialized = serialize_tensors(x) 17 | list(values = serialized, type = "list", version = torch:::use_ser_version()) 18 | 19 | } 20 | list( 21 | state_dict = list( 22 | network = f(network$state_dict()), 23 | optimizer = f(optimizer$state_dict()), 24 | loss_fn = f(loss_fn$state_dict()) 25 | ), 26 | prototype 27 | ) 28 | } 29 | serialize_tensors <- function(x, f) { 30 | lapply(x, function(x) { 31 | if (inherits(x, "torch_tensor")) { 32 | torch:::tensor_to_raw_vector_with_class(x) 33 | } 34 | else if (is.list(x)) { 35 | serialize_tensors(x) 36 | } 37 | else { 38 | x 39 | } 40 | }) 41 | } 42 | 43 | torch_load_list = function(obj, device = NULL) { 44 | reload = function(values) { 45 | lapply(values, function(x) { 46 | if (inherits(x, "torch_serialized_tensor")) { 47 | torch:::load_tensor_from_raw(x, device = device) 48 | } else if (is.list(x)) { 49 | reload(x) 50 | } 51 | else { 52 | x 53 | } 54 | }) 55 | } 56 | reload(obj$values) 57 | } 58 | 59 | load_broken_state_dict = function(module, state_dict) { 60 | 61 | } 62 | -------------------------------------------------------------------------------- /attic/task-img.R: -------------------------------------------------------------------------------- 1 | # Experimenting with image tasks 2 | mlr_reflections$task_feature_types[["img"]] <- "imageuri" 3 | # add class "imageuri" to column 4 | 5 | # torchvision dataset format, stores binary files on disk 6 | cifar10_bin <- torchvision::cifar10_dataset(root = "/opt/example-data", download = TRUE) 7 | 8 | # Keras version: Not sure where files are stored 9 | cifar10_keras <- keras::dataset_cifar10() 10 | -------------------------------------------------------------------------------- /attic/test-lrn-alexnet.R: -------------------------------------------------------------------------------- 1 | library(mlr3) 2 | library(mlr3torch) 3 | library(torch) 4 | 5 | # to_device <- function(x, device) x$to(device = device) 6 | 7 | img_task_df <- df_from_imagenet_dir("/opt/example-data/imagenette2-160/train/") 8 | 9 | # Subsample for faster testing, 5 imgs per class 10 | img_task_df <- img_task_df[img_task_df[, .I[sample.int(.N, min(min(5L, .N), .N))], by = .(target)]$V1] 11 | 12 | # Make it a task 13 | img_task <- mlr3::as_task_classif(img_task_df, target = "target") 14 | 15 | img_transforms <- function(img) { 16 | img %>% 17 | # first convert image to tensor 18 | torchvision::transform_to_tensor() %>%# $to(device = choose_device()) %>% 19 | # # then move to the GPU (if available) 20 | (function(x) x$to(device = choose_device())) %>% 21 | # Required resize for alexnet 22 | torchvision::transform_resize(c(64, 64)) 23 | } 24 | 25 | 26 | lrn_alexnet <- lrn("classif.alexnet", 27 | predict_type = "prob", 28 | num_threads = 15, 29 | # Can't use pretrained on 10-class dataset yet, expects 1000 30 | pretrained = TRUE, 31 | img_transform_train = img_transforms, 32 | img_transform_val = img_transforms, 33 | img_transform_predict = img_transforms, 34 | batch_size = 10, 35 | epochs = 15, 36 | device = choose_device() 37 | ) 38 | 39 | 40 | lrn_alexnet$train(img_task) 41 | 42 | 43 | # lrn_alexnet$model 44 | 45 | 46 | img_test <- df_from_imagenet_dir("/opt/example-data/imagenette2-160/val/") 47 | 48 | # Make it a task 49 | img_task_test <- mlr3::as_task_classif(img_task_df, target = "target") 50 | 51 | preds <- lrn_alexnet$predict(img_task_test) 52 | 53 | preds$score(msr("classif.acc")) 54 | preds$score(msr("classif.ce")) 55 | -------------------------------------------------------------------------------- /attic/test_GraphLearnerTorch.R: -------------------------------------------------------------------------------- 1 | test_that("GraphLearnerTorch works", { 2 | g = top("input") %>>% 3 | top("select", items = "num") %>>% 4 | top("linear", out_features = 10L) %>>% 5 | top("relu") %>>% 6 | top("output") %>>% 7 | top("optimizer", "adam", lr = 0.1) %>>% 8 | top("loss", "cross_entropy") 9 | 10 | g_classif = g %>>% top("model.classif", epochs = 1L, batch_size = 16L) 11 | g_regr = g %>>% top("model.regr", epochs = 1L, batch_size = 16L) 12 | 13 | lrn_classif = as_learner_torch(g_classif) 14 | assert_true(lrn_classif$task_type == "classif") 15 | task = tsk("iris") 16 | lrn_classif$train(task) 17 | expect_class(lrn_classif$network, "nn_graph") 18 | expect_class(lrn_classif$history, "History") 19 | expect_class(lrn_classif$optimizer, "torch_optimizer") 20 | expect_class(lrn_classif$loss_fn, "nn_loss") 21 | 22 | lrn_regr = as_learner_torch(g_regr) 23 | assert_true(lrn_regr$task_type == "regr") 24 | }) 25 | -------------------------------------------------------------------------------- /attic/test_LearnerClassifTabResNet.R: -------------------------------------------------------------------------------- 1 | test_that("LearnerClassifTabResNet works", { 2 | l = lrn("classif.tab_resnet", 3 | loss = "cross_entropy", 4 | optimizer = "adam", 5 | n_blocks = 2L, 6 | d_hidden = 30L, 7 | d_main = 30L, 8 | dropout_first = 0.2, 9 | dropout_second = 0.2, 10 | activation = "relu", 11 | activation_args = list(), 12 | skip_connection = TRUE, 13 | bn.momentum = 0.1, 14 | # training args 15 | batch_size = 16L, 16 | epochs = 10L, 17 | opt.lr = 0.03, 18 | callbacks = list(), 19 | shuffle = TRUE 20 | ) 21 | 22 | result = run_autotest(learner = l, check_replicable = FALSE) 23 | 24 | expect_true(result, info = result$error) 25 | }) 26 | 27 | if (FALSE) { 28 | task = tsk("iris") 29 | l$train(task) 30 | } 31 | -------------------------------------------------------------------------------- /attic/test_LearnerClassifTorch.R: -------------------------------------------------------------------------------- 1 | test_that("autotest: classification", { 2 | module = nn_module( 3 | initialize = function(task) { 4 | out = if (task$task_type == "classif") length(task$class_names) else 1 5 | self$linear = nn_linear(length(task$feature_names), out) 6 | }, 7 | forward = function(x) { 8 | self$linear(x) 9 | } 10 | ) 11 | 12 | learner = lrn("classif.torch_module", module, feature_types = c("numeric", "integer"), 13 | batch_size = 16, epochs = 5 14 | ) 15 | # task = tsk("iris") 16 | # learner$train(task) 17 | # learner$predict_type = "prob" 18 | # learner$predict(task) 19 | # 20 | expect_learner(learner) 21 | 22 | result = run_autotest(learner, check_replicable = FALSE, exclude = "sanity") 23 | expect_true(result, info = result$error) 24 | }) 25 | 26 | -------------------------------------------------------------------------------- /attic/test_PipeOpImageTrafo.R: -------------------------------------------------------------------------------- 1 | test_that("PipeOpImageTrafo works", { 2 | task = tsk("tiny_imagenet") 3 | batch = get_batch(task, 1L, device = "cpu") 4 | graph = po("imagetrafo", .trafo = "to_tensor") %>>% 5 | po("imagetrafo", .trafo = "vflip") 6 | 7 | out = graph$train(task) 8 | task_vflip = out$vflip.output 9 | batch_vflip = get_batch(task_vflip, batch_size = 1L, device = "cpu") 10 | img = batch$x 11 | img_vflip = batch_vflip$x 12 | expect_true(torch_equal(torchvision::transform_vflip(img), img_vflip)) 13 | }) 14 | -------------------------------------------------------------------------------- /attic/test_TorchOpFlatten.R: -------------------------------------------------------------------------------- 1 | test_that("TorchOpFlatten works", { 2 | op = top("flatten") 3 | task = tsk("iris") 4 | param_vals = list( 5 | start_dim = 2L, 6 | end_dim = 3L 7 | ) 8 | inputs = list(input = torch_randn(16, 8, 9)) 9 | op$param_set$values = insert_named(op$param_set$values, param_vals) 10 | expect_torchop( 11 | op = op, 12 | inputs = inputs, 13 | task = task, 14 | class = "nn_flatten" 15 | ) 16 | }) 17 | -------------------------------------------------------------------------------- /attic/test_TorchOpInput.R: -------------------------------------------------------------------------------- 1 | test_that("TorchOpInput works", { 2 | task = tsk("iris") 3 | to = top("input") 4 | model_args = to$train(list(task))$output 5 | expect_true(inherits(model_args, "ModelConfig")) 6 | 7 | g = top("input") %>>% top("tab_tokenizer", d_token = 3L) 8 | out = g$train(task) 9 | }) 10 | -------------------------------------------------------------------------------- /attic/test_TorchOpModel.R: -------------------------------------------------------------------------------- 1 | test_that("TorchOpModel works", { 2 | task = tsk("iris") 3 | graph = po("pca") %>>% 4 | top("input") %>>% 5 | top("tab_tokenizer", d_token = 1L) %>>% 6 | top("flatten") %>>% 7 | top("linear_1", out_features = 20L) %>>% 8 | top("relu_1") %>>% 9 | top("output") %>>% 10 | top("loss", loss = "cross_entropy") %>>% 11 | top("optimizer", optimizer = "adam") %>>% 12 | top("model.classif", 13 | batch_size = 16L, 14 | device = "cpu", 15 | epochs = 1L, 16 | callbacks = list() 17 | ) 18 | glrn = as_learner_torch(graph) 19 | glrn$id = "net" 20 | glrn$train(task) 21 | expect_error(glrn$train(task), regexp = NA) 22 | expect_error(glrn$predict(task), regexp = NA) 23 | 24 | resampling = rsmp("cv", folds = 2) 25 | expect_error(resample(task, glrn, resampling), regexp = NA) 26 | }) 27 | -------------------------------------------------------------------------------- /attic/test_TorchOpOutput.R: -------------------------------------------------------------------------------- 1 | test_that("TorchOpOutput works", { 2 | task = tsk("iris") 3 | 4 | out = top("output") 5 | debugonce(get_private(out)$.train) 6 | 7 | graph = top("input") %>>% 8 | top("select", items = "num") 9 | 10 | res = graph$train(task) 11 | 12 | out$train(list(input = res[[1L]])) 13 | 14 | 15 | de 16 | 17 | b = top("linear", out_features = 10L) %>>% top("relu") 18 | op = top("repeat", block = b, times = 1L) 19 | out = op$build(list(input = torch_randn(16, 3)), task) 20 | x = out$output$output 21 | expect_true(all(x$shape == c(16, 10))) 22 | 23 | 24 | top("repeat", block = b, times = 2) %>>% 25 | top("loss", loss = "cross_entropy") %>>% 26 | top("optimizer", optimizer = "adam", lr = 0.01) %>>% 27 | top("model.classif", epochs = 0L, batch_size = 16L) 28 | 29 | expect_error(graph$train(task), regexp = NA) 30 | }) 31 | -------------------------------------------------------------------------------- /attic/test_TorchOpRepeat.R: -------------------------------------------------------------------------------- 1 | test_that("TorchOpRepeat works", { 2 | task = tsk("iris") 3 | 4 | b = top("linear", out_features = 10L) %>>% top("relu") 5 | op = top("repeat", block = b, times = 1L) 6 | out = op$build(list(input = torch_randn(16, 3)), task) 7 | x = out$output$output 8 | expect_true(all(x$shape == c(16, 10))) 9 | 10 | 11 | graph = top("input") %>>% 12 | top("select", items = "num") %>>% 13 | top("repeat", block = b, times = 2) %>>% 14 | top("loss", loss = "cross_entropy") %>>% 15 | top("optimizer", optimizer = "adam", lr = 0.01) %>>% 16 | top("model.classif", epochs = 0L, batch_size = 16L) 17 | 18 | expect_error(graph$train(task), regexp = NA) 19 | }) 20 | -------------------------------------------------------------------------------- /attic/test_TorchOpSelect.R: -------------------------------------------------------------------------------- 1 | test_that("TorchOpSelect works", { 2 | task = tsk("iris") 3 | y = torch_randn(10) 4 | inputs = list( 5 | input = list(img = torch_randn(10, 8, 8), num = torch_randn(10, 5), categ = torch_randn(10, 7)) 6 | ) 7 | to = top("select", items = "img") 8 | res = to$build(inputs, task) 9 | out = res$output$output 10 | layer = res$layer 11 | 12 | x = layer(list(img = 1)) 13 | expect_true(x == 1) 14 | expect_true(all(out$shape == c(10, 8, 8))) 15 | }) 16 | -------------------------------------------------------------------------------- /attic/test_TorchOpTabResNetBlocks.R: -------------------------------------------------------------------------------- 1 | test_that("TorchOpTabResNetBlocks works", { 2 | op = top("tab_resnet_blocks") 3 | task = tsk("iris") 4 | 5 | for (i in seq_len(3)) { 6 | param_vals = list( 7 | n_blocks = sample(1:3, 1), 8 | d_main = sample(1:10, 1), 9 | d_hidden = sample(1:10, 1), 10 | dropout_first = runif(1), 11 | dropout_second = runif(1), 12 | activation = sample(c("relu", "elu"), 1), 13 | activation_args = list(inplace = sample(c(TRUE, FALSE), 1)), 14 | bn.momentum = 0.2, 15 | skip_connection = sample(c(TRUE, FALSE), 1) 16 | ) 17 | n_features = sample(1:10, 1) 18 | n_batch = sample(1:3, 1) 19 | inputs = list(input = torch_randn(n_batch, n_features)) 20 | op$param_set$values = insert_named(op$param_set$values, param_vals) 21 | 22 | expect_torchop( 23 | op = op, 24 | inputs = inputs, 25 | task = task, 26 | "nn_tab_resnet_blocks" 27 | ) 28 | 29 | } 30 | }) 31 | -------------------------------------------------------------------------------- /attic/test_as_dataloader.R: -------------------------------------------------------------------------------- 1 | test_that("as_dataloader works", { 2 | task = tsk("iris") 3 | expect_error(as_dataloader(task, device = "cpu", batch_size = 16L, shuffle = TRUE), 4 | regexp = NA 5 | ) 6 | dl = as_dataloader(task, device = "cpu", batch_size = 50L) 7 | batch = dl$.iter()$.next() 8 | expect_equal(batch$y$shape, 50) 9 | }) 10 | 11 | test_that("shuffling works for dataloader", { 12 | task = tsk("iris") 13 | dl = as_dataloader(task, device = "cpu", batch_size = 50L, shuffle = TRUE) 14 | batch = dl$.iter()$.next() 15 | expect_equal(batch$y$shape, 50) 16 | expect_true(!all(as.array(batch$y) == 1)) 17 | }) 18 | -------------------------------------------------------------------------------- /attic/test_as_dataset.R: -------------------------------------------------------------------------------- 1 | test_that("as_dataset works for tabular data", { 2 | task = tsk("iris") 3 | ds = as_dataset(task, row_ids = 1:10) 4 | expect_true(length(ds) == 10L) 5 | expect_r6(ds, "dataset") 6 | }) 7 | 8 | test_that("as_dataset works for image data", { 9 | task = tsk("tiny_imagenet") 10 | ds = as_dataset(task) 11 | expect_true(inherits(ds, "dataset")) 12 | batch = get_batch(ds, 16L) 13 | expect_true(all.equal(names(batch), c("y", "x"))) 14 | y = batch$y 15 | x = batch$x 16 | expect_true(inherits(y, "torch_tensor")) 17 | expect_true(inherits(x, "torch_tensor")) 18 | expect_true(x$shape[1] == 16) 19 | expect_true(y$shape[1] == 16) 20 | }) 21 | -------------------------------------------------------------------------------- /attic/test_dataloader.R: -------------------------------------------------------------------------------- 1 | make_data = function() { 2 | data.table( 3 | x_lgl = c(FALSE, TRUE, FALSE), 4 | x_fct = as.factor(c("l1", "l2", "l1")), 5 | x_int = 3:1, 6 | x_num = c(0.5, 0.3, 1.2), 7 | ..row_id = 1:3, 8 | y = c(1, -1.2, 0.3) 9 | ) 10 | } 11 | 12 | test_that("cat2tensor works", { 13 | dat = data.table( 14 | x_lgl = c(TRUE, FALSE, FALSE), 15 | x_fct = as.factor(c("l1", "l2", "l1")) 16 | ) 17 | tensor_expected = torch_tensor( 18 | matrix( 19 | c(1L, 1L, 20 | 0L, 2L, 21 | 0L, 1L), 22 | byrow = TRUE, 23 | ncol = 2 24 | ) 25 | ) 26 | tensor = cat2tensor(dat, device = "cpu") 27 | expect_true(all(as.logical(tensor_expected == tensor))) 28 | }) 29 | 30 | test_that("make_dataset works", { 31 | dat = make_data() 32 | data_set = make_tabular_dataset( 33 | data = dat, 34 | target = "y", 35 | features = paste0("x_", c("lgl", "fct", "int", "num")), 36 | device = "cpu" 37 | ) 38 | 39 | expect_r6(data_set, "dataset") 40 | expect_true(length(data_set) == 3L) 41 | }) 42 | -------------------------------------------------------------------------------- /attic/test_nn_attention.R: -------------------------------------------------------------------------------- 1 | test_that("nn_ft_attention works", { 2 | d_token = 10 3 | n_heads = 1 4 | n_batch = 32 5 | bias = TRUE 6 | n_features = 11 7 | dropout = 0.2 8 | initialization = "xavier" 9 | attention = nn_ft_attention(d_token = d_token, n_heads = n_heads, dropout = dropout, 10 | bias = bias, initialization = initialization 11 | ) 12 | query = torch_randn(n_batch, n_features, d_token) 13 | key = torch_randn(n_batch, n_features, d_token) 14 | output = attention$forward(query, key) 15 | expect_equal(output$shape, c(n_batch, n_features, d_token)) 16 | 17 | }) 18 | -------------------------------------------------------------------------------- /attic/test_nn_cls.R: -------------------------------------------------------------------------------- 1 | test_that("nn_cls works", { 2 | n_batch = 16 3 | d_token = 7 4 | n_features = 9 5 | cls = nn_cls(d_token) 6 | x = torch_empty(n_batch, n_features, d_token) 7 | y = cls(x) 8 | expect_true(all(y$shape == c(16, 10, 7))) 9 | }) 10 | -------------------------------------------------------------------------------- /attic/test_paramset_activation.R: -------------------------------------------------------------------------------- 1 | test_that("Can construc all paramsets and all parameters are covered", { 2 | for (act in mlr3torch_activations) { 3 | param_set = paramsets_activation$get(act) 4 | constructor = get_activation(act) 5 | expected_ids = formalArgs(constructor) %??% character(0) 6 | ids = param_set$ids() 7 | expect_true(setequal(ids, expected_ids), info = act) 8 | } 9 | }) 10 | -------------------------------------------------------------------------------- /attic/test_paramset_torchlearner.R: -------------------------------------------------------------------------------- 1 | test_that("paramset_torchlearner has parameters properly documented", { 2 | param_set = paramset_torchlearner() 3 | 4 | template = readLines(system.file("./man-roxygen/paramset_torchlearner.R", package = "mlr3torch")) 5 | 6 | pardoc = regmatches(template, regexpr("(?<=\\`)[a-zA-Z][a-zA-Z0-9_]+(?=\\` ::)", template, perl = TRUE)) 7 | 8 | expect_true(length(pardoc) == param_set$length) 9 | expect_set_equal(pardoc, param_set$ids()) 10 | }) 11 | -------------------------------------------------------------------------------- /attic/test_paramsets_image_trafo.R: -------------------------------------------------------------------------------- 1 | test_that("Can construc all paramsets and all parameters are covered", { 2 | for (act in torch_reflections$image_trafos) { 3 | param_set = paramsets_image_trafo$get(act) 4 | constructor = get_image_trafo(act) 5 | expected_ids = formalArgs(constructor) %??% character(0) 6 | expected_ids = setdiff(expected_ids, "img") 7 | ids = param_set$ids() 8 | expect_true(setequal(ids, expected_ids), info = act) 9 | } 10 | }) 11 | -------------------------------------------------------------------------------- /attic/test_pot.R: -------------------------------------------------------------------------------- 1 | test_that("pot works", { 2 | # for nn_ 3 | obj = pot("linear") 4 | expect_r6(obj, "PipeOpTorchLinear") 5 | expect_true(obj$id == "nn_linear") 6 | # for torch_optimizer 7 | obj = pot("optimizer", "adam") 8 | expect_r6(obj, "PipeOpTorchOptimizer") 9 | expect_true(obj$id == "torch_optimizer") 10 | }) 11 | 12 | test_that("pot gives informative error message", { 13 | expect_error(pot("nn_linear"), regexp = "You probably wanted po(\"nn_", fixed = TRUE) 14 | expect_error(pot("torch_optimizer"), regexp = "You probably wanted po(\"torch_", fixed = TRUE) 15 | expect_error(pot("xyz"), regexp = "PipeOp with neither id torch_xyz or nn_xyz exists.", fixed = TRUE) 16 | }) 17 | -------------------------------------------------------------------------------- /attic/test_tabnet.R: -------------------------------------------------------------------------------- 1 | # Classification ---------------------------------------------------------- 2 | test_that("Learner can be instantiated", { 3 | skip_if_not_installed("tabnet") 4 | 5 | lrn = LearnerClassifTabNet$new() 6 | 7 | lrn$param_set$values$epochs = 10L 8 | lrn$param_set$values$decision_width = NULL 9 | lrn$param_set$values$attention_width = 8L 10 | 11 | expect_learner(lrn) 12 | expect_identical(lrn$param_set$values$epochs, 10L) 13 | expect_identical(lrn$param_set$values$attention_width, 8L) 14 | expect_identical(lrn$param_set$values$decision_width, NULL) 15 | }) 16 | 17 | test_that("autotest", { 18 | learner = LearnerClassifTabNet$new() 19 | learner$param_set$values$epochs = 3L 20 | expect_learner(learner) 21 | result = run_autotest(learner, exclude = "(feat_single|sanity)", check_replicable = FALSE) 22 | expect_true(result, info = result$error) 23 | }) 24 | 25 | 26 | # Regression -------------------------------------------------------------- 27 | test_that("Learner can be instantiated", { 28 | skip_if_not_installed("tabnet") 29 | 30 | lrn = LearnerRegrTabNet$new() 31 | 32 | lrn$param_set$values$epochs = 10L 33 | lrn$param_set$values$decision_width = NULL 34 | lrn$param_set$values$attention_width = 8L 35 | 36 | expect_learner(lrn) 37 | expect_identical(lrn$param_set$values$epochs, 10L) 38 | expect_identical(lrn$param_set$values$attention_width, 8L) 39 | expect_identical(lrn$param_set$values$decision_width, NULL) 40 | }) 41 | 42 | test_that("autotest", { 43 | learner = LearnerRegrTabNet$new() 44 | learner$param_set$values$epochs = 3L 45 | expect_learner(learner) 46 | result = run_autotest(learner, exclude = "(feat_single|sanity)", check_replicable = FALSE) 47 | expect_true(result, info = result$error) 48 | }) 49 | -------------------------------------------------------------------------------- /attic/test_tuning.R: -------------------------------------------------------------------------------- 1 | test_that("Tuning works", { 2 | task = tsk("mtcars") 3 | }) 4 | -------------------------------------------------------------------------------- /attic/torch_reflections.R: -------------------------------------------------------------------------------- 1 | #' @title Reflections mechanism for torch 2 | #' 3 | #' @details 4 | #' Used to store / extend available hyperparameter levels for options used throughout torch, 5 | #' e.g. the available 'loss' for a given Learner. 6 | #' 7 | #' @format [environment]. 8 | #' @export 9 | torch_reflections = new.env(parent = emptyenv()) 10 | 11 | torch_reflections$callback_steps = c( 12 | "start", 13 | "before_train_epoch", 14 | "before_train_batch", 15 | "after_train_batch", 16 | "before_valid_epoch", 17 | "before_valid_batch", 18 | "after_valid_batch", 19 | "after_valid_epoch", 20 | "end" 21 | ) 22 | 23 | -------------------------------------------------------------------------------- /attic/tuning-test.R: -------------------------------------------------------------------------------- 1 | ## Regression ---- 2 | library(mlr3) 3 | library(paradox) 4 | library(mlr3torch) 5 | library(mlr3pipelines) 6 | library(mlr3tuning) 7 | 8 | data("kc_housing", package = "mlr3data") 9 | # Drop columns with missings and class POSIXct for simplicity 10 | kc_housing <- kc_housing[, setdiff(names(kc_housing), c("date", "sqft_basement", "yr_renovated"))] 11 | task_regr <- TaskRegr$new("kc_housing", kc_housing, target = "price") 12 | 13 | lrn = LearnerRegrTorchTabnet$new() 14 | 15 | lrn_graph <- po("scale") %>>% 16 | po("learner", learner = lrn) 17 | lrn_graph <- as_learner(lrn_graph) 18 | 19 | lrn_graph$param_set$values <- list( 20 | scale.robust = FALSE, 21 | # verbose = TRUE, 22 | regr.torch.tabnet.epochs = 50, 23 | regr.torch.tabnet.lr_scheduler = "step", 24 | regr.torch.tabnet.device = "cuda" 25 | ) 26 | # Set attention_width to NULL and tune over decision_width, as model authors 27 | # recommend N_d == N_a and {tabnet} sets N_d == N_a if either is NULL 28 | lrn_graph$param_set$values$regr.torch.tabnet.attention_width <- NULL 29 | 30 | # Define search space 31 | search_space <- ps( 32 | #penalty = p_dbl(lower = 0.0001, upper = 0.001), 33 | regr.torch.tabnet.decision_width = p_int(lower = log2(8), upper = log2(64), trafo = function(x) 2^x), 34 | #attention_width = p_int(lower = 8, upper = 64), 35 | regr.torch.tabnet.num_steps = p_int(lower = 3, upper = 10) 36 | ) 37 | 38 | instance <- TuningInstanceSingleCrit$new( 39 | task = task_regr, 40 | learner = lrn_graph, 41 | resampling = rsmp("holdout"), 42 | measure = msr("regr.rmse"), 43 | search_space = search_space, 44 | terminator = trm("evals", n_evals = 50) 45 | ) 46 | 47 | tuner <- tnr("grid_search", resolution = 4, batch_size = 4) 48 | 49 | future::plan("multisession", workers = 4) 50 | 51 | tuner$optimize(instance) 52 | 53 | instance$result_learner_param_vals 54 | 55 | saveRDS(instance, here::here("attic/tuning-result.rds")) 56 | -------------------------------------------------------------------------------- /benchmarks/image_loaders/benchmark_image_loaders.R: -------------------------------------------------------------------------------- 1 | library(torch) 2 | library(torchvision) 3 | library(mlr3torch) 4 | library(here) 5 | 6 | library(data.table) 7 | setDTthreads(threads = 1) 8 | 9 | training_metadata = fread(here::here("cache", "ISIC_2020_Training_GroundTruth.csv")) 10 | 11 | # hard-coded cache directory that I use locally 12 | cache_dir = here("cache") 13 | 14 | ds_base_loader = torch::dataset( 15 | initialize = function(n_images) { 16 | self$.metadata = fread(here(cache_dir, "ISIC_2020_Training_GroundTruth.csv"))[1:n_images, ] 17 | self$.path = file.path(here(cache_dir), "train") 18 | }, 19 | .getitem = function(idx) { 20 | force(idx) 21 | 22 | x = torchvision::base_loader(file.path(self$.path, paste0(self$.metadata[idx, ]$image_name, ".jpg"))) 23 | x = torchvision::transform_to_tensor(x) 24 | 25 | return(list(x = x)) 26 | }, 27 | .length = function() { 28 | nrow(self$.metadata) 29 | } 30 | ) 31 | 32 | ds_magick_loader = torch::dataset( 33 | initialize = function(n_images) { 34 | self$.metadata = fread(here(cache_dir, "ISIC_2020_Training_GroundTruth.csv"))[1:n_images, ] 35 | self$.path = file.path(here(cache_dir), "train") 36 | }, 37 | .getitem = function(idx) { 38 | force(idx) 39 | 40 | image_name = self$.metadata[idx, ]$image_name 41 | 42 | x = magick::image_read(file.path(self$.path, paste0(image_name, ".jpg"))) 43 | x = torchvision::transform_to_tensor(x) 44 | 45 | return(list(x = x, image_name = image_name)) 46 | }, 47 | .length = function() { 48 | nrow(self$.metadata) 49 | } 50 | ) 51 | 52 | n_images = 10 53 | 54 | ds_base = ds_base_loader(n_images) 55 | ds_magick = ds_magick_loader(n_images) 56 | 57 | bmr = bench::mark( 58 | for (i in 1:n_images) ds_base$.getitem(i), 59 | for (i in 1:n_images) ds_magick$.getitem(i), 60 | memory = FALSE 61 | ) 62 | 63 | print(bmr) -------------------------------------------------------------------------------- /benchmarks/tensor_dataset.R: -------------------------------------------------------------------------------- 1 | devtools::load_all(".") 2 | library(bench) 3 | 4 | task = tsk("spam")$filter(1:2000) 5 | 6 | learner1 = lrn("classif.mlp", 7 | epochs = 20L, 8 | batch_size = 256, 9 | device = "cpu" 10 | ) 11 | learner2 = learner1$clone(deep = TRUE) 12 | learner2$param_set$set_values( 13 | tensor_dataset = TRUE 14 | ) 15 | 16 | f1 = function() learner1$train(task) 17 | f2 = function() learner2$train(task) 18 | 19 | 20 | x = mark( 21 | multi_tensor_dataset = f2(), 22 | no_tensor_dataset = f1(), 23 | check = FALSE 24 | ) 25 | 26 | print(x) 27 | 28 | # 1 tensor_dataset 1.05m 1.05m 0.0158 41.28MB 1.64 1 104 29 | # 2 no_tensor_dataset 1.52m 1.52m 0.0110 3.67GB 0.920 1 84 30 | 31 | -------------------------------------------------------------------------------- /cran-comments.md: -------------------------------------------------------------------------------- 1 | ## R CMD check results 2 | 3 | 0 errors | 0 warnings | 1 note 4 | 5 | Maintainer: 'Sebastian Fischer ' 6 | 7 | Days since last update: 6 8 | 9 | I am sorry for another upload in such a short timeframe. 10 | However, because I am going on vacation, I would like to submit a new release 11 | with important bugfixes before that and have already discussed this 12 | with Uwe Ligges per e-mail. 13 | -------------------------------------------------------------------------------- /data-raw/mnist.R: -------------------------------------------------------------------------------- 1 | devtools::load_all() 2 | 3 | ci = col_info(get_private(tsk("mnist")$backend)$.constructor()) 4 | saveRDS(ci, here::here("inst/col_info/mnist.rds")) 5 | -------------------------------------------------------------------------------- /data-raw/tiny_imagenet.R: -------------------------------------------------------------------------------- 1 | devtools::load_all() 2 | 3 | ci = col_info(get_private(tsk("tiny_imagenet")$backend)$.constructor()) 4 | 5 | saveRDS(ci, here::here("inst/col_info/tiny_imagenet.rds")) -------------------------------------------------------------------------------- /info/testdata.md: -------------------------------------------------------------------------------- 1 | The data in `./tests/testthat/assets/nano_cats_vs_dogs` comes from the training set of `torchvision::dogs_vs_cats_dataset()`. 2 | The names were left unchanged. 3 | 4 | The data in `./tests/testthat/assets/nano_mnist` was downloaded from: https://github.com/teavanist/MNIST-JPG. 5 | 6 random pictures (2 from class 1, 2 and 3) were taken and the true label prepended to the name. 6 | 7 | The data in `./tests/testthat/assets/nano_imagenet` was created as follows. 8 | It contains 5 "wok" and 5 "torch" images 9 | 10 | ```{r} 11 | task = tsk("tiny_imagenet") 12 | task = tsk("tiny_imagenet") 13 | 14 | dt = task$data() 15 | dt = dt[class %in% c("wok", "torch"), ] 16 | dt = dt[,.SD[1:5], by = class] 17 | 18 | for (i in seq_len(nrow(dt))) { 19 | path = as.character(dt[i, "image"][[1L]][1L]) 20 | new_name = paste0(dt[i, "class"][[1L]][1L], "_", basename(path)) 21 | system(sprintf("cp %s /home/sebi/mlr/mlr3torch/tests/testthat/assets/nano_imagenet/%s", path, new_name)) 22 | } 23 | ``` 24 | -------------------------------------------------------------------------------- /inst/COPYRIGHTS: -------------------------------------------------------------------------------- 1 | Some parts of the code and the documentation copied parts from the R package torch, which comes with the MIT License. 2 | 3 | # MIT License 4 | 5 | Copyright (c) 2020 Daniel Falbel 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. 24 | -------------------------------------------------------------------------------- /inst/col_info/cifar10.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlr-org/mlr3torch/d4de0bf85790a0fb012bb862332f6ac39a16baff/inst/col_info/cifar10.rds -------------------------------------------------------------------------------- /inst/col_info/cifar100.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlr-org/mlr3torch/d4de0bf85790a0fb012bb862332f6ac39a16baff/inst/col_info/cifar100.rds -------------------------------------------------------------------------------- /inst/col_info/melanoma.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlr-org/mlr3torch/d4de0bf85790a0fb012bb862332f6ac39a16baff/inst/col_info/melanoma.rds -------------------------------------------------------------------------------- /inst/col_info/mnist.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlr-org/mlr3torch/d4de0bf85790a0fb012bb862332f6ac39a16baff/inst/col_info/mnist.rds -------------------------------------------------------------------------------- /inst/col_info/tiny_imagenet.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlr-org/mlr3torch/d4de0bf85790a0fb012bb862332f6ac39a16baff/inst/col_info/tiny_imagenet.rds -------------------------------------------------------------------------------- /man-roxygen/field_id.R: -------------------------------------------------------------------------------- 1 | #' @field id (`character(1)`)\cr 2 | #' Identifier of the object. 3 | #' Used in tables, plot and text output. 4 | -------------------------------------------------------------------------------- /man-roxygen/field_label.R: -------------------------------------------------------------------------------- 1 | #' @field label (`character(1)`)\cr 2 | #' Label for this object. 3 | #' Can be used in tables, plot and text output instead of the ID. 4 | -------------------------------------------------------------------------------- /man-roxygen/field_man.R: -------------------------------------------------------------------------------- 1 | #' @field man (`character(1)`)\cr 2 | #' String in the format `[pkg]::[topic]` pointing to a manual page for this object. 3 | -------------------------------------------------------------------------------- /man-roxygen/field_packages.R: -------------------------------------------------------------------------------- 1 | #' @field packages (`character(1)`)\cr 2 | #' Set of required packages. 3 | #' These packages are loaded, but not attached. 4 | -------------------------------------------------------------------------------- /man-roxygen/field_param_set.R: -------------------------------------------------------------------------------- 1 | #' @field param_set ([`ParamSet`][paradox::ParamSet])\cr 2 | #' Set of hyperparameters. 3 | -------------------------------------------------------------------------------- /man-roxygen/field_phash.R: -------------------------------------------------------------------------------- 1 | #' @field phash (`character(1)`)\cr 2 | #' Hash (unique identifier) for this partial object, excluding some components 3 | #' which are varied systematically (e.g. the parameter values). 4 | -------------------------------------------------------------------------------- /man-roxygen/field_task_type.R: -------------------------------------------------------------------------------- 1 | #' @field task_type (`character(1)`)\cr 2 | #' Task type, e.g. `"classif"` or `"regr"`. 3 | #' 4 | #' For a complete list of possible task types (depending on the loaded packages), 5 | #' see [`mlr_reflections$task_types$type`][mlr3::mlr_reflections]. 6 | -------------------------------------------------------------------------------- /man-roxygen/field_task_types.R: -------------------------------------------------------------------------------- 1 | #' @field task_types (`character()`)\cr 2 | #' The task types this object supports. 3 | -------------------------------------------------------------------------------- /man-roxygen/learner.R: -------------------------------------------------------------------------------- 1 | #' <% task_types_vec <- strsplit(task_types, split = ", ")[[1L]]%> 2 | #' 3 | #' @name <%=paste0("mlr_learners.", name)%> 4 | #' 5 | #' @include LearnerTorch.R 6 | #' 7 | #' @section Dictionary: 8 | #' This [Learner][mlr3::Learner] can be instantiated using the sugar function [`lrn()`][mlr3::lrn]: 9 | #' ``` 10 | #' <%= if ("classif" %in% task_types_vec) paste0("lrn(\"classif.", name, "\", ...)") else ""%> 11 | #' <%= if ("regr" %in% task_types_vec) paste0("lrn(\"regr.", name, "\", ...)") else ""%> 12 | #' ``` 13 | #' 14 | #' @section Properties: 15 | #' `r mlr3torch:::rd_info_learner_torch("<%=name%>", "<%=task_types%>")` 16 | #' @md 17 | #' 18 | #' @family Learner 19 | -------------------------------------------------------------------------------- /man-roxygen/learner_example.R: -------------------------------------------------------------------------------- 1 | #' <% task_types_vec = strsplit(task_types, ", ")[[1]] %> 2 | #' <% param_vals = if (exists("param_vals", inherits = FALSE)) param_vals else character(0) %> 3 | #' <% id = paste0(task_types_vec[1], ".", name)%> 4 | #' <% task_id = default_task_id(lrn(id)) %> 5 | #' 6 | #' @examplesIf torch::torch_is_installed() 7 | #' # Define the Learner and set parameter values 8 | #' <%= sprintf("learner = lrn(\"%s\")", id)%> 9 | #' learner$param_set$set_values( 10 | #' epochs = 1, batch_size = 16, device = "cpu"<%= if (length(param_vals)) "," else character()%> 11 | #' <%= if (length(param_vals)) paste0(param_vals, collapse = ",\n ") else character()%> 12 | #' ) 13 | #' 14 | #' # Define a Task 15 | #' <%= sprintf("task = tsk(\"%s\")", task_id)%> 16 | #' 17 | #' # Create train and test set 18 | #' <%= sprintf("ids = partition(task)")%> 19 | #' 20 | #' # Train the learner on the training ids 21 | #' <%= sprintf("learner$train(task, row_ids = ids$train)")%> 22 | #' 23 | #' # Make predictions for the test rows 24 | #' <%= sprintf("predictions = learner$predict(task, row_ids = ids$test)")%> 25 | #' 26 | #' # Score the predictions 27 | #' predictions$score() 28 | -------------------------------------------------------------------------------- /man-roxygen/param_callbacks.R: -------------------------------------------------------------------------------- 1 | #' @param callbacks (`list()` of [`TorchCallback`]s)\cr 2 | #' The callbacks used during training. 3 | #' Must have unique ids. 4 | #' They are executed in the order in which they are provided 5 | -------------------------------------------------------------------------------- /man-roxygen/param_dataset_shapes.R: -------------------------------------------------------------------------------- 1 | #' @param dataset_shapes (named `list()` of (`integer()` or `NULL`))\cr 2 | #' The shapes of the output. 3 | #' Names are the elements of the list returned by the dataset. 4 | #' If the shape is not `NULL` (unknown, e.g. for images of different sizes) the first dimension must be `NA` to 5 | #' indicate the batch dimension. 6 | -------------------------------------------------------------------------------- /man-roxygen/param_feature_types.R: -------------------------------------------------------------------------------- 1 | #' @param feature_types (`character()`)\cr 2 | #' The feature types. 3 | #' See [`mlr_reflections$task_feature_types`][mlr3::mlr_reflections] for available values, 4 | #' Additionally, `"lazy_tensor"` is supported. 5 | -------------------------------------------------------------------------------- /man-roxygen/param_id.R: -------------------------------------------------------------------------------- 1 | #' @param id (`character(1)`)\cr 2 | #' The id for of the new object. 3 | -------------------------------------------------------------------------------- /man-roxygen/param_label.R: -------------------------------------------------------------------------------- 1 | #' @param label (`character(1)`)\cr 2 | #' Label for the new instance. 3 | -------------------------------------------------------------------------------- /man-roxygen/param_loss.R: -------------------------------------------------------------------------------- 1 | #' @param loss ([`TorchLoss`])\cr 2 | #' The loss to use for training. 3 | -------------------------------------------------------------------------------- /man-roxygen/param_man.R: -------------------------------------------------------------------------------- 1 | #' @param man (`character(1)`)\cr 2 | #' String in the format `[pkg]::[topic]` pointing to a manual page for this object. 3 | #' The referenced help package can be opened via method `$help()`. 4 | -------------------------------------------------------------------------------- /man-roxygen/param_module_generator.R: -------------------------------------------------------------------------------- 1 | #' @param module_generator (`nn_module_generator`)\cr 2 | #' The torch module generator. 3 | -------------------------------------------------------------------------------- /man-roxygen/param_optimizer.R: -------------------------------------------------------------------------------- 1 | #' @param optimizer ([`TorchOptimizer`])\cr 2 | #' The torch optimizer. 3 | -------------------------------------------------------------------------------- /man-roxygen/param_packages.R: -------------------------------------------------------------------------------- 1 | #' @param packages (`character()`)\cr 2 | #' The R packages this object depends on. 3 | -------------------------------------------------------------------------------- /man-roxygen/param_param_set.R: -------------------------------------------------------------------------------- 1 | #' @param param_set ([`ParamSet`][paradox::ParamSet])\cr 2 | #' The parameter set. 3 | -------------------------------------------------------------------------------- /man-roxygen/param_param_vals.R: -------------------------------------------------------------------------------- 1 | #' @param param_vals (named `list()`)\cr 2 | #' Parameter values to be set after construction. 3 | -------------------------------------------------------------------------------- /man-roxygen/param_predict_types.R: -------------------------------------------------------------------------------- 1 | #' @param predict_types (`character()`)\cr 2 | #' The predict types. 3 | #' See [`mlr_reflections$learner_predict_types`][mlr3::mlr_reflections] for available values. 4 | -------------------------------------------------------------------------------- /man-roxygen/param_properties.R: -------------------------------------------------------------------------------- 1 | #' @param properties (`character()`)\cr 2 | #' The properties of the object. 3 | #' See [`mlr_reflections$learner_properties`][mlr3::mlr_reflections] for available values. 4 | -------------------------------------------------------------------------------- /man-roxygen/param_task_type.R: -------------------------------------------------------------------------------- 1 | #' @param task_type (`character(1)`)\cr 2 | #' The task type. 3 | -------------------------------------------------------------------------------- /man-roxygen/params_learner.R: -------------------------------------------------------------------------------- 1 | #' @param task_type (`character(1)`)\cr 2 | #' The task type, either `"classif`" or `"regr"`. 3 | #' @param optimizer ([`TorchOptimizer`])\cr 4 | #' The optimizer to use for training. 5 | #' Per default, *adam* is used. 6 | #' @param loss ([`TorchLoss`])\cr 7 | #' The loss used to train the network. 8 | #' Per default, *mse* is used for regression and *cross_entropy* for classification. 9 | #' @param callbacks (`list()` of [`TorchCallback`]s)\cr 10 | #' The callbacks. Must have unique ids. 11 | -------------------------------------------------------------------------------- /man-roxygen/params_pipelines.R: -------------------------------------------------------------------------------- 1 | #' @param id (`character(1)`)\cr 2 | #' Identifier of the resulting object. 3 | #' @param param_vals (`list()`)\cr 4 | #' List of hyperparameter settings, overwriting the hyperparameter settings that would 5 | #' otherwise be set during construction. 6 | -------------------------------------------------------------------------------- /man-roxygen/pipeop_torch.R: -------------------------------------------------------------------------------- 1 | #' @name <%=paste0("mlr_pipeops_", id)%> 2 | #' 3 | #' @section State: 4 | #' The state is the value calculated by the public method `$shapes_out()`. 5 | #' @family PipeOps 6 | -------------------------------------------------------------------------------- /man-roxygen/pipeop_torch_channels_default.R: -------------------------------------------------------------------------------- 1 | #' @section Input and Output Channels: 2 | #' One input channel called `"input"` and one output channel called `"output"`. 3 | #' For an explanation see [`PipeOpTorch`]. 4 | -------------------------------------------------------------------------------- /man-roxygen/pipeop_torch_example.R: -------------------------------------------------------------------------------- 1 | #' <% param_vals = if (exists("param_vals", inherits = FALSE)) paste0(", ", param_vals)%> 2 | #' 3 | #' @examplesIf torch::torch_is_installed() 4 | #' # Construct the PipeOp 5 | #' pipeop = po("<%=id %>"<%=param_vals%>) 6 | #' pipeop 7 | #' # The available parameters 8 | #' pipeop$param_set 9 | 10 | 11 | 12 | # Maybe finish this at some point 13 | # @examplesIf torch::torch_is_installed() 14 | # md = ModelDescriptor( 15 | # graph = 16 | # ) 17 | # 18 | # # The graph of the model descriptor is modified in-place 19 | # graph = md$graph 20 | # 21 | # 22 | # # These are the output shapes generated by the pipeop for the given task 23 | # pipeop$shapes_out(md$pointer_shape) 24 | # 25 | # # Apply the pipeop to the model descriptor 26 | # mdout = pipeop$train(list(md))[[1L]] 27 | # mdout 28 | # 29 | # # A PipeOpModule was generated by the PipeOpTorch and added to the graph 30 | # graph$edges 31 | # graph$pipeops$<%=id%>$module 32 | # 33 | # # The pointer is updated to the output channel of the applied pipeop 34 | # c(pipeop$id, pipeop$output$name) 35 | # mdout$pointer 36 | # 37 | # # The shapes are equal to those calculated by `$shapes_out()` above 38 | # mdout$pointer_shape 39 | -------------------------------------------------------------------------------- /man-roxygen/pipeop_torch_state_default.R: -------------------------------------------------------------------------------- 1 | #' @section State: 2 | #' The state is the value calculated by the public method `shapes_out()`. 3 | -------------------------------------------------------------------------------- /man-roxygen/preprocess_torchvision.R: -------------------------------------------------------------------------------- 1 | #' <% pipeop = po(id) %> 2 | #' @aliases <%= class(pipeop)[[1L]] %> 3 | #' @usage NULL 4 | #' @name mlr_pipeops_<%= id %> 5 | #' @rdname mlr_pipeops_<%= id %> 6 | #' @format [`R6Class`][R6::R6Class] inheriting from [`PipeOpTaskPreprocTorch`]. 7 | #' @section Construction: 8 | #' ```r 9 | #' po("<%= id%>"") 10 | #' ```` 11 | #' 12 | #' @description 13 | #' Calls [`<%= paste0("torchvision::", gsub("^(augment|trafo)", "transform", id)) %>`], 14 | #' see there for more information on the parameters. 15 | #' <%= if (pipeop$rowwise) "The preprocessing is applied to each element of a batch individually." else "The preprocessing is applied to the whole batch."%> 16 | #' 17 | #' @section Parameters: 18 | #' `r mlr3misc::rd_info(po("<%= id%>")$param_set)` 19 | -------------------------------------------------------------------------------- /man-roxygen/pretrained.R: -------------------------------------------------------------------------------- 1 | #' @section Architecture: 2 | #' Calls [<%= pkg[1] %>::<%= model %>] from package \CRANpkg{<%= pkg %>} to load the 3 | #' network architecture. If the parameter `pretrained == TRUE`, the learner is initialized 4 | #' with pretrained weights and the final layer is replaced with an simple linear 5 | #' output layer tailored to the task when the method `$train()` is called. 6 | -------------------------------------------------------------------------------- /man-roxygen/task_download.R: -------------------------------------------------------------------------------- 1 | #' @section Download: 2 | #' The [task][mlr3::Task]'s backend is a [`DataBackendLazy`] which will download the data once it is requested. 3 | #' Other meta-data is already available before that. 4 | #' You can cache these datasets by setting the `mlr3torch.cache` option to `TRUE` or to a specific path to be used 5 | #' as the cache directory. 6 | -------------------------------------------------------------------------------- /man/PipeOpPreprocTorchTrafoNop.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/preprocess.R 3 | \name{mlr_pipeops_trafo_nop} 4 | \alias{mlr_pipeops_trafo_nop} 5 | \alias{PipeOpPreprocTorchTrafoNop} 6 | \title{No Transformation} 7 | \format{ 8 | \code{\link[R6:R6Class]{R6Class}} inheriting from \code{\link{PipeOpTaskPreprocTorch}}. 9 | } 10 | \description{ 11 | Does nothing. 12 | } 13 | -------------------------------------------------------------------------------- /man/PipeOpPreprocTorchTrafoReshape.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/preprocess.R 3 | \name{mlr_pipeops_trafo_reshape} 4 | \alias{mlr_pipeops_trafo_reshape} 5 | \alias{PipeOpPreprocTorchTrafoReshape} 6 | \title{Reshaping Transformation} 7 | \format{ 8 | \code{\link[R6:R6Class]{R6Class}} inheriting from \code{\link{PipeOpTaskPreprocTorch}}. 9 | } 10 | \description{ 11 | Reshapes the tensor according to the parameter \code{shape}, by calling \code{torch_reshape()}. 12 | This preprocessing function is applied batch-wise. 13 | } 14 | \section{Parameters}{ 15 | 16 | \itemize{ 17 | \item \code{shape} :: \code{integer()}\cr 18 | The desired output shape. The first dimension is the batch dimension and should usually be \code{-1}. 19 | } 20 | } 21 | 22 | -------------------------------------------------------------------------------- /man/Select.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/Select.R 3 | \name{Select} 4 | \alias{Select} 5 | \alias{select_all} 6 | \alias{select_none} 7 | \alias{select_grep} 8 | \alias{select_name} 9 | \alias{select_invert} 10 | \title{Selector Functions for Character Vectors} 11 | \usage{ 12 | select_all() 13 | 14 | select_none() 15 | 16 | select_grep(pattern, ignore.case = FALSE, perl = FALSE, fixed = FALSE) 17 | 18 | select_name(param_names, assert_present = TRUE) 19 | 20 | select_invert(select) 21 | } 22 | \arguments{ 23 | \item{pattern}{See \code{grep()}} 24 | 25 | \item{ignore.case}{See \code{grep()}} 26 | 27 | \item{perl}{See \code{grep()}} 28 | 29 | \item{fixed}{See \code{grep()}} 30 | 31 | \item{param_names}{The names of the parameters that you want to select} 32 | 33 | \item{assert_present}{Whether to check that \code{param_names} is a subset of the full vector of names} 34 | 35 | \item{select}{A \code{Select}} 36 | } 37 | \description{ 38 | A \code{\link{Select}} function subsets a character vector. They are used by the callback \code{CallbackSetUnfreeze} to select parameters to freeze or unfreeze during training. 39 | } 40 | \section{Functions}{ 41 | \itemize{ 42 | \item \code{select_all()}: \code{select_all} selects all elements 43 | 44 | \item \code{select_none()}: \code{select_none} selects no elements 45 | 46 | \item \code{select_grep()}: \code{select_grep} selects elements with names matching a regular expression 47 | 48 | \item \code{select_name()}: \code{select_name} selects elements with names matching the given names 49 | 50 | \item \code{select_invert()}: \code{select_invert} selects the elements NOT selected by the given selector 51 | 52 | }} 53 | \examples{ 54 | select_all()(c("a", "b")) 55 | select_none()(c("a", "b")) 56 | select_grep("b$")(c("ab", "ac")) 57 | select_name("a")(c("a", "b")) 58 | select_invert(select_all())(c("a", "b")) 59 | } 60 | -------------------------------------------------------------------------------- /man/as_data_descriptor.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/DataDescriptor.R 3 | \name{as_data_descriptor} 4 | \alias{as_data_descriptor} 5 | \title{Convert to Data Descriptor} 6 | \usage{ 7 | as_data_descriptor(x, dataset_shapes, ...) 8 | } 9 | \arguments{ 10 | \item{x}{(any)\cr 11 | Object to convert.} 12 | 13 | \item{dataset_shapes}{(named \code{list()} of (\code{integer()} or \code{NULL}))\cr 14 | The shapes of the output. 15 | Names are the elements of the list returned by the dataset. 16 | If the shape is not \code{NULL} (unknown, e.g. for images of different sizes) the first dimension must be \code{NA} to 17 | indicate the batch dimension.} 18 | 19 | \item{...}{(any)\cr 20 | Further arguments passed to the \code{\link{DataDescriptor}} constructor.} 21 | } 22 | \description{ 23 | Converts the input to a \code{\link{DataDescriptor}}. 24 | } 25 | \examples{ 26 | \dontshow{if (torch::torch_is_installed()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} 27 | ds = dataset("example", 28 | initialize = function() self$iris = iris[, -5], 29 | .getitem = function(i) list(x = torch_tensor(as.numeric(self$iris[i, ]))), 30 | .length = function() nrow(self$iris) 31 | )() 32 | as_data_descriptor(ds, list(x = c(NA, 4L))) 33 | 34 | # if the dataset has a .getbatch method, the shapes are inferred 35 | ds2 = dataset("example", 36 | initialize = function() self$iris = iris[, -5], 37 | .getbatch = function(i) list(x = torch_tensor(as.matrix(self$iris[i, ]))), 38 | .length = function() nrow(self$iris) 39 | )() 40 | as_data_descriptor(ds2) 41 | \dontshow{\}) # examplesIf} 42 | } 43 | -------------------------------------------------------------------------------- /man/as_lr_scheduler.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/CallbackSetLRScheduler.R 3 | \name{as_lr_scheduler} 4 | \alias{as_lr_scheduler} 5 | \title{Convert to CallbackSetLRScheduler} 6 | \usage{ 7 | as_lr_scheduler(x, step_on_epoch) 8 | } 9 | \arguments{ 10 | \item{x}{(\code{function})\cr 11 | The \code{torch} scheduler generator defined using \code{torch::lr_scheduler()}.} 12 | 13 | \item{step_on_epoch}{(\code{logical(1)})\cr 14 | Whether the scheduler steps after every epoch} 15 | } 16 | \description{ 17 | Convert a \code{torch} scheduler generator to a \code{CallbackSetLRScheduler}. 18 | } 19 | -------------------------------------------------------------------------------- /man/as_torch_callback.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/TorchCallback.R 3 | \name{as_torch_callback} 4 | \alias{as_torch_callback} 5 | \title{Convert to a TorchCallback} 6 | \usage{ 7 | as_torch_callback(x, clone = FALSE, ...) 8 | } 9 | \arguments{ 10 | \item{x}{(any)\cr 11 | Object to be converted.} 12 | 13 | \item{clone}{(\code{logical(1)})\cr 14 | Whether to make a deep clone.} 15 | 16 | \item{...}{(any)\cr 17 | Additional arguments} 18 | } 19 | \value{ 20 | \code{\link{TorchCallback}}. 21 | } 22 | \description{ 23 | Converts an object to a \code{\link{TorchCallback}}. 24 | } 25 | \seealso{ 26 | Other Callback: 27 | \code{\link{TorchCallback}}, 28 | \code{\link{as_torch_callbacks}()}, 29 | \code{\link{callback_set}()}, 30 | \code{\link{mlr3torch_callbacks}}, 31 | \code{\link{mlr_callback_set}}, 32 | \code{\link{mlr_callback_set.checkpoint}}, 33 | \code{\link{mlr_callback_set.progress}}, 34 | \code{\link{mlr_callback_set.tb}}, 35 | \code{\link{mlr_callback_set.unfreeze}}, 36 | \code{\link{mlr_context_torch}}, 37 | \code{\link{t_clbk}()}, 38 | \code{\link{torch_callback}()} 39 | } 40 | \concept{Callback} 41 | -------------------------------------------------------------------------------- /man/as_torch_callbacks.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/TorchCallback.R 3 | \name{as_torch_callbacks} 4 | \alias{as_torch_callbacks} 5 | \title{Convert to a list of Torch Callbacks} 6 | \usage{ 7 | as_torch_callbacks(x, clone, ...) 8 | } 9 | \arguments{ 10 | \item{x}{(any)\cr 11 | Object to convert.} 12 | 13 | \item{clone}{(\code{logical(1)})\cr 14 | Whether to create a deep clone.} 15 | 16 | \item{...}{(any)\cr 17 | Additional arguments.} 18 | } 19 | \value{ 20 | \code{list()} of \code{\link{TorchCallback}}s 21 | } 22 | \description{ 23 | Converts an object to a list of \code{\link{TorchCallback}}. 24 | } 25 | \seealso{ 26 | Other Callback: 27 | \code{\link{TorchCallback}}, 28 | \code{\link{as_torch_callback}()}, 29 | \code{\link{callback_set}()}, 30 | \code{\link{mlr3torch_callbacks}}, 31 | \code{\link{mlr_callback_set}}, 32 | \code{\link{mlr_callback_set.checkpoint}}, 33 | \code{\link{mlr_callback_set.progress}}, 34 | \code{\link{mlr_callback_set.tb}}, 35 | \code{\link{mlr_callback_set.unfreeze}}, 36 | \code{\link{mlr_context_torch}}, 37 | \code{\link{t_clbk}()}, 38 | \code{\link{torch_callback}()} 39 | 40 | Other Torch Descriptor: 41 | \code{\link{TorchCallback}}, 42 | \code{\link{TorchDescriptor}}, 43 | \code{\link{TorchLoss}}, 44 | \code{\link{TorchOptimizer}}, 45 | \code{\link{as_torch_loss}()}, 46 | \code{\link{as_torch_optimizer}()}, 47 | \code{\link{mlr3torch_losses}}, 48 | \code{\link{mlr3torch_optimizers}}, 49 | \code{\link{t_clbk}()}, 50 | \code{\link{t_loss}()}, 51 | \code{\link{t_opt}()} 52 | } 53 | \concept{Callback} 54 | \concept{Torch Descriptor} 55 | -------------------------------------------------------------------------------- /man/as_torch_loss.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/TorchLoss.R 3 | \name{as_torch_loss} 4 | \alias{as_torch_loss} 5 | \title{Convert to TorchLoss} 6 | \usage{ 7 | as_torch_loss(x, clone = FALSE, ...) 8 | } 9 | \arguments{ 10 | \item{x}{(any)\cr 11 | Object to convert to a \code{\link{TorchLoss}}.} 12 | 13 | \item{clone}{(\code{logical(1)})\cr 14 | Whether to make a deep clone.} 15 | 16 | \item{...}{(any)\cr 17 | Additional arguments. 18 | Currently used to pass additional constructor arguments to \code{\link{TorchLoss}} for objects of type \code{nn_loss}.} 19 | } 20 | \value{ 21 | \code{\link{TorchLoss}}. 22 | } 23 | \description{ 24 | Converts an object to a \code{\link{TorchLoss}}. 25 | } 26 | \seealso{ 27 | Other Torch Descriptor: 28 | \code{\link{TorchCallback}}, 29 | \code{\link{TorchDescriptor}}, 30 | \code{\link{TorchLoss}}, 31 | \code{\link{TorchOptimizer}}, 32 | \code{\link{as_torch_callbacks}()}, 33 | \code{\link{as_torch_optimizer}()}, 34 | \code{\link{mlr3torch_losses}}, 35 | \code{\link{mlr3torch_optimizers}}, 36 | \code{\link{t_clbk}()}, 37 | \code{\link{t_loss}()}, 38 | \code{\link{t_opt}()} 39 | } 40 | \concept{Torch Descriptor} 41 | -------------------------------------------------------------------------------- /man/as_torch_optimizer.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/TorchOptimizer.R 3 | \name{as_torch_optimizer} 4 | \alias{as_torch_optimizer} 5 | \title{Convert to TorchOptimizer} 6 | \usage{ 7 | as_torch_optimizer(x, clone = FALSE, ...) 8 | } 9 | \arguments{ 10 | \item{x}{(any)\cr 11 | Object to convert to a \code{\link{TorchOptimizer}}.} 12 | 13 | \item{clone}{(\code{logical(1)})\cr 14 | Whether to make a deep clone. Default is \code{FALSE}.} 15 | 16 | \item{...}{(any)\cr 17 | Additional arguments. 18 | Currently used to pass additional constructor arguments to \code{\link{TorchOptimizer}} for objects of type 19 | \code{torch_optimizer_generator}.} 20 | } 21 | \value{ 22 | \code{\link{TorchOptimizer}} 23 | } 24 | \description{ 25 | Converts an object to a \code{\link{TorchOptimizer}}. 26 | } 27 | \seealso{ 28 | Other Torch Descriptor: 29 | \code{\link{TorchCallback}}, 30 | \code{\link{TorchDescriptor}}, 31 | \code{\link{TorchLoss}}, 32 | \code{\link{TorchOptimizer}}, 33 | \code{\link{as_torch_callbacks}()}, 34 | \code{\link{as_torch_loss}()}, 35 | \code{\link{mlr3torch_losses}}, 36 | \code{\link{mlr3torch_optimizers}}, 37 | \code{\link{t_clbk}()}, 38 | \code{\link{t_loss}()}, 39 | \code{\link{t_opt}()} 40 | } 41 | \concept{Torch Descriptor} 42 | -------------------------------------------------------------------------------- /man/assert_lazy_tensor.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/lazy_tensor.R 3 | \name{assert_lazy_tensor} 4 | \alias{assert_lazy_tensor} 5 | \title{Assert Lazy Tensor} 6 | \usage{ 7 | assert_lazy_tensor(x) 8 | } 9 | \arguments{ 10 | \item{x}{(any)\cr 11 | Object to check.} 12 | } 13 | \description{ 14 | Asserts whether something is a lazy tensor. 15 | } 16 | -------------------------------------------------------------------------------- /man/auto_device.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/utils.R 3 | \name{auto_device} 4 | \alias{auto_device} 5 | \title{Auto Device} 6 | \usage{ 7 | auto_device(device = NULL) 8 | } 9 | \arguments{ 10 | \item{device}{(\code{character(1)})\cr 11 | The device. If not \code{NULL}, is returned as is.} 12 | } 13 | \description{ 14 | First tries cuda, then cpu. 15 | } 16 | -------------------------------------------------------------------------------- /man/batchgetter_categ.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/task_dataset.R 3 | \name{batchgetter_categ} 4 | \alias{batchgetter_categ} 5 | \title{Batchgetter for Categorical data} 6 | \usage{ 7 | batchgetter_categ(data, ...) 8 | } 9 | \arguments{ 10 | \item{data}{(\code{data.table})\cr 11 | \code{data.table} to be converted to a \code{tensor}.} 12 | 13 | \item{...}{(any)\cr 14 | Unused.} 15 | } 16 | \description{ 17 | Converts a data frame of categorical data into a long tensor by converting the data to integers. 18 | No input checks are performed. 19 | } 20 | -------------------------------------------------------------------------------- /man/batchgetter_num.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/task_dataset.R 3 | \name{batchgetter_num} 4 | \alias{batchgetter_num} 5 | \title{Batchgetter for Numeric Data} 6 | \usage{ 7 | batchgetter_num(data, ...) 8 | } 9 | \arguments{ 10 | \item{data}{(\code{data.table()})\cr 11 | \code{data.table} to be converted to a \code{tensor}.} 12 | 13 | \item{...}{(any)\cr 14 | Unused.} 15 | } 16 | \description{ 17 | Converts a data frame of numeric data into a float tensor by calling \code{as.matrix()}. 18 | No input checks are performed 19 | } 20 | -------------------------------------------------------------------------------- /man/cross_entropy.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/TorchLoss.R 3 | \name{cross_entropy} 4 | \alias{cross_entropy} 5 | \title{Cross Entropy Loss} 6 | \description{ 7 | The \code{cross_entropy} loss function selects the multi-class (\code{\link[torch:nn_cross_entropy_loss]{nn_cross_entropy_loss}}) 8 | or binary (\code{\link[torch:nn_bce_with_logits_loss]{nn_bce_with_logits_loss}}) cross entropy 9 | loss based on the number of classes. 10 | Because of this, there is a slight reparameterization of the loss arguments, see \emph{Parameters}. 11 | } 12 | \section{Parameters}{ 13 | 14 | \itemize{ 15 | \item \code{class_weight}:: \code{\link[torch:torch_tensor]{torch_tensor}}\cr 16 | The class weights. For multi-class problems, this must be a \code{torch_tensor} of length \code{num_classes} 17 | (and is passed as argument \code{weight} to \code{\link[torch:nn_cross_entropy_loss]{nn_cross_entropy_loss}}). 18 | For binary problems, this must be a scalar (and is passed as argument \code{pos_weight} to 19 | \code{\link[torch:nn_bce_with_logits_loss]{nn_bce_with_logits_loss}}). 20 | } 21 | \itemize{ 22 | \item \code{ignore_index}:: \code{integer(1)}\cr 23 | Index of the class which to ignore and which does not contribute to the gradient. 24 | This is only available for multi-class loss. 25 | \item \code{reduction} :: \code{character(1)}\cr 26 | The reduction to apply. Is either \code{"mean"} or \code{"sum"} and passed as argument \code{reduction} 27 | to either loss function. The default is \code{"mean"}. 28 | } 29 | } 30 | 31 | \examples{ 32 | \dontshow{if (torch::torch_is_installed()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} 33 | loss = t_loss("cross_entropy") 34 | # multi-class 35 | multi_ce = loss$generate(tsk("iris")) 36 | multi_ce 37 | 38 | # binary 39 | binary_ce = loss$generate(tsk("sonar")) 40 | binary_ce 41 | \dontshow{\}) # examplesIf} 42 | } 43 | -------------------------------------------------------------------------------- /man/equals-.lazy_tensor.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/lazy_tensor.R 3 | \name{==.lazy_tensor} 4 | \alias{==.lazy_tensor} 5 | \title{Compare lazy tensors} 6 | \usage{ 7 | \method{==}{lazy_tensor}(x, y) 8 | } 9 | \arguments{ 10 | \item{x, y}{(\code{\link{lazy_tensor}})\cr 11 | Values to compare.} 12 | } 13 | \description{ 14 | Compares lazy tensors using their indices and the data descriptor's hash. 15 | This means that if two \code{\link{lazy_tensor}}s: 16 | \itemize{ 17 | \item are equal: they will mateterialize to the same tensors. 18 | \item are unequal: they might materialize to the same tensors. 19 | } 20 | } 21 | \keyword{internal} 22 | -------------------------------------------------------------------------------- /man/figures/lifecycle-archived.svg: -------------------------------------------------------------------------------- 1 | lifecyclelifecyclearchivedarchived -------------------------------------------------------------------------------- /man/figures/lifecycle-defunct.svg: -------------------------------------------------------------------------------- 1 | lifecyclelifecycledefunctdefunct -------------------------------------------------------------------------------- /man/figures/lifecycle-deprecated.svg: -------------------------------------------------------------------------------- 1 | lifecyclelifecycledeprecateddeprecated -------------------------------------------------------------------------------- /man/figures/lifecycle-experimental.svg: -------------------------------------------------------------------------------- 1 | lifecyclelifecycleexperimentalexperimental -------------------------------------------------------------------------------- /man/figures/lifecycle-maturing.svg: -------------------------------------------------------------------------------- 1 | lifecyclelifecyclematuringmaturing -------------------------------------------------------------------------------- /man/figures/lifecycle-questioning.svg: -------------------------------------------------------------------------------- 1 | lifecyclelifecyclequestioningquestioning -------------------------------------------------------------------------------- /man/figures/lifecycle-stable.svg: -------------------------------------------------------------------------------- 1 | lifecyclelifecyclestablestable -------------------------------------------------------------------------------- /man/figures/lifecycle-superseded.svg: -------------------------------------------------------------------------------- 1 | lifecyclelifecyclesupersededsuperseded -------------------------------------------------------------------------------- /man/figures/logo_old.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlr-org/mlr3torch/d4de0bf85790a0fb012bb862332f6ac39a16baff/man/figures/logo_old.png -------------------------------------------------------------------------------- /man/infer_shapes.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/shape.R 3 | \name{infer_shapes} 4 | \alias{infer_shapes} 5 | \title{Infer Shapes} 6 | \usage{ 7 | infer_shapes(shapes_in, param_vals, output_names, fn, rowwise, id) 8 | } 9 | \arguments{ 10 | \item{shapes_in}{(\code{list()})\cr 11 | A list of shapes of the input tensors.} 12 | 13 | \item{param_vals}{(\code{list()})\cr 14 | A list of named parameters for the function.} 15 | 16 | \item{output_names}{(\code{character()})\cr 17 | The names of the output tensors.} 18 | 19 | \item{fn}{(\verb{function()})\cr 20 | The function to infer the shapes for.} 21 | 22 | \item{rowwise}{(\code{logical(1)})\cr 23 | Whether the function is rowwise.} 24 | 25 | \item{id}{(\code{character(1)})\cr 26 | The id of the PipeOp (for error messages).} 27 | } 28 | \value{ 29 | (\code{list()})\cr 30 | A list of shapes of the output tensors. 31 | } 32 | \description{ 33 | Infer the shapes of the output of a function based on the shapes of the input. 34 | This is done as follows: 35 | \enumerate{ 36 | \item All \code{NA}s are replaced with values \code{1}, \code{2}, \code{3}. 37 | \item Three tensors are generated for the three shapes of step 1. 38 | \item The function is called on these three tensors and the shapes are calculated. 39 | \item If: 40 | \itemize{ 41 | \item the number of dimensions varies, an error is thrown. 42 | \item the number of dimensions is the same, values are set to \code{NA} if the dimension is varying 43 | between the three tensors and otherwise set to the unique value. 44 | } 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /man/ingress_categ.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/PipeOpTorchIngress.R 3 | \name{ingress_categ} 4 | \alias{ingress_categ} 5 | \title{Ingress Token for Categorical Features} 6 | \usage{ 7 | ingress_categ(shape = NULL) 8 | } 9 | \arguments{ 10 | \item{shape}{(\code{integer()} or \code{NULL})\cr 11 | Shape that \code{batchgetter} will produce. Batch-dimension should be included as \code{NA}.} 12 | } 13 | \value{ 14 | \code{\link{TorchIngressToken}} 15 | } 16 | \description{ 17 | Represents an entry point representing a tensor containing all categorical (\code{factor()}, \code{ordered()}, \code{logical()}) 18 | features of a task. 19 | } 20 | -------------------------------------------------------------------------------- /man/ingress_ltnsr.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/PipeOpTorchIngress.R 3 | \name{ingress_ltnsr} 4 | \alias{ingress_ltnsr} 5 | \title{Ingress Token for Lazy Tensor Feature} 6 | \usage{ 7 | ingress_ltnsr(feature_name = NULL, shape = NULL) 8 | } 9 | \arguments{ 10 | \item{feature_name}{(\code{character(1)})\cr 11 | Which lazy tensor feature to select if there is more than one.} 12 | 13 | \item{shape}{(\code{integer()} or \code{NULL})\cr 14 | Shape that \code{batchgetter} will produce. Batch-dimension should be included as \code{NA}.} 15 | } 16 | \value{ 17 | \code{\link{TorchIngressToken}} 18 | } 19 | \description{ 20 | Represents an entry point representing a tensor containing a single lazy tensor feature. 21 | } 22 | -------------------------------------------------------------------------------- /man/ingress_num.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/PipeOpTorchIngress.R 3 | \name{ingress_num} 4 | \alias{ingress_num} 5 | \title{Ingress Token for Numeric Features} 6 | \usage{ 7 | ingress_num(shape = NULL) 8 | } 9 | \arguments{ 10 | \item{shape}{(\code{integer()} or \code{NULL})\cr 11 | Shape that \code{batchgetter} will produce. Batch-dimension should be included as \code{NA}.} 12 | } 13 | \value{ 14 | \code{\link{TorchIngressToken}} 15 | } 16 | \description{ 17 | Represents an entry point representing a tensor containing all numeric (\code{integer()} and \code{double()}) 18 | features of a task. 19 | } 20 | -------------------------------------------------------------------------------- /man/is_lazy_tensor.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/lazy_tensor.R 3 | \name{is_lazy_tensor} 4 | \alias{is_lazy_tensor} 5 | \title{Check for lazy tensor} 6 | \usage{ 7 | is_lazy_tensor(x) 8 | } 9 | \arguments{ 10 | \item{x}{(any)\cr 11 | Object to check.} 12 | } 13 | \description{ 14 | Checks whether an object is a lazy tensor. 15 | } 16 | -------------------------------------------------------------------------------- /man/lazy_shape.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/lazy_tensor.R 3 | \name{lazy_shape} 4 | \alias{lazy_shape} 5 | \title{Shape of Lazy Tensor} 6 | \usage{ 7 | lazy_shape(x) 8 | } 9 | \arguments{ 10 | \item{x}{(\code{\link{lazy_tensor}})\cr 11 | Lazy tensor.} 12 | } 13 | \value{ 14 | (\code{integer()} or \code{NULL}) 15 | } 16 | \description{ 17 | Shape of a lazy tensor. Might be \code{NULL} if the shapes is not known or varying between rows. 18 | Batch dimension is always \code{NA}. 19 | } 20 | \examples{ 21 | \dontshow{if (torch::torch_is_installed()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} 22 | lt = as_lazy_tensor(1:10) 23 | lazy_shape(lt) 24 | lt = as_lazy_tensor(matrix(1:10, nrow = 2)) 25 | lazy_shape(lt) 26 | \dontshow{\}) # examplesIf} 27 | } 28 | -------------------------------------------------------------------------------- /man/lazy_tensor.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/lazy_tensor.R 3 | \name{lazy_tensor} 4 | \alias{lazy_tensor} 5 | \title{Create a lazy tensor} 6 | \usage{ 7 | lazy_tensor(data_descriptor = NULL, ids = NULL) 8 | } 9 | \arguments{ 10 | \item{data_descriptor}{(\code{\link{DataDescriptor}} or \code{NULL})\cr 11 | The data descriptor or \code{NULL} for a lazy tensor of length 0.} 12 | 13 | \item{ids}{(\code{integer()})\cr 14 | The elements of the \code{data_descriptor} to be included in the lazy tensor.} 15 | } 16 | \description{ 17 | Create a lazy tensor. 18 | } 19 | \examples{ 20 | \dontshow{if (torch::torch_is_installed()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} 21 | ds = dataset("example", 22 | initialize = function() self$iris = iris[, -5], 23 | .getitem = function(i) list(x = torch_tensor(as.numeric(self$iris[i, ]))), 24 | .length = function() nrow(self$iris) 25 | )() 26 | dd = as_data_descriptor(ds, list(x = c(NA, 4L))) 27 | lt = as_lazy_tensor(dd) 28 | \dontshow{\}) # examplesIf} 29 | } 30 | -------------------------------------------------------------------------------- /man/mlr3torch-package.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/zzz.R 3 | \docType{package} 4 | \name{mlr3torch-package} 5 | \alias{mlr3torch} 6 | \alias{mlr3torch-package} 7 | \title{mlr3torch: Deep Learning with 'mlr3'} 8 | \description{ 9 | Deep Learning library that extends the mlr3 framework by building upon the 'torch' package. It allows to conveniently build, train, and evaluate deep learning models without having to worry about low level details. Custom architectures can be created using the graph language defined in 'mlr3pipelines'. 10 | } 11 | \section{Options}{ 12 | 13 | \itemize{ 14 | \item \code{mlr3torch.cache}: 15 | Whether to cache the downloaded data (\code{TRUE}) or not (\code{FALSE}, default). 16 | This can also be set to a specific folder on the file system to be used as the cache directory. 17 | } 18 | } 19 | 20 | \seealso{ 21 | Useful links: 22 | \itemize{ 23 | \item \url{https://mlr3torch.mlr-org.com/} 24 | \item \url{https://github.com/mlr-org/mlr3torch/} 25 | \item Report bugs at \url{https://github.com/mlr-org/mlr3torch/issues} 26 | } 27 | 28 | } 29 | \author{ 30 | \strong{Maintainer}: Sebastian Fischer \email{sebf.fischer@gmail.com} (\href{https://orcid.org/0000-0002-9609-3197}{ORCID}) 31 | 32 | Authors: 33 | \itemize{ 34 | \item Martin Binder \email{mlr.developer@mb706.com} 35 | } 36 | 37 | Other contributors: 38 | \itemize{ 39 | \item Bernd Bischl \email{bernd_bischl@gmx.net} (\href{https://orcid.org/0000-0001-6002-6980}{ORCID}) [contributor] 40 | \item Lukas Burk \email{github@quantenbrot.de} (\href{https://orcid.org/0000-0001-7528-3795}{ORCID}) [contributor] 41 | \item Florian Pfisterer \email{pfistererf@googlemail.com} (\href{https://orcid.org/0000-0001-8867-762X}{ORCID}) [contributor] 42 | \item Carson Zhang \email{carsonzhang4@gmail.com} [contributor] 43 | } 44 | 45 | } 46 | -------------------------------------------------------------------------------- /man/mlr3torch_callbacks.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/TorchCallback.R 3 | \docType{data} 4 | \name{mlr3torch_callbacks} 5 | \alias{mlr3torch_callbacks} 6 | \title{Dictionary of Torch Callbacks} 7 | \format{ 8 | An object of class \code{DictionaryMlr3torchCallbacks} (inherits from \code{Dictionary}, \code{R6}) of length 12. 9 | } 10 | \usage{ 11 | mlr3torch_callbacks 12 | } 13 | \description{ 14 | A \code{\link[mlr3misc:Dictionary]{mlr3misc::Dictionary}} of torch callbacks. 15 | Use \code{\link[=t_clbk]{t_clbk()}} to conveniently retrieve callbacks. 16 | Can be converted to a \code{\link[data.table:data.table]{data.table}} using 17 | \code{\link[data.table:as.data.table]{as.data.table}}. 18 | } 19 | \examples{ 20 | \dontshow{if (torch::torch_is_installed()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} 21 | mlr3torch_callbacks$get("checkpoint") 22 | # is the same as 23 | t_clbk("checkpoint") 24 | # convert to a data.table 25 | as.data.table(mlr3torch_callbacks) 26 | \dontshow{\}) # examplesIf} 27 | } 28 | \seealso{ 29 | Other Callback: 30 | \code{\link{TorchCallback}}, 31 | \code{\link{as_torch_callback}()}, 32 | \code{\link{as_torch_callbacks}()}, 33 | \code{\link{callback_set}()}, 34 | \code{\link{mlr_callback_set}}, 35 | \code{\link{mlr_callback_set.checkpoint}}, 36 | \code{\link{mlr_callback_set.progress}}, 37 | \code{\link{mlr_callback_set.tb}}, 38 | \code{\link{mlr_callback_set.unfreeze}}, 39 | \code{\link{mlr_context_torch}}, 40 | \code{\link{t_clbk}()}, 41 | \code{\link{torch_callback}()} 42 | 43 | Other Dictionary: 44 | \code{\link{mlr3torch_losses}}, 45 | \code{\link{mlr3torch_optimizers}}, 46 | \code{\link{t_opt}()} 47 | } 48 | \concept{Callback} 49 | \concept{Dictionary} 50 | \keyword{datasets} 51 | -------------------------------------------------------------------------------- /man/mlr3torch_losses.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/TorchLoss.R 3 | \docType{data} 4 | \name{mlr3torch_losses} 5 | \alias{mlr3torch_losses} 6 | \title{Loss Functions} 7 | \format{ 8 | An object of class \code{DictionaryMlr3torchLosses} (inherits from \code{Dictionary}, \code{R6}) of length 12. 9 | } 10 | \usage{ 11 | mlr3torch_losses 12 | } 13 | \description{ 14 | Dictionary of torch loss descriptors. 15 | See \code{\link[=t_loss]{t_loss()}} for conveniently retrieving a loss function. 16 | Can be converted to a \code{\link[data.table:data.table]{data.table}} using 17 | \code{\link[data.table:as.data.table]{as.data.table}}. 18 | } 19 | \section{Available Loss Functions}{ 20 | 21 | cross_entropy, l1, mse 22 | } 23 | 24 | \examples{ 25 | \dontshow{if (torch::torch_is_installed()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} 26 | mlr3torch_losses$get("mse") 27 | # is equivalent to 28 | t_loss("mse") 29 | # convert to a data.table 30 | as.data.table(mlr3torch_losses) 31 | \dontshow{\}) # examplesIf} 32 | } 33 | \seealso{ 34 | Other Torch Descriptor: 35 | \code{\link{TorchCallback}}, 36 | \code{\link{TorchDescriptor}}, 37 | \code{\link{TorchLoss}}, 38 | \code{\link{TorchOptimizer}}, 39 | \code{\link{as_torch_callbacks}()}, 40 | \code{\link{as_torch_loss}()}, 41 | \code{\link{as_torch_optimizer}()}, 42 | \code{\link{mlr3torch_optimizers}}, 43 | \code{\link{t_clbk}()}, 44 | \code{\link{t_loss}()}, 45 | \code{\link{t_opt}()} 46 | 47 | Other Dictionary: 48 | \code{\link{mlr3torch_callbacks}}, 49 | \code{\link{mlr3torch_optimizers}}, 50 | \code{\link{t_opt}()} 51 | } 52 | \concept{Dictionary} 53 | \concept{Torch Descriptor} 54 | \keyword{datasets} 55 | -------------------------------------------------------------------------------- /man/mlr3torch_optimizers.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/TorchOptimizer.R 3 | \docType{data} 4 | \name{mlr3torch_optimizers} 5 | \alias{mlr3torch_optimizers} 6 | \title{Optimizers} 7 | \format{ 8 | An object of class \code{DictionaryMlr3torchOptimizers} (inherits from \code{Dictionary}, \code{R6}) of length 12. 9 | } 10 | \usage{ 11 | mlr3torch_optimizers 12 | } 13 | \description{ 14 | Dictionary of torch optimizers. 15 | Use \code{\link{t_opt}} for conveniently retrieving optimizers. 16 | Can be converted to a \code{\link[data.table:data.table]{data.table}} using 17 | \code{\link[data.table:as.data.table]{as.data.table}}. 18 | } 19 | \section{Available Optimizers}{ 20 | 21 | adagrad, adam, adamw, rmsprop, sgd 22 | } 23 | 24 | \examples{ 25 | \dontshow{if (torch::torch_is_installed()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} 26 | mlr3torch_optimizers$get("adam") 27 | # is equivalent to 28 | t_opt("adam") 29 | # convert to a data.table 30 | as.data.table(mlr3torch_optimizers) 31 | \dontshow{\}) # examplesIf} 32 | } 33 | \seealso{ 34 | Other Torch Descriptor: 35 | \code{\link{TorchCallback}}, 36 | \code{\link{TorchDescriptor}}, 37 | \code{\link{TorchLoss}}, 38 | \code{\link{TorchOptimizer}}, 39 | \code{\link{as_torch_callbacks}()}, 40 | \code{\link{as_torch_loss}()}, 41 | \code{\link{as_torch_optimizer}()}, 42 | \code{\link{mlr3torch_losses}}, 43 | \code{\link{t_clbk}()}, 44 | \code{\link{t_loss}()}, 45 | \code{\link{t_opt}()} 46 | 47 | Other Dictionary: 48 | \code{\link{mlr3torch_callbacks}}, 49 | \code{\link{mlr3torch_losses}}, 50 | \code{\link{t_opt}()} 51 | } 52 | \concept{Dictionary} 53 | \concept{Torch Descriptor} 54 | \keyword{datasets} 55 | -------------------------------------------------------------------------------- /man/mlr_pipeops_augment_center_crop.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/preprocess.R 3 | \name{mlr_pipeops_augment_center_crop} 4 | \alias{mlr_pipeops_augment_center_crop} 5 | \alias{PipeOpPreprocTorchAugmentCenterCrop} 6 | \title{Center Crop Augmentation} 7 | \format{ 8 | \code{\link[R6:R6Class]{R6Class}} inheriting from \code{\link{PipeOpTaskPreprocTorch}}. 9 | } 10 | \description{ 11 | Calls \code{\link[torchvision:transform_center_crop]{torchvision::transform_center_crop}}, 12 | see there for more information on the parameters. 13 | The preprocessing is applied to each element of a batch individually. 14 | } 15 | \section{Construction}{ 16 | 17 | 18 | \if{html}{\out{
}}\preformatted{po("augment_center_crop"") 19 | }\if{html}{\out{
}} 20 | } 21 | 22 | \section{Parameters}{ 23 | \tabular{llll}{ 24 | Id \tab Type \tab Default \tab Levels \cr 25 | size \tab untyped \tab - \tab \cr 26 | stages \tab character \tab - \tab train, predict, both \cr 27 | affect_columns \tab untyped \tab selector_all() \tab \cr 28 | } 29 | } 30 | 31 | -------------------------------------------------------------------------------- /man/mlr_pipeops_augment_color_jitter.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/preprocess.R 3 | \name{mlr_pipeops_augment_color_jitter} 4 | \alias{mlr_pipeops_augment_color_jitter} 5 | \alias{PipeOpPreprocTorchAugmentColorJitter} 6 | \title{Color Jitter Augmentation} 7 | \format{ 8 | \code{\link[R6:R6Class]{R6Class}} inheriting from \code{\link{PipeOpTaskPreprocTorch}}. 9 | } 10 | \description{ 11 | Calls \code{\link[torchvision:transform_color_jitter]{torchvision::transform_color_jitter}}, 12 | see there for more information on the parameters. 13 | The preprocessing is applied to each element of a batch individually. 14 | } 15 | \section{Construction}{ 16 | 17 | 18 | \if{html}{\out{
}}\preformatted{po("augment_color_jitter"") 19 | }\if{html}{\out{
}} 20 | } 21 | 22 | \section{Parameters}{ 23 | \tabular{lllll}{ 24 | Id \tab Type \tab Default \tab Levels \tab Range \cr 25 | brightness \tab numeric \tab 0 \tab \tab \eqn{[0, \infty)}{[0, Inf)} \cr 26 | contrast \tab numeric \tab 0 \tab \tab \eqn{[0, \infty)}{[0, Inf)} \cr 27 | saturation \tab numeric \tab 0 \tab \tab \eqn{[0, \infty)}{[0, Inf)} \cr 28 | hue \tab numeric \tab 0 \tab \tab \eqn{[0, \infty)}{[0, Inf)} \cr 29 | stages \tab character \tab - \tab train, predict, both \tab - \cr 30 | affect_columns \tab untyped \tab selector_all() \tab \tab - \cr 31 | } 32 | } 33 | 34 | -------------------------------------------------------------------------------- /man/mlr_pipeops_augment_crop.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/preprocess.R 3 | \name{mlr_pipeops_augment_crop} 4 | \alias{mlr_pipeops_augment_crop} 5 | \alias{PipeOpPreprocTorchAugmentCrop} 6 | \title{Crop Augmentation} 7 | \format{ 8 | \code{\link[R6:R6Class]{R6Class}} inheriting from \code{\link{PipeOpTaskPreprocTorch}}. 9 | } 10 | \description{ 11 | Calls \code{\link[torchvision:transform_crop]{torchvision::transform_crop}}, 12 | see there for more information on the parameters. 13 | The preprocessing is applied to each element of a batch individually. 14 | } 15 | \section{Construction}{ 16 | 17 | 18 | \if{html}{\out{
}}\preformatted{po("augment_crop"") 19 | }\if{html}{\out{
}} 20 | } 21 | 22 | \section{Parameters}{ 23 | \tabular{lllll}{ 24 | Id \tab Type \tab Default \tab Levels \tab Range \cr 25 | top \tab integer \tab - \tab \tab \eqn{(-\infty, \infty)}{(-Inf, Inf)} \cr 26 | left \tab integer \tab - \tab \tab \eqn{(-\infty, \infty)}{(-Inf, Inf)} \cr 27 | height \tab integer \tab - \tab \tab \eqn{(-\infty, \infty)}{(-Inf, Inf)} \cr 28 | width \tab integer \tab - \tab \tab \eqn{(-\infty, \infty)}{(-Inf, Inf)} \cr 29 | stages \tab character \tab - \tab train, predict, both \tab - \cr 30 | affect_columns \tab untyped \tab selector_all() \tab \tab - \cr 31 | } 32 | } 33 | 34 | -------------------------------------------------------------------------------- /man/mlr_pipeops_augment_hflip.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/preprocess.R 3 | \name{mlr_pipeops_augment_hflip} 4 | \alias{mlr_pipeops_augment_hflip} 5 | \alias{PipeOpPreprocTorchAugmentHflip} 6 | \title{Horizontal Flip Augmentation} 7 | \format{ 8 | \code{\link[R6:R6Class]{R6Class}} inheriting from \code{\link{PipeOpTaskPreprocTorch}}. 9 | } 10 | \description{ 11 | Calls \code{\link[torchvision:transform_hflip]{torchvision::transform_hflip}}, 12 | see there for more information on the parameters. 13 | The preprocessing is applied to each element of a batch individually. 14 | } 15 | \section{Construction}{ 16 | 17 | 18 | \if{html}{\out{
}}\preformatted{po("augment_hflip"") 19 | }\if{html}{\out{
}} 20 | } 21 | 22 | \section{Parameters}{ 23 | \tabular{llll}{ 24 | Id \tab Type \tab Default \tab Levels \cr 25 | stages \tab character \tab - \tab train, predict, both \cr 26 | affect_columns \tab untyped \tab selector_all() \tab \cr 27 | } 28 | } 29 | 30 | -------------------------------------------------------------------------------- /man/mlr_pipeops_augment_random_affine.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/preprocess.R 3 | \name{mlr_pipeops_augment_random_affine} 4 | \alias{mlr_pipeops_augment_random_affine} 5 | \alias{PipeOpPreprocTorchAugmentRandomAffine} 6 | \title{Random Affine Augmentation} 7 | \format{ 8 | \code{\link[R6:R6Class]{R6Class}} inheriting from \code{\link{PipeOpTaskPreprocTorch}}. 9 | } 10 | \description{ 11 | Calls \code{\link[torchvision:transform_random_affine]{torchvision::transform_random_affine}}, 12 | see there for more information on the parameters. 13 | The preprocessing is applied to each element of a batch individually. 14 | } 15 | \section{Construction}{ 16 | 17 | 18 | \if{html}{\out{
}}\preformatted{po("augment_random_affine"") 19 | }\if{html}{\out{
}} 20 | } 21 | 22 | \section{Parameters}{ 23 | \tabular{lllll}{ 24 | Id \tab Type \tab Default \tab Levels \tab Range \cr 25 | degrees \tab untyped \tab - \tab \tab - \cr 26 | translate \tab untyped \tab NULL \tab \tab - \cr 27 | scale \tab untyped \tab NULL \tab \tab - \cr 28 | resample \tab integer \tab 0 \tab \tab \eqn{(-\infty, \infty)}{(-Inf, Inf)} \cr 29 | fillcolor \tab untyped \tab 0 \tab \tab - \cr 30 | stages \tab character \tab - \tab train, predict, both \tab - \cr 31 | affect_columns \tab untyped \tab selector_all() \tab \tab - \cr 32 | } 33 | } 34 | 35 | -------------------------------------------------------------------------------- /man/mlr_pipeops_augment_random_choice.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/preprocess.R 3 | \name{mlr_pipeops_augment_random_choice} 4 | \alias{mlr_pipeops_augment_random_choice} 5 | \alias{PipeOpPreprocTorchAugmentRandomChoice} 6 | \title{Random Choice Augmentation} 7 | \format{ 8 | \code{\link[R6:R6Class]{R6Class}} inheriting from \code{\link{PipeOpTaskPreprocTorch}}. 9 | } 10 | \description{ 11 | Calls \code{\link[torchvision:transform_random_choice]{torchvision::transform_random_choice}}, 12 | see there for more information on the parameters. 13 | The preprocessing is applied to each element of a batch individually. 14 | } 15 | \section{Construction}{ 16 | 17 | 18 | \if{html}{\out{
}}\preformatted{po("augment_random_choice"") 19 | }\if{html}{\out{
}} 20 | } 21 | 22 | \section{Parameters}{ 23 | \tabular{llll}{ 24 | Id \tab Type \tab Default \tab Levels \cr 25 | transforms \tab untyped \tab - \tab \cr 26 | stages \tab character \tab - \tab train, predict, both \cr 27 | affect_columns \tab untyped \tab selector_all() \tab \cr 28 | } 29 | } 30 | 31 | -------------------------------------------------------------------------------- /man/mlr_pipeops_augment_random_crop.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/preprocess.R 3 | \name{mlr_pipeops_augment_random_crop} 4 | \alias{mlr_pipeops_augment_random_crop} 5 | \alias{PipeOpPreprocTorchAugmentRandomCrop} 6 | \title{Random Crop Augmentation} 7 | \format{ 8 | \code{\link[R6:R6Class]{R6Class}} inheriting from \code{\link{PipeOpTaskPreprocTorch}}. 9 | } 10 | \description{ 11 | Calls \code{\link[torchvision:transform_random_crop]{torchvision::transform_random_crop}}, 12 | see there for more information on the parameters. 13 | The preprocessing is applied to each element of a batch individually. 14 | } 15 | \section{Construction}{ 16 | 17 | 18 | \if{html}{\out{
}}\preformatted{po("augment_random_crop"") 19 | }\if{html}{\out{
}} 20 | } 21 | 22 | \section{Parameters}{ 23 | \tabular{llll}{ 24 | Id \tab Type \tab Default \tab Levels \cr 25 | size \tab untyped \tab - \tab \cr 26 | padding \tab untyped \tab NULL \tab \cr 27 | pad_if_needed \tab logical \tab FALSE \tab TRUE, FALSE \cr 28 | fill \tab untyped \tab 0L \tab \cr 29 | padding_mode \tab character \tab constant \tab constant, edge, reflect, symmetric \cr 30 | stages \tab character \tab - \tab train, predict, both \cr 31 | affect_columns \tab untyped \tab selector_all() \tab \cr 32 | } 33 | } 34 | 35 | -------------------------------------------------------------------------------- /man/mlr_pipeops_augment_random_horizontal_flip.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/preprocess.R 3 | \name{mlr_pipeops_augment_random_horizontal_flip} 4 | \alias{mlr_pipeops_augment_random_horizontal_flip} 5 | \alias{PipeOpPreprocTorchAugmentRandomHorizontalFlip} 6 | \title{Random Horizontal Flip Augmentation} 7 | \format{ 8 | \code{\link[R6:R6Class]{R6Class}} inheriting from \code{\link{PipeOpTaskPreprocTorch}}. 9 | } 10 | \description{ 11 | Calls \code{\link[torchvision:transform_random_horizontal_flip]{torchvision::transform_random_horizontal_flip}}, 12 | see there for more information on the parameters. 13 | The preprocessing is applied to each element of a batch individually. 14 | } 15 | \section{Construction}{ 16 | 17 | 18 | \if{html}{\out{
}}\preformatted{po("augment_random_horizontal_flip"") 19 | }\if{html}{\out{
}} 20 | } 21 | 22 | \section{Parameters}{ 23 | \tabular{lllll}{ 24 | Id \tab Type \tab Default \tab Levels \tab Range \cr 25 | p \tab numeric \tab 0.5 \tab \tab \eqn{[0, 1]}{[0, 1]} \cr 26 | stages \tab character \tab - \tab train, predict, both \tab - \cr 27 | affect_columns \tab untyped \tab selector_all() \tab \tab - \cr 28 | } 29 | } 30 | 31 | -------------------------------------------------------------------------------- /man/mlr_pipeops_augment_random_order.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/preprocess.R 3 | \name{mlr_pipeops_augment_random_order} 4 | \alias{mlr_pipeops_augment_random_order} 5 | \alias{PipeOpPreprocTorchAugmentRandomOrder} 6 | \title{Random Order Augmentation} 7 | \format{ 8 | \code{\link[R6:R6Class]{R6Class}} inheriting from \code{\link{PipeOpTaskPreprocTorch}}. 9 | } 10 | \description{ 11 | Calls \code{\link[torchvision:transform_random_order]{torchvision::transform_random_order}}, 12 | see there for more information on the parameters. 13 | The preprocessing is applied to each element of a batch individually. 14 | } 15 | \section{Construction}{ 16 | 17 | 18 | \if{html}{\out{
}}\preformatted{po("augment_random_order"") 19 | }\if{html}{\out{
}} 20 | } 21 | 22 | \section{Parameters}{ 23 | \tabular{llll}{ 24 | Id \tab Type \tab Default \tab Levels \cr 25 | transforms \tab untyped \tab - \tab \cr 26 | stages \tab character \tab - \tab train, predict, both \cr 27 | affect_columns \tab untyped \tab selector_all() \tab \cr 28 | } 29 | } 30 | 31 | -------------------------------------------------------------------------------- /man/mlr_pipeops_augment_random_resized_crop.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/preprocess.R 3 | \name{mlr_pipeops_augment_random_resized_crop} 4 | \alias{mlr_pipeops_augment_random_resized_crop} 5 | \alias{PipeOpPreprocTorchAugmentRandomResizedCrop} 6 | \title{Random Resized Crop Augmentation} 7 | \format{ 8 | \code{\link[R6:R6Class]{R6Class}} inheriting from \code{\link{PipeOpTaskPreprocTorch}}. 9 | } 10 | \description{ 11 | Calls \code{\link[torchvision:transform_random_resized_crop]{torchvision::transform_random_resized_crop}}, 12 | see there for more information on the parameters. 13 | The preprocessing is applied to each element of a batch individually. 14 | } 15 | \section{Construction}{ 16 | 17 | 18 | \if{html}{\out{
}}\preformatted{po("augment_random_resized_crop"") 19 | }\if{html}{\out{
}} 20 | } 21 | 22 | \section{Parameters}{ 23 | \tabular{lllll}{ 24 | Id \tab Type \tab Default \tab Levels \tab Range \cr 25 | size \tab untyped \tab - \tab \tab - \cr 26 | scale \tab untyped \tab c(0.08, 1) \tab \tab - \cr 27 | ratio \tab untyped \tab c(3/4, 4/3) \tab \tab - \cr 28 | interpolation \tab integer \tab 2 \tab \tab \eqn{[0, 3]}{[0, 3]} \cr 29 | stages \tab character \tab - \tab train, predict, both \tab - \cr 30 | affect_columns \tab untyped \tab selector_all() \tab \tab - \cr 31 | } 32 | } 33 | 34 | -------------------------------------------------------------------------------- /man/mlr_pipeops_augment_random_vertical_flip.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/preprocess.R 3 | \name{mlr_pipeops_augment_random_vertical_flip} 4 | \alias{mlr_pipeops_augment_random_vertical_flip} 5 | \alias{PipeOpPreprocTorchAugmentRandomVerticalFlip} 6 | \title{Random Vertical Flip Augmentation} 7 | \format{ 8 | \code{\link[R6:R6Class]{R6Class}} inheriting from \code{\link{PipeOpTaskPreprocTorch}}. 9 | } 10 | \description{ 11 | Calls \code{\link[torchvision:transform_random_vertical_flip]{torchvision::transform_random_vertical_flip}}, 12 | see there for more information on the parameters. 13 | The preprocessing is applied to each element of a batch individually. 14 | } 15 | \section{Construction}{ 16 | 17 | 18 | \if{html}{\out{
}}\preformatted{po("augment_random_vertical_flip"") 19 | }\if{html}{\out{
}} 20 | } 21 | 22 | \section{Parameters}{ 23 | \tabular{lllll}{ 24 | Id \tab Type \tab Default \tab Levels \tab Range \cr 25 | p \tab numeric \tab 0.5 \tab \tab \eqn{[0, 1]}{[0, 1]} \cr 26 | stages \tab character \tab - \tab train, predict, both \tab - \cr 27 | affect_columns \tab untyped \tab selector_all() \tab \tab - \cr 28 | } 29 | } 30 | 31 | -------------------------------------------------------------------------------- /man/mlr_pipeops_augment_resized_crop.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/preprocess.R 3 | \name{mlr_pipeops_augment_resized_crop} 4 | \alias{mlr_pipeops_augment_resized_crop} 5 | \alias{PipeOpPreprocTorchAugmentResizedCrop} 6 | \title{Resized Crop Augmentation} 7 | \format{ 8 | \code{\link[R6:R6Class]{R6Class}} inheriting from \code{\link{PipeOpTaskPreprocTorch}}. 9 | } 10 | \description{ 11 | Calls \code{\link[torchvision:transform_resized_crop]{torchvision::transform_resized_crop}}, 12 | see there for more information on the parameters. 13 | The preprocessing is applied to each element of a batch individually. 14 | } 15 | \section{Construction}{ 16 | 17 | 18 | \if{html}{\out{
}}\preformatted{po("augment_resized_crop"") 19 | }\if{html}{\out{
}} 20 | } 21 | 22 | \section{Parameters}{ 23 | \tabular{lllll}{ 24 | Id \tab Type \tab Default \tab Levels \tab Range \cr 25 | top \tab integer \tab - \tab \tab \eqn{(-\infty, \infty)}{(-Inf, Inf)} \cr 26 | left \tab integer \tab - \tab \tab \eqn{(-\infty, \infty)}{(-Inf, Inf)} \cr 27 | height \tab integer \tab - \tab \tab \eqn{(-\infty, \infty)}{(-Inf, Inf)} \cr 28 | width \tab integer \tab - \tab \tab \eqn{(-\infty, \infty)}{(-Inf, Inf)} \cr 29 | size \tab untyped \tab - \tab \tab - \cr 30 | interpolation \tab integer \tab 2 \tab \tab \eqn{[0, 3]}{[0, 3]} \cr 31 | stages \tab character \tab - \tab train, predict, both \tab - \cr 32 | affect_columns \tab untyped \tab selector_all() \tab \tab - \cr 33 | } 34 | } 35 | 36 | -------------------------------------------------------------------------------- /man/mlr_pipeops_augment_rotate.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/preprocess.R 3 | \name{mlr_pipeops_augment_rotate} 4 | \alias{mlr_pipeops_augment_rotate} 5 | \alias{PipeOpPreprocTorchAugmentRotate} 6 | \title{Rotate Augmentation} 7 | \format{ 8 | \code{\link[R6:R6Class]{R6Class}} inheriting from \code{\link{PipeOpTaskPreprocTorch}}. 9 | } 10 | \description{ 11 | Calls \code{\link[torchvision:transform_rotate]{torchvision::transform_rotate}}, 12 | see there for more information on the parameters. 13 | The preprocessing is applied to each element of a batch individually. 14 | } 15 | \section{Construction}{ 16 | 17 | 18 | \if{html}{\out{
}}\preformatted{po("augment_rotate"") 19 | }\if{html}{\out{
}} 20 | } 21 | 22 | \section{Parameters}{ 23 | \tabular{lllll}{ 24 | Id \tab Type \tab Default \tab Levels \tab Range \cr 25 | angle \tab untyped \tab - \tab \tab - \cr 26 | resample \tab integer \tab 0 \tab \tab \eqn{(-\infty, \infty)}{(-Inf, Inf)} \cr 27 | expand \tab logical \tab FALSE \tab TRUE, FALSE \tab - \cr 28 | center \tab untyped \tab NULL \tab \tab - \cr 29 | fill \tab untyped \tab NULL \tab \tab - \cr 30 | stages \tab character \tab - \tab train, predict, both \tab - \cr 31 | affect_columns \tab untyped \tab selector_all() \tab \tab - \cr 32 | } 33 | } 34 | 35 | -------------------------------------------------------------------------------- /man/mlr_pipeops_augment_vflip.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/preprocess.R 3 | \name{mlr_pipeops_augment_vflip} 4 | \alias{mlr_pipeops_augment_vflip} 5 | \alias{PipeOpPreprocTorchAugmentVflip} 6 | \title{Vertical Flip Augmentation} 7 | \format{ 8 | \code{\link[R6:R6Class]{R6Class}} inheriting from \code{\link{PipeOpTaskPreprocTorch}}. 9 | } 10 | \description{ 11 | Calls \code{\link[torchvision:transform_vflip]{torchvision::transform_vflip}}, 12 | see there for more information on the parameters. 13 | The preprocessing is applied to each element of a batch individually. 14 | } 15 | \section{Construction}{ 16 | 17 | 18 | \if{html}{\out{
}}\preformatted{po("augment_vflip"") 19 | }\if{html}{\out{
}} 20 | } 21 | 22 | \section{Parameters}{ 23 | \tabular{llll}{ 24 | Id \tab Type \tab Default \tab Levels \cr 25 | stages \tab character \tab - \tab train, predict, both \cr 26 | affect_columns \tab untyped \tab selector_all() \tab \cr 27 | } 28 | } 29 | 30 | -------------------------------------------------------------------------------- /man/mlr_pipeops_trafo_adjust_brightness.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/preprocess.R 3 | \name{mlr_pipeops_trafo_adjust_brightness} 4 | \alias{mlr_pipeops_trafo_adjust_brightness} 5 | \alias{PipeOpPreprocTorchTrafoAdjustBrightness} 6 | \title{Adjust Brightness Transformation} 7 | \format{ 8 | \code{\link[R6:R6Class]{R6Class}} inheriting from \code{\link{PipeOpTaskPreprocTorch}}. 9 | } 10 | \description{ 11 | Calls \code{\link[torchvision:transform_adjust_brightness]{torchvision::transform_adjust_brightness}}, 12 | see there for more information on the parameters. 13 | The preprocessing is applied to each element of a batch individually. 14 | } 15 | \section{Construction}{ 16 | 17 | 18 | \if{html}{\out{
}}\preformatted{po("trafo_adjust_brightness"") 19 | }\if{html}{\out{
}} 20 | } 21 | 22 | \section{Parameters}{ 23 | \tabular{lllll}{ 24 | Id \tab Type \tab Default \tab Levels \tab Range \cr 25 | brightness_factor \tab numeric \tab - \tab \tab \eqn{[0, \infty)}{[0, Inf)} \cr 26 | stages \tab character \tab - \tab train, predict, both \tab - \cr 27 | affect_columns \tab untyped \tab selector_all() \tab \tab - \cr 28 | } 29 | } 30 | 31 | -------------------------------------------------------------------------------- /man/mlr_pipeops_trafo_adjust_gamma.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/preprocess.R 3 | \name{mlr_pipeops_trafo_adjust_gamma} 4 | \alias{mlr_pipeops_trafo_adjust_gamma} 5 | \alias{PipeOpPreprocTorchTrafoAdjustGamma} 6 | \title{Adjust Gamma Transformation} 7 | \format{ 8 | \code{\link[R6:R6Class]{R6Class}} inheriting from \code{\link{PipeOpTaskPreprocTorch}}. 9 | } 10 | \description{ 11 | Calls \code{\link[torchvision:transform_adjust_gamma]{torchvision::transform_adjust_gamma}}, 12 | see there for more information on the parameters. 13 | The preprocessing is applied to each element of a batch individually. 14 | } 15 | \section{Construction}{ 16 | 17 | 18 | \if{html}{\out{
}}\preformatted{po("trafo_adjust_gamma"") 19 | }\if{html}{\out{
}} 20 | } 21 | 22 | \section{Parameters}{ 23 | \tabular{lllll}{ 24 | Id \tab Type \tab Default \tab Levels \tab Range \cr 25 | gamma \tab numeric \tab - \tab \tab \eqn{[0, \infty)}{[0, Inf)} \cr 26 | gain \tab numeric \tab 1 \tab \tab \eqn{(-\infty, \infty)}{(-Inf, Inf)} \cr 27 | stages \tab character \tab - \tab train, predict, both \tab - \cr 28 | affect_columns \tab untyped \tab selector_all() \tab \tab - \cr 29 | } 30 | } 31 | 32 | -------------------------------------------------------------------------------- /man/mlr_pipeops_trafo_adjust_hue.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/preprocess.R 3 | \name{mlr_pipeops_trafo_adjust_hue} 4 | \alias{mlr_pipeops_trafo_adjust_hue} 5 | \alias{PipeOpPreprocTorchTrafoAdjustHue} 6 | \title{Adjust Hue Transformation} 7 | \format{ 8 | \code{\link[R6:R6Class]{R6Class}} inheriting from \code{\link{PipeOpTaskPreprocTorch}}. 9 | } 10 | \description{ 11 | Calls \code{\link[torchvision:transform_adjust_hue]{torchvision::transform_adjust_hue}}, 12 | see there for more information on the parameters. 13 | The preprocessing is applied to each element of a batch individually. 14 | } 15 | \section{Construction}{ 16 | 17 | 18 | \if{html}{\out{
}}\preformatted{po("trafo_adjust_hue"") 19 | }\if{html}{\out{
}} 20 | } 21 | 22 | \section{Parameters}{ 23 | \tabular{lllll}{ 24 | Id \tab Type \tab Default \tab Levels \tab Range \cr 25 | hue_factor \tab numeric \tab - \tab \tab \eqn{[-0.5, 0.5]}{[-0.5, 0.5]} \cr 26 | stages \tab character \tab - \tab train, predict, both \tab - \cr 27 | affect_columns \tab untyped \tab selector_all() \tab \tab - \cr 28 | } 29 | } 30 | 31 | -------------------------------------------------------------------------------- /man/mlr_pipeops_trafo_adjust_saturation.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/preprocess.R 3 | \name{mlr_pipeops_trafo_adjust_saturation} 4 | \alias{mlr_pipeops_trafo_adjust_saturation} 5 | \alias{PipeOpPreprocTorchTrafoAdjustSaturation} 6 | \title{Adjust Saturation Transformation} 7 | \format{ 8 | \code{\link[R6:R6Class]{R6Class}} inheriting from \code{\link{PipeOpTaskPreprocTorch}}. 9 | } 10 | \description{ 11 | Calls \code{\link[torchvision:transform_adjust_saturation]{torchvision::transform_adjust_saturation}}, 12 | see there for more information on the parameters. 13 | The preprocessing is applied to each element of a batch individually. 14 | } 15 | \section{Construction}{ 16 | 17 | 18 | \if{html}{\out{
}}\preformatted{po("trafo_adjust_saturation"") 19 | }\if{html}{\out{
}} 20 | } 21 | 22 | \section{Parameters}{ 23 | \tabular{lllll}{ 24 | Id \tab Type \tab Default \tab Levels \tab Range \cr 25 | saturation_factor \tab numeric \tab - \tab \tab \eqn{(-\infty, \infty)}{(-Inf, Inf)} \cr 26 | stages \tab character \tab - \tab train, predict, both \tab - \cr 27 | affect_columns \tab untyped \tab selector_all() \tab \tab - \cr 28 | } 29 | } 30 | 31 | -------------------------------------------------------------------------------- /man/mlr_pipeops_trafo_grayscale.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/preprocess.R 3 | \name{mlr_pipeops_trafo_grayscale} 4 | \alias{mlr_pipeops_trafo_grayscale} 5 | \alias{PipeOpPreprocTorchTrafoGrayscale} 6 | \title{Grayscale Transformation} 7 | \format{ 8 | \code{\link[R6:R6Class]{R6Class}} inheriting from \code{\link{PipeOpTaskPreprocTorch}}. 9 | } 10 | \description{ 11 | Calls \code{\link[torchvision:transform_grayscale]{torchvision::transform_grayscale}}, 12 | see there for more information on the parameters. 13 | The preprocessing is applied to each element of a batch individually. 14 | } 15 | \section{Construction}{ 16 | 17 | 18 | \if{html}{\out{
}}\preformatted{po("trafo_grayscale"") 19 | }\if{html}{\out{
}} 20 | } 21 | 22 | \section{Parameters}{ 23 | \tabular{lllll}{ 24 | Id \tab Type \tab Default \tab Levels \tab Range \cr 25 | num_output_channels \tab integer \tab - \tab \tab \eqn{[1, 3]}{[1, 3]} \cr 26 | stages \tab character \tab - \tab train, predict, both \tab - \cr 27 | affect_columns \tab untyped \tab selector_all() \tab \tab - \cr 28 | } 29 | } 30 | 31 | -------------------------------------------------------------------------------- /man/mlr_pipeops_trafo_normalize.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/preprocess.R 3 | \name{mlr_pipeops_trafo_normalize} 4 | \alias{mlr_pipeops_trafo_normalize} 5 | \alias{PipeOpPreprocTorchTrafoNormalize} 6 | \title{Normalization Transformation} 7 | \format{ 8 | \code{\link[R6:R6Class]{R6Class}} inheriting from \code{\link{PipeOpTaskPreprocTorch}}. 9 | } 10 | \description{ 11 | Calls \code{\link[torchvision:transform_normalize]{torchvision::transform_normalize}}, 12 | see there for more information on the parameters. 13 | The preprocessing is applied to each element of a batch individually. 14 | } 15 | \section{Construction}{ 16 | 17 | 18 | \if{html}{\out{
}}\preformatted{po("trafo_normalize"") 19 | }\if{html}{\out{
}} 20 | } 21 | 22 | \section{Parameters}{ 23 | \tabular{llll}{ 24 | Id \tab Type \tab Default \tab Levels \cr 25 | mean \tab untyped \tab - \tab \cr 26 | std \tab untyped \tab - \tab \cr 27 | stages \tab character \tab - \tab train, predict, both \cr 28 | affect_columns \tab untyped \tab selector_all() \tab \cr 29 | } 30 | } 31 | 32 | -------------------------------------------------------------------------------- /man/mlr_pipeops_trafo_pad.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/preprocess.R 3 | \name{mlr_pipeops_trafo_pad} 4 | \alias{mlr_pipeops_trafo_pad} 5 | \alias{PipeOpPreprocTorchTrafoPad} 6 | \title{Padding Transformation} 7 | \format{ 8 | \code{\link[R6:R6Class]{R6Class}} inheriting from \code{\link{PipeOpTaskPreprocTorch}}. 9 | } 10 | \description{ 11 | Calls \code{\link[torchvision:transform_pad]{torchvision::transform_pad}}, 12 | see there for more information on the parameters. 13 | The preprocessing is applied to each element of a batch individually. 14 | } 15 | \section{Construction}{ 16 | 17 | 18 | \if{html}{\out{
}}\preformatted{po("trafo_pad"") 19 | }\if{html}{\out{
}} 20 | } 21 | 22 | \section{Parameters}{ 23 | \tabular{llll}{ 24 | Id \tab Type \tab Default \tab Levels \cr 25 | padding \tab untyped \tab - \tab \cr 26 | fill \tab untyped \tab 0 \tab \cr 27 | padding_mode \tab character \tab constant \tab constant, edge, reflect, symmetric \cr 28 | stages \tab character \tab - \tab train, predict, both \cr 29 | affect_columns \tab untyped \tab selector_all() \tab \cr 30 | } 31 | } 32 | 33 | -------------------------------------------------------------------------------- /man/mlr_pipeops_trafo_resize.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/preprocess.R 3 | \name{mlr_pipeops_trafo_resize} 4 | \alias{mlr_pipeops_trafo_resize} 5 | \alias{PipeOpPreprocTorchTrafoResize} 6 | \title{Resizing Transformation} 7 | \format{ 8 | \code{\link[R6:R6Class]{R6Class}} inheriting from \code{\link{PipeOpTaskPreprocTorch}}. 9 | } 10 | \description{ 11 | Calls \code{\link[torchvision:transform_resize]{torchvision::transform_resize}}, 12 | see there for more information on the parameters. 13 | The preprocessing is applied to the whole batch. 14 | } 15 | \section{Construction}{ 16 | 17 | 18 | \if{html}{\out{
}}\preformatted{po("trafo_resize"") 19 | }\if{html}{\out{
}} 20 | } 21 | 22 | \section{Parameters}{ 23 | \tabular{llll}{ 24 | Id \tab Type \tab Default \tab Levels \cr 25 | size \tab untyped \tab - \tab \cr 26 | interpolation \tab character \tab 2 \tab Undefined, Bartlett, Blackman, Bohman, Box, Catrom, Cosine, Cubic, Gaussian, Hamming, \link{...} \cr 27 | stages \tab character \tab - \tab train, predict, both \cr 28 | affect_columns \tab untyped \tab selector_all() \tab \cr 29 | } 30 | } 31 | 32 | -------------------------------------------------------------------------------- /man/mlr_pipeops_trafo_rgb_to_grayscale.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/preprocess.R 3 | \name{mlr_pipeops_trafo_rgb_to_grayscale} 4 | \alias{mlr_pipeops_trafo_rgb_to_grayscale} 5 | \alias{PipeOpPreprocTorchTrafoRgbToGrayscale} 6 | \title{RGB to Grayscale Transformation} 7 | \format{ 8 | \code{\link[R6:R6Class]{R6Class}} inheriting from \code{\link{PipeOpTaskPreprocTorch}}. 9 | } 10 | \description{ 11 | Calls \code{\link[torchvision:transform_rgb_to_grayscale]{torchvision::transform_rgb_to_grayscale}}, 12 | see there for more information on the parameters. 13 | The preprocessing is applied to each element of a batch individually. 14 | } 15 | \section{Construction}{ 16 | 17 | 18 | \if{html}{\out{
}}\preformatted{po("trafo_rgb_to_grayscale"") 19 | }\if{html}{\out{
}} 20 | } 21 | 22 | \section{Parameters}{ 23 | \tabular{llll}{ 24 | Id \tab Type \tab Default \tab Levels \cr 25 | stages \tab character \tab - \tab train, predict, both \cr 26 | affect_columns \tab untyped \tab selector_all() \tab \cr 27 | } 28 | } 29 | 30 | -------------------------------------------------------------------------------- /man/mlr_tasks_lazy_iris.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/TaskClassif_lazy_iris.R 3 | \name{mlr_tasks_lazy_iris} 4 | \alias{mlr_tasks_lazy_iris} 5 | \title{Iris Classification Task} 6 | \format{ 7 | \link[R6:R6Class]{R6::R6Class} inheriting from \link[mlr3:TaskClassif]{mlr3::TaskClassif}. 8 | } 9 | \source{ 10 | \url{https://en.wikipedia.org/wiki/Iris_flower_data_set} 11 | } 12 | \description{ 13 | A classification task for the popular \link[datasets:iris]{datasets::iris} data set. 14 | Just like the iris task, but the features are represented as one lazy tensor column. 15 | } 16 | \section{Construction}{ 17 | 18 | 19 | \if{html}{\out{
}}\preformatted{tsk("lazy_iris") 20 | }\if{html}{\out{
}} 21 | } 22 | 23 | \section{Properties}{ 24 | 25 | \itemize{ 26 | \item Task type: \dQuote{classif} 27 | \item Properties: \dQuote{multiclass} 28 | \item Has Missings: no 29 | \item Target: \dQuote{Species} 30 | \item Features: \dQuote{x} 31 | \item Data Dimension: 150x3 32 | } 33 | } 34 | 35 | \examples{ 36 | \dontshow{if (torch::torch_is_installed()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} 37 | task = tsk("lazy_iris") 38 | task 39 | df = task$data() 40 | materialize(df$x[1:6], rbind = TRUE) 41 | \dontshow{\}) # examplesIf} 42 | } 43 | \references{ 44 | Anderson E (1936). 45 | \dQuote{The Species Problem in Iris.} 46 | \emph{Annals of the Missouri Botanical Garden}, \bold{23}(3), 457. 47 | \doi{10.2307/2394164}. 48 | } 49 | -------------------------------------------------------------------------------- /man/model_descriptor_to_learner.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/nn_graph.R 3 | \name{model_descriptor_to_learner} 4 | \alias{model_descriptor_to_learner} 5 | \title{Create a Torch Learner from a ModelDescriptor} 6 | \usage{ 7 | model_descriptor_to_learner(model_descriptor) 8 | } 9 | \arguments{ 10 | \item{model_descriptor}{(\code{\link{ModelDescriptor}})\cr 11 | The model descriptor.} 12 | } 13 | \value{ 14 | \code{\link[mlr3:Learner]{Learner}} 15 | } 16 | \description{ 17 | First a \code{\link{nn_graph}} is created using \code{\link{model_descriptor_to_module}} and then a learner is created from this 18 | module and the remaining information from the model descriptor, which must include the optimizer and loss function 19 | and optionally callbacks. 20 | } 21 | \seealso{ 22 | Other Graph Network: 23 | \code{\link{ModelDescriptor}()}, 24 | \code{\link{TorchIngressToken}()}, 25 | \code{\link{mlr_learners_torch_model}}, 26 | \code{\link{mlr_pipeops_module}}, 27 | \code{\link{mlr_pipeops_torch}}, 28 | \code{\link{mlr_pipeops_torch_ingress}}, 29 | \code{\link{mlr_pipeops_torch_ingress_categ}}, 30 | \code{\link{mlr_pipeops_torch_ingress_ltnsr}}, 31 | \code{\link{mlr_pipeops_torch_ingress_num}}, 32 | \code{\link{model_descriptor_to_module}()}, 33 | \code{\link{model_descriptor_union}()}, 34 | \code{\link{nn_graph}()} 35 | } 36 | \concept{Graph Network} 37 | -------------------------------------------------------------------------------- /man/model_descriptor_to_module.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/nn_graph.R 3 | \name{model_descriptor_to_module} 4 | \alias{model_descriptor_to_module} 5 | \title{Create a nn_graph from ModelDescriptor} 6 | \usage{ 7 | model_descriptor_to_module( 8 | model_descriptor, 9 | output_pointers = NULL, 10 | list_output = FALSE 11 | ) 12 | } 13 | \arguments{ 14 | \item{model_descriptor}{(\code{\link{ModelDescriptor}})\cr 15 | Model Descriptor. \code{pointer} is ignored, instead \code{output_pointer} values are used. \verb{$graph} member is 16 | modified by-reference.} 17 | 18 | \item{output_pointers}{(\code{list} of \code{character})\cr 19 | Collection of \code{pointer}s that indicate what part of the \code{model_descriptor$graph} is being used for output. 20 | Entries have the format of \code{ModelDescriptor$pointer}.} 21 | 22 | \item{list_output}{(\code{logical(1)})\cr 23 | Whether output should be a list of tensors. If \code{FALSE}, then \code{length(output_pointers)} must be 1.} 24 | } 25 | \value{ 26 | \code{\link{nn_graph}} 27 | } 28 | \description{ 29 | Creates the \code{\link{nn_graph}} from a \code{\link{ModelDescriptor}}. Mostly for internal use, since the \code{\link{ModelDescriptor}} is in 30 | most circumstances harder to use than just creating \code{\link{nn_graph}} directly. 31 | } 32 | \seealso{ 33 | Other Graph Network: 34 | \code{\link{ModelDescriptor}()}, 35 | \code{\link{TorchIngressToken}()}, 36 | \code{\link{mlr_learners_torch_model}}, 37 | \code{\link{mlr_pipeops_module}}, 38 | \code{\link{mlr_pipeops_torch}}, 39 | \code{\link{mlr_pipeops_torch_ingress}}, 40 | \code{\link{mlr_pipeops_torch_ingress_categ}}, 41 | \code{\link{mlr_pipeops_torch_ingress_ltnsr}}, 42 | \code{\link{mlr_pipeops_torch_ingress_num}}, 43 | \code{\link{model_descriptor_to_learner}()}, 44 | \code{\link{model_descriptor_union}()}, 45 | \code{\link{nn_graph}()} 46 | } 47 | \concept{Graph Network} 48 | -------------------------------------------------------------------------------- /man/nn.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/nn.R 3 | \name{nn} 4 | \alias{nn} 5 | \title{Create a Neural Network Layer} 6 | \usage{ 7 | nn(.key, ...) 8 | } 9 | \arguments{ 10 | \item{.key}{(\code{character(1)})\cr} 11 | 12 | \item{...}{(any)\cr 13 | Additional parameters, constructor arguments or fields.} 14 | } 15 | \description{ 16 | Retrieve a neural network layer from the 17 | \code{\link[mlr3pipelines:mlr_pipeops]{mlr_pipeops}} dictionary. 18 | } 19 | \examples{ 20 | po1 = po("nn_linear", id = "linear") 21 | # is the same as: 22 | po2 = nn("linear") 23 | } 24 | -------------------------------------------------------------------------------- /man/nn_ft_cls.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/PipeOpTorchFTCLS.R 3 | \name{nn_ft_cls} 4 | \alias{nn_ft_cls} 5 | \title{CLS Token for FT-Transformer} 6 | \usage{ 7 | nn_ft_cls(d_token, initialization) 8 | } 9 | \arguments{ 10 | \item{d_token}{(\code{integer(1)})\cr 11 | The dimension of the embedding.} 12 | 13 | \item{initialization}{(\code{character(1)})\cr 14 | The initialization method for the embedding weights. Possible values are \code{"uniform"} 15 | and \code{"normal"}.} 16 | } 17 | \description{ 18 | Concatenates a CLS token to the input as the last feature. 19 | The input shape is expected to be \verb{(batch, n_features, d_token)} and the output shape is 20 | \verb{(batch, n_features + 1, d_token)}. 21 | 22 | This is used in the FT-Transformer. 23 | } 24 | \references{ 25 | Devlin, Jacob, Chang, Ming-Wei, Lee, Kenton, Toutanova, Kristina (2018). 26 | \dQuote{Bert: Pre-training of deep bidirectional transformers for language understanding.} 27 | \emph{arXiv preprint arXiv:1810.04805}. 28 | } 29 | -------------------------------------------------------------------------------- /man/nn_geglu.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/PipeOpTorchActivation.R 3 | \name{nn_geglu} 4 | \alias{nn_geglu} 5 | \title{GeGLU Module} 6 | \usage{ 7 | nn_geglu() 8 | } 9 | \description{ 10 | This module implements the Gaussian Error Linear Unit Gated Linear Unit (GeGLU) activation function. 11 | It computes \eqn{\text{GeGLU}(x, g) = x \cdot \text{GELU}(g)} 12 | where \(x\) and \(g\) are created by splitting the input tensor in half along the last dimension. 13 | } 14 | \examples{ 15 | \dontshow{if (torch::torch_is_installed()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} 16 | x = torch::torch_randn(10, 10) 17 | glu = nn_geglu() 18 | glu(x) 19 | \dontshow{\}) # examplesIf} 20 | } 21 | \references{ 22 | Shazeer N (2020). 23 | \dQuote{GLU Variants Improve Transformer.} 24 | 2002.05202, \url{https://arxiv.org/abs/2002.05202}. 25 | } 26 | -------------------------------------------------------------------------------- /man/nn_merge_cat.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/PipeOpTorchMerge.R 3 | \name{nn_merge_cat} 4 | \alias{nn_merge_cat} 5 | \title{Concatenates multiple tensors} 6 | \usage{ 7 | nn_merge_cat(dim = -1) 8 | } 9 | \arguments{ 10 | \item{dim}{(\code{integer(1)})\cr 11 | The dimension for the concatenation.} 12 | } 13 | \description{ 14 | Concatenates multiple tensors on a given dimension. 15 | No broadcasting rules are applied here, you must reshape the tensors before to have the same shape. 16 | } 17 | -------------------------------------------------------------------------------- /man/nn_merge_prod.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/PipeOpTorchMerge.R 3 | \name{nn_merge_prod} 4 | \alias{nn_merge_prod} 5 | \title{Product of multiple tensors} 6 | \usage{ 7 | nn_merge_prod() 8 | } 9 | \description{ 10 | Calculates the product of all input tensors. 11 | } 12 | -------------------------------------------------------------------------------- /man/nn_merge_sum.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/PipeOpTorchMerge.R 3 | \name{nn_merge_sum} 4 | \alias{nn_merge_sum} 5 | \title{Sum of multiple tensors} 6 | \usage{ 7 | nn_merge_sum() 8 | } 9 | \description{ 10 | Calculates the sum of all input tensors. 11 | } 12 | -------------------------------------------------------------------------------- /man/nn_reglu.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/PipeOpTorchActivation.R 3 | \name{nn_reglu} 4 | \alias{nn_reglu} 5 | \title{ReGLU Module} 6 | \usage{ 7 | nn_reglu() 8 | } 9 | \description{ 10 | Rectified Gated Linear Unit (ReGLU) module. 11 | Computes the output as \eqn{\text{ReGLU}(x, g) = x \cdot \text{ReLU}(g)} 12 | where \(x\) and \(g\) are created by splitting the input tensor in half along the last dimension. 13 | } 14 | \examples{ 15 | \dontshow{if (torch::torch_is_installed()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} 16 | x = torch::torch_randn(10, 10) 17 | reglu = nn_reglu() 18 | reglu(x) 19 | \dontshow{\}) # examplesIf} 20 | } 21 | \references{ 22 | Shazeer N (2020). 23 | \dQuote{GLU Variants Improve Transformer.} 24 | 2002.05202, \url{https://arxiv.org/abs/2002.05202}. 25 | } 26 | -------------------------------------------------------------------------------- /man/nn_reshape.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/PipeOpTorchReshape.R 3 | \name{nn_reshape} 4 | \alias{nn_reshape} 5 | \title{Reshape} 6 | \usage{ 7 | nn_reshape(shape) 8 | } 9 | \arguments{ 10 | \item{shape}{(\code{integer()})\cr 11 | The desired output shape.} 12 | } 13 | \description{ 14 | Reshape a tensor to the given shape. 15 | } 16 | -------------------------------------------------------------------------------- /man/nn_squeeze.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/PipeOpTorchReshape.R 3 | \name{nn_squeeze} 4 | \alias{nn_squeeze} 5 | \title{Squeeze} 6 | \usage{ 7 | nn_squeeze(dim) 8 | } 9 | \arguments{ 10 | \item{dim}{(\code{integer()})\cr 11 | The dimension to squeeze.} 12 | } 13 | \description{ 14 | Squeezes a tensor by calling \code{\link[torch:torch_squeeze]{torch::torch_squeeze()}} with the given dimension \code{dim}. 15 | } 16 | -------------------------------------------------------------------------------- /man/nn_tokenizer_categ.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/PipeOpTorchTokenizer.R 3 | \name{nn_tokenizer_categ} 4 | \alias{nn_tokenizer_categ} 5 | \title{Categorical Tokenizer} 6 | \usage{ 7 | nn_tokenizer_categ(cardinalities, d_token, bias, initialization) 8 | } 9 | \arguments{ 10 | \item{cardinalities}{(\code{integer()})\cr 11 | The number of categories for each feature.} 12 | 13 | \item{d_token}{(\code{integer(1)})\cr 14 | The dimension of the embedding.} 15 | 16 | \item{bias}{(\code{logical(1)})\cr 17 | Whether to use a bias.} 18 | 19 | \item{initialization}{(\code{character(1)})\cr 20 | The initialization method for the embedding weights. Possible values are \code{"uniform"} 21 | and \code{"normal"}.} 22 | } 23 | \description{ 24 | Tokenizes categorical features into a dense embedding. 25 | For an input of shape \verb{(batch, n_features)} the output shape is \verb{(batch, n_features, d_token)}. 26 | } 27 | \references{ 28 | Gorishniy Y, Rubachev I, Khrulkov V, Babenko A (2021). 29 | \dQuote{Revisiting Deep Learning for Tabular Data.} 30 | \emph{arXiv}, \bold{2106.11959}. 31 | } 32 | -------------------------------------------------------------------------------- /man/nn_tokenizer_num.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/PipeOpTorchTokenizer.R 3 | \name{nn_tokenizer_num} 4 | \alias{nn_tokenizer_num} 5 | \title{Numeric Tokenizer} 6 | \usage{ 7 | nn_tokenizer_num(n_features, d_token, bias, initialization) 8 | } 9 | \arguments{ 10 | \item{n_features}{(\code{integer(1)})\cr 11 | The number of features.} 12 | 13 | \item{d_token}{(\code{integer(1)})\cr 14 | The dimension of the embedding.} 15 | 16 | \item{bias}{(\code{logical(1)})\cr 17 | Whether to use a bias.} 18 | 19 | \item{initialization}{(\code{character(1)})\cr 20 | The initialization method for the embedding weights. Possible values are \code{"uniform"} 21 | and \code{"normal"}.} 22 | } 23 | \description{ 24 | Tokenizes numeric features into a dense embedding. 25 | For an input of shape \verb{(batch, n_features)} the output shape is \verb{(batch, n_features, d_token)}. 26 | } 27 | \references{ 28 | Gorishniy Y, Rubachev I, Khrulkov V, Babenko A (2021). 29 | \dQuote{Revisiting Deep Learning for Tabular Data.} 30 | \emph{arXiv}, \bold{2106.11959}. 31 | } 32 | -------------------------------------------------------------------------------- /man/nn_unsqueeze.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/PipeOpTorchReshape.R 3 | \name{nn_unsqueeze} 4 | \alias{nn_unsqueeze} 5 | \title{Unsqueeze} 6 | \usage{ 7 | nn_unsqueeze(dim) 8 | } 9 | \arguments{ 10 | \item{dim}{(\code{integer(1)})\cr 11 | The dimension to unsqueeze.} 12 | } 13 | \description{ 14 | Unsqueezes a tensor by calling \code{\link[torch:torch_unsqueeze]{torch::torch_unsqueeze()}} with the given dimension \code{dim}. 15 | } 16 | -------------------------------------------------------------------------------- /man/output_dim_for.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/utils.R 3 | \name{output_dim_for} 4 | \alias{output_dim_for} 5 | \title{Network Output Dimension} 6 | \usage{ 7 | output_dim_for(x, ...) 8 | } 9 | \arguments{ 10 | \item{x}{(any)\cr 11 | The task.} 12 | 13 | \item{...}{(any)\cr 14 | Additional arguments. Not used yet.} 15 | } 16 | \description{ 17 | Calculates the output dimension of a neural network for a given task that is expected by 18 | \pkg{mlr3torch}. 19 | For classification, this is the number of classes (unless it is a binary classification task, 20 | where it is 1). For regression, it is 1. 21 | } 22 | -------------------------------------------------------------------------------- /man/replace_head.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/utils.R 3 | \name{replace_head} 4 | \alias{replace_head} 5 | \title{Replace the head of a network 6 | Replaces the head of the network with a linear layer with d_out classes.} 7 | \usage{ 8 | replace_head(network, d_out) 9 | } 10 | \arguments{ 11 | \item{network}{(\code{\link[torch:nn_module]{torch::nn_module}})\cr 12 | The network} 13 | 14 | \item{d_out}{(\code{integer(1)})\cr 15 | The number of output classes.} 16 | } 17 | \description{ 18 | Replace the head of a network 19 | Replaces the head of the network with a linear layer with d_out classes. 20 | } 21 | \keyword{internal} 22 | -------------------------------------------------------------------------------- /man/t_loss.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/TorchLoss.R 3 | \name{t_loss} 4 | \alias{t_loss} 5 | \alias{t_losses} 6 | \title{Loss Function Quick Access} 7 | \usage{ 8 | t_loss(.key, ...) 9 | 10 | t_losses(.keys, ...) 11 | } 12 | \arguments{ 13 | \item{.key}{(\code{character(1)})\cr 14 | Key of the object to retrieve.} 15 | 16 | \item{...}{(any)\cr 17 | See description of \code{\link[mlr3misc:dictionary_sugar_get]{dictionary_sugar_get}}.} 18 | 19 | \item{.keys}{(\code{character()})\cr 20 | The keys of the losses.} 21 | } 22 | \value{ 23 | A \code{\link{TorchLoss}} 24 | } 25 | \description{ 26 | Retrieve one or more \code{\link{TorchLoss}}es from \code{\link{mlr3torch_losses}}. 27 | Works like \code{\link[mlr3:mlr_sugar]{mlr3::lrn()}} and \code{\link[mlr3:mlr_sugar]{mlr3::lrns()}}. 28 | } 29 | \examples{ 30 | \dontshow{if (torch::torch_is_installed()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} 31 | t_loss("mse", reduction = "mean") 32 | # get the dictionary 33 | t_loss() 34 | \dontshow{\}) # examplesIf} 35 | \dontshow{if (torch::torch_is_installed()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} 36 | t_losses(c("mse", "l1")) 37 | # get the dictionary 38 | t_losses() 39 | \dontshow{\}) # examplesIf} 40 | } 41 | \seealso{ 42 | Other Torch Descriptor: 43 | \code{\link{TorchCallback}}, 44 | \code{\link{TorchDescriptor}}, 45 | \code{\link{TorchLoss}}, 46 | \code{\link{TorchOptimizer}}, 47 | \code{\link{as_torch_callbacks}()}, 48 | \code{\link{as_torch_loss}()}, 49 | \code{\link{as_torch_optimizer}()}, 50 | \code{\link{mlr3torch_losses}}, 51 | \code{\link{mlr3torch_optimizers}}, 52 | \code{\link{t_clbk}()}, 53 | \code{\link{t_opt}()} 54 | } 55 | \concept{Torch Descriptor} 56 | -------------------------------------------------------------------------------- /man/t_opt.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/TorchOptimizer.R 3 | \name{t_opt} 4 | \alias{t_opt} 5 | \alias{t_opts} 6 | \title{Optimizers Quick Access} 7 | \usage{ 8 | t_opt(.key, ...) 9 | 10 | t_opts(.keys, ...) 11 | } 12 | \arguments{ 13 | \item{.key}{(\code{character(1)})\cr 14 | Key of the object to retrieve.} 15 | 16 | \item{...}{(any)\cr 17 | See description of \code{\link[mlr3misc:dictionary_sugar_get]{dictionary_sugar_get}}.} 18 | 19 | \item{.keys}{(\code{character()})\cr 20 | The keys of the optimizers.} 21 | } 22 | \value{ 23 | A \code{\link{TorchOptimizer}} 24 | } 25 | \description{ 26 | Retrieves one or more \code{\link{TorchOptimizer}}s from \code{\link{mlr3torch_optimizers}}. 27 | Works like \code{\link[mlr3:mlr_sugar]{mlr3::lrn()}} and \code{\link[mlr3:mlr_sugar]{mlr3::lrns()}}. 28 | } 29 | \examples{ 30 | \dontshow{if (torch::torch_is_installed()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} 31 | t_opt("adam", lr = 0.1) 32 | # get the dictionary 33 | t_opt() 34 | \dontshow{\}) # examplesIf} 35 | \dontshow{if (torch::torch_is_installed()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} 36 | t_opts(c("adam", "sgd")) 37 | # get the dictionary 38 | t_opts() 39 | \dontshow{\}) # examplesIf} 40 | } 41 | \seealso{ 42 | Other Torch Descriptor: 43 | \code{\link{TorchCallback}}, 44 | \code{\link{TorchDescriptor}}, 45 | \code{\link{TorchLoss}}, 46 | \code{\link{TorchOptimizer}}, 47 | \code{\link{as_torch_callbacks}()}, 48 | \code{\link{as_torch_loss}()}, 49 | \code{\link{as_torch_optimizer}()}, 50 | \code{\link{mlr3torch_losses}}, 51 | \code{\link{mlr3torch_optimizers}}, 52 | \code{\link{t_clbk}()}, 53 | \code{\link{t_loss}()} 54 | 55 | Other Dictionary: 56 | \code{\link{mlr3torch_callbacks}}, 57 | \code{\link{mlr3torch_losses}}, 58 | \code{\link{mlr3torch_optimizers}} 59 | } 60 | \concept{Dictionary} 61 | \concept{Torch Descriptor} 62 | -------------------------------------------------------------------------------- /mlr3torch.Rproj: -------------------------------------------------------------------------------- 1 | Version: 1.0 2 | 3 | RestoreWorkspace: No 4 | SaveWorkspace: No 5 | AlwaysSaveHistory: Default 6 | 7 | EnableCodeIndexing: Yes 8 | UseSpacesForTab: Yes 9 | NumSpacesForTab: 2 10 | Encoding: UTF-8 11 | 12 | RnwWeave: knitr 13 | LaTeX: XeLaTeX 14 | 15 | AutoAppendNewline: Yes 16 | StripTrailingWhitespace: Yes 17 | LineEndingConversion: Posix 18 | 19 | BuildType: Package 20 | PackageUseDevtools: Yes 21 | PackageInstallArgs: --no-multiarch --with-keep.source 22 | PackageRoxygenize: rd,collate,namespace 23 | -------------------------------------------------------------------------------- /paper/finetuning.R: -------------------------------------------------------------------------------- 1 | library(mlr3torch) 2 | library(torch) 3 | library(torchdatasets) 4 | 5 | 6 | # Define the path to the Cats vs. Dogs dataset 7 | # Ensure that the dataset is organized in subdirectories for each class 8 | # For example: 9 | # data/ 10 | # ├── cats/ 11 | # │ ├── cat001.jpg 12 | # │ ├── cat002.jpg 13 | # │ └── ... 14 | # └── dogs/ 15 | # ├── dog001.jpg 16 | # ├── dog002.jpg 17 | # └── ... 18 | 19 | data_dir <- here::here("data") 20 | 21 | dogs_vs_cats_dataset(data_dir, train = TRUE, download = TRUE) 22 | 23 | 24 | # Create a Torch dataset using ImageFolder structure 25 | torchdatasets::dogs_vs_cats_dataset( 26 | data_dir = data_dir, 27 | train = FALSE 28 | ) 29 | 30 | dataset_cv <- torchvision::dataset_image_folder( 31 | 32 | path = data_dir, 33 | transform = transform_to_tensor() 34 | ) 35 | 36 | 37 | # the files in ./data/dogs-vs-cats/train are like dog.100.jpg, cat.100.jpg, 38 | # create a data.table with the paths and whether they are a dog or cat 39 | pathd = list.files(file.path(data_dir, "dogs-vs-cats/train"), pattern = "*.jpg", recursive = TRUE, full.names = TRUE) 40 | # regex that matches everything that ends with dog..jpg 41 | labels = ifelse(grepl("dog\\.\\d+\\.jpg", paths), "dog", "cat") 42 | 43 | 44 | # Wrap the dataset in a lazy_tensor 45 | # This allows for efficient on-the-fly data loading without storing all tensors in memory 46 | lazy_cv <- lazy_tensor(dataset_cv) 47 | -------------------------------------------------------------------------------- /pkgdown/favicon/apple-touch-icon-120x120.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlr-org/mlr3torch/d4de0bf85790a0fb012bb862332f6ac39a16baff/pkgdown/favicon/apple-touch-icon-120x120.png -------------------------------------------------------------------------------- /pkgdown/favicon/apple-touch-icon-152x152.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlr-org/mlr3torch/d4de0bf85790a0fb012bb862332f6ac39a16baff/pkgdown/favicon/apple-touch-icon-152x152.png -------------------------------------------------------------------------------- /pkgdown/favicon/apple-touch-icon-180x180.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlr-org/mlr3torch/d4de0bf85790a0fb012bb862332f6ac39a16baff/pkgdown/favicon/apple-touch-icon-180x180.png -------------------------------------------------------------------------------- /pkgdown/favicon/apple-touch-icon-60x60.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlr-org/mlr3torch/d4de0bf85790a0fb012bb862332f6ac39a16baff/pkgdown/favicon/apple-touch-icon-60x60.png -------------------------------------------------------------------------------- /pkgdown/favicon/apple-touch-icon-76x76.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlr-org/mlr3torch/d4de0bf85790a0fb012bb862332f6ac39a16baff/pkgdown/favicon/apple-touch-icon-76x76.png -------------------------------------------------------------------------------- /pkgdown/favicon/apple-touch-icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlr-org/mlr3torch/d4de0bf85790a0fb012bb862332f6ac39a16baff/pkgdown/favicon/apple-touch-icon.png -------------------------------------------------------------------------------- /pkgdown/favicon/favicon-16x16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlr-org/mlr3torch/d4de0bf85790a0fb012bb862332f6ac39a16baff/pkgdown/favicon/favicon-16x16.png -------------------------------------------------------------------------------- /pkgdown/favicon/favicon-32x32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlr-org/mlr3torch/d4de0bf85790a0fb012bb862332f6ac39a16baff/pkgdown/favicon/favicon-32x32.png -------------------------------------------------------------------------------- /pkgdown/favicon/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlr-org/mlr3torch/d4de0bf85790a0fb012bb862332f6ac39a16baff/pkgdown/favicon/favicon.ico -------------------------------------------------------------------------------- /tests/testthat.R: -------------------------------------------------------------------------------- 1 | if (identical(Sys.getenv("TORCH_TEST", unset = "0"), "1")) { 2 | library("checkmate") 3 | library("testthat") 4 | test_check("mlr3torch") 5 | } 6 | -------------------------------------------------------------------------------- /tests/testthat/assets/nano_dogs_vs_cats/cat.1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlr-org/mlr3torch/d4de0bf85790a0fb012bb862332f6ac39a16baff/tests/testthat/assets/nano_dogs_vs_cats/cat.1.jpg -------------------------------------------------------------------------------- /tests/testthat/assets/nano_dogs_vs_cats/cat.2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlr-org/mlr3torch/d4de0bf85790a0fb012bb862332f6ac39a16baff/tests/testthat/assets/nano_dogs_vs_cats/cat.2.jpg -------------------------------------------------------------------------------- /tests/testthat/assets/nano_dogs_vs_cats/cat.3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlr-org/mlr3torch/d4de0bf85790a0fb012bb862332f6ac39a16baff/tests/testthat/assets/nano_dogs_vs_cats/cat.3.jpg -------------------------------------------------------------------------------- /tests/testthat/assets/nano_dogs_vs_cats/cat.4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlr-org/mlr3torch/d4de0bf85790a0fb012bb862332f6ac39a16baff/tests/testthat/assets/nano_dogs_vs_cats/cat.4.jpg -------------------------------------------------------------------------------- /tests/testthat/assets/nano_dogs_vs_cats/cat.5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlr-org/mlr3torch/d4de0bf85790a0fb012bb862332f6ac39a16baff/tests/testthat/assets/nano_dogs_vs_cats/cat.5.jpg -------------------------------------------------------------------------------- /tests/testthat/assets/nano_dogs_vs_cats/dog.1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlr-org/mlr3torch/d4de0bf85790a0fb012bb862332f6ac39a16baff/tests/testthat/assets/nano_dogs_vs_cats/dog.1.jpg -------------------------------------------------------------------------------- /tests/testthat/assets/nano_dogs_vs_cats/dog.2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlr-org/mlr3torch/d4de0bf85790a0fb012bb862332f6ac39a16baff/tests/testthat/assets/nano_dogs_vs_cats/dog.2.jpg -------------------------------------------------------------------------------- /tests/testthat/assets/nano_dogs_vs_cats/dog.3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlr-org/mlr3torch/d4de0bf85790a0fb012bb862332f6ac39a16baff/tests/testthat/assets/nano_dogs_vs_cats/dog.3.jpg -------------------------------------------------------------------------------- /tests/testthat/assets/nano_dogs_vs_cats/dog.4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlr-org/mlr3torch/d4de0bf85790a0fb012bb862332f6ac39a16baff/tests/testthat/assets/nano_dogs_vs_cats/dog.4.jpg -------------------------------------------------------------------------------- /tests/testthat/assets/nano_dogs_vs_cats/dog.5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlr-org/mlr3torch/d4de0bf85790a0fb012bb862332f6ac39a16baff/tests/testthat/assets/nano_dogs_vs_cats/dog.5.jpg -------------------------------------------------------------------------------- /tests/testthat/assets/nano_imagenet/torch_n04456115_0.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlr-org/mlr3torch/d4de0bf85790a0fb012bb862332f6ac39a16baff/tests/testthat/assets/nano_imagenet/torch_n04456115_0.JPEG -------------------------------------------------------------------------------- /tests/testthat/assets/nano_imagenet/torch_n04456115_1.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlr-org/mlr3torch/d4de0bf85790a0fb012bb862332f6ac39a16baff/tests/testthat/assets/nano_imagenet/torch_n04456115_1.JPEG -------------------------------------------------------------------------------- /tests/testthat/assets/nano_imagenet/torch_n04456115_10.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlr-org/mlr3torch/d4de0bf85790a0fb012bb862332f6ac39a16baff/tests/testthat/assets/nano_imagenet/torch_n04456115_10.JPEG -------------------------------------------------------------------------------- /tests/testthat/assets/nano_imagenet/torch_n04456115_100.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlr-org/mlr3torch/d4de0bf85790a0fb012bb862332f6ac39a16baff/tests/testthat/assets/nano_imagenet/torch_n04456115_100.JPEG -------------------------------------------------------------------------------- /tests/testthat/assets/nano_imagenet/torch_n04456115_101.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlr-org/mlr3torch/d4de0bf85790a0fb012bb862332f6ac39a16baff/tests/testthat/assets/nano_imagenet/torch_n04456115_101.JPEG -------------------------------------------------------------------------------- /tests/testthat/assets/nano_imagenet/wok_n04596742_0.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlr-org/mlr3torch/d4de0bf85790a0fb012bb862332f6ac39a16baff/tests/testthat/assets/nano_imagenet/wok_n04596742_0.JPEG -------------------------------------------------------------------------------- /tests/testthat/assets/nano_imagenet/wok_n04596742_1.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlr-org/mlr3torch/d4de0bf85790a0fb012bb862332f6ac39a16baff/tests/testthat/assets/nano_imagenet/wok_n04596742_1.JPEG -------------------------------------------------------------------------------- /tests/testthat/assets/nano_imagenet/wok_n04596742_10.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlr-org/mlr3torch/d4de0bf85790a0fb012bb862332f6ac39a16baff/tests/testthat/assets/nano_imagenet/wok_n04596742_10.JPEG -------------------------------------------------------------------------------- /tests/testthat/assets/nano_imagenet/wok_n04596742_100.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlr-org/mlr3torch/d4de0bf85790a0fb012bb862332f6ac39a16baff/tests/testthat/assets/nano_imagenet/wok_n04596742_100.JPEG -------------------------------------------------------------------------------- /tests/testthat/assets/nano_imagenet/wok_n04596742_101.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlr-org/mlr3torch/d4de0bf85790a0fb012bb862332f6ac39a16baff/tests/testthat/assets/nano_imagenet/wok_n04596742_101.JPEG -------------------------------------------------------------------------------- /tests/testthat/assets/nano_mnist/0-1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlr-org/mlr3torch/d4de0bf85790a0fb012bb862332f6ac39a16baff/tests/testthat/assets/nano_mnist/0-1.jpg -------------------------------------------------------------------------------- /tests/testthat/assets/nano_mnist/0-14481.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlr-org/mlr3torch/d4de0bf85790a0fb012bb862332f6ac39a16baff/tests/testthat/assets/nano_mnist/0-14481.jpg -------------------------------------------------------------------------------- /tests/testthat/assets/nano_mnist/1-10006.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlr-org/mlr3torch/d4de0bf85790a0fb012bb862332f6ac39a16baff/tests/testthat/assets/nano_mnist/1-10006.jpg -------------------------------------------------------------------------------- /tests/testthat/assets/nano_mnist/1-10007.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlr-org/mlr3torch/d4de0bf85790a0fb012bb862332f6ac39a16baff/tests/testthat/assets/nano_mnist/1-10007.jpg -------------------------------------------------------------------------------- /tests/testthat/assets/nano_mnist/2-10009.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlr-org/mlr3torch/d4de0bf85790a0fb012bb862332f6ac39a16baff/tests/testthat/assets/nano_mnist/2-10009.jpg -------------------------------------------------------------------------------- /tests/testthat/assets/nano_mnist/2-10016.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlr-org/mlr3torch/d4de0bf85790a0fb012bb862332f6ac39a16baff/tests/testthat/assets/nano_mnist/2-10016.jpg -------------------------------------------------------------------------------- /tests/testthat/assets/nano_mnist/data.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlr-org/mlr3torch/d4de0bf85790a0fb012bb862332f6ac39a16baff/tests/testthat/assets/nano_mnist/data.rds -------------------------------------------------------------------------------- /tests/testthat/helper_mlr3.R: -------------------------------------------------------------------------------- 1 | # load mlr3 helper files 2 | 3 | library("mlr3") 4 | 5 | mlr_helpers = list.files(system.file("testthat", package = "mlr3"), pattern = "^helper.*\\.[rR]", full.names = TRUE) 6 | # torch assumes that one can access null fields of R6 instances 7 | mlr_helpers = mlr_helpers[!grepl("helper_debugging", mlr_helpers)] 8 | 9 | # this causes the 10 | lapply(mlr_helpers, FUN = source) 11 | -------------------------------------------------------------------------------- /tests/testthat/helper_module.R: -------------------------------------------------------------------------------- 1 | testmodule_linear = nn_module( 2 | initialize = function(task) { 3 | out = output_dim_for(task) 4 | self$linear = nn_linear(length(task$feature_names), out) 5 | }, 6 | forward = function(x) { 7 | self$linear(x) 8 | } 9 | ) 10 | -------------------------------------------------------------------------------- /tests/testthat/helper_pipeop.R: -------------------------------------------------------------------------------- 1 | nn_debug = nn_module( 2 | initialize = function(d_in1, d_in2, d_out1, d_out2, bias = TRUE) { 3 | self$linear1 = nn_linear(d_in1, d_out1, bias) 4 | self$linear2 = nn_linear(d_in2, d_out2, bias) 5 | }, 6 | forward = function(input1, input2) { 7 | output1 = self$linear1(input1) 8 | output2 = self$linear2(input2) 9 | 10 | list(output1 = output1, output2 = output2) 11 | } 12 | ) 13 | 14 | PipeOpTorchDebug = R6Class("PipeOpTorchDebug", 15 | inherit = PipeOpTorch, 16 | public = list( 17 | initialize = function(id = "nn_debug", param_vals = list(), inname = c("input1", "input2"), 18 | outname = c("output1", "output2")) { 19 | param_set = ps( 20 | d_out1 = p_int(lower = 1, tags = c("required", "train")), 21 | d_out2 = p_int(lower = 1, tags = c("required", "train")), 22 | bias = p_lgl(default = TRUE, tags = "train") 23 | ) 24 | 25 | super$initialize( 26 | id = id, 27 | param_vals = param_vals, 28 | param_set = param_set, 29 | inname = inname, 30 | outname = outname, 31 | module_generator = nn_debug 32 | ) 33 | } 34 | ), 35 | private = list( 36 | .shape_dependent_params = function(shapes_in, param_vals, task) { 37 | c(param_vals, list(d_in1 = tail(shapes_in[["input1"]], 1)), d_in2 = tail(shapes_in[["input2"]], 1)) 38 | }, 39 | .shapes_out = function(shapes_in, param_vals, task) { 40 | list( 41 | input1 = c(head(shapes_in[[1]], -1), param_vals$d_out1), 42 | input2 = c(head(shapes_in[[2]], -1), param_vals$d_out2) 43 | ) 44 | } 45 | ) 46 | ) 47 | 48 | PipeOpPreprocTorchAddSome = pipeop_preproc_torch("trafo_some", 49 | param_set = ps(some = p_dbl(default = 1L, tags = "train")), 50 | fn = crate(function(x, some = 1L) x + some), 51 | shapes_out = "infer" 52 | ) 53 | -------------------------------------------------------------------------------- /tests/testthat/setup.R: -------------------------------------------------------------------------------- 1 | old_opts = options( 2 | warnPartialMatchArgs = TRUE, 3 | warnPartialMatchAttr = TRUE, 4 | warnPartialMatchDollar = TRUE 5 | ) 6 | 7 | # https://github.com/HenrikBengtsson/Wishlist-for-R/issues/88 8 | old_opts = lapply(old_opts, function(x) if (is.null(x)) FALSE else x) 9 | 10 | lg = lgr::get_logger("mlr3") 11 | old_threshold = lg$threshold 12 | lg$set_threshold("warn") 13 | 14 | old_plan = future::plan() 15 | future::plan("sequential") 16 | -------------------------------------------------------------------------------- /tests/testthat/teardown.R: -------------------------------------------------------------------------------- 1 | lg$set_threshold(old_threshold) 2 | future::plan(old_plan) 3 | -------------------------------------------------------------------------------- /tests/testthat/test_CallbackSetCheckpoint.R: -------------------------------------------------------------------------------- 1 | test_that("Autotest", { 2 | cb = t_clbk("checkpoint", freq = 1, path = tempfile()) 3 | expect_torch_callback(cb) 4 | }) 5 | 6 | test_that("manual", { 7 | cb = t_clbk("checkpoint", freq = 1) 8 | task = tsk("iris") 9 | task$row_roles$use = 1 10 | 11 | pth0 = tempfile() 12 | learner = lrn("classif.mlp", epochs = 3, batch_size = 1, callbacks = cb) 13 | learner$param_set$set_values(cb.checkpoint.path = pth0) 14 | 15 | learner$train(task) 16 | 17 | expect_set_equal( 18 | c(paste0("network", 1:3, ".pt"), paste0("optimizer", 1:3, ".pt")), 19 | list.files(pth0) 20 | ) 21 | 22 | 23 | learner = lrn("classif.mlp", epochs = 3, batch_size = 1, callbacks = cb) 24 | pth2 = tempfile() 25 | learner$param_set$set_values(cb.checkpoint.path = pth2, cb.checkpoint.freq = 2) 26 | learner$train(task) 27 | 28 | expect_set_equal( 29 | c("network2.pt", "optimizer2.pt", "network3.pt", "optimizer3.pt"), 30 | list.files(pth2) 31 | ) 32 | pred = learner$predict(tsk("iris")) 33 | 34 | opt_state = torch_load(file.path(pth2, "optimizer3.pt")) 35 | expect_list(opt_state, types = c("numeric", "list", "torch_tensor")) 36 | }) 37 | 38 | test_that("error when using existing directory", { 39 | path = tempfile() 40 | dir.create(path) 41 | cb = t_clbk("checkpoint", freq = 1, path = path) 42 | expect_error(cb$generate(), "already exists") 43 | }) 44 | -------------------------------------------------------------------------------- /tests/testthat/test_CallbackSetHistory.R: -------------------------------------------------------------------------------- 1 | test_that("Autotest", { 2 | cb = t_clbk("history") 3 | expect_torch_callback(cb) 4 | }) 5 | 6 | test_that("CallbackSetHistory works", { 7 | cb = t_clbk("history") 8 | task = tsk("iris") 9 | task$internal_valid_task = task$clone(deep = TRUE)$filter(2) 10 | task$filter(1) 11 | 12 | learner = lrn("classif.mlp", epochs = 3, batch_size = 1, callbacks = t_clbk("history"), validate = "predefined") 13 | 14 | learner$train(task) 15 | 16 | expect_data_table(learner$model$callbacks$history, nrows = 0) 17 | 18 | learner$param_set$set_values( 19 | measures_train = msrs(c("classif.acc", "classif.ce")), 20 | measures_valid = msr("classif.ce")) 21 | learner$train(task) 22 | 23 | expect_equal(colnames(learner$model$callbacks$history), c("epoch", "train.classif.acc", "train.classif.ce", "valid.classif.ce")) 24 | expect_data_table(learner$model$callbacks$history, nrows = 3) 25 | }) 26 | 27 | test_that("history works with eval_freq", { 28 | learner = lrn("regr.torch_featureless", epochs = 10, batch_size = 50, eval_freq = 4, callbacks = "history", 29 | measures_train = msrs("regr.mse")) 30 | task = tsk("mtcars") 31 | learner$train(task) 32 | expect_equal(learner$model$callbacks$history$epoch, c(4, 8, 10)) 33 | 34 | learner$param_set$set_values(eval_freq = 5) 35 | learner$train(task) 36 | expect_equal(learner$model$callbacks$history$epoch, c(5, 10)) 37 | }) 38 | -------------------------------------------------------------------------------- /tests/testthat/test_CallbackSetProgress.R: -------------------------------------------------------------------------------- 1 | test_that("autotest", { 2 | cb = t_clbk("progress") 3 | expect_torch_callback(cb) 4 | }) 5 | 6 | test_that("manual test", { 7 | learner = lrn("classif.mlp", epochs = 1, batch_size = 1, 8 | measures_train = msr("classif.acc"), measures_valid = msr("classif.ce"), callbacks = t_clbk("progress"), 9 | drop_last = FALSE, shuffle = TRUE, validate = "predefined" 10 | ) 11 | task = tsk("iris") 12 | task$internal_valid_task = task$clone(deep = TRUE)$filter(2) 13 | task$filter(1) 14 | 15 | # Because the validation is so short, it does not show in the example 16 | # We can make it longer by adding some sleep through callbacks 17 | # Still, this is not captured by capture.output(), so one has to manually inspect that it works 18 | # callbacks = list(t_clbk("progress"), cbutil) 19 | # cbutil = torch_callback("util", on_batch_valid_begin = function() Sys.sleep(1)) 20 | 21 | stdout = suppressMessages(capture.output(learner$train(task))) 22 | 23 | expected = c( 24 | "Epoch 1 started", 25 | "Validation for epoch 1 started", 26 | "", 27 | "[Summary epoch 1]", 28 | "------------------", 29 | "Measures (Train):", 30 | " * classif.acc =", 31 | "Measures (Valid):", 32 | " * classif.ce =", 33 | "", 34 | "Finished training for 1 epochs" 35 | ) 36 | 37 | expect_true(length(stdout) == length(expected)) 38 | expect_true(all(map_lgl(seq_along(stdout), function(i) startsWith(stdout[[i]], expected[[i]])))) 39 | 40 | # does not throw with different eval_freq 41 | learner$param_set$set_values(eval_freq = 2) 42 | expect_error(capture.output(learner$train(task)), regexp = NA) 43 | }) 44 | -------------------------------------------------------------------------------- /tests/testthat/test_LearnerTorchFeatureless.R: -------------------------------------------------------------------------------- 1 | test_that("dataset_featureless works", { 2 | task = tsk("iris") 3 | ds = dataset_featureless(task = task) 4 | expect_true(ds$.length() == 150) 5 | batch = ds$.getbatch(1) 6 | expect_true(torch_equal(batch$x$n, torch_tensor(1L))) 7 | expect_true(inherits(batch$y, "torch_tensor")) 8 | expect_equal(batch$.index, torch_tensor(1, torch_long())) 9 | }) 10 | 11 | test_that("Basic checks: Classification", { 12 | learner = lrn("classif.torch_featureless", epochs = 1, batch_size = 50) 13 | expect_learner_torch(learner, task = tsk("iris")) 14 | }) 15 | 16 | test_that("LearnerTorchFeatureless works", { 17 | learner = lrn("classif.torch_featureless", batch_size = 50, epochs = 100, seed = 1) 18 | task = tsk("iris") 19 | task$row_roles$use = c(1:50, 51:60, 101:110) 20 | learner$train(task) 21 | pred = learner$predict(task) 22 | expect_true(pred$response[[1L]] == "setosa") 23 | }) 24 | 25 | 26 | test_that("Basic checks: Regression", { 27 | learner = lrn("regr.torch_featureless", epochs = 1, batch_size = 50) 28 | expect_learner_torch(learner, task = tsk("mtcars")) 29 | }) 30 | -------------------------------------------------------------------------------- /tests/testthat/test_LearnerTorchImage.R: -------------------------------------------------------------------------------- 1 | test_that("LearnerTorchImage works", { 2 | learner = LearnerTorchImageTest$new(task_type = "classif") 3 | learner$param_set$set_values( 4 | epochs = 1, 5 | batch_size = 1 6 | ) 7 | task = nano_imagenet()$filter(1) 8 | 9 | expect_equal(learner$man, "mlr3torch::mlr_learners.test") 10 | expect_r6(learner, c("Learner", "LearnerTorchImage", "LearnerTorch")) 11 | expect_true(learner$label == "Test Learner Image") 12 | expect_identical(learner$feature_types, "lazy_tensor") 13 | expect_set_equal(learner$predict_types, c("response", "prob")) 14 | expect_subset("R6", learner$packages) 15 | 16 | task = po("trafo_resize", size = c(64, 64))$train(list(task))[[1L]] 17 | 18 | learner$train(task) 19 | expect_class(learner$network, "nn_module") 20 | expect_true(is.null(learner$network$modules$`1`$bias)) 21 | pred = learner$predict(task) 22 | expect_class(pred, "PredictionClassif") 23 | }) 24 | 25 | test_that("unknown shapes work", { 26 | learner = lrn("classif.mobilenet_v2", batch_size = 1, epochs = 1) 27 | task = nano_dogs_vs_cats()$filter(1) 28 | learner$train(task) 29 | expect_prediction(learner$predict(task)) 30 | }) 31 | 32 | test_that("single lazy tensor is expected", { 33 | task = as_task_classif(data.table( 34 | x1 = as_lazy_tensor(matrix(1:4, 2, 2)), 35 | x2 = as_lazy_tensor(matrix(1:4, 2, 2)), 36 | y = as.factor(rep(c("a", "b"), each = 2)) 37 | ), target = "y", id = "test") 38 | learner = lrn("classif.mobilenet_v2", batch_size = 1, epochs = 1, pretrained = FALSE) 39 | expect_error(learner$train(task), "expects exactly one feature") 40 | }) 41 | -------------------------------------------------------------------------------- /tests/testthat/test_PipeOpTorchDropout.R: -------------------------------------------------------------------------------- 1 | test_that("PipeOpTorchDropout autotest", { 2 | po_test = po("nn_dropout") 3 | graph = po("torch_ingress_num") %>>% po_test 4 | 5 | expect_pipeop_torch(graph, "nn_dropout", tsk("iris")) 6 | }) 7 | 8 | 9 | test_that("PipeOpTorchDropout paramtest", { 10 | res = expect_paramset(po("nn_dropout"), nn_dropout) 11 | expect_paramtest(res) 12 | }) 13 | -------------------------------------------------------------------------------- /tests/testthat/test_PipeOpTorchFTCLS.R: -------------------------------------------------------------------------------- 1 | test_that("PipeOpTorchFTCLS autotest", { 2 | task = tsk("iris") 3 | graph = po("torch_ingress_num") %>>% 4 | po("nn_tokenizer_num", d_token = 10) %>>% 5 | po("nn_ft_cls", initialization = "uniform") 6 | 7 | expect_pipeop_torch(graph, "nn_ft_cls", task) 8 | }) 9 | 10 | test_that("PipeOpTorchFTCLS works for tensors of specified dimensions", { 11 | # the canonical case: tensor of shape c(batch_size, n_features, d_token) 12 | task = tsk("iris") 13 | batch_size = 3 14 | d_token = 10 15 | tnsr = torch_tensor(as.matrix(task$data()[seq_len(batch_size), .(Petal.Width, Petal.Length, Sepal.Width, Sepal.Length)])) 16 | 17 | graph = po("torch_ingress_num") %>>% 18 | po("nn_tokenizer_num", d_token = d_token) %>>% 19 | po("nn_ft_cls", initialization = "uniform") 20 | md = graph$train(task)[[1L]] 21 | net = nn_graph(md$graph, shapes_in = list(torch_ingress_num.input = c(NA, task$n_features, d_token))) 22 | 23 | tnsr_out = net(tnsr) 24 | 25 | # the resulting tensor has an extra feature 26 | expect_equal(tnsr_out$shape, c(batch_size, task$n_features + 1, d_token)) 27 | }) 28 | -------------------------------------------------------------------------------- /tests/testthat/test_PipeOpTorchHead.R: -------------------------------------------------------------------------------- 1 | test_that("PipeOpTorchHead autotest", { 2 | po_test = po("nn_head") 3 | task = tsk("iris") 4 | graph = po("torch_ingress_num") %>>% po_test 5 | 6 | expect_pipeop_torch(graph, "nn_head", task, "nn_linear") 7 | }) 8 | 9 | 10 | test_that("PipeOpTorchHead paramtest", { 11 | po_test = po("nn_head") 12 | res = expect_paramset(po_test, torch::nn_linear, exclude = c("out_features", "in_features")) 13 | expect_paramtest(res) 14 | }) 15 | 16 | 17 | test_that("correct error message", { 18 | task = nano_imagenet() 19 | graph = po("torch_ingress_ltnsr") %>>% po("nn_head") 20 | expect_error(graph$train(task), "expects 2D input") 21 | }) 22 | 23 | test_that("correct output dim", { 24 | po_test = po("nn_head") 25 | graph = po("torch_ingress_num") %>>% po_test 26 | # binary 27 | task = tsk("iris") 28 | expect_equal(po_test$shapes_out(list(c(NA, 4)), task = task), list(output = c(NA, 3))) 29 | expect_equal( graph$train(task)[[1L]]$graph$pipeops$nn_head$module$weight$shape[1], 3) 30 | # multiclass 31 | task = tsk("sonar") 32 | expect_equal(po_test$shapes_out(list(c(NA, 60)), task = task), list(output = c(NA, 1))) 33 | expect_equal( graph$train(task)[[1L]]$graph$pipeops$nn_head$module$weight$shape[1], 1) 34 | # regression 35 | task = tsk("mtcars") 36 | expect_equal(po_test$shapes_out(list(c(NA, 11)), task = task), list(output = c(NA, 1))) 37 | expect_equal( graph$train(task)[[1L]]$graph$pipeops$nn_head$module$weight$shape[1], 1) 38 | }) 39 | -------------------------------------------------------------------------------- /tests/testthat/test_PipeOpTorchIdentity.R: -------------------------------------------------------------------------------- 1 | test_that("PipeOpTorchIdentity works", { 2 | po_identity = po("nn_identity") 3 | graph = po("torch_ingress_num") %>>% po_identity 4 | task = tsk("iris") 5 | 6 | expect_pipeop_torch(graph, "nn_identity", task, "nn_identity") 7 | }) 8 | 9 | test_that("PipeOpTorchIdentity paramtest", { 10 | po_identity = po("nn_identity") 11 | res = expect_paramset(po_identity, nn_identity) 12 | expect_paramtest(res) 13 | }) -------------------------------------------------------------------------------- /tests/testthat/test_PipeOpTorchLayerNorm.R: -------------------------------------------------------------------------------- 1 | test_that("PipeOpTorch autotest", { 2 | po_test = po("nn_layer_norm", dims = 1) 3 | task = tsk("iris") 4 | graph = po("torch_ingress_num") %>>% po_test 5 | expect_pipeop_torch(graph, "nn_layer_norm", task, "nn_layer_norm") 6 | }) 7 | 8 | test_that("PipeOpTorch paramtest", { 9 | res = expect_paramset(po("nn_layer_norm", dims = 1), nn_layer_norm, exclude = c("normalized_shape", "dims")) 10 | expect_paramtest(res) 11 | }) 12 | -------------------------------------------------------------------------------- /tests/testthat/test_PipeOpTorchLinear.R: -------------------------------------------------------------------------------- 1 | test_that("PipeOpTorchLinear works", { 2 | po_linear = po("nn_linear", out_features = 10) 3 | graph = po("torch_ingress_num") %>>% po_linear 4 | task = tsk("iris") 5 | 6 | expect_pipeop_torch(graph, "nn_linear", task, "nn_linear") 7 | }) 8 | 9 | test_that("PipeOpTorchLinear paramtest", { 10 | po_linear = po("nn_linear", out_features = 10) 11 | res = expect_paramset(po_linear, nn_linear, exclude = "in_features") 12 | expect_paramtest(res) 13 | }) 14 | 15 | test_that("NA in second dimension", { 16 | ds = dataset( 17 | initialize = function() { 18 | self$xs = lapply(1:10, function(i) torch_randn(sample(1:10, 1), 10)) 19 | }, 20 | .getitem = function(i) { 21 | list(x = self$xs[[i]]) 22 | }, 23 | .length = function() { 24 | length(self$xs) 25 | } 26 | )() 27 | 28 | task = as_task_regr(data.table( 29 | x = as_lazy_tensor(ds, dataset_shapes = list(x = c(NA, NA, NA))), 30 | y = rnorm(10) 31 | ), target = "y", id = "test") 32 | 33 | graph = po("torch_ingress_ltnsr") %>>% po("nn_linear", out_features = 10) 34 | 35 | expect_error(graph$train(task), "Please provide an input with a known last dimension") 36 | 37 | task = as_task_regr(data.table( 38 | x = as_lazy_tensor(ds, dataset_shapes = list(x = c(NA, NA, 10))), 39 | y = rnorm(10) 40 | ), target = "y", id = "test") 41 | 42 | md = graph$train(task)[[1L]] 43 | expect_equal(md$pointer_shape, c(NA, NA, 10)) 44 | net = model_descriptor_to_module(md) 45 | expect_equal(net(torch_randn(1, 2, 10))$shape, c(1, 2, 10)) 46 | }) 47 | -------------------------------------------------------------------------------- /tests/testthat/test_PipeOpTorchReshape.R: -------------------------------------------------------------------------------- 1 | test_that("PipeOpTorchReshape autotest", { 2 | obj = po("nn_reshape", shape = c(-1, 2, 2)) 3 | task = tsk("iris") 4 | graph = po("torch_ingress_num") %>>% obj 5 | 6 | expect_pipeop_torch(graph, "nn_reshape", task) 7 | 8 | out = po("nn_reshape", shape = c(NA, 2, 2))$shapes_out(list(input = c(1, 4))) 9 | expect_true(!is.character(all.equal(out[[1L]], c(NA, 2, 2)))) 10 | }) 11 | 12 | test_that("PipeOpTorchReshape paramtest", { 13 | res = expect_paramset(po("nn_reshape"), nn_reshape) 14 | expect_paramtest(res) 15 | }) 16 | 17 | test_that("PipeOpTorchUnsqueeze autotest", { 18 | obj = po("nn_unsqueeze", dim = 3) 19 | task = tsk("iris") 20 | graph = po("torch_ingress_num") %>>% obj 21 | 22 | expect_pipeop_torch(graph, "nn_unsqueeze", task) 23 | }) 24 | 25 | test_that("PipeOpTorchUnsqueeze paramtest", { 26 | res = expect_paramset(po("nn_unsqueeze"), nn_unsqueeze) 27 | expect_paramtest(res) 28 | }) 29 | 30 | test_that("PipeOpTorchSqueeze autotest", { 31 | obj = po("nn_squeeze", dim = 3) 32 | task = tsk("iris") 33 | graph = po("torch_ingress_num") %>>% po("nn_unsqueeze", dim = 3) %>>% obj 34 | 35 | x = po("nn_squeeze") 36 | 37 | expect_pipeop_torch(graph, "nn_squeeze", task) 38 | }) 39 | 40 | test_that("PipeOpTorchSqueeze paramtest", { 41 | res = expect_paramset(po("nn_unsqueeze"), nn_unsqueeze) 42 | expect_paramtest(res) 43 | }) 44 | 45 | 46 | test_that("PipeOpTorchFlatten autotest", { 47 | obj = po("nn_flatten", start_dim = 2, end_dim = 4) 48 | task = nano_imagenet() 49 | graph = po("torch_ingress_ltnsr") %>>% obj 50 | expect_pipeop_torch(graph, "nn_flatten", task) 51 | }) 52 | 53 | test_that("PipeOpTorchFlatten", { 54 | res = expect_paramset(po("nn_flatten"), nn_flatten) 55 | expect_paramtest(res) 56 | }) 57 | -------------------------------------------------------------------------------- /tests/testthat/test_PipeOpTorchTokenizer.R: -------------------------------------------------------------------------------- 1 | test_that("nn_tokenizer_num works properly", { 2 | x = torch_randn(4, 2) 3 | n_objects = x$shape[1] 4 | n_features = x$shape[2] 5 | d_token = 3 6 | tokenizer = nn_tokenizer_num(n_features, d_token, TRUE, "uniform") 7 | tokens = tokenizer(x) 8 | expect_true(all(tokens$shape == c(n_objects, n_features, d_token))) 9 | }) 10 | 11 | test_that("pipeop numeric tokenizer", { 12 | po_tokenize = po("nn_tokenizer_num", d_token = 10) 13 | graph = po("torch_ingress_num") %>>% po_tokenize 14 | task = tsk("iris") 15 | expect_pipeop_torch(graph, "nn_tokenizer_num", task, "nn_tokenizer_num") 16 | res = expect_paramset(po_tokenize, nn_tokenizer_num, exclude = "n_features") 17 | expect_paramtest(res) 18 | }) 19 | 20 | test_that("nn_tokenizer_categ works properly", { 21 | cardinalities = c(3, 10) 22 | mat = matrix(nrow = 4, ncol = 2) 23 | mat[1, ] = c(1L, 6L) 24 | mat[2, ] = c(2L, 8L) 25 | mat[3, ] = c(1L, 3L) 26 | mat[4, ] = c(3L, 5L) 27 | x = torch_tensor(mat) 28 | n_objects = x$shape[1] 29 | n_features = x$shape[2] 30 | d_token = 3 31 | tokenizer = nn_tokenizer_categ(cardinalities, d_token, TRUE, "uniform") 32 | tokens = tokenizer(x) 33 | expect_true(all(tokens$shape == c(n_objects, n_features, d_token))) 34 | }) 35 | 36 | test_that("pipeop categ tokenizer", { 37 | po_tokenize = po("nn_tokenizer_categ", d_token = 10) 38 | graph = po("torch_ingress_categ") %>>% po_tokenize 39 | task = tsk("breast_cancer") 40 | expect_pipeop_torch(graph, "nn_tokenizer_categ", task, "nn_tokenizer_categ") 41 | res = expect_paramset(po_tokenize, nn_tokenizer_categ, exclude = "cardinalities") 42 | expect_paramtest(res) 43 | }) 44 | -------------------------------------------------------------------------------- /tests/testthat/test_Select.R: -------------------------------------------------------------------------------- 1 | test_that("selectors work", { 2 | all_params = c("0.weight", "0.bias", "3.weight", "3.bias", "6.weight", "6.bias", "9.weight", "9.bias") 3 | 4 | expect_equal(select_none()(all_params), character(0)) 5 | expect_equal(select_all()(all_params), all_params) 6 | expect_equal(select_grep("weight")(all_params), c("0.weight", "3.weight", "6.weight", "9.weight")) 7 | expect_equal(select_invert(select_none())(all_params), all_params) 8 | expect_equal(select_name("0.bias")(all_params), "0.bias") 9 | }) -------------------------------------------------------------------------------- /tests/testthat/test_TaskClassif_cifar.R: -------------------------------------------------------------------------------- 1 | skip_on_cran() 2 | 3 | test_that("CIFAR-10 works", { 4 | withr::local_options(mlr3torch.cache = TRUE) 5 | task = tsk("cifar10") 6 | 7 | expect_equal(task$nrow, 60000) 8 | 9 | task$filter(1:10) 10 | expect_equal(task$id, "cifar10") 11 | expect_equal(task$label, "CIFAR-10 Classification") 12 | expect_equal(task$feature_names, "image") 13 | expect_equal(task$target_names, "class") 14 | expect_equal(task$man, "mlr3torch::mlr_tasks_cifar") 15 | expect_equal(task$backend$hash, "mlr3torch::mlr_tasks_cifar10") 16 | task$data() 17 | expect_true("cifar-10-batches-bin" %in% list.files(file.path(get_cache_dir(), "datasets", "cifar10", "raw"))) 18 | expect_true("data.rds" %in% list.files(file.path(get_cache_dir(), "datasets", "cifar10"))) 19 | expect_equal(task$backend$nrow, 60000) 20 | expect_equal(task$backend$ncol, 4) 21 | }) 22 | 23 | test_that("CIFAR-100 works", { 24 | withr::local_options(mlr3torch.cache = TRUE) 25 | task = tsk("cifar100") 26 | 27 | expect_equal(task$nrow, 60000) 28 | 29 | task$filter(1:10) 30 | expect_equal(task$id, "cifar100") 31 | expect_equal(task$label, "CIFAR-100 Classification") 32 | expect_equal(task$feature_names, "image") 33 | expect_equal(task$target_names, "class") 34 | expect_equal(task$man, "mlr3torch::mlr_tasks_cifar") 35 | expect_equal(task$backend$hash, "mlr3torch::mlr_tasks_cifar100") 36 | task$data() 37 | expect_true("cifar-100-binary" %in% list.files(file.path(get_cache_dir(), "datasets", "cifar100", "raw"))) 38 | expect_true("data.rds" %in% list.files(file.path(get_cache_dir(), "datasets", "cifar100"))) 39 | expect_equal(task$backend$nrow, 60000) 40 | expect_equal(task$backend$ncol, 4) 41 | }) -------------------------------------------------------------------------------- /tests/testthat/test_TaskClassif_lazy_iris.R: -------------------------------------------------------------------------------- 1 | test_that("tiny_imagenet task works", { 2 | task = tsk("lazy_iris") 3 | expect_task(task) 4 | expect_equal(task$id, "lazy_iris") 5 | expect_equal(task$label, "Iris Flowers") 6 | expect_equal(task$target_names, "Species") 7 | expect_equal(task$feature_names, "x") 8 | expect_equal(task$man, "mlr3torch::mlr_tasks_lazy_iris") 9 | expect_equal(task$nrow, 150) 10 | expect_equal(task$ncol, 2) 11 | }) 12 | 13 | -------------------------------------------------------------------------------- /tests/testthat/test_TaskClassif_melanoma.R: -------------------------------------------------------------------------------- 1 | skip_on_cran() 2 | skip_if_not_installed("curl") 3 | 4 | test_that("melanoma task works", { 5 | withr::local_options(mlr3torch.cache = TRUE) 6 | task = tsk("melanoma") 7 | 8 | expect_equal(task$id, "melanoma") 9 | expect_equal(task$label, "Melanoma Classification") 10 | expect_equal(task$feature_names, c("sex", "anatom_site_general_challenge", "age_approx", "image")) 11 | expect_equal(task$target_names, "outcome") 12 | expect_equal(task$man, "mlr3torch::mlr_tasks_melanoma") 13 | expect_equal(task$properties, c("twoclass", "groups")) 14 | 15 | task$data() 16 | x = materialize(task$data(1, "image")[[1]])[[1L]] 17 | expect_class(x, "torch_tensor") 18 | expect_equal(x$shape, c(3, 128, 128)) 19 | 20 | expect_true("data.rds" %in% list.files(file.path(get_cache_dir(), "datasets", "melanoma"))) 21 | expect_equal(task$backend$nrow, 32701 + 10982) 22 | expect_equal(task$backend$ncol, 11) 23 | expect_equal(task$ncol, 5) 24 | }) 25 | -------------------------------------------------------------------------------- /tests/testthat/test_TaskClassif_mnist.R: -------------------------------------------------------------------------------- 1 | skip_on_cran() 2 | 3 | test_that("mnist task works", { 4 | withr::local_options(mlr3torch.cache = TRUE) 5 | task = tsk("mnist") 6 | # this makes the test faster 7 | task$row_roles$use = 1:10 8 | expect_equal(task$id, "mnist") 9 | expect_equal(task$label, "MNIST Digit Classification") 10 | expect_equal(task$feature_names, "image") 11 | expect_equal(task$target_names, "label") 12 | expect_equal(task$man, "mlr3torch::mlr_tasks_mnist") 13 | expect_equal(task$properties, "multiclass") 14 | 15 | x = materialize(task$data(task$row_ids[1:2], cols = "image")[[1L]], rbind = TRUE) 16 | expect_equal(x$shape, c(2, 1, 28, 28)) 17 | expect_equal(x$dtype, torch_float32()) 18 | }) 19 | -------------------------------------------------------------------------------- /tests/testthat/test_TaskClassif_tiny_imagenet.R: -------------------------------------------------------------------------------- 1 | skip_on_cran() 2 | 3 | test_that("tiny_imagenet works", { 4 | withr::local_options(mlr3torch.cache = TRUE) 5 | task = tsk("tiny_imagenet") 6 | 7 | expect_equal(task$nrow, 110000) 8 | # this makes the test faster 9 | task$filter(1:10) 10 | expect_equal(task$id, "tiny_imagenet") 11 | expect_equal(task$label, "ImageNet Subset") 12 | expect_equal(task$feature_names, "image") 13 | expect_equal(task$target_names, "class") 14 | expect_equal(task$man, "mlr3torch::mlr_tasks_tiny_imagenet") 15 | task$data() 16 | expect_true("tiny-imagenet-200" %in% list.files(file.path(get_cache_dir(), "datasets", "tiny_imagenet", "raw"))) 17 | expect_true("data.rds" %in% list.files(file.path(get_cache_dir(), "datasets", "tiny_imagenet"))) 18 | expect_equal(task$backend$nrow, 120000) 19 | expect_equal(task$backend$ncol, 4) 20 | }) 21 | -------------------------------------------------------------------------------- /tests/testthat/test_TorchDescriptor.R: -------------------------------------------------------------------------------- 1 | test_that("TorchDescriptor basic checks", { 2 | descriptor = TorchDescriptor$new( 3 | generator = nn_mse_loss, 4 | id = "mse", 5 | param_set = ps(reduction = p_uty()), 6 | packages = "R6", 7 | label = "MSE Loss", 8 | man = "torch::nn_mse_loss" 9 | ) 10 | 11 | # train tag is added 12 | expect_true("train" %in% descriptor$param_set$tags$reduction) 13 | expect_identical(descriptor$generator, nn_mse_loss) 14 | expect_identical(descriptor$id, "mse") 15 | expect_identical(descriptor$param_set$ids(), "reduction") 16 | expect_set_equal(descriptor$packages, c("R6", "torch", "mlr3torch")) 17 | expect_identical(descriptor$man, "torch::nn_mse_loss") 18 | 19 | expect_class(descriptor, "TorchDescriptor") 20 | 21 | observed = capture.output(descriptor) 22 | 23 | expected = c( 24 | " MSE Loss", 25 | "* Generator: nn_mse_loss", 26 | "* Parameters: list()", 27 | "* Packages: R6,torch,mlr3torch" 28 | ) 29 | expect_identical(observed, expected) 30 | 31 | expect_error(TorchDescriptor$new( 32 | generator = nn_mse_loss, 33 | id = "mse", 34 | param_set = ps(reduction = p_uty(), x = p_uty()), 35 | packages = "R6", 36 | label = "MSE Loss", 37 | man = "torch::nn_mse_loss" 38 | ), 39 | regexp = "Parameter values with ids 'x' are missing in generator.", 40 | fixed = TRUE 41 | ) 42 | }) 43 | -------------------------------------------------------------------------------- /tests/testthat/test_dictionaries.R: -------------------------------------------------------------------------------- 1 | test_that("mlr_pipeops can be converted to a table", { 2 | tbl = as.data.table(mlr_pipeops) 3 | expect_data_table(tbl) 4 | }) 5 | 6 | test_that("mlr_learners can be converted to a table", { 7 | tbl = as.data.table(mlr_learners) 8 | expect_data_table(tbl) 9 | }) 10 | 11 | test_that("mlr3torch_callbacks can be converted to a table", { 12 | tbl = as.data.table(mlr3torch_callbacks) 13 | expect_data_table(tbl) 14 | }) 15 | 16 | test_that("mlr3torch_optimizers can be converted to a table", { 17 | tbl = as.data.table(mlr3torch_optimizers) 18 | expect_data_table(tbl) 19 | }) 20 | 21 | test_that("mlr3torch_losses can be converted to a table", { 22 | tbl = as.data.table(mlr3torch_losses) 23 | expect_data_table(tbl) 24 | }) 25 | 26 | test_that("mlr_tasks can be converted to a table without downloading imagenet", { 27 | dir = tempfile() 28 | withr::local_options(mlr3torch.cache = dir) 29 | expect_data_table(as.data.table(mlr_tasks)) 30 | # nothing is in the cache directory -> imagenet was not downloaded 31 | expect_true(!dir.exists(dir)) 32 | }) 33 | -------------------------------------------------------------------------------- /tests/testthat/test_helper_tasks.R: -------------------------------------------------------------------------------- 1 | test_that("nano_mnist", { 2 | task = nano_mnist() 3 | x = materialize(task$data(cols = task$feature_names)[[1L]], rbind = TRUE) 4 | expect_equal(x$shape, c(10, 1, 28, 28)) 5 | 6 | xout = materialize(po("augment_vflip")$train(list(task))[[1L]]$data(cols = "image")[[1L]], rbind = TRUE) 7 | expect_equal(xout$shape, c(10, 1, 28, 28)) 8 | }) 9 | 10 | test_that("nano_imagenet", { 11 | task = nano_imagenet() 12 | task$row_roles$use = c(2, 3) 13 | x = materialize(task$data(cols = task$feature_names)[[1L]], rbind = TRUE) 14 | expect_equal(x$shape, c(2, 3, 64, 64)) 15 | 16 | xout = materialize(po("augment_vflip")$train(list(task))[[1L]]$data(cols = "image")[[1L]], rbind = TRUE) 17 | expect_equal(xout$shape, c(2, 3, 64, 64)) 18 | }) 19 | 20 | test_that("nano_dogs_vs_cats", { 21 | task = nano_dogs_vs_cats() 22 | task$row_roles$use = c(2, 3) 23 | x = materialize(task$data(cols = task$feature_names)[[1L]], rbind = FALSE) 24 | expect_list(x, types = "torch_tensor") 25 | }) 26 | -------------------------------------------------------------------------------- /tests/testthat/test_nn.R: -------------------------------------------------------------------------------- 1 | test_that("nn works", { 2 | x = nn("linear", out_features = 3) 3 | expect_equal(x$id, "linear") 4 | expect_class(x, "PipeOpTorchLinear") 5 | expect_equal(x$param_set$values$out_features, 3) 6 | }) 7 | 8 | test_that("overwrite id", { 9 | obj = nn("linear", id = "abc") 10 | expect_equal(obj$id, "abc") 11 | }) 12 | 13 | test_that("unnamed arg", { 14 | graph = po("torch_ingress_num") %>>% nn("block", nn("linear", out_features = 3), n_blocks = 2) 15 | md = graph$train(tsk("iris"))[[1L]] 16 | network = model_descriptor_to_module(md) 17 | expect_equal(network$module_list[[1]]$out_features, 3) 18 | expect_equal(network$module_list[[2]]$out_features, 3) 19 | }) 20 | -------------------------------------------------------------------------------- /vignettes/articles/callback_list.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Callbacks" 3 | output: rmarkdown::html_vignette 4 | vignette: > 5 | %\VignetteIndexEntry{Callbacks} 6 | %\VignetteEngine{knitr::rmarkdown} 7 | %\VignetteEncoding{UTF-8} 8 | --- 9 | 10 | Below is a list of the predefined callbacks that are available in `mlr3torch`. 11 | 12 | ```{r, echo = FALSE, message = FALSE} 13 | library(data.table) 14 | library(mlr3torch) 15 | content = as.data.table(mlr3torch::mlr3torch_callbacks) 16 | content$key = sapply(content$key, function(key) sprintf("[%s](https://mlr3torch.mlr-org.com/reference/mlr_callbacks_%s.html)", key, key)) 17 | content$packages = lapply(content$packages, function(ps) setdiff(ps, "mlr3torch")) 18 | content = content[, .(key, label, packages)] 19 | names(content) = c("Key", "Label", "Packages") 20 | knitr::kable(content) 21 | ``` 22 | -------------------------------------------------------------------------------- /vignettes/articles/layer_list.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Network Layers" 3 | output: rmarkdown::html_vignette 4 | vignette: > 5 | %\VignetteIndexEntry{Network Layers} 6 | %\VignetteEngine{knitr::rmarkdown} 7 | %\VignetteEncoding{UTF-8} 8 | --- 9 | 10 | Below is a list of neural network layers that are available in `mlr3torch`. 11 | 12 | ```{r, echo = FALSE, message = FALSE} 13 | library(data.table) 14 | library(mlr3torch) 15 | content = as.data.table(mlr3pipelines::mlr_pipeops) 16 | mlr3torch_ids = names(getFromNamespace("mlr3torch_pipeops", "mlr3torch")) 17 | content = content[key %in% mlr3torch_ids & grepl("^nn_", key)] 18 | content$packages = lapply(content$packages, function(ps) setdiff(ps, c("mlr3torch", "mlr3pipelines"))) 19 | content$key = sapply(content$key, function(key) sprintf("[%s](https://mlr3torch.mlr-org.com/reference/mlr_pipeops_%s.html)", key, key)) 20 | content = content[, .(key, label)] 21 | names(content) = c("Key", "Label") 22 | knitr::kable(content) 23 | ``` 24 | -------------------------------------------------------------------------------- /vignettes/articles/loss_list.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Loss Functions" 3 | output: rmarkdown::html_vignette 4 | vignette: > 5 | %\VignetteIndexEntry{Loss Functions} 6 | %\VignetteEngine{knitr::rmarkdown} 7 | %\VignetteEncoding{UTF-8} 8 | --- 9 | 10 | The table below shows all loss functions that are available in `mlr3torch`. 11 | 12 | ```{r, echo = FALSE, message = FALSE} 13 | library(data.table) 14 | library(mlr3torch) 15 | content = as.data.table(mlr3torch::mlr3torch_losses)[, c("key", "label", "task_types")] 16 | content$key = sapply(content$key, function(key) sprintf("[%s](https://torch.mlverse.org/docs/reference/%s)", key, gsub("torch::", "", t_loss(key)$man))) 17 | names(content) = c("Key", "Label", "Task Type") 18 | knitr::kable(content) 19 | ``` 20 | -------------------------------------------------------------------------------- /vignettes/articles/optimizer_list.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Optimizers" 3 | output: rmarkdown::html_vignette 4 | vignette: > 5 | %\VignetteIndexEntry{Optimizers} 6 | %\VignetteEngine{knitr::rmarkdown} 7 | %\VignetteEncoding{UTF-8} 8 | --- 9 | 10 | The table below shows all optimizers that are available in `mlr3torch`: 11 | 12 | ```{r, echo = FALSE, message = FALSE} 13 | library(data.table) 14 | library(mlr3torch) 15 | content = as.data.table(mlr3torch::mlr3torch_optimizers)[, c("key", "label")] 16 | content$key = sapply(content$key, function(key) sprintf("[%s](https://torch.mlverse.org/docs/reference/%s)", key, gsub("torch::", "", t_opt(key)$man))) 17 | names(content) = c("Key", "Label") 18 | knitr::kable(content) 19 | ``` 20 | -------------------------------------------------------------------------------- /vignettes/articles/preprocessing_list.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Preprocessing & Augmentation" 3 | output: rmarkdown::html_vignette 4 | vignette: > 5 | %\VignetteIndexEntry{Preprocessing & Augmentation} 6 | %\VignetteEngine{knitr::rmarkdown} 7 | %\VignetteEncoding{UTF-8} 8 | --- 9 | 10 | The table below shows all preprocessing and augmentation operations that are available in `mlr3torch`. 11 | 12 | ```{r, echo = FALSE, message = FALSE} 13 | library(data.table) 14 | library(mlr3torch) 15 | content = as.data.table(mlr3pipelines::mlr_pipeops) 16 | content$packages = lapply(content$packages, function(ps) setdiff(ps, c("mlr3torch", "mlr3pipelines"))) 17 | mlr3torch_ids = names(getFromNamespace("mlr3torch_pipeops", "mlr3torch")) 18 | content = content[key %in% mlr3torch_ids & grepl("^(trafo|augment)", key)] 19 | content$key = sapply(content$key, function(key) sprintf("[%s](https://mlr3torch.mlr-org.com/reference/mlr_pipeops_%s.html)", key, key)) 20 | content = content[, .(key, label, packages, feature_types)] 21 | names(content) = c("Key", "Label", "Packages", "Feature Types") 22 | knitr::kable(content) 23 | ``` 24 | -------------------------------------------------------------------------------- /vignettes/articles/tabular_learner_list.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Tabular Learners" 3 | output: rmarkdown::html_vignette 4 | vignette: > 5 | %\VignetteIndexEntry{Tabular Learners} 6 | %\VignetteEngine{knitr::rmarkdown} 7 | %\VignetteEncoding{UTF-8} 8 | --- 9 | 10 | The table below shows all tabular deep learning learners that are added by `mlr3torch`: 11 | An overview of all learners can also be found on the [mlr-org website](https://mlr-org.com/learners.html). 12 | 13 | ```{r, echo = FALSE, message = FALSE} 14 | library(data.table) 15 | library(mlr3torch) 16 | content = as.data.table(mlr3::mlr_learners)[, c("key", "label", "task_type", "feature_types", "packages")] 17 | mlr3torch_ids = names(getFromNamespace("mlr3torch_learners", "mlr3torch")) 18 | content = content[key %in% mlr3torch_ids] 19 | is_vision = sapply(content$packages, function(ps) "torchvision" %in% ps) 20 | content$packages = lapply(content$packages, function(ps) setdiff(ps, c("mlr3torch", "torch", "mlr3"))) 21 | content_vision = content[is_vision, ] 22 | content_rest = content[!is_vision, ] 23 | knitr::kable(content_rest) 24 | ``` 25 | -------------------------------------------------------------------------------- /vignettes/articles/task_list.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Tasks" 3 | output: rmarkdown::html_vignette 4 | vignette: > 5 | %\VignetteIndexEntry{Tasks} 6 | %\VignetteEngine{knitr::rmarkdown} 7 | %\VignetteEncoding{UTF-8} 8 | --- 9 | 10 | The table below shows all tasks that are added by `mlr3torch`. 11 | An overview of all tasks can also be found on the [mlr-org website](https://mlr-org.com/tasks.html). 12 | 13 | ```{r, echo = FALSE, message = FALSE} 14 | library(data.table) 15 | library(mlr3torch) 16 | content = as.data.table(mlr3::mlr_tasks)[, .(key, label, task_type, nrow, ncol)] 17 | mlr3torch_ids = names(getFromNamespace("mlr3torch_tasks", "mlr3torch")) 18 | content = content[key %in% mlr3torch_ids] 19 | content$key = sapply(content$key, function(key) sprintf("[%s](https://mlr3torch.mlr-org.com/reference/%s.html)", key, sub("mlr3torch::", "", tsk(key)$man))) 20 | content = content[, .(key, label, task_type, nrow, ncol)] 21 | colnames(content) = c("Key", "Label", "Task Type", "Rows", "Columns") 22 | knitr::kable(content) 23 | ``` 24 | -------------------------------------------------------------------------------- /vignettes/articles/vision_learner_list.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Vision Learners" 3 | output: rmarkdown::html_vignette 4 | vignette: > 5 | %\VignetteIndexEntry{Vision Learners} 6 | %\VignetteEngine{knitr::rmarkdown} 7 | %\VignetteEncoding{UTF-8} 8 | --- 9 | 10 | Below are the image learners from [`torchvision`](https://github.com/mlverse/torchvision): 11 | An overview of all learners can also be found on the [mlr-org website](https://mlr-org.com/learners.html). 12 | 13 | ```{r, echo = FALSE, message = FALSE} 14 | library(data.table) 15 | library(mlr3torch) 16 | content = as.data.table(mlr3::mlr_learners)[, c("key", "label", "task_type", "feature_types", "packages")] 17 | mlr3torch_ids = names(getFromNamespace("mlr3torch_learners", "mlr3torch")) 18 | content = content[key %in% mlr3torch_ids] 19 | is_vision = sapply(content$packages, function(ps) "torchvision" %in% ps) 20 | content$packages = lapply(content$packages, function(ps) setdiff(ps, c("mlr3torch", "torch", "mlr3"))) 21 | content_vision = content[is_vision, ] 22 | knitr::kable(content_vision) 23 | ``` 24 | --------------------------------------------------------------------------------