├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE.md ├── README.md ├── evaluate_ddp.sh ├── evaluate_ddp_jz.sh ├── experiments_margins_jz.sh ├── experiments_ssps.sh ├── experiments_ssps_jz.sh ├── logo.png ├── models ├── margins │ └── voxceleb2 │ │ ├── moco │ │ ├── adaface_m-0.05 │ │ │ └── config.yml │ │ ├── adaface_m-0.1 │ │ │ └── config.yml │ │ ├── adaface_m-0.2 │ │ │ └── config.yml │ │ ├── adaface_m-0.3 │ │ │ └── config.yml │ │ ├── adaface_m-0.5 │ │ │ └── config.yml │ │ ├── arcface_m-0.005 │ │ │ └── config.yml │ │ ├── arcface_m-0.01 │ │ │ └── config.yml │ │ ├── arcface_m-0.05 │ │ │ └── config.yml │ │ ├── arcface_m-0.1 │ │ │ └── config.yml │ │ ├── arcface_m-0.2 │ │ │ └── config.yml │ │ ├── cosface_m-0.05 │ │ │ └── config.yml │ │ ├── cosface_m-0.1 │ │ │ └── config.yml │ │ ├── cosface_m-0.2 │ │ │ └── config.yml │ │ ├── cosface_m-0.3 │ │ │ └── config.yml │ │ ├── cosface_m-0.5 │ │ │ └── config.yml │ │ ├── curricularface_m-0.005 │ │ │ └── config.yml │ │ ├── curricularface_m-0.01 │ │ │ └── config.yml │ │ ├── curricularface_m-0.05 │ │ │ └── config.yml │ │ ├── curricularface_m-0.1 │ │ │ └── config.yml │ │ ├── curricularface_m-0.2 │ │ │ └── config.yml │ │ ├── magface_m-0.01-0.05 │ │ │ └── config.yml │ │ ├── magface_m-0.05-0.1 │ │ │ └── config.yml │ │ ├── sphereface_m-0.01 │ │ │ └── config.yml │ │ ├── sphereface_m-0.1 │ │ │ └── config.yml │ │ ├── sphereface_m-2 │ │ │ └── config.yml │ │ ├── sphereface_m-3 │ │ │ └── config.yml │ │ └── sphereface_m-4 │ │ │ └── config.yml │ │ └── simclr │ │ ├── adaface_m-0.05 │ │ └── config.yml │ │ ├── adaface_m-0.1 │ │ └── config.yml │ │ ├── adaface_m-0.2 │ │ └── config.yml │ │ ├── adaface_m-0.3 │ │ └── config.yml │ │ ├── adaface_m-0.5 │ │ └── config.yml │ │ ├── arcface_m-0.005 │ │ └── config.yml │ │ ├── arcface_m-0.01 │ │ └── config.yml │ │ ├── arcface_m-0.05 │ │ └── config.yml │ │ ├── arcface_m-0.1 │ │ └── config.yml │ │ ├── arcface_m-0.2 │ │ └── config.yml │ │ ├── cosface_m-0.05 │ │ └── config.yml │ │ ├── cosface_m-0.1 │ │ └── config.yml │ │ ├── cosface_m-0.2 │ │ └── config.yml │ │ ├── cosface_m-0.3 │ │ └── config.yml │ │ ├── cosface_m-0.5 │ │ └── config.yml │ │ ├── curricularface_m-0.005 │ │ └── config.yml │ │ ├── curricularface_m-0.01 │ │ └── config.yml │ │ ├── curricularface_m-0.05 │ │ └── config.yml │ │ ├── curricularface_m-0.1 │ │ └── config.yml │ │ ├── curricularface_m-0.2 │ │ └── config.yml │ │ ├── magface_m-0.01-0.05 │ │ └── config.yml │ │ ├── magface_m-0.05-0.1 │ │ └── config.yml │ │ ├── sphereface_m-0.01 │ │ └── config.yml │ │ ├── sphereface_m-0.1 │ │ └── config.yml │ │ ├── sphereface_m-2 │ │ └── config.yml │ │ ├── sphereface_m-3 │ │ └── config.yml │ │ └── sphereface_m-4 │ │ └── config.yml ├── ssl │ └── voxceleb2 │ │ ├── barlowtwins │ │ ├── barlowtwins │ │ │ └── config.yml │ │ └── barlowtwins_p-2048-BN-R-2048-BN-R-512 │ │ │ └── config.yml │ │ ├── byol │ │ ├── byol │ │ │ └── config.yml │ │ └── byol_p-2048-BN-R-2048-BN-R-512 │ │ │ └── config.yml │ │ ├── cpc │ │ ├── cpc │ │ │ └── config.yml │ │ ├── cpc_aggdim128 │ │ │ └── config.yml │ │ ├── cpc_aggdim512 │ │ │ └── config.yml │ │ ├── cpc_bidirectional │ │ │ └── config.yml │ │ ├── cpc_t2 │ │ │ └── config.yml │ │ ├── cpc_t6 │ │ │ └── config.yml │ │ └── cpc_t8 │ │ │ └── config.yml │ │ ├── deepcluster │ │ ├── deepcluster │ │ │ └── config.yml │ │ └── deepcluster_p-2048-BN-R-2048-BN-R-512 │ │ │ └── config.yml │ │ ├── dino │ │ ├── dino+ │ │ │ └── config.yml │ │ ├── dino+_e-ecapa-1024 │ │ │ └── config.yml │ │ └── dino │ │ │ └── config.yml │ │ ├── lim │ │ ├── lim │ │ │ └── config.yml │ │ ├── lim_mine │ │ │ └── config.yml │ │ └── lim_nce │ │ │ └── config.yml │ │ ├── moco │ │ ├── moco │ │ │ └── config.yml │ │ ├── moco_e-ecapa-1024 │ │ │ └── config.yml │ │ ├── moco_p-2048-BN-R-2048-BN-R-512 │ │ │ └── config.yml │ │ ├── moco_p-2048-BN-R-2048-BN-R-512_temp-0.03 │ │ │ └── config.yml │ │ ├── moco_p-none │ │ │ └── config.yml │ │ └── moco_p-none_temp-0.03 │ │ │ └── config.yml │ │ ├── simclr │ │ ├── simclr │ │ │ └── config.yml │ │ ├── simclr_e-ecapa-1024 │ │ │ └── config.yml │ │ ├── simclr_p-2048-BN-R-2048-BN-R-512 │ │ │ └── config.yml │ │ ├── simclr_p-none │ │ │ └── config.yml │ │ ├── simclr_p-none_temp-0.03 │ │ │ └── config.yml │ │ ├── simclr_p-none_temp-0.03333333333333 │ │ │ └── config.yml │ │ ├── simclr_temp-0.01 │ │ │ └── config.yml │ │ ├── simclr_temp-0.03 │ │ │ └── config.yml │ │ ├── simclr_temp-0.05 │ │ │ └── config.yml │ │ └── simclr_temp-0.07 │ │ │ └── config.yml │ │ ├── simsiam │ │ └── simsiam │ │ │ └── config.yml │ │ ├── supervised │ │ ├── supervised │ │ │ └── config.yml │ │ ├── supervised_e-ecapa-1024 │ │ │ └── config.yml │ │ └── supervised_e-ecapa-512 │ │ │ └── config.yml │ │ ├── swav │ │ ├── swav │ │ │ └── config.yml │ │ ├── swav_e-ecapa-1024 │ │ │ └── config.yml │ │ ├── swav_p-2048-BN-R-2048-BN-R-512 │ │ │ └── config.yml │ │ └── swav_p-none │ │ │ └── config.yml │ │ ├── vicreg │ │ ├── vicreg │ │ │ └── config.yml │ │ ├── vicreg_e-ecapa-1024 │ │ │ └── config.yml │ │ ├── vicreg_p-2048-BN-R-2048-BN-R-512 │ │ │ └── config.yml │ │ └── vicreg_p-none │ │ │ └── config.yml │ │ └── wmse │ │ └── wmse │ │ └── config.yml ├── ssps │ └── voxceleb2 │ │ ├── dino_e-ecapa │ │ ├── _dino │ │ │ └── config.yml │ │ ├── baseline │ │ │ └── config.yml │ │ ├── baseline_sup │ │ │ └── config.yml │ │ ├── ssps_kmeans_25k_uni-1 │ │ │ └── config.yml │ │ └── ssps_kmeans_6k │ │ │ └── config.yml │ │ ├── moco_e-ecapa │ │ ├── _moco │ │ │ └── config.yml │ │ ├── baseline │ │ │ └── config.yml │ │ ├── baseline_sup │ │ │ └── config.yml │ │ ├── ssps_kmeans_25k_uni-1 │ │ │ └── config.yml │ │ └── ssps_kmeans_6k │ │ │ └── config.yml │ │ ├── simclr │ │ ├── _simclr │ │ │ └── config.yml │ │ ├── baseline │ │ │ └── config.yml │ │ ├── baseline_sup │ │ │ └── config.yml │ │ ├── exps │ │ │ ├── ssps_frame-2s-aug │ │ │ │ └── config.yml │ │ │ ├── ssps_frame-2s-clean │ │ │ │ └── config.yml │ │ │ ├── ssps_frame-3s-clean │ │ │ │ └── config.yml │ │ │ ├── ssps_kmeans-centroid_25k_uni-1 │ │ │ │ └── config.yml │ │ │ ├── ssps_kmeans-centroid_6k │ │ │ │ └── config.yml │ │ │ ├── ssps_kmeans_100k │ │ │ │ └── config.yml │ │ │ ├── ssps_kmeans_100k_uni-1 │ │ │ │ └── config.yml │ │ │ ├── ssps_kmeans_100k_uni-2 │ │ │ │ └── config.yml │ │ │ ├── ssps_kmeans_100k_uni-3 │ │ │ │ └── config.yml │ │ │ ├── ssps_kmeans_100k_uni-5 │ │ │ │ └── config.yml │ │ │ ├── ssps_kmeans_10k │ │ │ │ └── config.yml │ │ │ ├── ssps_kmeans_150k │ │ │ │ └── config.yml │ │ │ ├── ssps_kmeans_150k_uni-1 │ │ │ │ └── config.yml │ │ │ ├── ssps_kmeans_150k_uni-2 │ │ │ │ └── config.yml │ │ │ ├── ssps_kmeans_150k_uni-3 │ │ │ │ └── config.yml │ │ │ ├── ssps_kmeans_150k_uni-5 │ │ │ │ └── config.yml │ │ │ ├── ssps_kmeans_25k │ │ │ │ └── config.yml │ │ │ ├── ssps_kmeans_25k_uni-2 │ │ │ │ └── config.yml │ │ │ ├── ssps_kmeans_25k_uni-3 │ │ │ │ └── config.yml │ │ │ ├── ssps_kmeans_25k_uni-5 │ │ │ │ └── config.yml │ │ │ ├── ssps_kmeans_50k │ │ │ │ └── config.yml │ │ │ ├── ssps_kmeans_50k_uni-1 │ │ │ │ └── config.yml │ │ │ ├── ssps_kmeans_50k_uni-2 │ │ │ │ └── config.yml │ │ │ ├── ssps_kmeans_50k_uni-3 │ │ │ │ └── config.yml │ │ │ ├── ssps_kmeans_50k_uni-5 │ │ │ │ └── config.yml │ │ │ ├── ssps_knn_uni-1 │ │ │ │ └── config.yml │ │ │ ├── ssps_knn_uni-10 │ │ │ │ └── config.yml │ │ │ ├── ssps_knn_uni-100 │ │ │ │ └── config.yml │ │ │ ├── ssps_knn_uni-25 │ │ │ │ └── config.yml │ │ │ └── ssps_knn_uni-50 │ │ │ │ └── config.yml │ │ ├── ssps_kmeans_25k_uni-1 │ │ │ └── config.yml │ │ ├── ssps_kmeans_25k_uni-1_nofn │ │ │ └── config.yml │ │ └── ssps_kmeans_6k │ │ │ └── config.yml │ │ ├── simclr_e-ecapa │ │ ├── _simclr │ │ │ └── config.yml │ │ ├── baseline │ │ │ └── config.yml │ │ ├── baseline_sup │ │ │ └── config.yml │ │ ├── exps │ │ │ ├── baseline_aug-none │ │ │ │ └── config.yml │ │ │ ├── baseline_nmi │ │ │ │ └── config.yml │ │ │ ├── ssps_aug-none │ │ │ │ └── config.yml │ │ │ └── ssps_nmi │ │ │ │ └── config.yml │ │ ├── ssps_kmeans_25k_uni-1 │ │ │ └── config.yml │ │ ├── ssps_kmeans_25k_uni-1_nofn │ │ │ └── config.yml │ │ └── ssps_kmeans_6k │ │ │ └── config.yml │ │ ├── simclr_margins │ │ ├── ssps_adaface_m-0.05 │ │ │ └── config.yml │ │ ├── ssps_adaface_m-0.1 │ │ │ └── config.yml │ │ ├── ssps_adaface_m-0.2 │ │ │ └── config.yml │ │ ├── ssps_arcface_m-0.05 │ │ │ └── config.yml │ │ ├── ssps_arcface_m-0.1 │ │ │ └── config.yml │ │ ├── ssps_arcface_m-0.2 │ │ │ └── config.yml │ │ ├── ssps_cosface_m-0.05 │ │ │ └── config.yml │ │ ├── ssps_cosface_m-0.1 │ │ │ └── config.yml │ │ ├── ssps_cosface_m-0.2 │ │ │ └── config.yml │ │ ├── ssps_curricularface_m-0.05 │ │ │ └── config.yml │ │ ├── ssps_curricularface_m-0.1 │ │ │ └── config.yml │ │ ├── ssps_curricularface_m-0.1_nofn │ │ │ └── config.yml │ │ ├── ssps_curricularface_m-0.2 │ │ │ └── config.yml │ │ ├── ssps_magface │ │ │ └── config.yml │ │ └── ssps_sphereface_m-0.1 │ │ │ └── config.yml │ │ ├── simclr_margins_e-ecapa │ │ └── ssps_curricularface_m-0.1 │ │ │ └── config.yml │ │ ├── swav │ │ ├── _swav │ │ │ └── config.yml │ │ ├── baseline │ │ │ └── config.yml │ │ ├── baseline_sup │ │ │ └── config.yml │ │ ├── ssps_kmeans_25k_uni-1 │ │ │ └── config.yml │ │ └── ssps_kmeans_6k │ │ │ └── config.yml │ │ ├── swav_e-ecapa │ │ ├── _swav │ │ │ └── config.yml │ │ ├── baseline │ │ │ └── config.yml │ │ ├── baseline_sup │ │ │ └── config.yml │ │ ├── ssps_kmeans_25k_uni-1 │ │ │ └── config.yml │ │ └── ssps_kmeans_6k │ │ │ └── config.yml │ │ ├── vicreg │ │ ├── _vicreg │ │ │ └── config.yml │ │ ├── baseline │ │ │ └── config.yml │ │ ├── baseline_sup │ │ │ └── config.yml │ │ ├── ssps_kmeans_25k_uni-1 │ │ │ └── config.yml │ │ └── ssps_kmeans_6k │ │ │ └── config.yml │ │ └── vicreg_e-ecapa │ │ ├── _vicreg │ │ └── config.yml │ │ ├── baseline │ │ └── config.yml │ │ ├── baseline_sup │ │ └── config.yml │ │ ├── ssps_kmeans_25k_uni-1 │ │ └── config.yml │ │ └── ssps_kmeans_6k │ │ └── config.yml └── supervised │ ├── cremad │ └── emotion │ │ └── config.yml │ ├── voxceleb1 │ └── gender │ │ └── config.yml │ ├── voxceleb2 │ └── sv │ │ └── config.yml │ └── voxlingua107 │ └── language │ └── config.yml ├── notebooks ├── datasets │ ├── VAD.py │ ├── listen_to_augmented_audio_samples.ipynb │ ├── test_vad.ipynb │ ├── voxceleb1_listen_to_different_videos_from_speaker.ipynb │ ├── voxceleb1_stats.ipynb │ └── voxceleb2_stats.ipynb ├── evaluation │ ├── ScoreCalibration.py │ ├── evaluate_speaker_verification.ipynb │ ├── plot_speaker_verification_label_efficient.ipynb │ ├── plot_tsne_speaker_embeddings_over_ssl_training.ipynb │ ├── sv_visualization.py │ └── test_sv_score_fusion_calibration.ipynb ├── experiments │ ├── evaluate_speaker_verification_with_pca.ipynb │ └── predict_voxceleb_utt_info_from_ssl_representations.ipynb ├── notebooks_utils.py ├── requirements.txt └── ssps │ ├── ssps_simulation.ipynb │ ├── ssps_study new.ipynb │ ├── ssps_study.ipynb │ └── utils.py ├── pyproject.toml ├── pytest.ini ├── requirements.txt ├── requirements_strict.txt ├── sslsv ├── Config.py ├── __init__.py ├── bin │ ├── __init__.py │ ├── average_model.py │ ├── create_ssps_buffers_distributed.py │ ├── create_ssps_buffers_distributed_jz.py │ ├── evaluate.py │ ├── evaluate_distributed.py │ ├── evaluate_distributed_jz.py │ ├── evaluate_label_efficient.py │ ├── inference.py │ ├── inference_distributed.py │ ├── inference_distributed_jz.py │ ├── profiling.py │ ├── train.py │ ├── train_distributed.py │ └── train_distributed_jz.py ├── datasets │ ├── DataAugmentation.py │ ├── Dataset.py │ ├── DistributedSamplerWrapper.py │ ├── SSLDataset.py │ ├── Sampler.py │ ├── __init__.py │ └── utils.py ├── encoders │ ├── ECAPATDNN.py │ ├── ResNet34.py │ ├── SimpleAudioCNN.py │ ├── TDNN.py │ ├── _BaseEncoder.py │ └── __init__.py ├── evaluations │ ├── ClassificationEvaluation.py │ ├── CosineSVEvaluation.py │ ├── PLDASVEvaluation.py │ ├── _BaseEvaluation.py │ ├── _SpeakerVerificationEvaluation.py │ └── __init__.py ├── methods │ ├── BYOL │ │ ├── BYOL.py │ │ ├── BYOLLoss.py │ │ └── __init__.py │ ├── BarlowTwins │ │ ├── BarlowTwins.py │ │ ├── BarlowTwinsLoss.py │ │ └── __init__.py │ ├── CPC │ │ ├── CPC.py │ │ ├── CPCLoss.py │ │ ├── InfoNCELoss.py │ │ └── __init__.py │ ├── Combiner │ │ ├── Combiner.py │ │ └── __init__.py │ ├── DINO │ │ ├── DINO.py │ │ ├── DINOLoss.py │ │ └── __init__.py │ ├── DeepCluster │ │ ├── DeepCluster.py │ │ ├── DeepClusterLoss.py │ │ ├── KMeans.py │ │ └── __init__.py │ ├── LIM │ │ ├── LIM.py │ │ ├── LIMLoss.py │ │ └── __init__.py │ ├── MoCo │ │ ├── MoCo.py │ │ ├── MoCoLoss.py │ │ └── __init__.py │ ├── MoCoMargins │ │ ├── MoCoMargins.py │ │ ├── MoCoMarginsLoss.py │ │ └── __init__.py │ ├── SimCLR │ │ ├── SimCLR.py │ │ ├── SimCLRLoss.py │ │ └── __init__.py │ ├── SimCLRMargins │ │ ├── SimCLRMargins.py │ │ ├── SimCLRMarginsLoss.py │ │ └── __init__.py │ ├── SimCLRMultiViews │ │ ├── SimCLRMultiViews.py │ │ ├── SimCLRMultiViewsLoss.py │ │ └── __init__.py │ ├── SimSiam │ │ ├── SimSiam.py │ │ ├── SimSiamLoss.py │ │ └── __init__.py │ ├── Supervised │ │ ├── AAMSoftmaxLoss.py │ │ ├── ARPLLoss.py │ │ ├── Supervised.py │ │ └── __init__.py │ ├── SwAV │ │ ├── SinkhornKnopp.py │ │ ├── SwAV.py │ │ ├── SwAVLoss.py │ │ └── __init__.py │ ├── VICReg │ │ ├── VICReg.py │ │ ├── VICRegLoss.py │ │ └── __init__.py │ ├── VIbCReg │ │ ├── IterNorm.py │ │ ├── VIbCReg.py │ │ ├── VIbCRegLoss.py │ │ └── __init__.py │ ├── WMSE │ │ ├── WMSE.py │ │ ├── WMSELoss.py │ │ ├── Whitening2d.py │ │ └── __init__.py │ ├── _BaseMethod.py │ ├── _BaseMomentumMethod.py │ ├── _BaseSiameseMethod.py │ ├── _SSPS │ │ ├── KMeans.py │ │ ├── SSPS.py │ │ ├── SSPSConfig.py │ │ ├── SSPSSamplingMethods.py │ │ └── __init__.py │ └── __init__.py ├── trainer │ ├── EpochLogger.py │ ├── Trainer.py │ └── __init__.py └── utils │ ├── __init__.py │ ├── distributed.py │ └── helpers.py ├── tests ├── __init__.py ├── encoders │ ├── __init__.py │ ├── test_ecapatdnn.py │ ├── test_resnet34.py │ ├── test_simpleaudiocnn.py │ └── test_tdnn.py ├── methods │ ├── __init__.py │ ├── test_barlowtwins.py │ ├── test_byol.py │ ├── test_combiner.py │ ├── test_cpc.py │ ├── test_dino.py │ ├── test_lim.py │ ├── test_moco.py │ ├── test_simclr.py │ ├── test_simclr_margins.py │ ├── test_simsiam.py │ ├── test_supervised.py │ ├── test_vibcreg.py │ ├── test_vicreg.py │ └── test_wmse.py ├── resources │ ├── empty.yml │ ├── no_encoder.yml │ ├── no_method.yml │ └── simple │ │ └── config.yml ├── test_config.py ├── test_dataset.py ├── test_reproductibility.py ├── test_sampler.py └── test_trainer.py ├── tools ├── download_models.sh ├── export_model_metrics.py ├── prepare_data │ ├── create_cremad_train_csv.py │ ├── create_sitw_trials.py │ ├── create_voices_trials.py │ ├── create_voxlingua107_train_csv.py │ ├── prepare_augmentation.py │ ├── prepare_voxceleb.py │ └── utils.py ├── rsync_jz.sh └── slurm │ ├── clean_ckpts.py │ └── jobs_eta.py ├── train_ddp.sh ├── train_ddp_jz.sh ├── train_ddp_ssps_jz_x2.sh ├── train_ddp_ssps_jz_x2_exp.sh ├── train_ddp_ssps_jz_x4.sh └── training_framework.svg /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .ipynb_checkpoints 3 | 4 | .vscode 5 | 6 | .pypirc 7 | 8 | *.egg-info 9 | build 10 | dist 11 | 12 | data 13 | env 14 | wandb 15 | 16 | models/**/tensorboard/ 17 | models/**/wandb/ 18 | models/**/*.json 19 | models/**/*.pt 20 | 21 | tests/resources/**/tensorboard/ 22 | tests/resources/**/wandb/ 23 | tests/resources/**/*.json 24 | tests/resources/**/*.pt 25 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v3.2.0 4 | hooks: 5 | - id: trailing-whitespace 6 | exclude: '\.md$' 7 | - id: end-of-file-fixer 8 | - id: check-yaml 9 | - id: check-json 10 | - id: check-toml 11 | - id: check-case-conflict 12 | - id: check-added-large-files 13 | args: ['--maxkb', "750"] 14 | - id: check-docstring-first 15 | - id: detect-private-key 16 | - repo: https://github.com/kynan/nbstripout 17 | rev: 0.6.1 18 | hooks: 19 | - id: nbstripout 20 | - repo: https://github.com/psf/black 21 | rev: 24.3.0 22 | hooks: 23 | - id: black 24 | args: [--line-length, "88"] 25 | language_version: python3 26 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Theo Lepage 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /evaluate_ddp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ $# -eq 0 ]; then 4 | echo "Usage: $0 [args ...]" 5 | exit 1 6 | fi 7 | 8 | num_gpus=$1 9 | 10 | shift 11 | 12 | # torchrun --nproc_per_node=$num_gpus sslsv/bin/evaluate_distributed.py "$@" 13 | 14 | python sslsv/bin/average_model.py "$1" --silent 15 | torchrun --nproc_per_node=$num_gpus sslsv/bin/evaluate_distributed.py "$@" --model_suffix avg 16 | -------------------------------------------------------------------------------- /evaluate_ddp_jz.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | models="$@" 4 | 5 | commands="" 6 | 7 | for model in $models; do 8 | commands+=" 9 | python sslsv/bin/average_model.py $model/config.yml --silent 10 | srun python -u sslsv/bin/evaluate_distributed_jz.py $model/config.yml --model_suffix avg --silent 11 | " 12 | done 13 | 14 | sbatch <=1.11.0 2 | torchaudio>=0.11.0 3 | 4 | numpy 5 | pandas 6 | soundfile 7 | scikit-learn 8 | speechbrain 9 | tensorboard 10 | wandb 11 | 12 | ruamel.yaml 13 | dacite 14 | prettyprinter 15 | tqdm 16 | -------------------------------------------------------------------------------- /requirements_strict.txt: -------------------------------------------------------------------------------- 1 | torch==1.11.0 2 | torchaudio==0.11.0 3 | 4 | numpy==1.23.1 5 | pandas==1.4.3 6 | soundfile==0.11.0 7 | scikit-learn==1.0.2 8 | speechbrain==0.5.13 9 | tensorboard==2.10.0 10 | wandb==0.13.3 11 | 12 | ruamel.yaml==0.17.21 13 | dacite==1.8.1 14 | prettyprinter==0.18.0 15 | tqdm==4.64.0 16 | 17 | plotnine==0.9.0 18 | plotly==5.10.0 19 | seaborn==0.12.0 20 | -------------------------------------------------------------------------------- /sslsv/Config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from pathlib import Path 4 | 5 | from sslsv.encoders._BaseEncoder import BaseEncoderConfig 6 | from sslsv.methods._BaseMethod import BaseMethodConfig 7 | from sslsv.evaluations._BaseEvaluation import BaseEvaluationConfig 8 | from sslsv.trainer.Trainer import TrainerConfig 9 | from sslsv.datasets.Dataset import DatasetConfig 10 | 11 | 12 | @dataclass 13 | class Config: 14 | """ 15 | Global configuration. 16 | 17 | Attributes: 18 | model_name (str): Name of the model. 19 | model_path (Path): Path to the model directory. 20 | seed (int): Seed for reproducibility. 21 | reproducibility (bool): Whether or not to enable reproducibility mode. 22 | encoder (BaseEncoderConfig): Encoder configuration. 23 | method (BaseMethodConfig): Method configuration. 24 | trainer (TrainerConfig): Trainer configuration. 25 | dataset (DatasetConfig): Dataset configuration. 26 | evaluation (BaseEvaluationConfig): Evaluation configuration. 27 | """ 28 | 29 | model_name: str = "default" 30 | model_path: Path = Path("default") 31 | 32 | seed: int = 1717 33 | reproducibility: bool = False 34 | 35 | encoder: BaseEncoderConfig = None 36 | method: BaseMethodConfig = None 37 | trainer: TrainerConfig = TrainerConfig() 38 | dataset: DatasetConfig = DatasetConfig() 39 | evaluation: BaseEvaluationConfig = BaseEvaluationConfig() 40 | -------------------------------------------------------------------------------- /sslsv/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theolepage/sslsv/6f062138724365f3a8b1813664e0a1007d5b3ce4/sslsv/__init__.py -------------------------------------------------------------------------------- /sslsv/bin/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theolepage/sslsv/6f062138724365f3a8b1813664e0a1007d5b3ce4/sslsv/bin/__init__.py -------------------------------------------------------------------------------- /sslsv/bin/inference_distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) 5 | 6 | import argparse 7 | 8 | import torch 9 | from torch.nn.parallel import DistributedDataParallel 10 | 11 | from sslsv.bin.inference import inference_parser, inference_ 12 | from sslsv.utils.helpers import load_config, load_model 13 | 14 | 15 | def inference(args: argparse.Namespace): 16 | """ 17 | Perform model inference from the CLI (using DistributedDataParallel). 18 | 19 | Args: 20 | args (argparse.Namespace): Arguments parsed from the command line. 21 | 22 | Returns: 23 | None 24 | """ 25 | world_size = int(os.environ["WORLD_SIZE"]) # idr_torch.size 26 | rank = int(os.environ["LOCAL_RANK"]) # idr_torch.rank 27 | 28 | torch.distributed.init_process_group("nccl", rank=rank, world_size=world_size) 29 | torch.cuda.set_device(rank) 30 | 31 | config = load_config(args.config, verbose=not args.silent) 32 | 33 | model = load_model(config).to(rank) 34 | checkpoint = torch.load(config.model_ckpt_path / f"model_{args.model_suffix}.pt") 35 | model.load_state_dict(checkpoint["model"], strict=False) 36 | model.eval() 37 | model = DistributedDataParallel(model, device_ids=[rank]) 38 | 39 | inference_( 40 | config, 41 | model, 42 | torch.device("cuda", rank), 43 | args.input, 44 | args.output, 45 | batch_size=args.batch_size, 46 | frame_length=args.frame_length, 47 | num_frames=args.num_frames, 48 | verbose=not args.silent, 49 | ) 50 | 51 | 52 | if __name__ == "__main__": 53 | parser = inference_parser() 54 | args = parser.parse_args() 55 | inference(args) 56 | -------------------------------------------------------------------------------- /sslsv/bin/inference_distributed_jz.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) 5 | 6 | import argparse 7 | 8 | import torch 9 | from torch.nn.parallel import DistributedDataParallel 10 | 11 | from sslsv.bin.inference import inference_parser, inference_ 12 | from sslsv.utils.helpers import load_config, load_model 13 | 14 | import idr_torch 15 | 16 | 17 | def inference(args: argparse.Namespace): 18 | """ 19 | Perform model inference from the CLI (using DistributedDataParallel). 20 | 21 | Args: 22 | args (argparse.Namespace): Arguments parsed from the command line. 23 | 24 | Returns: 25 | None 26 | """ 27 | world_size = idr_torch.size 28 | rank = idr_torch.rank 29 | 30 | torch.distributed.init_process_group("nccl", rank=rank, world_size=world_size) 31 | torch.cuda.set_device(rank) 32 | 33 | config = load_config(args.config, verbose=not args.silent) 34 | 35 | model = load_model(config).to(rank) 36 | checkpoint = torch.load(config.model_ckpt_path / f"model_{args.model_suffix}.pt") 37 | model.load_state_dict(checkpoint["model"], strict=False) 38 | model.eval() 39 | model = DistributedDataParallel(model, device_ids=[rank]) 40 | 41 | inference_( 42 | config, 43 | model, 44 | torch.device("cuda", rank), 45 | args.input, 46 | args.output, 47 | batch_size=args.batch_size, 48 | frame_length=args.frame_length, 49 | num_frames=args.num_frames, 50 | verbose=not args.silent, 51 | ) 52 | 53 | 54 | if __name__ == "__main__": 55 | parser = inference_parser() 56 | args = parser.parse_args() 57 | inference(args) 58 | -------------------------------------------------------------------------------- /sslsv/bin/train_distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) 5 | 6 | import argparse 7 | 8 | import torch 9 | from torch.nn.parallel import DistributedDataParallel 10 | 11 | from sslsv.trainer.Trainer import Trainer 12 | from sslsv.utils.helpers import load_config, load_train_dataloader, load_model, evaluate 13 | 14 | 15 | def train(args: argparse.Namespace): 16 | """ 17 | Train a model from the CLI (using DistributedDataParallel). 18 | 19 | Args: 20 | args (argparse.Namespace): Arguments parsed from the command line. 21 | 22 | Returns: 23 | None 24 | """ 25 | world_size = int(os.environ["WORLD_SIZE"]) 26 | rank = int(os.environ["LOCAL_RANK"]) 27 | 28 | torch.distributed.init_process_group("nccl", rank=rank, world_size=world_size) 29 | torch.cuda.set_device(rank) 30 | 31 | config = load_config(args.config) 32 | train_dataloader = load_train_dataloader(config) 33 | 34 | model = load_model(config).to(rank) 35 | model = DistributedDataParallel(model, device_ids=[rank]) 36 | 37 | trainer = Trainer( 38 | model=model, 39 | train_dataloader=train_dataloader, 40 | config=config, 41 | evaluate=evaluate, 42 | device=torch.device("cuda", rank), 43 | ) 44 | trainer.start() 45 | 46 | torch.distributed.destroy_process_group() 47 | 48 | 49 | if __name__ == "__main__": 50 | parser = argparse.ArgumentParser() 51 | parser.add_argument("config", type=str, help="Path to model config file.") 52 | args = parser.parse_args() 53 | 54 | train(args) 55 | -------------------------------------------------------------------------------- /sslsv/bin/train_distributed_jz.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) 5 | 6 | import argparse 7 | 8 | import torch 9 | from torch.nn.parallel import DistributedDataParallel 10 | 11 | from sslsv.trainer.Trainer import Trainer 12 | from sslsv.utils.helpers import load_config, load_train_dataloader, load_model, evaluate 13 | 14 | import idr_torch 15 | 16 | 17 | def train(args: argparse.Namespace): 18 | """ 19 | Train a model on Jean Zay from the CLI (using DistributedDataParallel). 20 | 21 | Args: 22 | args (argparse.Namespace): Arguments parsed from the command line. 23 | 24 | Returns: 25 | None 26 | """ 27 | world_size = idr_torch.size 28 | rank = idr_torch.rank 29 | 30 | torch.distributed.init_process_group("nccl", rank=rank, world_size=world_size) 31 | torch.cuda.set_device(rank) 32 | 33 | config = load_config(args.config) 34 | train_dataloader = load_train_dataloader(config) 35 | 36 | model = load_model(config).to(rank) 37 | model = DistributedDataParallel(model, device_ids=[rank]) 38 | 39 | trainer = Trainer( 40 | model=model, 41 | train_dataloader=train_dataloader, 42 | config=config, 43 | evaluate=evaluate, 44 | device=torch.device("cuda", rank), 45 | ) 46 | trainer.start() 47 | 48 | torch.distributed.destroy_process_group() 49 | 50 | 51 | if __name__ == "__main__": 52 | parser = argparse.ArgumentParser() 53 | parser.add_argument("config", type=str, help="Path to model config file.") 54 | args = parser.parse_args() 55 | 56 | train(args) 57 | -------------------------------------------------------------------------------- /sslsv/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theolepage/sslsv/6f062138724365f3a8b1813664e0a1007d5b3ce4/sslsv/datasets/__init__.py -------------------------------------------------------------------------------- /sslsv/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theolepage/sslsv/6f062138724365f3a8b1813664e0a1007d5b3ce4/sslsv/encoders/__init__.py -------------------------------------------------------------------------------- /sslsv/evaluations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theolepage/sslsv/6f062138724365f3a8b1813664e0a1007d5b3ce4/sslsv/evaluations/__init__.py -------------------------------------------------------------------------------- /sslsv/methods/BYOL/BYOLLoss.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch import Tensor as T 3 | import torch.nn.functional as F 4 | 5 | 6 | class BYOLLoss(nn.Module): 7 | """ 8 | BYOL loss. 9 | """ 10 | 11 | def __init__(self): 12 | """ 13 | Initialize a BYOL loss. 14 | 15 | Returns: 16 | None 17 | """ 18 | super().__init__() 19 | 20 | def forward(self, P: T, Z: T) -> T: 21 | """ 22 | Compute loss. 23 | 24 | Args: 25 | P (T): Embeddings tensor of predictor. 26 | Z (T): Embeddings tensor of projector. 27 | 28 | Returns: 29 | T: Loss tensor. 30 | """ 31 | return 2 - 2 * F.cosine_similarity(P, Z.detach(), dim=-1).mean() 32 | -------------------------------------------------------------------------------- /sslsv/methods/BYOL/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theolepage/sslsv/6f062138724365f3a8b1813664e0a1007d5b3ce4/sslsv/methods/BYOL/__init__.py -------------------------------------------------------------------------------- /sslsv/methods/BarlowTwins/BarlowTwins.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Callable 3 | 4 | from sslsv.encoders._BaseEncoder import BaseEncoder 5 | from sslsv.methods._BaseSiameseMethod import BaseSiameseMethod, BaseSiameseMethodConfig 6 | 7 | from .BarlowTwinsLoss import BarlowTwinsLoss 8 | 9 | 10 | @dataclass 11 | class BarlowTwinsConfig(BaseSiameseMethodConfig): 12 | """ 13 | Barlow Twins method configuration. 14 | 15 | Attributes: 16 | lamda (float): Redundancy reduction weight. Defaults to 0.005. 17 | """ 18 | 19 | lamda: float = 0.005 20 | 21 | 22 | class BarlowTwins(BaseSiameseMethod): 23 | """ 24 | Barlow Twins method. 25 | 26 | Paper: 27 | Barlow Twins: Self-Supervised Learning via Redundancy Reduction 28 | *Jure Zbontar, Li Jing, Ishan Misra, Yann LeCun, Stéphane Deny* 29 | ICML 2021 30 | https://arxiv.org/abs/2103.03230 31 | 32 | Attributes: 33 | loss_fn (BarlowTwinsLoss): Loss function. 34 | """ 35 | 36 | def __init__( 37 | self, 38 | config: BarlowTwinsConfig, 39 | create_encoder_fn: Callable[[], BaseEncoder], 40 | ): 41 | """ 42 | Initialize a Barlow Twins method. 43 | 44 | Args: 45 | config (BarlowTwinsConfig): Method configuration. 46 | create_encoder_fn (Callable): Function that creates an encoder object. 47 | 48 | Returns: 49 | None 50 | """ 51 | super().__init__(config, create_encoder_fn) 52 | 53 | self.loss_fn = BarlowTwinsLoss(config.lamda) -------------------------------------------------------------------------------- /sslsv/methods/BarlowTwins/BarlowTwinsLoss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch import Tensor as T 4 | 5 | import torch.distributed as dist 6 | from sslsv.utils.distributed import is_dist_initialized, get_world_size 7 | 8 | 9 | class BarlowTwinsLoss(nn.Module): 10 | """ 11 | Barlow Twins loss. 12 | 13 | Attributes: 14 | lamda (float): Redundancy reduction weight. 15 | scale (float): Loss scaling factor. 16 | """ 17 | 18 | def __init__(self, lamda: float = 0.05, scale: float = 0.025): 19 | """ 20 | Initialize a Barlow Twins loss. 21 | 22 | Args: 23 | lamda (float): Redundancy reduction weight. Defaults to 0.05. 24 | scale (float): Loss scaling factor. Defaults to 0.025. 25 | """ 26 | super().__init__() 27 | 28 | self.lamda = lamda 29 | self.scale = scale 30 | 31 | def forward(self, Z_a: T, Z_b: T) -> T: 32 | """ 33 | Compute loss. 34 | 35 | Args: 36 | Z_a (T): Embeddings tensor of view A. 37 | Z_b (T): Embeddings tensor of view B. 38 | 39 | Returns: 40 | T: Loss tensor. 41 | """ 42 | N, D = Z_a.size() 43 | 44 | bn = nn.BatchNorm1d(D, affine=False).to(Z_a.device) 45 | Z_a = bn(Z_a) 46 | Z_b = bn(Z_b) 47 | 48 | c = (Z_a.T @ Z_b) / N 49 | 50 | if is_dist_initialized(): 51 | dist.all_reduce(c) 52 | c /= get_world_size() 53 | 54 | diag = torch.eye(D, device=Z_a.device) 55 | 56 | loss = (c - diag).pow(2) 57 | loss[~diag.bool()] *= self.lamda 58 | return loss.sum() * self.scale 59 | -------------------------------------------------------------------------------- /sslsv/methods/BarlowTwins/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theolepage/sslsv/6f062138724365f3a8b1813664e0a1007d5b3ce4/sslsv/methods/BarlowTwins/__init__.py -------------------------------------------------------------------------------- /sslsv/methods/CPC/CPCLoss.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch import Tensor as T 3 | 4 | from .InfoNCELoss import InfoNCELoss 5 | 6 | 7 | class CPCLoss(nn.Module): 8 | """ 9 | CPC loss. 10 | 11 | Attributes: 12 | infonce (InfoNCELoss): InfoNCE loss. 13 | """ 14 | 15 | def __init__(self): 16 | """ 17 | Initialize a CPC loss. 18 | 19 | Returns: 20 | None 21 | """ 22 | super().__init__() 23 | 24 | self.infonce = InfoNCELoss(temperature=1.0, normalize=False) 25 | 26 | def forward(self, Y_future_preds: T, Y_future: T) -> T: 27 | """ 28 | Compute loss. 29 | 30 | Args: 31 | Y_future_preds (T): Predicted embeddings tensor. 32 | Y_future (T): Embeddings tensor. 33 | 34 | Returns: 35 | T: Loss tensor. 36 | """ 37 | # Shape: (N, encoded_dim, nb_t_to_predict) 38 | 39 | nb_t_to_predict = Y_future.size(2) 40 | 41 | loss = 0 42 | for t in range(nb_t_to_predict): 43 | loss += self.infonce( 44 | Y_future[:, :, t].contiguous(), 45 | Y_future_preds[:, :, t].contiguous(), 46 | ) 47 | 48 | loss /= nb_t_to_predict 49 | 50 | return loss 51 | -------------------------------------------------------------------------------- /sslsv/methods/CPC/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theolepage/sslsv/6f062138724365f3a8b1813664e0a1007d5b3ce4/sslsv/methods/CPC/__init__.py -------------------------------------------------------------------------------- /sslsv/methods/Combiner/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theolepage/sslsv/6f062138724365f3a8b1813664e0a1007d5b3ce4/sslsv/methods/Combiner/__init__.py -------------------------------------------------------------------------------- /sslsv/methods/DINO/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theolepage/sslsv/6f062138724365f3a8b1813664e0a1007d5b3ce4/sslsv/methods/DINO/__init__.py -------------------------------------------------------------------------------- /sslsv/methods/DeepCluster/DeepClusterLoss.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | from torch import Tensor as T 4 | 5 | 6 | class DeepClusterLoss(nn.Module): 7 | """ 8 | DeepCluster loss. 9 | 10 | Attributes: 11 | temperature (float): Temperature value. 12 | """ 13 | 14 | def __init__(self, temperature: float = 0.1): 15 | """ 16 | Initialize a DeepCluster loss. 17 | 18 | Args: 19 | temperature (float): Temperature value. Defaults to 0.1. 20 | 21 | Returns: 22 | None 23 | """ 24 | super().__init__() 25 | 26 | self.temperature = temperature 27 | 28 | def forward(self, preds: T, assignments: T) -> T: 29 | """ 30 | Compute loss. 31 | 32 | Args: 33 | preds (T): Predictions tensor. Shape: (P, V, N, K). 34 | assignments (T): Assignment tensor. Shape: (P, N). 35 | 36 | Returns: 37 | T: Loss tensor. 38 | """ 39 | P, V, N, C = preds.size() 40 | 41 | loss = 0 42 | for p in range(P): 43 | logits = preds[p].view(-1, C) / self.temperature # (V*N, K) 44 | 45 | targets = assignments[p].repeat(V) # (V*N) 46 | targets = targets.to(preds.device, non_blocking=True) 47 | 48 | loss += F.cross_entropy(logits, targets, ignore_index=-1) 49 | 50 | return loss / P 51 | -------------------------------------------------------------------------------- /sslsv/methods/DeepCluster/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theolepage/sslsv/6f062138724365f3a8b1813664e0a1007d5b3ce4/sslsv/methods/DeepCluster/__init__.py -------------------------------------------------------------------------------- /sslsv/methods/LIM/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theolepage/sslsv/6f062138724365f3a8b1813664e0a1007d5b3ce4/sslsv/methods/LIM/__init__.py -------------------------------------------------------------------------------- /sslsv/methods/MoCo/MoCoLoss.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | from torch import Tensor as T 7 | 8 | 9 | class MoCoLoss(nn.Module): 10 | """ 11 | MoCo loss. 12 | 13 | Attributes: 14 | temperature (float): Temperature value. 15 | """ 16 | 17 | def __init__(self, temperature: float = 0.2): 18 | """ 19 | Initialize a MoCo loss. 20 | 21 | Args: 22 | temperature (float): Temperature value. Defaults to 0.2. 23 | 24 | Returns: 25 | None 26 | """ 27 | super().__init__() 28 | 29 | self.temperature = temperature 30 | 31 | def forward( 32 | self, 33 | query: T, 34 | key: T, 35 | queue: T, 36 | current_labels: Optional[T] = None, 37 | queue_labels: Optional[T] = None, 38 | ) -> T: 39 | """ 40 | Compute loss. 41 | 42 | Args: 43 | query (T): Query tensor. 44 | key (T): Key tensor. 45 | queue (T): Queue tensor. 46 | current_labels (Optional[T]): Labels tensor from the query/key. 47 | queue_labels (Optional[T]): Labels tensor from the queue. 48 | 49 | Returns: 50 | T: Loss tensor. 51 | """ 52 | N, _ = query.size() 53 | 54 | pos = torch.einsum("nc,nc->n", (query, key)).unsqueeze(-1) 55 | neg = torch.einsum("nc,ck->nk", (query, queue)) 56 | 57 | # Prevent class collisions using labels 58 | if current_labels is not None and queue_labels is not None: 59 | mask = current_labels.unsqueeze(1) == queue_labels.unsqueeze(0) 60 | neg[mask] = 0 61 | 62 | logits = torch.cat((pos, neg), dim=1) / self.temperature 63 | 64 | labels = torch.zeros(N, device=query.device, dtype=torch.long) 65 | 66 | return F.cross_entropy(logits, labels) 67 | -------------------------------------------------------------------------------- /sslsv/methods/MoCo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theolepage/sslsv/6f062138724365f3a8b1813664e0a1007d5b3ce4/sslsv/methods/MoCo/__init__.py -------------------------------------------------------------------------------- /sslsv/methods/MoCoMargins/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theolepage/sslsv/6f062138724365f3a8b1813664e0a1007d5b3ce4/sslsv/methods/MoCoMargins/__init__.py -------------------------------------------------------------------------------- /sslsv/methods/SimCLR/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theolepage/sslsv/6f062138724365f3a8b1813664e0a1007d5b3ce4/sslsv/methods/SimCLR/__init__.py -------------------------------------------------------------------------------- /sslsv/methods/SimCLRMargins/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theolepage/sslsv/6f062138724365f3a8b1813664e0a1007d5b3ce4/sslsv/methods/SimCLRMargins/__init__.py -------------------------------------------------------------------------------- /sslsv/methods/SimCLRMultiViews/SimCLRMultiViewsLoss.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor as T 2 | 3 | from sslsv.methods.SimCLRMargins.SimCLRMarginsLoss import ( 4 | SimCLRMarginsLoss, 5 | SimCLRMarginsLossConfig, 6 | ) 7 | 8 | 9 | class SimCLRMultiViewsLoss(SimCLRMarginsLoss): 10 | """ 11 | SimCLR MultiViews loss. 12 | 13 | Attributes: 14 | config (SimCLRMarginsLossConfig): Loss configuration. 15 | loss_fn (LossFunction): Loss function. 16 | """ 17 | 18 | def __init__(self, config: SimCLRMarginsLossConfig): 19 | """ 20 | Initialize a SimCLR MultiViews loss. 21 | 22 | Args: 23 | config (SimCLRMarginsLossConfig): Loss configuration. 24 | 25 | Returns: 26 | None 27 | """ 28 | super().__init__(config) 29 | 30 | def forward(self, Z: T) -> T: 31 | """ 32 | Compute loss. 33 | 34 | Args: 35 | Z (T): Embeddings tensor. Shape: (N, V, D). 36 | 37 | Returns: 38 | T: Loss tensor. 39 | """ 40 | global_embeddings = Z[:, :2] 41 | local_embeddings = Z[:, 2:] 42 | 43 | loss = self.loss_fn(local_embeddings, global_embeddings, discard_identity=False) 44 | 45 | return loss 46 | -------------------------------------------------------------------------------- /sslsv/methods/SimCLRMultiViews/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theolepage/sslsv/6f062138724365f3a8b1813664e0a1007d5b3ce4/sslsv/methods/SimCLRMultiViews/__init__.py -------------------------------------------------------------------------------- /sslsv/methods/SimSiam/SimSiamLoss.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | from torch import Tensor as T 4 | 5 | 6 | class SimSiamLoss(nn.Module): 7 | """ 8 | SimSiam loss. 9 | """ 10 | 11 | def __init__(self): 12 | """ 13 | Initialize a SimSiam loss. 14 | 15 | Returns: 16 | None 17 | """ 18 | super().__init__() 19 | 20 | def forward(self, P: T, Z: T) -> T: 21 | """ 22 | Compute loss. 23 | 24 | Args: 25 | P (T): Predictions tensor. 26 | Z (T): Embeddings tensor. 27 | 28 | Returns: 29 | T: Loss tensor. 30 | """ 31 | return -F.cosine_similarity(P, Z.detach(), dim=-1).mean() 32 | -------------------------------------------------------------------------------- /sslsv/methods/SimSiam/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theolepage/sslsv/6f062138724365f3a8b1813664e0a1007d5b3ce4/sslsv/methods/SimSiam/__init__.py -------------------------------------------------------------------------------- /sslsv/methods/Supervised/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theolepage/sslsv/6f062138724365f3a8b1813664e0a1007d5b3ce4/sslsv/methods/Supervised/__init__.py -------------------------------------------------------------------------------- /sslsv/methods/SwAV/SwAVLoss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torch import Tensor as T 5 | 6 | 7 | class SwAVLoss(nn.Module): 8 | """ 9 | SwAV loss. 10 | 11 | Attributes: 12 | temperature (float): Temperature value. 13 | """ 14 | 15 | def __init__(self, temperature: float = 0.1): 16 | """ 17 | Initialize a SwAV loss. 18 | 19 | Args: 20 | temperature (float): Temperature value. Defaults to 0.1. 21 | 22 | Returns: 23 | None 24 | """ 25 | super().__init__() 26 | 27 | self.temperature = temperature 28 | 29 | def forward(self, preds: T, assignments: T) -> T: 30 | """ 31 | Compute loss. 32 | 33 | Args: 34 | preds (T): Predictions tensor. 35 | assignments (T): Assignments tensor. 36 | 37 | Returns: 38 | T: Loss tensor. 39 | """ 40 | losses = [] 41 | for i, A in enumerate(assignments): 42 | for j, P in enumerate(preds): 43 | if i == j: 44 | continue 45 | 46 | P = P / self.temperature 47 | loss = -torch.mean(torch.sum(A * F.log_softmax(P, dim=1), dim=1)) 48 | losses.append(loss) 49 | 50 | return sum(losses) / len(losses) 51 | -------------------------------------------------------------------------------- /sslsv/methods/SwAV/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theolepage/sslsv/6f062138724365f3a8b1813664e0a1007d5b3ce4/sslsv/methods/SwAV/__init__.py -------------------------------------------------------------------------------- /sslsv/methods/VICReg/VICReg.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Callable 3 | 4 | from sslsv.encoders._BaseEncoder import BaseEncoder 5 | from sslsv.methods._BaseSiameseMethod import BaseSiameseMethod, BaseSiameseMethodConfig 6 | 7 | from .VICRegLoss import VICRegLoss 8 | 9 | 10 | @dataclass 11 | class VICRegConfig(BaseSiameseMethodConfig): 12 | """ 13 | VICReg method configuration. 14 | 15 | Attributes: 16 | inv_weight (float): Weight of invariance loss term. 17 | var_weight (float): Weight of variance loss term. 18 | cov_weight (float): Weight of covariance loss term. 19 | """ 20 | 21 | inv_weight: float = 1.0 22 | var_weight: float = 1.0 23 | cov_weight: float = 0.04 24 | 25 | 26 | class VICReg(BaseSiameseMethod): 27 | """ 28 | VICReg (Variance-Invariance-Covariance Regularization) method. 29 | 30 | Paper: 31 | VICReg: Variance-Invariance-Covariance Regularization for Self-Supervised Learning 32 | *Adrien Bardes, Jean Ponce, Yann LeCun* 33 | ICLR 2022 34 | https://arxiv.org/abs/2105.04906 35 | 36 | Attributes: 37 | loss_fn (VICRegLoss): Loss function. 38 | """ 39 | 40 | def __init__( 41 | self, 42 | config: VICRegConfig, 43 | create_encoder_fn: Callable[[], BaseEncoder], 44 | ): 45 | """ 46 | Initialize a VICReg method. 47 | 48 | Args: 49 | config (VICRegConfig): Method configuration. 50 | create_encoder_fn (Callable): Function that creates an encoder object. 51 | 52 | Returns: 53 | None 54 | """ 55 | super().__init__(config, create_encoder_fn) 56 | 57 | self.loss_fn = VICRegLoss( 58 | config.inv_weight, 59 | config.var_weight, 60 | config.cov_weight, 61 | ) 62 | -------------------------------------------------------------------------------- /sslsv/methods/VICReg/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theolepage/sslsv/6f062138724365f3a8b1813664e0a1007d5b3ce4/sslsv/methods/VICReg/__init__.py -------------------------------------------------------------------------------- /sslsv/methods/VIbCReg/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theolepage/sslsv/6f062138724365f3a8b1813664e0a1007d5b3ce4/sslsv/methods/VIbCReg/__init__.py -------------------------------------------------------------------------------- /sslsv/methods/WMSE/WMSELoss.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | from torch import Tensor as T 4 | 5 | 6 | class WMSELoss(nn.Module): 7 | """ 8 | W-MSE loss. 9 | """ 10 | 11 | def __init__(self): 12 | """ 13 | Initialize a W-MSE loss. 14 | 15 | Returns: 16 | None 17 | """ 18 | super().__init__() 19 | 20 | def forward(self, Z_a: T, Z_b: T) -> T: 21 | """ 22 | Compute loss. 23 | 24 | Args: 25 | Z_a (T): Embeddings tensor of view A. 26 | Z_b (T): Embeddings tensor of view B. 27 | 28 | Returns: 29 | T: Loss tensor. 30 | """ 31 | return 2 - 2 * (F.normalize(Z_a) * F.normalize(Z_b)).sum(dim=-1).mean() 32 | -------------------------------------------------------------------------------- /sslsv/methods/WMSE/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theolepage/sslsv/6f062138724365f3a8b1813664e0a1007d5b3ce4/sslsv/methods/WMSE/__init__.py -------------------------------------------------------------------------------- /sslsv/methods/_SSPS/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theolepage/sslsv/6f062138724365f3a8b1813664e0a1007d5b3ce4/sslsv/methods/_SSPS/__init__.py -------------------------------------------------------------------------------- /sslsv/methods/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theolepage/sslsv/6f062138724365f3a8b1813664e0a1007d5b3ce4/sslsv/methods/__init__.py -------------------------------------------------------------------------------- /sslsv/trainer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theolepage/sslsv/6f062138724365f3a8b1813664e0a1007d5b3ce4/sslsv/trainer/__init__.py -------------------------------------------------------------------------------- /sslsv/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theolepage/sslsv/6f062138724365f3a8b1813664e0a1007d5b3ce4/sslsv/utils/__init__.py -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theolepage/sslsv/6f062138724365f3a8b1813664e0a1007d5b3ce4/tests/__init__.py -------------------------------------------------------------------------------- /tests/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theolepage/sslsv/6f062138724365f3a8b1813664e0a1007d5b3ce4/tests/encoders/__init__.py -------------------------------------------------------------------------------- /tests/encoders/test_ecapatdnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from sslsv.encoders.ECAPATDNN import ECAPATDNN, ECAPATDNNConfig 5 | 6 | 7 | def count_parameters(model: nn.Module) -> int: 8 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 9 | 10 | 11 | def test_default(): 12 | config = ECAPATDNNConfig() 13 | encoder = ECAPATDNN(config) 14 | 15 | assert count_parameters(encoder) == 7075008 16 | 17 | Y = encoder(torch.randn(64, 32000)) 18 | 19 | assert isinstance(Y, torch.Tensor) 20 | assert Y.dtype == torch.float32 21 | assert Y.size() == (64, 512) 22 | -------------------------------------------------------------------------------- /tests/encoders/test_resnet34.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from sslsv.encoders.ResNet34 import ResNet34, ResNet34Config 5 | 6 | 7 | def count_parameters(model: nn.Module) -> int: 8 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 9 | 10 | 11 | def test_default(): 12 | config = ResNet34Config() 13 | encoder = ResNet34(config) 14 | 15 | assert count_parameters(encoder) == 1437078 16 | 17 | Y = encoder(torch.randn(64, 32000)) 18 | 19 | assert isinstance(Y, torch.Tensor) 20 | assert Y.dtype == torch.float32 21 | assert Y.size() == (64, 512) 22 | 23 | 24 | def test_default_no_pooling(): 25 | config = ResNet34Config(pooling=False) 26 | encoder = ResNet34(config) 27 | 28 | assert count_parameters(encoder) == 1682582 29 | 30 | Y = encoder(torch.randn(64, 32000)) 31 | 32 | assert isinstance(Y, torch.Tensor) 33 | assert Y.dtype == torch.float32 34 | assert Y.size() == (64, 512, 51) 35 | -------------------------------------------------------------------------------- /tests/encoders/test_simpleaudiocnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from sslsv.encoders.SimpleAudioCNN import SimpleAudioCNN, SimpleAudioCNNConfig 5 | 6 | 7 | def count_parameters(model: nn.Module) -> int: 8 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 9 | 10 | 11 | def test_default(): 12 | config = SimpleAudioCNNConfig() 13 | encoder = SimpleAudioCNN(config) 14 | 15 | assert count_parameters(encoder) == 5253120 16 | 17 | Y = encoder(torch.randn(64, 32000)) 18 | 19 | assert isinstance(Y, torch.Tensor) 20 | assert Y.dtype == torch.float32 21 | assert Y.size() == (64, 512, 200) 22 | -------------------------------------------------------------------------------- /tests/encoders/test_tdnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from sslsv.encoders.TDNN import TDNN, TDNNConfig 5 | 6 | 7 | def count_parameters(model: nn.Module) -> int: 8 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 9 | 10 | 11 | def test_default(): 12 | config = TDNNConfig() 13 | encoder = TDNN(config) 14 | 15 | assert count_parameters(encoder) == 4252564 16 | 17 | Y = encoder(torch.randn(64, 32000)) 18 | 19 | assert isinstance(Y, torch.Tensor) 20 | assert Y.dtype == torch.float32 21 | assert Y.size() == (64, 512) 22 | -------------------------------------------------------------------------------- /tests/methods/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theolepage/sslsv/6f062138724365f3a8b1813664e0a1007d5b3ce4/tests/methods/__init__.py -------------------------------------------------------------------------------- /tests/methods/test_barlowtwins.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from sslsv.encoders.ResNet34 import ResNet34, ResNet34Config 5 | from sslsv.methods.BarlowTwins.BarlowTwins import BarlowTwins, BarlowTwinsConfig 6 | 7 | 8 | def count_parameters(model: nn.Module) -> int: 9 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 10 | 11 | 12 | def test_default(): 13 | config = BarlowTwinsConfig() 14 | method = BarlowTwins(config, create_encoder_fn=lambda: ResNet34(ResNet34Config())) 15 | 16 | assert count_parameters(method) == 10888598 17 | 18 | # Inference 19 | Z = method(torch.randn(64, 32000)) 20 | assert isinstance(Z, torch.Tensor) 21 | assert Z.dtype == torch.float32 22 | assert Z.size() == (64, 512) 23 | 24 | # Training 25 | Z = method(torch.randn(64, 2, 32000), training=True) 26 | assert isinstance(Z, tuple) 27 | assert len(Z) == 2 28 | for z in Z: 29 | assert isinstance(z, torch.Tensor) 30 | assert z.dtype == torch.float32 31 | assert z.size() == (64, 2048) 32 | 33 | # Train step 34 | loss = method.train_step(Z, step=0) 35 | metrics = method.step_metrics 36 | assert isinstance(loss, torch.Tensor) 37 | assert loss.dtype == torch.float32 38 | assert "train/loss" in metrics 39 | assert isinstance(metrics["train/loss"], torch.Tensor) 40 | -------------------------------------------------------------------------------- /tests/methods/test_byol.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from sslsv.encoders.ResNet34 import ResNet34, ResNet34Config 5 | from sslsv.methods.BYOL.BYOL import BYOL, BYOLConfig 6 | 7 | 8 | def count_parameters(model: nn.Module) -> int: 9 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 10 | 11 | 12 | def test_default(): 13 | config = BYOLConfig() 14 | method = BYOL(config, create_encoder_fn=lambda: ResNet34(ResNet34Config())) 15 | 16 | assert count_parameters(method) == 6705046 17 | 18 | # Inference 19 | Z = method(torch.randn(64, 32000)) 20 | assert isinstance(Z, torch.Tensor) 21 | assert Z.dtype == torch.float32 22 | assert Z.size() == (64, 512) 23 | 24 | # Training 25 | Z = method(torch.randn(64, 2, 32000), training=True) 26 | assert isinstance(Z, tuple) 27 | assert len(Z) == 4 28 | for z in Z: 29 | assert isinstance(z, torch.Tensor) 30 | assert z.dtype == torch.float32 31 | assert z.size() == (64, 256) 32 | 33 | # Train step 34 | loss = method.train_step(Z, step=0) 35 | metrics = method.step_metrics 36 | assert isinstance(loss, torch.Tensor) 37 | assert loss.dtype == torch.float32 38 | assert "train/loss" in metrics 39 | assert "train/tau" in metrics 40 | assert isinstance(metrics["train/loss"], torch.Tensor) 41 | -------------------------------------------------------------------------------- /tests/methods/test_combiner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from sslsv.encoders.ResNet34 import ResNet34, ResNet34Config 5 | from sslsv.methods.Combiner.Combiner import ( 6 | Combiner, 7 | CombinerConfig, 8 | LossItemCombinerConfig, 9 | LossTypeCombinerEnum, 10 | ) 11 | 12 | 13 | def count_parameters(model: nn.Module) -> int: 14 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 15 | 16 | 17 | def test_default(): 18 | config = CombinerConfig( 19 | Y_losses=[ 20 | LossItemCombinerConfig(LossTypeCombinerEnum.INFONCE, 1.0), 21 | LossItemCombinerConfig(LossTypeCombinerEnum.VICREG, 1.0), 22 | ], 23 | Z_losses=[LossItemCombinerConfig(LossTypeCombinerEnum.BARLOWTWINS, 1.0)], 24 | ) 25 | method = Combiner(config, create_encoder_fn=lambda: ResNet34(ResNet34Config())) 26 | 27 | assert count_parameters(method) == 10888598 28 | 29 | # Inference 30 | Z = method(torch.randn(64, 32000)) 31 | assert isinstance(Z, torch.Tensor) 32 | assert Z.dtype == torch.float32 33 | assert Z.size() == (64, 512) 34 | 35 | # Training 36 | Z = method(torch.randn(64, 2, 32000), training=True) 37 | assert isinstance(Z, tuple) 38 | assert len(Z) == 4 39 | for i, z in enumerate(Z): 40 | assert isinstance(z, torch.Tensor) 41 | assert z.dtype == torch.float32 42 | if i < 2: 43 | assert z.size() == (64, 512) 44 | else: 45 | assert z.size() == (64, 2048) 46 | 47 | # Train step 48 | loss = method.train_step(Z, step=0) 49 | metrics = method.step_metrics 50 | assert isinstance(loss, torch.Tensor) 51 | assert loss.dtype == torch.float32 52 | assert "train/loss" in metrics 53 | assert "train/Y_loss" in metrics 54 | assert "train/Z_loss" in metrics 55 | assert "train/Y_accuracy" in metrics 56 | assert "train/Z_accuracy" in metrics 57 | assert isinstance(metrics["train/loss"], torch.Tensor) 58 | -------------------------------------------------------------------------------- /tests/methods/test_cpc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from sslsv.encoders.ResNet34 import ResNet34, ResNet34Config 5 | from sslsv.methods.CPC.CPC import CPC, CPCConfig 6 | 7 | 8 | def count_parameters(model: nn.Module) -> int: 9 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 10 | 11 | 12 | def test_default(): 13 | config = CPCConfig() 14 | method = CPC( 15 | config, create_encoder_fn=lambda: ResNet34(ResNet34Config(pooling=False)) 16 | ) 17 | 18 | assert count_parameters(method) == 2800278 19 | 20 | # Inference 21 | Z = method(torch.randn(64, 32000)) 22 | assert isinstance(Z, torch.Tensor) 23 | assert Z.dtype == torch.float32 24 | assert Z.size() == (64, 256) 25 | 26 | # Training 27 | Z = method(torch.randn(64, 2, 32000), training=True) 28 | assert isinstance(Z, tuple) 29 | assert len(Z) == 4 30 | for i, z in enumerate(Z): 31 | if i < 2: 32 | assert isinstance(z, torch.Tensor) 33 | assert z.dtype == torch.float32 34 | assert z.size() == (64, 512, 51) 35 | else: 36 | assert z is None 37 | 38 | # Train step 39 | loss = method.train_step(Z, step=0) 40 | metrics = method.step_metrics 41 | assert isinstance(loss, torch.Tensor) 42 | assert loss.dtype == torch.float32 43 | assert "train/loss" in metrics 44 | assert isinstance(metrics["train/loss"], torch.Tensor) 45 | -------------------------------------------------------------------------------- /tests/methods/test_dino.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from sslsv.encoders.ResNet34 import ResNet34, ResNet34Config 5 | from sslsv.methods.DINO.DINO import DINO, DINOConfig 6 | 7 | 8 | def count_parameters(model: nn.Module) -> int: 9 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 10 | 11 | 12 | def test_default(): 13 | config = DINOConfig() 14 | method = DINO(config, create_encoder_fn=lambda: ResNet34(ResNet34Config())) 15 | 16 | assert count_parameters(method) == 23994006 17 | 18 | # Inference 19 | Z = method(torch.randn(64, 32000)) 20 | assert isinstance(Z, torch.Tensor) 21 | assert Z.dtype == torch.float32 22 | assert Z.size() == (64, 512) 23 | 24 | # Training 25 | Z = method(torch.randn(64, 6, 32000), training=True) 26 | assert isinstance(Z, tuple) 27 | assert isinstance(Z[0], torch.Tensor) 28 | assert isinstance(Z[1], torch.Tensor) 29 | assert Z[0].dtype == torch.float32 30 | assert Z[1].dtype == torch.float32 31 | assert Z[0].size() == (64 * 6, 65536) 32 | assert Z[1].size() == (64 * 2, 65536) 33 | 34 | # Train step 35 | loss = method.train_step(Z, step=0) 36 | metrics = method.step_metrics 37 | assert isinstance(loss, torch.Tensor) 38 | assert loss.dtype == torch.float32 39 | assert "train/loss" in metrics 40 | assert isinstance(metrics["train/loss"], torch.Tensor) 41 | -------------------------------------------------------------------------------- /tests/methods/test_lim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from sslsv.encoders.ResNet34 import ResNet34, ResNet34Config 5 | from sslsv.methods.LIM.LIM import LIM, LIMConfig 6 | 7 | 8 | def count_parameters(model: nn.Module) -> int: 9 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 10 | 11 | 12 | def test_default(): 13 | config = LIMConfig() 14 | method = LIM(config, create_encoder_fn=lambda: ResNet34(ResNet34Config())) 15 | 16 | assert count_parameters(method) == 1437078 17 | 18 | # Inference 19 | Z = method(torch.randn(64, 32000)) 20 | assert isinstance(Z, torch.Tensor) 21 | assert Z.dtype == torch.float32 22 | assert Z.size() == (64, 512) 23 | 24 | # Training 25 | Z = method(torch.randn(64, 2, 32000), training=True) 26 | assert isinstance(Z, tuple) 27 | assert len(Z) == 2 28 | for z in Z: 29 | assert isinstance(z, torch.Tensor) 30 | assert z.dtype == torch.float32 31 | assert z.size() == (64, 512) 32 | 33 | # Train step 34 | loss = method.train_step(Z, step=0) 35 | metrics = method.step_metrics 36 | assert isinstance(loss, torch.Tensor) 37 | assert loss.dtype == torch.float32 38 | assert "train/loss" in metrics 39 | assert isinstance(metrics["train/loss"], torch.Tensor) 40 | -------------------------------------------------------------------------------- /tests/methods/test_moco.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from sslsv.encoders.ResNet34 import ResNet34, ResNet34Config 5 | from sslsv.methods.MoCo.MoCo import MoCo, MoCoConfig 6 | 7 | 8 | def count_parameters(model: nn.Module) -> int: 9 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 10 | 11 | 12 | def test_default(): 13 | config = MoCoConfig() 14 | method = MoCo(config, create_encoder_fn=lambda: ResNet34(ResNet34Config())) 15 | 16 | assert count_parameters(method) == 3012246 17 | 18 | # Inference 19 | Z = method(torch.randn(64, 32000)) 20 | assert isinstance(Z, torch.Tensor) 21 | assert Z.dtype == torch.float32 22 | assert Z.size() == (64, 512) 23 | 24 | # Training 25 | Z = method(torch.randn(64, 2, 32000), training=True) 26 | assert isinstance(Z, tuple) 27 | assert len(Z) == 4 28 | for z in Z: 29 | assert isinstance(z, torch.Tensor) 30 | assert z.dtype == torch.float32 31 | assert z.size() == (64, 256) 32 | 33 | # Train step 34 | loss = method.train_step(Z, step=0) 35 | metrics = method.step_metrics 36 | assert isinstance(loss, torch.Tensor) 37 | assert loss.dtype == torch.float32 38 | assert "train/loss" in metrics 39 | assert isinstance(metrics["train/loss"], torch.Tensor) 40 | -------------------------------------------------------------------------------- /tests/methods/test_simclr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from sslsv.encoders.ResNet34 import ResNet34, ResNet34Config 5 | from sslsv.methods.SimCLR.SimCLR import SimCLR, SimCLRConfig 6 | 7 | 8 | def count_parameters(model: nn.Module) -> int: 9 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 10 | 11 | 12 | def test_default(): 13 | config = SimCLRConfig() 14 | method = SimCLR(config, create_encoder_fn=lambda: ResNet34(ResNet34Config())) 15 | 16 | assert count_parameters(method) == 3012246 17 | 18 | # Inference 19 | Z = method(torch.randn(64, 32000)) 20 | assert isinstance(Z, torch.Tensor) 21 | assert Z.dtype == torch.float32 22 | assert Z.size() == (64, 512) 23 | 24 | # Training 25 | Z = method(torch.randn(64, 2, 32000), training=True) 26 | assert isinstance(Z, tuple) 27 | assert len(Z) == 2 28 | for z in Z: 29 | assert isinstance(z, torch.Tensor) 30 | assert z.dtype == torch.float32 31 | assert z.size() == (64, 256) 32 | 33 | # Train step 34 | loss = method.train_step(Z, step=0) 35 | metrics = method.step_metrics 36 | assert isinstance(loss, torch.Tensor) 37 | assert loss.dtype == torch.float32 38 | assert "train/loss" in metrics 39 | assert isinstance(metrics["train/loss"], torch.Tensor) 40 | -------------------------------------------------------------------------------- /tests/methods/test_simclr_margins.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from sslsv.encoders.ResNet34 import ResNet34, ResNet34Config 5 | from sslsv.methods.SimCLRMargins.SimCLRMargins import SimCLRMargins, SimCLRMarginsConfig 6 | 7 | 8 | def count_parameters(model: nn.Module) -> int: 9 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 10 | 11 | 12 | def test_default(): 13 | config = SimCLRMarginsConfig() 14 | method = SimCLRMargins(config, create_encoder_fn=lambda: ResNet34(ResNet34Config())) 15 | 16 | assert count_parameters(method) == 3012246 17 | 18 | # Inference 19 | Z = method(torch.randn(64, 32000)) 20 | assert isinstance(Z, torch.Tensor) 21 | assert Z.dtype == torch.float32 22 | assert Z.size() == (64, 512) 23 | 24 | # Training 25 | Z = method(torch.randn(64, 2, 32000), training=True) 26 | assert isinstance(Z, torch.Tensor) 27 | assert Z.dtype == torch.float32 28 | assert Z.size() == (64, 2, 256) 29 | 30 | # Train step 31 | loss = method.train_step(Z, step=0) 32 | metrics = method.step_metrics 33 | assert isinstance(loss, torch.Tensor) 34 | assert loss.dtype == torch.float32 35 | assert "train/loss" in metrics 36 | assert isinstance(metrics["train/loss"], torch.Tensor) 37 | -------------------------------------------------------------------------------- /tests/methods/test_simsiam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from sslsv.encoders.ResNet34 import ResNet34, ResNet34Config 5 | from sslsv.methods.SimSiam.SimSiam import SimSiam, SimSiamConfig 6 | 7 | 8 | def count_parameters(model: nn.Module) -> int: 9 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 10 | 11 | 12 | def test_default(): 13 | config = SimSiamConfig() 14 | method = SimSiam(config, create_encoder_fn=lambda: ResNet34(ResNet34Config())) 15 | 16 | assert count_parameters(method) == 12982678 17 | 18 | # Inference 19 | Z = method(torch.randn(64, 32000)) 20 | assert isinstance(Z, torch.Tensor) 21 | assert Z.dtype == torch.float32 22 | assert Z.size() == (64, 512) 23 | 24 | # Training 25 | Z = method(torch.randn(64, 2, 32000), training=True) 26 | assert isinstance(Z, tuple) 27 | assert len(Z) == 4 28 | for z in Z: 29 | assert isinstance(z, torch.Tensor) 30 | assert z.dtype == torch.float32 31 | assert z.size() == (64, 2048) 32 | 33 | # Train step 34 | loss = method.train_step(Z, step=0) 35 | metrics = method.step_metrics 36 | assert isinstance(loss, torch.Tensor) 37 | assert loss.dtype == torch.float32 38 | assert "train/loss" in metrics 39 | assert isinstance(metrics["train/loss"], torch.Tensor) 40 | -------------------------------------------------------------------------------- /tests/methods/test_supervised.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from sslsv.encoders.ResNet34 import ResNet34, ResNet34Config 5 | from sslsv.methods.Supervised.Supervised import Supervised, SupervisedConfig 6 | 7 | 8 | def count_parameters(model: nn.Module) -> int: 9 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 10 | 11 | 12 | def test_default(): 13 | config = SupervisedConfig() 14 | method = Supervised(config, create_encoder_fn=lambda: ResNet34(ResNet34Config())) 15 | 16 | assert count_parameters(method) == 2057110 17 | 18 | # Inference 19 | Z = method(torch.randn(64, 32000)) 20 | assert isinstance(Z, torch.Tensor) 21 | assert Z.dtype == torch.float32 22 | assert Z.size() == (64, 512) 23 | 24 | # Training 25 | Z = method(torch.randn(64, 32000), training=True) 26 | assert isinstance(Z, torch.Tensor) 27 | assert Z.dtype == torch.float32 28 | assert Z.size() == (64, 1211) 29 | 30 | # Train step 31 | loss = method.train_step(Z, step=0, labels=torch.ones(Z.size(0), dtype=torch.int64)) 32 | metrics = method.step_metrics 33 | assert isinstance(loss, torch.Tensor) 34 | assert loss.dtype == torch.float32 35 | assert "train/loss" in metrics 36 | assert isinstance(metrics["train/loss"], torch.Tensor) 37 | -------------------------------------------------------------------------------- /tests/methods/test_vibcreg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from sslsv.encoders.ResNet34 import ResNet34, ResNet34Config 5 | from sslsv.methods.VIbCReg.VIbCReg import VIbCReg, VIbCRegConfig 6 | 7 | 8 | def count_parameters(model: nn.Module) -> int: 9 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 10 | 11 | 12 | def test_default(): 13 | config = VIbCRegConfig() 14 | method = VIbCReg(config, create_encoder_fn=lambda: ResNet34(ResNet34Config())) 15 | 16 | assert count_parameters(method) == 10892694 17 | 18 | # Inference 19 | Z = method(torch.randn(64, 32000)) 20 | assert isinstance(Z, torch.Tensor) 21 | assert Z.dtype == torch.float32 22 | assert Z.size() == (64, 512) 23 | 24 | # Training 25 | Z = method(torch.randn(64, 2, 32000), training=True) 26 | assert isinstance(Z, tuple) 27 | assert len(Z) == 2 28 | for z in Z: 29 | assert isinstance(z, torch.Tensor) 30 | assert z.dtype == torch.float32 31 | assert z.size() == (64, 2048) 32 | 33 | # Train step 34 | loss = method.train_step(Z, step=0) 35 | metrics = method.step_metrics 36 | assert isinstance(loss, torch.Tensor) 37 | assert loss.dtype == torch.float32 38 | assert "train/loss" in metrics 39 | assert isinstance(metrics["train/loss"], torch.Tensor) 40 | -------------------------------------------------------------------------------- /tests/methods/test_vicreg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from sslsv.encoders.ResNet34 import ResNet34, ResNet34Config 5 | from sslsv.methods.VICReg.VICReg import VICReg, VICRegConfig 6 | 7 | 8 | def count_parameters(model: nn.Module) -> int: 9 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 10 | 11 | 12 | def test_default(): 13 | config = VICRegConfig() 14 | method = VICReg(config, create_encoder_fn=lambda: ResNet34(ResNet34Config())) 15 | 16 | assert count_parameters(method) == 10888598 17 | 18 | # Inference 19 | Z = method(torch.randn(64, 32000)) 20 | assert isinstance(Z, torch.Tensor) 21 | assert Z.dtype == torch.float32 22 | assert Z.size() == (64, 512) 23 | 24 | # Training 25 | Z = method(torch.randn(64, 2, 32000), training=True) 26 | assert isinstance(Z, tuple) 27 | assert len(Z) == 2 28 | for z in Z: 29 | assert isinstance(z, torch.Tensor) 30 | assert z.dtype == torch.float32 31 | assert z.size() == (64, 2048) 32 | 33 | # Train step 34 | loss = method.train_step(Z, step=0) 35 | metrics = method.step_metrics 36 | assert isinstance(loss, torch.Tensor) 37 | assert loss.dtype == torch.float32 38 | assert "train/loss" in metrics 39 | assert isinstance(metrics["train/loss"], torch.Tensor) 40 | -------------------------------------------------------------------------------- /tests/methods/test_wmse.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from sslsv.encoders.ResNet34 import ResNet34, ResNet34Config 5 | from sslsv.methods.WMSE.WMSE import WMSE, WMSEConfig 6 | 7 | 8 | def count_parameters(model: nn.Module) -> int: 9 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 10 | 11 | 12 | def test_default(): 13 | config = WMSEConfig() 14 | method = WMSE(config, create_encoder_fn=lambda: ResNet34(ResNet34Config())) 15 | 16 | assert count_parameters(method) == 2030038 17 | 18 | # Inference 19 | Z = method(torch.randn(128, 32000)) 20 | assert isinstance(Z, torch.Tensor) 21 | assert Z.dtype == torch.float32 22 | assert Z.size() == (128, 512) 23 | 24 | # Training 25 | Z = method(torch.randn(128, 2, 32000), training=True) 26 | assert isinstance(Z, tuple) 27 | assert len(Z) == 2 28 | for z in Z: 29 | assert isinstance(z, torch.Tensor) 30 | assert z.dtype == torch.float32 31 | assert z.size() == (128, 64) 32 | 33 | # Train step 34 | loss = method.train_step(Z, step=0) 35 | metrics = method.step_metrics 36 | assert isinstance(loss, torch.Tensor) 37 | assert loss.dtype == torch.float32 38 | assert "train/loss" in metrics 39 | assert isinstance(metrics["train/loss"], torch.Tensor) 40 | -------------------------------------------------------------------------------- /tests/resources/empty.yml: -------------------------------------------------------------------------------- 1 | hello: 2 | -------------------------------------------------------------------------------- /tests/resources/no_encoder.yml: -------------------------------------------------------------------------------- 1 | method: 2 | type: 'simclr' 3 | -------------------------------------------------------------------------------- /tests/resources/no_method.yml: -------------------------------------------------------------------------------- 1 | encoder: 2 | type: 'resnet34' 3 | -------------------------------------------------------------------------------- /tests/resources/simple/config.yml: -------------------------------------------------------------------------------- 1 | encoder: 2 | type: 'resnet34' 3 | method: 4 | type: 'simclr' 5 | trainer: 6 | epochs: 3 7 | dataset: 8 | ssl: true 9 | max_samples: 1024 10 | evaluation: 11 | validation: 12 | - type: 'sv_cosine' 13 | frame_length: 56240 14 | test: 15 | - type: 'sv_cosine' 16 | frame_length: 56240 17 | -------------------------------------------------------------------------------- /tests/test_config.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from pathlib import Path 4 | 5 | from sslsv.Config import Config 6 | from sslsv.datasets.Dataset import DatasetConfig 7 | from sslsv.encoders.ResNet34 import ResNet34Config 8 | from sslsv.methods.SimCLR.SimCLR import SimCLRConfig 9 | from sslsv.trainer.Trainer import TrainerConfig 10 | from sslsv.evaluations._BaseEvaluation import BaseEvaluationConfig 11 | from sslsv.utils.helpers import load_config 12 | 13 | 14 | @pytest.fixture 15 | def default_config() -> Config: 16 | return Config() 17 | 18 | 19 | def test_default(default_config: Config): 20 | assert default_config.model_name == "default" 21 | assert default_config.model_path == Path("default") 22 | assert default_config.seed == 1717 23 | assert default_config.reproducibility is False 24 | assert default_config.encoder is None 25 | assert default_config.method is None 26 | assert default_config.trainer == TrainerConfig() 27 | assert default_config.dataset == DatasetConfig() 28 | assert default_config.evaluation == BaseEvaluationConfig() 29 | 30 | 31 | def test_empty(): 32 | with pytest.raises(KeyError): 33 | load_config("tests/resources/empty.yml", verbose=False) 34 | 35 | 36 | def test_no_encoder(): 37 | with pytest.raises(KeyError): 38 | load_config("tests/resources/no_encoder.yml", verbose=False) 39 | 40 | 41 | def test_no_method(): 42 | with pytest.raises(KeyError): 43 | load_config("tests/resources/no_method.yml", verbose=False) 44 | 45 | 46 | def test_simple(): 47 | config = load_config("tests/resources/simple/config.yml", verbose=False) 48 | 49 | assert config.model_name == "resources/simple" 50 | assert config.model_path == Path("tests/resources/simple") 51 | assert config.encoder == ResNet34Config() 52 | assert config.method == SimCLRConfig() 53 | -------------------------------------------------------------------------------- /tests/test_reproductibility.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import torch 4 | from torch.optim import Adam 5 | 6 | from sslsv.utils.helpers import load_config, load_model 7 | 8 | 9 | def test_basic(): 10 | config = load_config("tests/resources/simple/config.yml") 11 | 12 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 13 | 14 | model = load_model(config).to(device) 15 | model = torch.nn.DataParallel(model) 16 | 17 | optimizer = Adam( 18 | model.module.get_learnable_params(), 19 | lr=0, 20 | weight_decay=0, 21 | ) 22 | 23 | X1 = torch.randn(256, 2, 32000).cuda() 24 | X2 = torch.randn(256, 32000).cuda() 25 | 26 | Z = model(X1, training=True) 27 | loss = model.module.train_step(Z, step=0) 28 | 29 | optimizer.zero_grad() 30 | loss.backward() 31 | optimizer.step() 32 | 33 | Z = model(X2) 34 | 35 | assert pytest.approx(Z.sum().item()) == 2419.561279296875 36 | -------------------------------------------------------------------------------- /tests/test_trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from pathlib import Path 4 | 5 | from sslsv.trainer.Trainer import Trainer 6 | from sslsv.utils.helpers import load_config, load_train_dataloader, load_model, evaluate 7 | 8 | 9 | def test_basic(): 10 | config = load_config("tests/resources/simple/config.yml") 11 | 12 | (config.model_ckpt_path / "model_latest.pt").unlink(missing_ok=True) 13 | (config.model_ckpt_path / "model_best.pt").unlink(missing_ok=True) 14 | 15 | train_dataloader = load_train_dataloader(config) 16 | 17 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 18 | 19 | model = load_model(config).to(device) 20 | model = torch.nn.DataParallel(model) 21 | 22 | trainer = Trainer( 23 | model=model, 24 | train_dataloader=train_dataloader, 25 | config=config, 26 | evaluate=evaluate, 27 | device=device, 28 | ) 29 | trainer.start() 30 | -------------------------------------------------------------------------------- /tools/download_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | base_path=/lustre/fshomisc/home/rech/genoub01/uxu84ci/sslsv/./models 4 | 5 | models=( 6 | "ssl/voxceleb2/simclr/simclr_e-ecapa-1024" 7 | "ssl/voxceleb2/moco/moco_e-ecapa-1024" 8 | "ssl/voxceleb2/vicreg/vicreg_e-ecapa-1024" 9 | "ssl/voxceleb2/swav/swav_e-ecapa-1024" 10 | "ssl/voxceleb2/dino/dino+_e-ecapa-1024" 11 | "ssl/voxceleb2/supervised/supervised_e-ecapa-1024" 12 | ) 13 | 14 | for model in "${models[@]}"; do 15 | rsync -av --relative \ 16 | jeanzay:${base_path}/${model}/{config.yml,checkpoints/model_latest.pt} . 17 | done -------------------------------------------------------------------------------- /tools/prepare_data/create_cremad_train_csv.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pandas as pd 4 | import numpy as np 5 | 6 | from utils import glob 7 | 8 | 9 | def create_cremad_train_csv(test_split: float = 0.9): 10 | files = glob("cremad/AudioWAV/*.wav") 11 | 12 | LABELS = { 13 | "ANG": "Anger", 14 | "DIS": "Disgust", 15 | "FEA": "Fear", 16 | "HAP": "Joy", 17 | "NEU": "Neutral", 18 | "SAD": "Sad", 19 | } 20 | 21 | df = pd.DataFrame( 22 | { 23 | "File": files, 24 | "Emotion": [LABELS[f.split("/")[-1].split("_")[2]] for f in files], 25 | } 26 | ) 27 | 28 | df.drop(df[df["Emotion"] == "Disgust"].index, inplace=True) 29 | 30 | # Add set column 31 | train_files = np.random.choice( 32 | files, size=int(test_split * len(files)), replace=False 33 | ) 34 | df["Set"] = ["train" if f in train_files else "test" for f in df["File"]] 35 | 36 | df.to_csv("cremad_train.csv", index=False) 37 | 38 | 39 | if __name__ == "__main__": 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument("output_path", help="Path to store datasets.") 42 | args = parser.parse_args() 43 | 44 | os.chdir(args.output_path) 45 | 46 | create_cremad_train_csv() 47 | -------------------------------------------------------------------------------- /tools/prepare_data/create_sitw_trials.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | SITW_TRIALS = [ 5 | ("sitw/dev/keys/core-core.lst", "sitw_dev_core-core"), 6 | ("sitw/dev/keys/core-multi.lst", "sitw_dev_core-multi"), 7 | ("sitw/eval/keys/core-core.lst", "sitw_eval_core-core"), 8 | ("sitw/eval/keys/core-multi.lst", "sitw_eval_core-multi"), 9 | ] 10 | 11 | 12 | def create_sitw_trials(): 13 | for src, dst in SITW_TRIALS: 14 | res = [] 15 | 16 | subset = src.split("/")[1] 17 | 18 | with open(f"sitw/{subset}/lists/enroll-core.lst") as f: 19 | spk_to_file = {l.split()[0]: l.strip().split()[-1] for l in f.readlines()} 20 | 21 | with open(src) as f: 22 | for line in f.readlines(): 23 | a, b, target = line.split() 24 | 25 | a = f"sitw/{subset}/{spk_to_file[a]}" 26 | b = f"sitw/{subset}/{b}" 27 | target = "0" if target == "imp" else "1" 28 | 29 | line_ = f"{target} {a} {b}" 30 | res.append(line_) 31 | 32 | with open(dst, "w") as f: 33 | f.write("\n".join(res)) 34 | 35 | 36 | if __name__ == "__main__": 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument("output_path", help="Path to store datasets.") 39 | args = parser.parse_args() 40 | 41 | os.chdir(args.output_path) 42 | 43 | create_sitw_trials() 44 | -------------------------------------------------------------------------------- /tools/prepare_data/create_voices_trials.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | SRC_VOICES_TRIAL = "voices/sid_dev_lists_and_keys/dev-trial-keys.lst" 5 | DST_VOICES_TRIAL = "voices2019_dev" 6 | 7 | 8 | def create_voices_trials(): 9 | res = [] 10 | with open(SRC_VOICES_TRIAL) as f: 11 | for line in f.readlines(): 12 | a, b, target = line.split() 13 | 14 | sp = a.split("-")[-7][2:] 15 | a = f"voices/sid_dev/sp{sp}/{a}.wav" 16 | 17 | b = f"voices/{b}" 18 | 19 | line_ = "0" if target == "imp" else "1" 20 | line_ += " " + a + " " + b 21 | 22 | res.append(line_) 23 | 24 | with open(DST_VOICES_TRIAL, "w") as f: 25 | f.write("\n".join(res)) 26 | 27 | 28 | if __name__ == "__main__": 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument("output_path", help="Path to store datasets.") 31 | args = parser.parse_args() 32 | 33 | os.chdir(args.output_path) 34 | 35 | create_voices_trials() 36 | -------------------------------------------------------------------------------- /tools/prepare_data/prepare_augmentation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import subprocess 3 | import os 4 | import soundfile as sf 5 | from tqdm import tqdm 6 | 7 | from utils import glob, download, extract 8 | 9 | AUG_DOWNLOAD = [ 10 | ( 11 | "http://www.openslr.org/resources/28/rirs_noises.zip", 12 | "e6f48e257286e05de56413b4779d8ffb", 13 | ), 14 | ( 15 | "http://www.openslr.org/resources/17/musan.tar.gz", 16 | "0c472d4fc0c5141eca47ad1ffeb2a7df", 17 | ), 18 | ] 19 | 20 | AUG_EXTRACT = ["rirs_noises.zip", "musan.tar.gz"] 21 | 22 | 23 | def fix_aug_structure(): 24 | subprocess.call("mv RIRS_NOISES/simulated_rirs .", shell=True) 25 | subprocess.call("rm -r RIRS_NOISES", shell=True) 26 | subprocess.call("rm -r rirs_noises.zip", shell=True) 27 | subprocess.call("rm -r musan.tar.gz", shell=True) 28 | 29 | 30 | def split_musan(length: int = 16000 * 8, stride: int = 16000 * 8): 31 | files = glob("musan/*/*/*.wav") 32 | 33 | for file in tqdm(files): 34 | audio, fs = sf.read(file) 35 | 36 | directory = os.path.dirname(file).replace("musan/", "musan_split/") 37 | os.makedirs(directory, exist_ok=True) 38 | 39 | for st in range(0, len(audio) - length, stride): 40 | filename = os.path.basename(file)[:-4] + ("_%05d.wav" % (st / fs)) 41 | filename = directory + "/" + filename 42 | sf.write(filename, audio[st : st + length], fs) 43 | 44 | subprocess.call("rm -r musan", shell=True) 45 | 46 | 47 | if __name__ == "__main__": 48 | parser = argparse.ArgumentParser() 49 | parser.add_argument("output_path", help="Path to store datasets.") 50 | args = parser.parse_args() 51 | 52 | os.chdir(args.output_path) 53 | 54 | download(AUG_DOWNLOAD) 55 | extract(AUG_EXTRACT) 56 | fix_aug_structure() 57 | split_musan() 58 | -------------------------------------------------------------------------------- /tools/prepare_data/utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Union 2 | 3 | from pathlib import Path 4 | import subprocess 5 | import hashlib 6 | from glob import glob as _glob 7 | 8 | 9 | def glob(path: Union[str, Path]) -> List[str]: 10 | paths = _glob(path) 11 | paths = [p.replace("\\", "/") for p in paths] 12 | return paths 13 | 14 | 15 | def get_md5(path: str) -> str: 16 | hash_md5 = hashlib.md5() 17 | with open(path, "rb") as f: 18 | for chunk in iter(lambda: f.read(4096), b""): 19 | hash_md5.update(chunk) 20 | return hash_md5.hexdigest() 21 | 22 | 23 | def download(entries: List[Tuple[str, str]]): 24 | for url, md5 in entries: 25 | filename = url.split("/")[-1] 26 | status = subprocess.call("wget %s -O %s" % (url, filename), shell=True) 27 | if status != 0: 28 | raise Exception("Download of %s failed" % filename) 29 | 30 | if md5 != get_md5(filename): 31 | raise Warning("Checksum of %s failed" % filename) 32 | 33 | 34 | def concatenate(entries: List[Tuple[str, str]]): 35 | for src, dst, md5 in entries: 36 | subprocess.call("cat %s > %s" % (src, dst), shell=True) 37 | subprocess.call("rm %s" % (src), shell=True) 38 | 39 | if md5 != get_md5(dst): 40 | raise Warning("Checksum of %s failed" % dst) 41 | 42 | 43 | def extract(entries: List[str]): 44 | for filename in entries: 45 | if filename.endswith(".tar.gz"): 46 | subprocess.call("tar xf %s" % (filename), shell=True) 47 | elif filename.endswith(".zip"): 48 | subprocess.call("unzip %s" % (filename), shell=True) 49 | -------------------------------------------------------------------------------- /tools/rsync_jz.sh: -------------------------------------------------------------------------------- 1 | source_path="." 2 | target_path="jeanzay:~/sslsv" 3 | 4 | rsync -azh $source_path $target_path \ 5 | --progress \ 6 | --force \ 7 | --delete \ 8 | --exclude="slurm_*" \ 9 | --exclude="data" \ 10 | --exclude="wandb" \ 11 | --exclude="tensorboard" \ 12 | --exclude="*.pt" \ 13 | --exclude="checkpoints" \ 14 | --exclude="*.json" \ 15 | --keep-dirlinks 16 | 17 | while inotifywait -r -e modify,create,delete $source_path 18 | do 19 | rsync -azh $source_path $target_path \ 20 | --progress \ 21 | --force \ 22 | --delete \ 23 | --exclude="slurm_*" \ 24 | --exclude="data" \ 25 | --exclude="wandb" \ 26 | --exclude="tensorboard" \ 27 | --exclude="*.pt" \ 28 | --exclude="checkpoints" \ 29 | --exclude="*.json" \ 30 | --keep-dirlinks 31 | done 32 | -------------------------------------------------------------------------------- /tools/slurm/clean_ckpts.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | from tqdm import tqdm 3 | import torch 4 | 5 | 6 | models = glob("models/ssl/voxceleb2/*/*/") 7 | 8 | for model in models: 9 | ckpts = glob(f"{model}/checkpoints/*") 10 | 11 | ckpts = [c.replace(model + '/checkpoints/', '') for c in ckpts] 12 | 13 | print(model) 14 | print('model_avg.pt' in ckpts) 15 | print('model_latest.pt' in ckpts) 16 | print('model_best.pt' in ckpts) 17 | i = 0 18 | for c in ckpts: 19 | if 'epoch' not in c: 20 | continue 21 | # print(c) 22 | i += 1 23 | print('model_epoch', i) 24 | print() 25 | print() 26 | # exit() 27 | 28 | # find models/ssl/ -type f -name 'model_epoch-*.pt' -regex '.*/model_epoch-\([0-8][0-9]\|[0-9]\).pt' 29 | 30 | 31 | # ckpts = [ 32 | # "models/ssps/voxceleb2/simclr_e-ecapa/_simclr/checkpoints/model_latest.pt", 33 | # "models/ssps/voxceleb2/swav_e-ecapa/_swav/checkpoints/model_latest.pt", 34 | # "models/ssps/voxceleb2/vicreg_e-ecapa/_vicreg/checkpoints/model_latest.pt", 35 | # ] 36 | # ckpts = glob("models/ssl/voxceleb2/*/*ecapa*/checkpoints/*.pt") 37 | # ckpts = glob("models/ssps/voxceleb2/simclr_e-ecapa/*/checkpoints/model_latest.pt") 38 | 39 | # for ckpt in tqdm(ckpts): 40 | # c = torch.load(ckpt, map_location='cpu') 41 | 42 | # model = c["model"] 43 | 44 | # for k in [ 45 | # "encoder.asb_bn.weight", 46 | # "encoder.asb_bn.bias", 47 | # "encoder.asb_bn.running_mean", 48 | # "encoder.asb_bn.running_var", 49 | # "encoder.asb_bn.num_batches_tracked" 50 | # ]: 51 | # if k not in model: 52 | # continue 53 | 54 | # print(ckpt, k, k.replace('asb', 'asp')) 55 | # model[k.replace('asb', 'asp')] = model.pop(k) 56 | 57 | # # torch.save(c, ckpt) -------------------------------------------------------------------------------- /train_ddp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ $# -eq 0 ]; then 4 | echo "Usage: $0 [args ...]" 5 | exit 1 6 | fi 7 | 8 | num_gpus=$1 9 | 10 | shift 11 | 12 | torchrun --nproc_per_node=$num_gpus sslsv/bin/train_distributed.py "$@" 13 | -------------------------------------------------------------------------------- /train_ddp_jz.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | sbatch <