├── satflow ├── baseline │ ├── __init__.py │ ├── README.md │ └── optical_flow.py ├── core │ ├── __init__.py │ └── utils.py ├── data │ ├── __init__.py │ ├── utils │ │ ├── __init__.py │ │ └── utils.py │ ├── datasets.py │ └── datamodules.py ├── experiments │ ├── __init__.py │ └── train.py ├── configs │ ├── callbacks │ │ ├── none.yaml │ │ ├── gan.yaml │ │ └── default.yaml │ ├── model │ │ ├── hf_perceiver.yaml │ │ ├── attention_unet.yaml │ │ ├── unet.yaml │ │ ├── runet.yaml │ │ ├── attention_runet.yaml │ │ ├── attention_runet_coord.yaml │ │ ├── convlstm.yaml │ │ ├── runet_coord.yaml │ │ ├── fcn_r50.yaml │ │ ├── fcn_r101.yaml │ │ ├── convlstm_coord.yaml │ │ ├── deeplabv3_r101.yaml │ │ ├── deeplabv3_r50.yaml │ │ ├── metnet.yaml │ │ ├── metnet_ssim.yaml │ │ ├── nowcasting_gan.yaml │ │ ├── pix2pix.yaml │ │ ├── cloudgan_runet.yaml │ │ ├── cloudgan_convlstm.yaml │ │ ├── perceiver_metnet.yaml │ │ ├── perceiver_single.yaml │ │ ├── perceiver.yaml │ │ └── perceiver_encoder.yaml │ ├── trainer │ │ ├── ddp.yaml │ │ ├── debug.yaml │ │ ├── minimal.yaml │ │ ├── default.yaml │ │ ├── half.yaml │ │ ├── simple_profiler.yaml │ │ ├── pytorch_profiler.yaml │ │ ├── deepspeed.yaml │ │ └── deepspeed_zero_three.yaml │ ├── logger │ │ ├── many_loggers.yaml │ │ ├── csv.yaml │ │ ├── tensorboard.yaml │ │ └── neptune.yaml │ ├── datamodule │ │ ├── local_perceiver.yaml │ │ ├── local.yaml │ │ ├── aws.yaml │ │ └── gcp.yaml │ ├── hydra │ │ └── default.yaml │ ├── experiment │ │ ├── unet_simple.yaml │ │ ├── metnet_simple.yaml │ │ ├── convlstm_simple.yaml │ │ ├── nowcasting_gan_simple.yaml │ │ ├── perceiver_simple.yaml │ │ ├── example_simple.yaml │ │ └── example_full.yaml │ ├── configurations │ │ ├── aws.yaml │ │ ├── gcp.yaml │ │ └── local.yaml │ ├── config.yaml │ └── hparams_search │ │ ├── convlstm_optuna.yaml │ │ ├── unet_optuna.yaml │ │ ├── metnet_optuna.yaml │ │ ├── nowcasting_gan_optuna.yaml │ │ └── perceiver_optuna.yaml ├── version.py ├── __init__.py ├── models │ ├── gan │ │ ├── __init__.py │ │ └── common.py │ ├── __init__.py │ ├── layers │ │ ├── __init__.py │ │ ├── ConditionTime.py │ │ ├── CoordConv.py │ │ ├── TimeDistributed.py │ │ ├── ConvLSTM.py │ │ ├── Normalization.py │ │ ├── GResBlock.py │ │ ├── SpatioTemporalLSTMCell_memory_decoupling.py │ │ ├── RUnetLayers.py │ │ └── Generator.py │ ├── pixel_cnn.py │ ├── utils.py │ ├── perceiverio.py │ ├── unet.py │ ├── fcn.py │ ├── deeplabv3.py │ ├── pl_metnet.py │ ├── runet.py │ ├── pix2pix.py │ └── conv_lstm.py ├── run.py └── examples │ └── metnet_example.py ├── MANIFEST.in ├── docs ├── requirements.txt └── .readthedocs-custom-steps.yml ├── .deepsource.toml ├── .readthedocs.yml ├── .bumpversion.cfg ├── .github └── workflows │ ├── linters.yaml │ ├── release.yaml │ ├── workflows.yaml │ └── docker.yaml ├── .dockerignore ├── requirements.txt ├── pydoc-markdown.yml ├── .all-contributorsrc ├── LICENSE ├── setup.py ├── .pre-commit-config.yaml ├── Dockerfile ├── README.md ├── .gitignore ├── tests └── test_models.py └── environment.yml /satflow/baseline/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /satflow/core/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /satflow/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include *.txt 2 | -------------------------------------------------------------------------------- /satflow/data/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /satflow/experiments/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /satflow/configs/callbacks/none.yaml: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /satflow/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.3.36" 2 | -------------------------------------------------------------------------------- /satflow/__init__.py: -------------------------------------------------------------------------------- 1 | from .version import __version__ 2 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | readthedocs-custom-steps==0.5.1 2 | pydoc-markdown 3 | -------------------------------------------------------------------------------- /docs/.readthedocs-custom-steps.yml: -------------------------------------------------------------------------------- 1 | steps: 2 | - | 3 | pydoc-markdown --build --site-dir "$PWD/_build/html" 4 | -------------------------------------------------------------------------------- /satflow/configs/model/hf_perceiver.yaml: -------------------------------------------------------------------------------- 1 | __target__: satflow.models.perceiverio.HuggingFacePerceiver 2 | input_szie: 64 3 | -------------------------------------------------------------------------------- /satflow/configs/trainer/ddp.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: pytorch_lightning.Trainer 3 | 4 | gpus: 4 5 | accelerator: ddp 6 | -------------------------------------------------------------------------------- /.deepsource.toml: -------------------------------------------------------------------------------- 1 | version = 1 2 | 3 | [[analyzers]] 4 | name = "python" 5 | enabled = true 6 | 7 | [analyzers.meta] 8 | runtime_version = "3.x.x" 9 | -------------------------------------------------------------------------------- /satflow/configs/logger/many_loggers.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | # train with many loggers at once 3 | 4 | defaults: 5 | - tensorboard.yaml 6 | - neptune.yaml 7 | -------------------------------------------------------------------------------- /satflow/models/gan/__init__.py: -------------------------------------------------------------------------------- 1 | from .discriminators import GANLoss, NLayerDiscriminator, PixelDiscriminator, define_discriminator 2 | from .generators import define_generator 3 | -------------------------------------------------------------------------------- /satflow/configs/model/attention_unet.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: satflow.models.attention_unet.AttentionUnet 3 | input_channels: 68 4 | forecast_steps: 48 5 | visualize: True 6 | lr: 0.0001 7 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | mkdocs: {} # tell readthedocs to use mkdocs 3 | python: 4 | version: 3.7 5 | install: 6 | - method: pip 7 | path: . 8 | - requirements: docs/requirements.txt 9 | -------------------------------------------------------------------------------- /satflow/configs/model/unet.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: satflow.models.unet.Unet 3 | input_channels: 68 4 | hidden_dim: 64 5 | num_layers: 5 6 | forecast_steps: 24 7 | visualize: True 8 | lr: 0.0001 9 | -------------------------------------------------------------------------------- /satflow/configs/model/runet.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: satflow.models.runet.RUnet 3 | input_channels: 68 4 | hidden_dim: 64 5 | num_layers: 5 6 | forecast_steps: 24 7 | visualize: True 8 | lr: 0.0001 9 | -------------------------------------------------------------------------------- /satflow/configs/model/attention_runet.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: satflow.models.attention_unet.AttentionRUnet 3 | input_channels: 68 4 | forecast_steps: 48 5 | visualize: True 6 | lr: 0.0001 7 | conv_type: "standard" 8 | -------------------------------------------------------------------------------- /.bumpversion.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | commit = True 3 | tag = False 4 | current_version = 0.3.36 5 | 6 | [bumpversion:file:satflow/version.py] 7 | search = __version__ = "{current_version}" 8 | replace = __version__ = "{new_version}" 9 | -------------------------------------------------------------------------------- /satflow/configs/model/attention_runet_coord.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: satflow.models.attention_unet.AttentionRUnet 3 | input_channels: 68 4 | forecast_steps: 48 5 | visualize: True 6 | lr: 0.0001 7 | conv_type: "coord" 8 | -------------------------------------------------------------------------------- /satflow/configs/model/convlstm.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: satflow.models.conv_lstm.EncoderDecoderConvLSTM 3 | input_channels: 17 4 | hidden_dim: 64 5 | out_channels: 1 6 | forecast_steps: 24 7 | lr: 0.0001 8 | visualize: True 9 | -------------------------------------------------------------------------------- /.github/workflows/linters.yaml: -------------------------------------------------------------------------------- 1 | name: Lint Python 2 | 3 | on: [push] 4 | 5 | jobs: 6 | call-run-python-linters: 7 | uses: openclimatefix/.github/.github/workflows/python-lint.yml@main 8 | with: 9 | folder: "satflow" 10 | -------------------------------------------------------------------------------- /satflow/configs/logger/csv.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | # csv logger built in lightning 3 | 4 | csv: 5 | _target_: pytorch_lightning.loggers.csv_logs.CSVLogger 6 | save_dir: "." 7 | name: "csv/" 8 | version: null 9 | prefix: "" 10 | -------------------------------------------------------------------------------- /satflow/configs/model/runet_coord.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: satflow.models.runet.RUnet 3 | input_channels: 68 4 | hidden_dim: 64 5 | num_layers: 5 6 | forecast_steps: 24 7 | visualize: True 8 | lr: 0.0001 9 | conv_type: "coord" 10 | -------------------------------------------------------------------------------- /satflow/configs/model/fcn_r50.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: satflow.models.fcn.FCN 3 | forecast_steps: 6 4 | input_channels: 12 5 | lr: 0.001 6 | make_vis: False 7 | loss: "bce" 8 | backbone: "resnet50" 9 | pretrained: False 10 | aux_loss: False 11 | -------------------------------------------------------------------------------- /satflow/configs/model/fcn_r101.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: satflow.models.fcn.FCN 3 | forecast_steps: 6 4 | input_channels: 12 5 | lr: 0.001 6 | make_vis: False 7 | loss: "bce" 8 | backbone: "resnet101" 9 | pretrained: False 10 | aux_loss: False 11 | -------------------------------------------------------------------------------- /satflow/configs/model/convlstm_coord.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: satflow.models.conv_lstm.EncoderDecoderConvLSTM 3 | input_channels: 17 4 | hidden_dim: 64 5 | out_channels: 1 6 | forecast_steps: 24 7 | lr: 0.0001 8 | visualize: True 9 | conv_type: "coord" 10 | -------------------------------------------------------------------------------- /satflow/configs/model/deeplabv3_r101.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: satflow.models.deeplabv3.DeepLabV3 3 | forecast_steps: 6 4 | input_channels: 12 5 | lr: 0.001 6 | make_vis: False 7 | loss: "bce" 8 | backbone: "resnet101" 9 | pretrained: False 10 | aux_loss: False 11 | -------------------------------------------------------------------------------- /satflow/configs/model/deeplabv3_r50.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: satflow.models.deeplabv3.DeepLabV3 3 | forecast_steps: 6 4 | input_channels: 12 5 | lr: 0.001 6 | make_vis: False 7 | loss: "bce" 8 | backbone: "resnet50" 9 | pretrained: False 10 | aux_loss: False 11 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | datasets/* 2 | datasets 3 | .github 4 | .github/* 5 | dist/* 6 | satflow/logs/ 7 | satflow/logs/* 8 | satflow.egg-info/ 9 | satflow.egg-info/* 10 | .idea/* 11 | .idea 12 | .pytest_cache/* 13 | .pytest_cache 14 | *.png 15 | .env 16 | satflow/.env 17 | .idea 18 | .idea/* 19 | .git 20 | .git/* 21 | -------------------------------------------------------------------------------- /.github/workflows/release.yaml: -------------------------------------------------------------------------------- 1 | name: Bump version and auto-release 2 | on: 3 | push: 4 | branches: 5 | - main 6 | jobs: 7 | call-run-python-release: 8 | uses: openclimatefix/.github/.github/workflows/python-release.yml@main 9 | secrets: 10 | token: ${{ secrets.PYPI_API_TOKEN }} 11 | -------------------------------------------------------------------------------- /satflow/models/__init__.py: -------------------------------------------------------------------------------- 1 | from nowcasting_utils.models.base import create_model, get_model 2 | 3 | from .attention_unet import AttU_Net, R2AttU_Net 4 | from .conv_lstm import ConvLSTM, EncoderDecoderConvLSTM 5 | from .perceiver import Perceiver 6 | from .pl_metnet import LitMetNet 7 | from .runet import R2U_Net, RUnet 8 | -------------------------------------------------------------------------------- /satflow/configs/model/metnet.yaml: -------------------------------------------------------------------------------- 1 | _target_: satflow.models.metnet.MetNet 2 | input_channels: 16 3 | output_channels: 1 4 | sat_channels: 12 5 | input_size: 64 6 | hidden_dim: 32 7 | forecast_steps: 24 8 | lr: 0.001 9 | kernel_size: 3 10 | num_layers: 1 11 | num_att_layers: 1 12 | temporal_dropout: 0.2 13 | visualize: False 14 | -------------------------------------------------------------------------------- /satflow/configs/model/metnet_ssim.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: satflow.models.metnet.MetNet 3 | input_channels: 17 4 | hidden_dim: 32 5 | forecast_steps: 24 6 | lr: 0.001 7 | kernel_size: 3 8 | num_layers: 1 9 | num_att_layers: 1 10 | temporal_dropout: 0.2 11 | output_channels: 12 12 | visualize: True 13 | loss: "ssim" 14 | -------------------------------------------------------------------------------- /satflow/configs/logger/tensorboard.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | # https://www.tensorflow.org/tensorboard/ 3 | 4 | tensorboard: 5 | _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger 6 | save_dir: "tensorboard/" 7 | name: "default" 8 | version: null 9 | log_graph: False 10 | default_hp_metric: True 11 | prefix: "" 12 | -------------------------------------------------------------------------------- /satflow/models/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .Attention import SelfAttention, SelfAttention2d 2 | from .ConditionTime import ConditionTime 3 | from .ConvLSTM import ConvLSTMCell 4 | from .RUnetLayers import Recurrent_block, RRCNN_block 5 | from .SpatioTemporalLSTMCell_memory_decoupling import SpatioTemporalLSTMCell 6 | from .TimeDistributed import TimeDistributed 7 | -------------------------------------------------------------------------------- /.github/workflows/workflows.yaml: -------------------------------------------------------------------------------- 1 | name: Python package 2 | 3 | on: [push] 4 | jobs: 5 | call-run-python-tests: 6 | uses: openclimatefix/.github/.github/workflows/python-test.yml@main 7 | with: 8 | # 0 means don't use pytest-xdist 9 | pytest_numcpus: "4" 10 | # pytest-cov looks at this folder 11 | pytest_cov_dir: "satflow" 12 | -------------------------------------------------------------------------------- /satflow/configs/logger/neptune.yaml: -------------------------------------------------------------------------------- 1 | # https://neptune.ai 2 | 3 | neptune: 4 | _target_: neptune.new.integrations.pytorch_lightning.NeptuneLogger 5 | api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable 6 | project: OpenClimateFix/forecasting-satellite-images 7 | close_after_fit: False 8 | prefix: "" 9 | name: "Sat" 10 | -------------------------------------------------------------------------------- /satflow/configs/model/nowcasting_gan.yaml: -------------------------------------------------------------------------------- 1 | _target_: satflow.models.nowcasting_gan.NowcastingGAN 2 | forecast_steps: 24 3 | input_channels: 1 4 | output_shape: 128 5 | gen_lr: 0.00005 6 | disc_lr: 0.0002 7 | visualize: True 8 | pretrained: False 9 | conv_type: "standard" 10 | num_samples: 3 11 | grid_lambda: 20.0 12 | beta1: 0.0 13 | beta2: 0.999 14 | latent_channels: 768 15 | context_channels: 768 16 | -------------------------------------------------------------------------------- /satflow/configs/datamodule/local_perceiver.yaml: -------------------------------------------------------------------------------- 1 | _target_: satflow.data.datamodules.SatFlowDataModule 2 | 3 | temp_path: "." 4 | n_train_data: 24900 5 | n_val_data: 1000 6 | cloud: "local" 7 | num_workers: 8 8 | pin_memory: True 9 | configuration_filename: "satflow/configs/configurations/local.yaml" 10 | fake_data: False 11 | required_keys: 12 | - sat_data 13 | - sat_x_coords 14 | - sat_y_coords 15 | history_minutes: 10 16 | forecast_minutes: 5 17 | -------------------------------------------------------------------------------- /satflow/configs/hydra/default.yaml: -------------------------------------------------------------------------------- 1 | # output paths for hydra logs 2 | run: 3 | dir: logs/runs/${now:%Y-%m-%d}/${now:%H-%M-%S} 4 | sweep: 5 | dir: logs/multiruns/${now:%Y-%m-%d_%H-%M-%S} 6 | subdir: ${hydra.job.num} 7 | 8 | # you can set here environment variables that are universal for all users 9 | # for system specific variables (like data paths) it's better to use .env file! 10 | job: 11 | env_set: 12 | EXAMPLE_VAR: "example_value" 13 | -------------------------------------------------------------------------------- /satflow/configs/model/pix2pix.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: satflow.models.pix2pix.Pix2Pix 3 | forecast_steps: 6 4 | input_channels: 12 5 | lr: 0.0002 6 | beta1: 0.5 7 | beta2: 0.999 8 | num_filters: 128 9 | generator_model: "unet_128" 10 | norm: "batch" 11 | use_dropout: false 12 | discriminator_model: "basic" 13 | discriminator_layers: 0 14 | loss: "vanilla" 15 | scheduler: "cosine" 16 | lr_epochs: 10 17 | lambda_l1: 0.01 18 | channels_per_timestep: 16 19 | -------------------------------------------------------------------------------- /satflow/configs/model/cloudgan_runet.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: satflow.models.cloudgan.CloudGAN 3 | forecast_steps: 48 4 | input_channels: 12 5 | lr: 0.0002 6 | beta1: 0.5 7 | beta2: 0.999 8 | num_filters: 128 9 | generator_model: "runet" 10 | norm: "batch" 11 | use_dropout: false 12 | discriminator_model: "basic" 13 | discriminator_layers: 0 14 | loss: "vanilla" 15 | scheduler: "cosine" 16 | lr_epochs: 10 17 | lambda_l1: 5 18 | channels_per_timestep: 16 19 | -------------------------------------------------------------------------------- /satflow/configs/trainer/debug.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: pytorch_lightning.Trainer 3 | 4 | gpus: 0 5 | 6 | min_epochs: 1 7 | max_epochs: 2 8 | 9 | # prints 10 | progress_bar_refresh_rate: null 11 | weights_summary: null 12 | profiler: null 13 | 14 | # debugs 15 | num_sanity_val_steps: 2 16 | fast_dev_run: False 17 | overfit_batches: 0 18 | limit_train_batches: 1.0 19 | limit_val_batches: 1.0 20 | limit_test_batches: 1.0 21 | track_grad_norm: -1 22 | terminate_on_nan: False 23 | -------------------------------------------------------------------------------- /satflow/configs/model/cloudgan_convlstm.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: satflow.models.cloudgan.CloudGAN 3 | forecast_steps: 24 4 | input_channels: 12 5 | lr: 0.0002 6 | beta1: 0.5 7 | beta2: 0.999 8 | num_filters: 32 9 | generator_model: "convlstm" 10 | norm: "batch" 11 | use_dropout: false 12 | discriminator_model: "basic" 13 | discriminator_layers: 0 14 | loss: "vanilla" 15 | scheduler: "cosine" 16 | lr_epochs: 10 17 | lambda_l1: 1 18 | channels_per_timestep: 12 19 | condition_time: True 20 | -------------------------------------------------------------------------------- /satflow/configs/model/perceiver_metnet.yaml: -------------------------------------------------------------------------------- 1 | _target_: satflow.models.perceiver.Perceiver 2 | input_channels: 16 3 | sat_channels: 12 4 | forecast_steps: 24 5 | lr: 0.005 6 | input_size: 32 7 | max_frequency: 16.0 8 | depth: 6 9 | num_latents: 256 10 | cross_heads: 1 11 | latent_heads: 8 12 | cross_dim_heads: 8 13 | latent_dim: 256 14 | weight_tie_layers: False 15 | decoder_ff: True 16 | dim: 32 17 | logits_dim: null 18 | queries_dim: 32 19 | latent_dim_heads: 64 20 | visualize: False 21 | preprocessor_type: "metnet" 22 | -------------------------------------------------------------------------------- /satflow/configs/datamodule/local.yaml: -------------------------------------------------------------------------------- 1 | _target_: satflow.data.datamodules.SatFlowDataModule 2 | 3 | temp_path: "." 4 | n_train_data: 24900 5 | n_val_data: 1000 6 | cloud: "local" 7 | num_workers: 8 8 | pin_memory: True 9 | configuration_filename: "satflow/configs/configurations/local.yaml" 10 | fake_data: False 11 | required_keys: 12 | - sat_data 13 | - nwp_data 14 | - topo_data 15 | - sat_x_coords 16 | - sat_y_coords 17 | - hour_of_day_sin 18 | - hour_of_day_cos 19 | - day_of_year_sin 20 | - day_of_year_cos 21 | history_minutes: 30 22 | forecast_minutes: 60 23 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | albumentations>=1.0.3 2 | antialiased-cnns>=0.3 3 | hydra-core>=1.1.0 4 | hydra-optuna-sweeper>=1.1.0 5 | hydra-colorlog 6 | lightning-bolts>=0.3.4 7 | neptune-client>=0.10.2 8 | neptune-pytorch-lightning>=0.9.7 9 | pytest>=6.2.4 10 | python-dotenv>=0.19.0 11 | pytorch-msssim>=0.2.1 12 | rich>=10.6.0 13 | torchvision>=0.10.0 14 | affine==2.3.0 15 | torch_optimizer 16 | huggingface_hub>=0.0.16 17 | einops>=0.3.2 18 | metnet>=0.0.3 19 | skillful_nowcasting>=0.0.2 20 | perceiver-model>=0.7.0 21 | nowcasting-dataset>=1.0.21 22 | nowcasting-utils>=0.0.8 23 | transformers 24 | torch 25 | -------------------------------------------------------------------------------- /satflow/configs/model/perceiver_single.yaml: -------------------------------------------------------------------------------- 1 | _target_: satflow.models.perceiver.SinglePassPerceiver 2 | input_channels: 16 3 | sat_channels: 12 4 | forecast_steps: 3 5 | lr: 0.005 6 | input_size: 32 7 | max_frequency: 16.0 8 | depth: 6 9 | num_latents: 256 10 | cross_heads: 1 11 | latent_heads: 8 12 | cross_dim_heads: 8 13 | latent_dim: 256 14 | weight_tie_layers: False 15 | self_per_cross_attention: 2 16 | dim: 32 17 | logits_dim: null 18 | queries_dim: 128 19 | latent_dim_heads: 64 20 | visualize: False 21 | preprocessor_type: "metnet" 22 | use_input_as_query: True 23 | use_learnable_query: False 24 | output_shape: [3, 32, 32] 25 | -------------------------------------------------------------------------------- /satflow/configs/model/perceiver.yaml: -------------------------------------------------------------------------------- 1 | _target_: satflow.models.perceiver.Perceiver 2 | input_channels: 23 3 | sat_channels: 12 4 | nwp_channels: 10 5 | forecast_steps: 12 6 | lr: 0.0005 7 | input_size: 64 8 | max_frequency: 32.0 9 | depth: 6 10 | num_latents: 256 11 | cross_heads: 1 12 | latent_heads: 8 13 | cross_dim_heads: 8 14 | latent_dim: 256 15 | weight_tie_layers: False 16 | decoder_ff: True 17 | dim: 32 18 | logits_dim: null 19 | queries_dim: 32 20 | latent_dim_heads: 64 21 | visualize: False 22 | use_learnable_query: False 23 | predict_timesteps_together: False 24 | nwp_modality: True 25 | datetime_modality: False 26 | history_steps: 6 27 | -------------------------------------------------------------------------------- /pydoc-markdown.yml: -------------------------------------------------------------------------------- 1 | loaders: 2 | - type: python 3 | search_path: [satflow/] 4 | processors: 5 | - type: filter 6 | - type: smart 7 | renderer: 8 | type: mkdocs 9 | pages: 10 | - title: Home 11 | name: index 12 | source: README.md 13 | - title: API Documentation 14 | children: 15 | - title: Baseline 16 | contents: [baseline, baseline.*] 17 | - title: Core 18 | contents: [core, core.*] 19 | - title: Models 20 | contents: [models, models.*] 21 | mkdocs_config: 22 | site_name: SatFlow 23 | theme: readthedocs 24 | repo_url: https://github.com/openclimatefix/satflow 25 | -------------------------------------------------------------------------------- /satflow/configs/datamodule/aws.yaml: -------------------------------------------------------------------------------- 1 | _target_: satflow.data.datamodules.SatFlowDataModule 2 | 3 | temp_path: "." 4 | n_train_data: 24900 5 | n_val_data: 1000 6 | cloud: "aws" 7 | num_workers: 8 8 | pin_memory: True 9 | configuration_filename: "satflow/configurations/aws.yaml" 10 | fake_data: False 11 | required_keys: 12 | - sat_data 13 | - sat_x_coords 14 | - sat_y_coords 15 | - sat_datetime_index 16 | - nwp_data 17 | - nwp_x_coords 18 | - nwp_y_coords 19 | - topo_data 20 | - topo_x_coords 21 | - topo_y_coords 22 | - hour_of_day_sin 23 | - hour_of_day_cos 24 | - day_of_year_sin 25 | - day_of_year_cos 26 | history_minutes: 30 27 | forecast_minutes: 120 28 | -------------------------------------------------------------------------------- /satflow/configs/datamodule/gcp.yaml: -------------------------------------------------------------------------------- 1 | _target_: satflow.data.datamodules.SatFlowDataModule 2 | 3 | temp_path: "." 4 | n_train_data: 24900 5 | n_val_data: 1000 6 | cloud: "gcp" 7 | num_workers: 8 8 | pin_memory: True 9 | configuration_filename: "satflow/configs/configurations/gcp.yaml" 10 | fake_data: False 11 | required_keys: 12 | - sat_data 13 | - sat_x_coords 14 | - sat_y_coords 15 | - sat_datetime_index 16 | - nwp_data 17 | - nwp_x_coords 18 | - nwp_y_coords 19 | - topo_data 20 | - topo_x_coords 21 | - topo_y_coords 22 | - hour_of_day_sin 23 | - hour_of_day_cos 24 | - day_of_year_sin 25 | - day_of_year_cos 26 | history_minutes: 30 27 | forecast_minutes: 120 28 | -------------------------------------------------------------------------------- /satflow/configs/trainer/minimal.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | 3 | # set `1` to train on GPU, `0` to train on CPU only 4 | gpus: 1 5 | 6 | min_epochs: 1 7 | max_epochs: 9999 8 | min_steps: 0 9 | max_steps: 999999999 10 | limit_train_batches: 2000 11 | limit_val_batches: 2000 12 | limit_test_batches: 10000 13 | 14 | terminate_on_nan: False 15 | auto_lr_find: False 16 | auto_scale_batch_size: False 17 | accumulate_grad_batches: 1 18 | precision: 32 19 | # stochastic_weight_avg: True 20 | fast_dev_run: False 21 | 22 | reload_dataloaders_every_epoch: True 23 | 24 | weights_summary: null 25 | progress_bar_refresh_rate: 1 26 | # resume_from_checkpoint: "/home/jacob/Development/satflow/satflow/logs/runs/2021-08-09/09-13-09/checkpoints/epoch=09.ckpt" 27 | -------------------------------------------------------------------------------- /satflow/configs/experiment/unet_simple.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python run.py experiment=example_simple.yaml 5 | 6 | defaults: 7 | - override /trainer: minimal.yaml # choose trainer from 'configs/trainer/' 8 | - override /model: unet.yaml 9 | - override /datamodule: unet_dataloaders.yaml 10 | - override /callbacks: default.yaml 11 | - override /logger: null 12 | 13 | # all parameters below will be merged with parameters from default configurations set above 14 | # this allows you to overwrite only specified parameters 15 | 16 | seed: 12345 17 | 18 | trainer: 19 | min_epochs: 1 20 | max_epochs: 5 21 | min_steps: 1000 22 | max_steps: 2000 23 | val_check_interval: 100 24 | limit_train_batches: 200 25 | limit_val_batches: 100 26 | -------------------------------------------------------------------------------- /satflow/configs/model/perceiver_encoder.yaml: -------------------------------------------------------------------------------- 1 | _target_: satflow.models.perceiver.Perceiver 2 | input_channels: 12 3 | sat_channels: 12 4 | forecast_steps: 12 5 | lr: 0.0005 6 | input_size: 32 7 | max_frequency: 16.0 8 | depth: 6 9 | num_latents: 256 10 | cross_heads: 1 11 | latent_heads: 8 12 | cross_dim_heads: 8 13 | latent_dim: 256 14 | weight_tie_layers: False 15 | decoder_ff: True 16 | dim: 32 17 | logits_dim: null 18 | queries_dim: 32 19 | latent_dim_heads: 64 20 | visualize: False 21 | preprocessor_type: "conv" 22 | encoder_kwargs: 23 | spatial_downsample: 4 24 | temporal_downsample: 1 25 | output_channels: 64 26 | conv_after_patching: False 27 | conv2d_use_batchnorm: True 28 | postprocessor_type: "conv" 29 | decoder_kwargs: 30 | spatial_upsample: 1 31 | temporal_upsample: 1 32 | output_channels: 12 33 | -------------------------------------------------------------------------------- /satflow/configs/experiment/metnet_simple.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python run.py experiment=example_simple.yaml 5 | 6 | defaults: 7 | - override /trainer: minimal.yaml # choose trainer from 'configs/trainer/' 8 | - override /model: metnet.yaml 9 | - override /datamodule: metnet.yaml 10 | - override /callbacks: default.yaml 11 | - override /logger: neptune.yaml 12 | 13 | # all parameters below will be merged with parameters from default configurations set above 14 | # this allows you to overwrite only specified parameters 15 | 16 | seed: 12345 17 | 18 | trainer: 19 | min_epochs: 1 20 | max_epochs: 10 # Warmup is the first 10 epochs 21 | min_steps: 200 22 | max_steps: 200000 23 | limit_train_batches: 250 24 | limit_val_batches: 500 25 | limit_test_batches: 100 26 | -------------------------------------------------------------------------------- /satflow/configs/experiment/convlstm_simple.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python run.py experiment=example_simple.yaml 5 | 6 | defaults: 7 | - override /trainer: minimal.yaml # choose trainer from 'configs/trainer/' 8 | - override /model: convlstm.yaml 9 | - override /datamodule: satflow_dataloaders.yaml 10 | - override /callbacks: default.yaml 11 | - override /logger: null 12 | 13 | # all parameters below will be merged with parameters from default configurations set above 14 | # this allows you to overwrite only specified parameters 15 | 16 | seed: 12345 17 | 18 | trainer: 19 | min_epochs: 1 20 | max_epochs: 10 21 | gradient_clip_val: 0.5 22 | min_steps: 2000 23 | max_steps: 20000 24 | val_check_interval: 100 25 | limit_train_batches: 2000 26 | limit_val_batches: 500 27 | -------------------------------------------------------------------------------- /satflow/configs/experiment/nowcasting_gan_simple.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python run.py experiment=example_simple.yaml 5 | 6 | defaults: 7 | - override /trainer: minimal.yaml # choose trainer from 'configs/trainer/' 8 | - override /model: nowcasting_gan.yaml 9 | - override /datamodule: nowcasting_gan_hrv_aws.yaml 10 | - override /callbacks: default.yaml 11 | - override /logger: neptune.yaml 12 | 13 | # all parameters below will be merged with parameters from default configurations set above 14 | # this allows you to overwrite only specified parameters 15 | 16 | seed: 12345 17 | 18 | trainer: 19 | min_epochs: 1 20 | max_epochs: 10 # Warmup is the first 10 epochs 21 | min_steps: 200 22 | max_steps: 2000000 23 | limit_train_batches: 1000 24 | limit_val_batches: 1000 25 | limit_test_batches: 100 26 | -------------------------------------------------------------------------------- /satflow/configs/experiment/perceiver_simple.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python run.py experiment=example_simple.yaml 5 | 6 | defaults: 7 | - override /trainer: minimal.yaml # choose trainer from 'configs/trainer/' 8 | - override /model: perceiver_metnet.yaml 9 | - override /datamodule: perceiver_metnet_aws.yaml 10 | - override /callbacks: default.yaml 11 | - override /logger: neptune.yaml 12 | 13 | # all parameters below will be merged with parameters from default configurations set above 14 | # this allows you to overwrite only specified parameters 15 | 16 | seed: null 17 | 18 | trainer: 19 | min_epochs: 1 20 | max_epochs: 10 # Warmup is the first 10 epochs 21 | min_steps: 200 22 | max_steps: 2000000 23 | gradient_clip_val: 0.5 24 | limit_train_batches: 500 25 | limit_val_batches: 1000 26 | limit_test_batches: 100 27 | -------------------------------------------------------------------------------- /satflow/configs/experiment/example_simple.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python run.py experiment=example_simple.yaml 5 | 6 | defaults: 7 | - override /trainer: minimal.yaml # choose trainer from 'configs/trainer/' 8 | - override /model: convlstm.yaml 9 | - override /datamodule: satflow_dataloaders.yaml 10 | - override /callbacks: default.yaml 11 | - override /logger: null 12 | 13 | # all parameters below will be merged with parameters from default configurations set above 14 | # this allows you to overwrite only specified parameters 15 | 16 | seed: 12345 17 | 18 | trainer: 19 | min_epochs: 1 20 | max_epochs: 10 21 | gradient_clip_val: 0.5 22 | 23 | model: 24 | lin1_size: 128 25 | lin2_size: 256 26 | lin3_size: 64 27 | lr: 0.002 28 | 29 | datamodule: 30 | batch_size: 64 31 | train_val_test_split: [55_000, 5_000, 10_000] 32 | -------------------------------------------------------------------------------- /satflow/configs/callbacks/gan.yaml: -------------------------------------------------------------------------------- 1 | model_checkpoint: 2 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 3 | monitor: "val/g_loss" # name of the logged metric which determines when model is improving 4 | save_top_k: 1 # save k best models (determined by above metric) 5 | save_last: True # additionaly always save model from last epoch 6 | mode: "min" # can be "max" or "min" 7 | verbose: False 8 | dirpath: "checkpoints/" 9 | filename: "{epoch:02d}" 10 | # TODO Get this working with FID or some other metric 11 | #early_stopping: 12 | # _target_: pytorch_lightning.callbacks.EarlyStopping 13 | # monitor: "val/loss" # name of the logged metric which determines when model is improving 14 | # patience: 30 # how many epochs of not improving until training stops 15 | # mode: "min" # can be "max" or "min" 16 | # min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement 17 | -------------------------------------------------------------------------------- /.all-contributorsrc: -------------------------------------------------------------------------------- 1 | { 2 | "files": [ 3 | "README.md" 4 | ], 5 | "imageSize": 100, 6 | "commit": false, 7 | "commitConvention": "angular", 8 | "contributors": [ 9 | { 10 | "login": "jacobbieker", 11 | "name": "Jacob Bieker", 12 | "avatar_url": "https://avatars.githubusercontent.com/u/7170359?v=4", 13 | "profile": "https://www.jacobbieker.com", 14 | "contributions": [ 15 | "code" 16 | ] 17 | }, 18 | { 19 | "login": "lewtun", 20 | "name": "lewtun", 21 | "avatar_url": "https://avatars.githubusercontent.com/u/26859204?v=4", 22 | "profile": "https://lewtun.github.io/blog/", 23 | "contributions": [ 24 | "code" 25 | ] 26 | } 27 | ], 28 | "contributorsPerLine": 7, 29 | "skipCi": true, 30 | "repoType": "github", 31 | "repoHost": "https://github.com", 32 | "projectName": "satflow", 33 | "projectOwner": "openclimatefix" 34 | } 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Open Climate Fix 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from pathlib import Path 3 | 4 | this_directory = Path(__file__).parent 5 | install_requires = (this_directory / "requirements.txt").read_text().splitlines() 6 | long_description = (this_directory / "README.md").read_text() 7 | 8 | exec(open("satflow/version.py").read()) 9 | setup( 10 | name="satflow", 11 | version=__version__, 12 | packages=["satflow", "satflow.data", "satflow.models"], 13 | url="https://github.com/openclimatefix/satflow", 14 | license="MIT License", 15 | company="Open Climate Fix Ltd", 16 | author="Jacob Bieker", 17 | install_requires=install_requires, 18 | long_description=long_description, 19 | ong_description_content_type="text/markdown", 20 | author_email="jacob@openclimatefix.org", 21 | description="Satellite Optical Flow", 22 | classifiers=[ 23 | "Development Status :: 4 - Beta", 24 | "Intended Audience :: Developers", 25 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 26 | "License :: OSI Approved :: MIT License", 27 | "Programming Language :: Python :: 3.8", 28 | ], 29 | ) 30 | -------------------------------------------------------------------------------- /satflow/configs/callbacks/default.yaml: -------------------------------------------------------------------------------- 1 | model_checkpoint: 2 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 3 | monitor: "val/loss" # name of the logged metric which determines when model is improving 4 | save_top_k: 1 # save k best models (determined by above metric) 5 | save_last: True # additionaly always save model from last epoch 6 | save_weights_only: True # Save only weights and hyperparams, makes smaller and doesn't include callbacks/optimizer/etc. Generally, this should be True, as haven't really been restarting training runs much 7 | mode: "min" # can be "max" or "min" 8 | verbose: False 9 | dirpath: "checkpoints/" 10 | filename: "best" 11 | 12 | early_stopping: 13 | _target_: pytorch_lightning.callbacks.EarlyStopping 14 | monitor: "val/loss" # name of the logged metric which determines when model is improving 15 | patience: 10 # how many epochs of not improving until training stops 16 | mode: "min" # can be "max" or "min" 17 | min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement 18 | 19 | model_logging: 20 | _target_: nowcasting_utils.training.callbacks.NeptuneModelLogger 21 | -------------------------------------------------------------------------------- /satflow/configs/configurations/aws.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | description: AWS configuration 3 | name: aws 4 | input_data: 5 | bucket: solar-pv-nowcasting-data 6 | nwp_base_path: NWP/UK_Met_Office/UKV__2018-01_to_2019-12__chunks__variable10__init_time1__step1__x548__y704__.zarr 7 | satellite_filename: satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr 8 | solar_pv_data_filename: UK_PV_timeseries_batch.nc 9 | solar_pv_metadata_filename: UK_PV_metadata.csv 10 | solar_pv_path: PV/PVOutput.org 11 | topographic_filename: Topographic/europe_dem_1km_osgb.tif 12 | output_data: 13 | filepath: solar-pv-nowcasting-data/prepared_ML_training_data/v6/ 14 | process: 15 | seed: 1234 16 | batch_size: 32 17 | forecast_minutes: 60 18 | history_minutes: 30 19 | satellite_image_size_pixels: 64 20 | nwp_image_size_pixels: 64 21 | nwp_channels: 22 | - t 23 | - dswrf 24 | - prate 25 | - r 26 | - sde 27 | - si10 28 | - vis 29 | - lcc 30 | - mcc 31 | - hcc 32 | sat_channels: 33 | - HRV 34 | - IR_016 35 | - IR_039 36 | - IR_087 37 | - IR_097 38 | - IR_108 39 | - IR_120 40 | - IR_134 41 | - VIS006 42 | - VIS008 43 | - WV_062 44 | - WV_073 45 | val_check_interval: 1000 46 | -------------------------------------------------------------------------------- /satflow/configs/configurations/gcp.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | description: GCP configuration 3 | name: gcp 4 | input_data: 5 | bucket: solar-pv-nowcasting-data 6 | nwp_base_path: NWP/UK_Met_Office/UKV__2018-01_to_2019-12__chunks__variable10__init_time1__step1__x548__y704__.zarr 7 | satellite_filename: satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr 8 | solar_pv_data_filename: UK_PV_timeseries_batch.nc 9 | solar_pv_metadata_filename: UK_PV_metadata.csv 10 | solar_pv_path: PV/PVOutput.org 11 | topographic_filename: Topographic/europe_dem_1km_osgb.tif 12 | output_data: 13 | filepath: solar-pv-nowcasting-data/prepared_ML_training_data/v6/ 14 | process: 15 | seed: 1234 16 | batch_size: 32 17 | forecast_minutes: 60 18 | history_minutes: 30 19 | satellite_image_size_pixels: 64 20 | nwp_image_size_pixels: 64 21 | nwp_channels: 22 | - t 23 | - dswrf 24 | - prate 25 | - r 26 | - sde 27 | - si10 28 | - vis 29 | - lcc 30 | - mcc 31 | - hcc 32 | sat_channels: 33 | - HRV 34 | - IR_016 35 | - IR_039 36 | - IR_087 37 | - IR_097 38 | - IR_108 39 | - IR_120 40 | - IR_134 41 | - VIS006 42 | - VIS008 43 | - WV_062 44 | - WV_073 45 | val_check_interval: 1000 46 | -------------------------------------------------------------------------------- /satflow/configs/configurations/local.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | description: local configuration 3 | name: local 4 | input_data: 5 | bucket: solar-pv-nowcasting-data 6 | nwp_base_path: NWP/UK_Met_Office/UKV__2018-01_to_2019-12__chunks__variable10__init_time1__step1__x548__y704__.zarr 7 | satellite_filename: satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr 8 | solar_pv_data_filename: UK_PV_timeseries_batch.nc 9 | solar_pv_metadata_filename: UK_PV_metadata.csv 10 | solar_pv_path: PV/PVOutput.org 11 | topographic_filename: Topographic/europe_dem_1km_osgb.tif 12 | output_data: 13 | filepath: solar-pv-nowcasting-data/prepared_ML_training_data/v6/ 14 | process: 15 | seed: 1234 16 | batch_size: 32 17 | forecast_minutes: 60 18 | history_minutes: 30 19 | satellite_image_size_pixels: 64 20 | nwp_image_size_pixels: 64 21 | nwp_channels: 22 | - t 23 | - dswrf 24 | - prate 25 | - r 26 | - sde 27 | - si10 28 | - vis 29 | - lcc 30 | - mcc 31 | - hcc 32 | sat_channels: 33 | - HRV 34 | - IR_016 35 | - IR_039 36 | - IR_087 37 | - IR_097 38 | - IR_108 39 | - IR_120 40 | - IR_134 41 | - VIS006 42 | - VIS008 43 | - WV_062 44 | - WV_073 45 | val_check_interval: 1000 46 | -------------------------------------------------------------------------------- /satflow/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["HYDRA_FULL_ERROR"] = "1" 4 | import dotenv 5 | import hydra 6 | from omegaconf import DictConfig 7 | 8 | # load environment variables from `.env` file if it exists 9 | # recursively searches for `.env` in all folders starting from work dir 10 | dotenv.load_dotenv(override=True) 11 | 12 | 13 | @hydra.main(config_path="configs/", config_name="config.yaml") 14 | def main(config: DictConfig): 15 | 16 | # Imports should be nested inside @hydra.main to optimize tab completion 17 | # Read more here: https://github.com/facebookresearch/hydra/issues/934 18 | from satflow.core import utils 19 | from satflow.experiments.train import train 20 | 21 | # A couple of optional utilities: 22 | # - disabling python warnings 23 | # - easier access to debug mode 24 | # - forcing debug friendly configuration 25 | # - forcing multi-gpu friendly configuration 26 | # You can safely get rid of this line if you don't want those 27 | utils.extras(config) 28 | 29 | # 30 | 31 | # Pretty print config using Rich library 32 | if config.get("print_config"): 33 | utils.print_config(config, resolve=True) 34 | 35 | # Train model 36 | return train(config) 37 | 38 | 39 | if __name__ == "__main__": 40 | main() 41 | -------------------------------------------------------------------------------- /satflow/configs/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: pytorch_lightning.Trainer 3 | 4 | # default values for all trainer parameters 5 | checkpoint_callback: True 6 | default_root_dir: null 7 | gradient_clip_val: 0.0 8 | process_position: 0 9 | num_nodes: 1 10 | num_processes: 1 11 | gpus: null 12 | auto_select_gpus: False 13 | tpu_cores: null 14 | log_gpu_memory: null 15 | progress_bar_refresh_rate: 1 16 | overfit_batches: 0.0 17 | track_grad_norm: -1 18 | check_val_every_n_epoch: 1 19 | fast_dev_run: False 20 | accumulate_grad_batches: 1 21 | max_epochs: 1 22 | min_epochs: 1 23 | max_steps: null 24 | min_steps: null 25 | limit_train_batches: 1.0 26 | limit_val_batches: 1.0 27 | limit_test_batches: 1.0 28 | val_check_interval: 1.0 29 | flush_logs_every_n_steps: 100 30 | log_every_n_steps: 50 31 | accelerator: null 32 | sync_batchnorm: False 33 | precision: 32 34 | weights_summary: "top" 35 | weights_save_path: null 36 | num_sanity_val_steps: 2 37 | truncated_bptt_steps: null 38 | resume_from_checkpoint: null 39 | profiler: null 40 | benchmark: False 41 | deterministic: False 42 | reload_dataloaders_every_epoch: False 43 | auto_lr_find: False 44 | replace_sampler_ddp: True 45 | terminate_on_nan: False 46 | auto_scale_batch_size: False 47 | prepare_data_per_node: True 48 | plugins: null 49 | amp_backend: "native" 50 | amp_level: "O2" 51 | move_metrics_to_cpu: False 52 | -------------------------------------------------------------------------------- /satflow/configs/trainer/half.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: pytorch_lightning.Trainer 3 | 4 | # default values for all trainer parameters 5 | checkpoint_callback: True 6 | default_root_dir: null 7 | gradient_clip_val: 0.0 8 | process_position: 0 9 | num_nodes: 1 10 | num_processes: 1 11 | gpus: 1 12 | auto_select_gpus: False 13 | tpu_cores: null 14 | log_gpu_memory: null 15 | progress_bar_refresh_rate: 1 16 | overfit_batches: 0.0 17 | track_grad_norm: -1 18 | check_val_every_n_epoch: 1 19 | fast_dev_run: False 20 | accumulate_grad_batches: 1 21 | min_epochs: 0 22 | max_epochs: 50 23 | min_steps: 2000 24 | max_steps: 200000 25 | val_check_interval: 1000 26 | limit_train_batches: 5000 27 | limit_val_batches: 500 28 | limit_test_batches: 5000 29 | flush_logs_every_n_steps: 100 30 | log_every_n_steps: 50 31 | accelerator: null 32 | sync_batchnorm: False 33 | precision: 16 34 | weights_summary: "top" 35 | weights_save_path: null 36 | num_sanity_val_steps: 2 37 | truncated_bptt_steps: null 38 | resume_from_checkpoint: null 39 | profiler: null 40 | benchmark: False 41 | deterministic: False 42 | reload_dataloaders_every_epoch: False 43 | auto_lr_find: False 44 | replace_sampler_ddp: True 45 | terminate_on_nan: False 46 | auto_scale_batch_size: False 47 | prepare_data_per_node: True 48 | plugins: null 49 | amp_backend: "native" 50 | amp_level: "O2" 51 | move_metrics_to_cpu: False 52 | -------------------------------------------------------------------------------- /satflow/configs/trainer/simple_profiler.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: pytorch_lightning.Trainer 3 | 4 | # default values for all trainer parameters 5 | checkpoint_callback: True 6 | default_root_dir: null 7 | gradient_clip_val: 0.0 8 | process_position: 0 9 | num_nodes: 1 10 | num_processes: 1 11 | gpus: 1 12 | auto_select_gpus: False 13 | tpu_cores: null 14 | log_gpu_memory: null 15 | progress_bar_refresh_rate: 1 16 | overfit_batches: 0.0 17 | track_grad_norm: -1 18 | check_val_every_n_epoch: 1 19 | fast_dev_run: False 20 | accumulate_grad_batches: 1 21 | limit_test_batches: 1.0 22 | flush_logs_every_n_steps: 100 23 | log_every_n_steps: 50 24 | accelerator: null 25 | sync_batchnorm: False 26 | precision: 32 27 | weights_summary: "top" 28 | weights_save_path: null 29 | num_sanity_val_steps: 2 30 | truncated_bptt_steps: null 31 | resume_from_checkpoint: null 32 | profiler: "simple" 33 | benchmark: False 34 | deterministic: False 35 | reload_dataloaders_every_epoch: False 36 | auto_lr_find: False 37 | replace_sampler_ddp: True 38 | terminate_on_nan: False 39 | auto_scale_batch_size: False 40 | prepare_data_per_node: True 41 | plugins: null 42 | amp_backend: "native" 43 | amp_level: "O2" 44 | move_metrics_to_cpu: False 45 | min_epochs: 0 46 | max_epochs: 5 47 | min_steps: 200 48 | max_steps: 2000 49 | val_check_interval: 100 50 | limit_train_batches: 200 51 | limit_val_batches: 50 52 | -------------------------------------------------------------------------------- /satflow/configs/trainer/pytorch_profiler.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: pytorch_lightning.Trainer 3 | 4 | # default values for all trainer parameters 5 | checkpoint_callback: True 6 | default_root_dir: null 7 | gradient_clip_val: 0.0 8 | process_position: 0 9 | num_nodes: 1 10 | num_processes: 1 11 | gpus: 1 12 | auto_select_gpus: False 13 | tpu_cores: null 14 | log_gpu_memory: null 15 | progress_bar_refresh_rate: 1 16 | overfit_batches: 0.0 17 | track_grad_norm: -1 18 | check_val_every_n_epoch: 1 19 | fast_dev_run: False 20 | accumulate_grad_batches: 1 21 | limit_test_batches: 1.0 22 | flush_logs_every_n_steps: 100 23 | log_every_n_steps: 50 24 | accelerator: null 25 | sync_batchnorm: False 26 | precision: 32 27 | weights_summary: "top" 28 | weights_save_path: null 29 | num_sanity_val_steps: 2 30 | truncated_bptt_steps: null 31 | resume_from_checkpoint: null 32 | profiler: "pytorch" 33 | benchmark: False 34 | deterministic: False 35 | reload_dataloaders_every_epoch: False 36 | auto_lr_find: False 37 | replace_sampler_ddp: True 38 | terminate_on_nan: False 39 | auto_scale_batch_size: False 40 | prepare_data_per_node: True 41 | plugins: null 42 | amp_backend: "native" 43 | amp_level: "O2" 44 | move_metrics_to_cpu: False 45 | min_epochs: 0 46 | max_epochs: 5 47 | min_steps: 200 48 | max_steps: 2000 49 | val_check_interval: 100 50 | limit_train_batches: 200 51 | limit_val_batches: 50 52 | -------------------------------------------------------------------------------- /satflow/models/layers/ConditionTime.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | 4 | 5 | def condition_time(x, i=0, size=(12, 16), seq_len=15): 6 | "create one hot encoded time image-layers, i in [1, seq_len]" 7 | assert i < seq_len 8 | times = (torch.eye(seq_len, dtype=x.dtype, device=x.device)[i]).unsqueeze(-1).unsqueeze(-1) 9 | ones = torch.ones(1, *size, dtype=x.dtype, device=x.device) 10 | return times * ones 11 | 12 | 13 | class ConditionTime(nn.Module): 14 | "Condition Time on a stack of images, adds `horizon` channels to image" 15 | 16 | def __init__(self, horizon, ch_dim=2, num_dims=5): 17 | super().__init__() 18 | self.horizon = horizon 19 | self.ch_dim = ch_dim 20 | self.num_dims = num_dims 21 | 22 | def forward(self, x, fstep=0): 23 | "x stack of images, fsteps" 24 | if self.num_dims == 5: 25 | bs, seq_len, ch, h, w = x.shape 26 | ct = condition_time(x, fstep, (h, w), seq_len=self.horizon).repeat(bs, seq_len, 1, 1, 1) 27 | else: 28 | bs, h, w, ch = x.shape 29 | ct = condition_time(x, fstep, (h, w), seq_len=self.horizon).repeat(bs, 1, 1, 1) 30 | ct = ct.permute(0, 2, 3, 1) 31 | x = torch.cat([x, ct], dim=self.ch_dim) 32 | assert x.shape[self.ch_dim] == (ch + self.horizon) # check if it makes sense 33 | return x 34 | -------------------------------------------------------------------------------- /satflow/configs/trainer/deepspeed.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: pytorch_lightning.Trainer 3 | 4 | # default values for all trainer parameters 5 | checkpoint_callback: True 6 | default_root_dir: null 7 | gradient_clip_val: 0.0 8 | process_position: 0 9 | num_nodes: 1 10 | num_processes: 1 11 | gpus: 1 12 | auto_select_gpus: False 13 | tpu_cores: null 14 | log_gpu_memory: null 15 | progress_bar_refresh_rate: 1 16 | overfit_batches: 0.0 17 | track_grad_norm: -1 18 | check_val_every_n_epoch: 1 19 | fast_dev_run: False 20 | accumulate_grad_batches: 1 21 | min_epochs: 0 22 | max_epochs: 50 23 | min_steps: 2000 24 | max_steps: 200000 25 | val_check_interval: 1000 26 | limit_train_batches: 5000 27 | limit_val_batches: 500 28 | limit_test_batches: 5000 29 | flush_logs_every_n_steps: 100 30 | log_every_n_steps: 50 31 | accelerator: null 32 | sync_batchnorm: False 33 | precision: 16 34 | weights_summary: "top" 35 | weights_save_path: null 36 | num_sanity_val_steps: 2 37 | truncated_bptt_steps: null 38 | resume_from_checkpoint: null 39 | profiler: null 40 | benchmark: False 41 | deterministic: False 42 | reload_dataloaders_every_epoch: False 43 | auto_lr_find: False 44 | replace_sampler_ddp: True 45 | terminate_on_nan: False 46 | auto_scale_batch_size: False 47 | prepare_data_per_node: True 48 | plugins: "deepspeed_stage_2_offload" 49 | amp_backend: "native" 50 | amp_level: "O2" 51 | move_metrics_to_cpu: False 52 | -------------------------------------------------------------------------------- /satflow/configs/trainer/deepspeed_zero_three.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: pytorch_lightning.Trainer 3 | 4 | # default values for all trainer parameters 5 | checkpoint_callback: True 6 | default_root_dir: null 7 | gradient_clip_val: 0.0 8 | process_position: 0 9 | num_nodes: 1 10 | num_processes: 1 11 | gpus: 1 12 | auto_select_gpus: False 13 | tpu_cores: null 14 | log_gpu_memory: null 15 | progress_bar_refresh_rate: 1 16 | overfit_batches: 0.0 17 | track_grad_norm: -1 18 | check_val_every_n_epoch: 1 19 | fast_dev_run: False 20 | accumulate_grad_batches: 1 21 | min_epochs: 0 22 | max_epochs: 50 23 | min_steps: 2000 24 | max_steps: 200000 25 | val_check_interval: 1000 26 | limit_train_batches: 5000 27 | limit_val_batches: 500 28 | limit_test_batches: 5000 29 | flush_logs_every_n_steps: 100 30 | log_every_n_steps: 50 31 | accelerator: null 32 | sync_batchnorm: False 33 | precision: 16 34 | weights_summary: "top" 35 | weights_save_path: null 36 | num_sanity_val_steps: 2 37 | truncated_bptt_steps: null 38 | resume_from_checkpoint: null 39 | profiler: null 40 | benchmark: False 41 | deterministic: False 42 | reload_dataloaders_every_epoch: False 43 | auto_lr_find: False 44 | replace_sampler_ddp: True 45 | terminate_on_nan: False 46 | auto_scale_batch_size: False 47 | prepare_data_per_node: True 48 | plugins: "deepspeed_stage_3_offload" 49 | amp_backend: "native" 50 | amp_level: "O2" 51 | move_metrics_to_cpu: False 52 | -------------------------------------------------------------------------------- /satflow/configs/config.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default training configuration 4 | defaults: 5 | - trainer: minimal.yaml 6 | - model: metnet.yaml 7 | - datamodule: local.yaml 8 | - callbacks: default.yaml # set this to null if you don't want to use callbacks 9 | - logger: neptune.yaml # set logger here or use command line (e.g. `python run.py logger=wandb`) 10 | 11 | - experiment: null 12 | - hparams_search: metnet_optuna.yaml 13 | 14 | - hydra: default.yaml 15 | 16 | # enable color logging 17 | - override hydra/hydra_logging: colorlog 18 | - override hydra/job_logging: colorlog 19 | 20 | # path to original working directory 21 | # hydra hijacks working directory by changing it to the current log directory, 22 | # so it's useful to have this path as a special variable 23 | # learn more here: https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory 24 | work_dir: ${hydra:runtime.cwd} 25 | 26 | # path to folder with data 27 | data_dir: /run/media/jacob/data/ 28 | 29 | # use `python run.py debug=true` for easy debugging! 30 | # this will run 1 train, val and test loop with only 1 batch 31 | # equivalent to running `python run.py trainer.fast_dev_run=true` 32 | # (this is placed here just for easier access from command line) 33 | debug: False 34 | 35 | # pretty print config at the start of the run using Rich library 36 | print_config: True 37 | 38 | # disable python warnings if they annoy you 39 | ignore_warnings: True 40 | -------------------------------------------------------------------------------- /satflow/examples/metnet_example.py: -------------------------------------------------------------------------------- 1 | import urllib.request 2 | 3 | import torch 4 | 5 | from satflow.models import LitMetNet 6 | 7 | 8 | def get_input_target(number: int): 9 | url = f"https://github.com/openclimatefix/satflow/releases/download/v0.0.3/input_{number}.pth" 10 | filename, headers = urllib.request.urlretrieve(url, filename=f"input_{number}.pth") 11 | input_data = torch.load(filename) 12 | return input_data 13 | 14 | 15 | # Setup the model (need to add loading weights from HuggingFace :) 16 | # 12 satellite channels + 1 Topographic + 3 Lat/Lon + 1 Cloud Mask 17 | # Output Channels: 1 Cloud mask, 12 for Satellite image 18 | model = LitMetNet(input_channels=17, sat_channels=13, input_size=64, out_channels=1) 19 | torch.set_grad_enabled(False) 20 | model.eval() 21 | # The inputs are Tensors of size (Batch, Curr+Prev Timesteps, Channel, Width, Height) 22 | # MetNet uses the last 90min of data, the previous 6 timesteps + Current one 23 | # This gives an input of (Batch, 7, 256, 256, 286), for Satflow, we use (Batch, 7, 17, 256, 256) and do the preprocessing 24 | # in the model 25 | 26 | # Data processing from raw satellite to Tensors is described in satflow/examples/create_webdataset.py and satflow/data/datasets.py 27 | # This just takes the output from the Dataloader, which has been stored here 28 | 29 | for i in range(11): 30 | forecast = model(get_input_target(i)) 31 | print(forecast.size()) 32 | 33 | # Output for this segmentation model is (Batch, 24, 1, 16, 16) for Satflow, MetNet has an output of (Batch, 480, 1, 256, 256) 34 | -------------------------------------------------------------------------------- /satflow/baseline/README.md: -------------------------------------------------------------------------------- 1 | ## Baseline 2 | 3 | To see if our ML models are actually improving the predictions of the cloud masks, we 4 | have benchmarked it against OpenCV's dense optical flow predictions, as well as a naive 5 | baseline of just predicting the current image for all future timesteps. 6 | 7 | As we come up with and implement more metrics to compare these models, they will be added 8 | here. Currently, the only metric tested is the mean squared error between the predicted frame 9 | and the ground truth frame. To get a sense if there is a temporal dependence, the mean loss is 10 | done not just for the overall predictions, but for each of the future timesteps, going up to 4 hours (48 timesteps) 11 | in the future. 12 | 13 | On average, the optical flow approach has an MSE of 0.1541. The naive baseline has a MSE of 0.1566, 14 | so optical flow beats out the naive baseline by about 1.6%. 15 | 16 | ## Caveats 17 | 18 | We tried obtaining the optical flow of consecutive, or even very temporally separated the cloud masks, 19 | but the optical flow usually ended up not actually changing anything. Instead, we used the 20 | MSG HRV satellite channel to compute the optical flow. This was chosen as that is the highest 21 | resolution satellite channel available, and it resulted in optical flow actually computing some movement. 22 | This flow field was then applied to the cloud masks directly to obtain the flow results. 23 | 24 | Avg Total Loss: 0.15720261434381796 Avg Baseline Loss: 0.1598897848692671 25 | Overall Loss: 0.15720261434381738 Baseline: 0.1598897848692671 26 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3.9 3 | 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v4.3.0 7 | hooks: 8 | # list of supported hooks: https://pre-commit.com/hooks.html 9 | - id: trailing-whitespace 10 | - id: end-of-file-fixer 11 | - id: check-yaml 12 | - id: debug-statements 13 | - id: detect-private-key 14 | 15 | # python code formatting/linting 16 | - repo: https://github.com/PyCQA/pydocstyle 17 | rev: 6.1.1 18 | hooks: 19 | - id: pydocstyle 20 | args: 21 | [ 22 | --convention=google, 23 | "--add-ignore=D200,D202,D210,D212,D415", 24 | "satflow", 25 | ] 26 | - repo: https://github.com/PyCQA/flake8 27 | rev: 5.0.4 28 | hooks: 29 | - id: flake8 30 | args: 31 | [ 32 | --max-line-length, 33 | "100", 34 | --extend-ignore=E203, 35 | --per-file-ignores, 36 | "__init__.py:F401", 37 | "satflow", 38 | ] 39 | - repo: https://github.com/PyCQA/isort 40 | rev: 5.10.1 41 | hooks: 42 | - id: isort 43 | args: [--profile, black, --line-length, "100", "satflow"] 44 | - repo: https://github.com/psf/black 45 | rev: 22.6.0 46 | hooks: 47 | - id: black 48 | args: [--line-length, "100"] 49 | 50 | # yaml formatting 51 | - repo: https://github.com/pre-commit/mirrors-prettier 52 | rev: v2.7.1 53 | hooks: 54 | - id: prettier 55 | types: [yaml] 56 | -------------------------------------------------------------------------------- /.github/workflows/docker.yaml: -------------------------------------------------------------------------------- 1 | # This workflow uses actions that are not certified by GitHub. 2 | # They are provided by a third-party and are governed by 3 | # separate terms of service, privacy policy, and support 4 | # documentation. 5 | 6 | name: Publish Docker image 7 | 8 | on: 9 | [push] 10 | # release: 11 | # types: [published] 12 | 13 | jobs: 14 | push_to_registries: 15 | name: Push Docker image to multiple registries 16 | runs-on: ubuntu-latest 17 | permissions: 18 | packages: write 19 | contents: read 20 | steps: 21 | - name: Check out the repo 22 | uses: actions/checkout@v2 23 | 24 | - name: Log in to Docker Hub 25 | uses: docker/login-action@f054a8b539a109f9f41c372932f1ae047eff08c9 26 | with: 27 | username: ${{ secrets.DOCKERUSERNAME }} 28 | password: ${{ secrets.DOCKERPASSWORD }} 29 | 30 | - name: Log in to the Container registry 31 | uses: docker/login-action@f054a8b539a109f9f41c372932f1ae047eff08c9 32 | with: 33 | registry: ghcr.io 34 | username: ${{ github.actor }} 35 | password: ${{ secrets.GITHUB_TOKEN }} 36 | 37 | - name: Extract metadata (tags, labels) for Docker 38 | id: meta 39 | uses: docker/metadata-action@98669ae865ea3cffbcbaa878cf57c20bbf1c6c38 40 | with: 41 | images: | 42 | jacobbieker/satflow 43 | ghcr.io/${{ github.repository }} 44 | 45 | - name: Build and push Docker images 46 | uses: docker/build-push-action@ad44023a93711e3deb337508980b4b5e9bcdc5dc 47 | with: 48 | context: . 49 | push: true 50 | tags: ${{ steps.meta.outputs.tags }} 51 | labels: ${{ steps.meta.outputs.labels }} 52 | -------------------------------------------------------------------------------- /satflow/configs/experiment/example_full.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python run.py experiment=example_full.yaml 5 | 6 | defaults: 7 | - override /trainer: null # override trainer to null so it's not loaded from main config defaults... 8 | - override /model: null 9 | - override /datamodule: null 10 | - override /callbacks: null 11 | - override /logger: null 12 | 13 | # we override default configurations with nulls to prevent them from loading at all 14 | # instead we define all modules and their paths directly in this config, 15 | # so everything is stored in one place 16 | 17 | seed: 12345 18 | 19 | trainer: 20 | _target_: pytorch_lightning.Trainer 21 | gpus: 0 22 | min_epochs: 1 23 | max_epochs: 10 24 | gradient_clip_val: 0.5 25 | accumulate_grad_batches: 2 26 | weights_summary: null 27 | # resume_from_checkpoint: ${work_dir}/last.ckpt 28 | 29 | model: 30 | _target_: src.models.mnist_model.MNISTLitModel 31 | lr: 0.001 32 | weight_decay: 0.00005 33 | architecture: SimpleDenseNet 34 | input_size: 784 35 | lin1_size: 256 36 | lin2_size: 256 37 | lin3_size: 128 38 | output_size: 10 39 | 40 | datamodule: 41 | _target_: src.datamodules.mnist_datamodule.MNISTDataModule 42 | data_dir: ${data_dir} 43 | batch_size: 64 44 | train_val_test_split: [55_000, 5_000, 10_000] 45 | num_workers: 0 46 | pin_memory: False 47 | 48 | callbacks: 49 | model_checkpoint: 50 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 51 | monitor: "val/acc" 52 | save_top_k: 2 53 | save_last: True 54 | mode: "max" 55 | dirpath: "checkpoints/" 56 | filename: "sample-mnist-{epoch:02d}" 57 | early_stopping: 58 | _target_: pytorch_lightning.callbacks.EarlyStopping 59 | monitor: "val/acc" 60 | patience: 10 61 | mode: "max" 62 | 63 | logger: 64 | csv_logger: 65 | save_dir: "." 66 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Build: sudo docker build -t . 2 | # Run: sudo docker run -v $(pwd):/workspace/project --gpus all -it --rm 3 | 4 | 5 | FROM nvidia/cuda:11.1-cudnn8-devel-ubuntu20.04 6 | 7 | 8 | ENV CONDA_ENV_NAME=satflow 9 | ENV PYTHON_VERSION=3.9 10 | 11 | 12 | # Basic setup 13 | RUN apt update && apt install -y bash \ 14 | build-essential \ 15 | git \ 16 | curl \ 17 | ca-certificates \ 18 | wget \ 19 | libaio-dev \ 20 | && rm -rf /var/lib/apt/lists 21 | 22 | # Install Miniconda and create main env 23 | ADD https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh miniconda3.sh 24 | RUN /bin/bash miniconda3.sh -b -p /conda \ 25 | && echo export PATH=/conda/bin:$PATH >> .bashrc \ 26 | && rm miniconda3.sh 27 | ENV PATH="/conda/bin:${PATH}" 28 | # RUN conda create -n ${CONDA_ENV_NAME} python=${PYTHON_VERSION} pytorch::pytorch=1.9 torchvision cudatoolkit=11.1 iris rasterio numpy cartopy satpy matplotlib hydra-core pytorch-lightning optuna eccodes -c conda-forge -c nvidia -c pytorch 29 | COPY environment.yml ./ 30 | RUN conda env create --file environment.yml && rm environment.yml 31 | 32 | # Switch to bash shell 33 | SHELL ["/bin/bash", "-c"] 34 | 35 | # Install DeepSpeed and build extensions 36 | RUN source activate ${CONDA_ENV_NAME} \ 37 | && DS_BUILD_OPS=1 pip install deepspeed 38 | 39 | 40 | # Install requirements 41 | COPY requirements.txt ./ 42 | RUN source activate ${CONDA_ENV_NAME} \ 43 | && pip install --no-cache-dir -r requirements.txt \ 44 | && rm requirements.txt 45 | 46 | 47 | # Set ${CONDA_ENV_NAME} to default virutal environment 48 | RUN echo "source activate ${CONDA_ENV_NAME}" >> ~/.bashrc 49 | 50 | # Cp in the development directory and install 51 | COPY . ./ 52 | RUN source activate ${CONDA_ENV_NAME} && pip install -e . 53 | -------------------------------------------------------------------------------- /satflow/models/layers/CoordConv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class AddCoords(nn.Module): 6 | def __init__(self, with_r=False): 7 | super().__init__() 8 | self.with_r = with_r 9 | 10 | def forward(self, input_tensor): 11 | """ 12 | Args: 13 | input_tensor: shape(batch, channel, x_dim, y_dim) 14 | """ 15 | batch_size, _, x_dim, y_dim = input_tensor.size() 16 | 17 | xx_channel = torch.arange(x_dim).repeat(1, y_dim, 1) 18 | yy_channel = torch.arange(y_dim).repeat(1, x_dim, 1).transpose(1, 2) 19 | 20 | xx_channel = xx_channel.float() / (x_dim - 1) 21 | yy_channel = yy_channel.float() / (y_dim - 1) 22 | 23 | xx_channel = xx_channel * 2 - 1 24 | yy_channel = yy_channel * 2 - 1 25 | 26 | xx_channel = xx_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3) 27 | yy_channel = yy_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3) 28 | 29 | ret = torch.cat( 30 | [input_tensor, xx_channel.type_as(input_tensor), yy_channel.type_as(input_tensor)], 31 | dim=1, 32 | ) 33 | 34 | if self.with_r: 35 | rr = torch.sqrt( 36 | torch.pow(xx_channel.type_as(input_tensor) - 0.5, 2) 37 | + torch.pow(yy_channel.type_as(input_tensor) - 0.5, 2) 38 | ) 39 | ret = torch.cat([ret, rr], dim=1) 40 | 41 | return ret 42 | 43 | 44 | class CoordConv(nn.Module): 45 | def __init__(self, in_channels, out_channels, with_r=False, **kwargs): 46 | super().__init__() 47 | self.addcoords = AddCoords(with_r=with_r) 48 | in_size = in_channels + 2 49 | if with_r: 50 | in_size += 1 51 | self.conv = nn.Conv2d(in_size, out_channels, **kwargs) 52 | 53 | def forward(self, x): 54 | ret = self.addcoords(x) 55 | ret = self.conv(ret) 56 | return ret 57 | -------------------------------------------------------------------------------- /satflow/models/layers/TimeDistributed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def _stack_tups(tuples, stack_dim=1): 6 | "Stack tuple of tensors along `stack_dim`" 7 | return tuple( 8 | torch.stack([t[i] for t in tuples], dim=stack_dim) for i in list(range(len(tuples[0]))) 9 | ) 10 | 11 | 12 | class TimeDistributed(nn.Module): 13 | "Applies `module` over `tdim` identically for each step, use `low_mem` to compute one at a time." 14 | 15 | def __init__(self, module, low_mem=False, tdim=1): 16 | super().__init__() 17 | self.module = module 18 | self.low_mem = low_mem 19 | self.tdim = tdim 20 | 21 | def forward(self, *tensors, **kwargs): 22 | "input x with shape:(bs,seq_len,channels,width,height)" 23 | if self.low_mem or self.tdim != 1: 24 | return self.low_mem_forward(*tensors, **kwargs) 25 | # only support tdim=1 26 | inp_shape = tensors[0].shape 27 | bs, seq_len = inp_shape[0], inp_shape[1] 28 | out = self.module(*[x.view(bs * seq_len, *x.shape[2:]) for x in tensors], **kwargs) 29 | return self.format_output(out, bs, seq_len) 30 | 31 | def low_mem_forward(self, *tensors, **kwargs): 32 | "input x with shape:(bs,seq_len,channels,width,height)" 33 | seq_len = tensors[0].shape[self.tdim] 34 | args_split = [torch.unbind(x, dim=self.tdim) for x in tensors] 35 | out = [] 36 | for i in range(seq_len): 37 | out.append(self.module(*[args[i] for args in args_split]), **kwargs) 38 | if isinstance(out[0], tuple): 39 | return _stack_tups(out, stack_dim=self.tdim) 40 | return torch.stack(out, dim=self.tdim) 41 | 42 | def format_output(self, out, bs, seq_len): 43 | "unstack from batchsize outputs" 44 | if isinstance(out, tuple): 45 | return tuple(out_i.view(bs, seq_len, *out_i.shape[1:]) for out_i in out) 46 | return out.view(bs, seq_len, *out.shape[1:]) 47 | 48 | def __repr__(self): 49 | return f"TimeDistributed({self.module})" 50 | -------------------------------------------------------------------------------- /satflow/configs/hparams_search/convlstm_optuna.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # example hyperparameter optimization of some experiment with Optuna: 4 | # python run.py -m hparams_search=mnist_optuna experiment=example_simple 5 | # python run.py -m hparams_search=mnist_optuna experiment=example_simple hydra.sweeper.n_trials=30 6 | # python run.py -m hparams_search=mnist_optuna experiment=example_simple logger=wandb 7 | 8 | defaults: 9 | - override /hydra/sweeper: optuna 10 | 11 | # choose metric which will be optimized by Optuna 12 | optimized_metric: "val/loss" 13 | 14 | hydra: 15 | # here we define Optuna hyperparameter search 16 | # it optimizes for value returned from function with @hydra.main decorator 17 | # learn more here: https://hydra.cc/docs/next/plugins/optuna_sweeper 18 | sweeper: 19 | _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper 20 | storage: null 21 | study_name: null 22 | n_jobs: 1 23 | 24 | # 'minimize' or 'maximize' the objective 25 | direction: minimize 26 | 27 | # number of experiments that will be executed 28 | n_trials: 20 29 | 30 | # choose Optuna hyperparameter sampler 31 | # learn more here: https://optuna.readthedocs.io/en/stable/reference/samplers.html 32 | sampler: 33 | _target_: optuna.samplers.TPESampler 34 | seed: 12345 35 | consider_prior: true 36 | prior_weight: 1.0 37 | consider_magic_clip: true 38 | consider_endpoints: false 39 | n_startup_trials: 10 40 | n_ei_candidates: 24 41 | multivariate: false 42 | warn_independent_sampling: true 43 | 44 | # define range of hyperparameters 45 | search_space: 46 | datamodule.batch_size: 47 | type: categorical 48 | choices: [1, 2, 4] 49 | datamodule.num_timesteps: 50 | type: categorical 51 | choices: [1, 3, 6, 9, 12, 15, 18, 21, 24] 52 | model.learning_rate: 53 | type: float 54 | low: 0.0001 55 | high: 0.2 56 | model.num_layers: 57 | type: categorical 58 | choices: [8, 16, 32, 64, 128] 59 | -------------------------------------------------------------------------------- /satflow/configs/hparams_search/unet_optuna.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # example hyperparameter optimization of some experiment with Optuna: 4 | # python run.py -m hparams_search=mnist_optuna experiment=example_simple 5 | # python run.py -m hparams_search=mnist_optuna experiment=example_simple hydra.sweeper.n_trials=30 6 | # python run.py -m hparams_search=mnist_optuna experiment=example_simple logger=wandb 7 | 8 | defaults: 9 | - override /hydra/sweeper: optuna 10 | 11 | # choose metric which will be optimized by Optuna 12 | optimized_metric: "val/loss" 13 | 14 | hydra: 15 | # here we define Optuna hyperparameter search 16 | # it optimizes for value returned from function with @hydra.main decorator 17 | # learn more here: https://hydra.cc/docs/next/plugins/optuna_sweeper 18 | sweeper: 19 | _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper 20 | storage: null 21 | study_name: null 22 | n_jobs: 1 23 | 24 | # 'minimize' or 'maximize' the objective 25 | direction: minimize 26 | 27 | # number of experiments that will be executed 28 | n_trials: 20 29 | 30 | # choose Optuna hyperparameter sampler 31 | # learn more here: https://optuna.readthedocs.io/en/stable/reference/samplers.html 32 | sampler: 33 | _target_: optuna.samplers.TPESampler 34 | seed: 12345 35 | consider_prior: true 36 | prior_weight: 1.0 37 | consider_magic_clip: true 38 | consider_endpoints: false 39 | n_startup_trials: 10 40 | n_ei_candidates: 24 41 | multivariate: false 42 | warn_independent_sampling: true 43 | 44 | # define range of hyperparameters 45 | search_space: 46 | datamodule.batch_size: 47 | type: categorical 48 | choices: [1, 2, 4] 49 | datamodule.num_timesteps: 50 | type: categorical 51 | choices: [1, 3, 6, 9, 12, 15, 18, 21, 24] 52 | model.learning_rate: 53 | type: float 54 | low: 0.00001 55 | high: 0.2 56 | model.num_layers: 57 | type: categorical 58 | choices: [2, 3, 5, 7, 9] 59 | model.features_start: 60 | type: categorical 61 | choices: [8, 16, 32, 64, 96, 128] 62 | -------------------------------------------------------------------------------- /satflow/models/layers/ConvLSTM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from satflow.models.utils import get_conv_layer 5 | 6 | 7 | class ConvLSTMCell(nn.Module): 8 | def __init__(self, input_dim, hidden_dim, kernel_size, bias, conv_type: str = "standard"): 9 | """ 10 | Initialize ConvLSTM cell. 11 | 12 | Parameters 13 | ---------- 14 | input_dim: int 15 | Number of channels of input tensor. 16 | hidden_dim: int 17 | Number of channels of hidden state. 18 | kernel_size: (int, int) 19 | Size of the convolutional kernel. 20 | bias: bool 21 | Whether or not to add the bias. 22 | """ 23 | 24 | super(ConvLSTMCell, self).__init__() 25 | 26 | self.input_dim = input_dim 27 | self.hidden_dim = hidden_dim 28 | 29 | self.kernel_size = kernel_size 30 | self.padding = kernel_size[0] // 2, kernel_size[1] // 2 31 | self.bias = bias 32 | conv2d = get_conv_layer(conv_type) 33 | 34 | self.conv = conv2d( 35 | in_channels=self.input_dim + self.hidden_dim, 36 | out_channels=4 * self.hidden_dim, 37 | kernel_size=self.kernel_size, 38 | padding=self.padding, 39 | bias=self.bias, 40 | ) 41 | 42 | def forward(self, input_tensor, cur_state): 43 | h_cur, c_cur = cur_state 44 | 45 | combined = torch.cat([input_tensor, h_cur], dim=1) # concatenate along channel axis 46 | 47 | combined_conv = self.conv(combined) 48 | cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1) 49 | i = torch.sigmoid(cc_i) 50 | f = torch.sigmoid(cc_f) 51 | o = torch.sigmoid(cc_o) 52 | g = torch.tanh(cc_g) 53 | 54 | c_next = f * c_cur + i * g 55 | h_next = o * torch.tanh(c_next) 56 | 57 | return h_next, c_next 58 | 59 | def init_hidden(self, batch_size, image_size): 60 | height, width = image_size 61 | return ( 62 | torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device), 63 | torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device), 64 | ) 65 | -------------------------------------------------------------------------------- /satflow/configs/hparams_search/metnet_optuna.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # example hyperparameter optimization of some experiment with Optuna: 4 | # python run.py -m hparams_search=mnist_optuna experiment=example_simple 5 | # python run.py -m hparams_search=mnist_optuna experiment=example_simple hydra.sweeper.n_trials=30 6 | # python run.py -m hparams_search=mnist_optuna experiment=example_simple logger=wandb 7 | 8 | defaults: 9 | - override /hydra/sweeper: optuna 10 | 11 | # choose metric which will be optimized by Optuna 12 | optimized_metric: "val/loss" 13 | 14 | hydra: 15 | # here we define Optuna hyperparameter search 16 | # it optimizes for value returned from function with @hydra.main decorator 17 | # learn more here: https://hydra.cc/docs/next/plugins/optuna_sweeper 18 | sweeper: 19 | _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper 20 | storage: null #"sqlite:///metnet.db" 21 | study_name: metnet 22 | n_jobs: 1 23 | 24 | # 'minimize' or 'maximize' the objective 25 | direction: minimize 26 | 27 | # number of experiments that will be executed 28 | n_trials: 20 29 | 30 | # choose Optuna hyperparameter sampler 31 | # learn more here: https://optuna.readthedocs.io/en/stable/reference/samplers.html 32 | sampler: 33 | _target_: optuna.samplers.TPESampler 34 | seed: 12345 35 | consider_prior: true 36 | prior_weight: 1.0 37 | consider_magic_clip: true 38 | consider_endpoints: false 39 | n_startup_trials: 10 40 | n_ei_candidates: 24 41 | multivariate: false 42 | warn_independent_sampling: true 43 | 44 | # define range of hyperparameters 45 | search_space: 46 | datamodule.batch_size: 47 | type: categorical 48 | choices: [2] 49 | datamodule.config.num_timesteps: 50 | type: categorical 51 | choices: [3, 6] 52 | datamodule.config.skip_timesteps: 53 | type: categorical 54 | choices: [1] 55 | model.lr: 56 | type: float 57 | low: 0.0001 58 | high: 0.2 59 | model.hidden_dim: 60 | type: categorical 61 | choices: [32, 64, 128] 62 | model.num_layers: 63 | type: categorical 64 | choices: [1, 2, 3] 65 | model.num_att_layers: 66 | type: categorical 67 | choices: [1, 2, 3] 68 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SatFlow 2 | 3 | [![All Contributors](https://img.shields.io/badge/all_contributors-2-orange.svg?style=flat-square)](#contributors-) 4 | 5 | ***Sat***ellite Optical ***Flow*** with machine learning models. 6 | 7 | The goal of this repo is to improve upon optical flow models for predicting 8 | future satellite images from current and past ones, focused primarily on EUMETSAT data. 9 | 10 | ## Installation 11 | 12 | Clone the repository, then run 13 | ```shell 14 | conda env create -f environment.yml 15 | conda activate satflow 16 | pip install -e . 17 | ```` 18 | 19 | Alternatively, you can also install a usually older version through ```pip install satflow``` 20 | 21 | ## Data 22 | 23 | The data used here is a combination of the UK Met Office's rainfall radar data, EUMETSAT MSG 24 | satellite data (12 channels), derived data from the MSG satellites (cloud masks, etc.), and 25 | numerical weather prediction data. Currently, some example transformed EUMETSAT data can be downloaded 26 | from the tagged release, as well as included under ```datasets/```. 27 | 28 | ## Contributors ✨ 29 | 30 | Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/docs/en/emoji-key)): 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 |

