├── .gitattributes
├── .github
└── workflows
│ └── pylint.yml
├── .gitignore
├── CODE_OF_CONDUCT.md
├── Data
├── WeatherReal-ISD-2021.nc
├── WeatherReal-ISD-2022.nc
└── WeatherReal-ISD-2023.nc
├── LICENSE
├── README.md
├── SECURITY.md
├── SUPPORT.md
├── conda.yml
├── evaluate.py
├── evaluation
├── __init__.py
├── forecast_reformat_catalog.py
├── metric_catalog.py
├── obs_reformat_catalog.py
└── utils.py
├── metric_config.yml
├── pylintrc
└── quality_control
├── README.md
├── __init__.py
├── algo
├── __init__.py
├── cluster.py
├── config.yaml
├── cross_variable.py
├── distributional_gap.py
├── diurnal_cycle.py
├── fine_tuning.py
├── neighbouring_stations.py
├── persistence.py
├── record_extreme.py
├── refinement.py
├── spike.py
├── time_series.py
└── utils.py
├── download_ISD.py
├── quality_control.py
├── raw_ISD_to_hourly.py
└── station_merging.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.nc filter=lfs diff=lfs merge=lfs -text
2 |
--------------------------------------------------------------------------------
/.github/workflows/pylint.yml:
--------------------------------------------------------------------------------
1 | name: Pylint
2 |
3 | on:
4 | pull_request:
5 | branches: [main]
6 |
7 | jobs:
8 | build:
9 | runs-on: ubuntu-latest
10 | strategy:
11 | matrix:
12 | python-version: ["3.8", "3.9", "3.10"]
13 | steps:
14 | - uses: actions/checkout@v4
15 | - name: Set up Python ${{ matrix.python-version }}
16 | uses: actions/setup-python@v3
17 | with:
18 | python-version: ${{ matrix.python-version }}
19 | - name: Install dependencies
20 | run: |
21 | python -m pip install --upgrade pip
22 | pip install pylint
23 | - name: Analysing the code with pylint
24 | run: |
25 | pylint $(git ls-files '*.py')
26 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | .idea/
161 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Microsoft Open Source Code of Conduct
2 |
3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
4 |
5 | Resources:
6 |
7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
10 |
--------------------------------------------------------------------------------
/Data/WeatherReal-ISD-2021.nc:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:b137bff773931b938f649015134d73964b12da8274066631e7f71132d44cde86
3 | size 535031735
4 |
--------------------------------------------------------------------------------
/Data/WeatherReal-ISD-2022.nc:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:b5a3a11ed0c2d7b790f1adc4b48df779883b99de8da38dc0fe91d371b79addc6
3 | size 530269322
4 |
--------------------------------------------------------------------------------
/Data/WeatherReal-ISD-2023.nc:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:2ae3267ea9cdbb6fa2551265f519afe06ab4f215b19a9b66117f32aede1eee45
3 | size 546501126
4 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Microsoft Corporation.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # WeatherReal: A Benchmark Based on In-Situ Observations for Evaluating Weather Models
2 |
3 | Welcome to the official GitHub repository and leaderboard page for the WeatherReal weather forecasting benchmark!
4 | This repository includes code for running the evaluation of weather forecasting models on the WeatherReal benchmark dataset,
5 | as well as links to the LFS storage for the data files.
6 | If you would like to contribute a model evaluation result to this page, please open an Issue with the tag "Submission".
7 | If you use this repository or its data, please cite the following:
8 | - [WeatherReal](https://arxiv.org/html/2409.09371): A Benchmark Based on In-Situ Observations for Evaluating Weather Models
9 | - Synoptic Data (for the Synoptic data used in the benchmark)
10 | - [NCEI ISD database](https://journals.ametsoc.org/view/journals/bams/92/6/2011bams3015_1.xml)
11 |
12 | ## Why WeatherReal?
13 |
14 | Weather forecasting is a critical application for many scenarios, including disaster preparedness, agriculture, energy,
15 | and day-to-day life. Recent advances in AI applications for weather forecasting have demonstrated enormous potential,
16 | with models like GraphCast, Pangu-Weather, and FengWu achieving state-of-the-art performance on benchmarks like
17 | WeatherBench-2, even outperforming traditional numerical weather prediction models (the ECMWF IFS). However, these models
18 | have not been fully tested on real-world data collected from observing stations; rather the evaluation has focused on
19 | reanalysis datasets such as ERA5 that are generated by NWP models and thus include those models' biases and approximations.
20 |
21 | WeatherReal is the first benchmark to provide comprehensive evaluation against observations from the
22 | [Integrated Surface Database (ISD)](https://www.ncei.noaa.gov/products/land-based-station/integrated-surface-database)
23 | and the [Synoptic Data](https://www.synopticdata.com/) API, in addition to aggregated user-reported observations collected
24 | from Weather from MSN Weather's reporting platform. We provide these datasets separately and offer separate benchmark
25 | leaderboards for each, in addition to separating leaderboards by tasks such as short-term and medium-range forecasting.
26 | We also provide evaluation code that can be run on a variety of model forecast schemas to generate scores for the benchmark.
27 | By interpolating gridded forecasts to station locations, and using nearest-neighbor interpolation to match point forecasts
28 | to station locations, we can fairly evaluate models that produce either gridded or point forecasts.
29 |
30 | ### What WeatherReal is
31 |
32 | - A benchmark dataset of quality-controlled weather observations spanning the year 2023, with ongoing updates planned
33 | - An evaluation framework to score many tasks for either grid-based or point-based forecasts
34 | - Still in its infancy - we welcome feedback on how to improve the benchmark, the evaluation code, and the leaderboards
35 |
36 | ### What WeatherReal is not
37 |
38 | - A dataset for training weather forecasting models
39 | - An inference platform for running models
40 | - An exclusive benchmark for weather forecasting
41 |
42 | ## Leaderboards
43 |
44 | ### [Provisional] Global medium-range weather forecasting
45 |
46 | Task: Forecasts are initialized twice daily at 00 and 12 UTC over the entire evaluation
47 | year 2023. Forecasts are evaluated at every 6 hours of lead time up to 168 hours (7 days). Headline metric is
48 | the RMSE for each predicted variable (except for ETS for precipitation) averaged over all forecasts and lead
49 | times. **Note**: The leaderboard is provisional due to incomplete forecast initializations for the provided mdoels, and
50 | therefore is subject to change.
51 |
52 | | **WeatherReal-ISD** | 2-m temperature
(RMSE, K) | 10-m wind speed
(RMSE, m/s) | mean sea-level pressure
(RMSE, hPa) | total cloud cover
(RMSE, okta) | 6-hour precipitation > 1 mm
(ETS) |
53 | |---------------------|--------------------------------|----------------------------------|------------------------------------------|-------------------------------------|----------------------------------------|
54 | | Microsoft-Point | **2.258** | **1.753** | - | **2.723** | - |
55 | | Aurora-9km | 2.417 | 2.186 | **2.939** | - | - |
56 | | ECMWF | 2.766 | 2.251 | 3.098 | 3.319 | 0.248 |
57 | | GFS | 3.168 | 2.455 | 3.480 | - | - |
58 |
59 |
60 | ### Other tasks
61 |
62 | We welcome feedback from the modeling community on how best to use the data for evaluating forecasts in a way that best
63 | reflects the end consumer experience with various forecasting models. We propose the following common tasks:
64 |
65 | - **Short-range forecasting.** Forecasts are initialized four times daily at 00, 06, 12, and 18 UTC. Forecasts are evaluated every 1 hour of lead time up to 72 hours. Headline metric is the RMSE (ETS for precipitation) for each predicted variable averaged over all forecasts and lead times.
66 | - **Nowcasting.** Forecasts are initialized every hour. Forecasts are evaluated every 1 hour of lead time up to 24 hours. Headline metric is the RMSE (ETS for precipitation) for each predicted variable averaged over all forecasts and lead times.
67 | - **Sub-seasonal-to-seasonal forecasts.** Following the schedule of ECMWF's long-range forecasts prior to June 2023, forecasts are initialized twice weekly at 00 UTC on Mondays and Thursdays. Forecasts are averaged either daily or weekly for lead times every 6 hours. Ideally forecasts should be probabilistic, enabling the use of proper scoring methods such as the continuous ranked probability score (CRPS). Headline metrics are week 3-4 and week 5-6 average scores.
68 |
69 |
70 |
71 | ## About the data
72 |
73 | WeatherReal includes several versions, all derived from global near-surface in-situ observations:
74 | (1) WeatherReal-ISD: An observational dataset based on Integrated Surface Database (ISD), which has been subjected to rigorous post-processing and quality control through our independently developed algorithms.
75 | (2) WeatherReal-Synoptic, An observational dataset from Synoptic Data PBC, a data service platform for 150,000+ in-situ surface weather stations, offering a much more densely distributed network.
76 | A quality control system is also provided as additional attributes delivered alongside the data from their API services.
77 | The following table lists available variables in WeatherReal-ISD and WeatherReal-Synoptic.
78 |
79 | | Variable | Short Name | Unit1 | Variable | Short Name | Unit |
80 | |---------------------------------------|------------|---------------------------------|-----------------------|------------|--------------------------|
81 | | 2m Temperature | t | °C | Total Cloud Cover | c | okta4 |
82 | | 2m Dewpoint Temperature | td | °C | 1-hour Precipitation | ra1 | mm |
83 | | Surface Pressure2 | sp | hPa | 3-hour Precipitation | ra3 | mm |
84 | | Mean Sea-level Pressure | msl | hPa | 6-hour Precipitation | ra6 | mm |
85 | | 10m Wind Speed | ws | m/s | 12-hour Precipitation | ra12 | mm |
86 | | 10m Wind Direction | wd | degree3 | 24-hour Precipitation | ra24 | mm |
87 |
88 | **1**: Refers to the units used in the WeatherReal-ISD we publish. For the units provided by the raw ISD and Synoptic, please consult their respective documentation.
89 | **2**: For in-situ weather stations, surface pressure is measured at the sensor's height, typically 2 meters above ground level at the weather station.
90 | **3**: The direction is measured clockwise from true north, ranging from 1° (north-northeast) to 360° (north), with 0° indicating calm winds.
91 | **4**: Okta is a unit of measurement used to describe the amount of cloud cover, with the data range being from 0 (clear sky) to 8 (completely overcast).
92 |
93 | ### WeatherReal-ISD
94 |
95 | The data source of WeatherReal-ISD, ISD [Smith et al., 2011], is a global near-surface observation dataset compiled by the National Centers for Environmental Information (NCEI).
96 | More than 100 original data sources, including SYNOP (surface synoptic observations) and METAR (meteorological aerodrome report) weather reports, are incorporated.
97 |
98 | There are currently more than 14,000 active reporting stations in ISD and it already includes the majority of known station observation data, making it an ideal data source for WeatherReal.
99 | However, the observational data have only undergone basic quality control, resulting in numerous erroneous data points.
100 | Therefore, to improve data fidelity, we performed extensive post-processing on it, including station selection and merging, and comprehensive quality control.
101 | For more details on the data processing, please refer to the paper.
102 |
103 | ### WeatherReal-Synoptic
104 |
105 | Data of WeatherReal-Synoptic is obtained from Synoptic Data PBC, which brings together observation data from hundreds of public and private station networks worldwide, providing a comprehensive and accessible data service platform for critical environmental information.
106 | For further details, please refer to [Synoptic Data’s official site](https://synopticdata.com/solutions/ai-ml-weather/).
107 | The WeatherReal-Synoptic dataset utilized in this paper was retrieved in real-time from their Time Series API services in 2023 to address our operational requirements, and the same data is available from them as a historical dataset.
108 | For precipitation, Synoptic also supports an advanced API that allows data retrieval through custom accumulation and interval windows.
109 | WeatherReal-Synoptic encompasses a greater volume of data, a more extensive observation network, and a larger number of stations compared to ISD.
110 | Note that Synoptic provides a quality control system as an additional attribute alongside the data from their API services, thus the quality control algorithm we developed independently has not been applied to the WeatherReal-Synoptic dataset.
111 |
112 |
113 | ## Acquiring data
114 |
115 | The WeatherReal datasets are available from the following locations:
116 | - WeatherReal-ISD: A single file in netCDF format for year 2023: [GitHub LFS](https://github.com/microsoft/WeatherReal/blob/main/Data/WeatherReal-ISD-2023.nc)
117 | - WeatherReal-Synoptic: Please reach out directly to [Synoptic Data PBC](https://synopticdata.com/) for access to the data.
118 |
119 |
120 | ## Evaluation code
121 |
122 | The evaluation code is written in Python and is available in the `evaluation` directory. The code is designed to be
123 | flexible and can be used to evaluate a wide range of forecast schemas. The launch script `evaluate.py` can be used to run
124 | the evaluation. The following example illustrates how to evaluate temperature forecasts from gridded and point-based models:
125 |
126 | ```bash
127 | python evaluate.py \
128 | --forecast-paths /path/to/grid_forecast_1.zarr /path/to/grid_forecast_2.zarr /path/to/point_forecast_1.zarr \
129 | --forecast-names GridForecast1 GridForecast2 PointForecast1 \
130 | --forecast-var-names t2m t2m t \
131 | --forecast-reformat-funcs grid_v1 grid_v1 point_standard \
132 | --obs-path /path/to/weatherreal-isd.nc \
133 | --obs-var-name t \
134 | --variable-type temperature \
135 | --convert-fcst-temperature-k-to-c \
136 | --output-directory /path/to/output
137 | ```
138 |
139 | ## Submitting metrics
140 |
141 | We welcome all submissions of evaluation results using WeatherReal data!
142 | To submit your model's evaluation metrics to the leaderboard, please open an Issue with the tag "Submission" and include
143 | all metrics/tasks you would like to submit. Please also include a reference paper or link to a public repository that can
144 | be used to peer-review your results. We will review your submission and add it to the leaderboard if it meets the
145 | requirements.
146 |
147 | ## Contributing
148 |
149 | This project welcomes contributions and suggestions. Most contributions require you to agree to a
150 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
151 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
152 |
153 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide
154 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
155 | provided by the bot. You will only need to do this once across all repos using our CLA.
156 |
157 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
158 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
159 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
160 |
161 | ## Trademarks
162 |
163 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft
164 | trademarks or logos is subject to and must follow
165 | [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
166 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
167 | Any use of third-party trademarks or logos are subject to those third-party's policies.
168 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Security
4 |
5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin).
6 |
7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below.
8 |
9 | ## Reporting Security Issues
10 |
11 | **Please do not report security vulnerabilities through public GitHub issues.**
12 |
13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report).
14 |
15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp).
16 |
17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
18 |
19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20 |
21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22 | * Full paths of source file(s) related to the manifestation of the issue
23 | * The location of the affected source code (tag/branch/commit or direct URL)
24 | * Any special configuration required to reproduce the issue
25 | * Step-by-step instructions to reproduce the issue
26 | * Proof-of-concept or exploit code (if possible)
27 | * Impact of the issue, including how an attacker might exploit the issue
28 |
29 | This information will help us triage your report more quickly.
30 |
31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs.
32 |
33 | ## Preferred Languages
34 |
35 | We prefer all communications to be in English.
36 |
37 | ## Policy
38 |
39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd).
40 |
41 |
42 |
--------------------------------------------------------------------------------
/SUPPORT.md:
--------------------------------------------------------------------------------
1 | # Support
2 |
3 | ## How to file issues and get help
4 |
5 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing
6 | issues before filing new issues to avoid duplicates. For new issues, file your bug or
7 | feature request as a new Issue.
8 |
9 | ## Microsoft Support Policy
10 |
11 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above.
12 |
--------------------------------------------------------------------------------
/conda.yml:
--------------------------------------------------------------------------------
1 | name: weatherreal
2 | channels:
3 | - defaults
4 | dependencies:
5 | - _libgcc_mutex=0.1=main
6 | - _openmp_mutex=5.1=1_gnu
7 | - ca-certificates=2024.3.11=h06a4308_0
8 | - ld_impl_linux-64=2.38=h1181459_1
9 | - libffi=3.4.4=h6a678d5_1
10 | - libgcc-ng=11.2.0=h1234567_1
11 | - libgomp=11.2.0=h1234567_1
12 | - libstdcxx-ng=11.2.0=h1234567_1
13 | - ncurses=6.4=h6a678d5_0
14 | - openssl=3.0.13=h7f8727e_1
15 | - pip=24.0=py39h06a4308_0
16 | - python=3.9.19=h955ad1f_1
17 | - readline=8.2=h5eee18b_0
18 | - setuptools=69.5.1=py39h06a4308_0
19 | - sqlite=3.45.3=h5eee18b_0
20 | - tk=8.6.14=h39e8969_0
21 | - tzdata=2024a=h04d1e81_0
22 | - wheel=0.43.0=py39h06a4308_0
23 | - xz=5.4.6=h5eee18b_1
24 | - zlib=1.2.13=h5eee18b_1
25 | - pip:
26 | - adal==1.2.7
27 | - annotated-types==0.6.0
28 | - antlr4-python3-runtime==4.9.3
29 | - applicationinsights==0.11.10
30 | - argcomplete==3.3.0
31 | - asciitree==0.3.3
32 | - attrs==23.2.0
33 | - azure-common==1.1.28
34 | - azure-core==1.30.1
35 | - azure-graphrbac==0.61.1
36 | - azure-identity==1.16.0
37 | - azure-mgmt-authorization==4.0.0
38 | - azure-mgmt-containerregistry==10.3.0
39 | - azure-mgmt-core==1.4.0
40 | - azure-mgmt-keyvault==10.3.0
41 | - azure-mgmt-network==25.3.0
42 | - azure-mgmt-resource==23.1.1
43 | - azure-mgmt-storage==21.1.0
44 | - azure-ml==0.0.1
45 | - azure-ml-component==0.9.18.post2
46 | - azure-storage-blob==12.19.0
47 | - azureml-contrib-services==1.56.0
48 | - azureml-core==1.56.0
49 | - azureml-dataprep==5.1.6
50 | - azureml-dataprep-native==41.0.0
51 | - azureml-dataprep-rslex==2.22.2
52 | - azureml-dataset-runtime==1.56.0
53 | - azureml-defaults==1.56.0.post1
54 | - azureml-inference-server-http==1.2.1
55 | - azureml-mlflow==1.56.0
56 | - azureml-telemetry==1.56.0
57 | - backports-tempfile==1.0
58 | - backports-weakref==1.0.post1
59 | - bcrypt==4.1.3
60 | - blinker==1.8.2
61 | - bytecode==0.15.1
62 | - cachetools==5.3.3
63 | - certifi==2024.2.2
64 | - cffi==1.16.0
65 | - cftime==1.6.3
66 | - charset-normalizer==3.3.2
67 | - click==8.1.7
68 | - cloudpickle==2.2.1
69 | - contextlib2==21.6.0
70 | - contourpy==1.2.1
71 | - cryptography==42.0.7
72 | - cycler==0.12.1
73 | - dask==2023.3.2
74 | - distributed==2023.3.2
75 | - docker==7.0.0
76 | - entrypoints==0.4
77 | - fasteners==0.19
78 | - flask==2.3.2
79 | - flask-cors==3.0.10
80 | - fonttools==4.51.0
81 | - fsspec==2024.3.1
82 | - fusepy==3.0.1
83 | - gitdb==4.0.11
84 | - gitpython==3.1.43
85 | - google-api-core==2.19.0
86 | - google-auth==2.29.0
87 | - googleapis-common-protos==1.63.0
88 | - gunicorn==22.0.0
89 | - humanfriendly==10.0
90 | - idna==3.7
91 | - importlib-metadata==7.1.0
92 | - importlib-resources==6.4.0
93 | - inference-schema==1.7.2
94 | - isodate==0.6.1
95 | - itsdangerous==2.2.0
96 | - jeepney==0.8.0
97 | - jinja2==3.1.4
98 | - jmespath==1.0.1
99 | - joblib==1.4.2
100 | - jsonpickle==3.0.4
101 | - jsonschema==4.22.0
102 | - jsonschema-specifications==2023.12.1
103 | - kiwisolver==1.4.5
104 | - knack==0.11.0
105 | - locket==1.0.0
106 | - markupsafe==2.1.5
107 | - matplotlib==3.7.1
108 | - metpy==1.3.1
109 | - mlflow-skinny==2.12.2
110 | - msal==1.28.0
111 | - msal-extensions==1.1.0
112 | - msgpack==1.0.8
113 | - msrest==0.7.1
114 | - msrestazure==0.6.4
115 | - ndg-httpsclient==0.5.1
116 | - netcdf4==1.6.3
117 | - numcodecs==0.12.1
118 | - numpy==1.23.5
119 | - oauthlib==3.2.2
120 | - omegaconf==2.3.0
121 | - opencensus==0.11.4
122 | - opencensus-context==0.1.3
123 | - opencensus-ext-azure==1.1.13
124 | - packaging==24.0
125 | - pandas==1.5.3
126 | - paramiko==3.4.0
127 | - partd==1.4.2
128 | - pathspec==0.12.1
129 | - pillow==10.3.0
130 | - pint==0.23
131 | - pkginfo==1.10.0
132 | - platformdirs==4.2.1
133 | - pooch==1.8.1
134 | - portalocker==2.8.2
135 | - proto-plus==1.23.0
136 | - protobuf==4.25.3
137 | - psutil==5.9.8
138 | - pyarrow==16.0.0
139 | - pyasn1==0.6.0
140 | - pyasn1-modules==0.4.0
141 | - pycparser==2.22
142 | - pydantic==2.7.1
143 | - pydantic-core==2.18.2
144 | - pydantic-settings==2.2.1
145 | - pydash==8.0.1
146 | - pygments==2.18.0
147 | - pyjwt==2.8.0
148 | - pynacl==1.5.0
149 | - pyopenssl==24.1.0
150 | - pyparsing==3.1.2
151 | - pyproj==3.6.1
152 | - pysocks==1.7.1
153 | - python-dateutil==2.9.0.post0
154 | - python-dotenv==1.0.1
155 | - pytz==2024.1
156 | - pyyaml==6.0.1
157 | - referencing==0.35.1
158 | - requests==2.31.0
159 | - requests-oauthlib==2.0.0
160 | - rpds-py==0.18.1
161 | - rsa==4.9
162 | - ruamel-yaml==0.17.16
163 | - ruamel-yaml-clib==0.2.8
164 | - scikit-learn==1.4.2
165 | - scipy==1.13.0
166 | - secretstorage==3.3.3
167 | - six==1.16.0
168 | - smmap==5.0.1
169 | - sortedcontainers==2.4.0
170 | - sqlparse==0.5.0
171 | - tabulate==0.9.0
172 | - tblib==3.0.0
173 | - threadpoolctl==3.5.0
174 | - toolz==0.12.1
175 | - tornado==6.4
176 | - tqdm==4.66.4
177 | - traitlets==5.14.3
178 | - typing-extensions==4.11.0
179 | - urllib3==2.2.1
180 | - werkzeug==2.3.8
181 | - wrapt==1.16.0
182 | - xarray==2023.1.0
183 | - zarr==2.14.2
184 | - zict==3.0.0
185 | - zipp==3.18.1
186 |
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from functools import reduce
3 | import logging
4 | import os
5 | from pathlib import Path
6 | from typing import Any, Dict, List, Optional, Tuple, Union
7 | import warnings
8 |
9 | import matplotlib.pyplot as plt
10 | import pandas as pd
11 | import xarray as xr
12 | import yaml
13 |
14 | from evaluation.forecast_reformat_catalog import reformat_forecast
15 | from evaluation.obs_reformat_catalog import get_interp_station_list, obs_to_verification, reformat_and_filter_obs
16 | from evaluation.metric_catalog import get_metric_func
17 | from evaluation.utils import configure_logging, get_metric_multiple_stations, generate_forecast_cache_path, \
18 | cache_reformat_forecast, load_reformat_forecast, ForecastData, ForecastInfo, MetricData, get_ideal_xticks
19 |
20 |
21 | warnings.filterwarnings("ignore")
22 |
23 | logger = logging.getLogger(__name__)
24 |
25 |
26 | def intersect_all_forecast(forecast_list: List[ForecastData]) -> List[ForecastData]:
27 | """
28 | Generates new versions of the forecast data objects where the dataframes have been aligned between all
29 | forecasts in the sequence. If only one forecast is provided, the data is returned as is.
30 | """
31 | raise NotImplementedError("Forecast alignment is not yet implemented.")
32 |
33 |
34 | def get_forecast_data(forecast_info: ForecastInfo, cache_forecast: bool) -> ForecastData:
35 | """
36 | Open all the forecasts from the forecast_info_list and reformat them.
37 |
38 | Parameters
39 | ----------
40 | forecast_info: ForecastInfo
41 | The forecast information.
42 | cache_forecast: bool
43 | If true, cache the reformat forecast data, next time will load from cache.
44 |
45 | Returns
46 | -------
47 | forecast: ForecastData
48 | Forecast data and the forecast information.
49 | """
50 | info = forecast_info
51 | cache_path = generate_forecast_cache_path(info)
52 | if not os.path.exists(cache_path) or not cache_forecast:
53 | logger.info(f"open forecast file: {info.path}")
54 | if info.file_type is None:
55 | if Path(info.path).is_dir():
56 | forecast = xr.open_zarr(info.path)
57 | else:
58 | forecast = xr.open_dataset(info.path, chunks={})
59 | elif info.file_type == 'zarr':
60 | forecast = xr.open_zarr(info.path)
61 | else:
62 | forecast = xr.open_dataset(info.path, chunks={})
63 | forecast_ds = reformat_forecast(forecast, info)
64 | if cache_forecast:
65 | cache_reformat_forecast(forecast_ds, cache_path)
66 | logger.info(f"save forecast file to cache: {cache_path}")
67 | else:
68 | forecast_ds = load_reformat_forecast(cache_path)
69 | logger.info(f"load forecast: {info.forecast_name}, from cache: {cache_path}")
70 | logger.debug(f"opened forecast dataset: {forecast_ds}")
71 | return ForecastData(info=info, forecast=forecast_ds)
72 |
73 |
74 | def get_observation_data(obs_base_path: str, obs_var_name: str, station_metadata_path: str,
75 | obs_file_type: str, obs_start_month: str, obs_end_month: str,
76 | precip_threshold: Optional[float] = None) -> xr.Dataset:
77 | """
78 | Open the observation file and reformat it. Required fields: station, valid_time, obs_var_name.
79 |
80 | Parameters
81 | ----------
82 | obs_base_path: str
83 | Path to the observation file.
84 | obs_var_name: str
85 | Name of the observation variable.
86 | station_metadata_path: str
87 | Path to the station metadata file.
88 | obs_file_type: str
89 | Type of the observation file.
90 | obs_start_month: str
91 | Obs start month, for multi-file netCDF data.
92 | obs_end_month: str
93 | Obs end month, for multi-file netCDF data.
94 | precip_threshold: float, optional
95 | Threshold for converting precipitation amount to binary. Default is no conversion.
96 |
97 | Returns
98 | -------
99 | obs: xr.Dataset
100 | Observation data with required fields.
101 | """
102 | if obs_start_month is not None and obs_end_month is not None:
103 | if obs_start_month is None or obs_end_month is None:
104 | raise ValueError("Both obs_start_month and obs_end_month must be provided.")
105 | month_list = pd.date_range(obs_start_month, obs_end_month, freq='MS')
106 | suffix = obs_file_type or 'nc'
107 | obs_path = [os.path.join(obs_base_path, month.strftime(f'%Y%m.{suffix}')) for month in month_list]
108 | obs_path_filter = []
109 | for path in obs_path:
110 | if not os.path.exists(path):
111 | logger.warning(f"expected observation path does not exist: {path}")
112 | else:
113 | obs_path_filter.append(path)
114 | obs = xr.open_mfdataset(obs_path_filter, chunks={})
115 | else:
116 | if obs_file_type is None:
117 | if Path(obs_base_path).is_dir():
118 | obs = xr.open_zarr(obs_base_path)
119 | else:
120 | obs = xr.open_dataset(obs_base_path, chunks={})
121 | elif obs_file_type == 'zarr':
122 | obs = xr.open_zarr(obs_base_path)
123 | else:
124 | obs = xr.open_dataset(obs_base_path, chunks={})
125 |
126 | obs = reformat_and_filter_obs(obs, obs_var_name, station_metadata_path, precip_threshold)
127 | logger.debug(f"opened observation dataset: {obs}")
128 | return obs
129 |
130 |
131 | def merge_forecast_obs(forecast: ForecastData, obs: xr.Dataset) -> ForecastData:
132 | """
133 | Merge the forecast and observation data.
134 | """
135 | new_obs = obs_to_verification(
136 | obs,
137 | steps=forecast.forecast.lead_time.values,
138 | max_lead=forecast.forecast.lead_time.values.max(),
139 | issue_times=forecast.forecast.issue_time.values
140 | )
141 | merge_data = xr.merge([forecast.forecast, new_obs], compat='override')
142 | merge_data['delta'] = merge_data['fc'] - merge_data['obs']
143 | result = ForecastData(info=forecast.info, forecast=forecast.forecast, merge_data=merge_data)
144 | logger.debug(f"after merge forecast and obs: {result.merge_data}")
145 | return result
146 |
147 |
148 | def filter_by_region(forecast: ForecastData, region_name: str, station_list: List[str]) \
149 | -> ForecastData:
150 | """
151 | Apply a selection on forecast based on the region_name and station_list.
152 | """
153 | if region_name == 'all':
154 | return forecast
155 | merge_data = forecast.merge_data
156 | filtered_merge_data = ForecastData(
157 | merge_data=merge_data.sel(station=merge_data.station.values.isin(station_list)),
158 | info=forecast.info
159 | )
160 | return filtered_merge_data
161 |
162 |
163 | def calculate_all_metrics(forecast_data: ForecastData, group_dim: str, metrics_params: Dict[str, Any]) \
164 | -> MetricData:
165 | """
166 | Calculate all the metrics together for dask graph efficiency.
167 |
168 | Parameters
169 | ----------
170 | forecast_data: ForecastData
171 | The forecast data.
172 | group_dim: str
173 | The dimension to group the metric calculation.
174 | metrics_params: dict
175 | Dictionary containing the metrics configs.
176 |
177 | Returns
178 | -------
179 | metric_data: MetricData
180 | Metric data and the forecast information.
181 | """
182 | metrics = MetricData(info=forecast_data.info, metric_data=xr.Dataset())
183 | for metric_name in metrics_params.keys():
184 | metric_func = get_metric_func(metrics_params[metric_name])
185 | metrics.metric_data[metric_name] = metric_func(forecast_data.merge_data, group_dim)
186 | metrics.metric_data = metrics.metric_data.compute()
187 | return metrics
188 |
189 |
190 | def get_plot_detail(forecast_data: ForecastData, group_dim: str):
191 | """
192 | Get some added data to show on plots
193 | """
194 | merge_data = forecast_data.merge_data
195 | counts = {key: coord.size for key, coord in merge_data.coords.items() if key not in ['lat', 'lon']}
196 | counts.update({'fc': merge_data.fc.size})
197 | all_dims = ['valid_time', 'issue_time', 'lead_time']
198 | all_dims.remove(group_dim)
199 | dim_info = []
200 | for dim in all_dims:
201 | if dim == 'lead_time':
202 | vmax, vmin = merge_data[dim].max(), merge_data[dim].min()
203 | else:
204 | vmax, vmin = pd.Timestamp(merge_data[dim].values.max()).strftime("%Y-%m-%d %H:%M:%S"), \
205 | pd.Timestamp(merge_data[dim].values.min()).strftime("%Y-%m-%d %H:%M:%S")
206 | dim_info.append(f'{dim} min: {vmin}, max: {vmax}')
207 | data_distribution = f"dim count: {str(counts)}\n{dim_info[0]}\n{dim_info[1]}"
208 | return data_distribution
209 |
210 |
211 | def plot_metric(
212 | example_data: ForecastData,
213 | metric_data_list: List[MetricData],
214 | group_dim: str,
215 | metric_name: str,
216 | base_plot_setting: Dict[str, Any],
217 | metrics_params: Dict[str, Any],
218 | output_dir: str,
219 | region_name: str,
220 | plot_save_format: Optional[str] = 'png'
221 | ) -> plt.Figure:
222 | """
223 | A generic, basic plot for a single metric.
224 |
225 | Parameters
226 | ----------
227 | example_data: ForecastData
228 | Example forecast data to get some extra information for the plot.
229 | metric_data_list: list of MetricData
230 | List of MetricData objects containing the metric data and the forecast information.
231 | group_dim: str
232 | The dimension to group the metric calculation.
233 | metric_name: str
234 | The name of the metric.
235 | base_plot_setting: dict
236 | Dictionary containing the base plot settings.
237 | metrics_params: dict
238 | Dictionary containing the metric method and other kwargs.
239 | output_dir: str
240 | The output directory for the plots.
241 | region_name: str
242 | The name of the region.
243 | plot_save_format: str, optional
244 | The format to save the plot in. Default is 'png'.
245 |
246 | Returns
247 | -------
248 | fig: plt.Figure
249 | The plot figure.
250 | """
251 | data_distribution = get_plot_detail(example_data, group_dim)
252 |
253 | fig = plt.figure(figsize=(5.5, 6.5))
254 | font = {'weight': 'medium', 'fontsize': 11}
255 | title = base_plot_setting['title']
256 | xlabel = base_plot_setting['xlabel']
257 | if 'plot_setting' in metrics_params:
258 | plot_setting = metrics_params['plot_setting']
259 | title = plot_setting.get('title', title)
260 | xlabel = plot_setting.get('xlabel', xlabel)
261 | plt.title(title)
262 | plt.suptitle(data_distribution, fontsize=7)
263 | plt.gca().set_xlabel(xlabel[group_dim], fontdict=font)
264 | plt.gca().set_ylabel(metric_name, fontdict=font)
265 |
266 | for metrics in metric_data_list:
267 | metric_data = metrics.metric_data
268 | forecast_name = metrics.info.forecast_name
269 | plt.plot(metric_data[group_dim], metric_data[metric_name], label=forecast_name, linewidth=1.5)
270 |
271 | if group_dim == 'lead_time':
272 | plt.gca().set_xticks(get_ideal_xticks(metric_data[group_dim].min(), metric_data[group_dim].max(), 8))
273 | plt.grid(linestyle=':')
274 | plt.legend(loc='upper center', bbox_to_anchor=(0.45, -0.14), frameon=False, ncol=3, fontsize=10)
275 | plt.tight_layout()
276 | plt.subplots_adjust(top=0.85)
277 | plot_path = os.path.join(output_dir, region_name)
278 | os.makedirs(plot_path, exist_ok=True)
279 | plt.savefig(os.path.join(plot_path, f"{metric_name}.{plot_save_format}"))
280 | return fig
281 |
282 |
283 | def metrics_to_csv(
284 | metric_data_list: List[MetricData],
285 | group_dim: str,
286 | output_dir: str,
287 | region_name: str
288 | ):
289 | """
290 | A generic, basic function to save the metric data to a CSV file.
291 |
292 | Parameters
293 | ----------
294 | metric_data_list: list of MetricData
295 | List of MetricData objects containing the metric data and the forecast information.
296 | group_dim: str
297 | The dimension to group the metric calculation.
298 | output_dir: str
299 | The output directory for the CSV files.
300 | region_name: str
301 | The name of the region.
302 |
303 | Returns
304 | -------
305 | merged_df: pd.DataFrame
306 | The merged DataFrame with metric data.
307 | """
308 | df_list = []
309 | for metrics in metric_data_list:
310 | df_list.append(metrics.metric_data.rename(
311 | {metric: f"{metrics.info.forecast_name}_{metric}" for metric in metrics.metric_data.data_vars.keys()}
312 | ).to_dataframe())
313 | merged_df = reduce(lambda left, right: pd.merge(left, right, on=group_dim, how='inner'), df_list)
314 |
315 | output_path = os.path.join(output_dir, region_name)
316 | os.makedirs(output_path, exist_ok=True)
317 | merged_df.to_csv(os.path.join(output_path, "metrics.csv"))
318 | return merged_df
319 |
320 |
321 | def parse_args(args: argparse.Namespace) -> Tuple[List[ForecastInfo], Any, Any, Union[
322 | Dict[str, List[Any]], Any], Any, Any, Any, Any, Any, Any, Any, str, bool, bool, Optional[float]]:
323 | forecast_info_list = []
324 | forecast_name_list = args.forecast_names
325 | forecast_var_name_list = args.forecast_var_names
326 | forecast_reformat_func_list = args.forecast_reformat_funcs
327 | station_metadata_path = args.station_metadata_path
328 | for index, forecast_path in enumerate(args.forecast_paths):
329 | forecast_info = ForecastInfo(
330 | path=forecast_path,
331 | forecast_name=forecast_name_list[index] if index < len(forecast_name_list) else f"forecast_{index}",
332 | fc_var_name=forecast_var_name_list[index] if index < len(forecast_var_name_list) else
333 | forecast_var_name_list[0],
334 | reformat_func=forecast_reformat_func_list[index] if index < len(forecast_reformat_func_list) else
335 | forecast_reformat_func_list[0],
336 | file_type=args.forecast_file_types[index] if index < len(args.forecast_file_types) else None,
337 | station_metadata_path=station_metadata_path,
338 | interp_station_path=station_metadata_path,
339 | output_directory=args.output_directory,
340 | start_date=args.start_date,
341 | end_date=args.end_date,
342 | issue_time_freq=args.issue_time_freq,
343 | start_lead=args.start_lead,
344 | end_lead=args.end_lead,
345 | convert_temperature=args.convert_fcst_temperature_k_to_c,
346 | convert_pressure=args.convert_fcst_pressure_pa_to_hpa,
347 | convert_cloud=args.convert_fcst_cloud_to_okta,
348 | precip_proba_threshold=args.precip_proba_threshold_conversion,
349 | )
350 | forecast_info_list.append(forecast_info)
351 |
352 | metrics_settings_path = args.config_path if args.config_path is not None else os.path.join(
353 | os.path.dirname(os.path.abspath(__file__)), 'metric_config.yml')
354 | with open(metrics_settings_path, 'r') as fs: # pylint: disable=unspecified-encoding
355 | metrics_settings = yaml.safe_load(fs)
356 |
357 | try:
358 | metrics_settings = metrics_settings[args.variable_type]
359 | except KeyError as exc:
360 | raise ValueError(f"Unknown variable type: {args.variable_type}. Check config file {metrics_settings_path}") \
361 | from exc
362 | metrics_dict = metrics_settings['metrics']
363 | base_plot_setting = metrics_settings['base_plot_setting']
364 | if args.eval_region_files is not None:
365 | try:
366 | region_dict = get_metric_multiple_stations(','.join(args.eval_region_files))
367 | except Exception as e: # pylint: disable=broad-exception-caught
368 | logger.info(f"get_metric_multiple_stations failed, use default region: all {e}")
369 | region_dict = {}
370 | else:
371 | region_dict = {}
372 | region_dict['all'] = []
373 |
374 | return (forecast_info_list, metrics_dict, base_plot_setting,
375 | region_dict, args.group_dim, args.obs_var_name, args.obs_path, args.obs_file_type,
376 | args.obs_start_month, args.obs_end_month, args.output_directory, station_metadata_path,
377 | bool(args.cache_forecast), bool(args.align_forecasts),
378 | args.precip_proba_threshold_conversion)
379 |
380 |
381 | def main(args):
382 | logger.info("===================== parse args =====================")
383 | (forecast_info_list, metrics_dict, base_plot_setting, region_dict, group_dim, obs_var_name,
384 | obs_base_path, obs_file_type, obs_start_month, obs_end_month, output_dir, station_metadata_path,
385 | cache_forecast, align_forecasts, precip_threshold) = parse_args(args)
386 |
387 | logger.info("===================== start get_observation_data =====================")
388 | obs_ds = get_observation_data(obs_base_path, obs_var_name, station_metadata_path, obs_file_type, obs_start_month,
389 | obs_end_month, precip_threshold=precip_threshold)
390 |
391 | # Get metadata and set it on all the forecast info objects
392 | if station_metadata_path is not None:
393 | metadata = get_interp_station_list(station_metadata_path)
394 | else:
395 | try:
396 | metadata = pd.DataFrame({'lat': obs_ds.lat.values, 'lon': obs_ds.lon.values,
397 | 'station': obs_ds.station.values})
398 | except KeyError as exc:
399 | raise ValueError("--station-metadata-path is required if lat/lon/station keys are not in the observation "
400 | "file") from exc
401 | for forecast_info in forecast_info_list:
402 | forecast_info.metadata = metadata
403 |
404 | if align_forecasts:
405 | # First load all forecasts, then compute and return metrics.
406 | logger.info("===================== start get_forecast_data =====================")
407 | forecast_list = [get_forecast_data(fi, cache_forecast) for fi in forecast_info_list]
408 | logger.info("===================== start intersect_all_forecast =====================")
409 | forecast_list = intersect_all_forecast(forecast_list)
410 | else:
411 | forecast_list = [None] * len(forecast_info_list)
412 |
413 | # For each forecast, compute all its metrics in every region.
414 | metric_data = {r: [] for r in region_dict}
415 | for forecast, forecast_info in zip(forecast_list, forecast_info_list):
416 | try:
417 | del merged_forecast # noqa: F821
418 | except NameError:
419 | pass
420 | logger.info(f"===================== compute metrics for forecast {forecast_info.forecast_name} "
421 | f"=====================")
422 | if forecast is None:
423 | logger.info("===================== get_forecast_data =====================")
424 | forecast = get_forecast_data(forecast_info, cache_forecast)
425 |
426 | logger.info("===================== start merge_forecast_obs =====================")
427 | merged_forecast = merge_forecast_obs(forecast, obs_ds)
428 |
429 | for region_name, region_data in region_dict.items():
430 | logger.info(f"===================== filter_by_region: {region_name} =====================")
431 | filtered_forecast = filter_by_region(merged_forecast, region_name, region_data)
432 | if region_name != 'all':
433 | logger.info(f"after filter_by_region: {region_name}; "
434 | f"stations: {filtered_forecast.merge_data.station.size}")
435 |
436 | logger.info(f"start calculate_metrics, region: {region_name}")
437 | metric_data[region_name].append(
438 | calculate_all_metrics(filtered_forecast, group_dim, metrics_dict)
439 | )
440 | forecast = None
441 |
442 | # Plot all metrics and save data
443 | for region_name in region_dict:
444 | for metric_name in metrics_dict:
445 | logger.info(f"===================== plot_metric: {metric_name}, region: {region_name} "
446 | f"=====================")
447 | plot_metric(merged_forecast, metric_data[region_name], group_dim, metric_name,
448 | base_plot_setting, metrics_dict[metric_name], output_dir, region_name)
449 |
450 | metrics_to_csv(metric_data[region_name], group_dim, output_dir, region_name)
451 |
452 |
453 | if __name__ == '__main__':
454 | parser = argparse.ArgumentParser(
455 | description="Forecast evaluation script. Given a set of forecasts and a file of reference observations, "
456 | "computes requested metrics as specified in the `metric_catalog.yml` file. Includes the ability "
457 | "to interpret either grid-based or point-based forecasts. Grid-based forecasts are interpolated "
458 | "to observation locations. Point-based forecasts are directly compared to nearest observations."
459 | )
460 | parser.add_argument(
461 | "--forecast-paths",
462 | nargs='+',
463 | type=str,
464 | required=True,
465 | help="List of paths containing forecasts. If a directory is provided, assumes forecast is a zarr store, "
466 | "and calls xarray's `open_zarr` method. "
467 | "Required dimensions: lead_time (or step), issue_time (or time), lat (or latitude), lon (or longitude)."
468 | )
469 | parser.add_argument(
470 | '--forecast-names',
471 | type=str,
472 | nargs='+',
473 | default=[],
474 | help="List of names to assign to the forecasts. If there are more forecast paths than names, fills in the "
475 | "remaining names with 'forecast_{index}'"
476 | )
477 | parser.add_argument(
478 | '--forecast-var-names',
479 | type=str,
480 | nargs='+',
481 | required=True,
482 | help="List of names (one per forecast path) of the forecast variable of interest in each file. If only one "
483 | "value is provided, assumes all forecast files have the same variable name. Raises an error if the "
484 | "number of listed values is less than the number of forecast paths."
485 | )
486 | parser.add_argument(
487 | '--forecast-reformat-funcs',
488 | type=str,
489 | nargs='+',
490 | required=True,
491 | help="For each forecast path, provide the name of the reformat function to apply. This function is based on "
492 | "the schema of the forecast file. Can be only a single value to apply to all forecasts. Options: "
493 | "\n - 'grid_standard': input is a grid forecast with dimensions lead_time, issue_time, lat, lon."
494 | "\n - 'point_standard': input is a point forecast with dimensions lead_time, issue_time, station."
495 | "\n - 'grid_v1': custom reformat function for grid forecasts with dims time, step, latitude, longitude."
496 | )
497 | parser.add_argument(
498 | '--forecast-file-types',
499 | type=str,
500 | nargs='+',
501 | default=[],
502 | help="List of file types for each forecast path. Options: 'nc', 'zarr'. If not provided, or not enough "
503 | "entries, will assume zarr store if forecast is a directory, and otherwise will use xarray's "
504 | "`open_dataset` method."
505 | )
506 | parser.add_argument(
507 | "--obs-path",
508 | type=str,
509 | required=True,
510 | help="Path to the verification folder or file"
511 | )
512 | parser.add_argument(
513 | "--obs-file-type",
514 | type=str,
515 | default=None,
516 | help="Type of the observation file. Options: 'nc', 'zarr'. If not provided, will assume zarr store if this is "
517 | "a directory, and otherwise will use xarray's `open_dataset` method."
518 | )
519 | parser.add_argument(
520 | "--obs-start-month",
521 | type=str,
522 | default=None,
523 | help="Option to read multiple netCDF files as a single dataset. These files are named 'YYYYMM.nc'. Provide "
524 | "the start month in the format 'YYYY-MM'. Not needed if obs-path is a single nc/zarr store."
525 | )
526 | parser.add_argument(
527 | "--obs-end-month",
528 | type=str,
529 | default=None,
530 | help="Option to read multiple netCDF files as a single dataset. These files are named 'YYYYMM.nc'. Provide "
531 | "the end month in the format 'YYYY-MM'. Not needed if obs-path is a single nc/zarr store."
532 | )
533 | parser.add_argument(
534 | "--obs-var-name",
535 | type=str,
536 | help="Name of the variable of interest in the observation data.",
537 | required=True
538 | )
539 | parser.add_argument(
540 | "--station-metadata-path",
541 | type=str,
542 | help="Path to the station list containing metadata. Must include columns 'station', 'lat', 'lon'. "
543 | "If not provided, assumes the station lat/lon are coordinates in the observation file.",
544 | required=False
545 | )
546 | parser.add_argument(
547 | "--config-path",
548 | type=str,
549 | help="Path to custom config yml file containing metric settings. Defaults to `metric_config.yml` in this "
550 | "script directory.",
551 | default=None
552 | )
553 | parser.add_argument(
554 | "--variable-type",
555 | type=str,
556 | help="The type of the variable, as used in `--config-path` to select the appropriate metric settings. For "
557 | "example, 'temperature' or 'wind'.",
558 | required=True
559 | )
560 | parser.add_argument(
561 | "--output-directory",
562 | type=str,
563 | help="Output directory for all evaluation artifacts",
564 | required=True
565 | )
566 | parser.add_argument(
567 | "--eval-region-files",
568 | type=str,
569 | default=None,
570 | nargs='+',
571 | help="A list of files containing station lists for evaluation in certain regions"
572 | )
573 | parser.add_argument(
574 | "--start-date",
575 | type=pd.Timestamp,
576 | default=None,
577 | help="First forecast issue time (as Timestamp) to include in evaluation"
578 | )
579 | parser.add_argument(
580 | "--end-date",
581 | type=pd.Timestamp,
582 | default=None,
583 | help="Last forecast issue time (as Timestamp) to include in evaluation"
584 | )
585 | parser.add_argument(
586 | "--issue-time-freq",
587 | type=str,
588 | default=None,
589 | help="Frequency of issue times (e.g., '1D') to include in evaluation. Default is None (all issue times)"
590 | )
591 | parser.add_argument(
592 | "--start-lead",
593 | type=int,
594 | default=None,
595 | help="First lead time (in hours) to include in evaluation"
596 | )
597 | parser.add_argument(
598 | "--end-lead",
599 | type=int,
600 | default=None,
601 | help="Last lead time (in hours) to include in evaluation"
602 | )
603 | parser.add_argument(
604 | "--group-dim",
605 | type=str,
606 | default="lead_time",
607 | help="Group dimension for metric computation, options: lead_time, issue_time, valid_time"
608 | )
609 | parser.add_argument(
610 | "--precip-proba-threshold-conversion",
611 | type=float,
612 | default=None,
613 | help="Convert observation and forecast fields from precipitation rate to probability of precipitation. Provide"
614 | " a threshold in mm/hr to use as positive precipitation class. Use only for evaluating precipitation!"
615 | )
616 | parser.add_argument(
617 | "--convert-fcst-temperature-k-to-c",
618 | action='store_true',
619 | help="Convert forecast field from Kelvin to Celsius. Use only for evaluating temperature!"
620 | )
621 | parser.add_argument(
622 | "--convert-fcst-pressure-pa-to-hpa",
623 | action='store_true',
624 | help="Convert forecast field from Pa to hPa. Use only for evaluating pressure!"
625 | )
626 | parser.add_argument(
627 | "--convert-fcst-cloud-to-okta",
628 | action='store_true',
629 | help="Convert forecast field from cloud fraction to okta. Use only for evaluating cloud!"
630 | )
631 | parser.add_argument(
632 | "--cache-forecast",
633 | action='store_true',
634 | help="If true, cache the intermediate interpolated forecast data in the output directory."
635 | )
636 | parser.add_argument(
637 | "--align-forecasts",
638 | action='store_true',
639 | help="If set, load all forecasts first and then align them based on the intersection of issue/lead times. "
640 | "Note this uses substantially more memory to store all data at once."
641 | )
642 | parser.add_argument(
643 | '--verbose',
644 | type=int,
645 | default=1,
646 | help="Verbosity level for logging. Options are 0 (WARNING), 1 (INFO), 2 (DEBUG), 3 (NOTSET). Default is 1."
647 | )
648 |
649 | run_args = parser.parse_args()
650 | configure_logging(run_args.verbose)
651 | main(run_args)
652 |
--------------------------------------------------------------------------------
/evaluation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/WeatherReal-Benchmark/68b2f9293d2a0a1b1cedf396cffda34c05a21a14/evaluation/__init__.py
--------------------------------------------------------------------------------
/evaluation/forecast_reformat_catalog.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import numpy as np
4 | import pandas as pd
5 | import xarray as xr
6 |
7 | from .utils import convert_to_binary, ForecastInfo
8 |
9 | logger = logging.getLogger(__name__)
10 |
11 |
12 | def convert_grid_to_point(grid_forecast: xr.Dataset, metadata: pd.DataFrame) -> xr.Dataset:
13 | """
14 | Convert grid forecast to point dataframe via interpolation
15 | input data dims must be: lead_time, issue_time, lat, lon
16 |
17 | Parameters
18 | ----------
19 | grid_forecast: xarray Dataset: grid forecast data
20 | metadata: pd.DataFrame: station metadata
21 |
22 | Returns
23 | -------
24 | xr.Dataset: interpolated forecast data
25 | """
26 | # grid_forecast = grid_forecast.load()
27 | grid_forecast = grid_forecast.assign_coords(lon=[lon if (lon < 180) else (lon - 360)
28 | for lon in grid_forecast['lon'].values])
29 | # Optionally roll the longitude if the minimum longitude is not at index 0. This should make longitudes
30 | # monotonically increasing.
31 | if grid_forecast['lon'].argmin().values != 0:
32 | grid_forecast = grid_forecast.roll(lon=grid_forecast['lon'].argmin().values, roll_coords=True)
33 |
34 | # Default interpolate
35 | interp_meta = xr.Dataset.from_dataframe(metadata[['station', 'lon', 'lat']].set_index(['station']))
36 | interp_forecast = grid_forecast.interp(lon=interp_meta['lon'], lat=interp_meta['lat'], method='linear')
37 | return interp_forecast
38 |
39 |
40 | def get_lead_time_slice(start_lead, end_lead):
41 | start = pd.Timedelta(start_lead, 'H') if start_lead is not None else None
42 | end = pd.Timedelta(end_lead, 'H') if end_lead is not None else None
43 | return slice(start, end)
44 |
45 |
46 | def update_unit_conversions(forecast: xr.Dataset, info: ForecastInfo):
47 | """
48 | Check if the forecast data needs to be converted to the required units. Will not perform correction if
49 | not specified by user, however, if specified and data are already in converted units, then do not perform
50 | the conversion.
51 |
52 | Parameters
53 | ----------
54 | forecast: xarray Dataset: forecast data
55 | info: ForecastInfo: forecast info
56 | """
57 | if info.reformat_func == 'omg_v1':
58 | logger.info(f"Unit conversion not needed for forecast {info.forecast_name}")
59 | return
60 | unit = forecast['fc'].attrs.get('units', '') or forecast['fc'].attrs.get('unit', '')
61 | if info.convert_temperature:
62 | if 'C' in unit:
63 | info.convert_temperature = False
64 | logger.info(f"Temperature conversion not needed for forecast {info.forecast_name}")
65 | if info.convert_pressure:
66 | if any(u in unit.lower() for u in ['hpa', 'mb', 'millibar']):
67 | info.convert_pressure = False
68 | logger.info(f"Pressure conversion not needed for forecast {info.forecast_name}")
69 | if info.convert_cloud:
70 | if any(u in unit.lower() for u in ['okta', '0-8']):
71 | info.convert_cloud = False
72 | logger.info(f"Cloud cover conversion not needed for forecast {info.forecast_name}")
73 | if info.precip_proba_threshold is not None:
74 | if not any(u in unit.lower() for u in ['mm', 'milli']):
75 | info.precip_proba_threshold /= 1e3
76 | logger.info(f"Probability thresholding not needed for forecast {info.forecast_name}")
77 |
78 |
79 | def convert_cloud(forecast: xr.Dataset):
80 | """
81 | Convert cloud cover inplace from percentage or fraction to okta. Assumes fraction by default unless units
82 | attribute says otherwise.
83 |
84 | Parameters
85 | ----------
86 | forecast: xarray Dataset: forecast data
87 |
88 | Returns
89 | -------
90 | xr.Dataset: converted forecast data
91 | """
92 | unit = forecast['fc'].attrs.get('units', '') or forecast['fc'].attrs.get('unit', '')
93 | if any(u in unit.lower() for u in ['percent', '%', '100']):
94 | forecast['fc'] *= 8 / 100.
95 | else:
96 | forecast['fc'] *= 8
97 | return forecast
98 |
99 |
100 | def convert_precip_binary(forecast: xr.Dataset, info: ForecastInfo):
101 | """
102 | Convert precipitation forecast to binary based on the threshold in the forecast info
103 |
104 | Parameters
105 | ----------
106 | forecast: xarray Dataset: forecast data
107 | info: ForecastInfo: forecast info
108 |
109 | Returns
110 | -------
111 | xr.Dataset: converted forecast data
112 | """
113 | if 'pp' in info.fc_var_name:
114 | logger.info(f"Forecast {info.forecast_name} is already in probability (%) format. Skipping thresholding.")
115 | forecast['fc'] /= 100.
116 | return forecast
117 | forecast['fc'] = convert_to_binary(forecast['fc'], info.precip_proba_threshold)
118 | return forecast
119 |
120 |
121 | def reformat_forecast(forecast: xr.Dataset, info: ForecastInfo) -> xr.Dataset:
122 | """
123 | Format the forecast data to the required format for evaluation, and keep only the required stations.
124 |
125 | Parameters
126 | ----------
127 | forecast: xarray Dataset: forecast data
128 | info: ForecastInfo: forecast info
129 |
130 | Returns
131 | -------
132 | xr.Dataset: formatted forecast data
133 | """
134 | if info.reformat_func in ['grid_v1']:
135 | reformat_data = reformat_grid_v1(forecast, info)
136 | elif info.reformat_func in ['omg_v1', 'grid_v2']:
137 | reformat_data = reformat_grid_v2(forecast, info)
138 | elif info.reformat_func in ['grid_standard']:
139 | reformat_data = reformat_grid_standard(forecast, info)
140 | elif info.reformat_func in ['point_standard']:
141 | reformat_data = reformat_point_standard(forecast, info)
142 | else:
143 | raise ValueError(f"Unknown reformat method {info.reformat_func} for forecast {info.forecast_name}")
144 |
145 | # Update unit conversions
146 | update_unit_conversions(reformat_data, info)
147 | if info.convert_temperature:
148 | reformat_data['fc'] -= 273.15
149 | if info.convert_pressure:
150 | reformat_data['fc'] /= 100
151 | if info.convert_cloud:
152 | convert_cloud(reformat_data)
153 | if info.precip_proba_threshold is not None:
154 | convert_precip_binary(reformat_data, info)
155 |
156 | # Convert coordinates
157 | reformat_data = reformat_data.assign_coords(valid_time=reformat_data['issue_time'] + reformat_data['lead_time'])
158 | reformat_data = reformat_data.assign_coords(lead_time=reformat_data['lead_time'] / np.timedelta64(1, 'h'))
159 | return reformat_data
160 |
161 |
162 | def select_forecasts(forecast: xr.Dataset, info: ForecastInfo) -> xr.Dataset:
163 | """
164 | Select the forecast data based on the issue time and lead time
165 |
166 | Parameters
167 | ----------
168 | forecast: xarray Dataset: forecast data
169 | info: ForecastInfo: forecast info
170 |
171 | Returns
172 | -------
173 | xr.Dataset: selected forecast data
174 | """
175 | forecast = forecast.sel(
176 | lead_time=get_lead_time_slice(info.start_lead, info.end_lead),
177 | issue_time=slice(info.start_date, info.end_date)
178 | )
179 | if info.issue_time_freq is not None:
180 | forecast = forecast.resample(issue_time=info.issue_time_freq).nearest()
181 | return forecast
182 |
183 |
184 | def reformat_grid_v1(grid_forecast: xr.Dataset, info: ForecastInfo) -> xr.Dataset:
185 | """
186 | Standard grid forecast format following ECMWF schema
187 | input nc|zarr file dims must be: time, step, latitude, longitude
188 |
189 | Parameters
190 | ----------
191 | grid_forecast: xarray Dataset: grid forecast data
192 | info: dict: forecast info
193 |
194 | Returns
195 | -------
196 | xr.Dataset: formatted forecast data
197 | """
198 | grid_forecast = grid_forecast.rename(
199 | {
200 | 'latitude': 'lat',
201 | 'longitude': 'lon',
202 | 'step': 'lead_time',
203 | 'time': 'issue_time',
204 | info.fc_var_name: 'fc'
205 | }
206 | )
207 | grid_forecast = select_forecasts(grid_forecast, info)
208 | interp_forecast = convert_grid_to_point(grid_forecast[['fc']], info.metadata)
209 | return interp_forecast
210 |
211 |
212 | def reformat_grid_v2(grid_forecast: xr.Dataset, info: ForecastInfo) -> xr.Dataset:
213 | """
214 | Grid format following OMG schema
215 | input nc|zarr file dims must be: lead_time, issue_time, y, x, index
216 | Must have lat, lon, index, var_name as coordinates
217 |
218 | Parameters
219 | ----------
220 | grid_forecast: xarray Dataset: grid forecast data
221 | info: dict: forecast info
222 |
223 | Returns
224 | -------
225 | xr.Dataset: formatted forecast data
226 | """
227 | lat = grid_forecast['lat'].isel(index=0).lat.values
228 | lon = grid_forecast['lon'].isel(index=0).lon.values
229 | issue_time = grid_forecast['issue_time'].values
230 | grid_forecast = grid_forecast.drop_vars(['lat', 'lon', 'index', 'var_name'])
231 | grid_forecast = grid_forecast.rename({'y': 'lat', 'x': 'lon', 'index': 'issue_time'})
232 | grid_forecast = grid_forecast.assign_coords(lat=lat, lon=lon, issue_time=issue_time)
233 | grid_forecast = grid_forecast.squeeze(dim='var_name')
234 |
235 | grid_forecast = grid_forecast.rename({info.fc_var_name: 'fc'})
236 | grid_forecast = select_forecasts(grid_forecast, info)
237 | interp_forecast = convert_grid_to_point(grid_forecast[['fc']], info.metadata)
238 | return interp_forecast
239 |
240 |
241 | def reformat_grid_standard(grid_forecast: xr.Dataset, info: ForecastInfo) -> \
242 | xr.Dataset:
243 | """
244 | Grid format following standard schema
245 | input nc|zarr file dims must be: lead_time, issue_time, lat, lon
246 |
247 | Parameters
248 | ----------
249 | grid_forecast: xarray Dataset: grid forecast data
250 | info: dict: forecast info
251 |
252 | Returns
253 | -------
254 | xr.Dataset: formatted forecast data
255 | """
256 | fc_var_name = info.fc_var_name
257 | grid_forecast = grid_forecast.rename({fc_var_name: 'fc'})
258 | grid_forecast = select_forecasts(grid_forecast, info)
259 | interp_forecast = convert_grid_to_point(grid_forecast[['fc']], info.metadata)
260 | return interp_forecast
261 |
262 |
263 | def reformat_point_standard(point_forecast: xr.Dataset, info: ForecastInfo) \
264 | -> xr.Dataset:
265 | """
266 | Standard point forecast format
267 | input nc|zarr file dims must be: lead_time, issue_time, station
268 |
269 | Parameters
270 | ----------
271 | point_forecast: xarray Dataset: grid forecast data
272 | info: dict: forecast info
273 |
274 | Returns
275 | -------
276 | xr.Dataset: formatted forecast data
277 | """
278 | fc_var_name = info.fc_var_name
279 | point_forecast = point_forecast.rename({fc_var_name: 'fc'})
280 | point_forecast = select_forecasts(point_forecast, info)
281 | return point_forecast[['fc']]
282 |
--------------------------------------------------------------------------------
/evaluation/metric_catalog.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | import numpy as np
4 | import xarray as xr
5 |
6 |
7 | def _get_mean_dims(data, group_dim):
8 | """
9 | Get the dimensions over which to compute mean.
10 | """
11 | return [dim for dim in data.dims if dim != group_dim]
12 |
13 |
14 | def get_metric_func(settings):
15 | """
16 | Find the appropriate metric function, specified by 'settings', and
17 | customize it according to the rest of the 'settings'.
18 |
19 | All metric functions expect two arguments:
20 | a dataset with 3 columns, 'obs', 'fc' and 'delta',
21 | and with 3 dimensions: 'station', 'issue_time' 'lead_time'
22 | a group dimension, which is one of the dimensions of the dataset.
23 |
24 | They return a dataset with metric values averaged over all dims
25 | except the group dim.
26 | """
27 | name = settings['method']
28 | if name == 'rmse':
29 | return rmse
30 | if name == 'count':
31 | return count
32 | if name == 'mean_error':
33 | return mean_error
34 | if name == 'mae':
35 | return mae
36 | if name == 'sde':
37 | return sde
38 | if name == 'min_record_error':
39 | return min_record_error
40 | if name == 'max_record_error':
41 | return max_record_error
42 | if name == 'step_function':
43 | th = settings['threshold']
44 | invert = settings.get('invert', False)
45 | return partial(step_function, threshold=th, invert=invert)
46 | if name == 'accuracy':
47 | th = settings['threshold']
48 | return partial(accuracy, threshold=th)
49 | if name == 'error':
50 | th = settings['threshold']
51 | return partial(error, threshold=th)
52 | if name == 'f1_thresholded':
53 | th = settings.get('threshold', None)
54 | return partial(f1_thresholded, threshold=th)
55 | if name == 'f1_class_averaged':
56 | return f1_class_averaged
57 | if name == 'threat_score':
58 | th = settings.get('threshold', None)
59 | equitable = settings.get('equitable', False)
60 | return partial(ts_thresholded, threshold=th, equitable=equitable)
61 | if name == 'reliability':
62 | b = settings.get('bins', None)
63 | return partial(reliability, bins=b)
64 | if name == 'brier':
65 | return mse
66 | if name == 'pod':
67 | th = settings.get('threshold')
68 | return partial(pod, threshold=th)
69 | if name == 'far':
70 | th = settings.get('threshold')
71 | return partial(far, threshold=th)
72 | if name == 'csi':
73 | th = settings.get('threshold')
74 | return partial(csi, threshold=th)
75 | raise ValueError(f'Unknown metric: {name}')
76 |
77 |
78 | def rmse(data: xr.Dataset, group_dim: str) -> xr.DataArray:
79 | src = pow(data['delta'], 2)
80 | return np.sqrt(src.mean(_get_mean_dims(src, group_dim))).rename('metric')
81 |
82 |
83 | def mse(data: xr.Dataset, group_dim: str) -> xr.DataArray:
84 | src = pow(data['delta'], 2)
85 | return src.mean(_get_mean_dims(src, group_dim)).rename('metric')
86 |
87 |
88 | def accuracy(data: xr.Dataset, group_dim: str, threshold: float) -> xr.DataArray:
89 | src = abs(data['delta']) < threshold
90 | cnt = count(data, group_dim)
91 | # Weight the average to ignore missing data
92 | r = src.sum(_get_mean_dims(src, group_dim))
93 | r = r.rename('metric')
94 | r /= cnt
95 | return r
96 |
97 |
98 | def error(data: xr.Dataset, group_dim: str, threshold: float) -> xr.DataArray:
99 | src = abs(data['delta']) > threshold
100 | cnt = count(data, group_dim)
101 | # Weight the average to ignore missing data
102 | r = src.sum(_get_mean_dims(src, group_dim))
103 | r = r.rename('metric')
104 | r /= cnt
105 | return r
106 |
107 |
108 | def mean_error(data: xr.Dataset, group_dim: str) -> xr.DataArray:
109 | src = data['delta']
110 | me = src.mean(_get_mean_dims(src, group_dim))
111 | return me.rename('metric')
112 |
113 |
114 | def mae(data: xr.Dataset, group_dim: str) -> xr.DataArray:
115 | src = abs(data['delta'])
116 | src = src.mean(_get_mean_dims(src, group_dim))
117 | return src.rename('metric')
118 |
119 |
120 | def sde(data: xr.Dataset, group_dim: str) -> xr.DataArray:
121 | src = data['delta']
122 | me_all = src.mean()
123 | src = pow((src-me_all), 2)
124 | src = src.mean(_get_mean_dims(src, group_dim))
125 | return np.sqrt(src).rename('metric')
126 |
127 |
128 | def min_record_error(data: xr.Dataset, group_dim: str) -> xr.DataArray:
129 | src = data['delta']
130 | mre = src.min(_get_mean_dims(src, group_dim))
131 | return mre.rename('metric')
132 |
133 |
134 | def max_record_error(data: xr.Dataset, group_dim: str) -> xr.DataArray:
135 | src = data['delta']
136 | mre = src.max(_get_mean_dims(src, group_dim))
137 | return mre.rename('metric')
138 |
139 |
140 | def step_function(data: xr.Dataset, group_dim: str, threshold: float, invert: bool) -> xr.DataArray:
141 | src = abs(data['delta']) < threshold
142 | cnt = count(data, group_dim)
143 |
144 | # Weight the average to ignore missing data
145 | r = src.sum(_get_mean_dims(src, group_dim))
146 | r = r.rename('metric')
147 | r /= cnt
148 | if invert:
149 | r = 1.0 - r
150 | return r
151 |
152 |
153 | def count(data: xr.Dataset, group_dim: str) -> xr.DataArray:
154 | src = ~data['delta'].isnull()
155 | cnt = src.sum(_get_mean_dims(src, group_dim))
156 | return cnt.rename('metric')
157 |
158 |
159 | def _true_positives_and_false_negatives(data, labels):
160 | """
161 | Compute the true positives and false negatives for a given set of labels. Filter keeps only where truth is True and
162 | neither prediction nor truth is NaN.
163 | """
164 | r_list = []
165 | for label in labels:
166 | recall_filter = xr.where(np.logical_and(~np.isnan(data['fc']), data['obs'] == label), 1, np.nan)
167 | r_list.append((data['fc'] == label) * recall_filter)
168 | return xr.concat(r_list, dim='category').rename('metric')
169 |
170 |
171 | def _true_positives_and_false_positives(data, labels):
172 | """
173 | Compute the true positives and false positives for a given set of labels. Filter keeps only where prediction is
174 | True and neither prediction nor truth is NaN.
175 | """
176 | p_list = []
177 | for label in labels:
178 | precision_filter = xr.where(np.logical_and(~np.isnan(data['obs']), data['fc'] == label), 1, np.nan)
179 | p_list.append((data['obs'] == label) * precision_filter)
180 | return xr.concat(p_list, dim='category').rename('metric')
181 |
182 |
183 | def _geometric_mean(a, b):
184 | return 2 * (a * b) / (a + b)
185 |
186 |
187 | def _count_binary_matches(data):
188 | data['tp'] = (data['obs'] == 1) & (data['fc'] == 1)
189 | data['fp'] = (data['obs'] == 0) & (data['fc'] == 1)
190 | data['tn'] = (data['obs'] == 0) & (data['fc'] == 0)
191 | data['fn'] = (data['obs'] == 1) & (data['fc'] == 0)
192 | return data
193 |
194 |
195 | def _threshold_digitize(data, threshold):
196 | new_data = xr.Dataset()
197 | new_data['fc'] = xr.apply_ufunc(np.digitize, data['fc'], [threshold], dask='allowed')
198 | new_data['obs'] = xr.apply_ufunc(np.digitize, data['obs'], [threshold], dask='allowed')
199 |
200 | # Re-assign NaN where appropriate, since they are converted to 1 by digitize
201 | new_data['fc'] = xr.where(np.isnan(data['fc']), np.nan, new_data['fc'])
202 | new_data['obs'] = xr.where(np.isnan(data['obs']), np.nan, new_data['obs'])
203 |
204 | return new_data
205 |
206 |
207 | def f1_thresholded(data: xr.Dataset, group_dim: str, threshold: float) -> xr.DataArray:
208 | """
209 | Compute the F1 score for a True/False classification based on values meeting or exceeding a defined threshold.
210 | :param data: Dataset
211 | :param group_dim: str: group dimension
212 | :param threshold: float threshold value
213 | :return: DataArray of scores
214 | """
215 | new_data = _threshold_digitize(data, threshold)
216 |
217 | # Precision/recall computation
218 | r = _true_positives_and_false_negatives(new_data, [1])
219 | p = _true_positives_and_false_positives(new_data, [1])
220 |
221 | # Compute and return F1
222 | f1 = _geometric_mean(
223 | p.mean(_get_mean_dims(p, group_dim)),
224 | r.mean(_get_mean_dims(r, group_dim))
225 | ).rename('metric')
226 |
227 | return f1.mean('category') # mean over the single category
228 |
229 |
230 | def ts_thresholded(data: xr.Dataset, group_dim: str, threshold: float, equitable: bool) -> xr.DataArray:
231 | """
232 | Compute the threat score for a True/False classification based on values meeting or exceeding a defined threshold.
233 | :param data: Dataset
234 | :param group_dim: str: group dimension
235 | :param threshold: float threshold value
236 | :param equitable: bool: use ETS formulation
237 | :return: DataArray of scores
238 | """
239 | def _count_equitable_random_chance(ds):
240 | n = ds['tp'] + ds['fp'] + ds['fn'] + ds['tn']
241 | correction = (ds['tp'] + ds['fp']) * (ds['tp'] + ds['fn']) / n
242 | return correction
243 |
244 | new_data = _threshold_digitize(data, threshold)
245 | new_data = _count_binary_matches(new_data)
246 |
247 | new_data = new_data.sum(_get_mean_dims(new_data, group_dim))
248 | ar = _count_equitable_random_chance(new_data) if equitable else 0
249 | ts = (new_data['tp'] - ar) / (new_data['tp'] + new_data['fp'] + new_data['fn'] - ar)
250 | return ts.rename('metric')
251 |
252 |
253 | def f1_class_averaged(data: xr.Dataset, group_dim: str) -> xr.DataArray:
254 | """
255 | Compute the F1 score as one-vs-rest for each unique category in the data. Averages the scores for each class
256 | equally.
257 | :param data: Dataset
258 | :param group_dim: str: group dimension
259 | :return: DataArray of scores
260 | """
261 | # F1 score must be aggregated in at least one dimension
262 | labels = np.unique(data['obs'].values[~np.isnan(data['obs'].values)])
263 |
264 | # Per-label metrics
265 | r = _true_positives_and_false_negatives(data, labels)
266 | p = _true_positives_and_false_positives(data, labels)
267 |
268 | # Compute and return F1
269 | f1 = _geometric_mean(
270 | p.mean(_get_mean_dims(p, group_dim)),
271 | r.mean(_get_mean_dims(r, group_dim))
272 | ).rename('metric')
273 |
274 | return f1.mean('category')
275 |
276 |
277 | def reliability(data: xr.Dataset, group_dim: str, bins: list) -> xr.DataArray:
278 | relative_freq = [0 for i in range(len(bins))]
279 | mean_predicted_value = [0 for i in range(len(bins))]
280 |
281 | truth = data['obs']
282 | # replace nan with 0 in truth
283 | truth = truth.where(truth == 1, 0)
284 | prediction = data['fc']
285 |
286 | for b, bin_ in enumerate(bins):
287 | # replace predicted prob with 0 / 1 for given range in bin
288 | pred = prediction.where((prediction <= bin_[1]) & (bin_[0] < prediction), 0)
289 | # sum up to calculate mean later
290 | pred_sum = float(np.sum(pred))
291 | pred = pred.where(pred == 0, 1)
292 |
293 | # how many days have prediction fall in the given bin
294 | n_prediction_in_bin = np.count_nonzero(pred)
295 |
296 | # how many days rain in prediction set
297 | correct_pred = np.logical_and(pred, truth)
298 | n_corr_pred = np.count_nonzero(correct_pred)
299 |
300 | if n_prediction_in_bin != 0:
301 | mean_predicted_value[b] = 100 * pred_sum / n_prediction_in_bin
302 | relative_freq[b] = 100 * n_corr_pred / n_prediction_in_bin
303 | else:
304 | mean_predicted_value[b] = 0
305 | relative_freq[b] = 0
306 |
307 | res = xr.Dataset({'metric': ([group_dim], relative_freq)}, coords={group_dim: mean_predicted_value})
308 | res = res.isel({group_dim: ~res.get_index(group_dim).duplicated()})
309 |
310 | return res['metric']
311 |
312 |
313 | def csi(data: xr.Dataset, group_dim: str, threshold: float) -> xr.DataArray:
314 | return ts_thresholded(data, group_dim, threshold, equitable=False)
315 |
316 |
317 | def pod(data: xr.Dataset, group_dim: str, threshold: float) -> xr.DataArray:
318 | new_data = _threshold_digitize(data, threshold)
319 | r = _true_positives_and_false_negatives(new_data, [1])
320 | return r.mean(_get_mean_dims(r, group_dim)).rename('metric')
321 |
322 |
323 | def far(data: xr.Dataset, group_dim: str, threshold: float) -> xr.DataArray:
324 | new_data = _threshold_digitize(data, threshold)
325 | r = _true_positives_and_false_positives(new_data, [1])
326 | return 1 - r.mean(_get_mean_dims(r, group_dim)).rename('metric')
327 |
--------------------------------------------------------------------------------
/evaluation/obs_reformat_catalog.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Optional, Sequence, Union
3 |
4 | import numpy as np
5 | import pandas as pd
6 | import xarray as xr
7 | try:
8 | from metpy.calc import specific_humidity_from_dewpoint
9 | from metpy.units import units
10 | except ImportError:
11 | specific_humidity_from_dewpoint = None
12 | units = None
13 |
14 | from .utils import convert_to_binary
15 |
16 | logger = logging.getLogger(__name__)
17 |
18 |
19 | def _convert_time_step(dt): # pylint: disable=invalid-name
20 | return pd.Timedelta(hours=dt) if isinstance(dt, (float, int)) else pd.Timedelta(dt)
21 |
22 |
23 | def reformat_and_filter_obs(obs: xr.Dataset, obs_var_name: str, interp_station_path: Optional[str],
24 | precip_threshold: Optional[float] = None) -> xr.Dataset:
25 | """
26 | Reformat and filter the observation data, and return the DataFrame with required fields.
27 | Required fields: station, valid_time, obs
28 | Filter removes the stations that are not in the interp_station list
29 |
30 | Parameters
31 | ----------
32 | obs: xarray Dataset: observation data
33 | obs_var_name: str: observation variable name
34 | interp_station_path: str: path to the interpolation station list
35 | precip_threshold: float: threshold for precipitation
36 |
37 | Returns
38 | -------
39 | xr.Dataset: formatted observation data
40 | """
41 | if interp_station_path is not None:
42 | interp_station = get_interp_station_list(interp_station_path)
43 | intersect_station = np.intersect1d(interp_station['station'].values, obs['station'].values)
44 | logger.debug(f"intersect_station count: {len(intersect_station)}, \
45 | obs_station count: {len(obs['station'].values)}, \
46 | interp_station count: {len(interp_station['station'].values)}")
47 | obs = obs.sel(station=intersect_station)
48 | if 'valid_time' not in obs.dims:
49 | obs = obs.rename({'time': 'valid_time'})
50 |
51 | if obs_var_name in ['u10', 'v10']:
52 | obs = calculate_u_v(obs, ws_name='ws', wd_name='wd', u_name='u10', v_name='v10')
53 | elif obs_var_name == 'q':
54 | obs['q'] = xr.apply_ufunc(calculate_q, obs['td'])
55 | elif obs_var_name == 'pp':
56 | precip_var = 'ra' if 'ra' in obs.data_vars else 'ra1'
57 | logger.info(f"User requested precipitation probability from obs. Using variable '{precip_var}' with "
58 | f"threshold of 0.1 mm/hr.")
59 | obs['pp'] = convert_to_binary(obs[precip_var], 0.1)
60 |
61 | if precip_threshold is not None:
62 | obs[obs_var_name] = convert_to_binary(obs[obs_var_name], precip_threshold)
63 |
64 | return obs[[obs_var_name]].rename({obs_var_name: 'obs'})
65 |
66 |
67 | def obs_to_verification(
68 | obs: Union[xr.Dataset, xr.DataArray],
69 | max_lead: Union[pd.Timedelta, int] = 168,
70 | steps: Optional[Sequence[pd.Timestamp]] = None,
71 | issue_times: Optional[Sequence[pd.Timestamp]] = None
72 | ) -> Union[xr.Dataset, xr.DataArray]:
73 | """
74 | Convert a Dataset or DataArray of continuous time-series observations
75 | according to the obs data spec into the forecast data spec for direct
76 | comparison to forecasts.
77 |
78 | Parameters
79 | ----------
80 | obs: xarray Dataset or DataArray of observation data
81 | max_lead: maximum lead time for verification dataset. If int, interpreted
82 | as hours.
83 | steps: optional sequence of lead times to retain
84 | issue_times: issue times for the forecast result. If
85 | not specified, uses all available obs times.
86 | """
87 | issue_dim = 'issue_time'
88 | lead_dim = 'lead_time'
89 | time_dim = 'valid_time'
90 | max_lead = _convert_time_step(max_lead)
91 | if issue_times is None:
92 | issue_times = obs[time_dim].values
93 | obs_series = []
94 | for issue in issue_times:
95 | try:
96 | obs_series.append(
97 | obs.sel(**{time_dim: slice(issue, issue + max_lead)}).rename({time_dim: lead_dim})
98 | )
99 | obs_series[-1] = obs_series[-1].assign_coords(
100 | **{issue_dim: [issue], lead_dim: obs_series[-1][lead_dim] - issue})
101 | except Exception as e: # pylint: disable=broad-exception-caught
102 | print(f'Failed to sel {issue} due to {e}')
103 | continue
104 | verification_ds = xr.concat(obs_series, dim=issue_dim)
105 | if steps is not None:
106 | steps = [_convert_time_step(s) for s in steps]
107 | verification_ds = verification_ds.sel(**{lead_dim: steps})
108 | verification_ds = verification_ds.assign_coords({lead_dim: verification_ds[lead_dim] / np.timedelta64(1, 'h')})
109 |
110 | return verification_ds
111 |
112 |
113 | def calculate_q(td):
114 | if specific_humidity_from_dewpoint is None:
115 | raise ImportError('metpy is not installed, specific_humidity_from_dewpoint cannot be calculated')
116 | q = specific_humidity_from_dewpoint(1013.25 * units.hPa, td * units.degC).magnitude
117 | return q
118 |
119 |
120 | def calculate_u_v(data, ws_name='ws', wd_name='wd', u_name='u10', v_name='v10'):
121 | data[u_name] = data[ws_name] * np.sin(data[wd_name] / 180 * np.pi - np.pi)
122 | data[v_name] = data[ws_name] * np.cos(data[wd_name] / 180 * np.pi - np.pi)
123 | return data
124 |
125 |
126 | def get_interp_station_list(interp_station_path: str) -> pd.DataFrame:
127 | """
128 | Read the station metadata from the interpolation station list file, and return a dataset for interpolation.
129 |
130 | Parameters
131 | ----------
132 | interp_station_path: str: path to the interpolation station list
133 |
134 | Returns
135 | -------
136 | pd.DataFrame: station metadata with columns station, lat, lon
137 | """
138 | interp_station = pd.read_csv(interp_station_path)
139 | interp_station = interp_station.rename({c: c.lower() for c in interp_station.columns}, axis=1)
140 | interp_station['lon'] = interp_station['lon'].apply(lambda lon: lon if (lon < 180) else (lon - 360))
141 | if 'station' not in interp_station.columns:
142 | interp_station['station'] = interp_station['id']
143 | return interp_station
144 |
--------------------------------------------------------------------------------
/evaluation/utils.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | import logging
3 | import os
4 | from pathlib import Path
5 | from typing import Optional
6 |
7 | import numpy as np
8 | import pandas as pd
9 | import xarray as xr
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | def configure_logging(verbose=1):
15 | verbose_levels = {
16 | 0: logging.WARNING,
17 | 1: logging.INFO,
18 | 2: logging.DEBUG,
19 | 3: logging.NOTSET
20 | }
21 | if verbose not in verbose_levels:
22 | verbose = 1
23 | logger.setLevel(verbose_levels[verbose])
24 | handler = logging.StreamHandler()
25 | handler.setFormatter(logging.Formatter(
26 | "[%(asctime)s] [PID=%(process)d] "
27 | "[%(levelname)s %(filename)s:%(lineno)d] %(message)s"))
28 | handler.setLevel(verbose_levels[verbose])
29 | logger.addHandler(handler)
30 |
31 |
32 | @dataclass
33 | class ForecastInfo:
34 | path: str
35 | forecast_name: str
36 | fc_var_name: str
37 | reformat_func: str
38 | file_type: str
39 | station_metadata_path: str
40 | interp_station_path: str
41 | output_directory: str
42 | start_date: Optional[pd.Timestamp] = None
43 | end_date: Optional[pd.Timestamp] = None
44 | issue_time_freq: Optional[str] = None
45 | start_lead: Optional[int] = None
46 | end_lead: Optional[int] = None
47 | convert_temperature: Optional[bool] = False
48 | convert_pressure: Optional[bool] = False
49 | convert_cloud: Optional[bool] = False
50 | precip_proba_threshold: Optional[float] = None
51 | metadata: Optional[pd.DataFrame] = None
52 | cache_path: Optional[str] = None
53 |
54 |
55 | @dataclass
56 | class ForecastData:
57 | info: ForecastInfo
58 | forecast: Optional[xr.Dataset] = None
59 | merge_data: Optional[xr.Dataset] = None
60 |
61 |
62 | @dataclass
63 | class MetricData:
64 | info: ForecastInfo
65 | metric_data: Optional[xr.Dataset] = None
66 |
67 |
68 | def get_metric_multiple_stations(files):
69 | data = {}
70 | files = files.split(',')
71 | for f in files:
72 | if os.path.exists(f):
73 | try:
74 | key = os.path.basename(f).replace(".csv", "")
75 | data[key] = [str(id) for id in pd.read_csv(f)['Station'].tolist()]
76 | except Exception:
77 | logger.error(f"Error opening {f}!")
78 | raise
79 | else:
80 | raise Warning(f'File {f} do not exist!')
81 | return data
82 |
83 |
84 | def generate_forecast_cache_path(info: ForecastInfo):
85 | if info.cache_path is not None:
86 | return info.cache_path
87 | file_name = Path(info.path).stem
88 | forecast_name = info.forecast_name
89 | fc_var_name = info.fc_var_name
90 | reformat_func = info.reformat_func
91 | interp_station = Path(info.interp_station_path).stem if info.interp_station_path is not None else 'default'
92 | cache_directory = os.path.join(info.output_directory, 'cache')
93 | os.makedirs(cache_directory, exist_ok=True)
94 | cache_file_name = "##".join([file_name, forecast_name, fc_var_name, reformat_func, interp_station, 'cache'])
95 | info.cache_path = os.path.join(cache_directory, cache_file_name)
96 | return info.cache_path
97 |
98 |
99 | def cache_reformat_forecast(forecast_ds, cache_path):
100 | logger.info(f"saving forecast to cache at {cache_path}")
101 | forecast_ds.to_zarr(cache_path, mode='w')
102 |
103 |
104 | def load_reformat_forecast(cache_path):
105 | forecast_ds = xr.open_zarr(cache_path)
106 | return forecast_ds
107 |
108 |
109 | def get_ideal_xticks(min_lead, max_lead, tick_count=8):
110 | """
111 | Pick the best interval for the x axis (hours) that optimizes the number of ticks
112 | """
113 | candidate_intervals = [1, 3, 6, 12, 24]
114 | tick_counts = []
115 | for interval in candidate_intervals:
116 | num_ticks = (max_lead - min_lead) / interval
117 | tick_counts.append(abs(num_ticks - tick_count))
118 | best_interval = candidate_intervals[tick_counts.index(min(tick_counts))]
119 | return np.arange(min_lead, max_lead + best_interval, best_interval)
120 |
121 |
122 | def convert_to_binary(da, threshold):
123 | result = xr.where(da >= threshold, np.float32(1.0), np.float32(0.0))
124 | result = result.where(~np.isnan(da), np.nan)
125 | return result
126 |
--------------------------------------------------------------------------------
/metric_config.yml:
--------------------------------------------------------------------------------
1 | temperature:
2 | base_plot_setting:
3 | title: Temperature
4 | xlabel:
5 | lead_time: Lead Hour (H)
6 | issue_time: Issue Time (UTC)
7 | valid_time: Valid Time (UTC)
8 |
9 | metrics:
10 | RMSE:
11 | method: rmse
12 |
13 | MAE:
14 | method: mae
15 |
16 | ACCURACY_1.7:
17 | method: accuracy
18 | threshold: 1.7
19 |
20 | ERROR_5.6:
21 | method: error
22 | threshold: 5.6
23 |
24 | cloud:
25 | base_plot_setting:
26 | title: Cloud
27 | xlabel:
28 | lead_time: Lead Hour (H)
29 | issue_time: Issue Time (UTC)
30 | valid_time: Valid Time (UTC)
31 |
32 | metrics:
33 | RMSE:
34 | method: rmse
35 |
36 | MAE:
37 | method: mae
38 |
39 | ACCURACY_2:
40 | method: accuracy
41 | threshold: 2
42 |
43 | ERROR_5:
44 | method: accuracy
45 | threshold: 5
46 |
47 | ETS_1p5:
48 | method: threat_score
49 | threshold: 1.5
50 | equitable: True
51 |
52 | ETS_6p5:
53 | method: threat_score
54 | threshold: 6.5
55 | equitable: True
56 |
57 | wind:
58 | base_plot_setting:
59 | title: wind
60 | xlabel:
61 | lead_time: Lead Hour (H)
62 | issue_time: Issue Time (UTC)
63 | valid_time: Valid Time (UTC)
64 |
65 | metrics:
66 | RMSE:
67 | method: rmse
68 |
69 | MAE:
70 | method: mae
71 |
72 | ACCURACY_1:
73 | method: accuracy
74 | threshold: 1
75 |
76 | ERROR_3:
77 | method: error
78 | threshold: 3
79 |
80 | specific_humidity:
81 | base_plot_setting:
82 | title: specific_humidity
83 | xlabel:
84 | lead_time: Lead Hour (H)
85 | issue_time: Issue Time (UTC)
86 | valid_time: Valid Time (UTC)
87 |
88 | metrics:
89 | RMSE:
90 | method: rmse
91 |
92 | MAE:
93 | method: mae
94 |
95 | precipitation:
96 | base_plot_setting:
97 | title: precipitation(mm)
98 | xlabel:
99 | lead_time: Lead Hour (H)
100 | issue_time: Issue Time (UTC)
101 | valid_time: Valid Time (UTC)
102 |
103 | metrics:
104 | RMSE:
105 | method: rmse
106 |
107 | ETS_0p1:
108 | method: threat_score
109 | threshold: 0.1
110 | equitable: True
111 |
112 | ETS_1:
113 | method: threat_score
114 | threshold: 1.0
115 | equitable: True
116 |
117 | POD_1:
118 | method: pod
119 | threshold: 1.0
120 |
121 | FAR_1:
122 | method: far
123 | threshold: 1.0
124 |
125 | precip_proba:
126 | base_plot_setting:
127 | title: precipitation probability(%)
128 | xlabel:
129 | lead_time: Lead Hour (H)
130 | issue_time: Issue Time (UTC)
131 | valid_time: Valid Time (UTC)
132 |
133 | metrics:
134 | Brier_Score:
135 | method: brier
136 |
137 | Occurrence_Frequency(%):
138 | method: reliability
139 | bins: [[0, 0.05], [0.05, 0.15], [0.15, 0.25], [0.25, 0.35], [0.35, 0.45], [0.45, 0.55], [0.55, 0.65], [0.65, 0.75], [0.75, 0.85], [0.85, 0.95], [0.95, 1]]
140 | plot_setting:
141 | xlabel:
142 | lead_time: Bins of probability(%)
143 |
144 | ETS_0.1:
145 | method: threat_score
146 | threshold: 0.1
147 | equitable: True
148 |
149 | ETS_0.35:
150 | method: threat_score
151 | threshold: 0.35
152 | equitable: True
153 |
154 | ETS_0.4:
155 | method: threat_score
156 | threshold: 0.4
157 | equitable: True
158 |
159 | ETS_0.5:
160 | method: threat_score
161 | threshold: 0.5
162 | equitable: True
163 |
164 | precip_binary:
165 | base_plot_setting:
166 | title: precipitation binary
167 | xlabel:
168 | lead_time: Lead Hour (H)
169 | issue_time: Issue Time (UTC)
170 | valid_time: Valid Time (UTC)
171 |
172 | metrics:
173 | ETS:
174 | method: threat_score
175 | threshold: 0.5
176 | equitable: True
177 |
178 | POD:
179 | method: pod
180 | threshold: 0.5
181 |
182 | FAR:
183 | method: far
184 | threshold: 0.5
185 |
186 | Accuracy:
187 | method: accuracy
188 | threshold: 0.1
189 |
190 | pressure:
191 | base_plot_setting:
192 | title: Mean sea-level pressure
193 | xlabel:
194 | lead_time: Lead Hour (H)
195 | issue_time: Issue Time (UTC)
196 | valid_time: Valid Time (UTC)
197 |
198 | metrics:
199 | RMSE:
200 | method: rmse
201 |
202 | MAE:
203 | method: mae
204 |
205 | ACCURACY_2:
206 | method: accuracy
207 | threshold: 2
208 |
209 | ERROR_5:
210 | method: error
211 | threshold: 5
--------------------------------------------------------------------------------
/pylintrc:
--------------------------------------------------------------------------------
1 | [MAIN]
2 |
3 | max-line-length=120
4 |
5 | [MESSAGES CONTROL]
6 |
7 | disable=
8 | missing-module-docstring,
9 | missing-class-docstring,
10 | missing-function-docstring,
11 | missing-docstring,
12 | invalid-name,
13 | import-error,
14 | logging-fstring-interpolation,
15 | too-many-arguments,
16 | too-many-locals,
17 | too-many-branches,
18 | too-many-statements,
19 | too-many-return-statements,
20 | too-many-instance-attributes,
21 | too-many-positional-arguments,
22 | duplicate-code
23 |
--------------------------------------------------------------------------------
/quality_control/README.md:
--------------------------------------------------------------------------------
1 | The quality control code in this directory is written in Python and is used by WeatherReal for downloading, post-processing, and quality control of ISD data. However, its modules can also be used for quality control of observation data from other sources. It includes four launch scripts:
2 |
3 | 1. `download_ISD.py`, used to download ISD data from the NCEI server;
4 | 2. `raw_ISD_to_hourly.py`, used to convert ISD data into hourly data;
5 | 3. `station_merging.py`, used to merge data from duplicate stations;
6 | 4. `quality_control.py`, used to perform quality control on hourly data.
7 |
8 | For the specific workflow, please refer to the WeatherReal paper.
--------------------------------------------------------------------------------
/quality_control/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/WeatherReal-Benchmark/68b2f9293d2a0a1b1cedf396cffda34c05a21a14/quality_control/__init__.py
--------------------------------------------------------------------------------
/quality_control/algo/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/WeatherReal-Benchmark/68b2f9293d2a0a1b1cedf396cffda34c05a21a14/quality_control/algo/__init__.py
--------------------------------------------------------------------------------
/quality_control/algo/cluster.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from sklearn.cluster import DBSCAN
3 | from .utils import intra_station_check, quality_control_statistics, get_config
4 |
5 |
6 | CONFIG = get_config()
7 |
8 |
9 | def _cluster_check(ts, reanalysis, min_samples_ratio, eps_scale, max_std_scale=None, min_num=None):
10 | """
11 | Perform a cluster-based check on time series data compared to reanalysis data.
12 |
13 | This function uses DBSCAN (Density-Based Spatial Clustering of Applications with Noise)
14 | to identify clusters in the difference between the time series and reanalysis data.
15 | It then flags outliers based on cluster membership.
16 |
17 | Parameters:
18 | -----------
19 | ts : array-like
20 | The time series data to be checked.
21 | reanalysis : array-like
22 | The corresponding reanalysis data for comparison.
23 | min_samples_ratio : float
24 | The ratio of minimum samples required to form a cluster in DBSCAN.
25 | eps_scale : float
26 | The scale factor for epsilon in DBSCAN, relative to the standard deviation of reanalysis.
27 | max_std_scale : float, optional
28 | If the ratio of standard deviations between ts and reanalysis is less than max_std_scale,
29 | the check is not performed
30 | min_num : int, optional
31 | Minimum number of valid data points required to perform the check.
32 |
33 | Returns:
34 | --------
35 | flag : np.ndarray
36 | 1D array with the same length as ts, containing flags
37 | """
38 | flag = np.full(len(ts), CONFIG["flag_missing"], dtype=np.int8)
39 | isnan = np.isnan(ts)
40 | flag[~isnan] = CONFIG["flag_normal"]
41 | both_valid = (~isnan) & (~np.isnan(reanalysis))
42 |
43 | if min_num is not None and both_valid.sum() < min_num:
44 | return flag
45 |
46 | if max_std_scale is not None:
47 | if np.std(ts[both_valid]) / np.std(reanalysis[both_valid]) <= max_std_scale:
48 | return flag
49 |
50 | indices = np.argwhere(both_valid).flatten()
51 | values1 = ts[indices]
52 | values2 = reanalysis[indices]
53 |
54 | cluster = DBSCAN(min_samples=int(indices.size * min_samples_ratio), eps=np.std(reanalysis) * eps_scale)
55 | labels = cluster.fit((values1 - values2).reshape(-1, 1)).labels_
56 |
57 | # If all data points except noise are in the same cluster, just remove the noise
58 | if np.max(labels) <= 0:
59 | indices_outliers = indices[np.argwhere(labels == -1).flatten()]
60 | flag[indices_outliers] = CONFIG["flag_error"]
61 | # If there is more than one cluster, select the cluster nearest to the reanalysis
62 | else:
63 | cluster_labels = np.unique(labels)
64 | cluster_labels = list(cluster_labels >= 0)
65 | best_cluster = 0
66 | min_median = np.inf
67 | for label in cluster_labels:
68 | indices_cluster = indices[np.argwhere(labels == label).flatten()]
69 | median = np.median(ts[indices_cluster] - reanalysis[indices_cluster])
70 | if np.abs(median) < min_median:
71 | best_cluster = label
72 | min_median = np.abs(median)
73 | indices_outliers = indices[np.argwhere(labels != best_cluster).flatten()]
74 | flag[indices_outliers] = CONFIG["flag_error"]
75 |
76 | return flag
77 |
78 |
79 | def run(da, reanalysis, varname):
80 | """
81 | To ensure the accuracy of the parameters in the distributional gap method, a DBSCAN is used first,
82 | so that the median and MAD used for filtering are only calculated by the normal data
83 | """
84 | flag = intra_station_check(
85 | da,
86 | reanalysis,
87 | qc_func=_cluster_check,
88 | input_core_dims=[["time"], ["time"]],
89 | kwargs=CONFIG["cluster"][varname],
90 | )
91 | quality_control_statistics(da, flag)
92 | return flag.rename("cluster")
93 |
--------------------------------------------------------------------------------
/quality_control/algo/config.yaml:
--------------------------------------------------------------------------------
1 | ############################################################
2 | # The flag values for the data quality check
3 | ############################################################
4 |
5 | flag_error: 2
6 | flag_suspect: 1
7 | flag_normal: 0
8 | flag_missing: -1
9 |
10 | ############################################################
11 | # The parameters for the data quality check
12 | ############################################################
13 |
14 | # record extreme check
15 | record:
16 | t:
17 | upper: 57.8
18 | lower: -89.2
19 | td:
20 | upper: 57.8
21 | lower: -100.0
22 | ws:
23 | upper: 113.2
24 | lower: 0
25 | wd:
26 | upper: 360
27 | lower: 0
28 | sp:
29 | upper: 1100.0
30 | lower: 300.0
31 | msl:
32 | upper: 1083.3
33 | lower: 870.0
34 | c:
35 | upper: 8
36 | lower: 0
37 | ra1:
38 | upper: 305.0
39 | lower: 0
40 | ra3:
41 | upper: 915.0
42 | lower: 0
43 | ra6:
44 | upper: 1144.0
45 | lower: 0
46 | ra12:
47 | upper: 1144.0
48 | lower: 0
49 | ra24:
50 | upper: 1825.0
51 | lower: 0
52 |
53 | # Persistence check
54 | persistence:
55 | defaults: &persistence_defaults
56 | min_num: 24
57 | max_window: 72
58 | min_var: 0.1
59 | error_length: 72
60 | t:
61 | <<: *persistence_defaults
62 | td:
63 | <<: *persistence_defaults
64 | sp:
65 | <<: *persistence_defaults
66 | msl:
67 | <<: *persistence_defaults
68 | ws:
69 | <<: *persistence_defaults
70 | exclude_value: 0
71 | wd:
72 | <<: *persistence_defaults
73 | exclude_value: 0
74 | c:
75 | <<: *persistence_defaults
76 | exclude_value:
77 | - 0
78 | - 8
79 | ra1:
80 | <<: *persistence_defaults
81 | exclude_value: 0
82 | ra3:
83 | <<: *persistence_defaults
84 | exclude_value: 0
85 | ra6:
86 | <<: *persistence_defaults
87 | exclude_value: 0
88 | ra12:
89 | <<: *persistence_defaults
90 | exclude_value: 0
91 | ra24:
92 | <<: *persistence_defaults
93 | exclude_value: 0
94 |
95 | # Spike check
96 | spike:
97 | t:
98 | max_change:
99 | - 6
100 | - 8
101 | - 10
102 | td:
103 | max_change:
104 | - 5
105 | - 7
106 | - 9
107 | sp:
108 | max_change:
109 | - 3
110 | - 5
111 | - 7
112 | msl:
113 | max_change:
114 | - 3
115 | - 5
116 | - 7
117 |
118 | # Distributional gap check with ERA5
119 | distribution:
120 | defaults: &distribution_defaults
121 | shift_step: 1
122 | gap_scale: 2
123 | default_mad: 1
124 | suspect_std_scale: 2.72
125 | min_num: 365
126 | t:
127 | <<: *distribution_defaults
128 | td:
129 | <<: *distribution_defaults
130 | sp:
131 | <<: *distribution_defaults
132 | default_mad: 0.5
133 | msl:
134 | <<: *distribution_defaults
135 | default_mad: 0.5
136 |
137 | # Cluster check
138 | cluster:
139 | defaults: &cluster_defaults
140 | min_samples_ratio: 0.1
141 | eps_scale: 2
142 | max_std_scale: 2
143 | min_num: 365
144 | t:
145 | <<: *cluster_defaults
146 |
147 | td:
148 | <<: *cluster_defaults
149 | sp:
150 | <<: *cluster_defaults
151 | msl:
152 | <<: *cluster_defaults
153 |
154 | # Neighbouring station check
155 | neighbouring:
156 | defaults: &neighbouring_defaults
157 | <<: *distribution_defaults
158 | max_dist: 300
159 | max_elev_diff: 500
160 | min_data_overlap: 0.3
161 | t:
162 | <<: *neighbouring_defaults
163 | td:
164 | <<: *neighbouring_defaults
165 | sp:
166 | <<: *neighbouring_defaults
167 | default_mad: 0.5
168 | msl:
169 | <<: *neighbouring_defaults
170 | default_mad: 0.5
171 |
172 | # Flag refinement
173 | refinement:
174 | t:
175 | check_monotonic: true
176 | check_ridge_trough: false
177 | td:
178 | check_monotonic: true
179 | check_ridge_trough: false
180 | sp:
181 | check_monotonic: true
182 | check_ridge_trough: true
183 | msl:
184 | check_monotonic: true
185 | check_ridge_trough: true
186 |
187 | diurnal:
188 | t:
189 | max_bias: 0.5
190 |
--------------------------------------------------------------------------------
/quality_control/algo/cross_variable.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import xarray as xr
3 | from .utils import get_config
4 |
5 |
6 | CONFIG = get_config()
7 |
8 |
9 | def supersaturation(t, td):
10 | flag = xr.where(t < td, CONFIG["flag_error"], CONFIG["flag_normal"])
11 | isnan = np.isnan(t) | np.isnan(td)
12 | flag = flag.where(~isnan, CONFIG["flag_missing"])
13 | return flag.astype(np.int8)
14 |
15 |
16 | def wind_consistency(ws, wd):
17 | zero_ws = ws == 0
18 | zero_wd = wd == 0
19 | inconsistent = zero_ws != zero_wd
20 | flag = xr.where(inconsistent, CONFIG["flag_error"], CONFIG["flag_normal"])
21 | isnan = np.isnan(ws) | np.isnan(wd)
22 | flag = flag.where(~isnan, CONFIG["flag_missing"])
23 | return flag.astype(np.int8)
24 |
25 |
26 | def ra_consistency(ds):
27 | """
28 | Precipitation in different period length is cross validated
29 | If there is a conflict (ra1 at 18:00 is 5mm, while ra3 at 19:00 is 0mm),
30 | both of them are flagged
31 | Parameters:
32 | -----------
33 | ds: xarray dataset with precipitation data including ra1, ra3, ra6, ra12 and ra24
34 | Returns:
35 | --------
36 | flag : numpy.ndarray
37 | 1D array with the same length as ts, containing flags for each value.
38 | """
39 | flag = xr.full_like(ds, CONFIG["flag_normal"], dtype=np.int8)
40 | periods = [3, 6, 12, 24]
41 | for period in periods:
42 | da_longer = ds[f"ra{period}"]
43 | for shift in range(1, period):
44 | # Shift the precipitation in the longer period to align with the shorter period
45 | shifted = da_longer.roll(time=-shift)
46 | shifted[:, -shift:] = np.nan
47 | for target_period in [1, 3, 6, 12, 24]:
48 | # Check if the two periods are overlapping
49 | if target_period >= period or target_period + shift > period:
50 | continue
51 | # If the precipitation in the shorter period is larger than the longer period, flag both
52 | flag_indices = np.where(ds[f"ra{target_period}"].values - shifted.values > 0.101)
53 | if len(flag_indices[0]) == 0:
54 | continue
55 | flag[f"ra{target_period}"].values[flag_indices] = 1
56 | flag[f"ra{period}"].values[(flag_indices[0], flag_indices[1]+shift)] = 1
57 | flag = flag.where(ds.notnull(), CONFIG["flag_missing"])
58 | return flag
59 |
--------------------------------------------------------------------------------
/quality_control/algo/distributional_gap.py:
--------------------------------------------------------------------------------
1 | import xarray as xr
2 | from .utils import get_config, intra_station_check, quality_control_statistics
3 | from .time_series import _time_series_comparison
4 |
5 |
6 | CONFIG = get_config()
7 |
8 |
9 | def _distributional_gap(
10 | ts,
11 | reanalysis,
12 | mask,
13 | shift_step,
14 | gap_scale,
15 | default_mad,
16 | suspect_std_scale,
17 | min_mad=0.1,
18 | min_num=None,
19 | ):
20 | flag = _time_series_comparison(
21 | ts1=ts,
22 | ts2=reanalysis,
23 | shift_step=shift_step,
24 | gap_scale=gap_scale,
25 | default_mad=default_mad,
26 | suspect_std_scale=suspect_std_scale,
27 | min_mad=min_mad,
28 | min_num=min_num,
29 | mask=mask,
30 | )
31 | return flag
32 |
33 |
34 | def run(da, reanalysis, varname, mask=None):
35 | """
36 | Perform a distributional gap check on time series data compared to reanalysis data.
37 | """
38 | flag = intra_station_check(
39 | da,
40 | reanalysis,
41 | mask if mask is not None else xr.full_like(da, True, dtype=bool),
42 | qc_func=_distributional_gap,
43 | input_core_dims=[["time"], ["time"], ["time"]],
44 | kwargs=CONFIG["distribution"][varname],
45 | )
46 | quality_control_statistics(da, flag)
47 | return flag.rename("distributional_gap")
48 |
--------------------------------------------------------------------------------
/quality_control/algo/diurnal_cycle.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import xarray as xr
4 | import bottleneck as bn
5 | from suntime import Sun, SunTimeException
6 | from .utils import get_config, intra_station_check
7 |
8 |
9 | CONFIG = get_config()
10 |
11 |
12 | def _diurnal_cycle_check_daily(ts, lat, lon, date, max_bias):
13 | """
14 | Fit the daily temperature time series with a sine curve
15 | The amplitude is estimated by the daily range of the time series while the phase is estimated by the sunrise time
16 | If the time series is well fitted by the sine curve, it is considered as normal
17 | Returns:
18 | --------
19 | flag: int, a single flag indicating the result of the daily series
20 | """
21 | if ts.size != 24:
22 | raise ValueError("The time series should have 24 values")
23 | # Only check stations between 60S and 60N with significant diurnal cycles
24 | if abs(lat) > 60:
25 | return CONFIG["flag_missing"]
26 | # Only check samples with at least 2 valid data points in each quartile of the day
27 | if (np.isfinite(ts.reshape(-1, 6)).sum(axis=1) < 2).any():
28 | return CONFIG["flag_missing"]
29 |
30 | maxv = bn.nanmax(ts)
31 | minv = bn.nanmin(ts)
32 | amplitude = (maxv - minv) / 2
33 |
34 | # Only check samples with significant amplitude
35 | if amplitude < 5:
36 | return CONFIG["flag_missing"]
37 |
38 | timestep = pd.Timestamp(date)
39 | try:
40 | sunrise = Sun(float(lat), float(lon)).get_sunrise_time(timestep)
41 | except SunTimeException:
42 | return CONFIG["flag_missing"]
43 |
44 | # Normalize the time series by the max and min values
45 | normed = (ts - maxv + amplitude) / amplitude
46 |
47 | # Assume the diurnal cycle is a sine curve and the valley is 1H before the sunrise time
48 | shift = (sunrise.hour + sunrise.minute / 60) / 24 * 2 * np.pi - 20 / 12 * np.pi - timestep.hour / 12 * np.pi
49 | # A tolerance of 3 hours is allowed
50 | tolerance_values = np.arange(-np.pi / 4, np.pi / 4 + 1e-5, np.pi / 12)
51 | for tolerance in tolerance_values:
52 | # Try to find a best fitted phase of the sine curve
53 | sine_curve = np.sin((2 * np.pi / 24) * np.arange(24) - shift - tolerance)
54 | if bn.nanmax(np.abs(normed - sine_curve)) < max_bias:
55 | return CONFIG["flag_normal"]
56 |
57 | return CONFIG["flag_suspect"]
58 |
59 |
60 | def _diurnal_cycle_check(ts, flagged, lat, lon, dates, max_bias):
61 | """
62 | Check the diurnal cycle of the temperature time series only for short period flagged by `flagged`
63 | Parameters:
64 | -----------
65 | ts: 1D np.array, the daily temperature time series
66 | flagged: 1D np.array, the flag array from other algorithms
67 | lat: float, the latitude of the station
68 | lon: float, the longitude of the station
69 | date: 1D np.array of numpy.datetime64, the date of the time series
70 | max_bias: float, the maximum bias allowed for the sine curve fitting (suggested: 0.5-1)
71 | Returns:
72 | --------
73 | new_flag: 1D np.array, only the checked days are flagged as either normal or suspect
74 | """
75 | new_flag = np.full_like(flagged, CONFIG["flag_missing"])
76 | length = len(flagged)
77 | error_flags = np.argwhere((flagged == CONFIG["flag_error"]) | (flagged == CONFIG["flag_suspect"])).flatten()
78 | end_idx = 0
79 | for idx, start_idx in enumerate(error_flags):
80 | if start_idx <= end_idx:
81 | continue
82 | # Combine these short erroneous/suspect periods into a longer one
83 | end_idx = start_idx
84 | for next_idx in error_flags[idx+1:]:
85 | if (
86 | (flagged[idx+1: next_idx] == CONFIG["flag_normal"]).any() or
87 | (flagged[idx+1: next_idx] == CONFIG["flag_suspect"]).any()
88 | ):
89 | break
90 | end_idx = next_idx
91 | period_length = end_idx - start_idx + 1
92 | if period_length > 12:
93 | continue
94 | # Select the daily series centered at the short erroneous/suspect period
95 | if length % 2 == 1:
96 | num_left, num_right = (24-period_length) // 2, (24-period_length) // 2 + 1
97 | else:
98 | num_left, num_right = (24-period_length) // 2, (24-period_length) // 2
99 | if (flagged[start_idx-num_left: start_idx] != CONFIG["flag_normal"]).all():
100 | continue
101 | if (flagged[end_idx+1: end_idx+1+num_right] != CONFIG["flag_normal"]).all():
102 | continue
103 |
104 | daily_start_idx = start_idx - num_left
105 | if daily_start_idx < 0:
106 | daily_start_idx = 0
107 | elif daily_start_idx > length - 24:
108 | daily_start_idx = length - 24
109 | daily_ts = ts[daily_start_idx: daily_start_idx + 24]
110 | daily_flag = _diurnal_cycle_check_daily(daily_ts, lat, lon, dates[daily_start_idx], max_bias=max_bias)
111 | new_flag[start_idx: end_idx+1] = daily_flag
112 |
113 | new_flag[np.isnan(ts)] = CONFIG["flag_missing"]
114 | return new_flag
115 |
116 |
117 | def adjust_by_diurnal(target_flag, diurnal_flag):
118 | """
119 | Diurnal cycle check can be used for refine flags from other algorithms
120 | """
121 | normal_diurnal = diurnal_flag == CONFIG["flag_normal"]
122 | error_target = target_flag == CONFIG["flag_error"]
123 | new_flag = target_flag.copy()
124 | new_flag = xr.where(normal_diurnal & error_target, CONFIG["flag_suspect"], new_flag)
125 | return new_flag
126 |
127 |
128 | def run(da, checked_flag):
129 | """
130 | Perform a diurnal cycle check on Temperature time series
131 | Returns:
132 | --------
133 | The adjusted flags
134 | """
135 | flag = intra_station_check(
136 | da,
137 | checked_flag,
138 | da["lat"],
139 | da["lon"],
140 | da["time"],
141 | qc_func=_diurnal_cycle_check,
142 | input_core_dims=[["time"], ["time"], [], [], ["time"]],
143 | kwargs=CONFIG["diurnal"]["t"],
144 | )
145 | new_flag = adjust_by_diurnal(checked_flag, flag)
146 | return new_flag.rename("diurnal_cycle")
147 |
--------------------------------------------------------------------------------
/quality_control/algo/fine_tuning.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import bottleneck as bn
3 | import xarray as xr
4 | from .utils import get_config, intra_station_check
5 |
6 |
7 | CONFIG = get_config()
8 |
9 |
10 | def _suspect_upgrade(ts_flag):
11 | """
12 | Upgrade suspect flag period if they are surrounded by error flags
13 | Returns:
14 | --------
15 | new_flag: np.array, the upgraded flag array with only part of suspect flag upgraded to erroneous
16 | """
17 | new_flag = ts_flag.copy()
18 |
19 | # Get the start and end indices of each continous suspect period
20 | is_suspect = (ts_flag == CONFIG["flag_suspect"]) | (ts_flag == CONFIG["flag_missing"])
21 | is_suspect = np.insert(is_suspect, 0, False)
22 | is_suspect = np.append(is_suspect, False)
23 | start = np.where(is_suspect & ~np.roll(is_suspect, 1))[0] - 1
24 | end = np.where(is_suspect & ~np.roll(is_suspect, -1))[0] - 1
25 |
26 | # Filter out the suspect period in case that pure missing flags are included
27 | periods = [item for item in list(zip(start, end)) if (ts_flag[item[0]: item[1]+1] == CONFIG["flag_suspect"]).any()]
28 |
29 | length = len(ts_flag)
30 | for start_idx, end_idx in periods:
31 | if (start_idx == 0) or (end_idx == length - 1):
32 | continue
33 | is_error_left = ts_flag[start_idx-1] == CONFIG["flag_error"]
34 | is_error_right = ts_flag[end_idx+1] == CONFIG["flag_error"]
35 | if is_error_left and is_error_right:
36 | new_flag[start_idx:end_idx+1] = CONFIG["flag_error"]
37 |
38 | new_flag[ts_flag == CONFIG["flag_missing"]] = CONFIG["flag_missing"]
39 | return new_flag
40 |
41 |
42 | def _upgrade_flags_window(flags, da):
43 | """
44 | Upgrade suspect flags based on the proportion of flagged points among valid data points in a sliding window
45 | """
46 | window_size = 720 # One month sliding window
47 | threshold = 0.5 # More than half of the data points
48 |
49 | # Convert to numpy arrays for faster computation
50 | flags_array = flags.values
51 | data_array = da.values
52 |
53 | mask = (flags_array == CONFIG["flag_suspect"]) | (flags_array == CONFIG["flag_error"])
54 | valid = ~np.isnan(data_array)
55 |
56 | # Calculate rolling sums
57 | rolling_sum = bn.move_sum(mask.astype(np.float64), window=window_size, min_count=1, axis=1)
58 | rolling_sum_valid = bn.move_sum(valid.astype(np.float64), window=window_size, min_count=1, axis=1)
59 | rolling_sum_valid = np.where(rolling_sum_valid == 0, np.nan, rolling_sum_valid)
60 |
61 | # Calculate proportion of flagged points among valid data points
62 | proportion_flagged = rolling_sum / rolling_sum_valid
63 |
64 | # Create a padded array for centered calculation
65 | pad_width = window_size // 2
66 | exceed_threshold = np.pad(proportion_flagged > threshold, ((0, 0), (pad_width, pad_width)), mode='edge')
67 |
68 | # Use move_max on the padded array
69 | expanded_mask = bn.move_max(exceed_threshold.astype(np.float64), window=window_size, min_count=1, axis=1)
70 |
71 | # Remove padding
72 | expanded_mask = expanded_mask[:, pad_width:-pad_width]
73 |
74 | # Upgrade suspect flags to error flags where the expanded mask is True
75 | upgraded_flags = np.where(
76 | (flags_array == CONFIG["flag_suspect"]) & (expanded_mask == 1),
77 | CONFIG["flag_error"],
78 | flags_array
79 | )
80 |
81 | # Convert back to xarray DataArray
82 | return xr.DataArray(upgraded_flags, coords=flags.coords, dims=flags.dims)
83 |
84 |
85 | def _upgrade_flags_all(flags, da):
86 | """
87 | If more than half of the data points at a station are flagged as erroneous,
88 | all the data points at this station are flagged as erroneous
89 | """
90 | threshold = 0.5 # More than half of the data points
91 |
92 | mask = flags == CONFIG["flag_error"]
93 | valid = da.notnull()
94 | proportion_flagged = mask.sum(dim="time") / valid.sum(dim="time")
95 |
96 | upgraded_flags = flags.where(proportion_flagged < threshold, CONFIG["flag_error"])
97 | upgraded_flags = flags.where(valid, CONFIG["flag_missing"])
98 |
99 | return upgraded_flags
100 |
101 |
102 | def run(flag, da):
103 | flag = intra_station_check(flag, qc_func=_suspect_upgrade)
104 | flag = _upgrade_flags_window(flag, da)
105 | flag = _upgrade_flags_all(flag, da)
106 | return flag.rename("fine_tuning")
107 |
--------------------------------------------------------------------------------
/quality_control/algo/neighbouring_stations.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import xarray as xr
3 | from .time_series import _time_series_comparison
4 | from .utils import get_config, quality_control_statistics
5 |
6 |
7 | CONFIG = get_config()
8 |
9 |
10 | def _select_neighbouring_stations(similarity, stn_data, lat, lon, data, max_dist, max_elev_diff, min_data_overlap):
11 | """
12 | Select neighbouring stations based on distance, elevations and data overlap
13 | At most 8 stations will be selected, 2 at most for each direction (northwest, northeast, southwest, southeast)
14 | Parameters:
15 | similarity: xr.Dataset generated at the previous step with only the checked station
16 | stn_data: xr.DataArray with the data of the station
17 | lat: latitude of the station
18 | lon: longitude of the station
19 | data: xr.Dataset with the data of all stations
20 | max_dist: maximum distance in km
21 | max_elev_diff: maximum elevation difference in m
22 | min_data_overlap: minimum data overlap in fraction
23 | Return:
24 | candidates: list of neighbouring station names
25 | """
26 | similar_position = similarity["dist"] >= np.exp(-max_dist / 25)
27 | similar_elevation = similarity["elev"] >= np.exp(-max_elev_diff / 100)
28 | candidates = similarity["target"][similar_position & similar_elevation].values
29 | candidates = [item for item in candidates if item != similarity["station"].item()]
30 |
31 | # Filter out stations with insufficient data overlap
32 | is_valid = stn_data.notnull()
33 | if not is_valid.any():
34 | return []
35 | is_valid_neighboring = data.sel(station=candidates).notnull()
36 | overlap = (is_valid & is_valid_neighboring).sum(dim="time") / is_valid.sum()
37 | candidates = overlap["station"].values[overlap >= min_data_overlap]
38 |
39 | similarity = similarity.sel(target=candidates).sortby("dist", ascending=False)
40 | lon_diff = similarity["lon"] - lon
41 | # In case of the longitudes cross the central meridian
42 | iswest = ((lon_diff < 0) & (np.abs(lon_diff) < 180)) | ((lon_diff > 0) & (np.abs(lon_diff) > 180))
43 | isnorth = (similarity["lat"] - lat) > 0
44 | northwest_stations = similarity["target"][iswest & isnorth].values[:2]
45 | northeast_stations = similarity["target"][~iswest & isnorth].values[:2]
46 | southwest_stations = similarity["target"][iswest & ~isnorth].values[:2]
47 | southeast_stations = similarity["target"][~iswest & ~isnorth].values[:2]
48 | return np.concatenate([northwest_stations, northeast_stations, southwest_stations, southeast_stations])
49 |
50 |
51 | def _load_similarity(fpath, dataset):
52 | """
53 | Load the similarity calculated at the previous step
54 | """
55 | similarity = xr.load_dataset(fpath)
56 | similarity = similarity.sel(station1=dataset["station"].values, station2=dataset["station"].values)
57 | similarity = similarity.assign_coords(lon=("station2", dataset["lon"].values))
58 | similarity = similarity.assign_coords(lat=("station2", dataset["lat"].values))
59 | similarity = similarity.rename({"station1": "station", "station2": "target"})
60 | return similarity
61 |
62 |
63 | def _neighbouring_stations_check_base(
64 | flag,
65 | similarity,
66 | data,
67 | max_dist,
68 | max_elev_diff,
69 | min_data_overlap,
70 | shift_step,
71 | gap_scale,
72 | default_mad,
73 | suspect_std_scale,
74 | min_num,
75 | ):
76 | """
77 | Check all stations in data by comparing them with neighbouring stations
78 | For each time step, when there are at least 3 neighbouring values available,
79 | and 2 / 3 of them are in agreement (either verified or dubious), save the result
80 | Parameters:
81 | -----------
82 | flag: a initialized DataArray with the same station dimension as the data
83 | similarity: similarity matrix between stations used for selecting neighbouring stations
84 | data: DataArray to be checked
85 | max_dist/max_elev_diff/min_data_overlap: parameters for selecting neighbouring stations
86 | shift_step/gap_scale/default_mad/suspect_std_scale: parameters for the time series comparison
87 | min_num: minimum number of valid data points to perform the check
88 | Returns:
89 | --------
90 | flag: Updated flag DataArray
91 | """
92 | for station in flag["station"].values:
93 | neighbours = _select_neighbouring_stations(
94 | similarity.sel(station=station),
95 | data.sel(station=station),
96 | lat=data.sel(station=station)["lat"].item(),
97 | lon=data.sel(station=station)["lon"].item(),
98 | data=data,
99 | max_dist=max_dist,
100 | max_elev_diff=max_elev_diff,
101 | min_data_overlap=min_data_overlap,
102 | )
103 | # Only apply the check to stations with at least 3 neighbouring stations
104 | if len(neighbours) < 3:
105 | continue
106 | results = []
107 |
108 | for target in neighbours:
109 | ith_flag = _time_series_comparison(
110 | data.sel(station=station).values,
111 | data.sel(station=target).values,
112 | shift_step=shift_step,
113 | gap_scale=gap_scale,
114 | default_mad=default_mad,
115 | suspect_std_scale=suspect_std_scale,
116 | min_num=min_num,
117 | )
118 | results.append(ith_flag)
119 | results = np.stack(results, axis=0)
120 | # For each time step, count the number of valid neighbouring data points
121 | num_neighbours = np.sum(results != CONFIG["flag_missing"], axis=0)
122 | # For each time step, still only consider the data points with at least 3 neighbouring data points
123 | num_neighbours = np.where(num_neighbours < 3, np.nan, num_neighbours)
124 | # Mask when 2 / 3 of the neighbouring data points are in agreement
125 | # Specifically, set erroneous only when there is no normal flags
126 | min_stn = num_neighbours * 2 / 3
127 | normal = (results == CONFIG["flag_normal"]).sum(axis=0)
128 | suspect = ((results == CONFIG["flag_suspect"]) | (results == CONFIG["flag_error"])).sum(axis=0)
129 | erroneous = (results == CONFIG["flag_error"]).sum(axis=0)
130 | aggregated = np.full_like(normal, CONFIG["flag_missing"])
131 | aggregated = np.where(normal >= min_stn, CONFIG["flag_normal"], aggregated)
132 | aggregated = np.where(suspect >= min_stn, CONFIG["flag_suspect"], aggregated)
133 | aggregated = np.where((erroneous >= min_stn) & (normal == 0), CONFIG["flag_error"], aggregated)
134 | flag.loc[{"station": station}] = aggregated
135 | return flag
136 |
137 |
138 | def run(da, f_similarity, varname):
139 | """
140 | Check the data by comparing with neighbouring stations
141 | Time series comparison is performed for each station with at least 3 neighbouring stations
142 | `map_blocks` is used to parallelize the calculation
143 | """
144 | similarity = _load_similarity(f_similarity, da)
145 |
146 | flag = xr.DataArray(
147 | np.full(da.shape, CONFIG["flag_missing"], dtype=np.int8),
148 | dims=["station", "time"],
149 | coords={k: da.coords[k].values for k in da.dims}
150 | )
151 | ret = xr.map_blocks(
152 | _neighbouring_stations_check_base,
153 | flag.chunk({"station": 500}),
154 | args=(similarity.chunk({"station": 500}), ),
155 | kwargs={"data": da, **CONFIG["neighbouring"][varname]},
156 | template=flag.chunk({"station": 500}),
157 | ).compute(scheduler='processes')
158 | ret = ret.where(da.notnull(), CONFIG["flag_missing"])
159 | quality_control_statistics(da, ret)
160 | return ret.rename("neighbouring_stations")
161 |
--------------------------------------------------------------------------------
/quality_control/algo/persistence.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import bottleneck as bn
3 | from .utils import get_config, intra_station_check, quality_control_statistics
4 |
5 |
6 | CONFIG = get_config()
7 |
8 |
9 | def _persistence_main(ts, unset_flag, min_num, max_window, min_var, error_length, exclude_value=None):
10 | """
11 | Perform persistence check on a time series.
12 |
13 | This function identifies periods of low variability in the time series and flags them as suspect or error.
14 |
15 | Parameters:
16 | -----------
17 | ts (array-like): The input time series.
18 | unset_flag (array-like): Pre-existing flags for the time series.
19 | min_num (int): Minimum number of valid values required in a window.
20 | max_window (int): Maximum size of the moving window.
21 | min_var (float): Minimum allowed standard deviation in a window.
22 | error_length (int): Minimum length of a period to be flagged as error.
23 | exclude_value (float or list, optional): Value(s) to exclude from the analysis.
24 |
25 | Algorithm:
26 | ----------
27 | 1. Scan the time series with moving windows of decreasing size (max_window to min_num).
28 | 2. Flag windows with standard deviation < min_var as suspect.
29 | 3. Merge overlapping suspect windows.
30 | 4. Flag merged windows as error if their length >= error_length.
31 | 5. Flag remaining suspect windows based on their length and pre-existing flags.
32 |
33 | Returns:
34 | --------
35 | numpy.ndarray: 1D array of flags with the same length as the input time series.
36 | """
37 | flag = np.full(ts.size, CONFIG["flag_missing"], dtype=np.int8)
38 | suspect_windows = []
39 |
40 | # Bottleneck for >100x faster moving window implementations but it only works well with float64
41 | ts = ts.astype(np.float64)
42 | if exclude_value is not None:
43 | if isinstance(exclude_value, list):
44 | for value in exclude_value:
45 | ts[ts == value] = np.nan
46 | else:
47 | ts[ts == exclude_value] = np.nan
48 |
49 | for window_size in range(max_window, min_num-1, -1):
50 | # min_count will ensure that the std is calculated only when there are enough valid values
51 | std = bn.move_std(ts, window=window_size, min_count=min_num)[window_size-1:]
52 | valid_indices = np.argwhere(~np.isnan(std)).flatten()
53 | for shift in range(window_size):
54 | # Values checked in at least one window are considered as valid
55 | flag[valid_indices + shift] = CONFIG["flag_normal"]
56 | error_index = np.argwhere(std < min_var).flatten()
57 | if error_index.size == 0:
58 | continue
59 | suspect_windows.extend([(i, i+window_size) for i in error_index])
60 |
61 | isvalid = ~np.isnan(ts)
62 | if len(suspect_windows) == 0:
63 | flag[~isvalid] = CONFIG["flag_missing"]
64 | return flag
65 |
66 | # trim the NaNs at both ends of each window
67 | for idx, (start, end) in enumerate(suspect_windows):
68 | start = start + np.argmax(isvalid[start:end])
69 | end = end - np.argmax(isvalid[start:end][::-1])
70 | suspect_windows[idx] = (start, end)
71 |
72 | # Combine the overlapping windows
73 | suspect_windows.sort(key=lambda x: x[0])
74 | suspect_windows_merged = [suspect_windows[0]]
75 | for current in suspect_windows[1:]:
76 | last = suspect_windows_merged[-1]
77 | if current[0] < last[1]:
78 | suspect_windows_merged[-1] = (last[0], max(last[1], current[1]))
79 | else:
80 | suspect_windows_merged.append(current)
81 |
82 | for start, end in suspect_windows_merged:
83 | if end - start >= error_length:
84 | flag[start:end] = CONFIG["flag_error"]
85 | else:
86 | # Set error flag only when more than 5% values are flagged by other methods
87 | num_suspend = (unset_flag[start:end] == CONFIG["flag_suspect"]).sum()
88 | num_error = (unset_flag[start:end] == CONFIG["flag_error"]).sum()
89 | num_valid = isvalid[start:end].sum()
90 | if num_suspend + num_error > num_valid * 0.05:
91 | flag[start:end] = CONFIG["flag_error"]
92 | else:
93 | flag[start:end] = CONFIG["flag_suspect"]
94 |
95 | flag[~isvalid] = CONFIG["flag_missing"]
96 | return flag
97 |
98 |
99 | def run(da, unset_flag, varname):
100 | flag = intra_station_check(
101 | da,
102 | unset_flag,
103 | qc_func=_persistence_main,
104 | input_core_dims=[["time"], ["time"]],
105 | kwargs=CONFIG["persistence"][varname],
106 | )
107 | quality_control_statistics(da, flag)
108 | return flag.rename("persistence")
109 |
--------------------------------------------------------------------------------
/quality_control/algo/record_extreme.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from .utils import get_config, intra_station_check, quality_control_statistics
3 |
4 |
5 | CONFIG = get_config()
6 |
7 |
8 | def _record_extreme_main(ts, upper, lower):
9 | """
10 | Flag outliers as erroneous based on the record extremes.
11 |
12 | Parameters:
13 | -----------
14 | ts : np.ndarray
15 | 1D time series to be checked
16 | upper : float
17 | Upper bound of the record extreme
18 | lower : float
19 | Lower bound of the record extreme
20 |
21 | Returns:
22 | --------
23 | flag : np.ndarray
24 | 1D array with the same length as ts, containing flags
25 | """
26 | flag_upper = ts > upper
27 | flag_lower = ts < lower
28 | flag = np.full(ts.shape, CONFIG["flag_normal"], dtype=np.int8)
29 | flag[np.logical_or(flag_upper, flag_lower)] = CONFIG["flag_error"]
30 | flag[np.isnan(ts)] = CONFIG["flag_missing"]
31 | return flag
32 |
33 |
34 | def run(da, varname):
35 | flag = intra_station_check(
36 | da,
37 | qc_func=_record_extreme_main,
38 | kwargs=CONFIG["record"][varname],
39 | )
40 | quality_control_statistics(da, flag)
41 | return flag.rename("record_extreme")
42 |
--------------------------------------------------------------------------------
/quality_control/algo/refinement.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from .utils import get_config, intra_station_check
3 |
4 |
5 | CONFIG = get_config()
6 |
7 |
8 | def _is_monotonic(ts, start_idx, end_idx):
9 | # Search the left and right nearest non-error values
10 | for left_idx in range(start_idx-1, max(0, start_idx-4), -1):
11 | # Note that here the found value is bound to be non-error
12 | if np.isfinite(ts[left_idx]):
13 | break
14 | else:
15 | left_idx = None
16 | for right_idx in range(end_idx, min(end_idx+3, len(ts))):
17 | if np.isfinite(ts[right_idx]):
18 | break
19 | else:
20 | right_idx = None
21 | # If these values are monotonic, downgrading the error flag to suspect
22 | if left_idx is not None and right_idx is not None:
23 | ts_period = ts[left_idx: right_idx + 1]
24 | ts_period = ts_period[~np.isnan(ts_period)]
25 | diff = np.diff(ts_period)
26 | if (diff > 0).all() or (diff < 0).all():
27 | return True
28 | return False
29 |
30 |
31 | def _is_ridge_or_trough(ts, start_idx, end_idx):
32 | # Search the left and right nearest 3 non-error values
33 | left_values = ts[max(0, start_idx-12): max(0, start_idx)]
34 | left_values = left_values[np.isfinite(left_values)]
35 | if len(left_values) < 4:
36 | return False
37 | left_values = left_values[-4:]
38 | right_values = ts[min(len(ts), end_idx): min(len(ts), end_idx+12)]
39 | right_values = right_values[np.isfinite(right_values)]
40 | if len(right_values) < 4:
41 | return False
42 | right_values = right_values[:4]
43 |
44 | ts_period = ts[start_idx: end_idx]
45 | ts_period = np.concatenate([left_values, ts_period[np.isfinite(ts_period)], right_values])
46 | min_idx = np.argmin(ts_period)
47 | if 3 < min_idx < len(ts_period) - 4:
48 | diff = np.diff(ts_period)
49 | # Check if it is a trough (e.g., low pressure)
50 | if (diff[:min_idx] < 0).all() and (diff[min_idx:] > 0).all():
51 | return True
52 | max_idx = np.argmax(ts_period)
53 | if 3 < max_idx < len(ts_period) - 4:
54 | diff = np.diff(ts_period)
55 | # Check if it is a ridge (e.g., temperature peak)
56 | if (diff[:max_idx] > 0).all() and (diff[max_idx:] < 0).all():
57 | return True
58 | return False
59 |
60 |
61 | def _refine_flag(cur_flag, ts, check_monotonic, check_ridge_trough):
62 | """
63 | Refine the flags from other algorithms:
64 | check_monotonic: If True, downgrade error flags to suspect if they are
65 | situated in a monotonic period in-between non-error values
66 | check_ridge_trough: If True, downgrade error flags to suspect if they are
67 | situated in a ridge or trough in-between non-error values
68 | Returns:
69 | --------
70 | new_flag: The refined flags
71 | """
72 | new_flag = cur_flag.copy()
73 | if not check_monotonic and not check_ridge_trough:
74 | return new_flag
75 | length = len(cur_flag)
76 | error_flags = np.argwhere(cur_flag == CONFIG["flag_error"]).flatten()
77 | cur_idx = 0
78 | for idx in error_flags:
79 | if cur_idx > idx:
80 | continue
81 | cur_idx = idx + 1
82 | num_nan = 0
83 | num_error = 1
84 | # Search for the next non-error value
85 | while num_nan <= 2 and cur_idx < length:
86 | # If there are more than 3 consecutive missing values or error flags, do nothing
87 | if np.isnan(ts[cur_idx]):
88 | num_nan += 1
89 | cur_idx += 1
90 | continue
91 | if cur_flag[cur_idx] == CONFIG["flag_error"]:
92 | num_nan = 0
93 | num_error += 1
94 | cur_idx += 1
95 | continue
96 | # If a non-error value is found, check if it is monotonic
97 | if num_error > 3:
98 | break
99 | if check_monotonic and _is_monotonic(ts, idx, cur_idx):
100 | new_flag[idx:cur_idx] = CONFIG["flag_suspect"]
101 | if check_ridge_trough and _is_ridge_or_trough(ts, idx, cur_idx):
102 | new_flag[idx:cur_idx] = CONFIG["flag_suspect"]
103 | break
104 | new_flag[np.isnan(ts)] = CONFIG["flag_missing"]
105 | return new_flag
106 |
107 |
108 | def run(da, flag, varname):
109 | flag = intra_station_check(
110 | flag,
111 | da,
112 | qc_func=_refine_flag,
113 | input_core_dims=[["time"], ["time"]],
114 | output_core_dims=[["time"]],
115 | kwargs={
116 | "check_monotonic": CONFIG["refinement"][varname]["check_monotonic"],
117 | "check_ridge_trough": CONFIG["refinement"][varname]["check_ridge_trough"]
118 | },
119 | )
120 | return flag.rename("refinement")
121 |
--------------------------------------------------------------------------------
/quality_control/algo/spike.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import bottleneck as bn
3 | from .utils import get_config, intra_station_check, quality_control_statistics
4 |
5 |
6 | CONFIG = get_config()
7 |
8 |
9 | def _spike_check_forward(ts, unset_flag, max_change):
10 | """
11 | Perform a forward spike check on a time series.
12 |
13 | A spike is defined as:
14 | 1. An abrupt change in value
15 | 2. Followed by an abrupt change back
16 | 3. With no abrupt changes during the spike
17 |
18 | This function flags values as suspect at both ends of abrupt changes and all values within a spike.
19 | Due to the high sensitivity of spike detection, if all values in a spike are at least suspect
20 | in unset_flag, they will be flagged as errors.
21 |
22 | Parameters:
23 | -----------
24 | ts : array-like
25 | The input time series to check for spikes.
26 | unset_flag : array-like
27 | Reference flags from other methods.
28 | max_change : array-like
29 | 1D array with 3 elements, representing the maximum allowed change for 1, 2, and 3 steps.
30 |
31 | Returns:
32 | --------
33 | flag : numpy.ndarray
34 | 1D array with the same length as ts, containing flags for each value.
35 | """
36 | length = len(ts)
37 | flag = np.full(length, CONFIG["flag_missing"], dtype=np.int8)
38 | isnan = np.isnan(ts).astype(int)
39 | indices = []
40 | gaps = []
41 | # Get all the indices of potential spikes
42 | for step in range(1, 4, 1):
43 | diff = np.abs(ts[step:] - ts[:-step])
44 | # Flag the values where the variation is checked
45 | flag[step:] = np.where(
46 | ~np.isnan(diff) & (flag[step:] == CONFIG["flag_missing"]),
47 | CONFIG["flag_normal"],
48 | flag[step:],
49 | )
50 | condition = diff > max_change[step - 1]
51 | if step > 1:
52 | # For step larger than 1, only conditions that the in-between values are all NaN are considered
53 | allnan = bn.move_sum(isnan, window=step - 1)[step - 1: -1] == step - 1
54 | condition = condition & allnan
55 | indices_step = np.argwhere(condition).flatten() + step
56 | # flag both sides of the abrupt change as suspect
57 | flag[indices_step] = CONFIG["flag_suspect"]
58 | flag[indices_step - step] = CONFIG["flag_suspect"]
59 | # Save the gap between abnormal variations to be used later
60 | gaps.extend(np.full(len(indices_step), step))
61 | indices.extend(indices_step)
62 | # Combine the potential spikes
63 | sorted_indices = np.argsort(indices)
64 | indices = np.array(indices)[sorted_indices]
65 | gaps = np.array(gaps)[sorted_indices]
66 |
67 | cur_idx = 0
68 | # Iterate all potential spikes
69 | for case_idx, start_idx in enumerate(indices):
70 | # To avoid checking on the end of the spike
71 | if start_idx <= cur_idx:
72 | continue
73 | # The value before the spike
74 | leftv = ts[start_idx - gaps[case_idx]]
75 | # The newest value in the spike
76 | lastv = ts[start_idx]
77 | # A threshold for detecting the end of a spike
78 | return_diff = np.abs((lastv - leftv) / 2)
79 | # The direction of the variation, positive(negative) for decreasing(increasing)
80 | change_sign = np.sign(leftv - lastv)
81 | cur_idx = start_idx + 1
82 | num_nan = 0
83 | # Start to search for the end of the spike. The spike should be shorter than 72 steps
84 | while cur_idx < length and cur_idx - start_idx < 72:
85 | if np.isnan(ts[cur_idx]):
86 | num_nan += 1
87 | cur_idx += 1
88 | # Consider 3 continuous NaNs as end of a spike
89 | if num_nan >= 3:
90 | # In experimental tests, if the value changes drastically and then stop recording,
91 | # this single value is considered as a spike
92 | if cur_idx - start_idx <= 4:
93 | break
94 | # Else, it is not considered as a spike
95 | cur_idx = start_idx - 1
96 | break
97 | continue
98 | num_nan = 0
99 |
100 | isabrupt = np.abs(ts[cur_idx] - lastv) > max_change[num_nan]
101 | isopposite = (lastv - ts[cur_idx]) * change_sign < 0
102 | isnear = np.abs(ts[cur_idx] - leftv) <= return_diff
103 | if not isabrupt:
104 | # if the value changes back slowly, it is considered as normal variation
105 | if isnear:
106 | cur_idx = start_idx - 1
107 | break
108 | # if there is no abrupt change, and the value is still far from the original value
109 | # continue searching for the end of the spike
110 | lastv = ts[cur_idx]
111 | cur_idx += 1
112 | continue
113 | # if there is an abrupt change, stop searching
114 | # If the value goes back with an opposite abrupt change, it is considered as the end of the spike
115 | if isopposite and isnear:
116 | break
117 | # Else, skip this case to avoid too complex situations
118 | cur_idx = start_idx - 1
119 | break
120 | else:
121 | cur_idx = start_idx - 1
122 | # Only flag the spike as erroneous when all the values are at least suspect
123 | if (unset_flag[start_idx:cur_idx] != CONFIG["flag_normal"]).all():
124 | flag[start_idx:cur_idx] = CONFIG["flag_error"]
125 | else:
126 | flag[start_idx:cur_idx] = CONFIG["flag_suspect"]
127 | flag[isnan.astype(bool)] = CONFIG["flag_missing"]
128 | return flag
129 |
130 |
131 | def _bidirectional_spike_check(ts, unset_flag, max_change):
132 | """
133 | Perform bidirectional spike check on the time series so that when there is an abrupt change,
134 | both directions will be checked
135 | The combined result is the higher flag of the two directions
136 | """
137 | flag_forward = _spike_check_forward(ts, unset_flag, max_change)
138 | flag_backward = _spike_check_forward(ts[::-1], unset_flag[::-1], max_change)[::-1]
139 | flag = np.full(len(ts), CONFIG["flag_missing"], dtype=np.int8)
140 | for flag_type in ["normal", "suspect", "error"]:
141 | for direction in [flag_forward, flag_backward]:
142 | flag[direction == CONFIG[f"flag_{flag_type}"]] = CONFIG[f"flag_{flag_type}"]
143 | return flag
144 |
145 |
146 | def run(da, unset_flag, varname):
147 | flag = intra_station_check(
148 | da,
149 | unset_flag,
150 | qc_func=_bidirectional_spike_check,
151 | input_core_dims=[["time"], ["time"]],
152 | kwargs=CONFIG["spike"][varname],
153 | )
154 | quality_control_statistics(da, flag)
155 | return flag.rename("spike")
156 |
--------------------------------------------------------------------------------
/quality_control/algo/time_series.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy import stats
3 | from .utils import get_config
4 |
5 |
6 | CONFIG = get_config()
7 |
8 |
9 | def _get_mean_and_mad(values):
10 | mean = np.median(values)
11 | mad = stats.median_abs_deviation(values)
12 | return mean, mad
13 |
14 |
15 | def _time_series_comparison(
16 | ts1,
17 | ts2,
18 | shift_step,
19 | gap_scale,
20 | default_mad,
21 | suspect_std_scale,
22 | min_mad=0.1,
23 | min_num=None,
24 | mask=None,
25 | ):
26 | """
27 | Perform time series comparison between two datasets and flag potential errors.
28 |
29 | This function compares two time series (ts1 and ts2) and identifies suspect and erroneous values
30 | based on their differences. It uses a robust statistical approach to handle outliers and
31 | can accommodate temporal shifts between the series.
32 |
33 | Parameters:
34 | -----------
35 | ts1 : array-like
36 | The primary time series to be checked.
37 | ts2 : array-like
38 | The reference time series for comparison.
39 | shift_step : int
40 | The maximum number of time steps to shift ts2 when comparing with ts1.
41 | gap_scale : float
42 | The scale factor applied to the median absolute deviation (MAD) to determine the gap threshold.
43 | default_mad : float
44 | The default MAD value to use when the calculated MAD is too small or when there are insufficient data points.
45 | suspect_std_scale : float
46 | The number of standard deviations from the mean to set the initial suspect threshold.
47 | min_mad : float, optional
48 | The minimum MAD value to calculate standard deviation, default is 0.1.
49 | min_num : int, optional
50 | The minimum number of valid data points required for robust statistics calculation.
51 | If the number of valid points is less than this, default values are used.
52 | mask : array-like, optional
53 | Boolean mask to select a subset of ts1 for calculating bounds. If None, all values are used.
54 | True for normal values, False for outliers
55 |
56 | Returns:
57 | --------
58 | flag : numpy.ndarray
59 | 1D array with the same length as ts, containing flags for each value.
60 | """
61 | diff = ts1 - ts2
62 | values = diff[~np.isnan(diff)]
63 |
64 | # Apply mask to diff if provided
65 | if mask is not None:
66 | masked_diff = diff[mask]
67 | values = masked_diff[~np.isnan(masked_diff)]
68 | else:
69 | values = diff[~np.isnan(diff)]
70 |
71 | if values.size == 0:
72 | return np.full(diff.size, CONFIG["flag_missing"], dtype=np.int8)
73 |
74 | if min_num is not None and values.size < min_num:
75 | fixed_mean = 0
76 | mad = default_mad
77 | else:
78 | # An estimate of the Gaussian distribution of the data which is calculated by the median and MAD
79 | # so that it is robust to outliers
80 | fixed_mean, mad = _get_mean_and_mad(values)
81 | mad = max(min_mad, mad)
82 |
83 | # Get the suspect threshold by the distance to the mean in the unit of standard deviation
84 | # Reference: y = 0.1, scale = 1.67; y = 0.05, scale = 2.04; y = 0.01, scale = 2.72
85 | # If the standard deviation estimated by MAD is too small, a default value is used
86 | fixed_std = max(default_mad, mad) * 1.4826
87 | init_upper_bound = fixed_mean + fixed_std * suspect_std_scale
88 | # For observations that the actual precision is integer, the upper and lower bounds are rounded up
89 | is_integer = np.nanmax(ts1 % 1) < 0.1 or np.nanmax(ts2 % 1) < 0.1
90 | if is_integer:
91 | init_upper_bound = np.ceil(init_upper_bound)
92 | # Set the erroneous threshold by find a gap larger than a multiple of the MAD
93 | larger_values = np.insert(np.sort(values[values > init_upper_bound]), 0, init_upper_bound)
94 | # Try to get the index of first value where the gap larger than min_gap
95 | gap = mad * gap_scale
96 | large_gap = np.diff(larger_values) > gap
97 | # If a gap is not found, no erroneous threshold is set
98 | upper_bound = larger_values[np.argmax(large_gap)] if large_gap.any() else np.max(values)
99 | if is_integer:
100 | upper_bound = np.ceil(upper_bound)
101 |
102 | init_lower_bound = fixed_mean - fixed_std * suspect_std_scale
103 | if is_integer:
104 | init_lower_bound = np.floor(init_lower_bound)
105 | smaller_values = np.insert(np.sort(values[values < init_lower_bound])[::-1], 0, init_lower_bound)
106 | small_gap = np.diff(smaller_values) < -gap
107 | lower_bound = smaller_values[np.argmax(small_gap)] if small_gap.any() else np.min(values)
108 | if is_integer:
109 | lower_bound = np.floor(lower_bound)
110 |
111 | min_diff = diff
112 | # If shift_step > 0, the values in ts1 are also compared to neighboring values in ts2
113 | # The minimum difference is kept to be inclusive of a certain degree of temporal deviation
114 | for shift in np.arange(-shift_step, shift_step+1, 1):
115 | diff_shifted = ts1 - np.roll(ts2, shift)
116 | if shift == 0:
117 | continue
118 | if shift > 0:
119 | diff_shifted[:shift] = np.nan
120 | elif shift < 0:
121 | diff_shifted[shift:] = np.nan
122 | min_diff = np.where(np.abs(diff_shifted - fixed_mean) < np.abs(min_diff - fixed_mean), diff_shifted, min_diff)
123 |
124 | flag = np.full_like(min_diff, CONFIG["flag_normal"], dtype=np.int8)
125 | flag[min_diff < init_lower_bound] = CONFIG["flag_suspect"]
126 | flag[min_diff > init_upper_bound] = CONFIG["flag_suspect"]
127 | flag[min_diff < lower_bound] = CONFIG["flag_error"]
128 | flag[min_diff > upper_bound] = CONFIG["flag_error"]
129 | flag[np.isnan(min_diff)] = CONFIG["flag_missing"]
130 | return flag
131 |
--------------------------------------------------------------------------------
/quality_control/algo/utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import numpy as np
4 | import xarray as xr
5 | import yaml
6 |
7 |
8 | logger = logging.getLogger(__name__)
9 |
10 |
11 | def configure_logging(verbose=1):
12 | verbose_levels = {
13 | 0: logging.WARNING,
14 | 1: logging.INFO,
15 | 2: logging.DEBUG,
16 | 3: logging.NOTSET
17 | }
18 | if verbose not in verbose_levels:
19 | verbose = 1
20 | root_logger = logging.getLogger()
21 | root_logger.setLevel(verbose_levels[verbose])
22 | handler = logging.StreamHandler()
23 | handler.setFormatter(logging.Formatter(
24 | "[%(asctime)s] [PID=%(process)d] "
25 | "[%(levelname)s %(filename)s:%(lineno)d] %(message)s"))
26 | handler.setLevel(verbose_levels[verbose])
27 | root_logger.addHandler(handler)
28 |
29 |
30 | class Config:
31 | _instance = None
32 |
33 | def __new__(cls, config_path=None):
34 | if cls._instance is None:
35 | if config_path is None:
36 | config_path = os.path.join(os.path.dirname(__file__), "config.yaml")
37 | cls._instance = super(Config, cls).__new__(cls)
38 | cls._instance.__init__(config_path)
39 | return cls._instance
40 |
41 | def __init__(self, config_path):
42 | if not hasattr(self, 'config'):
43 | self._load_config(config_path)
44 |
45 | def _load_config(self, config_path):
46 | with open(config_path, "r", encoding="utf-8") as file:
47 | self.config = yaml.safe_load(file)
48 |
49 | def get(self, key):
50 | if key not in self.config:
51 | raise KeyError(f"Key '{key}' not found in configuration")
52 | return self.config[key]
53 |
54 |
55 | def get_config(config_path=None):
56 | return Config(config_path).config
57 |
58 |
59 | def intra_station_check(
60 | *dataarrays,
61 | qc_func=lambda da: da,
62 | input_core_dims=None,
63 | output_core_dims=None,
64 | kwargs=None,
65 | ):
66 | """
67 | A wrapper function to apply the quality control functions to each station in a DataArray
68 | Multiprocessing is implemented by Dask
69 | Parameters:
70 | dataarrays: DataArrays to be checked, with more auxiliary DataArrays if needed
71 | qc_func: quality control function to be applied
72 | input_core_dims: core dimensions (to be remained for the function) of each DataArray
73 | output_core_dims: core dimensions (to be remained from the function) of each function result
74 | kwargs: keyword arguments for the quality control function
75 | Return:
76 | flag: DataArray with the same shape as the first input DataArray
77 | """
78 | dataarrays_chunked = []
79 | for item in dataarrays:
80 | if isinstance(item, xr.DataArray):
81 | dataarrays_chunked.append(
82 | item.chunk({k: 100 if k == "station" else -1 for k in item.dims})
83 | )
84 | else:
85 | dataarrays_chunked.append(item)
86 | if input_core_dims is None:
87 | input_core_dims = [["time"]]
88 | if output_core_dims is None:
89 | output_core_dims = [["time"]]
90 | if kwargs is None:
91 | kwargs = {}
92 | flag = xr.apply_ufunc(
93 | qc_func,
94 | *dataarrays_chunked,
95 | input_core_dims=input_core_dims,
96 | output_core_dims=output_core_dims,
97 | kwargs=kwargs,
98 | vectorize=True,
99 | dask="parallelized",
100 | output_dtypes=[np.int8],
101 | ).compute(scheduler="processes")
102 | return flag
103 |
104 |
105 | CONFIG = get_config()
106 |
107 |
108 | def merge_flags(flags, priority=None):
109 | """
110 | Merge flags from different quality control functions in the order of priority
111 | Prior flags will be overwritten by subsequent flags
112 | """
113 | ret = xr.full_like(flags[0], CONFIG["flag_missing"], dtype=np.int8)
114 | if priority is None:
115 | priority = ["normal", "suspect", "error"]
116 | for flag_type in priority:
117 | for item in flags:
118 | ret = xr.where(item == CONFIG[f"flag_{flag_type}"], CONFIG[f"flag_{flag_type}"], ret)
119 | return ret
120 |
121 |
122 | def quality_control_statistics(data, flag):
123 | num_valid = data.notnull().sum().item()
124 | num_normal = (flag == CONFIG["flag_normal"]).sum().item()
125 | num_suspect = (flag == CONFIG["flag_suspect"]).sum().item()
126 | num_error = (flag == CONFIG["flag_error"]).sum().item()
127 | num_checked = num_normal + num_suspect + num_error
128 | logger.debug(f"{num_valid / data.size:.5%} of the data are valid")
129 | logger.debug(f"{num_checked / num_valid:.5%} of the valid data are checked")
130 | logger.debug(
131 | "%s/%s/%s of the checked data are flagged as normal/suspect/erroneous",
132 | f"{num_normal / num_checked:.5%}",
133 | f"{num_suspect / num_checked:.5%}",
134 | f"{num_error / num_checked:.5%}"
135 | )
136 | return num_valid, num_normal, num_suspect, num_error
137 |
--------------------------------------------------------------------------------
/quality_control/download_ISD.py:
--------------------------------------------------------------------------------
1 | """
2 | Download ISD data from NOAA
3 | All available stations on the server will be downloaded
4 | """
5 |
6 | import argparse
7 | import logging
8 | import multiprocessing
9 | from functools import partial
10 | from pathlib import Path
11 |
12 | import pandas as pd
13 | import requests
14 | from bs4 import BeautifulSoup
15 | from tqdm import tqdm
16 | from algo.utils import configure_logging
17 |
18 |
19 | URL_DATA = "https://www.ncei.noaa.gov/data/global-hourly/access/"
20 |
21 | # Disable the logging from urllib3
22 | logging.getLogger("urllib3.connectionpool").setLevel(logging.CRITICAL)
23 | logger = logging.getLogger(__name__)
24 |
25 |
26 | def download_file(station, url, output_dir, overwrite=False):
27 | response = requests.get(url + f"{station}.csv", timeout=30)
28 | output_path = Path(output_dir) / f"{station}.csv"
29 | if output_path.exists() and not overwrite:
30 | return True
31 | if response.status_code == 200:
32 | with open(output_path, "wb") as file:
33 | file.write(response.content)
34 | return True
35 | return False
36 |
37 |
38 | def download_ISD(year, station_list, output_dir, num_proc, overwrite=False):
39 | """
40 | Download the data from the source for the stations in station_list.
41 | """
42 | logger.info("Start downloading")
43 | output_dir = Path(output_dir) / str(year)
44 | if not output_dir.exists():
45 | output_dir.mkdir(parents=True, exist_ok=True)
46 |
47 | url = URL_DATA + f"/{year}/"
48 |
49 | if num_proc == 1:
50 | for item in tqdm(station_list, desc="Downloading files"):
51 | download_file(item, url, output_dir, overwrite)
52 | else:
53 | func = partial(download_file, url=url, output_dir=output_dir, overwrite=overwrite)
54 | with multiprocessing.Pool(num_proc) as pool:
55 | results = list(tqdm(pool.imap(func, station_list), total=len(station_list), desc="Downloading files"))
56 |
57 | successful_downloads = sum(results)
58 | logger.info(f"Successfully downloaded {successful_downloads} out of {len(station_list)} files")
59 |
60 |
61 | def main():
62 | parser = argparse.ArgumentParser()
63 | parser.add_argument(
64 | "-o", "--output-dir", type=str, required=True, help="Parent directory to save the downloaded data"
65 | )
66 | parser.add_argument("-y", "--year", type=int, default=pd.Timestamp.now().year, help="Year to download")
67 | parser.add_argument("--overwrite", action="store_true", help="Overwrite existing files")
68 | parser.add_argument("--verbose", type=int, default=1, help="Verbosity level (int >= 0)")
69 | parser.add_argument("--num-proc", type=int, default=16, help="Number of parallel processes")
70 | args = parser.parse_args()
71 | configure_logging(args.verbose)
72 |
73 | response = requests.get(URL_DATA + f"/{args.year}/", timeout=30)
74 | soup = BeautifulSoup(response.text, "html.parser")
75 | file_list = [link.get("href") for link in soup.find_all("a") if link.get("href").endswith(".csv")]
76 | station_list = [item.split(".")[0] for item in file_list]
77 | logger.info(f"Found {len(station_list)} stations for year {args.year}")
78 | download_ISD(args.year, station_list, args.output_dir, args.num_proc, args.overwrite)
79 |
80 |
81 | if __name__ == "__main__":
82 | main()
83 |
--------------------------------------------------------------------------------
/quality_control/quality_control.py:
--------------------------------------------------------------------------------
1 | """
2 | This script performs quality control checks on weather observation data using various algorithms.
3 | It flags and filters out erroneous or suspicious data points.
4 |
5 | Usage:
6 | python quality_control.py --obs-path OBS_PATH --rnl-path RNL_PATH --similarity-path SIM_PATH \
7 | --output-path OUTPUT_PATH [--config-path CONFIG_PATH] [--verbose VERBOSE]
8 |
9 | Input Files:
10 | - Observation data (--obs-path): NetCDF file generated at the previous step by `station_merging.py`.
11 | You can also use your own observation data. The dimension should be 'station' and 'time' and
12 | include all the hours in a same year,
13 | e.g., the dimension of WeatherReal-ISD in 2023: {'station': 13297, 'time': 8760}
14 | Supported variables: t, td, sp, msl, c, ws, wd, ra1, ra3, ra6, ra12, ra24.
15 | - Reanalysis data (--rnl-path): NetCDF file containing reanalysis data in the same format as obs.
16 | You can refer to `convert_grid_to_point` in `evaluation/forecast_reformat_catalog.py` to interpolations.
17 | - Similarity matrix (--similarity-path): File containing station similarity information.
18 | It is also generated at the previous step by `station_merging.py`.
19 | You can also generate your own similarity matrix with `calc_similarity` in `station_merging.py`.
20 | - Config file (--config-path): YAML file containing algorithm parameters.
21 | You can refer to the default settings - `config.yaml` in the `algo` directory as an example.
22 | Output:
23 | - Quality controlled observation data (--output-path): NetCDF file with erroneous data removed.
24 |
25 | Please refer to the WeatherReal paper for more details.
26 | """
27 |
28 |
29 | import argparse
30 | import logging
31 | import os
32 | import numpy as np
33 | import pandas as pd
34 | import xarray as xr
35 | from algo import (
36 | record_extreme,
37 | cluster,
38 | distributional_gap,
39 | neighbouring_stations,
40 | spike,
41 | persistence,
42 | cross_variable,
43 | refinement,
44 | diurnal_cycle,
45 | fine_tuning,
46 | )
47 | from algo.utils import merge_flags, Config, get_config, configure_logging
48 |
49 |
50 | logger = logging.getLogger(__name__)
51 | CONFIG = get_config()
52 |
53 |
54 | def load_data(obs_path, rnl_path):
55 | """
56 | Load observation and reanalysis data
57 | The reanalysis data should be in the same format as observation data (interpolated to the same stations)
58 | """
59 | logger.info("Loading observation and reanalysis data")
60 | obs = xr.load_dataset(obs_path)
61 | year = obs["time"].dt.year.values[0]
62 | if not (obs["time"].dt.year == year).all():
63 | raise ValueError("Data contains multiple years, which is not supported yet")
64 |
65 | full_hours = pd.date_range(
66 | start=pd.Timestamp(f"{year}-01-01"), end=pd.Timestamp(f"{year}-12-31 23:00:00"), freq='h')
67 | if obs["time"].size != full_hours.size or (obs["time"] != full_hours).any():
68 | logger.warning("Reindexing observation data to match full hours in the year")
69 | obs = obs.reindex(time=full_hours)
70 |
71 | # Please prepare the Reanalysis data in the same format as obs
72 | rnl = xr.load_dataset(rnl_path)
73 | if obs.sizes != rnl.sizes:
74 | raise ValueError("The sizes of obs and rnl are different")
75 | return obs, rnl
76 |
77 |
78 | def cross_variable_check(obs):
79 | flag_cross = xr.Dataset()
80 | # Super-saturation check
81 | flag_cross["t"] = cross_variable.supersaturation(obs["t"], obs["td"])
82 | flag_cross["td"] = flag_cross["t"].copy()
83 | # Wind consistency check
84 | flag_cross["ws"] = cross_variable.wind_consistency(obs["ws"], obs["wd"])
85 | flag_cross["wd"] = flag_cross["ws"].copy()
86 | flag_ra = cross_variable.ra_consistency(obs[["ra1", "ra3", "ra6", "ra12", "ra24"]])
87 | for varname in ["ra1", "ra3", "ra6", "ra12", "ra24"]:
88 | flag_cross[varname] = flag_ra[varname].copy()
89 | return flag_cross
90 |
91 |
92 | def quality_control(obs, rnl, f_similarity):
93 | varlist = obs.data_vars.keys()
94 | result = obs.copy()
95 |
96 | # Record extreme check
97 | flag_extreme = xr.Dataset()
98 | for varname in CONFIG["record"]:
99 | if varname not in varlist:
100 | continue
101 | logger.info(f"Record extreme check for {varname}...")
102 | flag_extreme[varname] = record_extreme.run(result[varname], varname)
103 | # For extreme value check, outliers are directly removed
104 | result[varname] = result[varname].where(flag_extreme[varname] != CONFIG["flag_error"])
105 |
106 | # Cluster deviation check
107 | flag_cluster = xr.Dataset()
108 | for varname in CONFIG["cluster"]:
109 | if varname not in varlist:
110 | continue
111 | logger.info(f"Cluster deviation check for {varname}...")
112 | flag_cluster[varname] = cluster.run(result[varname], rnl[varname], varname)
113 |
114 | # Distributional gap check
115 | flag_distribution = xr.Dataset()
116 | for varname in CONFIG["distribution"]:
117 | if varname not in varlist:
118 | continue
119 | logger.info(f"Distributional gap check for {varname}...")
120 | # Mask from cluster deviation check is used to exclude abnormal data in the following distributional gap check
121 | mask = flag_cluster[varname] == CONFIG["flag_normal"]
122 | flag_distribution[varname] = distributional_gap.run(result[varname], rnl[varname], varname, mask)
123 |
124 | # Neighbouring station check
125 | flag_neighbour = xr.Dataset()
126 | for varname in CONFIG["neighbouring"]:
127 | if varname not in varlist:
128 | continue
129 | logger.info(f"Neighbouring station check for {varname}...")
130 | flag_neighbour[varname] = neighbouring_stations.run(result[varname], f_similarity, varname)
131 |
132 | # Spike check
133 | flag_dis_neigh = xr.Dataset()
134 | flag_spike = xr.Dataset()
135 | for varname in CONFIG["spike"]:
136 | if varname not in varlist:
137 | continue
138 | logger.info(f"Spike check for {varname}...")
139 | # Merge flags from distributional gap and neighbouring station check for spike and persistence check
140 | flag_dis_neigh[varname] = merge_flags(
141 | [flag_distribution[varname], flag_neighbour[varname]], priority=["error", "suspect", "normal"]
142 | )
143 | flag_spike[varname] = spike.run(result[varname], flag_dis_neigh[varname], varname)
144 |
145 | # Persistence check
146 | flag_persistence = xr.Dataset()
147 | for varname in CONFIG["persistence"]:
148 | if varname not in varlist:
149 | continue
150 | logger.info(f"Persistence check for {varname}...")
151 | # Some variables are not checked by distributional gap or neighbouring station check
152 | if varname not in flag_dis_neigh:
153 | flag_dis_neigh_cur = xr.full_like(result[varname], CONFIG["flag_missing"], dtype=np.int8)
154 | else:
155 | flag_dis_neigh_cur = flag_dis_neigh[varname]
156 | flag_persistence[varname] = persistence.run(result[varname], flag_dis_neigh_cur, varname)
157 |
158 | # Cross variable check
159 | flag_cross = cross_variable_check(result)
160 |
161 | # Merge all flags
162 | flags = xr.Dataset()
163 | for varname in varlist:
164 | logger.info(f"Merging flags for {varname}...")
165 | merge_list = [
166 | item[varname]
167 | for item in [flag_extreme, flag_dis_neigh, flag_spike, flag_persistence, flag_cross]
168 | if varname in item
169 | ]
170 | flags[varname] = merge_flags(merge_list, priority=["normal", "suspect", "error"])
171 |
172 | # Flag refinement
173 | flags_refined = xr.Dataset()
174 | for varname in varlist:
175 | if varname not in CONFIG["refinement"].keys():
176 | flags_refined[varname] = flags[varname].copy()
177 | else:
178 | logger.info(f"Flag refinement for {varname}...")
179 | flags_refined[varname] = refinement.run(result[varname], flags[varname], varname)
180 | if varname == "t":
181 | flags_refined[varname] = diurnal_cycle.run(result[varname], flags_refined[varname])
182 |
183 | # Fine-tuning the flags
184 | flags_final = xr.Dataset()
185 | for varname in varlist:
186 | logger.info(f"Fine-tuning (upgrade suspect) flags for {varname}...")
187 | flags_final[varname] = fine_tuning.run(flags_refined[varname], obs[varname])
188 |
189 | flags = {
190 | "flag_extreme": flag_extreme,
191 | "flag_cluster": flag_cluster,
192 | "flag_distribution": flag_distribution,
193 | "flag_neighbour": flag_neighbour,
194 | "flag_spike": flag_spike,
195 | "flag_persistence": flag_persistence,
196 | "flag_cross": flag_cross,
197 | "flags_refined": flags_refined,
198 | "flags_final": flags_final,
199 | }
200 |
201 | return flags
202 |
203 |
204 | def main(args):
205 | obs, rnl = load_data(args.obs_path, args.rnl_path)
206 | flags = quality_control(obs, rnl, args.similarity_path)
207 | if args.output_flags_dir:
208 | logger.info(f"Saving flags to {args.output_flags_dir}")
209 | for algo_name, flag_spec in flags.items():
210 | flag_spec.to_netcdf(os.path.join(args.output_flags_dir, f"{algo_name}.nc"))
211 | flags_final = flags["flags_final"]
212 | for varname in obs.data_vars.keys():
213 | obs[varname] = obs[varname].where(flags_final[varname] != CONFIG["flag_error"])
214 | obs.to_netcdf(args.output_path)
215 | logger.info(f"Quality control finished. The results are saved to {args.output_path}")
216 |
217 |
218 | if __name__ == "__main__":
219 | parser = argparse.ArgumentParser()
220 | parser.add_argument(
221 | "--obs-path", type=str, required=True, help="Data path of observation to be quality controlled"
222 | )
223 | parser.add_argument(
224 | "--rnl-path", type=str, required=True, help="Data path of reanalysis data to be used for quality control"
225 | )
226 | parser.add_argument("--similarity-path", type=str, required=True, help="Data path of similarity matrix")
227 | parser.add_argument("--output-path", type=str, required=True, help="Data path of output data")
228 | parser.add_argument("--output-flags-dir", type=str, help="If specified, flags will also be saved")
229 | parser.add_argument(
230 | "--config-path", type=str, help="Path to the configuration file, default is config.yaml in the algo directory"
231 | )
232 | parser.add_argument("--verbose", type=int, default=1, help="Verbosity level (int >= 0)")
233 | parsed_args = parser.parse_args()
234 | configure_logging(parsed_args.verbose)
235 | Config(parsed_args.config_path)
236 | main(parsed_args)
237 |
--------------------------------------------------------------------------------
/quality_control/raw_ISD_to_hourly.py:
--------------------------------------------------------------------------------
1 | """
2 | This script is used for parsing ISD files and converting them to hourly data
3 | 1. Parse corresponding columns of each variable
4 | 2. Aggregate / reorganize columns if needed
5 | 3. Simple unit conversion
6 | 4. Aggregate to hourly data. Rows that represent same hour are merged by rules
7 | 5. No quality control is applied except removing records with original erroneous flags
8 | The outputs are still saved as csv files
9 | """
10 |
11 | import argparse
12 | import multiprocessing
13 | import os
14 | from functools import partial
15 |
16 | import pandas as pd
17 | from tqdm import tqdm
18 |
19 |
20 | # Quality code for erroneous values flagged by ISD
21 | # To collect as much data as possible and avoid some values flagged by unknown codes being excluded,
22 | # We choose the strategy "only values marked with known error tags will be rejected"
23 | # instead of "only values marked with known correct tags will be accepted"
24 | ERRONEOUS_FLAGS = ["3", "7"]
25 |
26 |
27 | def parse_temperature_col(data):
28 | """
29 | Process temperature and dew point temperature columns
30 | TMP/DEW column format: -0100,1
31 | Steps:
32 | 1. Set values flagged as erroneous/missing to NaN
33 | 2. Convert to float in Celsius
34 | """
35 | if "TMP" in data.columns and data["TMP"].notnull().any():
36 | data[["t", "t_qc"]] = data["TMP"].str.split(",", expand=True)
37 | data["t"] = data["t"].where(data["t"] != "+9999", pd.NA)
38 | data["t"] = data["t"].where(~data["t_qc"].isin(ERRONEOUS_FLAGS), pd.NA)
39 | # Scaling factor: 10
40 | data["t"] = data["t"].astype("Float32") / 10
41 | if "DEW" in data.columns and data["DEW"].notnull().any():
42 | data[["td", "td_qc"]] = data["DEW"].str.split(",", expand=True)
43 | data["td"] = data["td"].where(data["td"] != "+9999", pd.NA).astype("Float32") / 10
44 | data["td"] = data["td"].where(~data["td_qc"].isin(ERRONEOUS_FLAGS), pd.NA)
45 | data = data.drop(columns=["TMP", "DEW", "t_qc", "td_qc"], errors="ignore")
46 | return data
47 |
48 |
49 | def parse_wind_col(data):
50 | """
51 | Process wind speed and direction column
52 | WND column format: 267,1,N,0142,1
53 | N indicates normal (other values include Beaufort, Calm, etc.). Not used currently
54 | Steps:
55 | 1. Set values flagged as erroneous/missing to NaN
56 | Note that if one of ws or wd is missing, both are set to NaN
57 | Exception: If ws is 0 and wd is missing, wd is set to 0 (calm)
58 | 2. Convert wd to integer and ws to float in m/s
59 | """
60 | if "WND" in data.columns and data["WND"].notnull().any():
61 | data[["wd", "wd_qc", "wt", "ws", "ws_qc"]] = data["WND"].str.split(",", expand=True)
62 | # First, set wd to 0 if ws is valid 0 and wd is missing
63 | calm = (data["ws"] == "0000") & (~data["ws_qc"].isin(ERRONEOUS_FLAGS)) & (data["wd"] == "999")
64 | data.loc[calm, "wd"] = "000"
65 | data.loc[calm, "wd_qc"] = "1"
66 | # After that, if one of ws or wd is missing/erroneous, both are set to NaN
67 | non_missing = (data["wd"] != "999") & (data["ws"] != "9999")
68 | non_error = (~data["wd_qc"].isin(ERRONEOUS_FLAGS)) & (~data["ws_qc"].isin(ERRONEOUS_FLAGS))
69 | valid = non_missing & non_error
70 | data["wd"] = data["wd"].where(valid, pd.NA)
71 | data["ws"] = data["ws"].where(valid, pd.NA)
72 | data["wd"] = data["wd"].astype("Int16")
73 | # Scaling factor: 10
74 | data["ws"] = data["ws"].astype("Float32") / 10
75 | data = data.drop(columns=["WND", "wd_qc", "wt", "ws_qc"], errors="ignore")
76 | return data
77 |
78 |
79 | def parse_cloud_col(data):
80 | """
81 | Process total cloud cover column
82 | All known columns including GA1-6, GD1-6, GF1 and GG1-6 are parsed and Maximum value of them is selected
83 | 1. GA1-6 column format: 07,1,+00800,1,06,1
84 | The 1st and 2nd items are c and its quality
85 | 2. GD1-6 column format: 3,99,1,+05182,9,9
86 | The 1st item is cloud cover in 0-4 and is converted to octas by multiplying 2
87 | The 2st and 3nd items are c in octas and its quality
88 | 3. GF1 column format: 07,99,1,07,1,99,9,01000,1,99,9,99,9
89 | The 1st and 3rd items are total coverage and its quality
90 | 4. GG1-6 column format: 01,1,01200,1,06,1,99,9
91 | The 1st and 2nd items are c and its quality
92 | Cloud/sky-condition related data is very complex and worth further investigation
93 | Steps:
94 | 1. Set values flagged as erroneous/missing to NaN
95 | 2. Select the maximum value of all columns
96 | """
97 | num = 0
98 | for group in ["GA", "GG"]:
99 | for col in [f"{group}{i}" for i in range(1, 7)]:
100 | if col in data.columns and data[col].notnull().any():
101 | data[[f"c{num}", "c_qc", "remain"]] = data[col].str.split(",", n=2, expand=True)
102 | # 99 will be removed later
103 | data[f"c{num}"] = data[f"c{num}"].where(~data["c_qc"].isin(ERRONEOUS_FLAGS), pd.NA)
104 | data[f"c{num}"] = data[f"c{num}"].astype("Int16")
105 | num += 1
106 | else:
107 | break
108 | for col in [f"GD{i}" for i in range(1, 7)]:
109 | if col in data.columns and data[col].notnull().any():
110 | data[[f"c{num}", f"c{num+1}", "c_qc", "remain"]] = data[col].str.split(",", n=3, expand=True)
111 | c_cols = [f"c{num}", f"c{num+1}"]
112 | data[c_cols] = data[c_cols].where(~data["c_qc"].isin(ERRONEOUS_FLAGS), pd.NA)
113 | data[c_cols] = data[c_cols].astype("Int16")
114 | # The first item is 5-level cloud cover and is converted to octas by multiplying 2
115 | data[f"c{num}"] = data[f"c{num}"] * 2
116 | num += 2
117 | else:
118 | break
119 | if "GF1" in data.columns and data["GF1"].notnull().any():
120 | data[[f"c{num}", "opa", "c_qc", "remain"]] = data["GF1"].str.split(",", n=3, expand=True)
121 | data[f"c{num}"] = data[f"c{num}"].where(~data["c_qc"].isin(ERRONEOUS_FLAGS), pd.NA)
122 | data[f"c{num}"] = data[f"c{num}"].astype("Int16")
123 | num += 1
124 | c_cols = [f"c{i}" for i in range(num)]
125 | # Mask all values larger than 8 to NaN to avoid overwriting the correct values
126 | data[c_cols] = data[c_cols].where(data[c_cols] <= 8, pd.NA)
127 | # Maximum value of all columns is selected to represent the total cloud cover
128 | data["c"] = data[c_cols].max(axis=1)
129 | data = data.drop(
130 | columns=[
131 | "GF1",
132 | *[f"GA{i}" for i in range(1, 7)],
133 | *[f"GG{i}" for i in range(1, 7)],
134 | *[f"GD{i}" for i in range(1, 7)],
135 | *[f"c{i}" for i in range(num)],
136 | "c_5",
137 | "opa",
138 | "c_qc",
139 | "remain",
140 | ],
141 | errors="ignore",
142 | )
143 | return data
144 |
145 |
146 | def parse_surface_pressure_col(data):
147 | """
148 | Process surface pressure (station-level pressure) column
149 | Currently MA1 column is used. Column format: 99999,9,09713,1
150 | The 3rd and 4th items are station pressure and its quality
151 | The 1st and 2nd items are altimeter setting and its quality which are not used currently
152 | Steps:
153 | 1. Set values flagged as erroneous/missing to NaN
154 | 2. Convert to float in hPa
155 | """
156 | if "MA1" in data.columns and data["MA1"].notnull().any():
157 | data[["MA1_remain", "sp", "sp_qc"]] = data["MA1"].str.rsplit(",", n=2, expand=True)
158 | data["sp"] = data["sp"].where(data["sp"] != "99999", pd.NA)
159 | data["sp"] = data["sp"].where(~data["sp_qc"].isin(ERRONEOUS_FLAGS), pd.NA)
160 | # Scaling factor: 10
161 | data["sp"] = data["sp"].astype("Float32") / 10
162 | data = data.drop(columns=["MA1", "MA1_remain", "sp_qc"], errors="ignore")
163 | return data
164 |
165 |
166 | def parse_sea_level_pressure_col(data):
167 | """
168 | Process mean sea level pressure column
169 | MSL Column format: 09725,1
170 | Steps:
171 | 1. Set values flagged as erroneous/missing to NaN
172 | 2. Convert to float in hPa
173 | """
174 | if "SLP" in data.columns and data["SLP"].notnull().any():
175 | data[["msl", "msl_qc"]] = data["SLP"].str.rsplit(",", expand=True)
176 | data["msl"] = data["msl"].where(data["msl"] != "99999", pd.NA)
177 | data["msl"] = data["msl"].where(~data["msl_qc"].isin(ERRONEOUS_FLAGS), pd.NA)
178 | # Scaling factor: 10
179 | data["msl"] = data["msl"].astype("Float32") / 10
180 | data = data.drop(columns=["SLP", "msl_qc"], errors="ignore")
181 | return data
182 |
183 |
184 | def parse_single_precipitation_col(data, col):
185 | """
186 | Parse one of the precipitation columns AA1-4
187 | """
188 | if data[col].isnull().all():
189 | return pd.DataFrame()
190 | datacol = data[[col]].copy()
191 | # Split the column to get the period first
192 | datacol[["period", f"{col}_remain"]] = datacol[col].str.split(",", n=1, expand=True)
193 | # Remove weird periods to avoid unexpected errors
194 | datacol = datacol[datacol["period"].isin(["01", "03", "06", "12", "24"])]
195 | if len(datacol) == 0:
196 | return pd.DataFrame()
197 | # Set the period as index and unstack so that different periods are converted to different columns
198 | datacol = datacol.set_index("period", append=True)[f"{col}_remain"]
199 | datacol = datacol.unstack("period")
200 | # Rename the columns according to the period, e.g., 03 -> ra3
201 | datacol.columns = [f"ra{item.lstrip('0')}" for item in datacol.columns]
202 | # Further split the remaining sections
203 | for var in datacol.columns:
204 | datacol[[var, f"{var}_cond", f"{var}_qc"]] = datacol[var].str.split(",", expand=True)
205 | return datacol
206 |
207 |
208 | def parse_precipitation_col(data):
209 | """
210 | Process precipitation columns
211 | Currently AA1-4 columns are used. Column format: 24,0073,3,1
212 | The items are period, depth, condition, quality. Condition is not used currently
213 | It is more complex than other variables as values during different periods are stored in same columns
214 | It is needed to separate them to different columns
215 | Steps:
216 | 1. Separate and recombine columns by period
217 | 2. Set values flagged as erroneous/missing to NaN
218 | 3. Convert to float in mm
219 | """
220 | for col in ["AA1", "AA2", "AA3", "AA4"]:
221 | if col in data.columns:
222 | datacol = parse_single_precipitation_col(data, col)
223 | # Same variable (e.g., ra24) may be stored in different original columns
224 | # Combine_first so that same variables can be merged to the same columns
225 | data = data.combine_first(datacol)
226 | else:
227 | # Assuming that the remaining columns are also not present
228 | break
229 | # Quality status treated as valid records. 3/7 indicates erroneous value
230 | for col in [item for item in data.columns if item.startswith("ra") and item[2:].isdigit()]:
231 | data[col] = data[col].where(data[col] != "9999", pd.NA)
232 | data[col] = data[col].where(~data[f"{col}_qc"].isin(ERRONEOUS_FLAGS), pd.NA)
233 | data[col] = data[col].astype("Float32") / 10
234 | data = data.drop(columns=[f"{col}_cond", f"{col}_qc"])
235 | data = data.drop(columns=["AA1", "AA2", "AA3", "AA4"], errors="ignore")
236 | return data
237 |
238 |
239 | def parse_single_file(fpath, fpath_last_year):
240 | """
241 | Parse columns of each variable in a single ISD file
242 | """
243 | # Gxn for cloud cover, MA1 for surface pressure, AAn for precipitation
244 | cols_var = [
245 | "TMP",
246 | "DEW",
247 | "WND",
248 | "SLP",
249 | "MA1",
250 | "AA1",
251 | "AA2",
252 | "AA3",
253 | "AA4",
254 | "GF1",
255 | *[f"GA{i}" for i in range(1, 7)],
256 | *[f"GD{i}" for i in range(1, 7)],
257 | *[f"GG{i}" for i in range(1, 7)],
258 | ]
259 | cols = ["DATE"] + list(cols_var)
260 |
261 | def _load_csv(fpath):
262 | return pd.read_csv(fpath, parse_dates=["DATE"], usecols=lambda c: c in set(cols), low_memory=False)
263 |
264 | data = _load_csv(fpath)
265 | if fpath_last_year is not None and os.path.exists(fpath_last_year):
266 | data_last_year = _load_csv(fpath_last_year)
267 | # Load the last day of the last year for better hourly aggregation
268 | data_last_year = data_last_year.loc[
269 | (data_last_year["DATE"].dt.month == 12) & (data_last_year["DATE"].dt.day == 31)
270 | ]
271 | data = pd.concat([data_last_year, data], ignore_index=True)
272 | data = data[[item for item in cols if item in data.columns]]
273 |
274 | data = parse_temperature_col(data)
275 | data = parse_wind_col(data)
276 | data = parse_cloud_col(data)
277 | data = parse_surface_pressure_col(data)
278 | data = parse_sea_level_pressure_col(data)
279 | data = parse_precipitation_col(data)
280 |
281 | data = data.rename(columns={"DATE": "time"})
282 | value_cols = [col for col in data.columns if col != "time"]
283 | # drop all-NaN rows
284 | data = data[["time"] + value_cols].sort_values("time").dropna(how="all", subset=value_cols)
285 | return data
286 |
287 |
288 | def aggregate_to_hourly(data):
289 | """
290 | Aggregate rows that represent same hour to one row
291 | Order the rows from same hour by difference from the top of the hour,
292 | then use ffill at each hour to get the nearest valid values for each variable
293 | Specifically, For t/td, avoid combining two records from different rows together
294 | """
295 | data["hour"] = data["time"].dt.round("h")
296 | # Sort data by difference from the top of the hour so that bfill can be applied
297 | # to give priority to the closer records
298 | data["hour_dist"] = (data["time"] - data["hour"]).dt.total_seconds().abs() // 60
299 | data = data.sort_values(["hour", "hour_dist"])
300 |
301 | if data["hour"].duplicated().any():
302 | # Consruct a new column of (t, td) tuples. Values are not NaN only when both of them are valid
303 | data["t_td"] = data.apply(
304 | lambda row: (row["t"], row["td"]) if row[["t", "td"]].notnull().all() else pd.NA, axis=1
305 | )
306 | # For same hour, fill NaNs at the first row in the order of difference from the top of the hour
307 | data = data.groupby("hour").apply(lambda df: df.bfill().iloc[0], include_groups=False)
308 |
309 | # 1st priority: for hours that has both valid t and td originally (decided by t_td),
310 | # fill values to t_new and td_new
311 | # Specifically, for corner cases that all t_td is NaN, we need to convert pd.NA to (pd.NA, pd.NA)
312 | # so that to_list() will not raise an error
313 | data["t_td"] = data["t_td"].apply(lambda item: (pd.NA, pd.NA) if pd.isna(item) else item)
314 | data[["t_new", "td_new"]] = pd.DataFrame(data["t_td"].to_list(), index=data.index)
315 | # 2nd priority: Remaining hours can only provide at most one of t and td. Try to fill t first
316 | rows_to_fill = data[["t_new", "td_new"]].isnull().all(axis=1)
317 | data.loc[rows_to_fill, "t_new"] = data.loc[rows_to_fill, "t"]
318 | # 3nd priority: Remaining hours has no t during time window. Try to fill td
319 | rows_to_fill = data[["t_new", "td_new"]].isnull().all(axis=1)
320 | data.loc[rows_to_fill, "td_new"] = data.loc[rows_to_fill, "td"]
321 |
322 | data = data.drop(columns=["t", "td", "t_td"]).rename(columns={"t_new": "t", "td_new": "td"})
323 |
324 | data = data.reset_index(drop=True)
325 | data["time"] = data["time"].dt.round("h")
326 | return data
327 |
328 |
329 | def post_process(data, year):
330 | """
331 | Some post-processing steps after aggregation
332 | """
333 | data = data.set_index("time")
334 | sorted_ra_columns = sorted([col for col in data.columns if col.startswith("ra")], key=lambda x: int(x[2:]))
335 | other_columns = [item for item in ["t", "td", "ws", "wd", "sp", "msl", "c"] if item in data.columns]
336 | data = data[other_columns + sorted_ra_columns]
337 | data = data[f"{year}-01-01":f"{year}-12-31"]
338 | return data
339 |
340 |
341 | def pipeline(input_path, output_dir, year, overwrite=True):
342 | """
343 | The pipeline function for processing a single ISD file
344 | """
345 | output_path = os.path.join(output_dir, os.path.basename(input_path))
346 | if not overwrite and os.path.exists(output_path):
347 | return
348 | input_dir = os.path.dirname(input_path)
349 | if input_dir.endswith(str(year)):
350 | input_path_last_year = os.path.join(input_dir[:-4] + str(year - 1), os.path.basename(input_path))
351 | else:
352 | input_path_last_year = None
353 | data = parse_single_file(input_path, input_path_last_year)
354 | data = aggregate_to_hourly(data)
355 | data = post_process(data, year)
356 | data.astype("Float32").to_csv(output_path, float_format="%.1f")
357 |
358 |
359 | def main(args):
360 | output_dir = os.path.join(args.output_dir, str(args.year))
361 | os.makedirs(output_dir, exist_ok=True)
362 | input_dir = os.path.join(args.input_dir, str(args.year))
363 | input_list = [os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith(".csv")]
364 | if args.num_proc == 1:
365 | for fpath in tqdm(input_list):
366 | pipeline(fpath, output_dir=output_dir, year=args.year, overwrite=args.overwrite)
367 | else:
368 | func = partial(pipeline, output_dir=output_dir, year=args.year, overwrite=args.overwrite)
369 | with multiprocessing.Pool(args.num_proc) as pool:
370 | for _ in tqdm(pool.imap(func, input_list), total=len(input_list), desc="Processing files"):
371 | pass
372 |
373 |
374 | if __name__ == "__main__":
375 | parser = argparse.ArgumentParser()
376 | parser.add_argument("-i", "--input-dir", type=str, required=True, help="Directory of ISD csv files")
377 | parser.add_argument("-o", "--output-dir", type=str, required=True, help="Directory of output NetCDF files")
378 | parser.add_argument("-y", "--year", type=int, required=True, help="Target year")
379 | parser.add_argument("--num-proc", type=int, default=16, help="Number of parallel processes")
380 | parser.add_argument("--overwrite", action="store_true", help="Overwrite existing files")
381 | parsed_args = parser.parse_args()
382 | main(parsed_args)
383 |
--------------------------------------------------------------------------------
/quality_control/station_merging.py:
--------------------------------------------------------------------------------
1 | """
2 | 1. Load the metadata of ISD data online. Post-processing includes:
3 | 1. Drop rows with missing values. These stations are considered not trustworthy
4 | 2. Use USAF and WBAN to create a unique station ID
5 | 2. Calculate pairwise similarity between stations and saved as NetCDF files
6 | - Distance and elevation similarity is calculated using an exponential decay
7 | - Distance of two stations are calculated using great circle distance
8 | 3. Load ISD hourly station data with valid metadata, and merge stations based on both metadata and data similarity
9 | - Metadata similarity is calculated based on a weighted sum of distance, elevation, and name similarity
10 | - Data similarity is calculated based on the proportion of identical values among common data points
11 | 4. The merged data and the similarity will be saved in NetCDF format
12 | """
13 |
14 | import argparse
15 | import logging
16 | import os
17 |
18 | import numpy as np
19 | import pandas as pd
20 | import xarray as xr
21 | from algo.utils import configure_logging
22 | from geopy.distance import great_circle as geodist
23 | from tqdm import tqdm
24 |
25 | # URL of the official ISD metadata file
26 | URL_ISD_HISTORY = "https://www.ncei.noaa.gov/pub/data/noaa/isd-history.txt"
27 | logger = logging.getLogger(__name__)
28 |
29 |
30 | def load_metadata(station_list):
31 | """
32 | Load the metadata of ISD data
33 | """
34 | meta = pd.read_fwf(
35 | URL_ISD_HISTORY,
36 | skiprows=20,
37 | usecols=["USAF", "WBAN", "STATION NAME", "CTRY", "CALL", "LAT", "LON", "ELEV(M)"],
38 | dtype={"USAF": str, "WBAN": str},
39 | )
40 | # Drop rows with missing values. These stations are considered not trustworthy
41 | meta = meta.dropna(how="any", subset=["LAT", "LON", "STATION NAME", "ELEV(M)", "CTRY"])
42 | meta["STATION"] = meta["USAF"] + meta["WBAN"]
43 | meta = meta[["STATION", "CALL", "STATION NAME", "CTRY", "LAT", "LON", "ELEV(M)"]].set_index("STATION", drop=True)
44 | return meta[meta.index.isin(station_list)]
45 |
46 |
47 | def calc_distance_similarity(latlon1, latlon2, scale_dist=25):
48 | distance = geodist(latlon1, latlon2).kilometers
49 | similarity = np.exp(-distance / scale_dist)
50 | return similarity
51 |
52 |
53 | def calc_elevation_similarity(elevation1, elevation2, scale_elev=100):
54 | similarity = np.exp(-abs(elevation1 - elevation2) / scale_elev)
55 | return similarity
56 |
57 |
58 | def calc_id_similarity(ids1, ids2):
59 | """
60 | Compare the USAF/WBAN/CALL IDs of two stations
61 | """
62 | usaf1, wban1, call1, ctry1 = ids1
63 | usaf2, wban2, call2, ctry2 = ids2
64 | if usaf1 != "999999" and usaf1 == usaf2:
65 | return 1
66 | if wban1 != "99999" and wban1 == wban2:
67 | return 1
68 | if call1 == call2:
69 | return 1
70 | # A special case for the CALL ID, e.g., KAIO and AIO are the same stations
71 | if isinstance(call1, str) and len(call1) == 3 and ("K" + call1) == call2:
72 | return 1
73 | if isinstance(call2, str) and len(call2) == 3 and ("K" + call2) == call1:
74 | return 1
75 | # For a special case in Germany, 09xxxx and 10xxxx are the same stations
76 | # See https://gi.copernicus.org/articles/5/473/2016/gi-5-473-2016.html
77 | if usaf1.startswith("09") and usaf2.startswith("10") and usaf1[2:] == usaf2[2:] and ctry1 == ctry2 == "DE":
78 | return 1
79 | return 0
80 |
81 |
82 | def calc_name_similarity(name1, name2):
83 | """
84 | Jaccard Index for calculating name similarity
85 | """
86 | set1 = set(name1)
87 | set2 = set(name2)
88 | intersection = set1.intersection(set2)
89 | union = set1.union(set2)
90 | jaccard_index = len(intersection) / len(union)
91 | return jaccard_index
92 |
93 |
94 | def post_process_similarity(similarity, ids):
95 | """
96 | Post-process the similarity matrix
97 | """
98 | # Fill the lower triangle of the matrix
99 | similarity = similarity + similarity.T
100 | # Set the diagonal (self-similarity) to 1
101 | np.fill_diagonal(similarity, 1)
102 | similarity = xr.DataArray(similarity, dims=["station1", "station2"], coords={"station1": ids, "station2": ids})
103 | return similarity
104 |
105 |
106 | def calc_similarity(meta, scale_dist=25, scale_elev=100):
107 | """
108 | Calculate pairwise similarity between stations
109 | 1. Distance similarity: great circle distance between two stations using an exponential decay
110 | 2. Elevation similarity: absolute difference of elevation between two stations using an exponential decay
111 | 3. ID similarity: whether the USAF/WBAN/CALL IDs of two stations are the same
112 | """
113 | latlon = meta[["LAT", "LON"]].apply(tuple, axis=1).values
114 | elev = meta["ELEV(M)"].values
115 | usaf = meta.index.str[:6].values
116 | wban = meta.index.str[6:].values
117 | name = meta["STATION NAME"].values
118 | ids = list(zip(usaf, wban, meta["CALL"].values, meta["CTRY"].values))
119 | num = len(meta)
120 | dist_similarity = np.zeros((num, num))
121 | elev_similarity = np.zeros((num, num))
122 | id_similarity = np.zeros((num, num))
123 | name_similarity = np.zeros((num, num))
124 | for idx1 in tqdm(range(num - 1), desc="Calculating similarity"):
125 | for idx2 in range(idx1 + 1, num):
126 | dist_similarity[idx1, idx2] = calc_distance_similarity(latlon[idx1], latlon[idx2], scale_dist)
127 | elev_similarity[idx1, idx2] = calc_elevation_similarity(elev[idx1], elev[idx2], scale_elev)
128 | id_similarity[idx1, idx2] = calc_id_similarity(ids[idx1], ids[idx2])
129 | name_similarity[idx1, idx2] = calc_name_similarity(name[idx1], name[idx2])
130 | dist_similarity = post_process_similarity(dist_similarity, meta.index.values)
131 | elev_similarity = post_process_similarity(elev_similarity, meta.index.values)
132 | id_similarity = post_process_similarity(id_similarity, meta.index.values)
133 | name_similarity = post_process_similarity(name_similarity, meta.index.values)
134 | similarity = xr.merge(
135 | [
136 | dist_similarity.rename("dist"),
137 | elev_similarity.rename("elev"),
138 | id_similarity.rename("id"),
139 | name_similarity.rename("name"),
140 | ]
141 | )
142 | return similarity
143 |
144 |
145 | def load_raw_data(data_dir, station_list):
146 | """
147 | Load raw hourly ISD data in csv files
148 | """
149 | data = []
150 | for stn in tqdm(station_list, desc="Loading data"):
151 | df = pd.read_csv(os.path.join(data_dir, f"{stn}.csv"), index_col="time", parse_dates=["time"])
152 | df["station"] = stn
153 | df = df.set_index("station", append=True)
154 | data.append(df.to_xarray())
155 | data = xr.concat(data, dim="station")
156 | data["time"] = pd.to_datetime(data["time"])
157 | return data
158 |
159 |
160 | def calc_meta_similarity(similarity):
161 | """
162 | Calculate the metadata similarity according to horizontal distance, elevation, and name similarity
163 | """
164 | meta_simi = (similarity["dist"] * 9 + similarity["elev"] * 1 + similarity["name"] * 5) / 15
165 | # meta_simi is set to 1 if IDs are the same,
166 | # or it is set to a weighted sum of distance, elevation, and name similarity
167 | meta_simi = np.maximum(meta_simi, similarity["id"])
168 | # set the diagonal and lower triangle to NaN to avoid duplicated pairs
169 | rows, cols = np.indices(meta_simi.shape)
170 | meta_simi.values[rows >= cols] = np.nan
171 | return meta_simi
172 |
173 |
174 | def need_merge(da, stn_source, stn_target, threshold=0.7):
175 | """
176 | Distinguish whether two stations need to be merged based on the similarity of their data
177 | It is possible that one of them only has few data points
178 | In this case, it can be treated as removing low-quality stations
179 | """
180 | ts1 = da.sel(station=stn_source)
181 | ts2 = da.sel(station=stn_target)
182 | diff = np.abs(ts1 - ts2)
183 | if (ts1.dropna(dim="time") % 1 < 1e-3).all() or (ts2.dropna(dim="time") % 1 < 1e-3).all():
184 | max_diff = 0.5
185 | else:
186 | max_diff = 0.1
187 | data_simi = (diff <= max_diff).sum() / diff.notnull().sum()
188 | return data_simi.item() >= threshold
189 |
190 |
191 | def merge_pairs(ds1, ds2):
192 | """
193 | Merge two stations. Each of the two ds should have only one station
194 | If there are only one variable, fill the missing values in ds1 with ds2
195 | If there are more than one variables, to ensure that all variables are from the same station,
196 | for each timestep, the ds with more valid variables will be selected
197 | """
198 | if len(ds1.data_vars) == 1:
199 | return ds1.fillna(ds2)
200 | da1 = ds1.to_array()
201 | da2 = ds2.to_array()
202 | mask = da1.count(dim="variable") >= da2.count(dim="variable")
203 | return xr.where(mask, da1, da2).to_dataset(dim="variable")
204 |
205 |
206 | def merge_stations(ds, meta_simi, main_var, appendant_var=None, meta_simi_th=0.35, data_simi_th=0.7):
207 | """
208 | For ds, merge stations based on metadata similarity and data similarity
209 | """
210 | result = []
211 | # Flags to avoid duplications
212 | is_merged = xr.DataArray(
213 | np.full(ds["station"].size, False), dims=["station"], coords={"station": ds["station"].values}
214 | )
215 | for station in tqdm(ds["station"].values, desc=f"Merging {main_var}"):
216 | if is_merged.sel(station=station).item():
217 | continue
218 | # Station list to be merged
219 | merged_stations = [station]
220 | # Candidates that pass the metadata similarity threshold
221 | candidates = meta_simi["station2"][meta_simi.sel(station1=station) >= meta_simi_th].values
222 | # Stack to store the station pairs to be checked
223 | stack = [(station, item) for item in candidates]
224 | # Search for all stations that need to be merged
225 | # If A and B should be merged, and B and C should be merged, then all of them are merged together
226 | while stack:
227 | stn_source, stn_target = stack.pop()
228 | if stn_target in merged_stations:
229 | continue
230 | if need_merge(ds[main_var], stn_source, stn_target, threshold=data_simi_th):
231 | is_merged.loc[stn_target] = True
232 | merged_stations.append(stn_target)
233 | candidates = meta_simi["station2"][meta_simi.sel(station1=stn_target) >= meta_simi_th].values
234 | stack.extend([(stn_target, item) for item in candidates])
235 | # Merge stations according to the number of valid data points
236 | num_valid = ds[main_var].sel(station=merged_stations).notnull().sum(dim="time")
237 | sorted_stns = num_valid["station"].sortby(num_valid).values
238 | variables = [main_var] + appendant_var if appendant_var is not None else [main_var]
239 | stn_data = ds[variables].sel(station=sorted_stns[0])
240 | for target_stn in sorted_stns[1:]:
241 | stn_data = merge_pairs(stn_data, ds[variables].sel(station=target_stn))
242 | stn_data = stn_data.assign_coords(station=station)
243 | result.append(stn_data)
244 | result = xr.concat(result, dim="station")
245 | return result
246 |
247 |
248 | def merge_all_variables(data, meta_simi, meta_simi_th=0.35, data_simi_th=0.7):
249 | """
250 | Merge stations for each variable
251 | The key is the main variable used to compare, and the value is the list of appendant variables
252 | """
253 | variables = {
254 | "t": ["td"],
255 | "ws": ["wd"],
256 | "sp": [],
257 | "msl": [],
258 | "c": [],
259 | "ra1": ["ra3", "ra6", "ra12", "ra24"],
260 | }
261 | merged = []
262 | for var, app_vars in variables.items():
263 | ret = merge_stations(data, meta_simi, var, app_vars, meta_simi_th=meta_simi_th, data_simi_th=data_simi_th)
264 | merged.append(ret)
265 | merged = xr.merge(merged).dropna(dim="station", how="all")
266 | return merged
267 |
268 |
269 | def assign_meta_coords(ds, meta):
270 | meta = meta.loc[ds["station"].values]
271 | ds = ds.assign_coords(
272 | call=("station", meta["CALL"].values),
273 | name=("station", meta["STATION NAME"].values),
274 | lat=("station", meta["LAT"].values.astype(np.float32)),
275 | lon=("station", meta["LON"].values.astype(np.float32)),
276 | elev=("station", meta["ELEV(M)"].values.astype(np.float32)),
277 | )
278 | return ds
279 |
280 |
281 | def main(args):
282 | station_list = [item.rsplit(".", 1)[0] for item in os.listdir(args.data_dir) if item.endswith(".csv")]
283 | meta = load_metadata(station_list)
284 |
285 | similarity = calc_similarity(meta)
286 | os.makedirs(args.output_dir, exist_ok=True)
287 | similarity_path = os.path.join(args.output_dir, f"similarity_{args.scale_dist}km_{args.scale_elev}m.nc")
288 | similarity.astype("float32").to_netcdf(similarity_path)
289 | logger.info(f"Saved similarity to {similarity_path}")
290 |
291 | data = load_raw_data(args.data_dir, meta.index.values)
292 | # some stations have no data, they have already been removed in load_raw_data
293 | meta = meta.loc[data["station"].values]
294 |
295 | similarity = similarity.sel(station1=meta.index.values, station2=meta.index.values)
296 | meta_simi = calc_meta_similarity(similarity)
297 | merged = merge_all_variables(data, meta_simi, meta_simi_th=args.meta_simi_th, data_simi_th=args.data_simi_th)
298 | # Save all metadata information in the NetCDF file
299 | merged = assign_meta_coords(merged, meta)
300 | data_path = os.path.join(args.output_dir, "data.nc")
301 | merged.astype(np.float32).to_netcdf(data_path)
302 | logger.info(f"Saved data to {data_path}")
303 |
304 |
305 | if __name__ == "__main__":
306 | parser = argparse.ArgumentParser()
307 | parser.add_argument(
308 | "-o", "--output-dir", type=str, required=True, help="Directory of output similarity and data files"
309 | )
310 | parser.add_argument(
311 | "-d", "--data-dir", type=str, required=True, help="Directory of ISD csv files from the previous step"
312 | )
313 | parser.add_argument("--scale-dist", type=int, default=25, help="e-fold scale of distance similarity")
314 | parser.add_argument("--scale-elev", type=int, default=100, help="e-fold scale of elevation similarity")
315 | parser.add_argument("--meta-simi-th", type=float, default=0.35, help="Threshold of metadata similarity")
316 | parser.add_argument("--data-simi-th", type=float, default=0.7, help="Threshold of data similarity")
317 | parser.add_argument("--verbose", type=int, default=1, help="Verbosity level (int >= 0)")
318 | parsed_args = parser.parse_args()
319 | configure_logging(parsed_args.verbose)
320 | main(parsed_args)
321 |
--------------------------------------------------------------------------------