├── .gitignore ├── requirements.txt ├── graphcast.yaml ├── constants.py ├── README.md ├── graphcast.dockerfile └── prediction.py /.gitignore: -------------------------------------------------------------------------------- 1 | # The CDS key. 2 | .cdsapirc 3 | 4 | # The stats files. 5 | model/* 6 | 7 | # CDS files. 8 | *.zip 9 | *.nc 10 | *.csv 11 | 12 | # Cache files. 13 | *.pyc -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cdsapi==0.7.5 2 | cfgrib==0.9.15.0 3 | Cython==3.0.12 4 | dm-tree==0.1.8 5 | Flask==2.0.2 6 | flask-restx==0.5.0 7 | gevent==24.10.3 8 | google-cloud-storage==2.13.0 9 | gunicorn==20.0.4 10 | isodate==0.6.0 11 | netCDF4==1.7.2 12 | pymongo==3.12.1 13 | pysolar==0.11 14 | python-dotenv==0.19.2 15 | werkzeug==2.1.2 -------------------------------------------------------------------------------- /graphcast.yaml: -------------------------------------------------------------------------------- 1 | services: 2 | graphcast: 3 | container_name: graphcast 4 | image: graphcast 5 | stdin_open: true 6 | tty: true 7 | deploy: 8 | resources: 9 | reservations: 10 | devices: 11 | - driver: nvidia 12 | count: 1 13 | capabilities: [gpu] 14 | volumes: 15 | - C:\Users\path:/app -------------------------------------------------------------------------------- /constants.py: -------------------------------------------------------------------------------- 1 | #pyright: reportMissingImports=false 2 | 3 | from enum import Enum 4 | 5 | class Constants: 6 | 7 | class CDSConstants(Enum): 8 | 9 | TIME_FIELD = 'valid_time' 10 | LAT_FIELD = 'latitude' 11 | LON_FIELD = 'longitude' 12 | PRESSURE_FIELD = 'pressure_level' 13 | 14 | class Graphcast(Enum): 15 | 16 | TIME_FIELD = 'time' 17 | LAT_FIELD = 'latitude' 18 | LON_FIELD = 'longitude' 19 | PRESSURE_FIELD = 'level' 20 | BATCH_FIELD = 'batch' 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Steps: 2 | 1. Create the docker image. 3 | 2. Run the docker image, creating a container. 4 | 3. While the docker image is running, open another terminal and enter the docker container. 5 | 4. Execute "python3 prediction.py". 6 | 7 |
8 | 9 | ### Command to build docker image: 10 | 11 | ``` 12 | docker build -f graphcast.dockerfile -t graphcast . 13 | ``` 14 | 15 | ### Command to run docker image using the compose file: 16 | 17 | ``` 18 | docker compose -f graphcast.yaml up 19 | ``` 20 | 21 | ### Command to enter the terminal and execute instructions: 22 | 23 | ``` 24 | docker exec -it graphcast /bin/bash 25 | ``` 26 | 27 | ### Running the prediction file: 28 | 29 | ``` 30 | python3 prediction.py 31 | ``` 32 | -------------------------------------------------------------------------------- /graphcast.dockerfile: -------------------------------------------------------------------------------- 1 | # Get ubuntu image from https://hub.docker.com/_/ubuntu. 2 | FROM ubuntu:22.04 3 | 4 | # Update ubuntu's package installer and download python and other important packages. 5 | ENV DEBIAN_FRONTEND=noninteractive 6 | RUN apt-get update 7 | RUN apt-get install -y python3 python3-pip curl unzip sudo 8 | 9 | # Download components required by Graphcast. 10 | RUN pip3 install --upgrade https://github.com/deepmind/graphcast/archive/master.zip 11 | 12 | # Create dashboard directory. 13 | RUN mkdir -p /app 14 | WORKDIR /app 15 | 16 | # Copy all files to environment. 17 | COPY . . 18 | 19 | # Download necessary packages. 20 | RUN sudo apt-get install -y libgeos-dev 21 | RUN pip3 uninstall -y shapely 22 | RUN pip3 install -r requirements.txt 23 | RUN pip3 install shapely --no-binary shapely 24 | RUN pip3 install -U "jax[cuda12_pip]" -f "https://storage.googleapis.com/jax-releases/jax_cuda_releases.html" 25 | 26 | # Move the CDS key to the root folder. 27 | RUN mv /app/.cdsapirc /root/.cdsapirc -------------------------------------------------------------------------------- /prediction.py: -------------------------------------------------------------------------------- 1 | #pyright: reportMissingImports=false 2 | 3 | import cdsapi 4 | import datetime 5 | import functools 6 | from google.cloud import storage 7 | from graphcast import autoregressive, casting, checkpoint, data_utils as du, graphcast, normalization, rollout 8 | import haiku as hk 9 | import isodate 10 | import jax 11 | import math 12 | import netCDF4 13 | import numpy as np 14 | import os 15 | import pandas as pd 16 | from pysolar.radiation import get_radiation_direct 17 | from pysolar.solar import get_altitude 18 | import pytz 19 | from typing import Dict 20 | import warnings 21 | import xarray 22 | import zipfile 23 | warnings.filterwarnings('ignore') 24 | 25 | from constants import Constants 26 | 27 | client = cdsapi.Client() 28 | 29 | gcs_client = storage.Client.create_anonymous_client() 30 | gcs_bucket = gcs_client.get_bucket('dm_graphcast') 31 | 32 | singlelevelfields = { 33 | 'u10': '10m_u_component_of_wind', 34 | 'v10': '10m_v_component_of_wind', 35 | 't2m': '2m_temperature', 36 | 'z': 'geopotential', 37 | 'lsm': 'land_sea_mask', 38 | 'msl': 'mean_sea_level_pressure', 39 | 'tisr': 'toa_incident_solar_radiation', 40 | 'tp': 'total_precipitation' 41 | } 42 | pressurelevelfields = { 43 | 'u': 'u_component_of_wind', 44 | 'v': 'v_component_of_wind', 45 | 'z': 'geopotential', 46 | 'q': 'specific_humidity', 47 | 't': 'temperature', 48 | 'w': 'vertical_velocity' 49 | } 50 | predictionFields = [ 51 | 'u_component_of_wind', 52 | 'v_component_of_wind', 53 | 'geopotential', 54 | 'specific_humidity', 55 | 'temperature', 56 | 'vertical_velocity', 57 | '10m_u_component_of_wind', 58 | '10m_v_component_of_wind', 59 | '2m_temperature', 60 | 'mean_sea_level_pressure', 61 | 'total_precipitation_6hr' 62 | ] 63 | pressure_levels = [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000] 64 | pi = math.pi 65 | gap = 6 66 | predictions_steps = 4 67 | watts_to_joules = 3600 68 | first_prediction = datetime.datetime(2025, 1, 1, 18, 0) 69 | lat_range = range(-90, 91, 1) 70 | lon_range = range(0, 360, 1) 71 | 72 | class AssignCoordinates: 73 | 74 | coordinates = { 75 | '2m_temperature': [Constants.Graphcast.BATCH_FIELD.value, 'lon', 'lat', 'time'], 76 | 'mean_sea_level_pressure': [Constants.Graphcast.BATCH_FIELD.value, 'lon', 'lat', 'time'], 77 | '10m_v_component_of_wind': [Constants.Graphcast.BATCH_FIELD.value, 'lon', 'lat', 'time'], 78 | '10m_u_component_of_wind': [Constants.Graphcast.BATCH_FIELD.value, 'lon', 'lat', 'time'], 79 | 'total_precipitation_6hr': [Constants.Graphcast.BATCH_FIELD.value, 'lon', 'lat', 'time'], 80 | 'temperature': [Constants.Graphcast.BATCH_FIELD.value, 'lon', 'lat', 'level', 'time'], 81 | 'geopotential': [Constants.Graphcast.BATCH_FIELD.value, 'lon', 'lat', 'level', 'time'], 82 | 'u_component_of_wind': [Constants.Graphcast.BATCH_FIELD.value, 'lon', 'lat', 'level', 'time'], 83 | 'v_component_of_wind': [Constants.Graphcast.BATCH_FIELD.value, 'lon', 'lat', 'level', 'time'], 84 | 'vertical_velocity': [Constants.Graphcast.BATCH_FIELD.value, 'lon', 'lat', 'level', 'time'], 85 | 'specific_humidity': [Constants.Graphcast.BATCH_FIELD.value, 'lon', 'lat', 'level', 'time'], 86 | 'toa_incident_solar_radiation': [Constants.Graphcast.BATCH_FIELD.value, 'lon', 'lat', 'time'], 87 | 'year_progress_cos': [Constants.Graphcast.BATCH_FIELD.value, 'time'], 88 | 'year_progress_sin': [Constants.Graphcast.BATCH_FIELD.value, 'time'], 89 | 'day_progress_cos': [Constants.Graphcast.BATCH_FIELD.value, 'lon', 'time'], 90 | 'day_progress_sin': [Constants.Graphcast.BATCH_FIELD.value, 'lon', 'time'], 91 | 'geopotential_at_surface': ['lon', 'lat'], 92 | 'land_sea_mask': ['lon', 'lat'], 93 | } 94 | 95 | print('Connecting to dm_graphcast bucket...') 96 | with gcs_bucket.blob(f'params/GraphCast_small - ERA5 1979-2015 - resolution 1.0 - pressure levels 13 - mesh 2to5 - precipitation input and output.npz').open('rb') as model: 97 | ckpt = checkpoint.load(model, graphcast.CheckPoint) 98 | params = ckpt.params 99 | state = {} 100 | model_config = ckpt.model_config 101 | task_config = ckpt.task_config 102 | 103 | print('Loading the diffs_stddev_by_level.nc file...') 104 | with open(r'model/stats/diffs_stddev_by_level.nc', 'rb') as f: 105 | diffs_stddev_by_level = xarray.load_dataset(f).compute() 106 | 107 | print('Loading the mean_by_level.nc file...') 108 | with open(r'model/stats/mean_by_level.nc', 'rb') as f: 109 | mean_by_level = xarray.load_dataset(f).compute() 110 | 111 | print('Loading the stddev_by_level.nc file...') 112 | with open(r'model/stats/stddev_by_level.nc', 'rb') as f: 113 | stddev_by_level = xarray.load_dataset(f).compute() 114 | 115 | def construct_wrapped_graphcast(model_config:graphcast.ModelConfig, task_config:graphcast.TaskConfig): 116 | 117 | predictor = graphcast.GraphCast(model_config, task_config) 118 | predictor = casting.Bfloat16Cast(predictor) 119 | predictor = normalization.InputsAndResiduals(predictor, diffs_stddev_by_level = diffs_stddev_by_level, mean_by_level = mean_by_level, stddev_by_level = stddev_by_level) 120 | predictor = autoregressive.Predictor(predictor, gradient_checkpointing = True) 121 | 122 | return predictor 123 | 124 | @hk.transform_with_state 125 | def run_forward(model_config, task_config, inputs, targets_template, forcings): 126 | 127 | predictor = construct_wrapped_graphcast(model_config, task_config) 128 | 129 | return predictor(inputs, targets_template = targets_template, forcings = forcings) 130 | 131 | def with_configs(fn): 132 | 133 | return functools.partial(fn, model_config = model_config, task_config = task_config) 134 | 135 | def with_params(fn): 136 | 137 | return functools.partial(fn, params = params, state = state) 138 | 139 | def drop_state(fn): 140 | 141 | return lambda **kw: fn(**kw)[0] 142 | 143 | run_forward_jitted = drop_state(with_params(jax.jit(with_configs(run_forward.apply)))) 144 | 145 | class Predictor: 146 | 147 | @classmethod 148 | def predict(cls, inputs, targets, forcings) -> xarray.Dataset: 149 | 150 | predictions = rollout.chunked_prediction(run_forward_jitted, rng = jax.random.PRNGKey(0), inputs = inputs, targets_template = targets, forcings = forcings) 151 | 152 | return predictions 153 | 154 | # Converting the variable to a datetime object. 155 | def toDatetime(dt) -> datetime.datetime: 156 | 157 | if isinstance(dt, datetime.date) and isinstance(dt, datetime.datetime): 158 | 159 | return dt 160 | 161 | elif isinstance(dt, datetime.date) and not isinstance(dt, datetime.datetime): 162 | 163 | return datetime.datetime.combine(dt, datetime.datetime.min.time()) 164 | 165 | elif isinstance(dt, str): 166 | 167 | if 'T' in dt: 168 | return isodate.parse_datetime(dt) 169 | else: 170 | return datetime.datetime.combine(isodate.parse_date(dt), datetime.datetime.min.time()) 171 | 172 | def nans(*args) -> list: 173 | 174 | return np.full((args), np.nan) 175 | 176 | def deltaTime(dt, **delta) -> datetime.datetime: 177 | 178 | return dt + datetime.timedelta(**delta) 179 | 180 | def addTimezone(dt, tz = pytz.UTC) -> datetime.datetime: 181 | 182 | dt = toDatetime(dt) 183 | if dt.tzinfo == None: 184 | return pytz.UTC.localize(dt).astimezone(tz) 185 | else: 186 | return dt.astimezone(tz) 187 | 188 | def remove_junk_columns(df:pd.DataFrame): 189 | 190 | for col in ['number', 'expver']: 191 | if col in df.columns.values.tolist(): 192 | df.pop(col) 193 | 194 | return df 195 | 196 | def getSingleLevelValues(filename): 197 | 198 | extract_to = filename.split('.')[0] 199 | with zipfile.ZipFile(filename, 'r') as f: 200 | f.extractall(extract_to) 201 | 202 | dfs = [] 203 | for i in os.listdir(extract_to): 204 | extension = i.split('.')[-1] 205 | if extension == 'nc': 206 | df = xarray.open_dataset('{}/{}'.format(extract_to, i), engine = netCDF4.__name__.lower()).to_dataframe() 207 | df = remove_junk_columns(df) 208 | dfs.append(df) 209 | 210 | single_level_df = pd.concat(dfs, axis = 1) 211 | 212 | return single_level_df 213 | 214 | # Getting the single and pressure level values. 215 | def getSingleAndPressureValues(): 216 | 217 | client.retrieve( 218 | 'reanalysis-era5-single-levels', 219 | { 220 | 'product_type': 'reanalysis', 221 | 'variable': list(singlelevelfields.values()), 222 | 'grid': '1.0/1.0', 223 | 'year': [2025], 224 | 'month': [1], 225 | 'day': [1], 226 | 'time': ['00:00', '01:00', '02:00', '03:00', '04:00', '05:00', '06:00', '07:00', '08:00', '09:00', '10:00', '11:00', '12:00'], 227 | 'data_format': 'netcdf', 228 | 'download_format': 'zip' 229 | } 230 | ).download('single-level.zip') 231 | singlelevel = getSingleLevelValues('single-level.zip') 232 | singlelevel = singlelevel.rename(columns = {col:singlelevelfields[col] for col in singlelevel.columns.values.tolist() if col in singlelevelfields}) 233 | singlelevel = singlelevel.rename(columns = {'geopotential': 'geopotential_at_surface'}) 234 | 235 | # Calculating the sum of the last 6 hours of rainfall. 236 | singlelevel = singlelevel.sort_index() 237 | singlelevel['total_precipitation_6hr'] = singlelevel.groupby(level=[0, 1])['total_precipitation'].rolling(window = 6, min_periods = 1).sum().reset_index(level=[0, 1], drop=True) 238 | singlelevel.pop('total_precipitation') 239 | 240 | client.retrieve( 241 | 'reanalysis-era5-pressure-levels', 242 | { 243 | 'product_type': 'reanalysis', 244 | 'variable': list(pressurelevelfields.values()), 245 | 'grid': '1.0/1.0', 246 | 'year': [2025], 247 | 'month': [1], 248 | 'day': [1], 249 | 'time': ['06:00', '12:00'], 250 | 'pressure_level': pressure_levels, 251 | 'data_format': 'netcdf', 252 | 'download_format': 'unarchived' 253 | } 254 | ).download('pressure-level.nc') 255 | pressurelevel = xarray.open_dataset('pressure-level.nc', engine = netCDF4.__name__.lower()).to_dataframe() 256 | pressurelevel = remove_junk_columns(pressurelevel) 257 | pressurelevel = pressurelevel.rename(columns = {col:pressurelevelfields[col] for col in pressurelevel.columns.values.tolist() if col in pressurelevelfields}) 258 | 259 | return singlelevel, pressurelevel 260 | 261 | # Adding sin and cos of the year progress. 262 | def addYearProgress(secs, data): 263 | 264 | progress = du.get_year_progress(secs) 265 | data['year_progress_sin'] = math.sin(2 * pi * progress) 266 | data['year_progress_cos'] = math.cos(2 * pi * progress) 267 | 268 | return data 269 | 270 | # Adding sin and cos of the day progress. 271 | def addDayProgress(secs, lon:str, data:pd.DataFrame): 272 | 273 | lons = data.index.get_level_values(lon).unique() 274 | progress:np.ndarray = du.get_day_progress(secs, np.array(lons)) 275 | prxlon = {lon:prog for lon, prog in list(zip(list(lons), progress.tolist()))} 276 | data['day_progress_sin'] = data.index.get_level_values(lon).map(lambda x: math.sin(2 * pi * prxlon[x])) 277 | data['day_progress_cos'] = data.index.get_level_values(lon).map(lambda x: math.cos(2 * pi * prxlon[x])) 278 | 279 | return data 280 | 281 | def integrateProgress(data:pd.DataFrame): 282 | 283 | for dt in data.index.get_level_values(Constants.CDSConstants.TIME_FIELD.value).unique(): 284 | seconds_since_epoch = toDatetime(dt).timestamp() 285 | data = addYearProgress(seconds_since_epoch, data) 286 | data = addDayProgress(seconds_since_epoch, 'longitude' if 'longitude' in data.index.names else 'lon', data) 287 | 288 | return data 289 | 290 | def getSolarRadiation(longitude, latitude, dt): 291 | 292 | altitude_degrees = get_altitude(latitude, longitude, addTimezone(dt)) 293 | solar_radiation = get_radiation_direct(dt, altitude_degrees) if altitude_degrees > 0 else 0 294 | 295 | return solar_radiation * watts_to_joules 296 | 297 | def integrateSolarRadiation(data:pd.DataFrame): 298 | 299 | dates = list(data.index.get_level_values(Constants.CDSConstants.TIME_FIELD.value).unique()) 300 | coords = [[lat, lon] for lat in lat_range for lon in lon_range] 301 | values = [] 302 | 303 | for dt in dates: 304 | values.extend(list(map(lambda coord:{Constants.CDSConstants.TIME_FIELD.value: dt, Constants.CDSConstants.LON_FIELD.value: coord[1], Constants.CDSConstants.LAT_FIELD.value: coord[0], 'toa_incident_solar_radiation': getSolarRadiation(coord[1], coord[0], dt)}, coords))) 305 | 306 | values = pd.DataFrame(values).set_index(keys = [Constants.CDSConstants.LAT_FIELD.value, Constants.CDSConstants.LON_FIELD.value, Constants.CDSConstants.TIME_FIELD.value]) 307 | 308 | return pd.merge(data, values, left_index = True, right_index = True, how = 'inner') 309 | 310 | def modifyCoordinates(data:xarray.Dataset): 311 | 312 | for var in list(data.data_vars): 313 | varArray:xarray.DataArray = data[var] 314 | nonIndices = list(set(list(varArray.coords)).difference(set(AssignCoordinates.coordinates[var]))) 315 | data[var] = varArray.isel(**{coord: 0 for coord in nonIndices}) 316 | data = data.drop_vars(Constants.Graphcast.BATCH_FIELD.value) 317 | 318 | return data 319 | 320 | def makeXarray(data:pd.DataFrame) -> xarray.Dataset: 321 | 322 | data = data.rename_axis(index={ 323 | Constants.CDSConstants.TIME_FIELD.value: Constants.Graphcast.TIME_FIELD.value, 324 | Constants.CDSConstants.PRESSURE_FIELD.value: Constants.Graphcast.PRESSURE_FIELD.value 325 | }) 326 | data = data.to_xarray() 327 | data = modifyCoordinates(data) 328 | 329 | return data 330 | 331 | def formatData(data:pd.DataFrame) -> pd.DataFrame: 332 | 333 | data = data.rename_axis(index = {Constants.CDSConstants.LAT_FIELD.value: 'lat', Constants.CDSConstants.LON_FIELD.value: 'lon'}) 334 | if Constants.Graphcast.BATCH_FIELD.value not in data.index.names: 335 | data[Constants.Graphcast.BATCH_FIELD.value] = 0 336 | data = data.set_index(Constants.Graphcast.BATCH_FIELD.value, append = True) 337 | 338 | return data 339 | 340 | def getTargets(dt, data:pd.DataFrame): 341 | 342 | lat, lon, levels, batch = sorted(data.index.get_level_values('lat').unique().tolist()), sorted(data.index.get_level_values('lon').unique().tolist()), sorted(data.index.get_level_values('pressure_level').unique().tolist()), data.index.get_level_values(Constants.Graphcast.BATCH_FIELD.value).unique().tolist() 343 | time = [deltaTime(dt, hours = days * gap) for days in range(predictions_steps)] 344 | target = xarray.Dataset({field: (['lat', 'lon', 'level', Constants.CDSConstants.TIME_FIELD.value], nans(len(lat), len(lon), len(levels), len(time))) for field in predictionFields}, coords = {'lat': lat, 'lon': lon, 'level': levels, Constants.CDSConstants.TIME_FIELD.value: time, Constants.Graphcast.BATCH_FIELD.value: batch}) 345 | 346 | return target.to_dataframe() 347 | 348 | def getForcings(data:pd.DataFrame): 349 | 350 | forcingdf = data.reset_index(level = 'level', drop = True).drop(labels = predictionFields, axis = 1) 351 | forcingdf = pd.DataFrame(index = forcingdf.index.drop_duplicates(keep = 'first')) 352 | forcingdf = integrateProgress(forcingdf) 353 | forcingdf = integrateSolarRadiation(forcingdf) 354 | 355 | return forcingdf 356 | 357 | if __name__ == '__main__': 358 | 359 | values:Dict[str, xarray.Dataset] = {} 360 | 361 | single, pressure = getSingleAndPressureValues() 362 | values['inputs'] = pd.merge(pressure, single, left_index = True, right_index = True, how = 'inner') 363 | values['inputs'] = integrateProgress(values['inputs']) 364 | values['inputs'] = formatData(values['inputs']) 365 | values['targets'] = getTargets(first_prediction, values['inputs']) 366 | values['forcings'] = getForcings(values['targets']) 367 | values = {value:makeXarray(values[value]) for value in values} 368 | 369 | predictions = Predictor.predict(values['inputs'], values['targets'], values['forcings']) 370 | predictions.to_dataframe().to_csv('predictions.csv', sep = ',') 371 | --------------------------------------------------------------------------------