├── benchmark_utils ├── __init__.py ├── preprocessing │ ├── requirements_preprocess.txt │ └── preprocess_twentynewsgroups.py ├── digit_no_da_experiment.py ├── scorers.py ├── generate_config │ ├── generate_base_estim_config.py │ └── generate_config_simulated.py ├── deep_base_solver.py └── extract_best_base_estim.py ├── config ├── datasets │ ├── Mushrooms.yml │ ├── Phishing.yml │ ├── Simulated.yml │ ├── 20NewsGroups.yml │ ├── AmazonReview.yml │ ├── Office31Decaf.yml │ ├── bci_projected.yml │ ├── mnist_usps_pca.yml │ └── OfficeHomeResnet.yml ├── solvers │ ├── NO_DA_SOURCE_ONLY │ │ ├── BCI.yml │ │ ├── Mushrooms.yml │ │ ├── Office31.yml │ │ ├── Simulated.yml │ │ ├── mnist_usps.yml │ │ ├── 20NewsGroups.yml │ │ ├── AmazonReview.yml │ │ ├── OfficeHomeResnet.yml │ │ └── Phishing.yml │ ├── NO_DA_TARGET_ONLY │ │ ├── BCI.yml │ │ ├── Mushrooms.yml │ │ ├── Office31.yml │ │ ├── Simulated.yml │ │ ├── mnist_usps.yml │ │ ├── 20NewsGroups.yml │ │ ├── AmazonReview.yml │ │ ├── OfficeHomeResnet.yml │ │ └── Phishing.yml │ ├── NO_DA_SOURCE_ONLY_BASE_ESTIM │ │ ├── BCI.yml │ │ ├── Mushrooms.yml │ │ ├── Simulated.yml │ │ ├── Office31.yml │ │ ├── mnist_usps.yml │ │ ├── 20NewsGroups.yml │ │ ├── AmazonReview.yml │ │ ├── OfficeHomeResnet.yml │ │ └── Phishing.yml │ ├── DASVM │ │ ├── Mushrooms.yml │ │ ├── Office31.yml │ │ ├── Simulated.yml │ │ ├── BCI.yml │ │ ├── Phishing.yml │ │ ├── 20NewsGroups.yml │ │ ├── mnist_usps.yml │ │ ├── AmazonReview.yml │ │ └── OfficeHomeResnet.yml │ ├── gaussian_reweight │ │ ├── BCI.yml │ │ ├── Mushrooms.yml │ │ ├── Simulated.yml │ │ ├── Office31.yml │ │ ├── mnist_usps.yml │ │ ├── 20NewsGroups.yml │ │ ├── AmazonReview.yml │ │ ├── OfficeHomeResnet.yml │ │ └── Phishing.yml │ ├── CORAL │ │ ├── BCI.yml │ │ ├── Mushrooms.yml │ │ ├── Simulated.yml │ │ ├── Office31.yml │ │ ├── 20NewsGroups.yml │ │ ├── mnist_usps.yml │ │ ├── AmazonReview.yml │ │ ├── OfficeHomeResnet.yml │ │ └── Phishing.yml │ ├── PCA │ │ ├── BCI.yml │ │ ├── Mushrooms.yml │ │ ├── Office31.yml │ │ ├── Simulated.yml │ │ ├── mnist_usps.yml │ │ ├── 20NewsGroups.yml │ │ ├── AmazonReview.yml │ │ ├── OfficeHomeResnet.yml │ │ └── Phishing.yml │ ├── discriminator_reweight │ │ ├── BCI.yml │ │ ├── Mushrooms.yml │ │ ├── Simulated.yml │ │ ├── Office31.yml │ │ ├── 20NewsGroups.yml │ │ ├── mnist_usps.yml │ │ ├── AmazonReview.yml │ │ ├── OfficeHomeResnet.yml │ │ └── Phishing.yml │ ├── subspace_alignment │ │ ├── BCI.yml │ │ ├── Mushrooms.yml │ │ ├── Simulated.yml │ │ ├── Office31.yml │ │ ├── 20NewsGroups.yml │ │ ├── mnist_usps.yml │ │ ├── AmazonReview.yml │ │ ├── OfficeHomeResnet.yml │ │ └── Phishing.yml │ ├── nearest_neighbor_reweight │ │ ├── BCI.yml │ │ ├── Mushrooms.yml │ │ ├── Office31.yml │ │ ├── Simulated.yml │ │ ├── mnist_usps.yml │ │ ├── 20NewsGroups.yml │ │ ├── AmazonReview.yml │ │ ├── OfficeHomeResnet.yml │ │ └── Phishing.yml │ ├── density_reweight │ │ ├── BCI.yml │ │ ├── Mushrooms.yml │ │ ├── Simulated.yml │ │ ├── Office31.yml │ │ ├── 20NewsGroups.yml │ │ ├── mnist_usps.yml │ │ ├── AmazonReview.yml │ │ ├── OfficeHomeResnet.yml │ │ └── Phishing.yml │ ├── linear_ot_mapping │ │ ├── BCI.yml │ │ ├── Mushrooms.yml │ │ ├── Simulated.yml │ │ ├── Office31.yml │ │ ├── 20NewsGroups.yml │ │ ├── mnist_usps.yml │ │ ├── AmazonReview.yml │ │ ├── OfficeHomeResnet.yml │ │ └── Phishing.yml │ ├── ot_mapping │ │ ├── BCI.yml │ │ ├── Mushrooms.yml │ │ ├── Office31.yml │ │ ├── Simulated.yml │ │ ├── mnist_usps.yml │ │ ├── 20NewsGroups.yml │ │ ├── AmazonReview.yml │ │ ├── OfficeHomeResnet.yml │ │ └── Phishing.yml │ ├── JDOT_SVC │ │ ├── Mushrooms.yml │ │ ├── Phishing.yml │ │ ├── OfficeHomeResnet.yml │ │ ├── Office31.yml │ │ ├── BCI.yml │ │ ├── Simulated.yml │ │ ├── mnist_usps.yml │ │ ├── 20NewsGroups.yml │ │ └── AmazonReview.yml │ ├── transfer_component_analysis │ │ ├── 20NewsGroups.yml │ │ ├── mnist_usps.yml │ │ ├── OfficeHomeResnet.yml │ │ ├── Phishing.yml │ │ ├── BCI.yml │ │ ├── Mushrooms.yml │ │ ├── Office31.yml │ │ ├── Simulated.yml │ │ └── AmazonReview.yml │ ├── MMDSConS │ │ ├── 20NewsGroups.yml │ │ ├── BCI.yml │ │ ├── Mushrooms.yml │ │ ├── Simulated.yml │ │ ├── Office31.yml │ │ ├── mnist_usps.yml │ │ ├── AmazonReview.yml │ │ ├── OfficeHomeResnet.yml │ │ └── Phishing.yml │ ├── TarS │ │ ├── BCI.yml │ │ ├── Mushrooms.yml │ │ ├── Office31.yml │ │ ├── Simulated.yml │ │ ├── mnist_usps.yml │ │ ├── 20NewsGroups.yml │ │ ├── AmazonReview.yml │ │ ├── OfficeHomeResnet.yml │ │ └── Phishing.yml │ ├── entropic_ot_mapping │ │ ├── BCI.yml │ │ ├── Mushrooms.yml │ │ ├── Simulated.yml │ │ ├── Office31.yml │ │ ├── 20NewsGroups.yml │ │ ├── mnist_usps.yml │ │ ├── AmazonReview.yml │ │ ├── OfficeHomeResnet.yml │ │ └── Phishing.yml │ ├── KMM │ │ ├── BCI.yml │ │ ├── Mushrooms.yml │ │ ├── Office31.yml │ │ ├── Simulated.yml │ │ ├── mnist_usps.yml │ │ ├── 20NewsGroups.yml │ │ ├── AmazonReview.yml │ │ ├── OfficeHomeResnet.yml │ │ └── Phishing.yml │ ├── KLIEP │ │ ├── BCI.yml │ │ ├── Mushrooms.yml │ │ ├── Office31.yml │ │ ├── Simulated.yml │ │ ├── mnist_usps.yml │ │ ├── 20NewsGroups.yml │ │ ├── AmazonReview.yml │ │ ├── OfficeHomeResnet.yml │ │ └── Phishing.yml │ ├── transfer_subspace_learning │ │ ├── Mushrooms.yml │ │ ├── mnist_usps.yml │ │ ├── 20NewsGroups.yml │ │ ├── OfficeHomeResnet.yml │ │ ├── Phishing.yml │ │ ├── BCI.yml │ │ ├── Simulated.yml │ │ ├── Office31.yml │ │ └── AmazonReview.yml │ ├── OTLabelProp │ │ ├── BCI.yml │ │ ├── Mushrooms.yml │ │ ├── Simulated.yml │ │ ├── Office31.yml │ │ ├── mnist_usps.yml │ │ ├── 20NewsGroups.yml │ │ ├── AmazonReview.yml │ │ ├── OfficeHomeResnet.yml │ │ └── Phishing.yml │ └── class_regularizer_ot_mapping │ │ ├── BCI.yml │ │ ├── Mushrooms.yml │ │ ├── Simulated.yml │ │ ├── Office31.yml │ │ ├── mnist_usps.yml │ │ ├── 20NewsGroups.yml │ │ ├── AmazonReview.yml │ │ ├── OfficeHomeResnet.yml │ │ └── Phishing.yml ├── example_config.yml ├── config_mnist_usps_test.yml ├── config_twentynewsgroups_test.yml ├── example_config_deep.yml ├── example_config_shallow.yml ├── best_base_estimators.yml ├── example_choleski_slurm.yaml ├── margaret_slurm.yaml ├── margaret_slurm_base_estimators.yaml ├── Bench_Simulated.yml └── example_config_simulated.yml ├── data └── .gitignore ├── outputs ├── base_estimators │ └── .gitignore ├── real_datasets │ └── .gitignore └── simulated_datasets │ └── .gitignore ├── visualize └── requirements_plot.txt ├── test_config.py ├── .gitignore ├── .github └── workflows │ └── main.yml ├── requirements_all.txt ├── select_base_estimators.sh ├── solvers ├── no_da_source_only_base_estim.py ├── gaussian_reweight.py ├── coral.py ├── subspace_alignment.py ├── kmm.py ├── kliep.py ├── linear_ot_mapping.py ├── ot_mapping.py ├── no_da_source_only.py ├── no_da_target_only.py ├── nearest_neighbor_reweight.py ├── discriminator_reweight.py ├── pca.py ├── deep_dan.py ├── tars.py ├── density_reweight.py ├── deep_coral.py ├── entropic_ot_mapping.py ├── deep_no_da_source_only.py ├── deep_no_da_target_only.py ├── mmdscons.py ├── transfer_component_analysis.py ├── deep_mcc.py ├── otlabelprop.py ├── deep_jdot.py ├── jdot_svc.py ├── deep_spa.py ├── transfer_subspace_learning.py ├── deep_mdd.py ├── deep_can.py └── deep_dann.py └── datasets ├── simulated.py └── office31_decaf.py /benchmark_utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /config/datasets/Mushrooms.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Mushrooms 3 | solver: [] 4 | -------------------------------------------------------------------------------- /config/datasets/Phishing.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Phishing 3 | solver: [] 4 | -------------------------------------------------------------------------------- /config/datasets/Simulated.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Simulated 3 | solver: [] 4 | -------------------------------------------------------------------------------- /config/datasets/20NewsGroups.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - 20NewsGroups 3 | solver: [] 4 | -------------------------------------------------------------------------------- /config/datasets/AmazonReview.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - AmazonReview 3 | solver: [] 4 | -------------------------------------------------------------------------------- /config/datasets/Office31Decaf.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Office31Decaf 3 | solver: [] 4 | -------------------------------------------------------------------------------- /config/datasets/bci_projected.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - bci_projected 3 | solver: [] 4 | -------------------------------------------------------------------------------- /config/datasets/mnist_usps_pca.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - mnist_usps_pca 3 | solver: [] 4 | -------------------------------------------------------------------------------- /config/datasets/OfficeHomeResnet.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - OfficeHomeResnet 3 | solver: [] 4 | -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore -------------------------------------------------------------------------------- /outputs/base_estimators/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore -------------------------------------------------------------------------------- /outputs/real_datasets/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore -------------------------------------------------------------------------------- /outputs/simulated_datasets/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore -------------------------------------------------------------------------------- /visualize/requirements_plot.txt: -------------------------------------------------------------------------------- 1 | # Mandatory for visualization scripts 2 | pandas==1.5.3 3 | numpy==1.24.4 4 | matplotlib==3.8.2 5 | scipy==1.10.1 6 | seaborn==0.12.2 7 | -------------------------------------------------------------------------------- /config/solvers/NO_DA_SOURCE_ONLY/BCI.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - BCI 3 | solver: 4 | - NO_DA_SOURCE_ONLY: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - LR_C2.0 8 | -------------------------------------------------------------------------------- /config/solvers/NO_DA_TARGET_ONLY/BCI.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - BCI 3 | solver: 4 | - NO_DA_TARGET_ONLY: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - LR_C2.0 8 | -------------------------------------------------------------------------------- /config/solvers/NO_DA_SOURCE_ONLY/Mushrooms.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Mushrooms 3 | solver: 4 | - NO_DA_SOURCE_ONLY: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - LR 8 | -------------------------------------------------------------------------------- /config/solvers/NO_DA_SOURCE_ONLY/Office31.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Office31 3 | solver: 4 | - NO_DA_SOURCE_ONLY: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - LR_C0.01 8 | -------------------------------------------------------------------------------- /config/solvers/NO_DA_SOURCE_ONLY/Simulated.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Simulated 3 | solver: 4 | - NO_DA_SOURCE_ONLY: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC 8 | -------------------------------------------------------------------------------- /config/solvers/NO_DA_TARGET_ONLY/Mushrooms.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Mushrooms 3 | solver: 4 | - NO_DA_TARGET_ONLY: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - LR 8 | -------------------------------------------------------------------------------- /config/solvers/NO_DA_TARGET_ONLY/Office31.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Office31 3 | solver: 4 | - NO_DA_TARGET_ONLY: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - LR_C0.01 8 | -------------------------------------------------------------------------------- /config/solvers/NO_DA_TARGET_ONLY/Simulated.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Simulated 3 | solver: 4 | - NO_DA_TARGET_ONLY: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC 8 | -------------------------------------------------------------------------------- /config/solvers/NO_DA_SOURCE_ONLY/mnist_usps.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - mnist_usps 3 | solver: 4 | - NO_DA_SOURCE_ONLY: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C10.0_Gamma0.01 8 | -------------------------------------------------------------------------------- /config/solvers/NO_DA_SOURCE_ONLY_BASE_ESTIM/BCI.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - BCI 3 | solver: 4 | - NO_DA_SOURCE_ONLY_BASE_ESTIM: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - LR_C2.0 8 | -------------------------------------------------------------------------------- /config/solvers/NO_DA_TARGET_ONLY/mnist_usps.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - mnist_usps 3 | solver: 4 | - NO_DA_TARGET_ONLY: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C10.0_Gamma0.01 8 | -------------------------------------------------------------------------------- /config/solvers/NO_DA_SOURCE_ONLY/20NewsGroups.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - 20NewsGroups 3 | solver: 4 | - NO_DA_SOURCE_ONLY: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C10.0_Gamma10.0 8 | -------------------------------------------------------------------------------- /config/solvers/NO_DA_SOURCE_ONLY_BASE_ESTIM/Mushrooms.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Mushrooms 3 | solver: 4 | - NO_DA_SOURCE_ONLY_BASE_ESTIM: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - LR 8 | -------------------------------------------------------------------------------- /config/solvers/NO_DA_SOURCE_ONLY_BASE_ESTIM/Simulated.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Simulated 3 | solver: 4 | - NO_DA_SOURCE_ONLY_BASE_ESTIM: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC 8 | -------------------------------------------------------------------------------- /config/solvers/NO_DA_TARGET_ONLY/20NewsGroups.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - 20NewsGroups 3 | solver: 4 | - NO_DA_TARGET_ONLY: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C10.0_Gamma10.0 8 | -------------------------------------------------------------------------------- /benchmark_utils/preprocessing/requirements_preprocess.txt: -------------------------------------------------------------------------------- 1 | # For preprocessing scripts 2 | numpy==1.24.4 3 | sentence_transformers==3.0.0 4 | skrub==0.1.1 5 | torch==2.2.1 6 | torchvision==0.17.1 7 | scikit-learn==1.5.0 -------------------------------------------------------------------------------- /config/solvers/NO_DA_SOURCE_ONLY/AmazonReview.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - AmazonReview 3 | solver: 4 | - NO_DA_SOURCE_ONLY: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C1000.0_Gamma0.001 8 | -------------------------------------------------------------------------------- /config/solvers/NO_DA_SOURCE_ONLY_BASE_ESTIM/Office31.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Office31 3 | solver: 4 | - NO_DA_SOURCE_ONLY_BASE_ESTIM: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - LR_C0.01 8 | -------------------------------------------------------------------------------- /config/solvers/NO_DA_TARGET_ONLY/AmazonReview.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - AmazonReview 3 | solver: 4 | - NO_DA_TARGET_ONLY: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C1000.0_Gamma0.001 8 | -------------------------------------------------------------------------------- /config/solvers/NO_DA_SOURCE_ONLY/OfficeHomeResnet.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - OfficeHomeResnet 3 | solver: 4 | - NO_DA_SOURCE_ONLY: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C10.0_Gamma0.001 8 | -------------------------------------------------------------------------------- /config/solvers/NO_DA_TARGET_ONLY/OfficeHomeResnet.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - OfficeHomeResnet 3 | solver: 4 | - NO_DA_TARGET_ONLY: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C10.0_Gamma0.001 8 | -------------------------------------------------------------------------------- /config/solvers/NO_DA_SOURCE_ONLY/Phishing.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Phishing 3 | solver: 4 | - NO_DA_SOURCE_ONLY: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - XGB_subsample0.8_colsample0.65_maxdepth20 8 | -------------------------------------------------------------------------------- /config/solvers/NO_DA_TARGET_ONLY/Phishing.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Phishing 3 | solver: 4 | - NO_DA_TARGET_ONLY: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - XGB_subsample0.8_colsample0.65_maxdepth20 8 | -------------------------------------------------------------------------------- /config/solvers/NO_DA_SOURCE_ONLY_BASE_ESTIM/mnist_usps.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - mnist_usps 3 | solver: 4 | - NO_DA_SOURCE_ONLY_BASE_ESTIM: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C10.0_Gamma0.01 8 | -------------------------------------------------------------------------------- /config/solvers/NO_DA_SOURCE_ONLY_BASE_ESTIM/20NewsGroups.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - 20NewsGroups 3 | solver: 4 | - NO_DA_SOURCE_ONLY_BASE_ESTIM: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C10.0_Gamma10.0 8 | -------------------------------------------------------------------------------- /config/solvers/NO_DA_SOURCE_ONLY_BASE_ESTIM/AmazonReview.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - AmazonReview 3 | solver: 4 | - NO_DA_SOURCE_ONLY_BASE_ESTIM: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C1000.0_Gamma0.001 8 | -------------------------------------------------------------------------------- /config/example_config.yml: -------------------------------------------------------------------------------- 1 | # This is an example of a configuration file for the experiments 2 | 3 | # We define the solvers 4 | solver: 5 | - CORAL 6 | 7 | # We define the datasets 8 | dataset: 9 | - Office31SURF 10 | 11 | -------------------------------------------------------------------------------- /config/solvers/NO_DA_SOURCE_ONLY_BASE_ESTIM/OfficeHomeResnet.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - OfficeHomeResnet 3 | solver: 4 | - NO_DA_SOURCE_ONLY_BASE_ESTIM: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C10.0_Gamma0.001 8 | -------------------------------------------------------------------------------- /config/solvers/DASVM/Mushrooms.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Mushrooms 3 | solver: 4 | - DASVM: 5 | param_grid: 6 | - dasvmclassifier__base_estimator__estimator_name: 7 | - SVC 8 | dasvmclassifier__max_iter: 9 | - 200 10 | -------------------------------------------------------------------------------- /config/solvers/DASVM/Office31.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Office31 3 | solver: 4 | - DASVM: 5 | param_grid: 6 | - dasvmclassifier__base_estimator__estimator_name: 7 | - SVC 8 | dasvmclassifier__max_iter: 9 | - 200 10 | -------------------------------------------------------------------------------- /config/solvers/DASVM/Simulated.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Simulated 3 | solver: 4 | - DASVM: 5 | param_grid: 6 | - dasvmclassifier__base_estimator__estimator_name: 7 | - SVC 8 | dasvmclassifier__max_iter: 9 | - 200 10 | -------------------------------------------------------------------------------- /config/solvers/NO_DA_SOURCE_ONLY_BASE_ESTIM/Phishing.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Phishing 3 | solver: 4 | - NO_DA_SOURCE_ONLY_BASE_ESTIM: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - XGB_subsample0.8_colsample0.65_maxdepth20 8 | -------------------------------------------------------------------------------- /config/solvers/DASVM/BCI.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - BCI 3 | solver: 4 | - DASVM: 5 | param_grid: 6 | - dasvmclassifier__base_estimator__estimator_name: 7 | - SVC_C10.0_Gamma0.1 8 | dasvmclassifier__max_iter: 9 | - 200 10 | -------------------------------------------------------------------------------- /config/solvers/gaussian_reweight/BCI.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - BCI 3 | solver: 4 | - gaussian_reweight: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - LR_C2.0 8 | gaussianreweightadapter__reg: 9 | - auto 10 | -------------------------------------------------------------------------------- /config/solvers/gaussian_reweight/Mushrooms.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Mushrooms 3 | solver: 4 | - gaussian_reweight: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - LR 8 | gaussianreweightadapter__reg: 9 | - auto 10 | -------------------------------------------------------------------------------- /config/solvers/gaussian_reweight/Simulated.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Simulated 3 | solver: 4 | - gaussian_reweight: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC 8 | gaussianreweightadapter__reg: 9 | - auto 10 | -------------------------------------------------------------------------------- /config/solvers/DASVM/Phishing.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Phishing 3 | solver: 4 | - DASVM: 5 | param_grid: 6 | - dasvmclassifier__base_estimator__estimator_name: 7 | - SVC_C100.0_Gamma0.1 8 | dasvmclassifier__max_iter: 9 | - 200 10 | -------------------------------------------------------------------------------- /config/solvers/gaussian_reweight/Office31.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Office31 3 | solver: 4 | - gaussian_reweight: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - LR_C0.01 8 | gaussianreweightadapter__reg: 9 | - auto 10 | -------------------------------------------------------------------------------- /config/solvers/DASVM/20NewsGroups.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - 20NewsGroups 3 | solver: 4 | - DASVM: 5 | param_grid: 6 | - dasvmclassifier__base_estimator__estimator_name: 7 | - SVC_C10.0_Gamma10.0 8 | dasvmclassifier__max_iter: 9 | - 30 10 | -------------------------------------------------------------------------------- /config/solvers/DASVM/mnist_usps.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - mnist_usps 3 | solver: 4 | - DASVM: 5 | param_grid: 6 | - dasvmclassifier__base_estimator__estimator_name: 7 | - SVC_C10.0_Gamma0.01 8 | dasvmclassifier__max_iter: 9 | - 200 10 | -------------------------------------------------------------------------------- /config/solvers/DASVM/AmazonReview.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - AmazonReview 3 | solver: 4 | - DASVM: 5 | param_grid: 6 | - dasvmclassifier__base_estimator__estimator_name: 7 | - SVC_C1000.0_Gamma0.001 8 | dasvmclassifier__max_iter: 9 | - 200 10 | -------------------------------------------------------------------------------- /config/solvers/DASVM/OfficeHomeResnet.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - OfficeHomeResnet 3 | solver: 4 | - DASVM: 5 | param_grid: 6 | - dasvmclassifier__base_estimator__estimator_name: 7 | - SVC_C10.0_Gamma0.001 8 | dasvmclassifier__max_iter: 9 | - 200 10 | -------------------------------------------------------------------------------- /config/solvers/gaussian_reweight/mnist_usps.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - mnist_usps 3 | solver: 4 | - gaussian_reweight: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C10.0_Gamma0.01 8 | gaussianreweightadapter__reg: 9 | - auto 10 | -------------------------------------------------------------------------------- /config/solvers/gaussian_reweight/20NewsGroups.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - 20NewsGroups 3 | solver: 4 | - gaussian_reweight: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C10.0_Gamma10.0 8 | gaussianreweightadapter__reg: 9 | - auto 10 | -------------------------------------------------------------------------------- /config/solvers/gaussian_reweight/AmazonReview.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - AmazonReview 3 | solver: 4 | - gaussian_reweight: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C1000.0_Gamma0.001 8 | gaussianreweightadapter__reg: 9 | - 0.9 10 | -------------------------------------------------------------------------------- /config/solvers/gaussian_reweight/OfficeHomeResnet.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - OfficeHomeResnet 3 | solver: 4 | - gaussian_reweight: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C10.0_Gamma0.001 8 | gaussianreweightadapter__reg: 9 | - auto 10 | -------------------------------------------------------------------------------- /config/config_mnist_usps_test.yml: -------------------------------------------------------------------------------- 1 | # This is an example of a configuration file for the experiments 2 | 3 | # We define the solvers 4 | solver: 5 | - NO_DA_SOURCE_ONLY 6 | - NO_DA_TARGET_ONLY 7 | - CORAL 8 | 9 | # We define the datasets 10 | dataset: 11 | - mnist_usps 12 | -------------------------------------------------------------------------------- /config/solvers/CORAL/BCI.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - BCI 3 | solver: 4 | - CORAL: 5 | param_grid: 6 | - coraladapter__assume_centered: 7 | - false 8 | - true 9 | coraladapter__reg: 10 | - auto 11 | finalestimator__estimator_name: 12 | - LR_C2.0 13 | -------------------------------------------------------------------------------- /config/solvers/gaussian_reweight/Phishing.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Phishing 3 | solver: 4 | - gaussian_reweight: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - XGB_subsample0.8_colsample0.65_maxdepth20 8 | gaussianreweightadapter__reg: 9 | - auto 10 | -------------------------------------------------------------------------------- /config/solvers/CORAL/Mushrooms.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Mushrooms 3 | solver: 4 | - CORAL: 5 | param_grid: 6 | - coraladapter__assume_centered: 7 | - false 8 | - true 9 | coraladapter__reg: 10 | - auto 11 | finalestimator__estimator_name: 12 | - LR 13 | -------------------------------------------------------------------------------- /config/solvers/CORAL/Simulated.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Simulated 3 | solver: 4 | - CORAL: 5 | param_grid: 6 | - coraladapter__assume_centered: 7 | - false 8 | - true 9 | coraladapter__reg: 10 | - auto 11 | finalestimator__estimator_name: 12 | - SVC 13 | -------------------------------------------------------------------------------- /config/solvers/PCA/BCI.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - BCI 3 | solver: 4 | - PCA: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - LR_C2.0 8 | pca__n_components: 9 | - 1 10 | - 2 11 | - 5 12 | - 10 13 | - 20 14 | - 50 15 | - 100 16 | -------------------------------------------------------------------------------- /config/solvers/CORAL/Office31.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Office31 3 | solver: 4 | - CORAL: 5 | param_grid: 6 | - coraladapter__assume_centered: 7 | - false 8 | - true 9 | coraladapter__reg: 10 | - auto 11 | finalestimator__estimator_name: 12 | - LR_C0.01 13 | -------------------------------------------------------------------------------- /config/solvers/PCA/Mushrooms.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Mushrooms 3 | solver: 4 | - PCA: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - LR 8 | pca__n_components: 9 | - 1 10 | - 2 11 | - 5 12 | - 10 13 | - 20 14 | - 50 15 | - 100 16 | -------------------------------------------------------------------------------- /config/solvers/PCA/Office31.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Office31 3 | solver: 4 | - PCA: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - LR_C0.01 8 | pca__n_components: 9 | - 1 10 | - 2 11 | - 5 12 | - 10 13 | - 20 14 | - 50 15 | - 100 16 | -------------------------------------------------------------------------------- /config/solvers/PCA/Simulated.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Simulated 3 | solver: 4 | - PCA: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC 8 | pca__n_components: 9 | - 1 10 | - 2 11 | - 5 12 | - 10 13 | - 20 14 | - 50 15 | - 100 16 | -------------------------------------------------------------------------------- /config/solvers/CORAL/20NewsGroups.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - 20NewsGroups 3 | solver: 4 | - CORAL: 5 | param_grid: 6 | - coraladapter__assume_centered: 7 | - false 8 | - true 9 | coraladapter__reg: 10 | - auto 11 | finalestimator__estimator_name: 12 | - SVC_C10.0_Gamma10.0 13 | -------------------------------------------------------------------------------- /config/solvers/CORAL/mnist_usps.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - mnist_usps 3 | solver: 4 | - CORAL: 5 | param_grid: 6 | - coraladapter__assume_centered: 7 | - false 8 | - true 9 | coraladapter__reg: 10 | - auto 11 | finalestimator__estimator_name: 12 | - SVC_C10.0_Gamma0.01 13 | -------------------------------------------------------------------------------- /config/solvers/CORAL/AmazonReview.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - AmazonReview 3 | solver: 4 | - CORAL: 5 | param_grid: 6 | - coraladapter__assume_centered: 7 | - false 8 | - true 9 | coraladapter__reg: 10 | - auto 11 | finalestimator__estimator_name: 12 | - SVC_C1000.0_Gamma0.001 13 | -------------------------------------------------------------------------------- /config/solvers/PCA/mnist_usps.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - mnist_usps 3 | solver: 4 | - PCA: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C10.0_Gamma0.01 8 | pca__n_components: 9 | - 1 10 | - 2 11 | - 5 12 | - 10 13 | - 20 14 | - 50 15 | - 100 16 | -------------------------------------------------------------------------------- /config/solvers/CORAL/OfficeHomeResnet.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - OfficeHomeResnet 3 | solver: 4 | - CORAL: 5 | param_grid: 6 | - coraladapter__assume_centered: 7 | - false 8 | - true 9 | coraladapter__reg: 10 | - auto 11 | finalestimator__estimator_name: 12 | - SVC_C10.0_Gamma0.001 13 | -------------------------------------------------------------------------------- /config/solvers/PCA/20NewsGroups.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - 20NewsGroups 3 | solver: 4 | - PCA: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C10.0_Gamma10.0 8 | pca__n_components: 9 | - 1 10 | - 2 11 | - 5 12 | - 10 13 | - 20 14 | - 50 15 | - 100 16 | -------------------------------------------------------------------------------- /config/config_twentynewsgroups_test.yml: -------------------------------------------------------------------------------- 1 | # This is an example of a configuration file for the experiments 2 | 3 | # We define the solvers 4 | solver: 5 | - NO_DA_SOURCE_ONLY 6 | - NO_DA_TARGET_ONLY 7 | - CORAL 8 | 9 | # We define the datasets 10 | dataset: 11 | - 20NewsGroups[preprocessing='sentence_transformers'] 12 | -------------------------------------------------------------------------------- /config/solvers/CORAL/Phishing.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Phishing 3 | solver: 4 | - CORAL: 5 | param_grid: 6 | - coraladapter__assume_centered: 7 | - false 8 | - true 9 | coraladapter__reg: 10 | - auto 11 | finalestimator__estimator_name: 12 | - XGB_subsample0.8_colsample0.65_maxdepth20 13 | -------------------------------------------------------------------------------- /config/solvers/PCA/AmazonReview.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - AmazonReview 3 | solver: 4 | - PCA: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C1000.0_Gamma0.001 8 | pca__n_components: 9 | - 1 10 | - 2 11 | - 5 12 | - 10 13 | - 20 14 | - 50 15 | - 100 16 | -------------------------------------------------------------------------------- /config/solvers/PCA/OfficeHomeResnet.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - OfficeHomeResnet 3 | solver: 4 | - PCA: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C10.0_Gamma0.001 8 | pca__n_components: 9 | - 1 10 | - 2 11 | - 5 12 | - 10 13 | - 20 14 | - 50 15 | - 100 16 | -------------------------------------------------------------------------------- /config/solvers/discriminator_reweight/BCI.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - BCI 3 | solver: 4 | - discriminator_reweight: 5 | param_grid: 6 | - discriminatorreweightadapter__domain_classifier__estimator_name: 7 | - LR 8 | - SVC 9 | - KNN 10 | - XGB 11 | finalestimator__estimator_name: 12 | - LR_C2.0 13 | -------------------------------------------------------------------------------- /config/solvers/PCA/Phishing.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Phishing 3 | solver: 4 | - PCA: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - XGB_subsample0.8_colsample0.65_maxdepth20 8 | pca__n_components: 9 | - 1 10 | - 2 11 | - 5 12 | - 10 13 | - 20 14 | - 50 15 | - 100 16 | -------------------------------------------------------------------------------- /config/solvers/discriminator_reweight/Mushrooms.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Mushrooms 3 | solver: 4 | - discriminator_reweight: 5 | param_grid: 6 | - discriminatorreweightadapter__domain_classifier__estimator_name: 7 | - LR 8 | - SVC 9 | - KNN 10 | - XGB 11 | finalestimator__estimator_name: 12 | - LR 13 | -------------------------------------------------------------------------------- /config/solvers/discriminator_reweight/Simulated.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Simulated 3 | solver: 4 | - discriminator_reweight: 5 | param_grid: 6 | - discriminatorreweightadapter__domain_classifier__estimator_name: 7 | - LR 8 | - SVC 9 | - KNN 10 | - XGB 11 | finalestimator__estimator_name: 12 | - SVC 13 | -------------------------------------------------------------------------------- /config/solvers/subspace_alignment/BCI.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - BCI 3 | solver: 4 | - subspace_alignment: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - LR_C2.0 8 | subspacealignmentadapter__n_components: 9 | - 1 10 | - 2 11 | - 5 12 | - 10 13 | - 20 14 | - 50 15 | - 100 16 | -------------------------------------------------------------------------------- /config/solvers/discriminator_reweight/Office31.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Office31 3 | solver: 4 | - discriminator_reweight: 5 | param_grid: 6 | - discriminatorreweightadapter__domain_classifier__estimator_name: 7 | - LR 8 | - SVC 9 | - KNN 10 | - XGB 11 | finalestimator__estimator_name: 12 | - LR_C0.01 13 | -------------------------------------------------------------------------------- /config/example_config_deep.yml: -------------------------------------------------------------------------------- 1 | # This is an example of a configuration file 2 | # for the deep experiments 3 | 4 | # We define the objective parameters 5 | objective: 6 | - 'SKADA Domain Adaptation Benchmark': 7 | n_splits_data: 1 8 | test_size_data: 0.2 9 | random_state: 0 10 | solver: 11 | - "deep_*" 12 | 13 | 14 | -------------------------------------------------------------------------------- /config/example_config_shallow.yml: -------------------------------------------------------------------------------- 1 | # This is an example of a configuration file 2 | # for the shallow experiments 3 | 4 | # We define the objective parameters 5 | objective: 6 | - 'SKADA Domain Adaptation Benchmark': 7 | n_splits_data: 5 8 | test_size_data: 0.2 9 | random_state: 0 10 | solver: 11 | - "^(?!deep_)*" 12 | -------------------------------------------------------------------------------- /config/solvers/subspace_alignment/Mushrooms.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Mushrooms 3 | solver: 4 | - subspace_alignment: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - LR 8 | subspacealignmentadapter__n_components: 9 | - 1 10 | - 2 11 | - 5 12 | - 10 13 | - 20 14 | - 50 15 | - 100 16 | -------------------------------------------------------------------------------- /config/solvers/subspace_alignment/Simulated.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Simulated 3 | solver: 4 | - subspace_alignment: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC 8 | subspacealignmentadapter__n_components: 9 | - 1 10 | - 2 11 | - 5 12 | - 10 13 | - 20 14 | - 50 15 | - 100 16 | -------------------------------------------------------------------------------- /config/solvers/subspace_alignment/Office31.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Office31 3 | solver: 4 | - subspace_alignment: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - LR_C0.01 8 | subspacealignmentadapter__n_components: 9 | - 1 10 | - 2 11 | - 5 12 | - 10 13 | - 20 14 | - 50 15 | - 100 16 | -------------------------------------------------------------------------------- /config/solvers/discriminator_reweight/20NewsGroups.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - 20NewsGroups 3 | solver: 4 | - discriminator_reweight: 5 | param_grid: 6 | - discriminatorreweightadapter__domain_classifier__estimator_name: 7 | - LR 8 | - SVC 9 | - KNN 10 | - XGB 11 | finalestimator__estimator_name: 12 | - SVC_C10.0_Gamma10.0 13 | -------------------------------------------------------------------------------- /config/solvers/discriminator_reweight/mnist_usps.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - mnist_usps 3 | solver: 4 | - discriminator_reweight: 5 | param_grid: 6 | - discriminatorreweightadapter__domain_classifier__estimator_name: 7 | - LR 8 | - SVC 9 | - KNN 10 | - XGB 11 | finalestimator__estimator_name: 12 | - SVC_C10.0_Gamma0.01 13 | -------------------------------------------------------------------------------- /config/solvers/discriminator_reweight/AmazonReview.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - AmazonReview 3 | solver: 4 | - discriminator_reweight: 5 | param_grid: 6 | - discriminatorreweightadapter__domain_classifier__estimator_name: 7 | - LR 8 | - SVC 9 | - KNN 10 | - XGB 11 | finalestimator__estimator_name: 12 | - SVC_C1000.0_Gamma0.001 13 | -------------------------------------------------------------------------------- /config/solvers/nearest_neighbor_reweight/BCI.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - BCI 3 | solver: 4 | - nearest_neighbor_reweight: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - LR_C2.0 8 | nearestneighborreweightadapter__laplace_smoothing: 9 | - true 10 | - false 11 | nearestneighborreweightadapter__n_neighbors: 12 | - 1 13 | -------------------------------------------------------------------------------- /test_config.py: -------------------------------------------------------------------------------- 1 | import sys # noqa: F401 2 | 3 | import pytest # noqa: F401 4 | 5 | 6 | def check_test_solver_install(solver_class): 7 | """Hook called in `test_solver_install`. 8 | 9 | If one solver needs to be skip/xfailed on some 10 | particular architecture, call pytest.xfail when 11 | detecting the situation. 12 | """ 13 | pass 14 | -------------------------------------------------------------------------------- /config/solvers/density_reweight/BCI.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - BCI 3 | solver: 4 | - density_reweight: 5 | param_grid: 6 | - densityreweightadapter__weight_estimator__bandwidth: 7 | - 0.01 8 | - 0.1 9 | - 1.0 10 | - 10.0 11 | - 100.0 12 | - scott 13 | - silverman 14 | finalestimator__estimator_name: 15 | - LR_C2.0 16 | -------------------------------------------------------------------------------- /config/solvers/discriminator_reweight/OfficeHomeResnet.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - OfficeHomeResnet 3 | solver: 4 | - discriminator_reweight: 5 | param_grid: 6 | - discriminatorreweightadapter__domain_classifier__estimator_name: 7 | - LR 8 | - SVC 9 | - KNN 10 | - XGB 11 | finalestimator__estimator_name: 12 | - SVC_C10.0_Gamma0.001 13 | -------------------------------------------------------------------------------- /config/solvers/nearest_neighbor_reweight/Mushrooms.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Mushrooms 3 | solver: 4 | - nearest_neighbor_reweight: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - LR 8 | nearestneighborreweightadapter__laplace_smoothing: 9 | - true 10 | - false 11 | nearestneighborreweightadapter__n_neighbors: 12 | - 1 13 | -------------------------------------------------------------------------------- /config/solvers/subspace_alignment/20NewsGroups.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - 20NewsGroups 3 | solver: 4 | - subspace_alignment: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C10.0_Gamma10.0 8 | subspacealignmentadapter__n_components: 9 | - 1 10 | - 2 11 | - 5 12 | - 10 13 | - 20 14 | - 50 15 | - 100 16 | -------------------------------------------------------------------------------- /config/solvers/subspace_alignment/mnist_usps.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - mnist_usps 3 | solver: 4 | - subspace_alignment: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C10.0_Gamma0.01 8 | subspacealignmentadapter__n_components: 9 | - 1 10 | - 2 11 | - 5 12 | - 10 13 | - 20 14 | - 50 15 | - 100 16 | -------------------------------------------------------------------------------- /config/solvers/discriminator_reweight/Phishing.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Phishing 3 | solver: 4 | - discriminator_reweight: 5 | param_grid: 6 | - discriminatorreweightadapter__domain_classifier__estimator_name: 7 | - LR 8 | - SVC 9 | - KNN 10 | - XGB 11 | finalestimator__estimator_name: 12 | - XGB_subsample0.8_colsample0.65_maxdepth20 13 | -------------------------------------------------------------------------------- /config/solvers/nearest_neighbor_reweight/Office31.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Office31 3 | solver: 4 | - nearest_neighbor_reweight: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - LR_C0.01 8 | nearestneighborreweightadapter__laplace_smoothing: 9 | - true 10 | - false 11 | nearestneighborreweightadapter__n_neighbors: 12 | - 1 13 | -------------------------------------------------------------------------------- /config/solvers/nearest_neighbor_reweight/Simulated.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Simulated 3 | solver: 4 | - nearest_neighbor_reweight: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC 8 | nearestneighborreweightadapter__laplace_smoothing: 9 | - true 10 | - false 11 | nearestneighborreweightadapter__n_neighbors: 12 | - 1 13 | -------------------------------------------------------------------------------- /config/solvers/subspace_alignment/AmazonReview.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - AmazonReview 3 | solver: 4 | - subspace_alignment: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C1000.0_Gamma0.001 8 | subspacealignmentadapter__n_components: 9 | - 1 10 | - 2 11 | - 5 12 | - 10 13 | - 20 14 | - 50 15 | - 100 16 | -------------------------------------------------------------------------------- /config/solvers/density_reweight/Mushrooms.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Mushrooms 3 | solver: 4 | - density_reweight: 5 | param_grid: 6 | - densityreweightadapter__weight_estimator__bandwidth: 7 | - 0.01 8 | - 0.1 9 | - 1.0 10 | - 10.0 11 | - 100.0 12 | - scott 13 | - silverman 14 | finalestimator__estimator_name: 15 | - LR 16 | -------------------------------------------------------------------------------- /config/solvers/density_reweight/Simulated.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Simulated 3 | solver: 4 | - density_reweight: 5 | param_grid: 6 | - densityreweightadapter__weight_estimator__bandwidth: 7 | - 0.01 8 | - 0.1 9 | - 1.0 10 | - 10.0 11 | - 100.0 12 | - scott 13 | - silverman 14 | finalestimator__estimator_name: 15 | - SVC 16 | -------------------------------------------------------------------------------- /config/solvers/subspace_alignment/OfficeHomeResnet.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - OfficeHomeResnet 3 | solver: 4 | - subspace_alignment: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C10.0_Gamma0.001 8 | subspacealignmentadapter__n_components: 9 | - 1 10 | - 2 11 | - 5 12 | - 10 13 | - 20 14 | - 50 15 | - 100 16 | -------------------------------------------------------------------------------- /config/solvers/density_reweight/Office31.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Office31 3 | solver: 4 | - density_reweight: 5 | param_grid: 6 | - densityreweightadapter__weight_estimator__bandwidth: 7 | - 0.01 8 | - 0.1 9 | - 1.0 10 | - 10.0 11 | - 100.0 12 | - scott 13 | - silverman 14 | finalestimator__estimator_name: 15 | - LR_C0.01 16 | -------------------------------------------------------------------------------- /config/solvers/linear_ot_mapping/BCI.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - BCI 3 | solver: 4 | - linear_ot_mapping: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - LR_C2.0 8 | linearotmappingadapter__bias: 9 | - true 10 | - false 11 | linearotmappingadapter__reg: 12 | - 1.0e-08 13 | - 1.0e-06 14 | - 0.1 15 | - 1 16 | - 10 17 | -------------------------------------------------------------------------------- /config/solvers/ot_mapping/BCI.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - BCI 3 | solver: 4 | - ot_mapping: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - LR_C2.0 8 | otmappingadapter__max_iter: 9 | - 1000000 10 | otmappingadapter__metric: 11 | - sqeuclidean 12 | - cosine 13 | - cityblock 14 | otmappingadapter__norm: 15 | - median 16 | -------------------------------------------------------------------------------- /config/solvers/subspace_alignment/Phishing.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Phishing 3 | solver: 4 | - subspace_alignment: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - XGB_subsample0.8_colsample0.65_maxdepth20 8 | subspacealignmentadapter__n_components: 9 | - 1 10 | - 2 11 | - 5 12 | - 10 13 | - 20 14 | - 50 15 | - 100 16 | -------------------------------------------------------------------------------- /config/solvers/nearest_neighbor_reweight/mnist_usps.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - mnist_usps 3 | solver: 4 | - nearest_neighbor_reweight: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C10.0_Gamma0.01 8 | nearestneighborreweightadapter__laplace_smoothing: 9 | - true 10 | - false 11 | nearestneighborreweightadapter__n_neighbors: 12 | - 1 13 | -------------------------------------------------------------------------------- /config/solvers/ot_mapping/Mushrooms.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Mushrooms 3 | solver: 4 | - ot_mapping: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - LR 8 | otmappingadapter__max_iter: 9 | - 1000000 10 | otmappingadapter__metric: 11 | - sqeuclidean 12 | - cosine 13 | - cityblock 14 | otmappingadapter__norm: 15 | - median 16 | -------------------------------------------------------------------------------- /config/solvers/linear_ot_mapping/Mushrooms.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Mushrooms 3 | solver: 4 | - linear_ot_mapping: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - LR 8 | linearotmappingadapter__bias: 9 | - true 10 | - false 11 | linearotmappingadapter__reg: 12 | - 1.0e-08 13 | - 1.0e-06 14 | - 0.1 15 | - 1 16 | - 10 17 | -------------------------------------------------------------------------------- /config/solvers/linear_ot_mapping/Simulated.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Simulated 3 | solver: 4 | - linear_ot_mapping: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC 8 | linearotmappingadapter__bias: 9 | - true 10 | - false 11 | linearotmappingadapter__reg: 12 | - 1.0e-08 13 | - 1.0e-06 14 | - 0.1 15 | - 1 16 | - 10 17 | -------------------------------------------------------------------------------- /config/solvers/nearest_neighbor_reweight/20NewsGroups.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - 20NewsGroups 3 | solver: 4 | - nearest_neighbor_reweight: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C10.0_Gamma10.0 8 | nearestneighborreweightadapter__laplace_smoothing: 9 | - true 10 | - false 11 | nearestneighborreweightadapter__n_neighbors: 12 | - 1 13 | -------------------------------------------------------------------------------- /config/solvers/ot_mapping/Office31.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Office31 3 | solver: 4 | - ot_mapping: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - LR_C0.01 8 | otmappingadapter__max_iter: 9 | - 1000000 10 | otmappingadapter__metric: 11 | - sqeuclidean 12 | - cosine 13 | - cityblock 14 | otmappingadapter__norm: 15 | - median 16 | -------------------------------------------------------------------------------- /config/solvers/ot_mapping/Simulated.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Simulated 3 | solver: 4 | - ot_mapping: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC 8 | otmappingadapter__max_iter: 9 | - 1000000 10 | otmappingadapter__metric: 11 | - sqeuclidean 12 | - cosine 13 | - cityblock 14 | otmappingadapter__norm: 15 | - median 16 | -------------------------------------------------------------------------------- /config/solvers/density_reweight/20NewsGroups.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - 20NewsGroups 3 | solver: 4 | - density_reweight: 5 | param_grid: 6 | - densityreweightadapter__weight_estimator__bandwidth: 7 | - 0.01 8 | - 0.1 9 | - 1.0 10 | - 10.0 11 | - 100.0 12 | - scott 13 | - silverman 14 | finalestimator__estimator_name: 15 | - SVC_C10.0_Gamma10.0 16 | -------------------------------------------------------------------------------- /config/solvers/density_reweight/mnist_usps.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - mnist_usps 3 | solver: 4 | - density_reweight: 5 | param_grid: 6 | - densityreweightadapter__weight_estimator__bandwidth: 7 | - 0.01 8 | - 0.1 9 | - 1.0 10 | - 10.0 11 | - 100.0 12 | - scott 13 | - silverman 14 | finalestimator__estimator_name: 15 | - SVC_C10.0_Gamma0.01 16 | -------------------------------------------------------------------------------- /config/solvers/linear_ot_mapping/Office31.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Office31 3 | solver: 4 | - linear_ot_mapping: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - LR_C0.01 8 | linearotmappingadapter__bias: 9 | - true 10 | - false 11 | linearotmappingadapter__reg: 12 | - 1.0e-08 13 | - 1.0e-06 14 | - 0.1 15 | - 1 16 | - 10 17 | -------------------------------------------------------------------------------- /config/solvers/nearest_neighbor_reweight/AmazonReview.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - AmazonReview 3 | solver: 4 | - nearest_neighbor_reweight: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C1000.0_Gamma0.001 8 | nearestneighborreweightadapter__laplace_smoothing: 9 | - true 10 | - false 11 | nearestneighborreweightadapter__n_neighbors: 12 | - 1 13 | -------------------------------------------------------------------------------- /config/solvers/JDOT_SVC/Mushrooms.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Mushrooms 3 | solver: 4 | - JDOT_SVC: 5 | param_grid: 6 | - jdotclassifier__alpha: 7 | - 0.5 8 | jdotclassifier__base_estimator__estimator_name: 9 | - SVC 10 | jdotclassifier__n_iter_max: 11 | - 3 12 | jdotclassifier__thr_weights: 13 | - 1.0e-07 14 | jdotclassifier__tol: 15 | - 1.0e-06 16 | -------------------------------------------------------------------------------- /config/solvers/density_reweight/AmazonReview.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - AmazonReview 3 | solver: 4 | - density_reweight: 5 | param_grid: 6 | - densityreweightadapter__weight_estimator__bandwidth: 7 | - 0.01 8 | - 0.1 9 | - 1.0 10 | - 10.0 11 | - 100.0 12 | - scott 13 | - silverman 14 | finalestimator__estimator_name: 15 | - SVC_C1000.0_Gamma0.001 16 | -------------------------------------------------------------------------------- /config/solvers/nearest_neighbor_reweight/OfficeHomeResnet.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - OfficeHomeResnet 3 | solver: 4 | - nearest_neighbor_reweight: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C10.0_Gamma0.001 8 | nearestneighborreweightadapter__laplace_smoothing: 9 | - true 10 | - false 11 | nearestneighborreweightadapter__n_neighbors: 12 | - 1 13 | -------------------------------------------------------------------------------- /config/solvers/density_reweight/OfficeHomeResnet.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - OfficeHomeResnet 3 | solver: 4 | - density_reweight: 5 | param_grid: 6 | - densityreweightadapter__weight_estimator__bandwidth: 7 | - 0.01 8 | - 0.1 9 | - 1.0 10 | - 10.0 11 | - 100.0 12 | - scott 13 | - silverman 14 | finalestimator__estimator_name: 15 | - SVC_C10.0_Gamma0.001 16 | -------------------------------------------------------------------------------- /config/solvers/nearest_neighbor_reweight/Phishing.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Phishing 3 | solver: 4 | - nearest_neighbor_reweight: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - XGB_subsample0.8_colsample0.65_maxdepth20 8 | nearestneighborreweightadapter__laplace_smoothing: 9 | - true 10 | - false 11 | nearestneighborreweightadapter__n_neighbors: 12 | - 1 13 | -------------------------------------------------------------------------------- /config/solvers/ot_mapping/mnist_usps.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - mnist_usps 3 | solver: 4 | - ot_mapping: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C10.0_Gamma0.01 8 | otmappingadapter__max_iter: 9 | - 1000000 10 | otmappingadapter__metric: 11 | - sqeuclidean 12 | - cosine 13 | - cityblock 14 | otmappingadapter__norm: 15 | - median 16 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Cache directories 2 | *.ipynb_checkpoints 3 | *.ipynb 4 | .pytest_cache 5 | __pycache__ 6 | *__cache__ 7 | *.egg-info 8 | .coverage 9 | **/outputs 10 | joblib/ 11 | data/ 12 | slurm_output.txt 13 | 14 | # IDE specific folders 15 | .vscode 16 | 17 | # Config files 18 | benchopt.ini 19 | 20 | .DS_Store 21 | coverage.xml 22 | 23 | # Results 24 | benchopt_run/ 25 | 20newsgroup/ 26 | mnist_usps/ 27 | -------------------------------------------------------------------------------- /config/solvers/density_reweight/Phishing.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Phishing 3 | solver: 4 | - density_reweight: 5 | param_grid: 6 | - densityreweightadapter__weight_estimator__bandwidth: 7 | - 0.01 8 | - 0.1 9 | - 1.0 10 | - 10.0 11 | - 100.0 12 | - scott 13 | - silverman 14 | finalestimator__estimator_name: 15 | - XGB_subsample0.8_colsample0.65_maxdepth20 16 | -------------------------------------------------------------------------------- /config/solvers/linear_ot_mapping/20NewsGroups.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - 20NewsGroups 3 | solver: 4 | - linear_ot_mapping: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C10.0_Gamma10.0 8 | linearotmappingadapter__bias: 9 | - true 10 | - false 11 | linearotmappingadapter__reg: 12 | - 1.0e-08 13 | - 1.0e-06 14 | - 0.1 15 | - 1 16 | - 10 17 | -------------------------------------------------------------------------------- /config/solvers/linear_ot_mapping/mnist_usps.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - mnist_usps 3 | solver: 4 | - linear_ot_mapping: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C10.0_Gamma0.01 8 | linearotmappingadapter__bias: 9 | - true 10 | - false 11 | linearotmappingadapter__reg: 12 | - 1.0e-08 13 | - 1.0e-06 14 | - 0.1 15 | - 1 16 | - 10 17 | -------------------------------------------------------------------------------- /config/solvers/ot_mapping/20NewsGroups.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - 20NewsGroups 3 | solver: 4 | - ot_mapping: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C10.0_Gamma10.0 8 | otmappingadapter__max_iter: 9 | - 1000000 10 | otmappingadapter__metric: 11 | - sqeuclidean 12 | - cosine 13 | - cityblock 14 | otmappingadapter__norm: 15 | - median 16 | -------------------------------------------------------------------------------- /config/solvers/JDOT_SVC/Phishing.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Phishing 3 | solver: 4 | - JDOT_SVC: 5 | param_grid: 6 | - jdotclassifier__alpha: 7 | - 0.5 8 | jdotclassifier__base_estimator__estimator_name: 9 | - SVC_C100.0_Gamma0.1 10 | jdotclassifier__n_iter_max: 11 | - 3 12 | jdotclassifier__thr_weights: 13 | - 1.0e-07 14 | jdotclassifier__tol: 15 | - 1.0e-06 16 | -------------------------------------------------------------------------------- /config/solvers/linear_ot_mapping/AmazonReview.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - AmazonReview 3 | solver: 4 | - linear_ot_mapping: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C1000.0_Gamma0.001 8 | linearotmappingadapter__bias: 9 | - true 10 | - false 11 | linearotmappingadapter__reg: 12 | - 1.0e-08 13 | - 1.0e-06 14 | - 0.1 15 | - 1 16 | - 10 17 | -------------------------------------------------------------------------------- /config/solvers/ot_mapping/AmazonReview.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - AmazonReview 3 | solver: 4 | - ot_mapping: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C1000.0_Gamma0.001 8 | otmappingadapter__max_iter: 9 | - 1000000 10 | otmappingadapter__metric: 11 | - sqeuclidean 12 | - cosine 13 | - cityblock 14 | otmappingadapter__norm: 15 | - median 16 | -------------------------------------------------------------------------------- /config/solvers/linear_ot_mapping/OfficeHomeResnet.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - OfficeHomeResnet 3 | solver: 4 | - linear_ot_mapping: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C10.0_Gamma0.001 8 | linearotmappingadapter__bias: 9 | - true 10 | - false 11 | linearotmappingadapter__reg: 12 | - 1.0e-08 13 | - 1.0e-06 14 | - 0.1 15 | - 1 16 | - 10 17 | -------------------------------------------------------------------------------- /config/solvers/ot_mapping/OfficeHomeResnet.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - OfficeHomeResnet 3 | solver: 4 | - ot_mapping: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C10.0_Gamma0.001 8 | otmappingadapter__max_iter: 9 | - 1000000 10 | otmappingadapter__metric: 11 | - sqeuclidean 12 | - cosine 13 | - cityblock 14 | otmappingadapter__norm: 15 | - median 16 | -------------------------------------------------------------------------------- /config/solvers/linear_ot_mapping/Phishing.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Phishing 3 | solver: 4 | - linear_ot_mapping: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - XGB_subsample0.8_colsample0.65_maxdepth20 8 | linearotmappingadapter__bias: 9 | - true 10 | - false 11 | linearotmappingadapter__reg: 12 | - 1.0e-08 13 | - 1.0e-06 14 | - 0.1 15 | - 1 16 | - 10 17 | -------------------------------------------------------------------------------- /config/solvers/ot_mapping/Phishing.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Phishing 3 | solver: 4 | - ot_mapping: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - XGB_subsample0.8_colsample0.65_maxdepth20 8 | otmappingadapter__max_iter: 9 | - 1000000 10 | otmappingadapter__metric: 11 | - sqeuclidean 12 | - cosine 13 | - cityblock 14 | otmappingadapter__norm: 15 | - median 16 | -------------------------------------------------------------------------------- /config/solvers/JDOT_SVC/OfficeHomeResnet.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - OfficeHomeResnet 3 | solver: 4 | - JDOT_SVC: 5 | param_grid: 6 | - jdotclassifier__alpha: 7 | - 0.5 8 | jdotclassifier__base_estimator__estimator_name: 9 | - SVC_C10.0_Gamma0.001 10 | jdotclassifier__n_iter_max: 11 | - 3 12 | jdotclassifier__thr_weights: 13 | - 1.0e-07 14 | jdotclassifier__tol: 15 | - 1.0e-06 16 | -------------------------------------------------------------------------------- /config/solvers/JDOT_SVC/Office31.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Office31 3 | solver: 4 | - JDOT_SVC: 5 | param_grid: 6 | - jdotclassifier__alpha: 7 | - 0.1 8 | - 0.3 9 | - 0.5 10 | - 0.7 11 | - 0.9 12 | jdotclassifier__base_estimator__estimator_name: 13 | - SVC 14 | jdotclassifier__n_iter_max: 15 | - 100 16 | jdotclassifier__thr_weights: 17 | - 1.0e-07 18 | jdotclassifier__tol: 19 | - 1.0e-06 20 | -------------------------------------------------------------------------------- /config/solvers/JDOT_SVC/BCI.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - BCI 3 | solver: 4 | - JDOT_SVC: 5 | param_grid: 6 | - jdotclassifier__alpha: 7 | - 0.1 8 | - 0.3 9 | - 0.5 10 | - 0.7 11 | - 0.9 12 | jdotclassifier__base_estimator__estimator_name: 13 | - SVC_C10.0_Gamma0.1 14 | jdotclassifier__n_iter_max: 15 | - 100 16 | jdotclassifier__thr_weights: 17 | - 1.0e-07 18 | jdotclassifier__tol: 19 | - 1.0e-06 20 | -------------------------------------------------------------------------------- /config/solvers/JDOT_SVC/Simulated.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Simulated 3 | solver: 4 | - JDOT_SVC: 5 | param_grid: 6 | - jdotclassifier__alpha: 7 | - 0.1 8 | - 0.3 9 | - 0.5 10 | - 0.7 11 | - 0.9 12 | jdotclassifier__base_estimator__estimator_name: 13 | - SVC 14 | jdotclassifier__n_iter_max: 15 | - 100 16 | jdotclassifier__thr_weights: 17 | - 1.0e-07 18 | jdotclassifier__tol: 19 | - 1.0e-06 20 | -------------------------------------------------------------------------------- /config/solvers/transfer_component_analysis/20NewsGroups.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - 20NewsGroups 3 | solver: 4 | - transfer_component_analysis: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C10.0_Gamma10.0 8 | transfercomponentanalysisadapter__kernel: 9 | - rbf 10 | transfercomponentanalysisadapter__mu: 11 | - 10 12 | - 100 13 | transfercomponentanalysisadapter__n_components: 14 | - 5 15 | - 10 16 | - 20 17 | -------------------------------------------------------------------------------- /config/solvers/transfer_component_analysis/mnist_usps.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - mnist_usps 3 | solver: 4 | - transfer_component_analysis: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C10.0_Gamma0.01 8 | transfercomponentanalysisadapter__kernel: 9 | - rbf 10 | transfercomponentanalysisadapter__mu: 11 | - 10 12 | - 100 13 | transfercomponentanalysisadapter__n_components: 14 | - 5 15 | - 10 16 | - 20 17 | -------------------------------------------------------------------------------- /config/solvers/JDOT_SVC/mnist_usps.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - mnist_usps 3 | solver: 4 | - JDOT_SVC: 5 | param_grid: 6 | - jdotclassifier__alpha: 7 | - 0.1 8 | - 0.3 9 | - 0.5 10 | - 0.7 11 | - 0.9 12 | jdotclassifier__base_estimator__estimator_name: 13 | - SVC_C10.0_Gamma0.01 14 | jdotclassifier__n_iter_max: 15 | - 100 16 | jdotclassifier__thr_weights: 17 | - 1.0e-07 18 | jdotclassifier__tol: 19 | - 1.0e-06 20 | -------------------------------------------------------------------------------- /config/solvers/JDOT_SVC/20NewsGroups.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - 20NewsGroups 3 | solver: 4 | - JDOT_SVC: 5 | param_grid: 6 | - jdotclassifier__alpha: 7 | - 0.1 8 | - 0.3 9 | - 0.5 10 | - 0.7 11 | - 0.9 12 | jdotclassifier__base_estimator__estimator_name: 13 | - SVC_C10.0_Gamma10.0 14 | jdotclassifier__n_iter_max: 15 | - 100 16 | jdotclassifier__thr_weights: 17 | - 1.0e-07 18 | jdotclassifier__tol: 19 | - 1.0e-06 20 | -------------------------------------------------------------------------------- /config/solvers/transfer_component_analysis/OfficeHomeResnet.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - OfficeHomeResnet 3 | solver: 4 | - transfer_component_analysis: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C10.0_Gamma0.001 8 | transfercomponentanalysisadapter__kernel: 9 | - rbf 10 | transfercomponentanalysisadapter__mu: 11 | - 10 12 | - 100 13 | transfercomponentanalysisadapter__n_components: 14 | - 5 15 | - 10 16 | - 20 17 | -------------------------------------------------------------------------------- /config/solvers/JDOT_SVC/AmazonReview.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - AmazonReview 3 | solver: 4 | - JDOT_SVC: 5 | param_grid: 6 | - jdotclassifier__alpha: 7 | - 0.1 8 | - 0.3 9 | - 0.5 10 | - 0.7 11 | - 0.9 12 | jdotclassifier__base_estimator__estimator_name: 13 | - SVC_C1000.0_Gamma0.001 14 | jdotclassifier__n_iter_max: 15 | - 100 16 | jdotclassifier__thr_weights: 17 | - 1.0e-07 18 | jdotclassifier__tol: 19 | - 1.0e-06 20 | -------------------------------------------------------------------------------- /config/solvers/MMDSConS/20NewsGroups.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - 20NewsGroups 3 | solver: 4 | - MMDSConS: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C10.0_Gamma10.0 8 | mmdlsconsmappingadapter__gamma: 9 | - 1 10 | mmdlsconsmappingadapter__max_iter: 11 | - 4 12 | mmdlsconsmappingadapter__reg_k: 13 | - 1.0e-08 14 | mmdlsconsmappingadapter__reg_m: 15 | - 1.0e-08 16 | mmdlsconsmappingadapter__tol: 17 | - 1.0e-05 18 | -------------------------------------------------------------------------------- /config/solvers/transfer_component_analysis/Phishing.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Phishing 3 | solver: 4 | - transfer_component_analysis: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - XGB_subsample0.8_colsample0.65_maxdepth20 8 | transfercomponentanalysisadapter__kernel: 9 | - rbf 10 | transfercomponentanalysisadapter__mu: 11 | - 10 12 | - 100 13 | transfercomponentanalysisadapter__n_components: 14 | - 5 15 | - 10 16 | - 20 17 | -------------------------------------------------------------------------------- /config/solvers/transfer_component_analysis/BCI.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - BCI 3 | solver: 4 | - transfer_component_analysis: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - LR_C2.0 8 | transfercomponentanalysisadapter__kernel: 9 | - rbf 10 | transfercomponentanalysisadapter__mu: 11 | - 10 12 | - 100 13 | transfercomponentanalysisadapter__n_components: 14 | - 1 15 | - 2 16 | - 5 17 | - 10 18 | - 20 19 | - 50 20 | - 100 21 | -------------------------------------------------------------------------------- /config/solvers/transfer_component_analysis/Mushrooms.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Mushrooms 3 | solver: 4 | - transfer_component_analysis: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - LR 8 | transfercomponentanalysisadapter__kernel: 9 | - rbf 10 | transfercomponentanalysisadapter__mu: 11 | - 10 12 | - 100 13 | transfercomponentanalysisadapter__n_components: 14 | - 1 15 | - 2 16 | - 5 17 | - 10 18 | - 20 19 | - 50 20 | - 100 21 | -------------------------------------------------------------------------------- /config/solvers/transfer_component_analysis/Office31.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Office31 3 | solver: 4 | - transfer_component_analysis: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - LR_C0.01 8 | transfercomponentanalysisadapter__kernel: 9 | - rbf 10 | transfercomponentanalysisadapter__mu: 11 | - 10 12 | - 100 13 | transfercomponentanalysisadapter__n_components: 14 | - 1 15 | - 2 16 | - 5 17 | - 10 18 | - 20 19 | - 50 20 | - 100 21 | -------------------------------------------------------------------------------- /config/solvers/transfer_component_analysis/Simulated.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Simulated 3 | solver: 4 | - transfer_component_analysis: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC 8 | transfercomponentanalysisadapter__kernel: 9 | - rbf 10 | transfercomponentanalysisadapter__mu: 11 | - 10 12 | - 100 13 | transfercomponentanalysisadapter__n_components: 14 | - 1 15 | - 2 16 | - 5 17 | - 10 18 | - 20 19 | - 50 20 | - 100 21 | -------------------------------------------------------------------------------- /config/solvers/MMDSConS/BCI.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - BCI 3 | solver: 4 | - MMDSConS: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - LR_C2.0 8 | mmdlsconsmappingadapter__gamma: 9 | - 0.01 10 | - 0.1 11 | - 1 12 | - 10 13 | - 100 14 | mmdlsconsmappingadapter__max_iter: 15 | - 20 16 | mmdlsconsmappingadapter__reg_k: 17 | - 1.0e-08 18 | mmdlsconsmappingadapter__reg_m: 19 | - 1.0e-08 20 | mmdlsconsmappingadapter__tol: 21 | - 1.0e-05 22 | -------------------------------------------------------------------------------- /config/solvers/MMDSConS/Mushrooms.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Mushrooms 3 | solver: 4 | - MMDSConS: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - LR 8 | mmdlsconsmappingadapter__gamma: 9 | - 0.01 10 | - 0.1 11 | - 1 12 | - 10 13 | - 100 14 | mmdlsconsmappingadapter__max_iter: 15 | - 20 16 | mmdlsconsmappingadapter__reg_k: 17 | - 1.0e-08 18 | mmdlsconsmappingadapter__reg_m: 19 | - 1.0e-08 20 | mmdlsconsmappingadapter__tol: 21 | - 1.0e-05 22 | -------------------------------------------------------------------------------- /config/solvers/MMDSConS/Simulated.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Simulated 3 | solver: 4 | - MMDSConS: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC 8 | mmdlsconsmappingadapter__gamma: 9 | - 0.01 10 | - 0.1 11 | - 1 12 | - 10 13 | - 100 14 | mmdlsconsmappingadapter__max_iter: 15 | - 20 16 | mmdlsconsmappingadapter__reg_k: 17 | - 1.0e-08 18 | mmdlsconsmappingadapter__reg_m: 19 | - 1.0e-08 20 | mmdlsconsmappingadapter__tol: 21 | - 1.0e-05 22 | -------------------------------------------------------------------------------- /config/solvers/TarS/BCI.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - BCI 3 | solver: 4 | - TarS: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - LR_C2.0 8 | mmdtarsreweightadapter__gamma: 9 | - 0.0001 10 | - 0.001 11 | - 0.01 12 | - 0.1 13 | - 1.0 14 | - 10.0 15 | - 100.0 16 | - 1000.0 17 | - null 18 | mmdtarsreweightadapter__max_iter: 19 | - 1000 20 | mmdtarsreweightadapter__reg: 21 | - 1.0e-06 22 | mmdtarsreweightadapter__tol: 23 | - 1.0e-06 24 | -------------------------------------------------------------------------------- /config/solvers/MMDSConS/Office31.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Office31 3 | solver: 4 | - MMDSConS: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - LR_C0.01 8 | mmdlsconsmappingadapter__gamma: 9 | - 0.01 10 | - 0.1 11 | - 1 12 | - 10 13 | - 100 14 | mmdlsconsmappingadapter__max_iter: 15 | - 20 16 | mmdlsconsmappingadapter__reg_k: 17 | - 1.0e-08 18 | mmdlsconsmappingadapter__reg_m: 19 | - 1.0e-08 20 | mmdlsconsmappingadapter__tol: 21 | - 1.0e-05 22 | -------------------------------------------------------------------------------- /config/solvers/TarS/Mushrooms.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Mushrooms 3 | solver: 4 | - TarS: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - LR 8 | mmdtarsreweightadapter__gamma: 9 | - 0.0001 10 | - 0.001 11 | - 0.01 12 | - 0.1 13 | - 1.0 14 | - 10.0 15 | - 100.0 16 | - 1000.0 17 | - null 18 | mmdtarsreweightadapter__max_iter: 19 | - 1000 20 | mmdtarsreweightadapter__reg: 21 | - 1.0e-06 22 | mmdtarsreweightadapter__tol: 23 | - 1.0e-06 24 | -------------------------------------------------------------------------------- /config/solvers/TarS/Office31.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Office31 3 | solver: 4 | - TarS: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - LR_C0.01 8 | mmdtarsreweightadapter__gamma: 9 | - 0.0001 10 | - 0.001 11 | - 0.01 12 | - 0.1 13 | - 1.0 14 | - 10.0 15 | - 100.0 16 | - 1000.0 17 | - null 18 | mmdtarsreweightadapter__max_iter: 19 | - 1000 20 | mmdtarsreweightadapter__reg: 21 | - 1.0e-06 22 | mmdtarsreweightadapter__tol: 23 | - 1.0e-06 24 | -------------------------------------------------------------------------------- /config/solvers/TarS/Simulated.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Simulated 3 | solver: 4 | - TarS: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC 8 | mmdtarsreweightadapter__gamma: 9 | - 0.0001 10 | - 0.001 11 | - 0.01 12 | - 0.1 13 | - 1.0 14 | - 10.0 15 | - 100.0 16 | - 1000.0 17 | - null 18 | mmdtarsreweightadapter__max_iter: 19 | - 1000 20 | mmdtarsreweightadapter__reg: 21 | - 1.0e-06 22 | mmdtarsreweightadapter__tol: 23 | - 1.0e-06 24 | -------------------------------------------------------------------------------- /config/solvers/transfer_component_analysis/AmazonReview.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - AmazonReview 3 | solver: 4 | - transfer_component_analysis: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C1000.0_Gamma0.001 8 | transfercomponentanalysisadapter__kernel: 9 | - rbf 10 | transfercomponentanalysisadapter__mu: 11 | - 10 12 | - 100 13 | transfercomponentanalysisadapter__n_components: 14 | - 1 15 | - 2 16 | - 5 17 | - 10 18 | - 20 19 | - 50 20 | - 100 21 | -------------------------------------------------------------------------------- /config/solvers/MMDSConS/mnist_usps.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - mnist_usps 3 | solver: 4 | - MMDSConS: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C10.0_Gamma0.01 8 | mmdlsconsmappingadapter__gamma: 9 | - 0.01 10 | - 0.1 11 | - 1 12 | - 10 13 | - 100 14 | mmdlsconsmappingadapter__max_iter: 15 | - 20 16 | mmdlsconsmappingadapter__reg_k: 17 | - 1.0e-08 18 | mmdlsconsmappingadapter__reg_m: 19 | - 1.0e-08 20 | mmdlsconsmappingadapter__tol: 21 | - 1.0e-05 22 | -------------------------------------------------------------------------------- /config/solvers/MMDSConS/AmazonReview.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - AmazonReview 3 | solver: 4 | - MMDSConS: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C1000.0_Gamma0.001 8 | mmdlsconsmappingadapter__gamma: 9 | - 0.01 10 | - 0.1 11 | - 1 12 | - 10 13 | - 100 14 | mmdlsconsmappingadapter__max_iter: 15 | - 20 16 | mmdlsconsmappingadapter__reg_k: 17 | - 1.0e-08 18 | mmdlsconsmappingadapter__reg_m: 19 | - 1.0e-08 20 | mmdlsconsmappingadapter__tol: 21 | - 1.0e-05 22 | -------------------------------------------------------------------------------- /config/solvers/TarS/mnist_usps.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - mnist_usps 3 | solver: 4 | - TarS: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C10.0_Gamma0.01 8 | mmdtarsreweightadapter__gamma: 9 | - 0.0001 10 | - 0.001 11 | - 0.01 12 | - 0.1 13 | - 1.0 14 | - 10.0 15 | - 100.0 16 | - 1000.0 17 | - null 18 | mmdtarsreweightadapter__max_iter: 19 | - 1000 20 | mmdtarsreweightadapter__reg: 21 | - 1.0e-06 22 | mmdtarsreweightadapter__tol: 23 | - 1.0e-06 24 | -------------------------------------------------------------------------------- /config/solvers/TarS/20NewsGroups.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - 20NewsGroups 3 | solver: 4 | - TarS: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C10.0_Gamma10.0 8 | mmdtarsreweightadapter__gamma: 9 | - 0.0001 10 | - 0.001 11 | - 0.01 12 | - 0.1 13 | - 1.0 14 | - 10.0 15 | - 100.0 16 | - 1000.0 17 | - null 18 | mmdtarsreweightadapter__max_iter: 19 | - 1000 20 | mmdtarsreweightadapter__reg: 21 | - 1.0e-06 22 | mmdtarsreweightadapter__tol: 23 | - 1.0e-06 24 | -------------------------------------------------------------------------------- /config/solvers/MMDSConS/OfficeHomeResnet.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - OfficeHomeResnet 3 | solver: 4 | - MMDSConS: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C10.0_Gamma0.001 8 | mmdlsconsmappingadapter__gamma: 9 | - 0.01 10 | - 0.1 11 | - 1 12 | - 10 13 | - 100 14 | mmdlsconsmappingadapter__max_iter: 15 | - 20 16 | mmdlsconsmappingadapter__reg_k: 17 | - 1.0e-08 18 | mmdlsconsmappingadapter__reg_m: 19 | - 1.0e-08 20 | mmdlsconsmappingadapter__tol: 21 | - 1.0e-05 22 | -------------------------------------------------------------------------------- /config/solvers/TarS/AmazonReview.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - AmazonReview 3 | solver: 4 | - TarS: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C1000.0_Gamma0.001 8 | mmdtarsreweightadapter__gamma: 9 | - 0.0001 10 | - 0.001 11 | - 0.01 12 | - 0.1 13 | - 1.0 14 | - 10.0 15 | - 100.0 16 | - 1000.0 17 | - null 18 | mmdtarsreweightadapter__max_iter: 19 | - 1000 20 | mmdtarsreweightadapter__reg: 21 | - 1.0e-06 22 | mmdtarsreweightadapter__tol: 23 | - 1.0e-06 24 | -------------------------------------------------------------------------------- /config/solvers/entropic_ot_mapping/BCI.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - BCI 3 | solver: 4 | - entropic_ot_mapping: 5 | param_grid: 6 | - entropicotmappingadapter__max_iter: 7 | - 1000 8 | entropicotmappingadapter__metric: 9 | - sqeuclidean 10 | - cosine 11 | - cityblock 12 | entropicotmappingadapter__norm: 13 | - median 14 | entropicotmappingadapter__reg_e: 15 | - 0.1 16 | - 0.5 17 | - 1.0 18 | entropicotmappingadapter__tol: 19 | - 1.0e-06 20 | finalestimator__estimator_name: 21 | - LR_C2.0 22 | -------------------------------------------------------------------------------- /config/solvers/MMDSConS/Phishing.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Phishing 3 | solver: 4 | - MMDSConS: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - XGB_subsample0.8_colsample0.65_maxdepth20 8 | mmdlsconsmappingadapter__gamma: 9 | - 0.01 10 | - 0.1 11 | - 1 12 | - 10 13 | - 100 14 | mmdlsconsmappingadapter__max_iter: 15 | - 20 16 | mmdlsconsmappingadapter__reg_k: 17 | - 1.0e-08 18 | mmdlsconsmappingadapter__reg_m: 19 | - 1.0e-08 20 | mmdlsconsmappingadapter__tol: 21 | - 1.0e-05 22 | -------------------------------------------------------------------------------- /config/solvers/TarS/OfficeHomeResnet.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - OfficeHomeResnet 3 | solver: 4 | - TarS: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C10.0_Gamma0.001 8 | mmdtarsreweightadapter__gamma: 9 | - 0.0001 10 | - 0.001 11 | - 0.01 12 | - 0.1 13 | - 1.0 14 | - 10.0 15 | - 100.0 16 | - 1000.0 17 | - null 18 | mmdtarsreweightadapter__max_iter: 19 | - 1000 20 | mmdtarsreweightadapter__reg: 21 | - 1.0e-06 22 | mmdtarsreweightadapter__tol: 23 | - 1.0e-06 24 | -------------------------------------------------------------------------------- /config/solvers/KMM/BCI.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - BCI 3 | solver: 4 | - KMM: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - LR_C2.0 8 | kmmreweightadapter__B: 9 | - 1000.0 10 | kmmreweightadapter__gamma: 11 | - 0.0001 12 | - 0.001 13 | - 0.01 14 | - 0.1 15 | - 1.0 16 | - 10.0 17 | - 100.0 18 | - 1000.0 19 | - null 20 | kmmreweightadapter__max_iter: 21 | - 1000 22 | kmmreweightadapter__smooth_weights: 23 | - false 24 | kmmreweightadapter__tol: 25 | - 1.0e-06 26 | -------------------------------------------------------------------------------- /config/solvers/TarS/Phishing.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Phishing 3 | solver: 4 | - TarS: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - XGB_subsample0.8_colsample0.65_maxdepth20 8 | mmdtarsreweightadapter__gamma: 9 | - 0.0001 10 | - 0.001 11 | - 0.01 12 | - 0.1 13 | - 1.0 14 | - 10.0 15 | - 100.0 16 | - 1000.0 17 | - null 18 | mmdtarsreweightadapter__max_iter: 19 | - 1000 20 | mmdtarsreweightadapter__reg: 21 | - 1.0e-06 22 | mmdtarsreweightadapter__tol: 23 | - 1.0e-06 24 | -------------------------------------------------------------------------------- /config/solvers/entropic_ot_mapping/Mushrooms.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Mushrooms 3 | solver: 4 | - entropic_ot_mapping: 5 | param_grid: 6 | - entropicotmappingadapter__max_iter: 7 | - 1000 8 | entropicotmappingadapter__metric: 9 | - sqeuclidean 10 | - cosine 11 | - cityblock 12 | entropicotmappingadapter__norm: 13 | - median 14 | entropicotmappingadapter__reg_e: 15 | - 0.1 16 | - 0.5 17 | - 1.0 18 | entropicotmappingadapter__tol: 19 | - 1.0e-06 20 | finalestimator__estimator_name: 21 | - LR 22 | -------------------------------------------------------------------------------- /config/solvers/entropic_ot_mapping/Simulated.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Simulated 3 | solver: 4 | - entropic_ot_mapping: 5 | param_grid: 6 | - entropicotmappingadapter__max_iter: 7 | - 1000 8 | entropicotmappingadapter__metric: 9 | - sqeuclidean 10 | - cosine 11 | - cityblock 12 | entropicotmappingadapter__norm: 13 | - median 14 | entropicotmappingadapter__reg_e: 15 | - 0.1 16 | - 0.5 17 | - 1.0 18 | entropicotmappingadapter__tol: 19 | - 1.0e-06 20 | finalestimator__estimator_name: 21 | - SVC 22 | -------------------------------------------------------------------------------- /config/solvers/entropic_ot_mapping/Office31.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Office31 3 | solver: 4 | - entropic_ot_mapping: 5 | param_grid: 6 | - entropicotmappingadapter__max_iter: 7 | - 1000 8 | entropicotmappingadapter__metric: 9 | - sqeuclidean 10 | - cosine 11 | - cityblock 12 | entropicotmappingadapter__norm: 13 | - median 14 | entropicotmappingadapter__reg_e: 15 | - 0.1 16 | - 0.5 17 | - 1.0 18 | entropicotmappingadapter__tol: 19 | - 1.0e-06 20 | finalestimator__estimator_name: 21 | - LR_C0.01 22 | -------------------------------------------------------------------------------- /config/solvers/KMM/Mushrooms.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Mushrooms 3 | solver: 4 | - KMM: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - LR 8 | kmmreweightadapter__B: 9 | - 1000.0 10 | kmmreweightadapter__gamma: 11 | - 0.0001 12 | - 0.001 13 | - 0.01 14 | - 0.1 15 | - 1.0 16 | - 10.0 17 | - 100.0 18 | - 1000.0 19 | - null 20 | kmmreweightadapter__max_iter: 21 | - 1000 22 | kmmreweightadapter__smooth_weights: 23 | - false 24 | kmmreweightadapter__tol: 25 | - 1.0e-06 26 | -------------------------------------------------------------------------------- /config/solvers/KMM/Office31.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Office31 3 | solver: 4 | - KMM: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - LR_C0.01 8 | kmmreweightadapter__B: 9 | - 1000.0 10 | kmmreweightadapter__gamma: 11 | - 0.0001 12 | - 0.001 13 | - 0.01 14 | - 0.1 15 | - 1.0 16 | - 10.0 17 | - 100.0 18 | - 1000.0 19 | - null 20 | kmmreweightadapter__max_iter: 21 | - 1000 22 | kmmreweightadapter__smooth_weights: 23 | - false 24 | kmmreweightadapter__tol: 25 | - 1.0e-06 26 | -------------------------------------------------------------------------------- /config/solvers/KMM/Simulated.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Simulated 3 | solver: 4 | - KMM: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC 8 | kmmreweightadapter__B: 9 | - 1000.0 10 | kmmreweightadapter__gamma: 11 | - 0.0001 12 | - 0.001 13 | - 0.01 14 | - 0.1 15 | - 1.0 16 | - 10.0 17 | - 100.0 18 | - 1000.0 19 | - null 20 | kmmreweightadapter__max_iter: 21 | - 1000 22 | kmmreweightadapter__smooth_weights: 23 | - false 24 | kmmreweightadapter__tol: 25 | - 1.0e-06 26 | -------------------------------------------------------------------------------- /config/solvers/entropic_ot_mapping/20NewsGroups.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - 20NewsGroups 3 | solver: 4 | - entropic_ot_mapping: 5 | param_grid: 6 | - entropicotmappingadapter__max_iter: 7 | - 1000 8 | entropicotmappingadapter__metric: 9 | - sqeuclidean 10 | - cosine 11 | - cityblock 12 | entropicotmappingadapter__norm: 13 | - median 14 | entropicotmappingadapter__reg_e: 15 | - 0.1 16 | - 0.5 17 | - 1.0 18 | entropicotmappingadapter__tol: 19 | - 1.0e-06 20 | finalestimator__estimator_name: 21 | - SVC_C10.0_Gamma10.0 22 | -------------------------------------------------------------------------------- /config/solvers/entropic_ot_mapping/mnist_usps.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - mnist_usps 3 | solver: 4 | - entropic_ot_mapping: 5 | param_grid: 6 | - entropicotmappingadapter__max_iter: 7 | - 1000 8 | entropicotmappingadapter__metric: 9 | - sqeuclidean 10 | - cosine 11 | - cityblock 12 | entropicotmappingadapter__norm: 13 | - median 14 | entropicotmappingadapter__reg_e: 15 | - 0.1 16 | - 0.5 17 | - 1.0 18 | entropicotmappingadapter__tol: 19 | - 1.0e-06 20 | finalestimator__estimator_name: 21 | - SVC_C10.0_Gamma0.01 22 | -------------------------------------------------------------------------------- /config/solvers/KMM/mnist_usps.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - mnist_usps 3 | solver: 4 | - KMM: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C10.0_Gamma0.01 8 | kmmreweightadapter__B: 9 | - 1000.0 10 | kmmreweightadapter__gamma: 11 | - 0.0001 12 | - 0.001 13 | - 0.01 14 | - 0.1 15 | - 1.0 16 | - 10.0 17 | - 100.0 18 | - 1000.0 19 | - null 20 | kmmreweightadapter__max_iter: 21 | - 1000 22 | kmmreweightadapter__smooth_weights: 23 | - false 24 | kmmreweightadapter__tol: 25 | - 1.0e-06 26 | -------------------------------------------------------------------------------- /config/solvers/entropic_ot_mapping/AmazonReview.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - AmazonReview 3 | solver: 4 | - entropic_ot_mapping: 5 | param_grid: 6 | - entropicotmappingadapter__max_iter: 7 | - 1000 8 | entropicotmappingadapter__metric: 9 | - sqeuclidean 10 | - cosine 11 | - cityblock 12 | entropicotmappingadapter__norm: 13 | - median 14 | entropicotmappingadapter__reg_e: 15 | - 0.1 16 | - 0.5 17 | - 1.0 18 | entropicotmappingadapter__tol: 19 | - 1.0e-06 20 | finalestimator__estimator_name: 21 | - SVC_C1000.0_Gamma0.001 22 | -------------------------------------------------------------------------------- /config/solvers/KMM/20NewsGroups.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - 20NewsGroups 3 | solver: 4 | - KMM: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C10.0_Gamma10.0 8 | kmmreweightadapter__B: 9 | - 1000.0 10 | kmmreweightadapter__gamma: 11 | - 0.0001 12 | - 0.001 13 | - 0.01 14 | - 0.1 15 | - 1.0 16 | - 10.0 17 | - 100.0 18 | - 1000.0 19 | - null 20 | kmmreweightadapter__max_iter: 21 | - 1000 22 | kmmreweightadapter__smooth_weights: 23 | - false 24 | kmmreweightadapter__tol: 25 | - 1.0e-06 26 | -------------------------------------------------------------------------------- /config/solvers/entropic_ot_mapping/OfficeHomeResnet.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - OfficeHomeResnet 3 | solver: 4 | - entropic_ot_mapping: 5 | param_grid: 6 | - entropicotmappingadapter__max_iter: 7 | - 1000 8 | entropicotmappingadapter__metric: 9 | - sqeuclidean 10 | - cosine 11 | - cityblock 12 | entropicotmappingadapter__norm: 13 | - median 14 | entropicotmappingadapter__reg_e: 15 | - 0.1 16 | - 0.5 17 | - 1.0 18 | entropicotmappingadapter__tol: 19 | - 1.0e-06 20 | finalestimator__estimator_name: 21 | - SVC_C10.0_Gamma0.001 22 | -------------------------------------------------------------------------------- /config/solvers/KMM/AmazonReview.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - AmazonReview 3 | solver: 4 | - KMM: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C1000.0_Gamma0.001 8 | kmmreweightadapter__B: 9 | - 1000.0 10 | kmmreweightadapter__gamma: 11 | - 0.0001 12 | - 0.001 13 | - 0.01 14 | - 0.1 15 | - 1.0 16 | - 10.0 17 | - 100.0 18 | - 1000.0 19 | - null 20 | kmmreweightadapter__max_iter: 21 | - 1000 22 | kmmreweightadapter__smooth_weights: 23 | - false 24 | kmmreweightadapter__tol: 25 | - 1.0e-06 26 | -------------------------------------------------------------------------------- /config/solvers/entropic_ot_mapping/Phishing.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Phishing 3 | solver: 4 | - entropic_ot_mapping: 5 | param_grid: 6 | - entropicotmappingadapter__max_iter: 7 | - 1000 8 | entropicotmappingadapter__metric: 9 | - sqeuclidean 10 | - cosine 11 | - cityblock 12 | entropicotmappingadapter__norm: 13 | - median 14 | entropicotmappingadapter__reg_e: 15 | - 0.1 16 | - 0.5 17 | - 1.0 18 | entropicotmappingadapter__tol: 19 | - 1.0e-06 20 | finalestimator__estimator_name: 21 | - XGB_subsample0.8_colsample0.65_maxdepth20 22 | -------------------------------------------------------------------------------- /config/solvers/KMM/OfficeHomeResnet.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - OfficeHomeResnet 3 | solver: 4 | - KMM: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C10.0_Gamma0.001 8 | kmmreweightadapter__B: 9 | - 1000.0 10 | kmmreweightadapter__gamma: 11 | - 0.0001 12 | - 0.001 13 | - 0.01 14 | - 0.1 15 | - 1.0 16 | - 10.0 17 | - 100.0 18 | - 1000.0 19 | - null 20 | kmmreweightadapter__max_iter: 21 | - 1000 22 | kmmreweightadapter__smooth_weights: 23 | - false 24 | kmmreweightadapter__tol: 25 | - 1.0e-06 26 | -------------------------------------------------------------------------------- /config/solvers/KMM/Phishing.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Phishing 3 | solver: 4 | - KMM: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - XGB_subsample0.8_colsample0.65_maxdepth20 8 | kmmreweightadapter__B: 9 | - 1000.0 10 | kmmreweightadapter__gamma: 11 | - 0.0001 12 | - 0.001 13 | - 0.01 14 | - 0.1 15 | - 1.0 16 | - 10.0 17 | - 100.0 18 | - 1000.0 19 | - null 20 | kmmreweightadapter__max_iter: 21 | - 1000 22 | kmmreweightadapter__smooth_weights: 23 | - false 24 | kmmreweightadapter__tol: 25 | - 1.0e-06 26 | -------------------------------------------------------------------------------- /config/solvers/KLIEP/BCI.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - BCI 3 | solver: 4 | - KLIEP: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - LR_C2.0 8 | kliepreweightadapter__cv: 9 | - 5 10 | kliepreweightadapter__gamma: 11 | - 0.0001 12 | - 0.001 13 | - 0.01 14 | - 0.1 15 | - 1.0 16 | - 10.0 17 | - 100.0 18 | - 1000.0 19 | - auto 20 | - scale 21 | kliepreweightadapter__max_iter: 22 | - 1000 23 | kliepreweightadapter__n_centers: 24 | - 100 25 | kliepreweightadapter__random_state: 26 | - 0 27 | kliepreweightadapter__tol: 28 | - 1.0e-06 29 | -------------------------------------------------------------------------------- /config/solvers/KLIEP/Mushrooms.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Mushrooms 3 | solver: 4 | - KLIEP: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - LR 8 | kliepreweightadapter__cv: 9 | - 5 10 | kliepreweightadapter__gamma: 11 | - 0.0001 12 | - 0.001 13 | - 0.01 14 | - 0.1 15 | - 1.0 16 | - 10.0 17 | - 100.0 18 | - 1000.0 19 | - auto 20 | - scale 21 | kliepreweightadapter__max_iter: 22 | - 1000 23 | kliepreweightadapter__n_centers: 24 | - 100 25 | kliepreweightadapter__random_state: 26 | - 0 27 | kliepreweightadapter__tol: 28 | - 1.0e-06 29 | -------------------------------------------------------------------------------- /config/solvers/KLIEP/Office31.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Office31 3 | solver: 4 | - KLIEP: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - LR_C0.01 8 | kliepreweightadapter__cv: 9 | - 5 10 | kliepreweightadapter__gamma: 11 | - 0.0001 12 | - 0.001 13 | - 0.01 14 | - 0.1 15 | - 1.0 16 | - 10.0 17 | - 100.0 18 | - 1000.0 19 | - auto 20 | - scale 21 | kliepreweightadapter__max_iter: 22 | - 1000 23 | kliepreweightadapter__n_centers: 24 | - 100 25 | kliepreweightadapter__random_state: 26 | - 0 27 | kliepreweightadapter__tol: 28 | - 1.0e-06 29 | -------------------------------------------------------------------------------- /config/solvers/KLIEP/Simulated.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Simulated 3 | solver: 4 | - KLIEP: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC 8 | kliepreweightadapter__cv: 9 | - 5 10 | kliepreweightadapter__gamma: 11 | - 0.0001 12 | - 0.001 13 | - 0.01 14 | - 0.1 15 | - 1.0 16 | - 10.0 17 | - 100.0 18 | - 1000.0 19 | - auto 20 | - scale 21 | kliepreweightadapter__max_iter: 22 | - 1000 23 | kliepreweightadapter__n_centers: 24 | - 100 25 | kliepreweightadapter__random_state: 26 | - 0 27 | kliepreweightadapter__tol: 28 | - 1.0e-06 29 | -------------------------------------------------------------------------------- /config/solvers/KLIEP/mnist_usps.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - mnist_usps 3 | solver: 4 | - KLIEP: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C10.0_Gamma0.01 8 | kliepreweightadapter__cv: 9 | - 5 10 | kliepreweightadapter__gamma: 11 | - 0.0001 12 | - 0.001 13 | - 0.01 14 | - 0.1 15 | - 1.0 16 | - 10.0 17 | - 100.0 18 | - 1000.0 19 | - auto 20 | - scale 21 | kliepreweightadapter__max_iter: 22 | - 1000 23 | kliepreweightadapter__n_centers: 24 | - 100 25 | kliepreweightadapter__random_state: 26 | - 0 27 | kliepreweightadapter__tol: 28 | - 1.0e-06 29 | -------------------------------------------------------------------------------- /config/solvers/KLIEP/20NewsGroups.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - 20NewsGroups 3 | solver: 4 | - KLIEP: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C10.0_Gamma10.0 8 | kliepreweightadapter__cv: 9 | - 5 10 | kliepreweightadapter__gamma: 11 | - 0.0001 12 | - 0.001 13 | - 0.01 14 | - 0.1 15 | - 1.0 16 | - 10.0 17 | - 100.0 18 | - 1000.0 19 | - auto 20 | - scale 21 | kliepreweightadapter__max_iter: 22 | - 1000 23 | kliepreweightadapter__n_centers: 24 | - 100 25 | kliepreweightadapter__random_state: 26 | - 0 27 | kliepreweightadapter__tol: 28 | - 1.0e-06 29 | -------------------------------------------------------------------------------- /config/solvers/KLIEP/AmazonReview.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - AmazonReview 3 | solver: 4 | - KLIEP: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C1000.0_Gamma0.001 8 | kliepreweightadapter__cv: 9 | - 5 10 | kliepreweightadapter__gamma: 11 | - 0.0001 12 | - 0.001 13 | - 0.01 14 | - 0.1 15 | - 1.0 16 | - 10.0 17 | - 100.0 18 | - 1000.0 19 | - auto 20 | - scale 21 | kliepreweightadapter__max_iter: 22 | - 1000 23 | kliepreweightadapter__n_centers: 24 | - 100 25 | kliepreweightadapter__random_state: 26 | - 0 27 | kliepreweightadapter__tol: 28 | - 1.0e-06 29 | -------------------------------------------------------------------------------- /config/solvers/KLIEP/OfficeHomeResnet.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - OfficeHomeResnet 3 | solver: 4 | - KLIEP: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C10.0_Gamma0.001 8 | kliepreweightadapter__cv: 9 | - 5 10 | kliepreweightadapter__gamma: 11 | - 0.0001 12 | - 0.001 13 | - 0.01 14 | - 0.1 15 | - 1.0 16 | - 10.0 17 | - 100.0 18 | - 1000.0 19 | - auto 20 | - scale 21 | kliepreweightadapter__max_iter: 22 | - 1000 23 | kliepreweightadapter__n_centers: 24 | - 100 25 | kliepreweightadapter__random_state: 26 | - 0 27 | kliepreweightadapter__tol: 28 | - 1.0e-06 29 | -------------------------------------------------------------------------------- /config/solvers/KLIEP/Phishing.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Phishing 3 | solver: 4 | - KLIEP: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - XGB_subsample0.8_colsample0.65_maxdepth20 8 | kliepreweightadapter__cv: 9 | - 5 10 | kliepreweightadapter__gamma: 11 | - 0.0001 12 | - 0.001 13 | - 0.01 14 | - 0.1 15 | - 1.0 16 | - 10.0 17 | - 100.0 18 | - 1000.0 19 | - auto 20 | - scale 21 | kliepreweightadapter__max_iter: 22 | - 1000 23 | kliepreweightadapter__n_centers: 24 | - 100 25 | kliepreweightadapter__random_state: 26 | - 0 27 | kliepreweightadapter__tol: 28 | - 1.0e-06 29 | -------------------------------------------------------------------------------- /config/solvers/transfer_subspace_learning/Mushrooms.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Mushrooms 3 | solver: 4 | - transfer_subspace_learning: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - LR 8 | transfersubspacelearningadapter__base_method: 9 | - flda 10 | transfersubspacelearningadapter__length_scale: 11 | - 2 12 | transfersubspacelearningadapter__max_iter: 13 | - 4 14 | transfersubspacelearningadapter__mu: 15 | - 1 16 | transfersubspacelearningadapter__n_components: 17 | - 5 18 | - 10 19 | - 20 20 | transfersubspacelearningadapter__reg: 21 | - 0.0001 22 | transfersubspacelearningadapter__tol: 23 | - 0.0001 24 | -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | create: 8 | tags: 9 | - '**' 10 | pull_request: 11 | branches: 12 | - main 13 | schedule: 14 | # Run every day at 7:42am UTC. 15 | - cron: '42 7 * * *' 16 | 17 | jobs: 18 | benchopt_dev: 19 | uses: benchopt/template_benchmark/.github/workflows/test_benchmarks.yml@main 20 | with: 21 | benchopt_branch: benchopt@main 22 | benchopt_release: 23 | uses: benchopt/template_benchmark/.github/workflows/test_benchmarks.yml@main 24 | with: 25 | benchopt_version: latest 26 | lint: 27 | uses: benchopt/template_benchmark/.github/workflows/lint_benchmarks.yml@main -------------------------------------------------------------------------------- /config/solvers/OTLabelProp/BCI.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - BCI 3 | solver: 4 | - OTLabelProp: 5 | param_grid: 6 | - - finalestimator__estimator_name: 7 | - LR_C2.0 8 | otlabelpropadapter__metric: 9 | - sqeuclidean 10 | - cosine 11 | - cityblock 12 | otlabelpropadapter__n_iter_max: 13 | - 10000 14 | otlabelpropadapter__reg: 15 | - null 16 | - finalestimator__estimator_name: 17 | - LR_C2.0 18 | otlabelpropadapter__metric: 19 | - sqeuclidean 20 | - cosine 21 | - cityblock 22 | otlabelpropadapter__n_iter_max: 23 | - 100 24 | otlabelpropadapter__reg: 25 | - 0.1 26 | - 1 27 | -------------------------------------------------------------------------------- /config/solvers/OTLabelProp/Mushrooms.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Mushrooms 3 | solver: 4 | - OTLabelProp: 5 | param_grid: 6 | - - finalestimator__estimator_name: 7 | - LR 8 | otlabelpropadapter__metric: 9 | - sqeuclidean 10 | - cosine 11 | - cityblock 12 | otlabelpropadapter__n_iter_max: 13 | - 10000 14 | otlabelpropadapter__reg: 15 | - null 16 | - finalestimator__estimator_name: 17 | - LR 18 | otlabelpropadapter__metric: 19 | - sqeuclidean 20 | - cosine 21 | - cityblock 22 | otlabelpropadapter__n_iter_max: 23 | - 100 24 | otlabelpropadapter__reg: 25 | - 0.1 26 | - 1 27 | -------------------------------------------------------------------------------- /config/solvers/OTLabelProp/Simulated.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Simulated 3 | solver: 4 | - OTLabelProp: 5 | param_grid: 6 | - - finalestimator__estimator_name: 7 | - SVC 8 | otlabelpropadapter__metric: 9 | - sqeuclidean 10 | - cosine 11 | - cityblock 12 | otlabelpropadapter__n_iter_max: 13 | - 10000 14 | otlabelpropadapter__reg: 15 | - null 16 | - finalestimator__estimator_name: 17 | - SVC 18 | otlabelpropadapter__metric: 19 | - sqeuclidean 20 | - cosine 21 | - cityblock 22 | otlabelpropadapter__n_iter_max: 23 | - 100 24 | otlabelpropadapter__reg: 25 | - 0.1 26 | - 1 27 | -------------------------------------------------------------------------------- /requirements_all.txt: -------------------------------------------------------------------------------- 1 | # Mandatory to run benchmarks 2 | benchopt==1.6.0 3 | 4 | # Mandatory to use slurm 5 | #benchopt[slurm]==1.6.0 6 | 7 | # Mandatory for visualization scripts 8 | seaborn==0.12.2 9 | 10 | # Already required-by benchopt 11 | # pandas==1.5.3 12 | # numpy==1.26.0 13 | # matplotlib==3.8.2 14 | # scipy==1.10.1 15 | 16 | # For preprocessing scripts 17 | sentence_transformers==3.0.0 18 | skrub==0.1.1 19 | torch==2.2.1 20 | torchvision==0.17.1 21 | scikit-learn==1.5.0 22 | 23 | # Imported with benchopt install 24 | skada==0.4.0 25 | torch==2.2.1 26 | torchvision==0.17.1 27 | #POT==0.9.1 28 | xgboost==2.0.3 29 | mne==1.6.1 30 | braindecode==0.8.1 31 | moabb==0.5 32 | pyriemann==0.3 33 | sentence_transformers==3.0.0 34 | -------------------------------------------------------------------------------- /config/solvers/OTLabelProp/Office31.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Office31 3 | solver: 4 | - OTLabelProp: 5 | param_grid: 6 | - - finalestimator__estimator_name: 7 | - LR_C0.01 8 | otlabelpropadapter__metric: 9 | - sqeuclidean 10 | - cosine 11 | - cityblock 12 | otlabelpropadapter__n_iter_max: 13 | - 10000 14 | otlabelpropadapter__reg: 15 | - null 16 | - finalestimator__estimator_name: 17 | - LR_C0.01 18 | otlabelpropadapter__metric: 19 | - sqeuclidean 20 | - cosine 21 | - cityblock 22 | otlabelpropadapter__n_iter_max: 23 | - 100 24 | otlabelpropadapter__reg: 25 | - 0.1 26 | - 1 27 | -------------------------------------------------------------------------------- /config/solvers/transfer_subspace_learning/mnist_usps.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - mnist_usps 3 | solver: 4 | - transfer_subspace_learning: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C10.0_Gamma0.01 8 | transfersubspacelearningadapter__base_method: 9 | - flda 10 | transfersubspacelearningadapter__length_scale: 11 | - 2 12 | transfersubspacelearningadapter__max_iter: 13 | - 4 14 | transfersubspacelearningadapter__mu: 15 | - 1 16 | transfersubspacelearningadapter__n_components: 17 | - 5 18 | - 10 19 | - 20 20 | transfersubspacelearningadapter__reg: 21 | - 0.0001 22 | transfersubspacelearningadapter__tol: 23 | - 0.0001 24 | -------------------------------------------------------------------------------- /config/solvers/transfer_subspace_learning/20NewsGroups.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - 20NewsGroups 3 | solver: 4 | - transfer_subspace_learning: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C10.0_Gamma10.0 8 | transfersubspacelearningadapter__base_method: 9 | - flda 10 | transfersubspacelearningadapter__length_scale: 11 | - 2 12 | transfersubspacelearningadapter__max_iter: 13 | - 4 14 | transfersubspacelearningadapter__mu: 15 | - 1 16 | transfersubspacelearningadapter__n_components: 17 | - 5 18 | - 10 19 | - 20 20 | transfersubspacelearningadapter__reg: 21 | - 0.0001 22 | transfersubspacelearningadapter__tol: 23 | - 0.0001 24 | -------------------------------------------------------------------------------- /config/solvers/transfer_subspace_learning/OfficeHomeResnet.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - OfficeHomeResnet 3 | solver: 4 | - transfer_subspace_learning: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C10.0_Gamma0.001 8 | transfersubspacelearningadapter__base_method: 9 | - flda 10 | transfersubspacelearningadapter__length_scale: 11 | - 2 12 | transfersubspacelearningadapter__max_iter: 13 | - 4 14 | transfersubspacelearningadapter__mu: 15 | - 1 16 | transfersubspacelearningadapter__n_components: 17 | - 5 18 | - 10 19 | - 20 20 | transfersubspacelearningadapter__reg: 21 | - 0.0001 22 | transfersubspacelearningadapter__tol: 23 | - 0.0001 24 | -------------------------------------------------------------------------------- /config/solvers/transfer_subspace_learning/Phishing.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Phishing 3 | solver: 4 | - transfer_subspace_learning: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - XGB_subsample0.8_colsample0.65_maxdepth20 8 | transfersubspacelearningadapter__base_method: 9 | - flda 10 | transfersubspacelearningadapter__length_scale: 11 | - 2 12 | transfersubspacelearningadapter__max_iter: 13 | - 4 14 | transfersubspacelearningadapter__mu: 15 | - 1 16 | transfersubspacelearningadapter__n_components: 17 | - 5 18 | - 10 19 | - 20 20 | transfersubspacelearningadapter__reg: 21 | - 0.0001 22 | transfersubspacelearningadapter__tol: 23 | - 0.0001 24 | -------------------------------------------------------------------------------- /config/solvers/OTLabelProp/mnist_usps.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - mnist_usps 3 | solver: 4 | - OTLabelProp: 5 | param_grid: 6 | - - finalestimator__estimator_name: 7 | - SVC_C10.0_Gamma0.01 8 | otlabelpropadapter__metric: 9 | - sqeuclidean 10 | - cosine 11 | - cityblock 12 | otlabelpropadapter__n_iter_max: 13 | - 10000 14 | otlabelpropadapter__reg: 15 | - null 16 | - finalestimator__estimator_name: 17 | - SVC_C10.0_Gamma0.01 18 | otlabelpropadapter__metric: 19 | - sqeuclidean 20 | - cosine 21 | - cityblock 22 | otlabelpropadapter__n_iter_max: 23 | - 100 24 | otlabelpropadapter__reg: 25 | - 0.1 26 | - 1 27 | -------------------------------------------------------------------------------- /config/solvers/OTLabelProp/20NewsGroups.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - 20NewsGroups 3 | solver: 4 | - OTLabelProp: 5 | param_grid: 6 | - - finalestimator__estimator_name: 7 | - SVC_C10.0_Gamma10.0 8 | otlabelpropadapter__metric: 9 | - sqeuclidean 10 | - cosine 11 | - cityblock 12 | otlabelpropadapter__n_iter_max: 13 | - 10000 14 | otlabelpropadapter__reg: 15 | - null 16 | - finalestimator__estimator_name: 17 | - SVC_C10.0_Gamma10.0 18 | otlabelpropadapter__metric: 19 | - sqeuclidean 20 | - cosine 21 | - cityblock 22 | otlabelpropadapter__n_iter_max: 23 | - 100 24 | otlabelpropadapter__reg: 25 | - 0.1 26 | - 1 27 | -------------------------------------------------------------------------------- /config/solvers/OTLabelProp/AmazonReview.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - AmazonReview 3 | solver: 4 | - OTLabelProp: 5 | param_grid: 6 | - - finalestimator__estimator_name: 7 | - SVC_C1000.0_Gamma0.001 8 | otlabelpropadapter__metric: 9 | - sqeuclidean 10 | - cosine 11 | - cityblock 12 | otlabelpropadapter__n_iter_max: 13 | - 10000 14 | otlabelpropadapter__reg: 15 | - null 16 | - finalestimator__estimator_name: 17 | - SVC_C1000.0_Gamma0.001 18 | otlabelpropadapter__metric: 19 | - sqeuclidean 20 | - cosine 21 | - cityblock 22 | otlabelpropadapter__n_iter_max: 23 | - 100 24 | otlabelpropadapter__reg: 25 | - 0.1 26 | - 1 27 | -------------------------------------------------------------------------------- /config/solvers/OTLabelProp/OfficeHomeResnet.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - OfficeHomeResnet 3 | solver: 4 | - OTLabelProp: 5 | param_grid: 6 | - - finalestimator__estimator_name: 7 | - SVC_C10.0_Gamma0.001 8 | otlabelpropadapter__metric: 9 | - sqeuclidean 10 | - cosine 11 | - cityblock 12 | otlabelpropadapter__n_iter_max: 13 | - 10000 14 | otlabelpropadapter__reg: 15 | - null 16 | - finalestimator__estimator_name: 17 | - SVC_C10.0_Gamma0.001 18 | otlabelpropadapter__metric: 19 | - sqeuclidean 20 | - cosine 21 | - cityblock 22 | otlabelpropadapter__n_iter_max: 23 | - 100 24 | otlabelpropadapter__reg: 25 | - 0.1 26 | - 1 27 | -------------------------------------------------------------------------------- /config/solvers/OTLabelProp/Phishing.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Phishing 3 | solver: 4 | - OTLabelProp: 5 | param_grid: 6 | - - finalestimator__estimator_name: 7 | - XGB_subsample0.8_colsample0.65_maxdepth20 8 | otlabelpropadapter__metric: 9 | - sqeuclidean 10 | - cosine 11 | - cityblock 12 | otlabelpropadapter__n_iter_max: 13 | - 10000 14 | otlabelpropadapter__reg: 15 | - null 16 | - finalestimator__estimator_name: 17 | - XGB_subsample0.8_colsample0.65_maxdepth20 18 | otlabelpropadapter__metric: 19 | - sqeuclidean 20 | - cosine 21 | - cityblock 22 | otlabelpropadapter__n_iter_max: 23 | - 100 24 | otlabelpropadapter__reg: 25 | - 0.1 26 | - 1 27 | -------------------------------------------------------------------------------- /config/solvers/transfer_subspace_learning/BCI.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - BCI 3 | solver: 4 | - transfer_subspace_learning: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - LR_C2.0 8 | transfersubspacelearningadapter__base_method: 9 | - flda 10 | transfersubspacelearningadapter__length_scale: 11 | - 2 12 | transfersubspacelearningadapter__max_iter: 13 | - 300 14 | transfersubspacelearningadapter__mu: 15 | - 0.1 16 | - 1 17 | - 10 18 | transfersubspacelearningadapter__n_components: 19 | - 1 20 | - 2 21 | - 5 22 | - 10 23 | - 20 24 | - 50 25 | - 100 26 | transfersubspacelearningadapter__reg: 27 | - 0.0001 28 | transfersubspacelearningadapter__tol: 29 | - 0.0001 30 | -------------------------------------------------------------------------------- /config/solvers/transfer_subspace_learning/Simulated.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Simulated 3 | solver: 4 | - transfer_subspace_learning: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC 8 | transfersubspacelearningadapter__base_method: 9 | - flda 10 | transfersubspacelearningadapter__length_scale: 11 | - 2 12 | transfersubspacelearningadapter__max_iter: 13 | - 300 14 | transfersubspacelearningadapter__mu: 15 | - 0.1 16 | - 1 17 | - 10 18 | transfersubspacelearningadapter__n_components: 19 | - 1 20 | - 2 21 | - 5 22 | - 10 23 | - 20 24 | - 50 25 | - 100 26 | transfersubspacelearningadapter__reg: 27 | - 0.0001 28 | transfersubspacelearningadapter__tol: 29 | - 0.0001 30 | -------------------------------------------------------------------------------- /config/solvers/transfer_subspace_learning/Office31.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Office31 3 | solver: 4 | - transfer_subspace_learning: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - LR_C0.01 8 | transfersubspacelearningadapter__base_method: 9 | - flda 10 | transfersubspacelearningadapter__length_scale: 11 | - 2 12 | transfersubspacelearningadapter__max_iter: 13 | - 300 14 | transfersubspacelearningadapter__mu: 15 | - 0.1 16 | - 1 17 | - 10 18 | transfersubspacelearningadapter__n_components: 19 | - 1 20 | - 2 21 | - 5 22 | - 10 23 | - 20 24 | - 50 25 | - 100 26 | transfersubspacelearningadapter__reg: 27 | - 0.0001 28 | transfersubspacelearningadapter__tol: 29 | - 0.0001 30 | -------------------------------------------------------------------------------- /config/solvers/transfer_subspace_learning/AmazonReview.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - AmazonReview 3 | solver: 4 | - transfer_subspace_learning: 5 | param_grid: 6 | - finalestimator__estimator_name: 7 | - SVC_C1000.0_Gamma0.001 8 | transfersubspacelearningadapter__base_method: 9 | - flda 10 | transfersubspacelearningadapter__length_scale: 11 | - 2 12 | transfersubspacelearningadapter__max_iter: 13 | - 300 14 | transfersubspacelearningadapter__mu: 15 | - 0.1 16 | - 1 17 | - 10 18 | transfersubspacelearningadapter__n_components: 19 | - 1 20 | - 2 21 | - 5 22 | - 10 23 | - 20 24 | - 50 25 | - 100 26 | transfersubspacelearningadapter__reg: 27 | - 0.0001 28 | transfersubspacelearningadapter__tol: 29 | - 0.0001 30 | -------------------------------------------------------------------------------- /config/best_base_estimators.yml: -------------------------------------------------------------------------------- 1 | 20NewsGroups: 2 | Best: SVC_C10.0_Gamma10.0 3 | BestSVC: SVC_C10.0_Gamma10.0 4 | AmazonReview: 5 | Best: SVC_C1000.0_Gamma0.001 6 | BestSVC: SVC_C1000.0_Gamma0.001 7 | BCI: 8 | Best: LR_C2.0 9 | BestSVC: SVC_C10.0_Gamma0.1 10 | Mushrooms: 11 | Best: LR 12 | BestSVC: SVC 13 | Office31: 14 | Best: LR_C0.01 15 | BestSVC: SVC 16 | Office31Decaf: 17 | Best: LR_C0.01 18 | BestSVC: SVC 19 | OfficeHomeResnet: 20 | Best: SVC_C10.0_Gamma0.001 21 | BestSVC: SVC_C10.0_Gamma0.001 22 | Phishing: 23 | Best: XGB_subsample0.8_colsample0.65_maxdepth20 24 | BestSVC: SVC_C10.0_Gamma0.1 25 | Simulated: 26 | Best: SVC 27 | BestSVC: SVC 28 | bci_projected: 29 | Best: LR_C2.0 30 | BestSVC: SVC_C10.0_Gamma0.1 31 | mnist_usps: 32 | Best: SVC_C10.0_Gamma0.01 33 | BestSVC: SVC_C10.0_Gamma0.01 34 | mnist_usps_pca: 35 | Best: SVC_C10.0_Gamma0.01 36 | BestSVC: SVC_C10.0_Gamma0.01 37 | -------------------------------------------------------------------------------- /config/example_choleski_slurm.yaml: -------------------------------------------------------------------------------- 1 | # This is an example of a configuration file to launch the experiments on a SLURM cluster 2 | # Before running the experiments, you need to change the following parameters: 3 | # - mail-user: the email address to receive notifications 4 | # - source activate env_name: the name of the conda environment to activate 5 | 6 | slurm_time: 01:00:00 # max runtime 1 hour 7 | slurm_additional_parameters: 8 | job-name: skada-bench # Job name 9 | output: slurm_output.txt # Output file 10 | partition: cpu_shared 11 | mail-type: ALL 12 | mail-user: name@email.com # TO CHANGE TO YOUR EMAIL 13 | ntasks: 1 # Number of tasks per job 14 | cpus-per-task: 5 # requires 5 CPUs per job 15 | slurm_setup: # sbatch script commands added before the main job 16 | - module purge 17 | - echo "current directory is $(pwd)" 18 | - module load anaconda3 19 | - source activate env_name # TO CHANGE TO YOUR ENVIRONMENT NAME 20 | -------------------------------------------------------------------------------- /config/margaret_slurm.yaml: -------------------------------------------------------------------------------- 1 | # This is an example of a configuration file to launch the experiments on a SLURM cluster 2 | # Before running the experiments, you need to change the following parameters: 3 | # - mail-user: the email address to receive notifications 4 | # - source activate env_name: the name of the conda environment to activate 5 | 6 | # slurm_time: 7-00:00:00 # max runtime 7 days 7 | slurm_additional_parameters: 8 | job-name: skada-bench # Job name 9 | output: slurm_output.txt # Output file 10 | partition: normal,parietal 11 | mail-type: ALL 12 | mail-user: name@email.com # TO CHANGE TO YOUR EMAIL 13 | ntasks: 1 # Number of tasks per job 14 | cpus-per-task: 5 # requires 5 CPUs per job 15 | slurm_setup: # sbatch script commands added before the main job 16 | - module purge 17 | - echo "current directory is $(pwd)" 18 | - module load anaconda3 19 | - source activate env # TO CHANGE TO YOUR ENVIRONMENT NAME 20 | -------------------------------------------------------------------------------- /config/margaret_slurm_base_estimators.yaml: -------------------------------------------------------------------------------- 1 | # This is an example of a configuration file to launch the experiments on a SLURM cluster 2 | # Before running the experiments, you need to change the following parameters: 3 | # - mail-user: the email address to receive notifications 4 | # - source activate env_name: the name of the conda environment to activate 5 | 6 | slurm_time: 7-00:00:00 # max runtime 7 days 7 | slurm_additional_parameters: 8 | job-name: skada-bench # Job name 9 | output: slurm_output.txt # Output file 10 | partition: normal,parietal 11 | mail-type: ALL 12 | mail-user: name@email.com # TO CHANGE TO YOUR EMAIL 13 | ntasks: 1 # Number of tasks per job 14 | cpus-per-task: 5 # requires 5 CPUs per job 15 | slurm_setup: # sbatch script commands added before the main job 16 | - module purge 17 | - echo "current directory is $(pwd)" 18 | - module load anaconda3 19 | - source activate env # TO CHANGE TO YOUR ENVIRONMENT NAME 20 | -------------------------------------------------------------------------------- /config/Bench_Simulated.yml: -------------------------------------------------------------------------------- 1 | # This is the Config file for the Simulated dataset 2 | 3 | # We define the solvers 4 | solver: 5 | # Baseline 6 | - NO_DA_TARGET_ONLY 7 | - NO_DA_SOURCE_ONLY 8 | 9 | # Reweighting 10 | - KLIEP 11 | - density_reweight 12 | - gaussian_reweight 13 | - discriminator_reweight 14 | - KMM 15 | - nearest_neighbor_reweight 16 | - TarS 17 | 18 | # Mapping 19 | - CORAL 20 | - MMDSConS 21 | - ot_mapping 22 | - entropic_ot_mapping 23 | - linear_ot_mapping 24 | - class_regularizer_ot_mapping 25 | 26 | # Subspace 27 | - PCA 28 | - subspace_alignment 29 | - transfer_component_analysis 30 | - transfer_joint_matching 31 | - transfer_subspace_learning 32 | 33 | # Other 34 | - DASVM 35 | - JDOT_SVC 36 | - OTLabelProp 37 | 38 | # We define the datasets 39 | dataset: 40 | - Simulated 41 | # debug 42 | # - Simulated[random_state=0,shift=covariate_shift,label=binary] 43 | 44 | -------------------------------------------------------------------------------- /config/example_config_simulated.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Simulated[label=binary,n_samples_source=100,n_samples_target=100,random_state=0,shift=covariate_shift] 3 | solver: 4 | # - KLIEP[param_grid=None] 5 | - KMM[param_grid=default] 6 | # - class_regularizer_ot_mapping[param_grid=None] 7 | # - CORAL[param_grid=default] 8 | # param_grid: 9 | # - finalestimator__estimator_name: 10 | # - LR 11 | # 'coraladapter__reg': 12 | # - auto 13 | # - discriminator_reweight[param_grid=None] 14 | # - entropic_ot_mapping[param_grid=None] 15 | # - gaussian_reweight[param_grid=None] 16 | # - linear_ot_mapping[param_grid=None] 17 | # - MMDSConS[param_grid=None] 18 | # - NO_DA_SOURCE_ONLY[param_grid=None] 19 | # - NO_DA_TARGET_ONLY[param_grid=None] 20 | # - ot_mapping[param_grid=None] 21 | # - PCA[param_grid=None] 22 | # - subspace_alignment[param_grid=None] 23 | # - TarS[param_grid=None] 24 | # - transfer_component_analysis[param_grid=None] 25 | # - transfer_joint_matching[param_grid=None] 26 | -------------------------------------------------------------------------------- /select_base_estimators.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Default SLURM config file 4 | SLURM_CONFIG="" 5 | 6 | # Parse command line arguments 7 | while [[ "$#" -gt 0 ]]; do 8 | case $1 in 9 | --slurm) SLURM_CONFIG="$2"; shift ;; 10 | *) echo "Unknown parameter passed: $1"; exit 1 ;; 11 | esac 12 | shift 13 | done 14 | 15 | # Set the config file for base estimator experiments for No DA Source Only 16 | python benchmark_utils/generate_base_estim_config.py 17 | 18 | # Run base estimator experiments for No DA Source Only. Store the results in `results_base_estimators/` 19 | if [ -n "$SLURM_CONFIG" ]; then 20 | benchopt run --config config/find_best_base_estimators_per_dataset.yml --slurm $SLURM_CONFIG --no-plot --no-html 21 | else 22 | benchopt run --config config/find_best_base_estimators_per_dataset.yml --no-plot --no-html 23 | fi 24 | 25 | # Clean the results and store them in `results_base_estimators/results_base_estim_experiments.csv` 26 | python visualize/convert_benchopt_output_to_readable_csv.py --domain source --directory outputs --output results_base_estimators --file_name results_base_estim_experiments 27 | 28 | # Find the best base estimator and best SVC per dataset and store them in `config/best_base_estimators.yml` 29 | python benchmark_utils/extract_best_base_estim.py 30 | 31 | # Update the config file per dataset with the best base estimator as final estimator 32 | python benchmark_utils/generate_config_per_dataset.py 33 | -------------------------------------------------------------------------------- /solvers/no_da_source_only_base_estim.py: -------------------------------------------------------------------------------- 1 | from benchopt import safe_import_context 2 | 3 | # Protect the import with `safe_import_context()`. This allows: 4 | # - skipping import to speed up autocompletion in CLI. 5 | # - getting requirements info when all dependencies are not installed. 6 | with safe_import_context() as import_ctx: 7 | from benchmark_utils.base_solver import DASolver, FinalEstimator 8 | from skada.base import SelectSource 9 | from skada import make_da_pipeline 10 | from skada.metrics import SupervisedScorer 11 | 12 | from benchmark_utils.base_solver import import_ctx as base_import_ctx 13 | if base_import_ctx.failed_import: 14 | exc, val, tb = base_import_ctx.import_error 15 | raise exc(val).with_traceback(tb) 16 | 17 | 18 | # The benchmark solvers must be named `Solver` and 19 | # inherit from `BaseSolver` for `benchopt` to work properly. 20 | class Solver(DASolver): 21 | # Name to select the solver in the CLI and to display the results. 22 | name = 'NO_DA_SOURCE_ONLY_BASE_ESTIM' 23 | 24 | default_param_grid = { 25 | 'finalestimator__estimator_name': ["LR", "SVC", "XGB"] 26 | } 27 | 28 | def get_estimator(self, **kwargs): 29 | self.criterions = { 30 | 'supervised': SupervisedScorer(), 31 | } 32 | 33 | # The estimator passed should have a 'predict_proba' method. 34 | return make_da_pipeline( 35 | ('finalestimator', SelectSource(FinalEstimator())), 36 | ) 37 | -------------------------------------------------------------------------------- /solvers/gaussian_reweight.py: -------------------------------------------------------------------------------- 1 | from benchopt import safe_import_context 2 | 3 | # Protect the import with `safe_import_context()`. This allows: 4 | # - skipping import to speed up autocompletion in CLI. 5 | # - getting requirements info when all dependencies are not installed. 6 | with safe_import_context() as import_ctx: 7 | from skada import GaussianReweightAdapter, make_da_pipeline 8 | from benchmark_utils.base_solver import DASolver, FinalEstimator 9 | 10 | from benchmark_utils.base_solver import import_ctx as base_import_ctx 11 | if base_import_ctx.failed_import: 12 | exc, val, tb = base_import_ctx.import_error 13 | raise exc(val).with_traceback(tb) 14 | 15 | 16 | # The benchmark solvers must be named `Solver` and 17 | # inherit from `BaseSolver` for `benchopt` to work properly. 18 | class Solver(DASolver): 19 | # Name to select the solver in the CLI and to display the results. 20 | name = 'gaussian_reweight' 21 | 22 | # List of parameters for the solver. The benchmark will consider 23 | # the cross product for each key in the dictionary. 24 | # All parameters 'p' defined here are available as 'self.p'. 25 | default_param_grid = { 26 | 'gaussianreweightadapter__reg': ["auto"], 27 | 'finalestimator__estimator_name': ["LR", "SVC", "XGB"], 28 | } 29 | 30 | def get_estimator(self, **kwargs): 31 | # The estimator passed should have a 'predict_proba' method. 32 | return make_da_pipeline( 33 | GaussianReweightAdapter(), 34 | FinalEstimator(), 35 | ) 36 | -------------------------------------------------------------------------------- /solvers/coral.py: -------------------------------------------------------------------------------- 1 | from benchopt import safe_import_context 2 | 3 | # Protect the import with `safe_import_context()`. This allows: 4 | # - skipping import to speed up autocompletion in CLI. 5 | # - getting requirements info when all dependencies are not installed. 6 | with safe_import_context() as import_ctx: 7 | from skada import CORALAdapter, make_da_pipeline 8 | from benchmark_utils.base_solver import DASolver, FinalEstimator 9 | 10 | from benchmark_utils.base_solver import import_ctx as base_import_ctx 11 | if base_import_ctx.failed_import: 12 | exc, val, tb = base_import_ctx.import_error 13 | raise exc(val).with_traceback(tb) 14 | 15 | 16 | # The benchmark solvers must be named `Solver` and 17 | # inherit from `BaseSolver` for `benchopt` to work properly. 18 | class Solver(DASolver): 19 | # Name to select the solver in the CLI and to display the results. 20 | name = 'CORAL' 21 | 22 | # List of parameters for the solver. The benchmark will consider 23 | # the cross product for each key in the dictionary. 24 | # All parameters 'p' defined here are available as 'self.p'. 25 | default_param_grid = { 26 | 'coraladapter__reg': ["auto"], 27 | 'coraladapter__assume_centered': [False, True], 28 | 'finalestimator__estimator_name': ["LR", "SVC", "XGB"], 29 | } 30 | 31 | def get_estimator(self, **kwargs): 32 | # return CORAL() 33 | # The estimator passed should have a 'predict_proba' method. 34 | return make_da_pipeline( 35 | CORALAdapter(), 36 | FinalEstimator(), 37 | ) 38 | -------------------------------------------------------------------------------- /solvers/subspace_alignment.py: -------------------------------------------------------------------------------- 1 | from benchopt import safe_import_context 2 | 3 | # Protect the import with `safe_import_context()`. This allows: 4 | # - skipping import to speed up autocompletion in CLI. 5 | # - getting requirements info when all dependencies are not installed. 6 | with safe_import_context() as import_ctx: 7 | from skada import SubspaceAlignmentAdapter, make_da_pipeline 8 | from benchmark_utils.base_solver import DASolver, FinalEstimator 9 | 10 | from benchmark_utils.base_solver import import_ctx as base_import_ctx 11 | if base_import_ctx.failed_import: 12 | exc, val, tb = base_import_ctx.import_error 13 | raise exc(val).with_traceback(tb) 14 | 15 | 16 | # The benchmark solvers must be named `Solver` and 17 | # inherit from `BaseSolver` for `benchopt` to work properly. 18 | class Solver(DASolver): 19 | # Name to select the solver in the CLI and to display the results. 20 | name = 'subspace_alignment' 21 | 22 | # List of parameters for the solver. The benchmark will consider 23 | # the cross product for each key in the dictionary. 24 | # All parameters 'p' defined here are available as 'self.p'. 25 | default_param_grid = { 26 | 'subspacealignmentadapter__n_components': [1, 2, 5, 10, 20, 50, 100], 27 | 'finalestimator__estimator_name': ["LR", "SVC", "XGB"], 28 | } 29 | 30 | def get_estimator(self, **kwargs): 31 | # The estimator passed should have a 'predict_proba' method. 32 | return make_da_pipeline( 33 | SubspaceAlignmentAdapter(), 34 | FinalEstimator(), 35 | ) 36 | -------------------------------------------------------------------------------- /benchmark_utils/digit_no_da_experiment.py: -------------------------------------------------------------------------------- 1 | """ File to test the models on the digit dataset 2 | - preprocessed data are available in data/digit.pkl file 3 | - a model is trained on the preprocessed data 4 | - Hyperparameters are tuned using GridSearchCV 5 | - Model is evaluated on the test data 6 | """ 7 | 8 | import pickle 9 | from sklearn.decomposition import PCA # noqa: F401 10 | from sklearn.model_selection import GridSearchCV, train_test_split 11 | from sklearn.pipeline import make_pipeline 12 | from sklearn.svm import SVC 13 | 14 | 15 | # Load the preprocessed data 16 | with open('data/digit.pkl', 'rb') as f: 17 | data = pickle.load(f) 18 | mnist = data['svhn'] 19 | X, y = mnist['X'], mnist['y'] 20 | 21 | # Split the data into train and test sets 22 | X_train, X_test, y_train, y_test = train_test_split( 23 | X, y, 24 | test_size=0.2, 25 | random_state=42, 26 | stratify=y 27 | ) 28 | X_train = X_train[::5] 29 | y_train = y_train[::5] 30 | print(f"Train data shape: {X_train.shape}") 31 | print(f"Test data shape: {X_test.shape}") 32 | 33 | # Create a pipeline 34 | # pipe = make_pipeline(PCA(whiten=True), SVC(kernel='rbf', C=10, gamma=0.001)) 35 | pipe = make_pipeline(SVC(kernel='rbf', C=100, gamma=0.01)) 36 | 37 | # Perform GridSearchCV 38 | param_grid = { 39 | 'svc__C': [100], 40 | 'svc__gamma': [0.01] 41 | } 42 | grid = GridSearchCV(pipe, param_grid, cv=5, n_jobs=-1) 43 | grid.fit(X_train, y_train) 44 | print(f"Best parameters: {grid.best_params_}") 45 | print(f"Training accuracy: {grid.score(X_train, y_train)*100:.2f}%") 46 | print(f"Test accuracy: {grid.score(X_test, y_test)*100:.2f}%") 47 | -------------------------------------------------------------------------------- /solvers/kmm.py: -------------------------------------------------------------------------------- 1 | from benchopt import safe_import_context 2 | 3 | # Protect the import with `safe_import_context()`. This allows: 4 | # - skipping import to speed up autocompletion in CLI. 5 | # - getting requirements info when all dependencies are not installed. 6 | with safe_import_context() as import_ctx: 7 | from skada import KMMReweightAdapter, make_da_pipeline 8 | from benchmark_utils.base_solver import DASolver, FinalEstimator 9 | 10 | from benchmark_utils.base_solver import import_ctx as base_import_ctx 11 | if base_import_ctx.failed_import: 12 | exc, val, tb = base_import_ctx.import_error 13 | raise exc(val).with_traceback(tb) 14 | 15 | 16 | # The benchmark solvers must be named `Solver` and 17 | # inherit from `BaseSolver` for `benchopt` to work properly. 18 | class Solver(DASolver): 19 | 20 | # Name to select the solver in the CLI and to display the results. 21 | name = 'KMM' 22 | 23 | default_param_grid = { 24 | 'kmmreweightadapter__gamma': [ 25 | 0.0001, 0.001, 0.01, 0.1, 1.0, 10.0, 100., 1000., None 26 | ], 27 | 'kmmreweightadapter__B': [1000.0], 28 | 'kmmreweightadapter__tol': [1e-6], 29 | 'kmmreweightadapter__max_iter': [1000], 30 | 'kmmreweightadapter__smooth_weights': [False], 31 | 'finalestimator__estimator_name': ["LR", "SVC", "XGB"], 32 | } 33 | 34 | def get_estimator(self, **kwargs): 35 | # The estimator passed should have a 'predict_proba' method. 36 | return make_da_pipeline( 37 | KMMReweightAdapter(), 38 | FinalEstimator(), 39 | ) 40 | -------------------------------------------------------------------------------- /solvers/kliep.py: -------------------------------------------------------------------------------- 1 | from benchopt import safe_import_context 2 | 3 | # Protect the import with `safe_import_context()`. This allows: 4 | # - skipping import to speed up autocompletion in CLI. 5 | # - getting requirements info when all dependencies are not installed. 6 | with safe_import_context() as import_ctx: 7 | from skada import KLIEPReweightAdapter, make_da_pipeline 8 | from benchmark_utils.base_solver import DASolver, FinalEstimator 9 | 10 | from benchmark_utils.base_solver import import_ctx as base_import_ctx 11 | if base_import_ctx.failed_import: 12 | exc, val, tb = base_import_ctx.import_error 13 | raise exc(val).with_traceback(tb) 14 | 15 | 16 | # The benchmark solvers must be named `Solver` and 17 | # inherit from `BaseSolver` for `benchopt` to work properly. 18 | class Solver(DASolver): 19 | 20 | # Name to select the solver in the CLI and to display the results. 21 | name = 'KLIEP' 22 | 23 | default_param_grid = { 24 | 'kliepreweightadapter__gamma': [ 25 | 0.0001, 0.001, 0.01, 0.1, 1., 10., 100., 1000., 'auto', 'scale' 26 | ], 27 | 'kliepreweightadapter__n_centers': [100], 28 | 'kliepreweightadapter__cv': [5], 29 | 'kliepreweightadapter__tol': [1e-6], 30 | 'kliepreweightadapter__max_iter': [1000], 31 | 'kliepreweightadapter__random_state': [0], 32 | 'finalestimator__estimator_name': ["LR", "SVC", "XGB"], 33 | } 34 | 35 | def get_estimator(self, **kwargs): 36 | return make_da_pipeline( 37 | KLIEPReweightAdapter(gamma=None), 38 | FinalEstimator(), 39 | ) 40 | -------------------------------------------------------------------------------- /solvers/linear_ot_mapping.py: -------------------------------------------------------------------------------- 1 | from benchopt import safe_import_context 2 | 3 | # Protect the import with `safe_import_context()`. This allows: 4 | # - skipping import to speed up autocompletion in CLI. 5 | # - getting requirements info when all dependencies are not installed. 6 | with safe_import_context() as import_ctx: 7 | from skada import LinearOTMappingAdapter, make_da_pipeline 8 | from benchmark_utils.base_solver import DASolver, FinalEstimator 9 | 10 | from benchmark_utils.base_solver import import_ctx as base_import_ctx 11 | if base_import_ctx.failed_import: 12 | exc, val, tb = base_import_ctx.import_error 13 | raise exc(val).with_traceback(tb) 14 | 15 | 16 | # The benchmark solvers must be named `Solver` and 17 | # inherit from `BaseSolver` for `benchopt` to work properly. 18 | class Solver(DASolver): 19 | # Name to select the solver in the CLI and to display the results. 20 | name = 'linear_ot_mapping' 21 | 22 | # List of parameters for the solver. The benchmark will consider 23 | # the cross product for each key in the dictionary. 24 | # All parameters 'p' defined here are available as 'self.p'. 25 | default_param_grid = { 26 | 'linearotmappingadapter__reg': [1e-08, 1e-06, 0.1, 1, 10], 27 | 'linearotmappingadapter__bias': [True, False], 28 | 'finalestimator__estimator_name': ["LR", "SVC", "XGB"], 29 | } 30 | 31 | def get_estimator(self, **kwargs): 32 | # The estimator passed should have a 'predict_proba' method. 33 | return make_da_pipeline( 34 | LinearOTMappingAdapter(), 35 | FinalEstimator(), 36 | ) 37 | -------------------------------------------------------------------------------- /solvers/ot_mapping.py: -------------------------------------------------------------------------------- 1 | from benchopt import safe_import_context 2 | 3 | # Protect the import with `safe_import_context()`. This allows: 4 | # - skipping import to speed up autocompletion in CLI. 5 | # - getting requirements info when all dependencies are not installed. 6 | with safe_import_context() as import_ctx: 7 | from skada import OTMappingAdapter, make_da_pipeline 8 | from benchmark_utils.base_solver import DASolver, FinalEstimator 9 | 10 | from benchmark_utils.base_solver import import_ctx as base_import_ctx 11 | if base_import_ctx.failed_import: 12 | exc, val, tb = base_import_ctx.import_error 13 | raise exc(val).with_traceback(tb) 14 | 15 | 16 | # The benchmark solvers must be named `Solver` and 17 | # inherit from `BaseSolver` for `benchopt` to work properly. 18 | class Solver(DASolver): 19 | # Name to select the solver in the CLI and to display the results. 20 | name = 'ot_mapping' 21 | 22 | # List of parameters for the solver. The benchmark will consider 23 | # the cross product for each key in the dictionary. 24 | # All parameters 'p' defined here are available as 'self.p'. 25 | default_param_grid = { 26 | 'otmappingadapter__metric': ['sqeuclidean', 'cosine', 'cityblock'], 27 | 'otmappingadapter__norm': ['median'], 28 | 'otmappingadapter__max_iter': [1_000_000], 29 | 'finalestimator__estimator_name': ["LR", "SVC", "XGB"], 30 | } 31 | 32 | def get_estimator(self, **kwargs): 33 | # The estimator passed should have a 'predict_proba' method. 34 | return make_da_pipeline( 35 | OTMappingAdapter(), 36 | FinalEstimator(), 37 | ) 38 | -------------------------------------------------------------------------------- /solvers/no_da_source_only.py: -------------------------------------------------------------------------------- 1 | from benchopt import safe_import_context 2 | 3 | # Protect the import with `safe_import_context()`. This allows: 4 | # - skipping import to speed up autocompletion in CLI. 5 | # - getting requirements info when all dependencies are not installed. 6 | with safe_import_context() as import_ctx: 7 | from benchmark_utils.base_solver import DASolver, FinalEstimator 8 | from skada.base import SelectSource 9 | from skada import make_da_pipeline 10 | 11 | from benchmark_utils.scorers import SupervisedScorer 12 | from benchmark_utils.base_solver import import_ctx as base_import_ctx 13 | if base_import_ctx.failed_import: 14 | exc, val, tb = base_import_ctx.import_error 15 | raise exc(val).with_traceback(tb) 16 | 17 | 18 | # The benchmark solvers must be named `Solver` and 19 | # inherit from `BaseSolver` for `benchopt` to work properly. 20 | class Solver(DASolver): 21 | # Name to select the solver in the CLI and to display the results. 22 | name = 'NO_DA_SOURCE_ONLY' 23 | 24 | # List of parameters for the solver. The benchmark will consider 25 | # the cross product for each key in the dictionary. 26 | # All parameters 'p' defined here are available as 'self.p'. 27 | default_param_grid = { 28 | 'finalestimator__estimator_name': ["LR", "SVC", "XGB"] 29 | } 30 | 31 | def get_estimator(self, **kwargs): 32 | self.criterions = { 33 | 'supervised': SupervisedScorer(), 34 | } 35 | # The estimator passed should have a 'predict_proba' method. 36 | return make_da_pipeline( 37 | ('finalestimator', SelectSource(FinalEstimator())), 38 | ) 39 | -------------------------------------------------------------------------------- /solvers/no_da_target_only.py: -------------------------------------------------------------------------------- 1 | from benchopt import safe_import_context 2 | 3 | # Protect the import with `safe_import_context()`. This allows: 4 | # - skipping import to speed up autocompletion in CLI. 5 | # - getting requirements info when all dependencies are not installed. 6 | with safe_import_context() as import_ctx: 7 | from benchmark_utils.base_solver import DASolver, FinalEstimator 8 | from skada.base import SelectTarget 9 | from skada import make_da_pipeline 10 | 11 | from benchmark_utils.scorers import SupervisedScorer 12 | from benchmark_utils.base_solver import import_ctx as base_import_ctx 13 | if base_import_ctx.failed_import: 14 | exc, val, tb = base_import_ctx.import_error 15 | raise exc(val).with_traceback(tb) 16 | 17 | 18 | # The benchmark solvers must be named `Solver` and 19 | # inherit from `BaseSolver` for `benchopt` to work properly. 20 | class Solver(DASolver): 21 | # Name to select the solver in the CLI and to display the results. 22 | name = 'NO_DA_TARGET_ONLY' 23 | 24 | # List of parameters for the solver. The benchmark will consider 25 | # the cross product for each key in the dictionary. 26 | # All parameters 'p' defined here are available as 'self.p'. 27 | default_param_grid = { 28 | 'finalestimator__estimator_name': ["LR", "SVC", "XGB"] 29 | } 30 | 31 | def get_estimator(self, **kwargs): 32 | self.criterions = { 33 | 'supervised': SupervisedScorer(), 34 | } 35 | # The estimator passed should have a 'predict_proba' method. 36 | return make_da_pipeline( 37 | ('finalestimator', SelectTarget(FinalEstimator())), 38 | ) 39 | -------------------------------------------------------------------------------- /solvers/nearest_neighbor_reweight.py: -------------------------------------------------------------------------------- 1 | from benchopt import safe_import_context 2 | 3 | # Protect the import with `safe_import_context()`. This allows: 4 | # - skipping import to speed up autocompletion in CLI. 5 | # - getting requirements info when all dependencies are not installed. 6 | with safe_import_context() as import_ctx: 7 | from skada import NearestNeighborReweightAdapter, make_da_pipeline 8 | from benchmark_utils.base_solver import DASolver, FinalEstimator 9 | 10 | from benchmark_utils.base_solver import import_ctx as base_import_ctx 11 | if base_import_ctx.failed_import: 12 | exc, val, tb = base_import_ctx.import_error 13 | raise exc(val).with_traceback(tb) 14 | 15 | 16 | # The benchmark solvers must be named `Solver` and 17 | # inherit from `BaseSolver` for `benchopt` to work properly. 18 | class Solver(DASolver): 19 | # Name to select the solver in the CLI and to display the results. 20 | name = 'nearest_neighbor_reweight' 21 | 22 | # List of parameters for the solver. The benchmark will consider 23 | # the cross product for each key in the dictionary. 24 | # All parameters 'p' defined here are available as 'self.p'. 25 | default_param_grid = { 26 | 'nearestneighborreweightadapter__n_neighbors': [1], 27 | 'nearestneighborreweightadapter__laplace_smoothing': [True, False], 28 | 'finalestimator__estimator_name': ["LR", "SVC", "XGB"], 29 | } 30 | 31 | def get_estimator(self, **kwargs): 32 | # The estimator passed should have a 'predict_proba' method. 33 | return make_da_pipeline( 34 | NearestNeighborReweightAdapter(), 35 | FinalEstimator(), 36 | ) 37 | -------------------------------------------------------------------------------- /benchmark_utils/scorers.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script imports and initializes various scoring metrics used for 3 | evaluating domain adaptation techniques. 4 | 5 | Adding a New Scorer 6 | ------------------- 7 | 8 | To add a new scorer to the script: 9 | 10 | 1. **Import the Scorer**: Add the import statement for the new scorer class. 11 | 2. **Update `CRITERIONS` Dictionary**: Add a new entry in the `CRITERIONS` 12 | dictionary with the scorer's name and its initialized instance. 13 | 14 | This method ensures that all scorers are organized and easily accessible 15 | for evaluating domain adaptation methods. 16 | """ 17 | from benchopt import safe_import_context 18 | 19 | # Protect the import with `safe_import_context()`. This allows: 20 | # - skipping import to speed up autocompletion in CLI. 21 | # - getting requirements info when all dependencies are not installed. 22 | with safe_import_context() as import_ctx: 23 | from skada.metrics import ( 24 | SupervisedScorer, 25 | PredictionEntropyScorer, 26 | ImportanceWeightedScorer, 27 | SoftNeighborhoodDensity, 28 | DeepEmbeddedValidation, 29 | CircularValidation, 30 | MixValScorer, 31 | ) 32 | 33 | 34 | CRITERIONS = { 35 | 'supervised': SupervisedScorer(), 36 | 'prediction_entropy': PredictionEntropyScorer(), 37 | 'importance_weighted': ImportanceWeightedScorer(), 38 | 'soft_neighborhood_density': SoftNeighborhoodDensity(), 39 | 'deep_embedded_validation': DeepEmbeddedValidation(), 40 | 'circular_validation': CircularValidation(), 41 | 'mix_val_both': MixValScorer(ice_type='both'), 42 | 'mix_val_inter': MixValScorer(ice_type='inter'), 43 | 'mix_val_intra': MixValScorer(ice_type='intra'), 44 | } 45 | -------------------------------------------------------------------------------- /solvers/discriminator_reweight.py: -------------------------------------------------------------------------------- 1 | from benchopt import safe_import_context 2 | 3 | # Protect the import with `safe_import_context()`. This allows: 4 | # - skipping import to speed up autocompletion in CLI. 5 | # - getting requirements info when all dependencies are not installed. 6 | with safe_import_context() as import_ctx: 7 | from skada import DiscriminatorReweightAdapter, make_da_pipeline 8 | from benchmark_utils.base_solver import DASolver, FinalEstimator 9 | 10 | from benchmark_utils.base_solver import import_ctx as base_import_ctx 11 | if base_import_ctx.failed_import: 12 | exc, val, tb = base_import_ctx.import_error 13 | raise exc(val).with_traceback(tb) 14 | 15 | 16 | # The benchmark solvers must be named `Solver` and 17 | # inherit from `BaseSolver` for `benchopt` to work properly. 18 | class Solver(DASolver): 19 | # Name to select the solver in the CLI and to display the results. 20 | name = 'discriminator_reweight' 21 | 22 | # List of parameters for the solver. The benchmark will consider 23 | # the cross product for each key in the dictionary. 24 | # All parameters 'p' defined here are available as 'self.p'. 25 | 26 | default_param_grid = { 27 | 'discriminatorreweightadapter__domain_classifier__estimator_name': [ 28 | "LR", "SVC", "KNN", "XGB" 29 | ], 30 | 'finalestimator__estimator_name': ["LR", "SVC", "XGB"], 31 | } 32 | 33 | def get_estimator(self, **kwargs): 34 | # The estimator passed should have a 'predict_proba' method. 35 | return make_da_pipeline( 36 | DiscriminatorReweightAdapter( 37 | domain_classifier=FinalEstimator() 38 | ), 39 | FinalEstimator(), 40 | ) 41 | -------------------------------------------------------------------------------- /solvers/pca.py: -------------------------------------------------------------------------------- 1 | from benchopt import safe_import_context 2 | 3 | # Protect the import with `safe_import_context()`. This allows: 4 | # - skipping import to speed up autocompletion in CLI. 5 | # - getting requirements info when all dependencies are not installed. 6 | with safe_import_context() as import_ctx: 7 | from skada import make_da_pipeline 8 | from skada.base import SelectSource 9 | from sklearn.decomposition import PCA 10 | from benchmark_utils.base_solver import DASolver, FinalEstimator 11 | 12 | from benchmark_utils.base_solver import import_ctx as base_import_ctx 13 | if base_import_ctx.failed_import: 14 | exc, val, tb = base_import_ctx.import_error 15 | raise exc(val).with_traceback(tb) 16 | 17 | 18 | # The benchmark solvers must be named `Solver` and 19 | # inherit from `BaseSolver` for `benchopt` to work properly. 20 | class Solver(DASolver): 21 | # Name to select the solver in the CLI and to display the results. 22 | name = 'PCA' 23 | 24 | # List of parameters for the solver. The benchmark will consider 25 | # the cross product for each key in the dictionary. 26 | # All parameters 'p' defined here are available as 'self.p'. 27 | default_param_grid = { 28 | 'pca__n_components': [1, 2, 5, 10, 20, 50, 100], 29 | 'finalestimator__estimator_name': ["LR", "SVC", "XGB"], 30 | } 31 | # Raise an error if n_components > min(n_samples, n_features) 32 | # and doesnt save the result in the benchmark results 33 | 34 | def get_estimator(self, **kwargs): 35 | # The estimator passed should have a 'predict_proba' method. 36 | return make_da_pipeline( 37 | SelectSource(PCA()), 38 | SelectSource(FinalEstimator()), 39 | ) 40 | -------------------------------------------------------------------------------- /solvers/deep_dan.py: -------------------------------------------------------------------------------- 1 | from benchopt import safe_import_context 2 | 3 | # Protect the import with `safe_import_context()`. This allows: 4 | # - skipping import to speed up autocompletion in CLI. 5 | # - getting requirements info when all dependencies are not installed. 6 | with safe_import_context() as import_ctx: 7 | from benchmark_utils.deep_base_solver import DeepDASolver 8 | from benchmark_utils.utils import get_params_per_dataset 9 | from skada.deep import DAN 10 | 11 | from benchmark_utils.deep_base_solver import import_ctx as base_import_ctx 12 | if base_import_ctx.failed_import: 13 | exc, val, tb = base_import_ctx.import_error 14 | raise exc(val).with_traceback(tb) 15 | 16 | 17 | # The benchmark solvers must be named `Solver` and 18 | # inherit from `BaseSolver` for `benchopt` to work properly. 19 | class Solver(DeepDASolver): 20 | # Name to select the solver in the CLI and to display the results. 21 | name = 'deep_dan' 22 | 23 | # List of parameters for the solver. The benchmark will consider 24 | # the cross product for each key in the dictionary. 25 | # All parameters 'p' defined here are available as 'self.p'. 26 | default_param_grid = { 27 | 'criterion__reg': [1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1, 1e1, 1e2, 1e3], 28 | } 29 | 30 | def get_estimator(self, n_classes, device, dataset_name, **kwargs): 31 | dataset_name = dataset_name.split("[")[0].lower() 32 | 33 | params = get_params_per_dataset( 34 | dataset_name, n_classes, 35 | ) 36 | 37 | net = DAN( 38 | **params, 39 | layer_name="feature_layer", 40 | train_split=None, 41 | device=device, 42 | warm_start=True, 43 | ) 44 | 45 | return net 46 | -------------------------------------------------------------------------------- /solvers/tars.py: -------------------------------------------------------------------------------- 1 | from benchopt import safe_import_context 2 | 3 | # Protect the import with `safe_import_context()`. This allows: 4 | # - skipping import to speed up autocompletion in CLI. 5 | # - getting requirements info when all dependencies are not installed. 6 | with safe_import_context() as import_ctx: 7 | from skada import MMDTarSReweightAdapter, make_da_pipeline 8 | from benchmark_utils.base_solver import DASolver, FinalEstimator 9 | 10 | from benchmark_utils.base_solver import import_ctx as base_import_ctx 11 | if base_import_ctx.failed_import: 12 | exc, val, tb = base_import_ctx.import_error 13 | raise exc(val).with_traceback(tb) 14 | 15 | 16 | # The benchmark solvers must be named `Solver` and 17 | # inherit from `BaseSolver` for `benchopt` to work properly. 18 | class Solver(DASolver): 19 | # Name to select the solver in the CLI and to display the results. 20 | name = 'TarS' 21 | 22 | # List of parameters for the solver. The benchmark will consider 23 | # the cross product for each key in the dictionary. 24 | # All parameters 'p' defined here are available as 'self.p'. 25 | default_param_grid = { 26 | 'mmdtarsreweightadapter__gamma': [ 27 | 0.0001, 0.001, 0.01, 0.1, 1.0, 10.0, 100., 1000., None 28 | ], 29 | 'mmdtarsreweightadapter__reg': [1e-6], 30 | 'mmdtarsreweightadapter__tol': [1e-6], 31 | 'mmdtarsreweightadapter__max_iter': [1000], 32 | 'finalestimator__estimator_name': ["LR", "SVC", "XGB"], 33 | } 34 | 35 | def get_estimator(self, **kwargs): 36 | # The estimator passed should have a 'predict_proba' method. 37 | return make_da_pipeline( 38 | MMDTarSReweightAdapter(gamma=0.1), 39 | FinalEstimator(), 40 | ) 41 | -------------------------------------------------------------------------------- /solvers/density_reweight.py: -------------------------------------------------------------------------------- 1 | from benchopt import safe_import_context 2 | 3 | # Protect the import with `safe_import_context()`. This allows: 4 | # - skipping import to speed up autocompletion in CLI. 5 | # - getting requirements info when all dependencies are not installed. 6 | with safe_import_context() as import_ctx: 7 | from skada import DensityReweightAdapter, make_da_pipeline 8 | from benchmark_utils.base_solver import DASolver, FinalEstimator 9 | from sklearn.neighbors import KernelDensity 10 | 11 | from benchmark_utils.base_solver import import_ctx as base_import_ctx 12 | if base_import_ctx.failed_import: 13 | exc, val, tb = base_import_ctx.import_error 14 | raise exc(val).with_traceback(tb) 15 | 16 | 17 | # The benchmark solvers must be named `Solver` and 18 | # inherit from `BaseSolver` for `benchopt` to work properly. 19 | class Solver(DASolver): 20 | # Name to select the solver in the CLI and to display the results. 21 | name = 'density_reweight' 22 | 23 | # List of parameters for the solver. The benchmark will consider 24 | # the cross product for each key in the dictionary. 25 | # All parameters 'p' defined here are available as 'self.p'. 26 | default_param_grid = { 27 | 'densityreweightadapter__weight_estimator__bandwidth': [ 28 | 0.01, 0.1, 1., 10., 100., "scott", "silverman" 29 | ], 30 | 'finalestimator__estimator_name': ["LR", "SVC", "XGB"], 31 | } 32 | 33 | def get_estimator(self, **kwargs): 34 | # The estimator passed should have a 'predict_proba' method. 35 | return make_da_pipeline( 36 | DensityReweightAdapter( 37 | weight_estimator=KernelDensity(bandwidth=1.) 38 | ), 39 | FinalEstimator(), 40 | ) 41 | -------------------------------------------------------------------------------- /solvers/deep_coral.py: -------------------------------------------------------------------------------- 1 | from benchopt import safe_import_context 2 | 3 | # Protect the import with `safe_import_context()`. This allows: 4 | # - skipping import to speed up autocompletion in CLI. 5 | # - getting requirements info when all dependencies are not installed. 6 | with safe_import_context() as import_ctx: 7 | from benchmark_utils.deep_base_solver import DeepDASolver 8 | from benchmark_utils.utils import get_params_per_dataset 9 | from skada.deep import DeepCoral 10 | 11 | from benchmark_utils.deep_base_solver import import_ctx as base_import_ctx 12 | if base_import_ctx.failed_import: 13 | exc, val, tb = base_import_ctx.import_error 14 | raise exc(val).with_traceback(tb) 15 | 16 | 17 | # The benchmark solvers must be named `Solver` and 18 | # inherit from `BaseSolver` for `benchopt` to work properly. 19 | class Solver(DeepDASolver): 20 | # Name to select the solver in the CLI and to display the results. 21 | name = 'deep_coral' 22 | 23 | # List of parameters for the solver. The benchmark will consider 24 | # the cross product for each key in the dictionary. 25 | # All parameters 'p' defined here are available as 'self.p'. 26 | default_param_grid = { 27 | # 'criterion__reg': np.logspace(-5, 3, 9), 28 | 'criterion__reg': [1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1, 1e1, 1e2, 1e3], 29 | } 30 | 31 | def get_estimator(self, n_classes, device, dataset_name, **kwargs): 32 | 33 | dataset_name = dataset_name.split("[")[0].lower() 34 | 35 | params = get_params_per_dataset( 36 | dataset_name, n_classes, 37 | ) 38 | 39 | net = DeepCoral( 40 | **params, 41 | layer_name="feature_layer", 42 | train_split=None, 43 | device=device, 44 | warm_start=True, 45 | ) 46 | 47 | return net 48 | -------------------------------------------------------------------------------- /solvers/entropic_ot_mapping.py: -------------------------------------------------------------------------------- 1 | from benchopt import safe_import_context 2 | 3 | # Protect the import with `safe_import_context()`. This allows: 4 | # - skipping import to speed up autocompletion in CLI. 5 | # - getting requirements info when all dependencies are not installed. 6 | with safe_import_context() as import_ctx: 7 | from skada import EntropicOTMappingAdapter, make_da_pipeline 8 | from benchmark_utils.base_solver import DASolver, FinalEstimator 9 | 10 | from benchmark_utils.base_solver import import_ctx as base_import_ctx 11 | if base_import_ctx.failed_import: 12 | exc, val, tb = base_import_ctx.import_error 13 | raise exc(val).with_traceback(tb) 14 | 15 | 16 | # The benchmark solvers must be named `Solver` and 17 | # inherit from `BaseSolver` for `benchopt` to work properly. 18 | class Solver(DASolver): 19 | # Name to select the solver in the CLI and to display the results. 20 | name = 'entropic_ot_mapping' 21 | 22 | # List of parameters for the solver. The benchmark will consider 23 | # the cross product for each key in the dictionary. 24 | # All parameters 'p' defined here are available as 'self.p'. 25 | default_param_grid = { 26 | 'entropicotmappingadapter__reg_e': [0.1, 0.5, 1.], 27 | 'entropicotmappingadapter__metric': [ 28 | 'sqeuclidean', 'cosine', 'cityblock' 29 | ], 30 | 'entropicotmappingadapter__norm': ['median'], 31 | 'entropicotmappingadapter__max_iter': [1000], 32 | 'entropicotmappingadapter__tol': [1e-6], 33 | 'finalestimator__estimator_name': ["LR", "SVC", "XGB"], 34 | } 35 | 36 | def get_estimator(self, **kwargs): 37 | # The estimator passed should have a 'predict_proba' method. 38 | return make_da_pipeline( 39 | EntropicOTMappingAdapter(), 40 | FinalEstimator(), 41 | ) 42 | -------------------------------------------------------------------------------- /solvers/deep_no_da_source_only.py: -------------------------------------------------------------------------------- 1 | from benchopt import safe_import_context 2 | 3 | # Protect the import with `safe_import_context()`. This allows: 4 | # - skipping import to speed up autocompletion in CLI. 5 | # - getting requirements info when all dependencies are not installed. 6 | with safe_import_context() as import_ctx: 7 | from benchmark_utils.deep_base_solver import DeepDASolver 8 | from benchmark_utils.utils import get_params_per_dataset 9 | from skada.deep import SourceOnly 10 | from skada.metrics import SupervisedScorer 11 | 12 | from benchmark_utils.deep_base_solver import import_ctx as base_import_ctx 13 | if base_import_ctx.failed_import: 14 | exc, val, tb = base_import_ctx.import_error 15 | raise exc(val).with_traceback(tb) 16 | 17 | 18 | # The benchmark solvers must be named `Solver` and 19 | # inherit from `BaseSolver` for `benchopt` to work properly. 20 | class Solver(DeepDASolver): 21 | # Name to select the solver in the CLI and to display the results. 22 | name = 'deep_no_da_source_only' 23 | 24 | # List of parameters for the solver. The benchmark will consider 25 | # the cross product for each key in the dictionary. 26 | # All parameters 'p' defined here are available as 'self.p'. 27 | default_param_grid = {} 28 | 29 | def get_estimator(self, n_classes, device, dataset_name, **kwargs): 30 | self.criterions = { 31 | 'supervised': SupervisedScorer(), 32 | } 33 | 34 | dataset_name = dataset_name.split("[")[0].lower() 35 | 36 | params = get_params_per_dataset( 37 | dataset_name, n_classes, 38 | ) 39 | 40 | net = SourceOnly( 41 | **params, 42 | layer_name="feature_layer", 43 | train_split=None, 44 | device=device, 45 | warm_start=True, 46 | ) 47 | 48 | return net 49 | -------------------------------------------------------------------------------- /solvers/deep_no_da_target_only.py: -------------------------------------------------------------------------------- 1 | from benchopt import safe_import_context 2 | 3 | # Protect the import with `safe_import_context()`. This allows: 4 | # - skipping import to speed up autocompletion in CLI. 5 | # - getting requirements info when all dependencies are not installed. 6 | with safe_import_context() as import_ctx: 7 | from benchmark_utils.deep_base_solver import DeepDASolver 8 | from benchmark_utils.utils import get_params_per_dataset 9 | from skada.deep import TargetOnly 10 | from skada.metrics import SupervisedScorer 11 | 12 | from benchmark_utils.deep_base_solver import import_ctx as base_import_ctx 13 | if base_import_ctx.failed_import: 14 | exc, val, tb = base_import_ctx.import_error 15 | raise exc(val).with_traceback(tb) 16 | 17 | 18 | # The benchmark solvers must be named `Solver` and 19 | # inherit from `BaseSolver` for `benchopt` to work properly. 20 | class Solver(DeepDASolver): 21 | # Name to select the solver in the CLI and to display the results. 22 | name = 'deep_no_da_target_only' 23 | 24 | # List of parameters for the solver. The benchmark will consider 25 | # the cross product for each key in the dictionary. 26 | # All parameters 'p' defined here are available as 'self.p'. 27 | default_param_grid = {} 28 | 29 | def get_estimator(self, n_classes, device, dataset_name, **kwargs): 30 | self.criterions = { 31 | 'supervised': SupervisedScorer(), 32 | } 33 | 34 | dataset_name = dataset_name.split("[")[0].lower() 35 | 36 | params = get_params_per_dataset( 37 | dataset_name, n_classes, 38 | ) 39 | 40 | net = TargetOnly( 41 | **params, 42 | layer_name="feature_layer", 43 | train_split=None, 44 | device=device, 45 | warm_start=True, 46 | ) 47 | 48 | return net 49 | -------------------------------------------------------------------------------- /benchmark_utils/generate_config/generate_base_estim_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | from os import walk 4 | import importlib.util 5 | import sys 6 | from pathlib import Path 7 | 8 | PATH_benchmark_utils = Path(__file__).resolve().parents[1] 9 | PATH_skada_bench = Path(__file__).resolve().parents[2] 10 | 11 | sys.path.extend([str(PATH_benchmark_utils), str(PATH_skada_bench)]) 12 | 13 | from base_solver import get_estimator_grid # noqa: E402 14 | 15 | 16 | if __name__ == "__main__": 17 | param_list = [ 18 | {"finalestimator__estimator_name": [k]} for k in get_estimator_grid() 19 | if k != "test" 20 | ] 21 | param_dict = { 22 | "solver": {"NO_DA_SOURCE_ONLY_BASE_ESTIM": {"param_grid": param_list}} 23 | } 24 | 25 | dataset_list = [] 26 | 27 | filenames_dataset = next( 28 | walk(os.path.join(PATH_skada_bench, "datasets")), 29 | (None, None, []) 30 | )[2] 31 | 32 | for name in filenames_dataset: 33 | if not name.endswith('.py') or name.startswith('deep'): 34 | # To skip non-Python files like .DS_Store 35 | # + to skip deep datasets 36 | continue 37 | 38 | spec = importlib.util.spec_from_file_location( 39 | name, os.path.join(PATH_skada_bench, "datasets", name) 40 | ) 41 | 42 | if spec is None: 43 | # Safety check in case spec creation fails 44 | continue 45 | 46 | foo = importlib.util.module_from_spec(spec) 47 | sys.modules[name] = foo 48 | spec.loader.exec_module(foo) 49 | dataset_list.append(foo.Dataset.name) 50 | 51 | print(dataset_list) 52 | 53 | param_dict["dataset"] = dataset_list 54 | cfg_file = ( 55 | PATH_skada_bench / 'config' / 56 | 'find_best_base_estimators_per_dataset.yml' 57 | ) 58 | with open(cfg_file, 'w+') as ff: 59 | yaml.dump(param_dict, ff) 60 | -------------------------------------------------------------------------------- /solvers/mmdscons.py: -------------------------------------------------------------------------------- 1 | from benchopt import safe_import_context 2 | 3 | # Protect the import with `safe_import_context()`. This allows: 4 | # - skipping import to speed up autocompletion in CLI. 5 | # - getting requirements info when all dependencies are not installed. 6 | with safe_import_context() as import_ctx: 7 | from benchmark_utils.base_solver import DASolver, FinalEstimator 8 | from skada import MMDLSConSMappingAdapter, make_da_pipeline 9 | import torch # noqa: F401 10 | 11 | from benchmark_utils.base_solver import import_ctx as base_import_ctx 12 | if base_import_ctx.failed_import: 13 | exc, val, tb = base_import_ctx.import_error 14 | raise exc(val).with_traceback(tb) 15 | 16 | 17 | # The benchmark solvers must be named `Solver` and 18 | # inherit from `BaseSolver` for `benchopt` to work properly. 19 | class Solver(DASolver): 20 | # Name to select the solver in the CLI and to display the results. 21 | name = 'MMDSConS' 22 | 23 | # MMDSConS requires torch 24 | requirements = DASolver.requirements + ['pip:torch'] 25 | 26 | # List of parameters for the solver. The benchmark will consider 27 | # the cross product for each key in the dictionary. 28 | # All parameters 'p' defined here are available as 'self.p'. 29 | default_param_grid = { 30 | 'mmdlsconsmappingadapter__gamma': [0.01, 0.1, 1, 10, 100], 31 | 'mmdlsconsmappingadapter__reg_k': [1e-8], 32 | 'mmdlsconsmappingadapter__reg_m': [1e-8], 33 | 'mmdlsconsmappingadapter__tol': [1e-5], 34 | 'mmdlsconsmappingadapter__max_iter': [20], 35 | 'finalestimator__estimator_name': ["LR", "SVC", "XGB"], 36 | } 37 | 38 | def get_estimator(self, **kwargs): 39 | # The estimator passed should have a 'predict_proba' method. 40 | return make_da_pipeline( 41 | MMDLSConSMappingAdapter(gamma=0.1), 42 | FinalEstimator(), 43 | ) 44 | -------------------------------------------------------------------------------- /solvers/transfer_component_analysis.py: -------------------------------------------------------------------------------- 1 | from benchopt import safe_import_context 2 | 3 | # Protect the import with `safe_import_context()`. This allows: 4 | # - skipping import to speed up autocompletion in CLI. 5 | # - getting requirements info when all dependencies are not installed. 6 | with safe_import_context() as import_ctx: 7 | from skada import TransferComponentAnalysisAdapter, make_da_pipeline 8 | from skada.transformers import StratifiedDomainSubsampler 9 | from benchmark_utils.base_solver import DASolver, FinalEstimator 10 | 11 | from benchmark_utils.base_solver import import_ctx as base_import_ctx 12 | if base_import_ctx.failed_import: 13 | exc, val, tb = base_import_ctx.import_error 14 | raise exc(val).with_traceback(tb) 15 | 16 | 17 | # The benchmark solvers must be named `Solver` and 18 | # inherit from `BaseSolver` for `benchopt` to work properly. 19 | class Solver(DASolver): 20 | # Name to select the solver in the CLI and to display the results. 21 | name = 'transfer_component_analysis' 22 | 23 | # List of parameters for the solver. The benchmark will consider 24 | # the cross product for each key in the dictionary. 25 | # All parameters 'p' defined here are available as 'self.p'. 26 | default_param_grid = { 27 | 'transfercomponentanalysisadapter__kernel': ['rbf'], 28 | 'transfercomponentanalysisadapter__n_components': [ 29 | 1, 2, 5, 10, 20, 50, 100 30 | ], 31 | 'transfercomponentanalysisadapter__mu': [10, 100], 32 | 'finalestimator__estimator_name': ["LR", "SVC", "XGB"], 33 | } 34 | 35 | def get_estimator(self, **kwargs): 36 | # The estimator passed should have a 'predict_proba' method. 37 | subsampler = StratifiedDomainSubsampler( 38 | train_size=1000 39 | ) 40 | 41 | return make_da_pipeline( 42 | subsampler, 43 | TransferComponentAnalysisAdapter(), 44 | FinalEstimator(), 45 | ) 46 | -------------------------------------------------------------------------------- /solvers/deep_mcc.py: -------------------------------------------------------------------------------- 1 | from benchopt import safe_import_context 2 | 3 | # Protect the import with `safe_import_context()`. This allows: 4 | # - skipping import to speed up autocompletion in CLI. 5 | # - getting requirements info when all dependencies are not installed. 6 | with safe_import_context() as import_ctx: 7 | from benchmark_utils.deep_base_solver import DeepDASolver 8 | from benchmark_utils.utils import get_params_per_dataset 9 | from skada.deep import MCC, MCCLoss 10 | 11 | from benchmark_utils.deep_base_solver import import_ctx as base_import_ctx 12 | if base_import_ctx.failed_import: 13 | exc, val, tb = base_import_ctx.import_error 14 | raise exc(val).with_traceback(tb) 15 | 16 | if import_ctx.failed_import: 17 | class MCCLoss: # noqa: F811 18 | def __init__(self, T): pass 19 | 20 | 21 | # The benchmark solvers must be named `Solver` and 22 | # inherit from `BaseSolver` for `benchopt` to work properly. 23 | class Solver(DeepDASolver): 24 | # Name to select the solver in the CLI and to display the results. 25 | name = 'deep_mcc' 26 | 27 | # List of parameters for the solver. The benchmark will consider 28 | # the cross product for each key in the dictionary. 29 | # All parameters 'p' defined here are available as 'self.p'. 30 | default_param_grid = { 31 | 'criterion__reg': [1e-2, 1e-1, 1], 32 | 'criterion__adapt_criterion': [ 33 | MCCLoss(T=T) 34 | for T in [1, 2, 3] 35 | ], 36 | } 37 | 38 | def get_estimator(self, n_classes, device, dataset_name, **kwargs): 39 | dataset_name = dataset_name.split("[")[0].lower() 40 | 41 | params = get_params_per_dataset( 42 | dataset_name, n_classes, 43 | ) 44 | 45 | net = MCC( 46 | **params, 47 | layer_name="feature_layer", 48 | train_split=None, 49 | device=device, 50 | warm_start=True, 51 | ) 52 | 53 | return net 54 | -------------------------------------------------------------------------------- /config/solvers/class_regularizer_ot_mapping/BCI.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - BCI 3 | solver: 4 | - class_regularizer_ot_mapping: 5 | param_grid: 6 | - - classregularizerotmappingadapter__max_inner_iter: 7 | - 1000 8 | classregularizerotmappingadapter__max_iter: 9 | - 10 10 | classregularizerotmappingadapter__metric: 11 | - sqeuclidean 12 | - cosine 13 | - cityblock 14 | classregularizerotmappingadapter__norm: 15 | - lpl1 16 | classregularizerotmappingadapter__reg_cl: 17 | - 0.1 18 | classregularizerotmappingadapter__reg_e: 19 | - 0.1 20 | classregularizerotmappingadapter__tol: 21 | - 1.0e-06 22 | finalestimator__estimator_name: 23 | - LR_C2.0 24 | - classregularizerotmappingadapter__max_inner_iter: 25 | - 1000 26 | classregularizerotmappingadapter__max_iter: 27 | - 10 28 | classregularizerotmappingadapter__metric: 29 | - sqeuclidean 30 | - cosine 31 | - cityblock 32 | classregularizerotmappingadapter__norm: 33 | - lpl1 34 | classregularizerotmappingadapter__reg_cl: 35 | - 0.5 36 | classregularizerotmappingadapter__reg_e: 37 | - 0.5 38 | classregularizerotmappingadapter__tol: 39 | - 1.0e-06 40 | finalestimator__estimator_name: 41 | - LR_C2.0 42 | - classregularizerotmappingadapter__max_inner_iter: 43 | - 1000 44 | classregularizerotmappingadapter__max_iter: 45 | - 10 46 | classregularizerotmappingadapter__metric: 47 | - sqeuclidean 48 | - cosine 49 | - cityblock 50 | classregularizerotmappingadapter__norm: 51 | - lpl1 52 | classregularizerotmappingadapter__reg_cl: 53 | - 1.0 54 | classregularizerotmappingadapter__reg_e: 55 | - 1.0 56 | classregularizerotmappingadapter__tol: 57 | - 1.0e-06 58 | finalestimator__estimator_name: 59 | - LR_C2.0 60 | -------------------------------------------------------------------------------- /config/solvers/class_regularizer_ot_mapping/Mushrooms.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Mushrooms 3 | solver: 4 | - class_regularizer_ot_mapping: 5 | param_grid: 6 | - - classregularizerotmappingadapter__max_inner_iter: 7 | - 1000 8 | classregularizerotmappingadapter__max_iter: 9 | - 10 10 | classregularizerotmappingadapter__metric: 11 | - sqeuclidean 12 | - cosine 13 | - cityblock 14 | classregularizerotmappingadapter__norm: 15 | - lpl1 16 | classregularizerotmappingadapter__reg_cl: 17 | - 0.1 18 | classregularizerotmappingadapter__reg_e: 19 | - 0.1 20 | classregularizerotmappingadapter__tol: 21 | - 1.0e-06 22 | finalestimator__estimator_name: 23 | - LR 24 | - classregularizerotmappingadapter__max_inner_iter: 25 | - 1000 26 | classregularizerotmappingadapter__max_iter: 27 | - 10 28 | classregularizerotmappingadapter__metric: 29 | - sqeuclidean 30 | - cosine 31 | - cityblock 32 | classregularizerotmappingadapter__norm: 33 | - lpl1 34 | classregularizerotmappingadapter__reg_cl: 35 | - 0.5 36 | classregularizerotmappingadapter__reg_e: 37 | - 0.5 38 | classregularizerotmappingadapter__tol: 39 | - 1.0e-06 40 | finalestimator__estimator_name: 41 | - LR 42 | - classregularizerotmappingadapter__max_inner_iter: 43 | - 1000 44 | classregularizerotmappingadapter__max_iter: 45 | - 10 46 | classregularizerotmappingadapter__metric: 47 | - sqeuclidean 48 | - cosine 49 | - cityblock 50 | classregularizerotmappingadapter__norm: 51 | - lpl1 52 | classregularizerotmappingadapter__reg_cl: 53 | - 1.0 54 | classregularizerotmappingadapter__reg_e: 55 | - 1.0 56 | classregularizerotmappingadapter__tol: 57 | - 1.0e-06 58 | finalestimator__estimator_name: 59 | - LR 60 | -------------------------------------------------------------------------------- /config/solvers/class_regularizer_ot_mapping/Simulated.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Simulated 3 | solver: 4 | - class_regularizer_ot_mapping: 5 | param_grid: 6 | - - classregularizerotmappingadapter__max_inner_iter: 7 | - 1000 8 | classregularizerotmappingadapter__max_iter: 9 | - 10 10 | classregularizerotmappingadapter__metric: 11 | - sqeuclidean 12 | - cosine 13 | - cityblock 14 | classregularizerotmappingadapter__norm: 15 | - lpl1 16 | classregularizerotmappingadapter__reg_cl: 17 | - 0.1 18 | classregularizerotmappingadapter__reg_e: 19 | - 0.1 20 | classregularizerotmappingadapter__tol: 21 | - 1.0e-06 22 | finalestimator__estimator_name: 23 | - SVC 24 | - classregularizerotmappingadapter__max_inner_iter: 25 | - 1000 26 | classregularizerotmappingadapter__max_iter: 27 | - 10 28 | classregularizerotmappingadapter__metric: 29 | - sqeuclidean 30 | - cosine 31 | - cityblock 32 | classregularizerotmappingadapter__norm: 33 | - lpl1 34 | classregularizerotmappingadapter__reg_cl: 35 | - 0.5 36 | classregularizerotmappingadapter__reg_e: 37 | - 0.5 38 | classregularizerotmappingadapter__tol: 39 | - 1.0e-06 40 | finalestimator__estimator_name: 41 | - SVC 42 | - classregularizerotmappingadapter__max_inner_iter: 43 | - 1000 44 | classregularizerotmappingadapter__max_iter: 45 | - 10 46 | classregularizerotmappingadapter__metric: 47 | - sqeuclidean 48 | - cosine 49 | - cityblock 50 | classregularizerotmappingadapter__norm: 51 | - lpl1 52 | classregularizerotmappingadapter__reg_cl: 53 | - 1.0 54 | classregularizerotmappingadapter__reg_e: 55 | - 1.0 56 | classregularizerotmappingadapter__tol: 57 | - 1.0e-06 58 | finalestimator__estimator_name: 59 | - SVC 60 | -------------------------------------------------------------------------------- /config/solvers/class_regularizer_ot_mapping/Office31.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Office31 3 | solver: 4 | - class_regularizer_ot_mapping: 5 | param_grid: 6 | - - classregularizerotmappingadapter__max_inner_iter: 7 | - 1000 8 | classregularizerotmappingadapter__max_iter: 9 | - 10 10 | classregularizerotmappingadapter__metric: 11 | - sqeuclidean 12 | - cosine 13 | - cityblock 14 | classregularizerotmappingadapter__norm: 15 | - lpl1 16 | classregularizerotmappingadapter__reg_cl: 17 | - 0.1 18 | classregularizerotmappingadapter__reg_e: 19 | - 0.1 20 | classregularizerotmappingadapter__tol: 21 | - 1.0e-06 22 | finalestimator__estimator_name: 23 | - LR_C0.01 24 | - classregularizerotmappingadapter__max_inner_iter: 25 | - 1000 26 | classregularizerotmappingadapter__max_iter: 27 | - 10 28 | classregularizerotmappingadapter__metric: 29 | - sqeuclidean 30 | - cosine 31 | - cityblock 32 | classregularizerotmappingadapter__norm: 33 | - lpl1 34 | classregularizerotmappingadapter__reg_cl: 35 | - 0.5 36 | classregularizerotmappingadapter__reg_e: 37 | - 0.5 38 | classregularizerotmappingadapter__tol: 39 | - 1.0e-06 40 | finalestimator__estimator_name: 41 | - LR_C0.01 42 | - classregularizerotmappingadapter__max_inner_iter: 43 | - 1000 44 | classregularizerotmappingadapter__max_iter: 45 | - 10 46 | classregularizerotmappingadapter__metric: 47 | - sqeuclidean 48 | - cosine 49 | - cityblock 50 | classregularizerotmappingadapter__norm: 51 | - lpl1 52 | classregularizerotmappingadapter__reg_cl: 53 | - 1.0 54 | classregularizerotmappingadapter__reg_e: 55 | - 1.0 56 | classregularizerotmappingadapter__tol: 57 | - 1.0e-06 58 | finalestimator__estimator_name: 59 | - LR_C0.01 60 | -------------------------------------------------------------------------------- /benchmark_utils/deep_base_solver.py: -------------------------------------------------------------------------------- 1 | from benchopt import safe_import_context 2 | 3 | from benchmark_utils.base_solver import DASolver 4 | 5 | with safe_import_context() as import_ctx: 6 | import torch 7 | from skada.metrics import ( 8 | SupervisedScorer, DeepEmbeddedValidation, 9 | PredictionEntropyScorer, ImportanceWeightedScorer, 10 | SoftNeighborhoodDensity, MixValScorer, 11 | ) 12 | 13 | 14 | class DeepDASolver(DASolver): 15 | n_jobs = 1 16 | 17 | requirements = ['pip:skada[deep]==0.4.0'] 18 | 19 | # For DeepDA solvers, empty test_param_grid 20 | test_param_grid = {} 21 | 22 | def __init__(self, **kwargs): 23 | super().__init__(print_infos=False, **kwargs) 24 | 25 | # Set device depending on the gpu/cpu available 26 | if torch.cuda.is_available(): 27 | self.device = torch.device("cuda") 28 | else: 29 | self.device = torch.device("cpu") 30 | 31 | print(f"n_jobs: {self.n_jobs}") 32 | print(f"device: {self.device}") 33 | 34 | self.criterions = { 35 | 'supervised': SupervisedScorer(), 36 | 'prediction_entropy': PredictionEntropyScorer(), 37 | 'importance_weighted': ImportanceWeightedScorer(), 38 | 'soft_neighborhood_density': SoftNeighborhoodDensity(), 39 | 'deep_embedded_validation': DeepEmbeddedValidation(), 40 | 'mix_val_both': MixValScorer(ice_type='both'), 41 | 'mix_val_inter': MixValScorer(ice_type='inter'), 42 | 'mix_val_intra': MixValScorer(ice_type='intra'), 43 | } 44 | 45 | # Override the DASolver skip method 46 | def skip(self, X, y, sample_domain, unmasked_y_train, dataset): 47 | # Check if the dataset name does not start 48 | # with 'deep' and is not 'Simulated' 49 | if not ( 50 | dataset.name.startswith('deep') or 51 | dataset.name == 'Simulated' 52 | ): 53 | return True, f"solver does not support the dataset {dataset.name}." 54 | 55 | return False, None 56 | -------------------------------------------------------------------------------- /config/solvers/class_regularizer_ot_mapping/mnist_usps.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - mnist_usps 3 | solver: 4 | - class_regularizer_ot_mapping: 5 | param_grid: 6 | - - classregularizerotmappingadapter__max_inner_iter: 7 | - 1000 8 | classregularizerotmappingadapter__max_iter: 9 | - 10 10 | classregularizerotmappingadapter__metric: 11 | - sqeuclidean 12 | - cosine 13 | - cityblock 14 | classregularizerotmappingadapter__norm: 15 | - lpl1 16 | classregularizerotmappingadapter__reg_cl: 17 | - 0.1 18 | classregularizerotmappingadapter__reg_e: 19 | - 0.1 20 | classregularizerotmappingadapter__tol: 21 | - 1.0e-06 22 | finalestimator__estimator_name: 23 | - SVC_C10.0_Gamma0.01 24 | - classregularizerotmappingadapter__max_inner_iter: 25 | - 1000 26 | classregularizerotmappingadapter__max_iter: 27 | - 10 28 | classregularizerotmappingadapter__metric: 29 | - sqeuclidean 30 | - cosine 31 | - cityblock 32 | classregularizerotmappingadapter__norm: 33 | - lpl1 34 | classregularizerotmappingadapter__reg_cl: 35 | - 0.5 36 | classregularizerotmappingadapter__reg_e: 37 | - 0.5 38 | classregularizerotmappingadapter__tol: 39 | - 1.0e-06 40 | finalestimator__estimator_name: 41 | - SVC_C10.0_Gamma0.01 42 | - classregularizerotmappingadapter__max_inner_iter: 43 | - 1000 44 | classregularizerotmappingadapter__max_iter: 45 | - 10 46 | classregularizerotmappingadapter__metric: 47 | - sqeuclidean 48 | - cosine 49 | - cityblock 50 | classregularizerotmappingadapter__norm: 51 | - lpl1 52 | classregularizerotmappingadapter__reg_cl: 53 | - 1.0 54 | classregularizerotmappingadapter__reg_e: 55 | - 1.0 56 | classregularizerotmappingadapter__tol: 57 | - 1.0e-06 58 | finalestimator__estimator_name: 59 | - SVC_C10.0_Gamma0.01 60 | -------------------------------------------------------------------------------- /config/solvers/class_regularizer_ot_mapping/20NewsGroups.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - 20NewsGroups 3 | solver: 4 | - class_regularizer_ot_mapping: 5 | param_grid: 6 | - - classregularizerotmappingadapter__max_inner_iter: 7 | - 1000 8 | classregularizerotmappingadapter__max_iter: 9 | - 10 10 | classregularizerotmappingadapter__metric: 11 | - sqeuclidean 12 | - cosine 13 | - cityblock 14 | classregularizerotmappingadapter__norm: 15 | - lpl1 16 | classregularizerotmappingadapter__reg_cl: 17 | - 0.1 18 | classregularizerotmappingadapter__reg_e: 19 | - 0.1 20 | classregularizerotmappingadapter__tol: 21 | - 1.0e-06 22 | finalestimator__estimator_name: 23 | - SVC_C10.0_Gamma10.0 24 | - classregularizerotmappingadapter__max_inner_iter: 25 | - 1000 26 | classregularizerotmappingadapter__max_iter: 27 | - 10 28 | classregularizerotmappingadapter__metric: 29 | - sqeuclidean 30 | - cosine 31 | - cityblock 32 | classregularizerotmappingadapter__norm: 33 | - lpl1 34 | classregularizerotmappingadapter__reg_cl: 35 | - 0.5 36 | classregularizerotmappingadapter__reg_e: 37 | - 0.5 38 | classregularizerotmappingadapter__tol: 39 | - 1.0e-06 40 | finalestimator__estimator_name: 41 | - SVC_C10.0_Gamma10.0 42 | - classregularizerotmappingadapter__max_inner_iter: 43 | - 1000 44 | classregularizerotmappingadapter__max_iter: 45 | - 10 46 | classregularizerotmappingadapter__metric: 47 | - sqeuclidean 48 | - cosine 49 | - cityblock 50 | classregularizerotmappingadapter__norm: 51 | - lpl1 52 | classregularizerotmappingadapter__reg_cl: 53 | - 1.0 54 | classregularizerotmappingadapter__reg_e: 55 | - 1.0 56 | classregularizerotmappingadapter__tol: 57 | - 1.0e-06 58 | finalestimator__estimator_name: 59 | - SVC_C10.0_Gamma10.0 60 | -------------------------------------------------------------------------------- /benchmark_utils/extract_best_base_estim.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import yaml 3 | import os 4 | 5 | PATH = os.path.dirname(os.path.dirname(__file__)) 6 | CONFIG_FILE = os.path.join(PATH, "config", "best_base_estimators.yml") 7 | RESULT_FILE = os.path.join( 8 | PATH, "results_base_estimators", "results_base_estim_experiments.csv" 9 | ) 10 | 11 | SCORE_COL = "source_accuracy-test-mean" 12 | 13 | 14 | if __name__ == "__main__": 15 | 16 | with open(CONFIG_FILE) as stream: 17 | best_base_estimators = yaml.safe_load(stream) 18 | 19 | df = pd.read_csv(RESULT_FILE) 20 | 21 | def rename_params(x): 22 | return x.split("['")[-1].split("']")[0] 23 | 24 | def extract_svc(x): 25 | if "SVC" in x: 26 | return True 27 | else: 28 | return False 29 | 30 | df["params"] = df["params"].apply(rename_params) 31 | df = df.loc[df["scorer"] == "supervised"] 32 | df = df.loc[df["estimator"] == "NO_DA_SOURCE_ONLY_BASE_ESTIM"] 33 | 34 | # Find best Estim 35 | for dataset in df.dataset.unique(): 36 | 37 | df_best = df.loc[df.dataset == dataset] 38 | df_best = ( 39 | df_best.groupby(["dataset", "params"]) 40 | .mean(numeric_only=True).reset_index() 41 | ) 42 | mask_svc = df_best.params.apply(extract_svc) 43 | df_best_svc = df_best.loc[mask_svc] 44 | 45 | best_estim = df_best.iloc[df_best[SCORE_COL].argmax()].params 46 | best_acc = df_best.iloc[df_best[SCORE_COL].argmax()][SCORE_COL] 47 | 48 | id_best = df_best_svc[SCORE_COL].argmax() 49 | best_estim_svc = df_best_svc.iloc[id_best].params 50 | best_acc_svc = df_best_svc.iloc[id_best][SCORE_COL] 51 | 52 | print(dataset, "Best:", best_estim, best_acc) 53 | print(dataset, "Best SVC:", best_estim_svc, best_acc_svc) 54 | 55 | best_base_estimators[dataset] = dict( 56 | Best=best_estim, BestSVC=best_estim_svc 57 | ) 58 | 59 | with open(CONFIG_FILE, 'w+') as ff: 60 | yaml.dump(best_base_estimators, ff, default_flow_style=False) 61 | -------------------------------------------------------------------------------- /config/solvers/class_regularizer_ot_mapping/AmazonReview.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - AmazonReview 3 | solver: 4 | - class_regularizer_ot_mapping: 5 | param_grid: 6 | - - classregularizerotmappingadapter__max_inner_iter: 7 | - 1000 8 | classregularizerotmappingadapter__max_iter: 9 | - 10 10 | classregularizerotmappingadapter__metric: 11 | - sqeuclidean 12 | - cosine 13 | - cityblock 14 | classregularizerotmappingadapter__norm: 15 | - lpl1 16 | classregularizerotmappingadapter__reg_cl: 17 | - 0.1 18 | classregularizerotmappingadapter__reg_e: 19 | - 0.1 20 | classregularizerotmappingadapter__tol: 21 | - 1.0e-06 22 | finalestimator__estimator_name: 23 | - SVC_C1000.0_Gamma0.001 24 | - classregularizerotmappingadapter__max_inner_iter: 25 | - 1000 26 | classregularizerotmappingadapter__max_iter: 27 | - 10 28 | classregularizerotmappingadapter__metric: 29 | - sqeuclidean 30 | - cosine 31 | - cityblock 32 | classregularizerotmappingadapter__norm: 33 | - lpl1 34 | classregularizerotmappingadapter__reg_cl: 35 | - 0.5 36 | classregularizerotmappingadapter__reg_e: 37 | - 0.5 38 | classregularizerotmappingadapter__tol: 39 | - 1.0e-06 40 | finalestimator__estimator_name: 41 | - SVC_C1000.0_Gamma0.001 42 | - classregularizerotmappingadapter__max_inner_iter: 43 | - 1000 44 | classregularizerotmappingadapter__max_iter: 45 | - 10 46 | classregularizerotmappingadapter__metric: 47 | - sqeuclidean 48 | - cosine 49 | - cityblock 50 | classregularizerotmappingadapter__norm: 51 | - lpl1 52 | classregularizerotmappingadapter__reg_cl: 53 | - 1.0 54 | classregularizerotmappingadapter__reg_e: 55 | - 1.0 56 | classregularizerotmappingadapter__tol: 57 | - 1.0e-06 58 | finalestimator__estimator_name: 59 | - SVC_C1000.0_Gamma0.001 60 | -------------------------------------------------------------------------------- /config/solvers/class_regularizer_ot_mapping/OfficeHomeResnet.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - OfficeHomeResnet 3 | solver: 4 | - class_regularizer_ot_mapping: 5 | param_grid: 6 | - - classregularizerotmappingadapter__max_inner_iter: 7 | - 1000 8 | classregularizerotmappingadapter__max_iter: 9 | - 10 10 | classregularizerotmappingadapter__metric: 11 | - sqeuclidean 12 | - cosine 13 | - cityblock 14 | classregularizerotmappingadapter__norm: 15 | - lpl1 16 | classregularizerotmappingadapter__reg_cl: 17 | - 0.1 18 | classregularizerotmappingadapter__reg_e: 19 | - 0.1 20 | classregularizerotmappingadapter__tol: 21 | - 1.0e-06 22 | finalestimator__estimator_name: 23 | - SVC_C10.0_Gamma0.001 24 | - classregularizerotmappingadapter__max_inner_iter: 25 | - 1000 26 | classregularizerotmappingadapter__max_iter: 27 | - 10 28 | classregularizerotmappingadapter__metric: 29 | - sqeuclidean 30 | - cosine 31 | - cityblock 32 | classregularizerotmappingadapter__norm: 33 | - lpl1 34 | classregularizerotmappingadapter__reg_cl: 35 | - 0.5 36 | classregularizerotmappingadapter__reg_e: 37 | - 0.5 38 | classregularizerotmappingadapter__tol: 39 | - 1.0e-06 40 | finalestimator__estimator_name: 41 | - SVC_C10.0_Gamma0.001 42 | - classregularizerotmappingadapter__max_inner_iter: 43 | - 1000 44 | classregularizerotmappingadapter__max_iter: 45 | - 10 46 | classregularizerotmappingadapter__metric: 47 | - sqeuclidean 48 | - cosine 49 | - cityblock 50 | classregularizerotmappingadapter__norm: 51 | - lpl1 52 | classregularizerotmappingadapter__reg_cl: 53 | - 1.0 54 | classregularizerotmappingadapter__reg_e: 55 | - 1.0 56 | classregularizerotmappingadapter__tol: 57 | - 1.0e-06 58 | finalestimator__estimator_name: 59 | - SVC_C10.0_Gamma0.001 60 | -------------------------------------------------------------------------------- /solvers/otlabelprop.py: -------------------------------------------------------------------------------- 1 | from benchopt import safe_import_context 2 | 3 | # Protect the import with `safe_import_context()`. This allows: 4 | # - skipping import to speed up autocompletion in CLI. 5 | # - getting requirements info when all dependencies are not installed. 6 | with safe_import_context() as import_ctx: 7 | from skada import OTLabelPropAdapter, make_da_pipeline 8 | from benchmark_utils.base_solver import DASolver, FinalEstimator 9 | 10 | from benchmark_utils.base_solver import import_ctx as base_import_ctx 11 | if base_import_ctx.failed_import: 12 | exc, val, tb = base_import_ctx.import_error 13 | raise exc(val).with_traceback(tb) 14 | 15 | 16 | # The benchmark solvers must be named `Solver` and 17 | # inherit from `BaseSolver` for `benchopt` to work properly. 18 | class Solver(DASolver): 19 | 20 | # Name to select the solver in the CLI and to display the results. 21 | name = 'OTLabelProp' 22 | 23 | # List of parameters for the solver. The benchmark will consider 24 | # the cross product for each key in the dictionary. 25 | # All parameters 'p' defined here are available as 'self.p'. 26 | default_param_grid = [ 27 | { 28 | 'otlabelpropadapter__metric': [ 29 | 'sqeuclidean', 'cosine', 'cityblock' 30 | ], 31 | 'otlabelpropadapter__reg': [None], 32 | 'otlabelpropadapter__n_iter_max': [10000], 33 | 'finalestimator__estimator_name': ["LR", "SVC", "XGB"], 34 | }, 35 | { 36 | 'otlabelpropadapter__metric': [ 37 | 'sqeuclidean', 'cosine', 'cityblock' 38 | ], 39 | 'otlabelpropadapter__reg': [0.1, 1], 40 | 'otlabelpropadapter__n_iter_max': [100], 41 | 'finalestimator__estimator_name': ["LR", "SVC", "XGB"], 42 | }, 43 | ] 44 | 45 | def get_estimator(self, **kwargs): 46 | # The estimator passed should have a 'predict_proba' method. 47 | return make_da_pipeline( 48 | OTLabelPropAdapter(), 49 | FinalEstimator(), 50 | ) 51 | -------------------------------------------------------------------------------- /config/solvers/class_regularizer_ot_mapping/Phishing.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | - Phishing 3 | solver: 4 | - class_regularizer_ot_mapping: 5 | param_grid: 6 | - - classregularizerotmappingadapter__max_inner_iter: 7 | - 50 8 | classregularizerotmappingadapter__max_iter: 9 | - 10 10 | classregularizerotmappingadapter__metric: 11 | - sqeuclidean 12 | - cosine 13 | - cityblock 14 | classregularizerotmappingadapter__norm: 15 | - lpl1 16 | classregularizerotmappingadapter__reg_cl: 17 | - 0.1 18 | classregularizerotmappingadapter__reg_e: 19 | - 0.1 20 | classregularizerotmappingadapter__tol: 21 | - 1.0e-06 22 | finalestimator__estimator_name: 23 | - XGB_subsample0.8_colsample0.65_maxdepth20 24 | - classregularizerotmappingadapter__max_inner_iter: 25 | - 50 26 | classregularizerotmappingadapter__max_iter: 27 | - 10 28 | classregularizerotmappingadapter__metric: 29 | - sqeuclidean 30 | - cosine 31 | - cityblock 32 | classregularizerotmappingadapter__norm: 33 | - lpl1 34 | classregularizerotmappingadapter__reg_cl: 35 | - 0.5 36 | classregularizerotmappingadapter__reg_e: 37 | - 0.5 38 | classregularizerotmappingadapter__tol: 39 | - 1.0e-06 40 | finalestimator__estimator_name: 41 | - XGB_subsample0.8_colsample0.65_maxdepth20 42 | - classregularizerotmappingadapter__max_inner_iter: 43 | - 50 44 | classregularizerotmappingadapter__max_iter: 45 | - 10 46 | classregularizerotmappingadapter__metric: 47 | - sqeuclidean 48 | - cosine 49 | - cityblock 50 | classregularizerotmappingadapter__norm: 51 | - lpl1 52 | classregularizerotmappingadapter__reg_cl: 53 | - 1.0 54 | classregularizerotmappingadapter__reg_e: 55 | - 1.0 56 | classregularizerotmappingadapter__tol: 57 | - 1.0e-06 58 | finalestimator__estimator_name: 59 | - XGB_subsample0.8_colsample0.65_maxdepth20 60 | -------------------------------------------------------------------------------- /solvers/deep_jdot.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from benchopt import safe_import_context 3 | 4 | # Protect the import with `safe_import_context()`. This allows: 5 | # - skipping import to speed up autocompletion in CLI. 6 | # - getting requirements info when all dependencies are not installed. 7 | with safe_import_context() as import_ctx: 8 | from benchmark_utils.deep_base_solver import DeepDASolver 9 | from benchmark_utils.utils import get_params_per_dataset 10 | from skada.deep import DeepJDOT, DeepJDOTLoss 11 | 12 | from benchmark_utils.deep_base_solver import import_ctx as base_import_ctx 13 | if base_import_ctx.failed_import: 14 | exc, val, tb = base_import_ctx.import_error 15 | raise exc(val).with_traceback(tb) 16 | 17 | if import_ctx.failed_import: 18 | class DeepJDOTLoss: # noqa: F811 19 | def __init__(self, reg_cl, reg_dist): pass 20 | 21 | 22 | # The benchmark solvers must be named `Solver` and 23 | # inherit from `BaseSolver` for `benchopt` to work properly. 24 | class Solver(DeepDASolver): 25 | # Name to select the solver in the CLI and to display the results. 26 | name = 'deep_jdot' 27 | 28 | # List of parameters for the solver. The benchmark will consider 29 | # the cross product for each key in the dictionary. 30 | # All parameters 'p' defined here are available as 'self.p'. 31 | default_param_grid = { 32 | 'criterion__adapt_criterion': [ 33 | DeepJDOTLoss(reg_cl=r_cl, reg_dist=r_dist) 34 | for r_cl, r_dist in itertools.product( 35 | [1e-4, 1e-3, 1e-2], [1e-4, 1e-3, 1e-2] 36 | ) 37 | ], 38 | } 39 | 40 | def get_estimator(self, n_classes, device, dataset_name, **kwargs): 41 | 42 | dataset_name = dataset_name.split("[")[0].lower() 43 | 44 | params = get_params_per_dataset( 45 | dataset_name, n_classes, 46 | ) 47 | 48 | net = DeepJDOT( 49 | **params, 50 | layer_name="feature_layer", 51 | train_split=None, 52 | device=device, 53 | warm_start=True, 54 | ) 55 | 56 | return net 57 | -------------------------------------------------------------------------------- /benchmark_utils/generate_config/generate_config_simulated.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import importlib.util 3 | import sys 4 | import yaml 5 | 6 | PATH_benchmark_utils = Path(__file__).resolve().parents[1] 7 | PATH_skada_bench = Path(__file__).resolve().parents[2] 8 | 9 | sys.path.extend([str(PATH_benchmark_utils), str(PATH_skada_bench)]) 10 | 11 | if __name__ == "__main__": 12 | solvers_path = PATH_skada_bench / "solvers" 13 | datasets_path = PATH_skada_bench / "datasets" 14 | config_path = PATH_skada_bench / "config" 15 | 16 | filenames = [ 17 | f for f in solvers_path.iterdir() 18 | if f.is_file() and not f.name.startswith('.') and f.suffix == '.py' 19 | ] 20 | 21 | with open(config_path / "best_base_estimators.yml") as stream: 22 | best_base_estimators = yaml.safe_load(stream) 23 | 24 | for best in ["LR", "SVC", "XGB"]: 25 | for dataset in ["Simulated"]: 26 | 27 | DD = {} 28 | DD["dataset"] = [dataset] 29 | DD["solver"] = [] 30 | 31 | for filepath in filenames: 32 | name = filepath.stem # Remove the .py suffix 33 | print(name) 34 | spec = importlib.util.spec_from_file_location(name, filepath) 35 | foo = importlib.util.module_from_spec(spec) 36 | sys.modules[name] = foo 37 | spec.loader.exec_module(foo) 38 | 39 | if foo.Solver.name not in ["JDOT_SVC", "DASVM"]: 40 | param_grid = foo.Solver.default_param_grid 41 | if isinstance(param_grid, list): 42 | for i in range(len(param_grid)): 43 | param_grid[i]['finalestimator__estimator_name'] = [best] # noqa: E501 44 | else: 45 | param_grid['finalestimator__estimator_name'] = [best] 46 | 47 | DD["solver"].append({ 48 | foo.Solver.name: {"param_grid": [param_grid]} 49 | }) 50 | 51 | with open(config_path / f"{dataset}_{best}.yml", 'w+') as ff: 52 | yaml.dump(DD, ff, default_flow_style=False) 53 | -------------------------------------------------------------------------------- /datasets/simulated.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from benchopt import BaseDataset, safe_import_context 3 | 4 | with safe_import_context() as import_ctx: 5 | from skada.datasets import make_shifted_datasets 6 | 7 | 8 | # All datasets must be named `Dataset` and inherit from `BaseDataset` 9 | class Dataset(BaseDataset): 10 | 11 | # Name to select the dataset in the CLI and to display the results. 12 | name = "Simulated" 13 | 14 | # List of parameters to generate the datasets. The benchmark will consider 15 | # the cross product for each key in the dictionary. 16 | # Any parameters 'param' defined here is available as `self.param`. 17 | parameters = { 18 | 'n_samples_source, n_samples_target': [(100, 100)], 19 | 'shift, label': [ 20 | ('covariate_shift', 'binary'), 21 | ('target_shift', 'binary'), 22 | ('concept_drift', 'binary'), 23 | ('subspace', 'binary'), 24 | ], 25 | 'random_state': list(range(5)) 26 | } 27 | 28 | def get_data(self): 29 | # The return arguments of this function are passed as keyword arguments 30 | # to `Objective.set_data`. This defines the benchmark's 31 | # API to pass data. It is customizable for each benchmark. 32 | 33 | # Generate pseudorandom data using `numpy`. 34 | if self.shift == "subspace": 35 | m = 3 36 | noise = 0.4 37 | elif self.shift == "covariate_shift": 38 | m = 1 39 | noise = 0.6 40 | else: 41 | m = 1 42 | noise = 0.8 43 | X, y, sample_domain = make_shifted_datasets( 44 | n_samples_source=m*self.n_samples_source, 45 | n_samples_target=m*self.n_samples_target, 46 | shift=self.shift, 47 | noise=noise, 48 | label=self.label, 49 | center_cov_shift=(-0.4, 3), 50 | random_state=self.random_state, 51 | ) 52 | 53 | X = X.astype(np.float32) 54 | y = y.astype(np.int64) 55 | sample_domain = sample_domain.astype(np.float32) 56 | 57 | return dict( 58 | X=X, 59 | y=y, 60 | sample_domain=sample_domain, 61 | ) 62 | -------------------------------------------------------------------------------- /solvers/jdot_svc.py: -------------------------------------------------------------------------------- 1 | from benchopt import safe_import_context 2 | 3 | # Protect the import with `safe_import_context()`. This allows: 4 | # - skipping import to speed up autocompletion in CLI. 5 | # - getting requirements info when all dependencies are not installed. 6 | with safe_import_context() as import_ctx: 7 | from skada import JDOTClassifier, make_da_pipeline 8 | from skada.transformers import StratifiedDomainSubsampler 9 | from benchmark_utils.base_solver import DASolver, FinalEstimator 10 | 11 | from benchmark_utils.base_solver import import_ctx as base_import_ctx 12 | if base_import_ctx.failed_import: 13 | exc, val, tb = base_import_ctx.import_error 14 | raise exc(val).with_traceback(tb) 15 | 16 | 17 | # The benchmark solvers must be named `Solver` and 18 | # inherit from `BaseSolver` for `benchopt` to work properly. 19 | class Solver(DASolver): 20 | 21 | # Name to select the solver in the CLI and to display the results. 22 | name = 'JDOT_SVC' 23 | 24 | requirements = [ 25 | "pip:POT", 26 | ] 27 | 28 | # List of parameters for the solver. The benchmark will consider 29 | # the cross product for each key in the dictionary. 30 | # All parameters 'p' defined here are available as 'self.p'. 31 | default_param_grid = { 32 | 'jdotclassifier__alpha': [0.1, 0.3, 0.5, 0.7, 0.9], 33 | 'jdotclassifier__n_iter_max': [100], 34 | 'jdotclassifier__tol': [1e-6], 35 | 'jdotclassifier__thr_weights': [1e-7], 36 | 'jdotclassifier__base_estimator__estimator_name': ["SVC"], 37 | } 38 | 39 | test_param_grid = { 40 | "jdotclassifier__base_estimator__estimator_name": ["SVC"], 41 | "jdotclassifier__n_iter_max": [10] 42 | } 43 | 44 | def get_estimator(self, **kwargs): 45 | # The estimator passed should have a 'predict_proba' method. 46 | subsampler = StratifiedDomainSubsampler( 47 | train_size=1000 48 | ) 49 | 50 | return make_da_pipeline( 51 | subsampler, 52 | JDOTClassifier(base_estimator=FinalEstimator(), 53 | metric='hinge') 54 | .set_fit_request(sample_weight=True) 55 | .set_score_request(sample_weight=True), 56 | ) 57 | -------------------------------------------------------------------------------- /solvers/deep_spa.py: -------------------------------------------------------------------------------- 1 | from benchopt import safe_import_context 2 | 3 | # Protect the import with `safe_import_context()`. This allows: 4 | # - skipping import to speed up autocompletion in CLI. 5 | # - getting requirements info when all dependencies are not installed. 6 | with safe_import_context() as import_ctx: 7 | from benchmark_utils.deep_base_solver import DeepDASolver 8 | from benchmark_utils.utils import get_params_per_dataset 9 | from benchmark_utils.backbones_architecture import DomainClassifier 10 | from skada.deep import SPA, SPALoss 11 | 12 | from benchmark_utils.deep_base_solver import import_ctx as base_import_ctx 13 | if base_import_ctx.failed_import: 14 | exc, val, tb = base_import_ctx.import_error 15 | raise exc(val).with_traceback(tb) 16 | 17 | if import_ctx.failed_import: 18 | class SPALoss: # noqa: F811 19 | def __init__(self, reg_adv, reg_nap): pass 20 | 21 | 22 | # The benchmark solvers must be named `Solver` and 23 | # inherit from `BaseSolver` for `benchopt` to work properly. 24 | class Solver(DeepDASolver): 25 | # Name to select the solver in the CLI and to display the results. 26 | name = 'deep_spa' 27 | 28 | # List of parameters for the solver. The benchmark will consider 29 | # the cross product for each key in the dictionary. 30 | # All parameters 'p' defined here are available as 'self.p'. 31 | default_param_grid = { 32 | 'criterion__reg': [1e-3, 1e-2, 1e-1, 1], 33 | } 34 | 35 | def get_estimator(self, n_classes, device, dataset_name, **kwargs): 36 | dataset_name = dataset_name.split("[")[0].lower() 37 | 38 | params = get_params_per_dataset( 39 | dataset_name, n_classes, 40 | ) 41 | # Reduce learning rate and increase momentum 42 | params['lr'] = params['lr'] * 0.1 43 | if 'optimizer__momentum' in params: 44 | params['optimizer__momentum'] = 0.9 45 | 46 | net = SPA( 47 | **params, 48 | reg_adv=1, 49 | reg_gsa=1, 50 | reg_nap=0, 51 | layer_name="feature_layer", 52 | train_split=None, 53 | device=device, 54 | warm_start=True, 55 | domain_classifier=DomainClassifier( 56 | num_features=params['module'].n_features 57 | ), 58 | ) 59 | 60 | return net 61 | -------------------------------------------------------------------------------- /solvers/transfer_subspace_learning.py: -------------------------------------------------------------------------------- 1 | from benchopt import safe_import_context 2 | 3 | # Protect the import with `safe_import_context()`. This allows: 4 | # - skipping import to speed up autocompletion in CLI. 5 | # - getting requirements info when all dependencies are not installed. 6 | with safe_import_context() as import_ctx: 7 | from benchmark_utils.base_solver import DASolver, FinalEstimator 8 | from skada import TransferSubspaceLearningAdapter, make_da_pipeline 9 | from skada.transformers import StratifiedDomainSubsampler 10 | import torch # noqa: F401 11 | 12 | from benchmark_utils.base_solver import import_ctx as base_import_ctx 13 | if base_import_ctx.failed_import: 14 | exc, val, tb = base_import_ctx.import_error 15 | raise exc(val).with_traceback(tb) 16 | 17 | 18 | # The benchmark solvers must be named `Solver` and 19 | # inherit from `BaseSolver` for `benchopt` to work properly. 20 | class Solver(DASolver): 21 | # Name to select the solver in the CLI and to display the results. 22 | name = 'transfer_subspace_learning' 23 | 24 | # TransferSubspaceLearningAdapter requires torch 25 | requirements = DASolver.requirements + ['pip:torch'] 26 | 27 | # List of parameters for the solver. The benchmark will consider 28 | # the cross product for each key in the dictionary. 29 | # All parameters 'p' defined here are available as 'self.p'. 30 | default_param_grid = { 31 | 'transfersubspacelearningadapter__n_components': [ 32 | 1, 2, 5, 10, 20, 50, 100 33 | ], 34 | 'transfersubspacelearningadapter__base_method': ['flda'], 35 | 'transfersubspacelearningadapter__length_scale': [2], 36 | 'transfersubspacelearningadapter__mu': [0.1, 1, 10], 37 | 'transfersubspacelearningadapter__reg': [1e-4], 38 | 'transfersubspacelearningadapter__max_iter': [300], 39 | 'transfersubspacelearningadapter__tol': [1e-4], 40 | 'finalestimator__estimator_name': ["LR", "SVC", "XGB"], 41 | } 42 | 43 | def get_estimator(self, **kwargs): 44 | # The estimator passed should have a 'predict_proba' method. 45 | subsampler = StratifiedDomainSubsampler( 46 | train_size=1000 47 | ) 48 | 49 | return make_da_pipeline( 50 | subsampler, 51 | TransferSubspaceLearningAdapter(), 52 | FinalEstimator(), 53 | ) 54 | -------------------------------------------------------------------------------- /benchmark_utils/preprocessing/preprocess_twentynewsgroups.py: -------------------------------------------------------------------------------- 1 | """ File to preprocess 20NewsGroups Dataset. 2 | - Download the dataset 3 | - Vectorize the text data using MinHashEncoder and sentence_transformers 4 | - Store the preprocessed data in a dictionary and save it in a pickle file. 5 | - The pickle file is stored in the datasets folder. 6 | """ 7 | 8 | from pathlib import Path 9 | import pickle 10 | from sentence_transformers import SentenceTransformer 11 | from sklearn.datasets import fetch_20newsgroups 12 | from sklearn.decomposition import PCA 13 | from skrub import MinHashEncoder 14 | 15 | 16 | if __name__ == "__main__": 17 | # MinHashEncoder hyperparameters 18 | N_COMPONENTS = 50 19 | NGRAM_RANGE = (2, 10) 20 | 21 | # Sentence Transformers hyperparameters 22 | MODEL_NAME = "BAAI/bge-large-en-v1.5" 23 | BATCH_SIZE = 512 24 | # MODEL_NAME = "BAAI/bge-small-en-v1.5" 25 | # BATCH_SIZE = 1024 26 | DEVICE = "cuda" 27 | 28 | # Set download_if_missing to True if not downloaded yet 29 | data = fetch_20newsgroups(download_if_missing=True, subset="all") 30 | 31 | # MinHashEncoder 32 | print("Encoding text data using MinHashEncoder...") 33 | vectorizer = MinHashEncoder( 34 | n_components=N_COMPONENTS, ngram_range=NGRAM_RANGE, n_jobs=-1 35 | ) 36 | X_min_hash = vectorizer.fit_transform([[d] for d in data.data]) 37 | 38 | # Sentence Transformers 39 | print("Encoding text data using Sentence Transformers...") 40 | model = SentenceTransformer(MODEL_NAME, device=DEVICE) 41 | X_sentence_transformers = model.encode( 42 | data.data, batch_size=BATCH_SIZE, show_progress_bar=True 43 | ) 44 | X_sentence_transformers = X_sentence_transformers.astype("float64") 45 | 46 | # Apply PCA to reduce the dimensionality of the embeddings 47 | print("Applying PCA...") 48 | pca = PCA(n_components=N_COMPONENTS) 49 | X_sentence_transformers = pca.fit_transform(X_sentence_transformers) 50 | 51 | # Save the preprocessed data 52 | print("Saving preprocessed data...") 53 | preprocessed_data = { 54 | "raw_data": data.data, 55 | "min_hash": X_min_hash, 56 | "sentence_transformers": X_sentence_transformers, 57 | } 58 | 59 | # Save the preprocessed data in a pickle file 60 | path = Path('data') 61 | path.mkdir(exist_ok=True) 62 | with open(path / "20newsgroups_preprocessed.pkl", "wb") as f: 63 | pickle.dump(preprocessed_data, f) 64 | -------------------------------------------------------------------------------- /solvers/deep_mdd.py: -------------------------------------------------------------------------------- 1 | from benchopt import safe_import_context 2 | 3 | # Protect the import with `safe_import_context()`. This allows: 4 | # - skipping import to speed up autocompletion in CLI. 5 | # - getting requirements info when all dependencies are not installed. 6 | with safe_import_context() as import_ctx: 7 | from benchmark_utils.deep_base_solver import DeepDASolver 8 | from benchmark_utils.utils import get_params_per_dataset 9 | from benchmark_utils.backbones_architecture import DiscrepancyClassifier 10 | from skada.deep import MDD, MDDLoss 11 | 12 | from benchmark_utils.deep_base_solver import import_ctx as base_import_ctx 13 | if base_import_ctx.failed_import: 14 | exc, val, tb = base_import_ctx.import_error 15 | raise exc(val).with_traceback(tb) 16 | 17 | 18 | if import_ctx.failed_import: 19 | class MDDLoss: # noqa: F811 20 | def __init__(self, gamma): pass 21 | 22 | 23 | # The benchmark solvers must be named `Solver` and 24 | # inherit from `BaseSolver` for `benchopt` to work properly. 25 | class Solver(DeepDASolver): 26 | # Name to select the solver in the CLI and to display the results. 27 | name = 'deep_mdd' 28 | 29 | # List of parameters for the solver. The benchmark will consider 30 | # the cross product for each key in the dictionary. 31 | # All parameters 'p' defined here are available as 'self.p'. 32 | default_param_grid = { 33 | 'criterion__reg': [1e-3, 1e-2, 1e-1], 34 | 'criterion__adapt_criterion': [ 35 | MDDLoss(gamma=gamma) 36 | for gamma in [1., 3.] 37 | ], 38 | } 39 | 40 | def get_estimator(self, n_classes, device, dataset_name, **kwargs): 41 | dataset_name = dataset_name.split("[")[0].lower() 42 | 43 | params = get_params_per_dataset( 44 | dataset_name, n_classes, 45 | ) 46 | # Reduce learning rate and increase momentum 47 | params['lr'] = params['lr'] * 0.1 48 | if 'optimizer__momentum' in params: 49 | params['optimizer__momentum'] = 0.9 50 | 51 | net = MDD( 52 | **params, 53 | layer_name="feature_layer", 54 | train_split=None, 55 | device=device, 56 | warm_start=True, 57 | disc_classifier=DiscrepancyClassifier( 58 | num_features=params['module'].n_features, 59 | n_classes=n_classes, 60 | ), 61 | ) 62 | 63 | return net 64 | -------------------------------------------------------------------------------- /solvers/deep_can.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from benchopt import safe_import_context 3 | 4 | # Protect the import with `safe_import_context()`. This allows: 5 | # - skipping import to speed up autocompletion in CLI. 6 | # - getting requirements info when all dependencies are not installed. 7 | with safe_import_context() as import_ctx: 8 | from benchmark_utils.deep_base_solver import DeepDASolver 9 | from benchmark_utils.utils import get_params_per_dataset 10 | from skada.deep import CAN, CANLoss 11 | 12 | from benchmark_utils.deep_base_solver import import_ctx as base_import_ctx 13 | if base_import_ctx.failed_import: 14 | exc, val, tb = base_import_ctx.import_error 15 | raise exc(val).with_traceback(tb) 16 | 17 | if import_ctx.failed_import: 18 | class CANLoss: # noqa: F811 19 | def __init__(self, distance_threshold, class_threshold): pass 20 | 21 | 22 | # The benchmark solvers must be named `Solver` and 23 | # inherit from `BaseSolver` for `benchopt` to work properly. 24 | class Solver(DeepDASolver): 25 | # Name to select the solver in the CLI and to display the results. 26 | name = 'deep_can' 27 | 28 | # List of parameters for the solver. The benchmark will consider 29 | # the cross product for each key in the dictionary. 30 | # All parameters 'p' defined here are available as 'self.p'. 31 | default_param_grid = { 32 | 'criterion__reg': [3e-1], 33 | 'criterion__adapt_criterion': [ 34 | CANLoss( 35 | distance_threshold=distance_threshold, 36 | class_threshold=class_threshold, 37 | ) 38 | for distance_threshold, class_threshold in itertools.product( 39 | [5e-3, 5e-2, 5e-1], 40 | [1, 3, 5], 41 | ) 42 | ] 43 | } 44 | 45 | def get_estimator(self, n_classes, device, dataset_name, **kwargs): 46 | dataset_name = dataset_name.split("[")[0].lower() 47 | 48 | params = get_params_per_dataset( 49 | dataset_name, n_classes, 50 | ) 51 | # Reduce learning rate and increase momentum 52 | params['lr'] = params['lr'] * 0.1 53 | if 'optimizer__momentum' in params: 54 | params['optimizer__momentum'] = 0.9 55 | 56 | net = CAN( 57 | **params, 58 | layer_name="feature_layer", 59 | train_split=None, 60 | device=device, 61 | warm_start=True, 62 | ) 63 | 64 | return net 65 | -------------------------------------------------------------------------------- /solvers/deep_dann.py: -------------------------------------------------------------------------------- 1 | from benchopt import safe_import_context 2 | 3 | # Protect the import with `safe_import_context()`. This allows: 4 | # - skipping import to speed up autocompletion in CLI. 5 | # - getting requirements info when all dependencies are not installed. 6 | with safe_import_context() as import_ctx: 7 | from benchmark_utils.deep_base_solver import DeepDASolver 8 | from benchmark_utils.utils import get_params_per_dataset 9 | from benchmark_utils.backbones_architecture import DomainClassifier 10 | from skada.deep import DANN 11 | 12 | from benchmark_utils.deep_base_solver import import_ctx as base_import_ctx 13 | if base_import_ctx.failed_import: 14 | exc, val, tb = base_import_ctx.import_error 15 | raise exc(val).with_traceback(tb) 16 | 17 | 18 | # The benchmark solvers must be named `Solver` and 19 | # inherit from `BaseSolver` for `benchopt` to work properly. 20 | class Solver(DeepDASolver): 21 | # Name to select the solver in the CLI and to display the results. 22 | name = 'deep_dann' 23 | 24 | # List of parameters for the solver. The benchmark will consider 25 | # the cross product for each key in the dictionary. 26 | # All parameters 'p' defined here are available as 'self.p'. 27 | default_param_grid = { 28 | # 'criterion__reg': np.logspace(-3, 0, 4), 29 | 'criterion__reg': [1e-3, 1e-2, 1e-1, 1], 30 | } 31 | 32 | def get_estimator(self, n_classes, device, dataset_name, **kwargs): 33 | 34 | dataset_name = dataset_name.split("[")[0].lower() 35 | 36 | params = get_params_per_dataset( 37 | dataset_name, n_classes, 38 | ) 39 | # Reduce learning rate and increase momentum 40 | params['lr'] = params['lr'] * 0.1 41 | if 'optimizer__momentum' in params: 42 | params['optimizer__momentum'] = 0.9 43 | 44 | net = DANN( 45 | **params, 46 | layer_name="feature_layer", 47 | train_split=None, 48 | device=device, 49 | domain_classifier=DomainClassifier( 50 | num_features=params['module'].n_features 51 | ), 52 | warm_start=True, 53 | # optimizer__param_groups=[ 54 | # ('base_module_.feature_layer*', {'lr': params['lr']}), 55 | # ('base_module_.final_layer*', {'lr': params['lr']}), 56 | # ('domain_classifier_*', {'lr': params['lr'],}), 57 | # ] 58 | ) 59 | 60 | return net 61 | -------------------------------------------------------------------------------- /datasets/office31_decaf.py: -------------------------------------------------------------------------------- 1 | from benchopt import BaseDataset, safe_import_context 2 | 3 | with safe_import_context() as import_ctx: 4 | import numpy as np 5 | from sklearn.decomposition import PCA 6 | from sklearn.preprocessing import LabelEncoder 7 | from skada.utils import source_target_merge 8 | from skada.datasets import fetch_office31_decaf_all 9 | 10 | 11 | # All datasets must be named `Dataset` and inherit from `BaseDataset` 12 | class Dataset(BaseDataset): 13 | # Name to select the dataset in the CLI and to display the results. 14 | name = "Office31Decaf" 15 | 16 | # List of parameters to generate the datasets. The benchmark will consider 17 | # the cross product for each key in the dictionary. 18 | # Any parameters 'param' defined here is available as `self.param`. 19 | parameters = { 20 | "source_target": [ 21 | ("dslr", "webcam"), 22 | ("dslr", "amazon"), 23 | ("webcam", "dslr"), 24 | ("webcam", "amazon"), 25 | ("amazon", "dslr"), 26 | ("amazon", "webcam"), 27 | ], 28 | "n_components": [100], 29 | } 30 | 31 | def get_data(self): 32 | # The return arguments of this function are passed as keyword arguments 33 | # to `Objective.set_data`. This defines the benchmark's 34 | # API to pass data. It is customizable for each benchmark. 35 | 36 | tmp_folder = "./data/OFFICE_31_DECAF_DATASET/" 37 | dataset = fetch_office31_decaf_all( 38 | # categories=Office31CategoriesPreset.CALTECH256, 39 | data_home=tmp_folder 40 | ) 41 | 42 | # Fit PCA on all domains 43 | domains = dataset.domain_names_.keys() 44 | X_total = np.concatenate([dataset.get_domain(d)[0] for d in domains]) 45 | pca = PCA(n_components=self.n_components).fit(X_total) 46 | 47 | # Get source and target data and apply PCA 48 | source = self.source_target[0] 49 | target = self.source_target[1] 50 | 51 | X_source, y_source = dataset.get_domain(source) 52 | X_target, y_target = dataset.get_domain(target) 53 | X_source = pca.transform(X_source) 54 | X_target = pca.transform(X_target) 55 | 56 | # XGBoost only supports labels in [0, num_classes-1] 57 | le = LabelEncoder() 58 | le.fit(np.concatenate([y_source, y_target])) 59 | y_source = le.transform(y_source) 60 | y_target = le.transform(y_target) 61 | 62 | X, y, sample_domain = source_target_merge( 63 | X_source, X_target, y_source, y_target 64 | ) 65 | 66 | return dict( 67 | X=X, 68 | y=y, 69 | sample_domain=sample_domain, 70 | ) 71 | --------------------------------------------------------------------------------