├── .gitignore ├── configs ├── model │ ├── cnn.json │ ├── cnn_3ch.json │ ├── ddim_32.json │ ├── ddim_32_3ch.json │ ├── ddim_medium.json │ ├── ddim_medium_3ch.json │ ├── ddim_minimal.json │ ├── diffusion.json │ ├── mlp.json │ ├── resnet.json │ └── vae.json └── strategy │ ├── cifar │ ├── cnn_w_cumulative.json │ ├── cnn_w_diffusion.json │ ├── cnn_w_naive.json │ ├── diffusion_cumulative.json │ ├── diffusion_full_gen_distill.json │ ├── diffusion_full_gen_distill_10.json │ ├── diffusion_full_gen_distill_100.json │ ├── diffusion_gaussian_distill.json │ ├── diffusion_lwf_distill.json │ ├── diffusion_naive.json │ ├── diffusion_no_distill_preliminary_10.json │ ├── diffusion_no_distill_preliminary_100.json │ └── diffusion_no_distill_preliminary_2.json │ ├── cnn_w_cumulative.json │ ├── cnn_w_diffusion.json │ ├── cnn_w_diffusion_debug.json │ ├── cnn_w_er.json │ ├── cnn_w_ewc.json │ ├── cnn_w_lwf.json │ ├── cnn_w_naive.json │ ├── cnn_w_si.json │ ├── diffusion_cumulative.json │ ├── diffusion_debug.json │ ├── diffusion_er.json │ ├── diffusion_ewc.json │ ├── diffusion_full_gen_distill.json │ ├── diffusion_full_gen_distill_10.json │ ├── diffusion_full_gen_distill_100.json │ ├── diffusion_gaussian_distill.json │ ├── diffusion_gaussian_symmetry_distill.json │ ├── diffusion_lwf_distill.json │ ├── diffusion_naive.json │ ├── diffusion_no_distill.json │ ├── diffusion_no_distill_preliminary_10.json │ ├── diffusion_no_distill_preliminary_100.json │ ├── diffusion_no_distill_preliminary_2.json │ ├── diffusion_no_distill_preliminary_20.json │ ├── diffusion_no_distill_preliminary_5.json │ ├── diffusion_partial_gen_distill.json │ ├── diffusion_si.json │ ├── mlp.json │ ├── mlp_w_diffusion.json │ └── vae.json ├── readme.md ├── requirements.txt ├── scripts ├── condor │ ├── launch_cl_cifar10_experiments.sh │ ├── launch_cl_cifar10_targets.sh │ ├── launch_cl_fmnist_experiments.sh │ ├── launch_fid_vs_samples.sh │ ├── launch_fid_vs_teachersteps.sh │ ├── launch_fid_vs_time.sh │ ├── launch_generative_replay_fullgen_gensteps_experiments.sh │ ├── launch_generative_replay_fullgen_lambd_experiments.sh │ ├── launch_generative_replay_gaussian_lambd_experiments.sh │ ├── launch_generative_replay_partialgen_lambd_experiments.sh │ ├── launch_generative_replay_preliminary_experiments.sh │ ├── launch_generative_replay_targets.sh │ ├── launch_generative_replay_thesis_experiments.sh │ ├── launch_test_speed.sh │ ├── launch_train_classifier_cifar10.sh │ ├── launch_train_classifier_fmnist.sh │ ├── launch_train_diffusion_cifar10.sh │ ├── launch_train_diffusion_fmnist.sh │ ├── launch_train_gaussian_experiments.sh │ ├── launch_train_generation_best_experiments.sh │ ├── launch_train_generation_experiments.sh │ └── launch_train_generation_preliminary_experiments.sh ├── launch_cl_cifar10_experiments.sh └── launch_cl_fmnist_experiments.sh ├── src ├── __init__.py ├── common │ ├── __init__.py │ ├── diffusion_utils.py │ ├── utils.py │ └── visual.py ├── continual_learning │ ├── __init__.py │ ├── loggers.py │ ├── metrics │ │ ├── diffusion_metrics.py │ │ └── loss.py │ ├── plugins.py │ └── strategies.py ├── datasets │ ├── __init__.py │ ├── cifar10.py │ ├── cifar100.py │ ├── fashion_mnist.py │ └── mnist.py ├── models │ ├── __init__.py │ ├── simple_cnn.py │ └── vae.py ├── pipelines │ ├── __init__.py │ └── pipeline_ddim.py ├── schedulers │ ├── __init__.py │ └── scheduler_ddim.py └── standard_training │ ├── __init__.py │ ├── evaluators │ ├── __init__.py │ ├── base_evaluator.py │ └── generative_evaluator.py │ ├── losses │ ├── __init__.py │ └── diffusion_losses.py │ ├── trackers │ ├── __init__.py │ ├── base_tracker.py │ ├── csv_tracker.py │ └── wandb_tracker.py │ └── trainers │ ├── __init__.py │ ├── base_trainer.py │ ├── diffusion_distillation.py │ ├── diffusion_training.py │ └── generative_training.py ├── tests ├── __init__.py ├── continual_learning │ ├── __init__.py │ ├── metrics │ │ ├── __init__.py │ │ └── test_diffusion_metrics.py │ ├── test_loggers.py │ └── test_strategies.py ├── pipelines │ ├── __init__.py │ └── test_pipeline_ddim.py └── schedulers │ ├── __init__.py │ └── test_scheduler_ddim.py ├── train_cl.py ├── train_iid.py ├── train_kld_classifier.py ├── utils ├── compute_mnist_statistics.py ├── generate_auc_vs_teachersteps.py ├── generate_fid_accuracy.py ├── generate_fid_vs_samples.py ├── generate_fid_vs_time.py ├── generate_report_cl.py ├── generate_report_iid.py ├── save_cifar10_examples.py ├── save_fmnist_examples.py └── test_speed.py └── weights ├── cnn_cifar10 └── resnet.pth └── cnn_fmnist └── resnet.pth /.gitignore: -------------------------------------------------------------------------------- 1 | ## PyTorch 2 | 3 | .coverage 4 | coverage.xml 5 | .dmypy.json 6 | .gradle 7 | .hypothesis 8 | .mypy_cache 9 | /.extracted_scripts/ 10 | **/.pytorch_specified_test_cases.csv 11 | **/.pytorch-disabled-tests.json 12 | **/.pytorch-slow-tests.json 13 | **/.pytorch-test-times.json 14 | */*.pyc 15 | */*.so* 16 | */**/__pycache__ 17 | */**/*.dylib* 18 | */**/*.pyc 19 | */**/*.pyd 20 | */**/*.so* 21 | */**/**/*.pyc 22 | */**/**/**/*.pyc 23 | */**/**/**/**/*.pyc 24 | aten/build/ 25 | aten/src/ATen/Config.h 26 | aten/src/ATen/cuda/CUDAConfig.h 27 | benchmarks/.data 28 | caffe2/cpp_test/ 29 | dist/ 30 | docs/build/ 31 | docs/cpp/src 32 | docs/src/**/* 33 | docs/cpp/build 34 | docs/cpp/source/api 35 | docs/cpp/source/html/ 36 | docs/cpp/source/latex/ 37 | docs/source/generated/ 38 | log 39 | usage_log.txt 40 | test-reports/ 41 | test/*.bak 42 | test/**/*.bak 43 | test/.coverage 44 | test/.hypothesis/ 45 | test/cpp/api/mnist 46 | test/custom_operator/model.pt 47 | test/jit_hooks/*.pt 48 | test/data/legacy_modules.t7 49 | test/data/*.pt 50 | test/forward_backward_compatibility/nightly_schemas.txt 51 | dropout_model.pt 52 | test/generated_type_hints_smoketest.py 53 | test/htmlcov 54 | test/cpp_extensions/install/ 55 | third_party/build/ 56 | tools/coverage_plugins_package/pip-wheel-metadata/ 57 | tools/shared/_utils_internal.py 58 | tools/fast_nvcc/wrap_nvcc.sh 59 | tools/fast_nvcc/wrap_nvcc.bat 60 | tools/fast_nvcc/tmp/ 61 | torch.egg-info/ 62 | torch/_C/__init__.pyi 63 | torch/_C/_nn.pyi 64 | torch/_C/_VariableFunctions.pyi 65 | torch/_VF.pyi 66 | torch/return_types.pyi 67 | torch/nn/functional.pyi 68 | torch/utils/data/datapipes/datapipe.pyi 69 | torch/csrc/autograd/generated/* 70 | torch/csrc/lazy/generated/*.[!m]* 71 | torch_compile_debug/ 72 | # Listed manually because some files in this directory are not generated 73 | torch/testing/_internal/generated/annotated_fn_args.py 74 | torch/testing/_internal/data/*.pt 75 | torch/csrc/api/include/torch/version.h 76 | torch/csrc/cudnn/cuDNN.cpp 77 | torch/csrc/generated 78 | torch/csrc/generic/TensorMethods.cpp 79 | torch/csrc/jit/generated/* 80 | torch/csrc/jit/fuser/config.h 81 | torch/csrc/nn/THCUNN.cpp 82 | torch/csrc/nn/THCUNN.cwrap 83 | torch/bin/ 84 | torch/cmake/ 85 | torch/lib/*.a* 86 | torch/lib/*.dll* 87 | torch/lib/*.exe* 88 | torch/lib/*.dylib* 89 | torch/lib/*.h 90 | torch/lib/*.lib 91 | torch/lib/*.pdb 92 | torch/lib/*.so* 93 | torch/lib/protobuf*.pc 94 | torch/lib/build 95 | torch/lib/caffe2/ 96 | torch/lib/cmake 97 | torch/lib/include 98 | torch/lib/pkgconfig 99 | torch/lib/protoc 100 | torch/lib/protobuf/ 101 | torch/lib/tmp_install 102 | torch/lib/torch_shm_manager 103 | torch/lib/site-packages/ 104 | torch/lib/python* 105 | torch/lib64 106 | torch/include/ 107 | torch/share/ 108 | torch/test/ 109 | torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h 110 | torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h 111 | torch/version.py 112 | minifier_launcher.py 113 | # Root level file used in CI to specify certain env configs. 114 | # E.g., see .circleci/config.yaml 115 | env 116 | .circleci/scripts/COMMIT_MSG 117 | scripts/release_notes/*.json 118 | sccache-stats*.json 119 | 120 | # These files get copied over on invoking setup.py 121 | torchgen/packaged/* 122 | !torchgen/packaged/README.md 123 | 124 | # IPython notebook checkpoints 125 | .ipynb_checkpoints 126 | 127 | # Editor temporaries 128 | *.swa 129 | *.swb 130 | *.swc 131 | *.swd 132 | *.swe 133 | *.swf 134 | *.swg 135 | *.swh 136 | *.swi 137 | *.swj 138 | *.swk 139 | *.swl 140 | *.swm 141 | *.swn 142 | *.swo 143 | *.swp 144 | *~ 145 | .~lock.* 146 | 147 | # macOS dir files 148 | .DS_Store 149 | 150 | # Ninja files 151 | .ninja_deps 152 | .ninja_log 153 | compile_commands.json 154 | *.egg-info/ 155 | docs/source/scripts/activation_images/ 156 | docs/source/scripts/quantization_backend_configs/ 157 | 158 | ## General 159 | 160 | # Compiled Object files 161 | *.slo 162 | *.lo 163 | *.o 164 | *.cuo 165 | *.obj 166 | 167 | # Compiled Dynamic libraries 168 | *.so 169 | *.dylib 170 | *.dll 171 | 172 | # Compiled Static libraries 173 | *.lai 174 | *.la 175 | *.a 176 | *.lib 177 | 178 | # Compiled protocol buffers 179 | *.pb.h 180 | *.pb.cc 181 | *_pb2.py 182 | 183 | # Compiled python 184 | *.pyc 185 | *.pyd 186 | 187 | # Compiled MATLAB 188 | *.mex* 189 | 190 | # IPython notebook checkpoints 191 | .ipynb_checkpoints 192 | 193 | # Editor temporaries 194 | *.swn 195 | *.swo 196 | *.swp 197 | *~ 198 | 199 | # NFS handle files 200 | **/.nfs* 201 | 202 | # Sublime Text settings 203 | *.sublime-workspace 204 | *.sublime-project 205 | 206 | # Eclipse Project settings 207 | *.*project 208 | .settings 209 | 210 | # QtCreator files 211 | *.user 212 | 213 | # PyCharm files 214 | .idea 215 | 216 | # GDB history 217 | .gdb_history 218 | 219 | ## Caffe2 220 | 221 | # build, distribute, and bins (+ python proto bindings) 222 | build 223 | build_host_protoc 224 | build_android 225 | build_ios 226 | .build_debug/* 227 | .build_release/* 228 | .build_profile/* 229 | distribute/* 230 | *.testbin 231 | *.bin 232 | cmake_build 233 | .cmake_build 234 | gen 235 | .setuptools-cmake-build 236 | .pytest_cache 237 | aten/build/* 238 | 239 | # Bram 240 | plsdontbreak 241 | 242 | # Generated documentation 243 | docs/_site 244 | docs/gathered 245 | _site 246 | doxygen 247 | docs/dev 248 | 249 | # LevelDB files 250 | *.sst 251 | *.ldb 252 | LOCK 253 | CURRENT 254 | MANIFEST-* 255 | 256 | # generated version file 257 | caffe2/version.py 258 | 259 | # setup.py intermediates 260 | .eggs 261 | caffe2.egg-info 262 | MANIFEST 263 | 264 | # Atom/Watchman required file 265 | .watchmanconfig 266 | 267 | # Files generated by CLion 268 | cmake-build-debug 269 | 270 | # BEGIN NOT-CLEAN-FILES (setup.py handles this marker. Do not change.) 271 | # 272 | # Below files are not deleted by "setup.py clean". 273 | 274 | # Downloaded bazel 275 | tools/bazel 276 | 277 | # Visual Studio Code files 278 | .vs 279 | /.vscode/* 280 | !/.vscode/extensions.json 281 | !/.vscode/settings_recommended.json 282 | 283 | # YouCompleteMe config file 284 | .ycm_extra_conf.py 285 | 286 | # Files generated when a patch is rejected 287 | *.orig 288 | *.rej 289 | 290 | # Files generated by ctags 291 | CTAGS 292 | GTAGS 293 | GRTAGS 294 | GSYMS 295 | GPATH 296 | tags 297 | TAGS 298 | 299 | 300 | # ccls file 301 | .ccls-cache/ 302 | 303 | # clang tooling storage location 304 | .clang-format-bin 305 | .clang-tidy-bin 306 | .lintbin 307 | 308 | # clangd background index 309 | .clangd/ 310 | .cache/ 311 | 312 | # bazel symlinks 313 | bazel-* 314 | 315 | # xla repo 316 | xla/ 317 | 318 | # direnv, posh-direnv 319 | .env 320 | .envrc 321 | .psenvrc 322 | 323 | # generated shellcheck directories 324 | .shellcheck_generated*/ 325 | 326 | # zip archives 327 | *.zip 328 | 329 | # core dump files 330 | **/core.[1-9]* 331 | 332 | # Generated if you use the pre-commit script for clang-tidy 333 | pr.diff 334 | 335 | # coverage files 336 | */**/.coverage.* 337 | 338 | # buck generated files 339 | .buckd/ 340 | .lsp-buck-out/ 341 | .lsp.buckd/ 342 | buck-out/ 343 | 344 | # Downloaded libraries 345 | third_party/ruy/ 346 | third_party/glog/ 347 | 348 | # Virtualenv 349 | venv/ 350 | 351 | # Log files 352 | *.log 353 | sweep/ 354 | 355 | # Custom 356 | /*.png 357 | /results* 358 | /condor* 359 | /wandb 360 | /data -------------------------------------------------------------------------------- /configs/model/cnn.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "name": "cnn", 4 | "n_classes": 10, 5 | "channels": 1 6 | }, 7 | "optimizer": { 8 | "name": "adam", 9 | "lr": 0.001 10 | } 11 | } -------------------------------------------------------------------------------- /configs/model/cnn_3ch.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "name": "cnn", 4 | "n_classes": 10, 5 | "channels": 3 6 | }, 7 | "optimizer": { 8 | "name": "adam", 9 | "lr": 0.001 10 | } 11 | } -------------------------------------------------------------------------------- /configs/model/ddim_32.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "name": "unet2d", 4 | "input_size": 32, 5 | "in_channels": 1, 6 | "out_channels": 1, 7 | "layers_per_block": 2, 8 | "block_out_channels": [128, 256, 512, 1024], 9 | "norm_num_groups": 32, 10 | "down_block_types": ["DownBlock2D", "DownBlock2D", "DownBlock2D", "AttnDownBlock2D"], 11 | "up_block_types": ["AttnUpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D"] 12 | }, 13 | "scheduler": { 14 | "name": "DDIM", 15 | "train_timesteps": 1000 16 | }, 17 | "optimizer": { 18 | "name": "adam", 19 | "lr": 0.0002 20 | } 21 | } -------------------------------------------------------------------------------- /configs/model/ddim_32_3ch.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "name": "unet2d", 4 | "input_size": 32, 5 | "in_channels": 3, 6 | "out_channels": 3, 7 | "layers_per_block": 2, 8 | "block_out_channels": [128, 256, 512, 1024], 9 | "norm_num_groups": 32, 10 | "down_block_types": ["DownBlock2D", "DownBlock2D", "DownBlock2D", "AttnDownBlock2D"], 11 | "up_block_types": ["AttnUpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D"] 12 | }, 13 | "scheduler": { 14 | "name": "DDIM", 15 | "train_timesteps": 1000 16 | }, 17 | "optimizer": { 18 | "name": "adam", 19 | "lr": 0.0002 20 | } 21 | } -------------------------------------------------------------------------------- /configs/model/ddim_medium.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "name": "unet2d", 4 | "input_size": 32, 5 | "in_channels": 1, 6 | "out_channels": 1, 7 | "layers_per_block": 1, 8 | "block_out_channels": [64, 128, 256, 512], 9 | "norm_num_groups": 32, 10 | "down_block_types": ["DownBlock2D","DownBlock2D","DownBlock2D","AttnDownBlock2D"], 11 | "up_block_types": ["AttnUpBlock2D","UpBlock2D","UpBlock2D","UpBlock2D"] 12 | }, 13 | "scheduler": { 14 | "name": "DDIM", 15 | "train_timesteps": 1000 16 | }, 17 | "optimizer": { 18 | "name": "adam", 19 | "lr": 0.0002 20 | } 21 | } -------------------------------------------------------------------------------- /configs/model/ddim_medium_3ch.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "name": "unet2d", 4 | "input_size": 32, 5 | "in_channels": 3, 6 | "out_channels": 3, 7 | "layers_per_block": 2, 8 | "block_out_channels": [128, 256, 256, 256], 9 | "norm_num_groups": 32, 10 | "norm_eps": 0.000001, 11 | "freq_shift": 1, 12 | "attention_head_dim": null, 13 | "flip_sin_to_cos": false, 14 | "down_block_types": ["DownBlock2D","AttnDownBlock2D","DownBlock2D","DownBlock2D"], 15 | "up_block_types": ["UpBlock2D","UpBlock2D","AttnUpBlock2D","UpBlock2D"] 16 | }, 17 | "scheduler": { 18 | "name": "DDIM", 19 | "train_timesteps": 1000 20 | }, 21 | "optimizer": { 22 | "name": "adam", 23 | "lr": 0.0002 24 | } 25 | } -------------------------------------------------------------------------------- /configs/model/ddim_minimal.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "name": "unet2d", 4 | "input_size": 32, 5 | "in_channels": 1, 6 | "out_channels": 1, 7 | "layers_per_block": 1, 8 | "block_out_channels": [128, 256], 9 | "norm_num_groups": 32, 10 | "down_block_types": ["DownBlock2D","DownBlock2D"], 11 | "up_block_types": ["UpBlock2D","UpBlock2D"] 12 | }, 13 | "scheduler": { 14 | "name": "DDIM", 15 | "train_timesteps": 1000 16 | }, 17 | "optimizer": { 18 | "name": "adam", 19 | "lr": 0.001 20 | } 21 | } -------------------------------------------------------------------------------- /configs/model/diffusion.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "name": "unet2d", 4 | "input_size": 32, 5 | "in_channels": 1, 6 | "out_channels": 1, 7 | "layers_per_block": 1, 8 | "block_out_channels": [16, 32, 32, 64], 9 | "norm_num_groups": 16, 10 | "down_block_types": ["DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"], 11 | "up_block_types": ["AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"] 12 | }, 13 | "scheduler": { 14 | "name": "DDIM", 15 | "train_timesteps": 1000 16 | }, 17 | "optimizer": { 18 | "name": "adam", 19 | "lr": 0.003 20 | } 21 | } -------------------------------------------------------------------------------- /configs/model/mlp.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "name": "mlp", 4 | "input_size": 28, 5 | "channels": 1, 6 | "n_classes": 10 7 | }, 8 | "optimizer": { 9 | "name": "adam", 10 | "lr": 0.001 11 | } 12 | } -------------------------------------------------------------------------------- /configs/model/resnet.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "name": "resnet18", 4 | "n_classes": 10 5 | }, 6 | "optimizer": { 7 | "name": "adam", 8 | "lr": 0.001 9 | } 10 | } -------------------------------------------------------------------------------- /configs/model/vae.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "name": "mlpvae", 4 | "channels": 1, 5 | "input_size": 28, 6 | "encoder_dims": [400, 400], 7 | "decoder_dims": [400, 400], 8 | "latent_dim": 100, 9 | "n_classes": 10 10 | }, 11 | "optimizer": { 12 | "name": "adam", 13 | "lr": 0.001 14 | } 15 | } -------------------------------------------------------------------------------- /configs/strategy/cifar/cnn_w_cumulative.json: -------------------------------------------------------------------------------- 1 | { 2 | "strategy": "cumulative", 3 | "train_batch_size": 128, 4 | "eval_batch_size": 256, 5 | "epochs": 20 6 | } -------------------------------------------------------------------------------- /configs/strategy/cifar/cnn_w_diffusion.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": 128, 3 | "eval_batch_size": 256, 4 | "epochs": 20, 5 | "increasing_replay_size": false, 6 | "replay_size": null 7 | } -------------------------------------------------------------------------------- /configs/strategy/cifar/cnn_w_naive.json: -------------------------------------------------------------------------------- 1 | { 2 | "strategy": "naive", 3 | "train_batch_size": 128, 4 | "eval_batch_size": 256, 5 | "epochs": 20 6 | } -------------------------------------------------------------------------------- /configs/strategy/cifar/diffusion_cumulative.json: -------------------------------------------------------------------------------- 1 | { 2 | "strategy": "cumulative", 3 | "train_batch_size": 64, 4 | "eval_batch_size": 64, 5 | "epochs": 200 6 | } -------------------------------------------------------------------------------- /configs/strategy/cifar/diffusion_full_gen_distill.json: -------------------------------------------------------------------------------- 1 | { 2 | "strategy": "full_generation_distillation", 3 | "train_batch_size": 64, 4 | "eval_batch_size": 64, 5 | "epochs": 200, 6 | "increasing_replay_size": false, 7 | "replay_size": null, 8 | "replay_start_timestep": 0, 9 | "weight_replay_loss": true, 10 | "teacher_steps": 2, 11 | "teacher_eta": 0.0 12 | } -------------------------------------------------------------------------------- /configs/strategy/cifar/diffusion_full_gen_distill_10.json: -------------------------------------------------------------------------------- 1 | { 2 | "strategy": "full_generation_distillation", 3 | "train_batch_size": 64, 4 | "eval_batch_size": 64, 5 | "epochs": 200, 6 | "increasing_replay_size": false, 7 | "replay_size": null, 8 | "replay_start_timestep": 0, 9 | "weight_replay_loss": true, 10 | "teacher_steps": 10, 11 | "teacher_eta": 0.0 12 | } -------------------------------------------------------------------------------- /configs/strategy/cifar/diffusion_full_gen_distill_100.json: -------------------------------------------------------------------------------- 1 | { 2 | "strategy": "full_generation_distillation", 3 | "train_batch_size": 64, 4 | "eval_batch_size": 64, 5 | "epochs": 200, 6 | "increasing_replay_size": false, 7 | "replay_size": null, 8 | "replay_start_timestep": 0, 9 | "weight_replay_loss": true, 10 | "teacher_steps": 100, 11 | "teacher_eta": 0.0 12 | } -------------------------------------------------------------------------------- /configs/strategy/cifar/diffusion_gaussian_distill.json: -------------------------------------------------------------------------------- 1 | { 2 | "strategy": "gaussian_distillation", 3 | "train_batch_size": 64, 4 | "eval_batch_size": 64, 5 | "epochs": 200, 6 | "increasing_replay_size": false, 7 | "replay_size": null, 8 | "replay_start_timestep": 0, 9 | "weight_replay_loss": true 10 | } -------------------------------------------------------------------------------- /configs/strategy/cifar/diffusion_lwf_distill.json: -------------------------------------------------------------------------------- 1 | { 2 | "strategy": "lwf_distillation", 3 | "train_batch_size": 64, 4 | "eval_batch_size": 64, 5 | "epochs": 200, 6 | "increasing_replay_size": false, 7 | "replay_size": null, 8 | "replay_start_timestep": 0, 9 | "weight_replay_loss": true 10 | } -------------------------------------------------------------------------------- /configs/strategy/cifar/diffusion_naive.json: -------------------------------------------------------------------------------- 1 | { 2 | "strategy": "naive", 3 | "train_batch_size": 64, 4 | "eval_batch_size": 64, 5 | "epochs": 200 6 | } -------------------------------------------------------------------------------- /configs/strategy/cifar/diffusion_no_distill_preliminary_10.json: -------------------------------------------------------------------------------- 1 | { 2 | "strategy": "no_distillation", 3 | "train_batch_size": 64, 4 | "eval_batch_size": 64, 5 | "epochs": 200, 6 | "increasing_replay_size": false, 7 | "replay_size": null, 8 | "replay_start_timestep": 0, 9 | "weight_replay_loss": true, 10 | "teacher_steps": 10, 11 | "teacher_eta": 0.0 12 | } -------------------------------------------------------------------------------- /configs/strategy/cifar/diffusion_no_distill_preliminary_100.json: -------------------------------------------------------------------------------- 1 | { 2 | "strategy": "no_distillation", 3 | "train_batch_size": 64, 4 | "eval_batch_size": 64, 5 | "epochs": 200, 6 | "increasing_replay_size": false, 7 | "replay_size": null, 8 | "replay_start_timestep": 0, 9 | "weight_replay_loss": true, 10 | "teacher_steps": 100, 11 | "teacher_eta": 0.0 12 | } -------------------------------------------------------------------------------- /configs/strategy/cifar/diffusion_no_distill_preliminary_2.json: -------------------------------------------------------------------------------- 1 | { 2 | "strategy": "no_distillation", 3 | "train_batch_size": 64, 4 | "eval_batch_size": 64, 5 | "epochs": 200, 6 | "increasing_replay_size": false, 7 | "replay_size": null, 8 | "replay_start_timestep": 0, 9 | "weight_replay_loss": true, 10 | "teacher_steps": 2, 11 | "teacher_eta": 0.0 12 | } -------------------------------------------------------------------------------- /configs/strategy/cnn_w_cumulative.json: -------------------------------------------------------------------------------- 1 | { 2 | "strategy": "cumulative", 3 | "train_batch_size": 128, 4 | "eval_batch_size": 256, 5 | "epochs": 20 6 | } -------------------------------------------------------------------------------- /configs/strategy/cnn_w_diffusion.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": 128, 3 | "eval_batch_size": 256, 4 | "epochs": 20, 5 | "increasing_replay_size": false, 6 | "replay_size": null 7 | } -------------------------------------------------------------------------------- /configs/strategy/cnn_w_diffusion_debug.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": 128, 3 | "eval_batch_size": 256, 4 | "epochs": 1, 5 | "increasing_replay_size": false, 6 | "replay_size": null 7 | } -------------------------------------------------------------------------------- /configs/strategy/cnn_w_er.json: -------------------------------------------------------------------------------- 1 | { 2 | "strategy": "er", 3 | "train_batch_size": 128, 4 | "eval_batch_size": 256, 5 | "epochs": 20, 6 | "buffer_size": 300, 7 | } -------------------------------------------------------------------------------- /configs/strategy/cnn_w_ewc.json: -------------------------------------------------------------------------------- 1 | { 2 | "strategy": "ewc", 3 | "train_batch_size": 128, 4 | "eval_batch_size": 256, 5 | "epochs": 20, 6 | "ewc_lambda": 1.0, 7 | "mode": "separate", 8 | "decay_factor": null, 9 | "keep_importance_data": false 10 | } -------------------------------------------------------------------------------- /configs/strategy/cnn_w_lwf.json: -------------------------------------------------------------------------------- 1 | { 2 | "strategy": "lwf", 3 | "train_batch_size": 128, 4 | "eval_batch_size": 256, 5 | "epochs": 20, 6 | "lwf_alpha": [0, 0.5, 1.33333, 2.25, 3.2], 7 | "lwf_temperature": 2.0 8 | } -------------------------------------------------------------------------------- /configs/strategy/cnn_w_naive.json: -------------------------------------------------------------------------------- 1 | { 2 | "strategy": "naive", 3 | "train_batch_size": 128, 4 | "eval_batch_size": 256, 5 | "epochs": 20 6 | } -------------------------------------------------------------------------------- /configs/strategy/cnn_w_si.json: -------------------------------------------------------------------------------- 1 | { 2 | "strategy": "si", 3 | "train_batch_size": 128, 4 | "eval_batch_size": 256, 5 | "epochs": 20, 6 | "si_lambda": 1.0, 7 | "eps": 0.0000001 8 | } -------------------------------------------------------------------------------- /configs/strategy/diffusion_cumulative.json: -------------------------------------------------------------------------------- 1 | { 2 | "strategy": "cumulative", 3 | "train_batch_size": 128, 4 | "eval_batch_size": 128, 5 | "epochs": 100 6 | } -------------------------------------------------------------------------------- /configs/strategy/diffusion_debug.json: -------------------------------------------------------------------------------- 1 | { 2 | "strategy": "full_generation_distillation", 3 | "train_batch_size": 128, 4 | "eval_batch_size": 128, 5 | "epochs": 1, 6 | "increasing_replay_size": false, 7 | "replay_size": null, 8 | "lambd": 0.5, 9 | "replay_start_timestep": 0, 10 | "weight_replay_loss": true, 11 | "teacher_steps": 2, 12 | "teacher_eta": 0.0 13 | } -------------------------------------------------------------------------------- /configs/strategy/diffusion_er.json: -------------------------------------------------------------------------------- 1 | { 2 | "strategy": "er", 3 | "train_batch_size": 128, 4 | "eval_batch_size": 128, 5 | "epochs": 100 6 | } -------------------------------------------------------------------------------- /configs/strategy/diffusion_ewc.json: -------------------------------------------------------------------------------- 1 | { 2 | "strategy": "ewc", 3 | "train_batch_size": 128, 4 | "eval_batch_size": 128, 5 | "epochs": 100, 6 | "ewc_lambda": 1.0, 7 | "mode": "separate", 8 | "decay_factor": null, 9 | "keep_importance_data": false 10 | } -------------------------------------------------------------------------------- /configs/strategy/diffusion_full_gen_distill.json: -------------------------------------------------------------------------------- 1 | { 2 | "strategy": "full_generation_distillation", 3 | "train_batch_size": 128, 4 | "eval_batch_size": 128, 5 | "epochs": 100, 6 | "increasing_replay_size": false, 7 | "replay_size": null, 8 | "replay_start_timestep": 0, 9 | "weight_replay_loss": true, 10 | "teacher_steps": 2, 11 | "teacher_eta": 0.0 12 | } -------------------------------------------------------------------------------- /configs/strategy/diffusion_full_gen_distill_10.json: -------------------------------------------------------------------------------- 1 | { 2 | "strategy": "full_generation_distillation", 3 | "train_batch_size": 128, 4 | "eval_batch_size": 128, 5 | "epochs": 100, 6 | "increasing_replay_size": false, 7 | "replay_size": null, 8 | "replay_start_timestep": 0, 9 | "weight_replay_loss": true, 10 | "teacher_steps": 10, 11 | "teacher_eta": 0.0 12 | } -------------------------------------------------------------------------------- /configs/strategy/diffusion_full_gen_distill_100.json: -------------------------------------------------------------------------------- 1 | { 2 | "strategy": "full_generation_distillation", 3 | "train_batch_size": 128, 4 | "eval_batch_size": 128, 5 | "epochs": 100, 6 | "increasing_replay_size": false, 7 | "replay_size": null, 8 | "replay_start_timestep": 0, 9 | "weight_replay_loss": true, 10 | "teacher_steps": 100, 11 | "teacher_eta": 0.0 12 | } -------------------------------------------------------------------------------- /configs/strategy/diffusion_gaussian_distill.json: -------------------------------------------------------------------------------- 1 | { 2 | "strategy": "gaussian_distillation", 3 | "train_batch_size": 128, 4 | "eval_batch_size": 128, 5 | "epochs": 100, 6 | "increasing_replay_size": false, 7 | "replay_size": null, 8 | "replay_start_timestep": 0, 9 | "weight_replay_loss": true 10 | } -------------------------------------------------------------------------------- /configs/strategy/diffusion_gaussian_symmetry_distill.json: -------------------------------------------------------------------------------- 1 | { 2 | "strategy": "gaussian_symmetry_distillation", 3 | "train_batch_size": 128, 4 | "eval_batch_size": 128, 5 | "epochs": 100, 6 | "increasing_replay_size": false, 7 | "replay_size": null, 8 | "replay_start_timestep": 0, 9 | "weight_replay_loss": true 10 | } -------------------------------------------------------------------------------- /configs/strategy/diffusion_lwf_distill.json: -------------------------------------------------------------------------------- 1 | { 2 | "strategy": "lwf_distillation", 3 | "train_batch_size": 128, 4 | "eval_batch_size": 128, 5 | "epochs": 100, 6 | "increasing_replay_size": false, 7 | "replay_size": null, 8 | "replay_start_timestep": 0, 9 | "weight_replay_loss": true 10 | } -------------------------------------------------------------------------------- /configs/strategy/diffusion_naive.json: -------------------------------------------------------------------------------- 1 | { 2 | "strategy": "naive", 3 | "train_batch_size": 128, 4 | "eval_batch_size": 128, 5 | "epochs": 100 6 | } -------------------------------------------------------------------------------- /configs/strategy/diffusion_no_distill.json: -------------------------------------------------------------------------------- 1 | { 2 | "strategy": "no_distillation", 3 | "train_batch_size": 128, 4 | "eval_batch_size": 128, 5 | "epochs": 100, 6 | "increasing_replay_size": false, 7 | "replay_size": null, 8 | "replay_start_timestep": 0, 9 | "weight_replay_loss": true, 10 | "teacher_steps": 2, 11 | "teacher_eta": 0.0 12 | } -------------------------------------------------------------------------------- /configs/strategy/diffusion_no_distill_preliminary_10.json: -------------------------------------------------------------------------------- 1 | { 2 | "strategy": "no_distillation", 3 | "train_batch_size": 128, 4 | "eval_batch_size": 128, 5 | "epochs": 100, 6 | "increasing_replay_size": false, 7 | "replay_size": null, 8 | "replay_start_timestep": 0, 9 | "weight_replay_loss": true, 10 | "teacher_steps": 10, 11 | "teacher_eta": 0.0 12 | } -------------------------------------------------------------------------------- /configs/strategy/diffusion_no_distill_preliminary_100.json: -------------------------------------------------------------------------------- 1 | { 2 | "strategy": "no_distillation", 3 | "train_batch_size": 128, 4 | "eval_batch_size": 128, 5 | "epochs": 100, 6 | "increasing_replay_size": false, 7 | "replay_size": null, 8 | "replay_start_timestep": 0, 9 | "weight_replay_loss": true, 10 | "teacher_steps": 100, 11 | "teacher_eta": 0.0 12 | } -------------------------------------------------------------------------------- /configs/strategy/diffusion_no_distill_preliminary_2.json: -------------------------------------------------------------------------------- 1 | { 2 | "strategy": "no_distillation", 3 | "train_batch_size": 128, 4 | "eval_batch_size": 128, 5 | "epochs": 100, 6 | "increasing_replay_size": false, 7 | "replay_size": null, 8 | "replay_start_timestep": 0, 9 | "weight_replay_loss": true, 10 | "teacher_steps": 2, 11 | "teacher_eta": 0.0 12 | } -------------------------------------------------------------------------------- /configs/strategy/diffusion_no_distill_preliminary_20.json: -------------------------------------------------------------------------------- 1 | { 2 | "strategy": "no_distillation", 3 | "train_batch_size": 128, 4 | "eval_batch_size": 128, 5 | "epochs": 100, 6 | "increasing_replay_size": false, 7 | "replay_size": null, 8 | "replay_start_timestep": 0, 9 | "weight_replay_loss": true, 10 | "teacher_steps": 20, 11 | "teacher_eta": 0.0 12 | } -------------------------------------------------------------------------------- /configs/strategy/diffusion_no_distill_preliminary_5.json: -------------------------------------------------------------------------------- 1 | { 2 | "strategy": "no_distillation", 3 | "train_batch_size": 128, 4 | "eval_batch_size": 128, 5 | "epochs": 100, 6 | "increasing_replay_size": false, 7 | "replay_size": null, 8 | "replay_start_timestep": 0, 9 | "weight_replay_loss": true, 10 | "teacher_steps": 5, 11 | "teacher_eta": 0.0 12 | } -------------------------------------------------------------------------------- /configs/strategy/diffusion_partial_gen_distill.json: -------------------------------------------------------------------------------- 1 | { 2 | "strategy": "partial_generation_distillation", 3 | "train_batch_size": 128, 4 | "eval_batch_size": 128, 5 | "epochs": 100, 6 | "increasing_replay_size": false, 7 | "replay_size": null, 8 | "replay_start_timestep": 0, 9 | "weight_replay_loss": true, 10 | "teacher_steps": 2, 11 | "teacher_eta": 0.0 12 | } -------------------------------------------------------------------------------- /configs/strategy/diffusion_si.json: -------------------------------------------------------------------------------- 1 | { 2 | "strategy": "si", 3 | "train_batch_size": 128, 4 | "eval_batch_size": 128, 5 | "epochs": 100, 6 | "si_lambda": 1.0, 7 | "eps": 0.0000001 8 | } -------------------------------------------------------------------------------- /configs/strategy/mlp.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": 128, 3 | "eval_batch_size": 256, 4 | "epochs": 21, 5 | "increasing_replay_size": false, 6 | "replay_size": null 7 | } -------------------------------------------------------------------------------- /configs/strategy/mlp_w_diffusion.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": 128, 3 | "eval_batch_size": 256, 4 | "epochs": 4, 5 | "increasing_replay_size": false, 6 | "replay_size": null 7 | } -------------------------------------------------------------------------------- /configs/strategy/vae.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": 128, 3 | "eval_batch_size": 256, 4 | "epochs": 21, 5 | "increasing_replay_size": false, 6 | "replay_size": null 7 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.4.0 2 | accelerate==0.19.0 3 | aiohttp==3.8.4 4 | aiosignal==1.3.1 5 | appdirs==1.4.4 6 | async-timeout==4.0.2 7 | attrs==23.1.0 8 | avalanche-lib==0.3.1 9 | beautifulsoup4==4.12.2 10 | brotlipy==0.7.0 11 | cachetools==5.3.0 12 | click==8.1.3 13 | contourpy==1.0.7 14 | cycler==0.11.0 15 | datasets==2.12.0 16 | diffusers==0.16.1 17 | dill==0.3.6 18 | docker-pycreds==0.4.0 19 | einops==0.6.1 20 | fonttools==4.39.4 21 | frozenlist==1.3.3 22 | fsspec==2023.5.0 23 | gdown==4.7.1 24 | gitdb==4.0.10 25 | GitPython==3.1.31 26 | google-auth==2.18.1 27 | google-auth-oauthlib==1.0.0 28 | GPUtil==1.4.0 29 | grpcio==1.54.2 30 | huggingface-hub==0.14.1 31 | importlib-metadata==6.6.0 32 | importlib-resources==5.12.0 33 | joblib==1.2.0 34 | kiwisolver==1.4.4 35 | Markdown==3.4.3 36 | matplotlib==3.7.1 37 | mkl-fft==1.3.6 38 | mkl-service==2.4.0 39 | mpmath==1.2.1 40 | multidict==6.0.4 41 | multiprocess==0.70.14 42 | oauthlib==3.2.2 43 | pandas==2.0.1 44 | pathtools==0.1.2 45 | Pillow==9.4.0 46 | protobuf==4.23.1 47 | psutil==5.9.5 48 | pyarrow==12.0.0 49 | pyasn1==0.5.0 50 | pyasn1-modules==0.3.0 51 | pyparsing==3.0.9 52 | pytorchcv==0.0.67 53 | pytz==2023.3 54 | PyYAML==6.0 55 | pyzmq==19.0.2 56 | quadprog==0.1.11 57 | regex==2023.5.5 58 | requests-oauthlib==1.3.1 59 | responses==0.18.0 60 | rsa==4.9 61 | scikit-learn==1.2.2 62 | scipy==1.10.1 63 | sentry-sdk==1.23.1 64 | setproctitle==1.3.2 65 | smmap==5.0.0 66 | soupsieve==2.4.1 67 | tensorboard==2.13.0 68 | tensorboard-data-server==0.7.0 69 | threadpoolctl==3.1.0 70 | torch==2.0.1 71 | torchaudio==2.0.2 72 | torchmetrics==0.11.4 73 | torchvision==0.15.2 74 | tqdm==4.65.0 75 | triton==2.0.0 76 | tzdata==2023.3 77 | wandb==0.15.3 78 | Werkzeug==2.3.4 79 | xxhash==3.2.0 80 | yarl==1.9.2 81 | zipp==3.15.0 82 | torch-fidelity==0.3.0 83 | munch==3.0.0 -------------------------------------------------------------------------------- /scripts/condor/launch_cl_cifar10_experiments.sh: -------------------------------------------------------------------------------- 1 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/cifar/diffusion_no_distill_preliminary_2.json" --generator_config_path "configs/model/ddim_medium_3ch.json" --generation_steps 20 --seed 42 --output_dir "results_fuji/smasipca/continual_learning/teacher_steps_2" --solver_strategy_config_path "configs/strategy/cifar/cnn_w_diffusion.json" --solver_config_path "configs/model/cnn_3ch.json" --dataset split_cifar10 --wandb' 2 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/cifar/diffusion_no_distill_preliminary_2.json" --generator_config_path "configs/model/ddim_medium_3ch.json" --generation_steps 20 --seed 69 --output_dir "results_fuji/smasipca/continual_learning/teacher_steps_2" --solver_strategy_config_path "configs/strategy/cifar/cnn_w_diffusion.json" --solver_config_path "configs/model/cnn_3ch.json" --dataset split_cifar10 --wandb' 3 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/cifar/diffusion_no_distill_preliminary_2.json" --generator_config_path "configs/model/ddim_medium_3ch.json" --generation_steps 20 --seed 1714 --output_dir "results_fuji/smasipca/continual_learning/teacher_steps_2" --solver_strategy_config_path "configs/strategy/cifar/cnn_w_diffusion.json" --solver_config_path "configs/model/cnn_3ch.json" --dataset split_cifar10 --wandb' 4 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/cifar/diffusion_no_distill_preliminary_10.json" --generator_config_path "configs/model/ddim_medium_3ch.json" --generation_steps 20 --seed 42 --output_dir "results_fuji/smasipca/continual_learning/teacher_steps_10" --solver_strategy_config_path "configs/strategy/cifar/cnn_w_diffusion.json" --solver_config_path "configs/model/cnn_3ch.json" --dataset split_cifar10 --wandb' 5 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/cifar/diffusion_no_distill_preliminary_10.json" --generator_config_path "configs/model/ddim_medium_3ch.json" --generation_steps 20 --seed 69 --output_dir "results_fuji/smasipca/continual_learning/teacher_steps_10" --solver_strategy_config_path "configs/strategy/cifar/cnn_w_diffusion.json" --solver_config_path "configs/model/cnn_3ch.json" --dataset split_cifar10 --wandb' 6 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/cifar/diffusion_no_distill_preliminary_10.json" --generator_config_path "configs/model/ddim_medium_3ch.json" --generation_steps 20 --seed 1714 --output_dir "results_fuji/smasipca/continual_learning/teacher_steps_10" --solver_strategy_config_path "configs/strategy/cifar/cnn_w_diffusion.json" --solver_config_path "configs/model/cnn_3ch.json" --dataset split_cifar10 --wandb' 7 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/cifar/diffusion_no_distill_preliminary_100.json" --generator_config_path "configs/model/ddim_medium_3ch.json" --generation_steps 20 --seed 42 --output_dir "results_fuji/smasipca/continual_learning/teacher_steps_100" --solver_strategy_config_path "configs/strategy/cifar/cnn_w_diffusion.json" --solver_config_path "configs/model/cnn_3ch.json" --dataset split_cifar10 --wandb' 8 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/cifar/diffusion_no_distill_preliminary_100.json" --generator_config_path "configs/model/ddim_medium_3ch.json" --generation_steps 20 --seed 69 --output_dir "results_fuji/smasipca/continual_learning/teacher_steps_100" --solver_strategy_config_path "configs/strategy/cifar/cnn_w_diffusion.json" --solver_config_path "configs/model/cnn_3ch.json" --dataset split_cifar10 --wandb' 9 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/cifar/diffusion_no_distill_preliminary_100.json" --generator_config_path "configs/model/ddim_medium_3ch.json" --generation_steps 20 --seed 1714 --output_dir "results_fuji/smasipca/continual_learning/teacher_steps_100" --solver_strategy_config_path "configs/strategy/cifar/cnn_w_diffusion.json" --solver_config_path "configs/model/cnn_3ch.json" --dataset split_cifar10 --wandb' 10 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/cifar/diffusion_full_gen_distill.json" --generator_config_path "configs/model/ddim_medium_3ch.json" --generation_steps 20 --lambd 0.75 --seed 42 --output_dir results_fuji/smasipca/continual_learning/teacher_steps_2 --solver_strategy_config_path "configs/strategy/cifar/cnn_w_diffusion.json" --solver_config_path "configs/model/cnn_3ch.json" --dataset split_cifar10 --wandb' 11 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/cifar/diffusion_full_gen_distill.json" --generator_config_path "configs/model/ddim_medium_3ch.json" --generation_steps 20 --lambd 0.75 --seed 69 --output_dir results_fuji/smasipca/continual_learning/teacher_steps_2 --solver_strategy_config_path "configs/strategy/cifar/cnn_w_diffusion.json" --solver_config_path "configs/model/cnn_3ch.json" --dataset split_cifar10 --wandb' 12 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/cifar/diffusion_full_gen_distill.json" --generator_config_path "configs/model/ddim_medium_3ch.json" --generation_steps 20 --lambd 0.75 --seed 1714 --output_dir results_fuji/smasipca/continual_learning/teacher_steps_2 --solver_strategy_config_path "configs/strategy/cifar/cnn_w_diffusion.json" --solver_config_path "configs/model/cnn_3ch.json" --dataset split_cifar10 --wandb' 13 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/cifar/diffusion_full_gen_distill_10.json" --generator_config_path "configs/model/ddim_medium_3ch.json" --generation_steps 20 --lambd 0.75 --seed 42 --output_dir results_fuji/smasipca/continual_learning/teacher_steps_10 --solver_strategy_config_path "configs/strategy/cifar/cnn_w_diffusion.json" --solver_config_path "configs/model/cnn_3ch.json" --dataset split_cifar10 --wandb' 14 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/cifar/diffusion_full_gen_distill_10.json" --generator_config_path "configs/model/ddim_medium_3ch.json" --generation_steps 20 --lambd 0.75 --seed 69 --output_dir results_fuji/smasipca/continual_learning/teacher_steps_10 --solver_strategy_config_path "configs/strategy/cifar/cnn_w_diffusion.json" --solver_config_path "configs/model/cnn_3ch.json" --dataset split_cifar10 --wandb' 15 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/cifar/diffusion_full_gen_distill_10.json" --generator_config_path "configs/model/ddim_medium_3ch.json" --generation_steps 20 --lambd 0.75 --seed 1714 --output_dir results_fuji/smasipca/continual_learning/teacher_steps_10 --solver_strategy_config_path "configs/strategy/cifar/cnn_w_diffusion.json" --solver_config_path "configs/model/cnn_3ch.json" --dataset split_cifar10 --wandb' 16 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/cifar/diffusion_full_gen_distill_100.json" --generator_config_path "configs/model/ddim_medium_3ch.json" --generation_steps 20 --lambd 0.75 --seed 42 --output_dir results_fuji/smasipca/continual_learning/teacher_steps_100 --solver_strategy_config_path "configs/strategy/cifar/cnn_w_diffusion.json" --solver_config_path "configs/model/cnn_3ch.json" --dataset split_cifar10 --wandb' 17 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/cifar/diffusion_full_gen_distill_100.json" --generator_config_path "configs/model/ddim_medium_3ch.json" --generation_steps 20 --lambd 0.75 --seed 69 --output_dir results_fuji/smasipca/continual_learning/teacher_steps_100 --solver_strategy_config_path "configs/strategy/cifar/cnn_w_diffusion.json" --solver_config_path "configs/model/cnn_3ch.json" --dataset split_cifar10 --wandb' 18 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/cifar/diffusion_full_gen_distill_100.json" --generator_config_path "configs/model/ddim_medium_3ch.json" --generation_steps 20 --lambd 0.75 --seed 1714 --output_dir results_fuji/smasipca/continual_learning/teacher_steps_100 --solver_strategy_config_path "configs/strategy/cifar/cnn_w_diffusion.json" --solver_config_path "configs/model/cnn_3ch.json" --dataset split_cifar10 --wandb' 19 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/cifar/diffusion_gaussian_distill.json" --generator_config_path "configs/model/ddim_medium_3ch.json" --lambd 8 --generation_steps 20 --seed -1 --output_dir results_fuji/smasipca/continual_learning/ --solver_strategy_config_path "configs/strategy/cifar/cnn_w_diffusion.json" --solver_config_path "configs/model/cnn_3ch.json" --dataset split_cifar10 --wandb' 20 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/cifar/diffusion_lwf_distill.json" --generator_config_path "configs/model/ddim_medium_3ch.json" --generation_steps 20 --lambd 1.25 --seed -1 --output_dir results_fuji/smasipca/continual_learning/ --solver_strategy_config_path "configs/strategy/cifar/cnn_w_diffusion.json" --solver_config_path "configs/model/cnn_3ch.json" --dataset split_cifar10 --wandb' -------------------------------------------------------------------------------- /scripts/condor/launch_cl_cifar10_targets.sh: -------------------------------------------------------------------------------- 1 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_config_path "configs/model/ddim_medium_3ch.json" --generator_strategy_config_path "configs/strategy/cifar/diffusion_cumulative.json" --generation_steps 20 --seed 69 --output_dir "results_fuji/smasipca/continual_learning/targets/" --solver_type None --dataset split_cifar10 --wandb' 2 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_config_path "configs/model/ddim_medium_3ch.json" --generator_strategy_config_path "configs/strategy/cifar/diffusion_naive.json" --generation_steps 20 --seed 69 --output_dir "results_fuji/smasipca/continual_learning/targets/" --solver_type None --dataset split_cifar10 --wandb' 3 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_config_path "configs/model/ddim_medium_3ch.json" --solver_strategy_config_path "configs/strategy/cifar/cnn_w_cumulative.json" --seed 69 --output_dir "results_fuji/smasipca/continual_learning/targets/" --generator_type None --solver_type cnn --solver_config_path "configs/model/cnn_3ch.json" --dataset split_cifar10 --wandb' 4 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_config_path "configs/model/ddim_medium_3ch.json" --solver_strategy_config_path "configs/strategy/cifar/cnn_w_naive.json" --seed 69 --output_dir "results_fuji/smasipca/continual_learning/targets/" --generator_type None --solver_type cnn --solver_config_path "configs/model/cnn_3ch.json" --dataset split_cifar10 --wandb' 5 | -------------------------------------------------------------------------------- /scripts/condor/launch_fid_vs_samples.sh: -------------------------------------------------------------------------------- 1 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python fid_vs_samples.py' -------------------------------------------------------------------------------- /scripts/condor/launch_fid_vs_teachersteps.sh: -------------------------------------------------------------------------------- 1 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python fid_vs_teachersteps.py --distillation_type generation' 2 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python fid_vs_teachersteps.py --distillation_type partial_generation' 3 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python fid_vs_teachersteps.py --distillation_type no_distillation' -------------------------------------------------------------------------------- /scripts/condor/launch_fid_vs_time.sh: -------------------------------------------------------------------------------- 1 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python fid_vs_time.py' -------------------------------------------------------------------------------- /scripts/condor/launch_generative_replay_fullgen_gensteps_experiments.sh: -------------------------------------------------------------------------------- 1 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_full_gen_distill.json" --generation_steps 2 --lambd 3.0 --seed -1 --output_dir results_fuji/smasipca/generative_replay_gensteps/ --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json"' 2 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_full_gen_distill.json" --generation_steps 5 --lambd 3.0 --seed -1 --output_dir results_fuji/smasipca/generative_replay_gensteps/ --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json"' 3 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_full_gen_distill.json" --generation_steps 10 --lambd 3.0 --seed -1 --output_dir results_fuji/smasipca/generative_replay_gensteps/ --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json"' 4 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_full_gen_distill.json" --generation_steps 20 --lambd 3.0 --seed -1 --output_dir results_fuji/smasipca/generative_replay_gensteps/ --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json"' 5 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_full_gen_distill.json" --generation_steps 2 --lambd 4.0 --seed -1 --output_dir results_fuji/smasipca/generative_replay_gensteps/ --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json"' 6 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_full_gen_distill.json" --generation_steps 5 --lambd 4.0 --seed -1 --output_dir results_fuji/smasipca/generative_replay_gensteps/ --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json"' 7 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_full_gen_distill.json" --generation_steps 10 --lambd 4.0 --seed -1 --output_dir results_fuji/smasipca/generative_replay_gensteps/ --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json"' 8 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_full_gen_distill.json" --generation_steps 20 --lambd 4.0 --seed -1 --output_dir results_fuji/smasipca/generative_replay_gensteps/ --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json"' 9 | -------------------------------------------------------------------------------- /scripts/condor/launch_generative_replay_fullgen_lambd_experiments.sh: -------------------------------------------------------------------------------- 1 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_full_gen_distill.json" --generation_steps 10 --lambd 1.0 --seed 42 --wandb --output_dir results_fuji/smasipca/generative_replay_single/ --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json"' 2 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_full_gen_distill.json" --generation_steps 10 --lambd 2.0 --seed 42 --wandb --output_dir results_fuji/smasipca/generative_replay_single/ --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json"' 3 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_full_gen_distill.json" --generation_steps 10 --lambd 3.0 --seed 42 --wandb --output_dir results_fuji/smasipca/generative_replay_single/ --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json"' 4 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_full_gen_distill.json" --generation_steps 10 --lambd 4.0 --seed 42 --wandb --output_dir results_fuji/smasipca/generative_replay_single/ --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json"' 5 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_full_gen_distill.json" --generation_steps 10 --lambd 5.0 --seed 42 --wandb --output_dir results_fuji/smasipca/generative_replay_single/ --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json"' 6 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_full_gen_distill.json" --generation_steps 10 --lambd 6.0 --seed 42 --wandb --output_dir results_fuji/smasipca/generative_replay_single/ --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json"' 7 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_full_gen_distill.json" --generation_steps 10 --lambd 7.0 --seed 42 --wandb --output_dir results_fuji/smasipca/generative_replay_single/ --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json"' 8 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_full_gen_distill.json" --generation_steps 10 --lambd 8.0 --seed 42 --wandb --output_dir results_fuji/smasipca/generative_replay_single/ --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json"' 9 | -------------------------------------------------------------------------------- /scripts/condor/launch_generative_replay_gaussian_lambd_experiments.sh: -------------------------------------------------------------------------------- 1 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_gaussian_distill.json" --lambd 1 --generation_steps 10 --seed 42 --wandb --output_dir results_fuji/smasipca/generative_replay_single/ --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json"' 2 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_gaussian_distill.json" --lambd 4 --generation_steps 10 --seed 42 --wandb --output_dir results_fuji/smasipca/generative_replay_single/ --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json"' 3 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_gaussian_distill.json" --lambd 8 --generation_steps 10 --seed 42 --wandb --output_dir results_fuji/smasipca/generative_replay_single/ --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json"' 4 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_gaussian_distill.json" --lambd 12 --generation_steps 10 --seed 42 --wandb --output_dir results_fuji/smasipca/generative_replay_single/ --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json"' 5 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_gaussian_distill.json" --lambd 16 --generation_steps 10 --seed 42 --wandb --output_dir results_fuji/smasipca/generative_replay_single/ --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json"' 6 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_gaussian_distill.json" --lambd 20 --generation_steps 10 --seed 42 --wandb --output_dir results_fuji/smasipca/generative_replay_single/ --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json"' 7 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_gaussian_symmetry_distill.json" --lambd 1 --generation_steps 10 --seed 42 --wandb --output_dir results_fuji/smasipca/generative_replay_single/ --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json"' 8 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_gaussian_symmetry_distill.json" --lambd 4 --generation_steps 10 --seed 42 --wandb --output_dir results_fuji/smasipca/generative_replay_single/ --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json"' 9 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_gaussian_symmetry_distill.json" --lambd 8 --generation_steps 10 --seed 42 --wandb --output_dir results_fuji/smasipca/generative_replay_single/ --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json"' 10 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_gaussian_symmetry_distill.json" --lambd 12 --generation_steps 10 --seed 42 --wandb --output_dir results_fuji/smasipca/generative_replay_single/ --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json"' 11 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_gaussian_symmetry_distill.json" --lambd 16 --generation_steps 10 --seed 42 --wandb --output_dir results_fuji/smasipca/generative_replay_single/ --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json"' 12 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_gaussian_symmetry_distill.json" --lambd 20 --generation_steps 10 --seed 42 --wandb --output_dir results_fuji/smasipca/generative_replay_single/ --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json"' 13 | -------------------------------------------------------------------------------- /scripts/condor/launch_generative_replay_partialgen_lambd_experiments.sh: -------------------------------------------------------------------------------- 1 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_partial_gen_distill.json" --generation_steps 10 --lambd 1.0 --seed 42 --wandb --output_dir results_fuji/smasipca/generative_replay_single/ --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json"' 2 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_partial_gen_distill.json" --generation_steps 10 --lambd 2.0 --seed 42 --wandb --output_dir results_fuji/smasipca/generative_replay_single/ --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json"' 3 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_partial_gen_distill.json" --generation_steps 10 --lambd 3.0 --seed 42 --wandb --output_dir results_fuji/smasipca/generative_replay_single/ --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json"' 4 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_partial_gen_distill.json" --generation_steps 10 --lambd 4.0 --seed 42 --wandb --output_dir results_fuji/smasipca/generative_replay_single/ --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json"' 5 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_partial_gen_distill.json" --generation_steps 10 --lambd 5.0 --seed 42 --wandb --output_dir results_fuji/smasipca/generative_replay_single/ --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json"' 6 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_partial_gen_distill.json" --generation_steps 10 --lambd 6.0 --seed 42 --wandb --output_dir results_fuji/smasipca/generative_replay_single/ --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json"' 7 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_partial_gen_distill.json" --generation_steps 10 --lambd 7.0 --seed 42 --wandb --output_dir results_fuji/smasipca/generative_replay_single/ --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json"' 8 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_partial_gen_distill.json" --generation_steps 10 --lambd 8.0 --seed 42 --wandb --output_dir results_fuji/smasipca/generative_replay_single/ --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json"' 9 | -------------------------------------------------------------------------------- /scripts/condor/launch_generative_replay_preliminary_experiments.sh: -------------------------------------------------------------------------------- 1 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_no_distill_preliminary_20.json" --generation_steps 20 --seed -1 --output_dir "results_fuji/smasipca/generative_replay_preliminary_acc/" --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json"' 2 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_no_distill_preliminary_10.json" --generation_steps 10 --seed -1 --output_dir "results_fuji/smasipca/generative_replay_preliminary_acc/" --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json"' 3 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_no_distill_preliminary_5.json" --generation_steps 5 --seed -1 --output_dir "results_fuji/smasipca/generative_replay_preliminary_acc/" --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json"' 4 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_no_distill_preliminary_2.json" --generation_steps 2 --seed -1 --output_dir "results_fuji/smasipca/generative_replay_preliminary_acc/" --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json"' 5 | -------------------------------------------------------------------------------- /scripts/condor/launch_generative_replay_targets.sh: -------------------------------------------------------------------------------- 1 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_cumulative.json" --generation_steps 10 --seed -1 --output_dir "results_fuji/smasipca/generative_replay_targets/" --solver_type None' 2 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_naive.json" --generation_steps 10 --seed -1 --output_dir "results_fuji/smasipca/generative_replay_targets/" --solver_type None' 3 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --solver_strategy_config_path "configs/strategy/cnn_w_cumulative.json" --seed -1 --output_dir "results_fuji/smasipca/generative_replay_targets/" --generator_type None --solver_type cnn' 4 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_cl.py --solver_strategy_config_path "configs/strategy/cnn_w_naive.json" --seed -1 --output_dir "results_fuji/smasipca/generative_replay_targets/" --generator_type None --solver_type cnn' 5 | -------------------------------------------------------------------------------- /scripts/condor/launch_test_speed.sh: -------------------------------------------------------------------------------- 1 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python test_speed.py' -------------------------------------------------------------------------------- /scripts/condor/launch_train_classifier_cifar10.sh: -------------------------------------------------------------------------------- 1 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_kld_classifier.py --dataset "CIFAR10" --output_path "./results/cnn_cifar10"' -------------------------------------------------------------------------------- /scripts/condor/launch_train_classifier_fmnist.sh: -------------------------------------------------------------------------------- 1 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_kld_classifier.py' -------------------------------------------------------------------------------- /scripts/condor/launch_train_diffusion_cifar10.sh: -------------------------------------------------------------------------------- 1 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_iid.py --dataset "cifar10" --model_config_path "configs/model/ddim_32_3ch.json" --num_epochs 500 --save_every 50 --seed 42 --use_wandb --training_type diffusion' -------------------------------------------------------------------------------- /scripts/condor/launch_train_diffusion_fmnist.sh: -------------------------------------------------------------------------------- 1 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_iid.py -dataset "fashion_mnist" --model_config_path "configs/model/ddim_medium.json" --num_epochs 100' -------------------------------------------------------------------------------- /scripts/condor/launch_train_gaussian_experiments.sh: -------------------------------------------------------------------------------- 1 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_iid.py --model_config_path "configs/model/ddim_medium.json" --distillation_type gaussian --save_every 1000 --num_epochs 20000' 2 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_iid.py --model_config_path "configs/model/ddim_medium.json" --distillation_type gaussian_symmetry --save_every 1000 --num_epochs 20000' -------------------------------------------------------------------------------- /scripts/condor/launch_train_generation_best_experiments.sh: -------------------------------------------------------------------------------- 1 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_iid.py --model_config_path "configs/model/ddim_medium.json" --distillation_type generation --save_every 1000 --num_epochs 20000 --teacher_generation_steps 10' 2 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_iid.py --model_config_path "configs/model/ddim_medium.json" --distillation_type partial_generation --save_every 1000 --num_epochs 20000 --teacher_generation_steps 5' 3 | -------------------------------------------------------------------------------- /scripts/condor/launch_train_generation_experiments.sh: -------------------------------------------------------------------------------- 1 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_iid.py --model_config_path "configs/model/ddim_medium.json" --distillation_type generation --save_every 1000 --num_epochs 20000 --teacher_generation_steps 2 --teacher_eta 0.0 --training_type diffusion' 2 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_iid.py --model_config_path "configs/model/ddim_medium.json" --distillation_type partial_generation --save_every 1000 --num_epochs 20000 --teacher_generation_steps 2 --teacher_eta 0.0 --training_type diffusion' 3 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_iid.py --model_config_path "configs/model/ddim_medium.json" --distillation_type no_distillation --save_every 1000 --num_epochs 20000 --teacher_generation_steps 2 --teacher_eta 0.0 --training_type diffusion' -------------------------------------------------------------------------------- /scripts/condor/launch_train_generation_preliminary_experiments.sh: -------------------------------------------------------------------------------- 1 | condor_send -c 'CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES#CUDA} python train_iid.py --model_config_path "configs/model/ddim_medium.json" --distillation_type no_distillation --save_every 1000 --num_epochs 20000 --teacher_generation_steps 20 --teacher_eta 0.0' -------------------------------------------------------------------------------- /scripts/launch_cl_cifar10_experiments.sh: -------------------------------------------------------------------------------- 1 | python train_cl.py --generator_strategy_config_path "configs/strategy/cifar/diffusion_no_distill_preliminary_2.json" --generator_config_path "configs/model/ddim_medium_3ch.json" --generation_steps 20 --seed 42 --output_dir "results/continual_learning/teacher_steps_2" --solver_strategy_config_path "configs/strategy/cifar/cnn_w_diffusion.json" --solver_config_path "configs/model/cnn_3ch.json" --dataset split_cifar10 --kld_clf_path "weights/cnn_cifar10/" 2 | python train_cl.py --generator_strategy_config_path "configs/strategy/cifar/diffusion_no_distill_preliminary_2.json" --generator_config_path "configs/model/ddim_medium_3ch.json" --generation_steps 20 --seed 69 --output_dir "results/continual_learning/teacher_steps_2" --solver_strategy_config_path "configs/strategy/cifar/cnn_w_diffusion.json" --solver_config_path "configs/model/cnn_3ch.json" --dataset split_cifar10 --kld_clf_path "weights/cnn_cifar10/" 3 | python train_cl.py --generator_strategy_config_path "configs/strategy/cifar/diffusion_no_distill_preliminary_2.json" --generator_config_path "configs/model/ddim_medium_3ch.json" --generation_steps 20 --seed 1714 --output_dir "results/continual_learning/teacher_steps_2" --solver_strategy_config_path "configs/strategy/cifar/cnn_w_diffusion.json" --solver_config_path "configs/model/cnn_3ch.json" --dataset split_cifar10 --kld_clf_path "weights/cnn_cifar10/" 4 | python train_cl.py --generator_strategy_config_path "configs/strategy/cifar/diffusion_no_distill_preliminary_10.json" --generator_config_path "configs/model/ddim_medium_3ch.json" --generation_steps 20 --seed 42 --output_dir "results/continual_learning/teacher_steps_10" --solver_strategy_config_path "configs/strategy/cifar/cnn_w_diffusion.json" --solver_config_path "configs/model/cnn_3ch.json" --dataset split_cifar10 --kld_clf_path "weights/cnn_cifar10/" 5 | python train_cl.py --generator_strategy_config_path "configs/strategy/cifar/diffusion_no_distill_preliminary_10.json" --generator_config_path "configs/model/ddim_medium_3ch.json" --generation_steps 20 --seed 69 --output_dir "results/continual_learning/teacher_steps_10" --solver_strategy_config_path "configs/strategy/cifar/cnn_w_diffusion.json" --solver_config_path "configs/model/cnn_3ch.json" --dataset split_cifar10 --kld_clf_path "weights/cnn_cifar10/" 6 | python train_cl.py --generator_strategy_config_path "configs/strategy/cifar/diffusion_no_distill_preliminary_10.json" --generator_config_path "configs/model/ddim_medium_3ch.json" --generation_steps 20 --seed 1714 --output_dir "results/continual_learning/teacher_steps_10" --solver_strategy_config_path "configs/strategy/cifar/cnn_w_diffusion.json" --solver_config_path "configs/model/cnn_3ch.json" --dataset split_cifar10 --kld_clf_path "weights/cnn_cifar10/" 7 | python train_cl.py --generator_strategy_config_path "configs/strategy/cifar/diffusion_no_distill_preliminary_100.json" --generator_config_path "configs/model/ddim_medium_3ch.json" --generation_steps 20 --seed 42 --output_dir "results/continual_learning/teacher_steps_100" --solver_strategy_config_path "configs/strategy/cifar/cnn_w_diffusion.json" --solver_config_path "configs/model/cnn_3ch.json" --dataset split_cifar10 --kld_clf_path "weights/cnn_cifar10/" 8 | python train_cl.py --generator_strategy_config_path "configs/strategy/cifar/diffusion_no_distill_preliminary_100.json" --generator_config_path "configs/model/ddim_medium_3ch.json" --generation_steps 20 --seed 69 --output_dir "results/continual_learning/teacher_steps_100" --solver_strategy_config_path "configs/strategy/cifar/cnn_w_diffusion.json" --solver_config_path "configs/model/cnn_3ch.json" --dataset split_cifar10 --kld_clf_path "weights/cnn_cifar10/" 9 | python train_cl.py --generator_strategy_config_path "configs/strategy/cifar/diffusion_no_distill_preliminary_100.json" --generator_config_path "configs/model/ddim_medium_3ch.json" --generation_steps 20 --seed 1714 --output_dir "results/continual_learning/teacher_steps_100" --solver_strategy_config_path "configs/strategy/cifar/cnn_w_diffusion.json" --solver_config_path "configs/model/cnn_3ch.json" --dataset split_cifar10 --kld_clf_path "weights/cnn_cifar10/" 10 | python train_cl.py --generator_strategy_config_path "configs/strategy/cifar/diffusion_full_gen_distill.json" --generator_config_path "configs/model/ddim_medium_3ch.json" --generation_steps 20 --lambd 0.75 --seed 42 --output_dir results/continual_learning/teacher_steps_2 --solver_strategy_config_path "configs/strategy/cifar/cnn_w_diffusion.json" --solver_config_path "configs/model/cnn_3ch.json" --dataset split_cifar10 --kld_clf_path "weights/cnn_cifar10/" 11 | python train_cl.py --generator_strategy_config_path "configs/strategy/cifar/diffusion_full_gen_distill.json" --generator_config_path "configs/model/ddim_medium_3ch.json" --generation_steps 20 --lambd 0.75 --seed 69 --output_dir results/continual_learning/teacher_steps_2 --solver_strategy_config_path "configs/strategy/cifar/cnn_w_diffusion.json" --solver_config_path "configs/model/cnn_3ch.json" --dataset split_cifar10 --kld_clf_path "weights/cnn_cifar10/" 12 | python train_cl.py --generator_strategy_config_path "configs/strategy/cifar/diffusion_full_gen_distill.json" --generator_config_path "configs/model/ddim_medium_3ch.json" --generation_steps 20 --lambd 0.75 --seed 1714 --output_dir results/continual_learning/teacher_steps_2 --solver_strategy_config_path "configs/strategy/cifar/cnn_w_diffusion.json" --solver_config_path "configs/model/cnn_3ch.json" --dataset split_cifar10 --kld_clf_path "weights/cnn_cifar10/" 13 | python train_cl.py --generator_strategy_config_path "configs/strategy/cifar/diffusion_full_gen_distill_10.json" --generator_config_path "configs/model/ddim_medium_3ch.json" --generation_steps 20 --lambd 0.75 --seed 42 --output_dir results/continual_learning/teacher_steps_10 --solver_strategy_config_path "configs/strategy/cifar/cnn_w_diffusion.json" --solver_config_path "configs/model/cnn_3ch.json" --dataset split_cifar10 --kld_clf_path "weights/cnn_cifar10/" 14 | python train_cl.py --generator_strategy_config_path "configs/strategy/cifar/diffusion_full_gen_distill_10.json" --generator_config_path "configs/model/ddim_medium_3ch.json" --generation_steps 20 --lambd 0.75 --seed 69 --output_dir results/continual_learning/teacher_steps_10 --solver_strategy_config_path "configs/strategy/cifar/cnn_w_diffusion.json" --solver_config_path "configs/model/cnn_3ch.json" --dataset split_cifar10 --kld_clf_path "weights/cnn_cifar10/" 15 | python train_cl.py --generator_strategy_config_path "configs/strategy/cifar/diffusion_full_gen_distill_10.json" --generator_config_path "configs/model/ddim_medium_3ch.json" --generation_steps 20 --lambd 0.75 --seed 1714 --output_dir results/continual_learning/teacher_steps_10 --solver_strategy_config_path "configs/strategy/cifar/cnn_w_diffusion.json" --solver_config_path "configs/model/cnn_3ch.json" --dataset split_cifar10 --kld_clf_path "weights/cnn_cifar10/" 16 | python train_cl.py --generator_strategy_config_path "configs/strategy/cifar/diffusion_full_gen_distill_100.json" --generator_config_path "configs/model/ddim_medium_3ch.json" --generation_steps 20 --lambd 0.75 --seed 42 --output_dir results/continual_learning/teacher_steps_100 --solver_strategy_config_path "configs/strategy/cifar/cnn_w_diffusion.json" --solver_config_path "configs/model/cnn_3ch.json" --dataset split_cifar10 --kld_clf_path "weights/cnn_cifar10/" 17 | python train_cl.py --generator_strategy_config_path "configs/strategy/cifar/diffusion_full_gen_distill_100.json" --generator_config_path "configs/model/ddim_medium_3ch.json" --generation_steps 20 --lambd 0.75 --seed 69 --output_dir results/continual_learning/teacher_steps_100 --solver_strategy_config_path "configs/strategy/cifar/cnn_w_diffusion.json" --solver_config_path "configs/model/cnn_3ch.json" --dataset split_cifar10 --kld_clf_path "weights/cnn_cifar10/" 18 | python train_cl.py --generator_strategy_config_path "configs/strategy/cifar/diffusion_full_gen_distill_100.json" --generator_config_path "configs/model/ddim_medium_3ch.json" --generation_steps 20 --lambd 0.75 --seed 1714 --output_dir results/continual_learning/teacher_steps_100 --solver_strategy_config_path "configs/strategy/cifar/cnn_w_diffusion.json" --solver_config_path "configs/model/cnn_3ch.json" --dataset split_cifar10 --kld_clf_path "weights/cnn_cifar10/" 19 | python train_cl.py --generator_strategy_config_path "configs/strategy/cifar/diffusion_gaussian_distill.json" --generator_config_path "configs/model/ddim_medium_3ch.json" --lambd 8 --generation_steps 20 --seed -1 --output_dir results/continual_learning/ --solver_strategy_config_path "configs/strategy/cifar/cnn_w_diffusion.json" --solver_config_path "configs/model/cnn_3ch.json" --dataset split_cifar10 --kld_clf_path "weights/cnn_cifar10/" 20 | python train_cl.py --generator_strategy_config_path "configs/strategy/cifar/diffusion_lwf_distill.json" --generator_config_path "configs/model/ddim_medium_3ch.json" --generation_steps 20 --lambd 1.25 --seed -1 --output_dir results/continual_learning/ --solver_strategy_config_path "configs/strategy/cifar/cnn_w_diffusion.json" --solver_config_path "configs/model/cnn_3ch.json" --dataset split_cifar10 --kld_clf_path "weights/cnn_cifar10/" -------------------------------------------------------------------------------- /scripts/launch_cl_fmnist_experiments.sh: -------------------------------------------------------------------------------- 1 | python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_full_gen_distill.json" --generation_steps 10 --lambd 0.1 --seed -1 --output_dir "results/continual_learning/" --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json" 2 | python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_full_gen_distill.json" --generation_steps 10 --lambd 0.25 --seed -1 --output_dir "results/continual_learning/" --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json" 3 | python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_full_gen_distill.json" --generation_steps 10 --lambd 0.5 --seed -1 --output_dir "results/continual_learning/" --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json" 4 | python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_full_gen_distill.json" --generation_steps 10 --lambd 0.75 --seed -1 --output_dir "results/continual_learning/" --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json" 5 | python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_full_gen_distill.json" --generation_steps 10 --lambd 1.0 --seed -1 --output_dir "results/continual_learning/" --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json" 6 | python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_full_gen_distill.json" --generation_steps 10 --lambd 1.25 --seed -1 --output_dir "results/continual_learning/" --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json" 7 | python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_full_gen_distill.json" --generation_steps 10 --lambd 1.5 --seed -1 --output_dir "results/continual_learning/" --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json" 8 | python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_full_gen_distill.json" --generation_steps 10 --lambd 1.75 --seed -1 --output_dir "results/continual_learning/" --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json" 9 | python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_full_gen_distill.json" --generation_steps 10 --lambd 2.0 --seed -1 --output_dir "results/continual_learning/" --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json" 10 | python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_full_gen_distill.json" --generation_steps 10 --lambd 2.25 --seed -1 --output_dir "results/continual_learning/" --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json" 11 | python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_full_gen_distill.json" --generation_steps 10 --lambd 2.5 --seed -1 --output_dir "results/continual_learning/" --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json" 12 | python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_full_gen_distill.json" --generation_steps 10 --lambd 2.75 --seed -1 --output_dir "results/continual_learning/" --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json" 13 | python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_full_gen_distill.json" --generation_steps 10 --lambd 3.0 --seed -1 --output_dir "results/continual_learning/" --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json" 14 | python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_gaussian_distill.json" --lambd 1 --generation_steps 10 --seed -1 --output_dir "results/continual_learning/" --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json" 15 | python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_gaussian_distill.json" --lambd 4 --generation_steps 10 --seed -1 --output_dir "results/continual_learning/" --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json" 16 | python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_gaussian_distill.json" --lambd 8 --generation_steps 10 --seed -1 --output_dir "results/continual_learning/" --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json" 17 | python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_gaussian_distill.json" --lambd 12 --generation_steps 10 --seed -1 --output_dir "results/continual_learning/" --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json" 18 | python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_gaussian_distill.json" --lambd 16 --generation_steps 10 --seed -1 --output_dir "results/continual_learning/" --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json" 19 | python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_gaussian_distill.json" --lambd 20 --generation_steps 10 --seed -1 --output_dir "results/continual_learning/" --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json" 20 | python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_gaussian_distill.json" --lambd 24 --generation_steps 10 --seed -1 --output_dir "results/continual_learning/" --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json" 21 | python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_gaussian_distill.json" --lambd 28 --generation_steps 10 --seed -1 --output_dir "results/continual_learning/" --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json" 22 | python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_lwf_distill.json" --generation_steps 10 --lambd 0.1 --seed -1 --output_dir "results/continual_learning/" --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json" 23 | python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_lwf_distill.json" --generation_steps 10 --lambd 0.25 --seed -1 --output_dir "results/continual_learning/" --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json" 24 | python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_lwf_distill.json" --generation_steps 10 --lambd 0.5 --seed -1 --output_dir "results/continual_learning/" --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json" 25 | python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_lwf_distill.json" --generation_steps 10 --lambd 0.75 --seed -1 --output_dir "results/continual_learning/" --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json" 26 | python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_lwf_distill.json" --generation_steps 10 --lambd 1.0 --seed -1 --output_dir "results/continual_learning/" --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json" 27 | python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_lwf_distill.json" --generation_steps 10 --lambd 1.25 --seed -1 --output_dir "results/continual_learning/" --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json" 28 | python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_lwf_distill.json" --generation_steps 10 --lambd 1.5 --seed -1 --output_dir "results/continual_learning/" --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json" 29 | python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_lwf_distill.json" --generation_steps 10 --lambd 1.75 --seed -1 --output_dir "results/continual_learning/" --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json" 30 | python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_no_distill_preliminary_2.json" --generation_steps 10 --seed -1 --output_dir "results/continual_learning/teacher_steps_2" --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json" 31 | python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_no_distill_preliminary_10.json" --generation_steps 10 --seed -1 --output_dir "results/continual_learning/teacher_steps_10" --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json" 32 | python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_no_distill_preliminary_100.json" --generation_steps 10 --seed -1 --output_dir "results/continual_learning/teacher_steps_100" --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json" 33 | python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_full_gen_distill_10.json" --generation_steps 10 --lambd 0.75 --seed -1 --output_dir "results/continual_learning/teacher_steps_10" --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json" 34 | python train_cl.py --generator_strategy_config_path "configs/strategy/diffusion_full_gen_distill_100.json" --generation_steps 10 --lambd 0.75 --seed -1 --output_dir "results/continual_learning/teacher_steps_100" --solver_strategy_config_path "configs/strategy/cnn_w_diffusion.json" 35 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Atenrev/diffusion_continual_learning/0d88ac5dea1ed7595b3f1f2f333b3bc1ce555f53/src/__init__.py -------------------------------------------------------------------------------- /src/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Atenrev/diffusion_continual_learning/0d88ac5dea1ed7595b3f1f2f333b3bc1ce555f53/src/common/__init__.py -------------------------------------------------------------------------------- /src/common/diffusion_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import types 3 | import torch 4 | import torch.optim.lr_scheduler 5 | 6 | from typing import List, Union, Optional 7 | from PIL import Image 8 | 9 | 10 | def wrap_in_pipeline(model, scheduler, pipeline_class, num_inference_steps: int, default_eta: float = 0.0, def_output_type: str = "torch"): 11 | """ 12 | Wrap a model in a pipeline for sampling. 13 | 14 | Args: 15 | model: The model to wrap. 16 | scheduler: The scheduler to use for sampling. 17 | pipeline_class: The pipeline class to use. 18 | num_inference_steps: The number of inference steps to use. 19 | eta: The eta value to use. 20 | def_output_type: The output type to use. Options are "torch", "torch_raw", and "pil". Defaults to "torch". 21 | """ 22 | assert def_output_type in ["torch", "torch_raw", "pil"], f"Invalid output type {def_output_type}" 23 | 24 | def generate(self, batch_size: int, target_steps: Union[List[int], int] = 0, generation_steps: int = num_inference_steps, eta: float = default_eta, output_type: str = def_output_type, seed: Optional[int] = None) -> torch.Tensor: 25 | if seed is not None: 26 | generator = torch.Generator(device=self.device) 27 | generator.manual_seed(seed) 28 | else: 29 | generator = None 30 | pipeline = pipeline_class(unet=self, scheduler=scheduler) 31 | pipeline.set_progress_bar_config(disable=True) 32 | samples = pipeline( 33 | batch_size, 34 | generator=generator, 35 | num_inference_steps=generation_steps, 36 | eta=eta, 37 | output_type=output_type, 38 | target_steps=target_steps, 39 | ) 40 | return samples 41 | 42 | model.generate = types.MethodType(generate, model) 43 | 44 | 45 | def make_grid(images, rows, cols): 46 | w, h = images[0].size 47 | grid = Image.new("RGB", size=(cols * w, rows * h)) 48 | for i, image in enumerate(images): 49 | grid.paste(image, box=(i % cols * w, i // cols * h)) 50 | return grid 51 | 52 | 53 | def generate_diffusion_samples(output_dir, eval_batch_size, epoch, generator, generation_steps: int = 20, eta: float = 0.0, seed: Optional[int] = None): 54 | generated_images = generator.generate(eval_batch_size, generation_steps=generation_steps, output_type="torch", eta=eta, seed=seed) 55 | # To PIL image 56 | generated_images = generated_images.mul(255).to(torch.uint8) 57 | generated_images = generated_images.permute(0, 2, 3, 1).cpu().numpy() 58 | generated_images = [Image.fromarray(img.squeeze()) for img in generated_images] 59 | nrows = int(eval_batch_size**0.5) 60 | ncols = eval_batch_size // nrows + eval_batch_size % nrows 61 | generated_images = make_grid(generated_images, rows=nrows, cols=ncols) 62 | 63 | # Save the images 64 | samples_dir = os.path.join(output_dir, "samples") 65 | os.makedirs(samples_dir, exist_ok=True) 66 | generated_images.save(f"{samples_dir}/{epoch:04d}.png", quality=100, subsampling=0) -------------------------------------------------------------------------------- /src/common/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | from typing import Any 4 | from munch import DefaultMunch 5 | 6 | 7 | def get_configuration(yaml_path: str) -> Any: 8 | with open(yaml_path, "r") as f: 9 | yaml_dict = yaml.safe_load(f) 10 | return DefaultMunch.fromDict(yaml_dict) 11 | 12 | 13 | def extract_into_tensor(arr, timesteps, broadcast_shape): 14 | """ 15 | Extract values from a 1-D numpy array for a batch of indices. 16 | 17 | :param arr: the 1-D numpy array. 18 | :param timesteps: a tensor of indices into the array to extract. 19 | :param broadcast_shape: a larger shape of K dimensions with the batch 20 | dimension equal to the length of timesteps. 21 | :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. 22 | """ 23 | res = arr.to(timesteps.device)[timesteps].float() 24 | while len(res.shape) < len(broadcast_shape): 25 | res = res[..., None] 26 | return res.expand(broadcast_shape) -------------------------------------------------------------------------------- /src/common/visual.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | 5 | def plot_bar(x, y, x_label, y_label, title, save_path, color='skyblue', y_labels=None, size=(14, 8)): 6 | # Set up the figure and axis 7 | fig, ax = plt.subplots() 8 | fig.set_size_inches(size[0], size[1]) 9 | colors = ['black', 'skyblue', 'orange', 'green', 'red', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan', 'yellow', 'magenta'] 10 | width = 1.0 11 | num_sublists = 1 12 | 13 | if isinstance(y[0], list) or isinstance(y[0], np.ndarray): 14 | # Plotting multiple bars for each sublist 15 | num_sublists = len(y) 16 | width = 1.0 / (num_sublists+2) # Adjust the width based on the number of sublists 17 | x_positions = np.arange(len(x)) 18 | 19 | for i, sublist_y in enumerate(y): 20 | lbl = f"{y_label} {i+1}" 21 | if y_labels is not None: 22 | lbl = y_labels[i] 23 | ax.bar(x_positions + i * width, sublist_y, width=width, label=lbl, color=colors[i]) 24 | 25 | ax.legend() 26 | 27 | else: 28 | # Plotting a single bar graph 29 | ax.bar(x, y, color=color) 30 | 31 | # Adding labels and title 32 | ax.set_xlabel(x_label) 33 | ax.set_ylabel(y_label) 34 | ax.set_title(title) 35 | 36 | # Adjusting the appearance 37 | ax.spines['top'].set_visible(False) 38 | ax.spines['right'].set_visible(False) 39 | 40 | # Rotating the x-axis labels if necessary 41 | # plt.xticks(rotation=45) 42 | 43 | # Setting the x-axis tick positions and labels 44 | ax.set_xticks([i + width * (num_sublists-1) / 2 for i in range(len(x))]) 45 | ax.set_xticklabels(x) 46 | 47 | # Save the graph to disk 48 | plt.tight_layout() 49 | plt.savefig(save_path) 50 | plt.close() 51 | 52 | 53 | def plot_line_graph(x, y, x_label, y_label, title, save_path, color='skyblue', log_x=False, x_ticks=None, second_x=None, second_x_label=None, y_lim=None, size=(8, 6)): 54 | # Set up the figure and axis 55 | fig, ax = plt.subplots() 56 | fig.set_size_inches(size[0], size[1]) 57 | 58 | # Plotting the line graph with small markers 59 | ax.plot(x, y, marker='o', linestyle='-', color=color, markersize=3) 60 | 61 | # Setting the x-axis scale 62 | if log_x: 63 | ax.set_xscale('log') 64 | ax.set_xticks(x) 65 | ax.get_xaxis().set_major_formatter(plt.ScalarFormatter()) 66 | 67 | # Setting the x-axis tick positions and labels 68 | if x_ticks is not None: 69 | ax.set_xticks(x_ticks) 70 | ax.set_xticklabels(x_ticks) 71 | 72 | # Setting the y-axis scale 73 | if y_lim is not None: 74 | ax.set_ylim(y_lim[0], y_lim[1]) 75 | 76 | # Adding labels and title 77 | ax.set_xlabel(x_label) 78 | ax.set_ylabel(y_label) 79 | ax.set_title(title) 80 | 81 | # Adjusting the appearance 82 | ax.spines['top'].set_visible(False) 83 | ax.spines['right'].set_visible(False) 84 | 85 | if second_x is not None and second_x_label is not None: 86 | ax2 = ax.twiny() 87 | ax2.set_xlabel(second_x_label) 88 | ax2.set_xticks(np.arange(0, len(second_x))) 89 | ax2.set_xticklabels(second_x) 90 | ax2.spines['top'].set_visible(False) 91 | ax2.spines['right'].set_visible(False) 92 | 93 | plt.savefig(save_path) 94 | plt.close() 95 | 96 | 97 | def plot_line_std_graph(x, y, std, x_label, y_label, title, save_path, colors=None, log_x=False, x_ticks=None, x_labels=None, y_labels=None, y_lim=None, size=(12, 8), annotate_last=False): 98 | # Set up the figure and axis 99 | fig, ax = plt.subplots() 100 | # Make figure larger 101 | fig.set_size_inches(size[0], size[1]) 102 | 103 | # Plotting multiple line graphs 104 | if colors is None: 105 | if isinstance(y[0], list) or isinstance(y[0], np.ndarray): 106 | # Different colors for each sublist 107 | colors = ['skyblue', 'orange', 'green', 'red', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan', 'black', 'yellow', 'magenta'] 108 | else: 109 | colors = 'skyblue' 110 | 111 | if not isinstance(y[0], list) and not isinstance(y[0], np.ndarray): 112 | y = [y] 113 | std = [std] 114 | colors = [colors] 115 | 116 | x = np.array(x) 117 | if len(x.shape) == 1: 118 | x = [x] * len(y) 119 | 120 | for i, y_vals in enumerate(y): 121 | color = colors[i] 122 | if y_labels is not None and not annotate_last: 123 | label = y_labels[i] 124 | else: 125 | label = None 126 | 127 | ax.plot(x[i], y_vals, linestyle='-', color=color, label=label) 128 | ax.fill_between(x[i], y_vals - std[i], y_vals + std[i], color=color, alpha=0.2) 129 | 130 | if annotate_last and y_labels is not None: 131 | y_arr = np.array(y)[:,-1] 132 | # Order the labels by the last value 133 | y_labels = [y_labels[i] for i in np.argsort(y_arr)] 134 | colors = [colors[i] for i in np.argsort(y_arr)] 135 | offset = 0.1 * y_lim[1] 136 | 137 | # Annotate the labels in order 138 | for i in range(len(y_labels)): 139 | y_val = offset * i + offset 140 | ax.annotate(y_labels[i], xy=(x[i][-1], y_val), xytext=(x[i][-1]+0.25, y_val), va='center', color=colors[i], weight='bold') 141 | 142 | # Setting the x-axis scale 143 | if log_x: 144 | ax.set_xscale('log') 145 | ax.set_xticks(x[0]) 146 | ax.get_xaxis().set_major_formatter(plt.ScalarFormatter()) 147 | 148 | # Setting the x-axis tick positions and labels 149 | if x_ticks is not None: 150 | if x_labels is None: 151 | ax.set_xticks(np.arange(0, len(x_ticks))) 152 | ax.set_xticklabels(x_ticks) 153 | else: 154 | ax.set_xticks(x_ticks) 155 | ax.set_xticklabels(x_labels) 156 | 157 | # Setting the y-axis scale 158 | if y_lim is not None: 159 | ax.set_ylim(y_lim[0], y_lim[1]) 160 | 161 | # Adding labels, legend, and title 162 | ax.set_xlabel(x_label) 163 | ax.set_ylabel(y_label) 164 | ax.set_title(title) 165 | if y_labels is not None and not annotate_last: 166 | ax.legend() 167 | 168 | # Adjusting the appearance 169 | ax.spines['top'].set_visible(False) 170 | ax.spines['right'].set_visible(False) 171 | 172 | # Save the graph to disk 173 | # plt.tight_layout() 174 | plt.savefig(save_path) 175 | plt.close() -------------------------------------------------------------------------------- /src/continual_learning/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Atenrev/diffusion_continual_learning/0d88ac5dea1ed7595b3f1f2f333b3bc1ce555f53/src/continual_learning/__init__.py -------------------------------------------------------------------------------- /src/continual_learning/metrics/diffusion_metrics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | import torch.nn as nn 5 | from torchvision.models import resnet18 6 | from torchmetrics.image.fid import FrechetInceptionDistance 7 | from avalanche.training.templates import SupervisedTemplate 8 | from avalanche.evaluation import Metric, PluginMetric 9 | from avalanche.evaluation.metric_results import MetricResult, MetricValue 10 | from avalanche.evaluation.metric_utils import get_metric_name 11 | 12 | 13 | class FIDMetric(Metric[float]): 14 | """ 15 | This metric computes the Frechet Inception Distance (FID) between two 16 | distributions of images. It uses the FID implementation from 17 | `torchmetrics `. 18 | """ 19 | def __init__(self, device='cuda'): 20 | self.device = device 21 | self.reset() 22 | 23 | @torch.no_grad() 24 | def update_true( 25 | self, 26 | true_y: torch.Tensor, 27 | ) -> None: 28 | true_y = torch.as_tensor(true_y) 29 | 30 | if true_y.min() < 0: 31 | true_y = (true_y + 1) / 2 32 | 33 | if true_y.shape[1] == 1: 34 | true_y = torch.cat([true_y] * 3, dim=1) 35 | 36 | self.fid.update(true_y, real=True) 37 | 38 | @torch.no_grad() 39 | def update_predicted( 40 | self, 41 | predicted_y: torch.Tensor, 42 | ) -> None: 43 | predicted_y = torch.as_tensor(predicted_y) 44 | 45 | if len(predicted_y) == 3: 46 | # Not expected from a dm output 47 | predicted_y = predicted_y[0] 48 | 49 | if predicted_y.shape[1] == 1: 50 | predicted_y = torch.cat([predicted_y] * 3, dim=1) 51 | 52 | self.fid.update(predicted_y, real=False) 53 | 54 | def result(self) -> float: 55 | return self.fid.compute().cpu().detach().item() 56 | 57 | def reset(self): 58 | self.fid = FrechetInceptionDistance(normalize=True, feature=2048) 59 | if type(self.device) == str and self.device.startswith('cuda') \ 60 | or type(self.device) == torch.device and self.device.type == 'cuda': 61 | self.fid.cuda() 62 | 63 | 64 | class DistributionMetrics(Metric[float]): 65 | """ 66 | This metric computes the Absolute Ratio Difference (ARD) and the KLD 67 | between two distributions of class frequencies. 68 | 69 | ARD^{b} = \sum_{j=1}^b \lvert \rho_j^{T_b} - \rho_j^{\epsilon_b} \rvert 70 | """ 71 | def __init__(self): 72 | self.reset() 73 | 74 | @torch.no_grad() 75 | def update_true( 76 | self, 77 | true_y: torch.Tensor, 78 | ) -> None: 79 | true_y = torch.as_tensor(true_y) 80 | hist_true = torch.zeros(10) 81 | for i in range(10): 82 | hist_true[i] = torch.sum(true_y == i) 83 | self.hist_true += hist_true 84 | 85 | @torch.no_grad() 86 | def update_predicted( 87 | self, 88 | predicted_y: torch.Tensor, 89 | ) -> None: 90 | predicted_y = torch.as_tensor(predicted_y) 91 | hist_pred = torch.zeros(10) 92 | for i in range(10): 93 | hist_pred[i] = torch.sum(predicted_y == i) 94 | self.hist_pred += hist_pred 95 | 96 | def result(self) -> float: 97 | ratio_true = self.hist_true / torch.sum(self.hist_true) 98 | ratio_pred = self.hist_pred / torch.sum(self.hist_pred) 99 | 100 | ard = torch.sum(torch.abs(ratio_true - ratio_pred)) / 2.0 101 | kl = torch.nn.KLDivLoss(reduction='batchmean') 102 | kl_div = kl((ratio_pred + 1e-8).log()[None, :], (ratio_true + 1e-8)[None, :]) 103 | return ard.cpu().detach().item(), kl_div.cpu().detach().item(), ratio_true, ratio_pred 104 | 105 | def reset(self): 106 | self.hist_true = torch.zeros(10) 107 | self.hist_pred = torch.zeros(10) 108 | 109 | 110 | class DiffusionMetricsMetric(PluginMetric[float]): 111 | 112 | def __init__(self, device='cuda', weights_path: str = "weights/cnn_fmnist/", n_samples: int = 10000, num_classes: int = 10): 113 | self.classifier = resnet18() 114 | self.classifier.fc = nn.Linear(self.classifier.fc.in_features, num_classes) 115 | self.classifier.to(device) 116 | self.classifier.load_state_dict(torch.load(os.path.join(weights_path, "resnet.pth"), map_location=device)) 117 | self.classifier.eval() 118 | 119 | self.fid_metric = FIDMetric(device) 120 | self.dist_metrics = DistributionMetrics() 121 | 122 | self.n_samples = n_samples 123 | 124 | def result(self) -> float: 125 | return self.fid_metric.result() 126 | 127 | def reset(self): 128 | self.fid_metric.reset() 129 | self.dist_metrics.reset() 130 | 131 | def before_training_exp(self, strategy: "SupervisedTemplate") -> None: 132 | super().before_training_exp(strategy) 133 | self.train_exp_id = strategy.experience.current_experience 134 | 135 | def after_training_exp(self, strategy: SupervisedTemplate) -> MetricResult: 136 | self.reset() 137 | 138 | batch_size = strategy.eval_mb_size 139 | num_samples = 10000 140 | num_batches = num_samples // batch_size 141 | remaining_samples = num_samples % batch_size 142 | 143 | for _ in range(num_batches): 144 | predicted_samples = strategy.generate_samples(batch_size) 145 | if predicted_samples.shape[1] == 1: 146 | predicted_samples = torch.cat([predicted_samples] * 3, dim=1) 147 | with torch.no_grad(): 148 | classes = torch.argmax(self.classifier((predicted_samples - 0.5) * 2), dim=1) 149 | classes_np = classes.cpu().numpy() 150 | self.fid_metric.update_predicted(predicted_samples) 151 | self.dist_metrics.update_predicted(classes_np) 152 | 153 | if remaining_samples > 0: 154 | predicted_samples = strategy.generate_samples(remaining_samples) 155 | if predicted_samples.shape[1] == 1: 156 | predicted_samples = torch.cat([predicted_samples] * 3, dim=1) 157 | with torch.no_grad(): 158 | classes = torch.argmax(self.classifier((predicted_samples - 0.5) * 2), dim=1) 159 | classes_np = classes.cpu().numpy() 160 | self.fid_metric.update_predicted(predicted_samples) 161 | self.dist_metrics.update_predicted(classes_np) 162 | 163 | return super().after_training_exp(strategy) 164 | 165 | def after_eval_iteration(self, strategy: 'PluggableStrategy'): 166 | """ 167 | Update the accuracy metric with the current 168 | predictions and targets 169 | """ 170 | super().after_eval_iteration(strategy) 171 | 172 | if strategy.experience.current_experience <= self.train_exp_id: 173 | self.fid_metric.update_true(strategy.mb_x) 174 | self.dist_metrics.update_true(strategy.mb_y) 175 | 176 | def _package_result(self, strategy): 177 | """ 178 | Package the result for logging 179 | """ 180 | add_exp = False 181 | plot_x_position = strategy.clock.train_iterations 182 | metrics = [] 183 | 184 | metric_value = self.fid_metric.result() 185 | metric_name = get_metric_name("stream_fid", strategy, 186 | add_experience=add_exp, 187 | add_task=True) 188 | metrics.append(MetricValue(self, metric_name, metric_value, 189 | plot_x_position)) 190 | 191 | ard_val, kl_val, hist_true, hist_pred = self.dist_metrics.result() 192 | ard_name = get_metric_name("stream_ard", strategy, 193 | add_experience=add_exp, 194 | add_task=True) 195 | kl_name = get_metric_name("stream_kld", strategy, 196 | add_experience=add_exp, 197 | add_task=True) 198 | hist_pred_name = get_metric_name("stream_hist_pred", strategy, 199 | add_experience=add_exp, 200 | add_task=True) 201 | hist_true_name = get_metric_name("stream_hist_true", strategy, 202 | add_experience=add_exp, 203 | add_task=True) 204 | 205 | metrics.append(MetricValue(self, ard_name, ard_val, plot_x_position)) 206 | metrics.append(MetricValue(self, kl_name, kl_val, plot_x_position)) 207 | metrics.append(MetricValue(self, hist_pred_name, hist_pred, plot_x_position)) 208 | metrics.append(MetricValue(self, hist_true_name, hist_true, plot_x_position)) 209 | 210 | return metrics 211 | 212 | def after_eval(self, strategy: 'PluggableStrategy'): 213 | return self._package_result(strategy) 214 | 215 | def __str__(self): 216 | return "DiffusionMetrics" -------------------------------------------------------------------------------- /src/continual_learning/plugins.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from avalanche.core import SupervisedPlugin 3 | import torch 4 | 5 | 6 | class OldGeneratorManager: 7 | """ 8 | OldGeneratorManager is a class that manages the old generator 9 | when using the UpdatedGenerativeReplayPlugin with multiple 10 | strategies. Stores the old generator as a singleton. 11 | """ 12 | _old_generator = None 13 | _current_experience = None 14 | 15 | @classmethod 16 | def update_and_get_old_generator(cls, generator, current_experience): 17 | """ 18 | Sets the old generator and the current experience. 19 | If the current experience is the same as the previous one, 20 | the old generator is not updated. 21 | 22 | Returns the old generator. 23 | """ 24 | if cls._current_experience != current_experience: 25 | cls._current_experience = current_experience 26 | new_device = generator.device 27 | 28 | # Check if there is a second GPU available 29 | if torch.cuda.device_count() > 1: 30 | # If yes, move the generator to the second GPU 31 | new_device = torch.device(torch.device("cuda:1")) 32 | 33 | cls._old_generator = deepcopy(generator) 34 | cls._old_generator.to(new_device) 35 | cls._old_generator.eval() 36 | 37 | return cls._old_generator 38 | 39 | 40 | class UpdatedGenerativeReplayPlugin(SupervisedPlugin): 41 | """ 42 | Experience generative replay plugin. 43 | 44 | Updates the current mbatch of a strategy before training an experience 45 | by sampling a generator model and concatenating the replay data to the 46 | current batch. 47 | 48 | :param generator_strategy: In case the plugin is applied to a non-generative 49 | model (e.g. a simple classifier), this should contain an Avalanche strategy 50 | for a model that implements a 'generate' method 51 | (see avalanche.models.generator.Generator). Defaults to None. 52 | :param untrained_solver: if True we assume this is the beginning of 53 | a continual learning task and add replay data only from the second 54 | experience onwards, otherwise we sample and add generative replay data 55 | before training the first experience. Default to True. 56 | :param replay_size: The user can specify the batch size of replays that 57 | should be added to each data batch. By default each data batch will be 58 | matched with replays of the same number. 59 | :param increasing_replay_size: If set to True, each experience this will 60 | double the amount of replay data added to each data batch. The effect 61 | will be that the older experiences will gradually increase in importance 62 | to the final loss. 63 | :param T: Temperature parameter for scaling logits of the replay data. 64 | """ 65 | 66 | def __init__( 67 | self, 68 | generator_strategy=None, 69 | untrained_solver: bool = True, 70 | replay_size: int = None, 71 | increasing_replay_size: bool = False, 72 | # T: float = 2.0, 73 | ): 74 | """ 75 | Init. 76 | """ 77 | super().__init__() 78 | self.generator_strategy = generator_strategy 79 | if self.generator_strategy: 80 | self.generator = generator_strategy.model 81 | else: 82 | self.generator = None 83 | self.untrained_solver = untrained_solver 84 | self.model_is_generator = False 85 | self.replay_size = replay_size 86 | self.increasing_replay_size = increasing_replay_size 87 | 88 | def before_training(self, strategy, *args, **kwargs): 89 | """ 90 | Checks whether we are using a user defined external generator 91 | or we use the strategy's model as the generator. 92 | If the generator is None after initialization 93 | we assume that strategy.model is the generator. 94 | (e.g. this would be the case when training a VAE with 95 | generative replay) 96 | """ 97 | if not self.generator_strategy: 98 | self.generator_strategy = strategy 99 | self.generator = strategy.model 100 | self.model_is_generator = True 101 | 102 | def before_training_exp( 103 | self, strategy, num_workers: int = 0, shuffle: bool = True, **kwargs 104 | ): 105 | """ 106 | Make deep copies of generator and solver before training new experience. 107 | Then, generate replay data and store it in the strategy's replay buffer. 108 | """ 109 | if self.untrained_solver: 110 | return 111 | 112 | self.old_generator = OldGeneratorManager.update_and_get_old_generator( 113 | self.generator, strategy.experience.current_experience 114 | ) 115 | 116 | if not self.model_is_generator: 117 | self.old_model = deepcopy(strategy.model) 118 | 119 | if torch.cuda.device_count() > 1: 120 | self.old_model.to(torch.device(torch.device("cuda:1"))) 121 | 122 | self.old_model.eval() 123 | 124 | torch.cuda.empty_cache() 125 | 126 | def after_training_exp( 127 | self, strategy, num_workers: int = 0, shuffle: bool = True, **kwargs 128 | ): 129 | """ 130 | Set untrained_solver boolean to False after (the first) experience, 131 | in order to start training with replay data from the second experience. 132 | """ 133 | self.untrained_solver = False 134 | 135 | def before_training_iteration(self, strategy, **kwargs): 136 | """ 137 | Appending replay data to current minibatch before 138 | each training iteration. 139 | """ 140 | if self.untrained_solver: 141 | return 142 | 143 | self.current_batch_size = len(strategy.mbatch[0]) 144 | 145 | if self.replay_size: 146 | number_replays_to_generate = self.replay_size 147 | else: 148 | if self.increasing_replay_size: 149 | number_replays_to_generate = len(strategy.mbatch[0]) * ( 150 | strategy.experience.current_experience 151 | ) 152 | else: 153 | number_replays_to_generate = len(strategy.mbatch[0]) 154 | 155 | replay = self.old_generator.generate(number_replays_to_generate).to(strategy.device) 156 | strategy.mbatch[0] = torch.cat([strategy.mbatch[0], replay], dim=0) 157 | 158 | # extend y with predicted labels (or mock labels if model==generator) 159 | if not self.model_is_generator: 160 | if torch.cuda.device_count() > 1: 161 | replay_in_old_device = replay.to(torch.device(torch.device("cuda:1"))) 162 | else: 163 | replay_in_old_device = replay 164 | 165 | with torch.no_grad(): 166 | replay_output = self.old_model(replay_in_old_device) 167 | else: 168 | # Mock labels: 169 | replay_output = torch.zeros(replay.shape[0]) 170 | 171 | replay_output = replay_output.to(strategy.device) 172 | 173 | if replay_output.ndim > 1: 174 | # If we are using a classification model, we one-hot encode the labels 175 | # of the training data (so we can use soft labels for the replay data) 176 | strategy.mbatch[1] = torch.nn.functional.one_hot( 177 | strategy.mbatch[1], num_classes=replay_output.shape[1] 178 | ).to(strategy.device) 179 | 180 | # Then we append the replay data to the current minibatch 181 | strategy.mbatch[1] = torch.cat( 182 | [strategy.mbatch[1], replay_output], dim=0 183 | ) 184 | 185 | # extend task id batch (we implicitly assume a task-free case) 186 | strategy.mbatch[-1] = torch.cat( 187 | [ 188 | strategy.mbatch[-1], 189 | torch.ones(replay.shape[0]).to(strategy.device) 190 | * strategy.mbatch[-1][0], 191 | ], 192 | dim=0, 193 | ) 194 | 195 | 196 | class TrainGeneratorAfterExpPlugin(SupervisedPlugin): 197 | """ 198 | TrainGeneratorAfterExpPlugin makes sure that after each experience of 199 | training the solver of a scholar model, we also train the generator on the 200 | data of the current experience. 201 | """ 202 | 203 | def after_training_exp(self, strategy, **kwargs): 204 | """ 205 | The training method expects an Experience object 206 | with a 'dataset' parameter. 207 | """ 208 | for plugin in strategy.plugins: 209 | if type(plugin) is UpdatedGenerativeReplayPlugin: 210 | plugin.generator_strategy.train(strategy.experience) 211 | -------------------------------------------------------------------------------- /src/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Atenrev/diffusion_continual_learning/0d88ac5dea1ed7595b3f1f2f333b3bc1ce555f53/src/datasets/__init__.py -------------------------------------------------------------------------------- /src/datasets/cifar10.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | from torchvision import transforms 3 | from torch.utils.data import DataLoader 4 | 5 | 6 | def create_dataloader(batch_size: int = 128, 7 | train_transform: transforms.Compose = None, 8 | test_transform: transforms.Compose = None, 9 | classes: list = None 10 | ): 11 | # load dataset from the hub 12 | train_dataset = load_dataset("cifar10", split="train") 13 | test_dataset = load_dataset("cifar10", split="test") 14 | 15 | # filter dataset 16 | if classes is not None: 17 | train_dataset = train_dataset.filter(lambda example: example["label"] in classes) 18 | test_dataset = test_dataset.filter(lambda example: example["label"] in classes) 19 | 20 | if train_transform is None: 21 | train_transform = transforms.Compose([ 22 | transforms.ToTensor(), 23 | ]) 24 | if test_transform is None: 25 | test_transform = transforms.Compose([ 26 | transforms.ToTensor(), 27 | ]) 28 | 29 | def apply_train_transforms(examples): 30 | examples["pixel_values"] = [train_transform(image) for image in examples["img"]] 31 | del examples["img"] 32 | return examples 33 | 34 | def apply_test_transforms(examples): 35 | examples["pixel_values"] = [test_transform(image) for image in examples["img"]] 36 | del examples["img"] 37 | return examples 38 | 39 | transformed_train_dataset = train_dataset.with_transform(apply_train_transforms) 40 | transformed_test_dataset = test_dataset.with_transform(apply_test_transforms) 41 | 42 | # create dataloader 43 | train_dataloader = DataLoader(transformed_train_dataset, batch_size=batch_size, shuffle=True) 44 | test_dataloader = DataLoader(transformed_test_dataset, batch_size=batch_size, shuffle=False) 45 | 46 | return train_dataloader, test_dataloader -------------------------------------------------------------------------------- /src/datasets/cifar100.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | from torchvision import transforms 3 | from torch.utils.data import DataLoader 4 | 5 | 6 | def create_dataloader(batch_size: int = 128, transform: transforms.Compose = None, classes: list = None): 7 | # load dataset from the hub 8 | dataset = load_dataset("cifar100") 9 | 10 | # filter dataset 11 | if classes is not None: 12 | dataset = dataset.filter(lambda example: example["label"] in classes) 13 | 14 | if transform is None: 15 | transform = transforms.Compose([ 16 | transforms.ToTensor(), 17 | ]) 18 | 19 | def apply_transforms(examples): 20 | examples["pixel_values"] = [transform(image) for image in examples["img"]] 21 | del examples["img"] 22 | return examples 23 | 24 | transformed_dataset = dataset.with_transform(apply_transforms) 25 | 26 | # create dataloader 27 | train_dataloader = DataLoader(transformed_dataset["train"], batch_size=batch_size, shuffle=True) 28 | test_dataloader = DataLoader(transformed_dataset["test"], batch_size=batch_size, shuffle=False) 29 | 30 | return train_dataloader, test_dataloader -------------------------------------------------------------------------------- /src/datasets/fashion_mnist.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | from torchvision import transforms 3 | from torch.utils.data import DataLoader 4 | 5 | 6 | def create_dataloader(batch_size: int = 128, transform: transforms.Compose = None, classes: list = None): 7 | # load dataset from the hub 8 | dataset = load_dataset("fashion_mnist") 9 | 10 | # filter dataset 11 | if classes is not None: 12 | dataset = dataset.filter(lambda example: example["label"] in classes) 13 | 14 | if transform is None: 15 | transform = transforms.Compose([ 16 | transforms.ToTensor(), 17 | ]) 18 | 19 | def apply_transforms(examples): 20 | examples["pixel_values"] = [transform(image.convert("L")) for image in examples["image"]] 21 | del examples["image"] 22 | return examples 23 | 24 | transformed_dataset = dataset.with_transform(apply_transforms) 25 | 26 | # create dataloader 27 | train_dataloader = DataLoader(transformed_dataset["train"], batch_size=batch_size, shuffle=True) 28 | test_dataloader = DataLoader(transformed_dataset["test"], batch_size=batch_size, shuffle=False) 29 | 30 | return train_dataloader, test_dataloader -------------------------------------------------------------------------------- /src/datasets/mnist.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | from torchvision import transforms 3 | from torch.utils.data import DataLoader 4 | 5 | 6 | def create_dataloader(batch_size: int = 128, transform: transforms.Compose = None, classes: list = None): 7 | # load dataset from the hub 8 | dataset = load_dataset("mnist") 9 | 10 | # filter dataset 11 | if classes is not None: 12 | dataset = dataset.filter(lambda example: example["label"] in classes) 13 | 14 | if transform is None: 15 | transform = transforms.Compose([ 16 | transforms.ToTensor(), 17 | ]) 18 | 19 | def apply_transforms(examples): 20 | examples["pixel_values"] = [transform(image.convert("L")) for image in examples["image"]] 21 | del examples["image"] 22 | return examples 23 | 24 | transformed_dataset = dataset.with_transform(apply_transforms) 25 | 26 | # create dataloader 27 | train_dataloader = DataLoader(transformed_dataset["train"], batch_size=batch_size, shuffle=True) 28 | test_dataloader = DataLoader(transformed_dataset["test"], batch_size=batch_size, shuffle=False) 29 | 30 | return train_dataloader, test_dataloader -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Atenrev/diffusion_continual_learning/0d88ac5dea1ed7595b3f1f2f333b3bc1ce555f53/src/models/__init__.py -------------------------------------------------------------------------------- /src/models/simple_cnn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from typing import Tuple 4 | 5 | 6 | class SimpleCNN(nn.Module): 7 | """ 8 | Convolutional Neural Network 9 | """ 10 | 11 | def __init__(self, n_channels: int = 3, num_classes: int = 10): 12 | super(SimpleCNN, self).__init__() 13 | 14 | self.features = nn.Sequential( 15 | nn.Conv2d(n_channels, 32, kernel_size=3, stride=1, padding=1), 16 | nn.ReLU(inplace=True), 17 | nn.Conv2d(32, 32, kernel_size=3, padding=0), 18 | nn.ReLU(inplace=True), 19 | nn.MaxPool2d(kernel_size=2, stride=2), 20 | nn.Dropout(p=0.25), 21 | nn.Conv2d(32, 64, kernel_size=3, padding=1), 22 | nn.ReLU(inplace=True), 23 | nn.Conv2d(64, 128, kernel_size=3, padding=0), 24 | nn.ReLU(inplace=True), 25 | nn.MaxPool2d(kernel_size=2, stride=2), 26 | nn.Dropout(p=0.25), 27 | nn.Conv2d(128, 128, kernel_size=1, padding=0), 28 | nn.ReLU(inplace=True), 29 | nn.AdaptiveMaxPool2d(1), 30 | nn.Dropout(p=0.25), 31 | ) 32 | self.classifier = nn.Sequential(nn.Linear(128, num_classes)) 33 | 34 | def forward(self, x): 35 | x = self.features(x) 36 | x = x.view(x.size(0), -1) 37 | x = self.classifier(x) 38 | return x -------------------------------------------------------------------------------- /src/models/vae.py: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | # Copyright (c) 2021 ContinualAI. # 3 | # Copyrights licensed under the MIT License. # 4 | # See the accompanying LICENSE file for terms. # 5 | # # 6 | # Date: 03-03-2022 # 7 | # Author: Florian Mies # 8 | # Website: https://github.com/travela # 9 | ################################################################################ 10 | 11 | """ 12 | 13 | File to place any kind of generative models 14 | and their respective helper functions. 15 | 16 | """ 17 | 18 | import torch 19 | import torch.nn as nn 20 | from abc import abstractmethod 21 | from torchvision import transforms 22 | from avalanche.models.utils import MLP, Flatten 23 | from avalanche.models.base_model import BaseModel 24 | 25 | 26 | class Generator(BaseModel): 27 | """ 28 | A base abstract class for generators 29 | """ 30 | 31 | @abstractmethod 32 | def generate(self, batch_size=None, condition=None): 33 | """ 34 | Lets the generator sample random samples. 35 | Output is either a single sample or, if provided, 36 | a batch of samples of size "batch_size" 37 | 38 | :param batch_size: Number of samples to generate 39 | :param condition: Possible condition for a condotional generator 40 | (e.g. a class label) 41 | """ 42 | 43 | 44 | ########################### 45 | # VARIATIONAL AUTOENCODER # 46 | ########################### 47 | 48 | 49 | class VAEMLPEncoder(nn.Module): 50 | """ 51 | Encoder part of the VAE, computer the latent represenations of the input. 52 | 53 | :param shape: Shape of the input to the network: (channels, height, width) 54 | :param latent_dim: Dimension of last hidden layer 55 | """ 56 | 57 | def __init__(self, shape, units_dim: tuple = (400, 400), latent_dim: int = 100, use_bn: bool = False): 58 | super(VAEMLPEncoder, self).__init__() 59 | 60 | flatten_size = torch.Size(shape).numel() 61 | prev_size = flatten_size 62 | self.encode = [Flatten(), ] 63 | 64 | for i in range(len(units_dim)): 65 | self.encode.append(nn.Linear(prev_size, units_dim[i])) 66 | if use_bn: 67 | self.encode.append(nn.BatchNorm1d(units_dim[i])) 68 | self.encode.append(nn.ReLU()) 69 | prev_size = units_dim[i] 70 | 71 | self.encode = nn.Sequential(*self.encode) 72 | self.z_mean = nn.Linear(prev_size, latent_dim) 73 | self.z_log_var = nn.Linear(prev_size, latent_dim) 74 | 75 | def forward(self, x, y=None): 76 | x = self.encode(x) 77 | if torch.isnan(x).any(): 78 | print("NAN in VAE") 79 | mean = self.z_mean(x) 80 | logvar = self.z_log_var(x) 81 | return x, mean, logvar 82 | 83 | 84 | class VAEMLPDecoder(nn.Module): 85 | """ 86 | Decoder part of the VAE. Reverses Encoder. 87 | 88 | :param shape: Shape of output: (channels, height, width). 89 | :param nhid: Dimension of input. 90 | """ 91 | 92 | def __init__(self, shape, units_dim: tuple = (400, 400), latent_dim: int = 100, use_bn: bool = False): 93 | super(VAEMLPDecoder, self).__init__() 94 | flattened_size = torch.Size(shape).numel() 95 | prev_size = latent_dim 96 | self.shape = shape 97 | self.decode = [] 98 | 99 | for i in range(len(units_dim)): 100 | self.decode.append(nn.Linear(prev_size, units_dim[i])) 101 | if use_bn: 102 | self.decode.append(nn.BatchNorm1d(units_dim[i])) 103 | self.decode.append(nn.ReLU()) 104 | prev_size = units_dim[i] 105 | 106 | self.decode.append(nn.Linear(prev_size, flattened_size)) 107 | self.decode.append(nn.Sigmoid()) 108 | self.decode = nn.Sequential(*self.decode) 109 | # self.inv_trans = transforms.Compose( 110 | # [transforms.Normalize((0.1307,), (0.3081,))] 111 | # ) 112 | 113 | def forward(self, z, y=None): 114 | if y is None: 115 | # return self.inv_trans(self.decode(z).view(-1, *self.shape)) 116 | return self.decode(z).view(-1, *self.shape) 117 | else: 118 | # return self.inv_trans(self.decode(torch.cat((z, y), dim=1)).view(-1, *self.shape)) 119 | return self.decode(torch.cat((z, y), dim=1)).view(-1, *self.shape) 120 | 121 | 122 | class MlpVAE(Generator, nn.Module): 123 | """ 124 | Variational autoencoder module: 125 | fully-connected and suited for any input shape and type. 126 | 127 | The encoder only computes the latent represenations 128 | and we have then two possible output heads: 129 | One for the usual output distribution and one for classification. 130 | The latter is an extension the conventional VAE and incorporates 131 | a classifier into the network. 132 | More details can be found in: https://arxiv.org/abs/1809.10635 133 | """ 134 | 135 | def __init__(self, shape, encoder_dims, decoder_dims, latent_dim, n_classes=10, device="cpu"): 136 | """ 137 | :param shape: Shape of each input sample 138 | :param nhid: Dimension of latent space of Encoder. 139 | :param n_classes: Number of classes - 140 | defines classification head's dimension 141 | """ 142 | super(MlpVAE, self).__init__() 143 | assert latent_dim % 2 == 0, "Latent dimension must be even" 144 | 145 | self.dim = latent_dim 146 | if device is None: 147 | device = 'cpu' 148 | 149 | self.device = torch.device(device) 150 | self.encoder = VAEMLPEncoder(shape, encoder_dims, latent_dim) 151 | self.decoder = VAEMLPDecoder(shape, decoder_dims, latent_dim) 152 | self.classification = nn.Linear(encoder_dims[-1], n_classes) 153 | 154 | def get_features(self, x): 155 | """ 156 | Get features for encoder part given input x 157 | """ 158 | return self.encoder(x) 159 | 160 | def generate(self, batch_size=None): 161 | """ 162 | Generate random samples. 163 | Output is either a single sample if batch_size=None, 164 | else it is a batch of samples of size "batch_size". 165 | """ 166 | z = ( 167 | torch.randn((batch_size, self.dim)).to(self.device) 168 | if batch_size 169 | else torch.randn((1, self.dim)).to(self.device) 170 | ) 171 | 172 | with torch.no_grad(): 173 | res = self.decoder(z) 174 | 175 | if not batch_size: 176 | res = res.squeeze(0) 177 | return res 178 | 179 | def sampling(self, mean, logvar): 180 | """ 181 | VAE 'reparametrization trick' 182 | """ 183 | eps = torch.randn(mean.shape).to(self.device) 184 | sigma = 0.5 * torch.exp(logvar) 185 | return mean + eps * sigma 186 | 187 | def forward(self, x_o): 188 | """ 189 | Forward. 190 | """ 191 | x, mean, logvar = self.encoder(x_o) 192 | z = self.sampling(mean, logvar) 193 | x_hat = self.decoder(z) 194 | 195 | if torch.isnan(x_hat).any(): 196 | print("NAN in VAE") 197 | 198 | return x_hat, mean, logvar 199 | 200 | 201 | BCE_loss = nn.BCELoss(reduction="mean") 202 | MSE_loss = nn.MSELoss(reduction="mean") 203 | 204 | 205 | def VAE_loss(X, forward_output): 206 | """ 207 | Loss function of a VAE using mean squared error for reconstruction loss. 208 | This is the criterion for VAE training loop. 209 | 210 | :param X: Original input batch. 211 | :param forward_output: Return value of a VAE.forward() call. 212 | Triplet consisting of (X_hat, mean. logvar), ie. 213 | (Reconstructed input after subsequent Encoder and Decoder, 214 | mean of the VAE output distribution, 215 | logvar of the VAE output distribution) 216 | """ 217 | from torch.nn import functional as F 218 | X_hat, mean, logvar = forward_output 219 | batch_size = X.shape[0] 220 | 221 | if batch_size == 0: 222 | return torch.tensor(0.0) 223 | 224 | # reconstruction_loss = MSE_loss(X_hat, X) 225 | # reconstruction_loss /= X.shape[1] * X.shape[2] * X.shape[3] 226 | reconstruction_loss = F.binary_cross_entropy(input=X_hat.view(batch_size, -1), target=X.view(batch_size, -1), 227 | reduction='none') 228 | reconstruction_loss = torch.mean(reconstruction_loss, dim=1) 229 | reconstruction_loss = torch.mean(reconstruction_loss) 230 | KL_divergence = 0.5 * torch.sum(-1 - logvar + torch.exp(logvar) + mean ** 2, dim=1) 231 | KL_divergence = torch.mean(KL_divergence) 232 | KL_divergence /= X.shape[1] * X.shape[2] * X.shape[3] 233 | return reconstruction_loss + KL_divergence 234 | 235 | 236 | __all__ = ["MlpVAE"] 237 | -------------------------------------------------------------------------------- /src/pipelines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Atenrev/diffusion_continual_learning/0d88ac5dea1ed7595b3f1f2f333b3bc1ce555f53/src/pipelines/__init__.py -------------------------------------------------------------------------------- /src/pipelines/pipeline_ddim.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import List, Optional, Tuple, Union 16 | 17 | import torch 18 | 19 | from diffusers.utils import randn_tensor 20 | from diffusers.pipeline_utils import DiffusionPipeline, ImagePipelineOutput 21 | 22 | from src.schedulers.scheduler_ddim import DDIMScheduler 23 | 24 | 25 | class DDIMPipeline(DiffusionPipeline): 26 | r""" 27 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the 28 | library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) 29 | 30 | Parameters: 31 | unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. 32 | scheduler ([`SchedulerMixin`]): 33 | A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of 34 | [`DDPMScheduler`], or [`DDIMScheduler`]. 35 | """ 36 | 37 | def __init__(self, unet, scheduler): 38 | super().__init__() 39 | 40 | # make sure scheduler can always be converted to DDIM 41 | scheduler = DDIMScheduler.from_config(scheduler.config) 42 | 43 | self.register_modules(unet=unet, scheduler=scheduler) 44 | 45 | @torch.no_grad() 46 | def __call__( 47 | self, 48 | batch_size: int = 1, 49 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 50 | eta: float = 0.0, 51 | num_inference_steps: int = 50, 52 | target_steps: Union[List[int], int] = 0, 53 | use_clipped_model_output: Optional[bool] = None, 54 | output_type: Optional[str] = "torch", 55 | return_dict: bool = True, 56 | ) -> Union[ImagePipelineOutput, Tuple, torch.Tensor]: 57 | r""" 58 | Args: 59 | batch_size (`int`, *optional*, defaults to 1): 60 | The number of images to generate. 61 | generator (`torch.Generator`, *optional*): 62 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) 63 | to make generation deterministic. 64 | eta (`float`, *optional*, defaults to 0.0): 65 | The eta parameter which controls the scale of the variance (0 is DDIM and 1 is one type of DDPM). 66 | num_inference_steps (`int`, *optional*, defaults to 50): 67 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 68 | expense of slower inference. 69 | target_step (`int`, *optional*, defaults to 0): 70 | The step at which to stop the denoising process. If `target_step` is 0, the denoising process will 71 | continue until the last step. If `target_step` is greater than 0, the denoising process will stop at 72 | `target_step` and the image will be denoised to the corresponding noise level. 73 | use_clipped_model_output (`bool`, *optional*, defaults to `None`): 74 | if `True` or `False`, see documentation for `DDIMScheduler.step`. If `None`, nothing is passed 75 | downstream to the scheduler. So use `None` for schedulers which don't support this argument. 76 | output_type (`str`, *optional*, defaults to `"pil"`): 77 | The output format of the generate image. Choose between 78 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 79 | return_dict (`bool`, *optional*, defaults to `True`): 80 | Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. 81 | 82 | Returns: 83 | [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is 84 | True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. 85 | """ 86 | 87 | # Sample gaussian noise to begin loop 88 | if isinstance(self.unet.config.sample_size, int): 89 | image_shape = ( 90 | batch_size, 91 | self.unet.config.in_channels, 92 | self.unet.config.sample_size, 93 | self.unet.config.sample_size, 94 | ) 95 | else: 96 | image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size) 97 | 98 | if isinstance(generator, list) and len(generator) != batch_size: 99 | raise ValueError( 100 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 101 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 102 | ) 103 | 104 | if isinstance(target_steps, torch.Tensor): 105 | target_steps = target_steps.cpu().numpy() 106 | 107 | image = randn_tensor(image_shape, generator=generator, device=self.device, dtype=self.unet.dtype) 108 | 109 | # set step values 110 | self.scheduler.set_timesteps(num_inference_steps, target_steps) 111 | 112 | for i, t in self.progress_bar(enumerate(self.scheduler.timesteps)): 113 | # 1. predict noise model_output 114 | model_output = self.unet(image, t.to(self.device)).sample 115 | 116 | # if torch.isnan(model_output).any(): 117 | # print("WARNING: NaNs encountered in model output.") 118 | 119 | # 2. predict previous mean of image x_t-1 and add variance depending on eta 120 | # eta corresponds to η in paper and should be between [0, 1] 121 | # do x_t -> x_t-1 122 | new_image = self.scheduler.step( 123 | model_output, t, image, eta=eta, use_clipped_model_output=use_clipped_model_output, generator=generator 124 | ).prev_sample.type(self.unet.dtype) 125 | 126 | # 3. Update images with timestep different from the previous one 127 | if i == 0: 128 | image = new_image 129 | else: 130 | mask = t == self.scheduler.timesteps[i-1] 131 | # Make mask broadcastable 132 | if not mask.shape: 133 | mask = mask[None] 134 | mask = mask.unsqueeze(1).unsqueeze(1).unsqueeze(1).to(new_image.device) 135 | image = torch.where(mask, image, new_image) 136 | 137 | # if torch.isnan(image).any(): 138 | # print("WARNING: NaNs encountered in image.") 139 | 140 | if output_type == "torch_raw": 141 | return image 142 | 143 | image = (image / 2 + 0.5).clamp(0, 1) 144 | 145 | if output_type == "torch": 146 | return image 147 | 148 | # image = image.clamp(0, 1) 149 | image = image.cpu().permute(0, 2, 3, 1).numpy() 150 | if output_type == "pil": 151 | image = self.numpy_to_pil(image) 152 | 153 | if not return_dict: 154 | return (image,) 155 | 156 | return ImagePipelineOutput(images=image) -------------------------------------------------------------------------------- /src/schedulers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Atenrev/diffusion_continual_learning/0d88ac5dea1ed7595b3f1f2f333b3bc1ce555f53/src/schedulers/__init__.py -------------------------------------------------------------------------------- /src/standard_training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Atenrev/diffusion_continual_learning/0d88ac5dea1ed7595b3f1f2f333b3bc1ce555f53/src/standard_training/__init__.py -------------------------------------------------------------------------------- /src/standard_training/evaluators/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Atenrev/diffusion_continual_learning/0d88ac5dea1ed7595b3f1f2f333b3bc1ce555f53/src/standard_training/evaluators/__init__.py -------------------------------------------------------------------------------- /src/standard_training/evaluators/base_evaluator.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | 4 | class BaseEvaluator(ABC): 5 | 6 | def __init__(self): 7 | pass 8 | 9 | def evaluate(self, *args, **kwargs): 10 | raise NotImplementedError -------------------------------------------------------------------------------- /src/standard_training/evaluators/generative_evaluator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from PIL import Image 5 | from tqdm import tqdm 6 | from torchmetrics.image.fid import FrechetInceptionDistance 7 | 8 | from src.common.diffusion_utils import make_grid 9 | from src.standard_training.evaluators.base_evaluator import BaseEvaluator 10 | 11 | 12 | class GenerativeModelEvaluator(BaseEvaluator): 13 | def __init__(self, device: str = "cuda", save_path="results", save_images: int = 0, fid_feature_size: int = 2048): 14 | self.device = device 15 | self.save_path = save_path 16 | self.fid_feature_size = fid_feature_size 17 | self.save_images = save_images 18 | self.real_features_computed = False 19 | self.fid = FrechetInceptionDistance(normalize=True, 20 | reset_real_features=False, 21 | feature=self.fid_feature_size) 22 | 23 | if self.device == "cuda": 24 | self.fid.cuda() 25 | 26 | @torch.no_grad() 27 | def evaluate_fid(self, model, dataloader, epoch: int = 0, fid_images: int = 10000, gensteps: int = 20) -> float: 28 | print("Evaluating FID...") 29 | self.fid.reset() 30 | batch_size = dataloader.batch_size 31 | 32 | if not self.real_features_computed: 33 | print("Processing real images...") 34 | 35 | for batch in tqdm(dataloader): 36 | batch = batch["pixel_values"].to(self.device) 37 | 38 | if batch.min() < 0: 39 | batch = (batch + 1) / 2 40 | 41 | if batch.shape[1] == 1: 42 | batch = torch.cat([batch] * 3, dim=1) 43 | 44 | self.fid.update(batch, real=True) 45 | 46 | self.real_features_computed = True 47 | 48 | print("Processing generated images...") 49 | 50 | if fid_images == 0: 51 | images_to_generate = len(dataloader.dataset) 52 | else: 53 | images_to_generate = fid_images 54 | 55 | bar = tqdm(total=images_to_generate // batch_size+1) 56 | while images_to_generate > batch_size: 57 | pred = model.generate(batch_size, generation_steps=gensteps, output_type="torch") 58 | 59 | if pred.shape[1] == 1: 60 | pred = torch.cat([pred] * 3, dim=1) 61 | 62 | self.fid.update(pred, real=False) 63 | images_to_generate -= batch_size 64 | bar.update(1) 65 | 66 | pred = model.generate(images_to_generate, generation_steps=gensteps, output_type="torch") 67 | 68 | if pred.shape[1] == 1: 69 | pred = torch.cat([pred] * 3, dim=1) 70 | 71 | self.fid.update(pred, real=False) 72 | bar.update(1) 73 | bar.close() 74 | 75 | fid = self.fid.compute().cpu().detach().item() 76 | print(f"Evaluation complete. FID: {fid}") 77 | 78 | if self.save_images > 0: 79 | generated_images = model.generate(self.save_images, generation_steps=gensteps, output_type="torch") 80 | # To PIL image 81 | generated_images = generated_images.mul(255).to(torch.uint8) 82 | generated_images = generated_images.permute(0, 2, 3, 1).cpu().numpy() 83 | generated_images = [Image.fromarray(img.squeeze()) for img in generated_images] 84 | nrows = int(self.save_images**0.5) 85 | ncols = self.save_images // nrows + self.save_images % nrows 86 | generated_images = make_grid(generated_images, rows=nrows, cols=ncols) 87 | out_dir = os.path.join(self.save_path, "samples") 88 | os.makedirs(out_dir, exist_ok=True) 89 | generated_images.save(os.path.join(out_dir, f"samples_epoch_{epoch}_gensteps_{gensteps}_fid_{fid:.4f}.png")) 90 | 91 | return fid 92 | 93 | def evaluate(self, model, dataloader, epoch: int = 0, fid_images: int = 10000, gensteps: int = 20, compute_auc: bool = True) -> dict: 94 | """ 95 | Computes the FID score for the given model and dataloader. 96 | 97 | If compute_auc is True, then the FID score is computed for gensteps 2, 5, 10, 20 98 | and the AUC is computed. 99 | 100 | Args: 101 | model (torch.nn.Module): The model to evaluate. 102 | dataloader (torch.utils.data.DataLoader): The dataloader to use for evaluation. 103 | epoch (int, optional): The epoch number. Defaults to 0. 104 | fid_images (int, optional): The number of images to use for FID computation. Defaults to 10000. 105 | gensteps (int, optional): The number of steps to use for generation. Defaults to 20. 106 | compute_auc (bool, optional): Whether to compute the AUC or not. Defaults to True. 107 | 108 | Returns: 109 | dict: A dictionary containing the FID score and the AUC (if compute_auc is True). 110 | """ 111 | if not compute_auc: 112 | fid = self.evaluate_fid(model, dataloader, epoch, fid_images, gensteps) 113 | return {"fid": fid} 114 | 115 | gensteps_list = [2, 5, 10, 20] 116 | fid_list = [] 117 | for genstep in gensteps_list: 118 | fid = self.evaluate_fid(model, dataloader, epoch, fid_images, genstep) 119 | fid_list.append(fid) 120 | 121 | auc = torch.trapz(torch.asarray(fid_list), x=torch.asarray([2, 5, 10, 20])).item() 122 | results = { 123 | "auc": auc, 124 | } 125 | for i, genstep in enumerate(gensteps_list): 126 | results[f"fid_{genstep}"] = fid_list[i] 127 | 128 | return results -------------------------------------------------------------------------------- /src/standard_training/losses/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Atenrev/diffusion_continual_learning/0d88ac5dea1ed7595b3f1f2f333b3bc1ce555f53/src/standard_training/losses/__init__.py -------------------------------------------------------------------------------- /src/standard_training/losses/diffusion_losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from abc import ABC 4 | from typing import Optional 5 | from torch.nn import functional as F 6 | from diffusers import SchedulerMixin 7 | 8 | from src.common.utils import extract_into_tensor 9 | 10 | 11 | class DiffusionLoss(ABC): 12 | 13 | def __init__(self, scheduler: SchedulerMixin): 14 | self.scheduler = scheduler 15 | 16 | def __call__(self, target: torch.Tensor, pred: torch.Tensor, timesteps: Optional[torch.Tensor] = None): 17 | raise NotImplementedError 18 | 19 | 20 | class MSELoss(DiffusionLoss): 21 | 22 | def __init__(self, scheduler: SchedulerMixin): 23 | super().__init__(scheduler) 24 | 25 | def __call__(self, target: torch.Tensor, pred: torch.Tensor, timesteps: Optional[torch.Tensor] = None): 26 | loss = F.mse_loss(target, pred) 27 | return loss 28 | 29 | 30 | class SmoothL1Loss(DiffusionLoss): 31 | 32 | def __init__(self, scheduler: SchedulerMixin): 33 | super().__init__(scheduler) 34 | 35 | def __call__(self, target: torch.Tensor, pred: torch.Tensor, timesteps: Optional[torch.Tensor] = None): 36 | loss = F.smooth_l1_loss(target, pred) 37 | return loss 38 | 39 | 40 | class MinSNRLoss(DiffusionLoss): 41 | """ 42 | Based on https://github.com/TiankaiHang/Min-SNR-Diffusion-Training 43 | """ 44 | 45 | def __init__(self, scheduler: SchedulerMixin, k: int = 5, divide_by_snr: bool = True, reduction: str = "mean"): 46 | super().__init__(scheduler) 47 | self.k = k 48 | self.divide_by_snr = divide_by_snr 49 | self.reduction = reduction 50 | 51 | def __call__(self, target: torch.Tensor, pred: torch.Tensor, timesteps: Optional[torch.Tensor] = None): 52 | assert timesteps is not None 53 | 54 | sqrt_alphas_cumprod = (self.scheduler.alphas_cumprod ** 0.5) 55 | sqrt_one_minus_alpha_prod = (1 - self.scheduler.alphas_cumprod) ** 0.5 56 | alpha = extract_into_tensor( 57 | sqrt_alphas_cumprod, timesteps, timesteps.shape) 58 | sigma = extract_into_tensor( 59 | sqrt_one_minus_alpha_prod, timesteps, timesteps.shape) 60 | snr = (alpha / sigma) ** 2 61 | mse_loss_weight = torch.stack( 62 | [snr, self.k * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] 63 | 64 | if self.divide_by_snr: 65 | mse_loss_weight = mse_loss_weight / snr 66 | 67 | loss = mse_loss_weight * F.mse_loss(target, pred) 68 | 69 | if self.reduction == "mean": 70 | loss = loss.mean() 71 | elif self.reduction == "sum": 72 | loss = loss.sum() 73 | else: 74 | raise NotImplementedError 75 | 76 | return loss 77 | -------------------------------------------------------------------------------- /src/standard_training/trackers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Atenrev/diffusion_continual_learning/0d88ac5dea1ed7595b3f1f2f333b3bc1ce555f53/src/standard_training/trackers/__init__.py -------------------------------------------------------------------------------- /src/standard_training/trackers/base_tracker.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script from https://github.com/ArjanCodes/2021-data-science-refactor/blob/main/after/ds/tracking.py 3 | """ 4 | from enum import Enum, auto 5 | import torch 6 | from typing import List, Protocol 7 | import numpy as np 8 | 9 | 10 | class Stage(Enum): 11 | TRAIN = auto() 12 | VAL = auto() 13 | TEST = auto() 14 | 15 | 16 | class ExperimentTracker(Protocol): 17 | def save_checkpoint(self, epoch: int, model: torch.nn.Module, optimizer: torch.optim.Optimizer): 18 | """Saves a checkpoint of the model.""" 19 | 20 | def set_stage(self, stage: Stage): 21 | """Sets the current stage of the experiment.""" 22 | 23 | def add_batch_metric(self, name: str, value: float, step: int, commit: bool = True): 24 | """Implements logging a batch-level metric.""" 25 | 26 | def add_epoch_metric(self, name: str, value: float, step: int): 27 | """Implements logging a epoch-level metric.""" 28 | 29 | def add_epoch_confusion_matrix( 30 | self, y_true: List[np.ndarray], y_pred: List[np.ndarray], step: int 31 | ): 32 | """Implements logging a confusion matrix at epoch-level.""" 33 | 34 | def flush(self): 35 | """Implements flushing the experiment tracker.""" 36 | 37 | def finish(self): 38 | """Implements finishing the experiment tracker.""" -------------------------------------------------------------------------------- /src/standard_training/trackers/csv_tracker.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from src.standard_training.trackers.base_tracker import Stage, ExperimentTracker 5 | 6 | 7 | class CSVTracker(ExperimentTracker): 8 | """ 9 | Creates a tracker that implements the ExperimentTracker protocol and logs to a CSV file. 10 | """ 11 | 12 | def __init__(self, configs: dict, results_path: str): 13 | self.stage = Stage.TRAIN 14 | self.csv_file_train = open(os.path.join(results_path, "train.csv"), "w") 15 | self.csv_file_test = open(os.path.join(results_path, "test.csv"), "w") 16 | 17 | # Save configs 18 | with open(os.path.join(results_path, "configs.json"), "w") as f: 19 | json.dump(configs, f) 20 | 21 | # Create CSV files 22 | self.csv_file_train.write("step,metric,value\n") 23 | self.csv_file_test.write("epoch,metric,value\n") 24 | 25 | def set_stage(self, stage: Stage): 26 | self.stage = stage 27 | 28 | def add_batch_metric(self, name: str, value: float, step: int, commit: bool = True): 29 | if self.stage == Stage.TRAIN: 30 | self.csv_file_train.write(f"{step},{name},{value}\n") 31 | else: 32 | pass 33 | 34 | def add_epoch_metric(self, name: str, value: float, step: int, commit: bool = True): 35 | if self.stage == Stage.TRAIN: 36 | pass 37 | else: 38 | self.csv_file_test.write(f"{step},{name},{value}\n") 39 | 40 | def flush(self): 41 | self.csv_file_train.flush() 42 | self.csv_file_test.flush() 43 | 44 | def finish(self): 45 | self.csv_file_train.close() 46 | self.csv_file_test.close() -------------------------------------------------------------------------------- /src/standard_training/trackers/wandb_tracker.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | 3 | from typing import List 4 | 5 | from src.standard_training.trackers.base_tracker import Stage, ExperimentTracker 6 | 7 | 8 | class WandbTracker(ExperimentTracker): 9 | """ 10 | Creates a tracker that implements the ExperimentTracker protocol and logs to wandb. 11 | """ 12 | 13 | def __init__(self, configs: dict, experiment_name: str, project_name: str, tags: List[str] = None): 14 | self.stage = Stage.TRAIN 15 | 16 | self.run = wandb.init(project=project_name, 17 | name=experiment_name, 18 | config=configs, 19 | tags=tags) 20 | 21 | wandb.define_metric("batch_step") 22 | wandb.define_metric("epoch") 23 | 24 | # TODO: Make metrics dynamic 25 | for metric in ["loss"]: 26 | for stage in Stage: 27 | wandb.define_metric(f"{stage.name}/batch_{metric}", step_metric='batch_step') 28 | wandb.define_metric(f"{stage.name}/epoch_{metric}", step_metric='epoch') 29 | 30 | wandb.define_metric(f"{Stage.TEST}/batch_fid", step_metric='batch_step') 31 | wandb.define_metric(f"{Stage.TEST}/epoch_fid", step_metric='epoch') 32 | 33 | def set_stage(self, stage: Stage): 34 | self.stage = stage 35 | 36 | def add_batch_metric(self, name: str, value: float, step: int, commit: bool = True): 37 | wandb.log({f"{self.stage.name}/batch_{name}": value, "batch_step": step}, commit=commit) 38 | 39 | def add_epoch_metric(self, name: str, value: float, step: int, commit: bool = True): 40 | wandb.log({f"{self.stage.name}/epoch_{name}": value, "epoch": step}, commit=commit) 41 | 42 | def flush(self): 43 | pass 44 | 45 | def finish(self): 46 | wandb.finish() -------------------------------------------------------------------------------- /src/standard_training/trainers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Atenrev/diffusion_continual_learning/0d88ac5dea1ed7595b3f1f2f333b3bc1ce555f53/src/standard_training/trainers/__init__.py -------------------------------------------------------------------------------- /src/standard_training/trainers/base_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from typing import Any, Optional 5 | from abc import ABC 6 | 7 | from src.standard_training.evaluators.base_evaluator import BaseEvaluator 8 | 9 | 10 | class BaseTrainer(ABC): 11 | 12 | def __init__(self, 13 | model: torch.nn.Module, 14 | optimizer: torch.optim.Optimizer, 15 | criterion: Any, 16 | train_mb_size: int, 17 | train_epochs: int, 18 | eval_mb_size: int, 19 | device: str, 20 | evaluator: Optional[BaseEvaluator] = None, 21 | save_path: str = "./results/diffusion", 22 | ): 23 | """ 24 | Class for training generative models in a traditional way. 25 | """ 26 | self.model = model 27 | self.optimizer = optimizer 28 | self.criterion = criterion 29 | self.train_mb_size = train_mb_size 30 | self.train_epochs = train_epochs 31 | self.eval_mb_size = eval_mb_size 32 | self.device = device 33 | self.evaluator = evaluator 34 | self.best_model = None 35 | 36 | def save(self, path: str, epoch: int): 37 | model_path = os.path.join(path, "model") 38 | os.makedirs(model_path, exist_ok=True) 39 | 40 | # Save model and scheduler 41 | torch.save({ 42 | "epoch": epoch, 43 | "model_state_dict": self.model.state_dict(), 44 | "optimizer_state_dict": self.optimizer.state_dict(), 45 | }, os.path.join(model_path, f"model_{epoch}.pt")) 46 | 47 | def train(self, train_loader: Any, eval_loader: Any, save_path: str = "./results/generative", save_every: int = 1, **kwargs): 48 | raise NotImplementedError 49 | 50 | def evaluate(self, eval_loader, save_path: str = "./results/generative"): 51 | assert self.evaluator is not None 52 | assert self.best_model is not None 53 | metrics = self.evaluator.evaluate(self.best_model, eval_loader, 0, save_path=save_path) 54 | return metrics -------------------------------------------------------------------------------- /src/standard_training/trainers/diffusion_training.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from tqdm import tqdm 5 | from copy import deepcopy 6 | from typing import Optional, Tuple 7 | from diffusers import DDIMPipeline, SchedulerMixin, EMAModel 8 | 9 | from src.standard_training.losses.diffusion_losses import DiffusionLoss 10 | from src.standard_training.evaluators.base_evaluator import BaseEvaluator 11 | from src.standard_training.trackers.wandb_tracker import WandbTracker 12 | from src.standard_training.trackers.base_tracker import Stage 13 | 14 | 15 | class DiffusionTraining: 16 | 17 | def __init__(self, 18 | model: torch.nn.Module, 19 | scheduler: SchedulerMixin, 20 | optimizer: torch.optim.Optimizer, 21 | criterion: DiffusionLoss, 22 | train_mb_size: int, 23 | train_epochs: int, 24 | eval_mb_size: int, 25 | device: str, 26 | train_timesteps: int, 27 | evaluator: Optional[BaseEvaluator] = None, 28 | tracker: Optional[WandbTracker] = None, 29 | save_path: str = "./results/diffusion", 30 | ): 31 | self.model = model 32 | self.scheduler = scheduler 33 | self.optimizer = optimizer 34 | self.criterion = criterion 35 | self.train_mb_size = train_mb_size 36 | self.train_epochs = train_epochs 37 | self.eval_mb_size = eval_mb_size 38 | self.device = device 39 | self.evaluator = evaluator 40 | self.train_timesteps = train_timesteps 41 | self.tracker = tracker 42 | self.save_path = save_path 43 | self.best_model_path = os.path.join(self.save_path, "best_model") 44 | self.last_model_path = os.path.join(self.save_path, "last_model") 45 | 46 | # adjust = 1* args.batch_size * args.model_ema_steps / args.epochs 47 | # alpha = 1.0 - args.model_ema_decay 48 | # alpha = min(1.0, alpha * adjust) 49 | # self.model_ema = EMAModel(model.parameters(), power=3/4) 50 | 51 | self.current_epoch = 0 52 | self.best_auc = torch.inf 53 | self.best_model = None 54 | 55 | try: 56 | self.load(self.best_model_path) 57 | except: 58 | pass 59 | 60 | os.makedirs(self.best_model_path, exist_ok=True) 61 | os.makedirs(self.last_model_path, exist_ok=True) 62 | 63 | def load(self, path: str): 64 | if not os.path.exists(path): 65 | return 66 | 67 | pipeline = DDIMPipeline.from_pretrained(path) 68 | self.model = self.model.to("cpu") 69 | self.model.load_state_dict(pipeline.unet.state_dict()) 70 | self.model = self.model.to(self.device) 71 | del pipeline 72 | 73 | # Load optimizer and training state 74 | checkpoint = torch.load(os.path.join(path, "training_state.pt")) 75 | self.optimizer.load_state_dict(checkpoint["optimizer"]) 76 | self.current_epoch = checkpoint["current_epoch"] 77 | self.best_auc = checkpoint["best_auc"] 78 | 79 | def save(self, path: str): 80 | pipeline = DDIMPipeline(self.model, self.scheduler) 81 | pipeline.save_pretrained(path) 82 | 83 | # Save optimizer and training state 84 | torch.save({ 85 | "optimizer": self.optimizer.state_dict(), 86 | "current_epoch": self.current_epoch, 87 | "best_auc": self.best_auc, 88 | }, os.path.join(path, "training_state.pt")) 89 | 90 | def forward(self, timesteps: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 91 | raise NotImplementedError 92 | 93 | def train(self, train_loader, eval_loader, save_every: int = 1): 94 | for epoch in range(self.current_epoch, self.train_epochs): 95 | print(f"Epoch {epoch}") 96 | self.current_epoch = epoch 97 | 98 | self.model.train() 99 | bar = tqdm(enumerate(train_loader), 100 | desc="Training loop", total=len(train_loader)) 101 | average_loss = 0 102 | 103 | if self.tracker is not None: 104 | self.tracker.set_stage(Stage.TRAIN) 105 | 106 | for step, clean_images in bar: 107 | self.optimizer.zero_grad() 108 | 109 | batch_size = clean_images["pixel_values"].shape[0] 110 | clean_images = clean_images["pixel_values"].to(self.device) 111 | 112 | noise = torch.randn(clean_images.shape).to(clean_images.device) 113 | timesteps = torch.randint( 114 | 0, self.train_timesteps, (batch_size,), device=self.device 115 | ).long() 116 | noisy_images = self.scheduler.add_noise( 117 | clean_images, noise, timesteps) 118 | 119 | noise_pred = self.model( 120 | noisy_images, timesteps, return_dict=False)[0] 121 | 122 | loss = self.criterion(noise_pred, noise, timesteps) 123 | 124 | loss.backward() 125 | # Clip gradients to avoid exploding gradients 126 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) 127 | 128 | self.optimizer.step() 129 | # self.model_ema.step() 130 | 131 | if self.tracker is not None: 132 | self.tracker.add_batch_metric("loss", loss.item(), step + epoch * len(train_loader)) 133 | 134 | average_loss += loss.item() 135 | bar.set_postfix(loss=average_loss / (step + 1)) 136 | 137 | if self.tracker is not None: 138 | self.tracker.add_epoch_metric("loss", average_loss / len(train_loader), epoch) 139 | self.tracker.flush() 140 | 141 | auc = torch.inf 142 | 143 | if save_every > 0 and epoch % save_every == 0 and epoch > 0 or epoch == self.train_epochs - 1: 144 | if self.evaluator is not None: 145 | if self.tracker is not None: 146 | self.tracker.set_stage(Stage.TEST) 147 | 148 | self.model.eval() 149 | metrics = self.evaluator.evaluate(self.model, eval_loader, epoch, compute_auc=True) 150 | auc = metrics["auc"] 151 | 152 | if self.tracker is not None: 153 | for key, value in metrics.items(): 154 | self.tracker.add_epoch_metric(key, value, epoch) 155 | self.tracker.flush() 156 | 157 | if auc <= self.best_auc: 158 | print(f"New best model with AUC {auc}") 159 | self.best_auc = auc 160 | self.best_model = deepcopy(self.model) 161 | self.save(self.best_model_path) 162 | 163 | self.save(self.last_model_path) 164 | 165 | if self.tracker is not None: 166 | self.tracker.finish() 167 | 168 | -------------------------------------------------------------------------------- /src/standard_training/trainers/generative_training.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from tqdm import tqdm 5 | from typing import Optional, Any 6 | 7 | from src.standard_training.evaluators.base_evaluator import BaseEvaluator 8 | from src.standard_training.trainers.base_trainer import BaseTrainer 9 | 10 | 11 | class GenerativeTraining(BaseTrainer): 12 | 13 | def __init__(self, 14 | model: torch.nn.Module, 15 | optimizer: torch.optim.Optimizer, 16 | criterion: Any, 17 | train_mb_size: int, 18 | train_epochs: int, 19 | eval_mb_size: int, 20 | device: str, 21 | evaluator: Optional[BaseEvaluator] = None, 22 | ): 23 | """ 24 | Class for training generative models in a traditional way. 25 | """ 26 | super().__init__(model, optimizer, criterion, train_mb_size, train_epochs, eval_mb_size, device, evaluator) 27 | 28 | def train(self, train_loader, eval_loader, save_path: str = "./results/generative"): 29 | if self.evaluator is not None: 30 | assert eval_loader is not None 31 | 32 | best_fid = torch.inf 33 | 34 | for epoch in range(self.train_epochs): 35 | bar = tqdm(train_loader, desc=f"Training epoch {epoch}", total=len(train_loader)) 36 | 37 | for batch in bar: 38 | self.optimizer.zero_grad() 39 | 40 | batch = batch["pixel_values"].to(self.device) 41 | pred = self.model(batch) 42 | loss = self.criterion(batch, pred) 43 | 44 | loss.backward() 45 | self.optimizer.step() 46 | 47 | bar.set_postfix(loss=loss.item()) 48 | 49 | fid = torch.inf 50 | 51 | if self.evaluator is not None: 52 | fid = self.evaluator.evaluate(self.model, eval_loader, epoch)["fid"] 53 | 54 | if fid <= best_fid: 55 | best_fid = fid 56 | self.best_model = self.model 57 | self.save(save_path, epoch) 58 | 59 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Atenrev/diffusion_continual_learning/0d88ac5dea1ed7595b3f1f2f333b3bc1ce555f53/tests/__init__.py -------------------------------------------------------------------------------- /tests/continual_learning/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Atenrev/diffusion_continual_learning/0d88ac5dea1ed7595b3f1f2f333b3bc1ce555f53/tests/continual_learning/__init__.py -------------------------------------------------------------------------------- /tests/continual_learning/metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Atenrev/diffusion_continual_learning/0d88ac5dea1ed7595b3f1f2f333b3bc1ce555f53/tests/continual_learning/metrics/__init__.py -------------------------------------------------------------------------------- /tests/continual_learning/metrics/test_diffusion_metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import unittest 3 | 4 | from src.continual_learning.metrics.diffusion_metrics import FIDMetric 5 | from src.continual_learning.metrics.diffusion_metrics import DistributionMetrics 6 | 7 | 8 | class TestFIDMetric(unittest.TestCase): 9 | def test_fid_metric(self): 10 | fid_metric = FIDMetric(device="cpu") 11 | true_y = torch.randn(10, 3, 32, 32) 12 | predicted_y = torch.randn(10, 3, 32, 32) 13 | 14 | fid_metric.update_true(true_y) 15 | fid_metric.update_predicted(predicted_y) 16 | 17 | fid_score = fid_metric.result() 18 | self.assertIsInstance(fid_score, float) 19 | 20 | 21 | class TestDistributionMetrics(unittest.TestCase): 22 | def test_distribution_metrics(self): 23 | distribution_metrics = DistributionMetrics() 24 | true_y = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) 25 | predicted_y = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) 26 | 27 | distribution_metrics.update_true(true_y) 28 | distribution_metrics.update_predicted(predicted_y) 29 | 30 | ard, kl_div, ratio_true, ratio_pred = distribution_metrics.result() 31 | self.assertIsInstance(ard, float) 32 | self.assertIsInstance(kl_div, float) 33 | self.assertIsInstance(ratio_true, torch.Tensor) 34 | self.assertIsInstance(ratio_pred, torch.Tensor) 35 | self.assertEqual(ratio_true.shape, (10,)) 36 | self.assertEqual(ratio_pred.shape, (10,)) 37 | self.assertAlmostEqual(ard, 0.0, delta=1e-6) 38 | self.assertAlmostEqual(kl_div, 0.0, delta=1e-6) 39 | 40 | def test_distribution_metrics_with_different_distributions(self): 41 | distribution_metrics = DistributionMetrics() 42 | true_y = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) 43 | predicted_y = torch.tensor([0, 1, 1, 1, 4, 5, 6, 7, 8, 8]) 44 | 45 | distribution_metrics.update_true(true_y) 46 | distribution_metrics.update_predicted(predicted_y) 47 | 48 | ard, kl_div, ratio_true, ratio_pred = distribution_metrics.result() 49 | self.assertIsInstance(ard, float) 50 | self.assertIsInstance(kl_div, float) 51 | self.assertIsInstance(ratio_true, torch.Tensor) 52 | self.assertIsInstance(ratio_pred, torch.Tensor) 53 | self.assertEqual(ratio_true.shape, (10,)) 54 | self.assertEqual(ratio_pred.shape, (10,)) 55 | self.assertGreater(ard, 0.0) 56 | self.assertGreater(kl_div, 0.0) -------------------------------------------------------------------------------- /tests/continual_learning/test_loggers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import shutil 4 | import unittest 5 | 6 | from unittest.mock import MagicMock 7 | 8 | from src.continual_learning.loggers import CSVLogger 9 | 10 | 11 | class TestCSVLogger(unittest.TestCase): 12 | def setUp(self): 13 | self.logger = CSVLogger(log_folder="test_logs") 14 | 15 | def tearDown(self): 16 | self.logger.close() 17 | shutil.rmtree("test_logs") 18 | 19 | def test_init(self): 20 | self.assertTrue(os.path.exists("test_logs")) 21 | self.assertTrue(os.path.exists("test_logs/training_results.csv")) 22 | self.assertTrue(os.path.exists("test_logs/eval_results.csv")) 23 | 24 | def test_print_csv_headers(self): 25 | with open("test_logs/training_results.csv", "r") as f: 26 | reader = csv.reader(f) 27 | header = next(reader) 28 | self.assertEqual(header, ["metric_name", "training_exp", "epoch", "x_plot", "value"]) 29 | 30 | with open("test_logs/eval_results.csv", "r") as f: 31 | reader = csv.reader(f) 32 | header = next(reader) 33 | self.assertEqual(header, ["metric_name", "eval_exp", "training_exp", "value"]) 34 | 35 | def test_log_single_metric(self): 36 | self.logger.metric_vals = {} 37 | self.logger.log_single_metric("accuracy", 0.9, 1) 38 | self.logger.log_single_metric("loss", 0.1, 1) 39 | self.assertEqual(self.logger.metric_vals, {"accuracy": [("accuracy", 1, 0.9)], "loss": [("loss", 1, 0.1)]}) 40 | 41 | def test_print_train_metrics(self): 42 | self.logger.training_file = MagicMock() 43 | self.logger.metric_vals = {"accuracy": [("accuracy", 1, 0.9)], "loss": [("loss", 1, 0.1)]} 44 | self.logger.print_train_metrics(1, 1) 45 | args = self.logger.training_file.write.call_args_list 46 | self.assertEqual("".join([arg[0][0] for arg in args]), "accuracy,1,1,1,0.9000\nloss,1,1,1,0.1000\n") 47 | 48 | def test_print_eval_metrics(self): 49 | self.logger.eval_file = MagicMock() 50 | self.logger.metric_vals = {"accuracy": [("accuracy", 1, 0.9)], "loss": [("loss", 1, 0.1)]} 51 | self.logger.print_eval_metrics(1, 2) 52 | # assert it's called many times to form "accuracy,1,2,0.9\nloss,1,2,0.1\n" 53 | args = self.logger.eval_file.write.call_args_list 54 | self.assertEqual("".join([arg[0][0] for arg in args]), "accuracy,1,2,0.9000\nloss,1,2,0.1000\n") 55 | -------------------------------------------------------------------------------- /tests/pipelines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Atenrev/diffusion_continual_learning/0d88ac5dea1ed7595b3f1f2f333b3bc1ce555f53/tests/pipelines/__init__.py -------------------------------------------------------------------------------- /tests/pipelines/test_pipeline_ddim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import unittest 3 | 4 | from src.schedulers.scheduler_ddim import DDIMScheduler 5 | from src.pipelines.pipeline_ddim import DDIMPipeline 6 | 7 | class TestDDIMPipeline(unittest.TestCase): 8 | 9 | def dummy_model(self): 10 | class DummyModel: 11 | def __init__(self) -> None: 12 | config = unittest.mock.Mock() 13 | config.in_channels = torch.tensor(1) 14 | config.sample_size = torch.tensor((32, 32)) 15 | self.config = config 16 | self.dtype = torch.float32 17 | 18 | def __call__(self, sample, t, *args): 19 | mock_response = unittest.mock.Mock() 20 | mock_response.sample = sample.clone() 21 | 22 | for s in range(sample.shape[0]): 23 | for i in range(sample.shape[2]): 24 | if mock_response.sample[s, 0, 0, i] >= t[s]: 25 | continue 26 | mock_response.sample[s, 0, 0, i] = t[s] 27 | break 28 | 29 | return mock_response 30 | 31 | model = DummyModel() 32 | return model 33 | 34 | def setUp(self): 35 | self.unet = self.dummy_model() 36 | self.scheduler = DDIMScheduler() 37 | self.pipeline = DDIMPipeline(self.unet, self.scheduler) 38 | 39 | def test_mask(self): 40 | batch_size = 2 41 | generator = torch.Generator() 42 | generator.manual_seed(42) 43 | target_timesteps = torch.tensor([0, 999]) 44 | output = self.pipeline(batch_size=batch_size, num_inference_steps=3, 45 | target_steps=target_timesteps, use_clipped_model_output=None, 46 | output_type="torch_raw", return_dict=False, generator=generator) 47 | self.assertEqual(output.shape, (batch_size, 1, 32, 32)) 48 | unmasked_output = output[0, 0, 0] 49 | masked_output = output[1, 0, 0] 50 | self.assertGreater(masked_output[0], 2) 51 | self.assertGreater(masked_output[1], 2) 52 | self.assertLess(masked_output[2], 2) 53 | self.assertEqual(abs(unmasked_output[0]), 1) 54 | self.assertEqual(abs(unmasked_output[1]), 1) 55 | self.assertEqual(abs(unmasked_output[2]), 1) -------------------------------------------------------------------------------- /tests/schedulers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Atenrev/diffusion_continual_learning/0d88ac5dea1ed7595b3f1f2f333b3bc1ce555f53/tests/schedulers/__init__.py -------------------------------------------------------------------------------- /tests/schedulers/test_scheduler_ddim.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | import numpy as np 4 | 5 | from src.schedulers.scheduler_ddim import DDIMScheduler 6 | 7 | class TestDDIMScheduler(unittest.TestCase): 8 | 9 | def test_set_timesteps(self): 10 | scheduler = DDIMScheduler() 11 | device = torch.device('cpu') 12 | num_inference_steps = 10 13 | 14 | target_steps = 0 15 | expected_timesteps = torch.from_numpy( 16 | np.array([999, 899, 799, 699, 599, 499, 399, 299, 199, 99]) + scheduler.config.steps_offset) 17 | scheduler.set_timesteps(num_inference_steps, target_steps, device) 18 | assert scheduler.num_inference_steps == num_inference_steps 19 | assert scheduler.target_steps == target_steps 20 | assert torch.equal(scheduler.timesteps, expected_timesteps) 21 | target_steps = 10 22 | expected_timesteps = torch.from_numpy( 23 | np.array([999, 900, 801, 702, 603, 504, 405, 306, 207, 108]) + scheduler.config.steps_offset) 24 | scheduler.set_timesteps(num_inference_steps, target_steps, device) 25 | assert scheduler.num_inference_steps == num_inference_steps 26 | assert scheduler.target_steps == target_steps 27 | assert torch.equal(scheduler.timesteps, expected_timesteps) 28 | 29 | target_steps = np.asarray([0, 10]) 30 | expected_timesteps = torch.from_numpy(np.vstack([ 31 | (np.array([999, 899, 799, 699, 599, 499, 399, 299, 199, 99]) + scheduler.config.steps_offset), 32 | (np.array([999, 900, 801, 702, 603, 504, 405, 306, 207, 108]) + scheduler.config.steps_offset), 33 | ])).T 34 | scheduler.set_timesteps(num_inference_steps, target_steps, device) 35 | assert scheduler.num_inference_steps == num_inference_steps 36 | assert np.equal(scheduler.target_steps, target_steps).all() 37 | assert torch.equal(scheduler.timesteps, expected_timesteps) 38 | 39 | -------------------------------------------------------------------------------- /train_kld_classifier.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import torchvision.transforms as transforms 7 | from copy import deepcopy 8 | from tqdm import tqdm 9 | from torchvision.datasets import FashionMNIST, CIFAR10 10 | from torch.utils.data import DataLoader 11 | from torchvision.models import resnet18 12 | 13 | class AddGaussianNoise(object): 14 | def __init__(self, mean=0., std=1.): 15 | self.std = std 16 | self.mean = mean 17 | 18 | def __call__(self, tensor): 19 | return tensor + torch.randn(tensor.size()) * self.std + self.mean 20 | 21 | def __repr__(self): 22 | return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) 23 | 24 | 25 | def get_data_loaders(batch_size=64, dataset='FashionMNIST'): 26 | if dataset == 'FashionMNIST': 27 | transform_train = transforms.Compose([ 28 | transforms.Resize(32), 29 | transforms.RandomHorizontalFlip(), 30 | transforms.RandomRotation(10), 31 | transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)), 32 | transforms.ToTensor(), 33 | transforms.Lambda(lambda x: x.repeat(3, 1, 1)), 34 | transforms.Normalize((0.5,), (0.5,)), 35 | AddGaussianNoise(0., 0.1), 36 | ]) 37 | 38 | transform_test = transforms.Compose([ 39 | transforms.Resize(32), 40 | transforms.ToTensor(), 41 | transforms.Lambda(lambda x: x.repeat(3, 1, 1)), 42 | transforms.Normalize((0.5,), (0.5,)) 43 | ]) 44 | train_dataset = FashionMNIST(root='./data', train=True, transform=transform_train, download=True) 45 | train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [50000, 10000]) 46 | test_dataset = FashionMNIST(root='./data', train=False, transform=transform_test, download=True) 47 | elif dataset == 'CIFAR10': 48 | transform_train = transforms.Compose([ 49 | transforms.Resize(32), 50 | transforms.RandomHorizontalFlip(), 51 | transforms.RandomRotation(10), 52 | transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)), 53 | transforms.ToTensor(), 54 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 55 | AddGaussianNoise(0., 0.1), 56 | ]) 57 | 58 | transform_test = transforms.Compose([ 59 | transforms.Resize(32), 60 | transforms.ToTensor(), 61 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 62 | ]) 63 | train_dataset = CIFAR10(root='./data', train=True, transform=transform_train, download=True) 64 | train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [45000, 5000]) 65 | test_dataset = CIFAR10(root='./data', train=False, transform=transform_test, download=True) 66 | else: 67 | raise ValueError(f"Dataset {dataset} not supported.") 68 | 69 | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 70 | val_loader = DataLoader(val_dataset, batch_size=batch_size) 71 | test_loader = DataLoader(test_dataset, batch_size=batch_size) 72 | 73 | return train_loader, val_loader, test_loader 74 | 75 | 76 | def get_model(num_classes=10): 77 | model = resnet18(pretrained='imagenet') 78 | model.fc = nn.Linear(model.fc.in_features, num_classes) 79 | return model 80 | 81 | 82 | def train_model(model, train_loader, val_loader, num_epochs=10, learning_rate=0.0001, device='cuda'): 83 | best_model = None 84 | best_accuracy = 0.0 85 | criterion = nn.CrossEntropyLoss() 86 | optimizer = optim.Adam(model.parameters(), lr=learning_rate) 87 | 88 | model.to(device) 89 | 90 | for epoch in range(num_epochs): 91 | model.train() 92 | running_loss = 0.0 93 | accuracy = 0.0 94 | bar = tqdm(train_loader) 95 | for inputs, labels in bar: 96 | inputs, labels = inputs.to(device), labels.to(device) 97 | optimizer.zero_grad() 98 | outputs = model(inputs) 99 | loss = criterion(outputs, labels) 100 | loss.backward() 101 | optimizer.step() 102 | running_loss += loss.item() 103 | accuracy += (outputs.argmax(1) == labels).float().mean() 104 | bar.set_description(f"Loss: {loss.item():.4f}, Accuracy: {accuracy.item() / (bar.n + 1):.4f}") 105 | 106 | print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {running_loss / len(train_loader)}, Accuracy: {accuracy / len(train_loader)}") 107 | 108 | # Evaluation 109 | model.eval() 110 | correct = 0 111 | total = 0 112 | 113 | with torch.no_grad(): 114 | for inputs, labels in tqdm(val_loader): 115 | inputs, labels = inputs.to(device), labels.to(device) 116 | outputs = model(inputs) 117 | _, predicted = torch.max(outputs, 1) 118 | total += labels.size(0) 119 | correct += (predicted == labels).sum().item() 120 | 121 | val_accuracy = 100 * correct / total 122 | 123 | if val_accuracy > best_accuracy: 124 | best_accuracy = val_accuracy 125 | best_model = deepcopy(model) 126 | 127 | print(f"Validation Accuracy: {val_accuracy:.2f}%") 128 | 129 | return best_model 130 | 131 | 132 | def evaluate_model(model, test_loader, device='cuda'): 133 | model.eval() 134 | correct = 0 135 | total = 0 136 | 137 | with torch.no_grad(): 138 | for inputs, labels in tqdm(test_loader): 139 | inputs, labels = inputs.to(device), labels.to(device) 140 | outputs = model(inputs) 141 | _, predicted = torch.max(outputs, 1) 142 | total += labels.size(0) 143 | correct += (predicted == labels).sum().item() 144 | 145 | print(f"Test Accuracy: {100 * correct / total:.2f}%") 146 | 147 | 148 | def save_model(model, save_folder='./weights/cnn_fmnist'): 149 | os.makedirs(save_folder, exist_ok=True) 150 | save_path = os.path.join(save_folder, 'resnet.pth') 151 | torch.save(model.state_dict(), save_path) 152 | print(f"Model saved at {save_path}") 153 | 154 | 155 | def main(): 156 | parser = argparse.ArgumentParser(description='Train and save a ResNet model on the given dataset.') 157 | parser.add_argument('--batch_size', type=int, default=64, help='Batch size for data loaders') 158 | parser.add_argument('--num_epochs', type=int, default=100, help='Number of epochs for training') 159 | parser.add_argument('--learning_rate', type=float, default=0.0005, help='Learning rate for the optimizer') 160 | parser.add_argument('--dataset', type=str, default='FashionMNIST', help='Dataset to use for training') 161 | parser.add_argument('--device', type=str, default='cuda', help='Device to use for training (cuda or cpu)') 162 | parser.add_argument('--output_path', type=str, default='./weights/cnn_fmnist', help='Path to save the model') 163 | args = parser.parse_args() 164 | 165 | train_loader, val_loader, test_loader = get_data_loaders(batch_size=args.batch_size, dataset=args.dataset) 166 | model = get_model() 167 | model = train_model(model, train_loader, val_loader, num_epochs=args.num_epochs, 168 | learning_rate=args.learning_rate, device=args.device) 169 | evaluate_model(model, test_loader, device=args.device) 170 | save_model(model, save_folder=args.output_path) 171 | 172 | if __name__ == "__main__": 173 | main() 174 | -------------------------------------------------------------------------------- /utils/compute_mnist_statistics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import numpy as np 5 | 6 | from tqdm import tqdm 7 | from typing import Any 8 | import torch.nn as nn 9 | from torchvision import transforms 10 | from torchvision.models import resnet18 11 | 12 | import sys 13 | from pathlib import Path 14 | # This script should be run from the root of the project 15 | sys.path.append(str(Path(__file__).parent.parent)) 16 | 17 | # from src.models.simple_cnn import SimpleCNN 18 | from src.common.utils import get_configuration 19 | from src.datasets.fashion_mnist import create_dataloader 20 | from src.pipelines.pipeline_ddim import DDIMPipeline 21 | from src.common.visual import plot_bar 22 | from src.models.simple_cnn import SimpleCNN 23 | 24 | 25 | preprocess = transforms.Compose( 26 | [ 27 | transforms.ToTensor(), 28 | transforms.Normalize([0.5], [0.5]), 29 | transforms.Resize((32, 32)), 30 | # Repeat channels to fit ResNet18 31 | transforms.Lambda(lambda x: x.repeat(3, 1, 1)), 32 | ] 33 | ) 34 | 35 | 36 | def __parse_args() -> argparse.Namespace: 37 | parser = argparse.ArgumentParser() 38 | 39 | parser.add_argument("--model_config_path", type=str, 40 | default="configs/model/resnet.json") 41 | parser.add_argument("--weights_path", type=str, 42 | default="weights/cnn_fmnist/") 43 | parser.add_argument("--generator_path", type=str, 44 | # default="results_fuji/smasipca/iid_results/fashion_mnist/diffusion/generation/ddim_medium_mse_teacher_2/42/best_model/") 45 | default="results/fashion_mnist/diffusion/None/ddim_medium_mse/42/best_model/") 46 | 47 | parser.add_argument("--classifier_batch_size", type=int, default=256) 48 | parser.add_argument("--generator_batch_size", type=int, default=128) 49 | 50 | parser.add_argument("--n_samples", type=int, default=2000) 51 | parser.add_argument("--n_steps", type=int, default=10) 52 | parser.add_argument("--eta", type=float, default=0.0) 53 | parser.add_argument("--device", type=str, default="cuda") 54 | 55 | return parser.parse_args() 56 | 57 | 58 | def main(args): 59 | device = args.device 60 | classifier = resnet18() 61 | classifier.fc = nn.Linear(classifier.fc.in_features, 10) 62 | print("Loading model from disk") 63 | classifier.load_state_dict(torch.load(os.path.join(args.weights_path, "resnet.pth"))) 64 | classifier.to(device) 65 | classifier.eval() 66 | 67 | _, test_loader = create_dataloader(args.classifier_batch_size, preprocess) 68 | 69 | samples_per_class = {i: 0 for i in range(10)} 70 | for batch in test_loader: 71 | with torch.no_grad(): 72 | batch_data = batch["pixel_values"].to(device) 73 | pred = classifier(batch_data) 74 | classes = torch.nn.functional.softmax(pred, dim=1) 75 | classes = torch.argmax(classes, dim=1) 76 | classes_np = classes.cpu().numpy() 77 | 78 | for c in classes_np: 79 | samples_per_class[c] += 1 80 | 81 | # Calculate entropy 82 | n_samples = sum(samples_per_class.values()) 83 | probabilities = np.array(list(samples_per_class.values())) / n_samples 84 | entropy = -np.sum(probabilities * np.log(probabilities)) 85 | print(f"Entropy: {entropy:.4f}") 86 | 87 | # Extract class names and sample counts 88 | class_names = list(samples_per_class.keys()) 89 | sample_counts = list(samples_per_class.values()) 90 | 91 | # Plot bar chart 92 | save_path = os.path.join(args.generator_path, f"mnist_samples_per_class_entropy_{entropy:.4f}.png") 93 | plot_bar( 94 | class_names, 95 | sample_counts, 96 | x_label="Classes", 97 | y_label="Number of samples", 98 | title="Number of samples for each class", 99 | save_path=save_path 100 | ) 101 | 102 | generator_pipeline = DDIMPipeline.from_pretrained(args.generator_path) 103 | generator_pipeline.set_progress_bar_config(disable=True) 104 | generator_pipeline = generator_pipeline.to(args.device) 105 | 106 | # initializes dict with the 10 classes to 0 107 | samples_per_class = {i: 0 for i in range(10)} 108 | entropy = 0.0 109 | n_iterations = args.n_samples // args.generator_batch_size 110 | pbar = tqdm(range(n_iterations)) 111 | for it in pbar: 112 | generated_samples = generator_pipeline( 113 | args.generator_batch_size, 114 | num_inference_steps=args.n_steps, 115 | eta=args.eta, 116 | output_type="torch_raw", 117 | ) 118 | 119 | # Repeat channels to fit the model 120 | generated_samples = generated_samples.repeat(1, 3, 1, 1) 121 | 122 | with torch.no_grad(): 123 | classes = classifier(generated_samples) 124 | 125 | # Compute entropy 126 | classes = torch.nn.functional.softmax(classes, dim=1) 127 | classes = torch.argmax(classes, dim=1) 128 | classes_np = classes.cpu().numpy() 129 | 130 | for c in classes_np: 131 | samples_per_class[c] += 1 132 | 133 | # Compute entropy 134 | n_samples = sum(samples_per_class.values()) 135 | probabilities = np.array(list(samples_per_class.values())) / n_samples 136 | entropy = -np.sum(probabilities * np.log(probabilities)) 137 | 138 | # Extract class names and sample counts 139 | class_names = list(samples_per_class.keys()) 140 | sample_counts = list(samples_per_class.values()) 141 | 142 | # Plot bar chart 143 | save_path = os.path.join(args.generator_path, f"mnist_samples_per_class_{args.n_steps}_entropy_{entropy:.4f}.png") 144 | plot_bar( 145 | class_names, 146 | sample_counts, 147 | x_label="Classes", 148 | y_label="Number of samples", 149 | title="Number of samples for each class", 150 | save_path=save_path 151 | ) 152 | 153 | 154 | if __name__ == "__main__": 155 | args = __parse_args() 156 | main(args) 157 | -------------------------------------------------------------------------------- /utils/generate_auc_vs_teachersteps.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | import numpy as np 5 | import random 6 | import json 7 | 8 | from torch.optim import Adam 9 | from torchvision import transforms 10 | from diffusers import UNet2DModel, DDIMScheduler 11 | 12 | import sys 13 | from pathlib import Path 14 | # This script should be run from the root of the project 15 | sys.path.append(str(Path(__file__).parent.parent)) 16 | 17 | from src.datasets.fashion_mnist import create_dataloader as create_fashion_mnist_dataloader 18 | from src.datasets.mnist import create_dataloader as create_mnist_dataloader 19 | from src.common.utils import get_configuration 20 | from src.common.diffusion_utils import wrap_in_pipeline 21 | from src.pipelines.pipeline_ddim import DDIMPipeline 22 | from src.standard_training.losses.diffusion_losses import MSELoss, MinSNRLoss 23 | from src.standard_training.trainers.diffusion_distillation import ( 24 | GaussianDistillation, 25 | PartialGenerationDistillation, 26 | GenerationDistillation, 27 | NoDistillation 28 | ) 29 | from src.standard_training.evaluators.generative_evaluator import GenerativeModelEvaluator 30 | from src.common.visual import plot_line_graph 31 | from src.standard_training.trackers.csv_tracker import CSVTracker 32 | 33 | 34 | def __parse_args() -> argparse.Namespace: 35 | parser = argparse.ArgumentParser() 36 | # 28 for vae, 32 for unet 37 | parser.add_argument("--image_size", type=int, default=32) 38 | parser.add_argument("--channels", type=int, default=1) 39 | 40 | parser.add_argument("--dataset", type=str, default="fashion_mnist") 41 | 42 | parser.add_argument("--model_config_path", type=str, 43 | default="configs/model/ddim_medium.json") 44 | parser.add_argument("--distillation_type", type=str, default="generation", 45 | help="Type of distillation to use (gaussian, generation, partial_generation, no_distillation)") 46 | parser.add_argument("--teacher_path", type=str, default="results/fashion_mnist/diffusion/None/ddim_medium_mse/42/best_model", 47 | help="Path to teacher model (only for distillation)") 48 | parser.add_argument("--criterion", type=str, default="mse", 49 | help="Criterion to use for training (mse, min_snr)") 50 | 51 | parser.add_argument("--generation_steps", type=int, default=20) 52 | parser.add_argument("--eta", type=float, default=0.0) 53 | parser.add_argument("--teacher_eta", type=float, default=0.0) 54 | 55 | parser.add_argument("--num_epochs", type=int, default=10000) 56 | parser.add_argument("--batch_size", type=int, default=128) 57 | parser.add_argument("--eval_batch_size", type=int, default=128) 58 | 59 | parser.add_argument("--save_every", type=int, default=1000, 60 | help="Save model every n iterations (only for distillation)") 61 | parser.add_argument("--seed", type=int, default=42) 62 | return parser.parse_args() 63 | 64 | 65 | def create_model(config): 66 | return UNet2DModel( 67 | sample_size=config.model.input_size, 68 | in_channels=config.model.in_channels, 69 | out_channels=config.model.out_channels, 70 | layers_per_block=config.model.layers_per_block, 71 | block_out_channels=config.model.block_out_channels, 72 | norm_num_groups=config.model.norm_num_groups, 73 | down_block_types=config.model.down_block_types, 74 | up_block_types=config.model.up_block_types, 75 | ) 76 | 77 | 78 | def main(args): 79 | torch.manual_seed(args.seed) 80 | np.random.seed(args.seed) 81 | random.seed(args.seed) 82 | 83 | run_name = f"fid_vs_gensteps_{args.distillation_type}_{args.criterion}_eta_{args.teacher_eta}_{args.seed}" 84 | results_folder = os.path.join("results", run_name) 85 | os.makedirs(results_folder, exist_ok=True) 86 | 87 | device = "cuda" if torch.cuda.is_available() else "cpu" 88 | 89 | preprocess = transforms.Compose( 90 | [ 91 | transforms.Resize((args.image_size, args.image_size)), 92 | transforms.ToTensor(), 93 | transforms.Normalize([0.5], [0.5]), 94 | ] 95 | ) 96 | 97 | if args.dataset == "mnist": 98 | train_dataloader, test_dataloader = create_mnist_dataloader( 99 | args.batch_size, preprocess) 100 | elif args.dataset == "fashion_mnist": 101 | train_dataloader, test_dataloader = create_fashion_mnist_dataloader( 102 | args.batch_size, preprocess) 103 | else: 104 | raise NotImplementedError 105 | 106 | model_config = get_configuration(args.model_config_path) 107 | 108 | noise_scheduler = DDIMScheduler( 109 | num_train_timesteps=model_config.scheduler.train_timesteps) 110 | 111 | if args.criterion == "mse": 112 | criterion = MSELoss(noise_scheduler) 113 | elif args.criterion == "min_snr": 114 | criterion = MinSNRLoss(noise_scheduler) 115 | else: 116 | raise NotImplementedError 117 | 118 | if args.distillation_type == "gaussian": 119 | trainer_class = GaussianDistillation 120 | elif args.distillation_type == "generation": 121 | trainer_class = GenerationDistillation 122 | elif args.distillation_type == "partial_generation": 123 | trainer_class = PartialGenerationDistillation 124 | elif args.distillation_type == "no_distillation": 125 | trainer_class = NoDistillation 126 | else: 127 | raise NotImplementedError 128 | 129 | assert args.teacher_path is not None 130 | teacher_pipeline = DDIMPipeline.from_pretrained(args.teacher_path) 131 | teacher_pipeline.set_progress_bar_config(disable=True) 132 | teacher = teacher_pipeline.unet.to(device) 133 | 134 | auc_list = [] 135 | time_list = [] 136 | gen_steps = [1, 2, 5, 10, 20, 50, 100] 137 | for gen_step in gen_steps: 138 | save_path = os.path.join(results_folder, f"gen_step_{gen_step}") 139 | os.makedirs(save_path, exist_ok=True) 140 | evaluator = GenerativeModelEvaluator( 141 | device=device, save_images=100, save_path=save_path) 142 | all_configs = { 143 | "args": vars(args), 144 | "model_config": model_config, 145 | } 146 | tracker = CSVTracker(all_configs, save_path) 147 | 148 | print( 149 | f"\n\n======= Training with {gen_step} generation steps =======\n") 150 | 151 | wrap_in_pipeline(teacher, noise_scheduler, 152 | DDIMPipeline, gen_step, args.teacher_eta, def_output_type="torch_raw") 153 | 154 | student = create_model(model_config).to(device) 155 | wrap_in_pipeline(student, noise_scheduler, DDIMPipeline, 156 | args.generation_steps, args.eta) 157 | 158 | optimizer = Adam(student.parameters(), lr=model_config.optimizer.lr) 159 | 160 | trainer = trainer_class( 161 | model=student, 162 | scheduler=noise_scheduler, 163 | optimizer=optimizer, 164 | criterion=criterion, 165 | train_mb_size=args.batch_size, 166 | train_iterations=args.num_epochs, 167 | eval_mb_size=args.eval_batch_size, 168 | device=device, 169 | train_timesteps=model_config.scheduler.train_timesteps, 170 | evaluator=evaluator, 171 | tracker=tracker, 172 | ) 173 | 174 | start = torch.cuda.Event(enable_timing=True) 175 | end = torch.cuda.Event(enable_timing=True) 176 | 177 | start.record() 178 | metrics = trainer.train(teacher, eval_loader=test_dataloader, 179 | save_every=args.save_every, save_path=save_path) 180 | end.record() 181 | 182 | torch.cuda.synchronize() 183 | time_list.append(start.elapsed_time(end)) 184 | auc_list.append(metrics["auc"]) 185 | 186 | print(f"Time taken: {start.elapsed_time(end) / 1000} s") 187 | for metric, value in metrics.items(): 188 | print(f"Best {metric}: {value}") 189 | 190 | # Save results as json 191 | results = { 192 | "config": { 193 | "args": vars(args), 194 | "model_config": model_config, 195 | }, 196 | "results": { 197 | "auc_list": auc_list, 198 | "time_list": time_list, 199 | "gen_steps": gen_steps 200 | } 201 | } 202 | 203 | with open(os.path.join(results_folder, "results.json"), "w") as f: 204 | json.dump(results, f, indent=4) 205 | 206 | # Save graph 207 | plot_line_graph( 208 | gen_steps, 209 | auc_list, 210 | "Teacher's Generation Steps", 211 | "Student AUC", 212 | "Student AUC vs Teacher's Generation Steps", 213 | os.path.join(results_folder, "fid_vs_gensteps.png") 214 | ) 215 | 216 | 217 | if __name__ == "__main__": 218 | args = __parse_args() 219 | main(args) 220 | -------------------------------------------------------------------------------- /utils/generate_fid_accuracy.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | import argparse 5 | 6 | from typing import List 7 | 8 | import sys 9 | from pathlib import Path 10 | # This script should be run from the root of the project 11 | sys.path.append(str(Path(__file__).parent.parent)) 12 | 13 | from src.common.visual import plot_line_graph 14 | 15 | 16 | def __parse_args() -> argparse.Namespace: 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--experiments_path", type=str, 19 | default="results_fuji/smasipca/generative_replay_gensteps/split_fmnist/") 20 | return parser.parse_args() 21 | 22 | 23 | def plot_metrics(experiment_paths: List[str], save_path: str): 24 | metrics = {} 25 | 26 | for experiment_path in experiment_paths: 27 | values = {} 28 | 29 | for seed_folder in os.listdir(experiment_path): 30 | seed_path = os.path.join(experiment_path, seed_folder) 31 | 32 | if not os.path.isdir(seed_path): 33 | continue 34 | 35 | seed_path = os.path.join(seed_path, 'logs') 36 | 37 | for file_name in os.listdir(seed_path): 38 | if file_name.endswith('.csv') and file_name.startswith('eval'): 39 | file_path = os.path.join(seed_path, file_name) 40 | df = pd.read_csv(file_path) 41 | last_experience = df['training_exp'].max() 42 | df = df[df['training_exp'] == last_experience] 43 | df = df[df['eval_exp'] == last_experience] 44 | 45 | for metric_name in ["Top1_Acc_Stream/eval_phase/test_stream/Task000", "StreamForgetting/eval_phase/test_stream", "stream_fid/eval_phase/test_stream/Task000"]: 46 | # Get metric row 47 | metric_row = df[df['metric_name'] == metric_name] 48 | # Get metric value (last column) 49 | metric_value = float(metric_row.iloc[0, -1]) 50 | if metric_name not in values: 51 | values[metric_name] = [] 52 | values[metric_name].append(metric_value) 53 | 54 | for metric_name in values: 55 | if metric_name not in metrics: 56 | metrics[metric_name] = {} 57 | metrics[metric_name]['means'] = [] 58 | metrics[metric_name]['stds'] = [] 59 | metrics[metric_name]['means'].append(np.mean(values[metric_name])) 60 | metrics[metric_name]['stds'].append(np.std(values[metric_name])) 61 | 62 | # Plot fid vs accuracy 63 | output_path = os.path.join(save_path, f"fid_vs_accuracy.pgf") 64 | y_lim = [0, 1] 65 | x = metrics["stream_fid/eval_phase/test_stream/Task000"]["means"] 66 | y = metrics["Top1_Acc_Stream/eval_phase/test_stream/Task000"]["means"] 67 | x_index = np.argsort(x) 68 | x = np.array(x)[x_index] 69 | y = np.array(y)[x_index] 70 | plot_line_graph(x, y, "FID", "Accuracy", "FID vs Accuracy", output_path, log_x=False, y_lim=y_lim, size=(3, 2)) 71 | 72 | # Plot gens vs accuracy 73 | output_path = os.path.join(save_path, f"gens_vs_accuracy.pgf") 74 | x = [20, 10, 5, 2] 75 | y = metrics["Top1_Acc_Stream/eval_phase/test_stream/Task000"]["means"] 76 | y = np.array(y)[x_index] 77 | plot_line_graph(x, y, "DDIM steps", "Accuracy", "DDIM steps vs Accuracy", output_path, x_ticks=x, log_x=False, y_lim=y_lim, size=(3, 2)) 78 | 79 | # Plot gens vs forgetting 80 | output_path = os.path.join(save_path, f"gens_vs_forgetting.pgf") 81 | y = metrics["StreamForgetting/eval_phase/test_stream"]["means"] 82 | y = np.array(y)[x_index] 83 | plot_line_graph(x, y, "DDIM steps", "Forgetting", "DDIM steps vs Forgetting", output_path, x_ticks=x, log_x=False, y_lim=y_lim, size=(3, 2)) 84 | 85 | # Plot fid vs forgetting 86 | output_path = os.path.join(save_path, f"fid_vs_forgetting.pgf") 87 | x = metrics["stream_fid/eval_phase/test_stream/Task000"]["means"] 88 | y = metrics["StreamForgetting/eval_phase/test_stream"]["means"] 89 | x = np.array(x)[x_index] 90 | y = np.array(y)[x_index] 91 | plot_line_graph(x, y, "FID", "Forgetting", "FID vs Forgetting", output_path, log_x=False, y_lim=y_lim, size=(3, 2)) 92 | 93 | 94 | if __name__ == '__main__': 95 | import matplotlib 96 | matplotlib.use("pgf") 97 | matplotlib.rcParams.update({ 98 | "pgf.texsystem": "pdflatex", 99 | 'font.family': 'serif', 100 | 'font.size': 8, 101 | 'text.usetex': True, 102 | 'pgf.rcfonts': False, 103 | 'figure.autolayout': True, 104 | }) 105 | args = __parse_args() 106 | experiment_names = [] 107 | 108 | for experiment_name in os.listdir(args.experiments_path): 109 | if "lambd_4.0" in experiment_name: 110 | continue 111 | 112 | experiment_path = os.path.join(args.experiments_path, experiment_name) 113 | if os.path.isdir(experiment_path): 114 | experiment_names.append(experiment_path) 115 | 116 | plot_metrics(experiment_names, save_path=args.experiments_path) 117 | 118 | -------------------------------------------------------------------------------- /utils/generate_fid_vs_samples.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | import numpy as np 5 | import random 6 | import json 7 | 8 | from torchvision import transforms 9 | from diffusers import DDIMScheduler 10 | 11 | import sys 12 | from pathlib import Path 13 | # This script should be run from the root of the project 14 | sys.path.append(str(Path(__file__).parent.parent)) 15 | 16 | from src.datasets.fashion_mnist import create_dataloader as create_fashion_mnist_dataloader 17 | from src.datasets.mnist import create_dataloader as create_mnist_dataloader 18 | from src.common.utils import get_configuration 19 | from src.common.diffusion_utils import wrap_in_pipeline 20 | from src.pipelines.pipeline_ddim import DDIMPipeline 21 | from src.standard_training.evaluators.generative_evaluator import GenerativeModelEvaluator 22 | from src.common.visual import plot_line_std_graph 23 | 24 | 25 | def __parse_args() -> argparse.Namespace: 26 | parser = argparse.ArgumentParser() 27 | # 28 for vae, 32 for unet 28 | parser.add_argument("--image_size", type=int, default=32) 29 | parser.add_argument("--channels", type=int, default=1) 30 | 31 | parser.add_argument("--dataset", type=str, default="fashion_mnist") 32 | 33 | parser.add_argument("--model_config_path", type=str, 34 | default="configs/model/ddim_medium.json") 35 | parser.add_argument("--model_path", type=str, default="results/fashion_mnist/diffusion/None/ddim_medium_mse/42/best_model", 36 | help="Path to teacher model (only for distillation)") 37 | 38 | parser.add_argument("--eta", type=float, default=0.0) 39 | 40 | parser.add_argument("--batch_size", type=int, default=128) 41 | 42 | parser.add_argument("--seed", type=int, default=42) 43 | return parser.parse_args() 44 | 45 | 46 | def main(args): 47 | torch.manual_seed(args.seed) 48 | np.random.seed(args.seed) 49 | random.seed(args.seed) 50 | 51 | run_name = f"fid_vs_samples_eta_{args.eta}_seed_{args.seed}" 52 | results_folder = os.path.join("results", run_name) 53 | os.makedirs(results_folder, exist_ok=True) 54 | 55 | device = "cuda" if torch.cuda.is_available() else "cpu" 56 | 57 | preprocess = transforms.Compose( 58 | [ 59 | transforms.Resize((args.image_size, args.image_size)), 60 | transforms.ToTensor(), 61 | ] 62 | ) 63 | 64 | if args.dataset == "mnist": 65 | train_dataloader, test_dataloader = create_mnist_dataloader( 66 | args.batch_size, preprocess) 67 | elif args.dataset == "fashion_mnist": 68 | train_dataloader, test_dataloader = create_fashion_mnist_dataloader( 69 | args.batch_size, preprocess) 70 | else: 71 | raise NotImplementedError 72 | 73 | model_config = get_configuration(args.model_config_path) 74 | 75 | noise_scheduler = DDIMScheduler( 76 | num_train_timesteps=model_config.scheduler.train_timesteps) 77 | 78 | assert args.model_path is not None 79 | model_pipeline = DDIMPipeline.from_pretrained(args.model_path) 80 | model_pipeline.set_progress_bar_config(disable=True) 81 | model = model_pipeline.unet.to(device) 82 | 83 | all_fid_list = [] 84 | all_time_list = [] 85 | num_samples = [100, 200, 500, 1000, 2000, 5000, 10000, 20000, 50000] 86 | for n_samples in num_samples: 87 | print(f"Running for n_samples: {n_samples}") 88 | save_path = os.path.join(results_folder, f"gen_step_{n_samples}") 89 | os.makedirs(save_path, exist_ok=True) 90 | evaluator = GenerativeModelEvaluator( 91 | device=device, save_images=0, save_path=save_path) 92 | 93 | wrap_in_pipeline(model, noise_scheduler, 94 | DDIMPipeline, n_samples, args.eta, def_output_type="torch") 95 | 96 | start = torch.cuda.Event(enable_timing=True) 97 | end = torch.cuda.Event(enable_timing=True) 98 | 99 | fid_list = [] 100 | time_list = [] 101 | for _ in range(5): 102 | start.record() 103 | fid = evaluator.evaluate(model, test_dataloader, fid_images=n_samples, gensteps=20, compute_auc=False)["fid"] 104 | end.record() 105 | torch.cuda.synchronize() 106 | 107 | time_list.append(start.elapsed_time(end)) 108 | fid_list.append(fid) 109 | 110 | all_fid_list.append(fid_list) 111 | all_time_list.append(time_list) 112 | 113 | print(f"Time taken: {np.mean(time_list) / 1000} +- {np.std(time_list) / 1000} s") 114 | print(f"FID: {np.mean(fid_list)} +- {np.std(fid_list)}") 115 | 116 | # Save results as json 117 | results = { 118 | "config": { 119 | "args": vars(args), 120 | "model_config": model_config, 121 | }, 122 | "results": { 123 | "fid_list": all_fid_list, 124 | "time_list": all_time_list, 125 | "num_samples": num_samples 126 | } 127 | } 128 | 129 | with open(os.path.join(results_folder, "fid_vs_samples_results.json"), "w") as f: 130 | json.dump(results, f, indent=4) 131 | 132 | # Save graph 133 | plot_line_std_graph( 134 | num_samples, 135 | np.array([np.mean(t) for t in all_fid_list]), 136 | np.array([np.std(t) for t in all_fid_list]), 137 | "Num Samples", 138 | "FID", 139 | "FID vs Number of Samples", 140 | os.path.join(results_folder, "fid_vs_samples.png") 141 | ) 142 | 143 | plot_line_std_graph( 144 | num_samples, 145 | np.array([np.mean(t) / 1000 for t in all_time_list]), 146 | np.array([np.std(t) / 1000 for t in all_time_list]), 147 | "Num Samples", 148 | "Time (s)", 149 | "Time vs Number of Samples", 150 | os.path.join(results_folder, "time_vs_samples.png") 151 | ) 152 | 153 | 154 | if __name__ == "__main__": 155 | args = __parse_args() 156 | main(args) 157 | -------------------------------------------------------------------------------- /utils/generate_fid_vs_time.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | import numpy as np 5 | import random 6 | import json 7 | 8 | from torchvision import transforms 9 | from diffusers import DDIMScheduler 10 | 11 | import sys 12 | from pathlib import Path 13 | # This script should be run from the root of the project 14 | sys.path.append(str(Path(__file__).parent.parent)) 15 | 16 | from src.datasets.fashion_mnist import create_dataloader as create_fashion_mnist_dataloader 17 | from src.datasets.mnist import create_dataloader as create_mnist_dataloader 18 | from src.common.utils import get_configuration 19 | from src.common.diffusion_utils import wrap_in_pipeline 20 | from src.pipelines.pipeline_ddim import DDIMPipeline 21 | from src.standard_training.evaluators.generative_evaluator import GenerativeModelEvaluator 22 | from src.common.visual import plot_line_graph 23 | 24 | 25 | def __parse_args() -> argparse.Namespace: 26 | parser = argparse.ArgumentParser() 27 | # 28 for vae, 32 for unet 28 | parser.add_argument("--image_size", type=int, default=32) 29 | parser.add_argument("--channels", type=int, default=1) 30 | 31 | parser.add_argument("--dataset", type=str, default="fashion_mnist") 32 | 33 | parser.add_argument("--model_config_path", type=str, 34 | default="configs/model/ddim_medium.json") 35 | parser.add_argument("--model_path", type=str, default="results/fashion_mnist/diffusion/None/ddim_medium_mse/42/best_model", 36 | help="Path to teacher model (only for distillation)") 37 | 38 | parser.add_argument("--eta", type=float, default=0.0) 39 | 40 | parser.add_argument("--batch_size", type=int, default=128) 41 | 42 | parser.add_argument("--seed", type=int, default=42) 43 | return parser.parse_args() 44 | 45 | 46 | def main(args): 47 | torch.manual_seed(args.seed) 48 | np.random.seed(args.seed) 49 | random.seed(args.seed) 50 | 51 | run_name = f"fid_vs_time_eta_{args.eta}_seed_{args.seed}" 52 | results_folder = os.path.join("results", run_name) 53 | os.makedirs(results_folder, exist_ok=True) 54 | 55 | device = "cuda" if torch.cuda.is_available() else "cpu" 56 | 57 | preprocess = transforms.Compose( 58 | [ 59 | transforms.Resize((args.image_size, args.image_size)), 60 | transforms.ToTensor(), 61 | ] 62 | ) 63 | 64 | if args.dataset == "mnist": 65 | train_dataloader, test_dataloader = create_mnist_dataloader( 66 | args.batch_size, preprocess) 67 | elif args.dataset == "fashion_mnist": 68 | train_dataloader, test_dataloader = create_fashion_mnist_dataloader( 69 | args.batch_size, preprocess) 70 | else: 71 | raise NotImplementedError 72 | 73 | model_config = get_configuration(args.model_config_path) 74 | 75 | noise_scheduler = DDIMScheduler( 76 | num_train_timesteps=model_config.scheduler.train_timesteps) 77 | 78 | assert args.model_path is not None 79 | model_pipeline = DDIMPipeline.from_pretrained(args.model_path) 80 | model_pipeline.set_progress_bar_config(disable=True) 81 | model = model_pipeline.unet.to(device) 82 | 83 | fid_list = [] 84 | time_list = [] 85 | gen_steps = [1, 2, 5, 10, 20, 50, 100] 86 | for gen_step in gen_steps: 87 | print(f"Running for gen_step: {gen_step}") 88 | save_path = os.path.join(results_folder, f"gen_step_{gen_step}") 89 | os.makedirs(save_path, exist_ok=True) 90 | evaluator = GenerativeModelEvaluator( 91 | device=device, save_images=100, save_path=save_path) 92 | 93 | wrap_in_pipeline(model, noise_scheduler, 94 | DDIMPipeline, gen_step, args.eta, def_output_type="torch") 95 | 96 | start = torch.cuda.Event(enable_timing=True) 97 | end = torch.cuda.Event(enable_timing=True) 98 | 99 | start.record() 100 | fid = evaluator.evaluate(model, test_dataloader, gen_step, gensteps=gen_step, compute_auc=False)["fid"] 101 | end.record() 102 | 103 | torch.cuda.synchronize() 104 | time_list.append(start.elapsed_time(end)) 105 | fid_list.append(fid) 106 | 107 | print(f"Time taken: {start.elapsed_time(end)/1000} s") 108 | print(f"FID: {fid}") 109 | 110 | # Save results as json 111 | results = { 112 | "config": { 113 | "args": vars(args), 114 | "model_config": model_config, 115 | }, 116 | "results": { 117 | "fid_list": fid_list, 118 | "time_list": time_list, 119 | "gen_steps": gen_steps 120 | } 121 | } 122 | 123 | with open(os.path.join(results_folder, "fid_vs_time_results.json"), "w") as f: 124 | json.dump(results, f, indent=4) 125 | 126 | # Save graph 127 | plot_line_graph( 128 | gen_steps, 129 | fid_list, 130 | "Generation Steps", 131 | "FID", 132 | "FID vs Generation Steps", 133 | os.path.join(results_folder, "fid_vs_gensteps.png") 134 | ) 135 | 136 | plot_line_graph( 137 | gen_steps, 138 | [t / 1000 for t in time_list], 139 | "Generation Steps", 140 | "Time (s)", 141 | "Time vs Generation Steps", 142 | os.path.join(results_folder, "time_vs_gensteps.png") 143 | ) 144 | 145 | 146 | if __name__ == "__main__": 147 | args = __parse_args() 148 | main(args) 149 | -------------------------------------------------------------------------------- /utils/save_cifar10_examples.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | import random 6 | 7 | 8 | if __name__ == "__main__": 9 | # Define the path where the examples will be saved 10 | output_folder = "examples_cifar10" 11 | os.makedirs(output_folder, exist_ok=True) 12 | 13 | # Define the Fashion MNIST dataset and dataloader 14 | transform = transforms.Compose([transforms.ToTensor()]) 15 | dataset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform) 16 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True) 17 | 18 | # Define class names for Fashion MNIST 19 | class_names = [ 20 | "airplane", 21 | "automobile", 22 | "bird", 23 | "cat", 24 | "deer", 25 | "dog", 26 | "frog", 27 | "horse", 28 | "ship", 29 | "truck", 30 | ] 31 | # Task orders: {3, 7}, {8, 0}, {1, 4}, {2, 6}, {9, 5} 32 | task_orders = [1, 2, 3, 0, 2, 4, 3, 0, 1, 4] 33 | 34 | # Initialize counters for each class 35 | class_counters = {class_name: 0 for class_name in class_names} 36 | num_samples_per_class = 1 # Number of samples to save per class 37 | 38 | # Loop through the dataset and save random samples for each class 39 | for image, label in dataloader: 40 | class_name = class_names[label] 41 | task_order = task_orders[label] 42 | 43 | if class_counters[class_name] < num_samples_per_class: 44 | # Generate a random file name 45 | random_suffix = random.randint(0, 99999) 46 | file_name = f"{class_name.replace('/', '_')}_{class_counters[class_name]}_task_{task_order}_{random_suffix}.png" 47 | file_path = os.path.join(output_folder, file_name) 48 | 49 | # Save the image 50 | torchvision.utils.save_image(image, file_path) 51 | 52 | # Increment the counter for the class 53 | class_counters[class_name] += 1 54 | 55 | # Check if we have saved enough samples for all classes 56 | if all(count >= num_samples_per_class for count in class_counters.values()): 57 | break 58 | 59 | print("Random samples saved to the 'examples_cifar10' folder.") 60 | -------------------------------------------------------------------------------- /utils/save_fmnist_examples.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | import random 6 | 7 | 8 | if __name__ == "__main__": 9 | # Define the path where the examples will be saved 10 | output_folder = "examples" 11 | os.makedirs(output_folder, exist_ok=True) 12 | 13 | # Define the Fashion MNIST dataset and dataloader 14 | transform = transforms.Compose([transforms.ToTensor()]) 15 | dataset = torchvision.datasets.FashionMNIST(root="./data", train=True, download=True, transform=transform) 16 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True) 17 | 18 | # Define class names for Fashion MNIST {0, 2}, {3, 5}, {1, 4}, {9, 6}, {8, 7} 19 | class_names = [ 20 | "T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", 21 | "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot" 22 | ] 23 | 24 | # Initialize counters for each class 25 | class_counters = {class_name: 0 for class_name in class_names} 26 | num_samples_per_class = 1 # Number of samples to save per class 27 | 28 | # Loop through the dataset and save random samples for each class 29 | for image, label in dataloader: 30 | class_name = class_names[label] 31 | if class_counters[class_name] < num_samples_per_class: 32 | # Generate a random file name 33 | random_suffix = random.randint(0, 99999) 34 | file_name = f"{class_name.replace('/', '_')}_{class_counters[class_name]}_{random_suffix}.png" 35 | file_path = os.path.join(output_folder, file_name) 36 | 37 | # Save the image 38 | torchvision.utils.save_image(image, file_path) 39 | 40 | # Increment the counter for the class 41 | class_counters[class_name] += 1 42 | 43 | # Check if we have saved enough samples for all classes 44 | if all(count >= num_samples_per_class for count in class_counters.values()): 45 | break 46 | 47 | print("Random samples saved to the 'examples' folder.") 48 | -------------------------------------------------------------------------------- /weights/cnn_cifar10/resnet.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Atenrev/diffusion_continual_learning/0d88ac5dea1ed7595b3f1f2f333b3bc1ce555f53/weights/cnn_cifar10/resnet.pth -------------------------------------------------------------------------------- /weights/cnn_fmnist/resnet.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Atenrev/diffusion_continual_learning/0d88ac5dea1ed7595b3f1f2f333b3bc1ce555f53/weights/cnn_fmnist/resnet.pth --------------------------------------------------------------------------------