├── .gitignore ├── README.md ├── best-submission ├── doxa_cli.py ├── example.ipynb ├── figures │ └── lr_find.png ├── inference_unet.ipynb ├── submission │ ├── climatehack.py │ ├── doxa.yaml │ ├── evaluate.py │ ├── model.py │ └── validate.py ├── submit.sh ├── test_and_visualize.ipynb ├── timm.ipynb ├── train.sh ├── train_timm.py ├── train_unet.ipynb ├── train_unet.py ├── train_unet_fast.ipynb ├── utils │ ├── __init__.py │ ├── data.py │ ├── loss.py │ ├── models.py │ └── preprocess.py └── validate.sh ├── common ├── __init__.py ├── checkpointing.py ├── denoiser.py ├── loss_utils.py ├── q20.mat └── utils.py ├── data ├── download_data.ipynb └── download_good_sun.ipynb ├── environment.yaml ├── experiments ├── climatehack-submission │ ├── .gitignore │ ├── doxa_cli.py │ ├── submission │ │ ├── climatehack.py │ │ ├── dgmr-oneshot │ │ │ └── dgmr │ │ │ │ ├── __init__.py │ │ │ │ ├── common.py │ │ │ │ ├── dgmr.py │ │ │ │ ├── discriminators.py │ │ │ │ ├── generators.py │ │ │ │ ├── hub.py │ │ │ │ ├── layers │ │ │ │ ├── Attention.py │ │ │ │ ├── ConvGRU.py │ │ │ │ ├── CoordConv.py │ │ │ │ ├── __init__.py │ │ │ │ └── utils.py │ │ │ │ └── losses.py │ │ ├── doxa.yaml │ │ ├── evaluate.py │ │ └── validate.py │ └── submit.sh ├── dgmr-dct │ └── dgmr │ │ ├── __init__.py │ │ ├── common.py │ │ ├── dgmr.py │ │ ├── discriminators.py │ │ ├── generators.py │ │ ├── hub.py │ │ ├── layers │ │ ├── Attention.py │ │ ├── ConvGRU.py │ │ ├── CoordConv.py │ │ ├── __init__.py │ │ └── utils.py │ │ └── losses.py ├── dgmr-multichannel │ └── dgmr │ │ ├── __init__.py │ │ ├── common.py │ │ ├── dgmr.py │ │ ├── discriminators.py │ │ ├── generators.py │ │ ├── hub.py │ │ ├── layers │ │ ├── Attention.py │ │ ├── ConvGRU.py │ │ ├── CoordConv.py │ │ ├── __init__.py │ │ └── utils.py │ │ └── losses.py ├── dgmr-oneshot-multichannel │ └── dgmr │ │ ├── __init__.py │ │ ├── common.py │ │ ├── dgmr.py │ │ ├── discriminators.py │ │ ├── generators.py │ │ ├── hub.py │ │ ├── layers │ │ ├── Attention.py │ │ ├── ConvGRU.py │ │ ├── CoordConv.py │ │ ├── __init__.py │ │ └── utils.py │ │ └── losses.py ├── dgmr-oneshot │ └── dgmr │ │ ├── __init__.py │ │ ├── common.py │ │ ├── dgmr.py │ │ ├── discriminators.py │ │ ├── generators.py │ │ ├── hub.py │ │ ├── layers │ │ ├── Attention.py │ │ ├── ConvGRU.py │ │ ├── CoordConv.py │ │ ├── __init__.py │ │ └── utils.py │ │ └── losses.py ├── dgmr-original │ └── dgmr │ │ ├── __init__.py │ │ ├── common.py │ │ ├── dgmr.py │ │ ├── discriminators.py │ │ ├── generators.py │ │ ├── hub.py │ │ ├── layers │ │ ├── Attention.py │ │ ├── ConvGRU.py │ │ ├── CoordConv.py │ │ ├── __init__.py │ │ └── utils.py │ │ └── losses.py ├── optical_flow_hyperopt.ipynb ├── plot_hist.ipynb ├── train_gen.ipynb ├── train_gen_dct.ipynb ├── train_gen_oneshot-coordconv.ipynb ├── train_gen_oneshot.ipynb ├── train_gen_oneshot_optflow.ipynb └── visualize_and_test.ipynb └── figs ├── final_leaderboard.png └── model_predictions.gif /.gitignore: -------------------------------------------------------------------------------- 1 | # OWN 2 | *.pt 3 | *.npz 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | pip-wheel-metadata/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | 135 | # Custom 136 | data/ 137 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Climatehack 2 | 3 | This is the repository for Illinois's Climatehack Team. We earned first place on the [leaderboard](https://climatehack.ai/compete/leaderboard/universities) with a final score of 0.87992. 4 | 5 |

6 | Final Leaderboard 7 |

