├── .gitignore ├── LICENSE ├── README.md ├── config ├── cifar10 │ ├── exp_cifar10_wrn2810-2_cutmixmo-p5_msdacutmix_bar4.yaml │ ├── exp_cifar10_wrn2810-2_cutmixmo-p5_standard_bar4.yaml │ ├── exp_cifar10_wrn2810-2_linearmixmo_msdacutmix_bar4.yaml │ ├── exp_cifar10_wrn2810-2_linearmixmo_standard_bar4.yaml │ ├── exp_cifar10_wrn2810-2_mimo_standard_bar4.yaml │ ├── exp_cifar10_wrn2810_1net_msdacutmix_bar1.yaml │ ├── exp_cifar10_wrn2810_1net_msdamixup_bar1.yaml │ └── exp_cifar10_wrn2810_1net_standard_bar1.yaml ├── cifar100 │ ├── exp_cifar100_wrn2810-2_cutmixmo-p5_msdacutmix_bar4.yaml │ ├── exp_cifar100_wrn2810-2_cutmixmo-p5_standard_bar4.yaml │ ├── exp_cifar100_wrn2810-2_linearmixmo_msdacutmix_bar4.yaml │ ├── exp_cifar100_wrn2810-2_linearmixmo_standard_bar4.yaml │ ├── exp_cifar100_wrn2810-2_mimo_standard_bar4.yaml │ ├── exp_cifar100_wrn2810_1net_msdacutmix_bar1.yaml │ ├── exp_cifar100_wrn2810_1net_msdamixup_bar1.yaml │ └── exp_cifar100_wrn2810_1net_standard_bar1.yaml └── tiny │ ├── exp_tinyimagenet_res182-2_cutmixmo-p5_standard_bar4.yaml │ ├── exp_tinyimagenet_res182-2_linearmixmo_standard_bar4.yaml │ ├── exp_tinyimagenet_res182_1net_standard_bar1.yaml │ ├── exp_tinyimagenet_res183-2_cutmixmo-p5_standard_bar4.yaml │ ├── exp_tinyimagenet_res183-2_linearmixmo_standard_bar4.yaml │ ├── exp_tinyimagenet_res183_1net_standard_bar1.yaml │ ├── exp_tinyimagenet_res18_1net_msdacutmix_bar1.yaml │ ├── exp_tinyimagenet_res18_1net_msdamixup_bar1.yaml │ └── exp_tinyimagenet_res18_1net_standard_bar1.yaml ├── docs ├── .buildinfo ├── .nojekyll ├── _autosummary _autofunction │ ├── mixmo.augmentations.augmix.html │ ├── mixmo.augmentations.html │ ├── mixmo.augmentations.mixing_blocks.html │ ├── mixmo.augmentations.standard_augmentations.html │ ├── mixmo.core.html │ ├── mixmo.core.loss.html │ ├── mixmo.core.metrics_ensemble.html │ ├── mixmo.core.metrics_wrapper.html │ ├── mixmo.core.optimizer.html │ ├── mixmo.core.scheduler.html │ ├── mixmo.core.temperature_scaling.html │ ├── mixmo.html │ ├── mixmo.loaders.abstract_loader.html │ ├── mixmo.loaders.batch_repetition_sampler.html │ ├── mixmo.loaders.cifar_dataset.html │ ├── mixmo.loaders.dataset_wrapper.html │ ├── mixmo.loaders.html │ ├── mixmo.loaders.loader.html │ ├── mixmo.networks.html │ ├── mixmo.networks.resnet.html │ ├── mixmo.networks.wrn.html │ ├── mixmo.utils.config.html │ ├── mixmo.utils.html │ ├── mixmo.utils.logger.html │ ├── mixmo.utils.misc.html │ ├── mixmo.utils.torchsummary.html │ ├── mixmo.utils.torchutils.html │ └── mixmo.utils.visualize.html ├── _autosummary │ ├── mixmo.augmentations.augmix.AugMix.html │ ├── mixmo.augmentations.augmix.AutoContrast.html │ ├── mixmo.augmentations.augmix.Brightness.html │ ├── mixmo.augmentations.augmix.Color.html │ ├── mixmo.augmentations.augmix.Contrast.html │ ├── mixmo.augmentations.augmix.Equalize.html │ ├── mixmo.augmentations.augmix.Identity.html │ ├── mixmo.augmentations.augmix.Invert.html │ ├── mixmo.augmentations.augmix.Posterize.html │ ├── mixmo.augmentations.augmix.Rotate.html │ ├── mixmo.augmentations.augmix.Sharpness.html │ ├── mixmo.augmentations.augmix.ShearX.html │ ├── mixmo.augmentations.augmix.ShearY.html │ ├── mixmo.augmentations.augmix.Solarize.html │ ├── mixmo.augmentations.augmix.TranslateX.html │ ├── mixmo.augmentations.augmix.TranslateY.html │ ├── mixmo.augmentations.augmix.augment_list.html │ ├── mixmo.augmentations.augmix.get_value_when_none.html │ ├── mixmo.augmentations.augmix.html │ ├── mixmo.augmentations.html │ ├── mixmo.augmentations.mixing_blocks._channel_mask.html │ ├── mixmo.augmentations.mixing_blocks._cow_mask.html │ ├── mixmo.augmentations.mixing_blocks._cutmix_mask.html │ ├── mixmo.augmentations.mixing_blocks._gaussian_blur_kernel.html │ ├── mixmo.augmentations.mixing_blocks._mixup_mask.html │ ├── mixmo.augmentations.mixing_blocks._n_cutinmix_mask.html │ ├── mixmo.augmentations.mixing_blocks._n_mix.html │ ├── mixmo.augmentations.mixing_blocks._n_mixup_mask.html │ ├── mixmo.augmentations.mixing_blocks._noise_mask.html │ ├── mixmo.augmentations.mixing_blocks._patchup_mask.html │ ├── mixmo.augmentations.mixing_blocks._patchuphard2d_mask.html │ ├── mixmo.augmentations.mixing_blocks._patchupsoft_mask.html │ ├── mixmo.augmentations.mixing_blocks._rand_bbox_of_area_lam.html │ ├── mixmo.augmentations.mixing_blocks._single_mix.html │ ├── mixmo.augmentations.mixing_blocks._stack0_mask.html │ ├── mixmo.augmentations.mixing_blocks._stack2_mask.html │ ├── mixmo.augmentations.mixing_blocks._stack_mask.html │ ├── mixmo.augmentations.mixing_blocks.html │ ├── mixmo.augmentations.mixing_blocks.mix.html │ ├── mixmo.augmentations.mixing_blocks.mix_manifolds.html │ ├── mixmo.augmentations.standard_augmentations.CustomCompose.html │ ├── mixmo.augmentations.standard_augmentations.get_default_composed_augmentations.html │ ├── mixmo.augmentations.standard_augmentations.html │ ├── mixmo.core.html │ ├── mixmo.core.loss.AbstractLoss.html │ ├── mixmo.core.loss.SoftCrossEntropyLoss.html │ ├── mixmo.core.loss.WrapperLoss.html │ ├── mixmo.core.loss.get_loss.html │ ├── mixmo.core.loss.html │ ├── mixmo.core.metrics_ensemble.MetricsEnsemble.html │ ├── mixmo.core.metrics_ensemble._process_predictions.html │ ├── mixmo.core.metrics_ensemble.compute_mean_without_diagonal.html │ ├── mixmo.core.metrics_ensemble.compute_pairwise_diversity.html │ ├── mixmo.core.metrics_ensemble.get_accuracy_multi.html │ ├── mixmo.core.metrics_ensemble.html │ ├── mixmo.core.metrics_ensemble.ratio_errors.html │ ├── mixmo.core.metrics_wrapper.MetricsWrapper.html │ ├── mixmo.core.metrics_wrapper._clean_metrics.html │ ├── mixmo.core.metrics_wrapper.get_brier.html │ ├── mixmo.core.metrics_wrapper.get_ece.html │ ├── mixmo.core.metrics_wrapper.get_ll.html │ ├── mixmo.core.metrics_wrapper.get_tace_bayesgroup.html │ ├── mixmo.core.metrics_wrapper.html │ ├── mixmo.core.metrics_wrapper.merge_scores.html │ ├── mixmo.core.metrics_wrapper.show_metrics.html │ ├── mixmo.core.optimizer.get_optimizer.html │ ├── mixmo.core.optimizer.html │ ├── mixmo.core.scheduler.GradualWarmupScheduler.html │ ├── mixmo.core.scheduler.MultiGammaStepLR.html │ ├── mixmo.core.scheduler.get_scheduler.html │ ├── mixmo.core.scheduler.get_warmup_scheduler.html │ ├── mixmo.core.scheduler.html │ ├── mixmo.core.temperature_scaling.NetworkWithTemperature.html │ ├── mixmo.core.temperature_scaling.apply_temperature_on_logits.html │ ├── mixmo.core.temperature_scaling.html │ ├── mixmo.html │ ├── mixmo.learners.abstract_learner.AbstractLearner.html │ ├── mixmo.learners.abstract_learner.html │ ├── mixmo.learners.html │ ├── mixmo.learners.learner.Learner.html │ ├── mixmo.learners.learner.html │ ├── mixmo.learners.model_wrapper.ModelWrapper.html │ ├── mixmo.learners.model_wrapper.get_predictions.html │ ├── mixmo.learners.model_wrapper.html │ ├── mixmo.loaders.abstract_loader.AbstractDataLoader.html │ ├── mixmo.loaders.abstract_loader.html │ ├── mixmo.loaders.batch_repetition_sampler.BatchRepetitionSampler.html │ ├── mixmo.loaders.batch_repetition_sampler.html │ ├── mixmo.loaders.cifar_dataset.CIFARCorruptions.html │ ├── mixmo.loaders.cifar_dataset.CustomCIFAR10.html │ ├── mixmo.loaders.cifar_dataset.CustomCIFAR100.html │ ├── mixmo.loaders.cifar_dataset.html │ ├── mixmo.loaders.dataset_wrapper.DADataset.html │ ├── mixmo.loaders.dataset_wrapper.MSDADataset.html │ ├── mixmo.loaders.dataset_wrapper.MixMoDataset.html │ ├── mixmo.loaders.dataset_wrapper.html │ ├── mixmo.loaders.get_loader.html │ ├── mixmo.loaders.html │ ├── mixmo.loaders.loader.CIFAR100Loader.html │ ├── mixmo.loaders.loader.CIFAR10Loader.html │ ├── mixmo.loaders.loader.TinyImagenet200Loader.html │ ├── mixmo.loaders.loader.html │ ├── mixmo.loaders.script_load_tiny_data.create_val_folder.html │ ├── mixmo.loaders.script_load_tiny_data.html │ ├── mixmo.loaders.script_load_tiny_data.main.html │ ├── mixmo.networks.get_network.html │ ├── mixmo.networks.html │ ├── mixmo.networks.resnet.PreActBlock.html │ ├── mixmo.networks.resnet.PreActResNet.html │ ├── mixmo.networks.resnet.PreActResNetMixMo.html │ ├── mixmo.networks.resnet.html │ ├── mixmo.networks.wrn.WideBasic.html │ ├── mixmo.networks.wrn.WideResNet.html │ ├── mixmo.networks.wrn.WideResNetMixMo.html │ ├── mixmo.networks.wrn.html │ ├── mixmo.utils.config.html │ ├── mixmo.utils.html │ ├── mixmo.utils.logger.get_logger.html │ ├── mixmo.utils.logger.html │ ├── mixmo.utils.misc._print_dict.html │ ├── mixmo.utils.misc.clean_startswith.html │ ├── mixmo.utils.misc.clean_update.html │ ├── mixmo.utils.misc.csv_writter.html │ ├── mixmo.utils.misc.dirty_update.html │ ├── mixmo.utils.misc.find_best_epoch.html │ ├── mixmo.utils.misc.get_checkpoint.html │ ├── mixmo.utils.misc.get_logs_path.html │ ├── mixmo.utils.misc.get_model_path.html │ ├── mixmo.utils.misc.get_nprandom.html │ ├── mixmo.utils.misc.get_output_folder_from_config.html │ ├── mixmo.utils.misc.get_previous_ckpt.html │ ├── mixmo.utils.misc.get_random.html │ ├── mixmo.utils.misc.html │ ├── mixmo.utils.misc.ifnotfound_update.html │ ├── mixmo.utils.misc.is_float.html │ ├── mixmo.utils.misc.is_int.html │ ├── mixmo.utils.misc.is_nan.html │ ├── mixmo.utils.misc.is_none.html │ ├── mixmo.utils.misc.is_zero.html │ ├── mixmo.utils.misc.load_config_yaml.html │ ├── mixmo.utils.misc.load_yaml.html │ ├── mixmo.utils.misc.print_args.html │ ├── mixmo.utils.misc.print_dict.html │ ├── mixmo.utils.misc.random_lower_than.html │ ├── mixmo.utils.misc.sample_lams.html │ ├── mixmo.utils.misc.set_determ.html │ ├── mixmo.utils.misc.update.html │ ├── mixmo.utils.torchsummary.html │ ├── mixmo.utils.torchsummary.summary.html │ ├── mixmo.utils.torchutils._calculate_correct_fan.html │ ├── mixmo.utils.torchutils._calculate_fan_in_and_fan_out.html │ ├── mixmo.utils.torchutils.html │ ├── mixmo.utils.torchutils.kaiming_normal_truncated.html │ ├── mixmo.utils.torchutils.onehot.html │ ├── mixmo.utils.torchutils.randperm_static.html │ ├── mixmo.utils.torchutils.truncated_normal_.html │ ├── mixmo.utils.torchutils.weights_init_hetruncatednormal.html │ ├── mixmo.utils.visualize.html │ ├── mixmo.utils.visualize.make_image.html │ ├── mixmo.utils.visualize.show_batch.html │ └── mixmo.utils.visualize.write_calibration.html ├── _modules │ ├── index.html │ └── mixmo │ │ ├── augmentations │ │ ├── augmix.html │ │ ├── mixing_blocks.html │ │ └── standard_augmentations.html │ │ ├── core │ │ ├── loss.html │ │ ├── metrics_ensemble.html │ │ ├── metrics_wrapper.html │ │ ├── optimizer.html │ │ ├── scheduler.html │ │ └── temperature_scaling.html │ │ ├── learners │ │ ├── abstract_learner.html │ │ ├── learner.html │ │ └── model_wrapper.html │ │ ├── loaders.html │ │ ├── loaders │ │ ├── abstract_loader.html │ │ ├── batch_repetition_sampler.html │ │ ├── cifar_dataset.html │ │ ├── dataset_wrapper.html │ │ └── loader.html │ │ ├── networks.html │ │ ├── networks │ │ ├── resnet.html │ │ └── wrn.html │ │ └── utils │ │ ├── logger.html │ │ ├── misc.html │ │ ├── torchsummary.html │ │ ├── torchutils.html │ │ └── visualize.html ├── _sources │ ├── _autosummary _autofunction │ │ ├── mixmo.augmentations.augmix.rst.txt │ │ ├── mixmo.augmentations.mixing_blocks.rst.txt │ │ ├── mixmo.augmentations.rst.txt │ │ ├── mixmo.augmentations.standard_augmentations.rst.txt │ │ ├── mixmo.core.loss.rst.txt │ │ ├── mixmo.core.metrics_ensemble.rst.txt │ │ ├── mixmo.core.metrics_wrapper.rst.txt │ │ ├── mixmo.core.optimizer.rst.txt │ │ ├── mixmo.core.rst.txt │ │ ├── mixmo.core.scheduler.rst.txt │ │ ├── mixmo.core.temperature_scaling.rst.txt │ │ ├── mixmo.loaders.abstract_loader.rst.txt │ │ ├── mixmo.loaders.batch_repetition_sampler.rst.txt │ │ ├── mixmo.loaders.cifar_dataset.rst.txt │ │ ├── mixmo.loaders.dataset_wrapper.rst.txt │ │ ├── mixmo.loaders.loader.rst.txt │ │ ├── mixmo.loaders.rst.txt │ │ ├── mixmo.networks.resnet.rst.txt │ │ ├── mixmo.networks.rst.txt │ │ ├── mixmo.networks.wrn.rst.txt │ │ ├── mixmo.rst.txt │ │ ├── mixmo.utils.config.rst.txt │ │ ├── mixmo.utils.logger.rst.txt │ │ ├── mixmo.utils.misc.rst.txt │ │ ├── mixmo.utils.rst.txt │ │ ├── mixmo.utils.torchsummary.rst.txt │ │ ├── mixmo.utils.torchutils.rst.txt │ │ └── mixmo.utils.visualize.rst.txt │ ├── _autosummary │ │ ├── mixmo.augmentations.augmix.AugMix.rst.txt │ │ ├── mixmo.augmentations.augmix.AutoContrast.rst.txt │ │ ├── mixmo.augmentations.augmix.Brightness.rst.txt │ │ ├── mixmo.augmentations.augmix.Color.rst.txt │ │ ├── mixmo.augmentations.augmix.Contrast.rst.txt │ │ ├── mixmo.augmentations.augmix.Equalize.rst.txt │ │ ├── mixmo.augmentations.augmix.Identity.rst.txt │ │ ├── mixmo.augmentations.augmix.Invert.rst.txt │ │ ├── mixmo.augmentations.augmix.Posterize.rst.txt │ │ ├── mixmo.augmentations.augmix.Rotate.rst.txt │ │ ├── mixmo.augmentations.augmix.Sharpness.rst.txt │ │ ├── mixmo.augmentations.augmix.ShearX.rst.txt │ │ ├── mixmo.augmentations.augmix.ShearY.rst.txt │ │ ├── mixmo.augmentations.augmix.Solarize.rst.txt │ │ ├── mixmo.augmentations.augmix.TranslateX.rst.txt │ │ ├── mixmo.augmentations.augmix.TranslateY.rst.txt │ │ ├── mixmo.augmentations.augmix.augment_list.rst.txt │ │ ├── mixmo.augmentations.augmix.get_value_when_none.rst.txt │ │ ├── mixmo.augmentations.augmix.rst.txt │ │ ├── mixmo.augmentations.mixing_blocks._channel_mask.rst.txt │ │ ├── mixmo.augmentations.mixing_blocks._cow_mask.rst.txt │ │ ├── mixmo.augmentations.mixing_blocks._cutmix_mask.rst.txt │ │ ├── mixmo.augmentations.mixing_blocks._gaussian_blur_kernel.rst.txt │ │ ├── mixmo.augmentations.mixing_blocks._mixup_mask.rst.txt │ │ ├── mixmo.augmentations.mixing_blocks._n_cutinmix_mask.rst.txt │ │ ├── mixmo.augmentations.mixing_blocks._n_mix.rst.txt │ │ ├── mixmo.augmentations.mixing_blocks._n_mixup_mask.rst.txt │ │ ├── mixmo.augmentations.mixing_blocks._noise_mask.rst.txt │ │ ├── mixmo.augmentations.mixing_blocks._patchup_mask.rst.txt │ │ ├── mixmo.augmentations.mixing_blocks._patchuphard2d_mask.rst.txt │ │ ├── mixmo.augmentations.mixing_blocks._patchupsoft_mask.rst.txt │ │ ├── mixmo.augmentations.mixing_blocks._rand_bbox_of_area_lam.rst.txt │ │ ├── mixmo.augmentations.mixing_blocks._single_mix.rst.txt │ │ ├── mixmo.augmentations.mixing_blocks._stack0_mask.rst.txt │ │ ├── mixmo.augmentations.mixing_blocks._stack2_mask.rst.txt │ │ ├── mixmo.augmentations.mixing_blocks._stack_mask.rst.txt │ │ ├── mixmo.augmentations.mixing_blocks.mix.rst.txt │ │ ├── mixmo.augmentations.mixing_blocks.mix_manifolds.rst.txt │ │ ├── mixmo.augmentations.mixing_blocks.rst.txt │ │ ├── mixmo.augmentations.rst.txt │ │ ├── mixmo.augmentations.standard_augmentations.CustomCompose.rst.txt │ │ ├── mixmo.augmentations.standard_augmentations.get_default_composed_augmentations.rst.txt │ │ ├── mixmo.augmentations.standard_augmentations.rst.txt │ │ ├── mixmo.core.loss.AbstractLoss.rst.txt │ │ ├── mixmo.core.loss.SoftCrossEntropyLoss.rst.txt │ │ ├── mixmo.core.loss.WrapperLoss.rst.txt │ │ ├── mixmo.core.loss.get_loss.rst.txt │ │ ├── mixmo.core.loss.rst.txt │ │ ├── mixmo.core.metrics_ensemble.MetricsEnsemble.rst.txt │ │ ├── mixmo.core.metrics_ensemble._process_predictions.rst.txt │ │ ├── mixmo.core.metrics_ensemble.compute_mean_without_diagonal.rst.txt │ │ ├── mixmo.core.metrics_ensemble.compute_pairwise_diversity.rst.txt │ │ ├── mixmo.core.metrics_ensemble.get_accuracy_multi.rst.txt │ │ ├── mixmo.core.metrics_ensemble.ratio_errors.rst.txt │ │ ├── mixmo.core.metrics_ensemble.rst.txt │ │ ├── mixmo.core.metrics_wrapper.MetricsWrapper.rst.txt │ │ ├── mixmo.core.metrics_wrapper._clean_metrics.rst.txt │ │ ├── mixmo.core.metrics_wrapper.get_brier.rst.txt │ │ ├── mixmo.core.metrics_wrapper.get_ece.rst.txt │ │ ├── mixmo.core.metrics_wrapper.get_ll.rst.txt │ │ ├── mixmo.core.metrics_wrapper.get_tace_bayesgroup.rst.txt │ │ ├── mixmo.core.metrics_wrapper.merge_scores.rst.txt │ │ ├── mixmo.core.metrics_wrapper.rst.txt │ │ ├── mixmo.core.metrics_wrapper.show_metrics.rst.txt │ │ ├── mixmo.core.optimizer.get_optimizer.rst.txt │ │ ├── mixmo.core.optimizer.rst.txt │ │ ├── mixmo.core.rst.txt │ │ ├── mixmo.core.scheduler.GradualWarmupScheduler.rst.txt │ │ ├── mixmo.core.scheduler.MultiGammaStepLR.rst.txt │ │ ├── mixmo.core.scheduler.get_scheduler.rst.txt │ │ ├── mixmo.core.scheduler.get_warmup_scheduler.rst.txt │ │ ├── mixmo.core.scheduler.rst.txt │ │ ├── mixmo.core.temperature_scaling.NetworkWithTemperature.rst.txt │ │ ├── mixmo.core.temperature_scaling.apply_temperature_on_logits.rst.txt │ │ ├── mixmo.core.temperature_scaling.rst.txt │ │ ├── mixmo.learners.abstract_learner.AbstractLearner.rst.txt │ │ ├── mixmo.learners.abstract_learner.rst.txt │ │ ├── mixmo.learners.learner.Learner.rst.txt │ │ ├── mixmo.learners.learner.rst.txt │ │ ├── mixmo.learners.model_wrapper.ModelWrapper.rst.txt │ │ ├── mixmo.learners.model_wrapper.get_predictions.rst.txt │ │ ├── mixmo.learners.model_wrapper.rst.txt │ │ ├── mixmo.learners.rst.txt │ │ ├── mixmo.loaders.abstract_loader.AbstractDataLoader.rst.txt │ │ ├── mixmo.loaders.abstract_loader.rst.txt │ │ ├── mixmo.loaders.batch_repetition_sampler.BatchRepetitionSampler.rst.txt │ │ ├── mixmo.loaders.batch_repetition_sampler.rst.txt │ │ ├── mixmo.loaders.cifar_dataset.CIFARCorruptions.rst.txt │ │ ├── mixmo.loaders.cifar_dataset.CustomCIFAR10.rst.txt │ │ ├── mixmo.loaders.cifar_dataset.CustomCIFAR100.rst.txt │ │ ├── mixmo.loaders.cifar_dataset.rst.txt │ │ ├── mixmo.loaders.dataset_wrapper.DADataset.rst.txt │ │ ├── mixmo.loaders.dataset_wrapper.MSDADataset.rst.txt │ │ ├── mixmo.loaders.dataset_wrapper.MixMoDataset.rst.txt │ │ ├── mixmo.loaders.dataset_wrapper.rst.txt │ │ ├── mixmo.loaders.get_loader.rst.txt │ │ ├── mixmo.loaders.loader.CIFAR100Loader.rst.txt │ │ ├── mixmo.loaders.loader.CIFAR10Loader.rst.txt │ │ ├── mixmo.loaders.loader.TinyImagenet200Loader.rst.txt │ │ ├── mixmo.loaders.loader.rst.txt │ │ ├── mixmo.loaders.rst.txt │ │ ├── mixmo.loaders.script_load_tiny_data.create_val_folder.rst.txt │ │ ├── mixmo.loaders.script_load_tiny_data.main.rst.txt │ │ ├── mixmo.loaders.script_load_tiny_data.rst.txt │ │ ├── mixmo.networks.get_network.rst.txt │ │ ├── mixmo.networks.resnet.PreActBlock.rst.txt │ │ ├── mixmo.networks.resnet.PreActResNet.rst.txt │ │ ├── mixmo.networks.resnet.PreActResNetMixMo.rst.txt │ │ ├── mixmo.networks.resnet.rst.txt │ │ ├── mixmo.networks.rst.txt │ │ ├── mixmo.networks.wrn.WideBasic.rst.txt │ │ ├── mixmo.networks.wrn.WideResNet.rst.txt │ │ ├── mixmo.networks.wrn.WideResNetMixMo.rst.txt │ │ ├── mixmo.networks.wrn.rst.txt │ │ ├── mixmo.rst.txt │ │ ├── mixmo.utils.config.rst.txt │ │ ├── mixmo.utils.logger.get_logger.rst.txt │ │ ├── mixmo.utils.logger.rst.txt │ │ ├── mixmo.utils.misc._print_dict.rst.txt │ │ ├── mixmo.utils.misc.clean_startswith.rst.txt │ │ ├── mixmo.utils.misc.clean_update.rst.txt │ │ ├── mixmo.utils.misc.csv_writter.rst.txt │ │ ├── mixmo.utils.misc.dirty_update.rst.txt │ │ ├── mixmo.utils.misc.find_best_epoch.rst.txt │ │ ├── mixmo.utils.misc.get_checkpoint.rst.txt │ │ ├── mixmo.utils.misc.get_logs_path.rst.txt │ │ ├── mixmo.utils.misc.get_model_path.rst.txt │ │ ├── mixmo.utils.misc.get_nprandom.rst.txt │ │ ├── mixmo.utils.misc.get_output_folder_from_config.rst.txt │ │ ├── mixmo.utils.misc.get_previous_ckpt.rst.txt │ │ ├── mixmo.utils.misc.get_random.rst.txt │ │ ├── mixmo.utils.misc.ifnotfound_update.rst.txt │ │ ├── mixmo.utils.misc.is_float.rst.txt │ │ ├── mixmo.utils.misc.is_int.rst.txt │ │ ├── mixmo.utils.misc.is_nan.rst.txt │ │ ├── mixmo.utils.misc.is_none.rst.txt │ │ ├── mixmo.utils.misc.is_zero.rst.txt │ │ ├── mixmo.utils.misc.load_config_yaml.rst.txt │ │ ├── mixmo.utils.misc.load_yaml.rst.txt │ │ ├── mixmo.utils.misc.print_args.rst.txt │ │ ├── mixmo.utils.misc.print_dict.rst.txt │ │ ├── mixmo.utils.misc.random_lower_than.rst.txt │ │ ├── mixmo.utils.misc.rst.txt │ │ ├── mixmo.utils.misc.sample_lams.rst.txt │ │ ├── mixmo.utils.misc.set_determ.rst.txt │ │ ├── mixmo.utils.misc.update.rst.txt │ │ ├── mixmo.utils.rst.txt │ │ ├── mixmo.utils.torchsummary.rst.txt │ │ ├── mixmo.utils.torchsummary.summary.rst.txt │ │ ├── mixmo.utils.torchutils._calculate_correct_fan.rst.txt │ │ ├── mixmo.utils.torchutils._calculate_fan_in_and_fan_out.rst.txt │ │ ├── mixmo.utils.torchutils.kaiming_normal_truncated.rst.txt │ │ ├── mixmo.utils.torchutils.onehot.rst.txt │ │ ├── mixmo.utils.torchutils.randperm_static.rst.txt │ │ ├── mixmo.utils.torchutils.rst.txt │ │ ├── mixmo.utils.torchutils.truncated_normal_.rst.txt │ │ ├── mixmo.utils.torchutils.weights_init_hetruncatednormal.rst.txt │ │ ├── mixmo.utils.visualize.make_image.rst.txt │ │ ├── mixmo.utils.visualize.rst.txt │ │ ├── mixmo.utils.visualize.show_batch.rst.txt │ │ └── mixmo.utils.visualize.write_calibration.rst.txt │ ├── includeme.rst.txt │ └── index.rst.txt ├── _static │ ├── basic.css │ ├── css │ │ ├── badge_only.css │ │ ├── fonts │ │ │ ├── Roboto-Slab-Bold.woff │ │ │ ├── Roboto-Slab-Bold.woff2 │ │ │ ├── Roboto-Slab-Regular.woff │ │ │ ├── Roboto-Slab-Regular.woff2 │ │ │ ├── fontawesome-webfont.eot │ │ │ ├── fontawesome-webfont.svg │ │ │ ├── fontawesome-webfont.ttf │ │ │ ├── fontawesome-webfont.woff │ │ │ ├── fontawesome-webfont.woff2 │ │ │ ├── lato-bold-italic.woff │ │ │ ├── lato-bold-italic.woff2 │ │ │ ├── lato-bold.woff │ │ │ ├── lato-bold.woff2 │ │ │ ├── lato-normal-italic.woff │ │ │ ├── lato-normal-italic.woff2 │ │ │ ├── lato-normal.woff │ │ │ └── lato-normal.woff2 │ │ └── theme.css │ ├── doctools.js │ ├── documentation_options.js │ ├── file.png │ ├── fonts │ │ ├── Inconsolata-Bold.ttf │ │ ├── Inconsolata-Regular.ttf │ │ ├── Inconsolata.ttf │ │ ├── Lato-Bold.ttf │ │ ├── Lato-Regular.ttf │ │ ├── Lato │ │ │ ├── lato-bold.eot │ │ │ ├── lato-bold.ttf │ │ │ ├── lato-bold.woff │ │ │ ├── lato-bold.woff2 │ │ │ ├── lato-bolditalic.eot │ │ │ ├── lato-bolditalic.ttf │ │ │ ├── lato-bolditalic.woff │ │ │ ├── lato-bolditalic.woff2 │ │ │ ├── lato-italic.eot │ │ │ ├── lato-italic.ttf │ │ │ ├── lato-italic.woff │ │ │ ├── lato-italic.woff2 │ │ │ ├── lato-regular.eot │ │ │ ├── lato-regular.ttf │ │ │ ├── lato-regular.woff │ │ │ └── lato-regular.woff2 │ │ ├── RobotoSlab-Bold.ttf │ │ ├── RobotoSlab-Regular.ttf │ │ ├── RobotoSlab │ │ │ ├── roboto-slab-v7-bold.eot │ │ │ ├── roboto-slab-v7-bold.ttf │ │ │ ├── roboto-slab-v7-bold.woff │ │ │ ├── roboto-slab-v7-bold.woff2 │ │ │ ├── roboto-slab-v7-regular.eot │ │ │ ├── roboto-slab-v7-regular.ttf │ │ │ ├── roboto-slab-v7-regular.woff │ │ │ └── roboto-slab-v7-regular.woff2 │ │ ├── fontawesome-webfont.eot │ │ ├── fontawesome-webfont.svg │ │ ├── fontawesome-webfont.ttf │ │ ├── fontawesome-webfont.woff │ │ └── fontawesome-webfont.woff2 │ ├── jquery-3.5.1.js │ ├── jquery.js │ ├── js │ │ ├── badge_only.js │ │ ├── html5shiv-printshiv.min.js │ │ ├── html5shiv.min.js │ │ ├── modernizr.min.js │ │ └── theme.js │ ├── language_data.js │ ├── minus.png │ ├── plus.png │ ├── pygments.css │ ├── searchtools.js │ ├── underscore-1.12.0.js │ └── underscore.js ├── genindex.html ├── includeme.html ├── index.html ├── objects.inv ├── py-modindex.html ├── search.html └── searchindex.js ├── mixmo ├── __init__.py ├── augmentations │ ├── __init__.py │ ├── augmix.py │ ├── mixing_blocks.py │ └── standard_augmentations.py ├── core │ ├── __init__.py │ ├── loss.py │ ├── metrics_ensemble.py │ ├── metrics_wrapper.py │ ├── optimizer.py │ ├── scheduler.py │ └── temperature_scaling.py ├── learners │ ├── __init__.py │ ├── abstract_learner.py │ ├── learner.py │ └── model_wrapper.py ├── loaders │ ├── __init__.py │ ├── abstract_loader.py │ ├── batch_repetition_sampler.py │ ├── cifar_dataset.py │ ├── dataset_wrapper.py │ └── loader.py ├── networks │ ├── __init__.py │ ├── resnet.py │ └── wrn.py └── utils │ ├── __init__.py │ ├── config.py │ ├── logger.py │ ├── misc.py │ ├── torchsummary.py │ ├── torchutils.py │ └── visualize.py ├── mixmo_intro.png ├── notebooks ├── evaluate.ipynb └── feature_separation.ipynb ├── requirements.txt └── scripts ├── __init__.py ├── evaluate.py ├── exp_mixmo_template.yaml ├── script_load_tiny_data.py ├── templateutils_mixmo.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.ipynb_checkpoints/ 3 | *ckpt 4 | *data/* 5 | -------------------------------------------------------------------------------- /config/cifar10/exp_cifar10_wrn2810-2_cutmixmo-p5_msdacutmix_bar4.yaml: -------------------------------------------------------------------------------- 1 | num_members: 2 2 | 3 | data: 4 | dataset: cifar10 5 | num_classes: 10 6 | 7 | training: 8 | nb_epochs: 300 9 | batch_size: 64 10 | dataset_wrapper: 11 | mixmo: 12 | mix_method: 13 | method_name: cutmix 14 | prob: 0.5 15 | replacement_method_name: mixup 16 | alpha: 2 17 | weight_root: 3 18 | msda: 19 | beta: 1 20 | prob: 0.5 21 | mix_method: cutmix 22 | da_method: null 23 | batch_sampler: 24 | batch_repetitions: 4 25 | proba_input_repetition: 0 26 | 27 | model_wrapper: 28 | name: classifier 29 | network: 30 | name: wideresnetmixmo 31 | depth: 28 32 | widen_factor: 10 33 | loss: 34 | listloss: 35 | - name: soft_cross_entropy 36 | display_name: ce0 37 | coeff: 1 38 | input: logits_0 39 | target: target_0 40 | - name: soft_cross_entropy 41 | display_name: ce1 42 | coeff: ".rst-current-version{width:auto;height:30px;line-height:30px;padding:0 6px;display:block;text-align:center}@media screen and (max-width:768px){.rst-versions{width:85%;display:none}.rst-versions.shift{display:block}} -------------------------------------------------------------------------------- /docs/_static/css/fonts/Roboto-Slab-Bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/css/fonts/Roboto-Slab-Bold.woff -------------------------------------------------------------------------------- /docs/_static/css/fonts/Roboto-Slab-Bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/css/fonts/Roboto-Slab-Bold.woff2 -------------------------------------------------------------------------------- /docs/_static/css/fonts/Roboto-Slab-Regular.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/css/fonts/Roboto-Slab-Regular.woff -------------------------------------------------------------------------------- /docs/_static/css/fonts/Roboto-Slab-Regular.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/css/fonts/Roboto-Slab-Regular.woff2 -------------------------------------------------------------------------------- /docs/_static/css/fonts/fontawesome-webfont.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/css/fonts/fontawesome-webfont.eot -------------------------------------------------------------------------------- /docs/_static/css/fonts/fontawesome-webfont.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/css/fonts/fontawesome-webfont.ttf -------------------------------------------------------------------------------- /docs/_static/css/fonts/fontawesome-webfont.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/css/fonts/fontawesome-webfont.woff -------------------------------------------------------------------------------- /docs/_static/css/fonts/fontawesome-webfont.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/css/fonts/fontawesome-webfont.woff2 -------------------------------------------------------------------------------- /docs/_static/css/fonts/lato-bold-italic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/css/fonts/lato-bold-italic.woff -------------------------------------------------------------------------------- /docs/_static/css/fonts/lato-bold-italic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/css/fonts/lato-bold-italic.woff2 -------------------------------------------------------------------------------- /docs/_static/css/fonts/lato-bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/css/fonts/lato-bold.woff -------------------------------------------------------------------------------- /docs/_static/css/fonts/lato-bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/css/fonts/lato-bold.woff2 -------------------------------------------------------------------------------- /docs/_static/css/fonts/lato-normal-italic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/css/fonts/lato-normal-italic.woff -------------------------------------------------------------------------------- /docs/_static/css/fonts/lato-normal-italic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/css/fonts/lato-normal-italic.woff2 -------------------------------------------------------------------------------- /docs/_static/css/fonts/lato-normal.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/css/fonts/lato-normal.woff -------------------------------------------------------------------------------- /docs/_static/css/fonts/lato-normal.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/css/fonts/lato-normal.woff2 -------------------------------------------------------------------------------- /docs/_static/documentation_options.js: -------------------------------------------------------------------------------- 1 | var DOCUMENTATION_OPTIONS = { 2 | URL_ROOT: document.getElementById("documentation_options").getAttribute('data-url_root'), 3 | VERSION: '', 4 | LANGUAGE: 'None', 5 | COLLAPSE_INDEX: false, 6 | BUILDER: 'html', 7 | FILE_SUFFIX: '.html', 8 | LINK_SUFFIX: '.html', 9 | HAS_SOURCE: true, 10 | SOURCELINK_SUFFIX: '.txt', 11 | NAVIGATION_WITH_KEYS: false 12 | }; -------------------------------------------------------------------------------- /docs/_static/file.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/file.png -------------------------------------------------------------------------------- /docs/_static/fonts/Inconsolata-Bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/fonts/Inconsolata-Bold.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/Inconsolata-Regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/fonts/Inconsolata-Regular.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/Inconsolata.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/fonts/Inconsolata.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/Lato-Bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/fonts/Lato-Bold.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/Lato-Regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/fonts/Lato-Regular.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-bold.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/fonts/Lato/lato-bold.eot -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/fonts/Lato/lato-bold.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/fonts/Lato/lato-bold.woff -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/fonts/Lato/lato-bold.woff2 -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-bolditalic.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/fonts/Lato/lato-bolditalic.eot -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-bolditalic.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/fonts/Lato/lato-bolditalic.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-bolditalic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/fonts/Lato/lato-bolditalic.woff -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-bolditalic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/fonts/Lato/lato-bolditalic.woff2 -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-italic.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/fonts/Lato/lato-italic.eot -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-italic.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/fonts/Lato/lato-italic.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-italic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/fonts/Lato/lato-italic.woff -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-italic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/fonts/Lato/lato-italic.woff2 -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-regular.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/fonts/Lato/lato-regular.eot -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/fonts/Lato/lato-regular.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-regular.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/fonts/Lato/lato-regular.woff -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-regular.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/fonts/Lato/lato-regular.woff2 -------------------------------------------------------------------------------- /docs/_static/fonts/RobotoSlab-Bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/fonts/RobotoSlab-Bold.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/RobotoSlab-Regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/fonts/RobotoSlab-Regular.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/RobotoSlab/roboto-slab-v7-bold.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/fonts/RobotoSlab/roboto-slab-v7-bold.eot -------------------------------------------------------------------------------- /docs/_static/fonts/RobotoSlab/roboto-slab-v7-bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/fonts/RobotoSlab/roboto-slab-v7-bold.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff -------------------------------------------------------------------------------- /docs/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff2 -------------------------------------------------------------------------------- /docs/_static/fonts/RobotoSlab/roboto-slab-v7-regular.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/fonts/RobotoSlab/roboto-slab-v7-regular.eot -------------------------------------------------------------------------------- /docs/_static/fonts/RobotoSlab/roboto-slab-v7-regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/fonts/RobotoSlab/roboto-slab-v7-regular.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff -------------------------------------------------------------------------------- /docs/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff2 -------------------------------------------------------------------------------- /docs/_static/fonts/fontawesome-webfont.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/fonts/fontawesome-webfont.eot -------------------------------------------------------------------------------- /docs/_static/fonts/fontawesome-webfont.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/fonts/fontawesome-webfont.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/fontawesome-webfont.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/fonts/fontawesome-webfont.woff -------------------------------------------------------------------------------- /docs/_static/fonts/fontawesome-webfont.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/fonts/fontawesome-webfont.woff2 -------------------------------------------------------------------------------- /docs/_static/js/badge_only.js: -------------------------------------------------------------------------------- 1 | !function(e){var t={};function r(n){if(t[n])return t[n].exports;var o=t[n]={i:n,l:!1,exports:{}};return e[n].call(o.exports,o,o.exports,r),o.l=!0,o.exports}r.m=e,r.c=t,r.d=function(e,t,n){r.o(e,t)||Object.defineProperty(e,t,{enumerable:!0,get:n})},r.r=function(e){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(e,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(e,"__esModule",{value:!0})},r.t=function(e,t){if(1&t&&(e=r(e)),8&t)return e;if(4&t&&"object"==typeof e&&e&&e.__esModule)return e;var n=Object.create(null);if(r.r(n),Object.defineProperty(n,"default",{enumerable:!0,value:e}),2&t&&"string"!=typeof e)for(var o in e)r.d(n,o,function(t){return e[t]}.bind(null,o));return n},r.n=function(e){var t=e&&e.__esModule?function(){return e.default}:function(){return e};return r.d(t,"a",t),t},r.o=function(e,t){return Object.prototype.hasOwnProperty.call(e,t)},r.p="",r(r.s=4)}({4:function(e,t,r){}}); -------------------------------------------------------------------------------- /docs/_static/js/html5shiv-printshiv.min.js: -------------------------------------------------------------------------------- 1 | /** 2 | * @preserve HTML5 Shiv 3.7.3-pre | @afarkas @jdalton @jon_neal @rem | MIT/GPL2 Licensed 3 | */ 4 | !function(a,b){function c(a,b){var c=a.createElement("p"),d=a.getElementsByTagName("head")[0]||a.documentElement;return c.innerHTML="x",d.insertBefore(c.lastChild,d.firstChild)}function d(){var a=y.elements;return"string"==typeof a?a.split(" "):a}function e(a,b){var c=y.elements;"string"!=typeof c&&(c=c.join(" ")),"string"!=typeof a&&(a=a.join(" ")),y.elements=c+" "+a,j(b)}function f(a){var b=x[a[v]];return b||(b={},w++,a[v]=w,x[w]=b),b}function g(a,c,d){if(c||(c=b),q)return c.createElement(a);d||(d=f(c));var e;return e=d.cache[a]?d.cache[a].cloneNode():u.test(a)?(d.cache[a]=d.createElem(a)).cloneNode():d.createElem(a),!e.canHaveChildren||t.test(a)||e.tagUrn?e:d.frag.appendChild(e)}function h(a,c){if(a||(a=b),q)return a.createDocumentFragment();c=c||f(a);for(var e=c.frag.cloneNode(),g=0,h=d(),i=h.length;i>g;g++)e.createElement(h[g]);return e}function i(a,b){b.cache||(b.cache={},b.createElem=a.createElement,b.createFrag=a.createDocumentFragment,b.frag=b.createFrag()),a.createElement=function(c){return y.shivMethods?g(c,a,b):b.createElem(c)},a.createDocumentFragment=Function("h,f","return function(){var n=f.cloneNode(),c=n.createElement;h.shivMethods&&("+d().join().replace(/[\w\-:]+/g,function(a){return b.createElem(a),b.frag.createElement(a),'c("'+a+'")'})+");return n}")(y,b.frag)}function j(a){a||(a=b);var d=f(a);return!y.shivCSS||p||d.hasCSS||(d.hasCSS=!!c(a,"article,aside,dialog,figcaption,figure,footer,header,hgroup,main,nav,section{display:block}mark{background:#FF0;color:#000}template{display:none}")),q||i(a,d),a}function k(a){for(var b,c=a.getElementsByTagName("*"),e=c.length,f=RegExp("^(?:"+d().join("|")+")$","i"),g=[];e--;)b=c[e],f.test(b.nodeName)&&g.push(b.applyElement(l(b)));return g}function l(a){for(var b,c=a.attributes,d=c.length,e=a.ownerDocument.createElement(A+":"+a.nodeName);d--;)b=c[d],b.specified&&e.setAttribute(b.nodeName,b.nodeValue);return e.style.cssText=a.style.cssText,e}function m(a){for(var b,c=a.split("{"),e=c.length,f=RegExp("(^|[\\s,>+~])("+d().join("|")+")(?=[[\\s,>+~#.:]|$)","gi"),g="$1"+A+"\\:$2";e--;)b=c[e]=c[e].split("}"),b[b.length-1]=b[b.length-1].replace(f,g),c[e]=b.join("}");return c.join("{")}function n(a){for(var b=a.length;b--;)a[b].removeNode()}function o(a){function b(){clearTimeout(g._removeSheetTimer),d&&d.removeNode(!0),d=null}var d,e,g=f(a),h=a.namespaces,i=a.parentWindow;return!B||a.printShived?a:("undefined"==typeof h[A]&&h.add(A),i.attachEvent("onbeforeprint",function(){b();for(var f,g,h,i=a.styleSheets,j=[],l=i.length,n=Array(l);l--;)n[l]=i[l];for(;h=n.pop();)if(!h.disabled&&z.test(h.media)){try{f=h.imports,g=f.length}catch(o){g=0}for(l=0;g>l;l++)n.push(f[l]);try{j.push(h.cssText)}catch(o){}}j=m(j.reverse().join("")),e=k(a),d=c(a,j)}),i.attachEvent("onafterprint",function(){n(e),clearTimeout(g._removeSheetTimer),g._removeSheetTimer=setTimeout(b,500)}),a.printShived=!0,a)}var p,q,r="3.7.3",s=a.html5||{},t=/^<|^(?:button|map|select|textarea|object|iframe|option|optgroup)$/i,u=/^(?:a|b|code|div|fieldset|h1|h2|h3|h4|h5|h6|i|label|li|ol|p|q|span|strong|style|table|tbody|td|th|tr|ul)$/i,v="_html5shiv",w=0,x={};!function(){try{var a=b.createElement("a");a.innerHTML="",p="hidden"in a,q=1==a.childNodes.length||function(){b.createElement("a");var a=b.createDocumentFragment();return"undefined"==typeof a.cloneNode||"undefined"==typeof a.createDocumentFragment||"undefined"==typeof a.createElement}()}catch(c){p=!0,q=!0}}();var y={elements:s.elements||"abbr article aside audio bdi canvas data datalist details dialog figcaption figure footer header hgroup main mark meter nav output picture progress section summary template time video",version:r,shivCSS:s.shivCSS!==!1,supportsUnknownElements:q,shivMethods:s.shivMethods!==!1,type:"default",shivDocument:j,createElement:g,createDocumentFragment:h,addElements:e};a.html5=y,j(b);var z=/^$|\b(?:all|print)\b/,A="html5shiv",B=!q&&function(){var c=b.documentElement;return!("undefined"==typeof b.namespaces||"undefined"==typeof b.parentWindow||"undefined"==typeof c.applyElement||"undefined"==typeof c.removeNode||"undefined"==typeof a.attachEvent)}();y.type+=" print",y.shivPrint=o,o(b),"object"==typeof module&&module.exports&&(module.exports=y)}("undefined"!=typeof window?window:this,document); -------------------------------------------------------------------------------- /docs/_static/js/html5shiv.min.js: -------------------------------------------------------------------------------- 1 | /** 2 | * @preserve HTML5 Shiv 3.7.3 | @afarkas @jdalton @jon_neal @rem | MIT/GPL2 Licensed 3 | */ 4 | !function(a,b){function c(a,b){var c=a.createElement("p"),d=a.getElementsByTagName("head")[0]||a.documentElement;return c.innerHTML="x",d.insertBefore(c.lastChild,d.firstChild)}function d(){var a=t.elements;return"string"==typeof a?a.split(" "):a}function e(a,b){var c=t.elements;"string"!=typeof c&&(c=c.join(" ")),"string"!=typeof a&&(a=a.join(" ")),t.elements=c+" "+a,j(b)}function f(a){var b=s[a[q]];return b||(b={},r++,a[q]=r,s[r]=b),b}function g(a,c,d){if(c||(c=b),l)return c.createElement(a);d||(d=f(c));var e;return e=d.cache[a]?d.cache[a].cloneNode():p.test(a)?(d.cache[a]=d.createElem(a)).cloneNode():d.createElem(a),!e.canHaveChildren||o.test(a)||e.tagUrn?e:d.frag.appendChild(e)}function h(a,c){if(a||(a=b),l)return a.createDocumentFragment();c=c||f(a);for(var e=c.frag.cloneNode(),g=0,h=d(),i=h.length;i>g;g++)e.createElement(h[g]);return e}function i(a,b){b.cache||(b.cache={},b.createElem=a.createElement,b.createFrag=a.createDocumentFragment,b.frag=b.createFrag()),a.createElement=function(c){return t.shivMethods?g(c,a,b):b.createElem(c)},a.createDocumentFragment=Function("h,f","return function(){var n=f.cloneNode(),c=n.createElement;h.shivMethods&&("+d().join().replace(/[\w\-:]+/g,function(a){return b.createElem(a),b.frag.createElement(a),'c("'+a+'")'})+");return n}")(t,b.frag)}function j(a){a||(a=b);var d=f(a);return!t.shivCSS||k||d.hasCSS||(d.hasCSS=!!c(a,"article,aside,dialog,figcaption,figure,footer,header,hgroup,main,nav,section{display:block}mark{background:#FF0;color:#000}template{display:none}")),l||i(a,d),a}var k,l,m="3.7.3-pre",n=a.html5||{},o=/^<|^(?:button|map|select|textarea|object|iframe|option|optgroup)$/i,p=/^(?:a|b|code|div|fieldset|h1|h2|h3|h4|h5|h6|i|label|li|ol|p|q|span|strong|style|table|tbody|td|th|tr|ul)$/i,q="_html5shiv",r=0,s={};!function(){try{var a=b.createElement("a");a.innerHTML="",k="hidden"in a,l=1==a.childNodes.length||function(){b.createElement("a");var a=b.createDocumentFragment();return"undefined"==typeof a.cloneNode||"undefined"==typeof a.createDocumentFragment||"undefined"==typeof a.createElement}()}catch(c){k=!0,l=!0}}();var t={elements:n.elements||"abbr article aside audio bdi canvas data datalist details dialog figcaption figure footer header hgroup main mark meter nav output picture progress section summary template time video",version:m,shivCSS:n.shivCSS!==!1,supportsUnknownElements:l,shivMethods:n.shivMethods!==!1,type:"default",shivDocument:j,createElement:g,createDocumentFragment:h,addElements:e};a.html5=t,j(b),"object"==typeof module&&module.exports&&(module.exports=t)}("undefined"!=typeof window?window:this,document); -------------------------------------------------------------------------------- /docs/_static/minus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/minus.png -------------------------------------------------------------------------------- /docs/_static/plus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/_static/plus.png -------------------------------------------------------------------------------- /docs/objects.inv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/docs/objects.inv -------------------------------------------------------------------------------- /mixmo/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Official MixMo implementation in pytorch 3 | """ 4 | -------------------------------------------------------------------------------- /mixmo/augmentations/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data augmentation functions 3 | """ 4 | -------------------------------------------------------------------------------- /mixmo/augmentations/standard_augmentations.py: -------------------------------------------------------------------------------- 1 | """ 2 | Basic augmentation procedures 3 | """ 4 | 5 | from torchvision import transforms 6 | from mixmo.utils.logger import get_logger 7 | 8 | LOGGER = get_logger(__name__, level="DEBUG") 9 | 10 | 11 | class CustomCompose(transforms.Compose): 12 | 13 | def __call__(self, img, apply_postprocessing=True): 14 | for i, t in enumerate(self.transforms): 15 | 16 | if i == len(self.transforms) - 1 and not apply_postprocessing: 17 | return {"pixels": img, "postprocessing": t} 18 | 19 | img = t(img) 20 | return img 21 | 22 | 23 | cifar_mean = (0.4913725490196078, 0.4823529411764706, 0.4466666666666667) 24 | cifar_std = (0.2023, 0.1994, 0.2010) 25 | 26 | 27 | def get_default_composed_augmentations(dataset_name): 28 | if dataset_name.startswith("cifar"): 29 | normalize = transforms.Normalize(cifar_mean, cifar_std) 30 | # Transformer for train set: random crops and horizontal flip 31 | train_transformer = CustomCompose( 32 | [ 33 | transforms.RandomCrop(32, padding=4), 34 | transforms.RandomHorizontalFlip(), # randomly flip image horizontally 35 | # postprocessing 36 | CustomCompose([transforms.ToTensor(), normalize]) 37 | ] 38 | ) 39 | test_transformer = CustomCompose([ 40 | transforms.ToTensor(), 41 | normalize]) 42 | 43 | elif dataset_name.startswith("tinyimagenet"): 44 | normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 45 | train_transformer = transforms.Compose( 46 | [ 47 | transforms.RandomHorizontalFlip(), 48 | transforms.RandomCrop(64, padding=4), 49 | CustomCompose([transforms.ToTensor(), normalize]) 50 | ] 51 | ) 52 | test_transformer = CustomCompose([transforms.ToTensor(), normalize]) 53 | 54 | else: 55 | raise ValueError(dataset_name) 56 | 57 | return train_transformer, test_transformer 58 | -------------------------------------------------------------------------------- /mixmo/core/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Loss, metrics and schedulers 3 | """ 4 | -------------------------------------------------------------------------------- /mixmo/core/optimizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Optimizer factory 3 | """ 4 | 5 | import torch.optim as optim 6 | from mixmo.utils.logger import get_logger 7 | 8 | 9 | LOGGER = get_logger(__name__, level="DEBUG") 10 | 11 | 12 | def get_optimizer(optimizer, list_param_groups): 13 | """ 14 | Builds optimizer objects from config 15 | """ 16 | optimizer_name = optimizer["name"] 17 | optimizer_params = optimizer["params"] 18 | 19 | LOGGER.info(f"Using optimizer {optimizer_name} with params {optimizer_params}") 20 | if optimizer_name == "sgd": 21 | optimizer = optim.SGD(list_param_groups, **optimizer_params) 22 | elif optimizer_name == "adam": 23 | optimizer = optim.Adam(list_param_groups, **optimizer_params) 24 | elif optimizer_name == "adadelta": 25 | optimizer = optim.Adadelta(list_param_groups, **optimizer_params) 26 | elif optimizer_name == "rmsprop": 27 | optimizer = optim.RMSprop(list_param_groups, **optimizer_params) 28 | else: 29 | raise KeyError("Bad optimizer name or not implemented (sgd, adam, adadelta).") 30 | 31 | return optimizer 32 | -------------------------------------------------------------------------------- /mixmo/core/scheduler.py: -------------------------------------------------------------------------------- 1 | """ 2 | Scheduler definitions and factory 3 | """ 4 | 5 | from torch.optim.lr_scheduler import Counter, _LRScheduler 6 | from mixmo.utils.logger import get_logger 7 | 8 | 9 | LOGGER = get_logger(__name__, level="DEBUG") 10 | 11 | class MultiGammaStepLR(_LRScheduler): 12 | """ 13 | Multi step decay scheduler, with decay applied to the learning rate every set milestone 14 | """ 15 | 16 | def __init__(self, optimizer, dict_milestone_to_gamma, last_epoch=-1): 17 | self.milestones = Counter(dict_milestone_to_gamma.keys()) 18 | self.dict_milestone_to_gamma = dict_milestone_to_gamma 19 | super(MultiGammaStepLR, self).__init__(optimizer, last_epoch) 20 | 21 | def get_lr(self): 22 | if self.last_epoch not in self.milestones: 23 | return [group['lr'] for group in self.optimizer.param_groups] 24 | gamma = self.dict_milestone_to_gamma[self.last_epoch] 25 | LOGGER.warning(f"Decrease lr by gamma: {gamma} at epoch: {self.last_epoch}") 26 | return [ 27 | group['lr'] * gamma 28 | for group in self.optimizer.param_groups 29 | ] 30 | 31 | 32 | SCHEDULERS = { 33 | "multigamma_step": MultiGammaStepLR, 34 | } 35 | 36 | 37 | def get_scheduler(lr_schedule, optimizer, start_epoch): 38 | """ 39 | Build the scheduler object 40 | """ 41 | scheduler_name = lr_schedule.pop("name") 42 | 43 | scheduler_params = lr_schedule["params"] 44 | # Add last epoch 45 | scheduler_params["last_epoch"] = start_epoch 46 | LOGGER.info(f"Using {scheduler_name} scheduler with {scheduler_params} params") 47 | base_scheduler = SCHEDULERS[scheduler_name](optimizer, **scheduler_params) 48 | return base_scheduler 49 | 50 | 51 | class GradualWarmupScheduler(_LRScheduler): 52 | """ Gradually warm-up(increasing) learning rate in optimizer. 53 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. 54 | Args: 55 | optimizer (Optimizer): Wrapped optimizer. 56 | multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr. 57 | total_steps: target learning rate is reached at total_steps, gradually 58 | """ 59 | 60 | def __init__(self, optimizer, multiplier, total_steps): 61 | self.multiplier = multiplier 62 | if self.multiplier < 1.: 63 | raise ValueError('multiplier should be greater thant or equal to 1.') 64 | self.total_steps = int(total_steps) 65 | self.finished = False 66 | self.last_steps = 0 67 | super(GradualWarmupScheduler, self).__init__(optimizer) 68 | 69 | def get_lr_warmup(self): 70 | if self.multiplier == 1.0: 71 | return [ 72 | base_lr * (float(self.last_steps) / self.total_steps) for base_lr in self.base_lrs 73 | ] 74 | else: 75 | raise NotImplementedError 76 | 77 | def step(self, steps=None): 78 | if steps is None: 79 | steps = self.last_steps + 1 80 | self.last_steps = steps if steps != 0 else 1 81 | if self.last_steps <= self.total_steps: 82 | warmup_lr = self.get_lr_warmup() 83 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): 84 | param_group['lr'] = lr 85 | if self.last_steps == self.total_steps: 86 | LOGGER.warning(f"This is the end of warmup at lr: {warmup_lr}") 87 | 88 | 89 | 90 | def get_warmup_scheduler(optimizer, warmup_period): 91 | """ 92 | Build a Scheduler instance with warmup 93 | """ 94 | return GradualWarmupScheduler( 95 | optimizer, 96 | multiplier=1, 97 | total_steps=warmup_period, 98 | ) 99 | -------------------------------------------------------------------------------- /mixmo/learners/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Learner objects that wrap models, optimizers and datasets to train and evaluate 3 | """ 4 | -------------------------------------------------------------------------------- /mixmo/loaders/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dataset wrappers and processing 3 | """ 4 | 5 | from mixmo.loaders import loader 6 | 7 | 8 | def get_loader(config_args, **kwargs): 9 | """ 10 | Return a new instance of dataset loader 11 | """ 12 | # Available datasets 13 | data_loader_factory = { 14 | "cifar10": loader.CIFAR10Loader, 15 | "cifar100": loader.CIFAR100Loader, 16 | "tinyimagenet200": loader.TinyImagenet200Loader, 17 | } 18 | 19 | return data_loader_factory[config_args['data']['dataset']](config_args=config_args, **kwargs) 20 | -------------------------------------------------------------------------------- /mixmo/loaders/batch_repetition_sampler.py: -------------------------------------------------------------------------------- 1 | """ 2 | Sampler definition for multi-input models 3 | """ 4 | import torch 5 | import random 6 | 7 | from mixmo.utils import ( 8 | config, torchutils, logger, misc) 9 | 10 | LOGGER = logger.get_logger(__name__, level="DEBUG") 11 | 12 | 13 | class BatchRepetitionSampler(torch.utils.data.sampler.BatchSampler): 14 | """ 15 | Wraps another sampler to yield a mini-batch of repeated indices. 16 | """ 17 | def __init__( 18 | self, 19 | sampler, 20 | batch_size, 21 | num_members, 22 | config_batch_sampler, 23 | drop_last=False, 24 | ): 25 | torch.utils.data.sampler.BatchSampler.__init__(self, sampler, batch_size, drop_last) 26 | 27 | self.num_members = num_members 28 | self._batch_repetitions = config_batch_sampler["batch_repetitions"] 29 | self._proba_input_repetition = config_batch_sampler["proba_input_repetition"] 30 | 31 | def __iter__(self): 32 | batch = [] 33 | for idx in self.sampler: 34 | for _ in range(self._batch_repetitions): 35 | batch.append(idx) 36 | if len(batch) >= self.batch_size: 37 | yield self.output_format(batch) 38 | batch = [] 39 | if len(batch) > 0 and not self.drop_last: 40 | yield self.output_format(batch) 41 | 42 | def output_format(self, std_batch): 43 | """ 44 | Transforms standards batches into batches of sample summaries 45 | """ 46 | # Create M shuffled batches, one for each input 47 | batch_size = len(std_batch) 48 | list_shuffled_index = [ 49 | torchutils.randperm_static(batch_size, proba_static=self._proba_input_repetition) 50 | for _ in range(self.num_members) 51 | ] 52 | 53 | shuffled_batch = [ 54 | std_batch[list_shuffled_index[0][count]] 55 | for count in range(batch_size)] 56 | 57 | # sample batch seed, shared among samples from the given batch 58 | batch_seed = random.randint(0, config.cfg.RANDOM.MAX_RANDOM) 59 | list_index = [ 60 | misc.clean_update( 61 | { 62 | "batch_seed": batch_seed, 63 | "index_" + str(0): shuffled_batch[count] 64 | }, { 65 | "index_" + str(num_member): 66 | shuffled_batch[list_shuffled_index[num_member][count]] 67 | for num_member in range(1, self.num_members) 68 | } 69 | ) 70 | for count in range(batch_size) 71 | ] 72 | 73 | return list_index 74 | 75 | def __len__(self): 76 | len_sampler = len(self.sampler) * self._batch_repetitions 77 | if self.drop_last: 78 | return len_sampler // self.batch_size 79 | else: 80 | return (len_sampler + self.batch_size - 1) // self.batch_size 81 | -------------------------------------------------------------------------------- /mixmo/loaders/cifar_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | CIFAR 10 and 100 dataset wrappers 3 | """ 4 | 5 | import os 6 | from PIL import Image 7 | import numpy as np 8 | from torchvision import datasets 9 | 10 | from mixmo.utils.logger import get_logger 11 | 12 | LOGGER = get_logger(__name__, level="DEBUG") 13 | 14 | 15 | class CustomCIFAR10(datasets.CIFAR10): 16 | """ 17 | Torchvision's CIFAR10 dataset class augmented with a custom __getitem__ method 18 | """ 19 | def __getitem__(self, index, apply_postprocessing=True): 20 | """ 21 | Args: 22 | index (int): Index 23 | 24 | Returns: 25 | tuple: (image, target) where target is index of the target class. 26 | """ 27 | img, target = self.data[index], self.targets[index] 28 | 29 | # doing this so that it is consistent with all other datasets 30 | # to return a PIL Image 31 | img = Image.fromarray(img) 32 | 33 | if self.transform is not None: 34 | img = self.transform(img, apply_postprocessing=apply_postprocessing) 35 | 36 | if self.target_transform is not None: 37 | target = self.target_transform(target) 38 | 39 | return img, target 40 | 41 | 42 | 43 | class CustomCIFAR100(datasets.CIFAR100): 44 | """ 45 | Torchvision's CIFAR100 dataset class augmented with a custom __getitem__ method 46 | """ 47 | def __getitem__(self, *args, **kwargs): 48 | return CustomCIFAR10.__getitem__(self, *args, **kwargs) 49 | 50 | 51 | 52 | CIFAR_ROBUSTNESS_FILENAMES = [ 53 | "speckle_noise.npy", "shot_noise.npy", "impulse_noise.npy", "defocus_blur.npy", 54 | "gaussian_blur.npy", "glass_blur.npy", "zoom_blur.npy", "fog.npy", "brightness.npy", 55 | "contrast.npy", "elastic_transform.npy", "pixelate.npy", "jpeg_compression.npy", 56 | "motion_blur.npy", "snow.npy", 'frost.npy', 'gaussian_noise.npy', 'saturate.npy', 'spatter.npy' 57 | ] 58 | 59 | 60 | class CIFARCorruptions(CustomCIFAR10): 61 | 62 | def __init__(self, root, transform=None, train=False): 63 | 64 | super(datasets.CIFAR10, self).__init__(root) 65 | self.transform = transform 66 | self.train = train # training set or test set 67 | self.data = [] 68 | self.targets = [] 69 | 70 | labels_path = os.path.join(root, "labels.npy") 71 | list_labels = np.load(labels_path).tolist() 72 | 73 | ## iterate over CIFAR_ROBUSTNESS_FILENAMES 74 | for i, filename in enumerate(CIFAR_ROBUSTNESS_FILENAMES): 75 | filepath = os.path.join(root, filename) 76 | data = np.load(filepath) 77 | 78 | if i == 0: 79 | ## initialize with first filename 80 | self.data = data 81 | self.targets = list_labels[:] 82 | else: 83 | self.data = np.concatenate( 84 | (self.data, data), 85 | axis=0) 86 | self.targets.extend(list_labels) 87 | 88 | nb_files = len(CIFAR_ROBUSTNESS_FILENAMES) 89 | LOGGER.warning(f"Robustness corruptions of len: {nb_files}") 90 | LOGGER.warning(f"Pixels of shape: {self.data.shape}") 91 | LOGGER.warning(f"Targets of len: {len(self.targets)}") 92 | -------------------------------------------------------------------------------- /mixmo/loaders/loader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Custom Dataloaders for each of the considered datasets 3 | """ 4 | 5 | import os 6 | 7 | from torchvision import datasets 8 | 9 | from mixmo.augmentations.standard_augmentations import get_default_composed_augmentations 10 | from mixmo.loaders import cifar_dataset, abstract_loader 11 | from mixmo.utils.logger import get_logger 12 | 13 | LOGGER = get_logger(__name__, level="DEBUG") 14 | 15 | 16 | class CIFAR10Loader(abstract_loader.AbstractDataLoader): 17 | """ 18 | Loader for the CIFAR10 dataset that inherits the abstract_loader.AbstractDataLoader dataloading API 19 | and defines the proper augmentations and datasets 20 | """ 21 | 22 | def _init_dataaugmentations(self): 23 | (self.augmentations_train, self.augmentations_test) = get_default_composed_augmentations( 24 | dataset_name="cifar", 25 | ) 26 | 27 | def _init_dataset(self, corruptions=False): 28 | self.train_dataset = cifar_dataset.CustomCIFAR10( 29 | root=self.data_dir, train=True, download=True, transform=self.augmentations_train 30 | ) 31 | if not corruptions: 32 | self.test_dataset = cifar_dataset.CustomCIFAR10( 33 | root=self.data_dir, train=False, download=True, transform=self.augmentations_test 34 | ) 35 | else: 36 | self.test_dataset = cifar_dataset.CIFARCorruptions( 37 | root=self.corruptions_data_dir, train=False, transform=self.augmentations_test 38 | ) 39 | 40 | @property 41 | def data_dir(self): 42 | return os.path.join(self.dataplace, "cifar10-data") 43 | 44 | @property 45 | def corruptions_data_dir(self): 46 | return os.path.join(self.dataplace, "CIFAR-10-C") 47 | 48 | 49 | @staticmethod 50 | def properties(key): 51 | dict_key_to_values = { 52 | "conv1_input_size": (16, 32, 32), 53 | "conv1_is_half_size": False, 54 | "pixels_size": 32, 55 | } 56 | return dict_key_to_values[key] 57 | 58 | 59 | class CIFAR100Loader(CIFAR10Loader): 60 | """ 61 | Loader for the CIFAR100 dataset that inherits the abstract_loader.AbstractDataLoader dataloading API 62 | and defines the proper augmentations and datasets 63 | """ 64 | 65 | def _init_dataset(self, corruptions=False): 66 | self.train_dataset = cifar_dataset.CustomCIFAR100( 67 | root=self.data_dir, train=True, download=True, transform=self.augmentations_train 68 | ) 69 | if not corruptions: 70 | self.test_dataset = cifar_dataset.CustomCIFAR100( 71 | root=self.data_dir, train=False, download=True, transform=self.augmentations_test 72 | ) 73 | else: 74 | self.test_dataset = cifar_dataset.CIFARCorruptions( 75 | root=self.corruptions_data_dir, train=False, transform=self.augmentations_test 76 | ) 77 | 78 | @property 79 | def data_dir(self): 80 | return os.path.join(self.dataplace, "cifar100-data") 81 | 82 | @property 83 | def corruptions_data_dir(self): 84 | return os.path.join(self.dataplace, "CIFAR-100-C") 85 | 86 | 87 | class TinyImagenet200Loader(abstract_loader.AbstractDataLoader): 88 | """ 89 | Loader for the TinyImageNet dataset that inherits the abstract_loader.AbstractDataLoader dataloading API 90 | and defines the proper augmentations and datasets 91 | """ 92 | 93 | def _init_dataaugmentations(self): 94 | (self.augmentations_train, self.augmentations_test) = get_default_composed_augmentations( 95 | dataset_name="tinyimagenet", 96 | ) 97 | 98 | @property 99 | def data_dir(self): 100 | return os.path.join(self.dataplace, "tinyimagenet200-data") 101 | 102 | def _init_dataset(self, corruptions=False): 103 | traindir = os.path.join(self.data_dir, 'train') 104 | valdir = os.path.join(self.data_dir, 'val/images') 105 | self.train_dataset = datasets.ImageFolder(traindir, self.augmentations_train) 106 | self.test_dataset = datasets.ImageFolder(valdir, self.augmentations_test) 107 | 108 | @staticmethod 109 | def properties(key): 110 | dict_key_to_values = { 111 | "conv1_input_size": (64, 32, 32), 112 | "conv1_is_half_size": True, 113 | "pixels_size": 64, 114 | } 115 | return dict_key_to_values[key] 116 | -------------------------------------------------------------------------------- /mixmo/networks/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Networks used in the main paper 3 | """ 4 | 5 | from mixmo.utils.logger import get_logger 6 | 7 | from mixmo.networks import resnet, wrn 8 | 9 | 10 | LOGGER = get_logger(__name__, level="DEBUG") 11 | 12 | 13 | def get_network(config_network, config_args): 14 | """ 15 | Return a new instance of network 16 | """ 17 | # Available networks for tiny 18 | if config_args["data"]["dataset"].startswith('tinyimagenet'): 19 | network_factory = resnet.resnet_network_factory 20 | elif config_args["data"]["dataset"].startswith('cifar'): 21 | network_factory = wrn.wrn_network_factory 22 | else: 23 | raise NotImplementedError 24 | 25 | LOGGER.warning(f"Loading network: {config_network['name']}") 26 | return network_factory[config_network["name"]]( 27 | config_network=config_network, 28 | config_args=config_args) 29 | -------------------------------------------------------------------------------- /mixmo/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | General utility functions 3 | """ 4 | -------------------------------------------------------------------------------- /mixmo/utils/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | General Variable definitions, mostly for testing purposes 3 | """ 4 | 5 | from easydict import EasyDict 6 | 7 | cfg = EasyDict() 8 | 9 | cfg.DEBUG = 0 10 | cfg.SAVE_EVERY_X_EPOCH = 30 11 | cfg.RATIO_EPOCH_DECREASE = 11/12 12 | 13 | # TEST CONFIGS 14 | cfg.CALIBRATION = EasyDict() 15 | cfg.CALIBRATION.LRS = [0.001, 0.0001, 0.00005, 0.00002, 0.00001, 0.000005] 16 | cfg.CALIBRATION.MAX_ITERS = [5000] 17 | 18 | cfg.RANDOM = EasyDict() 19 | cfg.RANDOM.SEED = 1234 20 | cfg.RANDOM.MAX_RANDOM = 10_000_000 21 | cfg.RANDOM.SEED_OFFSET_MIXMO = 11 22 | cfg.RANDOM.SEED_DA = 21 23 | cfg.RANDOM.SEED_TESTVAL = 31 24 | -------------------------------------------------------------------------------- /mixmo/utils/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Logger utilities. It improves levels of logger and add coloration for each level 3 | """ 4 | import logging 5 | 6 | import coloredlogs 7 | import verboselogs 8 | 9 | 10 | def get_logger(logger_name, level="SPAM"): 11 | """ 12 | Get the logger and: 13 | - improve logger levels: 14 | spam, debug, verbose, info, notice, warning, success, error, critical 15 | - add colors to each levels, and print ms to the time 16 | 17 | Use: 18 | As a global variable in each file: 19 | >>> LOGGER = logger.get_logger('name_of_the_file', level='DEBUG') 20 | The level allows to reduce the printed messages 21 | 22 | Jupyter notebook: 23 | Add argument "isatty=True" to coloredlogs.install() 24 | Easier to read with 'fmt = "%(name)s[%(process)d] %(levelname)s %(message)s"' 25 | """ 26 | verboselogs.install() 27 | logger = logging.getLogger(logger_name) 28 | 29 | field_styles = { 30 | "hostname": {"color": "magenta"}, 31 | "programname": {"color": "cyan"}, 32 | "name": {"color": "blue"}, 33 | "levelname": {"color": "black", "bold": True, "bright": True}, 34 | "asctime": {"color": "green"}, 35 | } 36 | coloredlogs.install( 37 | level=level, 38 | logger=logger, 39 | fmt="%(asctime)s,%(msecs)03d %(hostname)s %(name)s[%(process)d] %(levelname)s %(message)s", 40 | field_styles=field_styles, 41 | ) 42 | logger.propagate = False 43 | return logger 44 | -------------------------------------------------------------------------------- /mixmo/utils/visualize.py: -------------------------------------------------------------------------------- 1 | """ 2 | Calibration and image printing utility functions 3 | """ 4 | 5 | import matplotlib.pyplot as plt 6 | import torchvision 7 | import numpy as np 8 | 9 | __all__ = ['make_image', 'show_batch', 'write_calibration'] 10 | 11 | 12 | def write_calibration( 13 | avg_confs_in_bins, 14 | acc_in_bin_list, 15 | prop_bin, 16 | min_bin, 17 | max_bin, 18 | min_pred=0, 19 | write_file=None, 20 | suffix="" 21 | ): 22 | """ 23 | Utility function to show calibration through a histogram of classifier confidences 24 | """ 25 | fig, ax1 = plt.subplots() 26 | 27 | ax1.plot([min_pred, 1], [min_pred, 1], "k:", label="Perfectly calibrated") 28 | 29 | ax1.plot(avg_confs_in_bins, 30 | acc_in_bin_list, 31 | "s-", 32 | label="%s" % ("Discriminator Calibration")) 33 | 34 | if write_file: 35 | suffix = write_file.split("/")[-1].split(".")[0].split("_")[0] + "_" + suffix 36 | 37 | ax1.set_xlabel(f"Mean predicted value {suffix}") 38 | ax1.set_ylabel("Accuracy") 39 | ymin = min(acc_in_bin_list + [min_pred]) 40 | ax1.set_ylim([ymin, 1.0]) 41 | ax1.legend(loc="lower right") 42 | 43 | ax2 = ax1.twinx() 44 | ax2.hlines(prop_bin, min_bin, max_bin, label="%s" % ("Proportion in each bin"), color="r") 45 | ax2.set_ylabel("Proportion") 46 | ax2.legend(loc="upper center") 47 | 48 | if not write_file: 49 | plt.tight_layout() 50 | plt.show() 51 | else: 52 | fig.savefig( 53 | write_file 54 | ) 55 | 56 | 57 | def make_image(img, mean=None, std=None, normalize=True): 58 | """ 59 | Transform a CIFAR numpy image into a pytorch image (need to swap dimensions) 60 | """ 61 | if mean is None and std is None: 62 | from mixmo.augmentations.standard_augmentations import cifar_mean, cifar_std 63 | mean = cifar_mean 64 | std = cifar_std 65 | npimg = img.numpy().copy() 66 | if normalize: 67 | for i in range(0, 3): 68 | npimg[i] = npimg[i] * std[i] + mean[i] # unnormalize 69 | return np.transpose(npimg, (1, 2, 0)) 70 | 71 | 72 | def show_batch(images, normalize=True): 73 | """ 74 | Plot images in a batch of images 75 | """ 76 | images = make_image(torchvision.utils.make_grid(images), normalize=normalize) 77 | plt.imshow(images) 78 | plt.show() 79 | return images 80 | -------------------------------------------------------------------------------- /mixmo_intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/mixmo_intro.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.19.0 2 | scikit-learn==0.19.1 3 | scipy==1.1.0 4 | torch==1.4.0 5 | torchsummary==1.5.1 6 | torchvision==0.5.0 7 | tqdm==4.31.1 8 | pyyaml 9 | verboselogs 10 | coloredlogs 11 | click 12 | easydict 13 | tensorboard==1.14.0 14 | pandas==1.0.5 15 | matplotlib==3.3.0 16 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/mixmo-pytorch/5a2a1a090b7805212aaad8ebc36ef11ade75e8d2/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/exp_mixmo_template.yaml: -------------------------------------------------------------------------------- 1 | num_members: %(num_members)s 2 | 3 | data: 4 | dataset: %(dataset_name)s%(num_classes)s 5 | num_classes: %(num_classes)s 6 | 7 | training: 8 | nb_epochs: %(nb_epochs)s 9 | batch_size: %(batch_size)s 10 | dataset_wrapper: 11 | mixmo: 12 | mix_method: 13 | method_name: %(mixmo_mix_method_name)s 14 | prob: %(mixmo_mix_prob)s 15 | replacement_method_name: mixup 16 | alpha: %(mixmo_alpha)s 17 | weight_root: %(mixmo_weight_root)s 18 | msda: 19 | beta: 1 20 | prob: 0.5 21 | mix_method: %(msda_mix_method)s 22 | da_method: %(da_method)s 23 | batch_sampler: 24 | batch_repetitions: %(batch_repetitions)s 25 | proba_input_repetition: 0 26 | 27 | model_wrapper: 28 | name: classifier 29 | network: 30 | name: %(classifier)s 31 | depth: %(depth)s 32 | widen_factor: %(widen_factor)s 33 | loss: 34 | listloss: 35 | - name: soft_cross_entropy 36 | display_name: ce0 37 | coeff: 1 38 | input: logits_0 39 | target: target_0 40 | - name: soft_cross_entropy 41 | display_name: ce1 42 | coeff: "