├── .gitignore
├── README.md
├── best-submission
├── doxa_cli.py
├── example.ipynb
├── figures
│ └── lr_find.png
├── inference_unet.ipynb
├── submission
│ ├── climatehack.py
│ ├── doxa.yaml
│ ├── evaluate.py
│ ├── model.py
│ └── validate.py
├── submit.sh
├── test_and_visualize.ipynb
├── timm.ipynb
├── train.sh
├── train_timm.py
├── train_unet.ipynb
├── train_unet.py
├── train_unet_fast.ipynb
├── utils
│ ├── __init__.py
│ ├── data.py
│ ├── loss.py
│ ├── models.py
│ └── preprocess.py
└── validate.sh
├── common
├── __init__.py
├── checkpointing.py
├── denoiser.py
├── loss_utils.py
├── q20.mat
└── utils.py
├── data
├── download_data.ipynb
└── download_good_sun.ipynb
├── environment.yaml
├── experiments
├── climatehack-submission
│ ├── .gitignore
│ ├── doxa_cli.py
│ ├── submission
│ │ ├── climatehack.py
│ │ ├── dgmr-oneshot
│ │ │ └── dgmr
│ │ │ │ ├── __init__.py
│ │ │ │ ├── common.py
│ │ │ │ ├── dgmr.py
│ │ │ │ ├── discriminators.py
│ │ │ │ ├── generators.py
│ │ │ │ ├── hub.py
│ │ │ │ ├── layers
│ │ │ │ ├── Attention.py
│ │ │ │ ├── ConvGRU.py
│ │ │ │ ├── CoordConv.py
│ │ │ │ ├── __init__.py
│ │ │ │ └── utils.py
│ │ │ │ └── losses.py
│ │ ├── doxa.yaml
│ │ ├── evaluate.py
│ │ └── validate.py
│ └── submit.sh
├── dgmr-dct
│ └── dgmr
│ │ ├── __init__.py
│ │ ├── common.py
│ │ ├── dgmr.py
│ │ ├── discriminators.py
│ │ ├── generators.py
│ │ ├── hub.py
│ │ ├── layers
│ │ ├── Attention.py
│ │ ├── ConvGRU.py
│ │ ├── CoordConv.py
│ │ ├── __init__.py
│ │ └── utils.py
│ │ └── losses.py
├── dgmr-multichannel
│ └── dgmr
│ │ ├── __init__.py
│ │ ├── common.py
│ │ ├── dgmr.py
│ │ ├── discriminators.py
│ │ ├── generators.py
│ │ ├── hub.py
│ │ ├── layers
│ │ ├── Attention.py
│ │ ├── ConvGRU.py
│ │ ├── CoordConv.py
│ │ ├── __init__.py
│ │ └── utils.py
│ │ └── losses.py
├── dgmr-oneshot-multichannel
│ └── dgmr
│ │ ├── __init__.py
│ │ ├── common.py
│ │ ├── dgmr.py
│ │ ├── discriminators.py
│ │ ├── generators.py
│ │ ├── hub.py
│ │ ├── layers
│ │ ├── Attention.py
│ │ ├── ConvGRU.py
│ │ ├── CoordConv.py
│ │ ├── __init__.py
│ │ └── utils.py
│ │ └── losses.py
├── dgmr-oneshot
│ └── dgmr
│ │ ├── __init__.py
│ │ ├── common.py
│ │ ├── dgmr.py
│ │ ├── discriminators.py
│ │ ├── generators.py
│ │ ├── hub.py
│ │ ├── layers
│ │ ├── Attention.py
│ │ ├── ConvGRU.py
│ │ ├── CoordConv.py
│ │ ├── __init__.py
│ │ └── utils.py
│ │ └── losses.py
├── dgmr-original
│ └── dgmr
│ │ ├── __init__.py
│ │ ├── common.py
│ │ ├── dgmr.py
│ │ ├── discriminators.py
│ │ ├── generators.py
│ │ ├── hub.py
│ │ ├── layers
│ │ ├── Attention.py
│ │ ├── ConvGRU.py
│ │ ├── CoordConv.py
│ │ ├── __init__.py
│ │ └── utils.py
│ │ └── losses.py
├── optical_flow_hyperopt.ipynb
├── plot_hist.ipynb
├── train_gen.ipynb
├── train_gen_dct.ipynb
├── train_gen_oneshot-coordconv.ipynb
├── train_gen_oneshot.ipynb
├── train_gen_oneshot_optflow.ipynb
└── visualize_and_test.ipynb
└── figs
├── final_leaderboard.png
└── model_predictions.gif
/.gitignore:
--------------------------------------------------------------------------------
1 | # OWN
2 | *.pt
3 | *.npz
4 |
5 | # Byte-compiled / optimized / DLL files
6 | __pycache__/
7 | *.py[cod]
8 | *$py.class
9 |
10 | # C extensions
11 | *.so
12 |
13 | # Distribution / packaging
14 | .Python
15 | build/
16 | develop-eggs/
17 | dist/
18 | downloads/
19 | eggs/
20 | .eggs/
21 | lib/
22 | lib64/
23 | parts/
24 | sdist/
25 | var/
26 | wheels/
27 | pip-wheel-metadata/
28 | share/python-wheels/
29 | *.egg-info/
30 | .installed.cfg
31 | *.egg
32 | MANIFEST
33 |
34 | # PyInstaller
35 | # Usually these files are written by a python script from a template
36 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
37 | *.manifest
38 | *.spec
39 |
40 | # Installer logs
41 | pip-log.txt
42 | pip-delete-this-directory.txt
43 |
44 | # Unit test / coverage reports
45 | htmlcov/
46 | .tox/
47 | .nox/
48 | .coverage
49 | .coverage.*
50 | .cache
51 | nosetests.xml
52 | coverage.xml
53 | *.cover
54 | *.py,cover
55 | .hypothesis/
56 | .pytest_cache/
57 |
58 | # Translations
59 | *.mo
60 | *.pot
61 |
62 | # Django stuff:
63 | *.log
64 | local_settings.py
65 | db.sqlite3
66 | db.sqlite3-journal
67 |
68 | # Flask stuff:
69 | instance/
70 | .webassets-cache
71 |
72 | # Scrapy stuff:
73 | .scrapy
74 |
75 | # Sphinx documentation
76 | docs/_build/
77 |
78 | # PyBuilder
79 | target/
80 |
81 | # Jupyter Notebook
82 | .ipynb_checkpoints
83 |
84 | # IPython
85 | profile_default/
86 | ipython_config.py
87 |
88 | # pyenv
89 | .python-version
90 |
91 | # pipenv
92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
95 | # install all needed dependencies.
96 | #Pipfile.lock
97 |
98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
99 | __pypackages__/
100 |
101 | # Celery stuff
102 | celerybeat-schedule
103 | celerybeat.pid
104 |
105 | # SageMath parsed files
106 | *.sage.py
107 |
108 | # Environments
109 | .env
110 | .venv
111 | env/
112 | venv/
113 | ENV/
114 | env.bak/
115 | venv.bak/
116 |
117 | # Spyder project settings
118 | .spyderproject
119 | .spyproject
120 |
121 | # Rope project settings
122 | .ropeproject
123 |
124 | # mkdocs documentation
125 | /site
126 |
127 | # mypy
128 | .mypy_cache/
129 | .dmypy.json
130 | dmypy.json
131 |
132 | # Pyre type checker
133 | .pyre/
134 |
135 | # Custom
136 | data/
137 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Climatehack
2 |
3 | This is the repository for Illinois's Climatehack Team. We earned first place on the [leaderboard](https://climatehack.ai/compete/leaderboard/universities) with a final score of 0.87992.
4 |
5 |
6 |
7 |
8 |
9 | An overview of our approach can be found [here](https://docs.google.com/presentation/d/1P_cv3R7gTRXG41wFPXT2lZe9E1GnKqtaJVqe-vsAvL0/edit?usp=sharing).
10 |
11 | Example predictions:
12 | 
13 |
14 | # Setup
15 | ```bash
16 | conda env create -f environment.yaml
17 | conda activate climatehack
18 | python -m ipykernel install --user --name=climatehack
19 | ```
20 |
21 | First, download data by running `data/download_data.ipynb`. Alternatively, you can find preprocessed data files [here](https://drive.google.com/drive/folders/1JkPKjOBtm3dlOl2fRTvaLkSu7KnZsJGw?usp=sharing). Save them into the `data` folder. We used `train.npz` and `test.npz`. They consist of data temporally cropped from 10am to 4pm UK time across the entire dataset. You could also use `data_good_sun_2020.npz` and `data_good_sun_2021.npz`, which consist of all samples where the sun elevation is at least 10 degrees. Because these crops produced datasets that could fit in-memory, all our dataloaders work in-memory.
22 |
23 |
24 | # Best Submission
25 | Our best submission earned scores exceeding 0.85 on the Climatehack [leaderboard](https://climatehack.ai/compete/leaderboard). It is relatively simple and uses the `fastai` library to pick a base model, optimizer, and learning rate scheduler. After some experimentation, we chose `xse_resnext50_deeper`. We turned it into a UNET and trained it. More info is in the [slides](https://docs.google.com/presentation/d/1P_cv3R7gTRXG41wFPXT2lZe9E1GnKqtaJVqe-vsAvL0/edit?usp=sharing).
26 |
27 | To train:
28 | ```bash
29 | cd best-submission
30 | bash train.sh
31 | ```
32 |
33 | To submit, first move the trained model `xse_resnext50_deeper.pth` into `best-submission/submission`.
34 | ```bash
35 | cd best-submission
36 | python doxa_cli.py user login
37 | bash submit.sh
38 | ```
39 |
40 | Also, check out `best-submission/test_and_visualize.ipynb` to test the model and visualize results in a nice animation. This is how we produced the animations found in `figs/model_predictions.gif`.
41 |
42 | # Experiments
43 | We conducted several experiments that showed improvements on a strong baseline. The baseline was OpenClimateFix's skillful nowcasting [repo](https://github.com/openclimatefix/skillful_nowcasting), which itself is a implementation of Deepmind's precipitation forecasting GAN. This baseline is more-or-less copied to `experiments/dgmr-original`. One important difference is that instead of training the GAN, we just train the generator. This was doing well for us and training the GAN had much slower convergence. This baseline will actually train to a score greater than 0.8 on the Climatehack [leaderboard](https://climatehack.ai/compete/leaderboard). We didn't have time to properly test these experiments on top of our best model, but we suspect they would improve results. The experiments are summarized below:
44 |
45 | Experiment | Description | Results |
46 | --- | --- | --- |
47 | DCT-Trick | Inspired by [this](https://proceedings.neurips.cc/paper/2018/file/7af6266cc52234b5aa339b16695f7fc4-Paper.pdf), we use the DCT to turn 128x128 -> 64x16x16 and IDCT to turn 64x16x16 -> 128x128. This leads to a shallower network that is autoregressive at fewer spatial resolutions. We believe this is the first time this has been done with UNETs. A fast implementation is in `common/utils.py:create_conv_dct_filter` and `common/utils.py:get_idct_filter`. | 1.8-2x speedup, small <0.005 performance drop |
48 | Denoising | We noticed a lot of blocky artifacts in predictions. These artifacts are reminiscent of JPEG/H.264 compression artifacts. We show a comparison of these artifacts in the [slides](https://docs.google.com/presentation/d/1P_cv3R7gTRXG41wFPXT2lZe9E1GnKqtaJVqe-vsAvL0/edit?usp=sharing). We found a pretrained neural network to fix them. This can definitely be done better, but we show a proof-of-concept. | No performance drop, small visual improvement. The slides have an example. |
49 | CoordConv | Meteorological phenomenon are correlated with geographic coordinates. We add 2 input channels for the geographic coordinates in OSGB form. | +0.0072 MS-SSIM improvement |
50 | Optical Flow | Optical flow does well for the first few timesteps. We add 2 input channels for the optical flow vectors. | +0.0034 MS-SSIM improvement |
51 |
52 | The folder `experiments/climatehack-submission` was used to submit these experiments.
53 | ```bash
54 | cd experiments/climatehack-submission
55 | python doxa_cli.py user login
56 | bash submit.sh
57 | ```
58 |
59 | Use `experiments/test_and_visualize.ipynb` to test the model and visualize results in a nice animation.
60 |
--------------------------------------------------------------------------------
/best-submission/doxa_cli.py:
--------------------------------------------------------------------------------
1 | # NOTE:
2 | # This file downloads the doxa_cli to allow you to upload an agent.
3 | # You do NOT need to edit this file, use it as is.
4 |
5 | import sys
6 |
7 | if sys.version_info[0] != 3:
8 | print("Please run this script using python3")
9 | sys.exit(1)
10 |
11 | import json
12 | import os
13 | import platform
14 | import stat
15 | import subprocess
16 | import tarfile
17 | import urllib.error
18 | import urllib.request
19 |
20 |
21 | # Returns `windows`, `darwin` (macos) or `linux`
22 | def get_os():
23 | system = platform.system()
24 |
25 | if system == "Linux":
26 | # The exe version works better for WSL
27 | if "microsoft" in platform.platform():
28 | return "windows"
29 |
30 | return "linux"
31 | elif system == "Windows":
32 | return "windows"
33 | elif system == "Darwin":
34 | return "darwin"
35 | else:
36 | raise Exception(f"Unknown platform {system}")
37 |
38 |
39 | def get_bin_name():
40 | bin_name = "doxa_cli"
41 | if get_os() == "windows":
42 | bin_name = "doxa_cli.exe"
43 | return bin_name
44 |
45 |
46 | def get_bin_dir():
47 | return os.path.join(os.path.dirname(__file__), "bin")
48 |
49 |
50 | def get_binary():
51 | return os.path.join(get_bin_dir(), get_bin_name())
52 |
53 |
54 | def install_binary():
55 | match_release = None
56 | try:
57 | match_release = sys.argv[1]
58 | # Arguments are not required
59 | except IndexError:
60 | pass
61 |
62 | REPO_RELEASE_URL = "https://api.github.com/repos/louisdewar/doxa/releases/latest"
63 | try:
64 | f = urllib.request.urlopen(REPO_RELEASE_URL)
65 | except urllib.error.URLError:
66 | print("There was an SSL cert verification error")
67 | print(
68 | 'If you are on a mac and you have recently installed a new version of\
69 | python then you should navigate to "/Applications/Python {VERSION}/"'
70 | )
71 | print('Then run a script in that folder called "Install Certificates.command"')
72 | sys.exit(1)
73 |
74 | response = json.loads(f.read())
75 |
76 | print("Current version tag:", response["tag_name"])
77 |
78 | assets = [
79 | asset for asset in response["assets"] if asset["name"].endswith(".tar.gz")
80 | ]
81 |
82 | # Find the release for this OS
83 | match_release = get_os()
84 | try:
85 | asset_choice = next(asset for asset in assets if match_release in asset["name"])
86 | print(
87 | 'Automatically picked {} to download based on match "{}"\n'.format(
88 | asset_choice["name"], match_release
89 | )
90 | )
91 | except StopIteration:
92 | print('Couldn\'t find "{}" in releases'.format(match_release))
93 | sys.exit(1)
94 |
95 | download_url = asset_choice["browser_download_url"]
96 |
97 | # Folder where this script is + bin
98 | bin_dir = get_bin_dir()
99 |
100 | print("Downloading", asset_choice["name"], "to", bin_dir)
101 | print("({})".format(download_url))
102 |
103 | # Clear bin directory if it exists
104 | if not os.path.exists(bin_dir):
105 | os.mkdir(bin_dir)
106 |
107 | zip_path = os.path.join(bin_dir, asset_choice["name"])
108 |
109 | # Download zip file
110 | urllib.request.urlretrieve(download_url, zip_path)
111 |
112 | # Open and extract zip file
113 | tar_file = tarfile.open(zip_path)
114 | tar_file.extractall(bin_dir)
115 | tar_file.close()
116 |
117 | # Delete zip file
118 | os.remove(zip_path)
119 |
120 | # Path to the actual binary program (called doxa_cli or doxa_cli.exe)
121 | bin_name = get_bin_name()
122 | binary_path = os.path.join(bin_dir, bin_name)
123 |
124 | if not os.path.exists(binary_path):
125 | print(f"Couldn't find the binary file `{bin_name}` in the bin directory")
126 | print("This probably means that there was a problem with the download")
127 | sys.exit(1)
128 |
129 | if get_os() != "windows":
130 | # Make binary executable
131 | st = os.stat(binary_path)
132 | os.chmod(binary_path, st.st_mode | stat.S_IEXEC)
133 |
134 | # Run help
135 | print("Installed binary\n\n")
136 |
137 |
138 | def run_command(args):
139 | bin_path = get_binary()
140 |
141 | if not os.path.exists(bin_path):
142 | install_binary()
143 | subprocess.call([bin_path] + args)
144 |
145 |
146 | if __name__ == "__main__":
147 | run_command(sys.argv[1:])
148 |
--------------------------------------------------------------------------------
/best-submission/figures/lr_find.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jmathur25/climatehack/96a551131e2482601265595a80bce315c60874ac/best-submission/figures/lr_find.png
--------------------------------------------------------------------------------
/best-submission/submission/climatehack.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from pathlib import Path
3 | from typing import List, Tuple
4 |
5 | import numpy as np
6 |
7 |
8 | class BaseEvaluator:
9 | def __init__(self) -> None:
10 | self.setup()
11 |
12 | def setup(self):
13 | """Sets up anything required for evaluation, e.g. models."""
14 | pass
15 |
16 | def predict(self, coordinates: np.ndarray, data: np.ndarray) -> np.ndarray:
17 | """Makes a prediction for the next two hours of satellite imagery.
18 |
19 | Args:
20 | coordinates (np.ndarray): the OSGB x and y coordinates (2, 128, 128)
21 | data (np.ndarray): an array of 12 128*128 satellite images (12, 128, 128)
22 |
23 | Returns:
24 | np.ndarray: an array of 24 64*64 satellite image predictions (24, 64, 64)
25 | """
26 |
27 | raise NotImplementedError(
28 | "You need to extend this class to use your trained model(s)."
29 | )
30 |
31 | def _get_io_paths(self) -> Tuple[Path, Path]:
32 | """Gets the input and output directory paths from DOXA.
33 |
34 | Returns:
35 | Tuple[Path, Path]: The input and output paths
36 | """
37 | try:
38 | return Path(sys.argv[1]), Path(sys.argv[2])
39 | except IndexError:
40 | raise Exception(
41 | f"Run using: {sys.argv[0]} [input directory] [output directory]"
42 | )
43 |
44 | def _get_group_path(self) -> str:
45 | """Gets the path for the next group to be processed.
46 |
47 | Raises:
48 | ValueError: An unknown message was received from DOXA.
49 |
50 | Returns:
51 | str: The path of the next group.
52 | """
53 |
54 | msg = input()
55 | if not msg.startswith("Process "):
56 | raise ValueError(f"Unknown messsage {msg}")
57 |
58 | return msg[8:]
59 |
60 | def _evaluate_group(self, group: dict) -> List[np.ndarray]:
61 | """Evaluates a group of satellite image sequences using
62 | the user-implemented model(s).
63 |
64 | Args:
65 | group (dict): The OSGB and satellite imagery data.
66 |
67 | Returns:
68 | List[np.ndarray]: The predictions.
69 | """
70 |
71 | return [self.predict(*datum) for datum in zip(group["osgb"], group["data"])]
72 |
73 | def evaluate(self):
74 | """Evaluates the user's model on DOXA.
75 |
76 | Messages are sent and received through stdio.
77 |
78 | The input data is loaded from a directory in groups.
79 |
80 | The predictions are written to another directory in groups.
81 |
82 | Raises:
83 | Exception: An error occurred somewhere.
84 | """
85 |
86 | print("STARTUP")
87 | input_path, output_path = self._get_io_paths()
88 |
89 | # process test data groups
90 | while True:
91 | # load the data for the group DOXA requests
92 | group_path = self._get_group_path()
93 | group_data = np.load(input_path / group_path)
94 |
95 | # make predictions for this group
96 | try:
97 | predictions = self._evaluate_group(group_data)
98 | except Exception as err:
99 | raise Exception(f"Error while processing {group_path}: {str(err)}")
100 |
101 | # save the output group predictions
102 | np.savez(
103 | output_path / group_path,
104 | data=np.stack(predictions),
105 | )
106 | print(f"Exported {group_path}")
107 |
--------------------------------------------------------------------------------
/best-submission/submission/doxa.yaml:
--------------------------------------------------------------------------------
1 | language: python
2 | entrypoint: evaluate.py
3 |
--------------------------------------------------------------------------------
/best-submission/submission/evaluate.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from functools import partial
5 |
6 | from climatehack import BaseEvaluator
7 |
8 | # from model import Model
9 |
10 | from fastai.vision.all import create_unet_model, models
11 | from fastai.vision.models.xresnet import *
12 | from fastai.layers import Mish
13 |
14 |
15 | class Evaluator(BaseEvaluator):
16 | def setup(self):
17 | """Sets up anything required for evaluation.
18 |
19 | In this case, it loads the trained model (in evaluation mode)."""
20 |
21 | arch = partial(xse_resnext50_deeper, act_cls=Mish, sa=True)
22 | self.model = create_unet_model(
23 | arch=arch,
24 | img_size=(128, 128),
25 | n_out=24,
26 | pretrained=False,
27 | n_in=12,
28 | self_attention=True,
29 | ).cpu()
30 |
31 | self.model.load_state_dict(
32 | torch.load("xse_resnext50_deeper.pth", map_location="cpu")
33 | )
34 | self.model.eval()
35 |
36 | self.arcnn = ARCNN(weights).to("cpu").eval()
37 |
38 | def predict(self, coordinates: np.ndarray, data: np.ndarray) -> np.ndarray:
39 | """Makes a prediction for the next two hours of satellite imagery.
40 |
41 | Args:
42 | coordinates (np.ndarray): the OSGB x and y coordinates (2, 128, 128)
43 | data (np.ndarray): an array of 12 128*128 satellite images (12, 128, 128)
44 |
45 | Returns:
46 | np.ndarray: an array of 24 64*64 satellite image predictions (24, 64, 64)
47 | """
48 |
49 | assert coordinates.shape == (2, 128, 128)
50 | assert data.shape == (12, 128, 128)
51 |
52 | with torch.no_grad():
53 |
54 | inp = torch.from_numpy(data).unsqueeze(0)
55 | preds = self.model(inp)
56 | assert prediction.shape == (24, 64, 64)
57 |
58 | return prediction
59 |
60 |
61 | def main():
62 | evaluator = Evaluator()
63 | evaluator.evaluate()
64 |
65 |
66 | if __name__ == "__main__":
67 | main()
68 |
--------------------------------------------------------------------------------
/best-submission/submission/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from fastai.vision.all import *
5 |
6 | #########################################
7 | # Improve this basic model! #
8 | #########################################
9 |
10 |
11 | class Model(nn.Module):
12 | def __init__(self):
13 | super().__init__()
14 |
15 | self.layer1 = nn.Linear(in_features=12 * 128 * 128, out_features=256)
16 | self.layer2 = nn.Linear(in_features=256, out_features=256)
17 | self.layer3 = nn.Linear(in_features=256, out_features=24 * 64 * 64)
18 |
19 | def forward(self, features):
20 | x = features.view(-1, 12 * 128 * 128) / 1024.0
21 | x = torch.relu(self.layer1(x))
22 | x = torch.relu(self.layer2(x))
23 | x = torch.relu(self.layer3(x))
24 |
25 | return x.view(-1, 24, 64, 64) * 1024.0
26 |
27 |
28 | class CenterCrop(nn.Module):
29 | def __init__(self, size=(64, 64)):
30 | super().__init__()
31 |
32 | def forward(self, x):
33 | return x[:, :, 32:96, 32:96]
34 |
35 |
36 | class UNet(nn.Module):
37 | def __init__(self, in_channels=2, out_channels=24, weights=None):
38 | super(UNet, self).__init__()
39 | self.model = create_unet_model(
40 | arch=models.resnet50,
41 | img_size=(128, 128),
42 | n_out=in_channels,
43 | pretrained=True,
44 | n_in=out_channels,
45 | self_attention=True,
46 | )
47 |
48 | if weights:
49 | self.load_state_dict(torch.load(weights))
50 | self.model.layers.add_module("CenterCrop", CenterCrop())
51 |
52 | def forward(self, x):
53 | return self.model(x)
54 |
55 |
56 | class EnsembleNet(nn.Module):
57 | def __init__(self, in_channels=2, out_channels=24, weights=None):
58 | super(EnsembleNet, self).__init__()
59 | self.model = create_unet_model(
60 | arch=models.resnet50,
61 | img_size=(128, 128),
62 | n_out=in_channels,
63 | pretrained=True,
64 | n_in=out_channels,
65 | self_attention=True,
66 | )
67 |
68 | self.models = models
69 |
70 | assert len(weights) == len(self.models)
71 |
72 | for i, model in enumerate(self.models):
73 | model.layers.add_module("CenterCrop", CenterCrop())
74 | self.add_module(f"model_{i}", model)
75 |
76 | if weights:
77 | self.load_state_dict(torch.load(weights))
78 | self.model.layers.add_module("CenterCrop", CenterCrop())
79 |
80 | def forward(self, x):
81 |
82 | outputs = []
83 | for model in self.models:
84 | outputs.append(model(x))
85 |
86 | return torch.stack(outputs, dim=1).mean(dim=1)
87 |
--------------------------------------------------------------------------------
/best-submission/submission/validate.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from pytorch_msssim import MS_SSIM
3 | from torch import from_numpy
4 |
5 | from evaluate import Evaluator
6 |
7 |
8 | def main():
9 | features = np.load("../climatehack-submission/features.npz")
10 | targets = np.load("../climatehack-submission/targets.npz")
11 |
12 | criterion = MS_SSIM(data_range=1023.0, size_average=True, win_size=3, channel=1)
13 | evaluator = Evaluator()
14 |
15 | scores = [
16 | criterion(
17 | from_numpy(evaluator.predict(*datum)).view(24, 64, 64).unsqueeze(dim=1),
18 | from_numpy(target).view(24, 64, 64).unsqueeze(dim=1),
19 | ).item()
20 | for *datum, target in zip(features["osgb"], features["data"], targets["data"])
21 | ]
22 |
23 | print(f"Score: {np.mean(scores)} ({np.std(scores)})")
24 |
25 |
26 | if __name__ == "__main__":
27 | main()
28 |
--------------------------------------------------------------------------------
/best-submission/submit.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -eEou pipefail
4 |
5 | python doxa_cli.py agent upload climatehack ./submission
6 |
--------------------------------------------------------------------------------
/best-submission/train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | CUDA_VISIBLE_DEVICES=0 conda run -n climatehack --no-capture-output --live-stream python /home/iyaja/Git/climatehack/train_unet.py
3 |
--------------------------------------------------------------------------------
/best-submission/train_timm.py:
--------------------------------------------------------------------------------
1 | # %%
2 | import torch
3 | import torchvision.models as models
4 | from torch.utils.data import DataLoader, ConcatDataset
5 | import xarray as xr
6 | import wandb
7 |
8 | # custom
9 | from utils.loss import MS_SSIMLoss
10 | from utils.data import ClimatehackDataset, CustomDataset
11 | from utils.models import create_timm_model, TimmUnet
12 |
13 | # %%
14 | from fastai.vision.all import *
15 | from fastai.callback.wandb import *
16 | from fastai.callback.tracker import *
17 | from fastai.distributed import *
18 | from fastai.vision.models.xresnet import *
19 | from fastai.optimizer import ranger
20 | from fastai.layers import Mish
21 |
22 | # %%
23 | ARCH = "efficientnetv2_xl"
24 | BATCH_SIZE = 32
25 | FORECAST = 24
26 |
27 | wandb.init(project="climatehack", group=ARCH)
28 | # %%
29 | # train_ds = ClimatehackDataset("data/data.npz", with_osgb=True)
30 | train_ds = CustomDataset("../data/train.npz")
31 | valid_ds = CustomDataset("../data/test.npz")
32 |
33 | train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
34 | valid_loader = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False)
35 |
36 | dls = DataLoaders(train_loader, valid_loader, device=torch.device("cuda"))
37 | # %%
38 | criterion = MS_SSIMLoss(channels=FORECAST, crop=True)
39 |
40 | # %%
41 | # arch = partial(xse_resnext50_deeper, act_cls=Mish, sa=True)
42 | # model = create_unet_model(
43 | # arch=arch,
44 | # img_size=(128, 128),
45 | # n_out=24,
46 | # pretrained=False,
47 | # n_in=train_ds[0][0].shape[0],
48 | # self_attention=True,
49 | # )
50 |
51 | model = TimmUnet(
52 | encoder=ARCH,
53 | # img_size=(128, 128),
54 | n_in=train_ds[0][0].shape[0],
55 | n_out=train_ds[0][1].shape[0],
56 | pretrained=True,
57 | # bottleneck='conv',
58 | # self_attention=True,
59 | )
60 |
61 | # %%
62 | callbacks = [
63 | SaveModelCallback(monitor="train_loss", fname=ARCH),
64 | # ReduceLROnPlateau(monitor="train_loss"),
65 | # EarlyStoppingCallback(monitor="val_loss", patience=10, mode="min"),
66 | WandbCallback(),
67 | ]
68 | learn = Learner(
69 | dls,
70 | model,
71 | loss_func=criterion,
72 | cbs=callbacks,
73 | model_dir="checkpoints",
74 | opt_func=ranger,
75 | )
76 |
77 | # %%
78 | learn.fit_flat_cos(100, 1e-3)
79 |
--------------------------------------------------------------------------------
/best-submission/train_unet.py:
--------------------------------------------------------------------------------
1 | # %%
2 | import torch
3 | import torchvision.models as models
4 | from torch.utils.data import DataLoader, ConcatDataset
5 | import xarray as xr
6 | import wandb
7 |
8 | # custom
9 | from utils.loss import MS_SSIMLoss
10 | from utils.data import ClimatehackDataset, CustomDataset
11 |
12 | # %%
13 | from fastai.vision.all import *
14 | from fastai.callback.wandb import *
15 | from fastai.callback.tracker import *
16 | from fastai.distributed import *
17 | from fastai.vision.models.xresnet import *
18 | from fastai.optimizer import ranger
19 | from fastai.layers import Mish
20 |
21 | # %%
22 | NAME = "xse_resnext50_deeper"
23 | BATCH_SIZE = 32
24 | FORECAST = 24
25 |
26 | wandb.init(project="climatehack", group=NAME)
27 | # %%
28 | # train_ds = ClimatehackDataset("data/data.npz", with_osgb=True)
29 | train_ds = CustomDataset("../data/train.npz")
30 | valid_ds = CustomDataset("../data/test.npz")
31 |
32 | train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
33 | valid_loader = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False)
34 |
35 | dls = DataLoaders(train_loader, valid_loader, device=torch.device("cuda"))
36 | # %%
37 | criterion = MS_SSIMLoss(channels=FORECAST, crop=True)
38 |
39 | # %%
40 | arch = partial(xse_resnext50_deeper, act_cls=Mish, sa=True)
41 | model = create_unet_model(
42 | arch=arch,
43 | img_size=(128, 128),
44 | n_out=24,
45 | pretrained=False,
46 | n_in=train_ds[0][0].shape[0],
47 | self_attention=True,
48 | )
49 |
50 | # %%
51 | callbacks = [
52 | SaveModelCallback(monitor="train_loss", fname=NAME),
53 | ReduceLROnPlateau(monitor="train_loss", factor=2),
54 | # EarlyStoppingCallback(monitor="val_loss", patience=10, mode="min"),
55 | WandbCallback(),
56 | ]
57 | learn = Learner(
58 | dls,
59 | model,
60 | loss_func=criterion,
61 | cbs=callbacks,
62 | model_dir="checkpoints",
63 | opt_func=ranger,
64 | )
65 |
66 | # %%
67 | learn.fit_flat_cos(100, 1e-3)
68 |
--------------------------------------------------------------------------------
/best-submission/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jmathur25/climatehack/96a551131e2482601265595a80bce315c60874ac/best-submission/utils/__init__.py
--------------------------------------------------------------------------------
/best-submission/utils/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from pytorch_msssim import MS_SSIM
4 |
5 |
6 | class MS_SSIMLoss(nn.Module):
7 | """Multi-Scale SSIM Loss"""
8 |
9 | def __init__(self, channels=1, crop=True, **kwargs):
10 | """
11 | Initialize
12 | Args:
13 | convert_range: Convert input from -1,1 to 0,1 range
14 | **kwargs: Kwargs to pass through to MS_SSIM
15 | """
16 | super(MS_SSIMLoss, self).__init__()
17 | self.crop = crop
18 | self.ssim_module = MS_SSIM(
19 | data_range=1023.0, size_average=True, win_size=3, channel=channels, **kwargs
20 | )
21 |
22 | def forward(self, x: torch.Tensor, y: torch.Tensor):
23 | """
24 | Forward method
25 | Args:
26 | x: tensor one
27 | y: tensor two
28 | Returns: multi-scale SSIM Loss
29 | """
30 | if self.crop:
31 | return 1.0 - self.ssim_module(x[:, :, 32:96, 32:96], y[:, :, 32:96, 32:96])
32 | else:
33 | return 1.0 - self.ssim_module(x, y)
34 |
--------------------------------------------------------------------------------
/best-submission/utils/preprocess.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime, time, timedelta
2 |
3 | import numpy as np
4 | import xarray as xr
5 | from numpy import float32
6 | from tqdm import tqdm
7 |
8 | SATELLITE_ZARR_PATH = "data/eumetsat_seviri_hrv_uk.zarr"
9 |
10 |
11 | def main():
12 | dataset = xr.open_dataset(
13 | SATELLITE_ZARR_PATH,
14 | engine="zarr",
15 | chunks="auto",
16 | )
17 |
18 | times = dataset.get_index("time")
19 | min_date = times[0].date()
20 | max_date = times[-1].date()
21 |
22 | start_time = time(9, 0)
23 | end_time = time(16, 0)
24 |
25 | data = []
26 | date = min_date
27 | print(min_date, max_date, (max_date - min_date).days)
28 | with tqdm(total=(max_date - min_date).days) as pbar:
29 |
30 | # if you only want to preprocess a certain number of days, swap the while loop for the for loop below!
31 | # for _ in range(10):
32 |
33 | while date <= max_date:
34 | print(date)
35 | selection = (
36 | dataset["data"].sel(
37 | time=slice(
38 | datetime.combine(date, start_time),
39 | datetime.combine(date, end_time),
40 | ),
41 | )
42 | # comment out the .isel if you want the whole image
43 | .isel(
44 | x=slice(550, 950),
45 | y=slice(375, 700),
46 | )
47 | )
48 |
49 | if selection.shape == (85, 325, 400):
50 | data.append(selection.astype(float32).values)
51 |
52 | date += timedelta(days=1)
53 | pbar.update(1)
54 |
55 | clipped_data = np.clip(np.stack(data), 0.0, 1023.0)
56 |
57 | x_osgb = (
58 | dataset["x_osgb"]
59 | # comment out the .isel if you want the whole image
60 | .isel(
61 | x=slice(550, 950),
62 | y=slice(375, 700),
63 | ).values.astype(float32)
64 | )
65 |
66 | y_osgb = (
67 | dataset["y_osgb"]
68 | # comment out the .isel if you want the whole image
69 | .isel(
70 | x=slice(550, 950),
71 | y=slice(375, 700),
72 | ).values.astype(float32)
73 | )
74 |
75 | # you can also use np.savez_compressed if you'd rather compress everything!
76 | np.savez("data/sample", x_osgb=x_osgb, y_osgb=y_osgb, data=clipped_data)
77 |
78 |
79 | if __name__ == "__main__":
80 | main()
81 |
--------------------------------------------------------------------------------
/best-submission/validate.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | cd submissionv2
3 | conda run -n climatehack --no-capture-output --live-stream python /home/iyaja/Git/climatehack/submissionv2/validate.py
4 | conda run -n climatehack --no-capture-output --live-stream python /home/iyaja/Git/climatehack/submissionv2/validate_timestep.py
5 |
--------------------------------------------------------------------------------
/common/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jmathur25/climatehack/96a551131e2482601265595a80bce315c60874ac/common/__init__.py
--------------------------------------------------------------------------------
/common/checkpointing.py:
--------------------------------------------------------------------------------
1 | """
2 | I rolled my own checkpointing for no good reason. Oops.
3 | """
4 | import os
5 | import pathlib
6 | import torch
7 |
8 |
9 | class Checkpointer:
10 | def __init__(self, save_folder: pathlib.Path):
11 | self.save_folder = save_folder.absolute()
12 | self.hist = save_folder / "hist.txt"
13 | self.hist.touch()
14 |
15 | def get_best_model_info(self):
16 | """
17 | Returns info on the best model so far.
18 | """
19 | files = self.hist.read_text().split("\n")
20 | actual = [f for f in files if f.endswith(".pt")]
21 | best_fname = None
22 | best_loss = float("inf")
23 | best_epoch = None
24 | for fname in actual:
25 | if "batch" in fname:
26 | continue
27 | f = fname[:-3] # remove .pt
28 | _, epoch_str, loss_str = f.split("_")
29 | _, epoch = epoch_str.split("=")
30 | _, loss = loss_str.split("=")
31 | epoch = int(epoch)
32 | loss = float(loss)
33 | if loss < best_loss:
34 | best_loss = loss
35 | best_epoch = epoch
36 | best_fname = fname
37 | if best_fname is None:
38 | return None, None, None
39 | return self.save_folder / best_fname, best_epoch, best_loss
40 |
41 | def save_checkpoint(self, model, optimizer, epoch, avg_loss):
42 | """
43 | Save the model and optimizer state if the loss has decreased.
44 | """
45 | # see if we should save weights
46 | best_fpath, best_epoch, best_loss = self.get_best_model_info()
47 | if best_fpath is None or avg_loss < best_loss:
48 | print(f"Loss decreased from {best_loss} to {avg_loss}, saving...")
49 | checkpoint = {
50 | "model": model.state_dict(),
51 | "optimizer": optimizer.state_dict(),
52 | }
53 | avg_loss = round(avg_loss, 4)
54 | torch.save(
55 | checkpoint,
56 | self.save_folder / f"checkpoint_epochs={epoch}_loss={avg_loss}.pt",
57 | )
58 | if best_fpath is not None:
59 | print("Deleting old best:", best_fpath)
60 | os.remove(best_fpath)
61 |
62 | # save to history file regardless
63 | with self.hist.open("a") as f:
64 | f.write(f"checkpoint_epochs={epoch}_loss={avg_loss}.pt\n")
65 |
66 | def load_checkpoint(self):
67 | """
68 | Loads the best model and optimizer states.
69 | """
70 | best_fpath, best_epoch, best_loss = self.get_best_model_info()
71 | if best_fpath is None:
72 | return None, None
73 | checkpoint = torch.load(best_fpath)
74 | return checkpoint, best_epoch
75 |
--------------------------------------------------------------------------------
/common/denoiser.py:
--------------------------------------------------------------------------------
1 | """
2 | A simple denoiser that uses a pretrained neural network to fix some of the block artifacts that were found in model predictions.
3 | """
4 | import os
5 | import scipy.io
6 | import pathlib
7 | import torch.nn as nn
8 | import torch
9 |
10 | mydir = pathlib.Path(os.path.dirname(os.path.realpath(__file__)))
11 |
12 |
13 | class ARCNN(nn.Module):
14 | def __init__(self, weight):
15 | super().__init__()
16 |
17 | # PyTorch's Conv2D uses zero-padding while the matlab code uses replicate
18 | # So we need to use separate padding modules
19 | self.conv1 = nn.Conv2d(1, 64, kernel_size=9)
20 | self.conv2 = nn.Conv2d(64, 32, kernel_size=7)
21 | self.conv22 = nn.Conv2d(32, 16, kernel_size=1)
22 | self.conv3 = nn.Conv2d(16, 1, kernel_size=5)
23 |
24 | self.pad2 = nn.ReplicationPad2d(2)
25 | self.pad3 = nn.ReplicationPad2d(3)
26 | self.pad4 = nn.ReplicationPad2d(4)
27 |
28 | self.relu = nn.ReLU(inplace=True)
29 |
30 | # Load the weights from the weight dict
31 | self.conv1.weight.data = torch.from_numpy(
32 | weight["weights_conv1"]
33 | .transpose(2, 0, 1)
34 | .reshape(64, 1, 9, 9)
35 | .transpose(0, 1, 3, 2)
36 | ).float()
37 | self.conv1.bias.data = torch.from_numpy(
38 | weight["biases_conv1"].reshape(64)
39 | ).float()
40 |
41 | self.conv2.weight.data = torch.from_numpy(
42 | weight["weights_conv2"]
43 | .transpose(2, 0, 1)
44 | .reshape(32, 64, 7, 7)
45 | .transpose(0, 1, 3, 2)
46 | ).float()
47 | self.conv2.bias.data = torch.from_numpy(
48 | weight["biases_conv2"].reshape(32)
49 | ).float()
50 |
51 | self.conv22.weight.data = torch.from_numpy(
52 | weight["weights_conv22"]
53 | .transpose(2, 0, 1)
54 | .reshape(16, 32, 1, 1)
55 | .transpose(0, 1, 3, 2)
56 | ).float()
57 | self.conv22.bias.data = torch.from_numpy(
58 | weight["biases_conv22"].reshape(16)
59 | ).float()
60 |
61 | self.conv3.weight.data = torch.from_numpy(
62 | weight["weights_conv3"].reshape(1, 16, 5, 5).transpose(0, 1, 3, 2)
63 | ).float()
64 | self.conv3.bias.data = torch.from_numpy(
65 | weight["biases_conv3"].reshape(1)
66 | ).float()
67 |
68 | def forward(self, x):
69 | x = self.pad4(x)
70 | x = self.relu(self.conv1(x))
71 |
72 | x = self.pad3(x)
73 | x = self.relu(self.conv2(x))
74 | x = self.relu(self.conv22(x))
75 |
76 | x = self.pad2(x)
77 | x = self.conv3(x)
78 |
79 | return x
80 |
81 |
82 | class Denoiser:
83 | """
84 | A neural network denoiser.
85 | """
86 |
87 | def __init__(self, device):
88 | weights = scipy.io.loadmat(mydir / "q20.mat")
89 | self.arcnn = ARCNN(weights).to(device).eval()
90 |
91 | def denoise(self, x):
92 | with torch.no_grad():
93 | b, t, h, w = x.shape
94 | x = x.reshape(-1, 1, 128, 128)
95 | x = self.arcnn(x)
96 | x = x.reshape(b, t, h, w)
97 | return x
98 |
--------------------------------------------------------------------------------
/common/loss_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from pytorch_msssim import MS_SSIM
4 |
5 |
6 | class MS_SSIMLoss(nn.Module):
7 | """Multi-Scale SSIM Loss"""
8 | def __init__(self, data_range, channels=1, **kwargs):
9 | """
10 | Initialize
11 | Args:
12 | convert_range: Convert input from -1,1 to 0,1 range
13 | **kwargs: Kwargs to pass through to MS_SSIM
14 | """
15 | super(MS_SSIMLoss, self).__init__()
16 | self.ssim_module = MS_SSIM(
17 | data_range=data_range, size_average=True, win_size=3, channel=channels, **kwargs
18 | )
19 |
20 | def forward(self, x: torch.Tensor, y: torch.Tensor):
21 | """
22 | Forward method
23 | Args:
24 | x: tensor one
25 | y: tensor two
26 | Returns: multi-scale SSIM Loss
27 | """
28 | return 1.0 - self.ssim_module(x, y)
29 |
--------------------------------------------------------------------------------
/common/q20.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jmathur25/climatehack/96a551131e2482601265595a80bce315c60874ac/common/q20.mat
--------------------------------------------------------------------------------
/data/download_data.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "ee483ada",
6 | "metadata": {},
7 | "source": [
8 | " Creates `train.npz` and `test.npz` by making a simle temporal crop from 10am to 4pm across the full dataset."
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": 1,
14 | "id": "3a64f122",
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "import xarray as xr\n",
19 | "import numpy as np\n",
20 | "import pathlib\n",
21 | "import datetime\n",
22 | "import pandas as pd\n",
23 | "import tqdm"
24 | ]
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": 2,
29 | "id": "273c7e8b",
30 | "metadata": {},
31 | "outputs": [
32 | {
33 | "name": "stdout",
34 | "output_type": "stream",
35 | "text": [
36 | "\n",
37 | "Dimensions: (time: 173624, y: 891, x: 1843)\n",
38 | "Coordinates:\n",
39 | " * time (time) datetime64[ns] 2020-01-01T00:05:00 ... 2021-11-07T15:50:00\n",
40 | " * x (x) float32 2.8e+04 2.7e+04 2.6e+04 ... -1.813e+06 -1.814e+06\n",
41 | " x_osgb (y, x) float32 dask.array\n",
42 | " * y (y) float32 4.198e+06 4.199e+06 4.2e+06 ... 5.087e+06 5.088e+06\n",
43 | " y_osgb (y, x) float32 dask.array\n",
44 | "Data variables:\n",
45 | " data (time, y, x) int16 dask.array\n"
46 | ]
47 | }
48 | ],
49 | "source": [
50 | "SATELLITE_ZARR_PATH = \"gs://public-datasets-eumetsat-solar-forecasting/satellite/EUMETSAT/SEVIRI_RSS/v3/eumetsat_seviri_hrv_uk.zarr\"\n",
51 | "\n",
52 | "dataset = xr.open_dataset(\n",
53 | " SATELLITE_ZARR_PATH, \n",
54 | " engine=\"zarr\",\n",
55 | " chunks=\"auto\", # Load the data as a Dask array\n",
56 | ")\n",
57 | "\n",
58 | "print(dataset)\n"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": 3,
64 | "id": "1ea7fb67",
65 | "metadata": {},
66 | "outputs": [],
67 | "source": [
68 | "def get_day_slice(date):\n",
69 | " data_slice = dataset.loc[\n",
70 | " {\n",
71 | " # 10am to 4pm\n",
72 | " \"time\": slice(\n",
73 | " date + datetime.timedelta(hours=10),\n",
74 | " date + datetime.timedelta(hours=16),\n",
75 | " )\n",
76 | " }\n",
77 | " ].isel(\n",
78 | " x=slice(550, 950),\n",
79 | " y=slice(375, 700),\n",
80 | " )\n",
81 | " \n",
82 | " # sometimes there is no data\n",
83 | " if len(data_slice.time) == 0:\n",
84 | " return None\n",
85 | " return data_slice"
86 | ]
87 | },
88 | {
89 | "cell_type": "code",
90 | "execution_count": 4,
91 | "id": "853691db",
92 | "metadata": {},
93 | "outputs": [],
94 | "source": [
95 | "# it might be worth it to do this in batches of one year... a full download will take at least\n",
96 | "# 30 minutes if not hours\n",
97 | "start_date = datetime.datetime(2020, 1, 1)\n",
98 | "end_date = datetime.datetime(2021, 12, 31)\n",
99 | "\n",
100 | "cur = start_date\n",
101 | "days_to_get = []\n",
102 | "while cur != end_date + datetime.timedelta(days=1):\n",
103 | " days_to_get.append(cur)\n",
104 | " cur = cur + datetime.timedelta(days=1)"
105 | ]
106 | },
107 | {
108 | "cell_type": "code",
109 | "execution_count": 5,
110 | "id": "2133c65e",
111 | "metadata": {},
112 | "outputs": [
113 | {
114 | "name": "stderr",
115 | "output_type": "stream",
116 | "text": [
117 | "100%|██████████| 731/731 [00:00<00:00, 817.82it/s]\n"
118 | ]
119 | }
120 | ],
121 | "source": [
122 | "slices = []\n",
123 | "for date in tqdm.tqdm(days_to_get):\n",
124 | " slc = get_day_slice(date)\n",
125 | " if slc is None:\n",
126 | " continue\n",
127 | " slices.append(slc)"
128 | ]
129 | },
130 | {
131 | "cell_type": "code",
132 | "execution_count": 6,
133 | "id": "0c36633f",
134 | "metadata": {},
135 | "outputs": [
136 | {
137 | "data": {
138 | "text/plain": [
139 | "606"
140 | ]
141 | },
142 | "execution_count": 6,
143 | "metadata": {},
144 | "output_type": "execute_result"
145 | }
146 | ],
147 | "source": [
148 | "len(slices)"
149 | ]
150 | },
151 | {
152 | "cell_type": "code",
153 | "execution_count": null,
154 | "id": "98b4495b",
155 | "metadata": {},
156 | "outputs": [],
157 | "source": [
158 | "combined = xr.concat(slices, dim='time')"
159 | ]
160 | },
161 | {
162 | "cell_type": "code",
163 | "execution_count": null,
164 | "id": "e52d267f",
165 | "metadata": {},
166 | "outputs": [],
167 | "source": [
168 | "# takes a while\n",
169 | "times = combined['time'].to_numpy()\n",
170 | "x = combined['x'].to_numpy()\n",
171 | "x_osgb = combined['x_osgb'].to_numpy()\n",
172 | "y = combined['y'].to_numpy()\n",
173 | "y_osgb = combined['y_osgb'].to_numpy()\n",
174 | "%time data = combined['data'].to_numpy()"
175 | ]
176 | },
177 | {
178 | "cell_type": "code",
179 | "execution_count": null,
180 | "id": "a2f4d989",
181 | "metadata": {},
182 | "outputs": [],
183 | "source": [
184 | "times.shape, data.shape\n"
185 | ]
186 | },
187 | {
188 | "cell_type": "code",
189 | "execution_count": null,
190 | "id": "0c797430",
191 | "metadata": {},
192 | "outputs": [],
193 | "source": [
194 | "# due to some weird things in how we originally made `train.npz` and `test.npz`,\n",
195 | "# this will not be equivalent to the datasets we linked in the drive. But they\n",
196 | "# should still work.\n",
197 | "test_days = 30\n",
198 | "\n",
199 | "np.random.seed(7)\n",
200 | "test_dates = np.random.choice(days_to_get, size=test_days, replace=False)\n",
201 | "\n",
202 | "test_indices = []\n",
203 | "train_indices = []\n",
204 | "for i, t in enumerate(times):\n",
205 | " d = pd.Timestamp(t).date()\n",
206 | " if d in test_dates:\n",
207 | " test_indices.append(i)\n",
208 | " else:\n",
209 | " train_indices.append(i)"
210 | ]
211 | },
212 | {
213 | "cell_type": "code",
214 | "execution_count": null,
215 | "id": "8ffee4e0",
216 | "metadata": {},
217 | "outputs": [],
218 | "source": [
219 | "test_data = data[test_indices]\n",
220 | "test_times = times[test_indices]\n",
221 | "\n",
222 | "train_data = data[train_indices]\n",
223 | "train_times = times[train_indices]"
224 | ]
225 | },
226 | {
227 | "cell_type": "code",
228 | "execution_count": null,
229 | "id": "4d1fa9e8",
230 | "metadata": {},
231 | "outputs": [],
232 | "source": [
233 | "# save data\n",
234 | "p = pathlib.Path(f'train.npz')\n",
235 | "if p.exists():\n",
236 | " raise ValueError(f'Path {p} already exists!')\n",
237 | "\n",
238 | "np.savez(\n",
239 | " p,\n",
240 | " times=train_times,\n",
241 | " data=train_data,\n",
242 | ")\n",
243 | "\n",
244 | "p = pathlib.Path(f'test.npz')\n",
245 | "if p.exists():\n",
246 | " raise ValueError(f'Path {p} already exists!')\n",
247 | "\n",
248 | "np.savez(\n",
249 | " p,\n",
250 | " times=test_times,\n",
251 | " data=test_data,\n",
252 | ")\n",
253 | "\n",
254 | "p = pathlib.Path(f'coords.npz')\n",
255 | "if p.exists():\n",
256 | " raise ValueError(f'Path {p} already exists!')\n",
257 | "\n",
258 | "np.savez(\n",
259 | " p,\n",
260 | " x=x,\n",
261 | " x_osgb=x_osgb,\n",
262 | " y=y,\n",
263 | " y_osgb=y_osgb,\n",
264 | ")"
265 | ]
266 | },
267 | {
268 | "cell_type": "code",
269 | "execution_count": null,
270 | "id": "e270bb59",
271 | "metadata": {},
272 | "outputs": [],
273 | "source": []
274 | }
275 | ],
276 | "metadata": {
277 | "kernelspec": {
278 | "display_name": "climatehack",
279 | "language": "python",
280 | "name": "climatehack"
281 | },
282 | "language_info": {
283 | "codemirror_mode": {
284 | "name": "ipython",
285 | "version": 3
286 | },
287 | "file_extension": ".py",
288 | "mimetype": "text/x-python",
289 | "name": "python",
290 | "nbconvert_exporter": "python",
291 | "pygments_lexer": "ipython3",
292 | "version": "3.9.7"
293 | }
294 | },
295 | "nbformat": 4,
296 | "nbformat_minor": 5
297 | }
298 |
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: climatehack
2 | channels:
3 | - pvlib
4 | - pytorch
5 | - defaults
6 | - conda-forge
7 | dependencies:
8 | - _libgcc_mutex=0.1=main
9 | - _openmp_mutex=4.5=1_gnu
10 | - aiohttp=3.8.1=py39h7f8727e_1
11 | - aiosignal=1.2.0=pyhd3eb1b0_0
12 | - asciitree=0.3.3=py_2
13 | - asttokens=2.0.5=pyhd3eb1b0_0
14 | - async-timeout=4.0.1=pyhd3eb1b0_0
15 | - attrs=21.4.0=pyhd3eb1b0_0
16 | - backcall=0.2.0=pyhd3eb1b0_0
17 | - blas=1.0=mkl
18 | - blinker=1.4=py39h06a4308_0
19 | - bokeh=2.4.2=py39h06a4308_0
20 | - bottleneck=1.3.4=py39hce1f21e_0
21 | - brotli=1.0.9=he6710b0_2
22 | - brotlipy=0.7.0=py39h27cfd23_1003
23 | - bzip2=1.0.8=h7b6447c_0
24 | - c-ares=1.18.1=h7f8727e_0
25 | - ca-certificates=2022.3.18=h06a4308_0
26 | - cachetools=4.2.2=pyhd3eb1b0_0
27 | - cartopy=0.18.0=py39hc576cba_1
28 | - certifi=2021.10.8=py39h06a4308_2
29 | - cffi=1.15.0=py39hd667e15_1
30 | - cftime=1.5.1.1=py39hce1f21e_0
31 | - charset-normalizer=2.0.4=pyhd3eb1b0_0
32 | - click=8.0.4=py39h06a4308_0
33 | - cloudpickle=2.0.0=pyhd3eb1b0_0
34 | - cryptography=3.4.8=py39hd23ed53_0
35 | - cudatoolkit=11.3.1=h2bc3f7f_2
36 | - curl=7.80.0=h7f8727e_0
37 | - cycler=0.11.0=pyhd3eb1b0_0
38 | - cytoolz=0.11.0=py39h27cfd23_0
39 | - dask=2022.2.1=pyhd3eb1b0_0
40 | - dask-core=2022.2.1=pyhd3eb1b0_0
41 | - dataclasses=0.8=pyh6d0b6a4_7
42 | - dbus=1.13.18=hb2f20db_0
43 | - debugpy=1.5.1=py39h295c915_0
44 | - decorator=5.1.1=pyhd3eb1b0_0
45 | - distributed=2022.2.1=pyhd3eb1b0_0
46 | - entrypoints=0.3=py39h06a4308_0
47 | - ephem=4.1.2=py39h7f8727e_0
48 | - executing=0.8.3=pyhd3eb1b0_0
49 | - expat=2.4.4=h295c915_0
50 | - fasteners=0.16.3=pyhd3eb1b0_0
51 | - ffmpeg=4.3=hf484d3e_0
52 | - fontconfig=2.13.1=h6c09931_0
53 | - fonttools=4.25.0=pyhd3eb1b0_0
54 | - freetype=2.11.0=h70c0345_0
55 | - frozenlist=1.2.0=py39h7f8727e_0
56 | - fsspec=2022.2.0=pyhd3eb1b0_0
57 | - gcsfs=2022.2.0=pyhd8ed1ab_0
58 | - geos=3.8.0=he6710b0_0
59 | - giflib=5.2.1=h7b6447c_0
60 | - glib=2.69.1=h4ff587b_1
61 | - gmp=6.2.1=h2531618_2
62 | - gnutls=3.6.15=he1e5248_0
63 | - google-api-core=2.2.2=pyhd3eb1b0_0
64 | - google-auth=2.6.0=pyhd3eb1b0_0
65 | - google-auth-oauthlib=0.4.4=pyhd3eb1b0_0
66 | - google-cloud-core=2.2.2=pyhd3eb1b0_0
67 | - google-cloud-storage=2.1.0=pyh6c4a22f_0
68 | - google-crc32c=1.1.2=py39h27cfd23_0
69 | - google-resumable-media=1.3.1=pyhd3eb1b0_1
70 | - googleapis-common-protos=1.53.0=py39h06a4308_0
71 | - grpcio=1.42.0=py39hce63b2e_0
72 | - gst-plugins-base=1.14.0=h8213a91_2
73 | - gstreamer=1.14.0=h28cd5cc_2
74 | - h5py=3.6.0=py39ha0f2276_0
75 | - hdf4=4.2.13=h3ca952b_2
76 | - hdf5=1.10.6=hb1b8bf9_0
77 | - heapdict=1.0.1=pyhd3eb1b0_0
78 | - icu=58.2=he6710b0_3
79 | - idna=3.3=pyhd3eb1b0_0
80 | - intel-openmp=2021.4.0=h06a4308_3561
81 | - ipykernel=6.9.1=py39h06a4308_0
82 | - ipython=8.1.1=py39h06a4308_0
83 | - jedi=0.18.1=py39h06a4308_1
84 | - jinja2=3.0.3=pyhd3eb1b0_0
85 | - jpeg=9d=h7f8727e_0
86 | - jupyter_client=7.1.2=pyhd3eb1b0_0
87 | - jupyter_core=4.9.2=py39h06a4308_0
88 | - kiwisolver=1.3.2=py39h295c915_0
89 | - krb5=1.19.2=hac12032_0
90 | - lame=3.100=h7b6447c_0
91 | - lcms2=2.12=h3be6417_0
92 | - ld_impl_linux-64=2.35.1=h7274673_9
93 | - libcrc32c=1.1.1=he6710b0_2
94 | - libcurl=7.80.0=h0b77cf5_0
95 | - libedit=3.1.20210910=h7f8727e_0
96 | - libev=4.33=h7f8727e_1
97 | - libffi=3.3=he6710b0_2
98 | - libgcc-ng=9.3.0=h5101ec6_17
99 | - libgfortran-ng=7.5.0=ha8ba4b0_17
100 | - libgfortran4=7.5.0=ha8ba4b0_17
101 | - libgomp=9.3.0=h5101ec6_17
102 | - libiconv=1.15=h63c8f33_5
103 | - libidn2=2.3.2=h7f8727e_0
104 | - libllvm11=11.1.0=h3826bc1_1
105 | - libnetcdf=4.8.1=h42ceab0_1
106 | - libnghttp2=1.46.0=hce63b2e_0
107 | - libpng=1.6.37=hbc83047_0
108 | - libprotobuf=3.19.1=h4ff587b_0
109 | - libsodium=1.0.18=h7b6447c_0
110 | - libssh2=1.9.0=h1ba5d50_1
111 | - libstdcxx-ng=9.3.0=hd4cf53a_17
112 | - libtasn1=4.16.0=h27cfd23_0
113 | - libtiff=4.2.0=h85742a9_0
114 | - libunistring=0.9.10=h27cfd23_0
115 | - libuuid=1.0.3=h7f8727e_2
116 | - libuv=1.40.0=h7b6447c_0
117 | - libwebp=1.2.2=h55f646e_0
118 | - libwebp-base=1.2.2=h7f8727e_0
119 | - libxcb=1.14=h7b6447c_0
120 | - libxml2=2.9.12=h03d6c58_0
121 | - libzip=1.5.1=h8d318fa_1003
122 | - llvmlite=0.38.0=py39h4ff587b_0
123 | - locket=0.2.1=py39h06a4308_2
124 | - lz4-c=1.9.3=h295c915_1
125 | - markupsafe=2.0.1=py39h27cfd23_0
126 | - matplotlib=3.5.1=py39h06a4308_1
127 | - matplotlib-base=3.5.1=py39ha18d171_1
128 | - matplotlib-inline=0.1.2=pyhd3eb1b0_2
129 | - mkl=2021.4.0=h06a4308_640
130 | - mkl-service=2.4.0=py39h7f8727e_0
131 | - mkl_fft=1.3.1=py39hd3c417c_0
132 | - mkl_random=1.2.2=py39h51133e4_0
133 | - msgpack-python=1.0.2=py39hff7bd54_1
134 | - multidict=5.2.0=py39h7f8727e_2
135 | - munkres=1.1.4=py_0
136 | - ncurses=6.3=h7f8727e_2
137 | - nest-asyncio=1.5.1=pyhd3eb1b0_0
138 | - netcdf4=1.5.7=py39ha0f2276_1
139 | - nettle=3.7.3=hbbd107a_1
140 | - numba=0.55.1=py39h51133e4_0
141 | - numcodecs=0.9.1=py39h295c915_0
142 | - numexpr=2.8.1=py39h6abb31d_0
143 | - numpy=1.21.2=py39h20f2e39_0
144 | - numpy-base=1.21.2=py39h79a1101_0
145 | - oauthlib=3.2.0=pyhd3eb1b0_0
146 | - openh264=2.1.1=h4ff587b_0
147 | - openssl=1.1.1n=h7f8727e_0
148 | - packaging=21.3=pyhd3eb1b0_0
149 | - pandas=1.4.1=py39h295c915_1
150 | - parso=0.8.3=pyhd3eb1b0_0
151 | - partd=1.2.0=pyhd3eb1b0_1
152 | - patsy=0.5.2=py39h06a4308_1
153 | - pcre=8.45=h295c915_0
154 | - pexpect=4.8.0=pyhd3eb1b0_3
155 | - pickleshare=0.7.5=pyhd3eb1b0_1003
156 | - pillow=9.0.1=py39h22f2fdc_0
157 | - pip=21.2.4=py39h06a4308_0
158 | - proj=7.0.1=h59a7b90_1
159 | - prompt-toolkit=3.0.20=pyhd3eb1b0_0
160 | - protobuf=3.19.1=py39h295c915_0
161 | - psutil=5.8.0=py39h27cfd23_1
162 | - ptyprocess=0.7.0=pyhd3eb1b0_2
163 | - pure_eval=0.2.2=pyhd3eb1b0_0
164 | - pvlib=0.9.0=py_1
165 | - pyasn1=0.4.8=pyhd3eb1b0_0
166 | - pyasn1-modules=0.2.8=py_0
167 | - pycparser=2.21=pyhd3eb1b0_0
168 | - pygments=2.11.2=pyhd3eb1b0_0
169 | - pyjwt=2.1.0=py39h06a4308_0
170 | - pyopenssl=20.0.1=pyhd8ed1ab_0
171 | - pyparsing=3.0.4=pyhd3eb1b0_0
172 | - pyqt=5.9.2=py39h2531618_6
173 | - pyshp=2.1.3=pyhd3eb1b0_0
174 | - pysocks=1.7.1=py39h06a4308_0
175 | - python=3.9.7=h12debd9_1
176 | - python-dateutil=2.8.2=pyhd3eb1b0_0
177 | - pytorch-mutex=1.0=cuda
178 | - pytz=2021.3=pyhd3eb1b0_0
179 | - pyyaml=6.0=py39h7f8727e_1
180 | - pyzmq=22.3.0=py39h295c915_2
181 | - qt=5.9.7=h5867ecd_1
182 | - readline=8.1.2=h7f8727e_1
183 | - requests=2.27.1=pyhd3eb1b0_0
184 | - requests-oauthlib=1.3.0=py_0
185 | - rsa=4.7.2=pyhd3eb1b0_1
186 | - scipy=1.7.3=py39hc147768_0
187 | - shapely=1.7.1=py39h1728cc4_0
188 | - sip=4.19.13=py39h295c915_0
189 | - six=1.16.0=pyhd3eb1b0_1
190 | - sortedcontainers=2.4.0=pyhd3eb1b0_0
191 | - sqlite=3.38.0=hc218d9a_0
192 | - stack_data=0.2.0=pyhd3eb1b0_0
193 | - statsmodels=0.13.2=py39h7f8727e_0
194 | - tbb=2021.5.0=hd09550d_0
195 | - tblib=1.7.0=pyhd3eb1b0_0
196 | - tk=8.6.11=h1ccaba5_0
197 | - toolz=0.11.2=pyhd3eb1b0_0
198 | - torchaudio=0.11.0=py39_cu113
199 | - tornado=6.1=py39h27cfd23_0
200 | - traitlets=5.1.1=pyhd3eb1b0_0
201 | - typing-extensions=4.1.1=hd3eb1b0_0
202 | - typing_extensions=4.1.1=pyh06a4308_0
203 | - tzdata=2021e=hda174b7_0
204 | - urllib3=1.26.8=pyhd3eb1b0_0
205 | - wcwidth=0.2.5=pyhd3eb1b0_0
206 | - wheel=0.37.1=pyhd3eb1b0_0
207 | - xarray=0.20.1=pyhd3eb1b0_1
208 | - xz=5.2.5=h7b6447c_0
209 | - yaml=0.2.5=h7b6447c_0
210 | - yarl=1.6.3=py39h27cfd23_0
211 | - zarr=2.8.1=pyhd3eb1b0_0
212 | - zeromq=4.3.4=h2531618_0
213 | - zict=2.0.0=pyhd3eb1b0_0
214 | - zlib=1.2.11=h7f8727e_4
215 | - zstd=1.4.9=haebb681_0
216 | - pip:
217 | - absl-py==1.0.0
218 | - black==22.1.0
219 | - blis==0.7.7
220 | - catalogue==2.0.7
221 | - cymem==2.0.6
222 | - einops==0.4.1
223 | - fastai==2.5.3
224 | - fastcore==1.3.29
225 | - fastdownload==0.0.5
226 | - fastprogress==1.0.2
227 | - filelock==3.6.0
228 | - future==0.18.2
229 | - huggingface-hub==0.4.0
230 | - hyperopt==0.2.7
231 | - importlib-metadata==4.11.3
232 | - joblib==1.1.0
233 | - langcodes==3.3.0
234 | - markdown==3.3.6
235 | - murmurhash==1.0.6
236 | - mypy-extensions==0.4.3
237 | - networkx==2.7.1
238 | - pathspec==0.9.0
239 | - pathy==0.6.1
240 | - platformdirs==2.5.1
241 | - preshed==3.0.6
242 | - py4j==0.10.9.5
243 | - pydantic==1.8.2
244 | - pydeprecate==0.3.1
245 | - pyproj==3.3.0
246 | - pytorch-lightning==1.5.10
247 | - pytorch-msssim==0.2.1
248 | - scikit-learn==1.0.2
249 | - setuptools==59.5.0
250 | - smart-open==5.2.1
251 | - spacy==3.2.3
252 | - spacy-legacy==3.0.9
253 | - spacy-loggers==1.0.1
254 | - srsly==2.4.2
255 | - tensorboard==2.8.0
256 | - tensorboard-data-server==0.6.1
257 | - tensorboard-plugin-wit==1.8.1
258 | - thinc==8.0.15
259 | - threadpoolctl==3.1.0
260 | - tomli==2.0.1
261 | - torch==1.10.2
262 | - torchmetrics==0.7.3
263 | - torchvision==0.11.3
264 | - tqdm==4.63.1
265 | - typer==0.4.0
266 | - wasabi==0.9.0
267 | - werkzeug==2.0.3
268 | - zipp==3.7.0
269 |
--------------------------------------------------------------------------------
/experiments/climatehack-submission/.gitignore:
--------------------------------------------------------------------------------
1 | bin/
2 |
--------------------------------------------------------------------------------
/experiments/climatehack-submission/doxa_cli.py:
--------------------------------------------------------------------------------
1 | # NOTE:
2 | # This file downloads the doxa_cli to allow you to upload an agent.
3 | # You do NOT need to edit this file, use it as is.
4 |
5 | import sys
6 |
7 | if sys.version_info[0] != 3:
8 | print("Please run this script using python3")
9 | sys.exit(1)
10 |
11 | import json
12 | import os
13 | import platform
14 | import stat
15 | import subprocess
16 | import tarfile
17 | import urllib.error
18 | import urllib.request
19 |
20 |
21 | # Returns `windows`, `darwin` (macos) or `linux`
22 | def get_os():
23 | system = platform.system()
24 |
25 | if system == "Linux":
26 | # The exe version works better for WSL
27 | if "microsoft" in platform.platform():
28 | return "windows"
29 |
30 | return "linux"
31 | elif system == "Windows":
32 | return "windows"
33 | elif system == "Darwin":
34 | return "darwin"
35 | else:
36 | raise Exception(f"Unknown platform {system}")
37 |
38 |
39 | def get_bin_name():
40 | bin_name = "doxa_cli"
41 | if get_os() == "windows":
42 | bin_name = "doxa_cli.exe"
43 | return bin_name
44 |
45 |
46 | def get_bin_dir():
47 | return os.path.join(os.path.dirname(__file__), "bin")
48 |
49 |
50 | def get_binary():
51 | return os.path.join(get_bin_dir(), get_bin_name())
52 |
53 |
54 | def install_binary():
55 | match_release = None
56 | try:
57 | match_release = sys.argv[1]
58 | # Arguments are not required
59 | except IndexError:
60 | pass
61 |
62 | REPO_RELEASE_URL = "https://api.github.com/repos/louisdewar/doxa/releases/latest"
63 | try:
64 | f = urllib.request.urlopen(REPO_RELEASE_URL)
65 | except urllib.error.URLError:
66 | print("There was an SSL cert verification error")
67 | print(
68 | 'If you are on a mac and you have recently installed a new version of\
69 | python then you should navigate to "/Applications/Python {VERSION}/"'
70 | )
71 | print('Then run a script in that folder called "Install Certificates.command"')
72 | sys.exit(1)
73 |
74 | response = json.loads(f.read())
75 |
76 | print("Current version tag:", response["tag_name"])
77 |
78 | assets = [
79 | asset for asset in response["assets"] if asset["name"].endswith(".tar.gz")
80 | ]
81 |
82 | # Find the release for this OS
83 | match_release = get_os()
84 | try:
85 | asset_choice = next(asset for asset in assets if match_release in asset["name"])
86 | print(
87 | 'Automatically picked {} to download based on match "{}"\n'.format(
88 | asset_choice["name"], match_release
89 | )
90 | )
91 | except StopIteration:
92 | print('Couldn\'t find "{}" in releases'.format(match_release))
93 | sys.exit(1)
94 |
95 | download_url = asset_choice["browser_download_url"]
96 |
97 | # Folder where this script is + bin
98 | bin_dir = get_bin_dir()
99 |
100 | print("Downloading", asset_choice["name"], "to", bin_dir)
101 | print("({})".format(download_url))
102 |
103 | # Clear bin directory if it exists
104 | if not os.path.exists(bin_dir):
105 | os.mkdir(bin_dir)
106 |
107 | zip_path = os.path.join(bin_dir, asset_choice["name"])
108 |
109 | # Download zip file
110 | urllib.request.urlretrieve(download_url, zip_path)
111 |
112 | # Open and extract zip file
113 | tar_file = tarfile.open(zip_path)
114 | tar_file.extractall(bin_dir)
115 | tar_file.close()
116 |
117 | # Delete zip file
118 | os.remove(zip_path)
119 |
120 | # Path to the actual binary program (called doxa_cli or doxa_cli.exe)
121 | bin_name = get_bin_name()
122 | binary_path = os.path.join(bin_dir, bin_name)
123 |
124 | if not os.path.exists(binary_path):
125 | print(f"Couldn't find the binary file `{bin_name}` in the bin directory")
126 | print("This probably means that there was a problem with the download")
127 | sys.exit(1)
128 |
129 | if get_os() != "windows":
130 | # Make binary executable
131 | st = os.stat(binary_path)
132 | os.chmod(binary_path, st.st_mode | stat.S_IEXEC)
133 |
134 | # Run help
135 | print("Installed binary\n\n")
136 |
137 |
138 | def run_command(args):
139 | bin_path = get_binary()
140 |
141 | if not os.path.exists(bin_path):
142 | install_binary()
143 | subprocess.call([bin_path] + args)
144 |
145 |
146 | if __name__ == "__main__":
147 | run_command(sys.argv[1:])
148 |
--------------------------------------------------------------------------------
/experiments/climatehack-submission/submission/climatehack.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from pathlib import Path
3 | from typing import List, Tuple
4 |
5 | import numpy as np
6 |
7 |
8 | class BaseEvaluator:
9 | def __init__(self) -> None:
10 | self.setup()
11 |
12 | def setup(self):
13 | """Sets up anything required for evaluation, e.g. models."""
14 | pass
15 |
16 | def predict(self, coordinates: np.ndarray, data: np.ndarray) -> np.ndarray:
17 | """Makes a prediction for the next two hours of satellite imagery.
18 |
19 | Args:
20 | coordinates (np.ndarray): the OSGB x and y coordinates (2, 128, 128)
21 | data (np.ndarray): an array of 12 128*128 satellite images (12, 128, 128)
22 |
23 | Returns:
24 | np.ndarray: an array of 24 64*64 satellite image predictions (24, 64, 64)
25 | """
26 |
27 | raise NotImplementedError(
28 | "You need to extend this class to use your trained model(s)."
29 | )
30 |
31 | def _get_io_paths(self) -> Tuple[Path, Path]:
32 | """Gets the input and output directory paths from DOXA.
33 |
34 | Returns:
35 | Tuple[Path, Path]: The input and output paths
36 | """
37 | try:
38 | return Path(sys.argv[1]), Path(sys.argv[2])
39 | except IndexError:
40 | raise Exception(
41 | f"Run using: {sys.argv[0]} [input directory] [output directory]"
42 | )
43 |
44 | def _get_group_path(self) -> str:
45 | """Gets the path for the next group to be processed.
46 |
47 | Raises:
48 | ValueError: An unknown message was received from DOXA.
49 |
50 | Returns:
51 | str: The path of the next group.
52 | """
53 |
54 | msg = input()
55 | if not msg.startswith("Process "):
56 | raise ValueError(f"Unknown messsage {msg}")
57 |
58 | return msg[8:]
59 |
60 | def _evaluate_group(self, group: dict) -> List[np.ndarray]:
61 | """Evaluates a group of satellite image sequences using
62 | the user-implemented model(s).
63 |
64 | Args:
65 | group (dict): The OSGB and satellite imagery data.
66 |
67 | Returns:
68 | List[np.ndarray]: The predictions.
69 | """
70 | batch_size = 16
71 | split_num = len(group["data"]) // batch_size
72 | osgb_splits = np.array_split(group["osgb"], split_num, axis=0)
73 | data_splits = np.array_split(group["data"], split_num, axis=0)
74 |
75 | all_preds = []
76 | for (osgb, data) in zip(osgb_splits, data_splits):
77 | bs = len(data)
78 | preds = self.predict(osgb, data)
79 | for b in range(bs):
80 | all_preds.append(preds[b])
81 | return all_preds
82 | # return [self.predict(*datum) for datum in zip(group["osgb"], group["data"])]
83 |
84 | def evaluate(self):
85 | """Evaluates the user's model on DOXA.
86 |
87 | Messages are sent and received through stdio.
88 |
89 | The input data is loaded from a directory in groups.
90 |
91 | The predictions are written to another directory in groups.
92 |
93 | Raises:
94 | Exception: An error occurred somewhere.
95 | """
96 |
97 | print("STARTUP")
98 | input_path, output_path = self._get_io_paths()
99 |
100 | # process test data groups
101 | while True:
102 | # load the data for the group DOXA requests
103 | group_path = self._get_group_path()
104 | group_data = np.load(input_path / group_path)
105 |
106 | # make predictions for this group
107 | try:
108 | predictions = self._evaluate_group(group_data)
109 | except Exception as err:
110 | raise Exception(f"Error while processing {group_path}: {str(err)}")
111 |
112 | # save the output group predictions
113 | np.savez(
114 | output_path / group_path,
115 | data=np.stack(predictions),
116 | )
117 | print(f"Exported {group_path}")
118 |
--------------------------------------------------------------------------------
/experiments/climatehack-submission/submission/dgmr-oneshot/dgmr/__init__.py:
--------------------------------------------------------------------------------
1 | from .dgmr import DGMR
2 | from .generators import Sampler, Generator
3 | from .discriminators import SpatialDiscriminator, TemporalDiscriminator, Discriminator
4 | from .common import LatentConditioningStack, ContextConditioningStack
5 |
--------------------------------------------------------------------------------
/experiments/climatehack-submission/submission/dgmr-oneshot/dgmr/generators.py:
--------------------------------------------------------------------------------
1 | import einops
2 | import torch
3 | import torch.nn.functional as F
4 | from torch.nn.modules.pixelshuffle import PixelShuffle
5 | from torch.nn.utils.parametrizations import spectral_norm
6 | from typing import List
7 | from dgmr.common import GBlock, UpsampleGBlock
8 | from dgmr.layers import ConvGRU
9 | from huggingface_hub import PyTorchModelHubMixin
10 | import logging
11 |
12 | logger = logging.getLogger(__name__)
13 | logger.setLevel(logging.WARN)
14 |
15 |
16 | class Sampler(torch.nn.Module, PyTorchModelHubMixin):
17 | def __init__(
18 | self,
19 | forecast_steps: int = 18,
20 | context_channels: int = 384,
21 | latent_channels: int = 384,
22 | output_channels: int = 1,
23 | **kwargs
24 | ):
25 | """
26 | Sampler from the Skillful Nowcasting, see https://arxiv.org/pdf/2104.00954.pdf
27 |
28 | The sampler takes the output from the Latent and Context conditioning stacks and
29 | creates one stack of ConvGRU layers per future timestep.
30 | Args:
31 | forecast_steps: Number of forecast steps
32 | latent_channels: Number of input channels to the lowest ConvGRU layer
33 | """
34 | super().__init__()
35 | config = locals()
36 | config.pop("__class__")
37 | config.pop("self")
38 | self.config = kwargs.get("config", config)
39 | self.forecast_steps = self.config["forecast_steps"]
40 | latent_channels = self.config["latent_channels"]
41 | context_channels = self.config["context_channels"]
42 | output_channels = self.config["output_channels"]
43 |
44 | self.gru_conv_1x1 = spectral_norm(
45 | torch.nn.Conv2d(
46 | in_channels=context_channels,
47 | out_channels=latent_channels * self.forecast_steps,
48 | kernel_size=(1, 1),
49 | )
50 | )
51 | self.g1 = GBlock(
52 | input_channels=latent_channels * self.forecast_steps,
53 | output_channels=latent_channels * self.forecast_steps,
54 | )
55 | self.up_g1 = UpsampleGBlock(
56 | input_channels=latent_channels * self.forecast_steps,
57 | output_channels=latent_channels * self.forecast_steps // 2,
58 | )
59 |
60 | self.gru_conv_1x1_2 = spectral_norm(
61 | torch.nn.Conv2d(
62 | in_channels=self.up_g1.output_channels + context_channels // 2,
63 | out_channels=latent_channels * self.forecast_steps // 2,
64 | kernel_size=(1, 1),
65 | )
66 | )
67 | self.g2 = GBlock(
68 | input_channels=latent_channels * self.forecast_steps // 2,
69 | output_channels=latent_channels * self.forecast_steps // 2,
70 | )
71 | self.up_g2 = UpsampleGBlock(
72 | input_channels=latent_channels * self.forecast_steps // 2,
73 | output_channels=latent_channels * self.forecast_steps // 4,
74 | )
75 |
76 | self.gru_conv_1x1_3 = spectral_norm(
77 | torch.nn.Conv2d(
78 | in_channels=self.up_g2.output_channels + context_channels // 4,
79 | out_channels=latent_channels * self.forecast_steps // 4,
80 | kernel_size=(1, 1),
81 | )
82 | )
83 | self.g3 = GBlock(
84 | input_channels=latent_channels * self.forecast_steps // 4,
85 | output_channels=latent_channels * self.forecast_steps // 4,
86 | )
87 | self.up_g3 = UpsampleGBlock(
88 | input_channels=latent_channels * self.forecast_steps // 4,
89 | output_channels=latent_channels * self.forecast_steps // 8,
90 | )
91 |
92 | self.gru_conv_1x1_4 = spectral_norm(
93 | torch.nn.Conv2d(
94 | in_channels=self.up_g3.output_channels + context_channels // 8,
95 | out_channels=latent_channels * self.forecast_steps // 8,
96 | kernel_size=(1, 1),
97 | )
98 | )
99 | self.g4 = GBlock(
100 | input_channels=latent_channels * self.forecast_steps // 8,
101 | output_channels=latent_channels * self.forecast_steps // 8,
102 | )
103 | self.up_g4 = UpsampleGBlock(
104 | input_channels=latent_channels * self.forecast_steps // 8,
105 | output_channels=latent_channels * self.forecast_steps // 16,
106 | )
107 |
108 | self.bn = torch.nn.BatchNorm2d(latent_channels * self.forecast_steps // 16)
109 | self.relu = torch.nn.ReLU()
110 | self.conv_1x1 = spectral_norm(
111 | torch.nn.Conv2d(
112 | in_channels=latent_channels * self.forecast_steps // 16,
113 | out_channels=4 * output_channels * self.forecast_steps,
114 | kernel_size=(1, 1),
115 | )
116 | )
117 |
118 | self.depth2space = PixelShuffle(upscale_factor=2)
119 |
120 | def forward(self, conditioning_states: List[torch.Tensor]) -> torch.Tensor:
121 | """
122 | Perform the sampling from Skillful Nowcasting with GANs
123 | Args:
124 | conditioning_states: Outputs from the `ContextConditioningStack` with the 4 input states, ordered from largest to smallest spatially
125 | latent_dim: Output from `LatentConditioningStack` for input into the ConvGRUs
126 |
127 | Returns:
128 | forecast_steps-length output of images for future timesteps
129 |
130 | """
131 | # Iterate through each forecast step
132 | # Initialize with conditioning state for first one, output for second one
133 | init_states = conditioning_states
134 |
135 | layer4_states = self.gru_conv_1x1(init_states[3])
136 | layer4_states = self.g1(layer4_states)
137 | layer4_states = self.up_g1(layer4_states)
138 |
139 | # Layer 3.
140 | layer3_states = torch.cat([layer4_states, init_states[2]], dim=1)
141 | layer3_states = self.gru_conv_1x1_2(layer3_states)
142 | layer3_states = self.g2(layer3_states)
143 | layer3_states = self.up_g2(layer3_states)
144 |
145 | # Layer 2.
146 | layer2_states = torch.cat([layer3_states, init_states[1]], dim=1)
147 | layer2_states = self.gru_conv_1x1_3(layer2_states)
148 | layer2_states = self.g3(layer2_states)
149 | layer2_states = self.up_g3(layer2_states)
150 |
151 | # Layer 1 (top-most).
152 | layer1_states = torch.cat([layer2_states, init_states[0]], dim=1)
153 | layer1_states = self.gru_conv_1x1_4(layer1_states)
154 | layer1_states = self.g4(layer1_states)
155 | layer1_states = self.up_g4(layer1_states)
156 |
157 | # Final stuff
158 | output_states = self.relu(self.bn(layer1_states))
159 | output_states = self.conv_1x1(output_states)
160 | output_states = self.depth2space(output_states)
161 |
162 | # The satellite dimension was lost, add it back
163 | output_states = torch.unsqueeze(output_states, dim=2)
164 |
165 | return output_states
166 |
167 |
168 | class Generator(torch.nn.Module, PyTorchModelHubMixin):
169 | def __init__(
170 | self,
171 | conditioning_stack: torch.nn.Module,
172 | sampler: torch.nn.Module,
173 | ):
174 | """
175 | Wraps the three parts of the generator for simpler calling
176 | Args:
177 | conditioning_stack:
178 | latent_stack:
179 | sampler:
180 | """
181 | super().__init__()
182 | self.conditioning_stack = conditioning_stack
183 | self.sampler = sampler
184 |
185 | def forward(self, x):
186 | conditioning_states = self.conditioning_stack(x)
187 | x = self.sampler(conditioning_states)
188 | return x
189 |
--------------------------------------------------------------------------------
/experiments/climatehack-submission/submission/dgmr-oneshot/dgmr/hub.py:
--------------------------------------------------------------------------------
1 | """
2 | Originally Taken from https://github.com/rwightman/
3 |
4 | https://github.com/rwightman/pytorch-image-models/
5 | blob/acd6c687fd1c0507128f0ce091829b233c8560b9/timm/models/hub.py
6 | """
7 |
8 | import json
9 | import logging
10 | import os
11 | from functools import partial
12 |
13 | import torch
14 |
15 |
16 | try:
17 | from huggingface_hub import cached_download, hf_hub_url
18 |
19 | cached_download = partial(cached_download, library_name="dgmr")
20 | except ImportError:
21 | hf_hub_url = None
22 | cached_download = None
23 |
24 | from huggingface_hub import CONFIG_NAME, PYTORCH_WEIGHTS_NAME, ModelHubMixin, hf_hub_download
25 |
26 | MODEL_CARD_MARKDOWN = """---
27 | license: mit
28 | tags:
29 | - nowcasting
30 | - forecasting
31 | - timeseries
32 | - remote-sensing
33 | - gan
34 | ---
35 |
36 | # {model_name}
37 |
38 | ## Model description
39 |
40 | [More information needed]
41 |
42 | ## Intended uses & limitations
43 |
44 | [More information needed]
45 |
46 | ## How to use
47 |
48 | [More information needed]
49 |
50 | ## Limitations and bias
51 |
52 | [More information needed]
53 |
54 | ## Training data
55 |
56 | [More information needed]
57 |
58 | ## Training procedure
59 |
60 | [More information needed]
61 |
62 | ## Evaluation results
63 |
64 | [More information needed]
65 |
66 | """
67 |
68 | _logger = logging.getLogger(__name__)
69 |
70 |
71 | class NowcastingModelHubMixin(ModelHubMixin):
72 | """
73 | HuggingFace ModelHubMixin containing specific adaptions for Nowcasting models
74 | """
75 |
76 | def __init__(self, *args, **kwargs):
77 | """
78 | Mixin for pl.LightningModule and Hugging Face
79 |
80 | Mix this class with your pl.LightningModule class to easily push / download
81 | the model via the Hugging Face Hub
82 |
83 | Example::
84 |
85 | >>> from dgmr.hub import NowcastingModelHubMixin
86 |
87 | >>> class MyModel(nn.Module, NowcastingModelHubMixin):
88 | ... def __init__(self, **kwargs):
89 | ... super().__init__()
90 | ... self.layer = ...
91 | ... def forward(self, ...)
92 | ... return ...
93 |
94 | >>> model = MyModel()
95 | >>> model.push_to_hub("mymodel") # Pushing model-weights to hf-hub
96 |
97 | >>> # Downloading weights from hf-hub & model will be initialized from those weights
98 | >>> model = MyModel.from_pretrained("username/mymodel")
99 | """
100 |
101 | def _create_model_card(self, path):
102 | model_card = MODEL_CARD_MARKDOWN.format(model_name=type(self).__name__)
103 | with open(os.path.join(path, "README.md"), "w") as f:
104 | f.write(model_card)
105 |
106 | def _save_config(self, module, save_directory):
107 | config = dict(module.hparams)
108 | path = os.path.join(save_directory, CONFIG_NAME)
109 | with open(path, "w") as f:
110 | json.dump(config, f)
111 |
112 | def _save_pretrained(self, save_directory: str, save_config: bool = True):
113 | # Save model weights
114 | path = os.path.join(save_directory, PYTORCH_WEIGHTS_NAME)
115 | model_to_save = self.module if hasattr(self, "module") else self
116 | torch.save(model_to_save.state_dict(), path)
117 | # Save model config
118 | if save_config and model_to_save.hparams:
119 | self._save_config(model_to_save, save_directory)
120 | # Save model card
121 | self._create_model_card(save_directory)
122 |
123 | @classmethod
124 | def _from_pretrained(
125 | cls,
126 | model_id,
127 | revision,
128 | cache_dir,
129 | force_download,
130 | proxies,
131 | resume_download,
132 | local_files_only,
133 | use_auth_token,
134 | map_location="cpu",
135 | strict=False,
136 | **model_kwargs,
137 | ):
138 | map_location = torch.device(map_location)
139 |
140 | if os.path.isdir(model_id):
141 | print("Loading weights from local directory")
142 | model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME)
143 | else:
144 | model_file = hf_hub_download(
145 | repo_id=model_id,
146 | filename=PYTORCH_WEIGHTS_NAME,
147 | revision=revision,
148 | cache_dir=cache_dir,
149 | force_download=force_download,
150 | proxies=proxies,
151 | resume_download=resume_download,
152 | use_auth_token=use_auth_token,
153 | local_files_only=local_files_only,
154 | )
155 | model = cls(**model_kwargs["config"])
156 |
157 | state_dict = torch.load(model_file, map_location=map_location)
158 | model.load_state_dict(state_dict, strict=strict)
159 | model.eval()
160 |
161 | return model
162 |
--------------------------------------------------------------------------------
/experiments/climatehack-submission/submission/dgmr-oneshot/dgmr/layers/Attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import functional as F
4 | import einops
5 |
6 |
7 | def attention_einsum(q, k, v):
8 | """Apply the attention operator to tensors of shape [h, w, c]."""
9 |
10 | # Reshape 3D tensors to 2D tensor with first dimension L = h x w.
11 | k = einops.rearrange(k, "h w c -> (h w) c") # [h, w, c] -> [L, c]
12 | v = einops.rearrange(v, "h w c -> (h w) c") # [h, w, c] -> [L, c]
13 |
14 | # Einstein summation corresponding to the query * key operation.
15 | beta = F.softmax(torch.einsum("hwc, Lc->hwL", q, k), dim=-1)
16 |
17 | # Einstein summation corresponding to the attention * value operation.
18 | out = torch.einsum("hwL, Lc->hwc", beta, v)
19 | return out
20 |
21 |
22 | class AttentionLayer(torch.nn.Module):
23 | """Attention Module"""
24 |
25 | def __init__(self, input_channels: int, output_channels: int, ratio_kq=8, ratio_v=8):
26 | super(AttentionLayer, self).__init__()
27 |
28 | self.ratio_kq = ratio_kq
29 | self.ratio_v = ratio_v
30 | self.output_channels = output_channels
31 | self.input_channels = input_channels
32 |
33 | # Compute query, key and value using 1x1 convolutions.
34 | self.query = torch.nn.Conv2d(
35 | in_channels=input_channels,
36 | out_channels=self.output_channels // self.ratio_kq,
37 | kernel_size=(1, 1),
38 | padding="valid",
39 | bias=False,
40 | )
41 | self.key = torch.nn.Conv2d(
42 | in_channels=input_channels,
43 | out_channels=self.output_channels // self.ratio_kq,
44 | kernel_size=(1, 1),
45 | padding="valid",
46 | bias=False,
47 | )
48 | self.value = torch.nn.Conv2d(
49 | in_channels=input_channels,
50 | out_channels=self.output_channels // self.ratio_v,
51 | kernel_size=(1, 1),
52 | padding="valid",
53 | bias=False,
54 | )
55 |
56 | self.last_conv = torch.nn.Conv2d(
57 | in_channels=self.output_channels // 8,
58 | out_channels=self.output_channels,
59 | kernel_size=(1, 1),
60 | padding="valid",
61 | bias=False,
62 | )
63 |
64 | # Learnable gain parameter
65 | self.gamma = nn.Parameter(torch.zeros(1))
66 |
67 | def forward(self, x: torch.Tensor) -> torch.Tensor:
68 | # Compute query, key and value using 1x1 convolutions.
69 | query = self.query(x)
70 | key = self.key(x)
71 | value = self.value(x)
72 | # Apply the attention operation.
73 | out = []
74 | for b in range(x.shape[0]):
75 | # Apply to each in batch
76 | out.append(attention_einsum(query[b], key[b], value[b]))
77 | out = torch.stack(out, dim=0)
78 | out = self.gamma * self.last_conv(out)
79 | # Residual connection.
80 | return out + x
81 |
--------------------------------------------------------------------------------
/experiments/climatehack-submission/submission/dgmr-oneshot/dgmr/layers/ConvGRU.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.nn.utils.parametrizations import spectral_norm
4 |
5 |
6 | class ConvGRUCell(torch.nn.Module):
7 | """A ConvGRU implementation."""
8 |
9 | def __init__(
10 | self, input_channels: int, output_channels: int, kernel_size=3, sn_eps=0.0001
11 | ):
12 | """Constructor.
13 |
14 | Args:
15 | kernel_size: kernel size of the convolutions. Default: 3.
16 | sn_eps: constant for spectral normalization. Default: 1e-4.
17 | """
18 | super().__init__()
19 | self._kernel_size = kernel_size
20 | self._sn_eps = sn_eps
21 | self.read_gate_conv = spectral_norm(
22 | torch.nn.Conv2d(
23 | in_channels=input_channels,
24 | out_channels=output_channels,
25 | kernel_size=(kernel_size, kernel_size),
26 | padding=1,
27 | ),
28 | eps=sn_eps,
29 | )
30 | self.update_gate_conv = spectral_norm(
31 | torch.nn.Conv2d(
32 | in_channels=input_channels,
33 | out_channels=output_channels,
34 | kernel_size=(kernel_size, kernel_size),
35 | padding=1,
36 | ),
37 | eps=sn_eps,
38 | )
39 | self.output_conv = spectral_norm(
40 | torch.nn.Conv2d(
41 | in_channels=input_channels,
42 | out_channels=output_channels,
43 | kernel_size=(kernel_size, kernel_size),
44 | padding=1,
45 | ),
46 | eps=sn_eps,
47 | )
48 |
49 | def forward(self, x, prev_state):
50 | """
51 | ConvGRU forward, returning the current+new state
52 |
53 | Args:
54 | x: Input tensor
55 | prev_state: Previous state
56 |
57 | Returns:
58 | New tensor plus the new state
59 | """
60 | # Concatenate the inputs and previous state along the channel axis.
61 | xh = torch.cat([x, prev_state], dim=1)
62 |
63 | # Read gate of the GRU.
64 | read_gate = F.sigmoid(self.read_gate_conv(xh))
65 |
66 | # Update gate of the GRU.
67 | update_gate = F.sigmoid(self.update_gate_conv(xh))
68 |
69 | # Gate the inputs.
70 | gated_input = torch.cat([x, read_gate * prev_state], dim=1)
71 |
72 | # Gate the cell and state / outputs.
73 | c = F.relu(self.output_conv(gated_input))
74 | out = update_gate * prev_state + (1.0 - update_gate) * c
75 | new_state = out
76 |
77 | return out, new_state
78 |
79 |
80 | class ConvGRU(torch.nn.Module):
81 | """ConvGRU Cell wrapper to replace tf.static_rnn in TF implementation"""
82 |
83 | def __init__(
84 | self,
85 | input_channels: int,
86 | output_channels: int,
87 | kernel_size: int = 3,
88 | sn_eps=0.0001,
89 | ):
90 | super().__init__()
91 | self.cell = ConvGRUCell(input_channels, output_channels, kernel_size, sn_eps)
92 |
93 | def forward(self, x: torch.Tensor, hidden_state=None) -> torch.Tensor:
94 | outputs = []
95 | for step in range(len(x)):
96 | # Compute current timestep
97 | output, hidden_state = self.cell(x[step], hidden_state)
98 | outputs.append(output)
99 | # Stack outputs to return as tensor
100 | outputs = torch.stack(outputs, dim=0)
101 | return outputs
102 |
--------------------------------------------------------------------------------
/experiments/climatehack-submission/submission/dgmr-oneshot/dgmr/layers/CoordConv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class AddCoords(nn.Module):
6 | def __init__(self, with_r=False):
7 | super().__init__()
8 | self.with_r = with_r
9 |
10 | def forward(self, input_tensor):
11 | """
12 | Args:
13 | input_tensor: shape(batch, channel, x_dim, y_dim)
14 | """
15 | batch_size, _, x_dim, y_dim = input_tensor.size()
16 |
17 | xx_channel = torch.arange(x_dim).repeat(1, y_dim, 1)
18 | yy_channel = torch.arange(y_dim).repeat(1, x_dim, 1).transpose(1, 2)
19 |
20 | xx_channel = xx_channel.float() / (x_dim - 1)
21 | yy_channel = yy_channel.float() / (y_dim - 1)
22 |
23 | xx_channel = xx_channel * 2 - 1
24 | yy_channel = yy_channel * 2 - 1
25 |
26 | xx_channel = xx_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)
27 | yy_channel = yy_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)
28 |
29 | ret = torch.cat(
30 | [
31 | input_tensor,
32 | xx_channel.type_as(input_tensor),
33 | yy_channel.type_as(input_tensor),
34 | ],
35 | dim=1,
36 | )
37 |
38 | if self.with_r:
39 | rr = torch.sqrt(
40 | torch.pow(xx_channel.type_as(input_tensor) - 0.5, 2)
41 | + torch.pow(yy_channel.type_as(input_tensor) - 0.5, 2)
42 | )
43 | ret = torch.cat([ret, rr], dim=1)
44 |
45 | return ret
46 |
47 |
48 | class CoordConv(nn.Module):
49 | def __init__(self, in_channels, out_channels, with_r=False, **kwargs):
50 | super().__init__()
51 | self.addcoords = AddCoords(with_r=with_r)
52 | in_size = in_channels + 2
53 | if with_r:
54 | in_size += 1
55 | self.conv = nn.Conv2d(in_size, out_channels, **kwargs)
56 |
57 | def forward(self, x):
58 | ret = self.addcoords(x)
59 | ret = self.conv(ret)
60 | return ret
61 |
--------------------------------------------------------------------------------
/experiments/climatehack-submission/submission/dgmr-oneshot/dgmr/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from .Attention import AttentionLayer
2 | from .ConvGRU import ConvGRU
3 | from .CoordConv import CoordConv
4 |
--------------------------------------------------------------------------------
/experiments/climatehack-submission/submission/dgmr-oneshot/dgmr/layers/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from dgmr.layers import CoordConv
3 |
4 |
5 | def get_conv_layer(conv_type: str = "standard") -> torch.nn.Module:
6 | if conv_type == "standard":
7 | conv_layer = torch.nn.Conv2d
8 | elif conv_type == "coord":
9 | conv_layer = CoordConv
10 | elif conv_type == "3d":
11 | conv_layer = torch.nn.Conv3d
12 | else:
13 | raise ValueError(f"{conv_type} is not a recognized Conv method")
14 | return conv_layer
15 |
--------------------------------------------------------------------------------
/experiments/climatehack-submission/submission/doxa.yaml:
--------------------------------------------------------------------------------
1 | language: python
2 | entrypoint: evaluate.py
3 |
--------------------------------------------------------------------------------
/experiments/climatehack-submission/submission/evaluate.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import cv2
4 |
5 | from climatehack import BaseEvaluator
6 |
7 | import sys
8 |
9 | sys.path.append("./dgmr-oneshot")
10 | import dgmr
11 |
12 | # DEVICE = torch.device("cpu")
13 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14 |
15 | _MEAN_PIXEL = 240.3414
16 | _STD_PIXEL = 146.52366
17 |
18 |
19 | def transform(x):
20 | return (x - _MEAN_PIXEL) / _STD_PIXEL
21 |
22 |
23 | def inv_transform(x):
24 | return (x * _STD_PIXEL) + _MEAN_PIXEL
25 |
26 |
27 | def warp_flow(img, flow):
28 | h, w = flow.shape[:2]
29 | flow = -flow
30 | flow[:, :, 0] += np.arange(w)
31 | flow[:, :, 1] += np.arange(h)[:, np.newaxis]
32 | res = cv2.remap(img, flow, None, cv2.INTER_LINEAR, borderMode=cv2.BORDER_REPLICATE)
33 | return res
34 |
35 |
36 | class Evaluator(BaseEvaluator):
37 | def setup(self):
38 | """Sets up anything required for evaluation.
39 | In this case, it loads the trained model (in evaluation mode)."""
40 | ccs = dgmr.common.ContextConditioningStack(
41 | input_channels=1,
42 | conv_type="standard",
43 | output_channels=160,
44 | )
45 |
46 | sampler = dgmr.generators.Sampler(
47 | forecast_steps=24,
48 | latent_channels=96,
49 | context_channels=160,
50 | output_channels=1,
51 | )
52 | model = dgmr.generators.Generator(ccs, sampler)
53 | model.load_state_dict(torch.load("weights/model.pt", map_location=DEVICE))
54 | self.model = model.to(DEVICE)
55 | self.model.train()
56 | # print("DOING TRAIN MODE")
57 |
58 | def predict(self, coordinates: np.ndarray, data: np.ndarray) -> np.ndarray:
59 | """Makes a prediction for the next two hours of satellite imagery.
60 | Args:
61 | coordinates (np.ndarray): the OSGB x and y coordinates (2, 128, 128)
62 | data (np.ndarray): an array of 12 128*128 satellite images (bs, 12, 128, 128)
63 | Returns:
64 | np.ndarray: an array of 24 64*64 satellite image predictions (bs, 24, 64, 64)
65 | """
66 | # prediction_opt_flow = self._predict_opt_flow(data)
67 | prediction_dgmr = self._predict_dgmr(coordinates, data)
68 | # copy the opt flow predictions in
69 | # prediction_dgmr[:, : prediction_opt_flow.shape[1]] = prediction_opt_flow
70 | prediction = prediction_dgmr
71 | return prediction
72 |
73 | def _predict_dgmr(self, coordinates: np.ndarray, data: np.ndarray) -> np.ndarray:
74 | bs = data.shape[0]
75 | data = data[:, -4:]
76 | data = torch.FloatTensor(transform(data)).float().to(DEVICE)
77 | # add a satellite dimension
78 | data = torch.unsqueeze(data, dim=2)
79 | # make a batch to help with norm
80 | # data = torch.cat([data, self.default_batch], dim=0)
81 | with torch.no_grad():
82 | prediction = self.model(data)
83 | # remove the satellite dimension and grab the inner 64x64
84 | prediction = inv_transform(prediction[:, :, 0, 32:96, 32:96])
85 | if prediction.device == "cpu":
86 | prediction = prediction.numpy()
87 | else:
88 | prediction = prediction.detach().cpu().numpy()
89 | return prediction
90 |
91 | def _predict_opt_flow(self, data: np.ndarray) -> np.ndarray:
92 | bs = data.shape[0]
93 | forecast = 1
94 | prediction = np.zeros((bs, forecast, 64, 64), dtype=np.float32)
95 | test_params = {
96 | "pyr_scale": 0.5,
97 | "levels": 2,
98 | "winsize": 40,
99 | "iterations": 3,
100 | "poly_n": 5,
101 | "poly_sigma": 0.7,
102 | }
103 | for i in range(bs):
104 | sample = data[i]
105 | cur = sample[-1].astype(np.float32)
106 | flow = cv2.calcOpticalFlowFarneback(
107 | prev=sample[-2],
108 | next=sample[-1],
109 | flow=None,
110 | **test_params,
111 | flags=cv2.OPTFLOW_FARNEBACK_GAUSSIAN,
112 | )
113 | for j in range(forecast):
114 | cur = warp_flow(cur, flow)
115 | prediction[i, j] = cur[32:96, 32:96]
116 | return prediction
117 |
118 |
119 | def main():
120 | evaluator = Evaluator()
121 | evaluator.evaluate()
122 |
123 |
124 | if __name__ == "__main__":
125 | main()
126 |
--------------------------------------------------------------------------------
/experiments/climatehack-submission/submission/validate.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from pytorch_msssim import MS_SSIM
3 | from torch import from_numpy
4 | import torch
5 | import tqdm
6 |
7 | from evaluate import Evaluator
8 |
9 |
10 | def main():
11 | features = np.load("../features.npz")
12 | targets = np.load("../targets.npz")
13 |
14 | criterion = MS_SSIM(data_range=1023.0, size_average=True, win_size=3, channel=1)
15 | evaluator = Evaluator()
16 |
17 | batch_size = 16
18 | split_num = len(features["data"]) // batch_size
19 | osgb_splits = np.array_split(features["osgb"], split_num, axis=0)
20 | data_splits = np.array_split(features["data"], split_num, axis=0)
21 | targets_splits = np.array_split(targets["data"], split_num, axis=0)
22 |
23 | pbar = tqdm.tqdm(
24 | zip(osgb_splits, data_splits, targets_splits),
25 | total=len(data_splits),
26 | )
27 |
28 | scores = []
29 | for (osgb, data, target) in pbar:
30 | bs = len(data)
31 | preds = from_numpy(evaluator.predict(osgb, data))
32 | trgs = from_numpy(target)
33 | # this and the batch indexing essentially makes the 24 timesteps the batch
34 | preds = torch.unsqueeze(preds, dim=2)
35 | trgs = torch.unsqueeze(trgs, dim=2)
36 |
37 | for i in range(bs):
38 | # grab the current batch
39 | score = criterion(preds[i], trgs[i]).item()
40 | scores.append(score)
41 | pbar.set_description(f"Avg: {np.mean(scores)}")
42 |
43 | # scores = [
44 | # criterion(
45 | # from_numpy(evaluator.predict(*datum)).view(24, 64, 64).unsqueeze(dim=1),
46 | # from_numpy(target).view(24, 64, 64).unsqueeze(dim=1),
47 | # ).item()
48 | # for *datum, target in tqdm.tqdm(
49 | # zip(features["osgb"], features["data"], targets["data"]),
50 | # total=len(features["data"]),
51 | # )
52 | # ]
53 |
54 | print(f"Score: {np.mean(scores)} ({np.std(scores)})")
55 |
56 |
57 | if __name__ == "__main__":
58 | main()
59 |
--------------------------------------------------------------------------------
/experiments/climatehack-submission/submit.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -eEou pipefail
4 |
5 | python doxa_cli.py agent upload climatehack ./submission
6 |
--------------------------------------------------------------------------------
/experiments/dgmr-dct/dgmr/__init__.py:
--------------------------------------------------------------------------------
1 | from .dgmr import DGMR
2 | from .generators import Sampler, Generator
3 | from .discriminators import SpatialDiscriminator, TemporalDiscriminator, Discriminator
4 | from .common import LatentConditioningStack, ContextConditioningStack
5 |
--------------------------------------------------------------------------------
/experiments/dgmr-dct/dgmr/dgmr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from dgmr.losses import (
3 | NowcastingLoss,
4 | GridCellLoss,
5 | loss_hinge_disc,
6 | loss_hinge_gen,
7 | grid_cell_regularizer,
8 | )
9 | import pytorch_lightning as pl
10 | import torchvision
11 | from dgmr.common import LatentConditioningStack, ContextConditioningStack
12 | from dgmr.generators import Sampler, Generator
13 | from dgmr.discriminators import Discriminator
14 | from dgmr.hub import NowcastingModelHubMixin
15 |
16 |
17 | class DGMR(pl.LightningModule, NowcastingModelHubMixin):
18 | """Deep Generative Model of Radar"""
19 |
20 | def __init__(
21 | self,
22 | forecast_steps: int = 18,
23 | input_channels: int = 1,
24 | output_shape: int = 256,
25 | gen_lr: float = 5e-5,
26 | disc_lr: float = 2e-4,
27 | visualize: bool = False,
28 | conv_type: str = "standard",
29 | num_samples: int = 6,
30 | grid_lambda: float = 20.0,
31 | beta1: float = 0.0,
32 | beta2: float = 0.999,
33 | latent_channels: int = 768,
34 | context_channels: int = 384,
35 | **kwargs,
36 | ):
37 | """
38 | Nowcasting GAN is an attempt to recreate DeepMind's Skillful Nowcasting GAN from https://arxiv.org/abs/2104.00954
39 | but slightly modified for multiple satellite channels
40 |
41 | Args:
42 | forecast_steps: Number of steps to predict in the future
43 | input_channels: Number of input channels per image
44 | visualize: Whether to visualize output during training
45 | gen_lr: Learning rate for the generator
46 | disc_lr: Learning rate for the discriminators, shared for both temporal and spatial discriminator
47 | conv_type: Type of 2d convolution to use, see satflow/models/utils.py for options
48 | beta1: Beta1 for Adam optimizer
49 | beta2: Beta2 for Adam optimizer
50 | num_samples: Number of samples of the latent space to sample for training/validation
51 | grid_lambda: Lambda for the grid regularization loss
52 | output_shape: Shape of the output predictions, generally should be same as the input shape
53 | latent_channels: Number of channels that the latent space should be reshaped to,
54 | input dimension into ConvGRU, also affects the number of channels for other linked inputs/outputs
55 | pretrained:
56 | """
57 | super().__init__()
58 | config = locals()
59 | config.pop("__class__")
60 | config.pop("self")
61 | self.config = kwargs.get("config", config)
62 | input_channels = self.config["input_channels"]
63 | forecast_steps = self.config["forecast_steps"]
64 | output_shape = self.config["output_shape"]
65 | gen_lr = self.config["gen_lr"]
66 | disc_lr = self.config["disc_lr"]
67 | conv_type = self.config["conv_type"]
68 | num_samples = self.config["num_samples"]
69 | grid_lambda = self.config["grid_lambda"]
70 | beta1 = self.config["beta1"]
71 | beta2 = self.config["beta2"]
72 | latent_channels = self.config["latent_channels"]
73 | context_channels = self.config["context_channels"]
74 | visualize = self.config["visualize"]
75 | self.gen_lr = gen_lr
76 | self.disc_lr = disc_lr
77 | self.beta1 = beta1
78 | self.beta2 = beta2
79 | self.discriminator_loss = NowcastingLoss()
80 | self.grid_regularizer = GridCellLoss()
81 | self.grid_lambda = grid_lambda
82 | self.num_samples = num_samples
83 | self.visualize = visualize
84 | self.latent_channels = latent_channels
85 | self.context_channels = context_channels
86 | self.input_channels = input_channels
87 | self.conditioning_stack = ContextConditioningStack(
88 | input_channels=input_channels,
89 | conv_type=conv_type,
90 | output_channels=self.context_channels,
91 | )
92 | self.latent_stack = LatentConditioningStack(
93 | shape=(8 * self.input_channels, output_shape // 32, output_shape // 32),
94 | output_channels=self.latent_channels,
95 | )
96 | self.sampler = Sampler(
97 | forecast_steps=forecast_steps,
98 | latent_channels=self.latent_channels,
99 | context_channels=self.context_channels,
100 | )
101 | self.generator = Generator(
102 | self.conditioning_stack, self.latent_stack, self.sampler
103 | )
104 | self.discriminator = Discriminator(input_channels)
105 | self.save_hyperparameters()
106 |
107 | self.global_iteration = 0
108 |
109 | # Important: This property activates manual optimization.
110 | self.automatic_optimization = False
111 | torch.autograd.set_detect_anomaly(True)
112 |
113 | def forward(self, x):
114 | x = self.generator(x)
115 | return x
116 |
117 | def training_step(self, batch, batch_idx):
118 | images, future_images = batch
119 |
120 | self.global_iteration += 1
121 | g_opt, d_opt = self.optimizers()
122 | ##########################
123 | # Optimize Discriminator #
124 | ##########################
125 | # Two discriminator steps per generator step
126 | for _ in range(2):
127 | predictions = self(images)
128 | # Cat along time dimension [B, T, C, H, W]
129 | generated_sequence = torch.cat([images, predictions], dim=1)
130 | real_sequence = torch.cat([images, future_images], dim=1)
131 | # Cat long batch for the real+generated
132 | concatenated_inputs = torch.cat([real_sequence, generated_sequence], dim=0)
133 |
134 | concatenated_outputs = self.discriminator(concatenated_inputs)
135 |
136 | score_real, score_generated = torch.split(concatenated_outputs, 1, dim=1)
137 | discriminator_loss = loss_hinge_disc(score_generated, score_real)
138 | d_opt.zero_grad()
139 | self.manual_backward(discriminator_loss)
140 | d_opt.step()
141 |
142 | ######################
143 | # Optimize Generator #
144 | ######################
145 | predictions = [self(images) for _ in range(6)]
146 | grid_cell_reg = grid_cell_regularizer(
147 | torch.stack(predictions, dim=0), future_images
148 | )
149 | # Concat along time dimension
150 | generated_sequence = [torch.cat([images, x], dim=1) for x in predictions]
151 | real_sequence = torch.cat([images, future_images], dim=1)
152 | # Cat long batch for the real+generated, for each example in the range
153 | # For each of the 6 examples
154 | generated_scores = []
155 | for g_seq in generated_sequence:
156 | concatenated_inputs = torch.cat([real_sequence, g_seq], dim=0)
157 | concatenated_outputs = self.discriminator(concatenated_inputs)
158 | score_real, score_generated = torch.split(concatenated_outputs, 1, dim=1)
159 | generated_scores.append(score_generated)
160 | generator_disc_loss = loss_hinge_gen(torch.cat(generated_scores, dim=0))
161 | generator_loss = generator_disc_loss + self.grid_lambda * grid_cell_reg
162 | g_opt.zero_grad()
163 | self.manual_backward(generator_loss)
164 | g_opt.step()
165 |
166 | self.log_dict(
167 | {
168 | "train/d_loss": discriminator_loss,
169 | "train/g_loss": generator_loss,
170 | "train/grid_loss": grid_cell_reg,
171 | },
172 | prog_bar=True,
173 | )
174 |
175 | # generate images
176 | generated_images = self(images)
177 | # log sampled images
178 | if self.visualize:
179 | self.visualize_step(
180 | images,
181 | future_images,
182 | generated_images,
183 | self.global_iteration,
184 | step="train",
185 | )
186 |
187 | def configure_optimizers(self):
188 | b1 = self.beta1
189 | b2 = self.beta2
190 |
191 | opt_g = torch.optim.Adam(
192 | self.generator.parameters(), lr=self.gen_lr, betas=(b1, b2)
193 | )
194 | opt_d = torch.optim.Adam(
195 | self.discriminator.parameters(), lr=self.disc_lr, betas=(b1, b2)
196 | )
197 |
198 | return [opt_g, opt_d], []
199 |
200 | def visualize_step(
201 | self,
202 | x: torch.Tensor,
203 | y: torch.Tensor,
204 | y_hat: torch.Tensor,
205 | batch_idx: int,
206 | step: str,
207 | ) -> None:
208 | # the logger you used (in this case tensorboard)
209 | tensorboard = self.logger.experiment[0]
210 | # Timesteps per channel
211 | images = x[0].cpu().detach()
212 | future_images = y[0].cpu().detach()
213 | generated_images = y_hat[0].cpu().detach()
214 | for i, t in enumerate(images): # Now would be (C, H, W)
215 | t = [torch.unsqueeze(img, dim=0) for img in t]
216 | image_grid = torchvision.utils.make_grid(t, nrow=self.input_channels)
217 | tensorboard.add_image(
218 | f"{step}/Input_Image_Stack_Frame_{i}", image_grid, global_step=batch_idx
219 | )
220 | t = [torch.unsqueeze(img, dim=0) for img in future_images[i]]
221 | image_grid = torchvision.utils.make_grid(t, nrow=self.input_channels)
222 | tensorboard.add_image(
223 | f"{step}/Target_Image_Frame_{i}", image_grid, global_step=batch_idx
224 | )
225 | t = [torch.unsqueeze(img, dim=0) for img in generated_images[i]]
226 | image_grid = torchvision.utils.make_grid(t, nrow=self.input_channels)
227 | tensorboard.add_image(
228 | f"{step}/Generated_Image_Frame_{i}", image_grid, global_step=batch_idx
229 | )
230 |
--------------------------------------------------------------------------------
/experiments/dgmr-dct/dgmr/hub.py:
--------------------------------------------------------------------------------
1 | """
2 | Originally Taken from https://github.com/rwightman/
3 |
4 | https://github.com/rwightman/pytorch-image-models/
5 | blob/acd6c687fd1c0507128f0ce091829b233c8560b9/timm/models/hub.py
6 | """
7 |
8 | import json
9 | import logging
10 | import os
11 | from functools import partial
12 |
13 | import torch
14 |
15 |
16 | try:
17 | from huggingface_hub import cached_download, hf_hub_url
18 |
19 | cached_download = partial(cached_download, library_name="dgmr")
20 | except ImportError:
21 | hf_hub_url = None
22 | cached_download = None
23 |
24 | from huggingface_hub import CONFIG_NAME, PYTORCH_WEIGHTS_NAME, ModelHubMixin, hf_hub_download
25 |
26 | MODEL_CARD_MARKDOWN = """---
27 | license: mit
28 | tags:
29 | - nowcasting
30 | - forecasting
31 | - timeseries
32 | - remote-sensing
33 | - gan
34 | ---
35 |
36 | # {model_name}
37 |
38 | ## Model description
39 |
40 | [More information needed]
41 |
42 | ## Intended uses & limitations
43 |
44 | [More information needed]
45 |
46 | ## How to use
47 |
48 | [More information needed]
49 |
50 | ## Limitations and bias
51 |
52 | [More information needed]
53 |
54 | ## Training data
55 |
56 | [More information needed]
57 |
58 | ## Training procedure
59 |
60 | [More information needed]
61 |
62 | ## Evaluation results
63 |
64 | [More information needed]
65 |
66 | """
67 |
68 | _logger = logging.getLogger(__name__)
69 |
70 |
71 | class NowcastingModelHubMixin(ModelHubMixin):
72 | """
73 | HuggingFace ModelHubMixin containing specific adaptions for Nowcasting models
74 | """
75 |
76 | def __init__(self, *args, **kwargs):
77 | """
78 | Mixin for pl.LightningModule and Hugging Face
79 |
80 | Mix this class with your pl.LightningModule class to easily push / download
81 | the model via the Hugging Face Hub
82 |
83 | Example::
84 |
85 | >>> from dgmr.hub import NowcastingModelHubMixin
86 |
87 | >>> class MyModel(nn.Module, NowcastingModelHubMixin):
88 | ... def __init__(self, **kwargs):
89 | ... super().__init__()
90 | ... self.layer = ...
91 | ... def forward(self, ...)
92 | ... return ...
93 |
94 | >>> model = MyModel()
95 | >>> model.push_to_hub("mymodel") # Pushing model-weights to hf-hub
96 |
97 | >>> # Downloading weights from hf-hub & model will be initialized from those weights
98 | >>> model = MyModel.from_pretrained("username/mymodel")
99 | """
100 |
101 | def _create_model_card(self, path):
102 | model_card = MODEL_CARD_MARKDOWN.format(model_name=type(self).__name__)
103 | with open(os.path.join(path, "README.md"), "w") as f:
104 | f.write(model_card)
105 |
106 | def _save_config(self, module, save_directory):
107 | config = dict(module.hparams)
108 | path = os.path.join(save_directory, CONFIG_NAME)
109 | with open(path, "w") as f:
110 | json.dump(config, f)
111 |
112 | def _save_pretrained(self, save_directory: str, save_config: bool = True):
113 | # Save model weights
114 | path = os.path.join(save_directory, PYTORCH_WEIGHTS_NAME)
115 | model_to_save = self.module if hasattr(self, "module") else self
116 | torch.save(model_to_save.state_dict(), path)
117 | # Save model config
118 | if save_config and model_to_save.hparams:
119 | self._save_config(model_to_save, save_directory)
120 | # Save model card
121 | self._create_model_card(save_directory)
122 |
123 | @classmethod
124 | def _from_pretrained(
125 | cls,
126 | model_id,
127 | revision,
128 | cache_dir,
129 | force_download,
130 | proxies,
131 | resume_download,
132 | local_files_only,
133 | use_auth_token,
134 | map_location="cpu",
135 | strict=False,
136 | **model_kwargs,
137 | ):
138 | map_location = torch.device(map_location)
139 |
140 | if os.path.isdir(model_id):
141 | print("Loading weights from local directory")
142 | model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME)
143 | else:
144 | model_file = hf_hub_download(
145 | repo_id=model_id,
146 | filename=PYTORCH_WEIGHTS_NAME,
147 | revision=revision,
148 | cache_dir=cache_dir,
149 | force_download=force_download,
150 | proxies=proxies,
151 | resume_download=resume_download,
152 | use_auth_token=use_auth_token,
153 | local_files_only=local_files_only,
154 | )
155 | model = cls(**model_kwargs["config"])
156 |
157 | state_dict = torch.load(model_file, map_location=map_location)
158 | model.load_state_dict(state_dict, strict=strict)
159 | model.eval()
160 |
161 | return model
162 |
--------------------------------------------------------------------------------
/experiments/dgmr-dct/dgmr/layers/Attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import functional as F
4 | import einops
5 |
6 |
7 | def attention_einsum(q, k, v):
8 | """Apply the attention operator to tensors of shape [h, w, c]."""
9 |
10 | # Reshape 3D tensors to 2D tensor with first dimension L = h x w.
11 | k = einops.rearrange(k, "h w c -> (h w) c") # [h, w, c] -> [L, c]
12 | v = einops.rearrange(v, "h w c -> (h w) c") # [h, w, c] -> [L, c]
13 |
14 | # Einstein summation corresponding to the query * key operation.
15 | beta = F.softmax(torch.einsum("hwc, Lc->hwL", q, k), dim=-1)
16 |
17 | # Einstein summation corresponding to the attention * value operation.
18 | out = torch.einsum("hwL, Lc->hwc", beta, v)
19 | return out
20 |
21 |
22 | class AttentionLayer(torch.nn.Module):
23 | """Attention Module"""
24 |
25 | def __init__(self, input_channels: int, output_channels: int, ratio_kq=8, ratio_v=8):
26 | super(AttentionLayer, self).__init__()
27 |
28 | self.ratio_kq = ratio_kq
29 | self.ratio_v = ratio_v
30 | self.output_channels = output_channels
31 | self.input_channels = input_channels
32 |
33 | # Compute query, key and value using 1x1 convolutions.
34 | self.query = torch.nn.Conv2d(
35 | in_channels=input_channels,
36 | out_channels=self.output_channels // self.ratio_kq,
37 | kernel_size=(1, 1),
38 | padding="valid",
39 | bias=False,
40 | )
41 | self.key = torch.nn.Conv2d(
42 | in_channels=input_channels,
43 | out_channels=self.output_channels // self.ratio_kq,
44 | kernel_size=(1, 1),
45 | padding="valid",
46 | bias=False,
47 | )
48 | self.value = torch.nn.Conv2d(
49 | in_channels=input_channels,
50 | out_channels=self.output_channels // self.ratio_v,
51 | kernel_size=(1, 1),
52 | padding="valid",
53 | bias=False,
54 | )
55 |
56 | self.last_conv = torch.nn.Conv2d(
57 | in_channels=self.output_channels // 8,
58 | out_channels=self.output_channels,
59 | kernel_size=(1, 1),
60 | padding="valid",
61 | bias=False,
62 | )
63 |
64 | # Learnable gain parameter
65 | self.gamma = nn.Parameter(torch.zeros(1))
66 |
67 | def forward(self, x: torch.Tensor) -> torch.Tensor:
68 | # Compute query, key and value using 1x1 convolutions.
69 | query = self.query(x)
70 | key = self.key(x)
71 | value = self.value(x)
72 | # Apply the attention operation.
73 | out = []
74 | for b in range(x.shape[0]):
75 | # Apply to each in batch
76 | out.append(attention_einsum(query[b], key[b], value[b]))
77 | out = torch.stack(out, dim=0)
78 | out = self.gamma * self.last_conv(out)
79 | # Residual connection.
80 | return out + x
81 |
--------------------------------------------------------------------------------
/experiments/dgmr-dct/dgmr/layers/ConvGRU.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.nn.utils.parametrizations import spectral_norm
4 |
5 |
6 | class ConvGRUCell(torch.nn.Module):
7 | """A ConvGRU implementation."""
8 |
9 | def __init__(
10 | self, input_channels: int, output_channels: int, kernel_size=3, sn_eps=0.0001
11 | ):
12 | """Constructor.
13 |
14 | Args:
15 | kernel_size: kernel size of the convolutions. Default: 3.
16 | sn_eps: constant for spectral normalization. Default: 1e-4.
17 | """
18 | super().__init__()
19 | self._kernel_size = kernel_size
20 | self._sn_eps = sn_eps
21 | self.read_gate_conv = spectral_norm(
22 | torch.nn.Conv2d(
23 | in_channels=input_channels,
24 | out_channels=output_channels,
25 | kernel_size=(kernel_size, kernel_size),
26 | padding=1,
27 | ),
28 | eps=sn_eps,
29 | )
30 | self.update_gate_conv = spectral_norm(
31 | torch.nn.Conv2d(
32 | in_channels=input_channels,
33 | out_channels=output_channels,
34 | kernel_size=(kernel_size, kernel_size),
35 | padding=1,
36 | ),
37 | eps=sn_eps,
38 | )
39 | self.output_conv = spectral_norm(
40 | torch.nn.Conv2d(
41 | in_channels=input_channels,
42 | out_channels=output_channels,
43 | kernel_size=(kernel_size, kernel_size),
44 | padding=1,
45 | ),
46 | eps=sn_eps,
47 | )
48 |
49 | def forward(self, x, prev_state):
50 | """
51 | ConvGRU forward, returning the current+new state
52 |
53 | Args:
54 | x: Input tensor
55 | prev_state: Previous state
56 |
57 | Returns:
58 | New tensor plus the new state
59 | """
60 | # Concatenate the inputs and previous state along the channel axis.
61 | xh = torch.cat([x, prev_state], dim=1)
62 |
63 | # Read gate of the GRU.
64 | read_gate = F.sigmoid(self.read_gate_conv(xh))
65 |
66 | # Update gate of the GRU.
67 | update_gate = F.sigmoid(self.update_gate_conv(xh))
68 |
69 | # Gate the inputs.
70 | gated_input = torch.cat([x, read_gate * prev_state], dim=1)
71 |
72 | # Gate the cell and state / outputs.
73 | c = F.relu(self.output_conv(gated_input))
74 | out = update_gate * prev_state + (1.0 - update_gate) * c
75 | new_state = out
76 |
77 | return out, new_state
78 |
79 |
80 | class ConvGRU(torch.nn.Module):
81 | """ConvGRU Cell wrapper to replace tf.static_rnn in TF implementation"""
82 |
83 | def __init__(
84 | self,
85 | input_channels: int,
86 | output_channels: int,
87 | kernel_size: int = 3,
88 | sn_eps=0.0001,
89 | ):
90 | super().__init__()
91 | self.cell = ConvGRUCell(input_channels, output_channels, kernel_size, sn_eps)
92 |
93 | def forward(self, x: torch.Tensor, hidden_state=None) -> torch.Tensor:
94 | outputs = []
95 | for step in range(len(x)):
96 | # Compute current timestep
97 | output, hidden_state = self.cell(x[step], hidden_state)
98 | outputs.append(output)
99 | # Stack outputs to return as tensor
100 | outputs = torch.stack(outputs, dim=0)
101 | return outputs
102 |
--------------------------------------------------------------------------------
/experiments/dgmr-dct/dgmr/layers/CoordConv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class AddCoords(nn.Module):
6 | def __init__(self, with_r=False):
7 | super().__init__()
8 | self.with_r = with_r
9 |
10 | def forward(self, input_tensor):
11 | """
12 | Args:
13 | input_tensor: shape(batch, channel, x_dim, y_dim)
14 | """
15 | batch_size, _, x_dim, y_dim = input_tensor.size()
16 |
17 | xx_channel = torch.arange(x_dim).repeat(1, y_dim, 1)
18 | yy_channel = torch.arange(y_dim).repeat(1, x_dim, 1).transpose(1, 2)
19 |
20 | xx_channel = xx_channel.float() / (x_dim - 1)
21 | yy_channel = yy_channel.float() / (y_dim - 1)
22 |
23 | xx_channel = xx_channel * 2 - 1
24 | yy_channel = yy_channel * 2 - 1
25 |
26 | xx_channel = xx_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)
27 | yy_channel = yy_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)
28 |
29 | ret = torch.cat(
30 | [
31 | input_tensor,
32 | xx_channel.type_as(input_tensor),
33 | yy_channel.type_as(input_tensor),
34 | ],
35 | dim=1,
36 | )
37 |
38 | if self.with_r:
39 | rr = torch.sqrt(
40 | torch.pow(xx_channel.type_as(input_tensor) - 0.5, 2)
41 | + torch.pow(yy_channel.type_as(input_tensor) - 0.5, 2)
42 | )
43 | ret = torch.cat([ret, rr], dim=1)
44 |
45 | return ret
46 |
47 |
48 | class CoordConv(nn.Module):
49 | def __init__(self, in_channels, out_channels, with_r=False, **kwargs):
50 | super().__init__()
51 | self.addcoords = AddCoords(with_r=with_r)
52 | in_size = in_channels + 2
53 | if with_r:
54 | in_size += 1
55 | self.conv = nn.Conv2d(in_size, out_channels, **kwargs)
56 |
57 | def forward(self, x):
58 | ret = self.addcoords(x)
59 | ret = self.conv(ret)
60 | return ret
61 |
--------------------------------------------------------------------------------
/experiments/dgmr-dct/dgmr/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from .Attention import AttentionLayer
2 | from .ConvGRU import ConvGRU
3 | from .CoordConv import CoordConv
4 |
--------------------------------------------------------------------------------
/experiments/dgmr-dct/dgmr/layers/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from dgmr.layers import CoordConv
3 |
4 |
5 | def get_conv_layer(conv_type: str = "standard") -> torch.nn.Module:
6 | if conv_type == "standard":
7 | conv_layer = torch.nn.Conv2d
8 | elif conv_type == "coord":
9 | conv_layer = CoordConv
10 | elif conv_type == "3d":
11 | conv_layer = torch.nn.Conv3d
12 | else:
13 | raise ValueError(f"{conv_type} is not a recognized Conv method")
14 | return conv_layer
15 |
--------------------------------------------------------------------------------
/experiments/dgmr-multichannel/dgmr/__init__.py:
--------------------------------------------------------------------------------
1 | from .dgmr import DGMR
2 | from .generators import Sampler, Generator
3 | from .discriminators import SpatialDiscriminator, TemporalDiscriminator, Discriminator
4 | from .common import LatentConditioningStack, ContextConditioningStack
5 |
--------------------------------------------------------------------------------
/experiments/dgmr-multichannel/dgmr/generators.py:
--------------------------------------------------------------------------------
1 | import einops
2 | import torch
3 | import torch.nn.functional as F
4 | from torch.nn.modules.pixelshuffle import PixelShuffle
5 | from torch.nn.utils.parametrizations import spectral_norm
6 | from typing import List
7 | from dgmr.common import GBlock, UpsampleGBlock
8 | from dgmr.layers import ConvGRU
9 | from huggingface_hub import PyTorchModelHubMixin
10 | import logging
11 |
12 | logger = logging.getLogger(__name__)
13 | logger.setLevel(logging.WARN)
14 |
15 |
16 | class Sampler(torch.nn.Module, PyTorchModelHubMixin):
17 | def __init__(
18 | self,
19 | forecast_steps: int = 18,
20 | latent_channels: int = 768,
21 | context_channels: int = 384,
22 | output_channels: int = 1,
23 | **kwargs
24 | ):
25 | """
26 | Sampler from the Skillful Nowcasting, see https://arxiv.org/pdf/2104.00954.pdf
27 |
28 | The sampler takes the output from the Latent and Context conditioning stacks and
29 | creates one stack of ConvGRU layers per future timestep.
30 | Args:
31 | forecast_steps: Number of forecast steps
32 | latent_channels: Number of input channels to the lowest ConvGRU layer
33 | """
34 | super().__init__()
35 | config = locals()
36 | config.pop("__class__")
37 | config.pop("self")
38 | self.config = kwargs.get("config", config)
39 | self.forecast_steps = self.config["forecast_steps"]
40 | latent_channels = self.config["latent_channels"]
41 | context_channels = self.config["context_channels"]
42 | output_channels = self.config["output_channels"]
43 |
44 | self.convGRU1 = ConvGRU(
45 | input_channels=latent_channels + context_channels,
46 | output_channels=context_channels,
47 | kernel_size=3,
48 | )
49 | self.gru_conv_1x1 = spectral_norm(
50 | torch.nn.Conv2d(
51 | in_channels=context_channels,
52 | out_channels=latent_channels,
53 | kernel_size=(1, 1),
54 | )
55 | )
56 | self.g1 = GBlock(
57 | input_channels=latent_channels, output_channels=latent_channels
58 | )
59 | self.up_g1 = UpsampleGBlock(
60 | input_channels=latent_channels, output_channels=latent_channels // 2
61 | )
62 |
63 | self.convGRU2 = ConvGRU(
64 | input_channels=latent_channels // 2 + context_channels // 2,
65 | output_channels=context_channels // 2,
66 | kernel_size=3,
67 | )
68 | self.gru_conv_1x1_2 = spectral_norm(
69 | torch.nn.Conv2d(
70 | in_channels=context_channels // 2,
71 | out_channels=latent_channels // 2,
72 | kernel_size=(1, 1),
73 | )
74 | )
75 | self.g2 = GBlock(
76 | input_channels=latent_channels // 2, output_channels=latent_channels // 2
77 | )
78 | self.up_g2 = UpsampleGBlock(
79 | input_channels=latent_channels // 2, output_channels=latent_channels // 4
80 | )
81 |
82 | self.convGRU3 = ConvGRU(
83 | input_channels=latent_channels // 4 + context_channels // 4,
84 | output_channels=context_channels // 4,
85 | kernel_size=3,
86 | )
87 | self.gru_conv_1x1_3 = spectral_norm(
88 | torch.nn.Conv2d(
89 | in_channels=context_channels // 4,
90 | out_channels=latent_channels // 4,
91 | kernel_size=(1, 1),
92 | )
93 | )
94 | self.g3 = GBlock(
95 | input_channels=latent_channels // 4, output_channels=latent_channels // 4
96 | )
97 | self.up_g3 = UpsampleGBlock(
98 | input_channels=latent_channels // 4, output_channels=latent_channels // 8
99 | )
100 |
101 | self.convGRU4 = ConvGRU(
102 | input_channels=latent_channels // 8 + context_channels // 8,
103 | output_channels=context_channels // 8,
104 | kernel_size=3,
105 | )
106 | self.gru_conv_1x1_4 = spectral_norm(
107 | torch.nn.Conv2d(
108 | in_channels=context_channels // 8,
109 | out_channels=latent_channels // 8,
110 | kernel_size=(1, 1),
111 | )
112 | )
113 | self.g4 = GBlock(
114 | input_channels=latent_channels // 8, output_channels=latent_channels // 8
115 | )
116 | self.up_g4 = UpsampleGBlock(
117 | input_channels=latent_channels // 8, output_channels=latent_channels // 16
118 | )
119 |
120 | self.bn = torch.nn.BatchNorm2d(latent_channels // 16)
121 | self.relu = torch.nn.ReLU()
122 | self.conv_1x1 = spectral_norm(
123 | torch.nn.Conv2d(
124 | in_channels=latent_channels // 16,
125 | out_channels=4 * output_channels,
126 | kernel_size=(1, 1),
127 | )
128 | )
129 |
130 | self.depth2space = PixelShuffle(upscale_factor=2)
131 |
132 | def forward(
133 | self, conditioning_states: List[torch.Tensor], latent_dim: torch.Tensor
134 | ) -> torch.Tensor:
135 | """
136 | Perform the sampling from Skillful Nowcasting with GANs
137 | Args:
138 | conditioning_states: Outputs from the `ContextConditioningStack` with the 4 input states, ordered from largest to smallest spatially
139 | latent_dim: Output from `LatentConditioningStack` for input into the ConvGRUs
140 |
141 | Returns:
142 | forecast_steps-length output of images for future timesteps
143 |
144 | """
145 | # Iterate through each forecast step
146 | # Initialize with conditioning state for first one, output for second one
147 | init_states = conditioning_states
148 | # Expand latent dim to match batch size
149 | latent_dim = einops.repeat(
150 | latent_dim, "b c h w -> (repeat b) c h w", repeat=init_states[0].shape[0]
151 | )
152 | hidden_states = [latent_dim] * self.forecast_steps
153 |
154 | # TODO: can we make this into a UNET and remove the LSTM?
155 | # Layer 4 (bottom most)
156 | hidden_states = self.convGRU1(hidden_states, init_states[3])
157 | hidden_states = [self.gru_conv_1x1(h) for h in hidden_states]
158 | hidden_states = [self.g1(h) for h in hidden_states]
159 | hidden_states = [self.up_g1(h) for h in hidden_states]
160 |
161 | # Layer 3.
162 | hidden_states = self.convGRU2(hidden_states, init_states[2])
163 | hidden_states = [self.gru_conv_1x1_2(h) for h in hidden_states]
164 | hidden_states = [self.g2(h) for h in hidden_states]
165 | hidden_states = [self.up_g2(h) for h in hidden_states]
166 |
167 | # Layer 2.
168 | hidden_states = self.convGRU3(hidden_states, init_states[1])
169 | hidden_states = [self.gru_conv_1x1_3(h) for h in hidden_states]
170 | hidden_states = [self.g3(h) for h in hidden_states]
171 | hidden_states = [self.up_g3(h) for h in hidden_states]
172 |
173 | # Layer 1 (top-most).
174 | hidden_states = self.convGRU4(hidden_states, init_states[0])
175 | hidden_states = [self.gru_conv_1x1_4(h) for h in hidden_states]
176 | hidden_states = [self.g4(h) for h in hidden_states]
177 | hidden_states = [self.up_g4(h) for h in hidden_states]
178 |
179 | # Output layer.
180 | hidden_states = [F.relu(self.bn(h)) for h in hidden_states]
181 | hidden_states = [self.conv_1x1(h) for h in hidden_states]
182 | hidden_states = [self.depth2space(h) for h in hidden_states]
183 |
184 | # Convert forecasts to a torch Tensor
185 | forecasts = torch.stack(hidden_states, dim=1)
186 | return forecasts
187 |
188 |
189 | class Generator(torch.nn.Module, PyTorchModelHubMixin):
190 | def __init__(
191 | self,
192 | conditioning_stack: torch.nn.Module,
193 | latent_stack: torch.nn.Module,
194 | sampler: torch.nn.Module,
195 | ):
196 | """
197 | Wraps the three parts of the generator for simpler calling
198 | Args:
199 | conditioning_stack:
200 | latent_stack:
201 | sampler:
202 | """
203 | super().__init__()
204 | self.conditioning_stack = conditioning_stack
205 | self.latent_stack = latent_stack
206 | self.sampler = sampler
207 |
208 | def forward(self, x):
209 | conditioning_states = self.conditioning_stack(x)
210 | latent_dim = self.latent_stack(x)
211 | x = self.sampler(conditioning_states, latent_dim)
212 | return x
213 |
--------------------------------------------------------------------------------
/experiments/dgmr-multichannel/dgmr/hub.py:
--------------------------------------------------------------------------------
1 | """
2 | Originally Taken from https://github.com/rwightman/
3 |
4 | https://github.com/rwightman/pytorch-image-models/
5 | blob/acd6c687fd1c0507128f0ce091829b233c8560b9/timm/models/hub.py
6 | """
7 |
8 | import json
9 | import logging
10 | import os
11 | from functools import partial
12 |
13 | import torch
14 |
15 |
16 | try:
17 | from huggingface_hub import cached_download, hf_hub_url
18 |
19 | cached_download = partial(cached_download, library_name="dgmr")
20 | except ImportError:
21 | hf_hub_url = None
22 | cached_download = None
23 |
24 | from huggingface_hub import CONFIG_NAME, PYTORCH_WEIGHTS_NAME, ModelHubMixin, hf_hub_download
25 |
26 | MODEL_CARD_MARKDOWN = """---
27 | license: mit
28 | tags:
29 | - nowcasting
30 | - forecasting
31 | - timeseries
32 | - remote-sensing
33 | - gan
34 | ---
35 |
36 | # {model_name}
37 |
38 | ## Model description
39 |
40 | [More information needed]
41 |
42 | ## Intended uses & limitations
43 |
44 | [More information needed]
45 |
46 | ## How to use
47 |
48 | [More information needed]
49 |
50 | ## Limitations and bias
51 |
52 | [More information needed]
53 |
54 | ## Training data
55 |
56 | [More information needed]
57 |
58 | ## Training procedure
59 |
60 | [More information needed]
61 |
62 | ## Evaluation results
63 |
64 | [More information needed]
65 |
66 | """
67 |
68 | _logger = logging.getLogger(__name__)
69 |
70 |
71 | class NowcastingModelHubMixin(ModelHubMixin):
72 | """
73 | HuggingFace ModelHubMixin containing specific adaptions for Nowcasting models
74 | """
75 |
76 | def __init__(self, *args, **kwargs):
77 | """
78 | Mixin for pl.LightningModule and Hugging Face
79 |
80 | Mix this class with your pl.LightningModule class to easily push / download
81 | the model via the Hugging Face Hub
82 |
83 | Example::
84 |
85 | >>> from dgmr.hub import NowcastingModelHubMixin
86 |
87 | >>> class MyModel(nn.Module, NowcastingModelHubMixin):
88 | ... def __init__(self, **kwargs):
89 | ... super().__init__()
90 | ... self.layer = ...
91 | ... def forward(self, ...)
92 | ... return ...
93 |
94 | >>> model = MyModel()
95 | >>> model.push_to_hub("mymodel") # Pushing model-weights to hf-hub
96 |
97 | >>> # Downloading weights from hf-hub & model will be initialized from those weights
98 | >>> model = MyModel.from_pretrained("username/mymodel")
99 | """
100 |
101 | def _create_model_card(self, path):
102 | model_card = MODEL_CARD_MARKDOWN.format(model_name=type(self).__name__)
103 | with open(os.path.join(path, "README.md"), "w") as f:
104 | f.write(model_card)
105 |
106 | def _save_config(self, module, save_directory):
107 | config = dict(module.hparams)
108 | path = os.path.join(save_directory, CONFIG_NAME)
109 | with open(path, "w") as f:
110 | json.dump(config, f)
111 |
112 | def _save_pretrained(self, save_directory: str, save_config: bool = True):
113 | # Save model weights
114 | path = os.path.join(save_directory, PYTORCH_WEIGHTS_NAME)
115 | model_to_save = self.module if hasattr(self, "module") else self
116 | torch.save(model_to_save.state_dict(), path)
117 | # Save model config
118 | if save_config and model_to_save.hparams:
119 | self._save_config(model_to_save, save_directory)
120 | # Save model card
121 | self._create_model_card(save_directory)
122 |
123 | @classmethod
124 | def _from_pretrained(
125 | cls,
126 | model_id,
127 | revision,
128 | cache_dir,
129 | force_download,
130 | proxies,
131 | resume_download,
132 | local_files_only,
133 | use_auth_token,
134 | map_location="cpu",
135 | strict=False,
136 | **model_kwargs,
137 | ):
138 | map_location = torch.device(map_location)
139 |
140 | if os.path.isdir(model_id):
141 | print("Loading weights from local directory")
142 | model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME)
143 | else:
144 | model_file = hf_hub_download(
145 | repo_id=model_id,
146 | filename=PYTORCH_WEIGHTS_NAME,
147 | revision=revision,
148 | cache_dir=cache_dir,
149 | force_download=force_download,
150 | proxies=proxies,
151 | resume_download=resume_download,
152 | use_auth_token=use_auth_token,
153 | local_files_only=local_files_only,
154 | )
155 | model = cls(**model_kwargs["config"])
156 |
157 | state_dict = torch.load(model_file, map_location=map_location)
158 | model.load_state_dict(state_dict, strict=strict)
159 | model.eval()
160 |
161 | return model
162 |
--------------------------------------------------------------------------------
/experiments/dgmr-multichannel/dgmr/layers/Attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import functional as F
4 | import einops
5 |
6 |
7 | def attention_einsum(q, k, v):
8 | """Apply the attention operator to tensors of shape [h, w, c]."""
9 |
10 | # Reshape 3D tensors to 2D tensor with first dimension L = h x w.
11 | k = einops.rearrange(k, "h w c -> (h w) c") # [h, w, c] -> [L, c]
12 | v = einops.rearrange(v, "h w c -> (h w) c") # [h, w, c] -> [L, c]
13 |
14 | # Einstein summation corresponding to the query * key operation.
15 | beta = F.softmax(torch.einsum("hwc, Lc->hwL", q, k), dim=-1)
16 |
17 | # Einstein summation corresponding to the attention * value operation.
18 | out = torch.einsum("hwL, Lc->hwc", beta, v)
19 | return out
20 |
21 |
22 | class AttentionLayer(torch.nn.Module):
23 | """Attention Module"""
24 |
25 | def __init__(self, input_channels: int, output_channels: int, ratio_kq=8, ratio_v=8):
26 | super(AttentionLayer, self).__init__()
27 |
28 | self.ratio_kq = ratio_kq
29 | self.ratio_v = ratio_v
30 | self.output_channels = output_channels
31 | self.input_channels = input_channels
32 |
33 | # Compute query, key and value using 1x1 convolutions.
34 | self.query = torch.nn.Conv2d(
35 | in_channels=input_channels,
36 | out_channels=self.output_channels // self.ratio_kq,
37 | kernel_size=(1, 1),
38 | padding="valid",
39 | bias=False,
40 | )
41 | self.key = torch.nn.Conv2d(
42 | in_channels=input_channels,
43 | out_channels=self.output_channels // self.ratio_kq,
44 | kernel_size=(1, 1),
45 | padding="valid",
46 | bias=False,
47 | )
48 | self.value = torch.nn.Conv2d(
49 | in_channels=input_channels,
50 | out_channels=self.output_channels // self.ratio_v,
51 | kernel_size=(1, 1),
52 | padding="valid",
53 | bias=False,
54 | )
55 |
56 | self.last_conv = torch.nn.Conv2d(
57 | in_channels=self.output_channels // 8,
58 | out_channels=self.output_channels,
59 | kernel_size=(1, 1),
60 | padding="valid",
61 | bias=False,
62 | )
63 |
64 | # Learnable gain parameter
65 | self.gamma = nn.Parameter(torch.zeros(1))
66 |
67 | def forward(self, x: torch.Tensor) -> torch.Tensor:
68 | # Compute query, key and value using 1x1 convolutions.
69 | query = self.query(x)
70 | key = self.key(x)
71 | value = self.value(x)
72 | # Apply the attention operation.
73 | out = []
74 | for b in range(x.shape[0]):
75 | # Apply to each in batch
76 | out.append(attention_einsum(query[b], key[b], value[b]))
77 | out = torch.stack(out, dim=0)
78 | out = self.gamma * self.last_conv(out)
79 | # Residual connection.
80 | return out + x
81 |
--------------------------------------------------------------------------------
/experiments/dgmr-multichannel/dgmr/layers/ConvGRU.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.nn.utils.parametrizations import spectral_norm
4 |
5 |
6 | class ConvGRUCell(torch.nn.Module):
7 | """A ConvGRU implementation."""
8 |
9 | def __init__(
10 | self, input_channels: int, output_channels: int, kernel_size=3, sn_eps=0.0001
11 | ):
12 | """Constructor.
13 |
14 | Args:
15 | kernel_size: kernel size of the convolutions. Default: 3.
16 | sn_eps: constant for spectral normalization. Default: 1e-4.
17 | """
18 | super().__init__()
19 | self._kernel_size = kernel_size
20 | self._sn_eps = sn_eps
21 | self.read_gate_conv = spectral_norm(
22 | torch.nn.Conv2d(
23 | in_channels=input_channels,
24 | out_channels=output_channels,
25 | kernel_size=(kernel_size, kernel_size),
26 | padding=1,
27 | ),
28 | eps=sn_eps,
29 | )
30 | self.update_gate_conv = spectral_norm(
31 | torch.nn.Conv2d(
32 | in_channels=input_channels,
33 | out_channels=output_channels,
34 | kernel_size=(kernel_size, kernel_size),
35 | padding=1,
36 | ),
37 | eps=sn_eps,
38 | )
39 | self.output_conv = spectral_norm(
40 | torch.nn.Conv2d(
41 | in_channels=input_channels,
42 | out_channels=output_channels,
43 | kernel_size=(kernel_size, kernel_size),
44 | padding=1,
45 | ),
46 | eps=sn_eps,
47 | )
48 |
49 | def forward(self, x, prev_state):
50 | """
51 | ConvGRU forward, returning the current+new state
52 |
53 | Args:
54 | x: Input tensor
55 | prev_state: Previous state
56 |
57 | Returns:
58 | New tensor plus the new state
59 | """
60 | # Concatenate the inputs and previous state along the channel axis.
61 | xh = torch.cat([x, prev_state], dim=1)
62 |
63 | # Read gate of the GRU.
64 | read_gate = F.sigmoid(self.read_gate_conv(xh))
65 |
66 | # Update gate of the GRU.
67 | update_gate = F.sigmoid(self.update_gate_conv(xh))
68 |
69 | # Gate the inputs.
70 | gated_input = torch.cat([x, read_gate * prev_state], dim=1)
71 |
72 | # Gate the cell and state / outputs.
73 | c = F.relu(self.output_conv(gated_input))
74 | out = update_gate * prev_state + (1.0 - update_gate) * c
75 | new_state = out
76 |
77 | return out, new_state
78 |
79 |
80 | class ConvGRU(torch.nn.Module):
81 | """ConvGRU Cell wrapper to replace tf.static_rnn in TF implementation"""
82 |
83 | def __init__(
84 | self,
85 | input_channels: int,
86 | output_channels: int,
87 | kernel_size: int = 3,
88 | sn_eps=0.0001,
89 | ):
90 | super().__init__()
91 | self.cell = ConvGRUCell(input_channels, output_channels, kernel_size, sn_eps)
92 |
93 | def forward(self, x: torch.Tensor, hidden_state=None) -> torch.Tensor:
94 | outputs = []
95 | for step in range(len(x)):
96 | # Compute current timestep
97 | output, hidden_state = self.cell(x[step], hidden_state)
98 | outputs.append(output)
99 | # Stack outputs to return as tensor
100 | outputs = torch.stack(outputs, dim=0)
101 | return outputs
102 |
--------------------------------------------------------------------------------
/experiments/dgmr-multichannel/dgmr/layers/CoordConv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class AddCoords(nn.Module):
6 | def __init__(self, with_r=False):
7 | super().__init__()
8 | self.with_r = with_r
9 |
10 | def forward(self, input_tensor):
11 | """
12 | Args:
13 | input_tensor: shape(batch, channel, x_dim, y_dim)
14 | """
15 | batch_size, _, x_dim, y_dim = input_tensor.size()
16 |
17 | xx_channel = torch.arange(x_dim).repeat(1, y_dim, 1)
18 | yy_channel = torch.arange(y_dim).repeat(1, x_dim, 1).transpose(1, 2)
19 |
20 | xx_channel = xx_channel.float() / (x_dim - 1)
21 | yy_channel = yy_channel.float() / (y_dim - 1)
22 |
23 | xx_channel = xx_channel * 2 - 1
24 | yy_channel = yy_channel * 2 - 1
25 |
26 | xx_channel = xx_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)
27 | yy_channel = yy_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)
28 |
29 | ret = torch.cat(
30 | [
31 | input_tensor,
32 | xx_channel.type_as(input_tensor),
33 | yy_channel.type_as(input_tensor),
34 | ],
35 | dim=1,
36 | )
37 |
38 | if self.with_r:
39 | rr = torch.sqrt(
40 | torch.pow(xx_channel.type_as(input_tensor) - 0.5, 2)
41 | + torch.pow(yy_channel.type_as(input_tensor) - 0.5, 2)
42 | )
43 | ret = torch.cat([ret, rr], dim=1)
44 |
45 | return ret
46 |
47 |
48 | class CoordConv(nn.Module):
49 | def __init__(self, in_channels, out_channels, with_r=False, **kwargs):
50 | super().__init__()
51 | self.addcoords = AddCoords(with_r=with_r)
52 | in_size = in_channels + 2
53 | if with_r:
54 | in_size += 1
55 | self.conv = nn.Conv2d(in_size, out_channels, **kwargs)
56 |
57 | def forward(self, x):
58 | ret = self.addcoords(x)
59 | ret = self.conv(ret)
60 | return ret
61 |
--------------------------------------------------------------------------------
/experiments/dgmr-multichannel/dgmr/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from .Attention import AttentionLayer
2 | from .ConvGRU import ConvGRU
3 | from .CoordConv import CoordConv
4 |
--------------------------------------------------------------------------------
/experiments/dgmr-multichannel/dgmr/layers/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from dgmr.layers import CoordConv
3 |
4 |
5 | def get_conv_layer(conv_type: str = "standard") -> torch.nn.Module:
6 | if conv_type == "standard":
7 | conv_layer = torch.nn.Conv2d
8 | elif conv_type == "coord":
9 | conv_layer = CoordConv
10 | elif conv_type == "3d":
11 | conv_layer = torch.nn.Conv3d
12 | else:
13 | raise ValueError(f"{conv_type} is not a recognized Conv method")
14 | return conv_layer
15 |
--------------------------------------------------------------------------------
/experiments/dgmr-oneshot-multichannel/dgmr/__init__.py:
--------------------------------------------------------------------------------
1 | from .dgmr import DGMR
2 | from .generators import Sampler, Generator
3 | from .discriminators import SpatialDiscriminator, TemporalDiscriminator, Discriminator
4 | from .common import LatentConditioningStack, ContextConditioningStack
5 |
--------------------------------------------------------------------------------
/experiments/dgmr-oneshot-multichannel/dgmr/generators.py:
--------------------------------------------------------------------------------
1 | import einops
2 | import torch
3 | import torch.nn.functional as F
4 | from torch.nn.modules.pixelshuffle import PixelShuffle
5 | from torch.nn.utils.parametrizations import spectral_norm
6 | from typing import List
7 | from dgmr.common import GBlock, UpsampleGBlock
8 | from dgmr.layers import ConvGRU
9 | from huggingface_hub import PyTorchModelHubMixin
10 | import logging
11 |
12 | logger = logging.getLogger(__name__)
13 | logger.setLevel(logging.WARN)
14 |
15 |
16 | class Sampler(torch.nn.Module, PyTorchModelHubMixin):
17 | def __init__(
18 | self,
19 | forecast_steps: int = 18,
20 | context_channels: int = 384,
21 | latent_channels: int = 384,
22 | output_channels: int = 1,
23 | **kwargs
24 | ):
25 | """
26 | Sampler from the Skillful Nowcasting, see https://arxiv.org/pdf/2104.00954.pdf
27 |
28 | The sampler takes the output from the Latent and Context conditioning stacks and
29 | creates one stack of ConvGRU layers per future timestep.
30 | Args:
31 | forecast_steps: Number of forecast steps
32 | latent_channels: Number of input channels to the lowest ConvGRU layer
33 | """
34 | super().__init__()
35 | config = locals()
36 | config.pop("__class__")
37 | config.pop("self")
38 | self.config = kwargs.get("config", config)
39 | self.forecast_steps = self.config["forecast_steps"]
40 | latent_channels = self.config["latent_channels"]
41 | context_channels = self.config["context_channels"]
42 | output_channels = self.config["output_channels"]
43 |
44 | self.gru_conv_1x1 = spectral_norm(
45 | torch.nn.Conv2d(
46 | in_channels=context_channels,
47 | out_channels=latent_channels * self.forecast_steps,
48 | kernel_size=(1, 1),
49 | )
50 | )
51 | self.g1 = GBlock(
52 | input_channels=latent_channels * self.forecast_steps,
53 | output_channels=latent_channels * self.forecast_steps,
54 | )
55 | self.up_g1 = UpsampleGBlock(
56 | input_channels=latent_channels * self.forecast_steps,
57 | output_channels=latent_channels * self.forecast_steps // 2,
58 | )
59 |
60 | self.gru_conv_1x1_2 = spectral_norm(
61 | torch.nn.Conv2d(
62 | in_channels=self.up_g1.output_channels + context_channels // 2,
63 | out_channels=latent_channels * self.forecast_steps // 2,
64 | kernel_size=(1, 1),
65 | )
66 | )
67 | self.g2 = GBlock(
68 | input_channels=latent_channels * self.forecast_steps // 2,
69 | output_channels=latent_channels * self.forecast_steps // 2,
70 | )
71 | self.up_g2 = UpsampleGBlock(
72 | input_channels=latent_channels * self.forecast_steps // 2,
73 | output_channels=latent_channels * self.forecast_steps // 4,
74 | )
75 |
76 | self.gru_conv_1x1_3 = spectral_norm(
77 | torch.nn.Conv2d(
78 | in_channels=self.up_g2.output_channels + context_channels // 4,
79 | out_channels=latent_channels * self.forecast_steps // 4,
80 | kernel_size=(1, 1),
81 | )
82 | )
83 | self.g3 = GBlock(
84 | input_channels=latent_channels * self.forecast_steps // 4,
85 | output_channels=latent_channels * self.forecast_steps // 4,
86 | )
87 | self.up_g3 = UpsampleGBlock(
88 | input_channels=latent_channels * self.forecast_steps // 4,
89 | output_channels=latent_channels * self.forecast_steps // 8,
90 | )
91 |
92 | self.gru_conv_1x1_4 = spectral_norm(
93 | torch.nn.Conv2d(
94 | in_channels=self.up_g3.output_channels + context_channels // 8,
95 | out_channels=latent_channels * self.forecast_steps // 8,
96 | kernel_size=(1, 1),
97 | )
98 | )
99 | self.g4 = GBlock(
100 | input_channels=latent_channels * self.forecast_steps // 8,
101 | output_channels=latent_channels * self.forecast_steps // 8,
102 | )
103 | self.up_g4 = UpsampleGBlock(
104 | input_channels=latent_channels * self.forecast_steps // 8,
105 | output_channels=latent_channels * self.forecast_steps // 16,
106 | )
107 |
108 | self.bn = torch.nn.BatchNorm2d(latent_channels * self.forecast_steps // 16)
109 | self.relu = torch.nn.ReLU()
110 | self.conv_1x1 = spectral_norm(
111 | torch.nn.Conv2d(
112 | in_channels=latent_channels * self.forecast_steps // 16,
113 | out_channels=4 * output_channels * self.forecast_steps,
114 | kernel_size=(1, 1),
115 | )
116 | )
117 |
118 | self.depth2space = PixelShuffle(upscale_factor=2)
119 |
120 | def forward(self, conditioning_states: List[torch.Tensor]) -> torch.Tensor:
121 | """
122 | Perform the sampling from Skillful Nowcasting with GANs
123 | Args:
124 | conditioning_states: Outputs from the `ContextConditioningStack` with the 4 input states, ordered from largest to smallest spatially
125 | latent_dim: Output from `LatentConditioningStack` for input into the ConvGRUs
126 |
127 | Returns:
128 | forecast_steps-length output of images for future timesteps
129 |
130 | """
131 | # Iterate through each forecast step
132 | # Initialize with conditioning state for first one, output for second one
133 | init_states = conditioning_states
134 |
135 | layer4_states = self.gru_conv_1x1(init_states[3])
136 | layer4_states = self.g1(layer4_states)
137 | layer4_states = self.up_g1(layer4_states)
138 |
139 | # Layer 3.
140 | layer3_states = torch.cat([layer4_states, init_states[2]], dim=1)
141 | layer3_states = self.gru_conv_1x1_2(layer3_states)
142 | layer3_states = self.g2(layer3_states)
143 | layer3_states = self.up_g2(layer3_states)
144 |
145 | # Layer 2.
146 | layer2_states = torch.cat([layer3_states, init_states[1]], dim=1)
147 | layer2_states = self.gru_conv_1x1_3(layer2_states)
148 | layer2_states = self.g3(layer2_states)
149 | layer2_states = self.up_g3(layer2_states)
150 |
151 | # Layer 1 (top-most).
152 | layer1_states = torch.cat([layer2_states, init_states[0]], dim=1)
153 | layer1_states = self.gru_conv_1x1_4(layer1_states)
154 | layer1_states = self.g4(layer1_states)
155 | layer1_states = self.up_g4(layer1_states)
156 |
157 | # Final stuff
158 | output_states = self.relu(self.bn(layer1_states))
159 | output_states = self.conv_1x1(output_states)
160 | output_states = self.depth2space(output_states)
161 |
162 | # The satellite dimension was lost, add it back
163 | output_states = torch.unsqueeze(output_states, dim=2)
164 |
165 | return output_states
166 |
167 |
168 | class Generator(torch.nn.Module, PyTorchModelHubMixin):
169 | def __init__(
170 | self,
171 | conditioning_stack: torch.nn.Module,
172 | sampler: torch.nn.Module,
173 | ):
174 | """
175 | Wraps the three parts of the generator for simpler calling
176 | Args:
177 | conditioning_stack:
178 | latent_stack:
179 | sampler:
180 | """
181 | super().__init__()
182 | self.conditioning_stack = conditioning_stack
183 | self.sampler = sampler
184 |
185 | def forward(self, x):
186 | conditioning_states = self.conditioning_stack(x)
187 | x = self.sampler(conditioning_states)
188 | return x
189 |
--------------------------------------------------------------------------------
/experiments/dgmr-oneshot-multichannel/dgmr/hub.py:
--------------------------------------------------------------------------------
1 | """
2 | Originally Taken from https://github.com/rwightman/
3 |
4 | https://github.com/rwightman/pytorch-image-models/
5 | blob/acd6c687fd1c0507128f0ce091829b233c8560b9/timm/models/hub.py
6 | """
7 |
8 | import json
9 | import logging
10 | import os
11 | from functools import partial
12 |
13 | import torch
14 |
15 |
16 | try:
17 | from huggingface_hub import cached_download, hf_hub_url
18 |
19 | cached_download = partial(cached_download, library_name="dgmr")
20 | except ImportError:
21 | hf_hub_url = None
22 | cached_download = None
23 |
24 | from huggingface_hub import CONFIG_NAME, PYTORCH_WEIGHTS_NAME, ModelHubMixin, hf_hub_download
25 |
26 | MODEL_CARD_MARKDOWN = """---
27 | license: mit
28 | tags:
29 | - nowcasting
30 | - forecasting
31 | - timeseries
32 | - remote-sensing
33 | - gan
34 | ---
35 |
36 | # {model_name}
37 |
38 | ## Model description
39 |
40 | [More information needed]
41 |
42 | ## Intended uses & limitations
43 |
44 | [More information needed]
45 |
46 | ## How to use
47 |
48 | [More information needed]
49 |
50 | ## Limitations and bias
51 |
52 | [More information needed]
53 |
54 | ## Training data
55 |
56 | [More information needed]
57 |
58 | ## Training procedure
59 |
60 | [More information needed]
61 |
62 | ## Evaluation results
63 |
64 | [More information needed]
65 |
66 | """
67 |
68 | _logger = logging.getLogger(__name__)
69 |
70 |
71 | class NowcastingModelHubMixin(ModelHubMixin):
72 | """
73 | HuggingFace ModelHubMixin containing specific adaptions for Nowcasting models
74 | """
75 |
76 | def __init__(self, *args, **kwargs):
77 | """
78 | Mixin for pl.LightningModule and Hugging Face
79 |
80 | Mix this class with your pl.LightningModule class to easily push / download
81 | the model via the Hugging Face Hub
82 |
83 | Example::
84 |
85 | >>> from dgmr.hub import NowcastingModelHubMixin
86 |
87 | >>> class MyModel(nn.Module, NowcastingModelHubMixin):
88 | ... def __init__(self, **kwargs):
89 | ... super().__init__()
90 | ... self.layer = ...
91 | ... def forward(self, ...)
92 | ... return ...
93 |
94 | >>> model = MyModel()
95 | >>> model.push_to_hub("mymodel") # Pushing model-weights to hf-hub
96 |
97 | >>> # Downloading weights from hf-hub & model will be initialized from those weights
98 | >>> model = MyModel.from_pretrained("username/mymodel")
99 | """
100 |
101 | def _create_model_card(self, path):
102 | model_card = MODEL_CARD_MARKDOWN.format(model_name=type(self).__name__)
103 | with open(os.path.join(path, "README.md"), "w") as f:
104 | f.write(model_card)
105 |
106 | def _save_config(self, module, save_directory):
107 | config = dict(module.hparams)
108 | path = os.path.join(save_directory, CONFIG_NAME)
109 | with open(path, "w") as f:
110 | json.dump(config, f)
111 |
112 | def _save_pretrained(self, save_directory: str, save_config: bool = True):
113 | # Save model weights
114 | path = os.path.join(save_directory, PYTORCH_WEIGHTS_NAME)
115 | model_to_save = self.module if hasattr(self, "module") else self
116 | torch.save(model_to_save.state_dict(), path)
117 | # Save model config
118 | if save_config and model_to_save.hparams:
119 | self._save_config(model_to_save, save_directory)
120 | # Save model card
121 | self._create_model_card(save_directory)
122 |
123 | @classmethod
124 | def _from_pretrained(
125 | cls,
126 | model_id,
127 | revision,
128 | cache_dir,
129 | force_download,
130 | proxies,
131 | resume_download,
132 | local_files_only,
133 | use_auth_token,
134 | map_location="cpu",
135 | strict=False,
136 | **model_kwargs,
137 | ):
138 | map_location = torch.device(map_location)
139 |
140 | if os.path.isdir(model_id):
141 | print("Loading weights from local directory")
142 | model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME)
143 | else:
144 | model_file = hf_hub_download(
145 | repo_id=model_id,
146 | filename=PYTORCH_WEIGHTS_NAME,
147 | revision=revision,
148 | cache_dir=cache_dir,
149 | force_download=force_download,
150 | proxies=proxies,
151 | resume_download=resume_download,
152 | use_auth_token=use_auth_token,
153 | local_files_only=local_files_only,
154 | )
155 | model = cls(**model_kwargs["config"])
156 |
157 | state_dict = torch.load(model_file, map_location=map_location)
158 | model.load_state_dict(state_dict, strict=strict)
159 | model.eval()
160 |
161 | return model
162 |
--------------------------------------------------------------------------------
/experiments/dgmr-oneshot-multichannel/dgmr/layers/Attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import functional as F
4 | import einops
5 |
6 |
7 | def attention_einsum(q, k, v):
8 | """Apply the attention operator to tensors of shape [h, w, c]."""
9 |
10 | # Reshape 3D tensors to 2D tensor with first dimension L = h x w.
11 | k = einops.rearrange(k, "h w c -> (h w) c") # [h, w, c] -> [L, c]
12 | v = einops.rearrange(v, "h w c -> (h w) c") # [h, w, c] -> [L, c]
13 |
14 | # Einstein summation corresponding to the query * key operation.
15 | beta = F.softmax(torch.einsum("hwc, Lc->hwL", q, k), dim=-1)
16 |
17 | # Einstein summation corresponding to the attention * value operation.
18 | out = torch.einsum("hwL, Lc->hwc", beta, v)
19 | return out
20 |
21 |
22 | class AttentionLayer(torch.nn.Module):
23 | """Attention Module"""
24 |
25 | def __init__(self, input_channels: int, output_channels: int, ratio_kq=8, ratio_v=8):
26 | super(AttentionLayer, self).__init__()
27 |
28 | self.ratio_kq = ratio_kq
29 | self.ratio_v = ratio_v
30 | self.output_channels = output_channels
31 | self.input_channels = input_channels
32 |
33 | # Compute query, key and value using 1x1 convolutions.
34 | self.query = torch.nn.Conv2d(
35 | in_channels=input_channels,
36 | out_channels=self.output_channels // self.ratio_kq,
37 | kernel_size=(1, 1),
38 | padding="valid",
39 | bias=False,
40 | )
41 | self.key = torch.nn.Conv2d(
42 | in_channels=input_channels,
43 | out_channels=self.output_channels // self.ratio_kq,
44 | kernel_size=(1, 1),
45 | padding="valid",
46 | bias=False,
47 | )
48 | self.value = torch.nn.Conv2d(
49 | in_channels=input_channels,
50 | out_channels=self.output_channels // self.ratio_v,
51 | kernel_size=(1, 1),
52 | padding="valid",
53 | bias=False,
54 | )
55 |
56 | self.last_conv = torch.nn.Conv2d(
57 | in_channels=self.output_channels // 8,
58 | out_channels=self.output_channels,
59 | kernel_size=(1, 1),
60 | padding="valid",
61 | bias=False,
62 | )
63 |
64 | # Learnable gain parameter
65 | self.gamma = nn.Parameter(torch.zeros(1))
66 |
67 | def forward(self, x: torch.Tensor) -> torch.Tensor:
68 | # Compute query, key and value using 1x1 convolutions.
69 | query = self.query(x)
70 | key = self.key(x)
71 | value = self.value(x)
72 | # Apply the attention operation.
73 | out = []
74 | for b in range(x.shape[0]):
75 | # Apply to each in batch
76 | out.append(attention_einsum(query[b], key[b], value[b]))
77 | out = torch.stack(out, dim=0)
78 | out = self.gamma * self.last_conv(out)
79 | # Residual connection.
80 | return out + x
81 |
--------------------------------------------------------------------------------
/experiments/dgmr-oneshot-multichannel/dgmr/layers/ConvGRU.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.nn.utils.parametrizations import spectral_norm
4 |
5 |
6 | class ConvGRUCell(torch.nn.Module):
7 | """A ConvGRU implementation."""
8 |
9 | def __init__(
10 | self, input_channels: int, output_channels: int, kernel_size=3, sn_eps=0.0001
11 | ):
12 | """Constructor.
13 |
14 | Args:
15 | kernel_size: kernel size of the convolutions. Default: 3.
16 | sn_eps: constant for spectral normalization. Default: 1e-4.
17 | """
18 | super().__init__()
19 | self._kernel_size = kernel_size
20 | self._sn_eps = sn_eps
21 | self.read_gate_conv = spectral_norm(
22 | torch.nn.Conv2d(
23 | in_channels=input_channels,
24 | out_channels=output_channels,
25 | kernel_size=(kernel_size, kernel_size),
26 | padding=1,
27 | ),
28 | eps=sn_eps,
29 | )
30 | self.update_gate_conv = spectral_norm(
31 | torch.nn.Conv2d(
32 | in_channels=input_channels,
33 | out_channels=output_channels,
34 | kernel_size=(kernel_size, kernel_size),
35 | padding=1,
36 | ),
37 | eps=sn_eps,
38 | )
39 | self.output_conv = spectral_norm(
40 | torch.nn.Conv2d(
41 | in_channels=input_channels,
42 | out_channels=output_channels,
43 | kernel_size=(kernel_size, kernel_size),
44 | padding=1,
45 | ),
46 | eps=sn_eps,
47 | )
48 |
49 | def forward(self, x, prev_state):
50 | """
51 | ConvGRU forward, returning the current+new state
52 |
53 | Args:
54 | x: Input tensor
55 | prev_state: Previous state
56 |
57 | Returns:
58 | New tensor plus the new state
59 | """
60 | # Concatenate the inputs and previous state along the channel axis.
61 | xh = torch.cat([x, prev_state], dim=1)
62 |
63 | # Read gate of the GRU.
64 | read_gate = F.sigmoid(self.read_gate_conv(xh))
65 |
66 | # Update gate of the GRU.
67 | update_gate = F.sigmoid(self.update_gate_conv(xh))
68 |
69 | # Gate the inputs.
70 | gated_input = torch.cat([x, read_gate * prev_state], dim=1)
71 |
72 | # Gate the cell and state / outputs.
73 | c = F.relu(self.output_conv(gated_input))
74 | out = update_gate * prev_state + (1.0 - update_gate) * c
75 | new_state = out
76 |
77 | return out, new_state
78 |
79 |
80 | class ConvGRU(torch.nn.Module):
81 | """ConvGRU Cell wrapper to replace tf.static_rnn in TF implementation"""
82 |
83 | def __init__(
84 | self,
85 | input_channels: int,
86 | output_channels: int,
87 | kernel_size: int = 3,
88 | sn_eps=0.0001,
89 | ):
90 | super().__init__()
91 | self.cell = ConvGRUCell(input_channels, output_channels, kernel_size, sn_eps)
92 |
93 | def forward(self, x: torch.Tensor, hidden_state=None) -> torch.Tensor:
94 | outputs = []
95 | for step in range(len(x)):
96 | # Compute current timestep
97 | output, hidden_state = self.cell(x[step], hidden_state)
98 | outputs.append(output)
99 | # Stack outputs to return as tensor
100 | outputs = torch.stack(outputs, dim=0)
101 | return outputs
102 |
--------------------------------------------------------------------------------
/experiments/dgmr-oneshot-multichannel/dgmr/layers/CoordConv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class AddCoords(nn.Module):
6 | def __init__(self, with_r=False):
7 | super().__init__()
8 | self.with_r = with_r
9 |
10 | def forward(self, input_tensor):
11 | """
12 | Args:
13 | input_tensor: shape(batch, channel, x_dim, y_dim)
14 | """
15 | batch_size, _, x_dim, y_dim = input_tensor.size()
16 |
17 | xx_channel = torch.arange(x_dim).repeat(1, y_dim, 1)
18 | yy_channel = torch.arange(y_dim).repeat(1, x_dim, 1).transpose(1, 2)
19 |
20 | xx_channel = xx_channel.float() / (x_dim - 1)
21 | yy_channel = yy_channel.float() / (y_dim - 1)
22 |
23 | xx_channel = xx_channel * 2 - 1
24 | yy_channel = yy_channel * 2 - 1
25 |
26 | xx_channel = xx_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)
27 | yy_channel = yy_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)
28 |
29 | ret = torch.cat(
30 | [
31 | input_tensor,
32 | xx_channel.type_as(input_tensor),
33 | yy_channel.type_as(input_tensor),
34 | ],
35 | dim=1,
36 | )
37 |
38 | if self.with_r:
39 | rr = torch.sqrt(
40 | torch.pow(xx_channel.type_as(input_tensor) - 0.5, 2)
41 | + torch.pow(yy_channel.type_as(input_tensor) - 0.5, 2)
42 | )
43 | ret = torch.cat([ret, rr], dim=1)
44 |
45 | return ret
46 |
47 |
48 | class CoordConv(nn.Module):
49 | def __init__(self, in_channels, out_channels, with_r=False, **kwargs):
50 | super().__init__()
51 | self.addcoords = AddCoords(with_r=with_r)
52 | in_size = in_channels + 2
53 | if with_r:
54 | in_size += 1
55 | self.conv = nn.Conv2d(in_size, out_channels, **kwargs)
56 |
57 | def forward(self, x):
58 | ret = self.addcoords(x)
59 | ret = self.conv(ret)
60 | return ret
61 |
--------------------------------------------------------------------------------
/experiments/dgmr-oneshot-multichannel/dgmr/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from .Attention import AttentionLayer
2 | from .ConvGRU import ConvGRU
3 | from .CoordConv import CoordConv
4 |
--------------------------------------------------------------------------------
/experiments/dgmr-oneshot-multichannel/dgmr/layers/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from dgmr.layers import CoordConv
3 |
4 |
5 | def get_conv_layer(conv_type: str = "standard") -> torch.nn.Module:
6 | if conv_type == "standard":
7 | conv_layer = torch.nn.Conv2d
8 | elif conv_type == "coord":
9 | conv_layer = CoordConv
10 | elif conv_type == "3d":
11 | conv_layer = torch.nn.Conv3d
12 | else:
13 | raise ValueError(f"{conv_type} is not a recognized Conv method")
14 | return conv_layer
15 |
--------------------------------------------------------------------------------
/experiments/dgmr-oneshot/dgmr/__init__.py:
--------------------------------------------------------------------------------
1 | from .dgmr import DGMR
2 | from .generators import Sampler, Generator
3 | from .discriminators import SpatialDiscriminator, TemporalDiscriminator, Discriminator
4 | from .common import LatentConditioningStack, ContextConditioningStack
5 |
--------------------------------------------------------------------------------
/experiments/dgmr-oneshot/dgmr/dgmr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from dgmr.losses import (
3 | NowcastingLoss,
4 | GridCellLoss,
5 | loss_hinge_disc,
6 | loss_hinge_gen,
7 | grid_cell_regularizer,
8 | )
9 | import pytorch_lightning as pl
10 | import torchvision
11 | from dgmr.common import LatentConditioningStack, ContextConditioningStack
12 | from dgmr.generators import Sampler, Generator
13 | from dgmr.discriminators import Discriminator
14 | from dgmr.hub import NowcastingModelHubMixin
15 |
16 |
17 | class DGMR(pl.LightningModule, NowcastingModelHubMixin):
18 | """Deep Generative Model of Radar"""
19 |
20 | def __init__(
21 | self,
22 | forecast_steps: int = 18,
23 | input_channels: int = 1,
24 | output_shape: int = 256,
25 | gen_lr: float = 5e-5,
26 | disc_lr: float = 2e-4,
27 | visualize: bool = False,
28 | conv_type: str = "standard",
29 | num_samples: int = 6,
30 | grid_lambda: float = 20.0,
31 | beta1: float = 0.0,
32 | beta2: float = 0.999,
33 | latent_channels: int = 768,
34 | context_channels: int = 384,
35 | **kwargs,
36 | ):
37 | """
38 | Nowcasting GAN is an attempt to recreate DeepMind's Skillful Nowcasting GAN from https://arxiv.org/abs/2104.00954
39 | but slightly modified for multiple satellite channels
40 |
41 | Args:
42 | forecast_steps: Number of steps to predict in the future
43 | input_channels: Number of input channels per image
44 | visualize: Whether to visualize output during training
45 | gen_lr: Learning rate for the generator
46 | disc_lr: Learning rate for the discriminators, shared for both temporal and spatial discriminator
47 | conv_type: Type of 2d convolution to use, see satflow/models/utils.py for options
48 | beta1: Beta1 for Adam optimizer
49 | beta2: Beta2 for Adam optimizer
50 | num_samples: Number of samples of the latent space to sample for training/validation
51 | grid_lambda: Lambda for the grid regularization loss
52 | output_shape: Shape of the output predictions, generally should be same as the input shape
53 | latent_channels: Number of channels that the latent space should be reshaped to,
54 | input dimension into ConvGRU, also affects the number of channels for other linked inputs/outputs
55 | pretrained:
56 | """
57 | super().__init__()
58 | config = locals()
59 | config.pop("__class__")
60 | config.pop("self")
61 | self.config = kwargs.get("config", config)
62 | input_channels = self.config["input_channels"]
63 | forecast_steps = self.config["forecast_steps"]
64 | output_shape = self.config["output_shape"]
65 | gen_lr = self.config["gen_lr"]
66 | disc_lr = self.config["disc_lr"]
67 | conv_type = self.config["conv_type"]
68 | num_samples = self.config["num_samples"]
69 | grid_lambda = self.config["grid_lambda"]
70 | beta1 = self.config["beta1"]
71 | beta2 = self.config["beta2"]
72 | latent_channels = self.config["latent_channels"]
73 | context_channels = self.config["context_channels"]
74 | visualize = self.config["visualize"]
75 | self.gen_lr = gen_lr
76 | self.disc_lr = disc_lr
77 | self.beta1 = beta1
78 | self.beta2 = beta2
79 | self.discriminator_loss = NowcastingLoss()
80 | self.grid_regularizer = GridCellLoss()
81 | self.grid_lambda = grid_lambda
82 | self.num_samples = num_samples
83 | self.visualize = visualize
84 | self.latent_channels = latent_channels
85 | self.context_channels = context_channels
86 | self.input_channels = input_channels
87 | self.conditioning_stack = ContextConditioningStack(
88 | input_channels=input_channels,
89 | conv_type=conv_type,
90 | output_channels=self.context_channels,
91 | )
92 | self.latent_stack = LatentConditioningStack(
93 | shape=(8 * self.input_channels, output_shape // 32, output_shape // 32),
94 | output_channels=self.latent_channels,
95 | )
96 | self.sampler = Sampler(
97 | forecast_steps=forecast_steps,
98 | latent_channels=self.latent_channels,
99 | context_channels=self.context_channels,
100 | )
101 | self.generator = Generator(
102 | self.conditioning_stack, self.latent_stack, self.sampler
103 | )
104 | self.discriminator = Discriminator(input_channels)
105 | self.save_hyperparameters()
106 |
107 | self.global_iteration = 0
108 |
109 | # Important: This property activates manual optimization.
110 | self.automatic_optimization = False
111 | torch.autograd.set_detect_anomaly(True)
112 |
113 | def forward(self, x):
114 | x = self.generator(x)
115 | return x
116 |
117 | def training_step(self, batch, batch_idx):
118 | images, future_images = batch
119 |
120 | self.global_iteration += 1
121 | g_opt, d_opt = self.optimizers()
122 | ##########################
123 | # Optimize Discriminator #
124 | ##########################
125 | # Two discriminator steps per generator step
126 | for _ in range(2):
127 | predictions = self(images)
128 | # Cat along time dimension [B, T, C, H, W]
129 | generated_sequence = torch.cat([images, predictions], dim=1)
130 | real_sequence = torch.cat([images, future_images], dim=1)
131 | # Cat long batch for the real+generated
132 | concatenated_inputs = torch.cat([real_sequence, generated_sequence], dim=0)
133 |
134 | concatenated_outputs = self.discriminator(concatenated_inputs)
135 |
136 | score_real, score_generated = torch.split(concatenated_outputs, 1, dim=1)
137 | discriminator_loss = loss_hinge_disc(score_generated, score_real)
138 | d_opt.zero_grad()
139 | self.manual_backward(discriminator_loss)
140 | d_opt.step()
141 |
142 | ######################
143 | # Optimize Generator #
144 | ######################
145 | predictions = [self(images) for _ in range(6)]
146 | grid_cell_reg = grid_cell_regularizer(
147 | torch.stack(predictions, dim=0), future_images
148 | )
149 | # Concat along time dimension
150 | generated_sequence = [torch.cat([images, x], dim=1) for x in predictions]
151 | real_sequence = torch.cat([images, future_images], dim=1)
152 | # Cat long batch for the real+generated, for each example in the range
153 | # For each of the 6 examples
154 | generated_scores = []
155 | for g_seq in generated_sequence:
156 | concatenated_inputs = torch.cat([real_sequence, g_seq], dim=0)
157 | concatenated_outputs = self.discriminator(concatenated_inputs)
158 | score_real, score_generated = torch.split(concatenated_outputs, 1, dim=1)
159 | generated_scores.append(score_generated)
160 | generator_disc_loss = loss_hinge_gen(torch.cat(generated_scores, dim=0))
161 | generator_loss = generator_disc_loss + self.grid_lambda * grid_cell_reg
162 | g_opt.zero_grad()
163 | self.manual_backward(generator_loss)
164 | g_opt.step()
165 |
166 | self.log_dict(
167 | {
168 | "train/d_loss": discriminator_loss,
169 | "train/g_loss": generator_loss,
170 | "train/grid_loss": grid_cell_reg,
171 | },
172 | prog_bar=True,
173 | )
174 |
175 | # generate images
176 | generated_images = self(images)
177 | # log sampled images
178 | if self.visualize:
179 | self.visualize_step(
180 | images,
181 | future_images,
182 | generated_images,
183 | self.global_iteration,
184 | step="train",
185 | )
186 |
187 | def configure_optimizers(self):
188 | b1 = self.beta1
189 | b2 = self.beta2
190 |
191 | opt_g = torch.optim.Adam(
192 | self.generator.parameters(), lr=self.gen_lr, betas=(b1, b2)
193 | )
194 | opt_d = torch.optim.Adam(
195 | self.discriminator.parameters(), lr=self.disc_lr, betas=(b1, b2)
196 | )
197 |
198 | return [opt_g, opt_d], []
199 |
200 | def visualize_step(
201 | self,
202 | x: torch.Tensor,
203 | y: torch.Tensor,
204 | y_hat: torch.Tensor,
205 | batch_idx: int,
206 | step: str,
207 | ) -> None:
208 | # the logger you used (in this case tensorboard)
209 | tensorboard = self.logger.experiment[0]
210 | # Timesteps per channel
211 | images = x[0].cpu().detach()
212 | future_images = y[0].cpu().detach()
213 | generated_images = y_hat[0].cpu().detach()
214 | for i, t in enumerate(images): # Now would be (C, H, W)
215 | t = [torch.unsqueeze(img, dim=0) for img in t]
216 | image_grid = torchvision.utils.make_grid(t, nrow=self.input_channels)
217 | tensorboard.add_image(
218 | f"{step}/Input_Image_Stack_Frame_{i}", image_grid, global_step=batch_idx
219 | )
220 | t = [torch.unsqueeze(img, dim=0) for img in future_images[i]]
221 | image_grid = torchvision.utils.make_grid(t, nrow=self.input_channels)
222 | tensorboard.add_image(
223 | f"{step}/Target_Image_Frame_{i}", image_grid, global_step=batch_idx
224 | )
225 | t = [torch.unsqueeze(img, dim=0) for img in generated_images[i]]
226 | image_grid = torchvision.utils.make_grid(t, nrow=self.input_channels)
227 | tensorboard.add_image(
228 | f"{step}/Generated_Image_Frame_{i}", image_grid, global_step=batch_idx
229 | )
230 |
--------------------------------------------------------------------------------
/experiments/dgmr-oneshot/dgmr/generators.py:
--------------------------------------------------------------------------------
1 | import einops
2 | import torch
3 | import torch.nn.functional as F
4 | from torch.nn.modules.pixelshuffle import PixelShuffle
5 | from torch.nn.utils.parametrizations import spectral_norm
6 | from typing import List
7 | from dgmr.common import GBlock, UpsampleGBlock
8 | from dgmr.layers import ConvGRU
9 | from huggingface_hub import PyTorchModelHubMixin
10 | import logging
11 |
12 | logger = logging.getLogger(__name__)
13 | logger.setLevel(logging.WARN)
14 |
15 |
16 | class Sampler(torch.nn.Module, PyTorchModelHubMixin):
17 | def __init__(
18 | self,
19 | forecast_steps: int = 18,
20 | context_channels: int = 384,
21 | latent_channels: int = 384,
22 | output_channels: int = 1,
23 | **kwargs
24 | ):
25 | """
26 | Sampler from the Skillful Nowcasting, see https://arxiv.org/pdf/2104.00954.pdf
27 |
28 | The sampler takes the output from the Latent and Context conditioning stacks and
29 | creates one stack of ConvGRU layers per future timestep.
30 | Args:
31 | forecast_steps: Number of forecast steps
32 | latent_channels: Number of input channels to the lowest ConvGRU layer
33 | """
34 | super().__init__()
35 | config = locals()
36 | config.pop("__class__")
37 | config.pop("self")
38 | self.config = kwargs.get("config", config)
39 | self.forecast_steps = self.config["forecast_steps"]
40 | latent_channels = self.config["latent_channels"]
41 | context_channels = self.config["context_channels"]
42 | output_channels = self.config["output_channels"]
43 |
44 | self.gru_conv_1x1 = spectral_norm(
45 | torch.nn.Conv2d(
46 | in_channels=context_channels,
47 | out_channels=latent_channels * self.forecast_steps,
48 | kernel_size=(1, 1),
49 | )
50 | )
51 | self.g1 = GBlock(
52 | input_channels=latent_channels * self.forecast_steps,
53 | output_channels=latent_channels * self.forecast_steps,
54 | )
55 | self.up_g1 = UpsampleGBlock(
56 | input_channels=latent_channels * self.forecast_steps,
57 | output_channels=latent_channels * self.forecast_steps // 2,
58 | )
59 |
60 | self.gru_conv_1x1_2 = spectral_norm(
61 | torch.nn.Conv2d(
62 | in_channels=self.up_g1.output_channels + context_channels // 2,
63 | out_channels=latent_channels * self.forecast_steps // 2,
64 | kernel_size=(1, 1),
65 | )
66 | )
67 | self.g2 = GBlock(
68 | input_channels=latent_channels * self.forecast_steps // 2,
69 | output_channels=latent_channels * self.forecast_steps // 2,
70 | )
71 | self.up_g2 = UpsampleGBlock(
72 | input_channels=latent_channels * self.forecast_steps // 2,
73 | output_channels=latent_channels * self.forecast_steps // 4,
74 | )
75 |
76 | self.gru_conv_1x1_3 = spectral_norm(
77 | torch.nn.Conv2d(
78 | in_channels=self.up_g2.output_channels + context_channels // 4,
79 | out_channels=latent_channels * self.forecast_steps // 4,
80 | kernel_size=(1, 1),
81 | )
82 | )
83 | self.g3 = GBlock(
84 | input_channels=latent_channels * self.forecast_steps // 4,
85 | output_channels=latent_channels * self.forecast_steps // 4,
86 | )
87 | self.up_g3 = UpsampleGBlock(
88 | input_channels=latent_channels * self.forecast_steps // 4,
89 | output_channels=latent_channels * self.forecast_steps // 8,
90 | )
91 |
92 | self.gru_conv_1x1_4 = spectral_norm(
93 | torch.nn.Conv2d(
94 | in_channels=self.up_g3.output_channels + context_channels // 8,
95 | out_channels=latent_channels * self.forecast_steps // 8,
96 | kernel_size=(1, 1),
97 | )
98 | )
99 | self.g4 = GBlock(
100 | input_channels=latent_channels * self.forecast_steps // 8,
101 | output_channels=latent_channels * self.forecast_steps // 8,
102 | )
103 | self.up_g4 = UpsampleGBlock(
104 | input_channels=latent_channels * self.forecast_steps // 8,
105 | output_channels=latent_channels * self.forecast_steps // 16,
106 | )
107 |
108 | self.bn = torch.nn.BatchNorm2d(latent_channels * self.forecast_steps // 16)
109 | self.relu = torch.nn.ReLU()
110 | self.conv_1x1 = spectral_norm(
111 | torch.nn.Conv2d(
112 | in_channels=latent_channels * self.forecast_steps // 16,
113 | out_channels=4 * output_channels * self.forecast_steps,
114 | kernel_size=(1, 1),
115 | )
116 | )
117 |
118 | self.depth2space = PixelShuffle(upscale_factor=2)
119 |
120 | def forward(self, conditioning_states: List[torch.Tensor]) -> torch.Tensor:
121 | """
122 | Perform the sampling from Skillful Nowcasting with GANs
123 | Args:
124 | conditioning_states: Outputs from the `ContextConditioningStack` with the 4 input states, ordered from largest to smallest spatially
125 | latent_dim: Output from `LatentConditioningStack` for input into the ConvGRUs
126 |
127 | Returns:
128 | forecast_steps-length output of images for future timesteps
129 |
130 | """
131 | # Iterate through each forecast step
132 | # Initialize with conditioning state for first one, output for second one
133 | init_states = conditioning_states
134 |
135 | layer4_states = self.gru_conv_1x1(init_states[3])
136 | layer4_states = self.g1(layer4_states)
137 | layer4_states = self.up_g1(layer4_states)
138 |
139 | # Layer 3.
140 | layer3_states = torch.cat([layer4_states, init_states[2]], dim=1)
141 | layer3_states = self.gru_conv_1x1_2(layer3_states)
142 | layer3_states = self.g2(layer3_states)
143 | layer3_states = self.up_g2(layer3_states)
144 |
145 | # Layer 2.
146 | layer2_states = torch.cat([layer3_states, init_states[1]], dim=1)
147 | layer2_states = self.gru_conv_1x1_3(layer2_states)
148 | layer2_states = self.g3(layer2_states)
149 | layer2_states = self.up_g3(layer2_states)
150 |
151 | # Layer 1 (top-most).
152 | layer1_states = torch.cat([layer2_states, init_states[0]], dim=1)
153 | layer1_states = self.gru_conv_1x1_4(layer1_states)
154 | layer1_states = self.g4(layer1_states)
155 | layer1_states = self.up_g4(layer1_states)
156 |
157 | # Final stuff
158 | output_states = self.relu(self.bn(layer1_states))
159 | output_states = self.conv_1x1(output_states)
160 | output_states = self.depth2space(output_states)
161 |
162 | # The satellite dimension was lost, add it back
163 | output_states = torch.unsqueeze(output_states, dim=2)
164 |
165 | return output_states
166 |
167 |
168 | class Generator(torch.nn.Module, PyTorchModelHubMixin):
169 | def __init__(
170 | self,
171 | conditioning_stack: torch.nn.Module,
172 | sampler: torch.nn.Module,
173 | ):
174 | """
175 | Wraps the three parts of the generator for simpler calling
176 | Args:
177 | conditioning_stack:
178 | latent_stack:
179 | sampler:
180 | """
181 | super().__init__()
182 | self.conditioning_stack = conditioning_stack
183 | self.sampler = sampler
184 |
185 | def forward(self, x):
186 | conditioning_states = self.conditioning_stack(x)
187 | x = self.sampler(conditioning_states)
188 | return x
189 |
--------------------------------------------------------------------------------
/experiments/dgmr-oneshot/dgmr/hub.py:
--------------------------------------------------------------------------------
1 | """
2 | Originally Taken from https://github.com/rwightman/
3 |
4 | https://github.com/rwightman/pytorch-image-models/
5 | blob/acd6c687fd1c0507128f0ce091829b233c8560b9/timm/models/hub.py
6 | """
7 |
8 | import json
9 | import logging
10 | import os
11 | from functools import partial
12 |
13 | import torch
14 |
15 |
16 | try:
17 | from huggingface_hub import cached_download, hf_hub_url
18 |
19 | cached_download = partial(cached_download, library_name="dgmr")
20 | except ImportError:
21 | hf_hub_url = None
22 | cached_download = None
23 |
24 | from huggingface_hub import CONFIG_NAME, PYTORCH_WEIGHTS_NAME, ModelHubMixin, hf_hub_download
25 |
26 | MODEL_CARD_MARKDOWN = """---
27 | license: mit
28 | tags:
29 | - nowcasting
30 | - forecasting
31 | - timeseries
32 | - remote-sensing
33 | - gan
34 | ---
35 |
36 | # {model_name}
37 |
38 | ## Model description
39 |
40 | [More information needed]
41 |
42 | ## Intended uses & limitations
43 |
44 | [More information needed]
45 |
46 | ## How to use
47 |
48 | [More information needed]
49 |
50 | ## Limitations and bias
51 |
52 | [More information needed]
53 |
54 | ## Training data
55 |
56 | [More information needed]
57 |
58 | ## Training procedure
59 |
60 | [More information needed]
61 |
62 | ## Evaluation results
63 |
64 | [More information needed]
65 |
66 | """
67 |
68 | _logger = logging.getLogger(__name__)
69 |
70 |
71 | class NowcastingModelHubMixin(ModelHubMixin):
72 | """
73 | HuggingFace ModelHubMixin containing specific adaptions for Nowcasting models
74 | """
75 |
76 | def __init__(self, *args, **kwargs):
77 | """
78 | Mixin for pl.LightningModule and Hugging Face
79 |
80 | Mix this class with your pl.LightningModule class to easily push / download
81 | the model via the Hugging Face Hub
82 |
83 | Example::
84 |
85 | >>> from dgmr.hub import NowcastingModelHubMixin
86 |
87 | >>> class MyModel(nn.Module, NowcastingModelHubMixin):
88 | ... def __init__(self, **kwargs):
89 | ... super().__init__()
90 | ... self.layer = ...
91 | ... def forward(self, ...)
92 | ... return ...
93 |
94 | >>> model = MyModel()
95 | >>> model.push_to_hub("mymodel") # Pushing model-weights to hf-hub
96 |
97 | >>> # Downloading weights from hf-hub & model will be initialized from those weights
98 | >>> model = MyModel.from_pretrained("username/mymodel")
99 | """
100 |
101 | def _create_model_card(self, path):
102 | model_card = MODEL_CARD_MARKDOWN.format(model_name=type(self).__name__)
103 | with open(os.path.join(path, "README.md"), "w") as f:
104 | f.write(model_card)
105 |
106 | def _save_config(self, module, save_directory):
107 | config = dict(module.hparams)
108 | path = os.path.join(save_directory, CONFIG_NAME)
109 | with open(path, "w") as f:
110 | json.dump(config, f)
111 |
112 | def _save_pretrained(self, save_directory: str, save_config: bool = True):
113 | # Save model weights
114 | path = os.path.join(save_directory, PYTORCH_WEIGHTS_NAME)
115 | model_to_save = self.module if hasattr(self, "module") else self
116 | torch.save(model_to_save.state_dict(), path)
117 | # Save model config
118 | if save_config and model_to_save.hparams:
119 | self._save_config(model_to_save, save_directory)
120 | # Save model card
121 | self._create_model_card(save_directory)
122 |
123 | @classmethod
124 | def _from_pretrained(
125 | cls,
126 | model_id,
127 | revision,
128 | cache_dir,
129 | force_download,
130 | proxies,
131 | resume_download,
132 | local_files_only,
133 | use_auth_token,
134 | map_location="cpu",
135 | strict=False,
136 | **model_kwargs,
137 | ):
138 | map_location = torch.device(map_location)
139 |
140 | if os.path.isdir(model_id):
141 | print("Loading weights from local directory")
142 | model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME)
143 | else:
144 | model_file = hf_hub_download(
145 | repo_id=model_id,
146 | filename=PYTORCH_WEIGHTS_NAME,
147 | revision=revision,
148 | cache_dir=cache_dir,
149 | force_download=force_download,
150 | proxies=proxies,
151 | resume_download=resume_download,
152 | use_auth_token=use_auth_token,
153 | local_files_only=local_files_only,
154 | )
155 | model = cls(**model_kwargs["config"])
156 |
157 | state_dict = torch.load(model_file, map_location=map_location)
158 | model.load_state_dict(state_dict, strict=strict)
159 | model.eval()
160 |
161 | return model
162 |
--------------------------------------------------------------------------------
/experiments/dgmr-oneshot/dgmr/layers/Attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import functional as F
4 | import einops
5 |
6 |
7 | def attention_einsum(q, k, v):
8 | """Apply the attention operator to tensors of shape [h, w, c]."""
9 |
10 | # Reshape 3D tensors to 2D tensor with first dimension L = h x w.
11 | k = einops.rearrange(k, "h w c -> (h w) c") # [h, w, c] -> [L, c]
12 | v = einops.rearrange(v, "h w c -> (h w) c") # [h, w, c] -> [L, c]
13 |
14 | # Einstein summation corresponding to the query * key operation.
15 | beta = F.softmax(torch.einsum("hwc, Lc->hwL", q, k), dim=-1)
16 |
17 | # Einstein summation corresponding to the attention * value operation.
18 | out = torch.einsum("hwL, Lc->hwc", beta, v)
19 | return out
20 |
21 |
22 | class AttentionLayer(torch.nn.Module):
23 | """Attention Module"""
24 |
25 | def __init__(self, input_channels: int, output_channels: int, ratio_kq=8, ratio_v=8):
26 | super(AttentionLayer, self).__init__()
27 |
28 | self.ratio_kq = ratio_kq
29 | self.ratio_v = ratio_v
30 | self.output_channels = output_channels
31 | self.input_channels = input_channels
32 |
33 | # Compute query, key and value using 1x1 convolutions.
34 | self.query = torch.nn.Conv2d(
35 | in_channels=input_channels,
36 | out_channels=self.output_channels // self.ratio_kq,
37 | kernel_size=(1, 1),
38 | padding="valid",
39 | bias=False,
40 | )
41 | self.key = torch.nn.Conv2d(
42 | in_channels=input_channels,
43 | out_channels=self.output_channels // self.ratio_kq,
44 | kernel_size=(1, 1),
45 | padding="valid",
46 | bias=False,
47 | )
48 | self.value = torch.nn.Conv2d(
49 | in_channels=input_channels,
50 | out_channels=self.output_channels // self.ratio_v,
51 | kernel_size=(1, 1),
52 | padding="valid",
53 | bias=False,
54 | )
55 |
56 | self.last_conv = torch.nn.Conv2d(
57 | in_channels=self.output_channels // 8,
58 | out_channels=self.output_channels,
59 | kernel_size=(1, 1),
60 | padding="valid",
61 | bias=False,
62 | )
63 |
64 | # Learnable gain parameter
65 | self.gamma = nn.Parameter(torch.zeros(1))
66 |
67 | def forward(self, x: torch.Tensor) -> torch.Tensor:
68 | # Compute query, key and value using 1x1 convolutions.
69 | query = self.query(x)
70 | key = self.key(x)
71 | value = self.value(x)
72 | # Apply the attention operation.
73 | out = []
74 | for b in range(x.shape[0]):
75 | # Apply to each in batch
76 | out.append(attention_einsum(query[b], key[b], value[b]))
77 | out = torch.stack(out, dim=0)
78 | out = self.gamma * self.last_conv(out)
79 | # Residual connection.
80 | return out + x
81 |
--------------------------------------------------------------------------------
/experiments/dgmr-oneshot/dgmr/layers/ConvGRU.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.nn.utils.parametrizations import spectral_norm
4 |
5 |
6 | class ConvGRUCell(torch.nn.Module):
7 | """A ConvGRU implementation."""
8 |
9 | def __init__(
10 | self, input_channels: int, output_channels: int, kernel_size=3, sn_eps=0.0001
11 | ):
12 | """Constructor.
13 |
14 | Args:
15 | kernel_size: kernel size of the convolutions. Default: 3.
16 | sn_eps: constant for spectral normalization. Default: 1e-4.
17 | """
18 | super().__init__()
19 | self._kernel_size = kernel_size
20 | self._sn_eps = sn_eps
21 | self.read_gate_conv = spectral_norm(
22 | torch.nn.Conv2d(
23 | in_channels=input_channels,
24 | out_channels=output_channels,
25 | kernel_size=(kernel_size, kernel_size),
26 | padding=1,
27 | ),
28 | eps=sn_eps,
29 | )
30 | self.update_gate_conv = spectral_norm(
31 | torch.nn.Conv2d(
32 | in_channels=input_channels,
33 | out_channels=output_channels,
34 | kernel_size=(kernel_size, kernel_size),
35 | padding=1,
36 | ),
37 | eps=sn_eps,
38 | )
39 | self.output_conv = spectral_norm(
40 | torch.nn.Conv2d(
41 | in_channels=input_channels,
42 | out_channels=output_channels,
43 | kernel_size=(kernel_size, kernel_size),
44 | padding=1,
45 | ),
46 | eps=sn_eps,
47 | )
48 |
49 | def forward(self, x, prev_state):
50 | """
51 | ConvGRU forward, returning the current+new state
52 |
53 | Args:
54 | x: Input tensor
55 | prev_state: Previous state
56 |
57 | Returns:
58 | New tensor plus the new state
59 | """
60 | # Concatenate the inputs and previous state along the channel axis.
61 | xh = torch.cat([x, prev_state], dim=1)
62 |
63 | # Read gate of the GRU.
64 | read_gate = F.sigmoid(self.read_gate_conv(xh))
65 |
66 | # Update gate of the GRU.
67 | update_gate = F.sigmoid(self.update_gate_conv(xh))
68 |
69 | # Gate the inputs.
70 | gated_input = torch.cat([x, read_gate * prev_state], dim=1)
71 |
72 | # Gate the cell and state / outputs.
73 | c = F.relu(self.output_conv(gated_input))
74 | out = update_gate * prev_state + (1.0 - update_gate) * c
75 | new_state = out
76 |
77 | return out, new_state
78 |
79 |
80 | class ConvGRU(torch.nn.Module):
81 | """ConvGRU Cell wrapper to replace tf.static_rnn in TF implementation"""
82 |
83 | def __init__(
84 | self,
85 | input_channels: int,
86 | output_channels: int,
87 | kernel_size: int = 3,
88 | sn_eps=0.0001,
89 | ):
90 | super().__init__()
91 | self.cell = ConvGRUCell(input_channels, output_channels, kernel_size, sn_eps)
92 |
93 | def forward(self, x: torch.Tensor, hidden_state=None) -> torch.Tensor:
94 | outputs = []
95 | for step in range(len(x)):
96 | # Compute current timestep
97 | output, hidden_state = self.cell(x[step], hidden_state)
98 | outputs.append(output)
99 | # Stack outputs to return as tensor
100 | outputs = torch.stack(outputs, dim=0)
101 | return outputs
102 |
--------------------------------------------------------------------------------
/experiments/dgmr-oneshot/dgmr/layers/CoordConv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class AddCoords(nn.Module):
6 | def __init__(self, with_r=False):
7 | super().__init__()
8 | self.with_r = with_r
9 |
10 | def forward(self, input_tensor):
11 | """
12 | Args:
13 | input_tensor: shape(batch, channel, x_dim, y_dim)
14 | """
15 | batch_size, _, x_dim, y_dim = input_tensor.size()
16 |
17 | xx_channel = torch.arange(x_dim).repeat(1, y_dim, 1)
18 | yy_channel = torch.arange(y_dim).repeat(1, x_dim, 1).transpose(1, 2)
19 |
20 | xx_channel = xx_channel.float() / (x_dim - 1)
21 | yy_channel = yy_channel.float() / (y_dim - 1)
22 |
23 | xx_channel = xx_channel * 2 - 1
24 | yy_channel = yy_channel * 2 - 1
25 |
26 | xx_channel = xx_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)
27 | yy_channel = yy_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)
28 |
29 | ret = torch.cat(
30 | [
31 | input_tensor,
32 | xx_channel.type_as(input_tensor),
33 | yy_channel.type_as(input_tensor),
34 | ],
35 | dim=1,
36 | )
37 |
38 | if self.with_r:
39 | rr = torch.sqrt(
40 | torch.pow(xx_channel.type_as(input_tensor) - 0.5, 2)
41 | + torch.pow(yy_channel.type_as(input_tensor) - 0.5, 2)
42 | )
43 | ret = torch.cat([ret, rr], dim=1)
44 |
45 | return ret
46 |
47 |
48 | class CoordConv(nn.Module):
49 | def __init__(self, in_channels, out_channels, with_r=False, **kwargs):
50 | super().__init__()
51 | self.addcoords = AddCoords(with_r=with_r)
52 | in_size = in_channels + 2
53 | if with_r:
54 | in_size += 1
55 | self.conv = nn.Conv2d(in_size, out_channels, **kwargs)
56 |
57 | def forward(self, x):
58 | ret = self.addcoords(x)
59 | ret = self.conv(ret)
60 | return ret
61 |
--------------------------------------------------------------------------------
/experiments/dgmr-oneshot/dgmr/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from .Attention import AttentionLayer
2 | from .ConvGRU import ConvGRU
3 | from .CoordConv import CoordConv
4 |
--------------------------------------------------------------------------------
/experiments/dgmr-oneshot/dgmr/layers/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from dgmr.layers import CoordConv
3 |
4 |
5 | def get_conv_layer(conv_type: str = "standard") -> torch.nn.Module:
6 | if conv_type == "standard":
7 | conv_layer = torch.nn.Conv2d
8 | elif conv_type == "coord":
9 | conv_layer = CoordConv
10 | elif conv_type == "3d":
11 | conv_layer = torch.nn.Conv3d
12 | else:
13 | raise ValueError(f"{conv_type} is not a recognized Conv method")
14 | return conv_layer
15 |
--------------------------------------------------------------------------------
/experiments/dgmr-original/dgmr/__init__.py:
--------------------------------------------------------------------------------
1 | from .dgmr import DGMR
2 | from .generators import Sampler, Generator
3 | from .discriminators import SpatialDiscriminator, TemporalDiscriminator, Discriminator
4 | from .common import LatentConditioningStack, ContextConditioningStack
5 |
--------------------------------------------------------------------------------
/experiments/dgmr-original/dgmr/generators.py:
--------------------------------------------------------------------------------
1 | import einops
2 | import torch
3 | import torch.nn.functional as F
4 | from torch.nn.modules.pixelshuffle import PixelShuffle
5 | from torch.nn.utils.parametrizations import spectral_norm
6 | from typing import List
7 | from dgmr.common import GBlock, UpsampleGBlock
8 | from dgmr.layers import ConvGRU
9 | from huggingface_hub import PyTorchModelHubMixin
10 | import logging
11 |
12 | logger = logging.getLogger(__name__)
13 | logger.setLevel(logging.WARN)
14 |
15 |
16 | class Sampler(torch.nn.Module, PyTorchModelHubMixin):
17 | def __init__(
18 | self,
19 | forecast_steps: int = 18,
20 | latent_channels: int = 768,
21 | context_channels: int = 384,
22 | output_channels: int = 1,
23 | **kwargs
24 | ):
25 | """
26 | Sampler from the Skillful Nowcasting, see https://arxiv.org/pdf/2104.00954.pdf
27 |
28 | The sampler takes the output from the Latent and Context conditioning stacks and
29 | creates one stack of ConvGRU layers per future timestep.
30 | Args:
31 | forecast_steps: Number of forecast steps
32 | latent_channels: Number of input channels to the lowest ConvGRU layer
33 | """
34 | super().__init__()
35 | config = locals()
36 | config.pop("__class__")
37 | config.pop("self")
38 | self.config = kwargs.get("config", config)
39 | self.forecast_steps = self.config["forecast_steps"]
40 | latent_channels = self.config["latent_channels"]
41 | context_channels = self.config["context_channels"]
42 | output_channels = self.config["output_channels"]
43 |
44 | self.convGRU1 = ConvGRU(
45 | input_channels=latent_channels + context_channels,
46 | output_channels=context_channels,
47 | kernel_size=3,
48 | )
49 | self.gru_conv_1x1 = spectral_norm(
50 | torch.nn.Conv2d(
51 | in_channels=context_channels,
52 | out_channels=latent_channels,
53 | kernel_size=(1, 1),
54 | )
55 | )
56 | self.g1 = GBlock(
57 | input_channels=latent_channels, output_channels=latent_channels
58 | )
59 | self.up_g1 = UpsampleGBlock(
60 | input_channels=latent_channels, output_channels=latent_channels // 2
61 | )
62 |
63 | self.convGRU2 = ConvGRU(
64 | input_channels=latent_channels // 2 + context_channels // 2,
65 | output_channels=context_channels // 2,
66 | kernel_size=3,
67 | )
68 | self.gru_conv_1x1_2 = spectral_norm(
69 | torch.nn.Conv2d(
70 | in_channels=context_channels // 2,
71 | out_channels=latent_channels // 2,
72 | kernel_size=(1, 1),
73 | )
74 | )
75 | self.g2 = GBlock(
76 | input_channels=latent_channels // 2, output_channels=latent_channels // 2
77 | )
78 | self.up_g2 = UpsampleGBlock(
79 | input_channels=latent_channels // 2, output_channels=latent_channels // 4
80 | )
81 |
82 | self.convGRU3 = ConvGRU(
83 | input_channels=latent_channels // 4 + context_channels // 4,
84 | output_channels=context_channels // 4,
85 | kernel_size=3,
86 | )
87 | self.gru_conv_1x1_3 = spectral_norm(
88 | torch.nn.Conv2d(
89 | in_channels=context_channels // 4,
90 | out_channels=latent_channels // 4,
91 | kernel_size=(1, 1),
92 | )
93 | )
94 | self.g3 = GBlock(
95 | input_channels=latent_channels // 4, output_channels=latent_channels // 4
96 | )
97 | self.up_g3 = UpsampleGBlock(
98 | input_channels=latent_channels // 4, output_channels=latent_channels // 8
99 | )
100 |
101 | self.convGRU4 = ConvGRU(
102 | input_channels=latent_channels // 8 + context_channels // 8,
103 | output_channels=context_channels // 8,
104 | kernel_size=3,
105 | )
106 | self.gru_conv_1x1_4 = spectral_norm(
107 | torch.nn.Conv2d(
108 | in_channels=context_channels // 8,
109 | out_channels=latent_channels // 8,
110 | kernel_size=(1, 1),
111 | )
112 | )
113 | self.g4 = GBlock(
114 | input_channels=latent_channels // 8, output_channels=latent_channels // 8
115 | )
116 | self.up_g4 = UpsampleGBlock(
117 | input_channels=latent_channels // 8, output_channels=latent_channels // 16
118 | )
119 |
120 | self.bn = torch.nn.BatchNorm2d(latent_channels // 16)
121 | self.relu = torch.nn.ReLU()
122 | self.conv_1x1 = spectral_norm(
123 | torch.nn.Conv2d(
124 | in_channels=latent_channels // 16,
125 | out_channels=4 * output_channels,
126 | kernel_size=(1, 1),
127 | )
128 | )
129 |
130 | self.depth2space = PixelShuffle(upscale_factor=2)
131 |
132 | def forward(
133 | self, conditioning_states: List[torch.Tensor], latent_dim: torch.Tensor
134 | ) -> torch.Tensor:
135 | """
136 | Perform the sampling from Skillful Nowcasting with GANs
137 | Args:
138 | conditioning_states: Outputs from the `ContextConditioningStack` with the 4 input states, ordered from largest to smallest spatially
139 | latent_dim: Output from `LatentConditioningStack` for input into the ConvGRUs
140 |
141 | Returns:
142 | forecast_steps-length output of images for future timesteps
143 |
144 | """
145 | # Iterate through each forecast step
146 | # Initialize with conditioning state for first one, output for second one
147 | init_states = conditioning_states
148 | # Expand latent dim to match batch size
149 | latent_dim = einops.repeat(
150 | latent_dim, "b c h w -> (repeat b) c h w", repeat=init_states[0].shape[0]
151 | )
152 | hidden_states = [latent_dim] * self.forecast_steps
153 |
154 | # Layer 4 (bottom most)
155 | hidden_states = self.convGRU1(hidden_states, init_states[3])
156 | hidden_states = [self.gru_conv_1x1(h) for h in hidden_states]
157 | hidden_states = [self.g1(h) for h in hidden_states]
158 | hidden_states = [self.up_g1(h) for h in hidden_states]
159 |
160 | # Layer 3.
161 | hidden_states = self.convGRU2(hidden_states, init_states[2])
162 | hidden_states = [self.gru_conv_1x1_2(h) for h in hidden_states]
163 | hidden_states = [self.g2(h) for h in hidden_states]
164 | hidden_states = [self.up_g2(h) for h in hidden_states]
165 |
166 | # Layer 2.
167 | hidden_states = self.convGRU3(hidden_states, init_states[1])
168 | hidden_states = [self.gru_conv_1x1_3(h) for h in hidden_states]
169 | hidden_states = [self.g3(h) for h in hidden_states]
170 | hidden_states = [self.up_g3(h) for h in hidden_states]
171 |
172 | # Layer 1 (top-most).
173 | hidden_states = self.convGRU4(hidden_states, init_states[0])
174 | hidden_states = [self.gru_conv_1x1_4(h) for h in hidden_states]
175 | hidden_states = [self.g4(h) for h in hidden_states]
176 | hidden_states = [self.up_g4(h) for h in hidden_states]
177 |
178 | # Output layer.
179 | hidden_states = [F.relu(self.bn(h)) for h in hidden_states]
180 | hidden_states = [self.conv_1x1(h) for h in hidden_states]
181 | hidden_states = [self.depth2space(h) for h in hidden_states]
182 |
183 | # Convert forecasts to a torch Tensor
184 | forecasts = torch.stack(hidden_states, dim=1)
185 | return forecasts
186 |
187 |
188 | class Generator(torch.nn.Module, PyTorchModelHubMixin):
189 | def __init__(
190 | self,
191 | conditioning_stack: torch.nn.Module,
192 | latent_stack: torch.nn.Module,
193 | sampler: torch.nn.Module,
194 | ):
195 | """
196 | Wraps the three parts of the generator for simpler calling
197 | Args:
198 | conditioning_stack:
199 | latent_stack:
200 | sampler:
201 | """
202 | super().__init__()
203 | self.conditioning_stack = conditioning_stack
204 | self.latent_stack = latent_stack
205 | self.sampler = sampler
206 |
207 | def forward(self, x):
208 | conditioning_states = self.conditioning_stack(x)
209 | latent_dim = self.latent_stack(x)
210 | x = self.sampler(conditioning_states, latent_dim)
211 | return x
212 |
--------------------------------------------------------------------------------
/experiments/dgmr-original/dgmr/hub.py:
--------------------------------------------------------------------------------
1 | """
2 | Originally Taken from https://github.com/rwightman/
3 |
4 | https://github.com/rwightman/pytorch-image-models/
5 | blob/acd6c687fd1c0507128f0ce091829b233c8560b9/timm/models/hub.py
6 | """
7 |
8 | import json
9 | import logging
10 | import os
11 | from functools import partial
12 |
13 | import torch
14 |
15 |
16 | try:
17 | from huggingface_hub import cached_download, hf_hub_url
18 |
19 | cached_download = partial(cached_download, library_name="dgmr")
20 | except ImportError:
21 | hf_hub_url = None
22 | cached_download = None
23 |
24 | from huggingface_hub import CONFIG_NAME, PYTORCH_WEIGHTS_NAME, ModelHubMixin, hf_hub_download
25 |
26 | MODEL_CARD_MARKDOWN = """---
27 | license: mit
28 | tags:
29 | - nowcasting
30 | - forecasting
31 | - timeseries
32 | - remote-sensing
33 | - gan
34 | ---
35 |
36 | # {model_name}
37 |
38 | ## Model description
39 |
40 | [More information needed]
41 |
42 | ## Intended uses & limitations
43 |
44 | [More information needed]
45 |
46 | ## How to use
47 |
48 | [More information needed]
49 |
50 | ## Limitations and bias
51 |
52 | [More information needed]
53 |
54 | ## Training data
55 |
56 | [More information needed]
57 |
58 | ## Training procedure
59 |
60 | [More information needed]
61 |
62 | ## Evaluation results
63 |
64 | [More information needed]
65 |
66 | """
67 |
68 | _logger = logging.getLogger(__name__)
69 |
70 |
71 | class NowcastingModelHubMixin(ModelHubMixin):
72 | """
73 | HuggingFace ModelHubMixin containing specific adaptions for Nowcasting models
74 | """
75 |
76 | def __init__(self, *args, **kwargs):
77 | """
78 | Mixin for pl.LightningModule and Hugging Face
79 |
80 | Mix this class with your pl.LightningModule class to easily push / download
81 | the model via the Hugging Face Hub
82 |
83 | Example::
84 |
85 | >>> from dgmr.hub import NowcastingModelHubMixin
86 |
87 | >>> class MyModel(nn.Module, NowcastingModelHubMixin):
88 | ... def __init__(self, **kwargs):
89 | ... super().__init__()
90 | ... self.layer = ...
91 | ... def forward(self, ...)
92 | ... return ...
93 |
94 | >>> model = MyModel()
95 | >>> model.push_to_hub("mymodel") # Pushing model-weights to hf-hub
96 |
97 | >>> # Downloading weights from hf-hub & model will be initialized from those weights
98 | >>> model = MyModel.from_pretrained("username/mymodel")
99 | """
100 |
101 | def _create_model_card(self, path):
102 | model_card = MODEL_CARD_MARKDOWN.format(model_name=type(self).__name__)
103 | with open(os.path.join(path, "README.md"), "w") as f:
104 | f.write(model_card)
105 |
106 | def _save_config(self, module, save_directory):
107 | config = dict(module.hparams)
108 | path = os.path.join(save_directory, CONFIG_NAME)
109 | with open(path, "w") as f:
110 | json.dump(config, f)
111 |
112 | def _save_pretrained(self, save_directory: str, save_config: bool = True):
113 | # Save model weights
114 | path = os.path.join(save_directory, PYTORCH_WEIGHTS_NAME)
115 | model_to_save = self.module if hasattr(self, "module") else self
116 | torch.save(model_to_save.state_dict(), path)
117 | # Save model config
118 | if save_config and model_to_save.hparams:
119 | self._save_config(model_to_save, save_directory)
120 | # Save model card
121 | self._create_model_card(save_directory)
122 |
123 | @classmethod
124 | def _from_pretrained(
125 | cls,
126 | model_id,
127 | revision,
128 | cache_dir,
129 | force_download,
130 | proxies,
131 | resume_download,
132 | local_files_only,
133 | use_auth_token,
134 | map_location="cpu",
135 | strict=False,
136 | **model_kwargs,
137 | ):
138 | map_location = torch.device(map_location)
139 |
140 | if os.path.isdir(model_id):
141 | print("Loading weights from local directory")
142 | model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME)
143 | else:
144 | model_file = hf_hub_download(
145 | repo_id=model_id,
146 | filename=PYTORCH_WEIGHTS_NAME,
147 | revision=revision,
148 | cache_dir=cache_dir,
149 | force_download=force_download,
150 | proxies=proxies,
151 | resume_download=resume_download,
152 | use_auth_token=use_auth_token,
153 | local_files_only=local_files_only,
154 | )
155 | model = cls(**model_kwargs["config"])
156 |
157 | state_dict = torch.load(model_file, map_location=map_location)
158 | model.load_state_dict(state_dict, strict=strict)
159 | model.eval()
160 |
161 | return model
162 |
--------------------------------------------------------------------------------
/experiments/dgmr-original/dgmr/layers/Attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import functional as F
4 | import einops
5 |
6 |
7 | def attention_einsum(q, k, v):
8 | """Apply the attention operator to tensors of shape [h, w, c]."""
9 |
10 | # Reshape 3D tensors to 2D tensor with first dimension L = h x w.
11 | k = einops.rearrange(k, "h w c -> (h w) c") # [h, w, c] -> [L, c]
12 | v = einops.rearrange(v, "h w c -> (h w) c") # [h, w, c] -> [L, c]
13 |
14 | # Einstein summation corresponding to the query * key operation.
15 | beta = F.softmax(torch.einsum("hwc, Lc->hwL", q, k), dim=-1)
16 |
17 | # Einstein summation corresponding to the attention * value operation.
18 | out = torch.einsum("hwL, Lc->hwc", beta, v)
19 | return out
20 |
21 |
22 | class AttentionLayer(torch.nn.Module):
23 | """Attention Module"""
24 |
25 | def __init__(self, input_channels: int, output_channels: int, ratio_kq=8, ratio_v=8):
26 | super(AttentionLayer, self).__init__()
27 |
28 | self.ratio_kq = ratio_kq
29 | self.ratio_v = ratio_v
30 | self.output_channels = output_channels
31 | self.input_channels = input_channels
32 |
33 | # Compute query, key and value using 1x1 convolutions.
34 | self.query = torch.nn.Conv2d(
35 | in_channels=input_channels,
36 | out_channels=self.output_channels // self.ratio_kq,
37 | kernel_size=(1, 1),
38 | padding="valid",
39 | bias=False,
40 | )
41 | self.key = torch.nn.Conv2d(
42 | in_channels=input_channels,
43 | out_channels=self.output_channels // self.ratio_kq,
44 | kernel_size=(1, 1),
45 | padding="valid",
46 | bias=False,
47 | )
48 | self.value = torch.nn.Conv2d(
49 | in_channels=input_channels,
50 | out_channels=self.output_channels // self.ratio_v,
51 | kernel_size=(1, 1),
52 | padding="valid",
53 | bias=False,
54 | )
55 |
56 | self.last_conv = torch.nn.Conv2d(
57 | in_channels=self.output_channels // 8,
58 | out_channels=self.output_channels,
59 | kernel_size=(1, 1),
60 | padding="valid",
61 | bias=False,
62 | )
63 |
64 | # Learnable gain parameter
65 | self.gamma = nn.Parameter(torch.zeros(1))
66 |
67 | def forward(self, x: torch.Tensor) -> torch.Tensor:
68 | # Compute query, key and value using 1x1 convolutions.
69 | query = self.query(x)
70 | key = self.key(x)
71 | value = self.value(x)
72 | # Apply the attention operation.
73 | out = []
74 | for b in range(x.shape[0]):
75 | # Apply to each in batch
76 | out.append(attention_einsum(query[b], key[b], value[b]))
77 | out = torch.stack(out, dim=0)
78 | out = self.gamma * self.last_conv(out)
79 | # Residual connection.
80 | return out + x
81 |
--------------------------------------------------------------------------------
/experiments/dgmr-original/dgmr/layers/ConvGRU.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.nn.utils.parametrizations import spectral_norm
4 |
5 |
6 | class ConvGRUCell(torch.nn.Module):
7 | """A ConvGRU implementation."""
8 |
9 | def __init__(self, input_channels: int, output_channels: int, kernel_size=3, sn_eps=0.0001):
10 | """Constructor.
11 |
12 | Args:
13 | kernel_size: kernel size of the convolutions. Default: 3.
14 | sn_eps: constant for spectral normalization. Default: 1e-4.
15 | """
16 | super().__init__()
17 | self._kernel_size = kernel_size
18 | self._sn_eps = sn_eps
19 | self.read_gate_conv = spectral_norm(
20 | torch.nn.Conv2d(
21 | in_channels=input_channels,
22 | out_channels=output_channels,
23 | kernel_size=(kernel_size, kernel_size),
24 | padding=1,
25 | ),
26 | eps=sn_eps,
27 | )
28 | self.update_gate_conv = spectral_norm(
29 | torch.nn.Conv2d(
30 | in_channels=input_channels,
31 | out_channels=output_channels,
32 | kernel_size=(kernel_size, kernel_size),
33 | padding=1,
34 | ),
35 | eps=sn_eps,
36 | )
37 | self.output_conv = spectral_norm(
38 | torch.nn.Conv2d(
39 | in_channels=input_channels,
40 | out_channels=output_channels,
41 | kernel_size=(kernel_size, kernel_size),
42 | padding=1,
43 | ),
44 | eps=sn_eps,
45 | )
46 |
47 | def forward(self, x, prev_state):
48 | """
49 | ConvGRU forward, returning the current+new state
50 |
51 | Args:
52 | x: Input tensor
53 | prev_state: Previous state
54 |
55 | Returns:
56 | New tensor plus the new state
57 | """
58 | # Concatenate the inputs and previous state along the channel axis.
59 | xh = torch.cat([x, prev_state], dim=1)
60 |
61 | # Read gate of the GRU.
62 | read_gate = F.sigmoid(self.read_gate_conv(xh))
63 |
64 | # Update gate of the GRU.
65 | update_gate = F.sigmoid(self.update_gate_conv(xh))
66 |
67 | # Gate the inputs.
68 | gated_input = torch.cat([x, read_gate * prev_state], dim=1)
69 |
70 | # Gate the cell and state / outputs.
71 | c = F.relu(self.output_conv(gated_input))
72 | out = update_gate * prev_state + (1.0 - update_gate) * c
73 | new_state = out
74 |
75 | return out, new_state
76 |
77 |
78 | class ConvGRU(torch.nn.Module):
79 | """ConvGRU Cell wrapper to replace tf.static_rnn in TF implementation"""
80 |
81 | def __init__(
82 | self,
83 | input_channels: int,
84 | output_channels: int,
85 | kernel_size: int = 3,
86 | sn_eps=0.0001,
87 | ):
88 | super().__init__()
89 | self.cell = ConvGRUCell(input_channels, output_channels, kernel_size, sn_eps)
90 |
91 | def forward(self, x: torch.Tensor, hidden_state=None) -> torch.Tensor:
92 | outputs = []
93 | for step in range(len(x)):
94 | # Compute current timestep
95 | output, hidden_state = self.cell(x[step], hidden_state)
96 | outputs.append(output)
97 | # Stack outputs to return as tensor
98 | outputs = torch.stack(outputs, dim=0)
99 | return outputs
100 |
--------------------------------------------------------------------------------
/experiments/dgmr-original/dgmr/layers/CoordConv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class AddCoords(nn.Module):
6 | def __init__(self, with_r=False):
7 | super().__init__()
8 | self.with_r = with_r
9 |
10 | def forward(self, input_tensor):
11 | """
12 | Args:
13 | input_tensor: shape(batch, channel, x_dim, y_dim)
14 | """
15 | batch_size, _, x_dim, y_dim = input_tensor.size()
16 |
17 | xx_channel = torch.arange(x_dim).repeat(1, y_dim, 1)
18 | yy_channel = torch.arange(y_dim).repeat(1, x_dim, 1).transpose(1, 2)
19 |
20 | xx_channel = xx_channel.float() / (x_dim - 1)
21 | yy_channel = yy_channel.float() / (y_dim - 1)
22 |
23 | xx_channel = xx_channel * 2 - 1
24 | yy_channel = yy_channel * 2 - 1
25 |
26 | xx_channel = xx_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)
27 | yy_channel = yy_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)
28 |
29 | ret = torch.cat(
30 | [input_tensor, xx_channel.type_as(input_tensor), yy_channel.type_as(input_tensor)],
31 | dim=1,
32 | )
33 |
34 | if self.with_r:
35 | rr = torch.sqrt(
36 | torch.pow(xx_channel.type_as(input_tensor) - 0.5, 2)
37 | + torch.pow(yy_channel.type_as(input_tensor) - 0.5, 2)
38 | )
39 | ret = torch.cat([ret, rr], dim=1)
40 |
41 | return ret
42 |
43 |
44 | class CoordConv(nn.Module):
45 | def __init__(self, in_channels, out_channels, with_r=False, **kwargs):
46 | super().__init__()
47 | self.addcoords = AddCoords(with_r=with_r)
48 | in_size = in_channels + 2
49 | if with_r:
50 | in_size += 1
51 | self.conv = nn.Conv2d(in_size, out_channels, **kwargs)
52 |
53 | def forward(self, x):
54 | ret = self.addcoords(x)
55 | ret = self.conv(ret)
56 | return ret
57 |
--------------------------------------------------------------------------------
/experiments/dgmr-original/dgmr/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from .Attention import AttentionLayer
2 | from .ConvGRU import ConvGRU
3 | from .CoordConv import CoordConv
4 |
--------------------------------------------------------------------------------
/experiments/dgmr-original/dgmr/layers/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from dgmr.layers import CoordConv
3 |
4 |
5 | def get_conv_layer(conv_type: str = "standard") -> torch.nn.Module:
6 | if conv_type == "standard":
7 | conv_layer = torch.nn.Conv2d
8 | elif conv_type == "coord":
9 | conv_layer = CoordConv
10 | elif conv_type == "3d":
11 | conv_layer = torch.nn.Conv3d
12 | else:
13 | raise ValueError(f"{conv_type} is not a recognized Conv method")
14 | return conv_layer
15 |
--------------------------------------------------------------------------------
/figs/final_leaderboard.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jmathur25/climatehack/96a551131e2482601265595a80bce315c60874ac/figs/final_leaderboard.png
--------------------------------------------------------------------------------
/figs/model_predictions.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jmathur25/climatehack/96a551131e2482601265595a80bce315c60874ac/figs/model_predictions.gif
--------------------------------------------------------------------------------