8 | 9 | An overview of our approach can be found [here](https://docs.google.com/presentation/d/1P_cv3R7gTRXG41wFPXT2lZe9E1GnKqtaJVqe-vsAvL0/edit?usp=sharing). 10 | 11 | Example predictions: 12 | ![](figs/model_predictions.gif) 13 | 14 | # Setup 15 | ```bash 16 | conda env create -f environment.yaml 17 | conda activate climatehack 18 | python -m ipykernel install --user --name=climatehack 19 | ``` 20 | 21 | First, download data by running `data/download_data.ipynb`. Alternatively, you can find preprocessed data files [here](https://drive.google.com/drive/folders/1JkPKjOBtm3dlOl2fRTvaLkSu7KnZsJGw?usp=sharing). Save them into the `data` folder. We used `train.npz` and `test.npz`. They consist of data temporally cropped from 10am to 4pm UK time across the entire dataset. You could also use `data_good_sun_2020.npz` and `data_good_sun_2021.npz`, which consist of all samples where the sun elevation is at least 10 degrees. Because these crops produced datasets that could fit in-memory, all our dataloaders work in-memory. 22 | 23 | 24 | # Best Submission 25 | Our best submission earned scores exceeding 0.85 on the Climatehack [leaderboard](https://climatehack.ai/compete/leaderboard). It is relatively simple and uses the `fastai` library to pick a base model, optimizer, and learning rate scheduler. After some experimentation, we chose `xse_resnext50_deeper`. We turned it into a UNET and trained it. More info is in the [slides](https://docs.google.com/presentation/d/1P_cv3R7gTRXG41wFPXT2lZe9E1GnKqtaJVqe-vsAvL0/edit?usp=sharing). 26 | 27 | To train: 28 | ```bash 29 | cd best-submission 30 | bash train.sh 31 | ``` 32 | 33 | To submit, first move the trained model `xse_resnext50_deeper.pth` into `best-submission/submission`. 34 | ```bash 35 | cd best-submission 36 | python doxa_cli.py user login 37 | bash submit.sh 38 | ``` 39 | 40 | Also, check out `best-submission/test_and_visualize.ipynb` to test the model and visualize results in a nice animation. This is how we produced the animations found in `figs/model_predictions.gif`. 41 | 42 | # Experiments 43 | We conducted several experiments that showed improvements on a strong baseline. The baseline was OpenClimateFix's skillful nowcasting [repo](https://github.com/openclimatefix/skillful_nowcasting), which itself is a implementation of Deepmind's precipitation forecasting GAN. This baseline is more-or-less copied to `experiments/dgmr-original`. One important difference is that instead of training the GAN, we just train the generator. This was doing well for us and training the GAN had much slower convergence. This baseline will actually train to a score greater than 0.8 on the Climatehack [leaderboard](https://climatehack.ai/compete/leaderboard). We didn't have time to properly test these experiments on top of our best model, but we suspect they would improve results. The experiments are summarized below: 44 | 45 | Experiment | Description | Results | 46 | --- | --- | --- | 47 | DCT-Trick | Inspired by [this](https://proceedings.neurips.cc/paper/2018/file/7af6266cc52234b5aa339b16695f7fc4-Paper.pdf), we use the DCT to turn 128x128 -> 64x16x16 and IDCT to turn 64x16x16 -> 128x128. This leads to a shallower network that is autoregressive at fewer spatial resolutions. We believe this is the first time this has been done with UNETs. A fast implementation is in `common/utils.py:create_conv_dct_filter` and `common/utils.py:get_idct_filter`. | 1.8-2x speedup, small <0.005 performance drop | 48 | Denoising | We noticed a lot of blocky artifacts in predictions. These artifacts are reminiscent of JPEG/H.264 compression artifacts. We show a comparison of these artifacts in the [slides](https://docs.google.com/presentation/d/1P_cv3R7gTRXG41wFPXT2lZe9E1GnKqtaJVqe-vsAvL0/edit?usp=sharing). We found a pretrained neural network to fix them. This can definitely be done better, but we show a proof-of-concept. | No performance drop, small visual improvement. The slides have an example. | 49 | CoordConv | Meteorological phenomenon are correlated with geographic coordinates. We add 2 input channels for the geographic coordinates in OSGB form. | +0.0072 MS-SSIM improvement | 50 | Optical Flow | Optical flow does well for the first few timesteps. We add 2 input channels for the optical flow vectors. | +0.0034 MS-SSIM improvement | 51 | 52 | The folder `experiments/climatehack-submission` was used to submit these experiments. 53 | ```bash 54 | cd experiments/climatehack-submission 55 | python doxa_cli.py user login 56 | bash submit.sh 57 | ``` 58 | 59 | Use `experiments/test_and_visualize.ipynb` to test the model and visualize results in a nice animation. 60 | -------------------------------------------------------------------------------- /best-submission/doxa_cli.py: -------------------------------------------------------------------------------- 1 | # NOTE: 2 | # This file downloads the doxa_cli to allow you to upload an agent. 3 | # You do NOT need to edit this file, use it as is. 4 | 5 | import sys 6 | 7 | if sys.version_info[0] != 3: 8 | print("Please run this script using python3") 9 | sys.exit(1) 10 | 11 | import json 12 | import os 13 | import platform 14 | import stat 15 | import subprocess 16 | import tarfile 17 | import urllib.error 18 | import urllib.request 19 | 20 | 21 | # Returns `windows`, `darwin` (macos) or `linux` 22 | def get_os(): 23 | system = platform.system() 24 | 25 | if system == "Linux": 26 | # The exe version works better for WSL 27 | if "microsoft" in platform.platform(): 28 | return "windows" 29 | 30 | return "linux" 31 | elif system == "Windows": 32 | return "windows" 33 | elif system == "Darwin": 34 | return "darwin" 35 | else: 36 | raise Exception(f"Unknown platform {system}") 37 | 38 | 39 | def get_bin_name(): 40 | bin_name = "doxa_cli" 41 | if get_os() == "windows": 42 | bin_name = "doxa_cli.exe" 43 | return bin_name 44 | 45 | 46 | def get_bin_dir(): 47 | return os.path.join(os.path.dirname(__file__), "bin") 48 | 49 | 50 | def get_binary(): 51 | return os.path.join(get_bin_dir(), get_bin_name()) 52 | 53 | 54 | def install_binary(): 55 | match_release = None 56 | try: 57 | match_release = sys.argv[1] 58 | # Arguments are not required 59 | except IndexError: 60 | pass 61 | 62 | REPO_RELEASE_URL = "https://api.github.com/repos/louisdewar/doxa/releases/latest" 63 | try: 64 | f = urllib.request.urlopen(REPO_RELEASE_URL) 65 | except urllib.error.URLError: 66 | print("There was an SSL cert verification error") 67 | print( 68 | 'If you are on a mac and you have recently installed a new version of\ 69 | python then you should navigate to "/Applications/Python {VERSION}/"' 70 | ) 71 | print('Then run a script in that folder called "Install Certificates.command"') 72 | sys.exit(1) 73 | 74 | response = json.loads(f.read()) 75 | 76 | print("Current version tag:", response["tag_name"]) 77 | 78 | assets = [ 79 | asset for asset in response["assets"] if asset["name"].endswith(".tar.gz") 80 | ] 81 | 82 | # Find the release for this OS 83 | match_release = get_os() 84 | try: 85 | asset_choice = next(asset for asset in assets if match_release in asset["name"]) 86 | print( 87 | 'Automatically picked {} to download based on match "{}"\n'.format( 88 | asset_choice["name"], match_release 89 | ) 90 | ) 91 | except StopIteration: 92 | print('Couldn\'t find "{}" in releases'.format(match_release)) 93 | sys.exit(1) 94 | 95 | download_url = asset_choice["browser_download_url"] 96 | 97 | # Folder where this script is + bin 98 | bin_dir = get_bin_dir() 99 | 100 | print("Downloading", asset_choice["name"], "to", bin_dir) 101 | print("({})".format(download_url)) 102 | 103 | # Clear bin directory if it exists 104 | if not os.path.exists(bin_dir): 105 | os.mkdir(bin_dir) 106 | 107 | zip_path = os.path.join(bin_dir, asset_choice["name"]) 108 | 109 | # Download zip file 110 | urllib.request.urlretrieve(download_url, zip_path) 111 | 112 | # Open and extract zip file 113 | tar_file = tarfile.open(zip_path) 114 | tar_file.extractall(bin_dir) 115 | tar_file.close() 116 | 117 | # Delete zip file 118 | os.remove(zip_path) 119 | 120 | # Path to the actual binary program (called doxa_cli or doxa_cli.exe) 121 | bin_name = get_bin_name() 122 | binary_path = os.path.join(bin_dir, bin_name) 123 | 124 | if not os.path.exists(binary_path): 125 | print(f"Couldn't find the binary file `{bin_name}` in the bin directory") 126 | print("This probably means that there was a problem with the download") 127 | sys.exit(1) 128 | 129 | if get_os() != "windows": 130 | # Make binary executable 131 | st = os.stat(binary_path) 132 | os.chmod(binary_path, st.st_mode | stat.S_IEXEC) 133 | 134 | # Run help 135 | print("Installed binary\n\n") 136 | 137 | 138 | def run_command(args): 139 | bin_path = get_binary() 140 | 141 | if not os.path.exists(bin_path): 142 | install_binary() 143 | subprocess.call([bin_path] + args) 144 | 145 | 146 | if __name__ == "__main__": 147 | run_command(sys.argv[1:]) 148 | -------------------------------------------------------------------------------- /best-submission/figures/lr_find.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jmathur25/climatehack/96a551131e2482601265595a80bce315c60874ac/best-submission/figures/lr_find.png -------------------------------------------------------------------------------- /best-submission/submission/climatehack.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | from typing import List, Tuple 4 | 5 | import numpy as np 6 | 7 | 8 | class BaseEvaluator: 9 | def __init__(self) -> None: 10 | self.setup() 11 | 12 | def setup(self): 13 | """Sets up anything required for evaluation, e.g. models.""" 14 | pass 15 | 16 | def predict(self, coordinates: np.ndarray, data: np.ndarray) -> np.ndarray: 17 | """Makes a prediction for the next two hours of satellite imagery. 18 | 19 | Args: 20 | coordinates (np.ndarray): the OSGB x and y coordinates (2, 128, 128) 21 | data (np.ndarray): an array of 12 128*128 satellite images (12, 128, 128) 22 | 23 | Returns: 24 | np.ndarray: an array of 24 64*64 satellite image predictions (24, 64, 64) 25 | """ 26 | 27 | raise NotImplementedError( 28 | "You need to extend this class to use your trained model(s)." 29 | ) 30 | 31 | def _get_io_paths(self) -> Tuple[Path, Path]: 32 | """Gets the input and output directory paths from DOXA. 33 | 34 | Returns: 35 | Tuple[Path, Path]: The input and output paths 36 | """ 37 | try: 38 | return Path(sys.argv[1]), Path(sys.argv[2]) 39 | except IndexError: 40 | raise Exception( 41 | f"Run using: {sys.argv[0]} [input directory] [output directory]" 42 | ) 43 | 44 | def _get_group_path(self) -> str: 45 | """Gets the path for the next group to be processed. 46 | 47 | Raises: 48 | ValueError: An unknown message was received from DOXA. 49 | 50 | Returns: 51 | str: The path of the next group. 52 | """ 53 | 54 | msg = input() 55 | if not msg.startswith("Process "): 56 | raise ValueError(f"Unknown messsage {msg}") 57 | 58 | return msg[8:] 59 | 60 | def _evaluate_group(self, group: dict) -> List[np.ndarray]: 61 | """Evaluates a group of satellite image sequences using 62 | the user-implemented model(s). 63 | 64 | Args: 65 | group (dict): The OSGB and satellite imagery data. 66 | 67 | Returns: 68 | List[np.ndarray]: The predictions. 69 | """ 70 | 71 | return [self.predict(*datum) for datum in zip(group["osgb"], group["data"])] 72 | 73 | def evaluate(self): 74 | """Evaluates the user's model on DOXA. 75 | 76 | Messages are sent and received through stdio. 77 | 78 | The input data is loaded from a directory in groups. 79 | 80 | The predictions are written to another directory in groups. 81 | 82 | Raises: 83 | Exception: An error occurred somewhere. 84 | """ 85 | 86 | print("STARTUP") 87 | input_path, output_path = self._get_io_paths() 88 | 89 | # process test data groups 90 | while True: 91 | # load the data for the group DOXA requests 92 | group_path = self._get_group_path() 93 | group_data = np.load(input_path / group_path) 94 | 95 | # make predictions for this group 96 | try: 97 | predictions = self._evaluate_group(group_data) 98 | except Exception as err: 99 | raise Exception(f"Error while processing {group_path}: {str(err)}") 100 | 101 | # save the output group predictions 102 | np.savez( 103 | output_path / group_path, 104 | data=np.stack(predictions), 105 | ) 106 | print(f"Exported {group_path}") 107 | -------------------------------------------------------------------------------- /best-submission/submission/doxa.yaml: -------------------------------------------------------------------------------- 1 | language: python 2 | entrypoint: evaluate.py 3 | -------------------------------------------------------------------------------- /best-submission/submission/evaluate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from functools import partial 5 | 6 | from climatehack import BaseEvaluator 7 | 8 | # from model import Model 9 | 10 | from fastai.vision.all import create_unet_model, models 11 | from fastai.vision.models.xresnet import * 12 | from fastai.layers import Mish 13 | 14 | 15 | class Evaluator(BaseEvaluator): 16 | def setup(self): 17 | """Sets up anything required for evaluation. 18 | 19 | In this case, it loads the trained model (in evaluation mode).""" 20 | 21 | arch = partial(xse_resnext50_deeper, act_cls=Mish, sa=True) 22 | self.model = create_unet_model( 23 | arch=arch, 24 | img_size=(128, 128), 25 | n_out=24, 26 | pretrained=False, 27 | n_in=12, 28 | self_attention=True, 29 | ).cpu() 30 | 31 | self.model.load_state_dict( 32 | torch.load("xse_resnext50_deeper.pth", map_location="cpu") 33 | ) 34 | self.model.eval() 35 | 36 | self.arcnn = ARCNN(weights).to("cpu").eval() 37 | 38 | def predict(self, coordinates: np.ndarray, data: np.ndarray) -> np.ndarray: 39 | """Makes a prediction for the next two hours of satellite imagery. 40 | 41 | Args: 42 | coordinates (np.ndarray): the OSGB x and y coordinates (2, 128, 128) 43 | data (np.ndarray): an array of 12 128*128 satellite images (12, 128, 128) 44 | 45 | Returns: 46 | np.ndarray: an array of 24 64*64 satellite image predictions (24, 64, 64) 47 | """ 48 | 49 | assert coordinates.shape == (2, 128, 128) 50 | assert data.shape == (12, 128, 128) 51 | 52 | with torch.no_grad(): 53 | 54 | inp = torch.from_numpy(data).unsqueeze(0) 55 | preds = self.model(inp) 56 | assert prediction.shape == (24, 64, 64) 57 | 58 | return prediction 59 | 60 | 61 | def main(): 62 | evaluator = Evaluator() 63 | evaluator.evaluate() 64 | 65 | 66 | if __name__ == "__main__": 67 | main() 68 | -------------------------------------------------------------------------------- /best-submission/submission/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from fastai.vision.all import * 5 | 6 | ######################################### 7 | # Improve this basic model! # 8 | ######################################### 9 | 10 | 11 | class Model(nn.Module): 12 | def __init__(self): 13 | super().__init__() 14 | 15 | self.layer1 = nn.Linear(in_features=12 * 128 * 128, out_features=256) 16 | self.layer2 = nn.Linear(in_features=256, out_features=256) 17 | self.layer3 = nn.Linear(in_features=256, out_features=24 * 64 * 64) 18 | 19 | def forward(self, features): 20 | x = features.view(-1, 12 * 128 * 128) / 1024.0 21 | x = torch.relu(self.layer1(x)) 22 | x = torch.relu(self.layer2(x)) 23 | x = torch.relu(self.layer3(x)) 24 | 25 | return x.view(-1, 24, 64, 64) * 1024.0 26 | 27 | 28 | class CenterCrop(nn.Module): 29 | def __init__(self, size=(64, 64)): 30 | super().__init__() 31 | 32 | def forward(self, x): 33 | return x[:, :, 32:96, 32:96] 34 | 35 | 36 | class UNet(nn.Module): 37 | def __init__(self, in_channels=2, out_channels=24, weights=None): 38 | super(UNet, self).__init__() 39 | self.model = create_unet_model( 40 | arch=models.resnet50, 41 | img_size=(128, 128), 42 | n_out=in_channels, 43 | pretrained=True, 44 | n_in=out_channels, 45 | self_attention=True, 46 | ) 47 | 48 | if weights: 49 | self.load_state_dict(torch.load(weights)) 50 | self.model.layers.add_module("CenterCrop", CenterCrop()) 51 | 52 | def forward(self, x): 53 | return self.model(x) 54 | 55 | 56 | class EnsembleNet(nn.Module): 57 | def __init__(self, in_channels=2, out_channels=24, weights=None): 58 | super(EnsembleNet, self).__init__() 59 | self.model = create_unet_model( 60 | arch=models.resnet50, 61 | img_size=(128, 128), 62 | n_out=in_channels, 63 | pretrained=True, 64 | n_in=out_channels, 65 | self_attention=True, 66 | ) 67 | 68 | self.models = models 69 | 70 | assert len(weights) == len(self.models) 71 | 72 | for i, model in enumerate(self.models): 73 | model.layers.add_module("CenterCrop", CenterCrop()) 74 | self.add_module(f"model_{i}", model) 75 | 76 | if weights: 77 | self.load_state_dict(torch.load(weights)) 78 | self.model.layers.add_module("CenterCrop", CenterCrop()) 79 | 80 | def forward(self, x): 81 | 82 | outputs = [] 83 | for model in self.models: 84 | outputs.append(model(x)) 85 | 86 | return torch.stack(outputs, dim=1).mean(dim=1) 87 | -------------------------------------------------------------------------------- /best-submission/submission/validate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pytorch_msssim import MS_SSIM 3 | from torch import from_numpy 4 | 5 | from evaluate import Evaluator 6 | 7 | 8 | def main(): 9 | features = np.load("../climatehack-submission/features.npz") 10 | targets = np.load("../climatehack-submission/targets.npz") 11 | 12 | criterion = MS_SSIM(data_range=1023.0, size_average=True, win_size=3, channel=1) 13 | evaluator = Evaluator() 14 | 15 | scores = [ 16 | criterion( 17 | from_numpy(evaluator.predict(*datum)).view(24, 64, 64).unsqueeze(dim=1), 18 | from_numpy(target).view(24, 64, 64).unsqueeze(dim=1), 19 | ).item() 20 | for *datum, target in zip(features["osgb"], features["data"], targets["data"]) 21 | ] 22 | 23 | print(f"Score: {np.mean(scores)} ({np.std(scores)})") 24 | 25 | 26 | if __name__ == "__main__": 27 | main() 28 | -------------------------------------------------------------------------------- /best-submission/submit.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -eEou pipefail 4 | 5 | python doxa_cli.py agent upload climatehack ./submission 6 | -------------------------------------------------------------------------------- /best-submission/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | CUDA_VISIBLE_DEVICES=0 conda run -n climatehack --no-capture-output --live-stream python /home/iyaja/Git/climatehack/train_unet.py 3 | -------------------------------------------------------------------------------- /best-submission/train_timm.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import torch 3 | import torchvision.models as models 4 | from torch.utils.data import DataLoader, ConcatDataset 5 | import xarray as xr 6 | import wandb 7 | 8 | # custom 9 | from utils.loss import MS_SSIMLoss 10 | from utils.data import ClimatehackDataset, CustomDataset 11 | from utils.models import create_timm_model, TimmUnet 12 | 13 | # %% 14 | from fastai.vision.all import * 15 | from fastai.callback.wandb import * 16 | from fastai.callback.tracker import * 17 | from fastai.distributed import * 18 | from fastai.vision.models.xresnet import * 19 | from fastai.optimizer import ranger 20 | from fastai.layers import Mish 21 | 22 | # %% 23 | ARCH = "efficientnetv2_xl" 24 | BATCH_SIZE = 32 25 | FORECAST = 24 26 | 27 | wandb.init(project="climatehack", group=ARCH) 28 | # %% 29 | # train_ds = ClimatehackDataset("data/data.npz", with_osgb=True) 30 | train_ds = CustomDataset("../data/train.npz") 31 | valid_ds = CustomDataset("../data/test.npz") 32 | 33 | train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True) 34 | valid_loader = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False) 35 | 36 | dls = DataLoaders(train_loader, valid_loader, device=torch.device("cuda")) 37 | # %% 38 | criterion = MS_SSIMLoss(channels=FORECAST, crop=True) 39 | 40 | # %% 41 | # arch = partial(xse_resnext50_deeper, act_cls=Mish, sa=True) 42 | # model = create_unet_model( 43 | # arch=arch, 44 | # img_size=(128, 128), 45 | # n_out=24, 46 | # pretrained=False, 47 | # n_in=train_ds[0][0].shape[0], 48 | # self_attention=True, 49 | # ) 50 | 51 | model = TimmUnet( 52 | encoder=ARCH, 53 | # img_size=(128, 128), 54 | n_in=train_ds[0][0].shape[0], 55 | n_out=train_ds[0][1].shape[0], 56 | pretrained=True, 57 | # bottleneck='conv', 58 | # self_attention=True, 59 | ) 60 | 61 | # %% 62 | callbacks = [ 63 | SaveModelCallback(monitor="train_loss", fname=ARCH), 64 | # ReduceLROnPlateau(monitor="train_loss"), 65 | # EarlyStoppingCallback(monitor="val_loss", patience=10, mode="min"), 66 | WandbCallback(), 67 | ] 68 | learn = Learner( 69 | dls, 70 | model, 71 | loss_func=criterion, 72 | cbs=callbacks, 73 | model_dir="checkpoints", 74 | opt_func=ranger, 75 | ) 76 | 77 | # %% 78 | learn.fit_flat_cos(100, 1e-3) 79 | -------------------------------------------------------------------------------- /best-submission/train_unet.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import torch 3 | import torchvision.models as models 4 | from torch.utils.data import DataLoader, ConcatDataset 5 | import xarray as xr 6 | import wandb 7 | 8 | # custom 9 | from utils.loss import MS_SSIMLoss 10 | from utils.data import ClimatehackDataset, CustomDataset 11 | 12 | # %% 13 | from fastai.vision.all import * 14 | from fastai.callback.wandb import * 15 | from fastai.callback.tracker import * 16 | from fastai.distributed import * 17 | from fastai.vision.models.xresnet import * 18 | from fastai.optimizer import ranger 19 | from fastai.layers import Mish 20 | 21 | # %% 22 | NAME = "xse_resnext50_deeper" 23 | BATCH_SIZE = 32 24 | FORECAST = 24 25 | 26 | wandb.init(project="climatehack", group=NAME) 27 | # %% 28 | # train_ds = ClimatehackDataset("data/data.npz", with_osgb=True) 29 | train_ds = CustomDataset("../data/train.npz") 30 | valid_ds = CustomDataset("../data/test.npz") 31 | 32 | train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True) 33 | valid_loader = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False) 34 | 35 | dls = DataLoaders(train_loader, valid_loader, device=torch.device("cuda")) 36 | # %% 37 | criterion = MS_SSIMLoss(channels=FORECAST, crop=True) 38 | 39 | # %% 40 | arch = partial(xse_resnext50_deeper, act_cls=Mish, sa=True) 41 | model = create_unet_model( 42 | arch=arch, 43 | img_size=(128, 128), 44 | n_out=24, 45 | pretrained=False, 46 | n_in=train_ds[0][0].shape[0], 47 | self_attention=True, 48 | ) 49 | 50 | # %% 51 | callbacks = [ 52 | SaveModelCallback(monitor="train_loss", fname=NAME), 53 | ReduceLROnPlateau(monitor="train_loss", factor=2), 54 | # EarlyStoppingCallback(monitor="val_loss", patience=10, mode="min"), 55 | WandbCallback(), 56 | ] 57 | learn = Learner( 58 | dls, 59 | model, 60 | loss_func=criterion, 61 | cbs=callbacks, 62 | model_dir="checkpoints", 63 | opt_func=ranger, 64 | ) 65 | 66 | # %% 67 | learn.fit_flat_cos(100, 1e-3) 68 | -------------------------------------------------------------------------------- /best-submission/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jmathur25/climatehack/96a551131e2482601265595a80bce315c60874ac/best-submission/utils/__init__.py -------------------------------------------------------------------------------- /best-submission/utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from pytorch_msssim import MS_SSIM 4 | 5 | 6 | class MS_SSIMLoss(nn.Module): 7 | """Multi-Scale SSIM Loss""" 8 | 9 | def __init__(self, channels=1, crop=True, **kwargs): 10 | """ 11 | Initialize 12 | Args: 13 | convert_range: Convert input from -1,1 to 0,1 range 14 | **kwargs: Kwargs to pass through to MS_SSIM 15 | """ 16 | super(MS_SSIMLoss, self).__init__() 17 | self.crop = crop 18 | self.ssim_module = MS_SSIM( 19 | data_range=1023.0, size_average=True, win_size=3, channel=channels, **kwargs 20 | ) 21 | 22 | def forward(self, x: torch.Tensor, y: torch.Tensor): 23 | """ 24 | Forward method 25 | Args: 26 | x: tensor one 27 | y: tensor two 28 | Returns: multi-scale SSIM Loss 29 | """ 30 | if self.crop: 31 | return 1.0 - self.ssim_module(x[:, :, 32:96, 32:96], y[:, :, 32:96, 32:96]) 32 | else: 33 | return 1.0 - self.ssim_module(x, y) 34 | -------------------------------------------------------------------------------- /best-submission/utils/preprocess.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, time, timedelta 2 | 3 | import numpy as np 4 | import xarray as xr 5 | from numpy import float32 6 | from tqdm import tqdm 7 | 8 | SATELLITE_ZARR_PATH = "data/eumetsat_seviri_hrv_uk.zarr" 9 | 10 | 11 | def main(): 12 | dataset = xr.open_dataset( 13 | SATELLITE_ZARR_PATH, 14 | engine="zarr", 15 | chunks="auto", 16 | ) 17 | 18 | times = dataset.get_index("time") 19 | min_date = times[0].date() 20 | max_date = times[-1].date() 21 | 22 | start_time = time(9, 0) 23 | end_time = time(16, 0) 24 | 25 | data = [] 26 | date = min_date 27 | print(min_date, max_date, (max_date - min_date).days) 28 | with tqdm(total=(max_date - min_date).days) as pbar: 29 | 30 | # if you only want to preprocess a certain number of days, swap the while loop for the for loop below! 31 | # for _ in range(10): 32 | 33 | while date <= max_date: 34 | print(date) 35 | selection = ( 36 | dataset["data"].sel( 37 | time=slice( 38 | datetime.combine(date, start_time), 39 | datetime.combine(date, end_time), 40 | ), 41 | ) 42 | # comment out the .isel if you want the whole image 43 | .isel( 44 | x=slice(550, 950), 45 | y=slice(375, 700), 46 | ) 47 | ) 48 | 49 | if selection.shape == (85, 325, 400): 50 | data.append(selection.astype(float32).values) 51 | 52 | date += timedelta(days=1) 53 | pbar.update(1) 54 | 55 | clipped_data = np.clip(np.stack(data), 0.0, 1023.0) 56 | 57 | x_osgb = ( 58 | dataset["x_osgb"] 59 | # comment out the .isel if you want the whole image 60 | .isel( 61 | x=slice(550, 950), 62 | y=slice(375, 700), 63 | ).values.astype(float32) 64 | ) 65 | 66 | y_osgb = ( 67 | dataset["y_osgb"] 68 | # comment out the .isel if you want the whole image 69 | .isel( 70 | x=slice(550, 950), 71 | y=slice(375, 700), 72 | ).values.astype(float32) 73 | ) 74 | 75 | # you can also use np.savez_compressed if you'd rather compress everything! 76 | np.savez("data/sample", x_osgb=x_osgb, y_osgb=y_osgb, data=clipped_data) 77 | 78 | 79 | if __name__ == "__main__": 80 | main() 81 | -------------------------------------------------------------------------------- /best-submission/validate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd submissionv2 3 | conda run -n climatehack --no-capture-output --live-stream python /home/iyaja/Git/climatehack/submissionv2/validate.py 4 | conda run -n climatehack --no-capture-output --live-stream python /home/iyaja/Git/climatehack/submissionv2/validate_timestep.py 5 | -------------------------------------------------------------------------------- /common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jmathur25/climatehack/96a551131e2482601265595a80bce315c60874ac/common/__init__.py -------------------------------------------------------------------------------- /common/checkpointing.py: -------------------------------------------------------------------------------- 1 | """ 2 | I rolled my own checkpointing for no good reason. Oops. 3 | """ 4 | import os 5 | import pathlib 6 | import torch 7 | 8 | 9 | class Checkpointer: 10 | def __init__(self, save_folder: pathlib.Path): 11 | self.save_folder = save_folder.absolute() 12 | self.hist = save_folder / "hist.txt" 13 | self.hist.touch() 14 | 15 | def get_best_model_info(self): 16 | """ 17 | Returns info on the best model so far. 18 | """ 19 | files = self.hist.read_text().split("\n") 20 | actual = [f for f in files if f.endswith(".pt")] 21 | best_fname = None 22 | best_loss = float("inf") 23 | best_epoch = None 24 | for fname in actual: 25 | if "batch" in fname: 26 | continue 27 | f = fname[:-3] # remove .pt 28 | _, epoch_str, loss_str = f.split("_") 29 | _, epoch = epoch_str.split("=") 30 | _, loss = loss_str.split("=") 31 | epoch = int(epoch) 32 | loss = float(loss) 33 | if loss < best_loss: 34 | best_loss = loss 35 | best_epoch = epoch 36 | best_fname = fname 37 | if best_fname is None: 38 | return None, None, None 39 | return self.save_folder / best_fname, best_epoch, best_loss 40 | 41 | def save_checkpoint(self, model, optimizer, epoch, avg_loss): 42 | """ 43 | Save the model and optimizer state if the loss has decreased. 44 | """ 45 | # see if we should save weights 46 | best_fpath, best_epoch, best_loss = self.get_best_model_info() 47 | if best_fpath is None or avg_loss < best_loss: 48 | print(f"Loss decreased from {best_loss} to {avg_loss}, saving...") 49 | checkpoint = { 50 | "model": model.state_dict(), 51 | "optimizer": optimizer.state_dict(), 52 | } 53 | avg_loss = round(avg_loss, 4) 54 | torch.save( 55 | checkpoint, 56 | self.save_folder / f"checkpoint_epochs={epoch}_loss={avg_loss}.pt", 57 | ) 58 | if best_fpath is not None: 59 | print("Deleting old best:", best_fpath) 60 | os.remove(best_fpath) 61 | 62 | # save to history file regardless 63 | with self.hist.open("a") as f: 64 | f.write(f"checkpoint_epochs={epoch}_loss={avg_loss}.pt\n") 65 | 66 | def load_checkpoint(self): 67 | """ 68 | Loads the best model and optimizer states. 69 | """ 70 | best_fpath, best_epoch, best_loss = self.get_best_model_info() 71 | if best_fpath is None: 72 | return None, None 73 | checkpoint = torch.load(best_fpath) 74 | return checkpoint, best_epoch 75 | -------------------------------------------------------------------------------- /common/denoiser.py: -------------------------------------------------------------------------------- 1 | """ 2 | A simple denoiser that uses a pretrained neural network to fix some of the block artifacts that were found in model predictions. 3 | """ 4 | import os 5 | import scipy.io 6 | import pathlib 7 | import torch.nn as nn 8 | import torch 9 | 10 | mydir = pathlib.Path(os.path.dirname(os.path.realpath(__file__))) 11 | 12 | 13 | class ARCNN(nn.Module): 14 | def __init__(self, weight): 15 | super().__init__() 16 | 17 | # PyTorch's Conv2D uses zero-padding while the matlab code uses replicate 18 | # So we need to use separate padding modules 19 | self.conv1 = nn.Conv2d(1, 64, kernel_size=9) 20 | self.conv2 = nn.Conv2d(64, 32, kernel_size=7) 21 | self.conv22 = nn.Conv2d(32, 16, kernel_size=1) 22 | self.conv3 = nn.Conv2d(16, 1, kernel_size=5) 23 | 24 | self.pad2 = nn.ReplicationPad2d(2) 25 | self.pad3 = nn.ReplicationPad2d(3) 26 | self.pad4 = nn.ReplicationPad2d(4) 27 | 28 | self.relu = nn.ReLU(inplace=True) 29 | 30 | # Load the weights from the weight dict 31 | self.conv1.weight.data = torch.from_numpy( 32 | weight["weights_conv1"] 33 | .transpose(2, 0, 1) 34 | .reshape(64, 1, 9, 9) 35 | .transpose(0, 1, 3, 2) 36 | ).float() 37 | self.conv1.bias.data = torch.from_numpy( 38 | weight["biases_conv1"].reshape(64) 39 | ).float() 40 | 41 | self.conv2.weight.data = torch.from_numpy( 42 | weight["weights_conv2"] 43 | .transpose(2, 0, 1) 44 | .reshape(32, 64, 7, 7) 45 | .transpose(0, 1, 3, 2) 46 | ).float() 47 | self.conv2.bias.data = torch.from_numpy( 48 | weight["biases_conv2"].reshape(32) 49 | ).float() 50 | 51 | self.conv22.weight.data = torch.from_numpy( 52 | weight["weights_conv22"] 53 | .transpose(2, 0, 1) 54 | .reshape(16, 32, 1, 1) 55 | .transpose(0, 1, 3, 2) 56 | ).float() 57 | self.conv22.bias.data = torch.from_numpy( 58 | weight["biases_conv22"].reshape(16) 59 | ).float() 60 | 61 | self.conv3.weight.data = torch.from_numpy( 62 | weight["weights_conv3"].reshape(1, 16, 5, 5).transpose(0, 1, 3, 2) 63 | ).float() 64 | self.conv3.bias.data = torch.from_numpy( 65 | weight["biases_conv3"].reshape(1) 66 | ).float() 67 | 68 | def forward(self, x): 69 | x = self.pad4(x) 70 | x = self.relu(self.conv1(x)) 71 | 72 | x = self.pad3(x) 73 | x = self.relu(self.conv2(x)) 74 | x = self.relu(self.conv22(x)) 75 | 76 | x = self.pad2(x) 77 | x = self.conv3(x) 78 | 79 | return x 80 | 81 | 82 | class Denoiser: 83 | """ 84 | A neural network denoiser. 85 | """ 86 | 87 | def __init__(self, device): 88 | weights = scipy.io.loadmat(mydir / "q20.mat") 89 | self.arcnn = ARCNN(weights).to(device).eval() 90 | 91 | def denoise(self, x): 92 | with torch.no_grad(): 93 | b, t, h, w = x.shape 94 | x = x.reshape(-1, 1, 128, 128) 95 | x = self.arcnn(x) 96 | x = x.reshape(b, t, h, w) 97 | return x 98 | -------------------------------------------------------------------------------- /common/loss_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from pytorch_msssim import MS_SSIM 4 | 5 | 6 | class MS_SSIMLoss(nn.Module): 7 | """Multi-Scale SSIM Loss""" 8 | def __init__(self, data_range, channels=1, **kwargs): 9 | """ 10 | Initialize 11 | Args: 12 | convert_range: Convert input from -1,1 to 0,1 range 13 | **kwargs: Kwargs to pass through to MS_SSIM 14 | """ 15 | super(MS_SSIMLoss, self).__init__() 16 | self.ssim_module = MS_SSIM( 17 | data_range=data_range, size_average=True, win_size=3, channel=channels, **kwargs 18 | ) 19 | 20 | def forward(self, x: torch.Tensor, y: torch.Tensor): 21 | """ 22 | Forward method 23 | Args: 24 | x: tensor one 25 | y: tensor two 26 | Returns: multi-scale SSIM Loss 27 | """ 28 | return 1.0 - self.ssim_module(x, y) 29 | -------------------------------------------------------------------------------- /common/q20.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jmathur25/climatehack/96a551131e2482601265595a80bce315c60874ac/common/q20.mat -------------------------------------------------------------------------------- /data/download_data.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "ee483ada", 6 | "metadata": {}, 7 | "source": [ 8 | " Creates `train.npz` and `test.npz` by making a simle temporal crop from 10am to 4pm across the full dataset." 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "3a64f122", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import xarray as xr\n", 19 | "import numpy as np\n", 20 | "import pathlib\n", 21 | "import datetime\n", 22 | "import pandas as pd\n", 23 | "import tqdm" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 2, 29 | "id": "273c7e8b", 30 | "metadata": {}, 31 | "outputs": [ 32 | { 33 | "name": "stdout", 34 | "output_type": "stream", 35 | "text": [ 36 | "\n", 37 | "Dimensions: (time: 173624, y: 891, x: 1843)\n", 38 | "Coordinates:\n", 39 | " * time (time) datetime64[ns] 2020-01-01T00:05:00 ... 2021-11-07T15:50:00\n", 40 | " * x (x) float32 2.8e+04 2.7e+04 2.6e+04 ... -1.813e+06 -1.814e+06\n", 41 | " x_osgb (y, x) float32 dask.array\n", 42 | " * y (y) float32 4.198e+06 4.199e+06 4.2e+06 ... 5.087e+06 5.088e+06\n", 43 | " y_osgb (y, x) float32 dask.array\n", 44 | "Data variables:\n", 45 | " data (time, y, x) int16 dask.array\n" 46 | ] 47 | } 48 | ], 49 | "source": [ 50 | "SATELLITE_ZARR_PATH = \"gs://public-datasets-eumetsat-solar-forecasting/satellite/EUMETSAT/SEVIRI_RSS/v3/eumetsat_seviri_hrv_uk.zarr\"\n", 51 | "\n", 52 | "dataset = xr.open_dataset(\n", 53 | " SATELLITE_ZARR_PATH, \n", 54 | " engine=\"zarr\",\n", 55 | " chunks=\"auto\", # Load the data as a Dask array\n", 56 | ")\n", 57 | "\n", 58 | "print(dataset)\n" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 3, 64 | "id": "1ea7fb67", 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "def get_day_slice(date):\n", 69 | " data_slice = dataset.loc[\n", 70 | " {\n", 71 | " # 10am to 4pm\n", 72 | " \"time\": slice(\n", 73 | " date + datetime.timedelta(hours=10),\n", 74 | " date + datetime.timedelta(hours=16),\n", 75 | " )\n", 76 | " }\n", 77 | " ].isel(\n", 78 | " x=slice(550, 950),\n", 79 | " y=slice(375, 700),\n", 80 | " )\n", 81 | " \n", 82 | " # sometimes there is no data\n", 83 | " if len(data_slice.time) == 0:\n", 84 | " return None\n", 85 | " return data_slice" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 4, 91 | "id": "853691db", 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "# it might be worth it to do this in batches of one year... a full download will take at least\n", 96 | "# 30 minutes if not hours\n", 97 | "start_date = datetime.datetime(2020, 1, 1)\n", 98 | "end_date = datetime.datetime(2021, 12, 31)\n", 99 | "\n", 100 | "cur = start_date\n", 101 | "days_to_get = []\n", 102 | "while cur != end_date + datetime.timedelta(days=1):\n", 103 | " days_to_get.append(cur)\n", 104 | " cur = cur + datetime.timedelta(days=1)" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 5, 110 | "id": "2133c65e", 111 | "metadata": {}, 112 | "outputs": [ 113 | { 114 | "name": "stderr", 115 | "output_type": "stream", 116 | "text": [ 117 | "100%|██████████| 731/731 [00:00<00:00, 817.82it/s]\n" 118 | ] 119 | } 120 | ], 121 | "source": [ 122 | "slices = []\n", 123 | "for date in tqdm.tqdm(days_to_get):\n", 124 | " slc = get_day_slice(date)\n", 125 | " if slc is None:\n", 126 | " continue\n", 127 | " slices.append(slc)" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": 6, 133 | "id": "0c36633f", 134 | "metadata": {}, 135 | "outputs": [ 136 | { 137 | "data": { 138 | "text/plain": [ 139 | "606" 140 | ] 141 | }, 142 | "execution_count": 6, 143 | "metadata": {}, 144 | "output_type": "execute_result" 145 | } 146 | ], 147 | "source": [ 148 | "len(slices)" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": null, 154 | "id": "98b4495b", 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [ 158 | "combined = xr.concat(slices, dim='time')" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": null, 164 | "id": "e52d267f", 165 | "metadata": {}, 166 | "outputs": [], 167 | "source": [ 168 | "# takes a while\n", 169 | "times = combined['time'].to_numpy()\n", 170 | "x = combined['x'].to_numpy()\n", 171 | "x_osgb = combined['x_osgb'].to_numpy()\n", 172 | "y = combined['y'].to_numpy()\n", 173 | "y_osgb = combined['y_osgb'].to_numpy()\n", 174 | "%time data = combined['data'].to_numpy()" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": null, 180 | "id": "a2f4d989", 181 | "metadata": {}, 182 | "outputs": [], 183 | "source": [ 184 | "times.shape, data.shape\n" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": null, 190 | "id": "0c797430", 191 | "metadata": {}, 192 | "outputs": [], 193 | "source": [ 194 | "# due to some weird things in how we originally made `train.npz` and `test.npz`,\n", 195 | "# this will not be equivalent to the datasets we linked in the drive. But they\n", 196 | "# should still work.\n", 197 | "test_days = 30\n", 198 | "\n", 199 | "np.random.seed(7)\n", 200 | "test_dates = np.random.choice(days_to_get, size=test_days, replace=False)\n", 201 | "\n", 202 | "test_indices = []\n", 203 | "train_indices = []\n", 204 | "for i, t in enumerate(times):\n", 205 | " d = pd.Timestamp(t).date()\n", 206 | " if d in test_dates:\n", 207 | " test_indices.append(i)\n", 208 | " else:\n", 209 | " train_indices.append(i)" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": null, 215 | "id": "8ffee4e0", 216 | "metadata": {}, 217 | "outputs": [], 218 | "source": [ 219 | "test_data = data[test_indices]\n", 220 | "test_times = times[test_indices]\n", 221 | "\n", 222 | "train_data = data[train_indices]\n", 223 | "train_times = times[train_indices]" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": null, 229 | "id": "4d1fa9e8", 230 | "metadata": {}, 231 | "outputs": [], 232 | "source": [ 233 | "# save data\n", 234 | "p = pathlib.Path(f'train.npz')\n", 235 | "if p.exists():\n", 236 | " raise ValueError(f'Path {p} already exists!')\n", 237 | "\n", 238 | "np.savez(\n", 239 | " p,\n", 240 | " times=train_times,\n", 241 | " data=train_data,\n", 242 | ")\n", 243 | "\n", 244 | "p = pathlib.Path(f'test.npz')\n", 245 | "if p.exists():\n", 246 | " raise ValueError(f'Path {p} already exists!')\n", 247 | "\n", 248 | "np.savez(\n", 249 | " p,\n", 250 | " times=test_times,\n", 251 | " data=test_data,\n", 252 | ")\n", 253 | "\n", 254 | "p = pathlib.Path(f'coords.npz')\n", 255 | "if p.exists():\n", 256 | " raise ValueError(f'Path {p} already exists!')\n", 257 | "\n", 258 | "np.savez(\n", 259 | " p,\n", 260 | " x=x,\n", 261 | " x_osgb=x_osgb,\n", 262 | " y=y,\n", 263 | " y_osgb=y_osgb,\n", 264 | ")" 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": null, 270 | "id": "e270bb59", 271 | "metadata": {}, 272 | "outputs": [], 273 | "source": [] 274 | } 275 | ], 276 | "metadata": { 277 | "kernelspec": { 278 | "display_name": "climatehack", 279 | "language": "python", 280 | "name": "climatehack" 281 | }, 282 | "language_info": { 283 | "codemirror_mode": { 284 | "name": "ipython", 285 | "version": 3 286 | }, 287 | "file_extension": ".py", 288 | "mimetype": "text/x-python", 289 | "name": "python", 290 | "nbconvert_exporter": "python", 291 | "pygments_lexer": "ipython3", 292 | "version": "3.9.7" 293 | } 294 | }, 295 | "nbformat": 4, 296 | "nbformat_minor": 5 297 | } 298 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: climatehack 2 | channels: 3 | - pvlib 4 | - pytorch 5 | - defaults 6 | - conda-forge 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - _openmp_mutex=4.5=1_gnu 10 | - aiohttp=3.8.1=py39h7f8727e_1 11 | - aiosignal=1.2.0=pyhd3eb1b0_0 12 | - asciitree=0.3.3=py_2 13 | - asttokens=2.0.5=pyhd3eb1b0_0 14 | - async-timeout=4.0.1=pyhd3eb1b0_0 15 | - attrs=21.4.0=pyhd3eb1b0_0 16 | - backcall=0.2.0=pyhd3eb1b0_0 17 | - blas=1.0=mkl 18 | - blinker=1.4=py39h06a4308_0 19 | - bokeh=2.4.2=py39h06a4308_0 20 | - bottleneck=1.3.4=py39hce1f21e_0 21 | - brotli=1.0.9=he6710b0_2 22 | - brotlipy=0.7.0=py39h27cfd23_1003 23 | - bzip2=1.0.8=h7b6447c_0 24 | - c-ares=1.18.1=h7f8727e_0 25 | - ca-certificates=2022.3.18=h06a4308_0 26 | - cachetools=4.2.2=pyhd3eb1b0_0 27 | - cartopy=0.18.0=py39hc576cba_1 28 | - certifi=2021.10.8=py39h06a4308_2 29 | - cffi=1.15.0=py39hd667e15_1 30 | - cftime=1.5.1.1=py39hce1f21e_0 31 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 32 | - click=8.0.4=py39h06a4308_0 33 | - cloudpickle=2.0.0=pyhd3eb1b0_0 34 | - cryptography=3.4.8=py39hd23ed53_0 35 | - cudatoolkit=11.3.1=h2bc3f7f_2 36 | - curl=7.80.0=h7f8727e_0 37 | - cycler=0.11.0=pyhd3eb1b0_0 38 | - cytoolz=0.11.0=py39h27cfd23_0 39 | - dask=2022.2.1=pyhd3eb1b0_0 40 | - dask-core=2022.2.1=pyhd3eb1b0_0 41 | - dataclasses=0.8=pyh6d0b6a4_7 42 | - dbus=1.13.18=hb2f20db_0 43 | - debugpy=1.5.1=py39h295c915_0 44 | - decorator=5.1.1=pyhd3eb1b0_0 45 | - distributed=2022.2.1=pyhd3eb1b0_0 46 | - entrypoints=0.3=py39h06a4308_0 47 | - ephem=4.1.2=py39h7f8727e_0 48 | - executing=0.8.3=pyhd3eb1b0_0 49 | - expat=2.4.4=h295c915_0 50 | - fasteners=0.16.3=pyhd3eb1b0_0 51 | - ffmpeg=4.3=hf484d3e_0 52 | - fontconfig=2.13.1=h6c09931_0 53 | - fonttools=4.25.0=pyhd3eb1b0_0 54 | - freetype=2.11.0=h70c0345_0 55 | - frozenlist=1.2.0=py39h7f8727e_0 56 | - fsspec=2022.2.0=pyhd3eb1b0_0 57 | - gcsfs=2022.2.0=pyhd8ed1ab_0 58 | - geos=3.8.0=he6710b0_0 59 | - giflib=5.2.1=h7b6447c_0 60 | - glib=2.69.1=h4ff587b_1 61 | - gmp=6.2.1=h2531618_2 62 | - gnutls=3.6.15=he1e5248_0 63 | - google-api-core=2.2.2=pyhd3eb1b0_0 64 | - google-auth=2.6.0=pyhd3eb1b0_0 65 | - google-auth-oauthlib=0.4.4=pyhd3eb1b0_0 66 | - google-cloud-core=2.2.2=pyhd3eb1b0_0 67 | - google-cloud-storage=2.1.0=pyh6c4a22f_0 68 | - google-crc32c=1.1.2=py39h27cfd23_0 69 | - google-resumable-media=1.3.1=pyhd3eb1b0_1 70 | - googleapis-common-protos=1.53.0=py39h06a4308_0 71 | - grpcio=1.42.0=py39hce63b2e_0 72 | - gst-plugins-base=1.14.0=h8213a91_2 73 | - gstreamer=1.14.0=h28cd5cc_2 74 | - h5py=3.6.0=py39ha0f2276_0 75 | - hdf4=4.2.13=h3ca952b_2 76 | - hdf5=1.10.6=hb1b8bf9_0 77 | - heapdict=1.0.1=pyhd3eb1b0_0 78 | - icu=58.2=he6710b0_3 79 | - idna=3.3=pyhd3eb1b0_0 80 | - intel-openmp=2021.4.0=h06a4308_3561 81 | - ipykernel=6.9.1=py39h06a4308_0 82 | - ipython=8.1.1=py39h06a4308_0 83 | - jedi=0.18.1=py39h06a4308_1 84 | - jinja2=3.0.3=pyhd3eb1b0_0 85 | - jpeg=9d=h7f8727e_0 86 | - jupyter_client=7.1.2=pyhd3eb1b0_0 87 | - jupyter_core=4.9.2=py39h06a4308_0 88 | - kiwisolver=1.3.2=py39h295c915_0 89 | - krb5=1.19.2=hac12032_0 90 | - lame=3.100=h7b6447c_0 91 | - lcms2=2.12=h3be6417_0 92 | - ld_impl_linux-64=2.35.1=h7274673_9 93 | - libcrc32c=1.1.1=he6710b0_2 94 | - libcurl=7.80.0=h0b77cf5_0 95 | - libedit=3.1.20210910=h7f8727e_0 96 | - libev=4.33=h7f8727e_1 97 | - libffi=3.3=he6710b0_2 98 | - libgcc-ng=9.3.0=h5101ec6_17 99 | - libgfortran-ng=7.5.0=ha8ba4b0_17 100 | - libgfortran4=7.5.0=ha8ba4b0_17 101 | - libgomp=9.3.0=h5101ec6_17 102 | - libiconv=1.15=h63c8f33_5 103 | - libidn2=2.3.2=h7f8727e_0 104 | - libllvm11=11.1.0=h3826bc1_1 105 | - libnetcdf=4.8.1=h42ceab0_1 106 | - libnghttp2=1.46.0=hce63b2e_0 107 | - libpng=1.6.37=hbc83047_0 108 | - libprotobuf=3.19.1=h4ff587b_0 109 | - libsodium=1.0.18=h7b6447c_0 110 | - libssh2=1.9.0=h1ba5d50_1 111 | - libstdcxx-ng=9.3.0=hd4cf53a_17 112 | - libtasn1=4.16.0=h27cfd23_0 113 | - libtiff=4.2.0=h85742a9_0 114 | - libunistring=0.9.10=h27cfd23_0 115 | - libuuid=1.0.3=h7f8727e_2 116 | - libuv=1.40.0=h7b6447c_0 117 | - libwebp=1.2.2=h55f646e_0 118 | - libwebp-base=1.2.2=h7f8727e_0 119 | - libxcb=1.14=h7b6447c_0 120 | - libxml2=2.9.12=h03d6c58_0 121 | - libzip=1.5.1=h8d318fa_1003 122 | - llvmlite=0.38.0=py39h4ff587b_0 123 | - locket=0.2.1=py39h06a4308_2 124 | - lz4-c=1.9.3=h295c915_1 125 | - markupsafe=2.0.1=py39h27cfd23_0 126 | - matplotlib=3.5.1=py39h06a4308_1 127 | - matplotlib-base=3.5.1=py39ha18d171_1 128 | - matplotlib-inline=0.1.2=pyhd3eb1b0_2 129 | - mkl=2021.4.0=h06a4308_640 130 | - mkl-service=2.4.0=py39h7f8727e_0 131 | - mkl_fft=1.3.1=py39hd3c417c_0 132 | - mkl_random=1.2.2=py39h51133e4_0 133 | - msgpack-python=1.0.2=py39hff7bd54_1 134 | - multidict=5.2.0=py39h7f8727e_2 135 | - munkres=1.1.4=py_0 136 | - ncurses=6.3=h7f8727e_2 137 | - nest-asyncio=1.5.1=pyhd3eb1b0_0 138 | - netcdf4=1.5.7=py39ha0f2276_1 139 | - nettle=3.7.3=hbbd107a_1 140 | - numba=0.55.1=py39h51133e4_0 141 | - numcodecs=0.9.1=py39h295c915_0 142 | - numexpr=2.8.1=py39h6abb31d_0 143 | - numpy=1.21.2=py39h20f2e39_0 144 | - numpy-base=1.21.2=py39h79a1101_0 145 | - oauthlib=3.2.0=pyhd3eb1b0_0 146 | - openh264=2.1.1=h4ff587b_0 147 | - openssl=1.1.1n=h7f8727e_0 148 | - packaging=21.3=pyhd3eb1b0_0 149 | - pandas=1.4.1=py39h295c915_1 150 | - parso=0.8.3=pyhd3eb1b0_0 151 | - partd=1.2.0=pyhd3eb1b0_1 152 | - patsy=0.5.2=py39h06a4308_1 153 | - pcre=8.45=h295c915_0 154 | - pexpect=4.8.0=pyhd3eb1b0_3 155 | - pickleshare=0.7.5=pyhd3eb1b0_1003 156 | - pillow=9.0.1=py39h22f2fdc_0 157 | - pip=21.2.4=py39h06a4308_0 158 | - proj=7.0.1=h59a7b90_1 159 | - prompt-toolkit=3.0.20=pyhd3eb1b0_0 160 | - protobuf=3.19.1=py39h295c915_0 161 | - psutil=5.8.0=py39h27cfd23_1 162 | - ptyprocess=0.7.0=pyhd3eb1b0_2 163 | - pure_eval=0.2.2=pyhd3eb1b0_0 164 | - pvlib=0.9.0=py_1 165 | - pyasn1=0.4.8=pyhd3eb1b0_0 166 | - pyasn1-modules=0.2.8=py_0 167 | - pycparser=2.21=pyhd3eb1b0_0 168 | - pygments=2.11.2=pyhd3eb1b0_0 169 | - pyjwt=2.1.0=py39h06a4308_0 170 | - pyopenssl=20.0.1=pyhd8ed1ab_0 171 | - pyparsing=3.0.4=pyhd3eb1b0_0 172 | - pyqt=5.9.2=py39h2531618_6 173 | - pyshp=2.1.3=pyhd3eb1b0_0 174 | - pysocks=1.7.1=py39h06a4308_0 175 | - python=3.9.7=h12debd9_1 176 | - python-dateutil=2.8.2=pyhd3eb1b0_0 177 | - pytorch-mutex=1.0=cuda 178 | - pytz=2021.3=pyhd3eb1b0_0 179 | - pyyaml=6.0=py39h7f8727e_1 180 | - pyzmq=22.3.0=py39h295c915_2 181 | - qt=5.9.7=h5867ecd_1 182 | - readline=8.1.2=h7f8727e_1 183 | - requests=2.27.1=pyhd3eb1b0_0 184 | - requests-oauthlib=1.3.0=py_0 185 | - rsa=4.7.2=pyhd3eb1b0_1 186 | - scipy=1.7.3=py39hc147768_0 187 | - shapely=1.7.1=py39h1728cc4_0 188 | - sip=4.19.13=py39h295c915_0 189 | - six=1.16.0=pyhd3eb1b0_1 190 | - sortedcontainers=2.4.0=pyhd3eb1b0_0 191 | - sqlite=3.38.0=hc218d9a_0 192 | - stack_data=0.2.0=pyhd3eb1b0_0 193 | - statsmodels=0.13.2=py39h7f8727e_0 194 | - tbb=2021.5.0=hd09550d_0 195 | - tblib=1.7.0=pyhd3eb1b0_0 196 | - tk=8.6.11=h1ccaba5_0 197 | - toolz=0.11.2=pyhd3eb1b0_0 198 | - torchaudio=0.11.0=py39_cu113 199 | - tornado=6.1=py39h27cfd23_0 200 | - traitlets=5.1.1=pyhd3eb1b0_0 201 | - typing-extensions=4.1.1=hd3eb1b0_0 202 | - typing_extensions=4.1.1=pyh06a4308_0 203 | - tzdata=2021e=hda174b7_0 204 | - urllib3=1.26.8=pyhd3eb1b0_0 205 | - wcwidth=0.2.5=pyhd3eb1b0_0 206 | - wheel=0.37.1=pyhd3eb1b0_0 207 | - xarray=0.20.1=pyhd3eb1b0_1 208 | - xz=5.2.5=h7b6447c_0 209 | - yaml=0.2.5=h7b6447c_0 210 | - yarl=1.6.3=py39h27cfd23_0 211 | - zarr=2.8.1=pyhd3eb1b0_0 212 | - zeromq=4.3.4=h2531618_0 213 | - zict=2.0.0=pyhd3eb1b0_0 214 | - zlib=1.2.11=h7f8727e_4 215 | - zstd=1.4.9=haebb681_0 216 | - pip: 217 | - absl-py==1.0.0 218 | - black==22.1.0 219 | - blis==0.7.7 220 | - catalogue==2.0.7 221 | - cymem==2.0.6 222 | - einops==0.4.1 223 | - fastai==2.5.3 224 | - fastcore==1.3.29 225 | - fastdownload==0.0.5 226 | - fastprogress==1.0.2 227 | - filelock==3.6.0 228 | - future==0.18.2 229 | - huggingface-hub==0.4.0 230 | - hyperopt==0.2.7 231 | - importlib-metadata==4.11.3 232 | - joblib==1.1.0 233 | - langcodes==3.3.0 234 | - markdown==3.3.6 235 | - murmurhash==1.0.6 236 | - mypy-extensions==0.4.3 237 | - networkx==2.7.1 238 | - pathspec==0.9.0 239 | - pathy==0.6.1 240 | - platformdirs==2.5.1 241 | - preshed==3.0.6 242 | - py4j==0.10.9.5 243 | - pydantic==1.8.2 244 | - pydeprecate==0.3.1 245 | - pyproj==3.3.0 246 | - pytorch-lightning==1.5.10 247 | - pytorch-msssim==0.2.1 248 | - scikit-learn==1.0.2 249 | - setuptools==59.5.0 250 | - smart-open==5.2.1 251 | - spacy==3.2.3 252 | - spacy-legacy==3.0.9 253 | - spacy-loggers==1.0.1 254 | - srsly==2.4.2 255 | - tensorboard==2.8.0 256 | - tensorboard-data-server==0.6.1 257 | - tensorboard-plugin-wit==1.8.1 258 | - thinc==8.0.15 259 | - threadpoolctl==3.1.0 260 | - tomli==2.0.1 261 | - torch==1.10.2 262 | - torchmetrics==0.7.3 263 | - torchvision==0.11.3 264 | - tqdm==4.63.1 265 | - typer==0.4.0 266 | - wasabi==0.9.0 267 | - werkzeug==2.0.3 268 | - zipp==3.7.0 269 | -------------------------------------------------------------------------------- /experiments/climatehack-submission/.gitignore: -------------------------------------------------------------------------------- 1 | bin/ 2 | -------------------------------------------------------------------------------- /experiments/climatehack-submission/doxa_cli.py: -------------------------------------------------------------------------------- 1 | # NOTE: 2 | # This file downloads the doxa_cli to allow you to upload an agent. 3 | # You do NOT need to edit this file, use it as is. 4 | 5 | import sys 6 | 7 | if sys.version_info[0] != 3: 8 | print("Please run this script using python3") 9 | sys.exit(1) 10 | 11 | import json 12 | import os 13 | import platform 14 | import stat 15 | import subprocess 16 | import tarfile 17 | import urllib.error 18 | import urllib.request 19 | 20 | 21 | # Returns `windows`, `darwin` (macos) or `linux` 22 | def get_os(): 23 | system = platform.system() 24 | 25 | if system == "Linux": 26 | # The exe version works better for WSL 27 | if "microsoft" in platform.platform(): 28 | return "windows" 29 | 30 | return "linux" 31 | elif system == "Windows": 32 | return "windows" 33 | elif system == "Darwin": 34 | return "darwin" 35 | else: 36 | raise Exception(f"Unknown platform {system}") 37 | 38 | 39 | def get_bin_name(): 40 | bin_name = "doxa_cli" 41 | if get_os() == "windows": 42 | bin_name = "doxa_cli.exe" 43 | return bin_name 44 | 45 | 46 | def get_bin_dir(): 47 | return os.path.join(os.path.dirname(__file__), "bin") 48 | 49 | 50 | def get_binary(): 51 | return os.path.join(get_bin_dir(), get_bin_name()) 52 | 53 | 54 | def install_binary(): 55 | match_release = None 56 | try: 57 | match_release = sys.argv[1] 58 | # Arguments are not required 59 | except IndexError: 60 | pass 61 | 62 | REPO_RELEASE_URL = "https://api.github.com/repos/louisdewar/doxa/releases/latest" 63 | try: 64 | f = urllib.request.urlopen(REPO_RELEASE_URL) 65 | except urllib.error.URLError: 66 | print("There was an SSL cert verification error") 67 | print( 68 | 'If you are on a mac and you have recently installed a new version of\ 69 | python then you should navigate to "/Applications/Python {VERSION}/"' 70 | ) 71 | print('Then run a script in that folder called "Install Certificates.command"') 72 | sys.exit(1) 73 | 74 | response = json.loads(f.read()) 75 | 76 | print("Current version tag:", response["tag_name"]) 77 | 78 | assets = [ 79 | asset for asset in response["assets"] if asset["name"].endswith(".tar.gz") 80 | ] 81 | 82 | # Find the release for this OS 83 | match_release = get_os() 84 | try: 85 | asset_choice = next(asset for asset in assets if match_release in asset["name"]) 86 | print( 87 | 'Automatically picked {} to download based on match "{}"\n'.format( 88 | asset_choice["name"], match_release 89 | ) 90 | ) 91 | except StopIteration: 92 | print('Couldn\'t find "{}" in releases'.format(match_release)) 93 | sys.exit(1) 94 | 95 | download_url = asset_choice["browser_download_url"] 96 | 97 | # Folder where this script is + bin 98 | bin_dir = get_bin_dir() 99 | 100 | print("Downloading", asset_choice["name"], "to", bin_dir) 101 | print("({})".format(download_url)) 102 | 103 | # Clear bin directory if it exists 104 | if not os.path.exists(bin_dir): 105 | os.mkdir(bin_dir) 106 | 107 | zip_path = os.path.join(bin_dir, asset_choice["name"]) 108 | 109 | # Download zip file 110 | urllib.request.urlretrieve(download_url, zip_path) 111 | 112 | # Open and extract zip file 113 | tar_file = tarfile.open(zip_path) 114 | tar_file.extractall(bin_dir) 115 | tar_file.close() 116 | 117 | # Delete zip file 118 | os.remove(zip_path) 119 | 120 | # Path to the actual binary program (called doxa_cli or doxa_cli.exe) 121 | bin_name = get_bin_name() 122 | binary_path = os.path.join(bin_dir, bin_name) 123 | 124 | if not os.path.exists(binary_path): 125 | print(f"Couldn't find the binary file `{bin_name}` in the bin directory") 126 | print("This probably means that there was a problem with the download") 127 | sys.exit(1) 128 | 129 | if get_os() != "windows": 130 | # Make binary executable 131 | st = os.stat(binary_path) 132 | os.chmod(binary_path, st.st_mode | stat.S_IEXEC) 133 | 134 | # Run help 135 | print("Installed binary\n\n") 136 | 137 | 138 | def run_command(args): 139 | bin_path = get_binary() 140 | 141 | if not os.path.exists(bin_path): 142 | install_binary() 143 | subprocess.call([bin_path] + args) 144 | 145 | 146 | if __name__ == "__main__": 147 | run_command(sys.argv[1:]) 148 | -------------------------------------------------------------------------------- /experiments/climatehack-submission/submission/climatehack.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | from typing import List, Tuple 4 | 5 | import numpy as np 6 | 7 | 8 | class BaseEvaluator: 9 | def __init__(self) -> None: 10 | self.setup() 11 | 12 | def setup(self): 13 | """Sets up anything required for evaluation, e.g. models.""" 14 | pass 15 | 16 | def predict(self, coordinates: np.ndarray, data: np.ndarray) -> np.ndarray: 17 | """Makes a prediction for the next two hours of satellite imagery. 18 | 19 | Args: 20 | coordinates (np.ndarray): the OSGB x and y coordinates (2, 128, 128) 21 | data (np.ndarray): an array of 12 128*128 satellite images (12, 128, 128) 22 | 23 | Returns: 24 | np.ndarray: an array of 24 64*64 satellite image predictions (24, 64, 64) 25 | """ 26 | 27 | raise NotImplementedError( 28 | "You need to extend this class to use your trained model(s)." 29 | ) 30 | 31 | def _get_io_paths(self) -> Tuple[Path, Path]: 32 | """Gets the input and output directory paths from DOXA. 33 | 34 | Returns: 35 | Tuple[Path, Path]: The input and output paths 36 | """ 37 | try: 38 | return Path(sys.argv[1]), Path(sys.argv[2]) 39 | except IndexError: 40 | raise Exception( 41 | f"Run using: {sys.argv[0]} [input directory] [output directory]" 42 | ) 43 | 44 | def _get_group_path(self) -> str: 45 | """Gets the path for the next group to be processed. 46 | 47 | Raises: 48 | ValueError: An unknown message was received from DOXA. 49 | 50 | Returns: 51 | str: The path of the next group. 52 | """ 53 | 54 | msg = input() 55 | if not msg.startswith("Process "): 56 | raise ValueError(f"Unknown messsage {msg}") 57 | 58 | return msg[8:] 59 | 60 | def _evaluate_group(self, group: dict) -> List[np.ndarray]: 61 | """Evaluates a group of satellite image sequences using 62 | the user-implemented model(s). 63 | 64 | Args: 65 | group (dict): The OSGB and satellite imagery data. 66 | 67 | Returns: 68 | List[np.ndarray]: The predictions. 69 | """ 70 | batch_size = 16 71 | split_num = len(group["data"]) // batch_size 72 | osgb_splits = np.array_split(group["osgb"], split_num, axis=0) 73 | data_splits = np.array_split(group["data"], split_num, axis=0) 74 | 75 | all_preds = [] 76 | for (osgb, data) in zip(osgb_splits, data_splits): 77 | bs = len(data) 78 | preds = self.predict(osgb, data) 79 | for b in range(bs): 80 | all_preds.append(preds[b]) 81 | return all_preds 82 | # return [self.predict(*datum) for datum in zip(group["osgb"], group["data"])] 83 | 84 | def evaluate(self): 85 | """Evaluates the user's model on DOXA. 86 | 87 | Messages are sent and received through stdio. 88 | 89 | The input data is loaded from a directory in groups. 90 | 91 | The predictions are written to another directory in groups. 92 | 93 | Raises: 94 | Exception: An error occurred somewhere. 95 | """ 96 | 97 | print("STARTUP") 98 | input_path, output_path = self._get_io_paths() 99 | 100 | # process test data groups 101 | while True: 102 | # load the data for the group DOXA requests 103 | group_path = self._get_group_path() 104 | group_data = np.load(input_path / group_path) 105 | 106 | # make predictions for this group 107 | try: 108 | predictions = self._evaluate_group(group_data) 109 | except Exception as err: 110 | raise Exception(f"Error while processing {group_path}: {str(err)}") 111 | 112 | # save the output group predictions 113 | np.savez( 114 | output_path / group_path, 115 | data=np.stack(predictions), 116 | ) 117 | print(f"Exported {group_path}") 118 | -------------------------------------------------------------------------------- /experiments/climatehack-submission/submission/dgmr-oneshot/dgmr/__init__.py: -------------------------------------------------------------------------------- 1 | from .dgmr import DGMR 2 | from .generators import Sampler, Generator 3 | from .discriminators import SpatialDiscriminator, TemporalDiscriminator, Discriminator 4 | from .common import LatentConditioningStack, ContextConditioningStack 5 | -------------------------------------------------------------------------------- /experiments/climatehack-submission/submission/dgmr-oneshot/dgmr/generators.py: -------------------------------------------------------------------------------- 1 | import einops 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.nn.modules.pixelshuffle import PixelShuffle 5 | from torch.nn.utils.parametrizations import spectral_norm 6 | from typing import List 7 | from dgmr.common import GBlock, UpsampleGBlock 8 | from dgmr.layers import ConvGRU 9 | from huggingface_hub import PyTorchModelHubMixin 10 | import logging 11 | 12 | logger = logging.getLogger(__name__) 13 | logger.setLevel(logging.WARN) 14 | 15 | 16 | class Sampler(torch.nn.Module, PyTorchModelHubMixin): 17 | def __init__( 18 | self, 19 | forecast_steps: int = 18, 20 | context_channels: int = 384, 21 | latent_channels: int = 384, 22 | output_channels: int = 1, 23 | **kwargs 24 | ): 25 | """ 26 | Sampler from the Skillful Nowcasting, see https://arxiv.org/pdf/2104.00954.pdf 27 | 28 | The sampler takes the output from the Latent and Context conditioning stacks and 29 | creates one stack of ConvGRU layers per future timestep. 30 | Args: 31 | forecast_steps: Number of forecast steps 32 | latent_channels: Number of input channels to the lowest ConvGRU layer 33 | """ 34 | super().__init__() 35 | config = locals() 36 | config.pop("__class__") 37 | config.pop("self") 38 | self.config = kwargs.get("config", config) 39 | self.forecast_steps = self.config["forecast_steps"] 40 | latent_channels = self.config["latent_channels"] 41 | context_channels = self.config["context_channels"] 42 | output_channels = self.config["output_channels"] 43 | 44 | self.gru_conv_1x1 = spectral_norm( 45 | torch.nn.Conv2d( 46 | in_channels=context_channels, 47 | out_channels=latent_channels * self.forecast_steps, 48 | kernel_size=(1, 1), 49 | ) 50 | ) 51 | self.g1 = GBlock( 52 | input_channels=latent_channels * self.forecast_steps, 53 | output_channels=latent_channels * self.forecast_steps, 54 | ) 55 | self.up_g1 = UpsampleGBlock( 56 | input_channels=latent_channels * self.forecast_steps, 57 | output_channels=latent_channels * self.forecast_steps // 2, 58 | ) 59 | 60 | self.gru_conv_1x1_2 = spectral_norm( 61 | torch.nn.Conv2d( 62 | in_channels=self.up_g1.output_channels + context_channels // 2, 63 | out_channels=latent_channels * self.forecast_steps // 2, 64 | kernel_size=(1, 1), 65 | ) 66 | ) 67 | self.g2 = GBlock( 68 | input_channels=latent_channels * self.forecast_steps // 2, 69 | output_channels=latent_channels * self.forecast_steps // 2, 70 | ) 71 | self.up_g2 = UpsampleGBlock( 72 | input_channels=latent_channels * self.forecast_steps // 2, 73 | output_channels=latent_channels * self.forecast_steps // 4, 74 | ) 75 | 76 | self.gru_conv_1x1_3 = spectral_norm( 77 | torch.nn.Conv2d( 78 | in_channels=self.up_g2.output_channels + context_channels // 4, 79 | out_channels=latent_channels * self.forecast_steps // 4, 80 | kernel_size=(1, 1), 81 | ) 82 | ) 83 | self.g3 = GBlock( 84 | input_channels=latent_channels * self.forecast_steps // 4, 85 | output_channels=latent_channels * self.forecast_steps // 4, 86 | ) 87 | self.up_g3 = UpsampleGBlock( 88 | input_channels=latent_channels * self.forecast_steps // 4, 89 | output_channels=latent_channels * self.forecast_steps // 8, 90 | ) 91 | 92 | self.gru_conv_1x1_4 = spectral_norm( 93 | torch.nn.Conv2d( 94 | in_channels=self.up_g3.output_channels + context_channels // 8, 95 | out_channels=latent_channels * self.forecast_steps // 8, 96 | kernel_size=(1, 1), 97 | ) 98 | ) 99 | self.g4 = GBlock( 100 | input_channels=latent_channels * self.forecast_steps // 8, 101 | output_channels=latent_channels * self.forecast_steps // 8, 102 | ) 103 | self.up_g4 = UpsampleGBlock( 104 | input_channels=latent_channels * self.forecast_steps // 8, 105 | output_channels=latent_channels * self.forecast_steps // 16, 106 | ) 107 | 108 | self.bn = torch.nn.BatchNorm2d(latent_channels * self.forecast_steps // 16) 109 | self.relu = torch.nn.ReLU() 110 | self.conv_1x1 = spectral_norm( 111 | torch.nn.Conv2d( 112 | in_channels=latent_channels * self.forecast_steps // 16, 113 | out_channels=4 * output_channels * self.forecast_steps, 114 | kernel_size=(1, 1), 115 | ) 116 | ) 117 | 118 | self.depth2space = PixelShuffle(upscale_factor=2) 119 | 120 | def forward(self, conditioning_states: List[torch.Tensor]) -> torch.Tensor: 121 | """ 122 | Perform the sampling from Skillful Nowcasting with GANs 123 | Args: 124 | conditioning_states: Outputs from the `ContextConditioningStack` with the 4 input states, ordered from largest to smallest spatially 125 | latent_dim: Output from `LatentConditioningStack` for input into the ConvGRUs 126 | 127 | Returns: 128 | forecast_steps-length output of images for future timesteps 129 | 130 | """ 131 | # Iterate through each forecast step 132 | # Initialize with conditioning state for first one, output for second one 133 | init_states = conditioning_states 134 | 135 | layer4_states = self.gru_conv_1x1(init_states[3]) 136 | layer4_states = self.g1(layer4_states) 137 | layer4_states = self.up_g1(layer4_states) 138 | 139 | # Layer 3. 140 | layer3_states = torch.cat([layer4_states, init_states[2]], dim=1) 141 | layer3_states = self.gru_conv_1x1_2(layer3_states) 142 | layer3_states = self.g2(layer3_states) 143 | layer3_states = self.up_g2(layer3_states) 144 | 145 | # Layer 2. 146 | layer2_states = torch.cat([layer3_states, init_states[1]], dim=1) 147 | layer2_states = self.gru_conv_1x1_3(layer2_states) 148 | layer2_states = self.g3(layer2_states) 149 | layer2_states = self.up_g3(layer2_states) 150 | 151 | # Layer 1 (top-most). 152 | layer1_states = torch.cat([layer2_states, init_states[0]], dim=1) 153 | layer1_states = self.gru_conv_1x1_4(layer1_states) 154 | layer1_states = self.g4(layer1_states) 155 | layer1_states = self.up_g4(layer1_states) 156 | 157 | # Final stuff 158 | output_states = self.relu(self.bn(layer1_states)) 159 | output_states = self.conv_1x1(output_states) 160 | output_states = self.depth2space(output_states) 161 | 162 | # The satellite dimension was lost, add it back 163 | output_states = torch.unsqueeze(output_states, dim=2) 164 | 165 | return output_states 166 | 167 | 168 | class Generator(torch.nn.Module, PyTorchModelHubMixin): 169 | def __init__( 170 | self, 171 | conditioning_stack: torch.nn.Module, 172 | sampler: torch.nn.Module, 173 | ): 174 | """ 175 | Wraps the three parts of the generator for simpler calling 176 | Args: 177 | conditioning_stack: 178 | latent_stack: 179 | sampler: 180 | """ 181 | super().__init__() 182 | self.conditioning_stack = conditioning_stack 183 | self.sampler = sampler 184 | 185 | def forward(self, x): 186 | conditioning_states = self.conditioning_stack(x) 187 | x = self.sampler(conditioning_states) 188 | return x 189 | -------------------------------------------------------------------------------- /experiments/climatehack-submission/submission/dgmr-oneshot/dgmr/hub.py: -------------------------------------------------------------------------------- 1 | """ 2 | Originally Taken from https://github.com/rwightman/ 3 | 4 | https://github.com/rwightman/pytorch-image-models/ 5 | blob/acd6c687fd1c0507128f0ce091829b233c8560b9/timm/models/hub.py 6 | """ 7 | 8 | import json 9 | import logging 10 | import os 11 | from functools import partial 12 | 13 | import torch 14 | 15 | 16 | try: 17 | from huggingface_hub import cached_download, hf_hub_url 18 | 19 | cached_download = partial(cached_download, library_name="dgmr") 20 | except ImportError: 21 | hf_hub_url = None 22 | cached_download = None 23 | 24 | from huggingface_hub import CONFIG_NAME, PYTORCH_WEIGHTS_NAME, ModelHubMixin, hf_hub_download 25 | 26 | MODEL_CARD_MARKDOWN = """--- 27 | license: mit 28 | tags: 29 | - nowcasting 30 | - forecasting 31 | - timeseries 32 | - remote-sensing 33 | - gan 34 | --- 35 | 36 | # {model_name} 37 | 38 | ## Model description 39 | 40 | [More information needed] 41 | 42 | ## Intended uses & limitations 43 | 44 | [More information needed] 45 | 46 | ## How to use 47 | 48 | [More information needed] 49 | 50 | ## Limitations and bias 51 | 52 | [More information needed] 53 | 54 | ## Training data 55 | 56 | [More information needed] 57 | 58 | ## Training procedure 59 | 60 | [More information needed] 61 | 62 | ## Evaluation results 63 | 64 | [More information needed] 65 | 66 | """ 67 | 68 | _logger = logging.getLogger(__name__) 69 | 70 | 71 | class NowcastingModelHubMixin(ModelHubMixin): 72 | """ 73 | HuggingFace ModelHubMixin containing specific adaptions for Nowcasting models 74 | """ 75 | 76 | def __init__(self, *args, **kwargs): 77 | """ 78 | Mixin for pl.LightningModule and Hugging Face 79 | 80 | Mix this class with your pl.LightningModule class to easily push / download 81 | the model via the Hugging Face Hub 82 | 83 | Example:: 84 | 85 | >>> from dgmr.hub import NowcastingModelHubMixin 86 | 87 | >>> class MyModel(nn.Module, NowcastingModelHubMixin): 88 | ... def __init__(self, **kwargs): 89 | ... super().__init__() 90 | ... self.layer = ... 91 | ... def forward(self, ...) 92 | ... return ... 93 | 94 | >>> model = MyModel() 95 | >>> model.push_to_hub("mymodel") # Pushing model-weights to hf-hub 96 | 97 | >>> # Downloading weights from hf-hub & model will be initialized from those weights 98 | >>> model = MyModel.from_pretrained("username/mymodel") 99 | """ 100 | 101 | def _create_model_card(self, path): 102 | model_card = MODEL_CARD_MARKDOWN.format(model_name=type(self).__name__) 103 | with open(os.path.join(path, "README.md"), "w") as f: 104 | f.write(model_card) 105 | 106 | def _save_config(self, module, save_directory): 107 | config = dict(module.hparams) 108 | path = os.path.join(save_directory, CONFIG_NAME) 109 | with open(path, "w") as f: 110 | json.dump(config, f) 111 | 112 | def _save_pretrained(self, save_directory: str, save_config: bool = True): 113 | # Save model weights 114 | path = os.path.join(save_directory, PYTORCH_WEIGHTS_NAME) 115 | model_to_save = self.module if hasattr(self, "module") else self 116 | torch.save(model_to_save.state_dict(), path) 117 | # Save model config 118 | if save_config and model_to_save.hparams: 119 | self._save_config(model_to_save, save_directory) 120 | # Save model card 121 | self._create_model_card(save_directory) 122 | 123 | @classmethod 124 | def _from_pretrained( 125 | cls, 126 | model_id, 127 | revision, 128 | cache_dir, 129 | force_download, 130 | proxies, 131 | resume_download, 132 | local_files_only, 133 | use_auth_token, 134 | map_location="cpu", 135 | strict=False, 136 | **model_kwargs, 137 | ): 138 | map_location = torch.device(map_location) 139 | 140 | if os.path.isdir(model_id): 141 | print("Loading weights from local directory") 142 | model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME) 143 | else: 144 | model_file = hf_hub_download( 145 | repo_id=model_id, 146 | filename=PYTORCH_WEIGHTS_NAME, 147 | revision=revision, 148 | cache_dir=cache_dir, 149 | force_download=force_download, 150 | proxies=proxies, 151 | resume_download=resume_download, 152 | use_auth_token=use_auth_token, 153 | local_files_only=local_files_only, 154 | ) 155 | model = cls(**model_kwargs["config"]) 156 | 157 | state_dict = torch.load(model_file, map_location=map_location) 158 | model.load_state_dict(state_dict, strict=strict) 159 | model.eval() 160 | 161 | return model 162 | -------------------------------------------------------------------------------- /experiments/climatehack-submission/submission/dgmr-oneshot/dgmr/layers/Attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | import einops 5 | 6 | 7 | def attention_einsum(q, k, v): 8 | """Apply the attention operator to tensors of shape [h, w, c].""" 9 | 10 | # Reshape 3D tensors to 2D tensor with first dimension L = h x w. 11 | k = einops.rearrange(k, "h w c -> (h w) c") # [h, w, c] -> [L, c] 12 | v = einops.rearrange(v, "h w c -> (h w) c") # [h, w, c] -> [L, c] 13 | 14 | # Einstein summation corresponding to the query * key operation. 15 | beta = F.softmax(torch.einsum("hwc, Lc->hwL", q, k), dim=-1) 16 | 17 | # Einstein summation corresponding to the attention * value operation. 18 | out = torch.einsum("hwL, Lc->hwc", beta, v) 19 | return out 20 | 21 | 22 | class AttentionLayer(torch.nn.Module): 23 | """Attention Module""" 24 | 25 | def __init__(self, input_channels: int, output_channels: int, ratio_kq=8, ratio_v=8): 26 | super(AttentionLayer, self).__init__() 27 | 28 | self.ratio_kq = ratio_kq 29 | self.ratio_v = ratio_v 30 | self.output_channels = output_channels 31 | self.input_channels = input_channels 32 | 33 | # Compute query, key and value using 1x1 convolutions. 34 | self.query = torch.nn.Conv2d( 35 | in_channels=input_channels, 36 | out_channels=self.output_channels // self.ratio_kq, 37 | kernel_size=(1, 1), 38 | padding="valid", 39 | bias=False, 40 | ) 41 | self.key = torch.nn.Conv2d( 42 | in_channels=input_channels, 43 | out_channels=self.output_channels // self.ratio_kq, 44 | kernel_size=(1, 1), 45 | padding="valid", 46 | bias=False, 47 | ) 48 | self.value = torch.nn.Conv2d( 49 | in_channels=input_channels, 50 | out_channels=self.output_channels // self.ratio_v, 51 | kernel_size=(1, 1), 52 | padding="valid", 53 | bias=False, 54 | ) 55 | 56 | self.last_conv = torch.nn.Conv2d( 57 | in_channels=self.output_channels // 8, 58 | out_channels=self.output_channels, 59 | kernel_size=(1, 1), 60 | padding="valid", 61 | bias=False, 62 | ) 63 | 64 | # Learnable gain parameter 65 | self.gamma = nn.Parameter(torch.zeros(1)) 66 | 67 | def forward(self, x: torch.Tensor) -> torch.Tensor: 68 | # Compute query, key and value using 1x1 convolutions. 69 | query = self.query(x) 70 | key = self.key(x) 71 | value = self.value(x) 72 | # Apply the attention operation. 73 | out = [] 74 | for b in range(x.shape[0]): 75 | # Apply to each in batch 76 | out.append(attention_einsum(query[b], key[b], value[b])) 77 | out = torch.stack(out, dim=0) 78 | out = self.gamma * self.last_conv(out) 79 | # Residual connection. 80 | return out + x 81 | -------------------------------------------------------------------------------- /experiments/climatehack-submission/submission/dgmr-oneshot/dgmr/layers/ConvGRU.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.nn.utils.parametrizations import spectral_norm 4 | 5 | 6 | class ConvGRUCell(torch.nn.Module): 7 | """A ConvGRU implementation.""" 8 | 9 | def __init__( 10 | self, input_channels: int, output_channels: int, kernel_size=3, sn_eps=0.0001 11 | ): 12 | """Constructor. 13 | 14 | Args: 15 | kernel_size: kernel size of the convolutions. Default: 3. 16 | sn_eps: constant for spectral normalization. Default: 1e-4. 17 | """ 18 | super().__init__() 19 | self._kernel_size = kernel_size 20 | self._sn_eps = sn_eps 21 | self.read_gate_conv = spectral_norm( 22 | torch.nn.Conv2d( 23 | in_channels=input_channels, 24 | out_channels=output_channels, 25 | kernel_size=(kernel_size, kernel_size), 26 | padding=1, 27 | ), 28 | eps=sn_eps, 29 | ) 30 | self.update_gate_conv = spectral_norm( 31 | torch.nn.Conv2d( 32 | in_channels=input_channels, 33 | out_channels=output_channels, 34 | kernel_size=(kernel_size, kernel_size), 35 | padding=1, 36 | ), 37 | eps=sn_eps, 38 | ) 39 | self.output_conv = spectral_norm( 40 | torch.nn.Conv2d( 41 | in_channels=input_channels, 42 | out_channels=output_channels, 43 | kernel_size=(kernel_size, kernel_size), 44 | padding=1, 45 | ), 46 | eps=sn_eps, 47 | ) 48 | 49 | def forward(self, x, prev_state): 50 | """ 51 | ConvGRU forward, returning the current+new state 52 | 53 | Args: 54 | x: Input tensor 55 | prev_state: Previous state 56 | 57 | Returns: 58 | New tensor plus the new state 59 | """ 60 | # Concatenate the inputs and previous state along the channel axis. 61 | xh = torch.cat([x, prev_state], dim=1) 62 | 63 | # Read gate of the GRU. 64 | read_gate = F.sigmoid(self.read_gate_conv(xh)) 65 | 66 | # Update gate of the GRU. 67 | update_gate = F.sigmoid(self.update_gate_conv(xh)) 68 | 69 | # Gate the inputs. 70 | gated_input = torch.cat([x, read_gate * prev_state], dim=1) 71 | 72 | # Gate the cell and state / outputs. 73 | c = F.relu(self.output_conv(gated_input)) 74 | out = update_gate * prev_state + (1.0 - update_gate) * c 75 | new_state = out 76 | 77 | return out, new_state 78 | 79 | 80 | class ConvGRU(torch.nn.Module): 81 | """ConvGRU Cell wrapper to replace tf.static_rnn in TF implementation""" 82 | 83 | def __init__( 84 | self, 85 | input_channels: int, 86 | output_channels: int, 87 | kernel_size: int = 3, 88 | sn_eps=0.0001, 89 | ): 90 | super().__init__() 91 | self.cell = ConvGRUCell(input_channels, output_channels, kernel_size, sn_eps) 92 | 93 | def forward(self, x: torch.Tensor, hidden_state=None) -> torch.Tensor: 94 | outputs = [] 95 | for step in range(len(x)): 96 | # Compute current timestep 97 | output, hidden_state = self.cell(x[step], hidden_state) 98 | outputs.append(output) 99 | # Stack outputs to return as tensor 100 | outputs = torch.stack(outputs, dim=0) 101 | return outputs 102 | -------------------------------------------------------------------------------- /experiments/climatehack-submission/submission/dgmr-oneshot/dgmr/layers/CoordConv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class AddCoords(nn.Module): 6 | def __init__(self, with_r=False): 7 | super().__init__() 8 | self.with_r = with_r 9 | 10 | def forward(self, input_tensor): 11 | """ 12 | Args: 13 | input_tensor: shape(batch, channel, x_dim, y_dim) 14 | """ 15 | batch_size, _, x_dim, y_dim = input_tensor.size() 16 | 17 | xx_channel = torch.arange(x_dim).repeat(1, y_dim, 1) 18 | yy_channel = torch.arange(y_dim).repeat(1, x_dim, 1).transpose(1, 2) 19 | 20 | xx_channel = xx_channel.float() / (x_dim - 1) 21 | yy_channel = yy_channel.float() / (y_dim - 1) 22 | 23 | xx_channel = xx_channel * 2 - 1 24 | yy_channel = yy_channel * 2 - 1 25 | 26 | xx_channel = xx_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3) 27 | yy_channel = yy_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3) 28 | 29 | ret = torch.cat( 30 | [ 31 | input_tensor, 32 | xx_channel.type_as(input_tensor), 33 | yy_channel.type_as(input_tensor), 34 | ], 35 | dim=1, 36 | ) 37 | 38 | if self.with_r: 39 | rr = torch.sqrt( 40 | torch.pow(xx_channel.type_as(input_tensor) - 0.5, 2) 41 | + torch.pow(yy_channel.type_as(input_tensor) - 0.5, 2) 42 | ) 43 | ret = torch.cat([ret, rr], dim=1) 44 | 45 | return ret 46 | 47 | 48 | class CoordConv(nn.Module): 49 | def __init__(self, in_channels, out_channels, with_r=False, **kwargs): 50 | super().__init__() 51 | self.addcoords = AddCoords(with_r=with_r) 52 | in_size = in_channels + 2 53 | if with_r: 54 | in_size += 1 55 | self.conv = nn.Conv2d(in_size, out_channels, **kwargs) 56 | 57 | def forward(self, x): 58 | ret = self.addcoords(x) 59 | ret = self.conv(ret) 60 | return ret 61 | -------------------------------------------------------------------------------- /experiments/climatehack-submission/submission/dgmr-oneshot/dgmr/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .Attention import AttentionLayer 2 | from .ConvGRU import ConvGRU 3 | from .CoordConv import CoordConv 4 | -------------------------------------------------------------------------------- /experiments/climatehack-submission/submission/dgmr-oneshot/dgmr/layers/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from dgmr.layers import CoordConv 3 | 4 | 5 | def get_conv_layer(conv_type: str = "standard") -> torch.nn.Module: 6 | if conv_type == "standard": 7 | conv_layer = torch.nn.Conv2d 8 | elif conv_type == "coord": 9 | conv_layer = CoordConv 10 | elif conv_type == "3d": 11 | conv_layer = torch.nn.Conv3d 12 | else: 13 | raise ValueError(f"{conv_type} is not a recognized Conv method") 14 | return conv_layer 15 | -------------------------------------------------------------------------------- /experiments/climatehack-submission/submission/doxa.yaml: -------------------------------------------------------------------------------- 1 | language: python 2 | entrypoint: evaluate.py 3 | -------------------------------------------------------------------------------- /experiments/climatehack-submission/submission/evaluate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import cv2 4 | 5 | from climatehack import BaseEvaluator 6 | 7 | import sys 8 | 9 | sys.path.append("./dgmr-oneshot") 10 | import dgmr 11 | 12 | # DEVICE = torch.device("cpu") 13 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 14 | 15 | _MEAN_PIXEL = 240.3414 16 | _STD_PIXEL = 146.52366 17 | 18 | 19 | def transform(x): 20 | return (x - _MEAN_PIXEL) / _STD_PIXEL 21 | 22 | 23 | def inv_transform(x): 24 | return (x * _STD_PIXEL) + _MEAN_PIXEL 25 | 26 | 27 | def warp_flow(img, flow): 28 | h, w = flow.shape[:2] 29 | flow = -flow 30 | flow[:, :, 0] += np.arange(w) 31 | flow[:, :, 1] += np.arange(h)[:, np.newaxis] 32 | res = cv2.remap(img, flow, None, cv2.INTER_LINEAR, borderMode=cv2.BORDER_REPLICATE) 33 | return res 34 | 35 | 36 | class Evaluator(BaseEvaluator): 37 | def setup(self): 38 | """Sets up anything required for evaluation. 39 | In this case, it loads the trained model (in evaluation mode).""" 40 | ccs = dgmr.common.ContextConditioningStack( 41 | input_channels=1, 42 | conv_type="standard", 43 | output_channels=160, 44 | ) 45 | 46 | sampler = dgmr.generators.Sampler( 47 | forecast_steps=24, 48 | latent_channels=96, 49 | context_channels=160, 50 | output_channels=1, 51 | ) 52 | model = dgmr.generators.Generator(ccs, sampler) 53 | model.load_state_dict(torch.load("weights/model.pt", map_location=DEVICE)) 54 | self.model = model.to(DEVICE) 55 | self.model.train() 56 | # print("DOING TRAIN MODE") 57 | 58 | def predict(self, coordinates: np.ndarray, data: np.ndarray) -> np.ndarray: 59 | """Makes a prediction for the next two hours of satellite imagery. 60 | Args: 61 | coordinates (np.ndarray): the OSGB x and y coordinates (2, 128, 128) 62 | data (np.ndarray): an array of 12 128*128 satellite images (bs, 12, 128, 128) 63 | Returns: 64 | np.ndarray: an array of 24 64*64 satellite image predictions (bs, 24, 64, 64) 65 | """ 66 | # prediction_opt_flow = self._predict_opt_flow(data) 67 | prediction_dgmr = self._predict_dgmr(coordinates, data) 68 | # copy the opt flow predictions in 69 | # prediction_dgmr[:, : prediction_opt_flow.shape[1]] = prediction_opt_flow 70 | prediction = prediction_dgmr 71 | return prediction 72 | 73 | def _predict_dgmr(self, coordinates: np.ndarray, data: np.ndarray) -> np.ndarray: 74 | bs = data.shape[0] 75 | data = data[:, -4:] 76 | data = torch.FloatTensor(transform(data)).float().to(DEVICE) 77 | # add a satellite dimension 78 | data = torch.unsqueeze(data, dim=2) 79 | # make a batch to help with norm 80 | # data = torch.cat([data, self.default_batch], dim=0) 81 | with torch.no_grad(): 82 | prediction = self.model(data) 83 | # remove the satellite dimension and grab the inner 64x64 84 | prediction = inv_transform(prediction[:, :, 0, 32:96, 32:96]) 85 | if prediction.device == "cpu": 86 | prediction = prediction.numpy() 87 | else: 88 | prediction = prediction.detach().cpu().numpy() 89 | return prediction 90 | 91 | def _predict_opt_flow(self, data: np.ndarray) -> np.ndarray: 92 | bs = data.shape[0] 93 | forecast = 1 94 | prediction = np.zeros((bs, forecast, 64, 64), dtype=np.float32) 95 | test_params = { 96 | "pyr_scale": 0.5, 97 | "levels": 2, 98 | "winsize": 40, 99 | "iterations": 3, 100 | "poly_n": 5, 101 | "poly_sigma": 0.7, 102 | } 103 | for i in range(bs): 104 | sample = data[i] 105 | cur = sample[-1].astype(np.float32) 106 | flow = cv2.calcOpticalFlowFarneback( 107 | prev=sample[-2], 108 | next=sample[-1], 109 | flow=None, 110 | **test_params, 111 | flags=cv2.OPTFLOW_FARNEBACK_GAUSSIAN, 112 | ) 113 | for j in range(forecast): 114 | cur = warp_flow(cur, flow) 115 | prediction[i, j] = cur[32:96, 32:96] 116 | return prediction 117 | 118 | 119 | def main(): 120 | evaluator = Evaluator() 121 | evaluator.evaluate() 122 | 123 | 124 | if __name__ == "__main__": 125 | main() 126 | -------------------------------------------------------------------------------- /experiments/climatehack-submission/submission/validate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pytorch_msssim import MS_SSIM 3 | from torch import from_numpy 4 | import torch 5 | import tqdm 6 | 7 | from evaluate import Evaluator 8 | 9 | 10 | def main(): 11 | features = np.load("../features.npz") 12 | targets = np.load("../targets.npz") 13 | 14 | criterion = MS_SSIM(data_range=1023.0, size_average=True, win_size=3, channel=1) 15 | evaluator = Evaluator() 16 | 17 | batch_size = 16 18 | split_num = len(features["data"]) // batch_size 19 | osgb_splits = np.array_split(features["osgb"], split_num, axis=0) 20 | data_splits = np.array_split(features["data"], split_num, axis=0) 21 | targets_splits = np.array_split(targets["data"], split_num, axis=0) 22 | 23 | pbar = tqdm.tqdm( 24 | zip(osgb_splits, data_splits, targets_splits), 25 | total=len(data_splits), 26 | ) 27 | 28 | scores = [] 29 | for (osgb, data, target) in pbar: 30 | bs = len(data) 31 | preds = from_numpy(evaluator.predict(osgb, data)) 32 | trgs = from_numpy(target) 33 | # this and the batch indexing essentially makes the 24 timesteps the batch 34 | preds = torch.unsqueeze(preds, dim=2) 35 | trgs = torch.unsqueeze(trgs, dim=2) 36 | 37 | for i in range(bs): 38 | # grab the current batch 39 | score = criterion(preds[i], trgs[i]).item() 40 | scores.append(score) 41 | pbar.set_description(f"Avg: {np.mean(scores)}") 42 | 43 | # scores = [ 44 | # criterion( 45 | # from_numpy(evaluator.predict(*datum)).view(24, 64, 64).unsqueeze(dim=1), 46 | # from_numpy(target).view(24, 64, 64).unsqueeze(dim=1), 47 | # ).item() 48 | # for *datum, target in tqdm.tqdm( 49 | # zip(features["osgb"], features["data"], targets["data"]), 50 | # total=len(features["data"]), 51 | # ) 52 | # ] 53 | 54 | print(f"Score: {np.mean(scores)} ({np.std(scores)})") 55 | 56 | 57 | if __name__ == "__main__": 58 | main() 59 | -------------------------------------------------------------------------------- /experiments/climatehack-submission/submit.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -eEou pipefail 4 | 5 | python doxa_cli.py agent upload climatehack ./submission 6 | -------------------------------------------------------------------------------- /experiments/dgmr-dct/dgmr/__init__.py: -------------------------------------------------------------------------------- 1 | from .dgmr import DGMR 2 | from .generators import Sampler, Generator 3 | from .discriminators import SpatialDiscriminator, TemporalDiscriminator, Discriminator 4 | from .common import LatentConditioningStack, ContextConditioningStack 5 | -------------------------------------------------------------------------------- /experiments/dgmr-dct/dgmr/dgmr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from dgmr.losses import ( 3 | NowcastingLoss, 4 | GridCellLoss, 5 | loss_hinge_disc, 6 | loss_hinge_gen, 7 | grid_cell_regularizer, 8 | ) 9 | import pytorch_lightning as pl 10 | import torchvision 11 | from dgmr.common import LatentConditioningStack, ContextConditioningStack 12 | from dgmr.generators import Sampler, Generator 13 | from dgmr.discriminators import Discriminator 14 | from dgmr.hub import NowcastingModelHubMixin 15 | 16 | 17 | class DGMR(pl.LightningModule, NowcastingModelHubMixin): 18 | """Deep Generative Model of Radar""" 19 | 20 | def __init__( 21 | self, 22 | forecast_steps: int = 18, 23 | input_channels: int = 1, 24 | output_shape: int = 256, 25 | gen_lr: float = 5e-5, 26 | disc_lr: float = 2e-4, 27 | visualize: bool = False, 28 | conv_type: str = "standard", 29 | num_samples: int = 6, 30 | grid_lambda: float = 20.0, 31 | beta1: float = 0.0, 32 | beta2: float = 0.999, 33 | latent_channels: int = 768, 34 | context_channels: int = 384, 35 | **kwargs, 36 | ): 37 | """ 38 | Nowcasting GAN is an attempt to recreate DeepMind's Skillful Nowcasting GAN from https://arxiv.org/abs/2104.00954 39 | but slightly modified for multiple satellite channels 40 | 41 | Args: 42 | forecast_steps: Number of steps to predict in the future 43 | input_channels: Number of input channels per image 44 | visualize: Whether to visualize output during training 45 | gen_lr: Learning rate for the generator 46 | disc_lr: Learning rate for the discriminators, shared for both temporal and spatial discriminator 47 | conv_type: Type of 2d convolution to use, see satflow/models/utils.py for options 48 | beta1: Beta1 for Adam optimizer 49 | beta2: Beta2 for Adam optimizer 50 | num_samples: Number of samples of the latent space to sample for training/validation 51 | grid_lambda: Lambda for the grid regularization loss 52 | output_shape: Shape of the output predictions, generally should be same as the input shape 53 | latent_channels: Number of channels that the latent space should be reshaped to, 54 | input dimension into ConvGRU, also affects the number of channels for other linked inputs/outputs 55 | pretrained: 56 | """ 57 | super().__init__() 58 | config = locals() 59 | config.pop("__class__") 60 | config.pop("self") 61 | self.config = kwargs.get("config", config) 62 | input_channels = self.config["input_channels"] 63 | forecast_steps = self.config["forecast_steps"] 64 | output_shape = self.config["output_shape"] 65 | gen_lr = self.config["gen_lr"] 66 | disc_lr = self.config["disc_lr"] 67 | conv_type = self.config["conv_type"] 68 | num_samples = self.config["num_samples"] 69 | grid_lambda = self.config["grid_lambda"] 70 | beta1 = self.config["beta1"] 71 | beta2 = self.config["beta2"] 72 | latent_channels = self.config["latent_channels"] 73 | context_channels = self.config["context_channels"] 74 | visualize = self.config["visualize"] 75 | self.gen_lr = gen_lr 76 | self.disc_lr = disc_lr 77 | self.beta1 = beta1 78 | self.beta2 = beta2 79 | self.discriminator_loss = NowcastingLoss() 80 | self.grid_regularizer = GridCellLoss() 81 | self.grid_lambda = grid_lambda 82 | self.num_samples = num_samples 83 | self.visualize = visualize 84 | self.latent_channels = latent_channels 85 | self.context_channels = context_channels 86 | self.input_channels = input_channels 87 | self.conditioning_stack = ContextConditioningStack( 88 | input_channels=input_channels, 89 | conv_type=conv_type, 90 | output_channels=self.context_channels, 91 | ) 92 | self.latent_stack = LatentConditioningStack( 93 | shape=(8 * self.input_channels, output_shape // 32, output_shape // 32), 94 | output_channels=self.latent_channels, 95 | ) 96 | self.sampler = Sampler( 97 | forecast_steps=forecast_steps, 98 | latent_channels=self.latent_channels, 99 | context_channels=self.context_channels, 100 | ) 101 | self.generator = Generator( 102 | self.conditioning_stack, self.latent_stack, self.sampler 103 | ) 104 | self.discriminator = Discriminator(input_channels) 105 | self.save_hyperparameters() 106 | 107 | self.global_iteration = 0 108 | 109 | # Important: This property activates manual optimization. 110 | self.automatic_optimization = False 111 | torch.autograd.set_detect_anomaly(True) 112 | 113 | def forward(self, x): 114 | x = self.generator(x) 115 | return x 116 | 117 | def training_step(self, batch, batch_idx): 118 | images, future_images = batch 119 | 120 | self.global_iteration += 1 121 | g_opt, d_opt = self.optimizers() 122 | ########################## 123 | # Optimize Discriminator # 124 | ########################## 125 | # Two discriminator steps per generator step 126 | for _ in range(2): 127 | predictions = self(images) 128 | # Cat along time dimension [B, T, C, H, W] 129 | generated_sequence = torch.cat([images, predictions], dim=1) 130 | real_sequence = torch.cat([images, future_images], dim=1) 131 | # Cat long batch for the real+generated 132 | concatenated_inputs = torch.cat([real_sequence, generated_sequence], dim=0) 133 | 134 | concatenated_outputs = self.discriminator(concatenated_inputs) 135 | 136 | score_real, score_generated = torch.split(concatenated_outputs, 1, dim=1) 137 | discriminator_loss = loss_hinge_disc(score_generated, score_real) 138 | d_opt.zero_grad() 139 | self.manual_backward(discriminator_loss) 140 | d_opt.step() 141 | 142 | ###################### 143 | # Optimize Generator # 144 | ###################### 145 | predictions = [self(images) for _ in range(6)] 146 | grid_cell_reg = grid_cell_regularizer( 147 | torch.stack(predictions, dim=0), future_images 148 | ) 149 | # Concat along time dimension 150 | generated_sequence = [torch.cat([images, x], dim=1) for x in predictions] 151 | real_sequence = torch.cat([images, future_images], dim=1) 152 | # Cat long batch for the real+generated, for each example in the range 153 | # For each of the 6 examples 154 | generated_scores = [] 155 | for g_seq in generated_sequence: 156 | concatenated_inputs = torch.cat([real_sequence, g_seq], dim=0) 157 | concatenated_outputs = self.discriminator(concatenated_inputs) 158 | score_real, score_generated = torch.split(concatenated_outputs, 1, dim=1) 159 | generated_scores.append(score_generated) 160 | generator_disc_loss = loss_hinge_gen(torch.cat(generated_scores, dim=0)) 161 | generator_loss = generator_disc_loss + self.grid_lambda * grid_cell_reg 162 | g_opt.zero_grad() 163 | self.manual_backward(generator_loss) 164 | g_opt.step() 165 | 166 | self.log_dict( 167 | { 168 | "train/d_loss": discriminator_loss, 169 | "train/g_loss": generator_loss, 170 | "train/grid_loss": grid_cell_reg, 171 | }, 172 | prog_bar=True, 173 | ) 174 | 175 | # generate images 176 | generated_images = self(images) 177 | # log sampled images 178 | if self.visualize: 179 | self.visualize_step( 180 | images, 181 | future_images, 182 | generated_images, 183 | self.global_iteration, 184 | step="train", 185 | ) 186 | 187 | def configure_optimizers(self): 188 | b1 = self.beta1 189 | b2 = self.beta2 190 | 191 | opt_g = torch.optim.Adam( 192 | self.generator.parameters(), lr=self.gen_lr, betas=(b1, b2) 193 | ) 194 | opt_d = torch.optim.Adam( 195 | self.discriminator.parameters(), lr=self.disc_lr, betas=(b1, b2) 196 | ) 197 | 198 | return [opt_g, opt_d], [] 199 | 200 | def visualize_step( 201 | self, 202 | x: torch.Tensor, 203 | y: torch.Tensor, 204 | y_hat: torch.Tensor, 205 | batch_idx: int, 206 | step: str, 207 | ) -> None: 208 | # the logger you used (in this case tensorboard) 209 | tensorboard = self.logger.experiment[0] 210 | # Timesteps per channel 211 | images = x[0].cpu().detach() 212 | future_images = y[0].cpu().detach() 213 | generated_images = y_hat[0].cpu().detach() 214 | for i, t in enumerate(images): # Now would be (C, H, W) 215 | t = [torch.unsqueeze(img, dim=0) for img in t] 216 | image_grid = torchvision.utils.make_grid(t, nrow=self.input_channels) 217 | tensorboard.add_image( 218 | f"{step}/Input_Image_Stack_Frame_{i}", image_grid, global_step=batch_idx 219 | ) 220 | t = [torch.unsqueeze(img, dim=0) for img in future_images[i]] 221 | image_grid = torchvision.utils.make_grid(t, nrow=self.input_channels) 222 | tensorboard.add_image( 223 | f"{step}/Target_Image_Frame_{i}", image_grid, global_step=batch_idx 224 | ) 225 | t = [torch.unsqueeze(img, dim=0) for img in generated_images[i]] 226 | image_grid = torchvision.utils.make_grid(t, nrow=self.input_channels) 227 | tensorboard.add_image( 228 | f"{step}/Generated_Image_Frame_{i}", image_grid, global_step=batch_idx 229 | ) 230 | -------------------------------------------------------------------------------- /experiments/dgmr-dct/dgmr/hub.py: -------------------------------------------------------------------------------- 1 | """ 2 | Originally Taken from https://github.com/rwightman/ 3 | 4 | https://github.com/rwightman/pytorch-image-models/ 5 | blob/acd6c687fd1c0507128f0ce091829b233c8560b9/timm/models/hub.py 6 | """ 7 | 8 | import json 9 | import logging 10 | import os 11 | from functools import partial 12 | 13 | import torch 14 | 15 | 16 | try: 17 | from huggingface_hub import cached_download, hf_hub_url 18 | 19 | cached_download = partial(cached_download, library_name="dgmr") 20 | except ImportError: 21 | hf_hub_url = None 22 | cached_download = None 23 | 24 | from huggingface_hub import CONFIG_NAME, PYTORCH_WEIGHTS_NAME, ModelHubMixin, hf_hub_download 25 | 26 | MODEL_CARD_MARKDOWN = """--- 27 | license: mit 28 | tags: 29 | - nowcasting 30 | - forecasting 31 | - timeseries 32 | - remote-sensing 33 | - gan 34 | --- 35 | 36 | # {model_name} 37 | 38 | ## Model description 39 | 40 | [More information needed] 41 | 42 | ## Intended uses & limitations 43 | 44 | [More information needed] 45 | 46 | ## How to use 47 | 48 | [More information needed] 49 | 50 | ## Limitations and bias 51 | 52 | [More information needed] 53 | 54 | ## Training data 55 | 56 | [More information needed] 57 | 58 | ## Training procedure 59 | 60 | [More information needed] 61 | 62 | ## Evaluation results 63 | 64 | [More information needed] 65 | 66 | """ 67 | 68 | _logger = logging.getLogger(__name__) 69 | 70 | 71 | class NowcastingModelHubMixin(ModelHubMixin): 72 | """ 73 | HuggingFace ModelHubMixin containing specific adaptions for Nowcasting models 74 | """ 75 | 76 | def __init__(self, *args, **kwargs): 77 | """ 78 | Mixin for pl.LightningModule and Hugging Face 79 | 80 | Mix this class with your pl.LightningModule class to easily push / download 81 | the model via the Hugging Face Hub 82 | 83 | Example:: 84 | 85 | >>> from dgmr.hub import NowcastingModelHubMixin 86 | 87 | >>> class MyModel(nn.Module, NowcastingModelHubMixin): 88 | ... def __init__(self, **kwargs): 89 | ... super().__init__() 90 | ... self.layer = ... 91 | ... def forward(self, ...) 92 | ... return ... 93 | 94 | >>> model = MyModel() 95 | >>> model.push_to_hub("mymodel") # Pushing model-weights to hf-hub 96 | 97 | >>> # Downloading weights from hf-hub & model will be initialized from those weights 98 | >>> model = MyModel.from_pretrained("username/mymodel") 99 | """ 100 | 101 | def _create_model_card(self, path): 102 | model_card = MODEL_CARD_MARKDOWN.format(model_name=type(self).__name__) 103 | with open(os.path.join(path, "README.md"), "w") as f: 104 | f.write(model_card) 105 | 106 | def _save_config(self, module, save_directory): 107 | config = dict(module.hparams) 108 | path = os.path.join(save_directory, CONFIG_NAME) 109 | with open(path, "w") as f: 110 | json.dump(config, f) 111 | 112 | def _save_pretrained(self, save_directory: str, save_config: bool = True): 113 | # Save model weights 114 | path = os.path.join(save_directory, PYTORCH_WEIGHTS_NAME) 115 | model_to_save = self.module if hasattr(self, "module") else self 116 | torch.save(model_to_save.state_dict(), path) 117 | # Save model config 118 | if save_config and model_to_save.hparams: 119 | self._save_config(model_to_save, save_directory) 120 | # Save model card 121 | self._create_model_card(save_directory) 122 | 123 | @classmethod 124 | def _from_pretrained( 125 | cls, 126 | model_id, 127 | revision, 128 | cache_dir, 129 | force_download, 130 | proxies, 131 | resume_download, 132 | local_files_only, 133 | use_auth_token, 134 | map_location="cpu", 135 | strict=False, 136 | **model_kwargs, 137 | ): 138 | map_location = torch.device(map_location) 139 | 140 | if os.path.isdir(model_id): 141 | print("Loading weights from local directory") 142 | model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME) 143 | else: 144 | model_file = hf_hub_download( 145 | repo_id=model_id, 146 | filename=PYTORCH_WEIGHTS_NAME, 147 | revision=revision, 148 | cache_dir=cache_dir, 149 | force_download=force_download, 150 | proxies=proxies, 151 | resume_download=resume_download, 152 | use_auth_token=use_auth_token, 153 | local_files_only=local_files_only, 154 | ) 155 | model = cls(**model_kwargs["config"]) 156 | 157 | state_dict = torch.load(model_file, map_location=map_location) 158 | model.load_state_dict(state_dict, strict=strict) 159 | model.eval() 160 | 161 | return model 162 | -------------------------------------------------------------------------------- /experiments/dgmr-dct/dgmr/layers/Attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | import einops 5 | 6 | 7 | def attention_einsum(q, k, v): 8 | """Apply the attention operator to tensors of shape [h, w, c].""" 9 | 10 | # Reshape 3D tensors to 2D tensor with first dimension L = h x w. 11 | k = einops.rearrange(k, "h w c -> (h w) c") # [h, w, c] -> [L, c] 12 | v = einops.rearrange(v, "h w c -> (h w) c") # [h, w, c] -> [L, c] 13 | 14 | # Einstein summation corresponding to the query * key operation. 15 | beta = F.softmax(torch.einsum("hwc, Lc->hwL", q, k), dim=-1) 16 | 17 | # Einstein summation corresponding to the attention * value operation. 18 | out = torch.einsum("hwL, Lc->hwc", beta, v) 19 | return out 20 | 21 | 22 | class AttentionLayer(torch.nn.Module): 23 | """Attention Module""" 24 | 25 | def __init__(self, input_channels: int, output_channels: int, ratio_kq=8, ratio_v=8): 26 | super(AttentionLayer, self).__init__() 27 | 28 | self.ratio_kq = ratio_kq 29 | self.ratio_v = ratio_v 30 | self.output_channels = output_channels 31 | self.input_channels = input_channels 32 | 33 | # Compute query, key and value using 1x1 convolutions. 34 | self.query = torch.nn.Conv2d( 35 | in_channels=input_channels, 36 | out_channels=self.output_channels // self.ratio_kq, 37 | kernel_size=(1, 1), 38 | padding="valid", 39 | bias=False, 40 | ) 41 | self.key = torch.nn.Conv2d( 42 | in_channels=input_channels, 43 | out_channels=self.output_channels // self.ratio_kq, 44 | kernel_size=(1, 1), 45 | padding="valid", 46 | bias=False, 47 | ) 48 | self.value = torch.nn.Conv2d( 49 | in_channels=input_channels, 50 | out_channels=self.output_channels // self.ratio_v, 51 | kernel_size=(1, 1), 52 | padding="valid", 53 | bias=False, 54 | ) 55 | 56 | self.last_conv = torch.nn.Conv2d( 57 | in_channels=self.output_channels // 8, 58 | out_channels=self.output_channels, 59 | kernel_size=(1, 1), 60 | padding="valid", 61 | bias=False, 62 | ) 63 | 64 | # Learnable gain parameter 65 | self.gamma = nn.Parameter(torch.zeros(1)) 66 | 67 | def forward(self, x: torch.Tensor) -> torch.Tensor: 68 | # Compute query, key and value using 1x1 convolutions. 69 | query = self.query(x) 70 | key = self.key(x) 71 | value = self.value(x) 72 | # Apply the attention operation. 73 | out = [] 74 | for b in range(x.shape[0]): 75 | # Apply to each in batch 76 | out.append(attention_einsum(query[b], key[b], value[b])) 77 | out = torch.stack(out, dim=0) 78 | out = self.gamma * self.last_conv(out) 79 | # Residual connection. 80 | return out + x 81 | -------------------------------------------------------------------------------- /experiments/dgmr-dct/dgmr/layers/ConvGRU.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.nn.utils.parametrizations import spectral_norm 4 | 5 | 6 | class ConvGRUCell(torch.nn.Module): 7 | """A ConvGRU implementation.""" 8 | 9 | def __init__( 10 | self, input_channels: int, output_channels: int, kernel_size=3, sn_eps=0.0001 11 | ): 12 | """Constructor. 13 | 14 | Args: 15 | kernel_size: kernel size of the convolutions. Default: 3. 16 | sn_eps: constant for spectral normalization. Default: 1e-4. 17 | """ 18 | super().__init__() 19 | self._kernel_size = kernel_size 20 | self._sn_eps = sn_eps 21 | self.read_gate_conv = spectral_norm( 22 | torch.nn.Conv2d( 23 | in_channels=input_channels, 24 | out_channels=output_channels, 25 | kernel_size=(kernel_size, kernel_size), 26 | padding=1, 27 | ), 28 | eps=sn_eps, 29 | ) 30 | self.update_gate_conv = spectral_norm( 31 | torch.nn.Conv2d( 32 | in_channels=input_channels, 33 | out_channels=output_channels, 34 | kernel_size=(kernel_size, kernel_size), 35 | padding=1, 36 | ), 37 | eps=sn_eps, 38 | ) 39 | self.output_conv = spectral_norm( 40 | torch.nn.Conv2d( 41 | in_channels=input_channels, 42 | out_channels=output_channels, 43 | kernel_size=(kernel_size, kernel_size), 44 | padding=1, 45 | ), 46 | eps=sn_eps, 47 | ) 48 | 49 | def forward(self, x, prev_state): 50 | """ 51 | ConvGRU forward, returning the current+new state 52 | 53 | Args: 54 | x: Input tensor 55 | prev_state: Previous state 56 | 57 | Returns: 58 | New tensor plus the new state 59 | """ 60 | # Concatenate the inputs and previous state along the channel axis. 61 | xh = torch.cat([x, prev_state], dim=1) 62 | 63 | # Read gate of the GRU. 64 | read_gate = F.sigmoid(self.read_gate_conv(xh)) 65 | 66 | # Update gate of the GRU. 67 | update_gate = F.sigmoid(self.update_gate_conv(xh)) 68 | 69 | # Gate the inputs. 70 | gated_input = torch.cat([x, read_gate * prev_state], dim=1) 71 | 72 | # Gate the cell and state / outputs. 73 | c = F.relu(self.output_conv(gated_input)) 74 | out = update_gate * prev_state + (1.0 - update_gate) * c 75 | new_state = out 76 | 77 | return out, new_state 78 | 79 | 80 | class ConvGRU(torch.nn.Module): 81 | """ConvGRU Cell wrapper to replace tf.static_rnn in TF implementation""" 82 | 83 | def __init__( 84 | self, 85 | input_channels: int, 86 | output_channels: int, 87 | kernel_size: int = 3, 88 | sn_eps=0.0001, 89 | ): 90 | super().__init__() 91 | self.cell = ConvGRUCell(input_channels, output_channels, kernel_size, sn_eps) 92 | 93 | def forward(self, x: torch.Tensor, hidden_state=None) -> torch.Tensor: 94 | outputs = [] 95 | for step in range(len(x)): 96 | # Compute current timestep 97 | output, hidden_state = self.cell(x[step], hidden_state) 98 | outputs.append(output) 99 | # Stack outputs to return as tensor 100 | outputs = torch.stack(outputs, dim=0) 101 | return outputs 102 | -------------------------------------------------------------------------------- /experiments/dgmr-dct/dgmr/layers/CoordConv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class AddCoords(nn.Module): 6 | def __init__(self, with_r=False): 7 | super().__init__() 8 | self.with_r = with_r 9 | 10 | def forward(self, input_tensor): 11 | """ 12 | Args: 13 | input_tensor: shape(batch, channel, x_dim, y_dim) 14 | """ 15 | batch_size, _, x_dim, y_dim = input_tensor.size() 16 | 17 | xx_channel = torch.arange(x_dim).repeat(1, y_dim, 1) 18 | yy_channel = torch.arange(y_dim).repeat(1, x_dim, 1).transpose(1, 2) 19 | 20 | xx_channel = xx_channel.float() / (x_dim - 1) 21 | yy_channel = yy_channel.float() / (y_dim - 1) 22 | 23 | xx_channel = xx_channel * 2 - 1 24 | yy_channel = yy_channel * 2 - 1 25 | 26 | xx_channel = xx_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3) 27 | yy_channel = yy_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3) 28 | 29 | ret = torch.cat( 30 | [ 31 | input_tensor, 32 | xx_channel.type_as(input_tensor), 33 | yy_channel.type_as(input_tensor), 34 | ], 35 | dim=1, 36 | ) 37 | 38 | if self.with_r: 39 | rr = torch.sqrt( 40 | torch.pow(xx_channel.type_as(input_tensor) - 0.5, 2) 41 | + torch.pow(yy_channel.type_as(input_tensor) - 0.5, 2) 42 | ) 43 | ret = torch.cat([ret, rr], dim=1) 44 | 45 | return ret 46 | 47 | 48 | class CoordConv(nn.Module): 49 | def __init__(self, in_channels, out_channels, with_r=False, **kwargs): 50 | super().__init__() 51 | self.addcoords = AddCoords(with_r=with_r) 52 | in_size = in_channels + 2 53 | if with_r: 54 | in_size += 1 55 | self.conv = nn.Conv2d(in_size, out_channels, **kwargs) 56 | 57 | def forward(self, x): 58 | ret = self.addcoords(x) 59 | ret = self.conv(ret) 60 | return ret 61 | -------------------------------------------------------------------------------- /experiments/dgmr-dct/dgmr/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .Attention import AttentionLayer 2 | from .ConvGRU import ConvGRU 3 | from .CoordConv import CoordConv 4 | -------------------------------------------------------------------------------- /experiments/dgmr-dct/dgmr/layers/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from dgmr.layers import CoordConv 3 | 4 | 5 | def get_conv_layer(conv_type: str = "standard") -> torch.nn.Module: 6 | if conv_type == "standard": 7 | conv_layer = torch.nn.Conv2d 8 | elif conv_type == "coord": 9 | conv_layer = CoordConv 10 | elif conv_type == "3d": 11 | conv_layer = torch.nn.Conv3d 12 | else: 13 | raise ValueError(f"{conv_type} is not a recognized Conv method") 14 | return conv_layer 15 | -------------------------------------------------------------------------------- /experiments/dgmr-multichannel/dgmr/__init__.py: -------------------------------------------------------------------------------- 1 | from .dgmr import DGMR 2 | from .generators import Sampler, Generator 3 | from .discriminators import SpatialDiscriminator, TemporalDiscriminator, Discriminator 4 | from .common import LatentConditioningStack, ContextConditioningStack 5 | -------------------------------------------------------------------------------- /experiments/dgmr-multichannel/dgmr/generators.py: -------------------------------------------------------------------------------- 1 | import einops 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.nn.modules.pixelshuffle import PixelShuffle 5 | from torch.nn.utils.parametrizations import spectral_norm 6 | from typing import List 7 | from dgmr.common import GBlock, UpsampleGBlock 8 | from dgmr.layers import ConvGRU 9 | from huggingface_hub import PyTorchModelHubMixin 10 | import logging 11 | 12 | logger = logging.getLogger(__name__) 13 | logger.setLevel(logging.WARN) 14 | 15 | 16 | class Sampler(torch.nn.Module, PyTorchModelHubMixin): 17 | def __init__( 18 | self, 19 | forecast_steps: int = 18, 20 | latent_channels: int = 768, 21 | context_channels: int = 384, 22 | output_channels: int = 1, 23 | **kwargs 24 | ): 25 | """ 26 | Sampler from the Skillful Nowcasting, see https://arxiv.org/pdf/2104.00954.pdf 27 | 28 | The sampler takes the output from the Latent and Context conditioning stacks and 29 | creates one stack of ConvGRU layers per future timestep. 30 | Args: 31 | forecast_steps: Number of forecast steps 32 | latent_channels: Number of input channels to the lowest ConvGRU layer 33 | """ 34 | super().__init__() 35 | config = locals() 36 | config.pop("__class__") 37 | config.pop("self") 38 | self.config = kwargs.get("config", config) 39 | self.forecast_steps = self.config["forecast_steps"] 40 | latent_channels = self.config["latent_channels"] 41 | context_channels = self.config["context_channels"] 42 | output_channels = self.config["output_channels"] 43 | 44 | self.convGRU1 = ConvGRU( 45 | input_channels=latent_channels + context_channels, 46 | output_channels=context_channels, 47 | kernel_size=3, 48 | ) 49 | self.gru_conv_1x1 = spectral_norm( 50 | torch.nn.Conv2d( 51 | in_channels=context_channels, 52 | out_channels=latent_channels, 53 | kernel_size=(1, 1), 54 | ) 55 | ) 56 | self.g1 = GBlock( 57 | input_channels=latent_channels, output_channels=latent_channels 58 | ) 59 | self.up_g1 = UpsampleGBlock( 60 | input_channels=latent_channels, output_channels=latent_channels // 2 61 | ) 62 | 63 | self.convGRU2 = ConvGRU( 64 | input_channels=latent_channels // 2 + context_channels // 2, 65 | output_channels=context_channels // 2, 66 | kernel_size=3, 67 | ) 68 | self.gru_conv_1x1_2 = spectral_norm( 69 | torch.nn.Conv2d( 70 | in_channels=context_channels // 2, 71 | out_channels=latent_channels // 2, 72 | kernel_size=(1, 1), 73 | ) 74 | ) 75 | self.g2 = GBlock( 76 | input_channels=latent_channels // 2, output_channels=latent_channels // 2 77 | ) 78 | self.up_g2 = UpsampleGBlock( 79 | input_channels=latent_channels // 2, output_channels=latent_channels // 4 80 | ) 81 | 82 | self.convGRU3 = ConvGRU( 83 | input_channels=latent_channels // 4 + context_channels // 4, 84 | output_channels=context_channels // 4, 85 | kernel_size=3, 86 | ) 87 | self.gru_conv_1x1_3 = spectral_norm( 88 | torch.nn.Conv2d( 89 | in_channels=context_channels // 4, 90 | out_channels=latent_channels // 4, 91 | kernel_size=(1, 1), 92 | ) 93 | ) 94 | self.g3 = GBlock( 95 | input_channels=latent_channels // 4, output_channels=latent_channels // 4 96 | ) 97 | self.up_g3 = UpsampleGBlock( 98 | input_channels=latent_channels // 4, output_channels=latent_channels // 8 99 | ) 100 | 101 | self.convGRU4 = ConvGRU( 102 | input_channels=latent_channels // 8 + context_channels // 8, 103 | output_channels=context_channels // 8, 104 | kernel_size=3, 105 | ) 106 | self.gru_conv_1x1_4 = spectral_norm( 107 | torch.nn.Conv2d( 108 | in_channels=context_channels // 8, 109 | out_channels=latent_channels // 8, 110 | kernel_size=(1, 1), 111 | ) 112 | ) 113 | self.g4 = GBlock( 114 | input_channels=latent_channels // 8, output_channels=latent_channels // 8 115 | ) 116 | self.up_g4 = UpsampleGBlock( 117 | input_channels=latent_channels // 8, output_channels=latent_channels // 16 118 | ) 119 | 120 | self.bn = torch.nn.BatchNorm2d(latent_channels // 16) 121 | self.relu = torch.nn.ReLU() 122 | self.conv_1x1 = spectral_norm( 123 | torch.nn.Conv2d( 124 | in_channels=latent_channels // 16, 125 | out_channels=4 * output_channels, 126 | kernel_size=(1, 1), 127 | ) 128 | ) 129 | 130 | self.depth2space = PixelShuffle(upscale_factor=2) 131 | 132 | def forward( 133 | self, conditioning_states: List[torch.Tensor], latent_dim: torch.Tensor 134 | ) -> torch.Tensor: 135 | """ 136 | Perform the sampling from Skillful Nowcasting with GANs 137 | Args: 138 | conditioning_states: Outputs from the `ContextConditioningStack` with the 4 input states, ordered from largest to smallest spatially 139 | latent_dim: Output from `LatentConditioningStack` for input into the ConvGRUs 140 | 141 | Returns: 142 | forecast_steps-length output of images for future timesteps 143 | 144 | """ 145 | # Iterate through each forecast step 146 | # Initialize with conditioning state for first one, output for second one 147 | init_states = conditioning_states 148 | # Expand latent dim to match batch size 149 | latent_dim = einops.repeat( 150 | latent_dim, "b c h w -> (repeat b) c h w", repeat=init_states[0].shape[0] 151 | ) 152 | hidden_states = [latent_dim] * self.forecast_steps 153 | 154 | # TODO: can we make this into a UNET and remove the LSTM? 155 | # Layer 4 (bottom most) 156 | hidden_states = self.convGRU1(hidden_states, init_states[3]) 157 | hidden_states = [self.gru_conv_1x1(h) for h in hidden_states] 158 | hidden_states = [self.g1(h) for h in hidden_states] 159 | hidden_states = [self.up_g1(h) for h in hidden_states] 160 | 161 | # Layer 3. 162 | hidden_states = self.convGRU2(hidden_states, init_states[2]) 163 | hidden_states = [self.gru_conv_1x1_2(h) for h in hidden_states] 164 | hidden_states = [self.g2(h) for h in hidden_states] 165 | hidden_states = [self.up_g2(h) for h in hidden_states] 166 | 167 | # Layer 2. 168 | hidden_states = self.convGRU3(hidden_states, init_states[1]) 169 | hidden_states = [self.gru_conv_1x1_3(h) for h in hidden_states] 170 | hidden_states = [self.g3(h) for h in hidden_states] 171 | hidden_states = [self.up_g3(h) for h in hidden_states] 172 | 173 | # Layer 1 (top-most). 174 | hidden_states = self.convGRU4(hidden_states, init_states[0]) 175 | hidden_states = [self.gru_conv_1x1_4(h) for h in hidden_states] 176 | hidden_states = [self.g4(h) for h in hidden_states] 177 | hidden_states = [self.up_g4(h) for h in hidden_states] 178 | 179 | # Output layer. 180 | hidden_states = [F.relu(self.bn(h)) for h in hidden_states] 181 | hidden_states = [self.conv_1x1(h) for h in hidden_states] 182 | hidden_states = [self.depth2space(h) for h in hidden_states] 183 | 184 | # Convert forecasts to a torch Tensor 185 | forecasts = torch.stack(hidden_states, dim=1) 186 | return forecasts 187 | 188 | 189 | class Generator(torch.nn.Module, PyTorchModelHubMixin): 190 | def __init__( 191 | self, 192 | conditioning_stack: torch.nn.Module, 193 | latent_stack: torch.nn.Module, 194 | sampler: torch.nn.Module, 195 | ): 196 | """ 197 | Wraps the three parts of the generator for simpler calling 198 | Args: 199 | conditioning_stack: 200 | latent_stack: 201 | sampler: 202 | """ 203 | super().__init__() 204 | self.conditioning_stack = conditioning_stack 205 | self.latent_stack = latent_stack 206 | self.sampler = sampler 207 | 208 | def forward(self, x): 209 | conditioning_states = self.conditioning_stack(x) 210 | latent_dim = self.latent_stack(x) 211 | x = self.sampler(conditioning_states, latent_dim) 212 | return x 213 | -------------------------------------------------------------------------------- /experiments/dgmr-multichannel/dgmr/hub.py: -------------------------------------------------------------------------------- 1 | """ 2 | Originally Taken from https://github.com/rwightman/ 3 | 4 | https://github.com/rwightman/pytorch-image-models/ 5 | blob/acd6c687fd1c0507128f0ce091829b233c8560b9/timm/models/hub.py 6 | """ 7 | 8 | import json 9 | import logging 10 | import os 11 | from functools import partial 12 | 13 | import torch 14 | 15 | 16 | try: 17 | from huggingface_hub import cached_download, hf_hub_url 18 | 19 | cached_download = partial(cached_download, library_name="dgmr") 20 | except ImportError: 21 | hf_hub_url = None 22 | cached_download = None 23 | 24 | from huggingface_hub import CONFIG_NAME, PYTORCH_WEIGHTS_NAME, ModelHubMixin, hf_hub_download 25 | 26 | MODEL_CARD_MARKDOWN = """--- 27 | license: mit 28 | tags: 29 | - nowcasting 30 | - forecasting 31 | - timeseries 32 | - remote-sensing 33 | - gan 34 | --- 35 | 36 | # {model_name} 37 | 38 | ## Model description 39 | 40 | [More information needed] 41 | 42 | ## Intended uses & limitations 43 | 44 | [More information needed] 45 | 46 | ## How to use 47 | 48 | [More information needed] 49 | 50 | ## Limitations and bias 51 | 52 | [More information needed] 53 | 54 | ## Training data 55 | 56 | [More information needed] 57 | 58 | ## Training procedure 59 | 60 | [More information needed] 61 | 62 | ## Evaluation results 63 | 64 | [More information needed] 65 | 66 | """ 67 | 68 | _logger = logging.getLogger(__name__) 69 | 70 | 71 | class NowcastingModelHubMixin(ModelHubMixin): 72 | """ 73 | HuggingFace ModelHubMixin containing specific adaptions for Nowcasting models 74 | """ 75 | 76 | def __init__(self, *args, **kwargs): 77 | """ 78 | Mixin for pl.LightningModule and Hugging Face 79 | 80 | Mix this class with your pl.LightningModule class to easily push / download 81 | the model via the Hugging Face Hub 82 | 83 | Example:: 84 | 85 | >>> from dgmr.hub import NowcastingModelHubMixin 86 | 87 | >>> class MyModel(nn.Module, NowcastingModelHubMixin): 88 | ... def __init__(self, **kwargs): 89 | ... super().__init__() 90 | ... self.layer = ... 91 | ... def forward(self, ...) 92 | ... return ... 93 | 94 | >>> model = MyModel() 95 | >>> model.push_to_hub("mymodel") # Pushing model-weights to hf-hub 96 | 97 | >>> # Downloading weights from hf-hub & model will be initialized from those weights 98 | >>> model = MyModel.from_pretrained("username/mymodel") 99 | """ 100 | 101 | def _create_model_card(self, path): 102 | model_card = MODEL_CARD_MARKDOWN.format(model_name=type(self).__name__) 103 | with open(os.path.join(path, "README.md"), "w") as f: 104 | f.write(model_card) 105 | 106 | def _save_config(self, module, save_directory): 107 | config = dict(module.hparams) 108 | path = os.path.join(save_directory, CONFIG_NAME) 109 | with open(path, "w") as f: 110 | json.dump(config, f) 111 | 112 | def _save_pretrained(self, save_directory: str, save_config: bool = True): 113 | # Save model weights 114 | path = os.path.join(save_directory, PYTORCH_WEIGHTS_NAME) 115 | model_to_save = self.module if hasattr(self, "module") else self 116 | torch.save(model_to_save.state_dict(), path) 117 | # Save model config 118 | if save_config and model_to_save.hparams: 119 | self._save_config(model_to_save, save_directory) 120 | # Save model card 121 | self._create_model_card(save_directory) 122 | 123 | @classmethod 124 | def _from_pretrained( 125 | cls, 126 | model_id, 127 | revision, 128 | cache_dir, 129 | force_download, 130 | proxies, 131 | resume_download, 132 | local_files_only, 133 | use_auth_token, 134 | map_location="cpu", 135 | strict=False, 136 | **model_kwargs, 137 | ): 138 | map_location = torch.device(map_location) 139 | 140 | if os.path.isdir(model_id): 141 | print("Loading weights from local directory") 142 | model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME) 143 | else: 144 | model_file = hf_hub_download( 145 | repo_id=model_id, 146 | filename=PYTORCH_WEIGHTS_NAME, 147 | revision=revision, 148 | cache_dir=cache_dir, 149 | force_download=force_download, 150 | proxies=proxies, 151 | resume_download=resume_download, 152 | use_auth_token=use_auth_token, 153 | local_files_only=local_files_only, 154 | ) 155 | model = cls(**model_kwargs["config"]) 156 | 157 | state_dict = torch.load(model_file, map_location=map_location) 158 | model.load_state_dict(state_dict, strict=strict) 159 | model.eval() 160 | 161 | return model 162 | -------------------------------------------------------------------------------- /experiments/dgmr-multichannel/dgmr/layers/Attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | import einops 5 | 6 | 7 | def attention_einsum(q, k, v): 8 | """Apply the attention operator to tensors of shape [h, w, c].""" 9 | 10 | # Reshape 3D tensors to 2D tensor with first dimension L = h x w. 11 | k = einops.rearrange(k, "h w c -> (h w) c") # [h, w, c] -> [L, c] 12 | v = einops.rearrange(v, "h w c -> (h w) c") # [h, w, c] -> [L, c] 13 | 14 | # Einstein summation corresponding to the query * key operation. 15 | beta = F.softmax(torch.einsum("hwc, Lc->hwL", q, k), dim=-1) 16 | 17 | # Einstein summation corresponding to the attention * value operation. 18 | out = torch.einsum("hwL, Lc->hwc", beta, v) 19 | return out 20 | 21 | 22 | class AttentionLayer(torch.nn.Module): 23 | """Attention Module""" 24 | 25 | def __init__(self, input_channels: int, output_channels: int, ratio_kq=8, ratio_v=8): 26 | super(AttentionLayer, self).__init__() 27 | 28 | self.ratio_kq = ratio_kq 29 | self.ratio_v = ratio_v 30 | self.output_channels = output_channels 31 | self.input_channels = input_channels 32 | 33 | # Compute query, key and value using 1x1 convolutions. 34 | self.query = torch.nn.Conv2d( 35 | in_channels=input_channels, 36 | out_channels=self.output_channels // self.ratio_kq, 37 | kernel_size=(1, 1), 38 | padding="valid", 39 | bias=False, 40 | ) 41 | self.key = torch.nn.Conv2d( 42 | in_channels=input_channels, 43 | out_channels=self.output_channels // self.ratio_kq, 44 | kernel_size=(1, 1), 45 | padding="valid", 46 | bias=False, 47 | ) 48 | self.value = torch.nn.Conv2d( 49 | in_channels=input_channels, 50 | out_channels=self.output_channels // self.ratio_v, 51 | kernel_size=(1, 1), 52 | padding="valid", 53 | bias=False, 54 | ) 55 | 56 | self.last_conv = torch.nn.Conv2d( 57 | in_channels=self.output_channels // 8, 58 | out_channels=self.output_channels, 59 | kernel_size=(1, 1), 60 | padding="valid", 61 | bias=False, 62 | ) 63 | 64 | # Learnable gain parameter 65 | self.gamma = nn.Parameter(torch.zeros(1)) 66 | 67 | def forward(self, x: torch.Tensor) -> torch.Tensor: 68 | # Compute query, key and value using 1x1 convolutions. 69 | query = self.query(x) 70 | key = self.key(x) 71 | value = self.value(x) 72 | # Apply the attention operation. 73 | out = [] 74 | for b in range(x.shape[0]): 75 | # Apply to each in batch 76 | out.append(attention_einsum(query[b], key[b], value[b])) 77 | out = torch.stack(out, dim=0) 78 | out = self.gamma * self.last_conv(out) 79 | # Residual connection. 80 | return out + x 81 | -------------------------------------------------------------------------------- /experiments/dgmr-multichannel/dgmr/layers/ConvGRU.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.nn.utils.parametrizations import spectral_norm 4 | 5 | 6 | class ConvGRUCell(torch.nn.Module): 7 | """A ConvGRU implementation.""" 8 | 9 | def __init__( 10 | self, input_channels: int, output_channels: int, kernel_size=3, sn_eps=0.0001 11 | ): 12 | """Constructor. 13 | 14 | Args: 15 | kernel_size: kernel size of the convolutions. Default: 3. 16 | sn_eps: constant for spectral normalization. Default: 1e-4. 17 | """ 18 | super().__init__() 19 | self._kernel_size = kernel_size 20 | self._sn_eps = sn_eps 21 | self.read_gate_conv = spectral_norm( 22 | torch.nn.Conv2d( 23 | in_channels=input_channels, 24 | out_channels=output_channels, 25 | kernel_size=(kernel_size, kernel_size), 26 | padding=1, 27 | ), 28 | eps=sn_eps, 29 | ) 30 | self.update_gate_conv = spectral_norm( 31 | torch.nn.Conv2d( 32 | in_channels=input_channels, 33 | out_channels=output_channels, 34 | kernel_size=(kernel_size, kernel_size), 35 | padding=1, 36 | ), 37 | eps=sn_eps, 38 | ) 39 | self.output_conv = spectral_norm( 40 | torch.nn.Conv2d( 41 | in_channels=input_channels, 42 | out_channels=output_channels, 43 | kernel_size=(kernel_size, kernel_size), 44 | padding=1, 45 | ), 46 | eps=sn_eps, 47 | ) 48 | 49 | def forward(self, x, prev_state): 50 | """ 51 | ConvGRU forward, returning the current+new state 52 | 53 | Args: 54 | x: Input tensor 55 | prev_state: Previous state 56 | 57 | Returns: 58 | New tensor plus the new state 59 | """ 60 | # Concatenate the inputs and previous state along the channel axis. 61 | xh = torch.cat([x, prev_state], dim=1) 62 | 63 | # Read gate of the GRU. 64 | read_gate = F.sigmoid(self.read_gate_conv(xh)) 65 | 66 | # Update gate of the GRU. 67 | update_gate = F.sigmoid(self.update_gate_conv(xh)) 68 | 69 | # Gate the inputs. 70 | gated_input = torch.cat([x, read_gate * prev_state], dim=1) 71 | 72 | # Gate the cell and state / outputs. 73 | c = F.relu(self.output_conv(gated_input)) 74 | out = update_gate * prev_state + (1.0 - update_gate) * c 75 | new_state = out 76 | 77 | return out, new_state 78 | 79 | 80 | class ConvGRU(torch.nn.Module): 81 | """ConvGRU Cell wrapper to replace tf.static_rnn in TF implementation""" 82 | 83 | def __init__( 84 | self, 85 | input_channels: int, 86 | output_channels: int, 87 | kernel_size: int = 3, 88 | sn_eps=0.0001, 89 | ): 90 | super().__init__() 91 | self.cell = ConvGRUCell(input_channels, output_channels, kernel_size, sn_eps) 92 | 93 | def forward(self, x: torch.Tensor, hidden_state=None) -> torch.Tensor: 94 | outputs = [] 95 | for step in range(len(x)): 96 | # Compute current timestep 97 | output, hidden_state = self.cell(x[step], hidden_state) 98 | outputs.append(output) 99 | # Stack outputs to return as tensor 100 | outputs = torch.stack(outputs, dim=0) 101 | return outputs 102 | -------------------------------------------------------------------------------- /experiments/dgmr-multichannel/dgmr/layers/CoordConv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class AddCoords(nn.Module): 6 | def __init__(self, with_r=False): 7 | super().__init__() 8 | self.with_r = with_r 9 | 10 | def forward(self, input_tensor): 11 | """ 12 | Args: 13 | input_tensor: shape(batch, channel, x_dim, y_dim) 14 | """ 15 | batch_size, _, x_dim, y_dim = input_tensor.size() 16 | 17 | xx_channel = torch.arange(x_dim).repeat(1, y_dim, 1) 18 | yy_channel = torch.arange(y_dim).repeat(1, x_dim, 1).transpose(1, 2) 19 | 20 | xx_channel = xx_channel.float() / (x_dim - 1) 21 | yy_channel = yy_channel.float() / (y_dim - 1) 22 | 23 | xx_channel = xx_channel * 2 - 1 24 | yy_channel = yy_channel * 2 - 1 25 | 26 | xx_channel = xx_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3) 27 | yy_channel = yy_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3) 28 | 29 | ret = torch.cat( 30 | [ 31 | input_tensor, 32 | xx_channel.type_as(input_tensor), 33 | yy_channel.type_as(input_tensor), 34 | ], 35 | dim=1, 36 | ) 37 | 38 | if self.with_r: 39 | rr = torch.sqrt( 40 | torch.pow(xx_channel.type_as(input_tensor) - 0.5, 2) 41 | + torch.pow(yy_channel.type_as(input_tensor) - 0.5, 2) 42 | ) 43 | ret = torch.cat([ret, rr], dim=1) 44 | 45 | return ret 46 | 47 | 48 | class CoordConv(nn.Module): 49 | def __init__(self, in_channels, out_channels, with_r=False, **kwargs): 50 | super().__init__() 51 | self.addcoords = AddCoords(with_r=with_r) 52 | in_size = in_channels + 2 53 | if with_r: 54 | in_size += 1 55 | self.conv = nn.Conv2d(in_size, out_channels, **kwargs) 56 | 57 | def forward(self, x): 58 | ret = self.addcoords(x) 59 | ret = self.conv(ret) 60 | return ret 61 | -------------------------------------------------------------------------------- /experiments/dgmr-multichannel/dgmr/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .Attention import AttentionLayer 2 | from .ConvGRU import ConvGRU 3 | from .CoordConv import CoordConv 4 | -------------------------------------------------------------------------------- /experiments/dgmr-multichannel/dgmr/layers/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from dgmr.layers import CoordConv 3 | 4 | 5 | def get_conv_layer(conv_type: str = "standard") -> torch.nn.Module: 6 | if conv_type == "standard": 7 | conv_layer = torch.nn.Conv2d 8 | elif conv_type == "coord": 9 | conv_layer = CoordConv 10 | elif conv_type == "3d": 11 | conv_layer = torch.nn.Conv3d 12 | else: 13 | raise ValueError(f"{conv_type} is not a recognized Conv method") 14 | return conv_layer 15 | -------------------------------------------------------------------------------- /experiments/dgmr-oneshot-multichannel/dgmr/__init__.py: -------------------------------------------------------------------------------- 1 | from .dgmr import DGMR 2 | from .generators import Sampler, Generator 3 | from .discriminators import SpatialDiscriminator, TemporalDiscriminator, Discriminator 4 | from .common import LatentConditioningStack, ContextConditioningStack 5 | -------------------------------------------------------------------------------- /experiments/dgmr-oneshot-multichannel/dgmr/generators.py: -------------------------------------------------------------------------------- 1 | import einops 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.nn.modules.pixelshuffle import PixelShuffle 5 | from torch.nn.utils.parametrizations import spectral_norm 6 | from typing import List 7 | from dgmr.common import GBlock, UpsampleGBlock 8 | from dgmr.layers import ConvGRU 9 | from huggingface_hub import PyTorchModelHubMixin 10 | import logging 11 | 12 | logger = logging.getLogger(__name__) 13 | logger.setLevel(logging.WARN) 14 | 15 | 16 | class Sampler(torch.nn.Module, PyTorchModelHubMixin): 17 | def __init__( 18 | self, 19 | forecast_steps: int = 18, 20 | context_channels: int = 384, 21 | latent_channels: int = 384, 22 | output_channels: int = 1, 23 | **kwargs 24 | ): 25 | """ 26 | Sampler from the Skillful Nowcasting, see https://arxiv.org/pdf/2104.00954.pdf 27 | 28 | The sampler takes the output from the Latent and Context conditioning stacks and 29 | creates one stack of ConvGRU layers per future timestep. 30 | Args: 31 | forecast_steps: Number of forecast steps 32 | latent_channels: Number of input channels to the lowest ConvGRU layer 33 | """ 34 | super().__init__() 35 | config = locals() 36 | config.pop("__class__") 37 | config.pop("self") 38 | self.config = kwargs.get("config", config) 39 | self.forecast_steps = self.config["forecast_steps"] 40 | latent_channels = self.config["latent_channels"] 41 | context_channels = self.config["context_channels"] 42 | output_channels = self.config["output_channels"] 43 | 44 | self.gru_conv_1x1 = spectral_norm( 45 | torch.nn.Conv2d( 46 | in_channels=context_channels, 47 | out_channels=latent_channels * self.forecast_steps, 48 | kernel_size=(1, 1), 49 | ) 50 | ) 51 | self.g1 = GBlock( 52 | input_channels=latent_channels * self.forecast_steps, 53 | output_channels=latent_channels * self.forecast_steps, 54 | ) 55 | self.up_g1 = UpsampleGBlock( 56 | input_channels=latent_channels * self.forecast_steps, 57 | output_channels=latent_channels * self.forecast_steps // 2, 58 | ) 59 | 60 | self.gru_conv_1x1_2 = spectral_norm( 61 | torch.nn.Conv2d( 62 | in_channels=self.up_g1.output_channels + context_channels // 2, 63 | out_channels=latent_channels * self.forecast_steps // 2, 64 | kernel_size=(1, 1), 65 | ) 66 | ) 67 | self.g2 = GBlock( 68 | input_channels=latent_channels * self.forecast_steps // 2, 69 | output_channels=latent_channels * self.forecast_steps // 2, 70 | ) 71 | self.up_g2 = UpsampleGBlock( 72 | input_channels=latent_channels * self.forecast_steps // 2, 73 | output_channels=latent_channels * self.forecast_steps // 4, 74 | ) 75 | 76 | self.gru_conv_1x1_3 = spectral_norm( 77 | torch.nn.Conv2d( 78 | in_channels=self.up_g2.output_channels + context_channels // 4, 79 | out_channels=latent_channels * self.forecast_steps // 4, 80 | kernel_size=(1, 1), 81 | ) 82 | ) 83 | self.g3 = GBlock( 84 | input_channels=latent_channels * self.forecast_steps // 4, 85 | output_channels=latent_channels * self.forecast_steps // 4, 86 | ) 87 | self.up_g3 = UpsampleGBlock( 88 | input_channels=latent_channels * self.forecast_steps // 4, 89 | output_channels=latent_channels * self.forecast_steps // 8, 90 | ) 91 | 92 | self.gru_conv_1x1_4 = spectral_norm( 93 | torch.nn.Conv2d( 94 | in_channels=self.up_g3.output_channels + context_channels // 8, 95 | out_channels=latent_channels * self.forecast_steps // 8, 96 | kernel_size=(1, 1), 97 | ) 98 | ) 99 | self.g4 = GBlock( 100 | input_channels=latent_channels * self.forecast_steps // 8, 101 | output_channels=latent_channels * self.forecast_steps // 8, 102 | ) 103 | self.up_g4 = UpsampleGBlock( 104 | input_channels=latent_channels * self.forecast_steps // 8, 105 | output_channels=latent_channels * self.forecast_steps // 16, 106 | ) 107 | 108 | self.bn = torch.nn.BatchNorm2d(latent_channels * self.forecast_steps // 16) 109 | self.relu = torch.nn.ReLU() 110 | self.conv_1x1 = spectral_norm( 111 | torch.nn.Conv2d( 112 | in_channels=latent_channels * self.forecast_steps // 16, 113 | out_channels=4 * output_channels * self.forecast_steps, 114 | kernel_size=(1, 1), 115 | ) 116 | ) 117 | 118 | self.depth2space = PixelShuffle(upscale_factor=2) 119 | 120 | def forward(self, conditioning_states: List[torch.Tensor]) -> torch.Tensor: 121 | """ 122 | Perform the sampling from Skillful Nowcasting with GANs 123 | Args: 124 | conditioning_states: Outputs from the `ContextConditioningStack` with the 4 input states, ordered from largest to smallest spatially 125 | latent_dim: Output from `LatentConditioningStack` for input into the ConvGRUs 126 | 127 | Returns: 128 | forecast_steps-length output of images for future timesteps 129 | 130 | """ 131 | # Iterate through each forecast step 132 | # Initialize with conditioning state for first one, output for second one 133 | init_states = conditioning_states 134 | 135 | layer4_states = self.gru_conv_1x1(init_states[3]) 136 | layer4_states = self.g1(layer4_states) 137 | layer4_states = self.up_g1(layer4_states) 138 | 139 | # Layer 3. 140 | layer3_states = torch.cat([layer4_states, init_states[2]], dim=1) 141 | layer3_states = self.gru_conv_1x1_2(layer3_states) 142 | layer3_states = self.g2(layer3_states) 143 | layer3_states = self.up_g2(layer3_states) 144 | 145 | # Layer 2. 146 | layer2_states = torch.cat([layer3_states, init_states[1]], dim=1) 147 | layer2_states = self.gru_conv_1x1_3(layer2_states) 148 | layer2_states = self.g3(layer2_states) 149 | layer2_states = self.up_g3(layer2_states) 150 | 151 | # Layer 1 (top-most). 152 | layer1_states = torch.cat([layer2_states, init_states[0]], dim=1) 153 | layer1_states = self.gru_conv_1x1_4(layer1_states) 154 | layer1_states = self.g4(layer1_states) 155 | layer1_states = self.up_g4(layer1_states) 156 | 157 | # Final stuff 158 | output_states = self.relu(self.bn(layer1_states)) 159 | output_states = self.conv_1x1(output_states) 160 | output_states = self.depth2space(output_states) 161 | 162 | # The satellite dimension was lost, add it back 163 | output_states = torch.unsqueeze(output_states, dim=2) 164 | 165 | return output_states 166 | 167 | 168 | class Generator(torch.nn.Module, PyTorchModelHubMixin): 169 | def __init__( 170 | self, 171 | conditioning_stack: torch.nn.Module, 172 | sampler: torch.nn.Module, 173 | ): 174 | """ 175 | Wraps the three parts of the generator for simpler calling 176 | Args: 177 | conditioning_stack: 178 | latent_stack: 179 | sampler: 180 | """ 181 | super().__init__() 182 | self.conditioning_stack = conditioning_stack 183 | self.sampler = sampler 184 | 185 | def forward(self, x): 186 | conditioning_states = self.conditioning_stack(x) 187 | x = self.sampler(conditioning_states) 188 | return x 189 | -------------------------------------------------------------------------------- /experiments/dgmr-oneshot-multichannel/dgmr/hub.py: -------------------------------------------------------------------------------- 1 | """ 2 | Originally Taken from https://github.com/rwightman/ 3 | 4 | https://github.com/rwightman/pytorch-image-models/ 5 | blob/acd6c687fd1c0507128f0ce091829b233c8560b9/timm/models/hub.py 6 | """ 7 | 8 | import json 9 | import logging 10 | import os 11 | from functools import partial 12 | 13 | import torch 14 | 15 | 16 | try: 17 | from huggingface_hub import cached_download, hf_hub_url 18 | 19 | cached_download = partial(cached_download, library_name="dgmr") 20 | except ImportError: 21 | hf_hub_url = None 22 | cached_download = None 23 | 24 | from huggingface_hub import CONFIG_NAME, PYTORCH_WEIGHTS_NAME, ModelHubMixin, hf_hub_download 25 | 26 | MODEL_CARD_MARKDOWN = """--- 27 | license: mit 28 | tags: 29 | - nowcasting 30 | - forecasting 31 | - timeseries 32 | - remote-sensing 33 | - gan 34 | --- 35 | 36 | # {model_name} 37 | 38 | ## Model description 39 | 40 | [More information needed] 41 | 42 | ## Intended uses & limitations 43 | 44 | [More information needed] 45 | 46 | ## How to use 47 | 48 | [More information needed] 49 | 50 | ## Limitations and bias 51 | 52 | [More information needed] 53 | 54 | ## Training data 55 | 56 | [More information needed] 57 | 58 | ## Training procedure 59 | 60 | [More information needed] 61 | 62 | ## Evaluation results 63 | 64 | [More information needed] 65 | 66 | """ 67 | 68 | _logger = logging.getLogger(__name__) 69 | 70 | 71 | class NowcastingModelHubMixin(ModelHubMixin): 72 | """ 73 | HuggingFace ModelHubMixin containing specific adaptions for Nowcasting models 74 | """ 75 | 76 | def __init__(self, *args, **kwargs): 77 | """ 78 | Mixin for pl.LightningModule and Hugging Face 79 | 80 | Mix this class with your pl.LightningModule class to easily push / download 81 | the model via the Hugging Face Hub 82 | 83 | Example:: 84 | 85 | >>> from dgmr.hub import NowcastingModelHubMixin 86 | 87 | >>> class MyModel(nn.Module, NowcastingModelHubMixin): 88 | ... def __init__(self, **kwargs): 89 | ... super().__init__() 90 | ... self.layer = ... 91 | ... def forward(self, ...) 92 | ... return ... 93 | 94 | >>> model = MyModel() 95 | >>> model.push_to_hub("mymodel") # Pushing model-weights to hf-hub 96 | 97 | >>> # Downloading weights from hf-hub & model will be initialized from those weights 98 | >>> model = MyModel.from_pretrained("username/mymodel") 99 | """ 100 | 101 | def _create_model_card(self, path): 102 | model_card = MODEL_CARD_MARKDOWN.format(model_name=type(self).__name__) 103 | with open(os.path.join(path, "README.md"), "w") as f: 104 | f.write(model_card) 105 | 106 | def _save_config(self, module, save_directory): 107 | config = dict(module.hparams) 108 | path = os.path.join(save_directory, CONFIG_NAME) 109 | with open(path, "w") as f: 110 | json.dump(config, f) 111 | 112 | def _save_pretrained(self, save_directory: str, save_config: bool = True): 113 | # Save model weights 114 | path = os.path.join(save_directory, PYTORCH_WEIGHTS_NAME) 115 | model_to_save = self.module if hasattr(self, "module") else self 116 | torch.save(model_to_save.state_dict(), path) 117 | # Save model config 118 | if save_config and model_to_save.hparams: 119 | self._save_config(model_to_save, save_directory) 120 | # Save model card 121 | self._create_model_card(save_directory) 122 | 123 | @classmethod 124 | def _from_pretrained( 125 | cls, 126 | model_id, 127 | revision, 128 | cache_dir, 129 | force_download, 130 | proxies, 131 | resume_download, 132 | local_files_only, 133 | use_auth_token, 134 | map_location="cpu", 135 | strict=False, 136 | **model_kwargs, 137 | ): 138 | map_location = torch.device(map_location) 139 | 140 | if os.path.isdir(model_id): 141 | print("Loading weights from local directory") 142 | model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME) 143 | else: 144 | model_file = hf_hub_download( 145 | repo_id=model_id, 146 | filename=PYTORCH_WEIGHTS_NAME, 147 | revision=revision, 148 | cache_dir=cache_dir, 149 | force_download=force_download, 150 | proxies=proxies, 151 | resume_download=resume_download, 152 | use_auth_token=use_auth_token, 153 | local_files_only=local_files_only, 154 | ) 155 | model = cls(**model_kwargs["config"]) 156 | 157 | state_dict = torch.load(model_file, map_location=map_location) 158 | model.load_state_dict(state_dict, strict=strict) 159 | model.eval() 160 | 161 | return model 162 | -------------------------------------------------------------------------------- /experiments/dgmr-oneshot-multichannel/dgmr/layers/Attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | import einops 5 | 6 | 7 | def attention_einsum(q, k, v): 8 | """Apply the attention operator to tensors of shape [h, w, c].""" 9 | 10 | # Reshape 3D tensors to 2D tensor with first dimension L = h x w. 11 | k = einops.rearrange(k, "h w c -> (h w) c") # [h, w, c] -> [L, c] 12 | v = einops.rearrange(v, "h w c -> (h w) c") # [h, w, c] -> [L, c] 13 | 14 | # Einstein summation corresponding to the query * key operation. 15 | beta = F.softmax(torch.einsum("hwc, Lc->hwL", q, k), dim=-1) 16 | 17 | # Einstein summation corresponding to the attention * value operation. 18 | out = torch.einsum("hwL, Lc->hwc", beta, v) 19 | return out 20 | 21 | 22 | class AttentionLayer(torch.nn.Module): 23 | """Attention Module""" 24 | 25 | def __init__(self, input_channels: int, output_channels: int, ratio_kq=8, ratio_v=8): 26 | super(AttentionLayer, self).__init__() 27 | 28 | self.ratio_kq = ratio_kq 29 | self.ratio_v = ratio_v 30 | self.output_channels = output_channels 31 | self.input_channels = input_channels 32 | 33 | # Compute query, key and value using 1x1 convolutions. 34 | self.query = torch.nn.Conv2d( 35 | in_channels=input_channels, 36 | out_channels=self.output_channels // self.ratio_kq, 37 | kernel_size=(1, 1), 38 | padding="valid", 39 | bias=False, 40 | ) 41 | self.key = torch.nn.Conv2d( 42 | in_channels=input_channels, 43 | out_channels=self.output_channels // self.ratio_kq, 44 | kernel_size=(1, 1), 45 | padding="valid", 46 | bias=False, 47 | ) 48 | self.value = torch.nn.Conv2d( 49 | in_channels=input_channels, 50 | out_channels=self.output_channels // self.ratio_v, 51 | kernel_size=(1, 1), 52 | padding="valid", 53 | bias=False, 54 | ) 55 | 56 | self.last_conv = torch.nn.Conv2d( 57 | in_channels=self.output_channels // 8, 58 | out_channels=self.output_channels, 59 | kernel_size=(1, 1), 60 | padding="valid", 61 | bias=False, 62 | ) 63 | 64 | # Learnable gain parameter 65 | self.gamma = nn.Parameter(torch.zeros(1)) 66 | 67 | def forward(self, x: torch.Tensor) -> torch.Tensor: 68 | # Compute query, key and value using 1x1 convolutions. 69 | query = self.query(x) 70 | key = self.key(x) 71 | value = self.value(x) 72 | # Apply the attention operation. 73 | out = [] 74 | for b in range(x.shape[0]): 75 | # Apply to each in batch 76 | out.append(attention_einsum(query[b], key[b], value[b])) 77 | out = torch.stack(out, dim=0) 78 | out = self.gamma * self.last_conv(out) 79 | # Residual connection. 80 | return out + x 81 | -------------------------------------------------------------------------------- /experiments/dgmr-oneshot-multichannel/dgmr/layers/ConvGRU.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.nn.utils.parametrizations import spectral_norm 4 | 5 | 6 | class ConvGRUCell(torch.nn.Module): 7 | """A ConvGRU implementation.""" 8 | 9 | def __init__( 10 | self, input_channels: int, output_channels: int, kernel_size=3, sn_eps=0.0001 11 | ): 12 | """Constructor. 13 | 14 | Args: 15 | kernel_size: kernel size of the convolutions. Default: 3. 16 | sn_eps: constant for spectral normalization. Default: 1e-4. 17 | """ 18 | super().__init__() 19 | self._kernel_size = kernel_size 20 | self._sn_eps = sn_eps 21 | self.read_gate_conv = spectral_norm( 22 | torch.nn.Conv2d( 23 | in_channels=input_channels, 24 | out_channels=output_channels, 25 | kernel_size=(kernel_size, kernel_size), 26 | padding=1, 27 | ), 28 | eps=sn_eps, 29 | ) 30 | self.update_gate_conv = spectral_norm( 31 | torch.nn.Conv2d( 32 | in_channels=input_channels, 33 | out_channels=output_channels, 34 | kernel_size=(kernel_size, kernel_size), 35 | padding=1, 36 | ), 37 | eps=sn_eps, 38 | ) 39 | self.output_conv = spectral_norm( 40 | torch.nn.Conv2d( 41 | in_channels=input_channels, 42 | out_channels=output_channels, 43 | kernel_size=(kernel_size, kernel_size), 44 | padding=1, 45 | ), 46 | eps=sn_eps, 47 | ) 48 | 49 | def forward(self, x, prev_state): 50 | """ 51 | ConvGRU forward, returning the current+new state 52 | 53 | Args: 54 | x: Input tensor 55 | prev_state: Previous state 56 | 57 | Returns: 58 | New tensor plus the new state 59 | """ 60 | # Concatenate the inputs and previous state along the channel axis. 61 | xh = torch.cat([x, prev_state], dim=1) 62 | 63 | # Read gate of the GRU. 64 | read_gate = F.sigmoid(self.read_gate_conv(xh)) 65 | 66 | # Update gate of the GRU. 67 | update_gate = F.sigmoid(self.update_gate_conv(xh)) 68 | 69 | # Gate the inputs. 70 | gated_input = torch.cat([x, read_gate * prev_state], dim=1) 71 | 72 | # Gate the cell and state / outputs. 73 | c = F.relu(self.output_conv(gated_input)) 74 | out = update_gate * prev_state + (1.0 - update_gate) * c 75 | new_state = out 76 | 77 | return out, new_state 78 | 79 | 80 | class ConvGRU(torch.nn.Module): 81 | """ConvGRU Cell wrapper to replace tf.static_rnn in TF implementation""" 82 | 83 | def __init__( 84 | self, 85 | input_channels: int, 86 | output_channels: int, 87 | kernel_size: int = 3, 88 | sn_eps=0.0001, 89 | ): 90 | super().__init__() 91 | self.cell = ConvGRUCell(input_channels, output_channels, kernel_size, sn_eps) 92 | 93 | def forward(self, x: torch.Tensor, hidden_state=None) -> torch.Tensor: 94 | outputs = [] 95 | for step in range(len(x)): 96 | # Compute current timestep 97 | output, hidden_state = self.cell(x[step], hidden_state) 98 | outputs.append(output) 99 | # Stack outputs to return as tensor 100 | outputs = torch.stack(outputs, dim=0) 101 | return outputs 102 | -------------------------------------------------------------------------------- /experiments/dgmr-oneshot-multichannel/dgmr/layers/CoordConv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class AddCoords(nn.Module): 6 | def __init__(self, with_r=False): 7 | super().__init__() 8 | self.with_r = with_r 9 | 10 | def forward(self, input_tensor): 11 | """ 12 | Args: 13 | input_tensor: shape(batch, channel, x_dim, y_dim) 14 | """ 15 | batch_size, _, x_dim, y_dim = input_tensor.size() 16 | 17 | xx_channel = torch.arange(x_dim).repeat(1, y_dim, 1) 18 | yy_channel = torch.arange(y_dim).repeat(1, x_dim, 1).transpose(1, 2) 19 | 20 | xx_channel = xx_channel.float() / (x_dim - 1) 21 | yy_channel = yy_channel.float() / (y_dim - 1) 22 | 23 | xx_channel = xx_channel * 2 - 1 24 | yy_channel = yy_channel * 2 - 1 25 | 26 | xx_channel = xx_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3) 27 | yy_channel = yy_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3) 28 | 29 | ret = torch.cat( 30 | [ 31 | input_tensor, 32 | xx_channel.type_as(input_tensor), 33 | yy_channel.type_as(input_tensor), 34 | ], 35 | dim=1, 36 | ) 37 | 38 | if self.with_r: 39 | rr = torch.sqrt( 40 | torch.pow(xx_channel.type_as(input_tensor) - 0.5, 2) 41 | + torch.pow(yy_channel.type_as(input_tensor) - 0.5, 2) 42 | ) 43 | ret = torch.cat([ret, rr], dim=1) 44 | 45 | return ret 46 | 47 | 48 | class CoordConv(nn.Module): 49 | def __init__(self, in_channels, out_channels, with_r=False, **kwargs): 50 | super().__init__() 51 | self.addcoords = AddCoords(with_r=with_r) 52 | in_size = in_channels + 2 53 | if with_r: 54 | in_size += 1 55 | self.conv = nn.Conv2d(in_size, out_channels, **kwargs) 56 | 57 | def forward(self, x): 58 | ret = self.addcoords(x) 59 | ret = self.conv(ret) 60 | return ret 61 | -------------------------------------------------------------------------------- /experiments/dgmr-oneshot-multichannel/dgmr/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .Attention import AttentionLayer 2 | from .ConvGRU import ConvGRU 3 | from .CoordConv import CoordConv 4 | -------------------------------------------------------------------------------- /experiments/dgmr-oneshot-multichannel/dgmr/layers/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from dgmr.layers import CoordConv 3 | 4 | 5 | def get_conv_layer(conv_type: str = "standard") -> torch.nn.Module: 6 | if conv_type == "standard": 7 | conv_layer = torch.nn.Conv2d 8 | elif conv_type == "coord": 9 | conv_layer = CoordConv 10 | elif conv_type == "3d": 11 | conv_layer = torch.nn.Conv3d 12 | else: 13 | raise ValueError(f"{conv_type} is not a recognized Conv method") 14 | return conv_layer 15 | -------------------------------------------------------------------------------- /experiments/dgmr-oneshot/dgmr/__init__.py: -------------------------------------------------------------------------------- 1 | from .dgmr import DGMR 2 | from .generators import Sampler, Generator 3 | from .discriminators import SpatialDiscriminator, TemporalDiscriminator, Discriminator 4 | from .common import LatentConditioningStack, ContextConditioningStack 5 | -------------------------------------------------------------------------------- /experiments/dgmr-oneshot/dgmr/dgmr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from dgmr.losses import ( 3 | NowcastingLoss, 4 | GridCellLoss, 5 | loss_hinge_disc, 6 | loss_hinge_gen, 7 | grid_cell_regularizer, 8 | ) 9 | import pytorch_lightning as pl 10 | import torchvision 11 | from dgmr.common import LatentConditioningStack, ContextConditioningStack 12 | from dgmr.generators import Sampler, Generator 13 | from dgmr.discriminators import Discriminator 14 | from dgmr.hub import NowcastingModelHubMixin 15 | 16 | 17 | class DGMR(pl.LightningModule, NowcastingModelHubMixin): 18 | """Deep Generative Model of Radar""" 19 | 20 | def __init__( 21 | self, 22 | forecast_steps: int = 18, 23 | input_channels: int = 1, 24 | output_shape: int = 256, 25 | gen_lr: float = 5e-5, 26 | disc_lr: float = 2e-4, 27 | visualize: bool = False, 28 | conv_type: str = "standard", 29 | num_samples: int = 6, 30 | grid_lambda: float = 20.0, 31 | beta1: float = 0.0, 32 | beta2: float = 0.999, 33 | latent_channels: int = 768, 34 | context_channels: int = 384, 35 | **kwargs, 36 | ): 37 | """ 38 | Nowcasting GAN is an attempt to recreate DeepMind's Skillful Nowcasting GAN from https://arxiv.org/abs/2104.00954 39 | but slightly modified for multiple satellite channels 40 | 41 | Args: 42 | forecast_steps: Number of steps to predict in the future 43 | input_channels: Number of input channels per image 44 | visualize: Whether to visualize output during training 45 | gen_lr: Learning rate for the generator 46 | disc_lr: Learning rate for the discriminators, shared for both temporal and spatial discriminator 47 | conv_type: Type of 2d convolution to use, see satflow/models/utils.py for options 48 | beta1: Beta1 for Adam optimizer 49 | beta2: Beta2 for Adam optimizer 50 | num_samples: Number of samples of the latent space to sample for training/validation 51 | grid_lambda: Lambda for the grid regularization loss 52 | output_shape: Shape of the output predictions, generally should be same as the input shape 53 | latent_channels: Number of channels that the latent space should be reshaped to, 54 | input dimension into ConvGRU, also affects the number of channels for other linked inputs/outputs 55 | pretrained: 56 | """ 57 | super().__init__() 58 | config = locals() 59 | config.pop("__class__") 60 | config.pop("self") 61 | self.config = kwargs.get("config", config) 62 | input_channels = self.config["input_channels"] 63 | forecast_steps = self.config["forecast_steps"] 64 | output_shape = self.config["output_shape"] 65 | gen_lr = self.config["gen_lr"] 66 | disc_lr = self.config["disc_lr"] 67 | conv_type = self.config["conv_type"] 68 | num_samples = self.config["num_samples"] 69 | grid_lambda = self.config["grid_lambda"] 70 | beta1 = self.config["beta1"] 71 | beta2 = self.config["beta2"] 72 | latent_channels = self.config["latent_channels"] 73 | context_channels = self.config["context_channels"] 74 | visualize = self.config["visualize"] 75 | self.gen_lr = gen_lr 76 | self.disc_lr = disc_lr 77 | self.beta1 = beta1 78 | self.beta2 = beta2 79 | self.discriminator_loss = NowcastingLoss() 80 | self.grid_regularizer = GridCellLoss() 81 | self.grid_lambda = grid_lambda 82 | self.num_samples = num_samples 83 | self.visualize = visualize 84 | self.latent_channels = latent_channels 85 | self.context_channels = context_channels 86 | self.input_channels = input_channels 87 | self.conditioning_stack = ContextConditioningStack( 88 | input_channels=input_channels, 89 | conv_type=conv_type, 90 | output_channels=self.context_channels, 91 | ) 92 | self.latent_stack = LatentConditioningStack( 93 | shape=(8 * self.input_channels, output_shape // 32, output_shape // 32), 94 | output_channels=self.latent_channels, 95 | ) 96 | self.sampler = Sampler( 97 | forecast_steps=forecast_steps, 98 | latent_channels=self.latent_channels, 99 | context_channels=self.context_channels, 100 | ) 101 | self.generator = Generator( 102 | self.conditioning_stack, self.latent_stack, self.sampler 103 | ) 104 | self.discriminator = Discriminator(input_channels) 105 | self.save_hyperparameters() 106 | 107 | self.global_iteration = 0 108 | 109 | # Important: This property activates manual optimization. 110 | self.automatic_optimization = False 111 | torch.autograd.set_detect_anomaly(True) 112 | 113 | def forward(self, x): 114 | x = self.generator(x) 115 | return x 116 | 117 | def training_step(self, batch, batch_idx): 118 | images, future_images = batch 119 | 120 | self.global_iteration += 1 121 | g_opt, d_opt = self.optimizers() 122 | ########################## 123 | # Optimize Discriminator # 124 | ########################## 125 | # Two discriminator steps per generator step 126 | for _ in range(2): 127 | predictions = self(images) 128 | # Cat along time dimension [B, T, C, H, W] 129 | generated_sequence = torch.cat([images, predictions], dim=1) 130 | real_sequence = torch.cat([images, future_images], dim=1) 131 | # Cat long batch for the real+generated 132 | concatenated_inputs = torch.cat([real_sequence, generated_sequence], dim=0) 133 | 134 | concatenated_outputs = self.discriminator(concatenated_inputs) 135 | 136 | score_real, score_generated = torch.split(concatenated_outputs, 1, dim=1) 137 | discriminator_loss = loss_hinge_disc(score_generated, score_real) 138 | d_opt.zero_grad() 139 | self.manual_backward(discriminator_loss) 140 | d_opt.step() 141 | 142 | ###################### 143 | # Optimize Generator # 144 | ###################### 145 | predictions = [self(images) for _ in range(6)] 146 | grid_cell_reg = grid_cell_regularizer( 147 | torch.stack(predictions, dim=0), future_images 148 | ) 149 | # Concat along time dimension 150 | generated_sequence = [torch.cat([images, x], dim=1) for x in predictions] 151 | real_sequence = torch.cat([images, future_images], dim=1) 152 | # Cat long batch for the real+generated, for each example in the range 153 | # For each of the 6 examples 154 | generated_scores = [] 155 | for g_seq in generated_sequence: 156 | concatenated_inputs = torch.cat([real_sequence, g_seq], dim=0) 157 | concatenated_outputs = self.discriminator(concatenated_inputs) 158 | score_real, score_generated = torch.split(concatenated_outputs, 1, dim=1) 159 | generated_scores.append(score_generated) 160 | generator_disc_loss = loss_hinge_gen(torch.cat(generated_scores, dim=0)) 161 | generator_loss = generator_disc_loss + self.grid_lambda * grid_cell_reg 162 | g_opt.zero_grad() 163 | self.manual_backward(generator_loss) 164 | g_opt.step() 165 | 166 | self.log_dict( 167 | { 168 | "train/d_loss": discriminator_loss, 169 | "train/g_loss": generator_loss, 170 | "train/grid_loss": grid_cell_reg, 171 | }, 172 | prog_bar=True, 173 | ) 174 | 175 | # generate images 176 | generated_images = self(images) 177 | # log sampled images 178 | if self.visualize: 179 | self.visualize_step( 180 | images, 181 | future_images, 182 | generated_images, 183 | self.global_iteration, 184 | step="train", 185 | ) 186 | 187 | def configure_optimizers(self): 188 | b1 = self.beta1 189 | b2 = self.beta2 190 | 191 | opt_g = torch.optim.Adam( 192 | self.generator.parameters(), lr=self.gen_lr, betas=(b1, b2) 193 | ) 194 | opt_d = torch.optim.Adam( 195 | self.discriminator.parameters(), lr=self.disc_lr, betas=(b1, b2) 196 | ) 197 | 198 | return [opt_g, opt_d], [] 199 | 200 | def visualize_step( 201 | self, 202 | x: torch.Tensor, 203 | y: torch.Tensor, 204 | y_hat: torch.Tensor, 205 | batch_idx: int, 206 | step: str, 207 | ) -> None: 208 | # the logger you used (in this case tensorboard) 209 | tensorboard = self.logger.experiment[0] 210 | # Timesteps per channel 211 | images = x[0].cpu().detach() 212 | future_images = y[0].cpu().detach() 213 | generated_images = y_hat[0].cpu().detach() 214 | for i, t in enumerate(images): # Now would be (C, H, W) 215 | t = [torch.unsqueeze(img, dim=0) for img in t] 216 | image_grid = torchvision.utils.make_grid(t, nrow=self.input_channels) 217 | tensorboard.add_image( 218 | f"{step}/Input_Image_Stack_Frame_{i}", image_grid, global_step=batch_idx 219 | ) 220 | t = [torch.unsqueeze(img, dim=0) for img in future_images[i]] 221 | image_grid = torchvision.utils.make_grid(t, nrow=self.input_channels) 222 | tensorboard.add_image( 223 | f"{step}/Target_Image_Frame_{i}", image_grid, global_step=batch_idx 224 | ) 225 | t = [torch.unsqueeze(img, dim=0) for img in generated_images[i]] 226 | image_grid = torchvision.utils.make_grid(t, nrow=self.input_channels) 227 | tensorboard.add_image( 228 | f"{step}/Generated_Image_Frame_{i}", image_grid, global_step=batch_idx 229 | ) 230 | -------------------------------------------------------------------------------- /experiments/dgmr-oneshot/dgmr/generators.py: -------------------------------------------------------------------------------- 1 | import einops 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.nn.modules.pixelshuffle import PixelShuffle 5 | from torch.nn.utils.parametrizations import spectral_norm 6 | from typing import List 7 | from dgmr.common import GBlock, UpsampleGBlock 8 | from dgmr.layers import ConvGRU 9 | from huggingface_hub import PyTorchModelHubMixin 10 | import logging 11 | 12 | logger = logging.getLogger(__name__) 13 | logger.setLevel(logging.WARN) 14 | 15 | 16 | class Sampler(torch.nn.Module, PyTorchModelHubMixin): 17 | def __init__( 18 | self, 19 | forecast_steps: int = 18, 20 | context_channels: int = 384, 21 | latent_channels: int = 384, 22 | output_channels: int = 1, 23 | **kwargs 24 | ): 25 | """ 26 | Sampler from the Skillful Nowcasting, see https://arxiv.org/pdf/2104.00954.pdf 27 | 28 | The sampler takes the output from the Latent and Context conditioning stacks and 29 | creates one stack of ConvGRU layers per future timestep. 30 | Args: 31 | forecast_steps: Number of forecast steps 32 | latent_channels: Number of input channels to the lowest ConvGRU layer 33 | """ 34 | super().__init__() 35 | config = locals() 36 | config.pop("__class__") 37 | config.pop("self") 38 | self.config = kwargs.get("config", config) 39 | self.forecast_steps = self.config["forecast_steps"] 40 | latent_channels = self.config["latent_channels"] 41 | context_channels = self.config["context_channels"] 42 | output_channels = self.config["output_channels"] 43 | 44 | self.gru_conv_1x1 = spectral_norm( 45 | torch.nn.Conv2d( 46 | in_channels=context_channels, 47 | out_channels=latent_channels * self.forecast_steps, 48 | kernel_size=(1, 1), 49 | ) 50 | ) 51 | self.g1 = GBlock( 52 | input_channels=latent_channels * self.forecast_steps, 53 | output_channels=latent_channels * self.forecast_steps, 54 | ) 55 | self.up_g1 = UpsampleGBlock( 56 | input_channels=latent_channels * self.forecast_steps, 57 | output_channels=latent_channels * self.forecast_steps // 2, 58 | ) 59 | 60 | self.gru_conv_1x1_2 = spectral_norm( 61 | torch.nn.Conv2d( 62 | in_channels=self.up_g1.output_channels + context_channels // 2, 63 | out_channels=latent_channels * self.forecast_steps // 2, 64 | kernel_size=(1, 1), 65 | ) 66 | ) 67 | self.g2 = GBlock( 68 | input_channels=latent_channels * self.forecast_steps // 2, 69 | output_channels=latent_channels * self.forecast_steps // 2, 70 | ) 71 | self.up_g2 = UpsampleGBlock( 72 | input_channels=latent_channels * self.forecast_steps // 2, 73 | output_channels=latent_channels * self.forecast_steps // 4, 74 | ) 75 | 76 | self.gru_conv_1x1_3 = spectral_norm( 77 | torch.nn.Conv2d( 78 | in_channels=self.up_g2.output_channels + context_channels // 4, 79 | out_channels=latent_channels * self.forecast_steps // 4, 80 | kernel_size=(1, 1), 81 | ) 82 | ) 83 | self.g3 = GBlock( 84 | input_channels=latent_channels * self.forecast_steps // 4, 85 | output_channels=latent_channels * self.forecast_steps // 4, 86 | ) 87 | self.up_g3 = UpsampleGBlock( 88 | input_channels=latent_channels * self.forecast_steps // 4, 89 | output_channels=latent_channels * self.forecast_steps // 8, 90 | ) 91 | 92 | self.gru_conv_1x1_4 = spectral_norm( 93 | torch.nn.Conv2d( 94 | in_channels=self.up_g3.output_channels + context_channels // 8, 95 | out_channels=latent_channels * self.forecast_steps // 8, 96 | kernel_size=(1, 1), 97 | ) 98 | ) 99 | self.g4 = GBlock( 100 | input_channels=latent_channels * self.forecast_steps // 8, 101 | output_channels=latent_channels * self.forecast_steps // 8, 102 | ) 103 | self.up_g4 = UpsampleGBlock( 104 | input_channels=latent_channels * self.forecast_steps // 8, 105 | output_channels=latent_channels * self.forecast_steps // 16, 106 | ) 107 | 108 | self.bn = torch.nn.BatchNorm2d(latent_channels * self.forecast_steps // 16) 109 | self.relu = torch.nn.ReLU() 110 | self.conv_1x1 = spectral_norm( 111 | torch.nn.Conv2d( 112 | in_channels=latent_channels * self.forecast_steps // 16, 113 | out_channels=4 * output_channels * self.forecast_steps, 114 | kernel_size=(1, 1), 115 | ) 116 | ) 117 | 118 | self.depth2space = PixelShuffle(upscale_factor=2) 119 | 120 | def forward(self, conditioning_states: List[torch.Tensor]) -> torch.Tensor: 121 | """ 122 | Perform the sampling from Skillful Nowcasting with GANs 123 | Args: 124 | conditioning_states: Outputs from the `ContextConditioningStack` with the 4 input states, ordered from largest to smallest spatially 125 | latent_dim: Output from `LatentConditioningStack` for input into the ConvGRUs 126 | 127 | Returns: 128 | forecast_steps-length output of images for future timesteps 129 | 130 | """ 131 | # Iterate through each forecast step 132 | # Initialize with conditioning state for first one, output for second one 133 | init_states = conditioning_states 134 | 135 | layer4_states = self.gru_conv_1x1(init_states[3]) 136 | layer4_states = self.g1(layer4_states) 137 | layer4_states = self.up_g1(layer4_states) 138 | 139 | # Layer 3. 140 | layer3_states = torch.cat([layer4_states, init_states[2]], dim=1) 141 | layer3_states = self.gru_conv_1x1_2(layer3_states) 142 | layer3_states = self.g2(layer3_states) 143 | layer3_states = self.up_g2(layer3_states) 144 | 145 | # Layer 2. 146 | layer2_states = torch.cat([layer3_states, init_states[1]], dim=1) 147 | layer2_states = self.gru_conv_1x1_3(layer2_states) 148 | layer2_states = self.g3(layer2_states) 149 | layer2_states = self.up_g3(layer2_states) 150 | 151 | # Layer 1 (top-most). 152 | layer1_states = torch.cat([layer2_states, init_states[0]], dim=1) 153 | layer1_states = self.gru_conv_1x1_4(layer1_states) 154 | layer1_states = self.g4(layer1_states) 155 | layer1_states = self.up_g4(layer1_states) 156 | 157 | # Final stuff 158 | output_states = self.relu(self.bn(layer1_states)) 159 | output_states = self.conv_1x1(output_states) 160 | output_states = self.depth2space(output_states) 161 | 162 | # The satellite dimension was lost, add it back 163 | output_states = torch.unsqueeze(output_states, dim=2) 164 | 165 | return output_states 166 | 167 | 168 | class Generator(torch.nn.Module, PyTorchModelHubMixin): 169 | def __init__( 170 | self, 171 | conditioning_stack: torch.nn.Module, 172 | sampler: torch.nn.Module, 173 | ): 174 | """ 175 | Wraps the three parts of the generator for simpler calling 176 | Args: 177 | conditioning_stack: 178 | latent_stack: 179 | sampler: 180 | """ 181 | super().__init__() 182 | self.conditioning_stack = conditioning_stack 183 | self.sampler = sampler 184 | 185 | def forward(self, x): 186 | conditioning_states = self.conditioning_stack(x) 187 | x = self.sampler(conditioning_states) 188 | return x 189 | -------------------------------------------------------------------------------- /experiments/dgmr-oneshot/dgmr/hub.py: -------------------------------------------------------------------------------- 1 | """ 2 | Originally Taken from https://github.com/rwightman/ 3 | 4 | https://github.com/rwightman/pytorch-image-models/ 5 | blob/acd6c687fd1c0507128f0ce091829b233c8560b9/timm/models/hub.py 6 | """ 7 | 8 | import json 9 | import logging 10 | import os 11 | from functools import partial 12 | 13 | import torch 14 | 15 | 16 | try: 17 | from huggingface_hub import cached_download, hf_hub_url 18 | 19 | cached_download = partial(cached_download, library_name="dgmr") 20 | except ImportError: 21 | hf_hub_url = None 22 | cached_download = None 23 | 24 | from huggingface_hub import CONFIG_NAME, PYTORCH_WEIGHTS_NAME, ModelHubMixin, hf_hub_download 25 | 26 | MODEL_CARD_MARKDOWN = """--- 27 | license: mit 28 | tags: 29 | - nowcasting 30 | - forecasting 31 | - timeseries 32 | - remote-sensing 33 | - gan 34 | --- 35 | 36 | # {model_name} 37 | 38 | ## Model description 39 | 40 | [More information needed] 41 | 42 | ## Intended uses & limitations 43 | 44 | [More information needed] 45 | 46 | ## How to use 47 | 48 | [More information needed] 49 | 50 | ## Limitations and bias 51 | 52 | [More information needed] 53 | 54 | ## Training data 55 | 56 | [More information needed] 57 | 58 | ## Training procedure 59 | 60 | [More information needed] 61 | 62 | ## Evaluation results 63 | 64 | [More information needed] 65 | 66 | """ 67 | 68 | _logger = logging.getLogger(__name__) 69 | 70 | 71 | class NowcastingModelHubMixin(ModelHubMixin): 72 | """ 73 | HuggingFace ModelHubMixin containing specific adaptions for Nowcasting models 74 | """ 75 | 76 | def __init__(self, *args, **kwargs): 77 | """ 78 | Mixin for pl.LightningModule and Hugging Face 79 | 80 | Mix this class with your pl.LightningModule class to easily push / download 81 | the model via the Hugging Face Hub 82 | 83 | Example:: 84 | 85 | >>> from dgmr.hub import NowcastingModelHubMixin 86 | 87 | >>> class MyModel(nn.Module, NowcastingModelHubMixin): 88 | ... def __init__(self, **kwargs): 89 | ... super().__init__() 90 | ... self.layer = ... 91 | ... def forward(self, ...) 92 | ... return ... 93 | 94 | >>> model = MyModel() 95 | >>> model.push_to_hub("mymodel") # Pushing model-weights to hf-hub 96 | 97 | >>> # Downloading weights from hf-hub & model will be initialized from those weights 98 | >>> model = MyModel.from_pretrained("username/mymodel") 99 | """ 100 | 101 | def _create_model_card(self, path): 102 | model_card = MODEL_CARD_MARKDOWN.format(model_name=type(self).__name__) 103 | with open(os.path.join(path, "README.md"), "w") as f: 104 | f.write(model_card) 105 | 106 | def _save_config(self, module, save_directory): 107 | config = dict(module.hparams) 108 | path = os.path.join(save_directory, CONFIG_NAME) 109 | with open(path, "w") as f: 110 | json.dump(config, f) 111 | 112 | def _save_pretrained(self, save_directory: str, save_config: bool = True): 113 | # Save model weights 114 | path = os.path.join(save_directory, PYTORCH_WEIGHTS_NAME) 115 | model_to_save = self.module if hasattr(self, "module") else self 116 | torch.save(model_to_save.state_dict(), path) 117 | # Save model config 118 | if save_config and model_to_save.hparams: 119 | self._save_config(model_to_save, save_directory) 120 | # Save model card 121 | self._create_model_card(save_directory) 122 | 123 | @classmethod 124 | def _from_pretrained( 125 | cls, 126 | model_id, 127 | revision, 128 | cache_dir, 129 | force_download, 130 | proxies, 131 | resume_download, 132 | local_files_only, 133 | use_auth_token, 134 | map_location="cpu", 135 | strict=False, 136 | **model_kwargs, 137 | ): 138 | map_location = torch.device(map_location) 139 | 140 | if os.path.isdir(model_id): 141 | print("Loading weights from local directory") 142 | model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME) 143 | else: 144 | model_file = hf_hub_download( 145 | repo_id=model_id, 146 | filename=PYTORCH_WEIGHTS_NAME, 147 | revision=revision, 148 | cache_dir=cache_dir, 149 | force_download=force_download, 150 | proxies=proxies, 151 | resume_download=resume_download, 152 | use_auth_token=use_auth_token, 153 | local_files_only=local_files_only, 154 | ) 155 | model = cls(**model_kwargs["config"]) 156 | 157 | state_dict = torch.load(model_file, map_location=map_location) 158 | model.load_state_dict(state_dict, strict=strict) 159 | model.eval() 160 | 161 | return model 162 | -------------------------------------------------------------------------------- /experiments/dgmr-oneshot/dgmr/layers/Attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | import einops 5 | 6 | 7 | def attention_einsum(q, k, v): 8 | """Apply the attention operator to tensors of shape [h, w, c].""" 9 | 10 | # Reshape 3D tensors to 2D tensor with first dimension L = h x w. 11 | k = einops.rearrange(k, "h w c -> (h w) c") # [h, w, c] -> [L, c] 12 | v = einops.rearrange(v, "h w c -> (h w) c") # [h, w, c] -> [L, c] 13 | 14 | # Einstein summation corresponding to the query * key operation. 15 | beta = F.softmax(torch.einsum("hwc, Lc->hwL", q, k), dim=-1) 16 | 17 | # Einstein summation corresponding to the attention * value operation. 18 | out = torch.einsum("hwL, Lc->hwc", beta, v) 19 | return out 20 | 21 | 22 | class AttentionLayer(torch.nn.Module): 23 | """Attention Module""" 24 | 25 | def __init__(self, input_channels: int, output_channels: int, ratio_kq=8, ratio_v=8): 26 | super(AttentionLayer, self).__init__() 27 | 28 | self.ratio_kq = ratio_kq 29 | self.ratio_v = ratio_v 30 | self.output_channels = output_channels 31 | self.input_channels = input_channels 32 | 33 | # Compute query, key and value using 1x1 convolutions. 34 | self.query = torch.nn.Conv2d( 35 | in_channels=input_channels, 36 | out_channels=self.output_channels // self.ratio_kq, 37 | kernel_size=(1, 1), 38 | padding="valid", 39 | bias=False, 40 | ) 41 | self.key = torch.nn.Conv2d( 42 | in_channels=input_channels, 43 | out_channels=self.output_channels // self.ratio_kq, 44 | kernel_size=(1, 1), 45 | padding="valid", 46 | bias=False, 47 | ) 48 | self.value = torch.nn.Conv2d( 49 | in_channels=input_channels, 50 | out_channels=self.output_channels // self.ratio_v, 51 | kernel_size=(1, 1), 52 | padding="valid", 53 | bias=False, 54 | ) 55 | 56 | self.last_conv = torch.nn.Conv2d( 57 | in_channels=self.output_channels // 8, 58 | out_channels=self.output_channels, 59 | kernel_size=(1, 1), 60 | padding="valid", 61 | bias=False, 62 | ) 63 | 64 | # Learnable gain parameter 65 | self.gamma = nn.Parameter(torch.zeros(1)) 66 | 67 | def forward(self, x: torch.Tensor) -> torch.Tensor: 68 | # Compute query, key and value using 1x1 convolutions. 69 | query = self.query(x) 70 | key = self.key(x) 71 | value = self.value(x) 72 | # Apply the attention operation. 73 | out = [] 74 | for b in range(x.shape[0]): 75 | # Apply to each in batch 76 | out.append(attention_einsum(query[b], key[b], value[b])) 77 | out = torch.stack(out, dim=0) 78 | out = self.gamma * self.last_conv(out) 79 | # Residual connection. 80 | return out + x 81 | -------------------------------------------------------------------------------- /experiments/dgmr-oneshot/dgmr/layers/ConvGRU.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.nn.utils.parametrizations import spectral_norm 4 | 5 | 6 | class ConvGRUCell(torch.nn.Module): 7 | """A ConvGRU implementation.""" 8 | 9 | def __init__( 10 | self, input_channels: int, output_channels: int, kernel_size=3, sn_eps=0.0001 11 | ): 12 | """Constructor. 13 | 14 | Args: 15 | kernel_size: kernel size of the convolutions. Default: 3. 16 | sn_eps: constant for spectral normalization. Default: 1e-4. 17 | """ 18 | super().__init__() 19 | self._kernel_size = kernel_size 20 | self._sn_eps = sn_eps 21 | self.read_gate_conv = spectral_norm( 22 | torch.nn.Conv2d( 23 | in_channels=input_channels, 24 | out_channels=output_channels, 25 | kernel_size=(kernel_size, kernel_size), 26 | padding=1, 27 | ), 28 | eps=sn_eps, 29 | ) 30 | self.update_gate_conv = spectral_norm( 31 | torch.nn.Conv2d( 32 | in_channels=input_channels, 33 | out_channels=output_channels, 34 | kernel_size=(kernel_size, kernel_size), 35 | padding=1, 36 | ), 37 | eps=sn_eps, 38 | ) 39 | self.output_conv = spectral_norm( 40 | torch.nn.Conv2d( 41 | in_channels=input_channels, 42 | out_channels=output_channels, 43 | kernel_size=(kernel_size, kernel_size), 44 | padding=1, 45 | ), 46 | eps=sn_eps, 47 | ) 48 | 49 | def forward(self, x, prev_state): 50 | """ 51 | ConvGRU forward, returning the current+new state 52 | 53 | Args: 54 | x: Input tensor 55 | prev_state: Previous state 56 | 57 | Returns: 58 | New tensor plus the new state 59 | """ 60 | # Concatenate the inputs and previous state along the channel axis. 61 | xh = torch.cat([x, prev_state], dim=1) 62 | 63 | # Read gate of the GRU. 64 | read_gate = F.sigmoid(self.read_gate_conv(xh)) 65 | 66 | # Update gate of the GRU. 67 | update_gate = F.sigmoid(self.update_gate_conv(xh)) 68 | 69 | # Gate the inputs. 70 | gated_input = torch.cat([x, read_gate * prev_state], dim=1) 71 | 72 | # Gate the cell and state / outputs. 73 | c = F.relu(self.output_conv(gated_input)) 74 | out = update_gate * prev_state + (1.0 - update_gate) * c 75 | new_state = out 76 | 77 | return out, new_state 78 | 79 | 80 | class ConvGRU(torch.nn.Module): 81 | """ConvGRU Cell wrapper to replace tf.static_rnn in TF implementation""" 82 | 83 | def __init__( 84 | self, 85 | input_channels: int, 86 | output_channels: int, 87 | kernel_size: int = 3, 88 | sn_eps=0.0001, 89 | ): 90 | super().__init__() 91 | self.cell = ConvGRUCell(input_channels, output_channels, kernel_size, sn_eps) 92 | 93 | def forward(self, x: torch.Tensor, hidden_state=None) -> torch.Tensor: 94 | outputs = [] 95 | for step in range(len(x)): 96 | # Compute current timestep 97 | output, hidden_state = self.cell(x[step], hidden_state) 98 | outputs.append(output) 99 | # Stack outputs to return as tensor 100 | outputs = torch.stack(outputs, dim=0) 101 | return outputs 102 | -------------------------------------------------------------------------------- /experiments/dgmr-oneshot/dgmr/layers/CoordConv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class AddCoords(nn.Module): 6 | def __init__(self, with_r=False): 7 | super().__init__() 8 | self.with_r = with_r 9 | 10 | def forward(self, input_tensor): 11 | """ 12 | Args: 13 | input_tensor: shape(batch, channel, x_dim, y_dim) 14 | """ 15 | batch_size, _, x_dim, y_dim = input_tensor.size() 16 | 17 | xx_channel = torch.arange(x_dim).repeat(1, y_dim, 1) 18 | yy_channel = torch.arange(y_dim).repeat(1, x_dim, 1).transpose(1, 2) 19 | 20 | xx_channel = xx_channel.float() / (x_dim - 1) 21 | yy_channel = yy_channel.float() / (y_dim - 1) 22 | 23 | xx_channel = xx_channel * 2 - 1 24 | yy_channel = yy_channel * 2 - 1 25 | 26 | xx_channel = xx_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3) 27 | yy_channel = yy_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3) 28 | 29 | ret = torch.cat( 30 | [ 31 | input_tensor, 32 | xx_channel.type_as(input_tensor), 33 | yy_channel.type_as(input_tensor), 34 | ], 35 | dim=1, 36 | ) 37 | 38 | if self.with_r: 39 | rr = torch.sqrt( 40 | torch.pow(xx_channel.type_as(input_tensor) - 0.5, 2) 41 | + torch.pow(yy_channel.type_as(input_tensor) - 0.5, 2) 42 | ) 43 | ret = torch.cat([ret, rr], dim=1) 44 | 45 | return ret 46 | 47 | 48 | class CoordConv(nn.Module): 49 | def __init__(self, in_channels, out_channels, with_r=False, **kwargs): 50 | super().__init__() 51 | self.addcoords = AddCoords(with_r=with_r) 52 | in_size = in_channels + 2 53 | if with_r: 54 | in_size += 1 55 | self.conv = nn.Conv2d(in_size, out_channels, **kwargs) 56 | 57 | def forward(self, x): 58 | ret = self.addcoords(x) 59 | ret = self.conv(ret) 60 | return ret 61 | -------------------------------------------------------------------------------- /experiments/dgmr-oneshot/dgmr/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .Attention import AttentionLayer 2 | from .ConvGRU import ConvGRU 3 | from .CoordConv import CoordConv 4 | -------------------------------------------------------------------------------- /experiments/dgmr-oneshot/dgmr/layers/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from dgmr.layers import CoordConv 3 | 4 | 5 | def get_conv_layer(conv_type: str = "standard") -> torch.nn.Module: 6 | if conv_type == "standard": 7 | conv_layer = torch.nn.Conv2d 8 | elif conv_type == "coord": 9 | conv_layer = CoordConv 10 | elif conv_type == "3d": 11 | conv_layer = torch.nn.Conv3d 12 | else: 13 | raise ValueError(f"{conv_type} is not a recognized Conv method") 14 | return conv_layer 15 | -------------------------------------------------------------------------------- /experiments/dgmr-original/dgmr/__init__.py: -------------------------------------------------------------------------------- 1 | from .dgmr import DGMR 2 | from .generators import Sampler, Generator 3 | from .discriminators import SpatialDiscriminator, TemporalDiscriminator, Discriminator 4 | from .common import LatentConditioningStack, ContextConditioningStack 5 | -------------------------------------------------------------------------------- /experiments/dgmr-original/dgmr/generators.py: -------------------------------------------------------------------------------- 1 | import einops 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.nn.modules.pixelshuffle import PixelShuffle 5 | from torch.nn.utils.parametrizations import spectral_norm 6 | from typing import List 7 | from dgmr.common import GBlock, UpsampleGBlock 8 | from dgmr.layers import ConvGRU 9 | from huggingface_hub import PyTorchModelHubMixin 10 | import logging 11 | 12 | logger = logging.getLogger(__name__) 13 | logger.setLevel(logging.WARN) 14 | 15 | 16 | class Sampler(torch.nn.Module, PyTorchModelHubMixin): 17 | def __init__( 18 | self, 19 | forecast_steps: int = 18, 20 | latent_channels: int = 768, 21 | context_channels: int = 384, 22 | output_channels: int = 1, 23 | **kwargs 24 | ): 25 | """ 26 | Sampler from the Skillful Nowcasting, see https://arxiv.org/pdf/2104.00954.pdf 27 | 28 | The sampler takes the output from the Latent and Context conditioning stacks and 29 | creates one stack of ConvGRU layers per future timestep. 30 | Args: 31 | forecast_steps: Number of forecast steps 32 | latent_channels: Number of input channels to the lowest ConvGRU layer 33 | """ 34 | super().__init__() 35 | config = locals() 36 | config.pop("__class__") 37 | config.pop("self") 38 | self.config = kwargs.get("config", config) 39 | self.forecast_steps = self.config["forecast_steps"] 40 | latent_channels = self.config["latent_channels"] 41 | context_channels = self.config["context_channels"] 42 | output_channels = self.config["output_channels"] 43 | 44 | self.convGRU1 = ConvGRU( 45 | input_channels=latent_channels + context_channels, 46 | output_channels=context_channels, 47 | kernel_size=3, 48 | ) 49 | self.gru_conv_1x1 = spectral_norm( 50 | torch.nn.Conv2d( 51 | in_channels=context_channels, 52 | out_channels=latent_channels, 53 | kernel_size=(1, 1), 54 | ) 55 | ) 56 | self.g1 = GBlock( 57 | input_channels=latent_channels, output_channels=latent_channels 58 | ) 59 | self.up_g1 = UpsampleGBlock( 60 | input_channels=latent_channels, output_channels=latent_channels // 2 61 | ) 62 | 63 | self.convGRU2 = ConvGRU( 64 | input_channels=latent_channels // 2 + context_channels // 2, 65 | output_channels=context_channels // 2, 66 | kernel_size=3, 67 | ) 68 | self.gru_conv_1x1_2 = spectral_norm( 69 | torch.nn.Conv2d( 70 | in_channels=context_channels // 2, 71 | out_channels=latent_channels // 2, 72 | kernel_size=(1, 1), 73 | ) 74 | ) 75 | self.g2 = GBlock( 76 | input_channels=latent_channels // 2, output_channels=latent_channels // 2 77 | ) 78 | self.up_g2 = UpsampleGBlock( 79 | input_channels=latent_channels // 2, output_channels=latent_channels // 4 80 | ) 81 | 82 | self.convGRU3 = ConvGRU( 83 | input_channels=latent_channels // 4 + context_channels // 4, 84 | output_channels=context_channels // 4, 85 | kernel_size=3, 86 | ) 87 | self.gru_conv_1x1_3 = spectral_norm( 88 | torch.nn.Conv2d( 89 | in_channels=context_channels // 4, 90 | out_channels=latent_channels // 4, 91 | kernel_size=(1, 1), 92 | ) 93 | ) 94 | self.g3 = GBlock( 95 | input_channels=latent_channels // 4, output_channels=latent_channels // 4 96 | ) 97 | self.up_g3 = UpsampleGBlock( 98 | input_channels=latent_channels // 4, output_channels=latent_channels // 8 99 | ) 100 | 101 | self.convGRU4 = ConvGRU( 102 | input_channels=latent_channels // 8 + context_channels // 8, 103 | output_channels=context_channels // 8, 104 | kernel_size=3, 105 | ) 106 | self.gru_conv_1x1_4 = spectral_norm( 107 | torch.nn.Conv2d( 108 | in_channels=context_channels // 8, 109 | out_channels=latent_channels // 8, 110 | kernel_size=(1, 1), 111 | ) 112 | ) 113 | self.g4 = GBlock( 114 | input_channels=latent_channels // 8, output_channels=latent_channels // 8 115 | ) 116 | self.up_g4 = UpsampleGBlock( 117 | input_channels=latent_channels // 8, output_channels=latent_channels // 16 118 | ) 119 | 120 | self.bn = torch.nn.BatchNorm2d(latent_channels // 16) 121 | self.relu = torch.nn.ReLU() 122 | self.conv_1x1 = spectral_norm( 123 | torch.nn.Conv2d( 124 | in_channels=latent_channels // 16, 125 | out_channels=4 * output_channels, 126 | kernel_size=(1, 1), 127 | ) 128 | ) 129 | 130 | self.depth2space = PixelShuffle(upscale_factor=2) 131 | 132 | def forward( 133 | self, conditioning_states: List[torch.Tensor], latent_dim: torch.Tensor 134 | ) -> torch.Tensor: 135 | """ 136 | Perform the sampling from Skillful Nowcasting with GANs 137 | Args: 138 | conditioning_states: Outputs from the `ContextConditioningStack` with the 4 input states, ordered from largest to smallest spatially 139 | latent_dim: Output from `LatentConditioningStack` for input into the ConvGRUs 140 | 141 | Returns: 142 | forecast_steps-length output of images for future timesteps 143 | 144 | """ 145 | # Iterate through each forecast step 146 | # Initialize with conditioning state for first one, output for second one 147 | init_states = conditioning_states 148 | # Expand latent dim to match batch size 149 | latent_dim = einops.repeat( 150 | latent_dim, "b c h w -> (repeat b) c h w", repeat=init_states[0].shape[0] 151 | ) 152 | hidden_states = [latent_dim] * self.forecast_steps 153 | 154 | # Layer 4 (bottom most) 155 | hidden_states = self.convGRU1(hidden_states, init_states[3]) 156 | hidden_states = [self.gru_conv_1x1(h) for h in hidden_states] 157 | hidden_states = [self.g1(h) for h in hidden_states] 158 | hidden_states = [self.up_g1(h) for h in hidden_states] 159 | 160 | # Layer 3. 161 | hidden_states = self.convGRU2(hidden_states, init_states[2]) 162 | hidden_states = [self.gru_conv_1x1_2(h) for h in hidden_states] 163 | hidden_states = [self.g2(h) for h in hidden_states] 164 | hidden_states = [self.up_g2(h) for h in hidden_states] 165 | 166 | # Layer 2. 167 | hidden_states = self.convGRU3(hidden_states, init_states[1]) 168 | hidden_states = [self.gru_conv_1x1_3(h) for h in hidden_states] 169 | hidden_states = [self.g3(h) for h in hidden_states] 170 | hidden_states = [self.up_g3(h) for h in hidden_states] 171 | 172 | # Layer 1 (top-most). 173 | hidden_states = self.convGRU4(hidden_states, init_states[0]) 174 | hidden_states = [self.gru_conv_1x1_4(h) for h in hidden_states] 175 | hidden_states = [self.g4(h) for h in hidden_states] 176 | hidden_states = [self.up_g4(h) for h in hidden_states] 177 | 178 | # Output layer. 179 | hidden_states = [F.relu(self.bn(h)) for h in hidden_states] 180 | hidden_states = [self.conv_1x1(h) for h in hidden_states] 181 | hidden_states = [self.depth2space(h) for h in hidden_states] 182 | 183 | # Convert forecasts to a torch Tensor 184 | forecasts = torch.stack(hidden_states, dim=1) 185 | return forecasts 186 | 187 | 188 | class Generator(torch.nn.Module, PyTorchModelHubMixin): 189 | def __init__( 190 | self, 191 | conditioning_stack: torch.nn.Module, 192 | latent_stack: torch.nn.Module, 193 | sampler: torch.nn.Module, 194 | ): 195 | """ 196 | Wraps the three parts of the generator for simpler calling 197 | Args: 198 | conditioning_stack: 199 | latent_stack: 200 | sampler: 201 | """ 202 | super().__init__() 203 | self.conditioning_stack = conditioning_stack 204 | self.latent_stack = latent_stack 205 | self.sampler = sampler 206 | 207 | def forward(self, x): 208 | conditioning_states = self.conditioning_stack(x) 209 | latent_dim = self.latent_stack(x) 210 | x = self.sampler(conditioning_states, latent_dim) 211 | return x 212 | -------------------------------------------------------------------------------- /experiments/dgmr-original/dgmr/hub.py: -------------------------------------------------------------------------------- 1 | """ 2 | Originally Taken from https://github.com/rwightman/ 3 | 4 | https://github.com/rwightman/pytorch-image-models/ 5 | blob/acd6c687fd1c0507128f0ce091829b233c8560b9/timm/models/hub.py 6 | """ 7 | 8 | import json 9 | import logging 10 | import os 11 | from functools import partial 12 | 13 | import torch 14 | 15 | 16 | try: 17 | from huggingface_hub import cached_download, hf_hub_url 18 | 19 | cached_download = partial(cached_download, library_name="dgmr") 20 | except ImportError: 21 | hf_hub_url = None 22 | cached_download = None 23 | 24 | from huggingface_hub import CONFIG_NAME, PYTORCH_WEIGHTS_NAME, ModelHubMixin, hf_hub_download 25 | 26 | MODEL_CARD_MARKDOWN = """--- 27 | license: mit 28 | tags: 29 | - nowcasting 30 | - forecasting 31 | - timeseries 32 | - remote-sensing 33 | - gan 34 | --- 35 | 36 | # {model_name} 37 | 38 | ## Model description 39 | 40 | [More information needed] 41 | 42 | ## Intended uses & limitations 43 | 44 | [More information needed] 45 | 46 | ## How to use 47 | 48 | [More information needed] 49 | 50 | ## Limitations and bias 51 | 52 | [More information needed] 53 | 54 | ## Training data 55 | 56 | [More information needed] 57 | 58 | ## Training procedure 59 | 60 | [More information needed] 61 | 62 | ## Evaluation results 63 | 64 | [More information needed] 65 | 66 | """ 67 | 68 | _logger = logging.getLogger(__name__) 69 | 70 | 71 | class NowcastingModelHubMixin(ModelHubMixin): 72 | """ 73 | HuggingFace ModelHubMixin containing specific adaptions for Nowcasting models 74 | """ 75 | 76 | def __init__(self, *args, **kwargs): 77 | """ 78 | Mixin for pl.LightningModule and Hugging Face 79 | 80 | Mix this class with your pl.LightningModule class to easily push / download 81 | the model via the Hugging Face Hub 82 | 83 | Example:: 84 | 85 | >>> from dgmr.hub import NowcastingModelHubMixin 86 | 87 | >>> class MyModel(nn.Module, NowcastingModelHubMixin): 88 | ... def __init__(self, **kwargs): 89 | ... super().__init__() 90 | ... self.layer = ... 91 | ... def forward(self, ...) 92 | ... return ... 93 | 94 | >>> model = MyModel() 95 | >>> model.push_to_hub("mymodel") # Pushing model-weights to hf-hub 96 | 97 | >>> # Downloading weights from hf-hub & model will be initialized from those weights 98 | >>> model = MyModel.from_pretrained("username/mymodel") 99 | """ 100 | 101 | def _create_model_card(self, path): 102 | model_card = MODEL_CARD_MARKDOWN.format(model_name=type(self).__name__) 103 | with open(os.path.join(path, "README.md"), "w") as f: 104 | f.write(model_card) 105 | 106 | def _save_config(self, module, save_directory): 107 | config = dict(module.hparams) 108 | path = os.path.join(save_directory, CONFIG_NAME) 109 | with open(path, "w") as f: 110 | json.dump(config, f) 111 | 112 | def _save_pretrained(self, save_directory: str, save_config: bool = True): 113 | # Save model weights 114 | path = os.path.join(save_directory, PYTORCH_WEIGHTS_NAME) 115 | model_to_save = self.module if hasattr(self, "module") else self 116 | torch.save(model_to_save.state_dict(), path) 117 | # Save model config 118 | if save_config and model_to_save.hparams: 119 | self._save_config(model_to_save, save_directory) 120 | # Save model card 121 | self._create_model_card(save_directory) 122 | 123 | @classmethod 124 | def _from_pretrained( 125 | cls, 126 | model_id, 127 | revision, 128 | cache_dir, 129 | force_download, 130 | proxies, 131 | resume_download, 132 | local_files_only, 133 | use_auth_token, 134 | map_location="cpu", 135 | strict=False, 136 | **model_kwargs, 137 | ): 138 | map_location = torch.device(map_location) 139 | 140 | if os.path.isdir(model_id): 141 | print("Loading weights from local directory") 142 | model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME) 143 | else: 144 | model_file = hf_hub_download( 145 | repo_id=model_id, 146 | filename=PYTORCH_WEIGHTS_NAME, 147 | revision=revision, 148 | cache_dir=cache_dir, 149 | force_download=force_download, 150 | proxies=proxies, 151 | resume_download=resume_download, 152 | use_auth_token=use_auth_token, 153 | local_files_only=local_files_only, 154 | ) 155 | model = cls(**model_kwargs["config"]) 156 | 157 | state_dict = torch.load(model_file, map_location=map_location) 158 | model.load_state_dict(state_dict, strict=strict) 159 | model.eval() 160 | 161 | return model 162 | -------------------------------------------------------------------------------- /experiments/dgmr-original/dgmr/layers/Attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | import einops 5 | 6 | 7 | def attention_einsum(q, k, v): 8 | """Apply the attention operator to tensors of shape [h, w, c].""" 9 | 10 | # Reshape 3D tensors to 2D tensor with first dimension L = h x w. 11 | k = einops.rearrange(k, "h w c -> (h w) c") # [h, w, c] -> [L, c] 12 | v = einops.rearrange(v, "h w c -> (h w) c") # [h, w, c] -> [L, c] 13 | 14 | # Einstein summation corresponding to the query * key operation. 15 | beta = F.softmax(torch.einsum("hwc, Lc->hwL", q, k), dim=-1) 16 | 17 | # Einstein summation corresponding to the attention * value operation. 18 | out = torch.einsum("hwL, Lc->hwc", beta, v) 19 | return out 20 | 21 | 22 | class AttentionLayer(torch.nn.Module): 23 | """Attention Module""" 24 | 25 | def __init__(self, input_channels: int, output_channels: int, ratio_kq=8, ratio_v=8): 26 | super(AttentionLayer, self).__init__() 27 | 28 | self.ratio_kq = ratio_kq 29 | self.ratio_v = ratio_v 30 | self.output_channels = output_channels 31 | self.input_channels = input_channels 32 | 33 | # Compute query, key and value using 1x1 convolutions. 34 | self.query = torch.nn.Conv2d( 35 | in_channels=input_channels, 36 | out_channels=self.output_channels // self.ratio_kq, 37 | kernel_size=(1, 1), 38 | padding="valid", 39 | bias=False, 40 | ) 41 | self.key = torch.nn.Conv2d( 42 | in_channels=input_channels, 43 | out_channels=self.output_channels // self.ratio_kq, 44 | kernel_size=(1, 1), 45 | padding="valid", 46 | bias=False, 47 | ) 48 | self.value = torch.nn.Conv2d( 49 | in_channels=input_channels, 50 | out_channels=self.output_channels // self.ratio_v, 51 | kernel_size=(1, 1), 52 | padding="valid", 53 | bias=False, 54 | ) 55 | 56 | self.last_conv = torch.nn.Conv2d( 57 | in_channels=self.output_channels // 8, 58 | out_channels=self.output_channels, 59 | kernel_size=(1, 1), 60 | padding="valid", 61 | bias=False, 62 | ) 63 | 64 | # Learnable gain parameter 65 | self.gamma = nn.Parameter(torch.zeros(1)) 66 | 67 | def forward(self, x: torch.Tensor) -> torch.Tensor: 68 | # Compute query, key and value using 1x1 convolutions. 69 | query = self.query(x) 70 | key = self.key(x) 71 | value = self.value(x) 72 | # Apply the attention operation. 73 | out = [] 74 | for b in range(x.shape[0]): 75 | # Apply to each in batch 76 | out.append(attention_einsum(query[b], key[b], value[b])) 77 | out = torch.stack(out, dim=0) 78 | out = self.gamma * self.last_conv(out) 79 | # Residual connection. 80 | return out + x 81 | -------------------------------------------------------------------------------- /experiments/dgmr-original/dgmr/layers/ConvGRU.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.nn.utils.parametrizations import spectral_norm 4 | 5 | 6 | class ConvGRUCell(torch.nn.Module): 7 | """A ConvGRU implementation.""" 8 | 9 | def __init__(self, input_channels: int, output_channels: int, kernel_size=3, sn_eps=0.0001): 10 | """Constructor. 11 | 12 | Args: 13 | kernel_size: kernel size of the convolutions. Default: 3. 14 | sn_eps: constant for spectral normalization. Default: 1e-4. 15 | """ 16 | super().__init__() 17 | self._kernel_size = kernel_size 18 | self._sn_eps = sn_eps 19 | self.read_gate_conv = spectral_norm( 20 | torch.nn.Conv2d( 21 | in_channels=input_channels, 22 | out_channels=output_channels, 23 | kernel_size=(kernel_size, kernel_size), 24 | padding=1, 25 | ), 26 | eps=sn_eps, 27 | ) 28 | self.update_gate_conv = spectral_norm( 29 | torch.nn.Conv2d( 30 | in_channels=input_channels, 31 | out_channels=output_channels, 32 | kernel_size=(kernel_size, kernel_size), 33 | padding=1, 34 | ), 35 | eps=sn_eps, 36 | ) 37 | self.output_conv = spectral_norm( 38 | torch.nn.Conv2d( 39 | in_channels=input_channels, 40 | out_channels=output_channels, 41 | kernel_size=(kernel_size, kernel_size), 42 | padding=1, 43 | ), 44 | eps=sn_eps, 45 | ) 46 | 47 | def forward(self, x, prev_state): 48 | """ 49 | ConvGRU forward, returning the current+new state 50 | 51 | Args: 52 | x: Input tensor 53 | prev_state: Previous state 54 | 55 | Returns: 56 | New tensor plus the new state 57 | """ 58 | # Concatenate the inputs and previous state along the channel axis. 59 | xh = torch.cat([x, prev_state], dim=1) 60 | 61 | # Read gate of the GRU. 62 | read_gate = F.sigmoid(self.read_gate_conv(xh)) 63 | 64 | # Update gate of the GRU. 65 | update_gate = F.sigmoid(self.update_gate_conv(xh)) 66 | 67 | # Gate the inputs. 68 | gated_input = torch.cat([x, read_gate * prev_state], dim=1) 69 | 70 | # Gate the cell and state / outputs. 71 | c = F.relu(self.output_conv(gated_input)) 72 | out = update_gate * prev_state + (1.0 - update_gate) * c 73 | new_state = out 74 | 75 | return out, new_state 76 | 77 | 78 | class ConvGRU(torch.nn.Module): 79 | """ConvGRU Cell wrapper to replace tf.static_rnn in TF implementation""" 80 | 81 | def __init__( 82 | self, 83 | input_channels: int, 84 | output_channels: int, 85 | kernel_size: int = 3, 86 | sn_eps=0.0001, 87 | ): 88 | super().__init__() 89 | self.cell = ConvGRUCell(input_channels, output_channels, kernel_size, sn_eps) 90 | 91 | def forward(self, x: torch.Tensor, hidden_state=None) -> torch.Tensor: 92 | outputs = [] 93 | for step in range(len(x)): 94 | # Compute current timestep 95 | output, hidden_state = self.cell(x[step], hidden_state) 96 | outputs.append(output) 97 | # Stack outputs to return as tensor 98 | outputs = torch.stack(outputs, dim=0) 99 | return outputs 100 | -------------------------------------------------------------------------------- /experiments/dgmr-original/dgmr/layers/CoordConv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class AddCoords(nn.Module): 6 | def __init__(self, with_r=False): 7 | super().__init__() 8 | self.with_r = with_r 9 | 10 | def forward(self, input_tensor): 11 | """ 12 | Args: 13 | input_tensor: shape(batch, channel, x_dim, y_dim) 14 | """ 15 | batch_size, _, x_dim, y_dim = input_tensor.size() 16 | 17 | xx_channel = torch.arange(x_dim).repeat(1, y_dim, 1) 18 | yy_channel = torch.arange(y_dim).repeat(1, x_dim, 1).transpose(1, 2) 19 | 20 | xx_channel = xx_channel.float() / (x_dim - 1) 21 | yy_channel = yy_channel.float() / (y_dim - 1) 22 | 23 | xx_channel = xx_channel * 2 - 1 24 | yy_channel = yy_channel * 2 - 1 25 | 26 | xx_channel = xx_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3) 27 | yy_channel = yy_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3) 28 | 29 | ret = torch.cat( 30 | [input_tensor, xx_channel.type_as(input_tensor), yy_channel.type_as(input_tensor)], 31 | dim=1, 32 | ) 33 | 34 | if self.with_r: 35 | rr = torch.sqrt( 36 | torch.pow(xx_channel.type_as(input_tensor) - 0.5, 2) 37 | + torch.pow(yy_channel.type_as(input_tensor) - 0.5, 2) 38 | ) 39 | ret = torch.cat([ret, rr], dim=1) 40 | 41 | return ret 42 | 43 | 44 | class CoordConv(nn.Module): 45 | def __init__(self, in_channels, out_channels, with_r=False, **kwargs): 46 | super().__init__() 47 | self.addcoords = AddCoords(with_r=with_r) 48 | in_size = in_channels + 2 49 | if with_r: 50 | in_size += 1 51 | self.conv = nn.Conv2d(in_size, out_channels, **kwargs) 52 | 53 | def forward(self, x): 54 | ret = self.addcoords(x) 55 | ret = self.conv(ret) 56 | return ret 57 | -------------------------------------------------------------------------------- /experiments/dgmr-original/dgmr/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .Attention import AttentionLayer 2 | from .ConvGRU import ConvGRU 3 | from .CoordConv import CoordConv 4 | -------------------------------------------------------------------------------- /experiments/dgmr-original/dgmr/layers/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from dgmr.layers import CoordConv 3 | 4 | 5 | def get_conv_layer(conv_type: str = "standard") -> torch.nn.Module: 6 | if conv_type == "standard": 7 | conv_layer = torch.nn.Conv2d 8 | elif conv_type == "coord": 9 | conv_layer = CoordConv 10 | elif conv_type == "3d": 11 | conv_layer = torch.nn.Conv3d 12 | else: 13 | raise ValueError(f"{conv_type} is not a recognized Conv method") 14 | return conv_layer 15 | -------------------------------------------------------------------------------- /figs/final_leaderboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jmathur25/climatehack/96a551131e2482601265595a80bce315c60874ac/figs/final_leaderboard.png -------------------------------------------------------------------------------- /figs/model_predictions.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jmathur25/climatehack/96a551131e2482601265595a80bce315c60874ac/figs/model_predictions.gif --------------------------------------------------------------------------------