├── .gitignore ├── LICENSE ├── README.md ├── data └── .gitignore ├── figures ├── across_subjects.png ├── best_middle_worst.png ├── copy_content.png ├── copy_content2.png ├── create_thingseeg2_metadata.png ├── create_thingseeg2_preproc.png ├── examples.png ├── extract_here.png ├── extract_here2.png ├── in_context.png ├── move_weights.png ├── move_weights2.png ├── overview.png ├── swapping.png ├── three_patterns_abc.png ├── time_course.png ├── umap.png └── unclip_pipeline.png ├── requirements-vd.txt ├── requirements.txt ├── results ├── .gitignore ├── nsd_preproc │ └── avg-1-2-5-7 │ │ ├── clip_patterns │ │ └── mni │ │ │ ├── animals_pattern.png │ │ │ ├── food_pattern.png │ │ │ ├── human-closeup_pattern.png │ │ │ ├── human-distant_pattern.png │ │ │ ├── interiors_pattern.png │ │ │ └── urban_pattern.png │ │ ├── ica-color_patterns │ │ └── mni │ │ │ ├── blue_pattern.png │ │ │ ├── green_pattern.png │ │ │ └── red_pattern.png │ │ ├── pca-brightness_patterns │ │ └── mni │ │ │ ├── bright_pattern.png │ │ │ └── dark_pattern.png │ │ └── vdvae-texture_patterns │ │ └── mni │ │ ├── smooth_pattern.png │ │ └── textured_pattern.png └── thingseeg2_preproc │ ├── .gitignore │ ├── blue_pattern_negative_avg_120ms.png │ ├── fig_CLIP_across_size_num_avg.png │ ├── fig_ablations.png │ ├── fig_across_duration.png │ ├── fig_performance.png │ ├── pca-bright_patterns_avg_120-220ms.png │ ├── pca-dark_patterns_avg_120-220ms.png │ ├── red_pattern_negative_avg_120ms.png │ ├── smooth_pattern_negative_avg_120ms.png │ ├── sub-01 │ ├── .gitignore │ ├── umap_CLIP.png │ └── versatile_diffusion_ordered_by_performance.png │ ├── textured_pattern_negative_avg_120ms.png │ └── three_patterns_negative_avg_180ms.png ├── scripts-nsd ├── evaluate_reconstruction.py ├── reconstruct_from_embeddings.py ├── reconstruct_from_embeddings_ica.py ├── reconstruct_from_embeddings_pca.py ├── reconstruct_from_embeddings_vdvae.py ├── train_regression-encode_clip.py ├── train_regression-encode_ica.py ├── train_regression-encode_pca.py ├── train_regression-encode_vdvae.py ├── train_regression_clip.py ├── train_regression_ica.py ├── train_regression_pca.py ├── train_regression_vae.py └── train_regression_vdvae.py ├── scripts-nsd_dataprep ├── download_nsd_data.py ├── evaluation_extract_features_from_test_images.py ├── extract_features-clip.py ├── extract_features-ica.py ├── extract_features-pca.py ├── extract_features-vae.py ├── extract_features-vdvae.py ├── prepare_nsd_data.py └── save_test_images.py ├── scripts-nsd_figures ├── freesurfer_import_subj.py ├── make_clip_patterns_func1pt8mm.py ├── make_ica-color_patterns_func1pt8mm.py ├── make_pca-brightness_patterns_func1pt8mm.py ├── make_vdvae-texture_patterns_func1pt8mm.py ├── nsd_mapdata.py ├── plot_clip_patterns_mni.py ├── plot_clip_patterns_mni_avg.py ├── plot_ica-color_patterns_mni.py ├── plot_ica-color_patterns_mni_avg.py ├── plot_pca-brightness_patterns_mni.py ├── plot_pca-brightness_patterns_mni_avg.py ├── plot_vdvae-texture_patterns_mni.py ├── plot_vdvae-texture_patterns_mni_avg.py └── to_mni.py ├── scripts-thingseeg2 ├── data_segment_replacement.ipynb ├── evaluate_reconstruction.py ├── make_video.ipynb ├── map_embeddings.ipynb ├── plot_reconstructions.py ├── plot_umap_CLIP.py ├── reconstruct_from_embeddings.py ├── reconstruct_from_embeddings_clip-only.py ├── reconstruct_from_embeddings_ica.py ├── reconstruct_from_embeddings_pca.py ├── reconstruct_from_embeddings_vd.py ├── reconstruct_from_embeddings_vdvae.py ├── train_regression-encode_clip.py ├── train_regression-encode_clip_grayscale.py ├── train_regression-encode_ica.py ├── train_regression-encode_pca.py ├── train_regression-encode_vdvae.py ├── train_regression-for-saliency-encode_clip.py ├── train_regression-for-saliency_clip.py ├── train_regression.py ├── train_regression_clip.py ├── train_regression_clip_grayscale.py ├── train_regression_ica.py ├── train_regression_pca.py ├── train_regression_vae.py └── train_regression_vdvae.py ├── scripts-thingseeg2_dataprep ├── evaluation_extract_features_from_test_images.py ├── extract_features-clip.py ├── extract_features-clip_grayscale.py ├── extract_features-cliptext.py ├── extract_features-clipvision.py ├── extract_features-ica.py ├── extract_features-pca.py ├── extract_features-vae.py ├── extract_features-vdvae.py ├── grayscale_images.py ├── ica.py ├── pca.py ├── pipeline_stable_unclip_img2img_modified.py ├── prepare_all_subjects_data.sh ├── prepare_thingseeg2_data.py ├── save_thingseeg2_concepts.py └── save_thingseeg2_images.py ├── scripts-thingseeg2_figures ├── evaluate_ablation.sh ├── evaluate_across_duration.sh ├── evaluate_across_size_num_avg.sh ├── evaluate_all_subjects.sh ├── fig_ablations.py ├── fig_across_durations.py ├── fig_across_size_num_avg.py ├── fig_performance.py ├── plot_ablations_recon.sh ├── plot_all_patterns.py ├── plot_clip_patterns.py ├── plot_clip_patterns_negative.py ├── plot_clip_patterns_negative_avg.py ├── plot_ica-color_patterns.py ├── plot_ica-color_patterns_negative.py ├── plot_ica-color_patterns_negative_avg.py ├── plot_pca-brightness_patterns.py ├── plot_pca-brightness_patterns_avg.py ├── plot_vdvae-texture_patterns.py ├── plot_vdvae-texture_patterns_negative.py ├── plot_vdvae-texture_patterns_negative_avg.py ├── reconstruct_ablation.sh ├── reconstruct_across_duration.sh ├── reconstruct_across_size_num_avg.sh ├── reconstruct_all_subjects.sh ├── train_across_duration.sh ├── train_across_size_num_avg.sh ├── train_all_subjects.sh ├── visualize_clip_RSA.ipynb └── visualize_clip_patterns.ipynb ├── scripts-thingseeg2_transfer_learning ├── average_thingseeg2_subjects.py ├── evaluate_reconstruction.py ├── evaluation_extract_features.py ├── plot_performance.ipynb ├── plot_reconstructions.py ├── reconstruct_from_embeddings.py ├── train_regression.py └── transfer_learning.py ├── scripts-thingsmeg ├── avg1b_regression_prediction.ipynb ├── avg1b_regression_prediction.py ├── avg1b_regression_prediction_.py ├── avg1b_vdvae_reconstruct_images1b.py ├── avg1b_versatilediffusion_reconstruct_images1b.py ├── cliptext1b_gridsearch.py ├── cliptext1b_module_prediction.py ├── cliptext1b_regression.py ├── cliptext1b_regression_alltokens.py ├── cliptext_extract_features.py ├── cliptext_regression.py ├── clipvision1b_gridsearch.py ├── clipvision1b_module_prediction.py ├── clipvision1b_regression.py ├── clipvision_extract_features.py ├── clipvision_regression.py ├── explore_meg_epochs.ipynb ├── generate_captions1b.ipynb ├── generate_captions1b.py ├── get_stims.py ├── get_stims1b.py ├── make_video.ipynb ├── preprocess_meg.ipynb ├── preprocess_meg.py ├── preprocess_meg_epoching.py ├── run_brain_module.ipynb ├── save_test_images.py ├── save_test_images1b.py ├── save_things_categories.py ├── save_things_images.py ├── test_brain_module.ipynb ├── train_brain_module.py ├── vdvae1b_gridsearch.py ├── vdvae_extract_features.py ├── vdvae_extract_features1b.py ├── vdvae_reconstruct_images.py ├── vdvae_reconstruct_images1b.py ├── vdvae_regression.py ├── vdvae_regression1b.py ├── versatilediffusion_reconstruct_images.py ├── versatilediffusion_reconstruct_images1b.py ├── versatilediffusion_reconstruct_images1b_brainmodule.py ├── versatilediffusion_reconstruct_images1b_brainregress.py ├── versatilediffusion_reconstruct_images1b_dummymodule.py ├── versatilediffusion_reconstruct_images1b_refclip.py └── versatilediffusion_reconstruct_images1b_scrambled.py ├── vdvae ├── LICENSE.md ├── README.md ├── data.py ├── files_to_npy.py ├── header-image.png ├── hps.py ├── image_utils.py ├── model │ └── .gitignore ├── model_utils.py ├── setup_cifar10.sh ├── setup_ffhq1024.sh ├── setup_ffhq256.sh ├── setup_imagenet.sh ├── train.py ├── train_helpers.py ├── utils.py ├── vae.py └── vae_helpers.py └── versatile_diffusion ├── LICENSE ├── README.md ├── README_extra.md ├── configs ├── dataset │ └── laion2b.yaml ├── experiment │ ├── sd_eval.yaml │ ├── sd_variation_eval.yaml │ ├── vd_dc_eval.yaml │ └── vd_official_eval.yaml └── model │ ├── clip.yaml │ ├── openai_unet.yaml │ ├── optimus.yaml │ ├── sd.yaml │ └── vd.yaml ├── inference.py ├── lib ├── __init__.py ├── cfg_helper.py ├── cfg_holder.py ├── data_factory │ ├── __init__.py │ ├── common │ │ ├── __init__.py │ │ ├── ds_base.py │ │ ├── ds_estimator.py │ │ ├── ds_formatter.py │ │ ├── ds_loader.py │ │ ├── ds_sampler.py │ │ └── ds_transform.py │ └── ds_laion2b_webdataset.py ├── evaluator │ ├── __init__.py │ ├── eva_base.py │ └── eva_null.py ├── experiments │ ├── __init__.py │ ├── sd_default.py │ └── vd_default.py ├── log_service.py ├── model_zoo │ ├── __init__.py │ ├── attention.py │ ├── autoencoder.py │ ├── bert.py │ ├── clip.py │ ├── clip_justin │ │ ├── __init__.py │ │ ├── clip.py │ │ ├── model.py │ │ └── simple_tokenizer.py │ ├── common │ │ ├── get_model.py │ │ ├── get_optimizer.py │ │ ├── get_scheduler.py │ │ └── utils.py │ ├── ddim.py │ ├── ddim_dualcontext.py │ ├── ddim_dualmodel.py │ ├── ddim_vd.py │ ├── ddim_vd_old.py │ ├── diffusion_modules.py │ ├── diffusion_utils.py │ ├── distributions.py │ ├── ema.py │ ├── openaimodel.py │ ├── optimus.py │ ├── optimus_models │ │ ├── configuration_bert.py │ │ ├── configuration_gpt2.py │ │ ├── configuration_utils.py │ │ ├── file_utils.py │ │ ├── modeling_utils.py │ │ ├── optimus_bert.py │ │ ├── optimus_gpt2.py │ │ ├── tokenization_bert.py │ │ ├── tokenization_gpt2.py │ │ ├── tokenization_utils.py │ │ └── vocab │ │ │ ├── bert-base-cased-vocab.txt │ │ │ ├── bert_vocab_download_info.json │ │ │ ├── gpt2-merges.txt │ │ │ ├── gpt2-vocab.json │ │ │ └── gpt2_vocab_merge_download_info.json │ ├── sd.py │ └── vd.py ├── sync.py └── utils.py ├── log └── sd_nodataset │ └── 99999_evalonly │ └── sd_variation │ ├── code │ ├── configs │ │ ├── dataset │ │ │ └── laion2b.yaml │ │ ├── experiment │ │ │ ├── sd_eval.yaml │ │ │ ├── sd_variation_eval.yaml │ │ │ ├── vd_dc_eval.yaml │ │ │ └── vd_official_eval.yaml │ │ └── model │ │ │ ├── clip.yaml │ │ │ ├── openai_unet.yaml │ │ │ ├── optimus.yaml │ │ │ ├── sd.yaml │ │ │ └── vd.yaml │ └── lib │ │ ├── __init__.py │ │ ├── cfg_helper.py │ │ ├── cfg_holder.py │ │ ├── data_factory │ │ ├── __init__.py │ │ ├── common │ │ │ ├── __init__.py │ │ │ ├── ds_base.py │ │ │ ├── ds_estimator.py │ │ │ ├── ds_formatter.py │ │ │ ├── ds_loader.py │ │ │ ├── ds_sampler.py │ │ │ └── ds_transform.py │ │ └── ds_laion2b_webdataset.py │ │ ├── evaluator │ │ ├── __init__.py │ │ ├── eva_base.py │ │ └── eva_null.py │ │ ├── experiments │ │ ├── __init__.py │ │ ├── sd_default.py │ │ └── vd_default.py │ │ ├── log_service.py │ │ ├── model_zoo │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── autoencoder.py │ │ ├── bert.py │ │ ├── clip.py │ │ ├── clip_justin │ │ │ ├── __init__.py │ │ │ ├── clip.py │ │ │ ├── model.py │ │ │ └── simple_tokenizer.py │ │ ├── common │ │ │ ├── get_model.py │ │ │ ├── get_optimizer.py │ │ │ ├── get_scheduler.py │ │ │ └── utils.py │ │ ├── ddim.py │ │ ├── ddim_dualcontext.py │ │ ├── ddim_dualmodel.py │ │ ├── ddim_vd.py │ │ ├── diffusion_modules.py │ │ ├── diffusion_utils.py │ │ ├── distributions.py │ │ ├── ema.py │ │ ├── openaimodel.py │ │ ├── optimus.py │ │ ├── optimus_models │ │ │ ├── configuration_bert.py │ │ │ ├── configuration_gpt2.py │ │ │ ├── configuration_utils.py │ │ │ ├── file_utils.py │ │ │ ├── modeling_utils.py │ │ │ ├── optimus_bert.py │ │ │ ├── optimus_gpt2.py │ │ │ ├── tokenization_bert.py │ │ │ ├── tokenization_gpt2.py │ │ │ ├── tokenization_utils.py │ │ │ └── vocab │ │ │ │ ├── bert-base-cased-vocab.txt │ │ │ │ ├── bert_vocab_download_info.json │ │ │ │ ├── gpt2-merges.txt │ │ │ │ ├── gpt2-vocab.json │ │ │ │ └── gpt2_vocab_merge_download_info.json │ │ ├── sd.py │ │ └── vd.py │ │ ├── sync.py │ │ └── utils.py │ ├── config.yaml │ └── eval.log ├── main.py ├── pretrained └── .gitignore ├── reconstruct_images.py ├── reconstruct_txt2im.py ├── requirement.txt └── requirement_colab.txt /.gitignore: -------------------------------------------------------------------------------- 1 | cache/ 2 | misc/ 3 | __pycache__/ 4 | pyenv/ -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /figures/across_subjects.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desa-lab/Perceptogram/54db0981aa26d2b2f8b56e2d4ad04c81d1419d35/figures/across_subjects.png -------------------------------------------------------------------------------- /figures/best_middle_worst.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desa-lab/Perceptogram/54db0981aa26d2b2f8b56e2d4ad04c81d1419d35/figures/best_middle_worst.png -------------------------------------------------------------------------------- /figures/copy_content.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desa-lab/Perceptogram/54db0981aa26d2b2f8b56e2d4ad04c81d1419d35/figures/copy_content.png -------------------------------------------------------------------------------- /figures/copy_content2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desa-lab/Perceptogram/54db0981aa26d2b2f8b56e2d4ad04c81d1419d35/figures/copy_content2.png -------------------------------------------------------------------------------- /figures/create_thingseeg2_metadata.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desa-lab/Perceptogram/54db0981aa26d2b2f8b56e2d4ad04c81d1419d35/figures/create_thingseeg2_metadata.png -------------------------------------------------------------------------------- /figures/create_thingseeg2_preproc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desa-lab/Perceptogram/54db0981aa26d2b2f8b56e2d4ad04c81d1419d35/figures/create_thingseeg2_preproc.png -------------------------------------------------------------------------------- /figures/examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desa-lab/Perceptogram/54db0981aa26d2b2f8b56e2d4ad04c81d1419d35/figures/examples.png -------------------------------------------------------------------------------- /figures/extract_here.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desa-lab/Perceptogram/54db0981aa26d2b2f8b56e2d4ad04c81d1419d35/figures/extract_here.png -------------------------------------------------------------------------------- /figures/extract_here2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desa-lab/Perceptogram/54db0981aa26d2b2f8b56e2d4ad04c81d1419d35/figures/extract_here2.png -------------------------------------------------------------------------------- /figures/in_context.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desa-lab/Perceptogram/54db0981aa26d2b2f8b56e2d4ad04c81d1419d35/figures/in_context.png -------------------------------------------------------------------------------- /figures/move_weights.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desa-lab/Perceptogram/54db0981aa26d2b2f8b56e2d4ad04c81d1419d35/figures/move_weights.png -------------------------------------------------------------------------------- /figures/move_weights2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desa-lab/Perceptogram/54db0981aa26d2b2f8b56e2d4ad04c81d1419d35/figures/move_weights2.png -------------------------------------------------------------------------------- /figures/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desa-lab/Perceptogram/54db0981aa26d2b2f8b56e2d4ad04c81d1419d35/figures/overview.png -------------------------------------------------------------------------------- /figures/swapping.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desa-lab/Perceptogram/54db0981aa26d2b2f8b56e2d4ad04c81d1419d35/figures/swapping.png -------------------------------------------------------------------------------- /figures/three_patterns_abc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desa-lab/Perceptogram/54db0981aa26d2b2f8b56e2d4ad04c81d1419d35/figures/three_patterns_abc.png -------------------------------------------------------------------------------- /figures/time_course.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desa-lab/Perceptogram/54db0981aa26d2b2f8b56e2d4ad04c81d1419d35/figures/time_course.png -------------------------------------------------------------------------------- /figures/umap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desa-lab/Perceptogram/54db0981aa26d2b2f8b56e2d4ad04c81d1419d35/figures/umap.png -------------------------------------------------------------------------------- /figures/unclip_pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desa-lab/Perceptogram/54db0981aa26d2b2f8b56e2d4ad04c81d1419d35/figures/unclip_pipeline.png -------------------------------------------------------------------------------- /requirements-vd.txt: -------------------------------------------------------------------------------- 1 | easydict==1.11 2 | einops==0.7.0 3 | ipython==8.19.0 4 | matplotlib==3.8.2 5 | openai-clip==1.0.1 6 | opencv-python==4.8.1.78 7 | scikit-image==0.22.0 8 | scikit-learn==1.3.2 9 | setuptools==69.5.1 10 | tokenizers==0.12.1 11 | transformers==4.19.2 12 | umap-learn==0.5.6 13 | --extra-index-url https://download.pytorch.org/whl/cu121 14 | torch==2.1.1 15 | torchaudio==2.1.1 16 | torchvision==0.16.1 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | albumentations==0.4.3 2 | opencv-python==4.10.0.84 3 | pudb==2019.2 4 | h5py==3.13.0 5 | imageio==2.9.0 6 | imageio-ffmpeg==0.4.2 7 | nibabel==5.3.2 8 | pytorch-lightning==1.4.2 9 | torchmetrics==0.6 10 | omegaconf==2.1.1 11 | test-tube==0.7.5 12 | streamlit==1.37.1 13 | einops==0.3.0 14 | transformers==4.44.0 15 | webdataset==0.2.5 16 | open-clip-torch==2.7.0 17 | gradio==3.13.2 18 | kornia==0.6 19 | invisible-watermark==0.2.0 20 | streamlit-drawable-canvas==0.8.0 21 | wandb==0.19.1 22 | diffusers==0.32.2 23 | openai-clip==1.0.1 24 | scikit-learn==1.6.1 25 | nilearn==0.11.0 26 | pycortex==1.2.10 27 | --extra-index-url https://download.pytorch.org/whl/cu121 28 | torch==2.4.0 29 | accelerate==0.33.0 30 | torchvision==0.19.0 31 | torchaudio==2.4.0 -------------------------------------------------------------------------------- /results/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | !thingseeg2_preproc/ 6 | !thingseeg2_preproc/sub-01/ 7 | !thingseeg2_preproc/fig_performance.png 8 | !thingseeg2_preproc/fig_ablations.png 9 | !thingseeg2_preproc/fig_across_duration.png 10 | !thingseeg2_preproc/fig_CLIP_across_size_num_avg.png 11 | !nsd_preproc/ 12 | !nsd_preproc/avg-1-2-5-7/ 13 | !nsd_preproc/avg-1-2-5-7/* 14 | !nsd_preproc/avg-1-2-5-7/clip_patterns/ 15 | !nsd_preproc/avg-1-2-5-7/clip_patterns/mni/ 16 | !nsd_preproc/avg-1-2-5-7/clip_patterns/mni/* 17 | !nsd_preproc/avg-1-2-5-7/pca-brightness_patterns/mni/ 18 | !nsd_preproc/avg-1-2-5-7/pca-brightness_patterns/mni/* 19 | !nsd_preproc/avg-1-2-5-7/ica-color_patterns/mni/ 20 | !nsd_preproc/avg-1-2-5-7/ica-color_patterns/mni/* 21 | !nsd_preproc/avg-1-2-5-7/vdvae-texture_patterns/mni/ 22 | !nsd_preproc/avg-1-2-5-7/vdvae-texture_patterns/mni/* 23 | -------------------------------------------------------------------------------- /results/nsd_preproc/avg-1-2-5-7/clip_patterns/mni/animals_pattern.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desa-lab/Perceptogram/54db0981aa26d2b2f8b56e2d4ad04c81d1419d35/results/nsd_preproc/avg-1-2-5-7/clip_patterns/mni/animals_pattern.png -------------------------------------------------------------------------------- /results/nsd_preproc/avg-1-2-5-7/clip_patterns/mni/food_pattern.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desa-lab/Perceptogram/54db0981aa26d2b2f8b56e2d4ad04c81d1419d35/results/nsd_preproc/avg-1-2-5-7/clip_patterns/mni/food_pattern.png -------------------------------------------------------------------------------- /results/nsd_preproc/avg-1-2-5-7/clip_patterns/mni/human-closeup_pattern.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desa-lab/Perceptogram/54db0981aa26d2b2f8b56e2d4ad04c81d1419d35/results/nsd_preproc/avg-1-2-5-7/clip_patterns/mni/human-closeup_pattern.png -------------------------------------------------------------------------------- /results/nsd_preproc/avg-1-2-5-7/clip_patterns/mni/human-distant_pattern.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desa-lab/Perceptogram/54db0981aa26d2b2f8b56e2d4ad04c81d1419d35/results/nsd_preproc/avg-1-2-5-7/clip_patterns/mni/human-distant_pattern.png -------------------------------------------------------------------------------- /results/nsd_preproc/avg-1-2-5-7/clip_patterns/mni/interiors_pattern.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desa-lab/Perceptogram/54db0981aa26d2b2f8b56e2d4ad04c81d1419d35/results/nsd_preproc/avg-1-2-5-7/clip_patterns/mni/interiors_pattern.png -------------------------------------------------------------------------------- /results/nsd_preproc/avg-1-2-5-7/clip_patterns/mni/urban_pattern.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desa-lab/Perceptogram/54db0981aa26d2b2f8b56e2d4ad04c81d1419d35/results/nsd_preproc/avg-1-2-5-7/clip_patterns/mni/urban_pattern.png -------------------------------------------------------------------------------- /results/nsd_preproc/avg-1-2-5-7/ica-color_patterns/mni/blue_pattern.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desa-lab/Perceptogram/54db0981aa26d2b2f8b56e2d4ad04c81d1419d35/results/nsd_preproc/avg-1-2-5-7/ica-color_patterns/mni/blue_pattern.png -------------------------------------------------------------------------------- /results/nsd_preproc/avg-1-2-5-7/ica-color_patterns/mni/green_pattern.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desa-lab/Perceptogram/54db0981aa26d2b2f8b56e2d4ad04c81d1419d35/results/nsd_preproc/avg-1-2-5-7/ica-color_patterns/mni/green_pattern.png -------------------------------------------------------------------------------- /results/nsd_preproc/avg-1-2-5-7/ica-color_patterns/mni/red_pattern.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desa-lab/Perceptogram/54db0981aa26d2b2f8b56e2d4ad04c81d1419d35/results/nsd_preproc/avg-1-2-5-7/ica-color_patterns/mni/red_pattern.png -------------------------------------------------------------------------------- /results/nsd_preproc/avg-1-2-5-7/pca-brightness_patterns/mni/bright_pattern.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desa-lab/Perceptogram/54db0981aa26d2b2f8b56e2d4ad04c81d1419d35/results/nsd_preproc/avg-1-2-5-7/pca-brightness_patterns/mni/bright_pattern.png -------------------------------------------------------------------------------- /results/nsd_preproc/avg-1-2-5-7/pca-brightness_patterns/mni/dark_pattern.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desa-lab/Perceptogram/54db0981aa26d2b2f8b56e2d4ad04c81d1419d35/results/nsd_preproc/avg-1-2-5-7/pca-brightness_patterns/mni/dark_pattern.png -------------------------------------------------------------------------------- /results/nsd_preproc/avg-1-2-5-7/vdvae-texture_patterns/mni/smooth_pattern.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desa-lab/Perceptogram/54db0981aa26d2b2f8b56e2d4ad04c81d1419d35/results/nsd_preproc/avg-1-2-5-7/vdvae-texture_patterns/mni/smooth_pattern.png -------------------------------------------------------------------------------- /results/nsd_preproc/avg-1-2-5-7/vdvae-texture_patterns/mni/textured_pattern.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desa-lab/Perceptogram/54db0981aa26d2b2f8b56e2d4ad04c81d1419d35/results/nsd_preproc/avg-1-2-5-7/vdvae-texture_patterns/mni/textured_pattern.png -------------------------------------------------------------------------------- /results/thingseeg2_preproc/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | !sub-01/ 6 | !fig_performance.png 7 | !fig_ablations.png 8 | !fig_across_duration.png 9 | !fig_CLIP_across_size_num_avg.png 10 | !three_patterns_negative_avg_180ms.png 11 | !blue_pattern_negative_avg_120ms.png 12 | !red_pattern_negative_avg_120ms.png 13 | !smooth_pattern_negative_avg_120ms.png 14 | !textured_pattern_negative_avg_120ms.png 15 | !pca-bright_patterns_avg_120-220ms.png 16 | !pca-dark_patterns_avg_120-220ms.png 17 | -------------------------------------------------------------------------------- /results/thingseeg2_preproc/blue_pattern_negative_avg_120ms.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desa-lab/Perceptogram/54db0981aa26d2b2f8b56e2d4ad04c81d1419d35/results/thingseeg2_preproc/blue_pattern_negative_avg_120ms.png -------------------------------------------------------------------------------- /results/thingseeg2_preproc/fig_CLIP_across_size_num_avg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desa-lab/Perceptogram/54db0981aa26d2b2f8b56e2d4ad04c81d1419d35/results/thingseeg2_preproc/fig_CLIP_across_size_num_avg.png -------------------------------------------------------------------------------- /results/thingseeg2_preproc/fig_ablations.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desa-lab/Perceptogram/54db0981aa26d2b2f8b56e2d4ad04c81d1419d35/results/thingseeg2_preproc/fig_ablations.png -------------------------------------------------------------------------------- /results/thingseeg2_preproc/fig_across_duration.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desa-lab/Perceptogram/54db0981aa26d2b2f8b56e2d4ad04c81d1419d35/results/thingseeg2_preproc/fig_across_duration.png -------------------------------------------------------------------------------- /results/thingseeg2_preproc/fig_performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desa-lab/Perceptogram/54db0981aa26d2b2f8b56e2d4ad04c81d1419d35/results/thingseeg2_preproc/fig_performance.png -------------------------------------------------------------------------------- /results/thingseeg2_preproc/pca-bright_patterns_avg_120-220ms.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desa-lab/Perceptogram/54db0981aa26d2b2f8b56e2d4ad04c81d1419d35/results/thingseeg2_preproc/pca-bright_patterns_avg_120-220ms.png -------------------------------------------------------------------------------- /results/thingseeg2_preproc/pca-dark_patterns_avg_120-220ms.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desa-lab/Perceptogram/54db0981aa26d2b2f8b56e2d4ad04c81d1419d35/results/thingseeg2_preproc/pca-dark_patterns_avg_120-220ms.png -------------------------------------------------------------------------------- /results/thingseeg2_preproc/red_pattern_negative_avg_120ms.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desa-lab/Perceptogram/54db0981aa26d2b2f8b56e2d4ad04c81d1419d35/results/thingseeg2_preproc/red_pattern_negative_avg_120ms.png -------------------------------------------------------------------------------- /results/thingseeg2_preproc/smooth_pattern_negative_avg_120ms.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desa-lab/Perceptogram/54db0981aa26d2b2f8b56e2d4ad04c81d1419d35/results/thingseeg2_preproc/smooth_pattern_negative_avg_120ms.png -------------------------------------------------------------------------------- /results/thingseeg2_preproc/sub-01/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | !versatile_diffusion_ordered_by_performance.png 6 | !umap_CLIP.png 7 | -------------------------------------------------------------------------------- /results/thingseeg2_preproc/sub-01/umap_CLIP.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desa-lab/Perceptogram/54db0981aa26d2b2f8b56e2d4ad04c81d1419d35/results/thingseeg2_preproc/sub-01/umap_CLIP.png -------------------------------------------------------------------------------- /results/thingseeg2_preproc/sub-01/versatile_diffusion_ordered_by_performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desa-lab/Perceptogram/54db0981aa26d2b2f8b56e2d4ad04c81d1419d35/results/thingseeg2_preproc/sub-01/versatile_diffusion_ordered_by_performance.png -------------------------------------------------------------------------------- /results/thingseeg2_preproc/textured_pattern_negative_avg_120ms.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desa-lab/Perceptogram/54db0981aa26d2b2f8b56e2d4ad04c81d1419d35/results/thingseeg2_preproc/textured_pattern_negative_avg_120ms.png -------------------------------------------------------------------------------- /results/thingseeg2_preproc/three_patterns_negative_avg_180ms.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desa-lab/Perceptogram/54db0981aa26d2b2f8b56e2d4ad04c81d1419d35/results/thingseeg2_preproc/three_patterns_negative_avg_180ms.png -------------------------------------------------------------------------------- /scripts-nsd/reconstruct_from_embeddings.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import torch 3 | from PIL import Image 4 | from io import BytesIO 5 | import numpy as np 6 | import os 7 | 8 | from diffusers import StableUnCLIPImg2ImgPipeline 9 | 10 | import argparse 11 | parser = argparse.ArgumentParser(description='Argument Parser') 12 | parser.add_argument("-sub", "--sub",help="Subject Number",default=1) 13 | args = parser.parse_args() 14 | sub=int(args.sub) 15 | 16 | pred_clipvision = np.load(f"cache/nsd_preproc/predicted_embeddings/sub-{sub:02d}/regress_clip.npy", mmap_mode='r') # Load the embeddings 17 | pred_vae = np.load(f"cache/nsd_preproc/predicted_embeddings/sub-{sub:02d}/regress_vae.npy", mmap_mode='r') 18 | recon_dir = f"results/nsd_preproc/sub-{sub:02d}/unclip/" # Directory to save the reconstructed images 19 | os.makedirs(recon_dir, exist_ok=True) 20 | 21 | #Start the StableUnCLIP Image variations pipeline 22 | pipe = StableUnCLIPImg2ImgPipeline.from_pretrained( 23 | "stabilityai/stable-diffusion-2-1-unclip", torch_dtype=torch.float16, variation="fp16" 24 | ) 25 | pipe = pipe.to("cuda") 26 | device = pipe._execution_device 27 | torch_ones = torch.ones(512, dtype=torch.float16, device=device) 28 | torch_zeros = torch.zeros(512, dtype=torch.float16, device=device) 29 | extra_portion = torch.cat([torch_ones, torch_zeros]) 30 | 31 | for i, embedding in enumerate(pred_clipvision): 32 | print(i) 33 | vae_latent = pred_vae[i].reshape((1, 4, 96, 96)) 34 | vae_latent = torch.from_numpy(vae_latent).to(device).half() 35 | torch.manual_seed(0) 36 | noise_latent=torch.randn(vae_latent.shape, device=device).half() 37 | vae_latent = vae_latent*0.02 + noise_latent 38 | embedding = torch.tensor(embedding, device=device, dtype=torch.float16) 39 | embedding = torch.cat([embedding, extra_portion]).unsqueeze(0) 40 | negative_prompt_embeds = torch.zeros_like(embedding) 41 | embedding = torch.cat([negative_prompt_embeds, embedding]) 42 | torch.manual_seed(0) 43 | image = pipe.decode(embedding, latents=vae_latent, guidance_scale=7.5).images[0] 44 | image.save(recon_dir + f"{i}.png") -------------------------------------------------------------------------------- /scripts-nsd/reconstruct_from_embeddings_ica.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from tqdm import tqdm 4 | from PIL import Image 5 | 6 | import argparse 7 | parser = argparse.ArgumentParser(description='Argument Parser') 8 | parser.add_argument("-sub", "--sub",help="Subject Number",default=1) 9 | parser.add_argument('-size', '--size', help='Size', default=8859) 10 | args = parser.parse_args() 11 | sub = int(args.sub) 12 | param = '' 13 | 14 | 15 | pred_latents = np.load(f'cache/nsd_preproc/predicted_embeddings/sub-{sub:02d}/regress_ica1k.npy', mmap_mode='r') 16 | 17 | recon_dir = f'results/nsd_preproc/sub-{sub:02d}/ica1k{param}/' 18 | os.makedirs(recon_dir, exist_ok=True) 19 | 20 | ica = np.load("cache/ica.npz") 21 | encoder = ica["encoder"] 22 | decoder = ica["decoder"] 23 | train_mean = ica["mean"] 24 | latent_dim = 1000 25 | 26 | print('Reconstructing images...') 27 | images = np.clip(decoder[:, :latent_dim] @ pred_latents.T + train_mean[:, np.newaxis], 0, 1).T.reshape((len(pred_latents), 64, 64, 3), order="F") 28 | images = (images * 255).astype(np.uint8) 29 | 30 | print('Saving images...') 31 | for iter in tqdm(range(len(pred_latents)), total=len(pred_latents)): 32 | img = Image.fromarray(images[iter]) 33 | img.save(f'{recon_dir}{iter:03d}.png') 34 | 35 | 36 | -------------------------------------------------------------------------------- /scripts-nsd/reconstruct_from_embeddings_pca.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from tqdm import tqdm 4 | from PIL import Image 5 | 6 | import argparse 7 | parser = argparse.ArgumentParser(description='Argument Parser') 8 | parser.add_argument("-sub", "--sub",help="Subject Number",default=1) 9 | parser.add_argument('-size', '--size', help='Size', default=8859) 10 | args = parser.parse_args() 11 | sub = int(args.sub) 12 | param = '' 13 | 14 | 15 | pred_latents = np.load(f'cache/nsd_preproc/predicted_embeddings/sub-{sub:02d}/regress_pca1k.npy', mmap_mode='r') 16 | 17 | recon_dir = f'results/nsd_preproc/sub-{sub:02d}/pca1k{param}/' 18 | os.makedirs(recon_dir, exist_ok=True) 19 | 20 | pca = np.load("cache/pca.npz") 21 | eigenvectors = pca["eigenvectors"] 22 | eigenvalues = pca["eigenvalues"] 23 | latent_dim = 1000 24 | 25 | print('Reconstructing images...') 26 | images = np.clip(eigenvectors[:latent_dim].T @ pred_latents.T, 0, 1).T.reshape((len(pred_latents), 64, 64, 3), order="F") 27 | images = (images * 255).astype(np.uint8) 28 | 29 | print('Saving images...') 30 | for iter in tqdm(range(len(pred_latents)), total=len(pred_latents)): 31 | img = Image.fromarray(images[iter]) 32 | img.save(f'{recon_dir}{iter:03d}.png') 33 | 34 | 35 | -------------------------------------------------------------------------------- /scripts-nsd/train_regression-encode_clip.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy 3 | from scipy.spatial.distance import correlation 4 | import random 5 | import sklearn.linear_model as skl 6 | import os 7 | import pickle 8 | 9 | import argparse 10 | parser = argparse.ArgumentParser(description='Argument Parser') 11 | parser.add_argument("-sub", "--sub",help="Subject Number",default=1) 12 | parser.add_argument('-weights', '--saving_weights',help="Saving the weights", default=True, action=argparse.BooleanOptionalAction) 13 | parser.add_argument('-size', '--size', help='Size', default=8859) 14 | parser.add_argument('-alpha', '--alpha', help='Alpha for regression strength', default=10000) 15 | args = parser.parse_args() 16 | sub = int(args.sub) 17 | saving_weights=args.saving_weights 18 | alpha=int(args.alpha) 19 | param = '' 20 | 21 | # Load fMRI data 22 | fmri_train = np.load(f'data/nsd_preproc/sub-{sub:02d}/train_fmriavg_nsdgeneral.npy') 23 | fmri_test = np.load(f'data/nsd_preproc/sub-{sub:02d}/test_fmriavg_nsdgeneral.npy') 24 | fmri_train = fmri_train / 300 25 | fmri_test = fmri_test / 300 26 | norm_mean_train = np.mean(fmri_train, axis=0) 27 | norm_scale_train = np.std(fmri_train, axis=0, ddof=1) 28 | norm_mean_test = np.mean(fmri_test, axis=0) 29 | norm_scale_test = np.std(fmri_test, axis=0, ddof=1) 30 | fmri_train = (fmri_train - norm_mean_train) / norm_scale_train 31 | fmri_test = (fmri_test - norm_mean_test) / norm_scale_test 32 | print(fmri_train.shape, fmri_test.shape) 33 | 34 | # Save Directory 35 | weights_save_dir = f'cache/nsd_preproc/regression_weights/sub-{sub:02d}/' 36 | os.makedirs(weights_save_dir, exist_ok=True) 37 | weights_filename = f'regress-encode_clip_weights{param}.pkl' 38 | save_dir = f'cache/nsd_preproc/predicted_fmri/sub-{sub:02d}/' 39 | os.makedirs(save_dir, exist_ok=True) 40 | pred_filename = f'regress-encode_clip{param}.npy' 41 | 42 | # Regression 43 | train_latents= np.load(f'cache/nsd_extracted_embeddings/train_clip_sub-{sub:02d}.npy', mmap_mode='r') 44 | test_latents = np.load(f'cache/nsd_extracted_embeddings/test_clip.npy', mmap_mode='r') 45 | print(train_latents.shape, test_latents.shape) 46 | train_latents_mean = np.mean(train_latents,axis=0) 47 | train_latents_std = np.std(train_latents,axis=0) 48 | train_latents = (train_latents - train_latents_mean) / train_latents_std 49 | test_latents = (test_latents - train_latents_mean) / train_latents_std 50 | 51 | print("Training Regression") 52 | reg = skl.Ridge(alpha=alpha, max_iter=50000, fit_intercept=True) # alpha=50000 53 | reg.fit(train_latents, fmri_train) 54 | print('Training complete') 55 | 56 | if saving_weights: 57 | datadict = { 58 | 'weight' : reg.coef_, 59 | 'bias' : reg.intercept_, 60 | } 61 | with open(weights_save_dir + weights_filename, "wb") as f: 62 | pickle.dump(datadict,f) 63 | 64 | pred_fmri = reg.predict(test_latents) 65 | 66 | np.save(save_dir + pred_filename, pred_fmri) 67 | 68 | # Compute the Euclidean distances 69 | euclidean_distances = np.array([np.linalg.norm(u - v) for u, v in zip(pred_fmri, fmri_test)]) 70 | correlation_distances = np.array([correlation(u, v) for u, v in zip(pred_fmri, fmri_test)]) 71 | # Compute the average Euclidean distance 72 | average_euclidean_distance = euclidean_distances.mean() 73 | correlations = (1 - correlation_distances).mean() 74 | print(reg.score(test_latents,fmri_test), average_euclidean_distance, correlations) 75 | 76 | # 0.08149915683891203 118.8716259728687 0.29988410403096066 for 1000 77 | # 0.1022317182004 117.47897866038834 0.31772801990618493 for 10000 78 | # 0.08073052247217462 118.89326785685333 0.30945139026879914 for 100000 79 | -------------------------------------------------------------------------------- /scripts-nsd/train_regression_clip.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy 3 | from scipy.spatial.distance import correlation 4 | import random 5 | import sklearn.linear_model as skl 6 | import os 7 | import pickle 8 | 9 | import argparse 10 | parser = argparse.ArgumentParser(description='Argument Parser') 11 | parser.add_argument("-sub", "--sub",help="Subject Number",default=1) 12 | parser.add_argument('-weights', '--saving_weights',help="Saving the weights", default=True, action=argparse.BooleanOptionalAction) 13 | parser.add_argument('-size', '--size', help='Size', default=8859) 14 | parser.add_argument('-alpha', '--alpha', help='Alpha for regression strength', default=100000) 15 | args = parser.parse_args() 16 | sub = int(args.sub) 17 | saving_weights=args.saving_weights 18 | alpha = int(args.alpha) 19 | param = '' 20 | 21 | # Load fMRI data 22 | fmri_train = np.load(f'data/nsd_preproc/sub-{sub:02d}/train_fmriavg_nsdgeneral.npy') 23 | fmri_test = np.load(f'data/nsd_preproc/sub-{sub:02d}/test_fmriavg_nsdgeneral.npy') 24 | fmri_train = fmri_train / 300 25 | fmri_test = fmri_test / 300 26 | norm_mean_train = np.mean(fmri_train, axis=0) 27 | norm_scale_train = np.std(fmri_train, axis=0, ddof=1) 28 | fmri_train = (fmri_train - norm_mean_train) / norm_scale_train 29 | fmri_test = (fmri_test - norm_mean_train) / norm_scale_train 30 | print(fmri_train.shape, fmri_test.shape) 31 | 32 | # Save Directory 33 | weights_save_dir = f'cache/nsd_preproc/regression_weights/sub-{sub:02d}/' 34 | os.makedirs(weights_save_dir, exist_ok=True) 35 | weights_filename = f'regress_clip_weights{param}.pkl' 36 | save_dir = f'cache/nsd_preproc/predicted_embeddings/sub-{sub:02d}/' 37 | os.makedirs(save_dir, exist_ok=True) 38 | latent_filename = f'regress_clip{param}.npy' 39 | 40 | # Regression 41 | train_latents= np.load(f'cache/nsd_extracted_embeddings/train_clip_sub-{sub:02d}.npy', mmap_mode='r') 42 | test_latents = np.load(f'cache/nsd_extracted_embeddings/test_clip.npy', mmap_mode='r') 43 | print(train_latents.shape, test_latents.shape) 44 | 45 | print("Training Regression") 46 | reg = skl.Ridge(alpha=alpha, max_iter=50000, fit_intercept=True) 47 | reg.fit(fmri_train, train_latents) 48 | print('Training complete') 49 | 50 | if saving_weights: 51 | datadict = { 52 | 'weight' : reg.coef_, 53 | 'bias' : reg.intercept_, 54 | } 55 | with open(weights_save_dir + weights_filename, "wb") as f: 56 | pickle.dump(datadict,f) 57 | 58 | pred_latent = reg.predict(fmri_test) 59 | pred_latent_mean = np.mean(pred_latent,axis=0) 60 | pred_latent_std = np.std(pred_latent,axis=0) 61 | std_norm_pred_latent = (pred_latent - pred_latent_mean) / pred_latent_std 62 | train_latents_mean = np.mean(train_latents,axis=0) 63 | train_latents_std = np.std(train_latents,axis=0) 64 | pred_latents = std_norm_pred_latent * train_latents_std + train_latents_mean 65 | 66 | np.save(save_dir + latent_filename, pred_latents) 67 | 68 | # Compute the Euclidean distances 69 | euclidean_distances = np.array([np.linalg.norm(u - v) for u, v in zip(pred_latents, test_latents)]) 70 | correlation_distances = np.array([correlation(u, v) for u, v in zip(pred_latents, test_latents)]) 71 | # Compute the average Euclidean distance 72 | average_euclidean_distance = euclidean_distances.mean() 73 | correlations = (1 - correlation_distances).mean() 74 | print(reg.score(fmri_test,test_latents), average_euclidean_distance, correlations) 75 | 76 | # -0.27808989600960654 20.93133171545883 0.5027030380413428 for 1000 77 | # 0.10574885015099406 19.37551002838122 0.5721359892257075 for 10000 78 | # 0.1311441729408978 19.142296681837042 0.5835457920504794 for 100000 79 | # 0.07315393137818146 20.131488509833037 0.5436478701725378 for 1000000 -------------------------------------------------------------------------------- /scripts-nsd/train_regression_ica.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy 3 | from scipy.spatial.distance import correlation 4 | import random 5 | import sklearn.linear_model as skl 6 | import os 7 | import pickle 8 | 9 | import argparse 10 | parser = argparse.ArgumentParser(description='Argument Parser') 11 | parser.add_argument("-sub", "--sub",help="Subject Number",default=1) 12 | parser.add_argument('-weights', '--saving_weights',help="Saving the weights", default=True, action=argparse.BooleanOptionalAction) 13 | parser.add_argument('-size', '--size', help='Size', default=8859) 14 | parser.add_argument('-alpha', '--alpha', help='Alpha for regression strength', default=100000) 15 | args = parser.parse_args() 16 | sub = int(args.sub) 17 | saving_weights=args.saving_weights 18 | alpha=int(args.alpha) 19 | param = '' 20 | 21 | # Load fMRI data 22 | fmri_train = np.load(f'data/nsd_preproc/sub-{sub:02d}/train_fmriavg_nsdgeneral.npy') 23 | fmri_test = np.load(f'data/nsd_preproc/sub-{sub:02d}/test_fmriavg_nsdgeneral.npy') 24 | fmri_train = fmri_train / 300 25 | fmri_test = fmri_test / 300 26 | norm_mean_train = np.mean(fmri_train, axis=0) 27 | norm_scale_train = np.std(fmri_train, axis=0, ddof=1) 28 | fmri_train = (fmri_train - norm_mean_train) / norm_scale_train 29 | fmri_test = (fmri_test - norm_mean_train) / norm_scale_train 30 | print(fmri_train.shape, fmri_test.shape) 31 | 32 | # Save Directory 33 | weights_save_dir = f'cache/nsd_preproc/regression_weights/sub-{sub:02d}/' 34 | os.makedirs(weights_save_dir, exist_ok=True) 35 | weights_filename = f'regress_ica1k_weights{param}.pkl' 36 | save_dir = f'cache/nsd_preproc/predicted_embeddings/sub-{sub:02d}/' 37 | os.makedirs(save_dir, exist_ok=True) 38 | latent_filename = f'regress_ica1k{param}.npy' 39 | 40 | # Regression 41 | train_latents= np.load(f'cache/nsd_extracted_embeddings/train_ica1k_sub-{sub:02d}.npy', mmap_mode='r') 42 | test_latents = np.load(f'cache/nsd_extracted_embeddings/test_ica1k.npy', mmap_mode='r') 43 | print(train_latents.shape, test_latents.shape) 44 | 45 | print("Training Regression") 46 | reg = skl.Ridge(alpha=alpha, max_iter=50000, fit_intercept=True) # alpha=50000 47 | reg.fit(fmri_train, train_latents) 48 | print('Training complete') 49 | 50 | if saving_weights: 51 | datadict = { 52 | 'weight' : reg.coef_, 53 | 'bias' : reg.intercept_, 54 | } 55 | with open(weights_save_dir + weights_filename, "wb") as f: 56 | pickle.dump(datadict,f) 57 | 58 | pred_latent = reg.predict(fmri_test) 59 | pred_latent_mean = np.mean(pred_latent,axis=0) 60 | pred_latent_std = np.std(pred_latent,axis=0) 61 | std_norm_pred_latent = (pred_latent - pred_latent_mean) / pred_latent_std 62 | train_latents_mean = np.mean(train_latents,axis=0) 63 | train_latents_std = np.std(train_latents,axis=0) 64 | pred_latents = std_norm_pred_latent * train_latents_std + train_latents_mean 65 | 66 | np.save(save_dir + latent_filename, pred_latents) 67 | 68 | # Compute the Euclidean distances 69 | euclidean_distances = np.array([np.linalg.norm(u - v) for u, v in zip(pred_latents, test_latents)]) 70 | correlation_distances = np.array([correlation(u, v) for u, v in zip(pred_latents, test_latents)]) 71 | # Compute the average Euclidean distance 72 | average_euclidean_distance = euclidean_distances.mean() 73 | correlations = (1 - correlation_distances).mean() 74 | print(reg.score(fmri_test,test_latents), average_euclidean_distance, correlations) 75 | 76 | # -0.10514682935676765 41.80231167190395 0.04561628546156888 for 10000 77 | # -0.00597656134904789 41.62013437480925 0.05531842966118934 for 100000 78 | # 0.0010146477466907063 41.731707842050014 0.04857243051499194 for 1000000 -------------------------------------------------------------------------------- /scripts-nsd/train_regression_pca.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy 3 | from scipy.spatial.distance import correlation 4 | import random 5 | import sklearn.linear_model as skl 6 | import os 7 | import pickle 8 | 9 | import argparse 10 | parser = argparse.ArgumentParser(description='Argument Parser') 11 | parser.add_argument("-sub", "--sub",help="Subject Number",default=1) 12 | parser.add_argument('-weights', '--saving_weights',help="Saving the weights", default=True, action=argparse.BooleanOptionalAction) 13 | parser.add_argument('-size', '--size', help='Size', default=8859) 14 | parser.add_argument('-alpha', '--alpha', help='Alpha for regression strength', default=100000) 15 | args = parser.parse_args() 16 | sub = int(args.sub) 17 | saving_weights=args.saving_weights 18 | alpha=int(args.alpha) 19 | param = '' 20 | 21 | # Load fMRI data 22 | fmri_train = np.load(f'data/nsd_preproc/sub-{sub:02d}/train_fmriavg_nsdgeneral.npy') 23 | fmri_test = np.load(f'data/nsd_preproc/sub-{sub:02d}/test_fmriavg_nsdgeneral.npy') 24 | fmri_train = fmri_train / 300 25 | fmri_test = fmri_test / 300 26 | norm_mean_train = np.mean(fmri_train, axis=0) 27 | norm_scale_train = np.std(fmri_train, axis=0, ddof=1) 28 | fmri_train = (fmri_train - norm_mean_train) / norm_scale_train 29 | fmri_test = (fmri_test - norm_mean_train) / norm_scale_train 30 | print(fmri_train.shape, fmri_test.shape) 31 | 32 | # Save Directory 33 | weights_save_dir = f'cache/nsd_preproc/regression_weights/sub-{sub:02d}/' 34 | os.makedirs(weights_save_dir, exist_ok=True) 35 | weights_filename = f'regress_pca1k_weights{param}.pkl' 36 | save_dir = f'cache/nsd_preproc/predicted_embeddings/sub-{sub:02d}/' 37 | os.makedirs(save_dir, exist_ok=True) 38 | latent_filename = f'regress_pca1k{param}.npy' 39 | 40 | # Regression 41 | train_latents= np.load(f'cache/nsd_extracted_embeddings/train_pca1k_sub-{sub:02d}.npy', mmap_mode='r') 42 | test_latents = np.load(f'cache/nsd_extracted_embeddings/test_pca1k.npy', mmap_mode='r') 43 | print(train_latents.shape, test_latents.shape) 44 | 45 | print("Training Regression") 46 | reg = skl.Ridge(alpha=alpha, max_iter=50000, fit_intercept=True) # alpha=50000 47 | reg.fit(fmri_train, train_latents) 48 | print('Training complete') 49 | 50 | if saving_weights: 51 | datadict = { 52 | 'weight' : reg.coef_, 53 | 'bias' : reg.intercept_, 54 | } 55 | with open(weights_save_dir + weights_filename, "wb") as f: 56 | pickle.dump(datadict,f) 57 | 58 | pred_latent = reg.predict(fmri_test) 59 | pred_latent_mean = np.mean(pred_latent,axis=0) 60 | pred_latent_std = np.std(pred_latent,axis=0) 61 | std_norm_pred_latent = (pred_latent - pred_latent_mean) / pred_latent_std 62 | train_latents_mean = np.mean(train_latents,axis=0) 63 | train_latents_std = np.std(train_latents,axis=0) 64 | pred_latents = std_norm_pred_latent * train_latents_std + train_latents_mean 65 | 66 | np.save(save_dir + latent_filename, pred_latents) 67 | 68 | # Compute the Euclidean distances 69 | euclidean_distances = np.array([np.linalg.norm(u - v) for u, v in zip(pred_latents, test_latents)]) 70 | correlation_distances = np.array([correlation(u, v) for u, v in zip(pred_latents, test_latents)]) 71 | # Compute the average Euclidean distance 72 | average_euclidean_distance = euclidean_distances.mean() 73 | correlations = (1 - correlation_distances).mean() 74 | print(reg.score(fmri_test,test_latents), average_euclidean_distance, correlations) 75 | 76 | # -0.10386578687656549 28.45740907680061 0.8683454438679088 for 10000 77 | # -0.00617033229942896 29.179730019576592 0.8638009005384237 for 100000 78 | # 0.0007749183300167393 31.230264950791167 0.8480704738186576 for 1000000 -------------------------------------------------------------------------------- /scripts-nsd/train_regression_vdvae.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy 3 | from scipy.spatial.distance import correlation 4 | import random 5 | import sklearn.linear_model as skl 6 | import os 7 | import pickle 8 | 9 | import argparse 10 | parser = argparse.ArgumentParser(description='Argument Parser') 11 | parser.add_argument("-sub", "--sub",help="Subject Number",default=1) 12 | parser.add_argument('-weights', '--saving_weights',help="Saving the weights", default=True, action=argparse.BooleanOptionalAction) 13 | parser.add_argument('-size', '--size', help='Size', default=8859) 14 | parser.add_argument('-alpha', '--alpha', help='Alpha for regression strength', default=100000) 15 | args = parser.parse_args() 16 | sub = int(args.sub) 17 | saving_weights=args.saving_weights 18 | alpha=int(args.alpha) 19 | param = '' 20 | 21 | # Load fMRI data 22 | fmri_train = np.load(f'data/nsd_preproc/sub-{sub:02d}/train_fmriavg_nsdgeneral.npy') 23 | fmri_test = np.load(f'data/nsd_preproc/sub-{sub:02d}/test_fmriavg_nsdgeneral.npy') 24 | fmri_train = fmri_train / 300 25 | fmri_test = fmri_test / 300 26 | norm_mean_train = np.mean(fmri_train, axis=0) 27 | norm_scale_train = np.std(fmri_train, axis=0, ddof=1) 28 | fmri_train = (fmri_train - norm_mean_train) / norm_scale_train 29 | fmri_test = (fmri_test - norm_mean_train) / norm_scale_train 30 | print(fmri_train.shape, fmri_test.shape) 31 | 32 | # Save Directory 33 | weights_save_dir = f'cache/nsd_preproc/regression_weights/sub-{sub:02d}/' 34 | os.makedirs(weights_save_dir, exist_ok=True) 35 | weights_filename = f'regress_vdvae_weights{param}.pkl' 36 | save_dir = f'cache/nsd_preproc/predicted_embeddings/sub-{sub:02d}/' 37 | os.makedirs(save_dir, exist_ok=True) 38 | latent_filename = f'regress_vdvae{param}.npy' 39 | 40 | # Regression 41 | train_latents= np.load(f'cache/nsd_extracted_embeddings/train_vdvae_sub-{sub:02d}.npy', mmap_mode='r') 42 | test_latents = np.load(f'cache/nsd_extracted_embeddings/test_vdvae.npy', mmap_mode='r') 43 | print(train_latents.shape, test_latents.shape) 44 | 45 | print("Training Regression") 46 | reg = skl.Ridge(alpha=alpha, max_iter=50000, fit_intercept=True) 47 | reg.fit(fmri_train, train_latents) 48 | print('Training complete') 49 | 50 | if saving_weights: 51 | datadict = { 52 | 'weight' : reg.coef_, 53 | 'bias' : reg.intercept_, 54 | } 55 | with open(weights_save_dir + weights_filename, "wb") as f: 56 | pickle.dump(datadict,f) 57 | 58 | pred_latent = reg.predict(fmri_test) 59 | pred_latent_mean = np.mean(pred_latent,axis=0) 60 | pred_latent_std = np.std(pred_latent,axis=0) 61 | std_norm_pred_latent = (pred_latent - pred_latent_mean) / pred_latent_std 62 | train_latents_mean = np.mean(train_latents,axis=0) 63 | train_latents_std = np.std(train_latents,axis=0) 64 | pred_latents = std_norm_pred_latent * train_latents_std + train_latents_mean 65 | 66 | np.save(save_dir + latent_filename, pred_latents) 67 | 68 | # Compute the Euclidean distances 69 | euclidean_distances = np.array([np.linalg.norm(u - v) for u, v in zip(pred_latents, test_latents)]) 70 | correlation_distances = np.array([correlation(u, v) for u, v in zip(pred_latents, test_latents)]) 71 | # Compute the average Euclidean distance 72 | average_euclidean_distance = euclidean_distances.mean() 73 | correlations = (1 - correlation_distances).mean() 74 | print(reg.score(fmri_test,test_latents), average_euclidean_distance, correlations) 75 | 76 | # -0.11076795699774779 113.73543042630806 0.018716912714999166 for 10000 77 | # -0.013377966412249741 113.55331471782537 0.02185260768472681 for 100000 78 | # -0.002320856468348707 113.52458604838117 0.02071305333334158 for 1000000 -------------------------------------------------------------------------------- /scripts-nsd_dataprep/download_nsd_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | # Download Experiment Infos 4 | os.system('aws s3 cp s3://natural-scenes-dataset/nsddata/experiments/nsd/nsd_expdesign.mat data/nsd_metadata/experiments/nsd/ --no-sign-request') 5 | os.system('aws s3 cp s3://natural-scenes-dataset/nsddata/experiments/nsd/nsd_stim_info_merged.pkl data/nsd_metadata/experiments/nsd/ --no-sign-request') 6 | 7 | # Download Stimuli 8 | os.system('aws s3 cp s3://natural-scenes-dataset/nsddata_stimuli/stimuli/nsd/nsd_stimuli.hdf5 data/nsd_metadata/stimuli/nsd/ --no-sign-request') 9 | 10 | # Download Betas 11 | for sub in [1,2,5,7]: 12 | for sess in range(1,38): 13 | os.system('aws s3 cp s3://natural-scenes-dataset/nsddata_betas/ppdata/subj{:02d}/func1pt8mm/betas_fithrf_GLMdenoise_RR/betas_session{:02d}.nii.gz data/nsd_preproc/subj{:02d}/func1pt8mm/betas_fithrf_GLMdenoise_RR/ --no-sign-request'.format(sub,sess,sub)) 14 | 15 | # Download ROIs 16 | for sub in [1,2,5,7]: 17 | os.system('aws s3 cp s3://natural-scenes-dataset/nsddata/ppdata/subj{:02d}/func1pt8mm/roi/ data/nsd_preproc/subj{:02d}/func1pt8mm/roi/ --no-sign-request --recursive'.format(sub,sub)) 18 | 19 | # Download Freesurfer 20 | os.system('aws s3 cp s3://natural-scenes-dataset/nsddata/freesurfer/ data/nsd_preproc/freesurfer/ --no-sign-request --recursive') 21 | # for sub in [1,2,5,7]: 22 | # os.system('aws s3 cp s3://natural-scenes-dataset/nsddata/freesurfer/subj{:02d}/ data/nsd_preproc/subj{:02d}/freesurfer/ --no-sign-request --recursive'.format(sub,sub)) 23 | 24 | # Download MNI Transforms 25 | for sub in [1,2,5,7]: 26 | os.system(f'aws s3 cp s3://natural-scenes-dataset/nsddata/ppdata/subj{sub:02d}/transforms/func1pt8-to-MNI.nii.gz data/nsd_preproc/nsddata/ppdata/subj{sub:02d}/transforms/func1pt8-to-MNI.nii.gz --no-sign-request') 27 | -------------------------------------------------------------------------------- /scripts-nsd_dataprep/extract_features-clip.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import torch 3 | from PIL import Image 4 | from io import BytesIO 5 | import numpy as np 6 | from tqdm import tqdm 7 | import os 8 | 9 | from diffusers import StableUnCLIPImg2ImgPipeline 10 | 11 | pipe = StableUnCLIPImg2ImgPipeline.from_pretrained( 12 | "stabilityai/stable-diffusion-2-1-unclip", torch_dtype=torch.float16, variation="fp16" 13 | ) 14 | pipe = pipe.to("cuda") 15 | 16 | 17 | # Load the test_images NumPy array 18 | images = np.load("data/nsd_metadata/test_images.npy", mmap_mode='r').astype(np.uint8) 19 | # Convert each image to a PIL image 20 | images = [Image.fromarray(image).convert("RGB") for image in images] 21 | 22 | embeddings = np.zeros((len(images), 1024)) 23 | device = pipe._execution_device 24 | noise_level = torch.tensor([0], device=device) 25 | torch.manual_seed(0) 26 | for i_image, image in tqdm(enumerate(images), total=len(images)): 27 | embedding = pipe._encode_image(image, device=device, batch_size=1, num_images_per_prompt=1, do_classifier_free_guidance=True,noise_level=noise_level,generator=None, image_embeds = None) 28 | embedding = embedding[1] 29 | embeddings[i_image] = embedding.detach().cpu().numpy()[:1024] 30 | 31 | os.makedirs('cache/nsd_extracted_embeddings', exist_ok=True) 32 | np.save('cache/nsd_extracted_embeddings/test_clip.npy', embeddings) 33 | 34 | for sub in [1,2,5,7]: 35 | # Load the train_images NumPy array 36 | images = np.load(f'data/nsd_metadata/train_images_sub-{sub:02d}.npy', mmap_mode='r').astype(np.uint8) 37 | # Convert each image to a PIL image 38 | images = [Image.fromarray(image).convert("RGB") for image in images] 39 | 40 | embeddings = np.zeros((len(images), 1024)) 41 | device = pipe._execution_device 42 | noise_level = torch.tensor([0], device=device) 43 | torch.manual_seed(0) 44 | for i_image, image in tqdm(enumerate(images), total=len(images)): 45 | embedding = pipe._encode_image(image, device=device, batch_size=1, num_images_per_prompt=1, do_classifier_free_guidance=True,noise_level=noise_level,generator=None, image_embeds = None) 46 | embedding = embedding[1] 47 | embeddings[i_image] = embedding.detach().cpu().numpy()[:1024] 48 | 49 | os.makedirs('cache/nsd_extracted_embeddings', exist_ok=True) 50 | np.save(f'cache/nsd_extracted_embeddings/train_clip_sub-{sub:02d}.npy', embeddings) -------------------------------------------------------------------------------- /scripts-nsd_dataprep/extract_features-ica.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from skimage.transform import resize 3 | import numpy as np 4 | 5 | import argparse 6 | parser = argparse.ArgumentParser(description='Argument Parser') 7 | parser.add_argument("-sub", "--sub",help="Subject Number",default=1) 8 | args = parser.parse_args() 9 | sub = int(args.sub) 10 | param = '' 11 | 12 | ica = np.load("cache/ica.npz") 13 | encoder = ica["encoder"] 14 | decoder = ica["decoder"] 15 | train_mean = ica["mean"] 16 | 17 | print('loading data...') 18 | test_images = np.load("data/nsd_metadata/test_images.npy", mmap_mode="r") / 255 19 | test_images = np.array([resize(image, (64, 64)) for image in test_images]) 20 | train_images = np.load(f"data/nsd_metadata/train_images_sub-{sub:02d}.npy", mmap_mode="r") / 255 21 | train_images = np.array([resize(image, (64, 64)) for image in train_images]) 22 | 23 | print('extracting features...') 24 | 25 | image_dim = 64 * 64 * 3 26 | test_data = np.zeros((len(test_images), image_dim)) 27 | for i, image in enumerate(test_images): 28 | test_data[i, :] = image.flatten() 29 | train_data = np.zeros((len(train_images), image_dim)) 30 | for i, image in enumerate(train_images): 31 | train_data[i, :] = image.flatten() 32 | 33 | latent_dim = 1000 34 | test_latents = np.zeros((len(test_data), latent_dim)) 35 | for i, image in tqdm(enumerate(test_data)): 36 | test_latents[i] = encoder[:latent_dim] @ (image - train_mean).reshape(64, 64, 3).T.flatten() 37 | np.save("cache/nsd_extracted_embeddings/test_ica1k.npy", test_latents) 38 | 39 | train_latents = np.zeros((len(train_data), latent_dim)) 40 | for i, image in tqdm(enumerate(train_data)): 41 | train_latents[i] = encoder[:latent_dim] @ (image - train_mean).reshape(64, 64, 3).T.flatten() 42 | np.save(f"cache/nsd_extracted_embeddings/train_ica1k_sub-{sub:02d}.npy", train_latents) -------------------------------------------------------------------------------- /scripts-nsd_dataprep/extract_features-pca.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from skimage.transform import resize 3 | import numpy as np 4 | 5 | import argparse 6 | parser = argparse.ArgumentParser(description='Argument Parser') 7 | parser.add_argument("-sub", "--sub",help="Subject Number",default=1) 8 | args = parser.parse_args() 9 | sub = int(args.sub) 10 | param = '' 11 | 12 | pca = np.load("cache/pca.npz") 13 | eigenvectors = pca["eigenvectors"] 14 | eigenvalues = pca["eigenvalues"] 15 | 16 | print('loading data...') 17 | test_images = np.load("data/nsd_metadata/test_images.npy", mmap_mode="r") / 255 18 | test_images = np.array([resize(image, (64, 64)) for image in test_images]) 19 | train_images = np.load(f"data/nsd_metadata/train_images_sub-{sub:02d}.npy", mmap_mode="r") / 255 20 | train_images = np.array([resize(image, (64, 64)) for image in train_images]) 21 | 22 | print('extracting features...') 23 | 24 | image_dim = 64 * 64 * 3 25 | test_data = np.zeros((len(test_images), image_dim)) 26 | for i, image in enumerate(test_images): 27 | test_data[i, :] = image.flatten() 28 | train_data = np.zeros((len(train_images), image_dim)) 29 | for i, image in enumerate(train_images): 30 | train_data[i, :] = image.flatten() 31 | 32 | latent_dim = 1000 33 | test_latents = np.zeros((len(test_data), latent_dim)) 34 | for i, image in tqdm(enumerate(test_data)): 35 | test_latents[i] = eigenvectors[:latent_dim] @ image.reshape(64, 64, 3).T.flatten() 36 | np.save("cache/nsd_extracted_embeddings/test_pca1k.npy", test_latents) 37 | 38 | train_latents = np.zeros((len(train_data), latent_dim)) 39 | for i, image in tqdm(enumerate(train_data)): 40 | train_latents[i] = eigenvectors[:latent_dim] @ image.reshape(64, 64, 3).T.flatten() 41 | np.save(f"cache/nsd_extracted_embeddings/train_pca1k_sub-{sub:02d}.npy", train_latents) -------------------------------------------------------------------------------- /scripts-nsd_dataprep/save_test_images.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from PIL import Image 4 | import argparse 5 | 6 | parser = argparse.ArgumentParser(description='Argument Parser') 7 | parser.add_argument("-sub", "--sub",help="Subject Number",default=1) 8 | args = parser.parse_args() 9 | sub = int(args.sub) 10 | 11 | images = np.load(f'data/nsd_metadata/test_images.npy', mmap_mode='r') 12 | test_images_dir = f'data/nsd_metadata/test_images_direct/' 13 | 14 | if not os.path.exists(test_images_dir): 15 | os.makedirs(test_images_dir) 16 | for i in range(len(images)): 17 | im = Image.fromarray(images[i].astype(np.uint8)) 18 | im.save(os.path.join(test_images_dir, f"{i}.png")) 19 | 20 | 21 | -------------------------------------------------------------------------------- /scripts-nsd_figures/freesurfer_import_subj.py: -------------------------------------------------------------------------------- 1 | import cortex 2 | import os 3 | import argparse 4 | parser = argparse.ArgumentParser(description='Argument Parser') 5 | parser.add_argument("-sub", "--sub",help="Subject Number",default=1) 6 | args = parser.parse_args() 7 | sub = int(args.sub) 8 | os.environ['FREESURFER_HOME'] = '/usr/local/freesurfer/7.4.1' 9 | os.environ['PATH'] = os.pathsep.join([os.path.join('/usr/local/freesurfer/7.4.1', 'bin'), os.environ['PATH']]) 10 | os.environ['SUBJECTS_DIR'] = 'data/nsd_preproc/freesurfer/' 11 | freesurfer_path = 'data/nsd_preproc/freesurfer/' 12 | 13 | subject = cortex.freesurfer.import_subj(f"subj{sub:02d}",freesurfer_subject_dir=freesurfer_path) 14 | cortex.freesurfer.import_flat(f"subj{sub:02d}",'full',freesurfer_subject_dir=freesurfer_path,auto_overwrite=True) 15 | pts_lh,polys_lh,_=cortex.freesurfer.get_surf(f'subj{sub:02d}','lh','patch','full'+'.flat',freesurfer_subject_dir=freesurfer_path) 16 | pts_rh,polys_rh,_=cortex.freesurfer.get_surf(f'subj{sub:02d}','rh','patch','full'+'.flat',freesurfer_subject_dir=freesurfer_path) 17 | ref_path = f'data/nsd_preproc/subj{sub:02d}/func1pt8mm/betas_fithrf_GLMdenoise_RR/betas_session01.nii.gz' 18 | cortex.align.automatic(f'subj{sub:02d}','full',reference=ref_path) -------------------------------------------------------------------------------- /scripts-nsd_figures/plot_clip_patterns_mni.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import nibabel as nib 4 | import cortex 5 | import os 6 | 7 | import argparse 8 | parser = argparse.ArgumentParser(description='Argument Parser') 9 | parser.add_argument("-sub", "--sub",help="Subject Number",default=1) 10 | args = parser.parse_args() 11 | sub = int(args.sub) 12 | 13 | vol_dir = f'cache/nsd_preproc/predicted_patterns/clip_patterns/sub-{sub:02d}/mni/' 14 | 15 | output_folder = f'results/nsd_preproc/sub-{sub:02d}/clip_patterns/mni/' 16 | os.makedirs(output_folder, exist_ok=True) 17 | 18 | cortex.download_subject('fsaverage') 19 | 20 | vol_filenames = os.listdir(vol_dir) 21 | for vol_filename in vol_filenames: 22 | vol_data = nib.load(vol_dir+vol_filename) 23 | volumn = cortex.Volume(np.moveaxis(np.moveaxis(vol_data.get_fdata(),0,-1),0,1), subject='fsaverage', xfmname='atlas', cmap='twilight', vmin=-1.5, vmax=1.5) 24 | fig = plt.figure() # 100 25 | cortex.quickflat.make_figure(volumn,recache=1, fig=fig) 26 | plt.savefig(output_folder+vol_filename.replace('.nii.gz','.png'), dpi=300) 27 | plt.close() 28 | 29 | -------------------------------------------------------------------------------- /scripts-nsd_figures/plot_clip_patterns_mni_avg.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import nibabel as nib 4 | import cortex 5 | import os 6 | 7 | vol_dir = f'cache/nsd_preproc/predicted_patterns/clip_patterns/sub-01/mni/' 8 | vol_filenames = os.listdir(vol_dir) 9 | pattern_names = [vol_filename.split('_')[0] for vol_filename in vol_filenames] 10 | 11 | output_folder = f'results/nsd_preproc/avg-1-2-5-7/clip_patterns/mni/' 12 | os.makedirs(output_folder, exist_ok=True) 13 | os.makedirs('cache/nsd_preproc/predicted_patterns/clip_patterns/avg-1-2-5-7/mni/', exist_ok=True) 14 | 15 | for i, vol_filename in enumerate(vol_filenames): 16 | subs = [1,2,5,7] 17 | volumes = [] 18 | for sub in subs: 19 | vol_dir = f'cache/nsd_preproc/predicted_patterns/clip_patterns/sub-{sub:02d}/mni/' 20 | vol_data = nib.load(vol_dir+vol_filename) 21 | volumes.append(vol_data.get_fdata()) 22 | volumes = np.array(volumes) 23 | avg_vol = np.mean(volumes, axis=0) 24 | pattern_image = nib.Nifti1Image(avg_vol, affine=vol_data.affine, header=vol_data.header) 25 | nib.save(pattern_image, f'cache/nsd_preproc/predicted_patterns/clip_patterns/avg-1-2-5-7/mni/{vol_filename}') 26 | volumn = cortex.Volume(np.moveaxis(np.moveaxis(avg_vol,0,-1),0,1), subject='fsaverage', xfmname='atlas', cmap='twilight', vmin=-1.5, vmax=1.5) 27 | fig = plt.figure() # 100 28 | cortex.quickflat.make_figure(volumn,recache=1, fig=fig) 29 | plt.title(f'{pattern_names[i].capitalize()} Pattern - NSD 4-Subject Average (1, 2, 5, 7) MNI') 30 | plt.savefig(output_folder+vol_filename.replace('.nii.gz','.png'), dpi=300) # dpi=100 if uploading to github 31 | plt.close() 32 | 33 | -------------------------------------------------------------------------------- /scripts-nsd_figures/plot_ica-color_patterns_mni.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import nibabel as nib 4 | import cortex 5 | import os 6 | 7 | import argparse 8 | parser = argparse.ArgumentParser(description='Argument Parser') 9 | parser.add_argument("-sub", "--sub",help="Subject Number",default=1) 10 | args = parser.parse_args() 11 | sub = int(args.sub) 12 | 13 | vol_dir = f'cache/nsd_preproc/predicted_patterns/ica-color_patterns/sub-{sub:02d}/mni/' 14 | 15 | output_folder = f'results/nsd_preproc/sub-{sub:02d}/ica-color_patterns/mni/' 16 | os.makedirs(output_folder, exist_ok=True) 17 | 18 | cortex.download_subject('fsaverage') 19 | 20 | vol_filenames = os.listdir(vol_dir) 21 | for vol_filename in vol_filenames: 22 | vol_data = nib.load(vol_dir+vol_filename) 23 | volumn = cortex.Volume(np.moveaxis(np.moveaxis(vol_data.get_fdata(),0,-1),0,1), subject='fsaverage', xfmname='atlas', cmap='twilight', vmin=-0.1, vmax=0.1) 24 | fig = plt.figure() # 100 25 | cortex.quickflat.make_figure(volumn,recache=1, fig=fig) 26 | plt.savefig(output_folder+vol_filename.replace('.nii.gz','.png'), dpi=300) 27 | plt.close() 28 | 29 | 30 | -------------------------------------------------------------------------------- /scripts-nsd_figures/plot_ica-color_patterns_mni_avg.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import nibabel as nib 4 | import cortex 5 | import os 6 | 7 | vol_dir = f'cache/nsd_preproc/predicted_patterns/ica-color_patterns/sub-01/mni/' 8 | vol_filenames = os.listdir(vol_dir) 9 | pattern_names = [vol_filename.split('_')[0] for vol_filename in vol_filenames] 10 | 11 | output_folder = f'results/nsd_preproc/avg-1-2-5-7/ica-color_patterns/mni/' 12 | os.makedirs(output_folder, exist_ok=True) 13 | os.makedirs('cache/nsd_preproc/predicted_patterns/ica-color_patterns/avg-1-2-5-7/mni/', exist_ok=True) 14 | 15 | for i, vol_filename in enumerate(vol_filenames): 16 | subs = [1,2,5,7] 17 | volumes = [] 18 | for sub in subs: 19 | vol_dir = f'cache/nsd_preproc/predicted_patterns/ica-color_patterns/sub-{sub:02d}/mni/' 20 | vol_data = nib.load(vol_dir+vol_filename) 21 | volumes.append(vol_data.get_fdata()) 22 | volumes = np.array(volumes) 23 | avg_vol = np.mean(volumes, axis=0) 24 | pattern_image = nib.Nifti1Image(avg_vol, affine=vol_data.affine, header=vol_data.header) 25 | nib.save(pattern_image, f'cache/nsd_preproc/predicted_patterns/ica-color_patterns/avg-1-2-5-7/mni/{vol_filename}') 26 | volumn = cortex.Volume(np.moveaxis(np.moveaxis(avg_vol,0,-1),0,1), subject='fsaverage', xfmname='atlas', cmap='twilight', vmin=-0.1, vmax=0.1) 27 | fig = plt.figure() # 100 28 | cortex.quickflat.make_figure(volumn,recache=1, fig=fig) 29 | plt.title(f'{pattern_names[i].capitalize()} Pattern - NSD 4-Subject Average (1, 2, 5, 7) MNI') 30 | plt.savefig(output_folder+vol_filename.replace('.nii.gz','.png'), dpi=300) # dpi=100 if uploading to github 31 | plt.close() 32 | 33 | -------------------------------------------------------------------------------- /scripts-nsd_figures/plot_pca-brightness_patterns_mni.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import nibabel as nib 4 | import cortex 5 | import os 6 | 7 | import argparse 8 | parser = argparse.ArgumentParser(description='Argument Parser') 9 | parser.add_argument("-sub", "--sub",help="Subject Number",default=1) 10 | args = parser.parse_args() 11 | sub = int(args.sub) 12 | 13 | vol_dir = f'cache/nsd_preproc/predicted_patterns/pca-brightness_patterns/sub-{sub:02d}/mni/' 14 | 15 | output_folder = f'results/nsd_preproc/sub-{sub:02d}/pca-brightness_patterns/mni/' 16 | os.makedirs(output_folder, exist_ok=True) 17 | 18 | cortex.download_subject('fsaverage') 19 | 20 | vol_filenames = os.listdir(vol_dir) 21 | for vol_filename in vol_filenames: 22 | vol_data = nib.load(vol_dir+vol_filename) 23 | # volumn = cortex.Volume(np.moveaxis(np.moveaxis(vol_data.get_fdata(),0,-1),0,1), subject='fsaverage', xfmname='atlas', cmap='twilight', vmin=-0.1, vmax=0.1) 24 | volumn = cortex.Volume(np.moveaxis(np.moveaxis(vol_data.get_fdata(),0,-1),0,1), subject='fsaverage', xfmname='atlas', cmap='twilight', vmin=-1., vmax=1.) 25 | fig = plt.figure() # 100 26 | cortex.quickflat.make_figure(volumn,recache=1, fig=fig) 27 | plt.savefig(output_folder+vol_filename.replace('.nii.gz','.png'), dpi=300) 28 | plt.close() 29 | 30 | 31 | -------------------------------------------------------------------------------- /scripts-nsd_figures/plot_pca-brightness_patterns_mni_avg.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import nibabel as nib 4 | import cortex 5 | import os 6 | 7 | vol_dir = f'cache/nsd_preproc/predicted_patterns/pca-brightness_patterns/sub-01/mni/' 8 | vol_filenames = os.listdir(vol_dir) 9 | pattern_names = [vol_filename.split('_')[0] for vol_filename in vol_filenames] 10 | 11 | output_folder = f'results/nsd_preproc/avg-1-2-5-7/pca-brightness_patterns/mni/' 12 | os.makedirs(output_folder, exist_ok=True) 13 | os.makedirs('cache/nsd_preproc/predicted_patterns/pca-brightness_patterns/avg-1-2-5-7/mni/', exist_ok=True) 14 | 15 | for i, vol_filename in enumerate(vol_filenames): 16 | subs = [1,2,5,7] 17 | volumes = [] 18 | for sub in subs: 19 | vol_dir = f'cache/nsd_preproc/predicted_patterns/pca-brightness_patterns/sub-{sub:02d}/mni/' 20 | vol_data = nib.load(vol_dir+vol_filename) 21 | volumes.append(vol_data.get_fdata()) 22 | volumes = np.array(volumes) 23 | avg_vol = np.mean(volumes, axis=0) 24 | pattern_image = nib.Nifti1Image(avg_vol, affine=vol_data.affine, header=vol_data.header) 25 | nib.save(pattern_image, f'cache/nsd_preproc/predicted_patterns/pca-brightness_patterns/avg-1-2-5-7/mni/{vol_filename}') 26 | # volumn = cortex.Volume(np.moveaxis(np.moveaxis(avg_vol,0,-1),0,1), subject='fsaverage', xfmname='atlas', cmap='twilight', vmin=-0.1, vmax=0.1) 27 | volumn = cortex.Volume(np.moveaxis(np.moveaxis(avg_vol,0,-1),0,1), subject='fsaverage', xfmname='atlas', cmap='twilight', vmin=-1., vmax=1.) 28 | fig = plt.figure() # 100 29 | cortex.quickflat.make_figure(volumn,recache=1, fig=fig) 30 | plt.title(f'{pattern_names[i].capitalize()} Pattern - NSD 4-Subject Average (1, 2, 5, 7) MNI') 31 | plt.savefig(output_folder+vol_filename.replace('.nii.gz','.png'), dpi=300) # dpi=100 if uploading to github 32 | plt.close() 33 | 34 | -------------------------------------------------------------------------------- /scripts-nsd_figures/plot_vdvae-texture_patterns_mni.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import nibabel as nib 4 | import cortex 5 | import os 6 | 7 | import argparse 8 | parser = argparse.ArgumentParser(description='Argument Parser') 9 | parser.add_argument("-sub", "--sub",help="Subject Number",default=1) 10 | args = parser.parse_args() 11 | sub = int(args.sub) 12 | 13 | vol_dir = f'cache/nsd_preproc/predicted_patterns/vdvae-texture_patterns/sub-{sub:02d}/mni/' 14 | 15 | output_folder = f'results/nsd_preproc/sub-{sub:02d}/vdvae-texture_patterns/mni/' 16 | os.makedirs(output_folder, exist_ok=True) 17 | 18 | cortex.download_subject('fsaverage') 19 | 20 | vol_filenames = os.listdir(vol_dir) 21 | for vol_filename in vol_filenames: 22 | vol_data = nib.load(vol_dir+vol_filename) 23 | volumn = cortex.Volume(np.moveaxis(np.moveaxis(vol_data.get_fdata(),0,-1),0,1), subject='fsaverage', xfmname='atlas', cmap='twilight', vmin=-0.4, vmax=0.4) 24 | fig = plt.figure() # 100 25 | cortex.quickflat.make_figure(volumn,recache=1, fig=fig) 26 | plt.savefig(output_folder+vol_filename.replace('.nii.gz','.png'), dpi=300) 27 | plt.close() 28 | 29 | 30 | -------------------------------------------------------------------------------- /scripts-nsd_figures/plot_vdvae-texture_patterns_mni_avg.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import nibabel as nib 4 | import cortex 5 | import os 6 | 7 | vol_dir = f'cache/nsd_preproc/predicted_patterns/vdvae-texture_patterns/sub-01/mni/' 8 | vol_filenames = os.listdir(vol_dir) 9 | pattern_names = [vol_filename.split('_')[0] for vol_filename in vol_filenames] 10 | 11 | output_folder = f'results/nsd_preproc/avg-1-2-5-7/vdvae-texture_patterns/mni/' 12 | os.makedirs(output_folder, exist_ok=True) 13 | os.makedirs('cache/nsd_preproc/predicted_patterns/vdvae-texture_patterns/avg-1-2-5-7/mni/', exist_ok=True) 14 | 15 | for i, vol_filename in enumerate(vol_filenames): 16 | subs = [1,2,5,7] 17 | volumes = [] 18 | for sub in subs: 19 | vol_dir = f'cache/nsd_preproc/predicted_patterns/vdvae-texture_patterns/sub-{sub:02d}/mni/' 20 | vol_data = nib.load(vol_dir+vol_filename) 21 | volumes.append(vol_data.get_fdata()) 22 | volumes = np.array(volumes) 23 | avg_vol = np.mean(volumes, axis=0) 24 | pattern_image = nib.Nifti1Image(avg_vol, affine=vol_data.affine, header=vol_data.header) 25 | nib.save(pattern_image, f'cache/nsd_preproc/predicted_patterns/vdvae-texture_patterns/avg-1-2-5-7/mni/{vol_filename}') 26 | volumn = cortex.Volume(np.moveaxis(np.moveaxis(avg_vol,0,-1),0,1), subject='fsaverage', xfmname='atlas', cmap='twilight', vmin=-0.4, vmax=0.4) 27 | fig = plt.figure() # 100 28 | cortex.quickflat.make_figure(volumn,recache=1, fig=fig) 29 | plt.title(f'{pattern_names[i].capitalize()} Pattern - NSD 4-Subject Average (1, 2, 5, 7) MNI') 30 | plt.savefig(output_folder+vol_filename.replace('.nii.gz','.png'), dpi=300) # dpi=100 if uploading to github 31 | plt.close() 32 | 33 | -------------------------------------------------------------------------------- /scripts-nsd_figures/to_mni.py: -------------------------------------------------------------------------------- 1 | import cortex 2 | from nsd_mapdata import NSDmapdata 3 | 4 | import os 5 | os.environ['FREESURFER_HOME'] = '/usr/local/freesurfer/7.4.1' 6 | os.environ['PATH'] = os.pathsep.join([os.path.join('/usr/local/freesurfer/7.4.1', 'bin'), os.environ['PATH']]) 7 | os.environ['SUBJECTS_DIR'] = 'data/nsd_preproc/freesurfer' 8 | 9 | print(cortex.database.default_filestore) 10 | 11 | import argparse 12 | parser = argparse.ArgumentParser(description='Argument Parser') 13 | parser.add_argument("-sub", "--sub",help="Subject Number",default=1) 14 | parser.add_argument('-pattern', '--pattern-type', help='Pattern Type', default='clip') 15 | args = parser.parse_args() 16 | sub = int(args.sub) 17 | pattern_type = args.pattern_type 18 | assert pattern_type in ['clip', 'pca-brightness', 'ica-color', 'vdvae-texture'] 19 | 20 | base_path = os.path.join('data/nsd_preproc') 21 | nsd = NSDmapdata(base_path) 22 | sourcespace = 'func1pt8' 23 | sourcefolder = f'cache/nsd_preproc/predicted_patterns/{pattern_type}_patterns/sub-{sub:02d}/func1pt8mm/' 24 | outputfolder = f'cache/nsd_preproc/predicted_patterns/{pattern_type}_patterns/sub-{sub:02d}/mni/' 25 | 26 | os.makedirs(outputfolder, exist_ok=True) 27 | for filename in os.listdir(sourcefolder): 28 | sourcedata = os.path.join(sourcefolder, filename) 29 | targetspace = 'MNI' 30 | targetdata = nsd.fit( 31 | sub, 32 | sourcespace, 33 | targetspace, 34 | sourcedata, 35 | interptype='cubic', 36 | badval=0, 37 | outputfile=os.path.join(outputfolder, filename)) -------------------------------------------------------------------------------- /scripts-thingseeg2/reconstruct_from_embeddings.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import torch 3 | from PIL import Image 4 | from io import BytesIO 5 | import numpy as np 6 | import os 7 | 8 | from diffusers import StableUnCLIPImg2ImgPipeline 9 | 10 | import argparse 11 | parser = argparse.ArgumentParser(description='Argument Parser') 12 | parser.add_argument("-sub", "--sub",help="Subject Number",default=1) 13 | args = parser.parse_args() 14 | sub=int(args.sub) 15 | 16 | pred_clip = np.load(f"cache/thingseeg2_preproc/predicted_embeddings/sub-{sub:02d}/regress_clip.npy", mmap_mode='r') # Load the embeddings 17 | pred_vae = np.load(f"cache/thingseeg2_preproc/predicted_embeddings/sub-{sub:02d}/regress_vae.npy", mmap_mode='r') 18 | recon_dir = f"results/thingseeg2_preproc/sub-{sub:02d}/unclip/" # Directory to save the reconstructed images 19 | os.makedirs(recon_dir, exist_ok=True) 20 | 21 | #Start the StableUnCLIP Image variations pipeline 22 | pipe = StableUnCLIPImg2ImgPipeline.from_pretrained( 23 | "stabilityai/stable-diffusion-2-1-unclip", torch_dtype=torch.float16, variation="fp16" 24 | ) 25 | pipe = pipe.to("cuda") 26 | device = pipe._execution_device 27 | torch_ones = torch.ones(512, dtype=torch.float16, device=device) 28 | torch_zeros = torch.zeros(512, dtype=torch.float16, device=device) 29 | extra_portion = torch.cat([torch_ones, torch_zeros]) 30 | 31 | for i, embedding in enumerate(pred_clip): 32 | print(i) 33 | vae_latent = pred_vae[i].reshape((1, 4, 96, 96)) 34 | vae_latent = torch.from_numpy(vae_latent).to(device).half() 35 | torch.manual_seed(0) 36 | noise_latent=torch.randn(vae_latent.shape, device=device).half() 37 | vae_latent = vae_latent*0.02 + noise_latent 38 | embedding = torch.tensor(embedding, device=device, dtype=torch.float16) 39 | embedding = torch.cat([embedding, extra_portion]).unsqueeze(0) 40 | negative_prompt_embeds = torch.zeros_like(embedding) 41 | embedding = torch.cat([negative_prompt_embeds, embedding]) 42 | torch.manual_seed(0) 43 | image = pipe.decode(embedding, latents=vae_latent, guidance_scale=7.5).images[0] 44 | image.save(recon_dir + f"{i}.png") -------------------------------------------------------------------------------- /scripts-thingseeg2/reconstruct_from_embeddings_clip-only.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import torch 3 | from PIL import Image 4 | from io import BytesIO 5 | import numpy as np 6 | import os 7 | 8 | from diffusers import StableUnCLIPImg2ImgPipeline 9 | 10 | import argparse 11 | parser = argparse.ArgumentParser(description='Argument Parser') 12 | parser.add_argument("-sub", "--sub",help="Subject Number",default=1) 13 | args = parser.parse_args() 14 | sub=int(args.sub) 15 | 16 | pred_clip = np.load(f"cache/thingseeg2_preproc/predicted_embeddings/sub-{sub:02d}/regress_clip.npy", mmap_mode='r') # Load the embeddings 17 | recon_dir = f"results/thingseeg2_preproc/sub-{sub:02d}/unclip/" # Directory to save the reconstructed images 18 | os.makedirs(recon_dir, exist_ok=True) 19 | 20 | #Start the StableUnCLIP Image variations pipeline 21 | pipe = StableUnCLIPImg2ImgPipeline.from_pretrained( 22 | "stabilityai/stable-diffusion-2-1-unclip", torch_dtype=torch.float16, variation="fp16" 23 | ) 24 | pipe = pipe.to("cuda") 25 | device = pipe._execution_device 26 | torch_ones = torch.ones(512, dtype=torch.float16, device=device) 27 | torch_zeros = torch.zeros(512, dtype=torch.float16, device=device) 28 | extra_portion = torch.cat([torch_ones, torch_zeros]) 29 | 30 | for i, embedding in enumerate(pred_clip): 31 | print(i) 32 | embedding = torch.tensor(embedding, device=device, dtype=torch.float16) 33 | embedding = torch.cat([embedding, extra_portion]).unsqueeze(0) 34 | negative_prompt_embeds = torch.zeros_like(embedding) 35 | embedding = torch.cat([negative_prompt_embeds, embedding]) 36 | torch.manual_seed(0) 37 | image = pipe.decode(embedding).images[0] 38 | image.save(recon_dir + f"{i}.png") -------------------------------------------------------------------------------- /scripts-thingseeg2/reconstruct_from_embeddings_ica.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from tqdm import tqdm 4 | from PIL import Image 5 | 6 | import argparse 7 | parser = argparse.ArgumentParser(description='Argument Parser') 8 | parser.add_argument("-sub", "--sub",help="Subject Number",default=1) 9 | parser.add_argument('-size', '--size', help='Size', default=8859) 10 | args = parser.parse_args() 11 | sub = int(args.sub) 12 | param = '' 13 | 14 | 15 | pred_latents = np.load(f'cache/thingseeg2_preproc/predicted_embeddings/sub-{sub:02d}/regress_ica1k.npy', mmap_mode='r') 16 | 17 | recon_dir = f'results/thingseeg2_preproc/sub-{sub:02d}/ica1k{param}/' 18 | os.makedirs(recon_dir, exist_ok=True) 19 | 20 | ica = np.load("cache/ica.npz") 21 | encoder = ica["encoder"] 22 | decoder = ica["decoder"] 23 | train_mean = ica["mean"] 24 | latent_dim = 1000 25 | 26 | print('Reconstructing images...') 27 | images = np.clip(decoder[:, :latent_dim] @ pred_latents.T + train_mean[:, np.newaxis], 0, 1).T.reshape((len(pred_latents), 64, 64, 3), order="F") 28 | images = (images * 255).astype(np.uint8) 29 | 30 | print('Saving images...') 31 | for iter in tqdm(range(len(pred_latents)), total=len(pred_latents)): 32 | img = Image.fromarray(images[iter]) 33 | img.save(f'{recon_dir}{iter:03d}.png') 34 | 35 | 36 | -------------------------------------------------------------------------------- /scripts-thingseeg2/reconstruct_from_embeddings_pca.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from tqdm import tqdm 4 | from PIL import Image 5 | 6 | import argparse 7 | parser = argparse.ArgumentParser(description='Argument Parser') 8 | parser.add_argument("-sub", "--sub",help="Subject Number",default=1) 9 | args = parser.parse_args() 10 | sub = int(args.sub) 11 | param = '' 12 | 13 | 14 | pred_latents = np.load(f'cache/thingseeg2_preproc/predicted_embeddings/sub-{sub:02d}/regress_pca1k.npy', mmap_mode='r') 15 | 16 | recon_dir = f'results/thingseeg2_preproc/sub-{sub:02d}/pca1k{param}/' 17 | os.makedirs(recon_dir, exist_ok=True) 18 | 19 | pca = np.load("cache/pca.npz") 20 | eigenvectors = pca["eigenvectors"] 21 | eigenvalues = pca["eigenvalues"] 22 | latent_dim = 1000 23 | 24 | print('Reconstructing images...') 25 | images = np.clip(eigenvectors[:latent_dim].T @ pred_latents.T, 0, 1).T.reshape((len(pred_latents), 64, 64, 3), order="F") 26 | images = (images * 255).astype(np.uint8) 27 | 28 | print('Saving images...') 29 | for iter in tqdm(range(len(pred_latents)), total=len(pred_latents)): 30 | img = Image.fromarray(images[iter]) 31 | img.save(f'{recon_dir}{iter:03d}.png') 32 | 33 | 34 | -------------------------------------------------------------------------------- /scripts-thingseeg2_dataprep/extract_features-clip.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import torch 3 | from PIL import Image 4 | from io import BytesIO 5 | import numpy as np 6 | from tqdm import tqdm 7 | import os 8 | 9 | from diffusers import StableUnCLIPImg2ImgPipeline 10 | 11 | pipe = StableUnCLIPImg2ImgPipeline.from_pretrained( 12 | "stabilityai/stable-diffusion-2-1-unclip", torch_dtype=torch.float16, variation="fp16" 13 | ) 14 | pipe = pipe.to("cuda") 15 | 16 | # Load the test_images NumPy array 17 | images = np.load("data/thingseeg2_metadata/test_images.npy", mmap_mode='r') 18 | # Convert each image to a PIL image 19 | images = [Image.fromarray(image).convert("RGB") for image in images] 20 | 21 | embeddings = np.zeros((len(images), 1024)) 22 | device = pipe._execution_device 23 | noise_level = torch.tensor([0], device=device) 24 | torch.manual_seed(0) 25 | for i_image, image in tqdm(enumerate(images), total=len(images)): 26 | embedding = pipe._encode_image(image, device=device, batch_size=1, num_images_per_prompt=1, do_classifier_free_guidance=True,noise_level=noise_level,generator=None, image_embeds = None) 27 | embedding = embedding[1] 28 | embeddings[i_image] = embedding.detach().cpu().numpy()[:1024] 29 | 30 | os.makedirs('cache/thingseeg2_extracted_embeddings', exist_ok=True) 31 | np.save('cache/thingseeg2_extracted_embeddings/test_clip.npy', embeddings) 32 | 33 | # Load the train_images NumPy array 34 | images = np.load("data/thingseeg2_metadata/train_images.npy", mmap_mode='r') 35 | # Convert each image to a PIL image 36 | images = [Image.fromarray(image).convert("RGB") for image in images] 37 | 38 | embeddings = np.zeros((len(images), 1024)) 39 | device = pipe._execution_device 40 | noise_level = torch.tensor([0], device=device) 41 | torch.manual_seed(0) 42 | for i_image, image in tqdm(enumerate(images), total=len(images)): 43 | embedding = pipe._encode_image(image, device=device, batch_size=1, num_images_per_prompt=1, do_classifier_free_guidance=True,noise_level=noise_level,generator=None, image_embeds = None) 44 | embedding = embedding[1] 45 | embeddings[i_image] = embedding.detach().cpu().numpy()[:1024] 46 | 47 | os.makedirs('cache/thingseeg2_extracted_embeddings', exist_ok=True) 48 | np.save('cache/thingseeg2_extracted_embeddings/train_clip.npy', embeddings) -------------------------------------------------------------------------------- /scripts-thingseeg2_dataprep/extract_features-clip_grayscale.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import torch 3 | from PIL import Image 4 | from io import BytesIO 5 | import numpy as np 6 | from tqdm import tqdm 7 | import os 8 | 9 | from diffusers import StableUnCLIPImg2ImgPipeline 10 | 11 | pipe = StableUnCLIPImg2ImgPipeline.from_pretrained( 12 | "stabilityai/stable-diffusion-2-1-unclip", torch_dtype=torch.float16, variation="fp16" 13 | ) 14 | pipe = pipe.to("cuda") 15 | 16 | # Load the test_images NumPy array 17 | images = np.load("data/thingseeg2_metadata/test_images_grayscale.npy", mmap_mode='r') 18 | # Convert each image to a PIL image 19 | images = [Image.fromarray(image).convert("RGB") for image in images] 20 | 21 | embeddings = np.zeros((len(images), 1024)) 22 | device = pipe._execution_device 23 | noise_level = torch.tensor([0], device=device) 24 | torch.manual_seed(0) 25 | for i_image, image in tqdm(enumerate(images), total=len(images)): 26 | embedding = pipe._encode_image(image, device=device, batch_size=1, num_images_per_prompt=1, do_classifier_free_guidance=True,noise_level=noise_level,generator=None, image_embeds = None) 27 | embedding = embedding[1] 28 | embeddings[i_image] = embedding.detach().cpu().numpy()[:1024] 29 | 30 | os.makedirs('cache/thingseeg2_extracted_embeddings', exist_ok=True) 31 | np.save('cache/thingseeg2_extracted_embeddings/test_clip_grayscale.npy', embeddings) 32 | 33 | # Load the train_images NumPy array 34 | images = np.load("data/thingseeg2_metadata/train_images_grayscale.npy", mmap_mode='r') 35 | # Convert each image to a PIL image 36 | images = [Image.fromarray(image).convert("RGB") for image in images] 37 | 38 | embeddings = np.zeros((len(images), 1024)) 39 | device = pipe._execution_device 40 | noise_level = torch.tensor([0], device=device) 41 | torch.manual_seed(0) 42 | for i_image, image in tqdm(enumerate(images), total=len(images)): 43 | embedding = pipe._encode_image(image, device=device, batch_size=1, num_images_per_prompt=1, do_classifier_free_guidance=True,noise_level=noise_level,generator=None, image_embeds = None) 44 | embedding = embedding[1] 45 | embeddings[i_image] = embedding.detach().cpu().numpy()[:1024] 46 | 47 | os.makedirs('cache/thingseeg2_extracted_embeddings', exist_ok=True) 48 | np.save('cache/thingseeg2_extracted_embeddings/train_clip_grayscale.npy', embeddings) -------------------------------------------------------------------------------- /scripts-thingseeg2_dataprep/extract_features-cliptext.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('versatile_diffusion') 3 | import os 4 | import numpy as np 5 | 6 | import torch 7 | from lib.cfg_helper import model_cfg_bank 8 | from lib.model_zoo import get_model 9 | from torch.utils.data import DataLoader, Dataset 10 | 11 | from lib.model_zoo.vd import VD 12 | from lib.cfg_holder import cfg_unique_holder as cfguh 13 | from lib.cfg_helper import get_command_line_args, cfg_initiates, load_cfg_yaml 14 | import matplotlib.pyplot as plt 15 | import torchvision.transforms as T 16 | from tqdm import tqdm 17 | 18 | # import argparse 19 | # parser = argparse.ArgumentParser(description='Argument Parser') 20 | # parser.add_argument("-sub", "--sub",help="Subject Number",default=1) 21 | # args = parser.parse_args() 22 | # sub=int(args.sub) 23 | # assert sub in [1,2,5,7] 24 | 25 | cfgm_name = 'vd_noema' 26 | pth = 'versatile_diffusion/pretrained/vd-four-flow-v1-0-fp16-deprecated.pth' 27 | cfgm = model_cfg_bank()(cfgm_name) 28 | net = get_model()(cfgm) 29 | sd = torch.load(pth, map_location='cpu') 30 | net.load_state_dict(sd, strict=False) 31 | 32 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 33 | net.clip = net.clip.to(device) 34 | 35 | train_caps = np.load('data/thingseeg2_metadata/train_concepts.npy', mmap_mode='r') 36 | test_caps = np.load('data/thingseeg2_metadata/test_concepts.npy', mmap_mode='r') 37 | print(train_caps.shape, test_caps.shape) 38 | 39 | num_embed, num_features, num_test, num_train = 77, 768, len(test_caps), len(train_caps) 40 | 41 | train_clip = np.zeros((num_train,num_embed, num_features)) 42 | test_clip = np.zeros((num_test,num_embed, num_features)) 43 | if not os.path.exists('cache/thingseeg2_extracted_embeddings'): 44 | os.makedirs('cache/thingseeg2_extracted_embeddings') 45 | with torch.no_grad(): 46 | for i,annots in tqdm(enumerate(test_caps), total=len(test_caps), desc='Extracting test CLIP text'): 47 | # cin = list(annots[annots!='']) 48 | cin = [annots] 49 | # print(i) 50 | # print(i, cin) 51 | c = net.clip_encode_text(cin) 52 | test_clip[i] = c.to('cpu').numpy().mean(0) 53 | 54 | np.save('cache/thingseeg2_extracted_embeddings/test_cliptext.npy',test_clip) 55 | 56 | for i,annots in tqdm(enumerate(train_caps), total=len(train_caps), desc='Extracting train CLIP text'): 57 | # cin = list(annots[annots!='']) 58 | cin = [annots] 59 | # print(i) 60 | # print(i, cin) 61 | c = net.clip_encode_text(cin) 62 | train_clip[i] = c.to('cpu').numpy().mean(0) 63 | np.save('cache/thingseeg2_extracted_embeddings/train_cliptext.npy',train_clip) 64 | 65 | 66 | -------------------------------------------------------------------------------- /scripts-thingseeg2_dataprep/extract_features-clipvision.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('versatile_diffusion') 3 | import os 4 | from PIL import Image 5 | import numpy as np 6 | 7 | import torch 8 | from lib.cfg_helper import model_cfg_bank 9 | from lib.model_zoo import get_model 10 | from lib.experiments.sd_default import color_adjust, auto_merge_imlist 11 | from torch.utils.data import DataLoader, Dataset 12 | 13 | from lib.model_zoo.vd import VD 14 | from lib.cfg_holder import cfg_unique_holder as cfguh 15 | from lib.cfg_helper import get_command_line_args, cfg_initiates, load_cfg_yaml 16 | import torchvision.transforms as T 17 | from tqdm import tqdm 18 | 19 | # import argparse 20 | # parser = argparse.ArgumentParser(description='Argument Parser') 21 | # parser.add_argument("-sub", "--sub",help="Subject Number",default=1) 22 | # args = parser.parse_args() 23 | # sub=int(args.sub) 24 | # assert sub in [1,2,5,7] 25 | 26 | cfgm_name = 'vd_noema' 27 | 28 | pth = 'versatile_diffusion/pretrained/vd-four-flow-v1-0-fp16-deprecated.pth' 29 | cfgm = model_cfg_bank()(cfgm_name) 30 | net = get_model()(cfgm) 31 | sd = torch.load(pth, map_location='cpu') 32 | net.load_state_dict(sd, strict=False) 33 | 34 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 35 | net.clip = net.clip.to(device) 36 | 37 | class batch_generator_external_images(Dataset): 38 | 39 | def __init__(self, data_path): 40 | self.data_path = data_path 41 | self.im = np.load(data_path).astype(np.uint8) 42 | 43 | 44 | def __getitem__(self,idx): 45 | img = Image.fromarray(self.im[idx]) 46 | img = T.functional.resize(img,(512,512)) 47 | img = T.functional.to_tensor(img).float() 48 | #img = img/255 49 | img = img*2 - 1 50 | return img 51 | 52 | def __len__(self): 53 | return len(self.im) 54 | 55 | batch_size=1 56 | image_path = 'data/thingseeg2_metadata/train_images.npy' 57 | train_images = batch_generator_external_images(data_path = image_path) 58 | 59 | image_path = 'data/thingseeg2_metadata/test_images.npy' 60 | test_images = batch_generator_external_images(data_path = image_path) 61 | 62 | trainloader = DataLoader(train_images,batch_size,shuffle=False) 63 | testloader = DataLoader(test_images,batch_size,shuffle=False) 64 | 65 | num_embed, num_features, num_test, num_train = 257, 768, len(test_images), len(train_images) 66 | 67 | train_clip = np.zeros((num_train,num_embed,num_features)) 68 | test_clip = np.zeros((num_test,num_embed,num_features)) 69 | 70 | if not os.path.exists('cache/thingseeg2_extracted_embeddings'): 71 | os.makedirs('cache/thingseeg2_extracted_embeddings') 72 | with torch.no_grad(): 73 | for i,cin in tqdm(enumerate(testloader), total=len(testloader), desc='Extracting test CLIP vision'): 74 | # print(i) 75 | #ctemp = cin*2 - 1 76 | c = net.clip_encode_vision(cin) 77 | test_clip[i] = c[0].cpu().numpy() 78 | 79 | np.save('cache/thingseeg2_extracted_embeddings/test_clipvision.npy',test_clip) 80 | 81 | for i,cin in tqdm(enumerate(trainloader), total=len(trainloader), desc='Extracting train CLIP vision'): 82 | # print(i) 83 | #ctemp = cin*2 - 1 84 | c = net.clip_encode_vision(cin) 85 | train_clip[i] = c[0].cpu().numpy() 86 | np.save('cache/thingseeg2_extracted_embeddings/train_clipvision.npy',train_clip) 87 | 88 | 89 | 90 | 91 | -------------------------------------------------------------------------------- /scripts-thingseeg2_dataprep/extract_features-ica.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from skimage.transform import resize 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | 6 | ica = np.load("cache/ica.npz") 7 | encoder = ica["encoder"] 8 | decoder = ica["decoder"] 9 | train_mean = ica["mean"] 10 | 11 | print('loading data...') 12 | test_images = np.load("data/thingseeg2_metadata/test_images.npy", mmap_mode="r") 13 | test_images = np.array([resize(image, (64, 64)) for image in test_images]) 14 | train_images = np.load("data/thingseeg2_metadata/train_images.npy", mmap_mode="r") 15 | train_images = np.array([resize(image, (64, 64)) for image in train_images]) 16 | 17 | print('extracting features...') 18 | 19 | image_dim = 64 * 64 * 3 20 | test_data = np.zeros((len(test_images), image_dim)) 21 | for i, image in enumerate(test_images): 22 | test_data[i, :] = image.flatten() 23 | train_data = np.zeros((len(train_images), image_dim)) 24 | for i, image in enumerate(train_images): 25 | train_data[i, :] = image.flatten() 26 | 27 | latent_dim = 1000 28 | test_latents = np.zeros((len(test_data), latent_dim)) 29 | for i, image in tqdm(enumerate(test_data)): 30 | test_latents[i] = encoder[:latent_dim] @ (image - train_mean).reshape(64, 64, 3).T.flatten() 31 | np.save("cache/thingseeg2_extracted_embeddings/test_ica1k.npy", test_latents) 32 | 33 | train_latents = np.zeros((len(train_data), latent_dim)) 34 | for i, image in tqdm(enumerate(train_data)): 35 | train_latents[i] = encoder[:latent_dim] @ (image - train_mean).reshape(64, 64, 3).T.flatten() 36 | np.save("cache/thingseeg2_extracted_embeddings/train_ica1k.npy", train_latents) -------------------------------------------------------------------------------- /scripts-thingseeg2_dataprep/extract_features-pca.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from skimage.transform import resize 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | 6 | pca = np.load("cache/pca.npz") 7 | eigenvectors = pca["eigenvectors"] 8 | eigenvalues = pca["eigenvalues"] 9 | 10 | print('loading data...') 11 | test_images = np.load("data/thingseeg2_metadata/test_images.npy", mmap_mode="r") 12 | test_images = np.array([resize(image, (64, 64)) for image in test_images]) 13 | train_images = np.load("data/thingseeg2_metadata/train_images.npy", mmap_mode="r") 14 | train_images = np.array([resize(image, (64, 64)) for image in train_images]) 15 | 16 | print('extracting features...') 17 | 18 | image_dim = 64 * 64 * 3 19 | test_data = np.zeros((len(test_images), image_dim)) 20 | for i, image in enumerate(test_images): 21 | test_data[i, :] = image.flatten() 22 | train_data = np.zeros((len(train_images), image_dim)) 23 | for i, image in enumerate(train_images): 24 | train_data[i, :] = image.flatten() 25 | 26 | latent_dim = 1000 27 | test_latents = np.zeros((len(test_data), latent_dim)) 28 | for i, image in tqdm(enumerate(test_data)): 29 | test_latents[i] = eigenvectors[:latent_dim] @ image.reshape(64, 64, 3).T.flatten() 30 | np.save("cache/thingseeg2_extracted_embeddings/test_pca1k.npy", test_latents) 31 | 32 | train_latents = np.zeros((len(train_data), latent_dim)) 33 | for i, image in tqdm(enumerate(train_data)): 34 | train_latents[i] = eigenvectors[:latent_dim] @ image.reshape(64, 64, 3).T.flatten() 35 | np.save("cache/thingseeg2_extracted_embeddings/train_pca1k.npy", train_latents) -------------------------------------------------------------------------------- /scripts-thingseeg2_dataprep/grayscale_images.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | from tqdm import tqdm 4 | 5 | test_images = np.load(f'data/thingseeg2_metadata/test_images.npy').astype(np.uint8) 6 | train_images = np.load(f'data/thingseeg2_metadata/train_images.npy').astype(np.uint8) 7 | 8 | print(train_images.shape, test_images.shape) 9 | 10 | # Initialize a list to store the processed images 11 | processed_images = [] 12 | 13 | # Loop through each image 14 | for img in tqdm(train_images, total=len(train_images)): 15 | # Convert the image to grayscale 16 | grayscale_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) 17 | 18 | # Convert the grayscale image to a 3-channel image 19 | grayscale_img = cv2.cvtColor(grayscale_img, cv2.COLOR_GRAY2RGB) 20 | 21 | # Append the processed image to the list 22 | processed_images.append(grayscale_img) 23 | 24 | # Convert the list back to a NumPy array 25 | processed_images = np.array(processed_images) 26 | 27 | # Save the processed images 28 | np.save('data/thingseeg2_metadata/train_images_grayscale.npy', processed_images) 29 | 30 | # Initialize a list to store the processed images 31 | processed_images = [] 32 | 33 | # Loop through each image 34 | for img in tqdm(test_images, total=len(test_images)): 35 | # Convert the image to grayscale 36 | grayscale_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) 37 | 38 | # Convert the grayscale image to a 3-channel image 39 | grayscale_img = cv2.cvtColor(grayscale_img, cv2.COLOR_GRAY2RGB) 40 | 41 | # Append the processed image to the list 42 | processed_images.append(grayscale_img) 43 | 44 | # Convert the list back to a NumPy array 45 | processed_images = np.array(processed_images) 46 | 47 | # Save the processed images 48 | np.save('data/thingseeg2_metadata/test_images_grayscale.npy', processed_images) -------------------------------------------------------------------------------- /scripts-thingseeg2_dataprep/ica.py: -------------------------------------------------------------------------------- 1 | from sklearn.decomposition import FastICA 2 | import numpy as np 3 | import time, os 4 | from tqdm import tqdm 5 | 6 | pca = np.load("cache/pca.npz") 7 | eigenvectors = pca["eigenvectors"].T 8 | eigenvalues = pca["eigenvalues"] 9 | D = np.diag(1.0 / np.sqrt(eigenvalues[:1000])) 10 | whitening_matrix = D @ eigenvectors[:, :1000].T 11 | 12 | # constants 13 | num_train_samples = 1281167 14 | num_val_samples = 50000 15 | num_test_samples = 16740 16 | image_dim = 64 * 64 * 3 17 | num_components = 1000 18 | 19 | # reshape and rotate 90 degrees counter-clockwise 20 | def format(image): 21 | return np.rot90(np.reshape(image, (64, 64, 3), order="F"), k=-1) 22 | 23 | # load images 24 | start_time = time.time() # Start timing 25 | train_data = np.zeros((num_train_samples, image_dim)) 26 | for i in tqdm(range(9)): 27 | batch = np.load(f"data/imagenet64/train_data_batch_{i + 1}.npz") 28 | train_data[i * 128116:(i + 1) * 128116, :] = batch["data"] / 255 29 | batch = np.load("data/imagenet64/train_data_batch_10.npz") 30 | train_data[9 * 128116:, :] = batch["data"] / 255 31 | end_time = time.time() # End timing 32 | print(f"Time taken to load images: {end_time - start_time} seconds") # 190 seconds 33 | 34 | train_mean = np.mean(train_data, axis=0) 35 | np.save("data/imagenet64/train_mean.npy", train_mean) 36 | 37 | train_mean = np.load("data/imagenet64/train_mean.npy") 38 | # load images 39 | start_time = time.time() # Start timing 40 | for i in tqdm(range(9)): 41 | batch = np.load(f"data/imagenet64/train_data_batch_{i + 1}.npz") 42 | # pre-whiten to save memory 43 | whitened_batch = (batch["data"] / 255 - train_mean) @ whitening_matrix.T 44 | np.save(f"data/imagenet64/whitened_batch_{i + 1}.npy", whitened_batch) 45 | batch = np.load("data/imagenet64/train_data_batch_10.npz") 46 | whitened_batch = (batch["data"] / 255 - train_mean) @ whitening_matrix.T 47 | np.save(f"data/imagenet64/whitened_batch_{10}.npy", whitened_batch) 48 | end_time = time.time() # End timing 49 | print(f"Time taken to whiten images: {end_time - start_time} seconds") # 3304 seconds 50 | 51 | # loading whitened data 52 | start_time = time.time() # Start timing 53 | whitened_data = np.zeros((num_train_samples, num_components), dtype=np.float32) # saving memory by not using float64 54 | for i in range(9): 55 | print(i) 56 | whitened_data[i * 128116:(i + 1) * 128116, :] = np.load(f"data/imagenet64/whitened_batch_{i + 1}.npy") 57 | whitened_data[9 * 128116:, :] = np.load(f"data/imagenet64/whitened_batch_10.npy") 58 | end_time = time.time() # End timing 59 | print(f"Time taken to load whitened images: {end_time - start_time} seconds") # 38 seconds 60 | 61 | start_time = time.time() # Start timing 62 | ica1000 = FastICA(random_state=0, whiten=False, tol=0.001) 63 | # ica = FastICA(n_components=1000, random_state=0, whiten=False, tol=0.001) 64 | ica1000.fit(whitened_data) 65 | np.save("cache/ica1k_components.npy", ica1000.components_) 66 | end_time = time.time() # End timing 67 | print(f"Time taken to compute ICA: {end_time - start_time} seconds") # 3707 seconds 68 | 69 | start_time = time.time() # Start timing 70 | components_1000 = np.load("cache/ica1k_components.npy") 71 | encoder = components_1000 @ whitening_matrix 72 | decoder = np.linalg.pinv(encoder) 73 | np.savez("cache/ica.npz", encoder=encoder, decoder=decoder, mean=train_mean) 74 | end_time = time.time() # End timing 75 | print(f"Time taken to save encoder and decoder: {end_time - start_time} seconds") # 7 seconds -------------------------------------------------------------------------------- /scripts-thingseeg2_dataprep/pca.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time, os 3 | from tqdm import tqdm 4 | 5 | # constants 6 | num_train_samples = 1281167 7 | num_val_samples = 50000 8 | num_test_samples = 16740 9 | image_dim = 64 * 64 * 3 10 | 11 | # reshape and rotate 90 degrees counter-clockwise 12 | def format(image): 13 | return np.rot90(np.reshape(image, (64, 64, 3), order="F"), k=-1) 14 | 15 | # load images 16 | start_time = time.time() # Start timing 17 | train_data = np.zeros((num_train_samples, image_dim)) 18 | for i in tqdm(range(9)): 19 | batch = np.load(f"data/imagenet64/train_data_batch_{i + 1}.npz") 20 | # print(batch["data"].shape) 21 | train_data[i * 128116:(i + 1) * 128116, :] = batch["data"] / 255 22 | batch = np.load("data/imagenet64/train_data_batch_10.npz") 23 | train_data[9 * 128116:, :] = batch["data"] / 255 24 | val_data = np.load("data/imagenet64/val_data.npz")["data"] / 255 25 | end_time = time.time() # End timing 26 | print(f"Time taken to load images: {end_time - start_time} seconds") # 352 seconds 27 | 28 | start_time = time.time() # Start timing 29 | train_mean = np.mean(train_data, axis=0) 30 | np.save("data/imagenet64/train_mean.npy", train_mean) 31 | end_time = time.time() # End timing 32 | print(f"Time taken to compute mean: {end_time - start_time} seconds") # 45 seconds 33 | 34 | # subtract mean and save 35 | start_time = time.time() # Start timing 36 | train_mean = np.load("data/imagenet64/train_mean.npy") 37 | for i in tqdm(range(10)): 38 | batch = np.load(f"data/imagenet64/train_data_batch_{i + 1}.npz") 39 | # pre-whiten to save memory 40 | mean_subtracted_batch = (batch["data"] / 255 - train_mean) 41 | np.save(f"data/imagenet64/train_data_batch_mean_subtracted_{i + 1}.npy", mean_subtracted_batch) 42 | end_time = time.time() # End timing 43 | print(f"Time taken to subtract mean: {end_time - start_time} seconds") # 3304 seconds 44 | 45 | # load mean-subtracted images 46 | start_time = time.time() # Start timing 47 | train_data = np.zeros((num_train_samples, image_dim)) 48 | for i in tqdm(range(9)): 49 | train_data[i * 128116:(i + 1) * 128116, :] = np.load(f"data/imagenet64/train_data_batch_mean_subtracted_{i + 1}.npy") 50 | train_data[9 * 128116:, :] = np.load(f"data/imagenet64/train_data_batch_mean_subtracted_{10}.npy") 51 | end_time = time.time() # End timing 52 | print(f"Time taken to load mean-subtracted images: {end_time - start_time} seconds") # 163 seconds 53 | 54 | start_time = time.time() # Start timing 55 | cov = np.cov(train_data, rowvar=False) 56 | os.makedirs('cache/imagenet64', exist_ok=True) 57 | np.save("cache/imagenet64/cov.npy", cov) 58 | end_time = time.time() # End timing 59 | print(f"Time taken to compute covariance: {end_time - start_time} seconds") # 718 seconds 60 | # cov = np.load("cache/imagenet64/cov.npy") 61 | 62 | 63 | # compute PCA 64 | start_time = time.time() # Start timing 65 | eigenvalues, eigenvectors = np.linalg.eigh(cov) 66 | eigenvectors = np.fliplr(eigenvectors).T 67 | eigenvalues = np.flip(eigenvalues) 68 | np.savez('cache/pca.npz', eigenvectors=eigenvectors, eigenvalues=eigenvalues) 69 | end_time = time.time() # End timing 70 | print(f"Time taken to compute PCA: {end_time - start_time} seconds") # 117 seconds -------------------------------------------------------------------------------- /scripts-thingseeg2_dataprep/prepare_all_subjects_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | python thingseeg2_data_preparation_scripts/prepare_thingseeg2_data.py -avg 5 3 | python thingseeg2_data_preparation_scripts/prepare_thingseeg2_data.py -avg 10 4 | python thingseeg2_data_preparation_scripts/prepare_thingseeg2_data.py -avg 20 5 | python thingseeg2_data_preparation_scripts/prepare_thingseeg2_data.py -avg 30 6 | python thingseeg2_data_preparation_scripts/prepare_thingseeg2_data.py -avg 40 7 | python thingseeg2_data_preparation_scripts/prepare_thingseeg2_data.py -avg 60 -------------------------------------------------------------------------------- /scripts-thingseeg2_dataprep/prepare_thingseeg2_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import argparse 4 | parser = argparse.ArgumentParser(description='Argument Parser') 5 | parser.add_argument('-avg', '--average', help='Number of averages', default=80) 6 | args = parser.parse_args() 7 | average=int(args.average) 8 | if average != 80: 9 | param = f'{average}' 10 | else: 11 | param = '' 12 | 13 | for sub in range(1, 11): 14 | data_dir = f'data/thingseeg2_preproc/sub-{sub:02d}/' 15 | 16 | if average == 80: 17 | eeg_data_train = np.load(data_dir + 'preprocessed_eeg_training.npy', allow_pickle=True).item() 18 | print(f'\nTraining EEG data shape for sub-{sub:02d}:') 19 | print(eeg_data_train['preprocessed_eeg_data'].shape) 20 | print('(Training image conditions × Training EEG repetitions × EEG channels × ' 21 | 'EEG time points)') 22 | train_thingseeg2 = eeg_data_train['preprocessed_eeg_data'][:,:,:,20:] 23 | train_thingseeg2_avg = eeg_data_train['preprocessed_eeg_data'].mean(1)[:,:,20:] 24 | train_thingseeg2_avg_null = eeg_data_train['preprocessed_eeg_data'].mean(1)[:,:,:20] 25 | np.save(data_dir + 'train_thingseeg2.npy', train_thingseeg2) 26 | np.save(data_dir + 'train_thingseeg2_avg.npy', train_thingseeg2_avg) 27 | np.save(data_dir + 'train_thingseeg2_avg_null.npy', train_thingseeg2_avg_null) 28 | 29 | eeg_data_test = np.load(data_dir + 'preprocessed_eeg_test.npy', allow_pickle=True).item() 30 | print(f'\nTest EEG data shape for sub-{sub:02d}:') 31 | print(eeg_data_test['preprocessed_eeg_data'].shape) 32 | print('(Test image conditions × Test EEG repetitions × EEG channels × ' 33 | 'EEG time points)') 34 | test_thingseeg2 = eeg_data_test['preprocessed_eeg_data'][:,:,:,20:] 35 | test_thingseeg2_avg = eeg_data_test['preprocessed_eeg_data'][:,:average].mean(1)[:,:,20:] 36 | test_thingseeg2_avg_null = eeg_data_test['preprocessed_eeg_data'][:,:average].mean(1)[:,:,:20] 37 | np.save(data_dir + 'test_thingseeg2.npy', test_thingseeg2) 38 | np.save(data_dir + f'test_thingseeg2_avg{param}.npy', test_thingseeg2_avg) 39 | np.save(data_dir + f'test_thingseeg2_avg{param}_null.npy', test_thingseeg2_avg_null) 40 | -------------------------------------------------------------------------------- /scripts-thingseeg2_dataprep/save_thingseeg2_concepts.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | 4 | img_metadata = np.load('data/thingseeg2_metadata/image_metadata.npy',allow_pickle=True).item() 5 | n_train_img = len(img_metadata['train_img_concepts']) 6 | n_test_img = len(img_metadata['test_img_concepts']) 7 | 8 | train_concepts = [] 9 | test_concepts = [] 10 | 11 | for train_img_idx in tqdm(range(n_train_img), total=n_train_img, desc='Loading train images'): 12 | train_concepts.append(' '.join(img_metadata['train_img_concepts'][train_img_idx].split('_')[1:])) 13 | train_concepts = np.array(train_concepts) 14 | 15 | np.save('data/thingseeg2_metadata/train_concepts.npy', train_concepts) 16 | 17 | for test_img_idx in tqdm(range(n_test_img), total=n_test_img, desc='Loading test images'): 18 | test_concepts.append(' '.join(img_metadata['test_img_concepts'][test_img_idx].split('_')[1:])) 19 | test_concepts = np.array(test_concepts) 20 | 21 | np.save('data/thingseeg2_metadata/test_concepts.npy', test_concepts) -------------------------------------------------------------------------------- /scripts-thingseeg2_dataprep/save_thingseeg2_images.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | from tqdm import tqdm 5 | 6 | img_metadata = np.load('data/thingseeg2_metadata/image_metadata.npy',allow_pickle=True).item() 7 | n_train_img = len(img_metadata['train_img_concepts']) 8 | n_test_img = len(img_metadata['test_img_concepts']) 9 | 10 | train_img = np.zeros((n_train_img, 500, 500, 3), dtype=np.uint8) 11 | test_img = np.zeros((n_test_img, 500, 500, 3), dtype=np.uint8) 12 | 13 | for train_img_idx in tqdm(range(n_train_img), total=n_train_img, desc='Loading train images'): 14 | train_img_dir = os.path.join('data/thingseeg2_metadata', 'training_images', 15 | img_metadata['train_img_concepts'][train_img_idx], 16 | img_metadata['train_img_files'][train_img_idx]) 17 | train_img[train_img_idx] = np.array(Image.open(train_img_dir).convert('RGB')) 18 | 19 | np.save('data/thingseeg2_metadata/train_images.npy', train_img) 20 | 21 | for test_img_idx in tqdm(range(n_test_img), total=n_test_img, desc='Loading test images'): 22 | test_img_dir = os.path.join('data/thingseeg2_metadata', 'test_images', 23 | img_metadata['test_img_concepts'][test_img_idx], 24 | img_metadata['test_img_files'][test_img_idx]) 25 | test_img[test_img_idx] = np.array(Image.open(test_img_dir).convert('RGB')) 26 | 27 | np.save('data/thingseeg2_metadata/test_images.npy', test_img) 28 | 29 | test_images_dir = 'data/thingseeg2_metadata/test_images_direct/' 30 | 31 | if not os.path.exists(test_images_dir): 32 | os.makedirs(test_images_dir) 33 | for i in tqdm(range(len(test_img)), total=len(test_img), desc='Saving direct test images'): 34 | im = Image.fromarray(test_img[i].astype(np.uint8)) 35 | im.save('{}/{}.png'.format(test_images_dir,i)) -------------------------------------------------------------------------------- /scripts-thingseeg2_figures/evaluate_ablation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | for sub in 1 2 3 4 3 | do 4 | echo "Evaluating ablation for subject $sub" 5 | python thingseeg2_scripts/evaluate_reconstruction.py -sub $sub --no-clipvision --no-cliptext 6 | python thingseeg2_scripts/evaluate_reconstruction.py -sub $sub --no-vdvae --no-clipvision 7 | python thingseeg2_scripts/evaluate_reconstruction.py -sub $sub --no-clipvision 8 | python thingseeg2_scripts/evaluate_reconstruction.py -sub $sub --no-vdvae --no-cliptext 9 | python thingseeg2_scripts/evaluate_reconstruction.py -sub $sub --no-cliptext 10 | python thingseeg2_scripts/evaluate_reconstruction.py -sub $sub --no-vdvae 11 | done -------------------------------------------------------------------------------- /scripts-thingseeg2_figures/evaluate_across_duration.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | for sub in 1 2 3 4 3 | do 4 | for duration in 0 5 10 20 40 60 80 5 | do 6 | echo "sub $sub duration $duration" 7 | python thingseeg2_scripts/evaluate_reconstruction.py -sub $sub -duration $duration 8 | done 9 | done 10 | 11 | -------------------------------------------------------------------------------- /scripts-thingseeg2_figures/evaluate_across_size_num_avg.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | for size in 500 1000 2000 4000 6000 10000 14000 3 | # for size in 2000 4000 6000 10000 14000 4 | do 5 | for avg in 5 10 20 30 40 60 6 | do 7 | python thingseeg2_scripts/evaluate_reconstruction.py -sub 1 -size $size -avg $avg 8 | done 9 | python thingseeg2_scripts/evaluate_reconstruction.py -sub 1 -size $size 10 | done 11 | for avg in 5 10 20 30 40 60 12 | do 13 | python thingseeg2_scripts/evaluate_reconstruction.py -sub 1 -avg $avg 14 | done 15 | -------------------------------------------------------------------------------- /scripts-thingseeg2_figures/fig_ablations.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import matplotlib.cm as cm 4 | 5 | subs = [1,2,3,4] 6 | ablations = ['vdvae', 'novdvae_noclipvision', 'noclipvision', 'novdvae_nocliptext', 'nocliptext', 'novdvae', 'versatile_diffusion'] 7 | performance = np.zeros((len(ablations), len(subs), 8)) 8 | for i, ablation in enumerate(ablations): 9 | for j, sub in enumerate(subs): 10 | if ablation == 'vdvae': 11 | performance[i, j] = np.load(f'results/thingseeg2_preproc/sub-{sub:02d}/performances_vdvae.npy') 12 | elif ablation == 'versatile_diffusion': 13 | performance[i, j] = np.load(f'results/thingseeg2_preproc/sub-{sub:02d}/performances_versatile_diffusion.npy') 14 | else: 15 | performance[i, j] = np.load(f'results/thingseeg2_preproc/sub-{sub:02d}/performances_versatile_diffusion_{ablation}.npy') 16 | performance_mean = np.mean(performance, axis=1) 17 | performance_std = np.std(performance, axis=1) 18 | 19 | # Set the style of the plot 20 | plt.style.use("ggplot") 21 | 22 | # Define the figure size 23 | plt.figure(figsize=(9, 6)) 24 | 25 | # Define the bar width 26 | bar_width = 0.12 27 | 28 | # Define the x values for the bar plots 29 | x = np.arange(len(performance_mean[0])) 30 | 31 | # Define the colors for each duration using matplotlib's colormap 32 | colors = cm.get_cmap('tab20c', 7) 33 | 34 | ablate = ['VDVAE only', 'CLIP-Text only', 'no CLIP-Vision', 'CLIP-Vision only', 'no CLIP-Text', 'no VDVAE', 'full model'] 35 | # Plot the bar plots with error bars 36 | for i in range(7): 37 | plt.bar(x + i*bar_width, performance_mean[i], color=colors(i), width=bar_width, label=f'{ablate[i]}', yerr=performance_std[i], capsize=4) 38 | 39 | # Define the labels for the x-axis 40 | labels = ['PixCorr↑', 'SSIM↑', 'Alex(2)↑', 'Alex(5)↑', 'Incep↑', 'CLIP↑', 'Eff↓', 'SwAV↓'] # replace with your actual labels 41 | 42 | # Set the current tick locations and labels of the x-axis 43 | plt.xticks(x + bar_width*3.5, labels, fontsize=12) # Adjusted to center the labels 44 | plt.yticks(fontsize=12) 45 | plt.ylim(0, 1) 46 | 47 | # Add a legend 48 | plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=12) 49 | 50 | plt.title('Model Ablations', fontsize=16) 51 | plt.ylabel('Performance', fontsize=14) 52 | plt.xlabel('Metric', fontsize=14) 53 | 54 | # Display the plot 55 | plt.tight_layout() 56 | 57 | plt.savefig('results/thingseeg2_preproc/fig_ablations.png') 58 | plt.savefig('results/thingseeg2_preproc/fig_ablations.svg', format='svg') -------------------------------------------------------------------------------- /scripts-thingseeg2_figures/fig_across_durations.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import matplotlib.cm as cm 3 | import numpy as np 4 | 5 | subs = [1,2,3,4] 6 | duration = [0, 5, 10, 20, 40, 60, 80] 7 | performance = np.zeros((len(duration), len(subs), 8)) 8 | for i, dur in enumerate(duration): 9 | for j, sub in enumerate(subs): 10 | if dur == 80: 11 | performance[i, j] = np.load(f'results/thingseeg2_preproc/sub-{sub:02d}/performances_versatile_diffusion.npy') 12 | else: 13 | performance[i, j] = np.load(f'results/thingseeg2_preproc/sub-{sub:02d}/performances_versatile_diffusion_16540avg_dur{dur}.npy') 14 | performance_mean = np.mean(performance, axis=1) 15 | performance_std = np.std(performance, axis=1) 16 | 17 | import matplotlib.cm as cm 18 | 19 | # Set the style of the plot 20 | plt.style.use("ggplot") 21 | 22 | # Define the figure size 23 | plt.figure(figsize=(9, 6)) 24 | 25 | # Define the bar width 26 | bar_width = 0.12 27 | 28 | # Define the x values for the bar plots 29 | x = np.arange(len(performance_mean[0])) 30 | 31 | # Define the colors for each duration using matplotlib's colormap 32 | colors = cm.get_cmap('tab20c', 7) 33 | 34 | # Plot the bar plots with error bars 35 | for i in range(7): 36 | plt.bar(x + i*bar_width, performance_mean[i], color=colors(i), width=bar_width, label=f'{duration[i]*10}ms', yerr=performance_std[i], capsize=4) 37 | 38 | # Define the labels for the x-axis 39 | labels = ['PixCorr↑', 'SSIM↑', 'Alex(2)↑', 'Alex(5)↑', 'Incep↑', 'CLIP↑', 'Eff↓', 'SwAV↓'] # replace with your actual labels 40 | 41 | # Set the current tick locations and labels of the x-axis 42 | plt.xticks(x + bar_width*3.5, labels, fontsize=12) # Adjusted to center the labels 43 | plt.yticks(fontsize=12) 44 | plt.ylim(0, 1) 45 | 46 | # Add a legend 47 | plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=12) 48 | 49 | plt.title('Performance Across Durations', fontsize=16) 50 | plt.ylabel('Performance', fontsize=14) 51 | plt.xlabel('Metric', fontsize=14) 52 | 53 | # Display the plot 54 | plt.tight_layout() 55 | 56 | plt.savefig('results/thingseeg2_preproc/fig_across_duration.png') 57 | plt.savefig('results/thingseeg2_preproc/fig_across_duration.svg', format='svg') -------------------------------------------------------------------------------- /scripts-thingseeg2_figures/fig_across_size_num_avg.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | sizes = [500, 1000, 2000, 4000, 6000, 10000, 14000, 16540] 5 | avgs = [5, 10, 20, 30, 40, 60, 80] 6 | performance = np.zeros((len(sizes), len(avgs), 8)) 7 | sub = 1 8 | for i, size in enumerate(sizes): 9 | for j, avg in enumerate(avgs): 10 | if avg == 80 and size == 16540: 11 | performance[i, j] = np.load(f'results/thingseeg2_preproc/sub-{sub:02d}/performances_versatile_diffusion.npy') 12 | elif avg == 80: 13 | performance[i, j] = np.load(f'results/thingseeg2_preproc/sub-{sub:02d}/performances_versatile_diffusion_{size}avg.npy') 14 | else: 15 | performance[i, j] = np.load(f'results/thingseeg2_preproc/sub-{sub:02d}/performances_versatile_diffusion_{size}avg{avg}.npy') 16 | 17 | 18 | # Create a figure 19 | fig, ax = plt.subplots(figsize=(10*0.7, 8*0.7)) 20 | 21 | # Calculate the edges of the cells for the x and y coordinates 22 | avgs_edges = np.concatenate([[0], avgs]) 23 | 24 | # Calculate the edges for the sizes 25 | sizes_edges = np.concatenate([[0], sizes]) 26 | 27 | # Calculate the centers of the cells for the x and y coordinates 28 | avgs_centers = (avgs_edges[:-1] + avgs_edges[1:]) / 2 29 | sizes_centers = (sizes_edges[:-1] + sizes_edges[1:]) / 2 30 | 31 | # Create a meshgrid for the x and y coordinates 32 | X, Y = np.meshgrid(avgs_edges, sizes_edges) 33 | 34 | # Create the heatmap using pcolormesh 35 | c = ax.pcolormesh(X, Y, performance[:,:,5], cmap='viridis') 36 | 37 | # Create colorbar 38 | fig.colorbar(c, ax=ax, label='Color scale') 39 | 40 | # Set the labels for the x ticks 41 | ax.set_xticks(avgs_edges) 42 | # ax.set_xticklabels(avgs) 43 | ax.set_xlabel('Number of averaged test trials per image') 44 | 45 | # Set the labels for the y ticks 46 | ax.set_yticks(sizes_edges) 47 | ax.set_ylabel('Number of training images') 48 | 49 | # Set the title 50 | ax.set_title('CLIP Performance of Subject 1') 51 | 52 | # Loop over data dimensions and create text annotations. 53 | for i in range(performance.shape[0]): 54 | for j in range(performance.shape[1]): 55 | ax.text(avgs_centers[j], sizes_centers[i], f'{performance[i, j, 5]:.2g}', 56 | ha="center", va="center", color="w", fontsize=7) 57 | 58 | plt.savefig('results/thingseeg2_preproc/fig_CLIP_across_size_num_avg.png') 59 | plt.savefig('results/thingseeg2_preproc/fig_CLIP_across_size_num_avg.svg', format='svg') -------------------------------------------------------------------------------- /scripts-thingseeg2_figures/fig_performance.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import matplotlib.cm as cm 4 | 5 | subs = list(range(1,11)) 6 | seeds = [1, 2, 4, 5, 6, 7, 8] 7 | performance = np.zeros((len(seeds), len(subs), 8)) 8 | for i, seed in enumerate(seeds): 9 | for j, sub in enumerate(subs): 10 | performance[i, j] = np.load(f'results/thingseeg2_preproc/sub-{sub:02d}/performances_versatile_diffusion_seed{seed}.npy') 11 | performance_mean = np.mean(performance, axis=0) 12 | performance_std = np.std(performance, axis=0) 13 | 14 | 15 | # Set the style of the plot 16 | plt.style.use("ggplot") 17 | 18 | # Define the figure size 19 | plt.figure(figsize=(9, 6)) 20 | 21 | # Define the bar width 22 | bar_width = 0.08 23 | 24 | # Define the x values for the bar plots 25 | x = np.arange(len(performance_mean[0])) 26 | 27 | # Define the colors for each subject using matplotlib's colormap 28 | colors = cm.get_cmap('tab20c', 10) 29 | 30 | # Plot the bar plots with error bars 31 | for i in range(10): 32 | plt.bar(x + i*bar_width, performance_mean[i], color=colors(i), width=bar_width, label=f'subject {i+1}', yerr=performance_std[i], capsize=3) 33 | 34 | # Define the labels for the x-axis 35 | labels = ['PixCorr↑', 'SSIM↑', 'Alex(2)↑', 'Alex(5)↑', 'Incep↑', 'CLIP↑', 'Eff↓', 'SwAV↓'] # replace with your actual labels 36 | 37 | # Set the current tick locations and labels of the x-axis 38 | plt.xticks(x + bar_width*4.5, labels, fontsize=12) # Adjusted to center the labels 39 | plt.yticks(fontsize=12) 40 | plt.ylim(0, 1) 41 | 42 | # Add a legend 43 | plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=12) 44 | 45 | plt.title('Performance Across Subjects', fontsize=16) 46 | plt.ylabel('Performance', fontsize=14) 47 | plt.xlabel('Metric', fontsize=14) 48 | 49 | # Display the plot 50 | plt.tight_layout() 51 | plt.savefig('results/thingseeg2_preproc/fig_performance.png') 52 | plt.savefig('results/thingseeg2_preproc/fig_performance.svg', format='svg') -------------------------------------------------------------------------------- /scripts-thingseeg2_figures/plot_ablations_recon.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | for sub in 1 3 | do 4 | echo "Plotting ablation for subject $sub" 5 | # python thingseeg2_scripts/plot_reconstructions.py -sub $sub --no-clipvision --no-cliptext 6 | python thingseeg2_scripts/plot_reconstructions.py -sub $sub --no-vdvae --no-clipvision 7 | python thingseeg2_scripts/plot_reconstructions.py -sub $sub --no-clipvision 8 | python thingseeg2_scripts/plot_reconstructions.py -sub $sub --no-vdvae --no-cliptext 9 | python thingseeg2_scripts/plot_reconstructions.py -sub $sub --no-cliptext 10 | python thingseeg2_scripts/plot_reconstructions.py -sub $sub --no-vdvae 11 | done -------------------------------------------------------------------------------- /scripts-thingseeg2_figures/reconstruct_ablation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | for sub in 1 2 3 4 3 | do 4 | python thingseeg2_scripts/reconstruct_from_embeddings.py -sub $sub -gpu1 0 -gpu2 0 --no-vdvae --no-clipvision & 5 | python thingseeg2_scripts/reconstruct_from_embeddings.py -sub $sub -gpu1 1 -gpu2 1 --no-clipvision & 6 | python thingseeg2_scripts/reconstruct_from_embeddings.py -sub $sub -gpu1 3 -gpu2 3 --no-vdvae --no-cliptext & 7 | python thingseeg2_scripts/reconstruct_from_embeddings.py -sub $sub -gpu1 4 -gpu2 4 --no-cliptext & 8 | python thingseeg2_scripts/reconstruct_from_embeddings.py -sub $sub -gpu1 5 -gpu2 5 --no-vdvae & 9 | python thingseeg2_scripts/reconstruct_from_embeddings.py -sub $sub -gpu1 6 -gpu2 6 10 | wait 11 | done -------------------------------------------------------------------------------- /scripts-thingseeg2_figures/reconstruct_across_duration.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | for sub in 1 2 3 4 3 | do 4 | python thingseeg2_scripts/reconstruct_from_embeddings.py -sub $sub -gpu1 0 -gpu2 0 -duration 0 & 5 | python thingseeg2_scripts/reconstruct_from_embeddings.py -sub $sub -gpu1 1 -gpu2 1 -duration 5 & 6 | python thingseeg2_scripts/reconstruct_from_embeddings.py -sub $sub -gpu1 3 -gpu2 3 -duration 10 & 7 | python thingseeg2_scripts/reconstruct_from_embeddings.py -sub $sub -gpu1 4 -gpu2 4 -duration 20 & 8 | python thingseeg2_scripts/reconstruct_from_embeddings.py -sub $sub -gpu1 5 -gpu2 5 -duration 40 & 9 | python thingseeg2_scripts/reconstruct_from_embeddings.py -sub $sub -gpu1 6 -gpu2 6 -duration 60 & 10 | python thingseeg2_scripts/reconstruct_from_embeddings.py -sub $sub -gpu1 7 -gpu2 7 -duration 80 & 11 | wait 12 | done -------------------------------------------------------------------------------- /scripts-thingseeg2_figures/train_across_duration.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | for sub in 1 2 3 4 3 | do 4 | for duration in 0 5 10 20 40 60 80 5 | do 6 | echo "sub $sub duration $duration" 7 | python thingseeg2_scripts/train_regression.py -sub $sub -duration $duration 8 | done 9 | done 10 | 11 | -------------------------------------------------------------------------------- /scripts-thingseeg2_figures/train_all_subjects.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | python thingseeg2_scripts/train_regression.py -sub 1 3 | python thingseeg2_scripts/train_regression.py -sub 2 4 | python thingseeg2_scripts/train_regression.py -sub 3 5 | python thingseeg2_scripts/train_regression.py -sub 4 6 | python thingseeg2_scripts/train_regression.py -sub 5 7 | python thingseeg2_scripts/train_regression.py -sub 6 8 | python thingseeg2_scripts/train_regression.py -sub 7 9 | python thingseeg2_scripts/train_regression.py -sub 8 10 | python thingseeg2_scripts/train_regression.py -sub 9 11 | python thingseeg2_scripts/train_regression.py -sub 10 -------------------------------------------------------------------------------- /scripts-thingseeg2_transfer_learning/average_thingseeg2_subjects.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | for sub in range(1, 11): 5 | other_train_thingseeg2_avg = [] 6 | other_subjects = [i for i in range(1, 11) if i != sub] 7 | print(other_subjects) 8 | for other_sub in other_subjects: 9 | data_dir = f'data/thingseeg2_preproc/sub-{other_sub:02d}/' 10 | train_thingseeg2_avg = np.load(data_dir + f'train_thingseeg2_avg.npy') 11 | other_train_thingseeg2_avg.append(train_thingseeg2_avg) 12 | other_train_thingseeg2_avg = np.stack(other_train_thingseeg2_avg) 13 | other_train_thingseeg2_avg = other_train_thingseeg2_avg.mean(0) 14 | data_dir = f'cache/thingseeg2_preproc/transfer/sub-{sub:02d}/' 15 | if not os.path.exists(data_dir): 16 | os.makedirs(data_dir) 17 | np.save(data_dir + f'train_thingseeg2_avg_other.npy', other_train_thingseeg2_avg) -------------------------------------------------------------------------------- /scripts-thingsmeg/avg1b_regression_prediction_.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | 4 | with open('cache/regression_weights/BIGMEG1/thingsmeg_regress_autokl1b_weights_sub-BIGMEG1.pkl',"rb") as f: 5 | datadict = pickle.load(f) 6 | reg_w = datadict['weight'] 7 | reg_b = datadict['bias'] -------------------------------------------------------------------------------- /scripts-thingsmeg/cliptext_extract_features.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('versatile_diffusion') 3 | import os 4 | import numpy as np 5 | 6 | import torch 7 | from lib.cfg_helper import model_cfg_bank 8 | from lib.model_zoo import get_model 9 | from torch.utils.data import DataLoader, Dataset 10 | 11 | from lib.model_zoo.vd import VD 12 | from lib.cfg_holder import cfg_unique_holder as cfguh 13 | from lib.cfg_helper import get_command_line_args, cfg_initiates, load_cfg_yaml 14 | import matplotlib.pyplot as plt 15 | import torchvision.transforms as T 16 | 17 | import argparse 18 | parser = argparse.ArgumentParser(description='Argument Parser') 19 | parser.add_argument("-sub", "--sub",help="Subject Number",default=1) 20 | args = parser.parse_args() 21 | sub=int(args.sub) 22 | assert sub in [1,2,5,7] 23 | 24 | cfgm_name = 'vd_noema' 25 | pth = 'versatile_diffusion/pretrained/vd-four-flow-v1-0-fp16-deprecated.pth' 26 | cfgm = model_cfg_bank()(cfgm_name) 27 | net = get_model()(cfgm) 28 | sd = torch.load(pth, map_location='cpu') 29 | net.load_state_dict(sd, strict=False) 30 | 31 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 32 | net.clip = net.clip.to(device) 33 | 34 | train_caps = np.load('cache/thingsmeg/processed_data/BIGMEG1/train_captions1b_sub-BIGMEG1.npy', mmap_mode='r') 35 | test_caps = np.load('cache/thingsmeg/processed_data/BIGMEG1/test_captions1b_sub-BIGMEG1.npy', mmap_mode='r') 36 | print(train_caps.shape, test_caps.shape) 37 | 38 | num_embed, num_features, num_test, num_train = 77, 768, len(test_caps), len(train_caps) 39 | 40 | train_clip = np.zeros((num_train,num_embed, num_features)) 41 | test_clip = np.zeros((num_test,num_embed, num_features)) 42 | with torch.no_grad(): 43 | for i,annots in enumerate(test_caps): 44 | # cin = list(annots[annots!='']) 45 | cin = [annots] 46 | # print(i) 47 | print(i, cin) 48 | c = net.clip_encode_text(cin) 49 | test_clip[i] = c.to('cpu').numpy().mean(0) 50 | 51 | np.save('cache/thingsmeg/extracted_embeddings/BIGMEG1/test_cliptext1b_sub-BIGMEG1.npy',test_clip) 52 | 53 | for i,annots in enumerate(train_caps): 54 | # cin = list(annots[annots!='']) 55 | cin = [annots] 56 | # print(i) 57 | print(i, cin) 58 | c = net.clip_encode_text(cin) 59 | train_clip[i] = c.to('cpu').numpy().mean(0) 60 | np.save('cache/thingsmeg/extracted_embeddings/BIGMEG1/train_cliptext1b_sub-BIGMEG1.npy',train_clip) 61 | 62 | 63 | -------------------------------------------------------------------------------- /scripts-thingsmeg/clipvision_extract_features.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('versatile_diffusion') 3 | import os 4 | import PIL 5 | from PIL import Image 6 | import numpy as np 7 | 8 | import torch 9 | from lib.cfg_helper import model_cfg_bank 10 | from lib.model_zoo import get_model 11 | from lib.experiments.sd_default import color_adjust, auto_merge_imlist 12 | from torch.utils.data import DataLoader, Dataset 13 | 14 | from lib.model_zoo.vd import VD 15 | from lib.cfg_holder import cfg_unique_holder as cfguh 16 | from lib.cfg_helper import get_command_line_args, cfg_initiates, load_cfg_yaml 17 | import torchvision.transforms as T 18 | 19 | import argparse 20 | parser = argparse.ArgumentParser(description='Argument Parser') 21 | parser.add_argument("-sub", "--sub",help="Subject Number",default=1) 22 | args = parser.parse_args() 23 | sub=int(args.sub) 24 | assert sub in [1,2,5,7] 25 | 26 | cfgm_name = 'vd_noema' 27 | 28 | pth = 'versatile_diffusion/pretrained/vd-four-flow-v1-0-fp16-deprecated.pth' 29 | cfgm = model_cfg_bank()(cfgm_name) 30 | net = get_model()(cfgm) 31 | sd = torch.load(pth, map_location='cpu') 32 | net.load_state_dict(sd, strict=False) 33 | 34 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 35 | net.clip = net.clip.to(device) 36 | 37 | class batch_generator_external_images(Dataset): 38 | 39 | def __init__(self, data_path): 40 | self.data_path = data_path 41 | self.im = np.load(data_path).astype(np.uint8) 42 | 43 | 44 | def __getitem__(self,idx): 45 | img = Image.fromarray(self.im[idx]) 46 | img = T.functional.resize(img,(512,512)) 47 | img = T.functional.to_tensor(img).float() 48 | #img = img/255 49 | img = img*2 - 1 50 | return img 51 | 52 | def __len__(self): 53 | return len(self.im) 54 | 55 | batch_size=1 56 | image_path = 'cache/thingsmeg/processed_data/BIGMEG1/train_images1b_sub-BIGMEG1.npy' 57 | train_images = batch_generator_external_images(data_path = image_path) 58 | 59 | image_path = 'cache/thingsmeg/processed_data/BIGMEG1/test_images1b_sub-BIGMEG1.npy' 60 | test_images = batch_generator_external_images(data_path = image_path) 61 | 62 | trainloader = DataLoader(train_images,batch_size,shuffle=False) 63 | testloader = DataLoader(test_images,batch_size,shuffle=False) 64 | 65 | num_embed, num_features, num_test, num_train = 257, 768, len(test_images), len(train_images) 66 | 67 | train_clip = np.zeros((num_train,num_embed,num_features)) 68 | test_clip = np.zeros((num_test,num_embed,num_features)) 69 | 70 | with torch.no_grad(): 71 | for i,cin in enumerate(testloader): 72 | print(i) 73 | #ctemp = cin*2 - 1 74 | c = net.clip_encode_vision(cin) 75 | test_clip[i] = c[0].cpu().numpy() 76 | 77 | np.save('cache/thingsmeg/extracted_embeddings/BIGMEG1/test_clipvision1b_sub-BIGMEG1_temp.npy',test_clip) 78 | 79 | for i,cin in enumerate(trainloader): 80 | print(i) 81 | #ctemp = cin*2 - 1 82 | c = net.clip_encode_vision(cin) 83 | train_clip[i] = c[0].cpu().numpy() 84 | np.save('cache/thingsmeg/extracted_embeddings/BIGMEG1/train_clipvision1b_sub-BIGMEG1_temp.npy',train_clip) 85 | 86 | 87 | 88 | 89 | -------------------------------------------------------------------------------- /scripts-thingsmeg/generate_captions1b.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import torch 3 | from PIL import Image 4 | import requests 5 | from lavis.models import load_model_and_preprocess 6 | import numpy as np 7 | from tqdm import tqdm 8 | 9 | # %% 10 | # setup device to use 11 | device = torch.device("cuda:6") if torch.cuda.is_available() else "cpu" 12 | 13 | # %% 14 | # we associate a model with its preprocessors to make it easier for inference. 15 | # model, vis_processors, _ = load_model_and_preprocess( 16 | # name="blip2_t5", model_type="pretrain_flant5xxl", is_eval=True, device=device 17 | # ) 18 | 19 | # Other available models: 20 | # 21 | model, vis_processors, _ = load_model_and_preprocess( 22 | name="blip2_opt", model_type="pretrain_opt2.7b", is_eval=True, device=device 23 | ) 24 | # model, vis_processors, _ = load_model_and_preprocess( 25 | # name="blip2_opt", model_type="pretrain_opt6.7b", is_eval=True, device=device 26 | # ) 27 | # model, vis_processors, _ = load_model_and_preprocess( 28 | # name="blip2_opt", model_type="caption_coco_opt2.7b", is_eval=True, device=device 29 | # ) 30 | # model, vis_processors, _ = load_model_and_preprocess( 31 | # name="blip2_opt", model_type="caption_coco_opt6.7b", is_eval=True, device=device 32 | # ) 33 | # 34 | # model, vis_processors, _ = load_model_and_preprocess( 35 | # name="blip2_t5", model_type="pretrain_flant5xl", is_eval=True, device=device 36 | # ) 37 | # 38 | # model, vis_processors, _ = load_model_and_preprocess( 39 | # name="blip2_t5", model_type="caption_coco_flant5xl", is_eval=True, device=device 40 | # ) 41 | 42 | vis_processors.keys() 43 | 44 | # %% [markdown] 45 | # ## train set 46 | 47 | # %% 48 | images = np.load('cache/thingsmeg/processed_data/BIGMEG1/test_images1b_sub-BIGMEG1.npy', mmap_mode='r') 49 | captions = [] 50 | for i in tqdm(range(len(images)), total=len(images), desc="test captions"): 51 | image_pil = Image.fromarray(images[i].astype(np.uint8)) 52 | image = vis_processors["eval"](image_pil).unsqueeze(0).to(device) 53 | captions.append(model.generate({"image": image})[0]) 54 | np.save('cache/thingsmeg/processed_data/BIGMEG1/test_captions1b_sub-BIGMEG1.npy', captions) 55 | 56 | # %% [markdown] 57 | # ## test set 58 | 59 | # %% 60 | images = np.load('cache/thingsmeg/processed_data/BIGMEG1/train_images1b_sub-BIGMEG1.npy', mmap_mode='r') 61 | captions = [] 62 | for i in tqdm(range(len(images)), total=len(images), desc="train captions"): 63 | image_pil = Image.fromarray(images[i].astype(np.uint8)) 64 | image = vis_processors["eval"](image_pil).unsqueeze(0).to(device) 65 | captions.append(model.generate({"image": image})[0]) 66 | np.save('cache/thingsmeg/processed_data/BIGMEG1/train_captions1b_sub-BIGMEG1.npy', captions) 67 | 68 | 69 | -------------------------------------------------------------------------------- /scripts-thingsmeg/save_test_images.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from PIL import Image 4 | 5 | #The same for all subjects 6 | # images = np.load('data/processed_data/subj01/nsd_test_stim_sub1.npy') 7 | images = np.load('cache/processed_data/BIGMEG1/test_images_sub-BIGMEG1.npy', mmap_mode='r') 8 | 9 | test_images_dir = 'cache/thingsmeg_stimuli/test_images/' 10 | 11 | if not os.path.exists(test_images_dir): 12 | os.makedirs(test_images_dir) 13 | for i in range(len(images)): 14 | im = Image.fromarray(images[i].astype(np.uint8)) 15 | im.save('{}/{}.png'.format(test_images_dir,i)) 16 | 17 | 18 | -------------------------------------------------------------------------------- /scripts-thingsmeg/save_test_images1b.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from PIL import Image 4 | 5 | #The same for all subjects 6 | # images = np.load('data/processed_data/subj01/nsd_test_stim_sub1.npy') 7 | 8 | # images = np.load('cache/processed_data/BIGMEG1/test_images1b_sub-BIGMEG1.npy', mmap_mode='r') 9 | # test_images_dir = 'cache/thingsmeg_stimuli/test_images1b/' 10 | 11 | images = np.load('cache/thingsmeg/processed_data/BIGMEG1/test_avg_images1b_sub-BIGMEG1.npy', mmap_mode='r') 12 | test_images_dir = 'cache/thingsmeg_stimuli/avg_test_images1b/' 13 | 14 | if not os.path.exists(test_images_dir): 15 | os.makedirs(test_images_dir) 16 | for i in range(len(images)): 17 | im = Image.fromarray(images[i].astype(np.uint8)) 18 | im.save('{}/{}.png'.format(test_images_dir,i)) 19 | 20 | 21 | -------------------------------------------------------------------------------- /scripts-thingsmeg/save_things_categories.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | 4 | # df = pd.read_csv('THINGS-images/things_concepts.tsv',sep = '\t') 5 | # print(df.head()) 6 | 7 | ## get text for each image 8 | 9 | image_concept_index = pd.read_csv('THINGS-images/Metadata/Concept-specific/image_concept_index.csv', header=None).to_numpy()[:,0] - 1 10 | print(image_concept_index.shape) 11 | print(image_concept_index[:20]) 12 | 13 | words = pd.read_csv('THINGS-images/Metadata/Concept-specific/words.csv', header=None).to_numpy()[:,0] 14 | print(words.shape) 15 | print(words[:20]) 16 | 17 | text_np = np.array([words[id] for id in image_concept_index]) 18 | print(text_np[:20]) 19 | if not os.path.exists('data'): 20 | os.makedirs('data') 21 | np.save('data/things_text_labels.npy', text_np) 22 | 23 | 24 | # ## extract clip text 25 | # import sys 26 | # sys.path.append('versatile_diffusion') 27 | # import torch 28 | # from lib.cfg_helper import model_cfg_bank 29 | # from lib.model_zoo import get_model 30 | # from torch.utils.data import DataLoader, Dataset 31 | 32 | # from lib.model_zoo.vd import VD 33 | # from lib.cfg_holder import cfg_unique_holder as cfguh 34 | # from lib.cfg_helper import get_command_line_args, cfg_initiates, load_cfg_yaml 35 | # import torchvision.transforms as T 36 | # from tqdm import tqdm 37 | 38 | # cfgm_name = 'vd_noema' 39 | # pth = 'versatile_diffusion/pretrained/vd-four-flow-v1-0-fp16-deprecated.pth' 40 | # cfgm = model_cfg_bank()(cfgm_name) 41 | # net = get_model()(cfgm) 42 | # sd = torch.load(pth, map_location='cpu') 43 | # net.load_state_dict(sd, strict=False) 44 | 45 | # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 46 | # net.clip = net.clip.to(device) 47 | 48 | # num_embed, num_features, num_texts = 77, 768, len(text_np) 49 | # clip = np.zeros((num_texts,num_embed,num_features)) 50 | 51 | # with torch.no_grad(): 52 | # for i,cin in tqdm(enumerate(text_np), total=len(text_np)): 53 | # cin = [cin] 54 | # # print(i) 55 | # #ctemp = cin*2 - 1 56 | # c = net.clip_encode_text(cin) 57 | # clip[i] = c[0].cpu().numpy().mean(0) 58 | 59 | # np.save('data/extracted_features/things_cliptext.npy',clip) 60 | -------------------------------------------------------------------------------- /scripts-thingsmeg/save_things_images.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | from tqdm import tqdm 4 | import pandas as pd 5 | 6 | # image = Image.open('THINGS-images/Images/aardvark/aardvark_01b.jpg') 7 | # image = image.resize((425, 425)) 8 | 9 | # np_image = np.array(image) 10 | # print(np_image.shape) 11 | 12 | # list all items in a folder 13 | import os 14 | # path = 'THINGS-images/Images/aardvark/' 15 | path = 'THINGS-images/Images/' 16 | # items = os.listdir(path) 17 | # print(items) 18 | 19 | # list all items in a folder with a specific extension 20 | # import glob 21 | # items = glob.glob(path + '*.jpg') 22 | # print(items) 23 | 24 | # list all items in subfolders of a folder, in alphabetical order 25 | import glob 26 | items = sorted(glob.glob(path + '**/*.jpg', recursive=True)) 27 | print(len(items)) 28 | items1 = [item[21:] for item in items] 29 | print(items1[:20]) 30 | 31 | # im_dim, im_c = 425, 3 32 | # stim_images = np.zeros((len(items), im_dim, im_dim, im_c)) 33 | # print(stim_images.shape) 34 | # for i, item in tqdm.tqdm(enumerate(items)): 35 | # image = Image.open(item) 36 | # image = image.resize((im_dim, im_dim)) 37 | # stim_images[i] = np.array(image) 38 | 39 | # np.save('data/processed_data/things_stim_images.npy', stim_images) 40 | # print("THINGS images are saved.") 41 | 42 | items = pd.read_csv('THINGS-images/Metadata/Image-specific/image_paths.csv', header=None).to_numpy() 43 | items2 = [item[0][7:] for item in items] 44 | print(items.shape) 45 | print(items2[:20]) 46 | print(np.setdiff1d(items1, items2)) 47 | 48 | im_dim, im_c = 425, 3 49 | stim_images = np.zeros((len(items), im_dim, im_dim, im_c)) 50 | print(stim_images.shape) 51 | for i, item in tqdm(enumerate(items), total=len(items)): 52 | item = item[0] 53 | image = Image.open(path + item[7:]) 54 | image = image.resize((im_dim, im_dim)) 55 | stim_images[i] = np.array(image) 56 | 57 | if not os.path.exists('data'): 58 | os.makedirs('data') 59 | np.save('data/things_stim_images.npy', stim_images) 60 | print("THINGS images are saved.") -------------------------------------------------------------------------------- /vdvae/LICENSE.md: -------------------------------------------------------------------------------- 1 | Copyright © 2020 OpenAI 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /vdvae/files_to_npy.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | import imageio 4 | import glob 5 | import os 6 | 7 | if __name__ == "__main__": 8 | print("moving images in", sys.argv[1], "to", sys.argv[2]) 9 | files = glob.glob(os.path.join(sys.argv[1], "*.png")) 10 | shape = imageio.imread(files[0]).shape 11 | data = np.zeros(shape=(len(files), *shape), dtype=np.uint8) 12 | for idx, f in enumerate(files): 13 | data[idx] = imageio.imread(f) 14 | np.save(sys.argv[2], data) 15 | -------------------------------------------------------------------------------- /vdvae/header-image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desa-lab/Perceptogram/54db0981aa26d2b2f8b56e2d4ad04c81d1419d35/vdvae/header-image.png -------------------------------------------------------------------------------- /vdvae/image_utils.py: -------------------------------------------------------------------------------- 1 | import io 2 | # import IPython.display 3 | import PIL.Image 4 | import os 5 | from pprint import pformat 6 | import numpy as np 7 | 8 | def imgrid(imarray, cols=4, pad=1, padval=255, row_major=True): 9 | """Lays out a [N, H, W, C] image array as a single image grid.""" 10 | pad = int(pad) 11 | if pad < 0: 12 | raise ValueError('pad must be non-negative') 13 | cols = int(cols) 14 | assert cols >= 1 15 | N, H, W, C = imarray.shape 16 | rows = N // cols + int(N % cols != 0) 17 | batch_pad = rows * cols - N 18 | assert batch_pad >= 0 19 | post_pad = [batch_pad, pad, pad, 0] 20 | pad_arg = [[0, p] for p in post_pad] 21 | imarray = np.pad(imarray, pad_arg, 'constant', constant_values=padval) 22 | H += pad 23 | W += pad 24 | grid = (imarray 25 | .reshape(rows, cols, H, W, C) 26 | .transpose(0, 2, 1, 3, 4) 27 | .reshape(rows*H, cols*W, C)) 28 | if pad: 29 | grid = grid[:-pad, :-pad] 30 | return grid 31 | 32 | def interleave(*args): 33 | """Interleaves input arrays of the same shape along the batch axis.""" 34 | if not args: 35 | raise ValueError('At least one argument is required.') 36 | a0 = args[0] 37 | if any(a.shape != a0.shape for a in args): 38 | raise ValueError('All inputs must have the same shape.') 39 | if not a0.shape: 40 | raise ValueError('Inputs must have at least one axis.') 41 | out = np.transpose(args, [1, 0] + list(range(2, len(a0.shape) + 1))) 42 | out = out.reshape(-1, *a0.shape[1:]) 43 | return out 44 | 45 | # def imshow(a, format='png', jpeg_fallback=True): 46 | # """Displays an image in the given format.""" 47 | # a = a.astype(np.uint8) 48 | # data = io.BytesIO() 49 | # PIL.Image.fromarray(a).save(data, format) 50 | # im_data = data.getvalue() 51 | # try: 52 | # disp = IPython.display.display(IPython.display.Image(im_data)) 53 | # except IOError: 54 | # if jpeg_fallback and format != 'jpeg': 55 | # print ('Warning: image was too large to display in format "{}"; ' 56 | # 'trying jpeg instead.').format(format) 57 | # return imshow(a, format='jpeg') 58 | # else: 59 | # raise 60 | # return disp 61 | 62 | def image_to_uint8(x): 63 | """Converts [-1, 1] float array to [0, 255] uint8.""" 64 | x = np.asarray(x) 65 | x = (256. / 2.) * (x + 1.) 66 | x = np.clip(x, 0, 255) 67 | x = x.astype(np.uint8) 68 | return x 69 | -------------------------------------------------------------------------------- /vdvae/model/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /vdvae/setup_cifar10.sh: -------------------------------------------------------------------------------- 1 | wget https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz 2 | tar -xf cifar-10-python.tar.gz 3 | -------------------------------------------------------------------------------- /vdvae/setup_ffhq1024.sh: -------------------------------------------------------------------------------- 1 | # the first argument to this script should be the path to the ffhq_images1024x1024 folder 2 | # the same path should be provided as the `data_root` argument to train.py 3 | cd $1 4 | mkdir train 5 | mkdir train/0 6 | mkdir valid 7 | mkdir valid/0 8 | for i in $(seq -f "%05g" 0 64999); do 9 | mv $i.png train/0 10 | done 11 | for i in $(seq -f "%05g" 65000 69999); do 12 | mv $i.png valid/0 13 | done 14 | -------------------------------------------------------------------------------- /vdvae/setup_ffhq256.sh: -------------------------------------------------------------------------------- 1 | # we provide a copy of ffhq-256 for convenience, downsampled using the same function as NVAE (https://github.com/NVlabs/NVAE) (personal communication with author) 2 | 3 | # Resizing function is this one, with the default second argument and size=256 4 | # https://pytorch.org/docs/stable/torchvision/transforms.html#torchvision.transforms.Resize 5 | 6 | # 5-bit precision is calculated using the following lines, with num_bits=5 and for an x in [0, 1] 7 | # x = torch.floor(x * 255 / 2 ** (8 - num_bits)) 8 | # x /= (2 ** num_bits - 1) 9 | 10 | # the DMOL loss should also be adjusted to have 32 buckets instead of 256 (this code, or NVAE, can be used as reference) 11 | 12 | wget https://openaipublic.blob.core.windows.net/very-deep-vaes-assets/vdvae-assets/ffhq-256.npy -------------------------------------------------------------------------------- /vdvae/setup_imagenet.sh: -------------------------------------------------------------------------------- 1 | if [ "$1" == "imagenet32" ]; then 2 | 3 | echo "downloading imagenet32" 4 | wget http://www.image-net.org/small/train_32x32.tar 5 | wget http://www.image-net.org/small/valid_32x32.tar 6 | tar -xvf train_32x32.tar 7 | tar -xvf valid_32x32.tar 8 | python files_to_npy.py train_32x32/ imagenet32-train.npy 9 | python files_to_npy.py valid_32x32/ imagenet32-valid.npy 10 | 11 | elif [ "$1" == "imagenet64" ]; then 12 | 13 | echo "downloading imagenet64" 14 | wget http://www.image-net.org/small/train_64x64.tar 15 | wget http://www.image-net.org/small/valid_64x64.tar 16 | tar -xvf train_64x64.tar 17 | tar -xvf valid_64x64.tar 18 | python files_to_npy.py train_64x64/ imagenet64-train.npy 19 | python files_to_npy.py valid_64x64/ imagenet64-valid.npy 20 | 21 | else 22 | 23 | echo "please pass the string imagenet32 or imagenet64 as an argument" 24 | 25 | fi 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /versatile_diffusion/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 SHI Labs 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /versatile_diffusion/README_extra.md: -------------------------------------------------------------------------------- 1 | ## Baselines and VD-DC 2 | 3 | Some non-core models can be downloaded from this [link](https://drive.google.com/drive/folders/1SloRnOO9UnonfvubPWfw0uFpLco_2JvH?usp=sharing). 4 | It contains two baseline models: ```sd-v1-4.pth``` and ```sd-variation.pth```, and our VD-DC model ```vd-dc.pth``` 5 | 6 | All models should be copyed to ```pretrained``` folder. 7 | 8 | To evaluate baseline experiments: 9 | 10 | ``` 11 | python main.py --config sd_eval --gpu 0 1 2 3 4 5 6 7 --eval 99999 12 | python main.py --config sd_variation_eval --gpu 0 1 2 3 4 5 6 7 --eval 99999 13 | ``` 14 | 15 | You will need to create ```./log/sd_nodataset/99999_eval``` to make these baseline evaluations running. 16 | 17 | To evaluate VD-DC experiments: 18 | 19 | ``` 20 | python main.py --config vd_dc_eval --gpu 0 1 2 3 4 5 6 7 --eval 99999 21 | ``` 22 | 23 | Similarly, you will need to create ```./log/vd_nodataset/99999_eval``` to make the evaluation running. 24 | -------------------------------------------------------------------------------- /versatile_diffusion/configs/dataset/laion2b.yaml: -------------------------------------------------------------------------------- 1 | laion2b: 2 | symbol: laion2b 3 | type: laion2b 4 | root_dir: data/laion2b/ 5 | pin_memory: true 6 | 7 | laion2b_dummy: 8 | super_cfg: laion2b 9 | type: laion2b_dummy 10 | 11 | ############################### 12 | # webdataset based dataloader # 13 | ############################### 14 | 15 | laion2b_webdataset: 16 | super_cfg: laion2b 17 | type: laion2b_webdataset 18 | 19 | laion2b_webdataset_256: 20 | super_cfg: laion2b_webdataset 21 | scale: 256 22 | min_size: 224 23 | 24 | laion2b_webdataset_256_debug: 25 | super_cfg: laion2b_webdataset_256 26 | root_dir: data/laion2b-debug/ 27 | 28 | laion2b_webdataset_512: 29 | super_cfg: laion2b_webdataset 30 | scale: 512 31 | min_size: 448 32 | 33 | laion2b_webdataset_512_debug: 34 | super_cfg: laion2b_webdataset_512 35 | root_dir: data/laion2b-debug/ 36 | 37 | laion2b_webdataset_512_sdofficial: 38 | super_cfg: laion2b 39 | type: laion2b_webdataset_sdofficial 40 | scale: 512 41 | min_size: 448 42 | 43 | laion2b_webdataset_512_sdofficial_debug: 44 | super_cfg: laion2b_webdataset_512_sdofficial 45 | root_dir: data/laion2b-debug/ 46 | shuffle: 0 47 | -------------------------------------------------------------------------------- /versatile_diffusion/configs/experiment/sd_eval.yaml: -------------------------------------------------------------------------------- 1 | env: 2 | debug: false 3 | cuda: true 4 | dist_backend: nccl 5 | matplotlib_mode: agg 6 | log_root_dir: log 7 | rnd_seed: 20 8 | 9 | model: MODEL(sd_t2i_fullclip_backward_compatible) 10 | 11 | eval: 12 | main: lib.experiments.sd_default.eval 13 | stage: lib.experiments.sd_default.eval_stage 14 | dataset: null 15 | 16 | conditioning: 17 | - a dream of a village in china, by Caspar David Friedrich, matte painting trending on artstation HQ 18 | - a beautiful grand nebula in the universe 19 | - area of rocks that deep inside the forest, divine domain 20 | - heavy arms gundam penguin mech 21 | - realistic scenery of houston texas city view under a starry sky in hyperrealistic style and ultra HD, 8k 22 | - red maple on a hill in golden autumn 23 | - man standing on the beach near sea 24 | - blue and yellow balloons in the sky 25 | 26 | replicate: 1 27 | 28 | sample: 29 | output_dim: [512, 512] 30 | n_samples: 4 31 | ddim_steps: 50 32 | ddim_eta: 0.0 33 | scale: 5.0 34 | 35 | batch_size_per_gpu: 0 36 | batch_size: null 37 | dataset_num_workers_per_gpu: 0 38 | dataset_num_workers: null 39 | 40 | evaluator: null 41 | log_every: null 42 | 43 | pretrained_pth: pretrained/sd-v1-4.pth 44 | strict_sd: true 45 | 46 | fix_seed: true 47 | eval_subdir: sd_v1_4 48 | -------------------------------------------------------------------------------- /versatile_diffusion/configs/experiment/sd_variation_eval.yaml: -------------------------------------------------------------------------------- 1 | env: 2 | debug: false 3 | cuda: true 4 | dist_backend: nccl 5 | matplotlib_mode: agg 6 | log_root_dir: log 7 | rnd_seed: 200 8 | 9 | model: MODEL(sd_variation) 10 | eval: 11 | main: lib.experiments.sd_default.eval 12 | stage: lib.experiments.sd_default.eval_stage_variation 13 | dataset: null 14 | save_code: true 15 | 16 | conditioning: 17 | - assets/benz.jpg 18 | - assets/ghibli.jpg 19 | - assets/horse.png 20 | - assets/matisse.jpg 21 | - assets/penguin.png 22 | - assets/scream.jpg 23 | - assets/space.jpg 24 | - assets/vermeer.jpg 25 | - assets/boy_and_girl.jpg 26 | - assets/church.jpg 27 | - assets/firework.jpg 28 | - assets/house_by_lake.jpg 29 | - assets/night_light.jpg 30 | - assets/san_diego.jpg 31 | - assets/tiger.jpg 32 | - assets/train.jpg 33 | 34 | replicate: 1 35 | 36 | sample: 37 | output_dim: [512, 512] 38 | n_samples: 4 39 | ddim_steps: 50 40 | ddim_eta: 0.0 41 | scale: 7.5 42 | 43 | color_adj: true 44 | color_adj_keep_ratio: 0.5 45 | color_adj_simple: true 46 | 47 | batch_size_per_gpu: 0 48 | batch_size: null 49 | dataset_num_workers_per_gpu: 0 50 | dataset_num_workers: null 51 | 52 | pretrained_pth_ema: pretrained/sd-variation-ema.pth 53 | strict_sd: true 54 | 55 | is_lite: true 56 | fix_seed: true 57 | eval_subdir: sd_variation 58 | -------------------------------------------------------------------------------- /versatile_diffusion/configs/experiment/vd_dc_eval.yaml: -------------------------------------------------------------------------------- 1 | env: 2 | debug: false 3 | cuda: true 4 | dist_backend: nccl 5 | matplotlib_mode: agg 6 | log_root_dir: log 7 | rnd_seed: 200 8 | 9 | model: MODEL(vd_dc_noema) 10 | 11 | eval: 12 | main: lib.experiments.vd_default.eval 13 | stage: lib.experiments.vd_default.eval_stage_dc 14 | dataset: null 15 | save_code: true 16 | 17 | conditioning: 18 | - a dream of a village in china, by Caspar David Friedrich, matte painting trending on artstation HQ 19 | # - a beautiful grand nebula in the universe 20 | # - area of rocks that deep inside the forest, divine domain 21 | # - heavy arms gundam penguin mech 22 | # - realistic scenery of houston texas city view under a starry sky in hyperrealistic style and ultra HD, 8k 23 | # - red maple on a hill in golden autumn 24 | # - man standing on the beach near sea 25 | # - blue and yellow balloons in the sky 26 | - assets/benz.jpg 27 | # - assets/ghibli.jpg 28 | # - assets/horse.png 29 | # - assets/matisse.jpg 30 | # - assets/penguin.png 31 | # - assets/scream.jpg 32 | # - assets/space.jpg 33 | # - assets/vermeer.jpg 34 | 35 | replicate: 1 36 | 37 | sample: 38 | output_dim: [512, 512] 39 | n_samples: 4 40 | ddim_steps: 50 41 | ddim_eta: 0.0 42 | scale: 7.5 43 | 44 | color_adj: true 45 | color_adj_keep_ratio: 0.5 46 | color_adj_simple: true 47 | 48 | batch_size_per_gpu: 0 49 | batch_size: null 50 | dataset_num_workers_per_gpu: 0 51 | dataset_num_workers: null 52 | 53 | pretrained_pth: pretrained/vd-dc.pth 54 | strict_sd: true 55 | 56 | fix_seed: true 57 | eval_subdir: vd_dc 58 | -------------------------------------------------------------------------------- /versatile_diffusion/configs/experiment/vd_official_eval.yaml: -------------------------------------------------------------------------------- 1 | env: 2 | debug: false 3 | cuda: true 4 | dist_backend: nccl 5 | matplotlib_mode: agg 6 | log_root_dir: log 7 | rnd_seed: 20 8 | 9 | model: MODEL(vd_noema) 10 | 11 | eval: 12 | main: lib.experiments.vd_default.eval 13 | stage: lib.experiments.vd_default.eval_stage 14 | 15 | dataset: null 16 | 17 | conditioning: 18 | - ["a dream of a village in china, by Caspar David Friedrich, matte painting trending on artstation HQ", "image"] 19 | - ["a beautiful grand nebula in the universe", "image"] 20 | - ["area of rocks that deep inside the forest, divine domain", "image"] 21 | - ["heavy arms gundam penguin mech", "image"] 22 | - ["assets/boy_and_girl.jpg", "image"] 23 | - ["assets/house_by_lake.jpg", "image"] 24 | - ["assets/ghibli.jpg", "image"] 25 | - ["assets/scream.jpg", "image"] 26 | 27 | - ["a dream of a village in china, by Caspar David Friedrich, matte painting trending on artstation HQ", "text"] 28 | - ["a beautiful grand nebula in the universe", "text"] 29 | - ["area of rocks that deep inside the forest, divine domain", "text"] 30 | - ["heavy arms gundam penguin mech", "text"] 31 | - ["assets/boy_and_girl.jpg", "text"] 32 | - ["assets/house_by_lake.jpg", "text"] 33 | - ["assets/ghibli.jpg", "text"] 34 | - ["assets/scream.jpg", "text"] 35 | 36 | replicate: 1 37 | 38 | sample: 39 | image_output_dim: [512, 512] 40 | text_latent_dim: 768 41 | n_samples: 4 42 | ddim_steps: 50 43 | ddim_eta: 0.0 44 | scale: 7.5 45 | 46 | # Some useful post processing 47 | prompt_temperature: 1.0 48 | prompt_merge_same_adj_word: true 49 | color_adj: true 50 | color_adj_keep_ratio: 0.5 51 | color_adj_simple: true 52 | 53 | batch_size_per_gpu: 0 54 | batch_size: null 55 | dataset_num_workers_per_gpu: 0 56 | dataset_num_workers: null 57 | 58 | pretrained_pth: pretrained/vd-four-flow-v1-0.pth 59 | strict_sd: true 60 | 61 | is_lite: true 62 | fix_seed: true 63 | eval_subdir: vd-four-flow-v1-0-results 64 | -------------------------------------------------------------------------------- /versatile_diffusion/configs/model/clip.yaml: -------------------------------------------------------------------------------- 1 | 2 | clip: 3 | symbol: clip 4 | args: {} 5 | 6 | clip_frozen: 7 | super_cfg: clip 8 | type: clip_frozen 9 | args: {} 10 | 11 | clip_text_frozen: 12 | super_cfg: clip 13 | type: clip_text_frozen 14 | args: {} 15 | 16 | clip_vision_frozen: 17 | super_cfg: clip 18 | type: clip_vision_frozen 19 | args: {} 20 | 21 | ############################ 22 | # clip with focused encode # 23 | ############################ 24 | 25 | clip_frozen_encode_text: 26 | super_cfg: clip 27 | type: clip_frozen 28 | args: 29 | encode_type : encode_text 30 | 31 | clip_frozen_encode_vision: 32 | super_cfg: clip 33 | type: clip_frozen 34 | args: 35 | encode_type : encode_vision 36 | 37 | clip_frozen_encode_text_noproj: 38 | super_cfg: clip 39 | type: clip_frozen 40 | args: 41 | encode_type : encode_text_noproj 42 | 43 | ##################################### 44 | # clip vision forzen justin version # 45 | ##################################### 46 | 47 | clip_vision_frozen_justin: 48 | super_cfg: clip 49 | type: clip_vision_frozen_justin 50 | args: {} 51 | -------------------------------------------------------------------------------- /versatile_diffusion/configs/model/openai_unet.yaml: -------------------------------------------------------------------------------- 1 | openai_unet_sd: 2 | type: openai_unet 3 | args: 4 | image_size: null # no use 5 | in_channels: 4 6 | out_channels: 4 7 | model_channels: 320 8 | attention_resolutions: [ 4, 2, 1 ] 9 | num_res_blocks: [ 2, 2, 2, 2 ] 10 | channel_mult: [ 1, 2, 4, 4 ] 11 | # disable_self_attentions: [ False, False, False, False ] # converts the self-attention to a cross-attention layer if true 12 | num_heads: 8 13 | use_spatial_transformer: True 14 | transformer_depth: 1 15 | context_dim: 768 16 | use_checkpoint: True 17 | legacy: False 18 | 19 | openai_unet_dual_context: 20 | super_cfg: openai_unet_sd 21 | type: openai_unet_dual_context 22 | 23 | ######################## 24 | # Code cleaned version # 25 | ######################## 26 | 27 | openai_unet_2d: 28 | type: openai_unet_2d 29 | args: 30 | input_channels: 4 31 | model_channels: 320 32 | output_channels: 4 33 | num_noattn_blocks: [ 2, 2, 2, 2 ] 34 | channel_mult: [ 1, 2, 4, 4 ] 35 | with_attn: [true, true, true, false] 36 | num_heads: 8 37 | context_dim: 768 38 | use_checkpoint: True 39 | 40 | openai_unet_0d: 41 | type: openai_unet_0d 42 | args: 43 | input_channels: 768 44 | model_channels: 320 45 | output_channels: 768 46 | num_noattn_blocks: [ 2, 2, 2, 2 ] 47 | channel_mult: [ 1, 2, 4, 4 ] 48 | with_attn: [true, true, true, false] 49 | num_heads: 8 50 | context_dim: 768 51 | use_checkpoint: True 52 | 53 | openai_unet_0dmd: 54 | type: openai_unet_0dmd 55 | args: 56 | input_channels: 768 57 | model_channels: 320 58 | output_channels: 768 59 | num_noattn_blocks: [ 2, 2, 2, 2 ] 60 | channel_mult: [ 1, 2, 4, 4 ] 61 | second_dim: [ 4, 4, 4, 4 ] 62 | with_attn: [true, true, true, false] 63 | num_heads: 8 64 | context_dim: 768 65 | use_checkpoint: True 66 | 67 | openai_unet_vd: 68 | type: openai_unet_vd 69 | args: 70 | unet_image_cfg: MODEL(openai_unet_2d) 71 | unet_text_cfg: MODEL(openai_unet_0dmd) 72 | -------------------------------------------------------------------------------- /versatile_diffusion/configs/model/optimus.yaml: -------------------------------------------------------------------------------- 1 | 2 | optimus: 3 | symbol: optimus 4 | find_unused_parameters: false 5 | args: {} 6 | 7 | optimus_bert_encoder: 8 | super_cfg: optimus 9 | type: optimus_bert_connector 10 | # pth: pretrained/optimus_bert_encoder.pth 11 | args: 12 | config: 13 | architectures: 14 | - BertForMaskedLM 15 | attention_probs_dropout_prob: 0.1 16 | finetuning_task: null 17 | hidden_act: gelu 18 | hidden_dropout_prob: 0.1 19 | hidden_size: 768 20 | initializer_range: 0.02 21 | intermediate_size: 3072 22 | layer_norm_eps: 1.e-12 23 | max_position_embeddings: 512 24 | num_attention_heads: 12 25 | num_hidden_layers: 12 26 | num_labels: 2 27 | output_attentions: false 28 | output_hidden_states: false 29 | pruned_heads: {} 30 | torchscript: false 31 | type_vocab_size: 2 32 | vocab_size: 28996 33 | latent_size: 768 34 | 35 | optimus_bert_tokenizer: 36 | super_cfg: optimus 37 | type: optimus_bert_tokenizer 38 | args: 39 | do_lower_case: false 40 | max_len: 512 41 | vocab_file: versatile_diffusion/lib/model_zoo/optimus_models/vocab/bert-base-cased-vocab.txt 42 | 43 | optimus_gpt2_decoder: 44 | super_cfg: optimus 45 | type: optimus_gpt2_connector 46 | # pth: pretrained/optimus_gpt2_decoder.pth 47 | args: 48 | config: 49 | architectures: 50 | - GPT2LMHeadModel 51 | attn_pdrop: 0.1 52 | embd_pdrop: 0.1 53 | finetuning_task: null 54 | hidden_size: 768 55 | initializer_range: 0.02 56 | latent_size: 768 57 | layer_norm_epsilon: 1.e-05 58 | max_position_embeddings: 1024 59 | n_ctx: 1024 60 | n_embd: 768 61 | n_head: 12 62 | n_layer: 12 63 | n_positions: 1024 64 | num_attention_heads: 12 65 | num_hidden_layers: 12 66 | num_labels: 1 67 | output_attentions: false 68 | output_hidden_states: false 69 | pretrained_config_archive_map: 70 | gpt2 : https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json 71 | gpt2-medium : https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json 72 | gpt2-large : https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-config.json 73 | pruned_heads: {} 74 | resid_pdrop: 0.1 75 | summary_activation: null 76 | summary_first_dropout: 0.1 77 | summary_proj_to_labels: true 78 | summary_type: cls_index 79 | summary_use_proj: true 80 | torchscript: false 81 | vocab_size: 50260 82 | 83 | optimus_gpt2_tokenizer: 84 | super_cfg: optimus 85 | type: optimus_gpt2_tokenizer 86 | args: 87 | do_lower_case: false 88 | max_len: 1024 89 | vocab_file: versatile_diffusion/lib/model_zoo/optimus_models/vocab/gpt2-vocab.json 90 | merges_file: versatile_diffusion/lib/model_zoo/optimus_models/vocab/gpt2-merges.txt 91 | 92 | optimus_vae: 93 | super_cfg: optimus 94 | type: optimus_vae 95 | pth: versatile_diffusion/pretrained/optimus-vae.pth 96 | args: 97 | encoder: MODEL(optimus_bert_encoder) 98 | decoder: MODEL(optimus_gpt2_decoder) 99 | tokenizer_encoder: MODEL(optimus_bert_tokenizer) 100 | tokenizer_decoder: MODEL(optimus_gpt2_tokenizer) 101 | args: 102 | latent_size: 768 103 | -------------------------------------------------------------------------------- /versatile_diffusion/configs/model/sd.yaml: -------------------------------------------------------------------------------- 1 | sd_base: 2 | symbol: sd 3 | find_unused_parameters: true 4 | 5 | sd_autoencoder: 6 | type: autoencoderkl 7 | args: 8 | embed_dim: 4 9 | monitor: val/rec_loss 10 | ddconfig: 11 | double_z: true 12 | z_channels: 4 13 | resolution: 256 14 | in_channels: 3 15 | out_ch: 3 16 | ch: 128 17 | ch_mult: [1, 2, 4, 4] 18 | num_res_blocks: 2 19 | attn_resolutions: [] 20 | dropout: 0.0 21 | lossconfig: 22 | target: torch.nn.Identity 23 | pth: versatile_diffusion/pretrained/kl-f8.pth 24 | 25 | sd_t2i: 26 | super_cfg: sd_base 27 | type: sd_t2i 28 | args: 29 | first_stage_config: MODEL(sd_autoencoder) 30 | cond_stage_config: MODEL(clip_text_frozen) 31 | unet_config: MODEL(openai_unet_sd) 32 | beta_linear_start: 0.00085 33 | beta_linear_end: 0.012 34 | num_timesteps_cond: 1 35 | timesteps: 1000 36 | scale_factor: 0.18215 37 | use_ema: true 38 | 39 | sd_t2i_noema: 40 | super_cfg: sd 41 | args: 42 | use_ema: false 43 | 44 | ##################### 45 | # sd with full clip # 46 | ##################### 47 | 48 | sd_t2i_fullclip_backward_compatible: 49 | super_cfg: sd_t2i 50 | args: 51 | cond_stage_config: MODEL(clip_frozen_encode_text_noproj) 52 | 53 | sd_t2i_fullclip_backward_compatible_noema: 54 | super_cfg: sd_t2i_noema 55 | args: 56 | cond_stage_config: MODEL(clip_frozen_encode_text_noproj) 57 | 58 | sd_t2i_fullclip: 59 | super_cfg: sd_t2i 60 | args: 61 | cond_stage_config: MODEL(clip_frozen_encode_text) 62 | 63 | sd_variation: 64 | super_cfg: sd_t2i 65 | type: sd_variation 66 | args: 67 | cond_stage_config: MODEL(clip_vision_frozen_justin) 68 | 69 | -------------------------------------------------------------------------------- /versatile_diffusion/configs/model/vd.yaml: -------------------------------------------------------------------------------- 1 | # vd_base: 2 | # symbol: vd 3 | # find_unused_parameters: true 4 | 5 | ############ 6 | # vd basic # 7 | ############ 8 | 9 | vd_basic: 10 | super_cfg: sd_t2i 11 | type: vd_basic 12 | symbol: vd 13 | find_unused_parameters: true 14 | args: 15 | cond_stage_config: MODEL(clip_frozen_encode_vision) 16 | 17 | vd_basic_noema: 18 | super_cfg: vd_basic 19 | args: 20 | use_ema: false 21 | 22 | ################### 23 | # vd dual-context # 24 | ################### 25 | 26 | vd_dc: 27 | super_cfg: sd_t2i_fullclip 28 | type: vd_dc 29 | symbol: vd 30 | find_unused_parameters: true 31 | args: 32 | unet_config: MODEL(openai_unet_dual_context) 33 | 34 | vd_dc_noema: 35 | super_cfg: vd_dc 36 | args: 37 | use_ema: false 38 | 39 | ###### 40 | # vd # 41 | ###### 42 | 43 | vd: 44 | type: vd 45 | symbol: vd 46 | find_unused_parameters: true 47 | args: 48 | autokl_cfg: MODEL(sd_autoencoder) 49 | optimus_cfg: MODEL(optimus_vae) 50 | clip_cfg: MODEL(clip_frozen) 51 | unet_config: MODEL(openai_unet_vd) 52 | beta_linear_start: 0.00085 53 | beta_linear_end: 0.012 54 | timesteps: 1000 55 | scale_factor: 0.18215 56 | use_ema: true 57 | 58 | vd_noema: 59 | super_cfg: vd 60 | args: 61 | use_ema: false 62 | -------------------------------------------------------------------------------- /versatile_diffusion/lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desa-lab/Perceptogram/54db0981aa26d2b2f8b56e2d4ad04c81d1419d35/versatile_diffusion/lib/__init__.py -------------------------------------------------------------------------------- /versatile_diffusion/lib/cfg_holder.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | def singleton(class_): 4 | instances = {} 5 | def getinstance(*args, **kwargs): 6 | if class_ not in instances: 7 | instances[class_] = class_(*args, **kwargs) 8 | return instances[class_] 9 | return getinstance 10 | 11 | ############## 12 | # cfg_holder # 13 | ############## 14 | 15 | @singleton 16 | class cfg_unique_holder(object): 17 | def __init__(self): 18 | self.cfg = None 19 | # this is use to track the main codes. 20 | self.code = set() 21 | def save_cfg(self, cfg): 22 | self.cfg = copy.deepcopy(cfg) 23 | def add_code(self, code): 24 | """ 25 | A new main code is reached and 26 | its name is added. 27 | """ 28 | self.code.add(code) 29 | -------------------------------------------------------------------------------- /versatile_diffusion/lib/data_factory/__init__.py: -------------------------------------------------------------------------------- 1 | from .common.ds_base import collate, get_dataset 2 | from .common.ds_loader import get_loader 3 | from .common.ds_transform import get_transform 4 | from .common.ds_estimator import get_estimator 5 | from .common.ds_formatter import get_formatter 6 | from .common.ds_sampler import get_sampler 7 | -------------------------------------------------------------------------------- /versatile_diffusion/lib/data_factory/common/__init__.py: -------------------------------------------------------------------------------- 1 | from .ds_base import ds_base, collate, register as regdataset 2 | from .ds_loader import pre_loader_checkings, register as regloader 3 | from .ds_transform import TBase, have, register as regtrans 4 | from .ds_estimator import register as regestmat 5 | from .ds_formatter import register as regformat 6 | from .ds_sampler import register as regsampler 7 | -------------------------------------------------------------------------------- /versatile_diffusion/lib/data_factory/common/ds_estimator.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import numpy as np 3 | import numpy.random as npr 4 | import PIL 5 | import cv2 6 | 7 | import torch 8 | import torchvision 9 | import xml.etree.ElementTree as ET 10 | import json 11 | import copy 12 | import math 13 | 14 | def singleton(class_): 15 | instances = {} 16 | def getinstance(*args, **kwargs): 17 | if class_ not in instances: 18 | instances[class_] = class_(*args, **kwargs) 19 | return instances[class_] 20 | return getinstance 21 | 22 | @singleton 23 | class get_estimator(object): 24 | def __init__(self): 25 | self.estimator = {} 26 | 27 | def register(self, estimf): 28 | self.estimator[estimf.__name__] = estimf 29 | 30 | def __call__(self, cfg): 31 | if cfg is None: 32 | return None 33 | t = cfg.type 34 | return self.estimator[t](**cfg.args) 35 | 36 | def register(): 37 | def wrapper(class_): 38 | get_estimator().register(class_) 39 | return class_ 40 | return wrapper 41 | 42 | @register() 43 | class PickFileEstimator(object): 44 | """ 45 | This is an estimator that filter load_info 46 | using the provided filelist 47 | """ 48 | def __init__(self, 49 | filelist = None, 50 | repeat_n = 1): 51 | """ 52 | Args: 53 | filelist: a list of string gives the name of images 54 | we would like to visualize, evaluate or train. 55 | repeat_n: int, times these images will be repeated 56 | """ 57 | self.filelist = filelist 58 | self.repeat_n = repeat_n 59 | 60 | def __call__(self, load_info): 61 | load_info_new = [] 62 | for info in load_info: 63 | if os.path.basename(info['image_path']).split('.')[0] in self.filelist: 64 | load_info_new.append(info) 65 | return load_info_new * self.repeat_n 66 | 67 | @register() 68 | class PickIndexEstimator(object): 69 | """ 70 | This is an estimator that filter load_info 71 | using the provided indices 72 | """ 73 | def __init__(self, 74 | indexlist = None, 75 | **kwargs): 76 | """ 77 | Args: 78 | indexlist: [] of int. 79 | the indices to be filtered out. 80 | """ 81 | self.indexlist = indexlist 82 | 83 | def __call__(self, load_info): 84 | load_info_new = [load_info[i] for i in self.indexlist] 85 | return load_info_new 86 | -------------------------------------------------------------------------------- /versatile_diffusion/lib/data_factory/common/ds_formatter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import numpy as np 4 | import numpy.random as npr 5 | import torch 6 | import cv2 7 | import scipy.ndimage 8 | from PIL import Image 9 | import copy 10 | import gc 11 | import itertools 12 | 13 | def singleton(class_): 14 | instances = {} 15 | def getinstance(*args, **kwargs): 16 | if class_ not in instances: 17 | instances[class_] = class_(*args, **kwargs) 18 | return instances[class_] 19 | return getinstance 20 | 21 | @singleton 22 | class get_formatter(object): 23 | def __init__(self): 24 | self.formatter = {} 25 | 26 | def register(self, formatf): 27 | self.formatter[formatf.__name__] = formatf 28 | 29 | def __call__(self, cfg): 30 | if cfg is None: 31 | return None 32 | t = cfg.type 33 | return self.formatter[t](**cfg.args) 34 | 35 | def register(): 36 | def wrapper(class_): 37 | get_formatter().register(class_) 38 | return class_ 39 | return wrapper 40 | -------------------------------------------------------------------------------- /versatile_diffusion/lib/data_factory/common/ds_loader.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import numpy as np 3 | import numpy.random as npr 4 | import PIL 5 | import cv2 6 | 7 | import torch 8 | import torchvision 9 | import xml.etree.ElementTree as ET 10 | import json 11 | import copy 12 | 13 | from ...cfg_holder import cfg_unique_holder as cfguh 14 | 15 | def singleton(class_): 16 | instances = {} 17 | def getinstance(*args, **kwargs): 18 | if class_ not in instances: 19 | instances[class_] = class_(*args, **kwargs) 20 | return instances[class_] 21 | return getinstance 22 | 23 | @singleton 24 | class get_loader(object): 25 | def __init__(self): 26 | self.loader = {} 27 | 28 | def register(self, loadf): 29 | self.loader[loadf.__name__] = loadf 30 | 31 | def __call__(self, cfg): 32 | if cfg is None: 33 | return None 34 | if isinstance(cfg, list): 35 | loader = [] 36 | for ci in cfg: 37 | t = ci.type 38 | loader.append(self.loader[t](**ci.args)) 39 | return compose(loader) 40 | t = cfg.type 41 | return self.loader[t](**cfg.args) 42 | 43 | class compose(object): 44 | def __init__(self, loaders): 45 | self.loaders = loaders 46 | 47 | def __call__(self, element): 48 | for l in self.loaders: 49 | element = l(element) 50 | return element 51 | 52 | def __getitem__(self, idx): 53 | return self.loaders[idx] 54 | 55 | def register(): 56 | def wrapper(class_): 57 | get_loader().register(class_) 58 | return class_ 59 | return wrapper 60 | 61 | def pre_loader_checkings(ltype): 62 | lpath = ltype+'_path' 63 | # cache feature added on 20201021 64 | lcache = ltype+'_cache' 65 | def wrapper(func): 66 | def inner(self, element): 67 | if lcache in element: 68 | # cache feature added on 20201021 69 | data = element[lcache] 70 | else: 71 | if ltype in element: 72 | raise ValueError 73 | if lpath not in element: 74 | raise ValueError 75 | 76 | if element[lpath] is None: 77 | data = None 78 | else: 79 | data = func(self, element[lpath], element) 80 | element[ltype] = data 81 | 82 | if ltype == 'image': 83 | if isinstance(data, np.ndarray): 84 | imsize = data.shape[-2:] 85 | elif isinstance(data, PIL.Image.Image): 86 | imsize = data.size[::-1] 87 | elif isinstance(data, torch.Tensor): 88 | imsize = [data.size(-2), data.size(-1)] 89 | elif data is None: 90 | imsize = None 91 | else: 92 | raise ValueError 93 | element['imsize'] = imsize 94 | element['imsize_current'] = copy.deepcopy(imsize) 95 | return element 96 | return inner 97 | return wrapper 98 | -------------------------------------------------------------------------------- /versatile_diffusion/lib/evaluator/__init__.py: -------------------------------------------------------------------------------- 1 | from .eva_base import get_evaluator -------------------------------------------------------------------------------- /versatile_diffusion/lib/evaluator/eva_null.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import lpips 4 | 5 | from .. import nputils 6 | from ..log_service import print_log 7 | 8 | from .eva_base import base_evaluator, register 9 | 10 | @register('null') 11 | class null_evaluator(base_evaluator): 12 | def __init__(self, **dummy): 13 | super().__init__() 14 | 15 | def add_batch(self, 16 | **dummy): 17 | pass 18 | 19 | def compute(self): 20 | return None 21 | 22 | def one_line_summary(self): 23 | print_log('Evaluator null') 24 | 25 | def clear_data(self): 26 | pass -------------------------------------------------------------------------------- /versatile_diffusion/lib/experiments/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desa-lab/Perceptogram/54db0981aa26d2b2f8b56e2d4ad04c81d1419d35/versatile_diffusion/lib/experiments/__init__.py -------------------------------------------------------------------------------- /versatile_diffusion/lib/model_zoo/__init__.py: -------------------------------------------------------------------------------- 1 | from .common.get_model import get_model 2 | from .common.get_optimizer import get_optimizer 3 | from .common.get_scheduler import get_scheduler 4 | from .common.utils import get_unit 5 | -------------------------------------------------------------------------------- /versatile_diffusion/lib/model_zoo/clip_justin/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import load 2 | -------------------------------------------------------------------------------- /versatile_diffusion/lib/model_zoo/common/get_optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import numpy as np 4 | import itertools 5 | 6 | def singleton(class_): 7 | instances = {} 8 | def getinstance(*args, **kwargs): 9 | if class_ not in instances: 10 | instances[class_] = class_(*args, **kwargs) 11 | return instances[class_] 12 | return getinstance 13 | 14 | class get_optimizer(object): 15 | def __init__(self): 16 | self.optimizer = {} 17 | self.register(optim.SGD, 'sgd') 18 | self.register(optim.Adam, 'adam') 19 | self.register(optim.AdamW, 'adamw') 20 | 21 | def register(self, optim, name): 22 | self.optimizer[name] = optim 23 | 24 | def __call__(self, net, cfg): 25 | if cfg is None: 26 | return None 27 | t = cfg.type 28 | if isinstance(net, (torch.nn.DataParallel, 29 | torch.nn.parallel.DistributedDataParallel)): 30 | netm = net.module 31 | else: 32 | netm = net 33 | pg = getattr(netm, 'parameter_group', None) 34 | 35 | if pg is not None: 36 | params = [] 37 | for group_name, module_or_para in pg.items(): 38 | if not isinstance(module_or_para, list): 39 | module_or_para = [module_or_para] 40 | 41 | grouped_params = [mi.parameters() if isinstance(mi, torch.nn.Module) else [mi] for mi in module_or_para] 42 | grouped_params = itertools.chain(*grouped_params) 43 | pg_dict = {'params':grouped_params, 'name':group_name} 44 | params.append(pg_dict) 45 | else: 46 | params = net.parameters() 47 | return self.optimizer[t](params, lr=0, **cfg.args) 48 | -------------------------------------------------------------------------------- /versatile_diffusion/lib/model_zoo/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /versatile_diffusion/lib/model_zoo/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class LitEma(nn.Module): 5 | def __init__(self, model, decay=0.9999, use_num_updates=True): 6 | super().__init__() 7 | if decay < 0.0 or decay > 1.0: 8 | raise ValueError('Decay must be between 0 and 1') 9 | 10 | self.m_name2s_name = {} 11 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 12 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_updates 13 | else torch.tensor(-1,dtype=torch.int)) 14 | 15 | for name, p in model.named_parameters(): 16 | if p.requires_grad: 17 | #remove as '.'-character is not allowed in buffers 18 | s_name = name.replace('.','') 19 | self.m_name2s_name.update({name:s_name}) 20 | self.register_buffer(s_name,p.clone().detach().data) 21 | 22 | self.collected_params = [] 23 | 24 | def forward(self, model): 25 | decay = self.decay 26 | 27 | if self.num_updates >= 0: 28 | self.num_updates += 1 29 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 30 | 31 | one_minus_decay = 1.0 - decay 32 | 33 | with torch.no_grad(): 34 | m_param = dict(model.named_parameters()) 35 | shadow_params = dict(self.named_buffers()) 36 | 37 | for key in m_param: 38 | if m_param[key].requires_grad: 39 | sname = self.m_name2s_name[key] 40 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 41 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 42 | else: 43 | assert not key in self.m_name2s_name 44 | 45 | def copy_to(self, model): 46 | m_param = dict(model.named_parameters()) 47 | shadow_params = dict(self.named_buffers()) 48 | for key in m_param: 49 | if m_param[key].requires_grad: 50 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 51 | else: 52 | assert not key in self.m_name2s_name 53 | 54 | def store(self, parameters): 55 | """ 56 | Save the current parameters for restoring later. 57 | Args: 58 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 59 | temporarily stored. 60 | """ 61 | self.collected_params = [param.clone() for param in parameters] 62 | 63 | def restore(self, parameters): 64 | """ 65 | Restore the parameters stored with the `store` method. 66 | Useful to validate the model with EMA parameters without affecting the 67 | original optimization process. Store the parameters before the 68 | `copy_to` method. After validation (or model saving), use this to 69 | restore the former parameters. 70 | Args: 71 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 72 | updated with the stored parameters. 73 | """ 74 | for c_param, param in zip(self.collected_params, parameters): 75 | param.data.copy_(c_param.data) 76 | -------------------------------------------------------------------------------- /versatile_diffusion/lib/model_zoo/optimus_models/vocab/bert_vocab_download_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", 3 | "bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", 4 | "bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", 5 | "bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", 6 | "bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", 7 | "bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", 8 | "bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", 9 | "bert-base-german-cased": "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt", 10 | "bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt", 11 | "bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt", 12 | "bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt", 13 | "bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt", 14 | "bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt" 15 | } -------------------------------------------------------------------------------- /versatile_diffusion/lib/model_zoo/optimus_models/vocab/gpt2_vocab_merge_download_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "vocab_file": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json", 3 | "merges_file": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt", 4 | } -------------------------------------------------------------------------------- /versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/configs/dataset/laion2b.yaml: -------------------------------------------------------------------------------- 1 | laion2b: 2 | symbol: laion2b 3 | type: laion2b 4 | root_dir: data/laion2b/ 5 | pin_memory: true 6 | 7 | laion2b_dummy: 8 | super_cfg: laion2b 9 | type: laion2b_dummy 10 | 11 | ############################### 12 | # webdataset based dataloader # 13 | ############################### 14 | 15 | laion2b_webdataset: 16 | super_cfg: laion2b 17 | type: laion2b_webdataset 18 | 19 | laion2b_webdataset_256: 20 | super_cfg: laion2b_webdataset 21 | scale: 256 22 | min_size: 224 23 | 24 | laion2b_webdataset_256_debug: 25 | super_cfg: laion2b_webdataset_256 26 | root_dir: data/laion2b-debug/ 27 | 28 | laion2b_webdataset_512: 29 | super_cfg: laion2b_webdataset 30 | scale: 512 31 | min_size: 448 32 | 33 | laion2b_webdataset_512_debug: 34 | super_cfg: laion2b_webdataset_512 35 | root_dir: data/laion2b-debug/ 36 | 37 | laion2b_webdataset_512_sdofficial: 38 | super_cfg: laion2b 39 | type: laion2b_webdataset_sdofficial 40 | scale: 512 41 | min_size: 448 42 | 43 | laion2b_webdataset_512_sdofficial_debug: 44 | super_cfg: laion2b_webdataset_512_sdofficial 45 | root_dir: data/laion2b-debug/ 46 | shuffle: 0 47 | -------------------------------------------------------------------------------- /versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/configs/experiment/sd_eval.yaml: -------------------------------------------------------------------------------- 1 | env: 2 | debug: false 3 | cuda: true 4 | dist_backend: nccl 5 | matplotlib_mode: agg 6 | log_root_dir: log 7 | rnd_seed: 20 8 | 9 | model: MODEL(sd_t2i_fullclip_backward_compatible) 10 | 11 | eval: 12 | main: lib.experiments.sd_default.eval 13 | stage: lib.experiments.sd_default.eval_stage 14 | dataset: null 15 | 16 | conditioning: 17 | - a dream of a village in china, by Caspar David Friedrich, matte painting trending on artstation HQ 18 | - a beautiful grand nebula in the universe 19 | - area of rocks that deep inside the forest, divine domain 20 | - heavy arms gundam penguin mech 21 | - realistic scenery of houston texas city view under a starry sky in hyperrealistic style and ultra HD, 8k 22 | - red maple on a hill in golden autumn 23 | - man standing on the beach near sea 24 | - blue and yellow balloons in the sky 25 | 26 | replicate: 1 27 | 28 | sample: 29 | output_dim: [512, 512] 30 | n_samples: 4 31 | ddim_steps: 50 32 | ddim_eta: 0.0 33 | scale: 5.0 34 | 35 | batch_size_per_gpu: 0 36 | batch_size: null 37 | dataset_num_workers_per_gpu: 0 38 | dataset_num_workers: null 39 | 40 | evaluator: null 41 | log_every: null 42 | 43 | pretrained_pth: pretrained/sd-v1-4.pth 44 | strict_sd: true 45 | 46 | fix_seed: true 47 | eval_subdir: sd_v1_4 48 | -------------------------------------------------------------------------------- /versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/configs/experiment/sd_variation_eval.yaml: -------------------------------------------------------------------------------- 1 | env: 2 | debug: false 3 | cuda: true 4 | dist_backend: nccl 5 | matplotlib_mode: agg 6 | log_root_dir: log 7 | rnd_seed: 200 8 | 9 | model: MODEL(sd_variation) 10 | eval: 11 | main: lib.experiments.sd_default.eval 12 | stage: lib.experiments.sd_default.eval_stage_variation 13 | dataset: null 14 | save_code: true 15 | 16 | conditioning: 17 | - assets/benz.jpg 18 | - assets/ghibli.jpg 19 | - assets/horse.png 20 | - assets/matisse.jpg 21 | - assets/penguin.png 22 | - assets/scream.jpg 23 | - assets/space.jpg 24 | - assets/vermeer.jpg 25 | - assets/boy_and_girl.jpg 26 | - assets/church.jpg 27 | - assets/firework.jpg 28 | - assets/house_by_lake.jpg 29 | - assets/night_light.jpg 30 | - assets/san_diego.jpg 31 | - assets/tiger.jpg 32 | - assets/train.jpg 33 | 34 | replicate: 1 35 | 36 | sample: 37 | output_dim: [512, 512] 38 | n_samples: 1 39 | ddim_steps: 50 40 | ddim_eta: 0.0 41 | scale: 7.5 42 | 43 | color_adj: true 44 | color_adj_keep_ratio: 0.5 45 | color_adj_simple: true 46 | 47 | batch_size_per_gpu: 0 48 | batch_size: null 49 | dataset_num_workers_per_gpu: 0 50 | dataset_num_workers: null 51 | 52 | pretrained_pth_ema: pretrained/sd-variation-ema.pth 53 | strict_sd: true 54 | 55 | is_lite: true 56 | fix_seed: true 57 | eval_subdir: sd_variation 58 | -------------------------------------------------------------------------------- /versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/configs/experiment/vd_dc_eval.yaml: -------------------------------------------------------------------------------- 1 | env: 2 | debug: false 3 | cuda: true 4 | dist_backend: nccl 5 | matplotlib_mode: agg 6 | log_root_dir: log 7 | rnd_seed: 200 8 | 9 | model: MODEL(vd_dc_noema) 10 | 11 | eval: 12 | main: lib.experiments.vd_default.eval 13 | stage: lib.experiments.vd_default.eval_stage_dc 14 | dataset: null 15 | save_code: true 16 | 17 | conditioning: 18 | - a dream of a village in china, by Caspar David Friedrich, matte painting trending on artstation HQ 19 | # - a beautiful grand nebula in the universe 20 | # - area of rocks that deep inside the forest, divine domain 21 | # - heavy arms gundam penguin mech 22 | # - realistic scenery of houston texas city view under a starry sky in hyperrealistic style and ultra HD, 8k 23 | # - red maple on a hill in golden autumn 24 | # - man standing on the beach near sea 25 | # - blue and yellow balloons in the sky 26 | - assets/benz.jpg 27 | # - assets/ghibli.jpg 28 | # - assets/horse.png 29 | # - assets/matisse.jpg 30 | # - assets/penguin.png 31 | # - assets/scream.jpg 32 | # - assets/space.jpg 33 | # - assets/vermeer.jpg 34 | 35 | replicate: 1 36 | 37 | sample: 38 | output_dim: [512, 512] 39 | n_samples: 4 40 | ddim_steps: 50 41 | ddim_eta: 0.0 42 | scale: 7.5 43 | 44 | color_adj: true 45 | color_adj_keep_ratio: 0.5 46 | color_adj_simple: true 47 | 48 | batch_size_per_gpu: 0 49 | batch_size: null 50 | dataset_num_workers_per_gpu: 0 51 | dataset_num_workers: null 52 | 53 | pretrained_pth: pretrained/vd-dc.pth 54 | strict_sd: true 55 | 56 | fix_seed: true 57 | eval_subdir: vd_dc 58 | -------------------------------------------------------------------------------- /versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/configs/experiment/vd_official_eval.yaml: -------------------------------------------------------------------------------- 1 | env: 2 | debug: false 3 | cuda: true 4 | dist_backend: nccl 5 | matplotlib_mode: agg 6 | log_root_dir: log 7 | rnd_seed: 20 8 | 9 | model: MODEL(vd_noema) 10 | 11 | eval: 12 | main: lib.experiments.vd_default.eval 13 | stage: lib.experiments.vd_default.eval_stage 14 | 15 | dataset: null 16 | 17 | conditioning: 18 | - ["a dream of a village in china, by Caspar David Friedrich, matte painting trending on artstation HQ", "image"] 19 | - ["a beautiful grand nebula in the universe", "image"] 20 | - ["area of rocks that deep inside the forest, divine domain", "image"] 21 | - ["heavy arms gundam penguin mech", "image"] 22 | - ["assets/boy_and_girl.jpg", "image"] 23 | - ["assets/house_by_lake.jpg", "image"] 24 | - ["assets/ghibli.jpg", "image"] 25 | - ["assets/scream.jpg", "image"] 26 | 27 | - ["a dream of a village in china, by Caspar David Friedrich, matte painting trending on artstation HQ", "text"] 28 | - ["a beautiful grand nebula in the universe", "text"] 29 | - ["area of rocks that deep inside the forest, divine domain", "text"] 30 | - ["heavy arms gundam penguin mech", "text"] 31 | - ["assets/boy_and_girl.jpg", "text"] 32 | - ["assets/house_by_lake.jpg", "text"] 33 | - ["assets/ghibli.jpg", "text"] 34 | - ["assets/scream.jpg", "text"] 35 | 36 | replicate: 1 37 | 38 | sample: 39 | image_output_dim: [512, 512] 40 | text_latent_dim: 768 41 | n_samples: 4 42 | ddim_steps: 50 43 | ddim_eta: 0.0 44 | scale: 7.5 45 | 46 | # Some useful post processing 47 | prompt_temperature: 1.0 48 | prompt_merge_same_adj_word: true 49 | color_adj: true 50 | color_adj_keep_ratio: 0.5 51 | color_adj_simple: true 52 | 53 | batch_size_per_gpu: 0 54 | batch_size: null 55 | dataset_num_workers_per_gpu: 0 56 | dataset_num_workers: null 57 | 58 | pretrained_pth: pretrained/vd-official.pth 59 | strict_sd: true 60 | 61 | is_lite: true 62 | fix_seed: true 63 | eval_subdir: vd_official 64 | -------------------------------------------------------------------------------- /versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/configs/model/clip.yaml: -------------------------------------------------------------------------------- 1 | 2 | clip: 3 | symbol: clip 4 | args: {} 5 | 6 | clip_frozen: 7 | super_cfg: clip 8 | type: clip_frozen 9 | args: {} 10 | 11 | clip_text_frozen: 12 | super_cfg: clip 13 | type: clip_text_frozen 14 | args: {} 15 | 16 | clip_vision_frozen: 17 | super_cfg: clip 18 | type: clip_vision_frozen 19 | args: {} 20 | 21 | ############################ 22 | # clip with focused encode # 23 | ############################ 24 | 25 | clip_frozen_encode_text: 26 | super_cfg: clip 27 | type: clip_frozen 28 | args: 29 | encode_type : encode_text 30 | 31 | clip_frozen_encode_vision: 32 | super_cfg: clip 33 | type: clip_frozen 34 | args: 35 | encode_type : encode_vision 36 | 37 | clip_frozen_encode_text_noproj: 38 | super_cfg: clip 39 | type: clip_frozen 40 | args: 41 | encode_type : encode_text_noproj 42 | 43 | ##################################### 44 | # clip vision forzen justin version # 45 | ##################################### 46 | 47 | clip_vision_frozen_justin: 48 | super_cfg: clip 49 | type: clip_vision_frozen_justin 50 | args: {} 51 | -------------------------------------------------------------------------------- /versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/configs/model/openai_unet.yaml: -------------------------------------------------------------------------------- 1 | openai_unet_sd: 2 | type: openai_unet 3 | args: 4 | image_size: null # no use 5 | in_channels: 4 6 | out_channels: 4 7 | model_channels: 320 8 | attention_resolutions: [ 4, 2, 1 ] 9 | num_res_blocks: [ 2, 2, 2, 2 ] 10 | channel_mult: [ 1, 2, 4, 4 ] 11 | # disable_self_attentions: [ False, False, False, False ] # converts the self-attention to a cross-attention layer if true 12 | num_heads: 8 13 | use_spatial_transformer: True 14 | transformer_depth: 1 15 | context_dim: 768 16 | use_checkpoint: True 17 | legacy: False 18 | 19 | openai_unet_dual_context: 20 | super_cfg: openai_unet_sd 21 | type: openai_unet_dual_context 22 | 23 | ######################## 24 | # Code cleaned version # 25 | ######################## 26 | 27 | openai_unet_2d: 28 | type: openai_unet_2d 29 | args: 30 | input_channels: 4 31 | model_channels: 320 32 | output_channels: 4 33 | num_noattn_blocks: [ 2, 2, 2, 2 ] 34 | channel_mult: [ 1, 2, 4, 4 ] 35 | with_attn: [true, true, true, false] 36 | num_heads: 8 37 | context_dim: 768 38 | use_checkpoint: True 39 | 40 | openai_unet_0d: 41 | type: openai_unet_0d 42 | args: 43 | input_channels: 768 44 | model_channels: 320 45 | output_channels: 768 46 | num_noattn_blocks: [ 2, 2, 2, 2 ] 47 | channel_mult: [ 1, 2, 4, 4 ] 48 | with_attn: [true, true, true, false] 49 | num_heads: 8 50 | context_dim: 768 51 | use_checkpoint: True 52 | 53 | openai_unet_0dmd: 54 | type: openai_unet_0dmd 55 | args: 56 | input_channels: 768 57 | model_channels: 320 58 | output_channels: 768 59 | num_noattn_blocks: [ 2, 2, 2, 2 ] 60 | channel_mult: [ 1, 2, 4, 4 ] 61 | second_dim: [ 4, 4, 4, 4 ] 62 | with_attn: [true, true, true, false] 63 | num_heads: 8 64 | context_dim: 768 65 | use_checkpoint: True 66 | 67 | openai_unet_vd: 68 | type: openai_unet_vd 69 | args: 70 | unet_image_cfg: MODEL(openai_unet_2d) 71 | unet_test_cfg: MODEL(openai_unet_0dmd) 72 | -------------------------------------------------------------------------------- /versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/configs/model/optimus.yaml: -------------------------------------------------------------------------------- 1 | 2 | optimus: 3 | symbol: optimus 4 | find_unused_parameters: false 5 | args: {} 6 | 7 | optimus_bert_encoder: 8 | super_cfg: optimus 9 | type: optimus_bert_connector 10 | # pth: pretrained/optimus_bert_encoder.pth 11 | args: 12 | config: 13 | architectures: 14 | - BertForMaskedLM 15 | attention_probs_dropout_prob: 0.1 16 | finetuning_task: null 17 | hidden_act: gelu 18 | hidden_dropout_prob: 0.1 19 | hidden_size: 768 20 | initializer_range: 0.02 21 | intermediate_size: 3072 22 | layer_norm_eps: 1.e-12 23 | max_position_embeddings: 512 24 | num_attention_heads: 12 25 | num_hidden_layers: 12 26 | num_labels: 2 27 | output_attentions: false 28 | output_hidden_states: false 29 | pruned_heads: {} 30 | torchscript: false 31 | type_vocab_size: 2 32 | vocab_size: 28996 33 | latent_size: 768 34 | 35 | optimus_bert_tokenizer: 36 | super_cfg: optimus 37 | type: optimus_bert_tokenizer 38 | args: 39 | do_lower_case: false 40 | max_len: 512 41 | vocab_file: lib/model_zoo/optimus_models/vocab/bert-base-cased-vocab.txt 42 | 43 | optimus_gpt2_decoder: 44 | super_cfg: optimus 45 | type: optimus_gpt2_connector 46 | # pth: pretrained/optimus_gpt2_decoder.pth 47 | args: 48 | config: 49 | architectures: 50 | - GPT2LMHeadModel 51 | attn_pdrop: 0.1 52 | embd_pdrop: 0.1 53 | finetuning_task: null 54 | hidden_size: 768 55 | initializer_range: 0.02 56 | latent_size: 768 57 | layer_norm_epsilon: 1.e-05 58 | max_position_embeddings: 1024 59 | n_ctx: 1024 60 | n_embd: 768 61 | n_head: 12 62 | n_layer: 12 63 | n_positions: 1024 64 | num_attention_heads: 12 65 | num_hidden_layers: 12 66 | num_labels: 1 67 | output_attentions: false 68 | output_hidden_states: false 69 | pretrained_config_archive_map: 70 | gpt2 : https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json 71 | gpt2-medium : https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json 72 | gpt2-large : https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-config.json 73 | pruned_heads: {} 74 | resid_pdrop: 0.1 75 | summary_activation: null 76 | summary_first_dropout: 0.1 77 | summary_proj_to_labels: true 78 | summary_type: cls_index 79 | summary_use_proj: true 80 | torchscript: false 81 | vocab_size: 50260 82 | 83 | optimus_gpt2_tokenizer: 84 | super_cfg: optimus 85 | type: optimus_gpt2_tokenizer 86 | args: 87 | do_lower_case: false 88 | max_len: 1024 89 | vocab_file: lib/model_zoo/optimus_models/vocab/gpt2-vocab.json 90 | merges_file: lib/model_zoo/optimus_models/vocab/gpt2-merges.txt 91 | 92 | optimus_vae: 93 | super_cfg: optimus 94 | type: optimus_vae 95 | pth: pretrained/optimus-vae.pth 96 | args: 97 | encoder: MODEL(optimus_bert_encoder) 98 | decoder: MODEL(optimus_gpt2_decoder) 99 | tokenizer_encoder: MODEL(optimus_bert_tokenizer) 100 | tokenizer_decoder: MODEL(optimus_gpt2_tokenizer) 101 | args: 102 | latent_size: 768 103 | -------------------------------------------------------------------------------- /versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/configs/model/sd.yaml: -------------------------------------------------------------------------------- 1 | sd_base: 2 | symbol: sd 3 | find_unused_parameters: true 4 | 5 | sd_autoencoder: 6 | type: autoencoderkl 7 | args: 8 | embed_dim: 4 9 | monitor: val/rec_loss 10 | ddconfig: 11 | double_z: true 12 | z_channels: 4 13 | resolution: 256 14 | in_channels: 3 15 | out_ch: 3 16 | ch: 128 17 | ch_mult: [1, 2, 4, 4] 18 | num_res_blocks: 2 19 | attn_resolutions: [] 20 | dropout: 0.0 21 | lossconfig: 22 | target: torch.nn.Identity 23 | pth: pretrained/kl-f8.pth 24 | 25 | sd_t2i: 26 | super_cfg: sd_base 27 | type: sd_t2i 28 | args: 29 | first_stage_config: MODEL(sd_autoencoder) 30 | cond_stage_config: MODEL(clip_text_frozen) 31 | unet_config: MODEL(openai_unet_sd) 32 | beta_linear_start: 0.00085 33 | beta_linear_end: 0.012 34 | num_timesteps_cond: 1 35 | timesteps: 1000 36 | scale_factor: 0.18215 37 | use_ema: true 38 | 39 | sd_t2i_noema: 40 | super_cfg: sd 41 | args: 42 | use_ema: false 43 | 44 | ##################### 45 | # sd with full clip # 46 | ##################### 47 | 48 | sd_t2i_fullclip_backward_compatible: 49 | super_cfg: sd_t2i 50 | args: 51 | cond_stage_config: MODEL(clip_frozen_encode_text_noproj) 52 | 53 | sd_t2i_fullclip_backward_compatible_noema: 54 | super_cfg: sd_t2i_noema 55 | args: 56 | cond_stage_config: MODEL(clip_frozen_encode_text_noproj) 57 | 58 | sd_t2i_fullclip: 59 | super_cfg: sd_t2i 60 | args: 61 | cond_stage_config: MODEL(clip_frozen_encode_text) 62 | 63 | sd_variation: 64 | super_cfg: sd_t2i 65 | type: sd_variation 66 | args: 67 | cond_stage_config: MODEL(clip_vision_frozen_justin) 68 | -------------------------------------------------------------------------------- /versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/configs/model/vd.yaml: -------------------------------------------------------------------------------- 1 | # vd_base: 2 | # symbol: vd 3 | # find_unused_parameters: true 4 | 5 | ############ 6 | # vd basic # 7 | ############ 8 | 9 | vd_basic: 10 | super_cfg: sd_t2i 11 | type: vd_basic 12 | symbol: vd 13 | find_unused_parameters: true 14 | args: 15 | cond_stage_config: MODEL(clip_frozen_encode_vision) 16 | 17 | vd_basic_noema: 18 | super_cfg: vd_basic 19 | args: 20 | use_ema: false 21 | 22 | ################### 23 | # vd dual-context # 24 | ################### 25 | 26 | vd_dc: 27 | super_cfg: sd_t2i_fullclip 28 | type: vd_dc 29 | symbol: vd 30 | find_unused_parameters: true 31 | args: 32 | unet_config: MODEL(openai_unet_dual_context) 33 | 34 | vd_dc_noema: 35 | super_cfg: vd_dc 36 | args: 37 | use_ema: false 38 | 39 | ###### 40 | # vd # 41 | ###### 42 | 43 | vd: 44 | type: vd 45 | symbol: vd 46 | find_unused_parameters: true 47 | args: 48 | autokl_cfg: MODEL(sd_autoencoder) 49 | optimus_cfg: MODEL(optimus_vae) 50 | clip_cfg: MODEL(clip_frozen) 51 | unet_config: MODEL(openai_unet_vd) 52 | beta_linear_start: 0.00085 53 | beta_linear_end: 0.012 54 | timesteps: 1000 55 | scale_factor: 0.18215 56 | use_ema: true 57 | 58 | vd_noema: 59 | super_cfg: vd 60 | args: 61 | use_ema: false 62 | -------------------------------------------------------------------------------- /versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desa-lab/Perceptogram/54db0981aa26d2b2f8b56e2d4ad04c81d1419d35/versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/__init__.py -------------------------------------------------------------------------------- /versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/cfg_holder.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | def singleton(class_): 4 | instances = {} 5 | def getinstance(*args, **kwargs): 6 | if class_ not in instances: 7 | instances[class_] = class_(*args, **kwargs) 8 | return instances[class_] 9 | return getinstance 10 | 11 | ############## 12 | # cfg_holder # 13 | ############## 14 | 15 | @singleton 16 | class cfg_unique_holder(object): 17 | def __init__(self): 18 | self.cfg = None 19 | # this is use to track the main codes. 20 | self.code = set() 21 | def save_cfg(self, cfg): 22 | self.cfg = copy.deepcopy(cfg) 23 | def add_code(self, code): 24 | """ 25 | A new main code is reached and 26 | its name is added. 27 | """ 28 | self.code.add(code) 29 | -------------------------------------------------------------------------------- /versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/data_factory/__init__.py: -------------------------------------------------------------------------------- 1 | from .common.ds_base import collate, get_dataset 2 | from .common.ds_loader import get_loader 3 | from .common.ds_transform import get_transform 4 | from .common.ds_estimator import get_estimator 5 | from .common.ds_formatter import get_formatter 6 | from .common.ds_sampler import get_sampler 7 | -------------------------------------------------------------------------------- /versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/data_factory/common/__init__.py: -------------------------------------------------------------------------------- 1 | from .ds_base import ds_base, collate, register as regdataset 2 | from .ds_loader import pre_loader_checkings, register as regloader 3 | from .ds_transform import TBase, have, register as regtrans 4 | from .ds_estimator import register as regestmat 5 | from .ds_formatter import register as regformat 6 | from .ds_sampler import register as regsampler 7 | -------------------------------------------------------------------------------- /versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/data_factory/common/ds_estimator.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import numpy as np 3 | import numpy.random as npr 4 | import PIL 5 | import cv2 6 | 7 | import torch 8 | import torchvision 9 | import xml.etree.ElementTree as ET 10 | import json 11 | import copy 12 | import math 13 | 14 | def singleton(class_): 15 | instances = {} 16 | def getinstance(*args, **kwargs): 17 | if class_ not in instances: 18 | instances[class_] = class_(*args, **kwargs) 19 | return instances[class_] 20 | return getinstance 21 | 22 | @singleton 23 | class get_estimator(object): 24 | def __init__(self): 25 | self.estimator = {} 26 | 27 | def register(self, estimf): 28 | self.estimator[estimf.__name__] = estimf 29 | 30 | def __call__(self, cfg): 31 | if cfg is None: 32 | return None 33 | t = cfg.type 34 | return self.estimator[t](**cfg.args) 35 | 36 | def register(): 37 | def wrapper(class_): 38 | get_estimator().register(class_) 39 | return class_ 40 | return wrapper 41 | 42 | @register() 43 | class PickFileEstimator(object): 44 | """ 45 | This is an estimator that filter load_info 46 | using the provided filelist 47 | """ 48 | def __init__(self, 49 | filelist = None, 50 | repeat_n = 1): 51 | """ 52 | Args: 53 | filelist: a list of string gives the name of images 54 | we would like to visualize, evaluate or train. 55 | repeat_n: int, times these images will be repeated 56 | """ 57 | self.filelist = filelist 58 | self.repeat_n = repeat_n 59 | 60 | def __call__(self, load_info): 61 | load_info_new = [] 62 | for info in load_info: 63 | if os.path.basename(info['image_path']).split('.')[0] in self.filelist: 64 | load_info_new.append(info) 65 | return load_info_new * self.repeat_n 66 | 67 | @register() 68 | class PickIndexEstimator(object): 69 | """ 70 | This is an estimator that filter load_info 71 | using the provided indices 72 | """ 73 | def __init__(self, 74 | indexlist = None, 75 | **kwargs): 76 | """ 77 | Args: 78 | indexlist: [] of int. 79 | the indices to be filtered out. 80 | """ 81 | self.indexlist = indexlist 82 | 83 | def __call__(self, load_info): 84 | load_info_new = [load_info[i] for i in self.indexlist] 85 | return load_info_new 86 | -------------------------------------------------------------------------------- /versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/data_factory/common/ds_formatter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import numpy as np 4 | import numpy.random as npr 5 | import torch 6 | import cv2 7 | import scipy.ndimage 8 | from PIL import Image 9 | import copy 10 | import gc 11 | import itertools 12 | 13 | def singleton(class_): 14 | instances = {} 15 | def getinstance(*args, **kwargs): 16 | if class_ not in instances: 17 | instances[class_] = class_(*args, **kwargs) 18 | return instances[class_] 19 | return getinstance 20 | 21 | @singleton 22 | class get_formatter(object): 23 | def __init__(self): 24 | self.formatter = {} 25 | 26 | def register(self, formatf): 27 | self.formatter[formatf.__name__] = formatf 28 | 29 | def __call__(self, cfg): 30 | if cfg is None: 31 | return None 32 | t = cfg.type 33 | return self.formatter[t](**cfg.args) 34 | 35 | def register(): 36 | def wrapper(class_): 37 | get_formatter().register(class_) 38 | return class_ 39 | return wrapper 40 | -------------------------------------------------------------------------------- /versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/data_factory/common/ds_loader.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import numpy as np 3 | import numpy.random as npr 4 | import PIL 5 | import cv2 6 | 7 | import torch 8 | import torchvision 9 | import xml.etree.ElementTree as ET 10 | import json 11 | import copy 12 | 13 | from ...cfg_holder import cfg_unique_holder as cfguh 14 | 15 | def singleton(class_): 16 | instances = {} 17 | def getinstance(*args, **kwargs): 18 | if class_ not in instances: 19 | instances[class_] = class_(*args, **kwargs) 20 | return instances[class_] 21 | return getinstance 22 | 23 | @singleton 24 | class get_loader(object): 25 | def __init__(self): 26 | self.loader = {} 27 | 28 | def register(self, loadf): 29 | self.loader[loadf.__name__] = loadf 30 | 31 | def __call__(self, cfg): 32 | if cfg is None: 33 | return None 34 | if isinstance(cfg, list): 35 | loader = [] 36 | for ci in cfg: 37 | t = ci.type 38 | loader.append(self.loader[t](**ci.args)) 39 | return compose(loader) 40 | t = cfg.type 41 | return self.loader[t](**cfg.args) 42 | 43 | class compose(object): 44 | def __init__(self, loaders): 45 | self.loaders = loaders 46 | 47 | def __call__(self, element): 48 | for l in self.loaders: 49 | element = l(element) 50 | return element 51 | 52 | def __getitem__(self, idx): 53 | return self.loaders[idx] 54 | 55 | def register(): 56 | def wrapper(class_): 57 | get_loader().register(class_) 58 | return class_ 59 | return wrapper 60 | 61 | def pre_loader_checkings(ltype): 62 | lpath = ltype+'_path' 63 | # cache feature added on 20201021 64 | lcache = ltype+'_cache' 65 | def wrapper(func): 66 | def inner(self, element): 67 | if lcache in element: 68 | # cache feature added on 20201021 69 | data = element[lcache] 70 | else: 71 | if ltype in element: 72 | raise ValueError 73 | if lpath not in element: 74 | raise ValueError 75 | 76 | if element[lpath] is None: 77 | data = None 78 | else: 79 | data = func(self, element[lpath], element) 80 | element[ltype] = data 81 | 82 | if ltype == 'image': 83 | if isinstance(data, np.ndarray): 84 | imsize = data.shape[-2:] 85 | elif isinstance(data, PIL.Image.Image): 86 | imsize = data.size[::-1] 87 | elif isinstance(data, torch.Tensor): 88 | imsize = [data.size(-2), data.size(-1)] 89 | elif data is None: 90 | imsize = None 91 | else: 92 | raise ValueError 93 | element['imsize'] = imsize 94 | element['imsize_current'] = copy.deepcopy(imsize) 95 | return element 96 | return inner 97 | return wrapper 98 | -------------------------------------------------------------------------------- /versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/evaluator/__init__.py: -------------------------------------------------------------------------------- 1 | from .eva_base import get_evaluator -------------------------------------------------------------------------------- /versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/evaluator/eva_null.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import lpips 4 | 5 | from .. import nputils 6 | from ..log_service import print_log 7 | 8 | from .eva_base import base_evaluator, register 9 | 10 | @register('null') 11 | class null_evaluator(base_evaluator): 12 | def __init__(self, **dummy): 13 | super().__init__() 14 | 15 | def add_batch(self, 16 | **dummy): 17 | pass 18 | 19 | def compute(self): 20 | return None 21 | 22 | def one_line_summary(self): 23 | print_log('Evaluator null') 24 | 25 | def clear_data(self): 26 | pass -------------------------------------------------------------------------------- /versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/experiments/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desa-lab/Perceptogram/54db0981aa26d2b2f8b56e2d4ad04c81d1419d35/versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/experiments/__init__.py -------------------------------------------------------------------------------- /versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/model_zoo/__init__.py: -------------------------------------------------------------------------------- 1 | from .common.get_model import get_model 2 | from .common.get_optimizer import get_optimizer 3 | from .common.get_scheduler import get_scheduler 4 | from .common.utils import get_unit 5 | -------------------------------------------------------------------------------- /versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/model_zoo/clip_justin/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import load 2 | -------------------------------------------------------------------------------- /versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/model_zoo/common/get_optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import numpy as np 4 | import itertools 5 | 6 | def singleton(class_): 7 | instances = {} 8 | def getinstance(*args, **kwargs): 9 | if class_ not in instances: 10 | instances[class_] = class_(*args, **kwargs) 11 | return instances[class_] 12 | return getinstance 13 | 14 | class get_optimizer(object): 15 | def __init__(self): 16 | self.optimizer = {} 17 | self.register(optim.SGD, 'sgd') 18 | self.register(optim.Adam, 'adam') 19 | self.register(optim.AdamW, 'adamw') 20 | 21 | def register(self, optim, name): 22 | self.optimizer[name] = optim 23 | 24 | def __call__(self, net, cfg): 25 | if cfg is None: 26 | return None 27 | t = cfg.type 28 | if isinstance(net, (torch.nn.DataParallel, 29 | torch.nn.parallel.DistributedDataParallel)): 30 | netm = net.module 31 | else: 32 | netm = net 33 | pg = getattr(netm, 'parameter_group', None) 34 | 35 | if pg is not None: 36 | params = [] 37 | for group_name, module_or_para in pg.items(): 38 | if not isinstance(module_or_para, list): 39 | module_or_para = [module_or_para] 40 | 41 | grouped_params = [mi.parameters() if isinstance(mi, torch.nn.Module) else [mi] for mi in module_or_para] 42 | grouped_params = itertools.chain(*grouped_params) 43 | pg_dict = {'params':grouped_params, 'name':group_name} 44 | params.append(pg_dict) 45 | else: 46 | params = net.parameters() 47 | return self.optimizer[t](params, lr=0, **cfg.args) 48 | -------------------------------------------------------------------------------- /versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/model_zoo/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/model_zoo/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class LitEma(nn.Module): 5 | def __init__(self, model, decay=0.9999, use_num_updates=True): 6 | super().__init__() 7 | if decay < 0.0 or decay > 1.0: 8 | raise ValueError('Decay must be between 0 and 1') 9 | 10 | self.m_name2s_name = {} 11 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 12 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_updates 13 | else torch.tensor(-1,dtype=torch.int)) 14 | 15 | for name, p in model.named_parameters(): 16 | if p.requires_grad: 17 | #remove as '.'-character is not allowed in buffers 18 | s_name = name.replace('.','') 19 | self.m_name2s_name.update({name:s_name}) 20 | self.register_buffer(s_name,p.clone().detach().data) 21 | 22 | self.collected_params = [] 23 | 24 | def forward(self, model): 25 | decay = self.decay 26 | 27 | if self.num_updates >= 0: 28 | self.num_updates += 1 29 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 30 | 31 | one_minus_decay = 1.0 - decay 32 | 33 | with torch.no_grad(): 34 | m_param = dict(model.named_parameters()) 35 | shadow_params = dict(self.named_buffers()) 36 | 37 | for key in m_param: 38 | if m_param[key].requires_grad: 39 | sname = self.m_name2s_name[key] 40 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 41 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 42 | else: 43 | assert not key in self.m_name2s_name 44 | 45 | def copy_to(self, model): 46 | m_param = dict(model.named_parameters()) 47 | shadow_params = dict(self.named_buffers()) 48 | for key in m_param: 49 | if m_param[key].requires_grad: 50 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 51 | else: 52 | assert not key in self.m_name2s_name 53 | 54 | def store(self, parameters): 55 | """ 56 | Save the current parameters for restoring later. 57 | Args: 58 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 59 | temporarily stored. 60 | """ 61 | self.collected_params = [param.clone() for param in parameters] 62 | 63 | def restore(self, parameters): 64 | """ 65 | Restore the parameters stored with the `store` method. 66 | Useful to validate the model with EMA parameters without affecting the 67 | original optimization process. Store the parameters before the 68 | `copy_to` method. After validation (or model saving), use this to 69 | restore the former parameters. 70 | Args: 71 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 72 | updated with the stored parameters. 73 | """ 74 | for c_param, param in zip(self.collected_params, parameters): 75 | param.data.copy_(c_param.data) 76 | -------------------------------------------------------------------------------- /versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/model_zoo/optimus_models/vocab/bert_vocab_download_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", 3 | "bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", 4 | "bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", 5 | "bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", 6 | "bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", 7 | "bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", 8 | "bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", 9 | "bert-base-german-cased": "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt", 10 | "bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt", 11 | "bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt", 12 | "bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt", 13 | "bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt", 14 | "bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt" 15 | } -------------------------------------------------------------------------------- /versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/model_zoo/optimus_models/vocab/gpt2_vocab_merge_download_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "vocab_file": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json", 3 | "merges_file": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt", 4 | } -------------------------------------------------------------------------------- /versatile_diffusion/main.py: -------------------------------------------------------------------------------- 1 | import torch.distributed as dist 2 | import torch.multiprocessing as mp 3 | 4 | import os 5 | import os.path as osp 6 | import sys 7 | import numpy as np 8 | import copy 9 | 10 | from lib.cfg_holder import cfg_unique_holder as cfguh 11 | from lib.cfg_helper import \ 12 | get_command_line_args, \ 13 | cfg_initiates 14 | 15 | from lib.model_zoo.sd import version 16 | from lib.utils import get_obj_from_str 17 | 18 | if __name__ == "__main__": 19 | cfg = get_command_line_args() 20 | cfg = cfg_initiates(cfg) 21 | 22 | if 'train' in cfg: 23 | trainer = get_obj_from_str(cfg.train.main)(cfg) 24 | tstage = get_obj_from_str(cfg.train.stage)() 25 | if 'eval' in cfg: 26 | tstage.nested_eval_stage = get_obj_from_str(cfg.eval.stage)() 27 | trainer.register_stage(tstage) 28 | if cfg.env.gpu_count == 1: 29 | trainer(0) 30 | else: 31 | mp.spawn(trainer, 32 | args=(), 33 | nprocs=cfg.env.gpu_count, 34 | join=True) 35 | trainer.destroy() 36 | else: 37 | evaler = get_obj_from_str(cfg.eval.main)(cfg) 38 | estage = get_obj_from_str(cfg.eval.stage)() 39 | evaler.register_stage(estage) 40 | if cfg.env.gpu_count == 1: 41 | evaler(0) 42 | else: 43 | mp.spawn(evaler, 44 | args=(), 45 | nprocs=cfg.env.gpu_count, 46 | join=True) 47 | evaler.destroy() 48 | -------------------------------------------------------------------------------- /versatile_diffusion/pretrained/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /versatile_diffusion/requirement.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.4.2 2 | pyyaml==5.4.1 3 | opencv-python==4.5.1.48 4 | easydict==1.9 5 | scikit-image==0.17.2 6 | tensorboardx==2.1 7 | tensorboard==1.15.0 8 | lpips==0.1.3 9 | 10 | tqdm==4.60.0 11 | transformers==4.24.0 12 | torchmetrics==0.7.3 13 | 14 | einops==0.3.0 15 | omegaconf==2.1.1 16 | -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers 17 | -------------------------------------------------------------------------------- /versatile_diffusion/requirement_colab.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.4.2 2 | pyyaml==5.4.1 3 | opencv-python==4.5.1.48 4 | easydict==1.9 5 | scikit-image==0.17.2 6 | lpips==0.1.3 7 | 8 | tqdm==4.60.0 9 | transformers==4.24.0 10 | torchmetrics==0.7.3 11 | 12 | einops==0.3.0 13 | omegaconf==2.1.1 14 | --------------------------------------------------------------------------------