Jacob Bieker

💻

lewtun

💻
43 | 44 | 45 | 46 | 47 | 48 | 49 | This project follows the [all-contributors](https://github.com/all-contributors/all-contributors) specification. Contributions of any kind welcome! -------------------------------------------------------------------------------- /satflow/models/pixel_cnn.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | import torch.nn.functional as F 4 | from nowcasting_utils.models.base import register_model 5 | from pl_bolts.models.vision import PixelCNN as Pixcnn 6 | 7 | 8 | @register_model 9 | class PixelCNN(pl.LightningModule): 10 | def __init__( 11 | self, 12 | future_timesteps: int, 13 | input_channels: int = 3, 14 | num_layers: int = 5, 15 | num_hidden: int = 64, 16 | pretrained: bool = False, 17 | lr: float = 0.001, 18 | ): 19 | super(PixelCNN, self).__init__() 20 | self.lr = lr 21 | self.model = Pixcnn( 22 | input_channels=input_channels, hidden_channels=num_hidden, num_blocks=num_layers 23 | ) 24 | 25 | @classmethod 26 | def from_config(cls, config): 27 | return PixelCNN( 28 | future_timesteps=config.get("future_timesteps", 12), 29 | input_channels=config.get("in_channels", 12), 30 | features_start=config.get("features", 64), 31 | num_layers=config.get("num_layers", 5), 32 | bilinear=config.get("bilinear", False), 33 | lr=config.get("lr", 0.001), 34 | ) 35 | 36 | def forward(self, x): 37 | self.model.forward(x) 38 | 39 | def configure_optimizers(self): 40 | # DeepSpeedCPUAdam provides 5x to 7x speedup over torch.optim.adam(w) 41 | # optimizer = torch.optim.adam() 42 | return torch.optim.Adam(self.parameters(), lr=self.lr) 43 | 44 | def training_step(self, batch, batch_idx): 45 | x, y = batch 46 | y_hat = self(x) 47 | # Generally only care about the center x crop, so the model can take into account the clouds in the area without 48 | # being penalized for that, but for now, just do general MSE loss, also only care about first 12 channels 49 | loss = F.mse_loss(y_hat, y) 50 | self.log("train/loss", loss, on_step=True) 51 | return loss 52 | 53 | def validation_step(self, batch, batch_idx): 54 | x, y = batch 55 | y_hat = self(x) 56 | val_loss = F.mse_loss(y_hat, y) 57 | self.log("val/loss", val_loss, on_step=True, on_epoch=True) 58 | return val_loss 59 | 60 | def test_step(self, batch, batch_idx): 61 | x, y = batch 62 | y_hat = self(x, self.forecast_steps) 63 | loss = F.mse_loss(y_hat, y) 64 | return loss 65 | -------------------------------------------------------------------------------- /satflow/models/utils.py: -------------------------------------------------------------------------------- 1 | import einops 2 | import numpy as np 3 | import torch 4 | 5 | from satflow.models.layers import CoordConv 6 | 7 | 8 | def get_conv_layer(conv_type: str = "standard") -> torch.nn.Module: 9 | if conv_type == "standard": 10 | conv_layer = torch.nn.Conv2d 11 | elif conv_type == "coord": 12 | conv_layer = CoordConv 13 | elif conv_type == "antialiased": 14 | # TODO Add anti-aliased coordconv here 15 | conv_layer = torch.nn.Conv2d 16 | elif conv_type == "3d": 17 | conv_layer = torch.nn.Conv3d 18 | else: 19 | raise ValueError(f"{conv_type} is not a recognized Conv method") 20 | return conv_layer 21 | 22 | 23 | def reverse_space_to_depth( 24 | frames: np.ndarray, temporal_block_size: int = 1, spatial_block_size: int = 1 25 | ) -> np.ndarray: 26 | """Reverse space to depth transform.""" 27 | if len(frames.shape) == 4: 28 | return einops.rearrange( 29 | frames, 30 | "b h w (dh dw c) -> b (h dh) (w dw) c", 31 | dh=spatial_block_size, 32 | dw=spatial_block_size, 33 | ) 34 | if len(frames.shape) == 5: 35 | return einops.rearrange( 36 | frames, 37 | "b t h w (dt dh dw c) -> b (t dt) (h dh) (w dw) c", 38 | dt=temporal_block_size, 39 | dh=spatial_block_size, 40 | dw=spatial_block_size, 41 | ) 42 | raise ValueError( 43 | "Frames should be of rank 4 (batch, height, width, channels)" 44 | " or rank 5 (batch, time, height, width, channels)" 45 | ) 46 | 47 | 48 | def space_to_depth( 49 | frames: np.ndarray, temporal_block_size: int = 1, spatial_block_size: int = 1 50 | ) -> np.ndarray: 51 | """Space to depth transform.""" 52 | if len(frames.shape) == 4: 53 | return einops.rearrange( 54 | frames, 55 | "b (h dh) (w dw) c -> b h w (dh dw c)", 56 | dh=spatial_block_size, 57 | dw=spatial_block_size, 58 | ) 59 | if len(frames.shape) == 5: 60 | return einops.rearrange( 61 | frames, 62 | "b (t dt) (h dh) (w dw) c -> b t h w (dt dh dw c)", 63 | dt=temporal_block_size, 64 | dh=spatial_block_size, 65 | dw=spatial_block_size, 66 | ) 67 | raise ValueError( 68 | "Frames should be of rank 4 (batch, height, width, channels)" 69 | " or rank 5 (batch, time, height, width, channels)" 70 | ) 71 | -------------------------------------------------------------------------------- /satflow/configs/hparams_search/nowcasting_gan_optuna.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # example hyperparameter optimization of some experiment with Optuna: 4 | # python run.py -m hparams_search=mnist_optuna experiment=example_simple 5 | # python run.py -m hparams_search=mnist_optuna experiment=example_simple hydra.sweeper.n_trials=30 6 | # python run.py -m hparams_search=mnist_optuna experiment=example_simple logger=wandb 7 | 8 | defaults: 9 | - override /hydra/sweeper: optuna 10 | 11 | # choose metric which will be optimized by Optuna 12 | optimized_metric: "val/loss" 13 | 14 | hydra: 15 | # here we define Optuna hyperparameter search 16 | # it optimizes for value returned from function with @hydra.main decorator 17 | # learn more here: https://hydra.cc/docs/next/plugins/optuna_sweeper 18 | sweeper: 19 | _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper 20 | storage: "sqlite:///nowcasting_gan.db" 21 | study_name: nowcasting_gan 22 | n_jobs: 1 23 | 24 | # 'minimize' or 'maximize' the objective 25 | direction: minimize 26 | 27 | # number of experiments that will be executed 28 | n_trials: 20 29 | 30 | # choose Optuna hyperparameter sampler 31 | # learn more here: https://optuna.readthedocs.io/en/stable/reference/samplers.html 32 | sampler: 33 | _target_: optuna.samplers.TPESampler 34 | seed: 12345 35 | consider_prior: true 36 | prior_weight: 1.0 37 | consider_magic_clip: true 38 | consider_endpoints: false 39 | n_startup_trials: 10 40 | n_ei_candidates: 24 41 | multivariate: false 42 | warn_independent_sampling: true 43 | 44 | # define range of hyperparameters 45 | search_space: 46 | datamodule.batch_size: 47 | type: categorical 48 | choices: [2] 49 | datamodule.config.num_timesteps: 50 | type: categorical 51 | choices: [1, 3, 6, 9] 52 | datamodule.config.skip_timesteps: 53 | type: categorical 54 | choices: [1, 2, 3] 55 | model.grid_lambda: 56 | type: float 57 | low: 0.1 58 | high: 200.0 59 | model.num_samples: 60 | type: categorical 61 | choices: [1, 3, 5, 7] 62 | model.conv_type: 63 | type: categorical 64 | choices: ["standard", "coord"] 65 | model.disc_lr: 66 | type: float 67 | low: 0.00002 68 | high: 0.02 69 | model.gen_lr: 70 | type: float 71 | low: 0.000005 72 | high: 0.005 73 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | .idea/ 7 | satflow/logs/ 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | -------------------------------------------------------------------------------- /satflow/configs/hparams_search/perceiver_optuna.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # example hyperparameter optimization of some experiment with Optuna: 4 | # python run.py -m hparams_search=mnist_optuna experiment=example_simple 5 | # python run.py -m hparams_search=mnist_optuna experiment=example_simple hydra.sweeper.n_trials=30 6 | # python run.py -m hparams_search=mnist_optuna experiment=example_simple logger=wandb 7 | 8 | defaults: 9 | - override /hydra/sweeper: optuna 10 | 11 | # choose metric which will be optimized by Optuna 12 | optimized_metric: "val/loss" 13 | 14 | hydra: 15 | # here we define Optuna hyperparameter search 16 | # it optimizes for value returned from function with @hydra.main decorator 17 | # learn more here: https://hydra.cc/docs/next/plugins/optuna_sweeper 18 | sweeper: 19 | _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper 20 | storage: "sqlite:///perceiver.db" 21 | study_name: perceiver 22 | n_jobs: 1 23 | 24 | # 'minimize' or 'maximize' the objective 25 | direction: minimize 26 | 27 | # number of experiments that will be executed 28 | n_trials: 50 29 | 30 | # choose Optuna hyperparameter sampler 31 | # learn more here: https://optuna.readthedocs.io/en/stable/reference/samplers.html 32 | sampler: 33 | _target_: optuna.samplers.TPESampler 34 | seed: null 35 | consider_prior: true 36 | prior_weight: 1.0 37 | consider_magic_clip: true 38 | consider_endpoints: false 39 | n_startup_trials: 10 40 | n_ei_candidates: 24 41 | multivariate: false 42 | warn_independent_sampling: true 43 | 44 | # define range of hyperparameters 45 | search_space: 46 | datamodule.batch_size: 47 | type: categorical 48 | choices: [2] 49 | datamodule.config.num_timesteps: 50 | type: categorical 51 | choices: [1, 3, 6] 52 | datamodule.config.skip_timesteps: 53 | type: categorical 54 | choices: [1, 2, 3] 55 | model.lr: 56 | type: float 57 | low: 0.0001 58 | high: 0.02 59 | model.depth: 60 | type: categorical 61 | choices: [8, 6, 4, 2] 62 | model.cross_heads: 63 | type: categorical 64 | choices: [1, 2] 65 | model.latent_heads: 66 | type: categorical 67 | choices: [2, 4, 8] 68 | model.cross_dim_heads: 69 | type: categorical 70 | choices: [1, 2, 4, 8] 71 | model.self_per_cross_attention: 72 | type: categorical 73 | choices: [1, 2, 4, 8] 74 | model.num_latents: 75 | type: categorical 76 | choices: [64, 128, 256, 512] 77 | model.latent_dim: 78 | type: categorical 79 | choices: [64, 128, 256, 512] 80 | model.max_frequency: 81 | type: float 82 | low: 2.0 83 | high: 32.0 84 | model.preprocessor_type: 85 | type: categorical 86 | choices: ["metnet"] 87 | -------------------------------------------------------------------------------- /satflow/models/perceiverio.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Any, Dict, Iterable, Optional, Tuple, Union 3 | 4 | import einops 5 | import pandas as pd 6 | import torch 7 | import torch.nn.functional as F 8 | import torch_optimizer as optim 9 | from einops import rearrange, repeat 10 | from nowcasting_dataloader.batch import BatchML 11 | from nowcasting_dataset.consts import ( 12 | DEFAULT_N_GSP_PER_EXAMPLE, 13 | DEFAULT_N_PV_SYSTEMS_PER_EXAMPLE, 14 | GSP_DATETIME_INDEX, 15 | GSP_ID, 16 | GSP_YIELD, 17 | NWP_DATA, 18 | PV_SYSTEM_ID, 19 | PV_YIELD, 20 | SATELLITE_DATA, 21 | TOPOGRAPHIC_DATA, 22 | ) 23 | from nowcasting_utils.metrics.validation import ( 24 | make_validation_results, 25 | save_validation_results_to_logger, 26 | ) 27 | from nowcasting_utils.models.base import BaseModel, register_model 28 | from nowcasting_utils.models.loss import get_loss 29 | from nowcasting_utils.visualization.line import plot_batch_results 30 | from nowcasting_utils.visualization.visualization import plot_example 31 | from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR 32 | from transformers import ( 33 | PerceiverConfig, 34 | PerceiverForImageClassificationLearned, 35 | PerceiverForMultimodalAutoencoding, 36 | PerceiverForOpticalFlow, 37 | PerceiverModel, 38 | ) 39 | 40 | logger = logging.getLogger("satflow.model") 41 | logger.setLevel(logging.WARN) 42 | 43 | HRV_KEY = "hrv_" + SATELLITE_DATA 44 | 45 | 46 | class HuggingFacePerceiver(BaseModel): 47 | def __init__(self, input_size: int = 224): 48 | self.model = PerceiverForOpticalFlow.from_pretrained( 49 | "deepmind/optical-flow-perceiver", 50 | ignore_mismatched_sizes=True, 51 | train_size=[input_size, input_size], 52 | ) 53 | 54 | self.channel_change = torch.nn.Conv2d(in_channels=2, out_channels=11) 55 | self.predict_satellite = False 56 | self.predict_hrv_satellite = True 57 | self.hrv_channel_change = torch.nn.Conv2d(in_channels=2, out_channels=1) 58 | 59 | def forward(self, x, **kwargs) -> Any: 60 | return model(inputs=x) 61 | 62 | def _train_or_validate_step(self, batch, batch_idx, is_training: bool = True): 63 | x, y = batch 64 | # Now run predictions for all the queries 65 | # Predicting all future ones at once 66 | losses = [] 67 | if self.predict_satellite: 68 | sat_y_hat = self.model(inputs=x) 69 | sat_y_hat = self.channel_change(sat_y_hat) 70 | # Satellite losses 71 | sat_loss, sat_frame_loss = self.mse(hrv_sat_y_hat, y[SATELLITE_DATA]) 72 | losses.append(sat_loss) 73 | if self.predict_hrv_satellite: 74 | hrv_sat_y_hat = self.model(inputs=x) 75 | hrv_sat_y_hat = self.hrv_channel_change(hrv_sat_y_hat) 76 | # HRV Satellite losses 77 | hrv_sat_loss, sat_frame_loss = self.mse(hrv_sat_y_hat, y[HRV_KEY]) 78 | losses.append(hrv_sat_loss) 79 | loss = losses[0] 80 | for sat_loss in losses[1:]: 81 | loss += sat_loss 82 | self.log_dict({f"{'train' if is_training else 'val'}/loss": loss}) 83 | if is_training: 84 | return loss 85 | else: 86 | # Return the model outputs as well 87 | return loss 88 | 89 | def configure_optimizers(self): 90 | return torch.optim.AdamW(self.parameters(), lr=self.lr) 91 | -------------------------------------------------------------------------------- /satflow/models/layers/Normalization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import Parameter 4 | 5 | 6 | def l2normalize(v, eps=1e-12): 7 | return v / (v.norm() + eps) 8 | 9 | 10 | class SpectralNorm(nn.Module): 11 | def __init__(self, module, name="weight", power_iterations=1): 12 | super(SpectralNorm, self).__init__() 13 | self.module = module 14 | self.name = name 15 | self.power_iterations = power_iterations 16 | if not self._made_params(): 17 | self._make_params() 18 | 19 | def _update_u_v(self): 20 | u = getattr(self.module, self.name + "_u") 21 | v = getattr(self.module, self.name + "_v") 22 | w = getattr(self.module, self.name + "_bar") 23 | 24 | height = w.data.shape[0] 25 | for _ in range(self.power_iterations): 26 | v.data = l2normalize(torch.mv(torch.t(w.view(height, -1).data), u.data)) 27 | u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data)) 28 | 29 | # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data)) 30 | sigma = u.dot(w.view(height, -1).mv(v)) 31 | setattr(self.module, self.name, w / sigma.expand_as(w)) 32 | 33 | def _made_params(self): 34 | try: 35 | u = getattr(self.module, self.name + "_u") 36 | v = getattr(self.module, self.name + "_v") 37 | w = getattr(self.module, self.name + "_bar") 38 | return True 39 | except AttributeError: 40 | return False 41 | 42 | def _make_params(self): 43 | w = getattr(self.module, self.name) 44 | 45 | height = w.data.shape[0] 46 | width = w.view(height, -1).data.shape[1] 47 | 48 | u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) 49 | v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False) 50 | u.data = l2normalize(u.data) 51 | v.data = l2normalize(v.data) 52 | w_bar = Parameter(w.data) 53 | 54 | del self.module._parameters[self.name] 55 | 56 | self.module.register_parameter(self.name + "_u", u) 57 | self.module.register_parameter(self.name + "_v", v) 58 | self.module.register_parameter(self.name + "_bar", w_bar) 59 | 60 | def forward(self, *args): 61 | self._update_u_v() 62 | return self.module.forward(*args) 63 | 64 | 65 | class ConditionalNorm(nn.Module): 66 | def __init__(self, in_channel, n_condition=96): 67 | super().__init__() 68 | 69 | self.in_channel = in_channel 70 | self.bn = nn.BatchNorm2d(self.in_channel, affine=False) 71 | 72 | self.embed = nn.Linear(n_condition, self.in_channel * 2) 73 | self.embed.weight.data[:, : self.in_channel].normal_(1, 0.02) 74 | self.embed.weight.data[:, self.in_channel :].zero_() 75 | 76 | def forward(self, x, class_id): 77 | out = self.bn(x) 78 | embed = self.embed(class_id) 79 | gamma, beta = embed.chunk(2, 1) 80 | # gamma = gamma.unsqueeze(2).unsqueeze(3) 81 | # beta = beta.unsqueeze(2).unsqueeze(3) 82 | gamma = gamma.view(-1, self.in_channel, 1, 1) 83 | beta = beta.view(-1, self.in_channel, 1, 1) 84 | out = gamma * out + beta 85 | 86 | return out 87 | 88 | 89 | if __name__ == "__main__": 90 | 91 | cn = ConditionalNorm(3, 2) 92 | x = torch.rand([4, 3, 64, 64]) 93 | class_id = torch.rand([4, 2]) 94 | y = cn(x, class_id) 95 | print(cn) 96 | print(x.size()) 97 | print(y.size()) 98 | -------------------------------------------------------------------------------- /tests/test_models.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import yaml 4 | from nowcasting_dataset.consts import NWP_DATA, SATELLITE_DATA, TOPOGRAPHIC_DATA 5 | from nowcasting_utils.models.base import create_model, list_models 6 | 7 | from satflow.models import LitMetNet, Perceiver 8 | 9 | 10 | def load_config(config_file): 11 | with open(config_file, "r") as cfg: 12 | return yaml.load(cfg, Loader=yaml.FullLoader) 13 | 14 | 15 | def test_perceiver_creation(): 16 | config = load_config("satflow/configs/model/perceiver.yaml") 17 | config.pop("_target_") # This is only for Hydra 18 | model = Perceiver(**config) 19 | x = { 20 | SATELLITE_DATA: torch.randn( 21 | (2, 6, config["input_size"], config["input_size"], config["sat_channels"]) 22 | ), 23 | TOPOGRAPHIC_DATA: torch.randn((2, config["input_size"], config["input_size"], 1)), 24 | NWP_DATA: torch.randn( 25 | (2, 6, config["input_size"], config["input_size"], config["nwp_channels"]) 26 | ), 27 | "forecast_time": torch.randn(2, config["forecast_steps"], 1), 28 | } 29 | query = torch.randn((2, config["input_size"] * config["sat_channels"], config["queries_dim"])) 30 | model.eval() 31 | with torch.no_grad(): 32 | out = model(x, query=query) 33 | # MetNet creates predictions for the center 1/4th 34 | assert out.size() == ( 35 | 2, 36 | config["forecast_steps"] * config["input_size"], 37 | config["sat_channels"] * config["input_size"], 38 | ) 39 | assert not torch.isnan(out).any(), "Output included NaNs" 40 | 41 | 42 | def test_metnet_creation(): 43 | config = load_config("satflow/configs/model/metnet.yaml") 44 | config.pop("_target_") # This is only for Hydra 45 | model = LitMetNet(**config) 46 | # MetNet expects original HxW to be 4x the input size 47 | x = torch.randn( 48 | (2, 12, config["input_channels"], config["input_size"] * 4, config["input_size"] * 4) 49 | ) 50 | model.eval() 51 | with torch.no_grad(): 52 | out = model(x) 53 | # MetNet creates predictions for the center 1/4th 54 | assert out.size() == ( 55 | 2, 56 | config["forecast_steps"], 57 | config["output_channels"], 58 | config["input_size"] // 4, 59 | config["input_size"] // 4, 60 | ) 61 | assert not torch.isnan(out).any(), "Output included NaNs" 62 | 63 | 64 | @pytest.mark.parametrize("model_name", list_models()) 65 | def test_create_model(model_name): 66 | """ 67 | Test that create model works for all registered models 68 | Args: 69 | model_name: 70 | 71 | Returns: 72 | 73 | """ 74 | # TODO Load options from all configs and make sure they work 75 | model = create_model(model_name) 76 | pass 77 | 78 | 79 | @pytest.mark.skip( 80 | "Perceiver has changed in SatFlow, doesn't have the same options as the one on HF" 81 | ) 82 | def test_load_hf(): 83 | """ 84 | Current only HF model is PerceiverIO, change in future to do all ones 85 | Returns: 86 | 87 | """ 88 | model = create_model("hf_hub:openclimatefix/perceiver-io") 89 | pass 90 | 91 | 92 | @pytest.mark.skip( 93 | "Perceiver has changed in SatFlow, doesn't have the same options as the one on HF" 94 | ) 95 | def test_load_hf_pretrained(): 96 | """ 97 | Current only HF model is PerceiverIO, change in future to do all ones 98 | Returns: 99 | 100 | """ 101 | model = create_model("hf_hub:openclimatefix/perceiver-io", pretrained=True) 102 | pass 103 | -------------------------------------------------------------------------------- /satflow/experiments/train.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import hydra 4 | from omegaconf import DictConfig 5 | from pytorch_lightning import ( 6 | Callback, 7 | LightningDataModule, 8 | LightningModule, 9 | Trainer, 10 | seed_everything, 11 | ) 12 | from pytorch_lightning.callbacks import LearningRateMonitor 13 | from pytorch_lightning.loggers import LightningLoggerBase 14 | 15 | from satflow.core import utils 16 | from satflow.core.callbacks import NeptuneModelLogger 17 | 18 | log = utils.get_logger(__name__) 19 | 20 | 21 | def train(config: DictConfig) -> Optional[float]: 22 | """Contains training pipeline. 23 | Instantiates all PyTorch Lightning objects from config. 24 | 25 | Args: 26 | config (DictConfig): Configuration composed by Hydra. 27 | 28 | Returns: 29 | Optional[float]: Metric score for hyperparameter optimization. 30 | """ 31 | 32 | # Set seed for random number generators in pytorch, numpy and python.random 33 | if "seed" in config: 34 | seed_everything(config.seed, workers=True) 35 | 36 | # If required 37 | # Init Dataloaders 38 | log.info(f"Instantiating datamodule <{config.datamodule._target_}>") 39 | datamodule: LightningDataModule = hydra.utils.instantiate( 40 | config.datamodule, _convert_="partial" 41 | ) 42 | 43 | # Init Lightning model 44 | log.info(f"Instantiating model <{config.model._target_}>") 45 | model: LightningModule = hydra.utils.instantiate(config.model) 46 | 47 | # Init Lightning callbacks 48 | lr_monitor = LearningRateMonitor(logging_interval="step") 49 | callbacks: List[Callback] = [lr_monitor, NeptuneModelLogger()] 50 | if "callbacks" in config: 51 | for _, cb_conf in config["callbacks"].items(): 52 | if "_target_" in cb_conf: 53 | log.info(f"Instantiating callback <{cb_conf._target_}>") 54 | callbacks.append(hydra.utils.instantiate(cb_conf)) 55 | 56 | # Init Lightning loggers 57 | logger: List[LightningLoggerBase] = [] 58 | if "logger" in config: 59 | for _, lg_conf in config["logger"].items(): 60 | if "_target_" in lg_conf: 61 | log.info(f"Instantiating logger <{lg_conf._target_}>") 62 | logger.append(hydra.utils.instantiate(lg_conf)) 63 | 64 | # Init Lightning trainer 65 | log.info(f"Instantiating trainer <{config.trainer._target_}>") 66 | trainer: Trainer = hydra.utils.instantiate( 67 | config.trainer, 68 | callbacks=callbacks, 69 | logger=logger, 70 | ) 71 | 72 | # Send some parameters from config to all lightning loggers 73 | log.info("Logging hyperparameters!") 74 | utils.log_hyperparameters( 75 | config=config, 76 | model=model, 77 | trainer=trainer, 78 | ) 79 | 80 | # Train the model 81 | if config.trainer.auto_lr_find or config.trainer.auto_scale_batch_size: 82 | log.info("Starting tuning!") 83 | trainer.tune(model=model, datamodule=datamodule) 84 | log.info("Starting training!") 85 | trainer.fit(model=model, datamodule=datamodule) 86 | 87 | # Evaluate model on test set after training 88 | if not config.trainer.get("fast_dev_run", False): 89 | log.info("Starting testing!") 90 | trainer.test(model=model, datamodule=datamodule) 91 | 92 | # Print path to best checkpoint 93 | log.info(f"Best checkpoint path:\n{trainer.checkpoint_callback.best_model_path}") 94 | 95 | # Return metric score for hyperparameter optimization 96 | optimized_metric = config.get("optimized_metric") 97 | if optimized_metric: 98 | return trainer.callback_metrics[optimized_metric] 99 | -------------------------------------------------------------------------------- /satflow/baseline/optical_flow.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch.nn.functional as F 4 | import webdataset as wds 5 | import yaml 6 | 7 | from satflow.data.datasets import SatFlowDataset 8 | 9 | 10 | def load_config(config_file): 11 | with open(config_file, "r") as cfg: 12 | return yaml.load(cfg, Loader=yaml.FullLoader)["config"] 13 | 14 | 15 | config = load_config("/satflow/configs/datamodule/optical_flow.yaml") 16 | dset = wds.WebDataset("/run/media/jacob/data/satflow-flow-144-tiled-{00001..00149}.tar") 17 | 18 | dataset = SatFlowDataset([dset], config=config) 19 | 20 | import matplotlib.pyplot as plt 21 | import torch 22 | 23 | 24 | def warp_flow(img, flow): 25 | h, w = flow.shape[:2] 26 | flow = -flow 27 | flow[:, :, 0] += np.arange(w) 28 | flow[:, :, 1] += np.arange(h)[:, np.newaxis] 29 | res = cv2.remap(img, flow, None, cv2.INTER_LINEAR) 30 | return res 31 | 32 | 33 | debug = False 34 | total_losses = np.array([0.0 for _ in range(48)]) # Want to break down loss by future timestep 35 | channel_total_losses = np.array([total_losses for _ in range(12)]) 36 | count = 0 37 | baseline_losses = np.array([0.0 for _ in range(48)]) # Want to break down loss by future timestep 38 | channel_baseline_losses = np.array([baseline_losses for _ in range(12)]) 39 | 40 | for data in dataset: 41 | tmp_loss = 0 42 | tmp_base = 0 43 | count += 1 44 | past_frames, next_frames = data 45 | prev_frame = past_frames[1] 46 | curr_frame = past_frames[0] 47 | # Do it for each of the 12 channels 48 | for ch in range(12): 49 | # prev_frame = np.moveaxis(prev_frame, [0], [2]) 50 | # curr_frame = np.moveaxis(curr_frame, [0], [2]) 51 | flow = cv2.calcOpticalFlowFarneback( 52 | past_frames[1][ch], past_frames[0][ch], None, 0.5, 3, 15, 3, 5, 1.2, 0 53 | ) 54 | warped_frame = warp_flow(curr_frame[ch].astype(np.float32), flow) 55 | warped_frame = np.expand_dims(warped_frame, axis=-1) 56 | loss = F.mse_loss( 57 | torch.from_numpy(warped_frame), 58 | torch.from_numpy(np.expand_dims(next_frames[0][ch], axis=-1)), 59 | ) 60 | channel_total_losses[ch][0] += loss.item() 61 | loss = F.mse_loss( 62 | torch.from_numpy(curr_frame[ch].astype(np.float32)), 63 | torch.from_numpy(next_frames[0][ch]), 64 | ) 65 | channel_baseline_losses[ch][0] += loss.item() 66 | 67 | for i in range(1, 48): 68 | warped_frame = warp_flow(warped_frame.astype(np.float32), flow) 69 | warped_frame = np.expand_dims(warped_frame, axis=-1) 70 | loss = F.mse_loss( 71 | torch.from_numpy(warped_frame), 72 | torch.from_numpy(np.expand_dims(next_frames[i][ch], axis=-1)), 73 | ) 74 | channel_total_losses[ch][i] += loss.item() 75 | tmp_loss += loss.item() 76 | loss = F.mse_loss( 77 | torch.from_numpy(curr_frame[ch].astype(np.float32)), 78 | torch.from_numpy(next_frames[i][ch]), 79 | ) 80 | channel_baseline_losses[ch][i] += loss.item() 81 | print( 82 | f"Avg Total Loss: {np.mean(channel_total_losses) / count} Avg Baseline Loss: {np.mean(channel_baseline_losses) / count}" 83 | ) 84 | if count % 100 == 0: 85 | np.save("optical_flow_mse_loss_channels_reverse.npy", channel_total_losses / count) 86 | np.save( 87 | "baseline_current_image_mse_loss_channels_reverse.npy", channel_baseline_losses / count 88 | ) 89 | np.save("optical_flow_mse_loss_reverse.npy", channel_total_losses / count) 90 | np.save("baseline_current_image_mse_loss_reverse.npy", channel_baseline_losses / count) 91 | -------------------------------------------------------------------------------- /satflow/models/layers/GResBlock.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | from satflow.models.layers.Normalization import ConditionalNorm, SpectralNorm 6 | 7 | 8 | class GResBlock(nn.Module): 9 | def __init__( 10 | self, 11 | in_channel, 12 | out_channel, 13 | kernel_size=None, 14 | padding=1, 15 | stride=1, 16 | n_class=96, 17 | bn=True, 18 | activation=F.relu, 19 | upsample_factor=2, 20 | downsample_factor=1, 21 | ): 22 | super().__init__() 23 | 24 | self.upsample_factor = upsample_factor if downsample_factor is 1 else 1 25 | self.downsample_factor = downsample_factor 26 | self.activation = activation 27 | self.bn = bn if downsample_factor is 1 else False 28 | 29 | if kernel_size is None: 30 | kernel_size = [3, 3] 31 | 32 | self.conv0 = SpectralNorm( 33 | nn.Conv2d( 34 | in_channel, out_channel, kernel_size, stride, padding, bias=True if bn else True 35 | ) 36 | ) 37 | self.conv1 = SpectralNorm( 38 | nn.Conv2d( 39 | out_channel, out_channel, kernel_size, stride, padding, bias=True if bn else True 40 | ) 41 | ) 42 | 43 | self.skip_proj = True 44 | self.conv_sc = SpectralNorm(nn.Conv2d(in_channel, out_channel, 1, 1, 0)) 45 | 46 | # if in_channel != out_channel or upsample_factor or downsample_factor: 47 | # self.conv_sc = SpectralNorm(nn.Conv2d(in_channel, out_channel, 1, 1, 0)) 48 | # self.skip_proj = True 49 | 50 | if bn: 51 | self.CBNorm1 = ConditionalNorm(in_channel, n_class) # TODO 2 x noise.size[1] 52 | self.CBNorm2 = ConditionalNorm(out_channel, n_class) 53 | 54 | def forward(self, x, condition=None): 55 | 56 | # The time dimension is combined with the batch dimension here, so each frame proceeds 57 | # through the blocks independently 58 | BT, C, W, H = x.size() 59 | out = x 60 | 61 | if self.bn: 62 | out = self.CBNorm1(out, condition) 63 | 64 | out = self.activation(out) 65 | 66 | if self.upsample_factor != 1: 67 | out = F.interpolate(out, scale_factor=self.upsample_factor) 68 | 69 | out = self.conv0(out) 70 | 71 | if self.bn: 72 | out = out.view(BT, -1, W * self.upsample_factor, H * self.upsample_factor) 73 | out = self.CBNorm2(out, condition) 74 | 75 | out = self.activation(out) 76 | out = self.conv1(out) 77 | 78 | if self.downsample_factor != 1: 79 | out = F.avg_pool2d(out, self.downsample_factor) 80 | 81 | if self.skip_proj: 82 | skip = x 83 | if self.upsample_factor != 1: 84 | skip = F.interpolate(skip, scale_factor=self.upsample_factor) 85 | skip = self.conv_sc(skip) 86 | if self.downsample_factor != 1: 87 | skip = F.avg_pool2d(skip, self.downsample_factor) 88 | else: 89 | skip = x 90 | 91 | y = out + skip 92 | y = y.view( 93 | BT, 94 | -1, 95 | W * self.upsample_factor // self.downsample_factor, 96 | H * self.upsample_factor // self.downsample_factor, 97 | ) 98 | 99 | return y 100 | 101 | 102 | if __name__ == "__main__": 103 | 104 | n_class = 96 105 | batch_size = 4 106 | n_frames = 20 107 | 108 | gResBlock = GResBlock(3, 100, [3, 3]) 109 | x = torch.rand([batch_size * n_frames, 3, 64, 64]) 110 | condition = torch.rand([batch_size, n_class]) 111 | condition = condition.repeat(n_frames, 1) 112 | y = gResBlock(x, condition) 113 | print(gResBlock) 114 | print(x.size()) 115 | print(y.size()) 116 | 117 | # with SummaryWriter(comment='gResBlock') as w: 118 | # w.add_graph(gResBlock, [x, condition, ]) 119 | -------------------------------------------------------------------------------- /satflow/data/datasets.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple, Union 2 | 3 | import numpy as np 4 | from nowcasting_dataset.config.model import Configuration 5 | from nowcasting_dataset.consts import ( 6 | DATETIME_FEATURE_NAMES, 7 | NWP_DATA, 8 | NWP_X_COORDS, 9 | NWP_Y_COORDS, 10 | SATELLITE_DATA, 11 | SATELLITE_DATETIME_INDEX, 12 | SATELLITE_X_COORDS, 13 | SATELLITE_Y_COORDS, 14 | TOPOGRAPHIC_DATA, 15 | ) 16 | from nowcasting_dataset.dataset.datasets import NetCDFDataset 17 | 18 | 19 | class SatFlowDataset(NetCDFDataset): 20 | """Loads data saved by the `prepare_ml_training_data.py` script. 21 | Adapted from predict_pv_yield 22 | """ 23 | 24 | def __init__( 25 | self, 26 | n_batches: int, 27 | src_path: str, 28 | tmp_path: str, 29 | configuration: Configuration, 30 | cloud: str = "gcp", 31 | required_keys: Union[Tuple[str], List[str]] = [ 32 | NWP_DATA, 33 | NWP_X_COORDS, 34 | NWP_Y_COORDS, 35 | SATELLITE_DATA, 36 | SATELLITE_X_COORDS, 37 | SATELLITE_Y_COORDS, 38 | SATELLITE_DATETIME_INDEX, 39 | TOPOGRAPHIC_DATA, 40 | ] 41 | + list(DATETIME_FEATURE_NAMES), 42 | history_minutes: int = 30, 43 | forecast_minutes: int = 60, 44 | combine_inputs: bool = False, 45 | ): 46 | """ 47 | Args: 48 | n_batches: Number of batches available on disk. 49 | src_path: The full path (including 'gs://') to the data on 50 | Google Cloud storage. 51 | tmp_path: The full path to the local temporary directory 52 | (on a local filesystem). 53 | batch_size: Batch size, if requested, will subset data along batch dimension 54 | """ 55 | super().__init__( 56 | n_batches, 57 | src_path, 58 | tmp_path, 59 | configuration, 60 | cloud, 61 | required_keys, 62 | history_minutes, 63 | forecast_minutes, 64 | ) 65 | # SatFlow specific changes, i.e. which timestep to split on 66 | self.required_keys = list(required_keys) 67 | self.combine_inputs = combine_inputs 68 | self.current_timestep_index = (history_minutes // 5) + 1 69 | 70 | def __getitem__(self, batch_idx: int): 71 | batch = super().__getitem__(batch_idx) 72 | 73 | # Need to partition out past and future sat images here, along with the rest of the data 74 | past_satellite_data = batch[SATELLITE_DATA][:, : self.current_timestep_index] 75 | future_sat_data = batch[SATELLITE_DATA][:, self.current_timestep_index :] 76 | x = { 77 | SATELLITE_DATA: past_satellite_data, 78 | SATELLITE_X_COORDS: batch.get(SATELLITE_X_COORDS, None), 79 | SATELLITE_Y_COORDS: batch.get(SATELLITE_Y_COORDS, None), 80 | SATELLITE_DATETIME_INDEX: batch[SATELLITE_DATETIME_INDEX][ 81 | :, : self.current_timestep_index 82 | ], 83 | } 84 | y = { 85 | SATELLITE_DATA: future_sat_data, 86 | SATELLITE_DATETIME_INDEX: batch[SATELLITE_DATETIME_INDEX][ 87 | :, self.current_timestep_index : 88 | ], 89 | } 90 | 91 | for k in list(DATETIME_FEATURE_NAMES): 92 | if k in self.required_keys: 93 | x[k] = batch[k][:, : self.current_timestep_index] 94 | 95 | if NWP_DATA in self.required_keys: 96 | past_nwp_data = batch[NWP_DATA][:, :, : self.current_timestep_index] 97 | x[NWP_DATA] = past_nwp_data 98 | x[NWP_X_COORDS] = batch.get(NWP_X_COORDS, None) 99 | x[NWP_Y_COORDS] = batch.get(NWP_Y_COORDS, None) 100 | 101 | if TOPOGRAPHIC_DATA in self.required_keys: 102 | # Need to expand dims to get a single channel one 103 | # Results in topographic maps with [Batch, Channel, H, W] 104 | x[TOPOGRAPHIC_DATA] = np.expand_dims(batch[TOPOGRAPHIC_DATA], axis=1) 105 | 106 | return x, y 107 | -------------------------------------------------------------------------------- /satflow/models/unet.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import numpy as np 4 | import pytorch_lightning as pl 5 | import torch 6 | import torchvision 7 | from nowcasting_utils.models.base import register_model 8 | from nowcasting_utils.models.loss import get_loss 9 | from pl_bolts.models.vision import UNet 10 | 11 | 12 | @register_model 13 | class Unet(pl.LightningModule): 14 | def __init__( 15 | self, 16 | forecast_steps: int, 17 | input_channels: int = 3, 18 | num_layers: int = 5, 19 | hidden_dim: int = 64, 20 | bilinear: bool = False, 21 | lr: float = 0.001, 22 | visualize: bool = False, 23 | loss: Union[str, torch.nn.Module] = "mse", 24 | pretrained: bool = False, 25 | ): 26 | super(Unet, self).__init__() 27 | self.lr = lr 28 | self.input_channels = input_channels 29 | self.forecast_steps = forecast_steps 30 | self.criterion = get_loss(loss=loss) 31 | self.visualize = visualize 32 | self.model = UNet(forecast_steps, input_channels, num_layers, hidden_dim, bilinear) 33 | self.save_hyperparameters() 34 | 35 | @classmethod 36 | def from_config(cls, config): 37 | return Unet( 38 | forecast_steps=config.get("forecast_steps", 12), 39 | input_channels=config.get("in_channels", 12), 40 | hidden_dim=config.get("features", 64), 41 | num_layers=config.get("num_layers", 5), 42 | bilinear=config.get("bilinear", False), 43 | lr=config.get("lr", 0.001), 44 | ) 45 | 46 | def forward(self, x): 47 | return self.model.forward(x) 48 | 49 | def configure_optimizers(self): 50 | # DeepSpeedCPUAdam provides 5x to 7x speedup over torch.optim.adam(w) 51 | # optimizer = torch.optim.adam() 52 | return torch.optim.Adam(self.parameters(), lr=self.lr) 53 | 54 | def training_step(self, batch, batch_idx): 55 | x, y = batch 56 | x = x.float() 57 | y_hat = self(x) 58 | 59 | if self.visualize: 60 | if np.random.random() < 0.01: 61 | self.visualize_step(x, y, y_hat, batch_idx) 62 | # Generally only care about the center x crop, so the model can take into account the clouds in the area without 63 | # being penalized for that, but for now, just do general MSE loss, also only care about first 12 channels 64 | loss = self.criterion(y_hat, y) 65 | self.log("train/loss", loss, on_step=True) 66 | frame_loss_dict = {} 67 | for f in range(self.forecast_steps): 68 | frame_loss = self.criterion(y_hat[:, f, :, :], y[:, f, :, :]).item() 69 | frame_loss_dict[f"train/frame_{f}_loss"] = frame_loss 70 | self.log_dict(frame_loss_dict) 71 | return loss 72 | 73 | def validation_step(self, batch, batch_idx): 74 | x, y = batch 75 | x = x.float() 76 | y_hat = self(x) 77 | val_loss = self.criterion(y_hat, y) 78 | self.log("val/loss", val_loss) 79 | # Save out loss per frame as well 80 | frame_loss_dict = {} 81 | for f in range(self.forecast_steps): 82 | frame_loss = self.criterion(y_hat[:, f, :, :], y[:, f, :, :]).item() 83 | frame_loss_dict[f"val/frame_{f}_loss"] = frame_loss 84 | self.log_dict(frame_loss_dict) 85 | return val_loss 86 | 87 | def test_step(self, batch, batch_idx): 88 | x, y = batch 89 | x = x.float() 90 | y_hat = self(x) 91 | loss = self.criterion(y_hat, y) 92 | return loss 93 | 94 | def visualize_step(self, x, y, y_hat, batch_idx, step="train"): 95 | tensorboard = self.logger.experiment[0] 96 | # Add all the different timesteps for a single prediction, 0.1% of the time 97 | images = x[0].cpu().detach() 98 | images = [torch.unsqueeze(img, dim=0) for img in images] 99 | image_grid = torchvision.utils.make_grid(images, nrow=self.channels_per_timestep) 100 | tensorboard.add_image(f"{step}/Input_Image_Stack", image_grid, global_step=batch_idx) 101 | images = y[0].cpu().detach() 102 | images = [torch.unsqueeze(img, dim=0) for img in images] 103 | image_grid = torchvision.utils.make_grid(images, nrow=12) 104 | tensorboard.add_image(f"{step}/Target_Image_Stack", image_grid, global_step=batch_idx) 105 | images = y_hat[0].cpu().detach() 106 | images = [torch.unsqueeze(img, dim=0) for img in images] 107 | image_grid = torchvision.utils.make_grid(images, nrow=12) 108 | tensorboard.add_image(f"{step}/Generated_Image_Stack", image_grid, global_step=batch_idx) 109 | -------------------------------------------------------------------------------- /satflow/models/fcn.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import numpy as np 4 | import pytorch_lightning as pl 5 | import torch 6 | import torch.nn.functional as F 7 | from nowcasting_utils.models.base import register_model 8 | from nowcasting_utils.models.losses.FocalLoss import FocalLoss 9 | from torchvision.models.segmentation import fcn_resnet50, fcn_resnet101 10 | 11 | 12 | @register_model 13 | class FCN(pl.LightningModule): 14 | def __init__( 15 | self, 16 | forecast_steps: int = 48, 17 | input_channels: int = 12, 18 | lr: float = 0.001, 19 | make_vis: bool = False, 20 | loss: Union[str, torch.nn.Module] = "mse", 21 | backbone: str = "resnet50", 22 | pretrained: bool = False, 23 | ): 24 | super(FCN, self).__init__() 25 | self.lr = lr 26 | assert loss in ["mse", "bce", "binary_crossentropy", "crossentropy", "focal"] 27 | if loss == "mse": 28 | self.criterion = F.mse_loss 29 | elif loss in ["bce", "binary_crossentropy", "crossentropy"]: 30 | self.criterion = F.nll_loss 31 | elif loss in ["focal"]: 32 | self.criterion = FocalLoss() 33 | else: 34 | raise ValueError(f"loss {loss} not recognized") 35 | self.make_vis = make_vis 36 | if backbone in ["r101", "resnet101"]: 37 | self.model = fcn_resnet101(pretrained=pretrained, num_classes=forecast_steps) 38 | else: 39 | self.model = fcn_resnet50(pretrained=pretrained, num_classes=forecast_steps) 40 | 41 | if input_channels != 3: 42 | self.model.backbone.conv1 = torch.nn.Conv2d( 43 | input_channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False 44 | ) 45 | self.save_hyperparameters() 46 | 47 | @classmethod 48 | def from_config(cls, config): 49 | return DeeplabV3( 50 | forecast_steps=config.get("forecast_steps", 12), 51 | input_channels=config.get("in_channels", 12), 52 | hidden_dim=config.get("features", 64), 53 | num_layers=config.get("num_layers", 5), 54 | bilinear=config.get("bilinear", False), 55 | lr=config.get("lr", 0.001), 56 | ) 57 | 58 | def forward(self, x): 59 | return self.model.forward(x) 60 | 61 | def configure_optimizers(self): 62 | # DeepSpeedCPUAdam provides 5x to 7x speedup over torch.optim.adam(w) 63 | # optimizer = torch.optim.adam() 64 | return torch.optim.Adam(self.parameters(), lr=self.lr) 65 | 66 | def training_step(self, batch, batch_idx): 67 | x, y = batch 68 | y_hat = self(x) 69 | 70 | if self.make_vis: 71 | if np.random.random() < 0.01: 72 | self.visualize(x, y, y_hat, batch_idx) 73 | # Generally only care about the center x crop, so the model can take into account the clouds in the area without 74 | # being penalized for that, but for now, just do general MSE loss, also only care about first 12 channels 75 | loss = self.criterion(y_hat, y) 76 | self.log("train/loss", loss, on_step=True) 77 | return loss 78 | 79 | def validation_step(self, batch, batch_idx): 80 | x, y = batch 81 | y_hat = self(x) 82 | val_loss = self.criterion(y_hat, y) 83 | self.log("val/loss", val_loss, on_step=True, on_epoch=True) 84 | return val_loss 85 | 86 | def test_step(self, batch, batch_idx): 87 | x, y = batch 88 | y_hat = self(x, self.forecast_steps) 89 | loss = self.criterion(y_hat, y) 90 | return loss 91 | 92 | def visualize(self, x, y, y_hat, batch_idx): 93 | # the logger you used (in this case tensorboard) 94 | tensorboard = self.logger.experiment 95 | # Add all the different timesteps for a single prediction, 0.1% of the time 96 | in_image = ( 97 | x[0].cpu().detach().numpy() 98 | ) # Input image stack, Unet takes everything in channels, so no time dimension 99 | for i, in_slice in enumerate(in_image): 100 | j = 0 101 | if i % self.input_channels == 0: # First one 102 | j += 1 103 | tensorboard.add_image( 104 | f"Input_Image_{j}_Channel_{i}", in_slice, global_step=batch_idx 105 | ) # Each Channel 106 | out_image = y_hat[0].cpu().detach().numpy() 107 | for i, out_slice in enumerate(out_image): 108 | tensorboard.add_image( 109 | f"Output_Image_{i}", out_slice, global_step=batch_idx 110 | ) # Each Channel 111 | out_image = y[0].cpu().detach().numpy() 112 | for i, out_slice in enumerate(out_image): 113 | tensorboard.add_image( 114 | f"Target_Image_{i}", out_slice, global_step=batch_idx 115 | ) # Each Channel 116 | -------------------------------------------------------------------------------- /satflow/models/deeplabv3.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import numpy as np 4 | import pytorch_lightning as pl 5 | import torch 6 | import torch.nn.functional as F 7 | from nowcasting_utils.models.base import register_model 8 | from nowcasting_utils.models.losses.FocalLoss import FocalLoss 9 | from torchvision.models.segmentation import deeplabv3_resnet50, deeplabv3_resnet101 10 | 11 | 12 | @register_model 13 | class DeeplabV3(pl.LightningModule): 14 | def __init__( 15 | self, 16 | forecast_steps: int = 48, 17 | input_channels: int = 12, 18 | lr: float = 0.001, 19 | make_vis: bool = False, 20 | loss: Union[str, torch.nn.Module] = "mse", 21 | backbone: str = "resnet50", 22 | pretrained: bool = False, 23 | aux_loss: bool = False, 24 | ): 25 | super(DeeplabV3, self).__init__() 26 | self.lr = lr 27 | assert loss in ["mse", "bce", "binary_crossentropy", "crossentropy", "focal"] 28 | if loss == "mse": 29 | self.criterion = F.mse_loss 30 | elif loss in ["bce", "binary_crossentropy", "crossentropy"]: 31 | self.criterion = F.nll_loss 32 | elif loss in ["focal"]: 33 | self.criterion = FocalLoss() 34 | else: 35 | raise ValueError(f"loss {loss} not recognized") 36 | self.make_vis = make_vis 37 | if backbone in ["r101", "resnet101"]: 38 | self.model = deeplabv3_resnet101( 39 | pretrained=pretrained, num_classes=forecast_steps, aux_loss=aux_loss 40 | ) 41 | else: 42 | self.model = deeplabv3_resnet50( 43 | pretrained=pretrained, num_classes=forecast_steps, aux_loss=aux_loss 44 | ) 45 | 46 | if input_channels != 3: 47 | self.model.backbone.conv1 = torch.nn.Conv2d( 48 | input_channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False 49 | ) 50 | self.save_hyperparameters() 51 | 52 | @classmethod 53 | def from_config(cls, config): 54 | return DeeplabV3( 55 | forecast_steps=config.get("forecast_steps", 12), 56 | input_channels=config.get("in_channels", 12), 57 | hidden_dim=config.get("features", 64), 58 | num_layers=config.get("num_layers", 5), 59 | bilinear=config.get("bilinear", False), 60 | lr=config.get("lr", 0.001), 61 | ) 62 | 63 | def forward(self, x): 64 | return self.model.forward(x) 65 | 66 | def configure_optimizers(self): 67 | # DeepSpeedCPUAdam provides 5x to 7x speedup over torch.optim.adam(w) 68 | # optimizer = torch.optim.adam() 69 | return torch.optim.Adam(self.parameters(), lr=self.lr) 70 | 71 | def training_step(self, batch, batch_idx): 72 | x, y = batch 73 | y_hat = self(x) 74 | 75 | if self.make_vis: 76 | if np.random.random() < 0.01: 77 | self.visualize(x, y, y_hat, batch_idx) 78 | # Generally only care about the center x crop, so the model can take into account the clouds in the area without 79 | # being penalized for that, but for now, just do general MSE loss, also only care about first 12 channels 80 | loss = self.criterion(y_hat, y) 81 | self.log("train/loss", loss, on_step=True) 82 | return loss 83 | 84 | def validation_step(self, batch, batch_idx): 85 | x, y = batch 86 | y_hat = self(x) 87 | val_loss = self.criterion(y_hat, y) 88 | self.log("val/loss", val_loss, on_step=True, on_epoch=True) 89 | return val_loss 90 | 91 | def test_step(self, batch, batch_idx): 92 | x, y = batch 93 | y_hat = self(x, self.forecast_steps) 94 | loss = self.criterion(y_hat, y) 95 | return loss 96 | 97 | def visualize(self, x, y, y_hat, batch_idx): 98 | # the logger you used (in this case tensorboard) 99 | tensorboard = self.logger.experiment 100 | # Add all the different timesteps for a single prediction, 0.1% of the time 101 | in_image = ( 102 | x[0].cpu().detach().numpy() 103 | ) # Input image stack, Unet takes everything in channels, so no time dimension 104 | for i, in_slice in enumerate(in_image): 105 | j = 0 106 | if i % self.input_channels == 0: # First one 107 | j += 1 108 | tensorboard.add_image( 109 | f"Input_Image_{j}_Channel_{i}", in_slice, global_step=batch_idx 110 | ) # Each Channel 111 | out_image = y_hat[0].cpu().detach().numpy() 112 | for i, out_slice in enumerate(out_image): 113 | tensorboard.add_image( 114 | f"Output_Image_{i}", out_slice, global_step=batch_idx 115 | ) # Each Channel 116 | out_image = y[0].cpu().detach().numpy() 117 | for i, out_slice in enumerate(out_image): 118 | tensorboard.add_image( 119 | f"Target_Image_{i}", out_slice, global_step=batch_idx 120 | ) # Each Channel 121 | -------------------------------------------------------------------------------- /satflow/models/layers/SpatioTemporalLSTMCell_memory_decoupling.py: -------------------------------------------------------------------------------- 1 | __author__ = "yunbo" 2 | 3 | """ 4 | 5 | PredNN v2 adapted from https://github.com/thuml/predrnn-pytorch 6 | 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | 13 | class SpatioTemporalLSTMCell(nn.Module): 14 | def __init__(self, in_channel, num_hidden, width, filter_size, stride, layer_norm): 15 | super(SpatioTemporalLSTMCell, self).__init__() 16 | 17 | self.num_hidden = num_hidden 18 | self.padding = filter_size // 2 19 | self._forget_bias = 1.0 20 | if layer_norm: 21 | self.conv_x = nn.Sequential( 22 | nn.Conv2d( 23 | in_channel, 24 | num_hidden * 7, 25 | kernel_size=filter_size, 26 | stride=stride, 27 | padding=self.padding, 28 | bias=False, 29 | ), 30 | nn.LayerNorm([num_hidden * 7, width, width]), 31 | ) 32 | self.conv_h = nn.Sequential( 33 | nn.Conv2d( 34 | num_hidden, 35 | num_hidden * 4, 36 | kernel_size=filter_size, 37 | stride=stride, 38 | padding=self.padding, 39 | bias=False, 40 | ), 41 | nn.LayerNorm([num_hidden * 4, width, width]), 42 | ) 43 | self.conv_m = nn.Sequential( 44 | nn.Conv2d( 45 | num_hidden, 46 | num_hidden * 3, 47 | kernel_size=filter_size, 48 | stride=stride, 49 | padding=self.padding, 50 | bias=False, 51 | ), 52 | nn.LayerNorm([num_hidden * 3, width, width]), 53 | ) 54 | self.conv_o = nn.Sequential( 55 | nn.Conv2d( 56 | num_hidden * 2, 57 | num_hidden, 58 | kernel_size=filter_size, 59 | stride=stride, 60 | padding=self.padding, 61 | bias=False, 62 | ), 63 | nn.LayerNorm([num_hidden, width, width]), 64 | ) 65 | else: 66 | self.conv_x = nn.Sequential( 67 | nn.Conv2d( 68 | in_channel, 69 | num_hidden * 7, 70 | kernel_size=filter_size, 71 | stride=stride, 72 | padding=self.padding, 73 | bias=False, 74 | ), 75 | ) 76 | self.conv_h = nn.Sequential( 77 | nn.Conv2d( 78 | num_hidden, 79 | num_hidden * 4, 80 | kernel_size=filter_size, 81 | stride=stride, 82 | padding=self.padding, 83 | bias=False, 84 | ), 85 | ) 86 | self.conv_m = nn.Sequential( 87 | nn.Conv2d( 88 | num_hidden, 89 | num_hidden * 3, 90 | kernel_size=filter_size, 91 | stride=stride, 92 | padding=self.padding, 93 | bias=False, 94 | ), 95 | ) 96 | self.conv_o = nn.Sequential( 97 | nn.Conv2d( 98 | num_hidden * 2, 99 | num_hidden, 100 | kernel_size=filter_size, 101 | stride=stride, 102 | padding=self.padding, 103 | bias=False, 104 | ), 105 | ) 106 | self.conv_last = nn.Conv2d( 107 | num_hidden * 2, num_hidden, kernel_size=1, stride=1, padding=0, bias=False 108 | ) 109 | 110 | def forward(self, x_t, h_t, c_t, m_t): 111 | x_concat = self.conv_x(x_t) 112 | h_concat = self.conv_h(h_t) 113 | m_concat = self.conv_m(m_t) 114 | i_x, f_x, g_x, i_x_prime, f_x_prime, g_x_prime, o_x = torch.split( 115 | x_concat, self.num_hidden, dim=1 116 | ) 117 | i_h, f_h, g_h, o_h = torch.split(h_concat, self.num_hidden, dim=1) 118 | i_m, f_m, g_m = torch.split(m_concat, self.num_hidden, dim=1) 119 | 120 | i_t = torch.sigmoid(i_x + i_h) 121 | f_t = torch.sigmoid(f_x + f_h + self._forget_bias) 122 | g_t = torch.tanh(g_x + g_h) 123 | 124 | delta_c = i_t * g_t 125 | c_new = f_t * c_t + delta_c 126 | 127 | i_t_prime = torch.sigmoid(i_x_prime + i_m) 128 | f_t_prime = torch.sigmoid(f_x_prime + f_m + self._forget_bias) 129 | g_t_prime = torch.tanh(g_x_prime + g_m) 130 | 131 | delta_m = i_t_prime * g_t_prime 132 | m_new = f_t_prime * m_t + delta_m 133 | 134 | mem = torch.cat((c_new, m_new), 1) 135 | o_t = torch.sigmoid(o_x + o_h + self.conv_o(mem)) 136 | h_new = o_t * torch.tanh(self.conv_last(mem)) 137 | 138 | return h_new, c_new, m_new, delta_c, delta_m 139 | -------------------------------------------------------------------------------- /satflow/models/pl_metnet.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | import einops 4 | import torch 5 | import torch.nn as nn 6 | from metnet import MetNet 7 | from nowcasting_dataset.consts import NWP_DATA, SATELLITE_DATA, TOPOGRAPHIC_DATA 8 | from nowcasting_utils.models.base import BaseModel, register_model 9 | from nowcasting_utils.models.loss import get_loss 10 | from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR 11 | 12 | head_to_module = {"identity": nn.Identity()} 13 | 14 | 15 | @register_model 16 | class LitMetNet(BaseModel): 17 | def __init__( 18 | self, 19 | image_encoder: str = "downsampler", 20 | input_channels: int = 12, 21 | sat_channels: int = 12, 22 | input_size: int = 256, 23 | output_channels: int = 12, 24 | hidden_dim: int = 64, 25 | kernel_size: int = 3, 26 | num_layers: int = 1, 27 | num_att_layers: int = 1, 28 | head: str = "identity", 29 | forecast_steps: int = 48, 30 | temporal_dropout: float = 0.2, 31 | lr: float = 0.001, 32 | pretrained: bool = False, 33 | visualize: bool = False, 34 | loss: str = "mse", 35 | ): 36 | super(BaseModel, self).__init__() 37 | self.forecast_steps = forecast_steps 38 | self.input_channels = input_channels 39 | self.lr = lr 40 | self.pretrained = pretrained 41 | self.visualize = visualize 42 | self.output_channels = output_channels 43 | self.criterion = get_loss( 44 | loss, channel=output_channels, nonnegative_ssim=True, convert_range=True 45 | ) 46 | self.model = MetNet( 47 | image_encoder=image_encoder, 48 | input_channels=input_channels, 49 | sat_channels=sat_channels, 50 | input_size=input_size, 51 | output_channels=output_channels, 52 | hidden_dim=hidden_dim, 53 | kernel_size=kernel_size, 54 | num_layers=num_layers, 55 | num_att_layers=num_att_layers, 56 | head=head_to_module[head], 57 | forecast_steps=forecast_steps, 58 | temporal_dropout=temporal_dropout, 59 | ) 60 | # TODO: Would be nice to have this automatically applied to all classes 61 | # that inherit from BaseModel 62 | self.save_hyperparameters() 63 | 64 | def forward(self, imgs, **kwargs) -> Any: 65 | return self.model(imgs) 66 | 67 | def configure_optimizers(self): 68 | # DeepSpeedCPUAdam provides 5x to 7x speedup over torch.optim.adam(w) 69 | # optimizer = torch.optim.adam() 70 | optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) 71 | scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=10, max_epochs=100) 72 | lr_dict = { 73 | # REQUIRED: The scheduler instance 74 | "scheduler": scheduler, 75 | # The unit of the scheduler's step size, could also be 'step'. 76 | # 'epoch' updates the scheduler on epoch end whereas 'step' 77 | # updates it after a optimizer update. 78 | "interval": "step", 79 | # How many epochs/steps should pass between calls to 80 | # `scheduler.step()`. 1 corresponds to updating the learning 81 | # rate after every epoch/step. 82 | "frequency": 1, 83 | # If using the `LearningRateMonitor` callback to monitor the 84 | # learning rate progress, this keyword can be used to specify 85 | # a custom logged name 86 | "name": None, 87 | } 88 | return {"optimizer": optimizer, "lr_scheduler": lr_dict} 89 | 90 | def _combine_data_sources(self, x: Dict[str, torch.Tensor]) -> torch.Tensor: 91 | """ 92 | Combine different data sources from nowcasting dataset into a single input array for each example 93 | 94 | Mostly useful for adding topographic data to satellite 95 | 96 | Args: 97 | x: Dictionary containing mappings from nowcasting dataset names to the data 98 | 99 | Returns: 100 | Numpy array of [Batch, C, T, H, W] to give to model 101 | """ 102 | timesteps = x[SATELLITE_DATA].shape[2] 103 | topographic_repeat = einops.repeat(x[TOPOGRAPHIC_DATA], "b c h w -> b c t h w", t=timesteps) 104 | to_concat = [x[SATELLITE_DATA], topographic_repeat] 105 | to_concat = to_concat + x.get(NWP_DATA, []) 106 | input_data = torch.cat(to_concat, dim=1).float() # Cat along channel dim 107 | return input_data 108 | 109 | def _train_or_validate_step(self, batch, batch_idx, is_training: bool = True): 110 | x, y = batch 111 | y[SATELLITE_DATA] = y[SATELLITE_DATA].float() 112 | 113 | y_hat = self(self._combine_data_sources(x)) 114 | 115 | if self.visualize: 116 | if batch_idx == 1: 117 | self.visualize_step(x, y, y_hat, batch_idx, step="train" if is_training else "val") 118 | loss = self.criterion(y_hat, y[SATELLITE_DATA]) 119 | self.log(f"{'train' if is_training else 'val'}/loss", loss, prog_bar=True) 120 | frame_loss_dict = {} 121 | for f in range(self.forecast_steps): 122 | frame_loss = self.criterion(y_hat[:, f, :, :], y[SATELLITE_DATA][:, f, :, :]).item() 123 | frame_loss_dict[f"{'train' if is_training else 'val'}/frame_{f}_loss"] = frame_loss 124 | self.log_dict(frame_loss_dict) 125 | -------------------------------------------------------------------------------- /satflow/models/layers/RUnetLayers.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn import init 3 | 4 | from satflow.models.utils import get_conv_layer 5 | 6 | 7 | def init_weights(net, init_type="normal", gain=0.02): 8 | def init_func(m): 9 | classname = m.__class__.__name__ 10 | if hasattr(m, "weight") and ( 11 | classname.find("Conv") != -1 or classname.find("Linear") != -1 12 | ): 13 | if init_type == "normal": 14 | init.normal_(m.weight.data, 0.0, gain) 15 | elif init_type == "xavier": 16 | init.xavier_normal_(m.weight.data, gain=gain) 17 | elif init_type == "kaiming": 18 | init.kaiming_normal_(m.weight.data, a=0, mode="fan_in") 19 | elif init_type == "orthogonal": 20 | init.orthogonal_(m.weight.data, gain=gain) 21 | else: 22 | raise NotImplementedError( 23 | "initialization method [%s] is not implemented" % init_type 24 | ) 25 | if hasattr(m, "bias") and m.bias is not None: 26 | init.constant_(m.bias.data, 0.0) 27 | elif classname.find("BatchNorm2d") != -1: 28 | init.normal_(m.weight.data, 1.0, gain) 29 | init.constant_(m.bias.data, 0.0) 30 | 31 | print("initialize network with %s" % init_type) 32 | net.apply(init_func) 33 | 34 | 35 | class conv_block(nn.Module): 36 | def __init__(self, ch_in, ch_out, conv_type: str = "standard"): 37 | super(conv_block, self).__init__() 38 | conv2d = get_conv_layer(conv_type) 39 | self.conv = nn.Sequential( 40 | conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), 41 | nn.BatchNorm2d(ch_out), 42 | nn.ReLU(inplace=True), 43 | conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True), 44 | nn.BatchNorm2d(ch_out), 45 | nn.ReLU(inplace=True), 46 | ) 47 | 48 | def forward(self, x): 49 | x = self.conv(x) 50 | return x 51 | 52 | 53 | class up_conv(nn.Module): 54 | def __init__(self, ch_in, ch_out, conv_type: str = "standard"): 55 | super(up_conv, self).__init__() 56 | conv2d = get_conv_layer(conv_type) 57 | self.up = nn.Sequential( 58 | nn.Upsample(scale_factor=2), 59 | conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), 60 | nn.BatchNorm2d(ch_out), 61 | nn.ReLU(inplace=True), 62 | ) 63 | 64 | def forward(self, x): 65 | x = self.up(x) 66 | return x 67 | 68 | 69 | class Recurrent_block(nn.Module): 70 | def __init__(self, ch_out, t=2, conv_type: str = "standard"): 71 | super(Recurrent_block, self).__init__() 72 | conv2d = get_conv_layer(conv_type) 73 | self.t = t 74 | self.ch_out = ch_out 75 | self.conv = nn.Sequential( 76 | conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True), 77 | nn.BatchNorm2d(ch_out), 78 | nn.ReLU(inplace=True), 79 | ) 80 | 81 | def forward(self, x): 82 | for i in range(self.t): 83 | 84 | if i == 0: 85 | x1 = self.conv(x) 86 | 87 | x1 = self.conv(x + x1) 88 | return x1 89 | 90 | 91 | class RRCNN_block(nn.Module): 92 | def __init__(self, ch_in, ch_out, t=2, conv_type: str = "standard"): 93 | super(RRCNN_block, self).__init__() 94 | conv2d = get_conv_layer(conv_type) 95 | self.RCNN = nn.Sequential( 96 | Recurrent_block(ch_out, t=t, conv_type=conv_type), 97 | Recurrent_block(ch_out, t=t, conv_type=conv_type), 98 | ) 99 | self.Conv_1x1 = conv2d(ch_in, ch_out, kernel_size=1, stride=1, padding=0) 100 | 101 | def forward(self, x): 102 | x = self.Conv_1x1(x) 103 | x1 = self.RCNN(x) 104 | return x + x1 105 | 106 | 107 | class single_conv(nn.Module): 108 | def __init__(self, ch_in, ch_out, conv_type: str = "standard"): 109 | super(single_conv, self).__init__() 110 | conv2d = get_conv_layer(conv_type) 111 | self.conv = nn.Sequential( 112 | conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), 113 | nn.BatchNorm2d(ch_out), 114 | nn.ReLU(inplace=True), 115 | ) 116 | 117 | def forward(self, x): 118 | x = self.conv(x) 119 | return x 120 | 121 | 122 | class Attention_block(nn.Module): 123 | def __init__(self, F_g, F_l, F_int, conv_type: str = "standard"): 124 | super(Attention_block, self).__init__() 125 | conv2d = get_conv_layer(conv_type) 126 | self.W_g = nn.Sequential( 127 | conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True), 128 | nn.BatchNorm2d(F_int), 129 | ) 130 | 131 | self.W_x = nn.Sequential( 132 | conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True), 133 | nn.BatchNorm2d(F_int), 134 | ) 135 | 136 | self.psi = nn.Sequential( 137 | conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True), 138 | nn.BatchNorm2d(1), 139 | nn.Sigmoid(), 140 | ) 141 | 142 | self.relu = nn.ReLU(inplace=True) 143 | 144 | def forward(self, g, x): 145 | g1 = self.W_g(g) 146 | x1 = self.W_x(x) 147 | psi = self.relu(g1 + x1) 148 | psi = self.psi(psi) 149 | 150 | return x * psi 151 | -------------------------------------------------------------------------------- /satflow/models/layers/Generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | from satflow.models.layers.ConvGRU import ConvGRU 6 | from satflow.models.layers.GResBlock import GResBlock 7 | from satflow.models.layers.Normalization import SpectralNorm 8 | 9 | # from Module.CrossReplicaBN import ScaledCrossReplicaBatchNorm2d 10 | 11 | 12 | class Generator(nn.Module): 13 | def __init__(self, in_dim=120, latent_dim=4, n_class=4, ch=32, n_frames=48, hierar_flag=False): 14 | super().__init__() 15 | 16 | self.in_dim = in_dim 17 | self.latent_dim = latent_dim 18 | self.n_class = n_class 19 | self.ch = ch 20 | self.hierar_flag = hierar_flag 21 | self.n_frames = n_frames 22 | 23 | self.embedding = nn.Embedding(n_class, in_dim) 24 | 25 | self.affine_transfrom = nn.Linear(in_dim * 2, latent_dim * latent_dim * 8 * ch) 26 | 27 | self.conv = nn.ModuleList( 28 | [ 29 | ConvGRU( 30 | 8 * ch, 31 | hidden_sizes=[8 * ch, 16 * ch, 8 * ch], 32 | kernel_sizes=[3, 5, 3], 33 | n_layers=3, 34 | ), 35 | # ConvGRU(8 * ch, hidden_sizes=[8 * ch, 8 * ch], kernel_sizes=[3, 3], n_layers=2), 36 | GResBlock(8 * ch, 8 * ch, n_class=in_dim * 2, upsample_factor=1), 37 | GResBlock(8 * ch, 8 * ch, n_class=in_dim * 2), 38 | ConvGRU( 39 | 8 * ch, 40 | hidden_sizes=[8 * ch, 16 * ch, 8 * ch], 41 | kernel_sizes=[3, 5, 3], 42 | n_layers=3, 43 | ), 44 | # ConvGRU(8 * ch, hidden_sizes=[8 * ch, 8 * ch], kernel_sizes=[3, 3], n_layers=2), 45 | GResBlock(8 * ch, 8 * ch, n_class=in_dim * 2, upsample_factor=1), 46 | GResBlock(8 * ch, 8 * ch, n_class=in_dim * 2), 47 | ConvGRU( 48 | 8 * ch, 49 | hidden_sizes=[8 * ch, 16 * ch, 8 * ch], 50 | kernel_sizes=[3, 5, 3], 51 | n_layers=3, 52 | ), 53 | # ConvGRU(8 * ch, hidden_sizes=[8 * ch, 8 * ch], kernel_sizes=[3, 3], n_layers=2), 54 | GResBlock(8 * ch, 8 * ch, n_class=in_dim * 2, upsample_factor=1), 55 | GResBlock(8 * ch, 4 * ch, n_class=in_dim * 2), 56 | ConvGRU( 57 | 4 * ch, 58 | hidden_sizes=[4 * ch, 8 * ch, 4 * ch], 59 | kernel_sizes=[3, 5, 5], 60 | n_layers=3, 61 | ), 62 | # ConvGRU(4 * ch, hidden_sizes=[4 * ch, 4 * ch], kernel_sizes=[3, 5], n_layers=2), 63 | GResBlock(4 * ch, 4 * ch, n_class=in_dim * 2, upsample_factor=1), 64 | GResBlock(4 * ch, 2 * ch, n_class=in_dim * 2), 65 | ] 66 | ) 67 | 68 | self.colorize = SpectralNorm(nn.Conv2d(2 * ch, 3, kernel_size=(3, 3), padding=1)) 69 | 70 | def forward(self, x, class_id): 71 | 72 | if self.hierar_flag is True: 73 | noise_emb = torch.split(x, self.in_dim, dim=1) 74 | else: 75 | noise_emb = x 76 | 77 | class_emb = self.embedding(class_id) 78 | 79 | if self.hierar_flag is True: 80 | y = self.affine_transfrom( 81 | torch.cat((noise_emb[0], class_emb), dim=1) 82 | ) # B x (2 x ld x ch) 83 | else: 84 | y = self.affine_transfrom(torch.cat((noise_emb, class_emb), dim=1)) # B x (2 x ld x ch) 85 | 86 | y = y.view(-1, 8 * self.ch, self.latent_dim, self.latent_dim) # B x ch x ld x ld 87 | 88 | for k, conv in enumerate(self.conv): 89 | if isinstance(conv, ConvGRU): 90 | 91 | if k > 0: 92 | _, C, W, H = y.size() 93 | y = y.view(-1, self.n_frames, C, W, H).contiguous() 94 | 95 | frame_list = [] 96 | for i in range(self.n_frames): 97 | if k == 0: 98 | if i == 0: 99 | frame_list.append(conv(y)) # T x [B x ch x ld x ld] 100 | else: 101 | frame_list.append(conv(y, frame_list[i - 1])) 102 | else: 103 | if i == 0: 104 | frame_list.append( 105 | conv(y[:, 0, :, :, :].squeeze(1)) 106 | ) # T x [B x ch x ld x ld] 107 | else: 108 | frame_list.append(conv(y[:, i, :, :, :].squeeze(1), frame_list[i - 1])) 109 | frame_hidden_list = [] 110 | for i in frame_list: 111 | frame_hidden_list.append(i[-1].unsqueeze(0)) 112 | y = torch.cat(frame_hidden_list, dim=0) # T x B x ch x ld x ld 113 | 114 | y = y.permute(1, 0, 2, 3, 4).contiguous() # B x T x ch x ld x ld 115 | # print(y.size()) 116 | B, T, C, W, H = y.size() 117 | y = y.view(-1, C, W, H) 118 | 119 | elif isinstance(conv, GResBlock): 120 | condition = torch.cat([noise_emb, class_emb], dim=1) 121 | condition = condition.repeat(self.n_frames, 1) 122 | y = conv(y, condition) # BT, C, W, H 123 | 124 | y = F.relu(y) 125 | y = self.colorize(y) 126 | y = torch.tanh(y) 127 | 128 | BT, C, W, H = y.size() 129 | y = y.view(-1, self.n_frames, C, W, H) # B, T, C, W, H 130 | 131 | return y 132 | 133 | 134 | if __name__ == "__main__": 135 | 136 | batch_size = 5 137 | in_dim = 120 138 | n_class = 4 139 | n_frames = 4 140 | 141 | x = torch.randn(batch_size, in_dim).cuda() 142 | class_label = torch.randint(low=0, high=3, size=(batch_size,)).cuda() 143 | generator = Generator(in_dim, n_class=n_class, ch=3, n_frames=n_frames).cuda() 144 | y = generator(x, class_label) 145 | 146 | print(x.size()) 147 | print(y.size()) 148 | -------------------------------------------------------------------------------- /satflow/models/gan/common.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch 4 | from torch.nn import init 5 | 6 | 7 | def get_norm_layer(norm_type="instance"): 8 | """Return a normalization layer 9 | 10 | Parameters: 11 | norm_type (str) -- the name of the normalization layer: batch | instance | none 12 | 13 | For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). 14 | For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics. 15 | """ 16 | if norm_type == "batch": 17 | norm_layer = functools.partial(torch.nn.BatchNorm2d, affine=True, track_running_stats=True) 18 | elif norm_type == "instance": 19 | norm_layer = functools.partial( 20 | torch.nn.InstanceNorm2d, affine=False, track_running_stats=False 21 | ) 22 | elif norm_type == "none": 23 | 24 | def norm_layer(x): 25 | return torch.nn.Identity() 26 | 27 | else: 28 | raise NotImplementedError("normalization layer [%s] is not found" % norm_type) 29 | return norm_layer 30 | 31 | 32 | def init_weights(net, init_type="normal", init_gain=0.02): 33 | """Initialize network weights. 34 | 35 | Parameters: 36 | net (network) -- network to be initialized 37 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 38 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 39 | 40 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might 41 | work better for some applications. Feel free to try yourself. 42 | """ 43 | 44 | def init_func(m): # define the initialization function 45 | classname = m.__class__.__name__ 46 | if hasattr(m, "weight") and ( 47 | classname.find("Conv") != -1 or classname.find("Linear") != -1 48 | ): 49 | if init_type == "normal": 50 | init.normal_(m.weight.data, 0.0, init_gain) 51 | elif init_type == "xavier": 52 | init.xavier_normal_(m.weight.data, gain=init_gain) 53 | elif init_type == "kaiming": 54 | init.kaiming_normal_(m.weight.data, a=0, mode="fan_in") 55 | elif init_type == "orthogonal": 56 | init.orthogonal_(m.weight.data, gain=init_gain) 57 | else: 58 | raise NotImplementedError( 59 | "initialization method [%s] is not implemented" % init_type 60 | ) 61 | if hasattr(m, "bias") and m.bias is not None: 62 | init.constant_(m.bias.data, 0.0) 63 | elif ( 64 | classname.find("BatchNorm2d") != -1 65 | ): # BatchNorm Layer's weight is not a matrix; only normal distribution applies. 66 | init.normal_(m.weight.data, 1.0, init_gain) 67 | init.constant_(m.bias.data, 0.0) 68 | 69 | print("initialize network with %s" % init_type) 70 | net.apply(init_func) # apply the initialization function 71 | 72 | 73 | def init_net(net, init_type="normal", init_gain=0.02): 74 | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights 75 | Parameters: 76 | net (network) -- the network to be initialized 77 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 78 | gain (float) -- scaling factor for normal, xavier and orthogonal. 79 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 80 | 81 | Return an initialized network. 82 | """ 83 | init_weights(net, init_type, init_gain=init_gain) 84 | return net 85 | 86 | 87 | def cal_gradient_penalty( 88 | netD, real_data, fake_data, device, type="mixed", constant=1.0, lambda_gp=10.0 89 | ): 90 | """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028 91 | 92 | Arguments: 93 | netD (network) -- discriminator network 94 | real_data (tensor array) -- real images 95 | fake_data (tensor array) -- generated images from the generator 96 | device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') 97 | type (str) -- if we mix real and fake data or not [real | fake | mixed]. 98 | constant (float) -- the constant used in formula ( ||gradient||_2 - constant)^2 99 | lambda_gp (float) -- weight for this loss 100 | 101 | Returns the gradient penalty loss 102 | """ 103 | if lambda_gp > 0.0: 104 | if type == "real": # either use real images, fake images, or a linear interpolation of two. 105 | interpolatesv = real_data 106 | elif type == "fake": 107 | interpolatesv = fake_data 108 | elif type == "mixed": 109 | alpha = torch.rand(real_data.shape[0], 1, device=device) 110 | alpha = ( 111 | alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]) 112 | .contiguous() 113 | .view(*real_data.shape) 114 | ) 115 | interpolatesv = alpha * real_data + ((1 - alpha) * fake_data) 116 | else: 117 | raise NotImplementedError("{} not implemented".format(type)) 118 | interpolatesv.requires_grad_(True) 119 | disc_interpolates = netD(interpolatesv) 120 | gradients = torch.autograd.grad( 121 | outputs=disc_interpolates, 122 | inputs=interpolatesv, 123 | grad_outputs=torch.ones(disc_interpolates.size()).to(device), 124 | create_graph=True, 125 | retain_graph=True, 126 | only_inputs=True, 127 | ) 128 | gradients = gradients[0].view(real_data.size(0), -1) # flat the data 129 | gradient_penalty = ( 130 | ((gradients + 1e-16).norm(2, dim=1) - constant) ** 2 131 | ).mean() * lambda_gp # added eps 132 | return gradient_penalty, gradients 133 | return 0.0, None 134 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: satflow 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=conda_forge 8 | - _openmp_mutex=4.5=1_llvm 9 | - affine=2.3.0=py_0 10 | - appdirs=1.4.4=pyh9f0ad1d_0 11 | - asciitree=0.3.3=py_2 12 | - attrs=21.4.0=pyhd8ed1ab_0 13 | - blas=2.113=mkl 14 | - blas-devel=3.9.0=13_linux64_mkl 15 | - blosc=1.21.0=h9c3ff4c_0 16 | - bokeh=2.4.2=py39hf3d152e_0 17 | - boost-cpp=1.74.0=h359cf19_5 18 | - brotlipy=0.7.0=py39h3811e60_1003 19 | - bzip2=1.0.8=h7f98852_4 20 | - c-ares=1.18.1=h7f98852_0 21 | - ca-certificates=2021.10.8=ha878542_0 22 | - cached-property=1.5.2=hd8ed1ab_1 23 | - cached_property=1.5.2=pyha770c72_1 24 | - cairo=1.16.0=ha00ac49_1009 25 | - certifi=2021.10.8=py39hf3d152e_1 26 | - cffi=1.15.0=py39h4bc2ebd_0 27 | - cfitsio=4.0.0=h9a35b8e_0 28 | - cftime=1.5.2=py39hce5d2b2_0 29 | - charset-normalizer=2.0.10=pyhd8ed1ab_0 30 | - click=8.0.3=py39hf3d152e_1 31 | - click-plugins=1.1.1=py_0 32 | - cligj=0.7.2=pyhd8ed1ab_1 33 | - cloudpickle=2.0.0=pyhd8ed1ab_0 34 | - colorama=0.4.4=pyh9f0ad1d_0 35 | - configobj=5.0.6=py_0 36 | - cryptography=36.0.1=py39h95dcef6_0 37 | - cudatoolkit=11.3.1=h2bc3f7f_2 38 | - curl=7.81.0=h2574ce0_0 39 | - cytoolz=0.11.2=py39h3811e60_1 40 | - dask=2022.1.0=pyhd8ed1ab_0 41 | - dask-core=2022.1.0=pyhd8ed1ab_0 42 | - distributed=2022.1.0=py39hf3d152e_0 43 | - docutils=0.18.1=py39hf3d152e_0 44 | - donfig=0.6.0=pyhd8ed1ab_0 45 | - eccodes=2.24.2=h11d1a29_0 46 | - expat=2.4.3=h9c3ff4c_0 47 | - fasteners=0.16=pyhd8ed1ab_0 48 | - font-ttf-dejavu-sans-mono=2.37=hab24e00_0 49 | - font-ttf-inconsolata=3.000=h77eed37_0 50 | - font-ttf-source-code-pro=2.038=h77eed37_0 51 | - font-ttf-ubuntu=0.83=hab24e00_0 52 | - fontconfig=2.13.1=hba837de_1005 53 | - fonts-conda-ecosystem=1=0 54 | - fonts-conda-forge=1=0 55 | - freeglut=3.2.1=h9c3ff4c_2 56 | - freetype=2.10.4=h0708190_1 57 | - freexl=1.0.6=h7f98852_0 58 | - fsspec=2022.1.0=pyhd8ed1ab_0 59 | - geos=3.10.2=h9c3ff4c_0 60 | - geotiff=1.7.0=h6593c0a_6 61 | - gettext=0.19.8.1=h73d1719_1008 62 | - giflib=5.2.1=h36c2ea0_2 63 | - h5py=3.6.0=nompi_py39h7e08c79_100 64 | - hdf4=4.2.15=h10796ff_3 65 | - hdf5=1.12.1=nompi_h2750804_103 66 | - heapdict=1.0.1=py_0 67 | - icu=69.1=h9c3ff4c_0 68 | - idna=3.3=pyhd8ed1ab_0 69 | - importlib-metadata=4.10.1=py39hf3d152e_0 70 | - importlib_metadata=4.10.1=hd8ed1ab_0 71 | - jasper=2.0.33=ha77e612_0 72 | - jbig=2.1=h7f98852_2003 73 | - jinja2=3.0.3=pyhd8ed1ab_0 74 | - jpeg=9d=h36c2ea0_0 75 | - json-c=0.15=h98cffda_0 76 | - kealib=1.4.14=h87e4c3c_3 77 | - krb5=1.19.2=hcc1bbae_3 78 | - lcms2=2.12=hddcbb42_0 79 | - ld_impl_linux-64=2.36.1=hea4e1c9_2 80 | - lerc=3.0=h9c3ff4c_0 81 | - libaec=1.0.6=h9c3ff4c_0 82 | - libblas=3.9.0=13_linux64_mkl 83 | - libcblas=3.9.0=13_linux64_mkl 84 | - libcurl=7.81.0=h2574ce0_0 85 | - libdap4=3.20.6=hd7c4107_2 86 | - libdeflate=1.8=h7f98852_0 87 | - libedit=3.1.20191231=he28a2e2_2 88 | - libev=4.33=h516909a_1 89 | - libffi=3.4.2=h7f98852_5 90 | - libgcc-ng=11.2.0=h1d223b6_11 91 | - libgdal=3.4.1=h7b6f8d3_2 92 | - libgfortran-ng=11.2.0=h69a702a_11 93 | - libgfortran5=11.2.0=h5c6108e_11 94 | - libglib=2.70.2=h174f98d_1 95 | - libglu=9.0.0=he1b5a44_1001 96 | - libiconv=1.16=h516909a_0 97 | - libkml=1.3.0=h238a007_1014 98 | - liblapack=3.9.0=13_linux64_mkl 99 | - liblapacke=3.9.0=13_linux64_mkl 100 | - libnetcdf=4.8.1=nompi_hb3fd0d9_101 101 | - libnghttp2=1.43.0=h812cca2_1 102 | - libnsl=2.0.0=h7f98852_0 103 | - libpng=1.6.37=h21135ba_2 104 | - libpq=14.1=hd57d9b9_1 105 | - librttopo=1.1.0=hf69c175_9 106 | - libspatialite=5.0.1=h0e567f8_14 107 | - libssh2=1.10.0=ha56f1ee_2 108 | - libstdcxx-ng=11.2.0=he4da1e4_11 109 | - libtiff=4.3.0=h6f004c6_2 110 | - libuuid=2.32.1=h7f98852_1000 111 | - libuv=1.43.0=h7f98852_0 112 | - libwebp-base=1.2.2=h7f98852_1 113 | - libxcb=1.13=h7f98852_1004 114 | - libxml2=2.9.12=h885dcf4_1 115 | - libzip=1.8.0=h4de3113_1 116 | - libzlib=1.2.11=h36c2ea0_1013 117 | - llvm-openmp=12.0.1=h4bd325d_1 118 | - locket=0.2.0=py_2 119 | - lz4-c=1.9.3=h9c3ff4c_1 120 | - markupsafe=2.0.1=py39h3811e60_1 121 | - mkl=2022.0.1=h8d4b97c_803 122 | - mkl-devel=2022.0.1=ha770c72_804 123 | - mkl-include=2022.0.1=h8d4b97c_803 124 | - monotonic=1.5=py_0 125 | - msgpack-python=1.0.3=py39h1a9c180_0 126 | - ncurses=6.3=h9c3ff4c_0 127 | - netcdf4=1.5.8=nompi_py39h64b754b_101 128 | - nspr=4.32=h9c3ff4c_1 129 | - nss=3.74=hb5efdd6_0 130 | - numcodecs=0.9.1=py39he80948d_2 131 | - olefile=0.46=pyh9f0ad1d_1 132 | - openjpeg=2.4.0=hb52868f_1 133 | - openssl=1.1.1l=h7f98852_0 134 | - packaging=21.3=pyhd8ed1ab_0 135 | - pandas=1.4.0=py39hde0f152_0 136 | - partd=1.2.0=pyhd8ed1ab_0 137 | - pcre=8.45=h9c3ff4c_0 138 | - pillow=8.4.0=py39ha612740_0 139 | - pip=21.3.1=pyhd8ed1ab_0 140 | - pixman=0.40.0=h36c2ea0_0 141 | - pooch=1.5.2=pyhd8ed1ab_0 142 | - poppler=21.11.0=ha39eefc_0 143 | - poppler-data=0.4.11=hd8ed1ab_0 144 | - postgresql=14.1=h2510834_1 145 | - proj=8.2.1=h277dcde_0 146 | - psutil=5.9.0=py39h3811e60_0 147 | - pthread-stubs=0.4=h36c2ea0_1001 148 | - pycparser=2.21=pyhd8ed1ab_0 149 | - pykdtree=1.3.4=py39hce5d2b2_2 150 | - pyopenssl=21.0.0=pyhd8ed1ab_0 151 | - pyorbital=1.7.1=pyhd8ed1ab_0 152 | - pyparsing=3.0.7=pyhd8ed1ab_0 153 | - pyproj=3.3.0=py39hab5ddba_1 154 | - pyresample=1.22.3=py39hde0f152_0 155 | - pysocks=1.7.1=py39hf3d152e_4 156 | - pyspectral=0.10.6=pyhd8ed1ab_0 157 | - python=3.9.9=h62f1059_0_cpython 158 | - python-dateutil=2.8.2=pyhd8ed1ab_0 159 | - python-geotiepoints=1.3.0=py39hce5d2b2_2 160 | - python_abi=3.9=2_cp39 161 | - pytorch-mutex=1.0=cuda 162 | - pytorch 163 | - pytz=2021.3=pyhd8ed1ab_0 164 | - pyyaml=6.0=py39h3811e60_3 165 | - rasterio=1.2.10=py39h0401cea_4 166 | - readline=8.1=h46c0cb4_0 167 | - requests=2.27.1=pyhd8ed1ab_0 168 | - satpy=0.33.1=pyhd8ed1ab_0 169 | - scipy=1.7.3=py39hee8e79c_0 170 | - six=1.16.0=pyh6c4a22f_0 171 | - snuggs=1.4.7=py_0 172 | - sortedcontainers=2.4.0=pyhd8ed1ab_0 173 | - sqlite=3.37.0=h9cd32fc_0 174 | - tbb=2021.5.0=h4bd325d_0 175 | - tblib=1.7.0=pyhd8ed1ab_0 176 | - tiledb=2.6.1=h2038895_0 177 | - tk=8.6.11=h27826a3_1 178 | - toolz=0.11.2=pyhd8ed1ab_0 179 | - tornado=6.1=py39h3811e60_2 180 | - tqdm=4.62.3=pyhd8ed1ab_0 181 | - trollimage=1.17.0=pyhd8ed1ab_0 182 | - trollsift=0.3.5=pyh44b312d_0 183 | - typing_extensions=4.0.1=pyha770c72_0 184 | - tzcode=2021e=h7f98852_0 185 | - tzdata=2021e=he74cb21_0 186 | - urllib3=1.26.8=pyhd8ed1ab_1 187 | - wheel=0.37.1=pyhd8ed1ab_0 188 | - xarray=0.20.2=pyhd8ed1ab_0 189 | - xerces-c=3.2.3=h8ce2273_4 190 | - xorg-fixesproto=5.0=h7f98852_1002 191 | - xorg-inputproto=2.3.2=h7f98852_1002 192 | - xorg-kbproto=1.0.7=h7f98852_1002 193 | - xorg-libice=1.0.10=h7f98852_0 194 | - xorg-libsm=1.2.3=hd9c2040_1000 195 | - xorg-libx11=1.7.2=h7f98852_0 196 | - xorg-libxau=1.0.9=h7f98852_0 197 | - xorg-libxdmcp=1.1.3=h7f98852_0 198 | - xorg-libxext=1.3.4=h7f98852_1 199 | - xorg-libxfixes=5.0.3=h7f98852_1004 200 | - xorg-libxi=1.7.10=h7f98852_0 201 | - xorg-libxrender=0.9.10=h7f98852_1003 202 | - xorg-renderproto=0.11.1=h7f98852_1002 203 | - xorg-xextproto=7.3.0=h7f98852_1002 204 | - xorg-xproto=7.0.31=h7f98852_1007 205 | - xz=5.2.5=h516909a_1 206 | - yaml=0.2.5=h7f98852_2 207 | - zarr=2.10.3=pyhd8ed1ab_0 208 | - zict=2.0.0=py_0 209 | - zipp=3.7.0=pyhd8ed1ab_0 210 | - zlib=1.2.11=h36c2ea0_1013 211 | - zstd=1.5.2=ha95c52a_0 212 | -------------------------------------------------------------------------------- /satflow/data/utils/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import io 3 | import re 4 | 5 | import affine 6 | import numpy as np 7 | import yaml 8 | 9 | try: 10 | from pyresample import load_area 11 | from satpy import Scene 12 | 13 | _SAT_LIBS = True 14 | except: 15 | print("No pyresample or satpy") 16 | _SAT_LIBS = False 17 | 18 | 19 | def eumetsat_filename_to_datetime(inner_tar_name): 20 | """Takes a file from the EUMETSAT API and returns 21 | the date and time part of the filename""" 22 | 23 | p = re.compile("^MSG[23]-SEVI-MSG15-0100-NA-(\d*)\.") 24 | title_match = p.match(inner_tar_name) 25 | date_str = title_match.group(1) 26 | return datetime.datetime.strptime(date_str, "%Y%m%d%H%M%S") 27 | 28 | 29 | def eumetsat_name_to_datetime(filename: str): 30 | date_str = filename.split("0100-0100-")[-1].split(".")[0] 31 | return datetime.datetime.strptime(date_str, "%Y%m%d%H%M%S") 32 | 33 | 34 | def retrieve_pixel_value(geo_coord, data_source): 35 | """Return floating-point value that corresponds to given point. 36 | Taken from https://gis.stackexchange.com/questions/221292/retrieve-pixel-value-with-geographic-coordinate-as-input-with-gdal""" 37 | x, y = geo_coord[0], geo_coord[1] 38 | forward_transform = affine.Affine.from_gdal(*data_source.GetGeoTransform()) 39 | reverse_transform = ~forward_transform 40 | px, py = reverse_transform * (x, y) 41 | px, py = int(px + 0.5), int(py + 0.5) 42 | pixel_coord = px, py 43 | 44 | data_array = np.array(data_source.GetRasterBand(1).ReadAsArray()) 45 | return data_array[pixel_coord[0]][pixel_coord[1]] 46 | 47 | 48 | def map_satellite_to_mercator( 49 | native_satellite=None, 50 | grib_files=None, 51 | bufr_files=None, 52 | bands=( 53 | "HRV", 54 | "IR_016", 55 | "IR_039", 56 | "IR_087", 57 | "IR_097", 58 | "IR_108", 59 | "IR_120", 60 | "IR_134", 61 | "VIS006", 62 | "VIS008", 63 | "WV_062", 64 | "WV_073", 65 | ), 66 | save_scene="geotiff", 67 | save_loc=None, 68 | ): 69 | """ 70 | Opens, transforms to Transverse Mercator over Europe, and optionally saves it to files on disk. 71 | :param native_satellite: 72 | :param grib_files: 73 | :param bufr_files: 74 | :param bands: 75 | :param save_scene: 76 | :param save_loc: Save location 77 | :return: 78 | """ 79 | if not _SAT_LIBS: 80 | raise EnvironmentError("Pyresample or Satpy are not installed, please install them first") 81 | areas = load_area("/home/bieker/Development/satflow/satflow/resources/areas.yaml") 82 | filenames = {} 83 | if native_satellite is not None: 84 | filenames["seviri_l1b_native"] = [native_satellite] 85 | if grib_files is not None: 86 | filenames["seviri_l2_grib"] = [grib_files] 87 | if bufr_files is not None: 88 | filenames["seviri_l2_bufr"] = [bufr_files] 89 | scene = Scene(filenames=filenames) 90 | scene.load(bands) 91 | # By default resamples to 3km, as thats the native resolution of all bands other than HRV 92 | scene = scene.resample(areas[0]) 93 | if save_loc is not None: 94 | # Now the relvant data is all together, just need to save it somehow, or return it to the calling process 95 | scene.save_datasets(writer=save_scene, base_dir=save_loc, enhance=False) 96 | return scene 97 | 98 | 99 | def create_time_layer(dt: datetime.datetime, shape): 100 | """Create 3 layer for current time of observation""" 101 | month = dt.month / 12 102 | day = dt.day / 31 103 | hour = dt.hour / 24 104 | # minute = dt.minute / 60 105 | return np.stack([np.full(shape, month), np.full(shape, day), np.full(shape, hour)], axis=-1) 106 | 107 | 108 | def load_np(data): 109 | import numpy.lib.format 110 | 111 | stream = io.BytesIO(data) 112 | return numpy.lib.format.read_array(stream) 113 | 114 | 115 | def binarize_mask(mask): 116 | """Binarize mask, taking max value as the data, and setting everything else to 0""" 117 | tmp_mask = np.zeros_like(mask) 118 | tmp_mask[np.isclose(np.round(mask), 2)] = 1 119 | return tmp_mask 120 | 121 | 122 | def create_pixel_coord_layers(x_dim: int, y_dim: int, with_r: bool = False) -> np.ndarray: 123 | """ 124 | Creates Coord layer for CoordConv model 125 | 126 | :param x_dim: size of x dimension for output 127 | :param y_dim: size of y dimension for output 128 | :param with_r: Whether to include polar coordinates from center 129 | :return: (2, x_dim, y_dim) or (3, x_dim, y_dim) array of the pixel coordinates 130 | """ 131 | xx_ones = np.ones([1, x_dim], dtype=np.int32) 132 | xx_ones = np.expand_dims(xx_ones, -1) 133 | 134 | xx_range = np.expand_dims(np.arange(x_dim), 0) 135 | xx_range = np.expand_dims(xx_range, 1) 136 | 137 | xx_channel = np.matmul(xx_ones, xx_range) 138 | xx_channel = np.expand_dims(xx_channel, -1) 139 | 140 | yy_ones = np.ones([1, y_dim], dtype=np.int32) 141 | yy_ones = np.expand_dims(yy_ones, 1) 142 | 143 | yy_range = np.expand_dims(np.arange(y_dim), 0) 144 | yy_range = np.expand_dims(yy_range, -1) 145 | 146 | yy_channel = np.matmul(yy_range, yy_ones) 147 | yy_channel = np.expand_dims(yy_channel, -1) 148 | 149 | xx_channel = xx_channel.astype("float32") / (x_dim - 1) 150 | yy_channel = yy_channel.astype("float32") / (y_dim - 1) 151 | 152 | xx_channel = xx_channel * 2 - 1 153 | yy_channel = yy_channel * 2 - 1 154 | ret = np.stack([xx_channel, yy_channel], axis=0) 155 | 156 | if with_r: 157 | rr = np.sqrt(np.square(xx_channel - 0.5) + np.square(yy_channel - 0.5)) 158 | ret = np.concatenate([ret, np.expand_dims(rr, axis=0)], axis=0) 159 | ret = np.moveaxis(ret, [1], [0]) 160 | return ret 161 | 162 | 163 | def check_channels(config: dict) -> int: 164 | """ 165 | Checks the number of channels needed per timestep, to use for preallocating the numpy array 166 | Is not the same as the one for training, as that includes the number of channels after the array is partly 167 | flattened 168 | Args: 169 | config: 170 | 171 | Returns: 172 | 173 | """ 174 | channels = len(config.get("bands", [])) 175 | channels = channels + 1 if config.get("use_mask", False) else channels 176 | channels = ( 177 | channels + 3 178 | if config.get("use_time", False) and not config.get("time_aux", False) 179 | else channels 180 | ) 181 | # if config.get("time_as_channels", False): 182 | # Calc number of channels + inital ones 183 | # channels = channels * (config["num_timesteps"] + 1) 184 | channels = channels + 1 if config.get("use_topo", False) else channels 185 | channels = channels + 3 if config.get("use_latlon", False) else channels 186 | channels = channels + 2 if config.get("add_pixel_coords", False) else channels 187 | channels = channels + 1 if config.get("add_polar_coords", False) else channels 188 | return channels 189 | 190 | 191 | def crop_center(img: np.ndarray, cropx: int, cropy: int) -> np.ndarray: 192 | """Crops center of image through timestack, fails if all the images are concatenated as channels""" 193 | t, c, y, x = img.shape 194 | startx = x // 2 - (cropx // 2) 195 | starty = y // 2 - (cropy // 2) 196 | return img[:, :, starty : starty + cropy, startx : startx + cropx] 197 | 198 | 199 | def load_config(config_file): 200 | with open(config_file, "r") as cfg: 201 | return yaml.load(cfg, Loader=yaml.FullLoader)["config"] 202 | -------------------------------------------------------------------------------- /satflow/models/runet.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import antialiased_cnns 4 | import numpy as np 5 | import pytorch_lightning as pl 6 | import torch 7 | import torchvision 8 | from nowcasting_utils.models.base import register_model 9 | from nowcasting_utils.models.loss import get_loss 10 | 11 | from satflow.models.layers.RUnetLayers import * 12 | 13 | 14 | @register_model 15 | class RUnet(pl.LightningModule): 16 | def __init__( 17 | self, 18 | input_channels: int = 12, 19 | forecast_steps: int = 48, 20 | recurrent_steps: int = 2, 21 | loss: Union[str, torch.nn.Module] = "mse", 22 | lr: float = 0.001, 23 | visualize: bool = False, 24 | conv_type: str = "standard", 25 | pretrained: bool = False, 26 | ): 27 | super().__init__() 28 | self.input_channels = input_channels 29 | self.forecast_steps = forecast_steps 30 | self.module = R2U_Net( 31 | img_ch=input_channels, output_ch=forecast_steps, t=recurrent_steps, conv_type=conv_type 32 | ) 33 | self.lr = lr 34 | self.input_channels = input_channels 35 | self.forecast_steps = forecast_steps 36 | self.criterion = get_loss(loss=loss) 37 | self.visualize = visualize 38 | self.save_hyperparameters() 39 | 40 | @classmethod 41 | def from_config(cls, config): 42 | return RUnet( 43 | forecast_steps=config.get("forecast_steps", 12), 44 | input_channels=config.get("in_channels", 12), 45 | lr=config.get("lr", 0.001), 46 | ) 47 | 48 | def forward(self, x): 49 | return self.model.forward(x) 50 | 51 | def configure_optimizers(self): 52 | # DeepSpeedCPUAdam provides 5x to 7x speedup over torch.optim.adam(w) 53 | # optimizer = torch.optim.adam() 54 | return torch.optim.Adam(self.parameters(), lr=self.lr) 55 | 56 | def training_step(self, batch, batch_idx): 57 | x, y = batch 58 | x = x.float() 59 | y_hat = self(x) 60 | 61 | if self.visualize: 62 | if np.random.random() < 0.01: 63 | self.visualize_step(x, y, y_hat, batch_idx) 64 | # Generally only care about the center x crop, so the model can take into account the clouds in the area without 65 | # being penalized for that, but for now, just do general MSE loss, also only care about first 12 channels 66 | loss = self.criterion(y_hat, y) 67 | self.log("train/loss", loss, on_step=True) 68 | frame_loss_dict = {} 69 | for f in range(self.forecast_steps): 70 | frame_loss = self.criterion(y_hat[:, f, :, :], y[:, f, :, :]).item() 71 | frame_loss_dict[f"train/frame_{f}_loss"] = frame_loss 72 | self.log_dict(frame_loss_dict) 73 | return loss 74 | 75 | def validation_step(self, batch, batch_idx): 76 | x, y = batch 77 | x = x.float() 78 | y_hat = self(x) 79 | val_loss = self.criterion(y_hat, y) 80 | self.log("val/loss", val_loss) 81 | # Save out loss per frame as well 82 | frame_loss_dict = {} 83 | for f in range(self.forecast_steps): 84 | frame_loss = self.criterion(y_hat[:, f, :, :], y[:, f, :, :]).item() 85 | frame_loss_dict[f"val/frame_{f}_loss"] = frame_loss 86 | self.log_dict(frame_loss_dict) 87 | return val_loss 88 | 89 | def test_step(self, batch, batch_idx): 90 | x, y = batch 91 | x = x.float() 92 | y_hat = self(x) 93 | loss = self.criterion(y_hat, y) 94 | return loss 95 | 96 | def visualize_step(self, x, y, y_hat, batch_idx, step="train"): 97 | tensorboard = self.logger.experiment[0] 98 | # Add all the different timesteps for a single prediction, 0.1% of the time 99 | images = x[0].cpu().detach() 100 | images = [torch.unsqueeze(img, dim=0) for img in images] 101 | image_grid = torchvision.utils.make_grid(images, nrow=self.channels_per_timestep) 102 | tensorboard.add_image(f"{step}/Input_Image_Stack", image_grid, global_step=batch_idx) 103 | images = y[0].cpu().detach() 104 | images = [torch.unsqueeze(img, dim=0) for img in images] 105 | image_grid = torchvision.utils.make_grid(images, nrow=12) 106 | tensorboard.add_image(f"{step}/Target_Image_Stack", image_grid, global_step=batch_idx) 107 | images = y_hat[0].cpu().detach() 108 | images = [torch.unsqueeze(img, dim=0) for img in images] 109 | image_grid = torchvision.utils.make_grid(images, nrow=12) 110 | tensorboard.add_image(f"{step}/Generated_Image_Stack", image_grid, global_step=batch_idx) 111 | 112 | 113 | class R2U_Net(nn.Module): 114 | def __init__(self, img_ch=3, output_ch=1, t=2, conv_type: str = "standard"): 115 | super(R2U_Net, self).__init__() 116 | if conv_type == "antialiased": 117 | self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=1) 118 | self.antialiased = True 119 | else: 120 | self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2) 121 | self.antialiased = False 122 | 123 | self.Upsample = nn.Upsample(scale_factor=2) 124 | 125 | self.RRCNN1 = RRCNN_block(ch_in=img_ch, ch_out=64, t=t, conv_type=conv_type) 126 | self.Blur1 = antialiased_cnns.BlurPool(64, stride=2) if self.antialiased else nn.Identity() 127 | self.RRCNN2 = RRCNN_block(ch_in=64, ch_out=128, t=t, conv_type=conv_type) 128 | self.Blur2 = antialiased_cnns.BlurPool(128, stride=2) if self.antialiased else nn.Identity() 129 | 130 | self.RRCNN3 = RRCNN_block(ch_in=128, ch_out=256, t=t, conv_type=conv_type) 131 | self.Blur3 = antialiased_cnns.BlurPool(256, stride=2) if self.antialiased else nn.Identity() 132 | 133 | self.RRCNN4 = RRCNN_block(ch_in=256, ch_out=512, t=t, conv_type=conv_type) 134 | self.Blur4 = antialiased_cnns.BlurPool(512, stride=2) if self.antialiased else nn.Identity() 135 | 136 | self.RRCNN5 = RRCNN_block(ch_in=512, ch_out=1024, t=t, conv_type=conv_type) 137 | 138 | self.Up5 = up_conv(ch_in=1024, ch_out=512) 139 | self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512, t=t, conv_type=conv_type) 140 | 141 | self.Up4 = up_conv(ch_in=512, ch_out=256) 142 | self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256, t=t, conv_type=conv_type) 143 | 144 | self.Up3 = up_conv(ch_in=256, ch_out=128) 145 | self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128, t=t, conv_type=conv_type) 146 | 147 | self.Up2 = up_conv(ch_in=128, ch_out=64) 148 | self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64, t=t, conv_type=conv_type) 149 | 150 | self.Conv_1x1 = nn.Conv2d(64, output_ch, kernel_size=1, stride=1, padding=0) 151 | 152 | def forward(self, x): 153 | # encoding path 154 | x1 = self.RRCNN1(x) 155 | 156 | x2 = self.Maxpool(x1) 157 | x2 = self.Blur1(x2) 158 | x2 = self.RRCNN2(x2) 159 | 160 | x3 = self.Maxpool(x2) 161 | x3 = self.Blur2(x3) 162 | x3 = self.RRCNN3(x3) 163 | 164 | x4 = self.Maxpool(x3) 165 | x4 = self.Blur3(x4) 166 | x4 = self.RRCNN4(x4) 167 | 168 | x5 = self.Maxpool(x4) 169 | x5 = self.Blur4(x5) 170 | x5 = self.RRCNN5(x5) 171 | 172 | # decoding + concat path 173 | d5 = self.Up5(x5) 174 | d5 = torch.cat((x4, d5), dim=1) 175 | d5 = self.Up_RRCNN5(d5) 176 | 177 | d4 = self.Up4(d5) 178 | d4 = torch.cat((x3, d4), dim=1) 179 | d4 = self.Up_RRCNN4(d4) 180 | 181 | d3 = self.Up3(d4) 182 | d3 = torch.cat((x2, d3), dim=1) 183 | d3 = self.Up_RRCNN3(d3) 184 | 185 | d2 = self.Up2(d3) 186 | d2 = torch.cat((x1, d2), dim=1) 187 | d2 = self.Up_RRCNN2(d2) 188 | 189 | d1 = self.Conv_1x1(d2) 190 | 191 | return d1 192 | -------------------------------------------------------------------------------- /satflow/models/pix2pix.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import numpy as np 4 | import pytorch_lightning as pl 5 | import torch 6 | import torchvision 7 | from nowcasting_utils.models.base import register_model 8 | from torch.optim import lr_scheduler 9 | 10 | from satflow.models.gan import define_generator 11 | from satflow.models.gan.discriminators import GANLoss, define_discriminator 12 | 13 | 14 | @register_model 15 | class Pix2Pix(pl.LightningModule): 16 | def __init__( 17 | self, 18 | forecast_steps: int = 48, 19 | input_channels: int = 12, 20 | lr: float = 0.0002, 21 | beta1: float = 0.5, 22 | beta2: float = 0.999, 23 | num_filters: int = 64, 24 | generator_model: str = "unet_128", 25 | norm: str = "batch", 26 | use_dropout: bool = False, 27 | discriminator_model: str = "basic", 28 | discriminator_layers: int = 0, 29 | loss: str = "vanilla", 30 | scheduler: str = "plateau", 31 | lr_epochs: int = 10, 32 | lambda_l1: float = 100.0, 33 | channels_per_timestep: int = 12, 34 | pretrained: bool = False, 35 | ): 36 | super().__init__() 37 | self.lr = lr 38 | self.b1 = beta1 39 | self.b2 = beta2 40 | self.loss = loss 41 | self.lambda_l1 = lambda_l1 42 | self.lr_epochs = lr_epochs 43 | self.lr_method = scheduler 44 | self.forecast_steps = forecast_steps 45 | self.input_channels = input_channels 46 | self.output_channels = forecast_steps * 12 47 | self.channels_per_timestep = channels_per_timestep 48 | 49 | # define networks (both generator and discriminator) 50 | self.generator = define_generator( 51 | input_channels, self.output_channels, num_filters, generator_model, norm, use_dropout 52 | ) 53 | 54 | self.discriminator = define_discriminator( 55 | input_channels + self.output_channels, 56 | num_filters, 57 | discriminator_model, 58 | discriminator_layers, 59 | norm, 60 | ) 61 | 62 | # define loss functions 63 | self.criterionGAN = GANLoss(loss) 64 | self.criterionL1 = torch.nn.L1Loss() 65 | # initialize optimizers; schedulers will be automatically created by function .\ 66 | self.save_hyperparameters() 67 | 68 | def forward(self, x): 69 | return self.generator(x) 70 | 71 | def visualize_step(self, x, y, y_hat, batch_idx, step): 72 | # the logger you used (in this case tensorboard) 73 | tensorboard = self.logger.experiment[0] 74 | # Add all the different timesteps for a single prediction, 0.1% of the time 75 | images = x[0].cpu().detach() 76 | images = [torch.unsqueeze(img, dim=0) for img in images] 77 | image_grid = torchvision.utils.make_grid(images, nrow=self.channels_per_timestep) 78 | tensorboard.add_image(f"{step}/Input_Image_Stack", image_grid, global_step=batch_idx) 79 | images = y[0].cpu().detach() 80 | images = [torch.unsqueeze(img, dim=0) for img in images] 81 | image_grid = torchvision.utils.make_grid(images, nrow=12) 82 | tensorboard.add_image(f"{step}/Target_Image_Stack", image_grid, global_step=batch_idx) 83 | images = y_hat[0].cpu().detach() 84 | images = [torch.unsqueeze(img, dim=0) for img in images] 85 | image_grid = torchvision.utils.make_grid(images, nrow=12) 86 | tensorboard.add_image(f"{step}/Generated_Image_Stack", image_grid, global_step=batch_idx) 87 | 88 | def training_step(self, batch, batch_idx, optimizer_idx): 89 | images, future_images, future_masks = batch 90 | # train generator 91 | if optimizer_idx == 0: 92 | # generate images 93 | generated_images = self(images) 94 | fake = torch.cat((images, generated_images), 1) 95 | # log sampled images 96 | # if np.random.random() < 0.01: 97 | self.visualize_step(images, future_images, generated_images, batch_idx, step="train") 98 | 99 | # adversarial loss is binary cross-entropy 100 | gan_loss = self.criterionGAN(self.discriminator(fake), True) 101 | l1_loss = self.criterionL1(generated_images, future_images) * self.lambda_l1 102 | g_loss = gan_loss + l1_loss 103 | tqdm_dict = {"g_loss": g_loss} 104 | output = OrderedDict({"loss": g_loss, "progress_bar": tqdm_dict, "log": tqdm_dict}) 105 | self.log_dict({"train/g_loss": g_loss}) 106 | return output 107 | 108 | # train discriminator 109 | if optimizer_idx == 1: 110 | # Measure discriminator's ability to classify real from generated samples 111 | 112 | # how well can it label as real? 113 | real = torch.cat((images, future_images), 1) 114 | real_loss = self.criterionGAN(self.discriminator(real), True) 115 | 116 | # how well can it label as fake? 117 | gen_output = self(images) 118 | fake = torch.cat((images, gen_output), 1) 119 | fake_loss = self.criterionGAN(self.discriminator(fake), True) 120 | 121 | # discriminator loss is the average of these 122 | d_loss = (real_loss + fake_loss) / 2 123 | tqdm_dict = {"d_loss": d_loss} 124 | output = OrderedDict({"loss": d_loss, "progress_bar": tqdm_dict, "log": tqdm_dict}) 125 | self.log_dict({"train/d_loss": d_loss}) 126 | return output 127 | 128 | def validation_step(self, batch, batch_idx): 129 | images, future_images, future_masks = batch 130 | # generate images 131 | generated_images = self(images) 132 | fake = torch.cat((images, generated_images), 1) 133 | # log sampled images 134 | if np.random.random() < 0.01: 135 | self.visualize_step(images, future_images, generated_images, batch_idx, step="val") 136 | 137 | # adversarial loss is binary cross-entropy 138 | gan_loss = self.criterionGAN(self.discriminator(fake), True) 139 | l1_loss = self.criterionL1(generated_images, future_images) * self.lambda_l1 140 | g_loss = gan_loss + l1_loss 141 | # how well can it label as real? 142 | real = torch.cat((images, future_images), 1) 143 | real_loss = self.criterionGAN(self.discriminator(real), True) 144 | 145 | # how well can it label as fake? 146 | fake_loss = self.criterionGAN(self.discriminator(fake), True) 147 | 148 | # discriminator loss is the average of these 149 | d_loss = (real_loss + fake_loss) / 2 150 | tqdm_dict = {"d_loss": d_loss} 151 | output = OrderedDict( 152 | { 153 | "val/discriminator_loss": d_loss, 154 | "val/generator_loss": g_loss, 155 | "progress_bar": tqdm_dict, 156 | "log": tqdm_dict, 157 | } 158 | ) 159 | self.log_dict({"val/d_loss": d_loss, "val/g_loss": g_loss, "val/loss": d_loss + g_loss}) 160 | return output 161 | 162 | def configure_optimizers(self): 163 | lr = self.lr 164 | b1 = self.b1 165 | b2 = self.b2 166 | 167 | opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2)) 168 | opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2)) 169 | if self.lr_method == "plateau": 170 | g_scheduler = lr_scheduler.ReduceLROnPlateau( 171 | opt_g, mode="min", factor=0.2, threshold=0.01, patience=10 172 | ) 173 | d_scheduler = lr_scheduler.ReduceLROnPlateau( 174 | opt_d, mode="min", factor=0.2, threshold=0.01, patience=10 175 | ) 176 | elif self.lr_method == "cosine": 177 | g_scheduler = lr_scheduler.CosineAnnealingLR(opt_g, T_max=self.lr_epochs, eta_min=0) 178 | d_scheduler = lr_scheduler.CosineAnnealingLR(opt_d, T_max=self.lr_epochs, eta_min=0) 179 | else: 180 | return NotImplementedError("learning rate policy is not implemented") 181 | 182 | return [opt_g, opt_d], [g_scheduler, d_scheduler] 183 | -------------------------------------------------------------------------------- /satflow/core/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import typing as Dict 3 | 4 | import yaml 5 | from nowcasting_dataset.config.load import load_yaml_configuration 6 | 7 | 8 | def load_config(file_path: str) -> Dict: 9 | with open(file_path, "r") as f: 10 | config = yaml.load(f) 11 | return config 12 | 13 | 14 | def make_logger(name: str, level=logging.DEBUG) -> logging.Logger: 15 | logger = logging.getLogger(name) 16 | logger.setLevel(level=level) 17 | return logger 18 | 19 | 20 | import warnings 21 | from typing import List, Sequence 22 | 23 | import pytorch_lightning as pl 24 | import rich.syntax 25 | import rich.tree 26 | from omegaconf import DictConfig, OmegaConf 27 | from pytorch_lightning.utilities import rank_zero_only 28 | 29 | 30 | def get_logger(name=__name__, level=logging.INFO) -> logging.Logger: 31 | """Initializes multi-GPU-friendly python logger.""" 32 | 33 | logger = logging.getLogger(name) 34 | logger.setLevel(level) 35 | 36 | # this ensures all logging levels get marked with the rank zero decorator 37 | # otherwise logs would get multiplied for each GPU process in multi-GPU setup 38 | for level in ("debug", "info", "warning", "error", "exception", "fatal", "critical"): 39 | setattr(logger, level, rank_zero_only(getattr(logger, level))) 40 | 41 | return logger 42 | 43 | 44 | def extras(config: DictConfig) -> None: 45 | """A couple of optional utilities, controlled by main config file: 46 | - disabling warnings 47 | - easier access to debug mode 48 | - forcing debug friendly configuration 49 | - forcing multi-gpu friendly configuration 50 | - Ensure correct number of timesteps/etc for all of them 51 | 52 | Modifies DictConfig in place. 53 | 54 | Args: 55 | config (DictConfig): Configuration composed by Hydra. 56 | """ 57 | 58 | log = get_logger() 59 | 60 | # enable adding new keys to config 61 | OmegaConf.set_struct(config, False) 62 | # Ensure that model and dataloader are doing the same thing 63 | config.datamodule.config.forecast_times = ( 64 | config.model.forecast_steps * 5 65 | ) # Convert from steps to minutes 66 | # Get number of channels from config 67 | dataset_config = load_yaml_configuration(config.datamodule.configuration_filename) 68 | 69 | channels = len(dataset_config.process.sat_channels) 70 | log.info(f"Channels: (Bands) {channels}") 71 | channels = channels + 1 if "topo_data" in config.datamodule.required_keys else channels 72 | channels = ( 73 | channels + len(dataset_config.process.nwp_channels) 74 | if "nwp_data" in config.datamodule.required_keys 75 | else channels 76 | ) 77 | log.info(f"Channels: (Use Topo) {channels}") 78 | # Check lat/lon, would only use one coord for MetNet, basic check if using Perceiver or not, only single set of coords 79 | # Perceiver input channels also makes less sense, as each one is put in separately, so NWP and Sat won't be concatenated 80 | if ( 81 | "sat_x_coords" in config.datamodule.required_keys 82 | and "nwp_x_coords" not in config.datamodule.required_keys 83 | ): 84 | channels = channels + 2 if "sat_x_coords" in config.datamodule.required_keys else channels 85 | # If one datetime is in there, all will be, 1 layer for each value 86 | channels = ( 87 | channels + 4 if "hour_of_day_sin" in config.datamodule.required_keys else channels 88 | ) 89 | 90 | config.model.input_channels = channels 91 | 92 | # Update number of iterations per epoch based on accumulate 93 | if config.trainer.get("accumulate_grad_batches"): 94 | config.trainer.limit_train_batches = ( 95 | config.trainer.limit_train_batches * config.trainer.accumulate_grad_batches 96 | ) 97 | 98 | # disable python warnings if 99 | if config.get("ignore_warnings"): 100 | log.info("Disabling python warnings! ") 101 | warnings.filterwarnings("ignore") 102 | 103 | # set if 104 | if config.get("debug"): 105 | log.info("Running in debug mode! ") 106 | config.trainer.fast_dev_run = True 107 | 108 | # force debugger friendly configuration if 109 | if config.trainer.get("fast_dev_run"): 110 | log.info("Forcing debugger friendly configuration! ") 111 | # Debuggers don't like GPUs or multiprocessing 112 | if config.trainer.get("gpus"): 113 | config.trainer.gpus = 0 114 | if config.datamodule.get("pin_memory"): 115 | config.datamodule.pin_memory = False 116 | if config.datamodule.get("num_workers"): 117 | config.datamodule.num_workers = 0 118 | 119 | # force multi-gpu friendly configuration if 120 | accelerator = config.trainer.get("accelerator") 121 | if accelerator in ["ddp", "ddp_spawn", "dp", "ddp2"]: 122 | log.info(f"Forcing ddp friendly configuration! ") 123 | if config.datamodule.get("num_workers"): 124 | config.datamodule.num_workers = 0 125 | if config.datamodule.get("pin_memory"): 126 | config.datamodule.pin_memory = False 127 | 128 | # disable adding new keys to config 129 | OmegaConf.set_struct(config, True) 130 | 131 | 132 | @rank_zero_only 133 | def print_config( 134 | config: DictConfig, 135 | fields: Sequence[str] = ( 136 | "trainer", 137 | "model", 138 | "datamodule", 139 | "callbacks", 140 | "logger", 141 | "hparams_search" 142 | # "logger", 143 | # "seed", 144 | ), 145 | resolve: bool = True, 146 | ) -> None: 147 | """Prints content of DictConfig using Rich library and its tree structure. 148 | 149 | Args: 150 | config (DictConfig): Configuration composed by Hydra. 151 | fields (Sequence[str], optional): Determines which main fields from config will 152 | be printed and in what order. 153 | resolve (bool, optional): Whether to resolve reference fields of DictConfig. 154 | """ 155 | 156 | style = "dim" 157 | tree = rich.tree.Tree(":gear: CONFIG", style=style, guide_style=style) 158 | 159 | for field in fields: 160 | branch = tree.add(field, style=style, guide_style=style) 161 | 162 | config_section = config.get(field) 163 | branch_content = str(config_section) 164 | if isinstance(config_section, DictConfig): 165 | branch_content = OmegaConf.to_yaml(config_section, resolve=resolve) 166 | 167 | branch.add(rich.syntax.Syntax(branch_content, "yaml")) 168 | 169 | rich.print(tree) 170 | 171 | 172 | def empty(*args, **kwargs): 173 | pass 174 | 175 | 176 | @rank_zero_only 177 | def log_hyperparameters( 178 | config: DictConfig, 179 | model: pl.LightningModule, 180 | trainer: pl.Trainer, 181 | ) -> None: 182 | """This method controls which parameters from Hydra config are saved by Lightning loggers. 183 | 184 | Additionaly saves: 185 | - number of trainable model parameters 186 | """ 187 | 188 | hparams = {} 189 | 190 | # choose which parts of hydra config will be saved to loggers 191 | hparams["trainer"] = config["trainer"] 192 | hparams["model"] = config["model"] 193 | hparams["datamodule"] = config["datamodule"] 194 | if "callbacks" in config: 195 | hparams["callbacks"] = config["callbacks"] 196 | 197 | # save number of model parameters 198 | hparams["model/params_total"] = sum(p.numel() for p in model.parameters()) 199 | hparams["model/params_trainable"] = sum( 200 | p.numel() for p in model.parameters() if p.requires_grad 201 | ) 202 | hparams["model/params_not_trainable"] = sum( 203 | p.numel() for p in model.parameters() if not p.requires_grad 204 | ) 205 | 206 | # send hparams to all loggers 207 | trainer.logger.log_hyperparams(hparams) 208 | 209 | # disable logging any more hyperparameters for all loggers 210 | # this is just a trick to prevent trainer from logging hparams of model, 211 | # since we already did that above 212 | trainer.logger.log_hyperparams = empty 213 | -------------------------------------------------------------------------------- /satflow/data/datamodules.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from typing import List, Optional, Tuple, Union 4 | 5 | import torch 6 | from nowcasting_dataset.config.load import load_yaml_configuration 7 | from nowcasting_dataset.consts import ( 8 | DATETIME_FEATURE_NAMES, 9 | NWP_DATA, 10 | NWP_X_COORDS, 11 | NWP_Y_COORDS, 12 | SATELLITE_DATA, 13 | SATELLITE_DATETIME_INDEX, 14 | SATELLITE_X_COORDS, 15 | SATELLITE_Y_COORDS, 16 | TOPOGRAPHIC_DATA, 17 | TOPOGRAPHIC_X_COORDS, 18 | TOPOGRAPHIC_Y_COORDS, 19 | ) 20 | from nowcasting_dataset.dataset.datasets import worker_init_fn 21 | from pytorch_lightning import LightningDataModule 22 | 23 | from satflow.data.datasets import SatFlowDataset 24 | 25 | _LOG = logging.getLogger(__name__) 26 | _LOG.setLevel(logging.DEBUG) 27 | 28 | 29 | class SatFlowDataModule(LightningDataModule): 30 | """ 31 | Example of LightningDataModule for NETCDF dataset. 32 | A DataModule implements 5 key methods: 33 | - prepare_data (things to do on 1 GPU/TPU, not on every GPU/TPU in distributed mode) 34 | - setup (things to do on every accelerator in distributed mode) 35 | - train_dataloader (the training dataloader) 36 | - val_dataloader (the validation dataloader(s)) 37 | - test_dataloader (the test dataloader(s)) 38 | This allows you to share a full dataset without explaining how to download, 39 | split, transform and process the data. 40 | Read the docs: 41 | https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html 42 | """ 43 | 44 | def __init__( 45 | self, 46 | temp_path: str = ".", 47 | n_train_data: int = 24900, 48 | n_val_data: int = 1000, 49 | cloud: str = "aws", 50 | num_workers: int = 8, 51 | pin_memory: bool = True, 52 | configuration_filename="satflow/configs/local.yaml", 53 | fake_data: bool = False, 54 | required_keys: Union[Tuple[str], List[str]] = [ 55 | NWP_DATA, 56 | NWP_X_COORDS, 57 | NWP_Y_COORDS, 58 | SATELLITE_DATA, 59 | SATELLITE_X_COORDS, 60 | SATELLITE_Y_COORDS, 61 | SATELLITE_DATETIME_INDEX, 62 | TOPOGRAPHIC_DATA, 63 | TOPOGRAPHIC_X_COORDS, 64 | TOPOGRAPHIC_Y_COORDS, 65 | ] 66 | + list(DATETIME_FEATURE_NAMES), 67 | history_minutes: Optional[int] = None, 68 | forecast_minutes: Optional[int] = None, 69 | ): 70 | """ 71 | fake_data: random data is created and used instead. This is useful for testing 72 | """ 73 | super().__init__() 74 | 75 | self.temp_path = temp_path 76 | self.configuration = load_yaml_configuration(configuration_filename) 77 | self.cloud = cloud 78 | self.n_train_data = n_train_data 79 | self.n_val_data = n_val_data 80 | self.num_workers = num_workers 81 | self.pin_memory = pin_memory 82 | self.fake_data = fake_data 83 | self.required_keys = required_keys 84 | self.forecast_minutes = forecast_minutes 85 | self.history_minutes = history_minutes 86 | 87 | self.dataloader_config = dict( 88 | pin_memory=self.pin_memory, 89 | num_workers=self.num_workers, 90 | prefetch_factor=8, 91 | worker_init_fn=worker_init_fn, 92 | persistent_workers=True, 93 | # Disable automatic batching because dataset 94 | # returns complete batches. 95 | batch_size=None, 96 | ) 97 | 98 | def train_dataloader(self): 99 | if self.fake_data: 100 | train_dataset = FakeDataset( 101 | history_minutes=self.history_minutes, forecast_minutes=self.forecast_minutes 102 | ) 103 | else: 104 | train_dataset = SatFlowDataset( 105 | self.n_train_data, 106 | os.path.join(self.configuration.output_data.filepath, "train"), 107 | os.path.join(self.temp_path, "train"), 108 | configuration=self.configuration, 109 | cloud=self.cloud, 110 | required_keys=self.required_keys, 111 | history_minutes=self.history_minutes, 112 | forecast_minutes=self.forecast_minutes, 113 | ) 114 | 115 | return torch.utils.data.DataLoader(train_dataset, **self.dataloader_config) 116 | 117 | def val_dataloader(self): 118 | if self.fake_data: 119 | val_dataset = FakeDataset( 120 | history_minutes=self.history_minutes, forecast_minutes=self.forecast_minutes 121 | ) 122 | else: 123 | val_dataset = SatFlowDataset( 124 | self.n_val_data, 125 | os.path.join(self.configuration.output_data.filepath, "validation"), 126 | os.path.join(self.temp_path, "validation"), 127 | configuration=self.configuration, 128 | cloud=self.cloud, 129 | required_keys=self.required_keys, 130 | history_minutes=self.history_minutes, 131 | forecast_minutes=self.forecast_minutes, 132 | ) 133 | 134 | return torch.utils.data.DataLoader(val_dataset, **self.dataloader_config) 135 | 136 | def test_dataloader(self): 137 | if self.fake_data: 138 | test_dataset = FakeDataset( 139 | history_minutes=self.history_minutes, forecast_minutes=self.forecast_minutes 140 | ) 141 | else: 142 | # TODO need to change this to a test folder 143 | test_dataset = SatFlowDataset( 144 | self.n_val_data, 145 | os.path.join(self.configuration.output_data.filepath, "test"), 146 | os.path.join(self.temp_path, "test"), 147 | configuration=self.configuration, 148 | cloud=self.cloud, 149 | required_keys=self.required_keys, 150 | history_minutes=self.history_minutes, 151 | forecast_minutes=self.forecast_minutes, 152 | ) 153 | 154 | return torch.utils.data.DataLoader(test_dataset, **self.dataloader_config) 155 | 156 | 157 | class FakeDataset(torch.utils.data.Dataset): 158 | """Fake dataset.""" 159 | 160 | def __init__( 161 | self, 162 | batch_size=32, 163 | width=16, 164 | height=16, 165 | number_sat_channels=12, 166 | length=10, 167 | history_minutes=30, 168 | forecast_minutes=30, 169 | ): 170 | self.batch_size = batch_size 171 | if history_minutes is None or forecast_minutes is None: 172 | history_minutes = 30 # Half an hour 173 | forecast_minutes = 240 # 4 hours 174 | self.history_steps = history_minutes // 5 175 | self.forecast_steps = forecast_minutes // 5 176 | self.seq_length = self.history_steps + 1 177 | self.width = width 178 | self.height = height 179 | self.number_sat_channels = number_sat_channels 180 | self.length = length 181 | 182 | def __len__(self): 183 | return self.length 184 | 185 | def per_worker_init(self, worker_id: int): 186 | pass 187 | 188 | def __getitem__(self, idx): 189 | 190 | x = { 191 | SATELLITE_DATA: torch.randn( 192 | self.batch_size, self.seq_length, self.width, self.height, self.number_sat_channels 193 | ), 194 | NWP_DATA: torch.randn(self.batch_size, 10, self.seq_length, 2, 2), 195 | "hour_of_day_sin": torch.randn(self.batch_size, self.seq_length), 196 | "hour_of_day_cos": torch.randn(self.batch_size, self.seq_length), 197 | "day_of_year_sin": torch.randn(self.batch_size, self.seq_length), 198 | "day_of_year_cos": torch.randn(self.batch_size, self.seq_length), 199 | } 200 | 201 | # add fake x and y coords, and make sure they are sorted 202 | x[SATELLITE_X_COORDS], _ = torch.sort(torch.randn(self.batch_size, self.seq_length)) 203 | x[SATELLITE_Y_COORDS], _ = torch.sort( 204 | torch.randn(self.batch_size, self.seq_length), descending=True 205 | ) 206 | 207 | # add sorted (fake) time series 208 | x[SATELLITE_DATETIME_INDEX], _ = torch.sort(torch.randn(self.batch_size, self.seq_length)) 209 | 210 | y = { 211 | SATELLITE_DATA: torch.randn( 212 | self.batch_size, 213 | self.forecast_steps, 214 | self.width, 215 | self.height, 216 | self.number_sat_channels, 217 | ), 218 | } 219 | return x, y 220 | -------------------------------------------------------------------------------- /satflow/models/conv_lstm.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Union 2 | 3 | import pytorch_lightning as pl 4 | import torch 5 | import torch.nn as nn 6 | import torchvision 7 | from nowcasting_utils.models.base import register_model 8 | from nowcasting_utils.models.loss import get_loss 9 | 10 | from satflow.models.layers.ConvLSTM import ConvLSTMCell 11 | 12 | 13 | @register_model 14 | class EncoderDecoderConvLSTM(pl.LightningModule): 15 | def __init__( 16 | self, 17 | hidden_dim: int = 64, 18 | input_channels: int = 12, 19 | out_channels: int = 1, 20 | forecast_steps: int = 48, 21 | lr: float = 0.001, 22 | visualize: bool = False, 23 | loss: Union[str, torch.nn.Module] = "mse", 24 | pretrained: bool = False, 25 | conv_type: str = "standard", 26 | ): 27 | super(EncoderDecoderConvLSTM, self).__init__() 28 | self.forecast_steps = forecast_steps 29 | self.criterion = get_loss(loss) 30 | self.lr = lr 31 | self.visualize = visualize 32 | self.model = ConvLSTM(input_channels, hidden_dim, out_channels, conv_type=conv_type) 33 | self.save_hyperparameters() 34 | 35 | @classmethod 36 | def from_config(cls, config): 37 | return EncoderDecoderConvLSTM( 38 | hidden_dim=config.get("num_hidden", 64), 39 | input_channels=config.get("in_channels", 12), 40 | out_channels=config.get("out_channels", 1), 41 | forecast_steps=config.get("forecast_steps", 1), 42 | lr=config.get("lr", 0.001), 43 | ) 44 | 45 | def forward(self, x, future_seq=0, hidden_state=None): 46 | return self.model.forward(x, future_seq, hidden_state) 47 | 48 | def configure_optimizers(self): 49 | # DeepSpeedCPUAdam provides 5x to 7x speedup over torch.optim.adam(w) 50 | # optimizer = torch.optim.adam() 51 | return torch.optim.Adam(self.parameters(), lr=self.lr) 52 | 53 | def training_step(self, batch, batch_idx): 54 | x, y = batch 55 | y_hat = self(x, self.forecast_steps) 56 | y_hat = torch.permute(y_hat, dims=(0, 2, 1, 3, 4)) 57 | # Generally only care about the center x crop, so the model can take into account the clouds in the area without 58 | # being penalized for that, but for now, just do general MSE loss, also only care about first 12 channels 59 | # the logger you used (in this case tensorboard) 60 | # if self.visualize: 61 | # if np.random.random() < 0.01: 62 | # self.visualize_step(x, y, y_hat, batch_idx) 63 | loss = self.criterion(y_hat, y) 64 | self.log("train/loss", loss, on_step=True) 65 | frame_loss_dict = {} 66 | for f in range(self.forecast_steps): 67 | frame_loss = self.criterion(y_hat[:, f, :, :, :], y[:, f, :, :, :]).item() 68 | frame_loss_dict[f"train/frame_{f}_loss"] = frame_loss 69 | self.log_dict(frame_loss_dict, on_step=False, on_epoch=True) 70 | return loss 71 | 72 | def validation_step(self, batch, batch_idx): 73 | x, y = batch 74 | y_hat = self(x, self.forecast_steps) 75 | y_hat = torch.permute(y_hat, dims=(0, 2, 1, 3, 4)) 76 | val_loss = self.criterion(y_hat, y) 77 | # Save out loss per frame as well 78 | frame_loss_dict = {} 79 | # y_hat = torch.moveaxis(y_hat, 2, 1) 80 | for f in range(self.forecast_steps): 81 | frame_loss = self.criterion(y_hat[:, f, :, :, :], y[:, f, :, :, :]).item() 82 | frame_loss_dict[f"val/frame_{f}_loss"] = frame_loss 83 | self.log("val/loss", val_loss, on_step=True, on_epoch=True) 84 | self.log_dict(frame_loss_dict, on_step=False, on_epoch=True) 85 | return val_loss 86 | 87 | def test_step(self, batch, batch_idx): 88 | x, y = batch 89 | y_hat = self(x, self.forecast_steps) 90 | loss = self.criterion(y_hat, y) 91 | return loss 92 | 93 | def visualize_step(self, x, y, y_hat, batch_idx, step="train"): 94 | tensorboard = self.logger.experiment[0] 95 | # Add all the different timesteps for a single prediction, 0.1% of the time 96 | if len(x.shape) == 5: 97 | # Timesteps per channel 98 | images = x[0].cpu().detach() 99 | for i, t in enumerate(images): # Now would be (C, H, W) 100 | t = [torch.unsqueeze(img, dim=0) for img in t] 101 | image_grid = torchvision.utils.make_grid(t, nrow=self.input_channels) 102 | tensorboard.add_image( 103 | f"{step}/Input_Image_Stack_Frame_{i}", image_grid, global_step=batch_idx 104 | ) 105 | images = y[0].cpu().detach() 106 | for i, t in enumerate(images): # Now would be (C, H, W) 107 | t = [torch.unsqueeze(img, dim=0) for img in t] 108 | image_grid = torchvision.utils.make_grid(t, nrow=self.output_channels) 109 | tensorboard.add_image( 110 | f"{step}/Target_Image_Stack_Frame_{i}", image_grid, global_step=batch_idx 111 | ) 112 | images = y_hat[0].cpu().detach() 113 | for i, t in enumerate(images): # Now would be (C, H, W) 114 | t = [torch.unsqueeze(img, dim=0) for img in t] 115 | image_grid = torchvision.utils.make_grid(t, nrow=self.output_channels) 116 | tensorboard.add_image( 117 | f"{step}/Generated_Stack_Frame_{i}", image_grid, global_step=batch_idx 118 | ) 119 | 120 | 121 | class ConvLSTM(torch.nn.Module): 122 | def __init__(self, input_channels, hidden_dim, out_channels, conv_type: str = "standard"): 123 | super().__init__() 124 | """ ARCHITECTURE 125 | 126 | # Encoder (ConvLSTM) 127 | # Encoder Vector (final hidden state of encoder) 128 | # Decoder (ConvLSTM) - takes Encoder Vector as input 129 | # Decoder (3D CNN) - produces regression predictions for our model 130 | 131 | """ 132 | self.encoder_1_convlstm = ConvLSTMCell( 133 | input_dim=input_channels, 134 | hidden_dim=hidden_dim, 135 | kernel_size=(3, 3), 136 | bias=True, 137 | conv_type=conv_type, 138 | ) 139 | 140 | self.encoder_2_convlstm = ConvLSTMCell( 141 | input_dim=hidden_dim, 142 | hidden_dim=hidden_dim, 143 | kernel_size=(3, 3), 144 | bias=True, 145 | conv_type=conv_type, 146 | ) 147 | 148 | self.decoder_1_convlstm = ConvLSTMCell( 149 | input_dim=hidden_dim, 150 | hidden_dim=hidden_dim, 151 | kernel_size=(3, 3), 152 | bias=True, # nf + 1 153 | conv_type=conv_type, 154 | ) 155 | 156 | self.decoder_2_convlstm = ConvLSTMCell( 157 | input_dim=hidden_dim, 158 | hidden_dim=hidden_dim, 159 | kernel_size=(3, 3), 160 | bias=True, 161 | conv_type=conv_type, 162 | ) 163 | 164 | self.decoder_CNN = nn.Conv3d( 165 | in_channels=hidden_dim, 166 | out_channels=out_channels, 167 | kernel_size=(1, 3, 3), 168 | padding=(0, 1, 1), 169 | ) 170 | 171 | def autoencoder(self, x, seq_len, future_step, h_t, c_t, h_t2, c_t2, h_t3, c_t3, h_t4, c_t4): 172 | 173 | outputs = [] 174 | 175 | # encoder 176 | for t in range(seq_len): 177 | h_t, c_t = self.encoder_1_convlstm( 178 | input_tensor=x[:, t, :, :], cur_state=[h_t, c_t] 179 | ) # we could concat to provide skip conn here 180 | h_t2, c_t2 = self.encoder_2_convlstm( 181 | input_tensor=h_t, cur_state=[h_t2, c_t2] 182 | ) # we could concat to provide skip conn here 183 | 184 | # encoder_vector 185 | encoder_vector = h_t2 186 | 187 | # decoder 188 | for t in range(future_step): 189 | h_t3, c_t3 = self.decoder_1_convlstm( 190 | input_tensor=encoder_vector, cur_state=[h_t3, c_t3] 191 | ) # we could concat to provide skip conn here 192 | h_t4, c_t4 = self.decoder_2_convlstm( 193 | input_tensor=h_t3, cur_state=[h_t4, c_t4] 194 | ) # we could concat to provide skip conn here 195 | encoder_vector = h_t4 196 | outputs += [h_t4] # predictions 197 | 198 | outputs = torch.stack(outputs, 1) 199 | outputs = outputs.permute(0, 2, 1, 3, 4) 200 | outputs = self.decoder_CNN(outputs) 201 | outputs = torch.nn.Sigmoid()(outputs) 202 | 203 | return outputs 204 | 205 | def forward(self, x, forecast_steps=0, hidden_state=None): 206 | 207 | """ 208 | Parameters 209 | ---------- 210 | input_tensor: 211 | 5-D Tensor of shape (b, t, c, h, w) # batch, time, channel, height, width 212 | """ 213 | 214 | # find size of different input dimensions 215 | b, seq_len, _, h, w = x.size() 216 | 217 | # initialize hidden states 218 | h_t, c_t = self.encoder_1_convlstm.init_hidden(batch_size=b, image_size=(h, w)) 219 | h_t2, c_t2 = self.encoder_2_convlstm.init_hidden(batch_size=b, image_size=(h, w)) 220 | h_t3, c_t3 = self.decoder_1_convlstm.init_hidden(batch_size=b, image_size=(h, w)) 221 | h_t4, c_t4 = self.decoder_2_convlstm.init_hidden(batch_size=b, image_size=(h, w)) 222 | 223 | # autoencoder forward 224 | outputs = self.autoencoder( 225 | x, seq_len, forecast_steps, h_t, c_t, h_t2, c_t2, h_t3, c_t3, h_t4, c_t4 226 | ) 227 | 228 | return outputs 229 | --------------------------------------------------------------------------------